zhenlanghuo's Blog

Golang database/sql与go-sql-driver/mysql 源码阅读笔记 -- go-sql-driver/mysql篇

字数统计: 2.6k阅读时长: 13 min
2019/07/21 Share

driver.go

connector.go

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
// Connect implements driver.Connector interface.
// Connect returns a connection to the database.
func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
var err error

// New mysqlConn
mc := &mysqlConn{
maxAllowedPacket: maxPacketSize,
maxWriteSize: maxPacketSize - 1,
closech: make(chan struct{}),
cfg: c.cfg,
}
mc.parseTime = mc.cfg.ParseTime

// Connect to Server
// 连接服务器
dialsLock.RLock()
// 获取自定义的dial函数
dial, ok := dials[mc.cfg.Net]
dialsLock.RUnlock()
if ok {
// 使用自定义的dial函数进行连接
mc.netConn, err = dial(ctx, mc.cfg.Addr)
} else {
nd := net.Dialer{Timeout: mc.cfg.Timeout}
mc.netConn, err = nd.DialContext(ctx, mc.cfg.Net, mc.cfg.Addr)
}

if err != nil {
if nerr, ok := err.(net.Error); ok && nerr.Temporary() {
errLog.Print("net.Error from Dial()': ", nerr.Error())
return nil, driver.ErrBadConn
}
return nil, err
}

// Enable TCP Keepalives on TCP connections
// 如果是TCP连接,设置长连接
if tc, ok := mc.netConn.(*net.TCPConn); ok {
if err := tc.SetKeepAlive(true); err != nil {
// Don't send COM_QUIT before handshake.
mc.netConn.Close()
mc.netConn = nil
return nil, err
}
}

// Call startWatcher for context support (From Go 1.8)
// 调用startWatcher函数以支持对context的处理,详细看下面对该函数的笔记
mc.startWatcher()
if err := mc.watchCancel(ctx); err != nil {
return nil, err
}
defer mc.finish()

// 把连接句柄交给buffer对象,让buffer对象管理socket的读处理,提供缓冲读功能
mc.buf = newBuffer(mc.netConn)

// Set I/O timeouts
mc.buf.timeout = mc.cfg.ReadTimeout
mc.writeTimeout = mc.cfg.WriteTimeout

// 接下来就是与服务器握手认证阶段的处理

// Reading Handshake Initialization Packet
authData, plugin, err := mc.readHandshakePacket()
if err != nil {
mc.cleanup()
return nil, err
}

if plugin == "" {
plugin = defaultAuthPlugin
}

// Send Client Authentication Packet
authResp, err := mc.auth(authData, plugin)
if err != nil {
// try the default auth plugin, if using the requested plugin failed
errLog.Print("could not use requested auth plugin '"+plugin+"': ", err.Error())
plugin = defaultAuthPlugin
authResp, err = mc.auth(authData, plugin)
if err != nil {
mc.cleanup()
return nil, err
}
}
if err = mc.writeHandshakeResponsePacket(authResp, plugin); err != nil {
mc.cleanup()
return nil, err
}

// Handle response to auth packet, switch methods if possible
if err = mc.handleAuthResult(authData, plugin); err != nil {
// Authentication failed and MySQL has already closed the connection
// (https://dev.mysql.com/doc/internals/en/authentication-fails.html).
// Do not send COM_QUIT, just cleanup and return the error.
mc.cleanup()
return nil, err
}

if mc.cfg.MaxAllowedPacket > 0 {
mc.maxAllowedPacket = mc.cfg.MaxAllowedPacket
} else {
// Get max allowed packet size
maxap, err := mc.getSystemVar("max_allowed_packet")
if err != nil {
mc.Close()
return nil, err
}
mc.maxAllowedPacket = stringToInt(maxap) - 1
}
if mc.maxAllowedPacket < maxPacketSize {
mc.maxWriteSize = mc.maxAllowedPacket
}

// Handle DSN Params
// 对服务器进行会话的用户/系统环境变量的设置
err = mc.handleParams()
if err != nil {
mc.Close()
return nil, err
}

return mc, nil
}

connection.go

startWatcher、watchCancel

startWatcher是开启了一个go协程,循环监听watcher管道的消息,watcher管道放的是Context类型的对象,接收到一个context对象以后,监听该对象的结束管道,收到context对象的结束消息以后,调用mc.cancel。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
func (mc *mysqlConn) startWatcher() {
watcher := make(chan context.Context, 1)
mc.watcher = watcher
finished := make(chan struct{})
mc.finished = finished
go func() {
for {
var ctx context.Context
select {
case ctx = <-watcher:
case <-mc.closech:
return
}

select {
case <-ctx.Done():
mc.cancel(ctx.Err())
case <-finished:
case <-mc.closech:
return
}
}
}()
}

