• 数据分析--简单回测框架开发


    简易回测框架开发:

    框架内容:

      上下文信息保存,context

      获取数据

      下单函数

      用户接口

      ........

    import pandas as pd
    import matplotlib.pyplot as plt
    import tushare
    import datetime
    import dateutil
    
    '''
    获取所有的股票交易日,交易日信息保存在csv文件
    '''
    try:
        trade_cal = pd.read_csv("trade_cal.csv")
    except:
        trade_cal = tushare.trade_cal()
        trade_cal.to_csv("trade_cal.csv")
    
    class Context:
        def __init__(self, cash, start_date, end_date):
            '''
            保存股票信息
            :param cash: 现金量
            :param start_date: 量化策略开始时间
            :param end_date: 量化策略结束时间
            :param positions: 持仓股票和对应的数量
            :param benchmark: 参考股票
            :param date_range: 开始-结束之间的所有交易日
            :param dt:  当前日期 (循环时当前日期会发生变化)
            '''
            self.cash = cash
            self.start_date = start_date
            self.end_date = end_date
            self.positions = {}  # 持仓信息
            self.benchmark = None
            self.date_range = trade_cal[(trade_cal['isOpen']==1)&
                                        (trade_cal['calendarDate']>=start_date)&
                                         (trade_cal['calendarDate']<=end_date)]['calendarDate'].values
            self.dt = None
    
    class G:
        '''
        保存用户的全局参数
        '''
        pass
    
    '''
    默认的初始化信息
    '''
    g = G()
    CASH = 100000
    START_DATE = '2016-01-07'
    END_DATE = '2017-01-31'
    context = Context(CASH,START_DATE,END_DATE)
    
    
    def attribute_history(security,
                          count,
                          field=('open','close','high','low','volume')):
        '''
        获取某股票count天的历史行情,每运行一次该函数,日期范围后移
    
        :param security: 股票代码
        :param count: 天数
        :param field: 字段
        :return:
        '''
        end_date = (context.dt - datetime.timedelta(days=1)).strftime('%Y-%m-%d')
        start_date = trade_cal[(trade_cal['isOpen']==1)&
                               (trade_cal['calendarDate']<=end_date)][-count:]['calendarDate'].iloc[0]
        return attribute_daterange_history(security,start_date,end_date,field)
    
    def attribute_daterange_history(security,
                                    start_date,end_date,
                                    field=('open','close','high','low','volume')):
        '''
        底层,获取某股票某一段时间的历史行情
        :param security:
        :param start_date:
        :param end_date:
        :param field:
        :return:
        '''
        df = tushare.get_k_data(security,start_date,end_date)
        df.index = df['date']
        return df[list(field)]
    
    
    def get_today_data(security):
        '''
        获取context的"当天"的股票信息,停牌返回Null
        :param security:
        :return:
        '''
        try:
            today = context.dt.strftime('%Y-%m-%d')
            df = tushare.get_k_data(security,today,today)
            df.index = df['date']
            data = df.loc[today]
        except KeyError:  # 股票停牌
            data = pd.Series()
        return data
    
    
    def _order(today_data, security, amount):
        '''
        底层买股票的函数
        :param today_data: "当天"的股票价格OCHL
        :param security: 股票代码
        :param amount: 交易股数,正数为买入,负数为卖出
        :return:
        '''
        p = today_data['open']
        # 找不到该股票默认为0股
        old_amount = context.positions.get(security, 0)
    
        if len(today_data) == 0:
            print("今日停牌")
            return
        if context.cash - amount * p < 0:
            amount = context.cash // p
            print('%s:现金不足,已调整为%d' %(today_data['date'],amount))
        if amount % 100 != 0:
            # 买或卖不是100的倍数就调整为100的倍数,卖光则不调整
            if amount != -old_amount:
                # 2345 => 2300
                amount = int(amount / 100) * 100
                print('%s:不是100的倍数,已调整为%d' %(today_data['date'],amount))
        if old_amount < -amount:
            amount = -old_amount
            print('%s:卖出股票不能超过持仓数,已调整为%d'%(today_data['date'],amount))
    
        # 更新持仓信息
        context.positions[security] = old_amount + amount
        # 更新钱
        context.cash -= amount*p
        # 持仓为0就删掉
        if context.positions[security] == 0:
            del context.positions[security]
    
    def order(security, amount):
        # 买入股票。amount为正表示买入,负表示卖出
        today_data = get_today_data(security)
        _order(today_data, security, amount)
    def order_target(security, amount):
        # 把股票交易到多少股,不能为负数,比原来小是卖出,比原来大是买入
        if amount < 0:
            print("数量不能为负,已调整为0")
            amount = 0
        today_data = get_today_data(security)
        hold_amount = context.positions.get(security, 0) # TODO: T + 1 closeable total
        delta_amount = amount - hold_amount
        _order(today_data,security,delta_amount)
    def order_value(security, value):
        # 买多少钱的股票或者卖多少钱的股票
        today_data = get_today_data(security)
        amount = value / today_data['open']
        _order(today_data,security,amount)
    def order_target_value(security, value):
        # 买到或者卖到多少钱
        if value < 0:
            print("价值不能为负,已调整为0")
            value = 0
        today_data = get_today_data(security)
        hold_value = context.positions.get(security,0) * today_data['open']
        dalta_value = value - hold_value
        order_value(security,dalta_value)
    
    def run():
        plt_df = pd.DataFrame(index=pd.to_datetime(context.date_range),
                              columns=['value'])
        # 最初的钱,算收益率用
        init_value = context.cash
        # 保存停牌前一天的股票价格
        last_price = {}
        # 用户接口1
        initialize(context)
        for dt in context.date_range:
            context.dt = dateutil.parser.parse(dt)
            # 用户接口2
            handle_data(context)
            # 股票和现金的总价值
            value = context.cash
            for stock in context.positions:
                # 考虑停牌的情况
                today_data = get_today_data(stock)
                if len(today_data) == 0:
                    p = last_price[stock]
                else:
                    p = today_data['open']
                    last_price[stock] = p
                value += p * context.positions[stock]
            plt_df.loc[dt, 'value'] = value
        plt_df['ratio'] = (plt_df['value']-init_value) / init_value
    
        bm_df = attribute_daterange_history(context.benchmark,
                                    context.start_date,
                                    context.end_date)
        bm_init = bm_df['open'][0]
        plt_df['benchmark_raito'] = (bm_df['open']-bm_init) / bm_init
        print(plt_df)
        plt_df[['ratio','benchmark_raito']].plot()
        plt.show()
    
    '''
    initialize和handle_data是用户的操作
    '''
    def initialize(context):
        context.benchmark = '601318'
        g.p1 = 5
        g.p2 = 60
        g.security = '601318'
    def handle_data(context):
        hist = attribute_history(g.security, g.p2)
        ma5 = hist['close'][-g.p1:].mean()
        ma60 = hist['close'].mean()
    
        if ma5 > ma60 and g.security not in context.positions:
            order_value(g.security, context.cash)
        elif ma5 < ma60 and g.security in context.positions:
            order_target(g.security,0)
    
    if __name__ == '__main__':
        run()

  • 相关阅读:
    AJAX
    JQUERY基础
    PHP 数据库抽象层pdo
    会话控制:session与cookie
    php 如何造一个简短原始的数据库类用来增加工作效率
    php 数据访问(以mysql数据库为例)
    面向对象设计原则
    php 设计模式 例子
    PHP中静态与抽象的概念
    键盘的按钮键名
  • 原文地址:https://www.cnblogs.com/staff/p/10973790.html
Copyright © 2020-2023  润新知