From 477a80f658a04fc477d800887c55654ac307ff45 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Thu, 17 Jan 2019 14:07:23 +0800 Subject: [PATCH] upgrade go-sql-driver/mysql to fix invalid connection error (#5748) should fix #5736 --- Gopkg.lock | 4 +- Gopkg.toml | 2 +- vendor/github.com/go-sql-driver/mysql/AUTHORS | 5 + vendor/github.com/go-sql-driver/mysql/auth.go | 44 +-- .../github.com/go-sql-driver/mysql/buffer.go | 49 ++-- .../go-sql-driver/mysql/connection.go | 210 ++++++++++++++- .../go-sql-driver/mysql/connection_go18.go | 208 --------------- .../github.com/go-sql-driver/mysql/driver.go | 24 +- vendor/github.com/go-sql-driver/mysql/dsn.go | 2 +- .../github.com/go-sql-driver/mysql/packets.go | 99 ++++--- .../github.com/go-sql-driver/mysql/utils.go | 251 +++++++++++------- .../go-sql-driver/mysql/utils_go17.go | 40 --- .../go-sql-driver/mysql/utils_go18.go | 50 ---- 13 files changed, 466 insertions(+), 522 deletions(-) delete mode 100644 vendor/github.com/go-sql-driver/mysql/connection_go18.go delete mode 100644 vendor/github.com/go-sql-driver/mysql/utils_go17.go delete mode 100644 vendor/github.com/go-sql-driver/mysql/utils_go18.go diff --git a/Gopkg.lock b/Gopkg.lock index 17e4397b1..5c2b54e3f 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -383,11 +383,11 @@ revision = "a77f45a7ce909c0ff14b28279fa1a2b674acb70f" [[projects]] - digest = "1:747c1fcb10f8f6734551465ab73c6ed9c551aa6e66250fb6683d1624f554546a" + digest = "1:dce58f88343bd78f4d32dd9601aab4fa5d9994fd2cafa185c51bbd858851cdf9" name = "github.com/go-sql-driver/mysql" packages = ["."] pruneopts = "NUT" - revision = "d523deb1b23d913de5bdada721a6071e71283618" + revision = "c45f530f8e7fe40f4687eaa50d0c8c5f1b66f9e0" [[projects]] digest = "1:06d21295033f211588d0ad7ff391cc1b27e72b60cb6d4b7db0d70cffae4cf228" diff --git a/Gopkg.toml b/Gopkg.toml index 2eb81803a..51f2b2cab 100644 --- a/Gopkg.toml +++ b/Gopkg.toml @@ -46,7 +46,7 @@ ignored = ["google.golang.org/appengine*"] [[override]] name = "github.com/go-sql-driver/mysql" - revision = "d523deb1b23d913de5bdada721a6071e71283618" + revision = "c45f530f8e7fe40f4687eaa50d0c8c5f1b66f9e0" [[override]] name = "github.com/mattn/go-sqlite3" diff --git a/vendor/github.com/go-sql-driver/mysql/AUTHORS b/vendor/github.com/go-sql-driver/mysql/AUTHORS index 73ff68fbc..5ce4f7eca 100644 --- a/vendor/github.com/go-sql-driver/mysql/AUTHORS +++ b/vendor/github.com/go-sql-driver/mysql/AUTHORS @@ -35,6 +35,7 @@ Hanno Braun Henri Yandell Hirotaka Yamamoto ICHINOSE Shogo +Ilia Cimpoes INADA Naoki Jacek Szwec James Harr @@ -72,6 +73,9 @@ Shuode Li Soroush Pour Stan Putrya Stanley Gunawan +Steven Hartland +Thomas Wodarek +Tom Jenkinson Xiangyu Hu Xiaobing Jiang Xiuming Chen @@ -87,3 +91,4 @@ Keybase Inc. Percona LLC Pivotal Inc. Stripe Inc. +Multiplay Ltd. diff --git a/vendor/github.com/go-sql-driver/mysql/auth.go b/vendor/github.com/go-sql-driver/mysql/auth.go index 0b59f52ee..fec7040d4 100644 --- a/vendor/github.com/go-sql-driver/mysql/auth.go +++ b/vendor/github.com/go-sql-driver/mysql/auth.go @@ -234,64 +234,64 @@ func (mc *mysqlConn) sendEncryptedPassword(seed []byte, pub *rsa.PublicKey) erro if err != nil { return err } - return mc.writeAuthSwitchPacket(enc, false) + return mc.writeAuthSwitchPacket(enc) } -func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, bool, error) { +func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, error) { switch plugin { case "caching_sha2_password": authResp := scrambleSHA256Password(authData, mc.cfg.Passwd) - return authResp, (authResp == nil), nil + return authResp, nil case "mysql_old_password": if !mc.cfg.AllowOldPasswords { - return nil, false, ErrOldPassword + return nil, ErrOldPassword } // Note: there are edge cases where this should work but doesn't; // this is currently "wontfix": // https://github.com/go-sql-driver/mysql/issues/184 - authResp := scrambleOldPassword(authData[:8], mc.cfg.Passwd) - return authResp, true, nil + authResp := append(scrambleOldPassword(authData[:8], mc.cfg.Passwd), 0) + return authResp, nil case "mysql_clear_password": if !mc.cfg.AllowCleartextPasswords { - return nil, false, ErrCleartextPassword + return nil, ErrCleartextPassword } // http://dev.mysql.com/doc/refman/5.7/en/cleartext-authentication-plugin.html // http://dev.mysql.com/doc/refman/5.7/en/pam-authentication-plugin.html - return []byte(mc.cfg.Passwd), true, nil + return append([]byte(mc.cfg.Passwd), 0), nil case "mysql_native_password": if !mc.cfg.AllowNativePasswords { - return nil, false, ErrNativePassword + return nil, ErrNativePassword } // https://dev.mysql.com/doc/internals/en/secure-password-authentication.html // Native password authentication only need and will need 20-byte challenge. authResp := scramblePassword(authData[:20], mc.cfg.Passwd) - return authResp, false, nil + return authResp, nil case "sha256_password": if len(mc.cfg.Passwd) == 0 { - return nil, true, nil + return []byte{0}, nil } if mc.cfg.tls != nil || mc.cfg.Net == "unix" { // write cleartext auth packet - return []byte(mc.cfg.Passwd), true, nil + return append([]byte(mc.cfg.Passwd), 0), nil } pubKey := mc.cfg.pubKey if pubKey == nil { // request public key from server - return []byte{1}, false, nil + return []byte{1}, nil } // encrypted password enc, err := encryptPassword(mc.cfg.Passwd, authData, pubKey) - return enc, false, err + return enc, err default: errLog.Print("unknown auth plugin:", plugin) - return nil, false, ErrUnknownPlugin + return nil, ErrUnknownPlugin } } @@ -315,11 +315,11 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error { plugin = newPlugin - authResp, addNUL, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin) if err != nil { return err } - if err = mc.writeAuthSwitchPacket(authResp, addNUL); err != nil { + if err = mc.writeAuthSwitchPacket(authResp); err != nil { return err } @@ -352,7 +352,7 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error { case cachingSha2PasswordPerformFullAuthentication: if mc.cfg.tls != nil || mc.cfg.Net == "unix" { // write cleartext auth packet - err = mc.writeAuthSwitchPacket([]byte(mc.cfg.Passwd), true) + err = mc.writeAuthSwitchPacket(append([]byte(mc.cfg.Passwd), 0)) if err != nil { return err } @@ -360,13 +360,15 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error { pubKey := mc.cfg.pubKey if pubKey == nil { // request public key from server - data := mc.buf.takeSmallBuffer(4 + 1) + data, err := mc.buf.takeSmallBuffer(4 + 1) + if err != nil { + return err + } data[4] = cachingSha2PasswordRequestPublicKey mc.writePacket(data) // parse public key - data, err := mc.readPacket() - if err != nil { + if data, err = mc.readPacket(); err != nil { return err } diff --git a/vendor/github.com/go-sql-driver/mysql/buffer.go b/vendor/github.com/go-sql-driver/mysql/buffer.go index eb4748bf4..19486bd6f 100644 --- a/vendor/github.com/go-sql-driver/mysql/buffer.go +++ b/vendor/github.com/go-sql-driver/mysql/buffer.go @@ -22,17 +22,17 @@ const defaultBufSize = 4096 // The buffer is similar to bufio.Reader / Writer but zero-copy-ish // Also highly optimized for this particular use case. type buffer struct { - buf []byte + buf []byte // buf is a byte buffer who's length and capacity are equal. nc net.Conn idx int length int timeout time.Duration } +// newBuffer allocates and returns a new buffer. func newBuffer(nc net.Conn) buffer { - var b [defaultBufSize]byte return buffer{ - buf: b[:], + buf: make([]byte, defaultBufSize), nc: nc, } } @@ -105,43 +105,56 @@ func (b *buffer) readNext(need int) ([]byte, error) { return b.buf[offset:b.idx], nil } -// returns a buffer with the requested size. +// takeBuffer returns a buffer with the requested size. // If possible, a slice from the existing buffer is returned. // Otherwise a bigger buffer is made. // Only one buffer (total) can be used at a time. -func (b *buffer) takeBuffer(length int) []byte { +func (b *buffer) takeBuffer(length int) ([]byte, error) { if b.length > 0 { - return nil + return nil, ErrBusyBuffer } // test (cheap) general case first - if length <= defaultBufSize || length <= cap(b.buf) { - return b.buf[:length] + if length <= cap(b.buf) { + return b.buf[:length], nil } if length < maxPacketSize { b.buf = make([]byte, length) - return b.buf + return b.buf, nil } - return make([]byte, length) + + // buffer is larger than we want to store. + return make([]byte, length), nil } -// shortcut which can be used if the requested buffer is guaranteed to be -// smaller than defaultBufSize +// takeSmallBuffer is shortcut which can be used if length is +// known to be smaller than defaultBufSize. // Only one buffer (total) can be used at a time. -func (b *buffer) takeSmallBuffer(length int) []byte { +func (b *buffer) takeSmallBuffer(length int) ([]byte, error) { if b.length > 0 { - return nil + return nil, ErrBusyBuffer } - return b.buf[:length] + return b.buf[:length], nil } // takeCompleteBuffer returns the complete existing buffer. // This can be used if the necessary buffer size is unknown. +// cap and len of the returned buffer will be equal. // Only one buffer (total) can be used at a time. -func (b *buffer) takeCompleteBuffer() []byte { +func (b *buffer) takeCompleteBuffer() ([]byte, error) { + if b.length > 0 { + return nil, ErrBusyBuffer + } + return b.buf, nil +} + +// store stores buf, an updated buffer, if its suitable to do so. +func (b *buffer) store(buf []byte) error { if b.length > 0 { - return nil + return ErrBusyBuffer + } else if cap(buf) <= maxPacketSize && cap(buf) > cap(b.buf) { + b.buf = buf[:cap(buf)] } - return b.buf + return nil } diff --git a/vendor/github.com/go-sql-driver/mysql/connection.go b/vendor/github.com/go-sql-driver/mysql/connection.go index e57061412..fc4ec7597 100644 --- a/vendor/github.com/go-sql-driver/mysql/connection.go +++ b/vendor/github.com/go-sql-driver/mysql/connection.go @@ -9,6 +9,8 @@ package mysql import ( + "context" + "database/sql" "database/sql/driver" "io" "net" @@ -17,16 +19,6 @@ import ( "time" ) -// a copy of context.Context for Go 1.7 and earlier -type mysqlContext interface { - Done() <-chan struct{} - Err() error - - // defined in context.Context, but not used in this driver: - // Deadline() (deadline time.Time, ok bool) - // Value(key interface{}) interface{} -} - type mysqlConn struct { buf buffer netConn net.Conn @@ -43,7 +35,7 @@ type mysqlConn struct { // for context support (Go 1.8+) watching bool - watcher chan<- mysqlContext + watcher chan<- context.Context closech chan struct{} finished chan<- struct{} canceled atomicError // set non-nil if conn is canceled @@ -190,10 +182,10 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin return "", driver.ErrSkip } - buf := mc.buf.takeCompleteBuffer() - if buf == nil { + buf, err := mc.buf.takeCompleteBuffer() + if err != nil { // can not take the buffer. Something must be wrong with the connection - errLog.Print(ErrBusyBuffer) + errLog.Print(err) return "", ErrInvalidConn } buf = buf[:0] @@ -459,3 +451,193 @@ func (mc *mysqlConn) finish() { case <-mc.closech: } } + +// Ping implements driver.Pinger interface +func (mc *mysqlConn) Ping(ctx context.Context) (err error) { + if mc.closed.IsSet() { + errLog.Print(ErrInvalidConn) + return driver.ErrBadConn + } + + if err = mc.watchCancel(ctx); err != nil { + return + } + defer mc.finish() + + if err = mc.writeCommandPacket(comPing); err != nil { + return mc.markBadConn(err) + } + + return mc.readResultOK() +} + +// BeginTx implements driver.ConnBeginTx interface +func (mc *mysqlConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { + if err := mc.watchCancel(ctx); err != nil { + return nil, err + } + defer mc.finish() + + if sql.IsolationLevel(opts.Isolation) != sql.LevelDefault { + level, err := mapIsolationLevel(opts.Isolation) + if err != nil { + return nil, err + } + err = mc.exec("SET TRANSACTION ISOLATION LEVEL " + level) + if err != nil { + return nil, err + } + } + + return mc.begin(opts.ReadOnly) +} + +func (mc *mysqlConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { + dargs, err := namedValueToValue(args) + if err != nil { + return nil, err + } + + if err := mc.watchCancel(ctx); err != nil { + return nil, err + } + + rows, err := mc.query(query, dargs) + if err != nil { + mc.finish() + return nil, err + } + rows.finish = mc.finish + return rows, err +} + +func (mc *mysqlConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { + dargs, err := namedValueToValue(args) + if err != nil { + return nil, err + } + + if err := mc.watchCancel(ctx); err != nil { + return nil, err + } + defer mc.finish() + + return mc.Exec(query, dargs) +} + +func (mc *mysqlConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { + if err := mc.watchCancel(ctx); err != nil { + return nil, err + } + + stmt, err := mc.Prepare(query) + mc.finish() + if err != nil { + return nil, err + } + + select { + default: + case <-ctx.Done(): + stmt.Close() + return nil, ctx.Err() + } + return stmt, nil +} + +func (stmt *mysqlStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { + dargs, err := namedValueToValue(args) + if err != nil { + return nil, err + } + + if err := stmt.mc.watchCancel(ctx); err != nil { + return nil, err + } + + rows, err := stmt.query(dargs) + if err != nil { + stmt.mc.finish() + return nil, err + } + rows.finish = stmt.mc.finish + return rows, err +} + +func (stmt *mysqlStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { + dargs, err := namedValueToValue(args) + if err != nil { + return nil, err + } + + if err := stmt.mc.watchCancel(ctx); err != nil { + return nil, err + } + defer stmt.mc.finish() + + return stmt.Exec(dargs) +} + +func (mc *mysqlConn) watchCancel(ctx context.Context) error { + if mc.watching { + // Reach here if canceled, + // so the connection is already invalid + mc.cleanup() + return nil + } + // When ctx is already cancelled, don't watch it. + if err := ctx.Err(); err != nil { + return err + } + // When ctx is not cancellable, don't watch it. + if ctx.Done() == nil { + return nil + } + // When watcher is not alive, can't watch it. + if mc.watcher == nil { + return nil + } + + mc.watching = true + mc.watcher <- ctx + return nil +} + +func (mc *mysqlConn) startWatcher() { + watcher := make(chan context.Context, 1) + mc.watcher = watcher + finished := make(chan struct{}) + mc.finished = finished + go func() { + for { + var ctx context.Context + select { + case ctx = <-watcher: + case <-mc.closech: + return + } + + select { + case <-ctx.Done(): + mc.cancel(ctx.Err()) + case <-finished: + case <-mc.closech: + return + } + } + }() +} + +func (mc *mysqlConn) CheckNamedValue(nv *driver.NamedValue) (err error) { + nv.Value, err = converter{}.ConvertValue(nv.Value) + return +} + +// ResetSession implements driver.SessionResetter. +// (From Go 1.10) +func (mc *mysqlConn) ResetSession(ctx context.Context) error { + if mc.closed.IsSet() { + return driver.ErrBadConn + } + return nil +} diff --git a/vendor/github.com/go-sql-driver/mysql/connection_go18.go b/vendor/github.com/go-sql-driver/mysql/connection_go18.go deleted file mode 100644 index 62796bfce..000000000 --- a/vendor/github.com/go-sql-driver/mysql/connection_go18.go +++ /dev/null @@ -1,208 +0,0 @@ -// Go MySQL Driver - A MySQL-Driver for Go's database/sql package -// -// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this file, -// You can obtain one at http://mozilla.org/MPL/2.0/. - -// +build go1.8 - -package mysql - -import ( - "context" - "database/sql" - "database/sql/driver" -) - -// Ping implements driver.Pinger interface -func (mc *mysqlConn) Ping(ctx context.Context) (err error) { - if mc.closed.IsSet() { - errLog.Print(ErrInvalidConn) - return driver.ErrBadConn - } - - if err = mc.watchCancel(ctx); err != nil { - return - } - defer mc.finish() - - if err = mc.writeCommandPacket(comPing); err != nil { - return - } - - return mc.readResultOK() -} - -// BeginTx implements driver.ConnBeginTx interface -func (mc *mysqlConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { - if err := mc.watchCancel(ctx); err != nil { - return nil, err - } - defer mc.finish() - - if sql.IsolationLevel(opts.Isolation) != sql.LevelDefault { - level, err := mapIsolationLevel(opts.Isolation) - if err != nil { - return nil, err - } - err = mc.exec("SET TRANSACTION ISOLATION LEVEL " + level) - if err != nil { - return nil, err - } - } - - return mc.begin(opts.ReadOnly) -} - -func (mc *mysqlConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { - dargs, err := namedValueToValue(args) - if err != nil { - return nil, err - } - - if err := mc.watchCancel(ctx); err != nil { - return nil, err - } - - rows, err := mc.query(query, dargs) - if err != nil { - mc.finish() - return nil, err - } - rows.finish = mc.finish - return rows, err -} - -func (mc *mysqlConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { - dargs, err := namedValueToValue(args) - if err != nil { - return nil, err - } - - if err := mc.watchCancel(ctx); err != nil { - return nil, err - } - defer mc.finish() - - return mc.Exec(query, dargs) -} - -func (mc *mysqlConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { - if err := mc.watchCancel(ctx); err != nil { - return nil, err - } - - stmt, err := mc.Prepare(query) - mc.finish() - if err != nil { - return nil, err - } - - select { - default: - case <-ctx.Done(): - stmt.Close() - return nil, ctx.Err() - } - return stmt, nil -} - -func (stmt *mysqlStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { - dargs, err := namedValueToValue(args) - if err != nil { - return nil, err - } - - if err := stmt.mc.watchCancel(ctx); err != nil { - return nil, err - } - - rows, err := stmt.query(dargs) - if err != nil { - stmt.mc.finish() - return nil, err - } - rows.finish = stmt.mc.finish - return rows, err -} - -func (stmt *mysqlStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { - dargs, err := namedValueToValue(args) - if err != nil { - return nil, err - } - - if err := stmt.mc.watchCancel(ctx); err != nil { - return nil, err - } - defer stmt.mc.finish() - - return stmt.Exec(dargs) -} - -func (mc *mysqlConn) watchCancel(ctx context.Context) error { - if mc.watching { - // Reach here if canceled, - // so the connection is already invalid - mc.cleanup() - return nil - } - if ctx.Done() == nil { - return nil - } - - mc.watching = true - select { - default: - case <-ctx.Done(): - return ctx.Err() - } - if mc.watcher == nil { - return nil - } - - mc.watcher <- ctx - - return nil -} - -func (mc *mysqlConn) startWatcher() { - watcher := make(chan mysqlContext, 1) - mc.watcher = watcher - finished := make(chan struct{}) - mc.finished = finished - go func() { - for { - var ctx mysqlContext - select { - case ctx = <-watcher: - case <-mc.closech: - return - } - - select { - case <-ctx.Done(): - mc.cancel(ctx.Err()) - case <-finished: - case <-mc.closech: - return - } - } - }() -} - -func (mc *mysqlConn) CheckNamedValue(nv *driver.NamedValue) (err error) { - nv.Value, err = converter{}.ConvertValue(nv.Value) - return -} - -// ResetSession implements driver.SessionResetter. -// (From Go 1.10) -func (mc *mysqlConn) ResetSession(ctx context.Context) error { - if mc.closed.IsSet() { - return driver.ErrBadConn - } - return nil -} diff --git a/vendor/github.com/go-sql-driver/mysql/driver.go b/vendor/github.com/go-sql-driver/mysql/driver.go index 1a75a16ec..9f4967087 100644 --- a/vendor/github.com/go-sql-driver/mysql/driver.go +++ b/vendor/github.com/go-sql-driver/mysql/driver.go @@ -23,11 +23,6 @@ import ( "sync" ) -// watcher interface is used for context support (From Go 1.8) -type watcher interface { - startWatcher() -} - // MySQLDriver is exported to make the driver directly accessible. // In general the driver is used via the database/sql package. type MySQLDriver struct{} @@ -55,7 +50,7 @@ func RegisterDial(net string, dial DialFunc) { // Open new Connection. // See https://github.com/go-sql-driver/mysql#dsn-data-source-name for how -// the DSN string is formated +// the DSN string is formatted func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { var err error @@ -82,6 +77,10 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { mc.netConn, err = nd.Dial(mc.cfg.Net, mc.cfg.Addr) } if err != nil { + if nerr, ok := err.(net.Error); ok && nerr.Temporary() { + errLog.Print("net.Error from Dial()': ", nerr.Error()) + return nil, driver.ErrBadConn + } return nil, err } @@ -96,9 +95,7 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { } // Call startWatcher for context support (From Go 1.8) - if s, ok := interface{}(mc).(watcher); ok { - s.startWatcher() - } + mc.startWatcher() mc.buf = newBuffer(mc.netConn) @@ -112,20 +109,23 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { mc.cleanup() return nil, err } + if plugin == "" { + plugin = defaultAuthPlugin + } // Send Client Authentication Packet - authResp, addNUL, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin) if err != nil { // try the default auth plugin, if using the requested plugin failed errLog.Print("could not use requested auth plugin '"+plugin+"': ", err.Error()) plugin = defaultAuthPlugin - authResp, addNUL, err = mc.auth(authData, plugin) + authResp, err = mc.auth(authData, plugin) if err != nil { mc.cleanup() return nil, err } } - if err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin); err != nil { + if err = mc.writeHandshakeResponsePacket(authResp, plugin); err != nil { mc.cleanup() return nil, err } diff --git a/vendor/github.com/go-sql-driver/mysql/dsn.go b/vendor/github.com/go-sql-driver/mysql/dsn.go index be014babe..b9134722e 100644 --- a/vendor/github.com/go-sql-driver/mysql/dsn.go +++ b/vendor/github.com/go-sql-driver/mysql/dsn.go @@ -560,7 +560,7 @@ func parseDSNParams(cfg *Config, params string) (err error) { } else { cfg.TLSConfig = "false" } - } else if vl := strings.ToLower(value); vl == "skip-verify" { + } else if vl := strings.ToLower(value); vl == "skip-verify" || vl == "preferred" { cfg.TLSConfig = vl cfg.tls = &tls.Config{InsecureSkipVerify: true} } else { diff --git a/vendor/github.com/go-sql-driver/mysql/packets.go b/vendor/github.com/go-sql-driver/mysql/packets.go index d873a97b2..5e0853767 100644 --- a/vendor/github.com/go-sql-driver/mysql/packets.go +++ b/vendor/github.com/go-sql-driver/mysql/packets.go @@ -51,7 +51,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { mc.sequence++ // packets with length 0 terminate a previous packet which is a - // multiple of (2^24)−1 bytes long + // multiple of (2^24)-1 bytes long if pktLen == 0 { // there was no previous packet if prevData == nil { @@ -154,15 +154,15 @@ func (mc *mysqlConn) writePacket(data []byte) error { // Handshake Initialization Packet // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake -func (mc *mysqlConn) readHandshakePacket() ([]byte, string, error) { - data, err := mc.readPacket() +func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err error) { + data, err = mc.readPacket() if err != nil { // for init we can rewrite this to ErrBadConn for sql.Driver to retry, since // in connection initialization we don't risk retrying non-idempotent actions. if err == ErrInvalidConn { return nil, "", driver.ErrBadConn } - return nil, "", err + return } if data[0] == iERR { @@ -194,11 +194,14 @@ func (mc *mysqlConn) readHandshakePacket() ([]byte, string, error) { return nil, "", ErrOldProtocol } if mc.flags&clientSSL == 0 && mc.cfg.tls != nil { - return nil, "", ErrNoTLS + if mc.cfg.TLSConfig == "preferred" { + mc.cfg.tls = nil + } else { + return nil, "", ErrNoTLS + } } pos += 2 - plugin := "" if len(data) > pos { // character set [1 byte] // status flags [2 bytes] @@ -236,8 +239,6 @@ func (mc *mysqlConn) readHandshakePacket() ([]byte, string, error) { return b[:], plugin, nil } - plugin = defaultAuthPlugin - // make a memory safe copy of the cipher slice var b [8]byte copy(b[:], authData) @@ -246,7 +247,7 @@ func (mc *mysqlConn) readHandshakePacket() ([]byte, string, error) { // Client Authentication Packet // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse -func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, addNUL bool, plugin string) error { +func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string) error { // Adjust client flags based on server support clientFlags := clientProtocol41 | clientSecureConn | @@ -272,7 +273,8 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, addNUL bool, // encode length of the auth plugin data var authRespLEIBuf [9]byte - authRespLEI := appendLengthEncodedInteger(authRespLEIBuf[:0], uint64(len(authResp))) + authRespLen := len(authResp) + authRespLEI := appendLengthEncodedInteger(authRespLEIBuf[:0], uint64(authRespLen)) if len(authRespLEI) > 1 { // if the length can not be written in 1 byte, it must be written as a // length encoded integer @@ -280,9 +282,6 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, addNUL bool, } pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + len(authRespLEI) + len(authResp) + 21 + 1 - if addNUL { - pktLen++ - } // To specify a db name if n := len(mc.cfg.DBName); n > 0 { @@ -291,10 +290,10 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, addNUL bool, } // Calculate packet length and get buffer with that size - data := mc.buf.takeSmallBuffer(pktLen + 4) - if data == nil { + data, err := mc.buf.takeSmallBuffer(pktLen + 4) + if err != nil { // cannot take the buffer. Something must be wrong with the connection - errLog.Print(ErrBusyBuffer) + errLog.Print(err) return errBadConnNoWrite } @@ -353,10 +352,6 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, addNUL bool, // Auth Data [length encoded integer] pos += copy(data[pos:], authRespLEI) pos += copy(data[pos:], authResp) - if addNUL { - data[pos] = 0x00 - pos++ - } // Databasename [null terminated string] if len(mc.cfg.DBName) > 0 { @@ -367,30 +362,24 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, addNUL bool, pos += copy(data[pos:], plugin) data[pos] = 0x00 + pos++ // Send Auth packet - return mc.writePacket(data) + return mc.writePacket(data[:pos]) } // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse -func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte, addNUL bool) error { +func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte) error { pktLen := 4 + len(authData) - if addNUL { - pktLen++ - } - data := mc.buf.takeSmallBuffer(pktLen) - if data == nil { + data, err := mc.buf.takeSmallBuffer(pktLen) + if err != nil { // cannot take the buffer. Something must be wrong with the connection - errLog.Print(ErrBusyBuffer) + errLog.Print(err) return errBadConnNoWrite } // Add the auth data [EOF] copy(data[4:], authData) - if addNUL { - data[pktLen-1] = 0x00 - } - return mc.writePacket(data) } @@ -402,10 +391,10 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error { // Reset Packet Sequence mc.sequence = 0 - data := mc.buf.takeSmallBuffer(4 + 1) - if data == nil { + data, err := mc.buf.takeSmallBuffer(4 + 1) + if err != nil { // cannot take the buffer. Something must be wrong with the connection - errLog.Print(ErrBusyBuffer) + errLog.Print(err) return errBadConnNoWrite } @@ -421,10 +410,10 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error { mc.sequence = 0 pktLen := 1 + len(arg) - data := mc.buf.takeBuffer(pktLen + 4) - if data == nil { + data, err := mc.buf.takeBuffer(pktLen + 4) + if err != nil { // cannot take the buffer. Something must be wrong with the connection - errLog.Print(ErrBusyBuffer) + errLog.Print(err) return errBadConnNoWrite } @@ -442,10 +431,10 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error { // Reset Packet Sequence mc.sequence = 0 - data := mc.buf.takeSmallBuffer(4 + 1 + 4) - if data == nil { + data, err := mc.buf.takeSmallBuffer(4 + 1 + 4) + if err != nil { // cannot take the buffer. Something must be wrong with the connection - errLog.Print(ErrBusyBuffer) + errLog.Print(err) return errBadConnNoWrite } @@ -482,7 +471,7 @@ func (mc *mysqlConn) readAuthResult() ([]byte, string, error) { return data[1:], "", err case iEOF: - if len(data) < 1 { + if len(data) == 1 { // https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::OldAuthSwitchRequest return nil, "mysql_old_password", nil } @@ -898,7 +887,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { const minPktLen = 4 + 1 + 4 + 1 + 4 mc := stmt.mc - // Determine threshould dynamically to avoid packet size shortage. + // Determine threshold dynamically to avoid packet size shortage. longDataSize := mc.maxAllowedPacket / (stmt.paramCount + 1) if longDataSize < 64 { longDataSize = 64 @@ -908,15 +897,17 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { mc.sequence = 0 var data []byte + var err error if len(args) == 0 { - data = mc.buf.takeBuffer(minPktLen) + data, err = mc.buf.takeBuffer(minPktLen) } else { - data = mc.buf.takeCompleteBuffer() + data, err = mc.buf.takeCompleteBuffer() + // In this case the len(data) == cap(data) which is used to optimise the flow below. } - if data == nil { + if err != nil { // cannot take the buffer. Something must be wrong with the connection - errLog.Print(ErrBusyBuffer) + errLog.Print(err) return errBadConnNoWrite } @@ -942,7 +933,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { pos := minPktLen var nullMask []byte - if maskLen, typesLen := (len(args)+7)/8, 1+2*len(args); pos+maskLen+typesLen >= len(data) { + if maskLen, typesLen := (len(args)+7)/8, 1+2*len(args); pos+maskLen+typesLen >= cap(data) { // buffer has to be extended but we don't know by how much so // we depend on append after all data with known sizes fit. // We stop at that because we deal with a lot of columns here @@ -951,10 +942,11 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { copy(tmp[:pos], data[:pos]) data = tmp nullMask = data[pos : pos+maskLen] + // No need to clean nullMask as make ensures that. pos += maskLen } else { nullMask = data[pos : pos+maskLen] - for i := 0; i < maskLen; i++ { + for i := range nullMask { nullMask[i] = 0 } pos += maskLen @@ -1091,7 +1083,10 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { // In that case we must build the data packet with the new values buffer if valuesCap != cap(paramValues) { data = append(data[:pos], paramValues...) - mc.buf.buf = data + if err = mc.buf.store(data); err != nil { + errLog.Print(err) + return errBadConnNoWrite + } } pos += len(paramValues) @@ -1261,7 +1256,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { rows.rs.columns[i].decimals, ) } - dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], dstlen, true) + dest[i], err = formatBinaryTime(data[pos:pos+int(num)], dstlen) case rows.mc.parseTime: dest[i], err = parseBinaryDateTime(num, data[pos:], rows.mc.cfg.Loc) default: @@ -1281,7 +1276,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { ) } } - dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], dstlen, false) + dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], dstlen) } if err == nil { diff --git a/vendor/github.com/go-sql-driver/mysql/utils.go b/vendor/github.com/go-sql-driver/mysql/utils.go index 84d595b6b..cb3650bb9 100644 --- a/vendor/github.com/go-sql-driver/mysql/utils.go +++ b/vendor/github.com/go-sql-driver/mysql/utils.go @@ -10,10 +10,13 @@ package mysql import ( "crypto/tls" + "database/sql" "database/sql/driver" "encoding/binary" + "errors" "fmt" "io" + "strconv" "strings" "sync" "sync/atomic" @@ -79,7 +82,7 @@ func DeregisterTLSConfig(key string) { func getTLSConfigClone(key string) (config *tls.Config) { tlsConfigLock.RLock() if v, ok := tlsConfigRegistry[key]; ok { - config = cloneTLSConfig(v) + config = v.Clone() } tlsConfigLock.RUnlock() return @@ -227,141 +230,156 @@ var zeroDateTime = []byte("0000-00-00 00:00:00.000000") const digits01 = "0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789" const digits10 = "0000000000111111111122222222223333333333444444444455555555556666666666777777777788888888889999999999" -func formatBinaryDateTime(src []byte, length uint8, justTime bool) (driver.Value, error) { - // length expects the deterministic length of the zero value, - // negative time and 100+ hours are automatically added if needed - if len(src) == 0 { - if justTime { - return zeroDateTime[11 : 11+length], nil - } - return zeroDateTime[:length], nil +func appendMicrosecs(dst, src []byte, decimals int) []byte { + if decimals <= 0 { + return dst } - var dst []byte // return value - var pt, p1, p2, p3 byte // current digit pair - var zOffs byte // offset of value in zeroDateTime - if justTime { - switch length { - case - 8, // time (can be up to 10 when negative and 100+ hours) - 10, 11, 12, 13, 14, 15: // time with fractional seconds - default: - return nil, fmt.Errorf("illegal TIME length %d", length) - } - switch len(src) { - case 8, 12: - default: - return nil, fmt.Errorf("invalid TIME packet length %d", len(src)) - } - // +2 to enable negative time and 100+ hours - dst = make([]byte, 0, length+2) - if src[0] == 1 { - dst = append(dst, '-') - } - if src[1] != 0 { - hour := uint16(src[1])*24 + uint16(src[5]) - pt = byte(hour / 100) - p1 = byte(hour - 100*uint16(pt)) - dst = append(dst, digits01[pt]) - } else { - p1 = src[5] - } - zOffs = 11 - src = src[6:] - } else { - switch length { - case 10, 19, 21, 22, 23, 24, 25, 26: - default: - t := "DATE" - if length > 10 { - t += "TIME" - } - return nil, fmt.Errorf("illegal %s length %d", t, length) - } - switch len(src) { - case 4, 7, 11: - default: - t := "DATE" - if length > 10 { - t += "TIME" - } - return nil, fmt.Errorf("illegal %s packet length %d", t, len(src)) - } - dst = make([]byte, 0, length) - // start with the date - year := binary.LittleEndian.Uint16(src[:2]) - pt = byte(year / 100) - p1 = byte(year - 100*uint16(pt)) - p2, p3 = src[2], src[3] - dst = append(dst, - digits10[pt], digits01[pt], - digits10[p1], digits01[p1], '-', - digits10[p2], digits01[p2], '-', - digits10[p3], digits01[p3], - ) - if length == 10 { - return dst, nil - } - if len(src) == 4 { - return append(dst, zeroDateTime[10:length]...), nil - } - dst = append(dst, ' ') - p1 = src[4] // hour - src = src[5:] - } - // p1 is 2-digit hour, src is after hour - p2, p3 = src[0], src[1] - dst = append(dst, - digits10[p1], digits01[p1], ':', - digits10[p2], digits01[p2], ':', - digits10[p3], digits01[p3], - ) - if length <= byte(len(dst)) { - return dst, nil - } - src = src[2:] if len(src) == 0 { - return append(dst, zeroDateTime[19:zOffs+length]...), nil + return append(dst, ".000000"[:decimals+1]...) } + microsecs := binary.LittleEndian.Uint32(src[:4]) - p1 = byte(microsecs / 10000) + p1 := byte(microsecs / 10000) microsecs -= 10000 * uint32(p1) - p2 = byte(microsecs / 100) + p2 := byte(microsecs / 100) microsecs -= 100 * uint32(p2) - p3 = byte(microsecs) - switch decimals := zOffs + length - 20; decimals { + p3 := byte(microsecs) + + switch decimals { default: return append(dst, '.', digits10[p1], digits01[p1], digits10[p2], digits01[p2], digits10[p3], digits01[p3], - ), nil + ) case 1: return append(dst, '.', digits10[p1], - ), nil + ) case 2: return append(dst, '.', digits10[p1], digits01[p1], - ), nil + ) case 3: return append(dst, '.', digits10[p1], digits01[p1], digits10[p2], - ), nil + ) case 4: return append(dst, '.', digits10[p1], digits01[p1], digits10[p2], digits01[p2], - ), nil + ) case 5: return append(dst, '.', digits10[p1], digits01[p1], digits10[p2], digits01[p2], digits10[p3], - ), nil + ) } } +func formatBinaryDateTime(src []byte, length uint8) (driver.Value, error) { + // length expects the deterministic length of the zero value, + // negative time and 100+ hours are automatically added if needed + if len(src) == 0 { + return zeroDateTime[:length], nil + } + var dst []byte // return value + var p1, p2, p3 byte // current digit pair + + switch length { + case 10, 19, 21, 22, 23, 24, 25, 26: + default: + t := "DATE" + if length > 10 { + t += "TIME" + } + return nil, fmt.Errorf("illegal %s length %d", t, length) + } + switch len(src) { + case 4, 7, 11: + default: + t := "DATE" + if length > 10 { + t += "TIME" + } + return nil, fmt.Errorf("illegal %s packet length %d", t, len(src)) + } + dst = make([]byte, 0, length) + // start with the date + year := binary.LittleEndian.Uint16(src[:2]) + pt := year / 100 + p1 = byte(year - 100*uint16(pt)) + p2, p3 = src[2], src[3] + dst = append(dst, + digits10[pt], digits01[pt], + digits10[p1], digits01[p1], '-', + digits10[p2], digits01[p2], '-', + digits10[p3], digits01[p3], + ) + if length == 10 { + return dst, nil + } + if len(src) == 4 { + return append(dst, zeroDateTime[10:length]...), nil + } + dst = append(dst, ' ') + p1 = src[4] // hour + src = src[5:] + + // p1 is 2-digit hour, src is after hour + p2, p3 = src[0], src[1] + dst = append(dst, + digits10[p1], digits01[p1], ':', + digits10[p2], digits01[p2], ':', + digits10[p3], digits01[p3], + ) + return appendMicrosecs(dst, src[2:], int(length)-20), nil +} + +func formatBinaryTime(src []byte, length uint8) (driver.Value, error) { + // length expects the deterministic length of the zero value, + // negative time and 100+ hours are automatically added if needed + if len(src) == 0 { + return zeroDateTime[11 : 11+length], nil + } + var dst []byte // return value + + switch length { + case + 8, // time (can be up to 10 when negative and 100+ hours) + 10, 11, 12, 13, 14, 15: // time with fractional seconds + default: + return nil, fmt.Errorf("illegal TIME length %d", length) + } + switch len(src) { + case 8, 12: + default: + return nil, fmt.Errorf("invalid TIME packet length %d", len(src)) + } + // +2 to enable negative time and 100+ hours + dst = make([]byte, 0, length+2) + if src[0] == 1 { + dst = append(dst, '-') + } + days := binary.LittleEndian.Uint32(src[1:5]) + hours := int64(days)*24 + int64(src[5]) + + if hours >= 100 { + dst = strconv.AppendInt(dst, hours, 10) + } else { + dst = append(dst, digits10[hours], digits01[hours]) + } + + min, sec := src[6], src[7] + dst = append(dst, ':', + digits10[min], digits01[min], ':', + digits10[sec], digits01[sec], + ) + return appendMicrosecs(dst, src[8:], int(length)-9), nil +} + /****************************************************************************** * Convert from and to bytes * ******************************************************************************/ @@ -708,3 +726,30 @@ func (ae *atomicError) Value() error { } return nil } + +func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) { + dargs := make([]driver.Value, len(named)) + for n, param := range named { + if len(param.Name) > 0 { + // TODO: support the use of Named Parameters #561 + return nil, errors.New("mysql: driver does not support the use of Named Parameters") + } + dargs[n] = param.Value + } + return dargs, nil +} + +func mapIsolationLevel(level driver.IsolationLevel) (string, error) { + switch sql.IsolationLevel(level) { + case sql.LevelRepeatableRead: + return "REPEATABLE READ", nil + case sql.LevelReadCommitted: + return "READ COMMITTED", nil + case sql.LevelReadUncommitted: + return "READ UNCOMMITTED", nil + case sql.LevelSerializable: + return "SERIALIZABLE", nil + default: + return "", fmt.Errorf("mysql: unsupported isolation level: %v", level) + } +} diff --git a/vendor/github.com/go-sql-driver/mysql/utils_go17.go b/vendor/github.com/go-sql-driver/mysql/utils_go17.go deleted file mode 100644 index f59563456..000000000 --- a/vendor/github.com/go-sql-driver/mysql/utils_go17.go +++ /dev/null @@ -1,40 +0,0 @@ -// Go MySQL Driver - A MySQL-Driver for Go's database/sql package -// -// Copyright 2017 The Go-MySQL-Driver Authors. All rights reserved. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this file, -// You can obtain one at http://mozilla.org/MPL/2.0/. - -// +build go1.7 -// +build !go1.8 - -package mysql - -import "crypto/tls" - -func cloneTLSConfig(c *tls.Config) *tls.Config { - return &tls.Config{ - Rand: c.Rand, - Time: c.Time, - Certificates: c.Certificates, - NameToCertificate: c.NameToCertificate, - GetCertificate: c.GetCertificate, - RootCAs: c.RootCAs, - NextProtos: c.NextProtos, - ServerName: c.ServerName, - ClientAuth: c.ClientAuth, - ClientCAs: c.ClientCAs, - InsecureSkipVerify: c.InsecureSkipVerify, - CipherSuites: c.CipherSuites, - PreferServerCipherSuites: c.PreferServerCipherSuites, - SessionTicketsDisabled: c.SessionTicketsDisabled, - SessionTicketKey: c.SessionTicketKey, - ClientSessionCache: c.ClientSessionCache, - MinVersion: c.MinVersion, - MaxVersion: c.MaxVersion, - CurvePreferences: c.CurvePreferences, - DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled, - Renegotiation: c.Renegotiation, - } -} diff --git a/vendor/github.com/go-sql-driver/mysql/utils_go18.go b/vendor/github.com/go-sql-driver/mysql/utils_go18.go deleted file mode 100644 index c35c2a6aa..000000000 --- a/vendor/github.com/go-sql-driver/mysql/utils_go18.go +++ /dev/null @@ -1,50 +0,0 @@ -// Go MySQL Driver - A MySQL-Driver for Go's database/sql package -// -// Copyright 2017 The Go-MySQL-Driver Authors. All rights reserved. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this file, -// You can obtain one at http://mozilla.org/MPL/2.0/. - -// +build go1.8 - -package mysql - -import ( - "crypto/tls" - "database/sql" - "database/sql/driver" - "errors" - "fmt" -) - -func cloneTLSConfig(c *tls.Config) *tls.Config { - return c.Clone() -} - -func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) { - dargs := make([]driver.Value, len(named)) - for n, param := range named { - if len(param.Name) > 0 { - // TODO: support the use of Named Parameters #561 - return nil, errors.New("mysql: driver does not support the use of Named Parameters") - } - dargs[n] = param.Value - } - return dargs, nil -} - -func mapIsolationLevel(level driver.IsolationLevel) (string, error) { - switch sql.IsolationLevel(level) { - case sql.LevelRepeatableRead: - return "REPEATABLE READ", nil - case sql.LevelReadCommitted: - return "READ COMMITTED", nil - case sql.LevelReadUncommitted: - return "READ UNCOMMITTED", nil - case sql.LevelSerializable: - return "SERIALIZABLE", nil - default: - return "", fmt.Errorf("mysql: unsupported isolation level: %v", level) - } -}