Skip to content

Commit 87b3e97

Browse files
committed
add plot_line interface to context
1 parent fd19246 commit 87b3e97

12 files changed

Lines changed: 286 additions & 166 deletions

File tree

demo/plot_strategy.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,21 +25,24 @@ def on_init(self, ctx):
2525
ctx.ma100 = MA(ctx.close, 100, 'ma100', 'y', 2) #, 'ma200', 'b', '1')
2626
ctx.ma200 = MA(ctx.close, 200, 'ma200', 'b', 2) #, 'ma200', 'b', '1')
2727
#ctx.boll = BOLL(ctx.close, 20)
28-
ctx.ma2 = NumberSeries()
28+
ctx.dt = DateTimeSeries()
29+
ctx.month_price = NumberSeries()
2930

3031
#def on_symbol(self, ctx):
3132
#pass
3233

3334
def on_bar(self, ctx):
34-
if ctx.curbar > 1:
35-
ctx.ma2.update((ctx.close + ctx.close[1])/2)
35+
ctx.dt.update(ctx.datetime)
36+
#print ctx.dt[1].isocalendar()[1], ctx.dt[0].isocalendar()[1]
37+
if ctx.dt[1].month != ctx.dt[0].month:
38+
ctx.month_price.update(ctx.close)
3639
if ctx.curbar > 200:
3740
if ctx.pos() == 0 and ctx.ma100[2] < ctx.ma200[2] and ctx.ma100[1] > ctx.ma200[1]:
3841
ctx.buy(ctx.close, 1)
3942
elif ctx.pos() > 0 and ctx.ma100[2] > ctx.ma200[2] and \
4043
ctx.ma100[1] < ctx.ma200[1]:
4144
ctx.sell(ctx.close, ctx.pos())
42-
45+
ctx.plot_line("month_price", ctx.curbar, ctx.month_price, 'b-', lw=2)
4346
#boll['upper'].append(ctx.boll['upper'][0])
4447
#boll['middler'].append(ctx.boll['middler'][0])
4548
#boll['lower'].append(ctx.boll['lower'][0])
@@ -67,6 +70,7 @@ def on_bar(self, ctx):
6770
elif ctx.pos() > 0 and ctx.ma50[2] > ctx.ma100[2] and \
6871
ctx.ma50[1] < ctx.ma100[1]:
6972
ctx.sell(ctx.close, ctx.pos())
73+
7074
return
7175

7276
def on_exit(self, ctx):
@@ -94,7 +98,8 @@ def on_exit(self, ctx):
9498
curve1 = finance.create_equity_curve(profile.all_holdings(1))
9599
curve = finance.create_equity_curve(profile.all_holdings())
96100
plotting.plot_strategy(profile.data(0), profile.technicals(0),
97-
profile.deals(0), curve0.equity.values)
101+
profile.deals(0), curve0.equity.values,
102+
profile.marks(0))
98103
#plotting.plot_curves([curve0.equity, curve1.equity, curve.equity],
99104
#colors=['r', 'g', 'b'],
100105
#names=[profile.name(0), profile.name(1), 'A0'])

quantdigger/digger/plotting.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import matplotlib.pyplot as plt
55
from matplotlib.ticker import Formatter
66
from quantdigger.widgets.mplotwidgets import widgets, mplots
7-
from quantdigger.technicals import EquityCurve, Volume
7+
from quantdigger.technicals import Line, LineWithX, Volume
88

99

1010
def xticks_to_display(data_length):
@@ -18,7 +18,7 @@ def xticks_to_display(data_length):
1818
return xticks
1919

2020