cancel函数是将传入的error保存在mc.canceled变量中,然后调用mc.cleanup。
而celanup函数的工作主要就是关闭连接。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
// finish is called when the query has canceled.
func (mc *mysqlConn) cancel(err error) {
mc.canceled.Set(err)
mc.cleanup()
}

// Closes the network connection and unsets internal variables. Do not call this
// function after successfully authentication, call Close instead. This function
// is called before auth or on auth failure because MySQL will have already
// closed the network connection.
func (mc *mysqlConn) cleanup() {
if !mc.closed.TrySet(true) {
return
}

// Makes cleanup idempotent
close(mc.closech)
if mc.netConn == nil {
return
}
if err := mc.netConn.Close(); err != nil {
errLog.Print(err)
}
}

startWatcher是监听watcher管道,而watcher管道里的消息是由watchCancel函数进行放入的

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
func (mc *mysqlConn) watchCancel(ctx context.Context) error {
if mc.watching {
// Reach here if canceled,
// so the connection is already invalid
mc.cleanup()
return nil
}
// When ctx is already cancelled, don't watch it.
if err := ctx.Err(); err != nil {
return err
}
// When ctx is not cancellable, don't watch it.
if ctx.Done() == nil {
return nil
}
// When watcher is not alive, can't watch it.
if mc.watcher == nil {
return nil
}

mc.watching = true
mc.watcher <- ctx
return nil
}

startWatcher函数在连接服务器以后就调用了,之后需要监听context的时候就调用watchCancel,同一个连接同一时间只能监听一个context。

watchCanel一般会在mc.XxxContext函数中被调用,如mc.ExecContext

1
2
3
4
5
6
7
8
9
10
11
12
13
func (mc *mysqlConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
dargs, err := namedValueToValue(args)
if err != nil {
return nil, err
}

if err := mc.watchCancel(ctx); err != nil {
return nil, err
}
defer mc.finish()

return mc.Exec(query, dargs)
}

在调用mc.Exec之前,先调用watchCancel来对ctx进行监听,看到这里的时候我产生了一个问题,调用sql.DB.XxxContext函数,当context超时时,返回的error是context timeout,那么这个错误到底是从哪里返回的呢。

沿着函数的调用,终于发现,context的error是由mc.readPacket和mc.writePacket返回的

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
// Read packet to buffer 'data'
func (mc *mysqlConn) readPacket() ([]byte, error) {
var prevData []byte
for {
// read packet header
data, err := mc.buf.readNext(4)
if err != nil {
// 先检查mc.canceled中是否保存有error,如果有,返回该error
if cerr := mc.canceled.Value(); cerr != nil {
return nil, cerr
}
...省略一些代码...
}

...省略一些代码...

// read packet body [pktLen bytes]
data, err = mc.buf.readNext(pktLen)
if err != nil {
// 先检查mc.canceled中是否保存有error,如果有,返回该error
if cerr := mc.canceled.Value(); cerr != nil {
return nil, cerr
}
...省略一些代码...
}

...省略一些代码...
}
}

// Write packet buffer 'data'
func (mc *mysqlConn) writePacket(data []byte) error {
...省略一些代码...

if mc.reset {
...省略一些代码...
}

for {
...省略一些代码...

n, err := mc.netConn.Write(data[:4+size])
if err == nil && n == 4+size {
...省略一些代码...
}

// Handle error
if err == nil { // n != len(data)
...省略一些代码...
} else {
// 先检查mc.canceled中是否保存有error,如果有,返回该error
if cerr := mc.canceled.Value(); cerr != nil {
return cerr
}
...省略一些代码...
}
return ErrInvalidConn
}
}

在读写操作返回错误时,先检查mc.canceled中是否保存有error,如果有直接返回该error;前面说过,当ctx.Done有消息返回时,会调用mc.cancel,会将ctx.Err保存到mc.canceled,还会关闭连接的socket;因此在context被cancel或超时或者其他操作导致ctx.Done返回消息后,socket被被关闭,在socket上进行读写操作会返回error,然后将mc.canceled里保存的context的error返回。

packets.go

数据库驱动,其实也就是负责与数据库服务端通信的客户端,客户端只负责发送要执行的sql或者一些设置命令给服务端,并接受服务端的响应或者返回的数据。

