지금은 코인을 중심으로 프로그램을 작성하지만 주식에도 그대로 적용이 가능합니다. 코인 거래소는 upbit를 이용하고 있습니다. 지금은 코인원이 수수료가 가장 싸므로 api를 조금 수정하시면 코인원식으로 변경이 가능 할 겁니다. 강화학습에서 중요한 포인트는 env와 agent입니다. 코인에서는 env는 마켓으로 agent는 account라고 저는 명명을 했습니다. 꼭 그렇게 않해도 됩니다. 여기에 게시되는 소스는 아무런 라이센스 없이 사용하셔도 됩니다. 우리가 사용할 DQN에 대하여 간단하게 설명 드리겠습니다. 그전에 참고할 만한 책을 한권 소개 드립니다. '바닥부터 배우는 강화학습'(https://www.googleadservices.com/pagead/aclk?sa=L&ai=DChcSEwj6gIbC29D7AhVy3EwCHXJfAU4YABAKGgJ0bQ&ohost=www.google.com&cid=CAESbeD2J1z4qWevI6iWwrPeahjlRAd5PHQxq3vxCnxSB-XFcQcQlPCdbI5vS31FJsWDEdYDXr7wQRya94UUUP1TeQ_sv2Gp1efEdpeYMlECKaQ1jwRuNHUqGOxGxPQPMtaqjppqY8HY39-QdAnHBII&sig=AOD64_1r3tIroFWvKdgd3IEQB1BX0mmAdg&ctype=5&q=&ved=2ahUKEwizuv_B29D7AhUVBd4KHTcyApkQ9aACKAB6BAgEEAw&adurl=)
을 한번 읽어 보시기 바랍니다. 아주 기초 이론이고 실제 강화학습은 이 책 내용과는 다를 수 있습니다. 강화학습의 그리드월드 내용을 보시면 약간 지도학습의 모습과 닮았습니다. 모든 강화학습의 방법은 책 내용데로 해서는 결과를 보기는 상당히 힘듧니다. 기본적으로 agent는 일단 인공신경망을 배제하고 env와 agent를 구성할 필요가 있습니다. 그리고 나서 실제 agent가 발전하는게 확인이 되면 그다음은 인공신경망을 이용하여 Q(s,a)함수를 Deep learning 모듈화 하는 것입니다. 사실 이 모든 것은 어떤 수학자가 주장한 모든 문제는 곡선화 할 수 있으며 곡선화를 하면 해당 답을 구할 수 있다는 일반 함수론에서 출발하여 Q-learning을 DQN을 이용하여 함수로 만들 수 있다는 개념으로 출발했습니다. 그러니 Q(s,a)함수를 먼저 구하고 그 다음 그 함수를 DQN모듈로 변환하는 것이 바로 강화학습인것이죠. 물론 DB를 사용할 수도 있을 것입니다. 아주 정밀한 값도 저장할 수 있지만 만약에 데이터가 무한하다면 DB로 그 많은 데이터를 저장할 수 있을 까요? 그래서 강화학습에서는 DQN(CNN)을 이용하여 함수화 하자는 거죠 토치의 저장기능을 사용하면 몇 byte의 값으로도 그 함수를 저장할 수 있으니까요? 다시 시작할 때는 그냥 그 모듈을 다시 읽어서 재실행하면 끝입니다. DB보다는 컴퓨터 비용이 훨씬 저렴하죠 아니 저렴한 정도가 아니라 거의 공짜 수준이라고 할 수 있죠?
일단 Market class 소스를 보시죠
from matplotlib.backend_bases import FigureCanvasBase
import py
import plotly.graph_objects as go
import plotly.subplots as ms
import plotly.express as px
import plotly as plt
import pymysql
import pandas as pd
import numpy as np
import time
import talib
from PIL import Image
ticker = 'ETH'
class Market():
def __init__(self) -> None:
self.df = pd.DataFrame()
def add_bbands(self):
try:
print("볼리져 밴드 구하기:%s" % time.ctime(time.time()))
self.df['ma20'] = talib.SMA(np.asarray(self.df['close']), 20)
self.df['stddev'] = self.df['close'].rolling(window=20).std() # 20일 이동표준편차
upper, middle, lower = talib.BBANDS(np.asarray(self.df['close']), timeperiod=40, nbdevup=2.3, nbdevdn=2.3, matype=0)
self.df['lower']=lower
self.df['middle']=middle
self.df['upper']=upper
# self.df['bbands_sub']=self.df.upper - self.df.lower
# self.df['bbands_submax'] = self.df.bbands_sub.rolling(window=840, min_periods=10).max()
# self.df['bbands_submin'] = self.df.bbands_sub.rolling(window=840, min_periods=10).min()
# self.df['bbs_deg'] = (self.df.mavg - self.df.mavg.shift()).apply(lambda x: math.degrees(math.atan2(x, 40)))
# self.df['pct8']=(self.df.close - self.df.lower)/(self.df.upper - self.df.lower)
# self.df['pct8vsclose']=self.df.close - self.df.pct8
except Exception as ex:
print("`add_bbands -> exception! %s `" % str(ex))
def add_relative_strength(self):
try:
print("상대 강도 지수 구하기:%s" % time.ctime(time.time()))
rsi14 = talib.RSI(np.asarray(self.df['close']), 14)
self.df['rsi14'] = rsi14
except Exception as ex:
print("`add_relative_strength -> exception! %s `" % str(ex))
def add_iip(self):
try:
#역추세전략을 위한 IIP계산
self.df['II'] = (2*self.df['close']-self.df['high']-self.df['low'])/(self.df['high']-self.df['low'])*self.df['volume']
self.df['IIP21'] = self.df['II'].rolling(window=21).sum()/self.df['volume'].rolling(window=21).sum()*100
except Exception as ex:
print("`add_iip -> exception! %s `" % str(ex))
def add_stock_cast(self):
try:
# 스토캐스틱 구하기
self.df['ndays_high'] = self.df['high'].rolling(window=14, min_periods=1).max() # 14일 중 최고가
self.df['ndays_low'] = self.df['low'].rolling(window=14, min_periods=1).min() # 14일 중 최저가
self.df['fast_k'] = (self.df['close'] - self.df['ndays_low']) / (self.df['ndays_high'] - self.df['ndays_low']) * 100 # Fast %K 구하기
self.df['slow_d'] = self.df['fast_k'].rolling(window=3).mean() # Slow %D 구하기 except Exception as ex:
except Exception as ex:
print("`add_stock_cast -> exception! %s `" % str(ex))
def add_mfi(self):
try:
# MFI 구하기
self.df['PB'] = (self.df['close'] - self.df['lower']) / (self.df['upper'] - self.df['lower'])
self.df['TP'] = (self.df['high'] + self.df['low'] + self.df['close']) / 3
self.df['PMF'] = 0
self.df['NMF'] = 0
for i in range(len(self.df.close)-1):
if self.df.TP.values[i] < self.df.TP.values[i+1]:
self.df.PMF.values[i+1] = self.df.TP.values[i+1] * self.df.volume.values[i+1]
self.df.NMF.values[i+1] = 0
else:
self.df.NMF.values[i+1] = self.df.TP.values[i+1] * self.df.volume.values[i+1]
self.df.PMF.values[i+1] = 0
self.df['MFR'] = (self.df.PMF.rolling(window=10).sum() / self.df.NMF.rolling(window=10).sum())
self.df['MFI10'] = 100 - 100 / (1 + self.df['MFR'])
except Exception as ex:
print("`add_mfi -> exception! %s `" % str(ex))
def add_macd(self, sort, long, sig):
try:
print("MACD 구하기:%s, sort:%d, long:%d, sig:%d" % (time.ctime(time.time()), sort, long, sig))
macd, macdsignal, macdhist = talib.MACD(np.asarray(self.df['close']), sort, long, sig)
self.df['macd'] = macd
# self.df['macdmax'] = self.df.macd.rolling(window=840, min_periods=100).max()
# self.df['macdmin'] = self.df.macd.rolling(window=840, min_periods=100).min()
# self.df['macd_max_rate'] = self.df.apply(lambda x: (x['macdmax'] - x['macd']) * 100 / (x['macdmax'] - x['macdmin']), axis=1)
# self.df['macd_min_rate'] = self.df.apply(lambda x: (x['macd'] - x['macdmin']) * 100 / (x['macdmax'] - x['macdmin']), axis=1)
# self.df['prv_macd_degrees'] = self.df.macd_degrees.shift()
self.df['signal'] = macdsignal
self.df['flag'] = macdhist
# self.df['prv_osc_degrees'] = self.df.osc_degrees.shift()
except Exception as ex:
print("`add_macd -> exception! %s `" % str(ex))
def get_data(self):
table_name="TB_ETH_TRADE"
sqlcon = pymysql.connect(host='192.168.xx.xx', user='xxxx', password='xxxx', db='yubank2')
cursor = sqlcon.cursor(pymysql.cursors.DictCursor)
str_query = """select
'time'
,`start` as open
, high
, low
, close
, volume
, macd
, macdmax
, macdmin
, `signal`
, `flag`
, osc_degrees
, sma1200
, sma1200_degrees
, wma1200
, wma1200_degrees
from %s""" % table_name
cursor.execute(str_query)
data = cursor.fetchall()
self.df = pd.DataFrame(data)
sqlcon.close()
self.df['sma3'] = talib.SMA(np.asarray(self.df['close']), 3)
self.df['closemax'] = self.df.close.rolling(window=100, center=True).max()
self.df['closemin'] = self.df.close.rolling(window=100, center=True).min()
self.add_macd(12,26,9)
self.add_bbands()
self.add_relative_strength()
self.add_stock_cast()
self.add_iip()
self.add_mfi()
self.df.dropna(inplace=True)
self.df.reset_index(drop=True,inplace=True)
return self.df
def get_test_data(self):
table_name="TB_ETH_TRADE"
sqlcon = pymysql.connect(host='192.168.xx.xx', user='xxxx', password='xxxx', db='yubank2')
cursor = sqlcon.cursor(pymysql.cursors.DictCursor)
str_query = """select
`start` as open
, high
, low
, close
, volume
, macd
, macdmax
, macdmin
, `signal`
, `flag`
, osc_degrees
, sma1200
, sma1200_degrees
, wma1200
, wma1200_degrees
from %s""" % table_name
cursor.execute(str_query)
data = cursor.fetchall()
self.df = pd.DataFrame(data)
sqlcon.close()
self.add_bbands()
self.add_relative_strength()
self.df['clsmin'] = self.df.close.rolling(window=600, min_periods=600, center=False).min()
self.df['clsmax'] = self.df.close.rolling(window=600, min_periods=600, center=False).max()
self.df['action'] = self.df.apply(lambda x: 1 if x['close'] > x['up'] else( 1 if x['close'] < x['dn'] else 0), axis=1)
# self.df.drop(self.df[self.df['action'] == 0].index, inplace=True)
self.df.dropna(inplace=True)
self.df.reset_index(drop=True,inplace=True)
return self.df
def get_lstm_data(self):
table_name="TB_ETH_TRADE"
sqlcon = pymysql.connect(host='192.168.xx.xx', user='xxxx', password='xxxx', db='yubank2')
cursor = sqlcon.cursor(pymysql.cursors.DictCursor)
str_query = """select
`date`
,`start` as open
, high
, low
, close
from %s""" % table_name
cursor.execute(str_query)
data = cursor.fetchall()
self.df = pd.DataFrame(data)
sqlcon.close()
self.df['fclose'] = self.df.close.shift(-1)
self.df.dropna(inplace=True)
self.df.reset_index(drop=True,inplace=True)
return self.df
def get_chart(self, idx, max_data:int=300):
try:
df = self.df.head(idx + max_data).tail(max_data)
df.reset_index(drop=True, inplace=True)
if df.index.max() < 299:
return None
candle = go.Candlestick(x=df.index,open=df['open'],high=df['high'],low=df['low'],close=df['close'], increasing_line_color = 'red',decreasing_line_color = 'blue', showlegend=False)
upper = go.Scatter(x=df.index, y=df['upper'], line=dict(color='red', width=2), name='upper', showlegend=False)
ma20 = go.Scatter(x=df.index, y=df['ma20'], line=dict(color='black', width=2), name='ma20', showlegend=False)
lower = go.Scatter(x=df.index, y=df['lower'], line=dict(color='blue', width=2), name='lower', showlegend=False)
volume = go.Bar(x=df.index, y=df['volume'], marker_color='red', name='volume', showlegend=False)
MACD = go.Scatter(x=df.index, y=df['macd'], line=dict(color='blue', width=2), name='MACD', legendgroup='group2', legendgrouptitle_text='MACD')
MACD_Signal = go.Scatter(x=df.index, y=df['signal'], line=dict(dash='dashdot', color='green', width=2), name='MACD_Signal')
MACD_Oscil = go.Bar(x=df.index, y=df['flag'], marker_color='purple', name='MACD_Oscil')
fast_k = go.Scatter(x=df.index, y=df['fast_k'], line=dict(color='skyblue', width=2), name='fast_k', legendgroup='group3', legendgrouptitle_text='%K %D')
slow_d = go.Scatter(x=df.index, y=df['slow_d'], line=dict(dash='dashdot', color='black', width=2), name='slow_d')
PB = go.Scatter(x=df.index, y=df['PB']*100, line=dict(color='blue', width=2), name='PB', legendgroup='group4', legendgrouptitle_text='PB, MFI')
MFI10 = go.Scatter(x=df.index, y=df['MFI10'], line=dict(dash='dashdot', color='green', width=2), name='MFI10')
RSI = go.Scatter(x=df.index, y=df['rsi14'], line=dict(color='red', width=2), name='RSI', legendgroup='group5', legendgrouptitle_text='RSI')
# 스타일
fig = ms.make_subplots(rows=5, cols=2, specs=[[{'rowspan':4},{}],[None,{}],[None,{}],[None,{}],[{},{}]], shared_xaxes=True, horizontal_spacing=0.03, vertical_spacing=0.01)
fig.add_trace(candle,row=1,col=1)
fig.add_trace(upper,row=1,col=1)
fig.add_trace(ma20,row=1,col=1)
fig.add_trace(lower,row=1,col=1)
fig.add_trace(volume,row=5,col=1)
fig.add_trace(candle,row=1,col=2)
fig.add_trace(upper,row=1,col=2)
fig.add_trace(ma20,row=1,col=2)
fig.add_trace(lower,row=1,col=2)
fig.add_trace(MACD,row=2,col=2)
fig.add_trace(MACD_Signal,row=2,col=2)
fig.add_trace(MACD_Oscil,row=2,col=2)
fig.add_trace(fast_k,row=3,col=2)
fig.add_trace(slow_d,row=3,col=2)
fig.add_trace(PB,row=4,col=2)
fig.add_trace(MFI10,row=4,col=2)
fig.add_trace(RSI,row=5,col=2)
# 추세추종
# trend_fol = 0
# trend_refol = 0
# for i in df.index:
# if df['PB'][i] > 0.8 and df['MFI10'][i] > 80:
# trend_fol = go.Scatter(x=[df.index[i]], y=[df['close'][i]], marker_color='orange', marker_size=20, marker_symbol='triangle-up', opacity=0.7, showlegend=False)
# fig.add_trace(trend_fol,row=1,col=1)
# elif df['PB'][i] < 0.2 and df['MFI10'][i] < 20:
# trend_fol = go.Scatter(x=[df.index[i]], y=[df['close'][i]], marker_color='darkblue', marker_size=20, marker_symbol='triangle-down', opacity=0.7, showlegend=False)
# fig.add_trace(trend_fol,row=1,col=1)
# 역추세추종
# for i in df.index:
# if df['PB'][i] < 0.05 and df['IIP21'][i] > 0:
# trend_refol = go.Scatter(x=[df.index[i]], y=[df['close'][i]], marker_color='purple', marker_size=20, marker_symbol='triangle-up', opacity=0.7, showlegend=False) #보라
# fig.add_trace(trend_refol,row=1,col=1)
# elif df['PB'][i] > 0.95 and df['IIP21'][i] < 0:
# trend_refol = go.Scatter(x=[df.index[i]], y=[df['close'][i]], marker_color='skyblue', marker_size=20, marker_symbol='triangle-down', opacity=0.7, showlegend=False) #하늘
# fig.add_trace(trend_refol,row=1,col=1)
# fig.add_trace(trend_fol,row=1,col=1)
# 추세추총전략을 통해 캔들차트에 표시합니다.
# fig.add_trace(trend_refol,row=1,col=1)
# 역추세 전략을 통해 캔들차트에 표시합니다.
# fig.update_layout(autosize=True, xaxis1_rangeslider_visible=False, xaxis2_rangeslider_visible=False, margin=dict(l=50,r=50,t=50,b=50), template='seaborn', title=f'({ticker})의 날짜: ETH [추세추종전략:오↑파↓] [역추세전략:보↑하↓]')
# fig.update_xaxes(tickformat='%y년%m월%d일', zeroline=True, zerolinewidth=1, zerolinecolor='black', showgrid=True, gridwidth=2, gridcolor='lightgray', showline=True,linewidth=2, linecolor='black', mirror=True)
# fig.update_yaxes(tickformat=',d', zeroline=True, zerolinewidth=1, zerolinecolor='black', showgrid=True, gridwidth=2, gridcolor='lightgray',showline=True,linewidth=2, linecolor='black', mirror=True)
# fig.update_traces(xhoverformat='%y년%m월%d일')
# size = len(img)
# img = plt.io.to_image(fig, format='png')
# canvas = FigureCanvasBase(fig)
# img = Image.frombytes(mode='RGB', size=(700, 500), data=fig.to_image().to, decoder_name='raw')
# img = Image.frombytes('RGBA', (700, 500), plt.io.to_image(fig, format='png'), 'raw')
# img = Image.fromarray('RGBA', (700, 500), np.array(plt.io.to_image(fig, format='png')), 'raw')
# img = Image.fromarray(np.array(plt.io.to_image(fig, format='png')), 'RGB')
import io
# img = Image.fromarray(np.array(plt.io.to_image(fig, format='png')), 'L')
img = Image.open(io.BytesIO(plt.io.to_image(fig, format='png', width=140, height=100)))
img.convert("RGB")
img.thumbnail((100, 140), Image.ANTIALIAS)
# img = Image.Image(fig.to_image(), 'RGB')
# img.show()
return img
except Exception as ex:
print("`get_chart -> exception! %s `" % str(ex))
return None
차트 그리기는 https://sjblog1.tistory.com/m/45 의 내용을 참고했습니다. rolling함수를 사용할 경우 속도 문제가 있어서 ta-lib로 변경했습니다. ta-lib의 경우 거의 모든 주식차트 함수를 포함하고 있으며 c++로 개발이 되어 있어서 속도면에서 훨씬 빠릅니다.
클래스를 선언하고 get_chart함수를 호출하면 chart가 bmp 타입(png)의 이미지 데이터로 리턴이 됩니다. 데이터는 저의 경우 DB에 3분 단위로 매일 데이터를 저장하고 있고 그 기간이 벌써 1년반정도 되었씁니다. 꼭 DB를 사용하지 않아도 pyupbit로 데이터를 가져 오면 아마 1주일치 정도의 데이터를 가져 올 수 있습니다. 물론 더 많은 데이터가 있으면 좋겠지만 일주일 정도로도 훈련이 가능합니다. 다음 차에서는 상태값 S를 구하는 것에 대하여 설명 드리겠습니다.
'python > 자동매매 프로그램' 카테고리의 다른 글
강화학습을 이용한 비트코인 매매프로그램(5)-데이터에 Min/Max 추가하기 (0) | 2022.11.30 |
---|---|
강화학습을 이용한 비트코인 매매프로그램(4)-LSTM의 주가분석 문제점 (1) | 2022.11.30 |
강화학습을 이용한 비트코인 매매프로그램(2) - 개발 환경 셋팅 (2) | 2022.11.28 |
강화학습을 이용한 비트코인 매매프로그램(1)-강화 학습으로의 전환 (0) | 2022.11.23 |
python Bitcoin 자동 매매 프로그램(8) - Upsert 구현 (1) | 2021.08.21 |