Skip to content

Commit 9bfeca7

Browse files
committed
add tests
1 parent fc61ad5 commit 9bfeca7

30 files changed

Lines changed: 14671 additions & 625 deletions

quantdigger/datasource/data.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# -*- coding: utf-8 -*-
22
import os
33
import pandas as pd
4+
import time
45
from datetime import datetime, timedelta
56
from quantdigger.datasource.source import CsvSource, SqlLiteSource
67
from quantdigger.datasource.datautil import tick2period
@@ -19,16 +20,35 @@ class LocalData(object):
1920
def __init__(self):
2021
"""
2122
"""
22-
self._csv = CsvSource(''.join([os.getcwd(), os.sep, 'data', os.sep]))
23-
self._sql = SqlLiteSource(''.join([os.getcwd(), os.sep, 'data', os.sep, 'digger.db']))
23+
self._csv = CsvSource(os.path.join(os.getcwd(), 'data'))
24+
self._sql = SqlLiteSource(os.path.join(os.getcwd(), 'data', 'digger.db'))
2425
self._src = self._sql # 设置数据源
2526

2627
def load_bars(self, pcontract, dt_start, dt_end, window_size):
28+
""" 获取本地历史数据
29+
30+
Args:
31+
pcontract (PContract): 周期合约
32+
33+
dt_start (datetime): 数据的开始时间
34+
35+
dt_end (datetime): 数据的结束时间
36+
37+
window_size (int): 窗口大小,0表示大小为数据长度。
38+
39+
Returns:
40+
SourceWrapper. 数据
41+
"""
2742
if pcontract.contract.exch_type == 'stock':
2843
return []
2944
else:
3045
return self._src.load_bars(pcontract, dt_start, dt_end, window_size);
3146

47+
def load_data(self, pcontract, dt_start=datetime(1980,1,1),
48+
dt_end=datetime(2100,1,1), window_size=0):
49+
""" 返回DataFrame数据 """
50+
return self.load_bars(pcontract, dt_start, dt_end, window_size).data
51+
3252
def loadTickData(self):
3353
raise NotImplementedError
3454

