맨위로버튼이미지

from lib.UpbitTrade import UpbitTrade
from lib.ColoredLogger import COLORS_LOG as cl
from lib.Upsert import aio_data_update
# from lib.upbitSellBuy import *
from lib.dl_upbitSellBuy import predict
from lib.Common import send_slack
from lib.Common import make_real_rate
from lib.Common import dfTrend
from lib.Common import dfInclease
import pyupbit
import time
import sys
import os
import aiomysql
import pymysql
import asyncio

import logging
import logging.config
import json

import common.constrant as Const
import config.localconf as conf
import pandas as pd
import numpy as np

config = json.load(open('config/logging.json'))
logging.config.dictConfig(config)
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)

table_name="xxx.TB_ETH_TRADE"
table_report="xxx.TB_ETH_DAILY_REPORT"
table_config="xxx.TB_ETH_CASH_CONFIG"
table_crashed="xxx.TB_CRASHED_REPORT"
table_realrate="xxx.TB_REAL_REPORT"

pd.set_option('display.max_columns', 13)
pd.set_option('display.width', 200)

def get_avg_buy_price():
    try : 
        amount = upbit.get_amount('ALL')
        unit = trade.get_balance()[0]
        if unit == 0:
            return 0
        else:
            return amount / unit
    except Exception as ex:
        exc_type, exc_obj, exc_tb = sys.exc_info()
        fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
        logger.error("get_avg_buy_price exception! %s ` : %s %d" % (str(ex) , fname, exc_tb.tb_lineno))
        return 0

def get_avg_sell_price():
    try:
        balances, req = upbit.get_balances(contain_req=True)
        if len(balances) == 0:
            return 0

        avg_sell_price = 0
        for x in balances:
            if x['currency'] == conf.TRADE_TICKER:
                avg_sell_price = float(x['avg_sell_price'])
                break
        return avg_sell_price
    except Exception as x:
        exc_type, exc_obj, exc_tb = sys.exc_info()
        fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
        logger.error("get_avg_sell_price exception! %s ` : %s %d" % (str(ex) , fname, exc_tb.tb_lineno))
        return 0

async def read_all(sqlcon) -> pd.DataFrame:
    #read Backup DB 
    try:
        cursor = await sqlcon.cursor(aiomysql.cursors.DictCursor)
        str_query = "select * from %s" % table_name
        await cursor.execute(str_query)
        data = await cursor.fetchall()
        datadf = pd.DataFrame(data)
        return datadf
    except Exception as ex:
        exc_type, exc_obj, exc_tb = sys.exc_info()
        fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
        logger.error("read from db exception! %s ` : %s %d" % (str(ex) , fname, exc_tb.tb_lineno))
    finally:
        await cursor.close()

async def read_from(sqlcon, ntime) -> pd.DataFrame:
    try:
        cursor = await sqlcon.cursor(aiomysql.cursors.DictCursor)
        str_query = "select * from %s where `time` > %d" % (table_name, ntime)
        # logger.info("str_query:%s", str_query)
        await cursor.execute(str_query)
        data = await cursor.fetchall()
        datadf = pd.DataFrame(data)
        # logger.info("return df")
        # logger.info(datadf)
        return datadf
    except Exception as ex:
        exc_type, exc_obj, exc_tb = sys.exc_info()
        fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
        logger.error("read from db exception! %s ` : %s %d" % (str(ex) , fname, exc_tb.tb_lineno))
    finally:
        await cursor.close()