21-
def plot_strategy(price_data, indicators={}, deals=[], curve=[]):
21+
def plot_strategy(price_data, indicators={}, deals=[], curve=[], marks=[]):
2222
"""
2323
显示回测结果。
2424
"""
@@ -38,12 +38,37 @@ def plot_strategy(price_data, indicators={}, deals=[], curve=[]):
3838
signal = mplots.TradingSignalPos(price_data, deals, lw=2)
3939
frame.add_indicator(0, signal)
4040
if len(curve) > 0:
41-
curve = EquityCurve(curve)
42-
frame.add_indicator(0, curve, True)
41+
curve = Line(curve)
42+
#frame.add_indicator(0, curve, True)
4343
frame.add_indicator(1, Volume(price_data.open, price_data.close, price_data.volume))
4444
## 添加指标
4545
for name, indic in indicators.iteritems():
4646
frame.add_indicator(0, indic)
47+
# 绘制标志
48+
if marks:
49+
if marks[0]:
50+
# plot lines
51+
for name, values in marks[0].iteritems():
52+
v = values[0]
53+
line_pieces = [[v[0]], [v[1]], v[2], v[3], v[4]]
54+
line = []
55+
for v in values[1: ]:
56+
## @TODO 如果是带“点”的,以点的特征聚类,会减少indicator对象的数目
57+
if v[2] != line_pieces[2] or v[3] != line_pieces[3] or v[4] != line_pieces[4]:
58+
line.append(line_pieces)
59+
line_pieces = [[v[0]], [v[1]], v[2], v[3], v[4]]
60+
else:
61+
line_pieces[0].append(v[0])
62+
line_pieces[1].append(v[1])
63+
line.append(line_pieces)
64+
for v in line:
65+
## @TODO 这里的sytle明确指出有点奇怪,不一致。
66+
curve = LineWithX(v[0], v[1], style=v[2], lw=v[3], ms=v[4])
67+
frame.add_indicator(0, curve, False)
68+
if marks[1]:
69+
# plot texts
70+
for name, values in marks[0].iteritems():
71+
print name
4772
frame.draw_widgets()
4873
plt.show()
4974

quantdigger/engine/api.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,6 @@ def connect(self):
1515
""" 连接器 """
1616
pass
1717

18-
#@abstractmethod
19-
#def register_handlers(self, handlers):
20-
#""" 注册回调函数 """
21-
#pass
22-
2318
@abstractmethod
2419
def query_contract(self, contract, sync=False):
2520
""" 合约查询 """

quantdigger/engine/blotter.py

Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,17 @@
2121

2222
class Profile(object):
2323
""" 组合结果 """
24-
def __init__(self, blotters, dcontexts, strpcon, i):
25-
self._blts = blotters # 组合内所有策略的blotter
24+
def __init__(self, scontexts, dcontexts, strpcon, i):
25+
"""
26+
27+
Args:
28+
scontexts (list): 策略上下文集合
29+
dcontexts (list): 数据上下文集合
30+
strpcon (str): 主合约
31+
i (int): 当前profile所对应的组合索引
32+
"""
33+
self._marks = [ctx.marks for ctx in scontexts]
34+
self._blts = [ctx.blotter for ctx in scontexts]
2635
self._dcontexts = {}
2736
self._ith_comb = i # 对应于第几个组合
2837
self._main_pcontract = strpcon
@@ -93,17 +102,6 @@ def all_holdings(self, j=None):
93102
hd['equity'] += rhd['equity']
94103
return holdings
95104

96-
#def current_positions(self, j):
97-
#""" 当前持仓
98-
99-
#Args:
100-
#j (int): 第j个策略
101-
102-
#Returns:
103-
#dict. { Contract: Position }
104-
#"""
105-
#return self._blts[j].current_positions.values()
106-
107105
def holding(self, j=None):
108106
""" 当前账号情况
109107
@@ -126,8 +124,15 @@ def holding(self, j=None):
126124
holdings['history_profit'] += rhd['history_profit']
127125
return holdings
128126

127+
def marks(self, j=None):
128+
""" 返回第j个策略的绘图标志集合 """
129+
if j is not None:
130+
return self._marks[j]
131+
return self._marks[0]
132+
129133
def technicals(self, j=None, strpcon=None):
130134
# @TODO test case
135+
# @TODO 没必要针对不同的strpcon做分析
131136
""" 返回第j个策略的指标, 默认返回组合的所有指标。
132137
133138
Args:
@@ -353,9 +358,6 @@ def update_status(self, dt, at_baropen, append):
353358
self.holding['cash'] = dh['cash']
354359
self.holding['equity'] = dh['equity']
355360
self.holding['position_profit'] = pos_profit
356-
#if self.name == 'A2' and append == False:
357-
#print dh['equity'], "**" , self._datetime
358-
359361
if append:
360362
self._all_holdings.append(dh)
361363
else:
@@ -389,8 +391,6 @@ def update_signal(self, event):
389391
pos = self.positions[
390392
PositionKey(order.contract, order.direction)]
391393
pos.closable -= order.quantity
392-
#print "Receive %d signals!" % len(event.orders)
393-
#self.generate_naive_order(event.orders)
394394