@@ -100,7 +120,7 @@ def load_bars(self, pcontract, dt_start=datetime(1980,1,1),
100120
dt_end(datetime): 结束时间
101121
102122
Returns:
103-
DataFrame.
123+
SourceWrapper.
104124
"""
105125
if type(dt_start) == str:
106126
dt_start = datetime.strptime(dt_start, "%Y-%m-%d")

quantdigger/datasource/source.py

Lines changed: 93 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,78 @@
11
# -*- coding: utf-8 -*-
2+
import csv
3+
import datetime
24
import os
35
import pandas as pd
46
import string
5-
from quantdigger.errors import FileDoesNotExist
7+
import time
8+
from quantdigger.errors import FileDoesNotExist, DataFieldError
69
from quantdigger.datasource import datautil
710

11+
812
class SourceWrapper(object):
9-
def __init__(self, data, cursor, max_length):
13+
def __init__(self, pcontract, data, cursor, max_length=0):
14+
"""
15+
max_length=0,表示逐步模式
16+
"""
1017
self.data = data
1118
self.cursor = cursor
1219
self._max_length = max_length
20+
self.curbar = -1
21+
self.pcontract = pcontract
1322

1423
def __len__(self):
15-
# 总长度
1624
return self._max_length
1725

26+
1827
class SqliteSourceWrapper(SourceWrapper):
19-
def __init__(self, data, cursor, max_length):
20-
super(SqliteSourceWrapper, self).__init__(data, cursor, max_length)
28+
def __init__(self, pcontract, data, cursor, max_length=0):
29+
super(SqliteSourceWrapper, self).__init__(pcontract, data, cursor, max_length)
30+
31+
def rolling_foward(self):
32+
self.curbar += 1
33+
# self.cursor为None说明是向量运算
34+
if self.cursor:
35+
return self.cursor.fetchone(), self.curbar
36+
# 超过向量的最大长度。
37+
if self.curbar == self._max_length:
38+
return None, self.curbar
39+
else:
40+
return True, self.curbar
41+
42+
43+
class CsvSourceWrapper(SourceWrapper):
44+
def __init__(self, pcontract, data, cursor, max_length=0):
45+
super(CsvSourceWrapper, self).__init__(pcontract, data, cursor, max_length)
2146

2247
def rolling_foward(self):
23-
# cursor为None说明是向量运算
48+
self.curbar += 1
49+
# self.cursor为None说明是向量运算
2450
if self.cursor:
25-
return self.cursor.fetchone()
26-
return None
51+
try:
52+
row = self.cursor.next()
53+
except StopIteration:
54+
return None, self.curbar
55+
else:
56+
dt = datetime.datetime.strptime(row[0], "%Y-%m-%d %H:%M:%S")
57+
row[0] = dt
58+
return row, self.curbar
59+
if self.curbar == self._max_length:
60+
return None, self.curbar
61+
else:
62+
return True, self.curbar
63+
64+
65+
def convert_datetime(tf):
66+
return datetime.datetime.fromtimestamp(float(tf)/1000)
2767

2868
class SqlLiteSource(object):
2969
"""
3070
"""
3171
def __init__(self, fname):
3272
import sqlite3
33-
self.db = sqlite3.connect(fname)
73+
self.db = sqlite3.connect(fname,
74+
detect_types = sqlite3.PARSE_DECLTYPES)
75+
sqlite3.register_converter('timestamp', convert_datetime)
3476
## @todo remove self.cursor
3577
self.cursor = self.db.cursor()
3678

@@ -39,18 +81,18 @@ def load_bars(self, pcontract, dt_start, dt_end, window_size):
3981
id_start, u = datautil.encode2id(pcontract.period, dt_start)
4082
id_end, u = datautil.encode2id(pcontract.period, dt_end)
4183
table = string.replace(str(pcontract.contract), '.', '_')
84+
#sql = "SELECT COUNT(*) FROM {tb} \
85+
#WHERE {start}<=id AND id<={end}".format(tb=table, start=id_start, end=id_end)
86+
#max_length = cursor.execute(sql).fetchone()[0]
4287
#
43-
sql = "SELECT COUNT(*) FROM {tb} \
44-
WHERE {start}<=id AND id<={end}".format(tb=table, start=id_start, end=id_end)
45-
max_length = cursor.execute(sql).fetchone()[0]
46-
#
47-
sql = "SELECT utime, open, close, high, low, volume FROM {tb} \
88+
sql = "SELECT datetime, open, close, high, low, volume FROM {tb} \
4889
WHERE {start}<=id AND id<={end}".format(tb=table, start=id_start, end=id_end)
4990

91+
data = pd.read_sql_query(sql, self.db, index_col='datetime')
5092
if window_size == 0:
51-
data = pd.read_sql_query(sql, self.db, index_col='utime')
93+
data = pd.read_sql_query(sql, self.db, index_col='datetime')
5294
## @todo
53-
return SqliteSourceWrapper(data, None, len(data))
95+
return SqliteSourceWrapper(pcontract, data, None, len(data))
5496
else:
5597
cursor.execute(sql)
5698
data = pd.DataFrame({
@@ -61,7 +103,7 @@ def load_bars(self, pcontract, dt_start, dt_end, window_size):
61103
'volume': []
62104
})
63105
data.index = []
64-
return SqliteSourceWrapper(data, cursor, max_length)
106+
return SqliteSourceWrapper(pcontract, data, cursor, window_size)
65107

66108
def read_csv(self, path):
67109
""" 导入路径path下所有csv数据文件到sqlite中,每个文件对应数据库中的一周表格。
@@ -90,9 +132,9 @@ def to_csv(self, index=True, index_label='index'):
90132
for table_name in tables:
91133
table_name = table_name[0]
92134
table = pd.read_sql_query("SELECT * from %s" % table_name, self.db)
93-
#table['datetime'] = map(lambda x : datetime.fromtimestamp(x / 1000), table['utime'])
135+
#table['datetime'] = map(lambda x : datetime.fromtimestamp(x / 1000), table['datetime'])
94136
table.to_csv(table_name + '.csv', index=index, index_label=index_label,
95-
columns = ['utime', 'open', 'close', 'high', 'low', 'volume'])
137+
columns = ['datetime', 'open', 'close', 'high', 'low', 'volume'])
96138