async def trade_proc():
    last_time = ""
    krw = 0
    first_time = True
    avg_buy_price = 0.0
    avg_sell_price = 0.0
    slack_flag = 0
    old_rate = 0
    rate_trend = ""

    try:
        pool = await aiomysql.create_pool(host="193.xxx.xxx.xx", user="xxxxxx", password='xxxxxxx', db='xxxxx')
        async with pool.acquire() as sqlcon:
            
            df = await read_all(sqlcon)
            idx = df.index.max()
            idx_time = df.loc[idx, 'time']
            sqlcon.close()
        while True:
            async with pool.acquire() as sqlcon:
                try:
                    await get_trade_yn(sqlcon)
                    start_time = time.time()

                    add_df = await read_from(sqlcon, idx_time)
                    df = pd.concat([df, add_df])
                    df = df.drop_duplicates(['time'], keep='first', inplace=False, ignore_index=True)
                    try:
                        avg_buy_price = get_avg_buy_price()
                        logger.info(cl["BOLD"] + cl["RED"] + "평균매수단가:%.2f" + cl["RESET"], avg_buy_price)
                    except Exception as ex:
                        exc_type, exc_obj, exc_tb = sys.exc_info()
                        fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
                        logger.error("get_avg_buy_price -> exception! %s : %s %d" % (str(ex) , fname, exc_tb.tb_lineno))

                    ###잔고 조회
                    current_unit = 0.0
                    try:
                        balance = trade.get_balance()
                        price = pyupbit.get_current_price("KRW-" + conf.TRADE_TICKER)
                        logger.info("현재가 : %.4f" % price)
                        limit_cash = (balance[0] * price + balance[1]) * conf.REINVESTMENT_RATE if balance[0] != 0 or balance[1] != 0 else base_balance
                        logger.info("LIMIT_CASH: %.2f", limit_cash)
                        trade.set_limitcash(Const.LIMIT_CASH)
                        rate = ((balance[0]*price + balance[1] - base_balance)*100)/base_balance if base_balance is not None and base_balance != 0 else 0
                        trend_rate = old_rate - rate
                        old_rate = rate
                        if trend_rate == 0:
                            rate_trend += "="
                        elif trend_rate > 0:
                            rate_trend += "V"
                        else:
                            rate_trend += "^"

                        if len(rate_trend) > 5:
                            rate_trend = rate_trend[-5:]

                        logger.info(cl["ITELIC"] + cl["RED"] + "평가금액:%.2f 수익률: %.2f%% %s" + cl["RESET"], balance[0]*price + balance[1], rate, rate_trend)
                        #현재의 원화 잔고 얻기
                        krw = balance[1]
                        current_unit = balance[0]
                        logger.info(cl["BOLD"] + cl["RED"] + "잔고조회: %.5f %.2f" + cl["RESET"], balance[0], balance[1])
                    except Exception as ex:
                        exc_type, exc_obj, exc_tb = sys.exc_info()
                        fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
                        logger.error("get_current_price -> exception! %s : %s %d" % (str(ex) , fname, exc_tb.tb_lineno))

                    logger.info(cl["BOLD"] + cl["GREEN"] + "clsmin %.2f(%s), clsmax:%.2f(%s), close:%.2f(%s)" + cl["RESET"], 
                        df.loc[idx, 'clsmin'], dfInclease(df.index.max(), df, "clsmin"), df.loc[idx, 'clsmax'], 
                        dfInclease(df.index.max(), df, "clsmax"), df.loc[idx, 'close'], dfInclease(df.index.max(), df, "close"))
                    logger.info(cl["BOLD"] + cl["GREEN"] + "min_rate:%.2f%%(%s) %s max_rate:%.2f%%(%s) %s" + cl["RESET"], 
                        df.loc[idx, 'min_rate'], dfInclease(df.index.max(), df, "min_rate"), 
                        dfTrend(df.index.max(), df, "minratedeg"), df.loc[idx, 'max_rate'], 
                        dfInclease(df.index.max(), df, "max_rate"), 
                        dfTrend(df.index.max(), df, "maxratedeg"))
                    logger.info("time date:" + time.strftime("%Y%m%d%H%M%S", time.localtime(idx_time)))
                    logger.info(cl["BOLD"] + cl["GREEN"] + "trade_ok:%s" + cl["RESET"], str(trade_ok))
                    dx = df
                    pred = predict(logger, dx, idx)
                    logger.info(cl["BOLD"] + cl["GREEN"] + "predict action:%s" + cl["RESET"], str(pred.get_dl_action()))
                    if pred.fsell() and current_unit > 0 :
                        logger.info("매도 타이밍: %s ", time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(idx_time)))
                        if trade_ok == True:
                            ret = trade.sell_crypto_currency()
                            if ret is not None:
                                cursor = await sqlcon.cursor(aiomysql.cursors.DictCursor)
                                try:
                                    ret = upbit.get_order("KRW-" + conf.TRADE_TICKER)
                                    while len(ret) != 0 : 
                                        ret = upbit.get_order("KRW-" + conf.TRADE_TICKER)
                                        await asyncio.sleep(0.01)
                                    str_query = """SELECT * 
                                        FROM main.TB_TRADE_HIST 
                                        WHERE user_id='%s' 
                                        AND qty > 0 
                                        AND end_dt = 0""" % user
                                    await cursor.execute(str_query)
                                    data = await cursor.fetchall()
                                    if len(data) > 0:
                                        fatch_result = pd.DataFrame(data)
                                        start_dt = fatch_result['start_dt'][0]
                                        qty = fatch_result['qty'][0]
                                        # logger.info("start_dt:%d", start_dt)
                                        str_query = """
                                        UPDATE main.TB_TRADE_HIST 
                                        SET 
                                            sell_qty = qty
                                            , qty=0
                                            , close=%f 
                                            , sell_sum =%f
                                            , end_dt=%d
                                        WHERE user_id = '%s' AND start_dt = %d""" % (
                                            df.loc[idx, 'close'], df.loc[idx, 'close']*qty, time.time(),  user, start_dt
                                        )
                                        await cursor.execute(str_query)
                                        await sqlcon.commit()
                                except Exception as ex:
                                    exc_type, exc_obj, exc_tb = sys.exc_info()
                                    fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
                                    logger.error("read db error -> exception! " + str(ex) + "` %s: %d", fname, exc_tb.tb_lineno)
                                finally:
                                    await cursor.close()
                    elif pred.fbuy():
                        logger.info("매수 타이밍 %s", time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(idx_time)))
                        if trade_ok == True:
                            ret = trade.buy_crypto_currency()
                            if ret is not None:
                                try:
                                    ret = upbit.get_order("KRW-" + conf.TRADE_TICKER)
                                    while len(ret) != 0 :
                                        ret = upbit.get_order("KRW-" + conf.TRADE_TICKER)
                                        await asyncio.sleep(0.01)
                                    avg_buy_price = get_avg_buy_price()
                                    balance = trade.get_balance()
                                    trdhist_clmns = ['user_id', 'start_dt', 'ticker', 'qty', 'cost', 'cost_sum', 'sell_qty', 'close', 'sell_sum', 'end_dt']
                                    trdhist_data  = [[user, df.loc[idx, 'time'], conf.TRADE_TICKER, balance[0], avg_buy_price, balance[0] * avg_buy_price,
                                                    0, 0, 0, 0]]
                                    trdhist_df = pd.DataFrame.from_records(data=trdhist_data, columns=trdhist_clmns)
                                    await aio_data_update(trdhist_df, sqlcon, "main.TB_TRADE_HIST", "replace")
                                except Exception as ex:
                                    exc_type, exc_obj, exc_tb = sys.exc_info()
                                    fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
                                    logger.error("read db error -> exception! " + str(ex) + "` %s: %d", fname, exc_tb.tb_lineno)
                                finally:
                                    await cursor.close()

                    logger.info(df.tail(5))

                    idx = df.index.max()
                    idx_time = df.loc[idx, 'time']
                    sqlcon.close()
                except Exception as ex:
                    exc_type, exc_obj, exc_tb = sys.exc_info()
                    fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
                    logger.error("`main -> exception! %s ` : %s %d" % (str(ex) , fname, exc_tb.tb_lineno))
                    time.sleep(60)
                finally:
                    end_time = time.time()
                    if (end_time - start_time) > 0 and (end_time - start_time) <= Const.SLEEP_TIME:
                        await asyncio.sleep(Const.SLEEP_TIME - (end_time - start_time))
    except Exception as ex:
        exc_type, exc_obj, exc_tb = sys.exc_info()
        fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
        logger.error("`main -> exception! %s ` : %s %d" % (str(ex) , fname, exc_tb.tb_lineno))
        time.sleep(60)
    finally:
        pool.close()
        await pool.wait_closed()         

