Skip to content

Commit 29a2c51

Browse files
committed
完成交易引擎的测试重构
1 parent e889147 commit 29a2c51

7 files changed

Lines changed: 51 additions & 69 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,4 @@ tests/work/
1212

1313
_local_*
1414
demo/log
15+
__pycache__

quantdigger/technicals/common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
from six.moves import range
1010
import talib
11-
import matplotlib.finance as finance
1211

1312
from quantdigger.technicals.base import \
1413
TechnicalBase, ndarray, tech_init
@@ -152,7 +151,7 @@ def plot(self, widget):
152151
#emaslow = MA(x, nslow, type='exponential').value
153152
#emafast = MA(x, nfast, type='exponential').value
154153
#return emaslow, emafast, emafast - emaslow
155-
154+
156155
#def plot(self, widget):
157156
#self.widget = widget
158157
#fillcolor = 'darkslategrey'
@@ -176,6 +175,7 @@ def __init__(self, open, close, volume, name='volume',
176175
self.values = ndarray(volume)
177176

178177
def plot(self, widget):
178+
import matplotlib.finance as finance
179179
self.widget = widget
180180
finance.volume_overlay(widget, self.open, self.close, self.volume,
181181
self.colorup, self.colordown, self.width)

tests/sql_main.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
# -*- coding: utf-8 -*-
22
from test_data import *
33
from test_engine import *
4-
from test_future import *
5-
from test_stock import *
4+
from trading.test_future import *
5+
from trading.test_stock import *
66
from quantdigger import locd, set_config
77

88
if __name__ == '__main__':
99
set_config({ 'source': 'sqlite' })
1010
unittest.main()
11-
assert locd.source == 'sqlite'
11+
assert locd.source == 'sqlite'
1212
# 这里的代码不会被运行

tests/trading/stock_util.py

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
import datetime
55
import six
66
from quantdigger import settings
7-
from copy import deepcopy
87

98
window_size = 0
109
OFFSET = 0.6
10+
capital = 20000000
1111
bt1 = datetime.datetime.strptime("09:01:00", "%H:%M:%S").time()
1212
bt2 = datetime.datetime.strptime("09:02:00", "%H:%M:%S").time()
1313
bt3 = datetime.datetime.strptime("09:03:00", "%H:%M:%S").time()
@@ -25,10 +25,10 @@ def trade_closed_curbar(data, capital, long_margin, short_margin, volume_multipl
2525
"""
2626
assert(volume_multiple == 1 and long_margin == 1)
2727
UNIT = 1
28-
date_quantity= { }
28+
date_quantity= {}
2929
poscost = 0
3030
close_profit = 0
31-
equities = [] # 累计平仓盈亏
31+
equities = []
3232
dts = []
3333
cashes = []
3434

@@ -51,25 +51,25 @@ def trade_closed_curbar(data, capital, long_margin, short_margin, volume_multipl
5151
date_quantity.setdefault(curdate, 0)
5252
quantity = sum(date_quantity.values())
5353
poscost = (poscost * quantity + curprice *
54-
(1 + direction * settings['stock_commission'])*UNIT) / (quantity+UNIT)
54+
(1 + direction * settings['stock_commission']) * UNIT) / (quantity + UNIT)
5555
date_quantity[curdate] += UNIT
5656
elif curtime == st1:
5757
for posdate, quantity in six.iteritems(date_quantity):
5858
if posdate < curdate and quantity > 0: # 隔日交易
5959
open_close_profit = close_profit
6060
open_quantity = sum(date_quantity.values())
61-
close_profit += direction * (curprice*(1 - direction * settings['stock_commission'])-poscost) *\
62-
2*UNIT * volume_multiple
63-
date_quantity[posdate] -= 2*UNIT
61+
close_profit += direction * (curprice * (1 - direction *
62+
settings['stock_commission']) - poscost) * 2 * UNIT * volume_multiple
63+
date_quantity[posdate] -= 2 * UNIT
6464
elif posdate > curdate:
6565
assert(False)
6666
elif curtime == st2:
6767
for posdate, quantity in six.iteritems(date_quantity):
6868
if posdate < curdate and quantity > 0:
6969
open_close_profit = close_profit
7070
open_quantity = sum(date_quantity.values())
71-
close_profit += direction * (curprice*(1 - direction * settings['stock_commission'])-poscost) *\
72-
UNIT * volume_multiple
71+
close_profit += direction * (curprice * (1 - direction *
72+
settings['stock_commission']) - poscost) * UNIT * volume_multiple
7373
date_quantity[posdate] -= UNIT
7474
assert(date_quantity[posdate] == 0)
7575
elif posdate > curdate:
@@ -84,8 +84,8 @@ def trade_closed_curbar(data, capital, long_margin, short_margin, volume_multipl
8484
open_close_profit = close_profit
8585
open_quantity = sum(date_quantity.values())
8686
quantity = sum(date_quantity.values())
87-
close_profit += direction * (curprice*(1 - direction * settings['stock_commission'])-poscost) *\
88-
quantity * volume_multiple
87+
close_profit += direction * (curprice * (1 - direction * settings['stock_commission']) -
88+
poscost) * quantity * volume_multiple
8989
date_quantity.clear()
9090

9191
quantity = sum(date_quantity.values())
@@ -94,28 +94,27 @@ def trade_closed_curbar(data, capital, long_margin, short_margin, volume_multipl
9494
open_pos_profit = direction * (open_price - open_poscost) * open_quantity * volume_multiple
9595
open_posmargin = open_price * open_quantity * volume_multiple * long_margin
9696

97-
equities.append(capital+close_profit+pos_profit)
98-
cashes.append(equities[-1]-posmargin)
97+
equities.append(capital + close_profit + pos_profit)
98+
cashes.append(equities[-1] - posmargin)
9999
open_equities.append(capital + open_close_profit + open_pos_profit)
100-
open_cashes.append(open_equities[-1]-open_posmargin)
100+
open_cashes.append(open_equities[-1] - open_posmargin)
101101
dts.append(curdt)
102102
num += 1
103103
return equities, cashes, open_equities, open_cashes, dts
104104

105105

106106
def buy_monday_sell_friday(data, capital, long_margin, volume_multiple):
107107
""" 策略: 多头限价开仓且当根bar成交
108-
买入点: [bt1, bt2, bt3]
109-
当天卖出点: [st1, st2]
108+
周一买,周五卖
110109
"""
111110
assert(volume_multiple == 1 and long_margin == 1)
112111
UNIT = 1
113112
poscost = 0
114113
quantity = 0
115114
close_profit = 0
116-
equities = {} # 累计平仓盈亏
115+
equities = {}
117116
dts = []
118-
cashes = { }
117+
cashes = {}
119118

120119
open_poscost = 0
121120
open_cashes = {}
@@ -131,13 +130,13 @@ def buy_monday_sell_friday(data, capital, long_margin, volume_multiple):
131130
open_poscost = poscost
132131
open_quantity = quantity
133132
open_close_profit = close_profit
134-
poscost = (poscost * quantity + curprice*(1+settings['stock_commission'])*UNIT) / (quantity+UNIT)
133+
poscost = (poscost * quantity + curprice * (1 + settings['stock_commission']) * UNIT)\
134+
/ (quantity + UNIT)
135135
quantity += UNIT
136136
elif weekday == 4 and quantity > 0:
137137
open_close_profit = close_profit
138138
open_quantity = quantity
139-
close_profit += (curprice*(1-settings['stock_commission'])-poscost) *\
140-
quantity * volume_multiple
139+
close_profit += (curprice * (1 - settings['stock_commission']) - poscost) * quantity * volume_multiple
141140
quantity = 0
142141
else:
143142
open_close_profit = close_profit

tests/trading/temp.py

Lines changed: 0 additions & 5 deletions
This file was deleted.

tests/trading/test_future.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,22 @@
55
import os
66

77
from six.moves import range
8-
from future_util import (
8+
9+
from quantdigger.datastruct import TradeSide, Contract, Direction
10+
from quantdigger import (
11+
add_strategy,
12+
set_symbols,
13+
Strategy,
14+
run
15+
)
16+
17+
from .future_util import (
918
trade_closed_curbar,
1019
in_closed_nextbar,
1120
out_closed_nextbar,
1221
entries_maked_nextbar,
1322
market_trade_closed_curbar,
1423
OFFSET,
15-
)
16-
from source import (
1724
capital,
1825
bt1,
1926
bt2,
@@ -23,15 +30,6 @@
2330
st3,
2431
)
2532

26-
from quantdigger.datastruct import TradeSide, Contract, Direction
27-
28-
from quantdigger import (
29-
add_strategy,
30-
set_symbols,
31-
Strategy,
32-
run
33-
)
34-
3533
fname = os.path.join(os.getcwd(), 'data', '1MINUTE', 'TEST', 'FUTURE.csv')
3634
source = pd.read_csv(fname, parse_dates=True, index_col=0)
3735

tests/trading/test_stock.py

Lines changed: 15 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -10,26 +10,24 @@
1010
Strategy,
1111
run
1212
)
13-
from source import (
13+
14+
from .stock_util import (
15+
buy_monday_sell_friday,
16+
trade_closed_curbar,
1417
capital,
1518
bt1,
1619
bt2,
1720
bt3,
1821
st1,
1922
st2
2023
)
21-
from stock_util import (
22-
buy_monday_sell_friday,
23-
trade_closed_curbar,
24-
)
2524

2625
fname = os.path.join(os.getcwd(), 'data', '1MINUTE', 'TEST', 'STOCK.csv')
2726
source = pd.read_csv(fname, parse_dates=True, index_col=0)
2827

29-
3028
class TestOneDataOneCombinationStock(unittest.TestCase):
3129
"""
32-
ctx.pos 可平仓位
30+
ctx.pos 可平仓位, 当天买隔天卖,当天不能卖。
3331
"""
3432

3533
def test_case(self):
@@ -58,10 +56,11 @@ def on_bar(self, ctx):
5856
self.equities.append(ctx.equity())
5957

6058
def test(self, test):
61-
equities, cashes, open_equities, open_casheses, dts = trade_closed_curbar(source, capital*0.3, lmg, smg, multi, 1)
59+
equities, cashes, open_equities, open_casheses, dts =\
60+
trade_closed_curbar(source, capital * 0.3, lmg, smg, multi, 1)
6261

6362
test.assertTrue(len(self.cashes) == len(cashes), 'cash接口测试失败!')
64-
for i in range(0, len(self.cashes)): # 最后一根强平了无法比较
63+
for i in range(0, len(self.cashes)):
6564
test.assertAlmostEqual(self.cashes[i], open_casheses[i])
6665
test.assertAlmostEqual(self.equities[i], open_equities[i])
6766

@@ -70,7 +69,6 @@ def test(self, test):
7069
test.assertAlmostEqual(hd['equity'], equities[i])
7170
test.assertAlmostEqual(hd['cash'], cashes[i])
7271

73-
7472
class DemoStrategy2(Strategy):
7573
""" 限价买多卖空的策略 """
7674

@@ -110,7 +108,6 @@ def test(self, test):
110108
test.assertAlmostEqual(self.cashes[i], cashes[i])
111109

112110
class DemoStrategy3(Strategy):
113-
""" 测试平仓未成交时的持仓,撤单后的持仓,撤单。 """
114111
def on_init(self, ctx):
115112
"""初始化数据"""
116113
pass
@@ -189,7 +186,6 @@ def test(self, test, profile):
189186
count = 0
190187
all_holdings0 = profile.all_holdings(0)
191188
for i, hd in enumerate(all_holdings0):
192-
# 刚好最后一根没持仓,无需考虑强平, 见weekday输出
193189
dt = hd['datetime']
194190
if dt in cashes:
195191
test.assertAlmostEqual(hd['cash'], cashes[dt])
@@ -199,11 +195,10 @@ def test(self, test, profile):
199195
count += 1
200196
else:
201197
# 两支股票的混合,总数据长度和source不一样。
202-
test.assertAlmostEqual(all_holdings0[i-1]['cash'], hd['cash'])
203-
test.assertAlmostEqual(all_holdings0[i-1]['equity'], hd['equity'])
198+
test.assertAlmostEqual(all_holdings0[i - 1]['cash'], hd['cash'])
199+
test.assertAlmostEqual(all_holdings0[i - 1]['equity'], hd['equity'])
204200
test.assertTrue(count == len(dts))
205201

206-
207202
class DemoStrategy2(Strategy):
208203
""" 选股,并且时间没对齐的日线数据。 """
209204
def __init__(self, name):
@@ -235,21 +230,19 @@ def on_bar(self, ctx):
235230
self.tosells = []
236231

237232
def test(self, test, profile):
238-
# test Strategy2
239233
fname = os.path.join(os.getcwd(), 'data', '1DAY', 'SH', '600521.csv')
240234
source = pd.read_csv(fname, parse_dates=True, index_col=0)
241235
fname = os.path.join(os.getcwd(), 'data', '1DAY', 'SH', '600522.csv')
242236
source2 = pd.read_csv(fname, parse_dates=True, index_col=0)
243237
equities0, cashes0, open_equities0, open_cashes0, dts = \
244238
buy_monday_sell_friday(source, capital * 0.3 / 2, 1, 1)
245239
equities1, cashes1, open_equities1, open_cashes1, dts = \
246-
buy_monday_sell_friday(source2, capital*0.3/2, 1, 1)
240+
buy_monday_sell_friday(source2, capital * 0.3 / 2, 1, 1)
247241
last_equity0 = 0
248-
last_equity1 = 0
242+
last_equity1 = 0
249243
last_cash0 = 0
250244
last_cash1 = 0
251245
for i, hd in enumerate(profile.all_holdings(1)):
252-
# 刚好最后一根没持仓,无需考虑强平, 见weekday输出
253246
dt = hd['datetime']
254247
equity = 0
255248
cash = 0
@@ -271,14 +264,13 @@ def test(self, test, profile):
271264
if dt in equities1:
272265
equity += equities1[dt]
273266
cash += cashes1[dt]
274-
last_equity1 = equities1[dt]
267+
last_equity1 = equities1[dt]
275268
last_cash1 = cashes1[dt]
276269
open_equity += open_equities1[dt]
277270
open_cash += open_cashes1[dt]
278271
else:
279272
equity += last_equity1
280273
cash += last_cash1
281-
# 新的ctx.cash()将会以最近数据的收盘价为准。
282274
open_equity += last_equity1
283275
open_cash += last_cash1
284276

@@ -287,7 +279,6 @@ def test(self, test, profile):
287279
test.assertAlmostEqual(self._equities[dt], open_equity)
288280
test.assertAlmostEqual(self._cashes[dt], open_cash)
289281

290-
291282
class DemoStrategy3(Strategy):
292283
""" 测试平仓未成交时的持仓,撤单后的持仓,撤单。 """
293284
def on_init(self, ctx):
@@ -304,14 +295,12 @@ def on_bar(self, ctx):
304295
profile = add_strategy([b1, b2, b3], {
305296
'capital': capital,
306297
'ratio': [0.3, 0.3, 0.4]
307-
})
298+
}
299+
)
308300
run()
309301
b1.test(self, profile)
310302
b2.test(self, profile)
311303

312304

313-
314-
315-
316305
if __name__ == '__main__':
317306
unittest.main()

0 commit comments

Comments
 (0)