강화학습을 이용한 비트코인 매매프로그램(9) - CNN+RNN 모델
현재 강화학습 모델은 과거 데이터로 히스토리를 만든 다음 주가차트와 보조지표가 다 있는 이미지로 env의 상태 tensor를 만들어 프로세싱한 모델이 강제로 3가지로 라벨링을 하기 때문에 비슷한 결과에도 어떤때는 보유 또는 관망이 되고 어떤때는 매매가 되므로 optimizer가 하이퍼 파라메트를 결정하기가 쉽지 않게 됩니다. 그래서 loss가 충분히 소멸하지 않고 계속 널뛰기를 하는 현상이 있습니다. DQN 내부의 CNN의 outputs값은 분류 총 값이지만 실제 비지도 학습으로 분류된 값은 약 500~800개 정도로 넓은 범위를 가지고 있습니다. 이것을 최종 nn.Linear함수를 거치면서 모든 값이 소멸하고 3가지로 축소 됩니다. 만약에 차트가 조금 상승할때 팔고 조금하락할때 사고를 반복하면 수익률은 점점 가파르게 감소하여 잔고가 0이 되게 됩니다. 실제 히스토리 함수가 매매를 자주 하지 않는 이유는 작은 변동에 매매가 발생할 경우 수익도 크지 않을뿐더러 낙폭이 크지는 경우 큰 손실이 발생한 다음은 손실을 복구하지 못하는 상황이 발생하게 됩니다. 그래서 매매를 자주 하지 않게 설계되어 있습니다. 이 부분은 실제로는 수학적 손실,손익과는 논리적으로 맞지 않기 때문에 이상적인 목표를 가진 CNN의 입장에서는 모두 loss로 인식이 되게 됩니다. 이 문제를 극복하기 위해서는 CNN자체를 upgrade하는 방법과 CNN과 RNN을 병합하는 방법을 연구해야 합니다. 저도 어느게 맞다고 말할 수 없습니다. 아직도 강화학습은 연구 대상이지 실용단계로 보기는 힘듭니다. RNN이란 말을 들어시고 뭔가 이상함을 느꼈다면 제가 앞에서 이야기한 인간이 개입한 데이터는 시계열 데이터로 보기 힘들다는 이야기 때문일 겁니다. 그러나 RNN은 주가분석같은 시계열 데이터 분석뿐만 아니라 자연어 학습에도 사용됩니다. 우리가 연구 중인 주식이나 코인의 매매 차트를 분류하면 조금 전에 말씀드린것처럼 약 800개의 패턴이 발생합니다. 그 이상을 넘어 가지는 않습니다. RNN 자연어 분석 예제를 보시면 입력값이 입력 글짜 수(여기에 차트의 비지도 학습 분류 패턴 수를 입력),hidden layer 갯수, 출력 글짜 수(여기에 매매프로그램이 사용할 행동갯수를 입력) 입니다. 이 부분을 잘 응용하면 CNN의 출력수(비지도 분류값 약 800개)를 RNN 의 입력값으로 hidden값은 init_hidden함수로 자동 계산되고 출력값은 우리가 구하려는 3가지(매도,매수,관망)일겁니다. 여기에 임베딩(GRU를 사용할 경우 encoding을 사용하여 자연어 분석시 학습률을 높일 수 있음)을 더하면 조금 더 복잡해지겠지만 단순히 차트만으로 분류하는게 아닌 순서적 흐름을 이용한 분류가 가능해 질것 입니다. 대부분의 값은 관망값을 가져야 합니다. 그렇다면 히든값은 무엇을 의미할까요? RNN에 들어갈 첫번째 값이 있다면 이 값은 시간적인 값t0의 값을 가질겁니다. 히든값은 그 다음에 올 수 있는 수 많은 가능성의 t1값의 tensor일 것입니다. 두 값의 연산에 의해서 t1이 결정되면 hidden은 가능성의값 t2의 텐서가 될것이고 그로 인해 실제 t2가 결정될것 입니다. 학습이 진행되면 정확도는 증가할 것으로 예상됩니다.
import torch
import torch.nn as nn
import torch.nn.functional as F
class CNNRNN(nn.Module):
def __init__(self, device, h, w, outputs, hdnsize):
super(CNNRNN, self).__init__()
self.device = device
self.hidden_size = hdnsize
self.conv1 = nn.Conv2d(4, h*w, kernel_size=5, stride=2)
self.bn1 = nn.BatchNorm2d(h*w)
self.conv2 = nn.Conv2d(h*w, hdnsize, kernel_size=5, stride=2)
self.bn2 = nn.BatchNorm2d(hdnsize)
self.conv3 = nn.Conv2d(hdnsize, hdnsize, kernel_size=5, stride=2)
self.bn3 = nn.BatchNorm2d(hdnsize)
self.i2h = nn.Linear(54 * hdnsize, hdnsize)
self.h2h = nn.Linear(hdnsize, hdnsize)
self.i2o = nn.Linear(hdnsize, outputs)
self.act_fn = nn.Tanh()
def init_hidden(self):
return torch.zeros(1, self.hidden_size)
def forward(self, x):
x = x.to(self.device)
x = F.relu(self.bn1(self.conv1(x)))
x = F.relu(self.bn2(self.conv2(x)))
x = F.relu(self.bn3(self.conv3(x)))
hidden = self.init_hidden().to(self.device)
x= self.i2h(x.view(x.size(0), -1))
hidden = self.h2h(hidden)
hidden = self.act_fn(x + hidden)
return self.i2o(hidden)
cnnrnn.py 입니다.
import random
from collections import deque, namedtuple
from IPython.display import display, Math
from Account import Account
import math
from itertools import count
import os
import sys
import os.path as path
import plotly as plt
import torch
import torch.nn as nn
import torch.optim as optim
from Market import Market
import torchvision
import time
from cnnrnn import CNNRNN
from memory import ReplayMemory, Experience
action_kind = 3
max_episode = 5000
screen_height = 100
screen_width = 140
data_size = 250
visit_cnt = [0] * action_kind
# replay_buffer = deque()
epsilon = 0.3
dis = 0.9
BATCH_SIZE = 8
# GAMMA = 0.999
GAMMA = 0.99
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 200
TARGET_UPDATE = 10
steps_done = 0
loss = any
WINDOW_START = 0
WINDOW_SIZE = 1500
# for i in range(torch.cuda.device_count()):
# print(torch.cuda.get_device_name(i))
device_str = "cuda"
device = torch.device(device_str)
memory = ReplayMemory(BATCH_SIZE)
converter = torchvision.transforms.ToTensor()
market = Market()
train_net = CNNRNN(device, screen_height, screen_width, action_kind, 32).to(device)
train_net = nn.DataParallel(train_net, device_ids=[0,1]).to(device)
episode_durations = []
optimizer = optim.RMSprop(train_net.parameters())
def optimize_action(memory):
if len(memory) < BATCH_SIZE:
return None
epsode = memory.pop(BATCH_SIZE)
batch = Experience(*zip(*epsode))
state_batch = torch.cat(batch.state)
action_batch = torch.cat(batch.action)
train_net.train()
state_action_values = train_net(state_batch).gather(1, action_batch)
criterion = nn.SmoothL1Loss()
loss = criterion(state_action_values, action_batch)
# Optimize the model
optimizer.zero_grad()
loss.backward()
for param in train_net.parameters():
param.grad.data.clamp_(-1, 1)
optimizer.step()
return loss
def select_action(df, idx):
try:
global steps_done
sample = random.random()
eps_threshold = EPS_END + (EPS_START - EPS_END) * math.exp(-1. * steps_done / EPS_DECAY)
steps_done += 1
if sample > eps_threshold:
if df.loc[idx, "closemax"] == df.loc[idx, "close"]:
action = 2
return action
elif df.loc[idx, "closemin"] == df.loc[idx, "close"]:
action = 1
return action
else:
return 0
else:
action = random.randrange(action_kind)
return action
except Exception as ex:
exc_type, exc_obj, exc_tb = sys.exc_info()
fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
print("`select_action -> exception! %s : %s %d" % (str(ex) , fname, exc_tb.tb_lineno))
return 0
def get_chart(market, idx, max_data):
img = market.get_chart(idx, max_data=max_data)
if img is None:
return None
# img = Image.fromarray(np.uint8(cm.gist_earth(plt.io.to_image(fig, format='png')*255)))
# im = Image.fromarray(img, bytes=True)
# im = Image.fromarray(np.uint8(cm.gist_earth(img))/255)
# im = Image.fromarray(np.uint8(img)/255)
# img = img.resize((700, 500), resample=Image.BICUBIC)
# img = Image.fromarray(cm.gist_earth(plt.io.to_image(fig, format='png'), bytes=True))
# display(img)
chart = converter(img).unsqueeze(0).to(device)
return chart
def plot_durations(last_chart, curr_chart):
plt.figure()
# plt.subplot(1,2,1)
img = plt.imshow(last_chart.cpu().squeeze(0).permute(1, 2, 0).numpy(), interpolation='none')
plt.title('Example extracted screen')
plt.figure(2)
# plt.subplot(1,2,2)
plt.clf()
durations_t = torch.tensor(episode_durations, dtype=torch.float)
plt.title('Training...')
plt.xlabel('Episode')
plt.ylabel('Duration')
plt.plot(durations_t.numpy())
# plt.show()
# Take 100 episode averages and plot them too
if len(durations_t) >= 100:
means = durations_t.unfold(0, 100, 1).mean(1).view(-1)
means = torch.cat((torch.zeros(99), means))
plt.plot(means.numpy())
img.set_data(curr_chart.cpu().squeeze(0).permute(1, 2, 0).numpy())
plt.pause(0.01) # pause a bit so that plots are updated
display.clear_output(wait=True)
display.display(plt.gcf())
def main():
if path.exists("pt/train_cnnrnn_{}.pt".format(device_str)):
train_net.load_state_dict(torch.load("pt/train_cnnrnn_{}.pt".format(device_str)))
for _ in range(10):
df = market.get_data()
# df = df.head(WINDOW_SIZE)
for epoch in range(max_episode):
account = Account(df, 50000000)
account.reset()
last_chart = get_chart(market, data_size, data_size)
account.reset()
for idx,_ in enumerate(df.index, start=(data_size + 1)):
try:
since = time.time()
curr_chart = get_chart(market, idx, data_size)
if curr_chart is not None:
state = curr_chart - last_chart
else:
continue
last_chart = curr_chart
reward = 0
num_action = select_action(df, idx)
reward, real_action = account.exec_action(num_action, idx)
print("idx:%d==>action:%d, price:%.2f"%(idx, num_action, df.loc[idx, 'close']))
reward = torch.tensor([reward], device=device)
action = torch.tensor([[num_action]], device=device, dtype=torch.int64)
memory.push(state, action, reward)
while len(memory) >= BATCH_SIZE:
loss = optimize_action(memory)
if loss is not None:
print("epoch[%d:%d] epsode is next loss[%.10f]" % (epoch, idx, loss.item()))
if idx % TARGET_UPDATE == 0:
torch.save(train_net.state_dict(),"pt/train_cnnrnn_{}.pt".format(device_str))
spend = time.time() - since
print("idx:%d price [%.4f] unit[%.4f] used time[%.2f] agent rate:%.05f remind money:%.02f"
% (idx, df.loc[idx, 'close'], account.unit, spend, account.rate, account.balance + account.unit * df.loc[idx, 'close']))
if account.is_bankrupt():
break
if idx == df.index.max():
break
except Exception as ex:
exc_type, exc_obj, exc_tb = sys.exc_info()
fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
print("`recall_training -> exception! %s : %s %d" % (str(ex) , fname, exc_tb.tb_lineno))
print("end training DQN")
print('Complete Training')
if __name__ == "__main__":
main()
main_cnnrnn.py 입니다.
CNN 단독 모델과 CNN + RNN 모델의 학습 증가율을 비교해 보시기 바랍니다.
대표이미지출처:https://blog.kakaocdn.net/dna/IFp6A/btqASygMYCA/AAAAAAAAAAAAAAAAAAAAAHPRKM8z8_i1rJvLfrqQ8vMKOITsjcC7Q0zi2JxkpFQQ/img.png?credential=yqXZFxpELC7KVnFOS48ylbz2pIh7yKj8&expires=1767193199&allow_ip=&allow_referer=&signature=%2B1IRSqP3ty8OgHAVFr56dwkhsAQ%3D
RNN 소스코드 출처 : 책 파이토치 첫걸음에서 발췌 및 응용
