# encoding: utf-8

from apistar import Route, App
import redis
from .redisConfig import redisConf,jsonList
from apistar import http
import pyodbc
r = redis.StrictRedis(host=redisConf['host'], port=redisConf['port'])
import json

"""
API 设计
========================================================
"""


# API 接收用户选股参数 post


class clientAPI():
    DBcnxn = pyodbc.connect(DSN = 'TEST')
    DBcursor = DBcnxn.cursor()

    def __init__(self):
        pass

    class CORSMiddleware(App):
        """Add 跨域访问 Cross-origin resource sharing headers to every request."""

        def __init__(self, origin='*', **kwargs):
            super().__init__(**kwargs)
            self.origin = origin

        def __call__(self, environ, start_response):

            def add_cors_headers(status, headers, exc_info=None):
                
                headers = http.Headers(headers)
                headers.add("Access-Control-Allow-Origin", self.origin)  # 跨域访问
                headers.add("Access-Control-Allow-Headers", "Origin, Content-Type, Authorization")
                headers.add("Access-Control-Allow-Credentials", "true")
                headers.add("Access-Control-Allow-Methods", "GET , POST, OPTIONS")  # 跨域访问
                
                return start_response(status, headers.to_list())

            if environ.get("REQUEST_METHOD") == "OPTIONS":  # 跨域访问
                #add_cors_headers("200 OK", [("Access-Control-Allow-Origin", "*"),("Access-Control-Allow-Methods","GET, PUT, POST, DELETE")])
                add_cors_headers("200 OK", [("Best-NBA-player","LeBron James")])
                return [b'200 OK']

            return super().__call__(environ, add_cors_headers)

    def checkDict(self,mydict):
        # 需要检查策略ID的唯一性，和数据库中的所有策略对照
        print('\ncheckDict: ',type(mydict),mydict)
        

        findStrategyIDCommand = 'select DISTINCT(strategyID) from BT.[dbo].BTrecords'
        self.DBcursor.execute(findStrategyIDCommand)
        raw1 = self.DBcursor.fetchall()
        
        #print('checkDict',type([x[0] for x in raw1]),[x[0] for x in raw1])
        if mydict['strategyID'] in [x[0] for x in raw1]:
            print('已经跑过的策略')
    


    def strategyParams(self,request: http.Request)-> http.JSONResponse: #post
        """接收策略参数"""
        print('-----------------------------------接收策略参数------------------------------------------------')
        print('\n...',type(request.body.decode('utf-8')),request.body.decode('utf-8'))
        print('end')
        
        mydict = json.loads(request.body.decode('utf-8'))
        #print('type of mydict',type(mydict))

        with open('./postedID.txt','a+') as f:
            f.write(mydict['strategyID']) 
            f.write('\n')

        
        if isinstance(mydict,str):
            mydict = json.loads(mydict)
        #checkDict(mydict)
        r.lpush(jsonList, str(mydict)) # myJson is the type of str
        
        #headers = {'Access-Control-Allow-origin':'*','Access-Control-Allow-Methods': 'GET, POST, OPTIONS'} # 跨域访问
        return http.JSONResponse(mydict,status_code=200, headers={})
        #return str(mydict)

    # API2 获取策略的运行结果 get


    def strategyResults(self,strategyID,start,end = 0): 
        '''
        end 可选
        """返回策略运行结果"""
        0 strategyID VARCHAR(255),
        1 status VARCHAR(255)),
        2 startDate VARCHAR(255),
        3 endDate VARCHAR(255),
        4 length VARCHAR(255),
        5 btdate VARCHAR(255),
        6 positionList VARCHAR(255),
        7 cash VARCHAR(255),
        8 positionListValue VARCHAR(255),
        9 totalValue VARCHAR(255),
        10 tradingDayAPI VARCHAR(8000);

        请求的字段：
        {
            "strategy_id": 45534,  //必须
            "start":0,   //必须，表示请求从第几个周期开始的数据
            "end":0     //可选，表示请求区间的结束下标 
        }
        
        '''

        qur = "select top 1 * from BT.[dbo].BTrecords where strategyID = '%s' "%(strategyID)
        #print(qur)
        self.DBcursor.execute(qur)
        raw1 = self.DBcursor.fetchall()

        if len(eval(str(raw1))) == 0:
            return http.JSONResponse(json.dumps(None),status_code=200, headers={})

        for l in eval(str(raw1)):
            length = l[4]
            startDate = l[2]
            endDate = l[3]
            status = l[1]
            tradingDates = eval(l[10])
        start = int(start)
        end = int(end)
        length = int(length)

        if start <= 0:
            pickStartDate = startDate
        elif start < length:
            pickStartDate = tradingDates[start]
        else:
            pickStartDate = endDate
            
        if end <= 0:
            pickEndDate = endDate

        elif end < length:
            pickEndDate = tradingDates[end] 
        else:
            pickEndDate = endDate
            

        #items = 'btdate,positionList,cash,positionListValue,totalValue,strategyID'
        query = "select * from BT.[dbo].BTrecords where strategyID = '%s' and btdate >= '%s'  and btdate <= '%s' order by btdate"%(
            strategyID,pickStartDate,pickEndDate)
        #print(query)
        
        self.DBcursor.execute(query)
        #DBcursor.execute('select * from BTrecords')
        raw = self.DBcursor.fetchall()

        dataList = list()
        for ll in eval(str(raw)):
            tempDict = dict()
            tempDict['btdate'] = ll[5]
            tempDict['positionList'] = ll[6]
            tempDict['cash'] = ll[7]
            tempDict['positionListValue'] = ll[8] 
            tempDict['totalValue'] = ll[9]
            dataList.append(tempDict)

        end = start + len(dataList) # tell user how much data will be sent

        mydict = dict()
        mydict['strategy_id'] = list()
        mydict['status'] = status
        mydict['length'] = length
        mydict['start'] = start
        mydict['end'] = end
        mydict['data'] = dataList


        #headers = {'Access-Control-Allow-origin':'*','Access-Control-Allow-Methods': 'GET, POST, OPTIONS'} # 跨域访问
        return http.JSONResponse(mydict,status_code=200, headers={})
        #return json.dumps(mydict,ensure_ascii=False)

    def serve(self):
        
        routes = [
            Route('/p', "POST", self.strategyParams),
            Route('/g', "GET", self.strategyResults),
        ]

        #app = App(routes=routes)
        #app = App(routes=routes, event_hooks=event_hooks)+

        app = self.CORSMiddleware(
            origin='*',
            routes=routes,
        )
        app.serve('0.0.0.0',8080,use_debugger=True,use_reloader=True)

if __name__ == "__main__":
    myserver = clientAPI()
    myserver.serve()
