import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import pymysql
import warnings
import random
warnings.filterwarnings("ignore")
def read_db(path):
close = pd.read_csv(path)
info = close[:-6]
zhi_list = []
for x in info['qx_zu'][20:]:
x_index = info[info['qx_zu'] == x].index
all_shuju = info[x_index[0] - 20:x_index[0] - 1]
zhi = np.mean(all_shuju['qx_zu']) - 2 * (np.std(all_shuju['qx_zu']))
info[21:]['cha'] = zhi
zhi_list.append(zhi)
info['cha'] = 0
info['cha'][20:] = np.array(zhi_list)
info['jieguo'] = info['qx_zu'] < info['cha']
info['jieguo'][:20] = 0
x_true_index = info[info['jieguo'] == True].index
x_false_index = info[info['jieguo'] == False].index
x_small = info.loc[x_true_index]
x_small['jieguo'] = x_small['close']
# x_small['jieguo'] = 1
x_big = info.loc[x_false_index]
x_big['jieguo'] = 0
info_all = x_big.append(x_small)
info_all = info_all.sort_index()
# print(list(info_all['jieguo']))
# info_all.to_csv('info_all.csv')
info_all['date'] = pd.to_datetime(info_all['date'], format='%Y/%m/%d')
print(info_all['date'])
return info_all
def show_info(info):
"""
画图展示
:param info:
:return:
"""
plt.rcParams['font.sans-serif'] = ['SimHei'] # 设置全局显示汉子
plt.tick_params(axis='x', labelsize=10) # 设置x轴标签大小
plt.plot(info['date'], info['close'], 'black')
plt.legend(loc=4)
plt.xlabel('时间') # 设置x轴标题
plt.ylabel('沪深300收盘价', color='black') # 设置Y1轴标题
# plt.xticks(rotation=110) # 设置时间标签显示格式
plt.twinx() # 添加一条Y轴,
plt.axis('off') #不显示刻度
plt.ylabel('预警信号', color='r') # 设置Y2轴标题
plt.bar(info['date'], info['jieguo'], label='预警信号', color='r')
plt.legend(loc='upper right') # 右上
# 保存图片
# plt.savefig('image/{}.png'.format('沪深300收盘价'))
plt.show()
if __name__ == '__main__':
show_info(info=read_db('h300_close.csv'))