async def get_trade_yn(sqlcon):
    global trade_ok
    global base_balance
    user = os.getenv('user')
    
    str_sql = """
        SELECT 
            a.TRADE_YN
            ,a.BASE_BALANCE
        FROM main.TN_PAYMENT_ACNT a
        WHERE a.USER_ID = '%s'
    """ % user
    cursor = await sqlcon.cursor(aiomysql.cursors.DictCursor)
    await cursor.execute(str_sql)
    user_info = pd.DataFrame(await cursor.fetchall())
    await cursor.close()

    trade_yn = np.asarray(user_info)[0][0]
    base_balance = float(np.asarray(user_info)[0][1])
    trade_ok = (trade_yn == 'Y')

def get_env():
    global upbit
    global trade
    global user
    user = os.getenv('user')
    logger.info("user[%s]", user)
    
    sqlcon = pymysql.connect(host='193.xxx.xxx.xx', user='xxxxxx', password='xxxxxxx', db='xxxx')
    str_sql = """
        SELECT 
            a.PUB_KEY
            ,a.APP_KEY
            ,a.SECU_KEY
            ,a.BASE_BALANCE
        FROM TN_PAYMENT_ACNT a
        WHERE a.USER_ID = '%s'
    """ % user
    cursor = sqlcon.cursor(pymysql.cursors.DictCursor)
    cursor.execute(str_sql)
    user_info = pd.DataFrame(cursor.fetchall())
    cursor.close()

    pub_key = np.asarray(user_info)[0][0]
    app_key = np.asarray(user_info)[0][1]
    sec_key = np.asarray(user_info)[0][2]
    base_balance = float(np.asarray(user_info)[0][3])

    # aes = AESCryptoCBC(pub_key)

    enc_app_key = app_key
    enc_sec_key = sec_key

    upbit = pyupbit.Upbit(enc_app_key, enc_sec_key)
    trade  = UpbitTrade(upbit, logger, conf.TRADE_TICKER, base_balance, base_balance * conf.REINVESTMENT_RATE)

