Skip to content

Commit 959a92c

Browse files
committed
修复模拟器bug,添加模拟器的测试案例
1 parent 1524108 commit 959a92c

14 files changed

Lines changed: 5379 additions & 246 deletions

File tree

quantdigger/datastruct.py

Lines changed: 97 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,20 @@ def arg_to_type(self, arg):
4545
else:
4646
return arg
4747

48+
@classmethod
49+
def type_to_str(self, type):
50+
type2str = {
51+
self.BUY: '多头开仓',
52+
self.SHORT: '空头开仓',
53+
self.COVER: '空头平仓',
54+
self.SELL: '多头平仓',
55+
self.COVER_TODAY: '空头平今',
56+
self.SELL_TODAY: '多头平今',
57+
self.KAI: '开仓',
58+
self.PING: '平仓',
59+
}
60+
return type2str[type]
61+
4862

4963
class Captial(object):
5064
""" 账号资金
@@ -72,11 +86,11 @@ def __init__(self, dt, contract, type_, side, direction, price, quantity):
7286
class PriceType(object):
7387
""" 下单类型
7488
75-
:ivar LMT: 限价单 - 0.
76-
:ivar MKT: 市价单 - 1.
89+
:ivar LMT: 限价单 - 1.
90+
:ivar MKT: 市价单 - 0.
7791
"""
78-
LMT = 0
79-
MKT = 1
92+
LMT = 1
93+
MKT = 0
8094

8195
@classmethod
8296
def arg_to_type(self, arg):
@@ -91,6 +105,12 @@ def arg_to_type(self, arg):
91105
else:
92106
return arg
93107

108+
@classmethod
109+
def type_to_str(self, type):
110+
type2str = { self.LMT: 'LMT',
111+
self.MKT: 'MKT'}
112+
return type2str[type]
113+
94114
class HedgeType(object):
95115
""" 下单类型
96116
@@ -112,6 +132,13 @@ def arg_to_type(self, arg):
112132
else:
113133
return arg
114134

135+
@classmethod
136+
def type_to_str(self, type):
137+
type2str = { self.SPEC: 'SPEC',
138+
self.HEDG: 'HEDG'}
139+
140+
return type2str[type]
141+
115142
class Direction(object):
116143
"""
117144
多空方向。
@@ -137,8 +164,8 @@ def arg_to_type(self, arg):
137164

138165
@classmethod
139166
def type_to_str(self, type):
140-
type2str = { self.LONG: 'LONG',
141-
self.SHORT: 'SHORT'}
167+
type2str = { self.LONG: 'long',
168+
self.SHORT: 'short'}
142169

143170
return type2str[type]
144171

@@ -175,12 +202,23 @@ def __init__(self, order=None):
175202
self.margin_ratio = 1
176203

177204
def __hash__(self):
178-
if hasattr(self, '_hash'):
179-
return self._hash
180-
else:
205+
try:
206+
return self._hash
207+
except AttributeError:
181208
self._hash = hash(self.id)
182209
return self._hash
183210

211+
def __eq__(self, r):
212+
return self._hash == r._hash
213+
214+
def __str__(self):
215+
rst = " id: %s\n contract: %s\n direction: %s\n price: %f\n quantity: %d\n side: %s\n datetime: %s\n price_type: %s\n hedge_type: %s\n margin_ratio: %f" % \
216+
(self.id, self.contract, Direction.type_to_str(self.direction),
217+
self.price, self.quantity, TradeSide.type_to_str(self.side),
218+
self.datetime, PriceType.type_to_str(self.price_type),
219+
HedgeType.type_to_str(self.hedge_type), self.margin_ratio)
220+
return rst
221+
184222

185223

186224
class OrderID(object):
@@ -216,6 +254,9 @@ def __gt__(self, other):
216254

217255
def __ge__(self, other):
218256
return self.id >= other.id
257+
258+
def __str__(self):
259+
return str(self.id)
219260

220261

221262
class Order(object):
@@ -262,12 +303,23 @@ def print_order(self):
262303
pass
263304

264305
def __hash__(self):
265-
if hasattr(self, '_hash'):
266-
return self._hash
267-
else:
306+
try:
307+
return self._hash
308+
except AttributeError:
268309
self._hash = hash(self.id)
269310
return self._hash
270311

312+
def __str__(self):
313+
rst = " id: %s\n contract: %s\n direction: %s\n price: %f\n quantity: %d\n side: %s\n datetime: %s\n price_type: %s\n hedge_type: %s\n margin_ratio: %f" % \
314+
(self.id, self.contract, Direction.type_to_str(self.direction),
315+
self.price, self.quantity, TradeSide.type_to_str(self.side),
316+
self.datetime, PriceType.type_to_str(self.price_type),
317+
HedgeType.type_to_str(self.hedge_type), self.margin_ratio)
318+
return rst
319+
320+
def __eq__(self, r):
321+
return self._hash == r._hash
322+
271323

272324
class Contract(object):
273325
""" 合约。
@@ -285,19 +337,23 @@ def __init__(self, str_contract):
285337
assert False
286338
self.exch_type = exchange # 用'stock'表示中国股市
287339
self.code = code
340+
## @TODO 从代码中计算
288341
self._is_stock = True if exchange == 'stock' else False
289342