395395
def update_fill(self, event):
396396
""" 处理委托单成交事件。 """
@@ -441,10 +441,6 @@ def _update_holding(self, trans):
441441
flag = 1 if trans.direction == Direction.LONG else -1
442442
profit = (trans.price-self.positions[poskey].cost) * \
443443
trans.quantity * flag * trans.volume_multiple
444-
#print profit, trans.price, trans.quantity
445-
#if self.name == 'A1': # 平仓调试
446-
#print "***********"
447-
#print self._datetime, profit
448444
self.holding['history_profit'] += profit
449445
self._all_transactions.append(trans)
450446

quantdigger/engine/context.py

Lines changed: 45 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -204,13 +204,16 @@ class StrategyContext(object):
204204
:ivar name: 策略名
205205
:ivar blotter: 订单管理
206206
:ivar exchange: 价格撮合器
207+
:ivar marks: 绘图标志集合
207208
"""
208209
def __init__(self, name, settings={}):
209210
self.events_pool = EventsPool()
210211
# @TODO merge blotter and exchange
211212
self.blotter = SimpleBlotter(name, self.events_pool, settings)
212213
self.exchange = Exchange(name, self.events_pool, strict=True)
213214
self.name = name
215+
# 0: line_marks, 1: text_marks
216+
self.marks = [{}, {}]
214217
self._entry_orders = []
215218
self._exit_orders = []
216219
self._datetime = None
@@ -246,6 +249,35 @@ def process_trading_events(self, at_baropen):
246249
self.events_pool.put(OnceEvent())
247250
self._process_trading_events(at_baropen, append)
248251

252+
def plot_line(self, name, x, y, styles, lw=1, ms=10):
253+
""" 绘制曲线
254+
255+
Args:
256+
name (str): 标志名称
257+
x (datetime): 时间坐标
258+
y (float): y坐标
259+
styles (str): 控制颜色,线的风格,点的风格
260+
lw (int): 线宽
261+
ms (int): 点的大小
262+
"""
263+
mark = self.marks[0].setdefault(name, [])
264+
mark.append((x, y, styles, lw, ms))
265+
266+
def plot_text(self, name, x, y, text, color='black', size=10, rotation=0):
267+
""" 绘制文本
268+
269+
Args:
270+
name (str): 标志名称
271+
x (float): x坐标
272+
y (float): y坐标
273+
text (str): 文本内容
274+
color (str): 颜色
275+
size (int): 字体大小
276+
rotation (float): 旋转角度
277+
"""
278+
mark = self.marks[1].setdefault(name, [])
279+
mark.append((x, y, text, color, size, rotation))
280+
249281
def _process_trading_events(self, at_baropen, append):
250282
""""""
251283
while True:
@@ -257,9 +289,9 @@ def _process_trading_events(self, at_baropen, append):
257289
except IndexError:
258290
break
259291
else:
260-
#if event.type == 'MARKET':
261-
##strategy.calculate_signals(event)
262-
#port.update_timeindex(event)
292+
# if event.type == 'MARKET':
293+
# strategy.calculate_signals(event)
294+
# port.update_timeindex(event)
263295
if event.type == Event.SIGNAL:
264296
assert(not at_baropen)
265297
self.blotter.update_signal(event)
@@ -414,9 +446,9 @@ def time_aligned(self):
414446
return (self._cur_data_context.datetime[0] <= self.ctx_datetime and
415447
self._cur_data_context.last_date <= self.ctx_datetime)
416448
## 第一根是必须运行
417-
#return (self._cur_data_context.datetime[0] <= self.ctx_dt_series and
418-
#self._cur_data_context.ctx_dt_series <= self.ctx_dt_series) or \
419-
#self._cur_data_context.curbar == 0
449+
# return (self._cur_data_context.datetime[0] <= self.ctx_dt_series and
450+
# self._cur_data_context.ctx_dt_series <= self.ctx_dt_series) or \
451+
# self._cur_data_context.curbar == 0
420452

421453
def switch_to_strategy(self, i, j, trading=False):
422454
self._trading = trading
@@ -498,7 +530,7 @@ def symbol(self):
498530

499531
@property
500532
def curbar(self):
501-
""" 当前是第几根k线, 从0开始 """
533+
""" 当前是第几根k线, 从1开始 """
502534
if self.on_bar:
503535
return self.step + 1
504536
else:
@@ -717,6 +749,12 @@ def profit(self, contract=None):
717749
#return
718750
pass
719751

752+
def plot_line(self, name, x, y, styles, lw=1, ms=10):
753+
self._cur_strategy_context.plot_line(name, x-1, float(y), styles, lw, ms)
754+
755+
def plot_text(self, name, x, y, text, color='black', size=10, rotation=0):
756+
self._cur_strategy_context.plot_text(name, x-1, float(y), text, color, size, rotation)
757+
720758
def day_profit(self, contract=None):
721759
""" 当前持仓的浮动盈亏 """
722760
#if not self._trading:

quantdigger/engine/event.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
# -*- coding: utf-8 -*-
2-
# event.py
3-
#from flufl.enum import Enum
2+
# from flufl.enum import Enum
43

54

65
class EventsPool(object):

quantdigger/engine/execute_unit.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,13 @@ def __init__(self,
2121
"""
2222
Args:
2323
pcontracts (list): list of pcontracts(string)
24+
2425
dt_start (datetime/str): start time of all pcontracts
26+
2527
dt_end (datetime/str): end time of all pcontracts
28+
2629
n (int): last n bars
30+
2731
spec_date (dict): time range for specific pcontracts
2832
"""
2933
self.finished_data = []
@@ -122,8 +126,7 @@ def add_comb(self, comb, settings):
122126
# logger.debug(iset)
123127
ctxs.append(StrategyContext(s.name, iset))
124128
self.context.add_strategy_context(ctxs)
125-
blotters = [ctx.blotter for ctx in ctxs]
126-
return blotter.Profile(blotters,
129+
return blotter.Profile(ctxs,
127130
self._all_data,
128131
self.pcontracts[0],
129132
len(self._combs)-1)

quantdigger/engine/series.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -34,24 +34,16 @@ def reset_data(self, data, wsize):
3434
# 序列系统变量
3535
self.data = data
3636

37-
#def _sort_data(self):
38-
#temp = self.data[self._index]
39-
#max_index = self.window_size-1
40-
#for i in range(0, max_index):
41-
#ordered_index = (self.curbar - (max_index-self._index)) % self._window_size
42-
#self.data[self._index] = self.data[ordered_index]
43-
#self._index = ordered_index
44-
#self.data[max_index] = temp
45-
#self._index = max_index
46-
#return
47-
4837
def update_curbar(self, curbar):
4938
""" 更新当前Bar索引 """
5039
self.curbar = curbar
5140

5241
def update(self, v):
5342
""" 更新最后一个值 """
54-
self.data[self.curbar] = v
43+
if isinstance(v, SeriesBase):
44+
self.data[self.curbar] = v[0]
45+
else:
46+
self.data[self.curbar] = v
5547

5648
def __len__(self):
5749
return len(self.data)
@@ -202,6 +194,7 @@ def __call__(self, index):
202194
class DateTimeSeries(SeriesBase):
203195
""" 时间序列变量 """
204196
DEFAULT_VALUE = datetime.datetime(1980, 1, 1)
197+
value_type = datetime.datetime
205198

206199
def __init__(self, data=[], name='DateTimeSeries'):
207200
super(DateTimeSeries, self).__init__(data, name,

0 commit comments

Comments
 (0)