if __name__ == '__main__':
    get_env()
    asyncio.run(trade_proc())​

 

import torch
import torch.nn as nn
import torchvision
import torch.nn.functional as F

import plotly.graph_objects as go
import plotly.subplots as ms
import plotly as plt
from PIL import Image
import io
import os.path as path
import talib
import numpy as np

data_size = 100
action_kind = 4
screen_height = 50
screen_width  = 70

class DQN(nn.Module):
    def __init__(self, device, h, w, outputs, qsize):
        super(DQN, self).__init__()
        self.conv1 = nn.Conv2d(4, h*w, kernel_size=5, stride=2)
        self.bn1 = nn.BatchNorm2d(h*w)
        self.conv2 = nn.Conv2d(h*w, qsize, kernel_size=5, stride=2)
        self.bn2 = nn.BatchNorm2d(qsize)
        self.device = device
        self.conv3 = nn.Conv2d(qsize, qsize, kernel_size=5, stride=2)
        self.bn3 = nn.BatchNorm2d(qsize)

        linear_input_size = 3 * qsize
        self.head = nn.Linear(linear_input_size, outputs)

    # Called with either one element to determine next action, or a batch
    # during optimization. Returns tensor([[left0exp,right0exp]...]).
    def forward(self, x):
        x = x.to(self.device)
        x = F.relu(self.bn1(self.conv1(x)))
        # print("x1------->", x.size())
        x = F.relu(self.bn2(self.conv2(x)))
        # print("x2------->", x.size())
        x = F.relu(self.bn3(self.conv3(x)))
        # print("x3------->", x.size())
        return self.head(x.view(x.size(0), -1))

