import pyodbc
import logging
import pandas as pd
logger = logging.getLogger(__name__)
logger.setLevel(logging.CRITICAL)
import functools
import warnings
import datetime
import sys

import logging
logger = logging.getLogger(__name__)
try:
    from ..emClasses import *
    logger.setLevel(logging.CRITICAL)
    from ..util.kySDK import *
    from ..dataTable import DataTable
except:
    import sys   
    #D:\Project\益盟\回测框架\KYBT\BT
    sys.path.insert(0, 'D:/Project/YM/BackTesting Frame/KYBT/BT') 
    sys.path.insert(0, 'D:/Project/YM/BackTesting Frame/KYBT/util')
    #print(sys.path)
    from emClasses import *
    from dataTable import DataTable
    from kySDK import *
    logger.setLevel(logging.WARNING)

class ApiParent():
    '''
    parent of all apis
    myAPI = APIPARENT('BTMarketTrade')
    myAPI.init()

    or

    myAPI = APIPARENT('BTMarketTrade')
    with myAPI():
        do something
    '''
    # '2009以后不考虑，数据好像有问题'
    FinancialReportDate = {'Mar':'03-31','Jun':'06-30','Sep':'09-30','Dec':'12-31'}
    
    def __init__(self): 
        self.DSN = 'BTMarketTrade'

    def init(self):
        try:
            self.cnxn = pyodbc.connect(DSN = self.DSN,UID="SLReadOnly",PWD="gBNfZ9zPCzVD")
            self.cursor = self.cnxn.cursor()
        except:      
            raise 'connection failed'

    def __enter__(self): # with
        self.init()

    def __exit__(self,*args): # with
        self.cnxn.close()
        self.cnxn = None

    @staticmethod
    def getNeededTable(inputJson):

        tableList = ["EM_QUOTE_STOCK_DAILY"] #
        for key in inputJson['criteriaDict']:
            if key in DataTable:
                tableList.append(DataTable[key][1])  # 把需要用到的数据表加进去
            else:
                raise ValueError('%s is not in the available table'%key)
                
        #print('需要用到的数据表: ', tableList)
        return tableList

    def pdToObjList(self,panda,objName,itemID):
        '''
        dictObj: dict('date':[objList],'date':[objList],...)
        for example:
        {'2017-01-26': [<stock.Stock object at 0x0857FA70>, <stock.Stock object at 0x0857FB30>], 
        '2017-01-25': [<stock.Stock object at 0x0857FB10>, <stock.Stock object at 0x0857FAB0>],
        '''

        if objName in ['Stock','ValueDaily','Index']:
            date = 'TradeDate'
        else:
            date = 'EndDate'

        
        itemList = self.modifyItemString(itemID)

        dateList = [str(x)[:10] for x in set(panda[date])]
        dateList.sort(reverse = True)

        dictObj = dict() 
        for oneDate in dateList:
            pandaSection = panda[panda[date] == oneDate]
            objList = []
            for index,row in pandaSection.iterrows():

                exec("oneObj = %s()"%objName)
                for item in itemList:
                    if item in ['TradeDate','EndDate']:
                        exec("oneObj.%s = str(row['%s'])[:10]"%(item,item))
                    elif item in ['SecuCode','SecuAbbr','SecuMarket','PublDate','TradeStatus']:
                        exec("oneObj.%s = str(row['%s'])"%(item,item))
                    elif item in ['LLimitUpTime']:
                        exec("oneObj.%s = str(row['%s'])[:10]"%(item,item))
                    else:
                        #print(item,row[item])
                        exec("oneObj.%s = float(row['%s'])"%(item,item))
                exec("objList.append(oneObj)")
            dictObj[oneDate] = objList
        #print(dictObj[oneDate][-1].TradeDate)
        #print('dictObj: ',dictObj)
        
        return dictObj    

    def findTableColumns(self,tableName,itemString):
        query = "SELECT %s FROM sys.columns WHERE object_id = OBJECT_ID('dbo.%s')"%(itemString,tableName)
        self.cursor.execute(query)
        logger.warn('findTableColumns:query: ',query)
        columnData = self.cursor.fetchall()
        if not columnData:
            logger.warn( '%s has no data, return None'%tableName )
            return None
        self.columnName = [x[1] for x in columnData]
        return self.columnName
    
    @staticmethod
    def findObj(attr,value,objList):
        #print('findObj',attr,value,objList)
        objListTemp = list()
        for x in objList:
            for y in value:
                #print(y, ' aaaaaaaaaa ',str(eval('x.%s'%attr)))
                if y in str(eval('x.%s'%attr)) :
                    #print(y, ' aaaaaaaaaa ',str(eval('x.%s'%attr)))
                    objListTemp.append(x) 
        return objListTemp

    @staticmethod
    def modifyItemString(itemString):
        '''
        itemString = '\n            SecuCode,SecuAbbr,SecuMarket,TradeDate,TradeStatus,\n            PreClosePrice'
        '''
        listContainer = []
        #print(itemString)
        for x in itemString.split(','):
            if x.startswith('\n'):
                listContainer.append(x[2:].strip())
            else:
                listContainer.append(x)
        return listContainer

    
    def pandaData(self,rawData,itemString):
        PDcontainer = []
        for x in rawData:
            tempData = [y for y in x]
            PDcontainer.append(tempData)
            #print('tempData',tempData,type(tempData))
        #print('PDcontainer:',PDcontainer)
        myPD = pd.DataFrame(PDcontainer)
        #print('myPD,itemString:',myPD,itemString)
        myPD.columns = self.modifyItemString(itemString)

        logger.info(myPD[:10])

        return myPD


    #########################################################################################
    ########################### take care of date ############################################
    
    def getTradingDays(self,startDate,endDate):
        '''
        获得交易日,如果HS300,SZ50,ZZ500当天任何一家有交易，就认为当天是交易日
        '''

        itemString = 'SecuCode,TradeDate'
        index = ('000300','000905','000016')

        query = '''
        SELECT
            %s
        FROM
            embase2.[dbo].[EM_QUOTE_INDEX_DAILY]
        WHERE
                SecuCode in %s
        AND
                TradeDate >= '%s 00:00:00.000'
        AND     TradeDate <= '%s 00:00:00.000'
        ORDER BY TradeDate ASC
        '''%(itemString,index,startDate,endDate)

        #print(query)
        self.cursor.execute(query)

        raw = self.cursor.fetchall()
        logger.info('raw :',raw)
        if not raw:
            logger.critical( 'index market value|trading days| has no data, return None')
            return None
        
        indexPD = self.pandaData(raw,itemString)
        indexPD.fillna(value=0, inplace=True) # to replace None with value 0

        dateList = [str(x)[:10] for x in indexPD['TradeDate'].tolist()]
        dateList = list(set(dateList))
        dateList.sort()
        today = datetime.datetime.strftime(datetime.datetime.today(),'%Y-%m-%d')
        
        for x in list(dateList):
            if x >= today:    
                #print('future date',x)
                dateList.remove(x)

        return dateList

    def getToday(self):
        return str(datetime.datetime.today())[:10]

    def getYesterday(self):
        end = datetime.datetime.today()
        start = end - datetime.timedelta(1)
        return str(start)[:10]

    def getNdayBefore(self,date,N):
        end = datetime.datetime.strptime(date,'%Y-%m-%d')
        start = end - datetime.timedelta(N)
        return str(start)[:10]

    def getMostRecentTrading(self):
        end = datetime.datetime.today()
        start = end - datetime.timedelta(70)

        end = datetime.datetime.strftime(end,'%Y-%m-%d')
        start = datetime.datetime.strftime(start,'%Y-%m-%d')

        return self.getTradingDays(end,start)[-1]

    def getMostRecentTradingDay(self,then):
        '''
        比如说想找 then = '2015-06-05' 这天之前的最近一个交易日
        '''
        end = datetime.datetime.strptime(then,'%Y-%m-%d')
        start = end - datetime.timedelta(70)

        end = datetime.datetime.strftime(end,'%Y-%m-%d')
        start = datetime.datetime.strftime(start,'%Y-%m-%d')

        return self.getTradingDays(end,start)[-1]

    @staticmethod    
    def deprecated(func):
        """This is a decorator which can be used to mark functions
        as deprecated. It will result in a warning being emitted
        when the function is used."""
        @functools.wraps(func)
        def new_func(*args, **kwargs):
            warnings.simplefilter('always', DeprecationWarning)  # turn off filter
            warnings.warn("Call to deprecated function {}.".format(func.__name__),
                        category=DeprecationWarning,
                        stacklevel=2)
            warnings.simplefilter('default', DeprecationWarning)  # reset filter
            return func(*args, **kwargs)
        return new_func



if __name__ == '__main__':
    myAPI = ApiParent()
    myAPI.init()
    myAPI.getTradingDays()