@@ -17,91 +17,122 @@ import (
1717 "net"
1818)
1919
20- type Conn struct {
20+ type Conn interface {
21+ Cleanup ()
22+ Close () error
23+ Closed () bool
24+ NextPacket () ([]byte , error )
25+ Query (sql string ) (rows Rows , err error )
26+ Exec (sql string ) (rows Rows , err error )
27+ }
28+
29+ type conn struct {
2130 netConn net.Conn
2231 greeting * proto.Greeting
2332 auth * proto.Auth
2433 packets * packet.Packets
2534}
2635
27- func NewConn (username , password , protocol , address , database string ) (* Conn , error ) {
28- netconn , err := net .Dial (protocol , address )
29- if err != nil {
30- return nil , errors .WithStack (err )
36+ func (c * conn ) handleErrorPacket (data []byte ) error {
37+ if data [0 ] == proto .ERR_PACKET {
38+ pkt , e := c .packets .ParseERR (data , c .greeting .Capability )
39+ if e != nil {
40+ c .Cleanup ()
41+ return e
42+ }
43+ return errors .New (pkt .ErrorMessage )
3144 }
3245
33- conn := & Conn {
34- netConn : netconn ,
35- greeting : proto .NewGreeting (0 ),
36- auth : proto .NewAuth (),
37- packets : packet .NewPackets (netconn ),
46+ return nil
47+ }
48+
49+ func NewConn (username , password , address , database string ) (c * conn , err error ) {
50+ var payload []byte
51+
52+ c = & conn {}
53+ if c .netConn , err = net .Dial ("tcp" , address ); err != nil {
54+ return nil , errors .WithStack (err )
3855 }
3956
57+ c .auth = proto .NewAuth ()
58+ c .greeting = proto .NewGreeting (0 )
59+ c .packets = packet .NewPackets (c .netConn )
60+
4061 {
62+
4163 // greeting read
42- payload , err := conn .packets .Next ()
43- if err != nil {
44- return nil , err
64+ if payload , err = c .packets .Next (); err != nil {
65+ c .Cleanup ()
66+ return
67+ }
68+ if err = c .handleErrorPacket (payload ); err != nil {
69+ return
4570 }
4671
4772 // greeting unpack
48- err = conn .greeting .UnPack (payload )
49- if err != nil {
50- return nil , err
73+ if err = c .greeting .UnPack (payload ); err != nil {
74+ c . Cleanup ()
75+ return
5176 }
5277 }
5378
5479 {
5580 // auth pack
56- payload := conn .auth .Pack (
57- conn .greeting .Capability ,
58- conn .greeting .Charset ,
81+ payload := c .auth .Pack (
82+ c .greeting .Capability ,
83+ c .greeting .Charset ,
5984 username ,
6085 password ,
61- conn .greeting .Salt ,
86+ c .greeting .Salt ,
6287 database ,
6388 )
6489
6590 // auth write
66- err := conn .packets .Write (payload )
67- if err != nil {
68- return nil , err
91+ if err = c .packets .Write (payload ); err != nil {
92+ c . Cleanup ()
93+ return
6994 }
7095 }
7196
7297 {
7398 // read
74- payload , err := conn .packets .Next ()
75- if err != nil {
76- return nil , err
99+ if payload , err = c .packets .Next (); err != nil {
100+ c . Cleanup ()
101+ return
77102 }
78103
79- if payload [0 ] != proto .OK_PACKET {
80- pkt , err := conn .packets .ParseERR (payload , conn .greeting .Capability )
81- if err != nil {
82- return nil , err
83- }
84- return nil , errors .New (pkt .ErrorMessage )
104+ if err = c .handleErrorPacket (payload ); err != nil {
105+ return
85106 }
86107 }
87108
88- return conn , nil
109+ return c , nil
110+ }
111+
112+ func (c * conn ) NextPacket () ([]byte , error ) {
113+ return c .packets .Next ()
114+ }
115+
116+ func (c * conn ) Cleanup () {
117+ if c .netConn != nil {
118+ c .netConn .Close ()
119+ c .netConn = nil
120+ }
89121}
90122
91123// Close closes the connection
92- func (c * Conn ) Close () error {
124+ func (c * conn ) Close () error {
93125 if c .netConn != nil {
94126 if err := c .packets .WriteCommand (consts .COM_QUIT , nil ); err != nil {
95127 return err
96128 }
97-
98- if c .netConn != nil {
99- if err := c .netConn .Close (); err != nil {
100- return errors .WithStack (err )
101- }
102- c .netConn = nil
103- }
129+ c .Cleanup ()
104130 }
105131
106132 return nil
107133}
134+
135+ // Closed checks the connection broken or not
136+ func (c * conn ) Closed () bool {
137+ return c .netConn == nil
138+ }
0 commit comments