mysql服务端与客户端之间的通信协议,参考一下的博文:
MySQL协议分析
MySQL网络协议分析

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
// Read packet to buffer 'data'
func (mc *mysqlConn) readPacket() ([]byte, error) {
var prevData []byte
// for循环是为了读取被分片的数据,当需要发送的数据大小大于maxPacketSize(16MB-1byte)时就需要分片。
for {
// 读取固定4个字节的包头数据
data, err := mc.buf.readNext(4)
if err != nil {
// 上文提到的当从buffer读取数据失败时先检查mc.canceled中是否保存有error,因为可能context已经超时或者被取消而关闭了连接,此时应该返回canceled里的error。
if cerr := mc.canceled.Value(); cerr != nil {
return nil, cerr
}
errLog.Print(err)
mc.Close()
return nil, ErrInvalidConn
}

// packet length [24 bit]
// 包头前3个字节表示除去包头4个字节包数据的大小,单位为字节
pktLen := int(uint32(data[0]) | uint32(data[1])<<8 | uint32(data[2])<<16)

// check packet sync [8 bit]
// 包头第4个字节是序号,如果包头的序号与客户端记录的序号不一致,表明消息顺序异常,返回error
if data[3] != mc.sequence {
if data[3] > mc.sequence {
return nil, ErrPktSyncMul
}
return nil, ErrPktSync
}
mc.sequence++

// packets with length 0 terminate a previous packet which is a
// multiple of (2^24)-1 bytes long
if pktLen == 0 {
// there was no previous packet
// 包数据大小等于零,且prevData为nil,异常情况,报错
if prevData == nil {
errLog.Print(ErrMalformPkt)
mc.Close()
return nil, ErrInvalidConn
}

// 包数据大小等于零,prevData不为nil,这是数据大小刚好是maxPacketSize的整数倍的情况,返回prevData
return prevData, nil
}

// read packet body [pktLen bytes]
// 读取pktkLen字节数据
data, err = mc.buf.readNext(pktLen)
if err != nil {
// 同上
if cerr := mc.canceled.Value(); cerr != nil {
return nil, cerr
}
errLog.Print(err)
mc.Close()
return nil, ErrInvalidConn
}

// return data if this was the last packet
// ptkLen小于maxPacketSize,意味着需要返回数据了
if pktLen < maxPacketSize {
// zero allocations for non-split packets
// 如果prevData为nil,表明该数据没有分片,直接返回data
if prevData == nil {
return data, nil
}

// 如果preData不为nil,表明已经读到分片数据中的最后一部分数据,返回prevData加上data
return append(prevData, data...), nil
}

// 数据分片了,把读到的数据添加到prevData中,然后继续循环。
prevData = append(prevData, data...)
}
}

// Write packet buffer 'data'
func (mc *mysqlConn) writePacket(data []byte) error {
// 传入的data都预留了4个字节的包头位置
pktLen := len(data) - 4

if pktLen > mc.maxAllowedPacket {
return ErrPktTooLarge
}

// Perform a stale connection check. We only perform this check for
// the first query on a connection that has been checked out of the
// connection pool: a fresh connection from the pool is more likely
// to be stale, and it has not performed any previous writes that
// could cause data corruption, so it's safe to return ErrBadConn
// if the check fails.
if mc.reset {
mc.reset = false
conn := mc.netConn
if mc.rawConn != nil {
conn = mc.rawConn
}
var err error
// If this connection has a ReadTimeout which we've been setting on
// reads, reset it to its default value before we attempt a non-blocking
// read, otherwise the scheduler will just time us out before we can read
if mc.cfg.ReadTimeout != 0 {
err = conn.SetReadDeadline(time.Time{})
}
if err == nil {
err = connCheck(conn)
}
if err != nil {
errLog.Print("closing bad idle connection: ", err)
mc.Close()
return driver.ErrBadConn
}
}

for {
// 设置前面4个字节的包头
var size int
if pktLen >= maxPacketSize {
data[0] = 0xff
data[1] = 0xff
data[2] = 0xff
size = maxPacketSize
} else {
data[0] = byte(pktLen)
data[1] = byte(pktLen >> 8)
data[2] = byte(pktLen >> 16)
size = pktLen
}
data[3] = mc.sequence

// Write packet
if mc.writeTimeout > 0 {
if err := mc.netConn.SetWriteDeadline(time.Now().Add(mc.writeTimeout)); err != nil {
return err
}
}

// 将data的前4+size个字节写入连接的socket中
n, err := mc.netConn.Write(data[:4+size])
if err == nil && n == 4+size {
mc.sequence++
if size != maxPacketSize {
return nil
}
// 更新pktLen
pktLen -= size
// 将data的浮标往前移动size位,剩后面的4个字节作为下一个分片的包头
data = data[size:]
continue
}

// Handle error
if err == nil { // n != len(data)
mc.cleanup()
errLog.Print(ErrMalformPkt)
} else {
if cerr := mc.canceled.Value(); cerr != nil {
return cerr
}
if n == 0 && pktLen == len(data)-4 {
// only for the first loop iteration when nothing was written yet
return errBadConnNoWrite
}
mc.cleanup()
errLog.Print(err)
}
return ErrInvalidConn
}
}
CATALOG
  1. driver.go
  2. connector.go
  3. connection.go
    1. startWatcher、watchCancel
  4. packets.go