11# -*- coding: utf-8 -*-
2+ import csv
3+ import datetime
24import os
35import pandas as pd
46import string
5- from quantdigger .errors import FileDoesNotExist
7+ import time
8+ from quantdigger .errors import FileDoesNotExist , DataFieldError
69from quantdigger .datasource import datautil
710
11+
812class SourceWrapper (object ):
9- def __init__ (self , data , cursor , max_length ):
13+ def __init__ (self , pcontract , data , cursor , max_length = 0 ):
14+ """
15+ max_length=0,表示逐步模式
16+ """
1017 self .data = data
1118 self .cursor = cursor
1219 self ._max_length = max_length
20+ self .curbar = - 1
21+ self .pcontract = pcontract
1322
1423 def __len__ (self ):
15- # 总长度
1624 return self ._max_length
1725
26+
1827class SqliteSourceWrapper (SourceWrapper ):
19- def __init__ (self , data , cursor , max_length ):
20- super (SqliteSourceWrapper , self ).__init__ (data , cursor , max_length )
28+ def __init__ (self , pcontract , data , cursor , max_length = 0 ):
29+ super (SqliteSourceWrapper , self ).__init__ (pcontract , data , cursor , max_length )
30+
31+ def rolling_foward (self ):
32+ self .curbar += 1
33+ # self.cursor为None说明是向量运算
34+ if self .cursor :
35+ return self .cursor .fetchone (), self .curbar
36+ # 超过向量的最大长度。
37+ if self .curbar == self ._max_length :
38+ return None , self .curbar
39+ else :
40+ return True , self .curbar
41+
42+
43+ class CsvSourceWrapper (SourceWrapper ):
44+ def __init__ (self , pcontract , data , cursor , max_length = 0 ):
45+ super (CsvSourceWrapper , self ).__init__ (pcontract , data , cursor , max_length )
2146
2247 def rolling_foward (self ):
23- # cursor为None说明是向量运算
48+ self .curbar += 1
49+ # self.cursor为None说明是向量运算
2450 if self .cursor :
25- return self .cursor .fetchone ()
26- return None
51+ try :
52+ row = self .cursor .next ()
53+ except StopIteration :
54+ return None , self .curbar
55+ else :
56+ dt = datetime .datetime .strptime (row [0 ], "%Y-%m-%d %H:%M:%S" )
57+ row [0 ] = dt
58+ return row , self .curbar
59+ if self .curbar == self ._max_length :
60+ return None , self .curbar
61+ else :
62+ return True , self .curbar
63+
64+
65+ def convert_datetime (tf ):
66+ return datetime .datetime .fromtimestamp (float (tf )/ 1000 )
2767
2868class SqlLiteSource (object ):
2969 """
3070 """
3171 def __init__ (self , fname ):
3272 import sqlite3
33- self .db = sqlite3 .connect (fname )
73+ self .db = sqlite3 .connect (fname ,
74+ detect_types = sqlite3 .PARSE_DECLTYPES )
75+ sqlite3 .register_converter ('timestamp' , convert_datetime )
3476 ## @todo remove self.cursor
3577 self .cursor = self .db .cursor ()
3678
@@ -39,18 +81,18 @@ def load_bars(self, pcontract, dt_start, dt_end, window_size):
3981 id_start , u = datautil .encode2id (pcontract .period , dt_start )
4082 id_end , u = datautil .encode2id (pcontract .period , dt_end )
4183 table = string .replace (str (pcontract .contract ), '.' , '_' )
84+ #sql = "SELECT COUNT(*) FROM {tb} \
85+ #WHERE {start}<=id AND id<={end}".format(tb=table, start=id_start, end=id_end)
86+ #max_length = cursor.execute(sql).fetchone()[0]
4287 #
43- sql = "SELECT COUNT(*) FROM {tb} \
44- WHERE {start}<=id AND id<={end}" .format (tb = table , start = id_start , end = id_end )
45- max_length = cursor .execute (sql ).fetchone ()[0 ]
46- #
47- sql = "SELECT utime, open, close, high, low, volume FROM {tb} \
88+ sql = "SELECT datetime, open, close, high, low, volume FROM {tb} \
4889 WHERE {start}<=id AND id<={end}" .format (tb = table , start = id_start , end = id_end )
4990
91+ data = pd .read_sql_query (sql , self .db , index_col = 'datetime' )
5092 if window_size == 0 :
51- data = pd .read_sql_query (sql , self .db , index_col = 'utime ' )
93+ data = pd .read_sql_query (sql , self .db , index_col = 'datetime ' )
5294 ## @todo
53- return SqliteSourceWrapper (data , None , len (data ))
95+ return SqliteSourceWrapper (pcontract , data , None , len (data ))
5496 else :
5597 cursor .execute (sql )
5698 data = pd .DataFrame ({
@@ -61,7 +103,7 @@ def load_bars(self, pcontract, dt_start, dt_end, window_size):
61103 'volume' : []
62104 })
63105 data .index = []
64- return SqliteSourceWrapper (data , cursor , max_length )
106+ return SqliteSourceWrapper (pcontract , data , cursor , window_size )
65107
66108 def read_csv (self , path ):
67109 """ 导入路径path下所有csv数据文件到sqlite中,每个文件对应数据库中的一周表格。
@@ -90,9 +132,9 @@ def to_csv(self, index=True, index_label='index'):
90132 for table_name in tables :
91133 table_name = table_name [0 ]
92134 table = pd .read_sql_query ("SELECT * from %s" % table_name , self .db )
93- #table['datetime'] = map(lambda x : datetime.fromtimestamp(x / 1000), table['utime '])
135+ #table['datetime'] = map(lambda x : datetime.fromtimestamp(x / 1000), table['datetime '])
94136 table .to_csv (table_name + '.csv' , index = index , index_label = index_label ,
95- columns = ['utime ' , 'open' , 'close' , 'high' , 'low' , 'volume' ])
137+ columns = ['datetime ' , 'open' , 'close' , 'high' , 'low' , 'volume' ])
96138
97139 def get_tables (self ):
98140 """ 返回数据库所有的表格"""
@@ -108,41 +150,54 @@ def get_table_fields(self, tb):
108150 def _df2sqlite (self , df , tbname ):
109151 self .cursor .execute ('''CREATE TABLE {tb}
110152 (id int primary key,
111- utime timestamp,
153+ datetime timestamp,
112154 open real,
113155 close real,
114156 high real,
115157 low real,
116158 volume int)''' .format (tb = tbname ))
117159 data = []
118160 for index , row in df .iterrows ():
119- id , utime = datautil .encode2id ('1.Minute' , index )
120- data .append ((id , utime , row ['open' ], row ['close' ], row ['high' ], row ['low' ], row ['vol' ]))
161+ id , datetime = datautil .encode2id ('1.Minute' , index )
162+ data .append ((id , datetime , row ['open' ], row ['close' ], row ['high' ], row ['low' ], row ['vol' ]))
121163 sql = "INSERT INTO %s VALUES (?,?,?,?,?,?,?)" % tbname
122164 self .cursor .executemany (sql , data )
123165 self .db .commit ()
124166
125- import time
167+
126168class CsvSource (object ):
127169 """
128170 (datetime, open, close, high, low, volume)
129- ## @todo
130- (utime, open, close, high, low, volume)
131171 """
132172 def __init__ (self , root ):
133173 self ._root = root
134174
135- def load_bars (self , pcontract , dt_start , dt_end ):
136- fname = '' .join ([self ._root , str (pcontract ), ".csv" ])
137- try :
138- data = pd .read_csv (fname , index_col = 0 , parse_dates = True )
139- dt_start = pd .to_datetime (dt_start )
140- dt_end = pd .to_datetime (dt_end )
141- data = data [(dt_start <= data .index ) & (data .index <= dt_end )]
142- data .index = map (lambda x : int (time .mktime (x .timetuple ())* 1000 ), data .index )
143- assert data .index .is_unique
144- except IOError :
145- #print u"**Warning: File \"%s\" doesn't exist!"%fname
146- raise FileDoesNotExist (file = fname )
175+ def load_bars (self , pcontract , dt_start , dt_end , window_size ):
176+ fname = os .path .join (self ._root , str (pcontract ) + ".csv" )
177+ if window_size == 0 :
178+ try :
179+ data = pd .read_csv (fname , index_col = 0 , parse_dates = True )
180+ except IOError :
181+ raise FileDoesNotExist (file = fname )
182+ else :
183+ dt_start = pd .to_datetime (dt_start )
184+ dt_end = pd .to_datetime (dt_end )
185+ data = data [(dt_start <= data .index ) & (data .index <= dt_end )]
186+ #data.index = map(lambda x : int(time.mktime(x.timetuple())*1000), data.index)
187+ assert data .index .is_unique
188+ return CsvSourceWrapper (pcontract , data , None , len (data ))
147189 else :
148- return data
190+ data = pd .DataFrame ({
191+ 'open' : [],
192+ 'close' : [],
193+ 'high' : [],
194+ 'low' : [],
195+ 'volume' : []
196+ })
197+ data .index = []
198+ cursor = csv .reader (open (fname , 'rb' ))
199+ fmt = ['datetime' , 'open' , 'close' , 'high' , 'low' , 'volume' ]
200+ header = cursor .next ()
201+ if header != fmt :
202+ raise DataFieldError (error_fields = header , right_fields = fmt )
203+ return CsvSourceWrapper (pcontract , data , cursor , window_size )
0 commit comments