def get_chart(df, idx, max_data:int=300, i_w:int=140, i_h:int=100):
    ndf = df.head(idx).tail(max_data)
    ndf.reset_index(drop=True, inplace=True)
    
    candle = go.Candlestick(x=ndf.index,open=ndf['open'],high=ndf['high'],low=ndf['low'],close=ndf['close'], increasing_line_color = 'red',decreasing_line_color = 'blue', showlegend=False)
    upper = go.Scatter(x=ndf.index, y=ndf['upper'], line=dict(color='red', width=2), name='upper', showlegend=False)
    ma20 = go.Scatter(x=ndf.index, y=ndf['ma20'], line=dict(color='black', width=2), name='ma20', showlegend=False)
    lower = go.Scatter(x=ndf.index, y=ndf['lower'], line=dict(color='blue', width=2), name='lower', showlegend=False)

    volume = go.Bar(x=ndf.index, y=ndf['volume'], marker_color='red', name='volume', showlegend=False)

    MACD = go.Scatter(x=ndf.index, y=ndf['macd'], line=dict(color='blue', width=2), name='MACD', legendgroup='group2', legendgrouptitle_text='MACD')
    MACD_Signal = go.Scatter(x=ndf.index, y=ndf['signal'], line=dict(dash='dashdot', color='green', width=2), name='MACD_Signal')
    MACD_Oscil = go.Bar(x=ndf.index, y=ndf['flag'], marker_color='purple', name='MACD_Oscil')

    fast_k = go.Scatter(x=ndf.index, y=ndf['fast_k'], line=dict(color='skyblue', width=2), name='fast_k', legendgroup='group3', legendgrouptitle_text='%K %D')
    slow_d = go.Scatter(x=ndf.index, y=ndf['slow_d'], line=dict(dash='dashdot', color='black', width=2), name='slow_d')

    PB = go.Scatter(x=ndf.index, y=ndf['PB']*100, line=dict(color='blue', width=2), name='PB', legendgroup='group4', legendgrouptitle_text='PB, MFI')
    MFI10 = go.Scatter(x=ndf.index, y=ndf['MFI10'], line=dict(dash='dashdot', color='green', width=2), name='MFI10')

    RSI = go.Scatter(x=ndf.index, y=ndf['rsi14'], line=dict(color='red', width=2), name='RSI', legendgroup='group5', legendgrouptitle_text='RSI')
    
    # 스타일
    fig = ms.make_subplots(rows=5, cols=2, specs=[[{'rowspan':4},{}],[None,{}],[None,{}],[None,{}],[{},{}]], shared_xaxes=True, horizontal_spacing=0.03, vertical_spacing=0.01)

    fig.add_trace(candle,row=1,col=1)
    fig.add_trace(upper,row=1,col=1)
    fig.add_trace(ma20,row=1,col=1)
    fig.add_trace(lower,row=1,col=1)

    fig.add_trace(volume,row=5,col=1)

    fig.add_trace(candle,row=1,col=2)
    fig.add_trace(upper,row=1,col=2)
    fig.add_trace(ma20,row=1,col=2)
    fig.add_trace(lower,row=1,col=2)

    fig.add_trace(MACD,row=2,col=2)
    fig.add_trace(MACD_Signal,row=2,col=2)
    fig.add_trace(MACD_Oscil,row=2,col=2)

    fig.add_trace(fast_k,row=3,col=2)
    fig.add_trace(slow_d,row=3,col=2)

    fig.add_trace(PB,row=4,col=2)
    fig.add_trace(MFI10,row=4,col=2)

    fig.add_trace(RSI,row=5,col=2)

    # 추세추종
    # trend_fol = 0
    # trend_refol = 0
    # for i in ndf.index:
    #     if ndf['PB'][i] > 0.8 and ndf['MFI10'][i] > 80:
    #         trend_fol = go.Scatter(x=[ndf.index[i]], y=[ndf['close'][i]], marker_color='orange', marker_size=20, marker_symbol='triangle-up', opacity=0.7, showlegend=False)
    #         fig.add_trace(trend_fol,row=1,col=1)
    #     elif ndf['PB'][i] < 0.2 and ndf['MFI10'][i] < 20:
    #         trend_fol = go.Scatter(x=[ndf.index[i]], y=[ndf['close'][i]], marker_color='darkblue', marker_size=20, marker_symbol='triangle-down', opacity=0.7, showlegend=False)
    #         fig.add_trace(trend_fol,row=1,col=1)

    # 역추세추종
    # for i in ndf.index:
    #     if ndf['PB'][i] < 0.05 and ndf['IIP21'][i] > 0:
    #         trend_refol = go.Scatter(x=[ndf.index[i]], y=[ndf['close'][i]], marker_color='purple', marker_size=20, marker_symbol='triangle-up', opacity=0.7, showlegend=False)  #보라
    #         fig.add_trace(trend_refol,row=1,col=1)
    #     elif df['PB'][i] > 0.95 and ndf['IIP21'][i] < 0:
    #         trend_refol = go.Scatter(x=[ndf.index[i]], y=[ndf['close'][i]], marker_color='skyblue', marker_size=20, marker_symbol='triangle-down', opacity=0.7, showlegend=False)  #하늘
    #         fig.add_trace(trend_refol,row=1,col=1)    

    # fig.add_trace(trend_fol,row=1,col=1)
    # 추세추총전략을 통해 캔들차트에 표시합니다.

    # fig.add_trace(trend_refol,row=1,col=1)
    # 역추세 전략을 통해 캔들차트에 표시합니다.
    
    # fig.update_layout(autosize=True, xaxis1_rangeslider_visible=False, xaxis2_rangeslider_visible=False, margin=dict(l=50,r=50,t=50,b=50), template='seaborn', title=f'({ticker})의 날짜: ETH [추세추종전략:오↑파↓] [역추세전략:보↑하↓]')
    # fig.update_xaxes(tickformat='%y년%m월%d일', zeroline=True, zerolinewidth=1, zerolinecolor='black', showgrid=True, gridwidth=2, gridcolor='lightgray', showline=True,linewidth=2, linecolor='black', mirror=True)
    # fig.update_yaxes(tickformat=',d', zeroline=True, zerolinewidth=1, zerolinecolor='black', showgrid=True, gridwidth=2, gridcolor='lightgray',showline=True,linewidth=2, linecolor='black', mirror=True)
    # fig.update_traces(xhoverformat='%y년%m월%d일')
    # size = len(img)
    # img = plt.io.to_image(fig, format='png')
    # canvas = FigureCanvasBase(fig)
    # img = Image.frombytes(mode='RGB', size=(700, 500), data=fig.to_image().to, decoder_name='raw')
    # img = Image.frombytes('RGBA', (700, 500), plt.io.to_image(fig, format='png'), 'raw')
    # img = Image.fromarray('RGBA', (700, 500), np.array(plt.io.to_image(fig, format='png')), 'raw')
    # img = Image.fromarray(np.array(plt.io.to_image(fig, format='png')), 'RGB')
    # img = Image.fromarray(np.array(plt.io.to_image(fig, format='png')), 'L')
    # img = Image.open(io.BytesIO(plt.io.to_image(fig, format='png', width=140, height=100)))
    # img = Image.open(io.BytesIO(plt.io.to_image(fig, format='png', width=700, height=500)))
    img = Image.open(io.BytesIO(plt.io.to_image(fig, format='png')))
    # print(img)
    img.convert("RGB")
    # print(img)
    img.thumbnail((i_h, i_w), Image.ANTIALIAS)
    # print(img)
    # img = Image.Image(img, 'RGB')
    # img.show()
    return img