97139
def get_tables(self):
98140
""" 返回数据库所有的表格"""
@@ -108,41 +150,54 @@ def get_table_fields(self, tb):
108150
def _df2sqlite(self, df, tbname):
109151
self.cursor.execute('''CREATE TABLE {tb}
110152
(id int primary key,
111-
utime timestamp,
153+
datetime timestamp,
112154
open real,
113155
close real,
114156
high real,
115157
low real,
116158
volume int)'''.format(tb = tbname))
117159
data = []
118160
for index, row in df.iterrows():
119-
id, utime = datautil.encode2id('1.Minute', index)
120-
data.append((id, utime, row['open'], row['close'], row['high'], row['low'], row['vol']))
161+
id, datetime = datautil.encode2id('1.Minute', index)
162+
data.append((id, datetime, row['open'], row['close'], row['high'], row['low'], row['vol']))
121163
sql = "INSERT INTO %s VALUES (?,?,?,?,?,?,?)" % tbname
122164
self.cursor.executemany(sql, data)
123165
self.db.commit()
124166

125-
import time
167+
126168
class CsvSource(object):
127169
"""
128170
(datetime, open, close, high, low, volume)
129-
## @todo
130-
(utime, open, close, high, low, volume)
131171
"""
132172
def __init__(self, root):
133173
self._root = root
134174

135-
def load_bars(self, pcontract, dt_start, dt_end):
136-
fname = ''.join([self._root, str(pcontract), ".csv"])
137-
try:
138-
data = pd.read_csv(fname, index_col=0, parse_dates=True)
139-
dt_start = pd.to_datetime(dt_start)
140-
dt_end = pd.to_datetime(dt_end)
141-
data = data[(dt_start <= data.index) & (data.index <= dt_end)]
142-
data.index = map(lambda x : int(time.mktime(x.timetuple())*1000), data.index)
143-
assert data.index.is_unique
144-
except IOError:
145-
#print u"**Warning: File \"%s\" doesn't exist!"%fname
146-
raise FileDoesNotExist(file=fname)
175+
def load_bars(self, pcontract, dt_start, dt_end, window_size):
176+
fname = os.path.join(self._root, str(pcontract) + ".csv")
177+
if window_size == 0:
178+
try:
179+
data = pd.read_csv(fname, index_col=0, parse_dates=True)
180+
except IOError:
181+
raise FileDoesNotExist(file=fname)
182+
else:
183+
dt_start = pd.to_datetime(dt_start)
184+
dt_end = pd.to_datetime(dt_end)
185+
data = data[(dt_start <= data.index) & (data.index <= dt_end)]
186+
#data.index = map(lambda x : int(time.mktime(x.timetuple())*1000), data.index)
187+
assert data.index.is_unique
188+
return CsvSourceWrapper(pcontract, data, None, len(data))
147189
else:
148-
return data
190+
data = pd.DataFrame({
191+
'open': [],
192+
'close': [],
193+
'high': [],
194+
'low': [],
195+
'volume': []
196+
})
197+
data.index = []
198+
cursor = csv.reader(open(fname, 'rb'))
199+
fmt = ['datetime', 'open', 'close', 'high', 'low', 'volume']
200+
header = cursor.next()
201+
if header != fmt:
202+
raise DataFieldError(error_fields=header, right_fields=fmt)
203+
return CsvSourceWrapper(pcontract, data, cursor, window_size)

quantdigger/datastruct.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,11 @@ def __str__(self):
380380
""" return string like 'IF000.SHEF-10.Minutes' """
381381
return "%s-%s" % (self.contract, self.period)
382382

383+
@classmethod
384+
def from_string(self, strpc):
385+
t = strpc.split('-')
386+
return PContract(Contract(t[0]), Period(t[1]))
387+
383388
def __hash__(self):
384389
if hasattr(self, '_hash'):
385390
return self._hash

quantdigger/demo/.DS_Store

0 Bytes
Binary file not shown.

quantdigger/demo/data/.DS_Store

0 Bytes
Binary file not shown.

quantdigger/demo/main.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,14 +63,10 @@ def on_bar(self):
6363
signals = []
6464
for trans in algo.blotter.transactions:
6565
deals.update_positions(positions, signals, trans);
66-
#d = simulator.data[pcon]['close']
67-
#for i in d:
68-
#print i
69-
#assert False
7066
plotting.plot_result(simulator.data[pcon],
7167
algo._indicators,
7268
signals,
73-
algo.blotter)
69+
algo.blotter.all_holdings)
7470

7571

7672
except Exception, e:

0 commit comments

Comments
 (0)