diff --git a/Plane.py b/Plane.py deleted file mode 100644 index f613eb57db770335998e7dfd9eb59d4e1b1b2dad..0000000000000000000000000000000000000000 --- a/Plane.py +++ /dev/null @@ -1,66 +0,0 @@ -# coding: utf-8 -# 飞机的基类和敌方飞机 玩家飞机(缺乏玩家控制移动部分) - -import time -import random -import pygame -# 飞机基类的属性:基类+生命+飞机类型 方法:出现在窗口 被打中减少hp - -class Base(object): # 基类 - def __init__(self, screen_temp, x, y, image_name): - self.x = x - self.y = y - self.screen = screen_temp - self.image = pygame.image.load(image_name) - -class BasePlane(Base): # 飞机基类 - def __init__(self, screen_temp, x, y, image_name, plane_hp, plane_type_num, ): - Base.__init__(self, screen_temp, x, y,image_name) - self.bullet_list = [] - self.hp = plane_hp - self.plane_type = plane_type_num - - def display(self): # 加载飞机的图像和位置 待update后出现 - self.screen.blit(self.image, (self.x, self.y)) - - def isHitted(self, plane, width, height): # TODO:飞机和子弹“相遇”的判断函数 - {} - self.hp -= 1 - -class EnemyPlane(BasePlane): - def __init__(self, screen_temp, image_name, enemy_type_num, ): - random_x = random.randint(150, self.screen.get_width()-150) # 根据屏幕大小适当调整,这个是敌机的起始位置 - random_y = random.randint(-50, -20) - if enemy_type_num == 3: # 一类敌机最强 血条3, 三个类型血条和移动速度不同 - enemy_hp = 3 - elif enemy_type_num == 2: - enemy_hp = 2 - elif enemy_type_num == 1: - enemy_hp = 1 - else: - print("敌机目前只有三个类型") - BasePlane.__init__(self, screen_temp, random_x, random_y, image_name, enemy_hp, enemy_type_num) - - # 敌机的移动逻辑:在hp不为0的前提下,随着时间每过1s移动一次,遇到边界就换方向 - # TODO: 时间判断刷新次数较长 暂未想其他方法 待更改 - def enemy_move(self): - move_step = {"1": 50, "2": 80, "3": 120} # 不同类型飞机一秒移动步数不一样 - move_direction = {"0": -1, "1":1} # 随机判定起始的移动方向 结果为1/-1 - random_direction = move_direction[str(random.randint(0, 1))] - while self.hp != 0: # 判断有命才能动 - time_start = int(time.time()) # 获取个时间 暂定一秒一次 一次步数根据档次不同 - if time_start + 1 == int(time.time()): - self.x += move_step[str(self.plane_type)] * random_direction # 步数*1/-1 控制向左还是右 - self.y -= 50 - if self.x < 20 or self.x > self.screen.get_width() - 20: # 敌机移动到边界,就换方向 - random_direction *= -1 - - -class PlayerPlane(BasePlane): - def __init__(self, screen_temp, image_name, myplane_type_num): - x = self.screen.get_width()/2 # 默认玩家出现在最下面最中间 - y = self.screen.get_height() - BasePlane.__init__(self, screen_temp,x, y,image_name,3, myplane_type_num) - - def move(self): - {} # TODO: 玩家的移动控制 \ No newline at end of file diff --git a/Planewar.py b/Planewar.py index 1599c54e04262c7d6719a5adba434dc2bb87f6c4..02fad3f7827b6c07c7a6817fa424604388681bc5 100644 --- a/Planewar.py +++ b/Planewar.py @@ -8,9 +8,14 @@ Created on Tue Nov 16 21:31:30 2021 import pygame import time import random +import pickle - - +width = 901 +height = 897 +score = 0 +enemy_list = [] +bullet_list = [] +supply_list = [] class Base(object): """所有类的基类""" @@ -18,63 +23,112 @@ class Base(object): self.x = x self.y = y self.screen = screen_temp - self.image = pygame.image.load(image_name) + self.image = pygame.image.load(image_name).convert_alpha() + self.w = self.image.get_rect().size[0] + self.h = self.image.get_rect().size[1] class BasePlane(Base): """飞机基类""" - def __init__(self, plane_type, screen_temp, x, y, image_name, picture_num, HP_temp): + global bullet_list + global enemy_list + global width, height + def __init__(self, plane_type, screen_temp, x, y, image_name, picture_num, HP_temp, authority): Base.__init__(self, screen_temp, x, y, image_name)#plane_type飞机类型 self.bullet_list = [] #存储发射出去的子弹的引用 self.plane_type = plane_type #飞机类型标示 - self.HP = HP_temp #飞机hp - self.fire_bullet_count = 0#飞机已发射子弹计数 + self.hp = HP_temp #飞机hp + self.authority = authority + self.time = 0 + #self.fire_bullet_count = 0#飞机已发射子弹计数 def display(self): - """显示飞机""" - self.screen.blit(self.image, (self.x, self.y)) - - - #判断是否被击中 - def isHitted(self, plane, width, height): - {}# widht和height表示范围 + global score + if self.hp > 0: + self.screen.blit(self.image, (self.x, self.y)) + self.hit() + else: + if self.time < 100: + explode_image = pygame.image.load('image/explode.png').convert_alpha() + self.screen.blit(self.image, (self.x, self.y)) + self.screen.blit(explode_image, (self.x, self.y)) + self.time += 1 + else: + if self.authority == 'enemy' and self.hp <= 0: + enemy_list.remove(self) + score += 1 + for item in self.bullet_list: + if not item.judge(): + self.bullet_list.remove(item) + + def hit(self): + if bullet_list and self.hp: + for bullet in bullet_list: + if bullet.authority != self.authority: + if bullet.x < self.x + self.w and bullet.x > self.x and bullet.y < self.y + self.h and bullet.y > self.y: + self.hp -= bullet.damage + bullet.damage = 0 + bullet.visible = False + + + def pos_check(self): + if self.x < 0: + self.x = self.w + elif self.x > width: + self.x = 0 + if self.authority == 'player': + if self.y < 0: + self.y = 0 + elif self.y > height - self.h: + self.y = height - self.h + if self.authority == 'enemy' and self.y > height: + enemy_list.remove(self) class EnemyPlane(BasePlane): def __init__(self, screen_temp): random_num_x = random.randint(12, 418) random_num_y = random.randint(-50, -40) self.direction = "left" - BasePlane.__init__(self, 0, screen_temp, random_num_x, random_num_y, "image/plane/enemy.png", 4, 20) - + BasePlane.__init__(self, 0, screen_temp, random_num_x, random_num_y, "image/plane/enemy.png", 4, 1, "enemy") + def move(self): - if self.direction == "right": - self.x += 5 - elif self.direction == "left": - self.x -= 5 - if self.x > 430: - self.direction = "left" - elif self.x < 0: - self.direction = "right" - - + self.pos_check() + self.fire() + if self.hp > 0: + if self.direction == 'right': + self.x += 1 + elif self.direction == 'left': + self.x -= 1 + if self.x > 430: + self.direction = 'left' + elif self.x < 0: + self.direction = 'right' + self.y += 3 + + def fire(self): + if len(self.bullet_list) != 0: + if self.y - 14 - 80 > self.bullet_list[-1].y: + self.bullet_list.append(Bullet(self.screen, self.x, self.y+60, 'enemy')) + else: + self.bullet_list.append(Bullet(self.screen, self.x, self.y+60, 'enemy')) class PlayerPlane(BasePlane): global supply_size def __init__(self, screen_temp, player_no): - BasePlane.__init__(self, 3, screen_temp, 210, 728, "image/plane/plane.png", 4, 10) #super().__init__() + #x = self.screen.get_width()/2 # 默认玩家出现在最下面最中间 + #y = self.screen.get_height() + x = 400 + y = 200 + BasePlane.__init__(self, 3, screen_temp, x, y, "image/plane/plane.png", 4, 10, "player") #super().__init__() self.player_no = player_no + self.level = 1 + self.exp = 0 + self.hp_amount = 100 self.move_dict = {'horizontal' : 0, 'vertical' : 0, 'space' : 0} - self.bullet_list = [] + - def pos_check(self): - if self.x < 0: - self.x = 830 - elif self.x > 830: - self.x = 0 - if self.y < 0: - self.y = 0 - elif self.y > 800: - self.y = 800 - + def levelup(self): + self.level += 1 + def move(self): self.pos_check() if self.move_dict['horizontal'] != 0 and self.move_dict['vertical'] != 0: @@ -98,20 +152,26 @@ class Bullet(Base): self.screen = screen_temp self.image = pygame.image.load("image/bullet/bullet.png") self.authority = authority - self.width = 9 - self.height = 21 - + self.damage = 1 + self.visible = True + if self.authority == 'enemy': + self.image = pygame.transform.rotate(self.image, 180) + def display(self): - self.screen.blit(self.image, (self.x, self.y)) - + if self.visible == True: + self.screen.blit(self.image, (self.x, self.y)) + def judge(self): - if self.y > 0 or self.y < 897: + if self.y > 0 and self.y < 897: return True else: return False def move(self): - self.y -= 10 + if self.authority == 'player': + self.y -= 8 + elif self.authority == 'enemy': + self.y += 8 class PlayerBullet(Bullet): global bullet_type @@ -119,7 +179,40 @@ class PlayerBullet(Bullet): class EnemyBullet(Bullet): global bullet_type -class button(object): +class Game(object): + def __init__(self, screen, player_list, enemy_list): + self.player_list = player_list + self.enemy_list = enemy_list + + #保存游戏 + def save_game(self): + with open('save.dat', 'wb') as f: + pickle.dump(self, f) + + #读取游戏 + def load_game(self): + with open('save.dat', 'rb+') as f: + model = pickle.load(f) + return model + +#标签 +class Label(object): + def __init__(self, screen, x, y, w, h): + self.x = x + self.y = y + self.w = w + self.h = h + self.screen = screen + + def display(self, text, size = 32): + font = pygame.font.SysFont('Times New Romans', size) #设置字体及大小 + textSurf = font.render(text, True, (0, 0, 0)) + textRect = textSurf.get_rect() + textRect.center = ((self.x + (self.w / 2)), (self.y + (self.h / 2))) + self.screen.blit(textSurf, textRect) #绘制标签 + +#按钮 +class Button(object): def __init__(self, screen, x, y, w, h, text): self.x = x self.y = y @@ -143,20 +236,41 @@ class button(object): def click(self): mouse_pos = pygame.mouse.get_pos() if mouse_pos[0] < self.x + self.w and mouse_pos[0] > self.x and mouse_pos[1] < self.y + self.h and mouse_pos[1] > self.y: - if self.text == 'Back': - pass - elif self.text == 'Restart': - pass - elif self.text == 'Settings': - pass - elif self.text == 'Restart': - pass + return True + +class Supply(Base): + def __init__(self, screen_temp, x, y, image_name, supply_type): + super().__init__(screen_temp, x, y, image_name) + + def display(self): + self.screen.blit(self.image, (self.x, self.y)) + + def move(self): + self.y += 1 +def sum_bullet(player_list, enemy_list): + global bullet_list + for item in player_list: + bullet_list.extend(item.bullet_list) + for item in enemy_list: + bullet_list.extend(item.bullet_list) + return bullet_list + +def create_enemy(screen): + global width + global enemy_list + if len(enemy_list) < 5: + enemy_list.append(EnemyPlane(screen)) + +def game_fail(): + pass + def pause(player): - button_list = [button(player.screen, 160, 200, 160, 100, 'Back'), - button(player.screen, 160, 300, 160, 100, 'Restart'), - button(player.screen, 160, 400, 160, 100, 'Settings'), - button(player.screen, 160, 500, 160, 100, 'Exit')] + button_list = [Button(player.screen, 160, 200, 160, 100, 'Back'), + Button(player.screen, 160, 300, 160, 100, 'Restart'), + Button(player.screen, 160, 400, 160, 100, 'Save'), + Button(player.screen, 160, 500, 160, 100, 'Settings'), + Button(player.screen, 160, 600, 160, 100, 'Exit')] isPause = True while isPause: for item in button_list: @@ -170,13 +284,24 @@ def pause(player): elif event.type == event.type == pygame.KEYDOWN and event.key == pygame.K_ESCAPE: isPause = False break - controller = pygame.joystick.Joystick(0) - controller.init() - if controller.get_button(7): - isPause = False + controller_num = pygame.joystick.get_count() + if(controller_num != 0): + controller = pygame.joystick.Joystick(0) + controller.init() + if controller.get_button(7): + isPause = False pygame.display.update() time.sleep(0.01) - + +def ai_save_data(): + pass + +#自动游戏,ai控制移动和开火 +def ai_control(player, event): + '''model = load_model('lstm_300.h5', custom_objects={'r2': r2}) + player.move_dict['horizontal'] = model.predict(test_x)''' + pass + #键盘控制移动和开火 #玩家一:WASD+空格 #玩家二:方向键+小键盘回车键 @@ -187,7 +312,6 @@ def key_control(player, event): if player.player_no == 0: if event.key == pygame.K_a: player.move_dict['horizontal'] = -1 - print(player.move_dict) if event.key == pygame.K_d: player.move_dict['horizontal'] = 1 if event.key == pygame.K_w: @@ -210,7 +334,6 @@ def key_control(player, event): if event.key == pygame.K_ESCAPE: #按下ESC键暂停游戏 pause(player) - elif event.type == pygame.KEYUP and player: if player.player_no == 0: if event.key == pygame.K_a: @@ -234,6 +357,8 @@ def key_control(player, event): player.move_dict['vertical'] = 0 if event.key == pygame.K_KP_ENTER: player.move_dict['space'] = 0 + + #手柄控制移动和开火 #左摇杆+A(Xbox)/×(Playstation)/B(Nintendo) @@ -254,6 +379,12 @@ def joystick_control(player, controller): pause(player) def main(): + global width, height + global score + global enemy_list + global bullet_list + global supply_list + pygame.init() # 初始化pygame bg_size = width, height = 901,897 screen = pygame.display.set_mode(bg_size) # 显示窗口 @@ -263,15 +394,8 @@ def main(): #创建一个背景图片,路径需做出背景图片放入文件夹中填入路径,haven't finished background = pygame.image.load("image/bg.png").convert() - - screen.blit(background, (0,0)) - - #创建一个玩家飞机 - player = PlayerPlane(screen, player_no = 0) - - #创建一个敌机对象 - enemy = EnemyPlane(screen) + score_label = Label(screen, width - 120, 0, 120, 30) pygame.joystick.init() controller_num = pygame.joystick.get_count() if(controller_num != 0): @@ -279,32 +403,112 @@ def main(): for i in range(controller_num): controller.append(pygame.joystick.Joystick(i)) controller[i].init() + mode = 'single' + start = False + button_list = [Button(screen, width / 2 - 80, 350, 160, 100, 'New Game'), + Button(screen, width / 2 - 80, 450, 160, 100, 'Continue'), + Button(screen, width / 2 - 80, 550, 160, 100, 'Settings'), + Button(screen, width / 2 - 80, 650, 160, 100, 'Exit')] + + while not start: + operation = -1 + screen.blit(background, (0,0)) + label = Label(screen, width / 2 - 120, height / 2 - 200, 240, 120) + + label.display('Plane War', 72) + for button in button_list: + button.display() + for event in pygame.event.get(): + if event.type == pygame.QUIT: + pygame.quit() + elif event.type == pygame.MOUSEBUTTONDOWN: + if(len(button_list) == 4): + if button_list[0].click(): + button_list.clear() + button_list = [Button(screen, width / 2 - 80, 350, 160, 100, 'Single Player'), + Button(screen, width / 2 - 80, 650, 160, 100, 'Double Player')] + elif button_list[1].click(): + #game = Game.load_game() + pass + elif button_list[2].click(): + pass + elif button_list[3].click(): + pygame.quit() + else: + if button_list[0].click(): + start = True + break + elif button_list[1].click(): + mode = 'double' + start = True + break + + if operation == 0: + break + + time.sleep(0.01) + pygame.display.update() + + if mode == 'single': + player_list = [PlayerPlane(screen, player_no = 0)] + else: + player_list = [PlayerPlane(screen, player_no = 0), + PlayerPlane(screen, player_no = 1)] while True: screen.blit(background, (0,0)) - - player.display() - enemy.display() - enemy.move()#调用敌机移动 - - bullet_del_list = [] - for item in player.bullet_list: + for player in player_list: + player.display() + + create_enemy(screen) + for enemy in enemy_list: + enemy.display() + enemy.move() + bullet_list = [] + bullet_list = sum_bullet(player_list, enemy_list) + #bullet_del_list = [] + for item in bullet_list: if item.judge(): item.display() item.move() - else: + ''' else: bullet_del_list.append(item) for item in bullet_del_list: - player.bullet_list.remove(item) + bullet_list.remove(item)''' for event in pygame.event.get(): if event.type == pygame.QUIT: pygame.quit() - key_control(player, event) + for player in player_list: + key_control(player, event) + + if player_list[0].hp <= 0: + player_list.clear() + enemy_list.clear() + pygame.display.update() + label1 = Label(screen, width / 2 - 120, height / 2 - 200, 240, 120) + label2 = Label(screen, width / 2 - 60, height / 2 - 30, 120, 60) + button1 = Button(screen, width / 2 - 80, height / 2 + 100, 160, 100, 'Restart') + button2 = Button(screen, width / 2 - 80, height / 2 + 250, 160, 100, 'Back') + while True: + label1.display('Game Over!', 72) + label2.display('Your Score:' + str(score)) + button1.display() + button2.display() + for event in pygame.event.get(): + if event.type == pygame.QUIT: + pygame.quit() + pygame.display.update() + time.sleep(0.01) + if(controller_num != 0): - joystick_control(player, controller[0]) - player.move() - pygame.display.update() + joystick_control(player_list[0], controller[0]) + + for player in player_list: + player.move() + + score_label.display('Score:' + str(score)) time.sleep(0.01) + pygame.display.update() if __name__ == "__main__": main() diff --git a/cnn.py b/cnn.py new file mode 100644 index 0000000000000000000000000000000000000000..e82c64ec6fedfd8c41c5673f462a2b281fd6bffc --- /dev/null +++ b/cnn.py @@ -0,0 +1,402 @@ +# -*- coding: utf-8 -*- +""" +Created on Sat Dec 4 20:50:49 2021 + +@author: admin +""" + +from keras.layers import LSTM, Dense +from keras.models import Sequential +from keras.models import load_model +import keras.backend as K +import matplotlib.pyplot as plt +from sklearn.metrics import r2_score +import pandas as pd +import numpy as np +import os + + +def cnn_load_data(stock_code): + data = pd.read_excel('data.xlsx') + train_data = [] + train_label = [] + for index,row in data.iterrows(): + train_cells = [] + train_cells.append(row['horizontal']) + train_cells.append(row['vertical']) + train_cells.append(row['space']) + train_cells = np.array(train_cells, dtype='float') + train_data.append(train_cells) + train_label.append(row['close']) + train_data = train_data[:len(train_data)-1] + train_label = train_label[1:] + length = len(train_data) + length = round(length * 0.8) + test_data = train_data[length:] + test_label = train_label[length:] + train_data = train_data[:length] + train_label = train_label[:length] + train_data = np.array(train_data, dtype='float') + train_label = np.array(train_label, dtype='float') + test_data = np.array(test_data, dtype='float') + test_label = np.array(test_label, dtype='float') + #label = to_categorical(label,num_classes=class_num) + return train_data,train_label,test_data,test_label + +def lstm_load_data(stock_code): + '''file = Path('./tdxstocks/day/' + finance.get_tdx_type(stock_code) + '.xlsx') + if not file.exists(): + tdxdata.get_tdx_day(finance.get_tdx_type(stock_code) + '.day') + data = pd.read_excel(file, index_col = 'date')''' + data = ts.get_hist_data(stock_code) + mf_data = trendline.mainforce_monitor_ml(data) + gs_data = trendline.golden_snipe_ml(data) + data = data[['open', 'high', 'close', 'low', 'ma5', 'ma10', 'ma20']] + data = data.iloc[:len(data) - 100] + data = pd.merge(data, mf_data, on='date') + data = pd.merge(data, gs_data, on='date') + data = data.iloc[::-1] + train_data = [] + train_label = [] + for index,row in data.iterrows(): + train_cells = [] + train_cells.append(row['horizontal']) + train_cells.append(row['vertical']) + train_cells.append(row['space']) + train_cells = np.array(train_cells, dtype='float') + train_data.append(train_cells) + train_label.append(row['close']) + train_data = train_data[:len(train_data)-1] + train_label = train_label[1:] + length = len(train_data) + length = round(length * 0.8) + test_data = train_data[length:] + test_label = train_label[length:] + train_data = train_data[:length] + train_label = train_label[:length] + train_data = np.array(train_data, dtype='float') + train_label = np.array(train_label, dtype='float') + test_data = np.array(test_data, dtype='float') + test_label = np.array(test_label, dtype='float') + train_data = train_data.reshape((train_data.shape[0], 1, train_data.shape[1])) + test_data = test_data.reshape((test_data.shape[0], 1, test_data.shape[1])) + #label = to_categorical(label,num_classes=class_num) + return train_data,train_label,test_data,test_label + +def r2(y_true, y_pred): + a = K.square(y_pred - y_true) + b = K.sum(a) + c = K.mean(y_true) + d = K.square(y_true - c) + e = K.sum(d) + f = 1 - b/e + return f + +class CNN: + def neural_model(): + #input_shape = (bin_n * 4, 1) + model = Sequential() + model.add(Dense(64, activation='relu', input_dim=17)) + model.add(Dense(64, activation='relu')) + model.add(Dense(64, activation='relu')) + model.add(Dense(64, activation='relu')) + model.add(Dense(64, activation='relu')) + model.add(Dense(64, activation='relu')) + model.add(Dense(64, activation='relu')) + model.add(Dense(64, activation='relu')) + model.add(Dense(64, activation='relu')) + model.add(Dense(64, activation='relu')) + model.add(Dense(1)) + return model + +def cnn_train(stock_code): + batch_size = 32 + epochs = 1000 + model = CNN.neural_model() + train_x, train_y, test_x, test_y = cnn_load_data(stock_code) + + model.compile(loss='mse', optimizer='rmsprop', metrics=['mae',r2]) + model.fit(train_x, train_y, batch_size=batch_size, epochs=epochs, validation_split=0.2) + + pred_test_y = model.predict(test_x) + #print(pred_test_y) + + pred_acc = r2_score(test_y, pred_test_y) + print('pred_acc', pred_acc) + + plt.rcParams['font.sans-serif'] = ['SimHei'] + plt.rcParams['axes.unicode_minus'] = False + + plt.figure(figsize=(8, 4), dpi=80) + plt.plot(range(len(test_y)), test_y, ls='-.',lw=2,c='r',label='真实值') + plt.plot(range(len(pred_test_y)), pred_test_y, ls='-',lw=2,c='b',label='预测值') + + plt.grid(alpha=0.4, linestyle=':') + plt.legend() + plt.xlabel('number') + plt.ylabel('股价') + + plt.show() + +def cnn_predict(stock_code, length): + '''file = Path('./tdxstocks/day/' + finance.get_tdx_type(stock_code) + '.xlsx') + if not file.exists(): + tdxdata.get_tdx_day(finance.get_tdx_type(stock_code) + '.day') + data = pd.read_excel(file, index_col = 'date')''' + data = ts.get_hist_data(stock_code) + mf_data = trendline.mainforce_monitor_ml(data) + gs_data = trendline.golden_snipe_ml(data) + data = data[['open', 'high', 'close', 'low', 'ma5', 'ma10', 'ma20']] + data = data.iloc[:len(data) - 100] + data = pd.merge(data, mf_data, on='date') + data = pd.merge(data, gs_data, on='date') + data = data.iloc[::-1] + traindata = [] + label = [] + for index,row in data.iterrows(): + train_cells = [] + train_cells.append(row['horizontal']) + train_cells.append(row['vertical']) + train_cells.append(row['space']) + train_cells = np.array(train_cells, dtype='float') + traindata.append(train_cells) + label.append(row['close']) + traindata = traindata[:len(traindata)-1] + label = label[1:] + traindata = np.array(traindata, dtype='float') + #traindata = np.expand_dims(traindata, axis=2) + label = np.array(label, dtype='float') + #label = to_categorical(label,num_classes=class_num) + +class ml_LSTM: + def neural_model(): + model = Sequential() + model.add(LSTM(50, return_sequences=True, input_shape=(1, 18))) + model.add(Dense(64, activation='relu')) + model.add(Dense(64, activation='relu')) + model.add(Dense(64, activation='relu')) + model.add(Dense(64, activation='relu')) + model.add(Dense(64, activation='relu')) + model.add(LSTM(50, return_sequences=True)) + model.add(Dense(64, activation='relu')) + model.add(Dense(64, activation='relu')) + model.add(Dense(64, activation='relu')) + model.add(Dense(64, activation='relu')) + model.add(Dense(64, activation='relu')) + model.add(LSTM(50)) + model.add(Dense(1)) + return model + +def lstm_train(stock_code): + batch_size = 8 + epochs = 1000 + train_x, train_y, test_x, test_y = lstm_load_data(stock_code) + model = ml_LSTM.neural_model(train_x) + #model.compile(loss='mae', optimizer='adam') + model.compile(loss='mse', optimizer='rmsprop', metrics=['mae',r2]) + #model.fit(train_x, train_y, batch_size=batch_size, epochs=epochs, validation_split=0.2) + history = model.fit(train_x, train_y, batch_size=batch_size, epochs=epochs, validation_split=0.2) + + plt.figure() + plt.plot(history.history['loss'], label='train') + plt.plot(history.history['val_loss'], label='test') + plt.legend() + #plt.show() + + pred_test_y = model.predict(test_x) + #print(pred_test_y) + + pred_acc = r2_score(test_y, pred_test_y) + print(test_y) + print(pred_test_y) + print('pred_acc', pred_acc) + + plt.rcParams['font.sans-serif'] = ['SimHei'] + plt.rcParams['axes.unicode_minus'] = False + + plt.figure(figsize=(8, 4), dpi=80) + plt.plot(range(len(test_y)), test_y, ls='-.',lw=2,c='r',label='真实值') + plt.plot(range(len(pred_test_y)), pred_test_y, ls='-',lw=2,c='b',label='预测值') + + plt.grid(alpha=0.4, linestyle=':') + plt.legend() + plt.xlabel('number') + plt.ylabel('股价') + + plt.show() + +def lstm_train_all(): + listfile = os.listdir('./tdxstocks/day/') + batch_size = 32 + epochs = 1 + #model = ml_LSTM.neural_model() + #model.compile(loss='mse', optimizer='rmsprop', metrics=['mae',r2]) + model = load_model('lstm_model.h5', custom_objects={'r2': r2}) + for stock_code in listfile: + data = pd.read_excel('./tdxstocks/day/' + stock_code, index_col = 'date') + if len(data) < 250: + continue + data = data[:600] + + data = data.iloc[::-1] + train_data = [] + train_label = [] + for index,row in data.iterrows(): + train_cells = [] + train_cells.append(row['horizontal']) + train_cells.append(row['vertical']) + train_cells.append(row['space']) + train_cells = np.array(train_cells, dtype='float') + train_data.append(train_cells) + train_label.append(row['close']) + train_data = train_data[:len(train_data)-1] + train_label = train_label[1:] + train_data = np.array(train_data, dtype='float') + train_label = np.array(train_label, dtype='float') + train_data = train_data.reshape((train_data.shape[0], 1, train_data.shape[1])) + #label = to_categorical(label,num_classes=class_num) + + #model.compile(loss='mae', optimizer='adam') + + model.fit(train_data, train_label, batch_size=batch_size, epochs=epochs) + + stock_code = stock_code[2:8] + print(stock_code + ' finished') + model.save('lstm_model.h5') + +def lstm_train_300(): + df = pd.read_excel('./tdxstocks/hs300.xlsx', index_col = 0, converters = {'code':str}) + batch_size = 32 + epochs = 1 + model = ml_LSTM.neural_model() + model.compile(loss='mse', optimizer='rmsprop', metrics=['mae',r2]) + #model = load_model('lstm_model.h5', custom_objects={'r2': r2}) + for iters in range(0,30): + for index,row in df.iterrows(): + stock_code = finance.get_tdx_type(row[0]) + data = pd.read_excel('./tdxstocks/day/' + stock_code + '.xlsx', index_col = 'date') + if len(data) < 250: + continue + data = data[:600] + mf_data = trendline.mainforce_monitor_ml(data) + gs_data = trendline.golden_snipe_ml(data) + ma5_data = trendline.ma(data, 5) + ma10_data = trendline.ma(data, 10) + ma20_data = trendline.ma(data, 20) + ma30_data = trendline.ma(data, 30) + data = data[['open', 'high', 'close', 'low']] + data = data.iloc[:len(data) - 100] + data = pd.merge(data, ma5_data, on='date') + data = pd.merge(data, ma10_data, on='date') + data = pd.merge(data, ma20_data, on='date') + data = pd.merge(data, ma30_data, on='date') + data = pd.merge(data, mf_data, on='date') + data = pd.merge(data, gs_data, on='date') + data = data.iloc[::-1] + train_data = [] + train_label = [] + for index,row in data.iterrows(): + train_cells = [] + train_cells.append(row['horizontal']) + train_cells.append(row['vertical']) + train_cells.append(row['space']) + train_cells = np.array(train_cells, dtype='float') + train_data.append(train_cells) + train_label.append(row['close']) + train_data = train_data[:len(train_data)-1] + train_label = train_label[1:] + train_data = np.array(train_data, dtype='float') + train_label = np.array(train_label, dtype='float') + train_data = train_data.reshape((train_data.shape[0], 1, train_data.shape[1])) + #label = to_categorical(label,num_classes=class_num) + + #model.compile(loss='mae', optimizer='adam') + + model.fit(train_data, train_label, batch_size=batch_size, epochs=epochs) + + stock_code = stock_code[2:8] + print(stock_code + ' iter' + str(iters) + ' finished') + model.save('lstm_300.h5') + +def lstm_train_500(): + df = pd.read_excel('./tdxstocks/zz500.xlsx', index_col = 0, converters = {'code':str}) + batch_size = 32 + epochs = 1 + model = ml_LSTM.neural_model() + model.compile(loss='mse', optimizer='rmsprop', metrics=['mae',r2]) + #model = load_model('lstm_model.h5', custom_objects={'r2': r2}) + for iters in range(0,30): + for index,row in df.iterrows(): + data = data.iloc[::-1] + train_data = [] + train_label = [] + for index,row in data.iterrows(): + train_cells = [] + train_cells.append(row['horizontal']) + train_cells.append(row['vertical']) + train_cells.append(row['space']) + train_cells = np.array(train_cells, dtype='float') + train_data.append(train_cells) + train_label.append(row['close']) + train_data = train_data[:len(train_data)-1] + train_label = train_label[1:] + train_data = np.array(train_data, dtype='float') + train_label = np.array(train_label, dtype='float') + train_data = train_data.reshape((train_data.shape[0], 1, train_data.shape[1])) + #label = to_categorical(label,num_classes=class_num) + + #model.compile(loss='mae', optimizer='adam') + + model.fit(train_data, train_label, batch_size=batch_size, epochs=epochs) + + stock_code = stock_code[2:8] + print(stock_code + ' iter' + str(iters) + ' finished') + model.save('lstm_500.h5') + +def lstm_predict(stock_code): + train_data = [] + train_label = [] + for index,row in data.iterrows(): + train_cells = [] + train_cells.append(row['horizontal']) + train_cells.append(row['vertical']) + train_cells.append(row['space']) + train_cells = np.array(train_cells, dtype='float') + train_data.append(train_cells) + train_label.append(row['close']) + train_data = train_data[:len(train_data)-1] + train_label = train_label[1:] + test_x = np.array(train_data, dtype='float') + test_y = np.array(train_label, dtype='float') + test_x = test_x.reshape((test_x.shape[0], 1, test_x.shape[1])) + model = load_model('lstm_300.h5', custom_objects={'r2': r2}) + + pred_test_y = model.predict(test_x) + #print(pred_test_y) + + pred_acc = r2_score(test_y, pred_test_y) + print(pred_acc) + + plt.rcParams['font.sans-serif'] = ['SimHei'] + plt.rcParams['axes.unicode_minus'] = False + + plt.figure(figsize=(8, 4), dpi=80) + plt.plot(range(len(test_y)), test_y, ls='-.',lw=2,c='r',label='真实值') + plt.plot(range(len(pred_test_y)), pred_test_y, ls='-',lw=2,c='b',label='预测值') + + plt.grid(alpha=0.4, linestyle=':') + plt.legend() + plt.xlabel('number') + plt.ylabel('股价') + + plt.show() + + +if __name__ == '__main__': + #lstm_train('002221') + #cnn_train('002221') + #predict_nextday('002221') + #load_data('002221') + #lstm_train_all() + #lstm_train_300() + lstm_predict('002221') \ No newline at end of file diff --git a/image/explode.png b/image/explode.png new file mode 100644 index 0000000000000000000000000000000000000000..1ea338e054611b3c3288cb969c4681f1cc04c9cf Binary files /dev/null and b/image/explode.png differ