Skip to content

Commit 0665928

Browse files
committed
driver/sqlparser: add FetchAllWithFunc and support explain
summary: 1. driver: add FetchAllWithFunc function, the row cursor can be interrupted by the the callback. 2. sqlparser: support the 'EXPLAIN' statement.
1 parent 7c23749 commit 0665928

9 files changed

Lines changed: 954 additions & 829 deletions

File tree

driver/client.go

Lines changed: 39 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -30,17 +30,20 @@ type Conn interface {
3030
Close() error
3131
Closed() bool
3232
Cleanup()
33+
NextPacket() ([]byte, error)
3334

34-
// ConnectionID is the connection id at greeting
35+
// ConnectionID is the connection id at greeting.
3536
ConnectionID() uint32
36-
NextPacket() ([]byte, error)
3737

38-
// Query gets the row iterator
38+
// Query get the row cursor.
3939
Query(sql string) (Rows, error)
4040
Exec(sql string) error
4141

42-
// FetchAll fetchs all results
42+
// FetchAll fetchs all results.
4343
FetchAll(sql string, maxrows int) (*sqltypes.Result, error)
44+
45+
// FetchAllWithFunc fetchs all results but the row cursor can be interrupted by the fn.
46+
FetchAllWithFunc(sql string, maxrows int, fn Func) (*sqltypes.Result, error)
4447
}
4548