from lib.addAuxiliaryData import Auxiliary

class predict():
    def __init__(self, logger, df, idx) -> None:
        df['open'] = df['start']
        df['sma3'] = talib.SMA(np.asarray(df['close']), 3)
        aux = Auxiliary(logger)
        aux.add_macd(df,12,26,9)
        aux.add_bbands(df)
        aux.add_relative_strength(df)
        aux.add_stock_cast(df)
        aux.add_iip(df)
        aux.add_mfi(df)
        # df.dropna(inplace=True)
        df.reset_index(drop=True,inplace=True)
        # gpu가 사용 가능한 경우에는 device를 gpu로 설정하고 불가능하면 cpu로 설정합니다.
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        converter = torchvision.transforms.ToTensor()
        # 모델을 지정한 장치로 올립니다.
        self.model = DQN(device, screen_height, screen_width, action_kind, 4).to(device)
        # model = CNNRNN(device, screen_height, screen_width, action_kind, 4).to(device)
        # model = CNN().to(device)
        self.model = nn.DataParallel(self.model, device_ids=[0,1]).to(device)

        if path.exists("/home/yubank/pt/train_dqn_{0:02d}_{1}.pt".format(action_kind, device)):
            self.model.load_state_dict(torch.load("/home/yubank/pt/train_dqn_{0:02d}_{1}.pt".format(action_kind, device)))

        self.model.eval()

        self.x = get_chart(df, idx, data_size, i_w = screen_width, i_h = screen_height)
        self.x = converter(self.x).unsqueeze(0).to(device).squeeze(1)

    def get_dl_action(self):
        output = self.model.forward(self.x)
        _,action = torch.max(output,1)
        # print("action:", action)
        return action.cpu().numpy()[0]

    def fbuy(self):
        return True if self.get_dl_action() == 0 else False

    def fsell(self):
        return True if self.get_dl_action() == (action_kind - 1) else False

 

 

우여 곡절이 중간에 조금 있었습니다.

 

학습된 모델을 pt파일로 저장을 한 다음 predict 클래스를 만들었습니다.

 

파이선의 클래스는 static method와 일반 method의 개념이 없어서 그냥 함수 묶음으로도 사용이 가능합니다.


적용 후 실제 upbit에서 거래가 발생했고 한 차례도 손실없이 매매를 하고 있습니다.

 

1주일을 일단 집에서 돌리고 고로케이션을 알아보고 있는데 최소 30만원(랙 마운트 케이스 및 파워구매, 대충 4u 로케이션 가격 15만원 설치비 5만원(IDC내부 회선 설치 및 전원공사임 내서버 설치는 별도임))의 추가 비용이 발생합니다.  

 

수익률이 안나온다면 코로케이션을 진행할 수 없을 것 같네요.

 

맺은말

궁금한 점이나 이상하다고 생각되시면 답글로 남겨 주세요.

세세하고 친절히 답변 드리겠습니다.

반응형
LIST

+ Recent posts