290343
def __str__(self):
291344
""""""
292345
return "%s.%s" % (self.code, self.exch_type)
293346

294347
def __hash__(self):
295-
if hasattr(self, '_hash'):
296-
return self._hash
297-
else:
348+
try:
349+
return self._hash
350+
except AttributeError:
298351
self._hash = hash(self.__str__())
299352
return self._hash
300353

354+
def __eq__(self, r):
355+
return self._hash == r._hash
356+
301357
@property
302358
def is_stock(self):
303359
""" 是否是股票"""
@@ -358,13 +414,6 @@ def length(self):
358414
def __str__(self):
359415
return "%d.%s" % (self._length, self._type)
360416

361-
def __hash__(self):
362-
if hasattr(self, '_hash'):
363-
return self._hash
364-
else:
365-
self._hash = hash(self.__str__())
366-
return self._hash
367-
368417

369418
class PContract(object):
370419
""" 特定周期的合约。
@@ -385,13 +434,35 @@ def from_string(self, strpcon):
385434
t = strpcon.split('-')
386435
return PContract(Contract(t[0]), Period(t[1]))
387436

437+
def __hash__(self):
438+
try:
439+
return self._hash
440+
except AttributeError:
441+
self._hash = hash(self.__str__())
442+
return self._hash
443+
444+
def __eq__(self, r):
445+
return self._hash == r._hash
446+
447+
448+
class PositionKey(object):
449+
def __init__(self, contract, direction):
450+
self.contract = contract
451+
self.direction = direction
452+
453+
def __str__(self):
454+
return "%s_%s" % (self.contract, str(self.direction))
455+
388456
def __hash__(self):
389457
if hasattr(self, '_hash'):
390458
return self._hash
391459
else:
392-
self._hash = hash(self.__str__())
460+
self._hash = hash((self.contract, self.direction))
393461
return self._hash
394462

463+
def __eq__(self, r):
464+
return self._hash == r._hash
465+
395466

396467
class Position(object):
397468
""" 单笔仓位信息。
@@ -443,19 +514,10 @@ def position_margin(self, new_price):
443514
price = self.cost if self.contract.is_stock else new_price
444515
return price * self.quantity * self.margin_ratio
445516

446-
def __hash__(self):
447-
if hasattr(self, '_hash'):
448-
return self._hash
449-
else:
450-
self._hash = hash(self.contract)
451-
return self._hash
452-
453517
def __str__(self):
454-
rst = """
455-
Position:
456-
cost - %f
457-
quantity - %d
458-
""" % (self.cost, self.quantity)
518+
rst = " contract: %s\n direction: %s\n cost: %f\n quantity: %d\n " % \
519+
(self.contract, Direction.type_to_str(self.direction),
520+
self.cost, self.quantity)
459521
return rst
460522

461523

quantdigger/demo/plot_strategy.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,10 @@ def on_init(self, ctx):
3030
def on_bar(self, ctx):
3131
if ctx.curbar > 20:
3232
if ctx.ma10[1] < ctx.ma20[1] and ctx.ma10 > ctx.ma20:
33-
ctx.buy('long', ctx.close, 1)
33+
ctx.buy(ctx.close, 1)
3434
elif ctx.position() > 0 and ctx.ma10[1] > ctx.ma20[1] and \
3535
ctx.ma10 < ctx.ma20:
36-
ctx.sell('long', ctx.close, 1)
36+
ctx.sell(ctx.close, 1)
3737

3838
boll['upper'].append(ctx.boll['upper'][0])
3939
boll['middler'].append(ctx.boll['middler'][0])
@@ -58,10 +58,10 @@ def on_init(self, ctx):
5858
def on_bar(self, ctx):
5959
if ctx.curbar > 10:
6060
if ctx.ma5[1] < ctx.ma10[1] and ctx.ma5 > ctx.ma10:
61-
ctx.buy('long', ctx.close, 1)
61+
ctx.buy(ctx.close, 1)
6262
elif ctx.position() > 0 and ctx.ma5[1] > ctx.ma10[1] and \
6363
ctx.ma5 < ctx.ma10:
64-
ctx.sell('long', ctx.close, 1)
64+
ctx.sell(ctx.close, 1)
6565

6666
def on_final(self, ctx):
6767
return

quantdigger/digger/plotting.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def xticks_to_display(data_length):
1616
return xticks
1717

1818

19-
def plot_strategy(price_data, indicators, signals):
19+
def plot_strategy(price_data, indicators={ }, deals=[]):
2020
"""
2121
显示回测结果。
2222
"""
@@ -31,8 +31,9 @@ def plot_strategy(price_data, indicators, signals):
3131
kwindow = widgets.CandleWindow("kwindow", price_data, 100, 50)
3232
frame.add_widget(0, kwindow, True)
3333
## 交易信号。
34-
signal = mplots.TradingSignalPos(price_data, signals, lw=2)
35-
frame.add_indicator(0, signal)
34+
if deals:
35+
signal = mplots.TradingSignalPos(price_data, deals, lw=2)
36+
frame.add_indicator(0, signal)
3637
## @bug indicators导致的双水平线!
3738
## @todo 完mplot_demo上套。
3839
#frame.add_indicator(0, Volume(None, price_data.open, price_data.close, price_data.volume))

0 commit comments

Comments
 (0)