4649
type conn struct {
@@ -238,39 +241,52 @@ func (c *conn) Exec(sql string) error {
238241
}
239242

240243
func (c *conn) FetchAll(sql string, maxrows int) (*sqltypes.Result, error) {
241-
var r *sqltypes.Result
244+
return c.FetchAllWithFunc(sql, maxrows, func(rows Rows) error { return nil })
245+
}
242246

243-
rows, err := c.query(sqldb.COM_QUERY, sql)
244-
if err != nil {
247+
// Func calls on every rows.Next.
248+
// If func returns error, the row.Next() is interrupted and the error is return.
249+
type Func func(rows Rows) error
250+
251+
func (c *conn) FetchAllWithFunc(sql string, maxrows int, fn Func) (*sqltypes.Result, error) {
252+
var err error
253+
var iRows Rows
254+
var qrRow []sqltypes.Value
255+
var qrRows [][]sqltypes.Value
256+
257+
if iRows, err = c.query(sqldb.COM_QUERY, sql); err != nil {
245258
return nil, err
246259
}
247260

248-
r = &sqltypes.Result{
249-
Fields: rows.Fields(),
250-
RowsAffected: rows.RowsAffected(),
251-
InsertID: rows.LastInsertID(),
252-
}
261+
for iRows.Next() {
262+
// callback check.
263+
if err = fn(iRows); err != nil {
264+
break
265+
}
253266

254-
for rows.Next() {
255-
if len(r.Rows) == maxrows {
267+
// Max rows check.
268+
if len(qrRows) == maxrows {
256269
break
257270
}
258-
row, err := rows.RowValues()
259-
if err != nil {
271+
if qrRow, err = iRows.RowValues(); err != nil {
272+
c.Cleanup()
260273
return nil, err
261274
}
262-
r.Rows = append(r.Rows, row)
263-
}
264-
if len(r.Rows) > 0 {
265-
r.RowsAffected = uint64(len(r.Rows))
275+
qrRows = append(qrRows, qrRow)
266276
}
267277

268-
// Check last error
269-
if err := rows.Close(); err != nil {
278+
// Drain the results and check last error.
279+
if err := iRows.Close(); err != nil {
270280
c.Cleanup()
271281
return nil, err
272282
}
273-
return r, nil
283+
qr := &sqltypes.Result{
284+
Fields: iRows.Fields(),
285+
RowsAffected: uint64(len(qrRows)),
286+
InsertID: iRows.LastInsertID(),
287+
Rows: qrRows,
288+
}
289+
return qr, err
274290
}
275291

276292
// NextPacket used to get the next packet

driver/client_test.go

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,12 @@
1010
package driver
1111

1212
import (
13+
"errors"
1314
"testing"
1415

1516
"github.com/stretchr/testify/assert"
1617

18+
querypb "github.com/XeLabs/go-mysqlstack/sqlparser/depends/query"
1719
"github.com/XeLabs/go-mysqlstack/sqlparser/depends/sqltypes"
1820
"github.com/XeLabs/go-mysqlstack/xlog"
1921
)
@@ -84,3 +86,55 @@ func TestClientClosed(t *testing.T) {
8486
assert.Equal(t, want, got)
8587
}
8688
}
89+
90+
func TestClientFetchAllWithFunc(t *testing.T) {
91+
result1 := &sqltypes.Result{
92+
Fields: []*querypb.Field{
93+
{
94+
Name: "id",
95+
Type: querypb.Type_INT32,
96+
},
97+
{
98+
Name: "name",
99+
Type: querypb.Type_VARCHAR,
100+
},
101+
},
102+
Rows: [][]sqltypes.Value{
103+
{
104+
sqltypes.MakeTrusted(querypb.Type_INT32, []byte("10")),
105+
sqltypes.MakeTrusted(querypb.Type_VARCHAR, []byte("nice name")),
106+
},
107+
{
108+
sqltypes.MakeTrusted(querypb.Type_INT32, []byte("20")),
109+
sqltypes.NULL,
110+
},
111+
},
112+
}
113+
114+
log := xlog.NewStdLog(xlog.Level(xlog.DEBUG))
115+
th := NewTestHandler(log)
116+
svr, err := MockMysqlServer(log, th)
117+
assert.Nil(t, err)
118+
defer svr.Close()
119+
address := svr.Addr()
120+
121+
// query
122+
{
123+
124+
client, err := NewConn("mock", "mock", address, "test")
125+
assert.Nil(t, err)
126+
defer client.Close()
127+
128+
th.AddQuery("SELECT2", result1)
129+
checkFunc := func(rows Rows) error {
130+
if rows.Bytes() > 2 {
131+
return errors.New("client.checkFunc.error")
132+
}
133+
return nil
134+
}
135+
_, err = client.FetchAllWithFunc("SELECT2", -1, checkFunc)
136+
want := "client.checkFunc.error"
137+
got := err.Error()
138+
assert.Equal(t, want, got)
139+
}
140+
}

driver/rows.go

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,12 @@ import (
2121

2222
var _ Rows = &TextRows{}
2323

24+
// Rows presents row cursor interface.
2425
type Rows interface {
2526
Next() bool
2627
Close() error
2728
Datas() []byte
29+
Bytes() int
2830
RowsAffected() uint64
2931
LastInsertID() uint64
3032
LastError() error
@@ -37,6 +39,7 @@ type TextRows struct {
3739
end bool
3840
err error
3941
data []byte
42+
bytes int
4043
rowsAffected uint64
4144
insertID uint64
4245
buffer *common.Buffer
@@ -63,7 +66,7 @@ func (r *TextRows) Next() bool {
6366
}
6467

6568
// if fields count is 0
66-
// the packet is OK-Packet without Resultset
69+
// the packet is OK-Packet without Resultset.
6770
if len(r.fields) == 0 {
6871
r.end = true
6972
return false
@@ -92,14 +95,13 @@ func (r *TextRows) Next() bool {
9295
return true
9396
}
9497

95-
// Close drain the rest packets and check the error
98+
// Close drain the rest packets and check the error.
9699
func (r *TextRows) Close() error {
97100
for r.Next() {
98101
}
99102
if err := r.LastError(); err != nil {
100103
return err
101104
}
102-
103105
return nil
104106
}
105107

@@ -117,12 +119,13 @@ func (r *TextRows) RowValues() ([]sqltypes.Value, error) {
117119
r.c.Cleanup()
118120
return nil, err
119121
}
122+
r.bytes += len(v)
123+
120124
// if v is NIL, it's a NULL column
121125
if v != nil {
122126
result[i] = sqltypes.MakeTrusted(r.fields[i].Type, v)
123127
}
124128
}
125-
126129
return result, nil
127130
}
128131

@@ -134,6 +137,11 @@ func (r *TextRows) Fields() []*querypb.Field {
134137
return r.fields
135138
}
136139

140+
// Bytes returns all the memory usage which read by this row cursor.
141+
func (r *TextRows) Bytes() int {
142+
return r.bytes
143+
}
144+
137145
func (r *TextRows) RowsAffected() uint64 {
138146
return r.rowsAffected
139147
}

driver/rows_test.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,12 @@ func TestRows(t *testing.T) {
7979
assert.Nil(t, err)
8080
assert.Equal(t, result1.Fields, rows.Fields())
8181
for rows.Next() {
82-
_ = rows.Datas()
82+
//_ = rows.Datas()
8383
_, _ = rows.RowValues()
8484
}
85+
86+
want := 13
87+
got := int(rows.Bytes())
88+
assert.Equal(t, want, got)
8589
}
8690
}

sqlparser/explain.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
// Copyright 2012, Google Inc. All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
package sqlparser
6+
7+
import ()
8+
9+
func (*Explain) iStatement() {}
10+
11+
// Explain represents a explain statement.
12+
type Explain struct {
13+
}
14+
15+
// Format formats the node.
16+
func (node *Explain) Format(buf *TrackedBuffer) {
17+
buf.WriteString("explain")
18+
}
19+
20+
// WalkSubtree walks the nodes of the subtree.
21+
func (node *Explain) WalkSubtree(visit Visit) error {
22+
return nil
23+
}

sqlparser/parse_test.go

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -573,9 +573,6 @@ func TestValid(t *testing.T) {
573573
}, {
574574
input: "describe foobar",
575575
output: "other",
576-
}, {
577-
input: "explain foobar",
578-
output: "other",
579576
}, {
580577
input: "select /* EQ true */ 1 from t where a = true",
581578
}, {
@@ -930,6 +927,30 @@ func TestKill(t *testing.T) {
930927
}
931928
}
932929

930+
func TestExplain(t *testing.T) {
931+
sqls := []struct {
932+
input string
933+
output string
934+
}{{
935+
input: "explain select * from xx",
936+
output: "explain",
937+
}}
938+
for _, tcase := range sqls {
939+
if tcase.output == "" {
940+
tcase.output = tcase.input
941+
}
942+
tree, err := Parse(tcase.input)
943+
if err != nil {
944+
t.Errorf("input: %s, err: %v", tcase.input, err)
945+
continue
946+
}
947+
out := String(tree)
948+
if out != tcase.output {
949+
t.Errorf("out: %s, want %s", out, tcase.output)
950+
}
951+
}
952+
}
953+
933954
func BenchmarkParse1(b *testing.B) {
934955
sql := "select 'abcd', 20, 30.0, eid from a where 1=eid and name='3'"
935956
for i := 0; i < b.N; i++ {

0 commit comments

Comments
 (0)