# read a Json, then download all data needed after

import pandas as pd
import pyodbc
import ky
import datetime
import sys
import time
sdk = ky.Api("http://data-api.kuaiyutech.com/api.rpc")
import os
import logging

logger = logging.getLogger(__name__)


from .apiParent import ApiParent
#print(__file__)
#print(os.path.dirname(os.path.abspath(__file__)))
#print(os.path.dirname(os.path.abspath('initRobot.py')))
try:
    
    from ..BT.dataTable import DataTable
    logger.setLevel(logging.CRITICAL)
    from ..util.kySDK import *
    from .apiParent import ApiParent
except:
    import sys   
    #D:\Project\益盟\回测框架\KYBT\BT
    sys.path.insert(0, 'D:/Project/YM/BackTesting Frame/KYBT/BT') 
    sys.path.insert(0, 'D:/Project/YM/BackTesting Frame/KYBT/util')
    #print(sys.path)
    from emClasses import *
    from dataTable import DataTable
    from kySDK import *
    
    logger.setLevel(logging.WARNING)
    
class highwayAPI(ApiParent):

    '''
    get data from SQL and organize it 
    '''
    def __init__(self): 
        super().__init__()

    def getAllStockCode(self,oneDay):
        '''
        return ['600000','000002',...]
        '''
        itemString = 'SecuCode'
        query = '''
        SELECT DISTINCT %s
        
        FROM
            [dbo].[EM_QUOTE_STOCK_DAILY]
        WHERE
        TradeDate = '%s 00:00:00.000'
        AND 
        left(SecuCode,1) != 9 
        AND
        left(SecuCode,1) != 2
        '''%(itemString,oneDay)

        self.cursor.execute(query)

        raw = self.cursor.fetchall()
        if not raw:
            logger.critical( 'Daily Quote all has no data, return None')
            return None
        
        allStockPD = self.pandaData(raw,'SecuCode')
        allStockPD.fillna(value=0, inplace=True) # to replace None with value 0

        #print('hahahahaha',allStockPD[itemString].tolist())
        return allStockPD[itemString].tolist()

    def getIndexData(self,index,startDate,endDate):

        if isinstance(index,str):
            if index == 'HS300':
                index = "('000300')"
            elif index == 'ZZ500':
                index = "('000905')"
            elif index == 'SZ50':
                index = "('000016')"
            else:
                raise ValueError('Index input is incorrect')
        else:
            raise ValueError('Index input should be str')

        if endDate >= self.getYesterday():
            endDate = self.getYesterday()

        itemID = '''SecuCode,SecuAbbr,TradeDate,TradeStatus,PreClosePrice,OpenPrice,ClosePrice,HighPrice,LowPrice,
                    TurnoverVolume,TurnoverValue,ChangeRatio'''

        #
        query = '''
                SELECT
                    %s
                FROM
                    embase2.[dbo].[EM_QUOTE_INDEX_DAILY]
                WHERE
                    SecuCode in %s
                AND TradeDate >= '%s 00:00:00.000'
                AND TradeDate <= '%s 00:00:00.000'
                ORDER BY
                TradeDate DESC
                '''%(itemID,index,startDate,endDate)

        #print(query)
        self.cursor.execute(query)
        raw = self.cursor.fetchall()
        if not raw:
            return None

        allStockPD = self.pandaData(raw,itemID)
        allStockPD.fillna(value=0, inplace=True) # to replace None with value 0

        return self.pdToObjList(allStockPD,'Index',itemID)

    @staticmethod
    def backwardStartDate(inputJson,startDate,tableName):
        # 此表如果需要最近三年，则需要在下载数据的时候把endDate提前3年，保险起见，提前5年
        # 最近一年， 提前3年
        # 最近一期，提前2年
        # {'主营收入同比':['最近三年',None,-100]}
        
        for criteria in inputJson['criteriaDict']:
            # 如果'最近三年' 在 EM_QUOTE_STOCK_DAILY 这个数据表里
            #print(inputJson[ID]['criteriaDict'][criteria])
            #print(DataTable[criteria][1])
            if inputJson['criteriaDict'][criteria][0] == '最近三年' and DataTable[criteria][1] == tableName:
                startDate = str(eval('int(startDate[:4]) - 5')) + startDate[4:]
                break  
            elif inputJson['criteriaDict'][criteria][0] == '最近一年' and DataTable[criteria][1] == tableName:
                startDate = str(eval('int(startDate[:4]) - 3')) + startDate[4:]
                break
            elif  inputJson['criteriaDict'][criteria][0] == '最近一期' and DataTable[criteria][1] == tableName:
                startDate = str(eval('int(startDate[:4]) - 2')) + startDate[4:]
                break
        return startDate

    def getTableData(self,inputJson,tableName):
        '''
        get data from embase2.tableName
        '''


        stockList = inputJson['portforlio']
        startDate = inputJson['startDate']
        endDate = inputJson['endDate']

        if endDate < startDate:
            raise ValueError('endDate cannot be smaller than startDate')
        
        # index pool
        if isinstance(stockList,str) and stockList in ['HS300','ZZ500','SZ50']:
            if stockList == 'HS300':
                stockList = 'SHSE.000300'
            elif stockList == 'ZZ500':
                stockList = 'SHSE.000905'
            elif stockList == 'SZ50':
                stockList = 'SHSE.000016'
            
            stockList = [stockCode[-6:] for stockCode in getIndexPoolFromKy(stockList)] #['SHSE.600000','SHSE.000002'] 

        elif isinstance(stockList,str) and portforlio == 'all':
            
            stockList = self.getAllStockCode(startDate)
        
        elif isinstance(stockList,list):
            pass

        else: 
            raise ValueError('stockList input is incorrect')

        # date
        if tableName in ['EM_QUOTE_STOCK_DAILY','CY_ValueDaily_Quote']:
            date = 'TradeDate'
        else:
            date = 'EndDate'

        # itemID objName
        if tableName == 'EM_QUOTE_STOCK_DAILY':
            itemID = '''SecuCode,SecuAbbr,SecuMarket,TradeDate,TradeStatus,PreClosePrice,OpenPrice,ClosePrice,HighPrice,
                    LowPrice,TurnoverVolume,TurnoverValue,TurnoverDeals,Amplitude,ChangeRatio,
                    FlowShares,FlowMarketValue,TotalEquity,IssueEquity,TotalValue,LLimitUpTime,RiseOrDown'''
            objName = 'Stock'
                        
        elif tableName == 'CY_FinacialRatio_RP':
            itemID = '''SecuCode,SecuAbbr,EndDate,PublDate,InventoryTurnover,AccountsReceivableTurnover,FixedAssetsTurnover,
                    OperatingRevenueYoY,DeductNetProfitYoY'''
            objName = 'FinancialRatioRP'
            startDate = self.backwardStartDate(inputJson,startDate,tableName)

        elif tableName == 'CY_FinacialIndicators_RP':
            itemID = '''SecuCode,SecuAbbr,EndDate,AdvanceReceipts,TotalAssets,EquityBelongedToPC,TotalLiability,TotalCurrentAssets,
                    TotalCurrentLiability,CashOrDepositInCentralBank,NetProfit,NetOperateCashFlow'''
            objName = 'FinancialIndicator'
            startDate = self.backwardStartDate(inputJson,startDate,tableName)

        elif tableName == 'CY_ValueDaily_Quote':
            itemID = '''SecuCode,SecuAbbr,TradeDate,ReportPeriod,EPS,PE1,PE2,PB,PS,Dividend1Y,MarketValue,ClosePrice,ValidShares,
                    RMBShares,Bshares'''
            objName ='ValueDaily'
            startDate = self.backwardStartDate(inputJson,startDate,tableName)
        else:
            raise ValueError('%s is not a defined Class'%tableName)


        # 
        if len(stockList) == 1:
            stockList = '(%s)'%stockList[0]
        else:
            stockList = tuple(stockList)
        #
        query = '''
                SELECT
                    %s
                FROM
                    embase2.[dbo].[%s]
                WHERE
                    SecuCode in %s
                AND %s >= '%s 00:00:00.000'
                AND %s <= '%s 00:00:00.000'
                ORDER BY
                SecuCode DESC,
                %s DESC
                '''%(itemID,tableName,stockList,date,startDate,date,endDate,date)
        
        #print(query)
        self.cursor.execute(query)

        raw = self.cursor.fetchall()
        #print(raw)
        if not raw:
            return None

        allStockPD = self.pandaData(raw,itemID)
        allStockPD.fillna(value=0, inplace=True) # to replace None with value 0

        return self.pdToObjList(allStockPD,objName,itemID)


    def getAllNeeded(self,inputJson):
        """

        根据inputJson获取所有需要的数据，bulk格式为：

        {'EM_QUOTE_STOCK_DAILY':{date:[stockObj]},
        'CY_FinacialRatio_RP':{date:[financialRatioObj]},
        ...

        'EM_QUOTE_INDEX_DAILY':{date}:[indexObj]}

        """

        endDate = inputJson['endDate']
        startDate = inputJson['startDate']
        benchmark = inputJson['benchmark']
        bulk = dict()
        print('开始下载并整理数据...')
        a = time.time()
        for tableName in super().getNeededTable(inputJson):
            #print('tableName: ',tableName)
            bulk[tableName] = self.getTableData(inputJson,tableName)
            print('已经完成%s所需的下载'%tableName)
        # 指数要单独，因为input不一样  
        bulk['EM_QUOTE_INDEX_DAILY'] = self.getIndexData(benchmark,startDate,endDate)
        print('下载整理数据完毕，共耗时%0.1f秒'%(time.time() - a))
        print('包含的数据表有：',list(bulk.keys()))
        print(sys.getsizeof(bulk))
        return bulk


if __name__ == '__main__':

    def extractBulk(bulk,portforlio):
        '''
        extract from self.getAllNeeded(inputJson)
        return is 
        {'600485': {'主营收入同比': [Decimal('13.2225'), Decimal('33.8532'), Decimal('851.6680')]}, 
         '600000': {'主营收入同比': [Decimal('18.9713'), Decimal('18.9713'), Decimal('23.1625'), Decimal('23.1625'), Decimal('20.5
          697'), Decimal('20.5697')]}, 
         '000002': {'主营收入同比': [Decimal('33.5828'), Decimal('8.1002'), Decimal('31.3263')]}}
        '''
        resultDict = dict()

        for stock in portforlio:
            for criteria in inputJson['criteriaDict']:
                if inputJson['criteriaDict'][criteria][0] == '最近三年':
                    for spotDate in list(bulk['CY_FinacialRatio_RP'].keys()): #dict_keys(['2015-12-31', '2015-09-30', '2015-06-30', '2015-03-31', '2014-12-31'])
                        if '-12-31' in spotDate:
                            firstYear = spotDate  #'2015-12-31'
                            secondYear = str(eval('int(firstYear[:4]) - 1')) + firstYear[4:]    #'2014-12-31'
                            thirdYear = str(eval('int(firstYear[:4]) - 2')) + firstYear[4:]     #'2013-12-31'
                        break
            
                    #找到相对应的criteria数据
                    criteriaList = []          
                    for year in [firstYear,secondYear,thirdYear]:
                        for finObj in bulk['CY_FinacialRatio_RP'][year]:
                            #print(finObj.SecuCode,finObj.EndDate)
                            if finObj.SecuCode == stock and finObj.EndDate == year:
                                criteriaList.append(eval('finObj.%s'%(DataTable[criteria][0])))

                    resultDict[stock] = {criteria:criteriaList}
        
                elif inputJson['criteriaDict'][criteria][0] == '最近一年':
                    for spotDate in list(bulk['CY_FinacialRatio_RP'].keys()): #dict_keys(['2015-12-31', '2015-09-30', '2015-06-30', '2015-03-31', '2014-12-31'])
                        if '-12-31' in spotDate:
                            firstYear = spotDate  #'2015-12-31'
                        break
            
                    #找到相对应的criteria数据
                    criteriaList = []          

                    for finObj in bulk['CY_FinacialRatio_RP'][firstYear]:
                        #print(finObj.SecuCode,finObj.EndDate)
                        if finObj.SecuCode == stock and finObj.EndDate == firstYear:
                            criteriaList.append(eval('finObj.%s'%(DataTable[criteria][0])))

                    resultDict[stock] = {criteria:criteriaList}

                elif inputJson[ID]['criteriaDict'][criteria][0] == '最近一期':
                    firstPeriod = list(bulk['CY_FinacialRatio_RP'].keys())[0]  #'2015-12-31'

                    #找到相对应的criteria数据
                    criteriaList = []          
                    for finObj in bulk['CY_FinacialRatio_RP'][firstPeriod]:
                        #print(finObj.SecuCode,finObj.EndDate)
                        if finObj.SecuCode == stock and finObj.EndDate == firstPeriod:
                            criteriaList.append(eval('finObj.%s'%(DataTable[criteria][0])))

                        resultDict[stock] = {criteria:criteriaList}
                else:
                    raise ValueError('输入要是最近一期，最近一年，最近三年') 

        print('resultDict: ',resultDict,)    
        return resultDict


    def screenStock(inputJson,resultDict):
        """
        {'600485': {'主营收入同比': [Decimal('13.2225'), Decimal('33.8532'), Decimal('851.6680')]}, 
         '600000': {'主营收入同比': [Decimal('18.9713'), Decimal('18.9713'), Decimal('23.1625'), Decimal('23.1625'), Decimal('20.5
          697'), Decimal('20.5697')]}, 
         '000002': {'主营收入同比': [Decimal('33.5828'), Decimal('8.1002'), Decimal('31.3263')]}}
            

         {'600485':True,
         '600000':False,
         ...
         }
        """
        screenResult = dict()

        for stock,value in resultDict.items():
            multiple = []
            for criteria,valueList in value.items():

                upperbound = inputJson[ID]['criteriaDict'][criteria][1]
                lowerbound = inputJson[ID]['criteriaDict'][criteria][2]

                if upperbound == None:
                    upperbound = 1000000

                if lowerbound == None:
                    lowerbound = -1000000

                # 全真则为真，否则为假
                single = False if False in [True if x > lowerbound and x < upperbound else False for x in valueList] else True #[Decimal('13.2225'), Decimal('33.8532'), Decimal('851.6680')]
                multiple.append(single)

            consilidatedResult = False if False in multiple else True
            screenResult[stock] = consilidatedResult

        print('True or False result: ',screenResult)
        return screenResult
                
    def getTrueStock(screenResult):
        mylist = []
        for key,value in screenResult.items():
            if value == True:
                mylist.append(key)
        print('getTrueStock: ',mylist)
        return mylist

    ID = '1'
    # "criteriaDict" : {'存货周转率':['最近一年',None,-100],"PE(TTM)":["PE","CY_ValueDaily_Quote"]},
    # ['600485','600000','000002'], 
    strategy = \
    {
        "stockPicker":{'pick':False,'dateList':None},
        "criteriaDict" : {'主营收入同比':['最近三年',None,18]},
        "portforlio": ['600485','600000','000002'], 
        "benchmark": 'HS300', 
        "money": 5000000, 
        "startDate": '2014-07-01', 
        "endDate": '2016-02-01', 
        "turnover": 10, 
        "tradingHabit": "OpenPrice",
        "status":"init"
    }

    inputJson = \
    {
        ID:strategy 
    }

    ##### results
    results = \
    {
    "dailyResults":{'2016-01-07':'result1','2016-01-08':'result2',},
        "status":"completed"
    }

    outputJson = \
    {
        ID:'completed'
    }


    #tableName = 'EM_QUOTE_STOCK_DAILY' 
    #tableName = 'CY_FinacialRatio_RP'
    #tableName = 'CY_FinacialIndicators_RP'
    #tableName = 'CY_ValueDaily_Quote'
    #startDate = '2017-01-01'
    #endDate = '2018-02-01'
    #stockID = ['600000','000002']
    
    ##############
    myAPI = highwayAPI()
    myAPI.init()

    portforlio = inputJson[ID]['portforlio']     
    startDate = inputJson[ID]['startDate']     
    endDate = inputJson[ID]['endDate'] 

    #tableList = ['EM_QUOTE_INDEX_DAILY']
    
    bulk = myAPI.getAllNeeded(inputJson)

    #选出当日符合条件的股票,比如 inputJson[ID]['criteriaDict'] = {'主营收入同比':['最近三年',None,-100]}
    #print(bulk['CY_FinacialRatio_RP'].keys())

    resultDict = myAPI.extractBulk(bulk,portforlio)
    screenDict = myAPI.screenStock(inputJson,resultDict)
    screenStocks = myAPI.getTrueStock(screenDict)
    

    




    