diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..6937b3118cb695705398da3c5e889ffa4fb2c2ee --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +.DS_Store +.idea +.vscode +.history \ No newline at end of file diff --git a/README.cn.md b/README.cn.md index c8bf9fa10964866230e358dd970ab13b4b213207..a18fc14920ee5827e04603235c31cfd9ce0540de 100644 --- a/README.cn.md +++ b/README.cn.md @@ -30,7 +30,7 @@ alter system set password_encryption_type=0; ## 特性 -* 适配openGauss SHA256密码认证 +* 适配openGauss SHA256/SM3 密码认证 * 支持连接字符串多host定义 * SSL * 处理`database/sql`坏连接 diff --git a/README.en.md b/README.en.md index e06c1e3b79b0b0f78079f1089c2f90895654b15c..21da5643e4447a232d776aaf708cc82fd791c30d 100644 --- a/README.en.md +++ b/README.en.md @@ -24,7 +24,7 @@ We still prefer to use a more secure encryption method like sha256, so the modif ## Features -* Adapt openGauss SHA256 password authentication +* Adapt openGauss SHA256/SM3 password authentication * Support for multiple host defined connections * SSL * Handles bad connections for `database/sql` diff --git a/README.md b/README.md index 99cb5466db0f83903c1ed9db9e10ac472438b887..3528c7acf68231da89a4cffb3480ac94b49b75b0 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,7 @@ We still prefer to use a more secure encryption method like sha256, so the modif ## Features -* Adapt openGauss SHA256 password authentication +* Adapt openGauss SHA256/SM3 password authentication * Support for multiple host defined connections * SSL * Handles bad connections for `database/sql` diff --git a/TESTS.cn.MD b/TESTS.cn.MD index 35d15358f805b63f77903cb23bbf55cb490f77bc..9d0bd0459f435395d5a4e61cdf80c79316706d3d 100644 --- a/TESTS.cn.MD +++ b/TESTS.cn.MD @@ -30,5 +30,5 @@ enmotech/opengauss:latest 运行测试: ``` -PGHOST=localhost PGPORT=5432 PGUSER=gaussdb PGPASSWORD=Test@123 PGSSLMODE=disable PGDATABASE=postgres go test +PGHOST=localhost PGPORT=5433 PGUSER=sha256 PGPASSWORD=sha256@abc123 PGSSLMODE=disable PGDATABASE=postgres go test ``` diff --git a/config.go b/config.go index 2b989d598604c60b4020fe56861e3d5811a061b9..3af931c0552641fb66a064536da26d72a2430909 100644 --- a/config.go +++ b/config.go @@ -39,7 +39,8 @@ type Config struct { // ValidateConnect is called during a connection attempt after a successful authentication with the PostgreSQL server. // It can be used to validate that the server is acceptable. If this returns an error the connection is closed and the next // fallback config is tried. This allows implementing high availability behavior such as libpq does with target_session_attrs. - // ValidateConnect ValidateConnectFunc + + ValidateConnect ValidateConnectFunc // AfterConnect is called after ValidateConnect. It can be used to set up the connection (e.g. Set session variables // or prepare statements). If this returns an error the connection attempt fails. @@ -348,13 +349,24 @@ func ParseConfig(connString string) (*Config, error) { } } - if settings["target_session_attrs"] == "read-write" || settings["target_session_attrs"] == "read-only" { - // config.ValidateConnect = ValidateConnectTargetSessionAttrsReadWrite - config.targetSessionAttrs = settings["target_session_attrs"] - } else { - return nil, &parseConfigError{connString: connString, msg: fmt.Sprintf("unknown target_session_attrs value: %v", settings["target_session_attrs"])} + switch tsa := settings["target_session_attrs"]; tsa { + case "read-write": + config.ValidateConnect = ValidateConnectTargetSessionAttrsReadWrite + config.targetSessionAttrs = tsa + case "read-only": + config.ValidateConnect = ValidateConnectTargetSessionAttrsReadOnly + config.targetSessionAttrs = tsa + case "primary": + config.ValidateConnect = ValidateConnectTargetSessionAttrsPrimary + config.targetSessionAttrs = tsa + case "standby": + config.ValidateConnect = ValidateConnectTargetSessionAttrsStandby + config.targetSessionAttrs = tsa + case "any", "prefer-standby": + // do nothing + default: + return nil, &parseConfigError{connString: connString, msg: fmt.Sprintf("unknown target_session_attrs value: %v", tsa)} } - return config, nil } diff --git a/conn.go b/conn.go index f8610fc9db00dde66098531dca48afcb37585d76..61be14eb502c84eccfc3c1c9b0dc5847542fa782 100644 --- a/conn.go +++ b/conn.go @@ -32,6 +32,28 @@ var ( errNoLastInsertID = errors.New("no LastInsertId available after the empty statement") ) +const ( + AuthReqOk = 0 + AUTH_REQ_KRB4 = 1 + AUTH_REQ_KRB5 = 2 + AuthReqPassword = 3 + AUTH_REQ_CRYPT = 4 + AuthReqMd5 = 5 + AUTH_REQ_SCM = 6 + AuthReqGss = 7 + AuthReqGssContinue = 8 + AUTH_REQ_SSPI = 9 + AuthReqSha256 = 10 + AuthReqMd5Sha256 = 11 + AuthReqSm3 = 13 + + PlainPassword = 0 + Md5Password = 1 + Sha256Password = 2 + + Sm3Password = 3 +) + type parameterStatus struct { // server version in the same format as server_version_num, or 0 if // unavailable @@ -757,13 +779,16 @@ func (cn *conn) startup() { cn.auth(r) case 'Z': cn.processReadyForQuery(r) - if found, err := cn.ValidateConnect(); err != nil { - cn.c.Close() - panic(err) - } else if found { - return + if cn.config.ValidateConnect != nil { + err := cn.config.ValidateConnect(cn) + if err != nil { + cn.c.Close() + panic(err) + } else { + return + } } - errorf("ValidateConnect failed") + return default: errorf("unknown response for startup: %q", t) } @@ -772,9 +797,9 @@ func (cn *conn) startup() { func (cn *conn) auth(r *readBuf) { switch code := r.int32(); code { - case 0: + case AuthReqOk: // OK - case 3: + case AuthReqPassword: w := cn.writeBuf('p') w.string(cn.config.Password) cn.send(w) @@ -787,7 +812,7 @@ func (cn *conn) auth(r *readBuf) { if r.int32() != 0 { errorf("unexpected authentication response: %q", t) } - case 5: + case AuthReqMd5: s := string(r.next(4)) w := cn.writeBuf('p') w.string("md5" + md5s(md5s(cn.config.Password+cn.config.User)+s)) @@ -801,7 +826,7 @@ func (cn *conn) auth(r *readBuf) { if r.int32() != 0 { errorf("unexpected authentication response: %q", t) } - case 7: // GSSAPI, startup + case AuthReqGss: // GSSAPI, startup if newGss == nil { errorf("kerberos error: no GSSAPI provider registered (import gitee.com/opengauss/openGauss-connector-go-pq/auth/kerberos if you need Kerberos support)") } @@ -835,7 +860,8 @@ func (cn *conn) auth(r *readBuf) { // Store for GSSAPI continue message cn.gss = cli - case 8: // GSSAPI continue + + case AuthReqGssContinue: // GSSAPI continue if cn.gss == nil { errorf("GSSAPI protocol error") @@ -850,18 +876,19 @@ func (cn *conn) auth(r *readBuf) { cn.send(w) } - case 10: + case AuthReqSha256: + // 这里在openGauss为sha256加密办法,主要代码流程来自jdbc相关实现 passwordStoredMethod := r.int32() digest := "" if len(cn.config.Password) == 0 { errorf("The server requested password-based authentication, but no password was provided.") } - if passwordStoredMethod == 0 || passwordStoredMethod == 2 { + if passwordStoredMethod == PlainPassword || passwordStoredMethod == Sha256Password { random64code := string(r.next(64)) token := string(r.next(8)) serverIteration := r.int32() - result := RFC5802Algorithm(cn.config.Password, random64code, token, "", serverIteration) + result := RFC5802Algorithm(cn.config.Password, random64code, token, "", serverIteration, "sha256") if len(result) == 0 { errorf("Invalid username/password,login denied.") } @@ -883,7 +910,7 @@ func (cn *conn) auth(r *readBuf) { errorf("unexpected authentication response: %q", t) } // return - } else if passwordStoredMethod == 1 { + } else if passwordStoredMethod == Md5Password { s := string(r.next(4)) digest = "md5" + md5s(md5s(cn.config.Password+cn.config.User)+s) w := cn.writeBuf('p') @@ -900,12 +927,11 @@ func (cn *conn) auth(r *readBuf) { errorf("unexpected authentication response: %q", t) } } else { - errorf("The password-stored method is not supported ,must be plain , md5 or sha256.") + errorf("The password-stored method is not supported ,must be plain, md5 or sha256.") } // AUTH_REQ_MD5_SHA256 - case 11: - + case AuthReqMd5Sha256: random64code := string(r.next(64)) md5Salt := r.next(4) result := Md5Sha256encode(cn.config.Password, random64code, md5Salt) @@ -926,47 +952,41 @@ func (cn *conn) auth(r *readBuf) { if r.int32() != 0 { errorf("unexpected authentication response: %q", t) } + case AuthReqSm3: // sm3 + passwordStoredMethod := r.int32() + if passwordStoredMethod == Sm3Password { + random64code := string(r.next(64)) + token := string(r.next(8)) + serverIteration := r.int32() - default: - errorf("unknown authentication response: %d", code) - } -} + result := RFC5802Algorithm(cn.config.Password, random64code, token, "", serverIteration, "sm3") + if len(result) == 0 { + errorf("Invalid username/password,login denied.") + } -func (cn *conn) ValidateConnect() (bool, error) { - if cn.config.targetSessionAttrs == "" { - return true, nil - } - sqlText := "show transaction_read_only" + w := cn.writeBuf('p') + w.buf = []byte("p") + w.pos = 1 + w.int32(4 + len(result) + 1) + w.bytes(result) + w.byte(0) + cn.send(w) - cn.log(context.Background(), LogLevelDebug, "Check server is transaction_read_only ?", map[string]interface{}{"sql": sqlText, - "host": cn.config.Host, "port": cn.config.Port, "target_session_attrs": cn.config.targetSessionAttrs}) - inReRows, err := cn.query(sqlText, nil) - if err != nil { - cn.log(context.Background(), LogLevelDebug, "err:"+err.Error(), map[string]interface{}{}) - return false, err - } - defer inReRows.Close() - var dbTranReadOnly string - lastCols := []driver.Value{&dbTranReadOnly} - err = inReRows.Next(lastCols) - if err != nil { - cn.log(context.Background(), LogLevelDebug, "err:"+err.Error(), map[string]interface{}{}) - return false, err - } - readOnly := lastCols[0].(string) - cn.log(context.Background(), LogLevelDebug, "Check server is readOnly ?", map[string]interface{}{"readOnly": readOnly, - "host": cn.config.Host, "port": cn.config.Port}) - - if strings.EqualFold(cn.config.targetSessionAttrs, targetSessionAttrsReadWrite) && - strings.EqualFold(readOnly, "off") { - return true, nil - } else if strings.EqualFold(cn.config.targetSessionAttrs, targetSessionAttrsReadOnly) && - strings.EqualFold(readOnly, "on") { - return true, nil - } else { - return false, nil - } + t, r := cn.recv() + if t != 'R' { + errorf("unexpected password response: %q", t) + } + + if r.int32() != 0 { + errorf("unexpected authentication response: %q", t) + } + } else { + errorf("The password-stored method is not supported ,must be sm3.") + } + default: + errorf("unknown authentication response: %d", code) + } } type format int diff --git a/conn_go18.go b/conn_go18.go index cf3e31ed164d64179604f4d1810ce393c42925b7..753d177a01d9bcebef264a2f29f84f0670e05a76 100644 --- a/conn_go18.go +++ b/conn_go18.go @@ -13,7 +13,7 @@ import ( "time" ) -// Implement the "QueryerContext" interface +// QueryContext Implement the "QueryerContext" interface func (cn *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { list := make([]driver.Value, len(args)) for i, nv := range args { @@ -31,7 +31,7 @@ func (cn *conn) QueryContext(ctx context.Context, query string, args []driver.Na return r, nil } -// Implement the "ExecerContext" interface +// ExecContext Implement the "ExecerContext" interface func (cn *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { list := make([]driver.Value, len(args)) for i, nv := range args { @@ -45,7 +45,7 @@ func (cn *conn) ExecContext(ctx context.Context, query string, args []driver.Nam return cn.Exec(query, list) } -// Implement the "ConnBeginTx" interface +// BeginTx Implement the "ConnBeginTx" interface func (cn *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { var mode string diff --git a/conn_test.go b/conn_test.go index 6b29c163c4ea19a90cbe0a972e490e5805900edc..eccdb98e5fd037ec1d91fcaf7b09a96172b0e389 100644 --- a/conn_test.go +++ b/conn_test.go @@ -48,8 +48,8 @@ func testConninfo(conninfo string) string { return conninfo } -func openTestConnConninfo(conninfo string) (*sql.DB, error) { - return sql.Open("opengauss", testConninfo(conninfo)) +func openTestConnConninfo(connInfo string) (*sql.DB, error) { + return sql.Open("opengauss", testConninfo(connInfo)) } func openTestConn(t Fatalistic) *sql.DB { diff --git a/conn_validate.go b/conn_validate.go new file mode 100644 index 0000000000000000000000000000000000000000..ca2fea6eea2532e4e985b0533ff229f14ac20cef --- /dev/null +++ b/conn_validate.go @@ -0,0 +1,112 @@ +// 2022/1/14 Bin Liu + +package pq + +import ( + "context" + "database/sql/driver" + "errors" + "fmt" + "strings" +) + +type ValidateConnectFunc func(conn *conn) error + +const ( + showTransactionReadOnly = "show transaction_read_only" + pgIsInRecovery = "select pg_is_in_recovery()" +) + +func validateConnectTargetSessionAttrsTransaction(cn *conn, expectedStatus string) (bool, error) { + cn.log(context.Background(), LogLevelDebug, "Check server is transaction_read_only ?", map[string]interface{}{"sql": showTransactionReadOnly, + "host": cn.config.Host, "port": cn.config.Port, "target_session_attrs": cn.config.targetSessionAttrs}) + inReRows, err := cn.query(showTransactionReadOnly, nil) + defer inReRows.Close() + var dbTranReadOnly string + lastCols := []driver.Value{&dbTranReadOnly} + err = inReRows.Next(lastCols) + if err != nil { + cn.log(context.Background(), LogLevelDebug, "err:"+err.Error(), map[string]interface{}{}) + return false, err + } + readOnly := lastCols[0].(string) + cn.log(context.Background(), LogLevelDebug, "Check server is readOnly ?", map[string]interface{}{"readOnly": readOnly, + "host": cn.config.Host, "port": cn.config.Port}) + if strings.EqualFold(readOnly, expectedStatus) { + return true, nil + } + return false, nil +} +func ValidateConnectTargetSessionAttrsReadWrite(cn *conn) error { + // omm=# show transaction_read_only; + // transaction_read_only + // ----------------------- + // off + // (1 row) + b, err := validateConnectTargetSessionAttrsTransaction(cn, "off") + if err != nil { + return err + } + if !b { + return errors.New("connection is not read write") + } + return nil +} +func ValidateConnectTargetSessionAttrsReadOnly(cn *conn) error { + // omm=# show transaction_read_only; + // transaction_read_only + // ----------------------- + // on + // (1 row) + b, err := validateConnectTargetSessionAttrsTransaction(cn, "on") + if err != nil { + return err + } + if !b { + return errors.New("connection is not read only") + } + return nil +} + +func validateConnectTargetSessionAttrsRecovery(cn *conn, expectedIsRecovery bool) (bool, error) { + cn.log(context.Background(), LogLevelDebug, "Check server is pg_is_in_recovery ?", map[string]interface{}{"sql": pgIsInRecovery, + "host": cn.config.Host, "port": cn.config.Port, "target_session_attrs": cn.config.targetSessionAttrs}) + inReRows, err := cn.query(pgIsInRecovery, nil) + defer inReRows.Close() + var dbTranReadOnly string + lastCols := []driver.Value{&dbTranReadOnly} + err = inReRows.Next(lastCols) + if err != nil { + cn.log(context.Background(), LogLevelDebug, "err:"+err.Error(), map[string]interface{}{}) + return false, err + } + pgIsRecovery := lastCols[0].(bool) + fmt.Println(pgIsRecovery, expectedIsRecovery) + cn.log(context.Background(), LogLevelDebug, "Check server is pg_is_in_recovery ?", map[string]interface{}{"pgIsRecovery": pgIsRecovery, + "host": cn.config.Host, "port": cn.config.Port}) + if expectedIsRecovery == pgIsRecovery { + return true, nil + } + return false, nil +} +func ValidateConnectTargetSessionAttrsPrimary(cn *conn) error { + b, err := validateConnectTargetSessionAttrsRecovery(cn, false) + if err != nil { + return err + } + if !b { + return errors.New("connection is not primary instance") + } + return nil +} +func ValidateConnectTargetSessionAttrsStandby(cn *conn) error { + + b, err := validateConnectTargetSessionAttrsRecovery(cn, true) + if err != nil { + return err + } + if !b { + return errors.New("connection is not standby instance") + } + return nil +} diff --git a/connector.go b/connector.go index 095155cf69ab21b4be1c392be03bbaf25069abba..511e29082e554e05309c5513a0ea9c8688241984 100644 --- a/connector.go +++ b/connector.go @@ -214,7 +214,7 @@ func (c *Connector) connectFallbackConfig(ctx context.Context, config *Config, f fallbackConfig: fallbackConfig, } - cn.log(ctx, LogLevelInfo, "Dialing server", map[string]interface{}{"host": fallbackConfig.Host, "port": fallbackConfig.Port}) + cn.log(ctx, LogLevelInfo, "dialing server", map[string]interface{}{"host": fallbackConfig.Host, "port": fallbackConfig.Port}) network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port) cn.c, err = config.DialFunc(ctx, network, address) if err != nil { diff --git a/defaults.go b/defaults.go index 15c5feff7dd1e1646964c991ab94a0fc27e7c17b..84d579368c4a70c9001e40d36dab1fdc75a597e0 100644 --- a/defaults.go +++ b/defaults.go @@ -24,7 +24,7 @@ func defaultSettings() map[string]string { settings["servicefile"] = filepath.Join(user.HomeDir, ".pg_service.conf") } - settings["target_session_attrs"] = targetSessionAttrsReadWrite + settings["target_session_attrs"] = "any" settings["min_read_buffer_size"] = "8192" diff --git a/defaults_windows.go b/defaults_windows.go index 7576a3b97a72bea2400a6c33f6f6f77dba092cb2..5fbeb8953d345382a9e8f7383d96e90e3506675b 100644 --- a/defaults_windows.go +++ b/defaults_windows.go @@ -31,7 +31,7 @@ func defaultSettings() map[string]string { settings["servicefile"] = filepath.Join(user.HomeDir, ".pg_service.conf") } - settings["target_session_attrs"] = targetSessionAttrsReadWrite + settings["target_session_attrs"] = "any" settings["min_read_buffer_size"] = "8192" diff --git a/error.go b/error.go index 50c5e25e88ac736fd77bca16c203be1bc976ae95..985323586a1cccc615913a49b902ad3aca5a7152 100644 --- a/error.go +++ b/error.go @@ -494,11 +494,12 @@ func (cn *conn) errRecover(err *error) { cn.setBad() panic(v) case *Error: - if v.Fatal() { - *err = driver.ErrBadConn - } else { - *err = v - } + // if v.Fatal() { + // *err = driver.ErrBadConn + // } else { + // *err = v + // } + *err = v case *net.OpError: cn.setBad() *err = v diff --git a/example/multi_ip/multi_ip.go b/example/multi_ip/multi_ip.go index ca14ad2dffcdf0384de52d3b6740b5e3eb38c001..726f046e7f1d4363dd3f495ad120af68fbfbcd06 100644 --- a/example/multi_ip/multi_ip.go +++ b/example/multi_ip/multi_ip.go @@ -29,8 +29,8 @@ DSN="user=gaussdb password=secret host=foo,bar,baz port=5432,5432,5433 dbname=my ) func main() { - // os.Setenv("DSN","postgres://dbuser_monitor:Mon@1234@127.0.0.1:1112,127.0.0.1:1111/postgres?" + - // "sslmode=disable&loggerLevel=debug&target_session_attrs=read-only") + // os.Setenv("DSN", "postgres://mogdb:mtkOP@123@127.0.0.1:5436,127.0.0.1:1111/postgres?"+ + // "sslmode=disable&loggerLevel=debug") connStr := os.Getenv("DSN") if connStr == "" { fmt.Println("please define the env DSN. example:\n" + dsnExample) @@ -75,20 +75,21 @@ func getNodeName(db *sql.DB) error { // return err // } // defer tx.Commit() - var nodeName, sysdate string + var sysdate string var pgIsInRecovery bool - err = db.QueryRow("select sysdate,node_name,pg_is_in_recovery() from dbe_perf.global_instance_time limit 1 "). - Scan(&sysdate, &nodeName, &pgIsInRecovery) + var nodeName string + err = db.QueryRow("select sysdate,pg_is_in_recovery();"). + Scan(&sysdate, &pgIsInRecovery) if err != nil { return err } var channel string - err = db.QueryRow("select channel from pg_stat_get_wal_senders() limit 1 "). - Scan(&channel) + // err = db.QueryRow("select channel from pg_stat_get_wal_senders() limit 1 "). + // Scan(&channel) fmt.Println(sysdate, nodeName, pgIsInRecovery, channel) - if err != nil { - return err - } + // if err != nil { + // return err + // } return nil } diff --git a/go.mod b/go.mod index c262143b6c4f747df84b484233fd4f099536d262..5570d30091376cda2583f7e91afb853097011ca5 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,9 @@ require ( github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b github.com/kr/pretty v0.1.0 // indirect github.com/stretchr/testify v1.7.0 + github.com/tjfoc/gmsm v1.4.1 golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97 golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect + ) diff --git a/go.sum b/go.sum index fda3bf0dfca525ac151b79d2f1c33644a3a9d5fc..fa5e34a4f086c5034f454229a5ed20f4f0184e05 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,29 @@ -github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= +cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= +github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= +github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= +github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= +github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= +github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= +github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= +github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= +github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= +github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= +github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= +github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= +github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= +github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= +github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= +github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b h1:C8S2+VttkHFdOOCXJe+YGfa4vHYwlt4Zx+IVXQ97jYg= github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= @@ -11,25 +33,69 @@ github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/tjfoc/gmsm v1.4.1 h1:aMe1GlZb+0bLjn+cKTPEvvn9oUEBlJitaZiiBwsbgho= +github.com/tjfoc/gmsm v1.4.1/go.mod h1:j4INPkHWMrhJb38G+J6W4Tw0AbuN8Thu3PbdVYhVcTE= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20201012173705-84dcc777aaee/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97 h1:/UOmuWzQfxxo9UtlXMwuQU8CMgg1eZXqTRwkSQJWKOI= golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= +golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= +golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= +golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20201010224723-4f7140c49acb/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= +golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/text v0.3.3 h1:cokOdA+Jmi5PJGXLlLllQSgYigAEfHXJAERHVMaCc2k= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= +golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= +google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= +google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= +google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= +google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= +google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= +google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY= +google.golang.org/grpc v1.31.0/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM8pak= +google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= +google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= +google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= +google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= +google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= +google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/rfcdigest.go b/rfcdigest.go index b120f786e11a7371f82db5f170793015a7291ddb..f339ec6af903704c6dca331fc8a661e7abf11def 100644 --- a/rfcdigest.go +++ b/rfcdigest.go @@ -8,6 +8,7 @@ import ( "crypto/sha256" + "github.com/tjfoc/gmsm/sm3" "golang.org/x/crypto/pbkdf2" ) @@ -69,6 +70,13 @@ func getSha256(message []byte) []byte { return hash.Sum(nil) } +func getSm3(message []byte) []byte { + hash := sm3.New() + hash.Write(message) + + return hash.Sum(nil) +} + func XorBetweenPassword(password1 []byte, password2 []byte, length int) []byte { array := make([]byte, length) for i := 0; i < length; i++ { @@ -124,11 +132,17 @@ RFC5802Algorithm return result; } */ -func RFC5802Algorithm(password string, random64code string, token string, serverSignature string, serverIteration int) []byte { +func RFC5802Algorithm(password string, random64code string, token string, serverSignature string, serverIteration int, method string) []byte { k := generateKFromPBKDF2(password, random64code, serverIteration) serverKey := getKeyFromHmac(k, []byte("Sever Key")) clientKey := getKeyFromHmac(k, []byte("Client Key")) - storedKey := getSha256(clientKey) + var storedKey []byte + + if strings.EqualFold(method, "sha256") { + storedKey = getSha256(clientKey) + } else if strings.EqualFold(method, "sm3") { + storedKey = getSm3(clientKey) + } tokenByte := hexStringToBytes(token) clientSignature := getKeyFromHmac(serverKey, tokenByte) if serverSignature != "" && serverSignature != bytesToHexString(clientSignature) { diff --git a/rfcdigest_est.go b/rfcdigest_est.go deleted file mode 100644 index a4ff32fcf0384ba6132078bc4b4e8d91b2b18db1..0000000000000000000000000000000000000000 --- a/rfcdigest_est.go +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright © 2021 Bin Liu - -package pq - -import "testing" - -func TestRFC5802Algorithm(t *testing.T) { - RFC5802Algorithm("1", "2", "3", "4", 5) -} diff --git a/rfcdigest_test.go b/rfcdigest_test.go new file mode 100644 index 0000000000000000000000000000000000000000..82d73878cd0aa0b9e8ffe1fa8cdbf6308b3c7c58 --- /dev/null +++ b/rfcdigest_test.go @@ -0,0 +1,56 @@ +// Copyright © 2021 Bin Liu + +package pq + +import ( + "reflect" + "testing" +) + +func TestRFC5802Algorithm(t *testing.T) { + type args struct { + password string + random64code string + token string + serverSignature string + serverIteration int + method string + } + tests := []struct { + name string + args args + want []byte + }{ + { + name: "sm3", + args: args{ + password: "sm3@abc123", + random64code: "5ae737626add65f8da1b063104a6c4e2dc25b7343d8512a74826dc5b5e3e5188", + token: string([]byte{0, 0, 0, 0, 0, 0, 0, 0}), + serverSignature: "", + serverIteration: 10000, + method: "sm3", + }, + want: []byte{48, 48, 57, 102, 56, 52, 52, 99, 49, 57, 56, 48, 102, 53, 48, 49, 99, 54, 54, 99, 54, 56, 52, 49, 50, 98, 48, 97, 98, 99, 98, 53, 97, 101, 49, 55, 54, 100, 101, 51, 50, 102, 102, 98, 98, 98, 97, 101, 57, 55, 48, 98, 56, 50, 57, 50, 50, 49, 99, 100, 48, 99, 48, 56}, + }, + { + name: "sha256", + args: args{ + password: "sha256@abc123", + random64code: "3458fe51abe962f7b6011a1d73fc14edf50539fae89fb9dda75fbb642d9859bf", + token: string([]byte{50, 99, 102, 55, 49, 102, 49, 48}), + serverSignature: "", + serverIteration: 10000, + method: "sha256", + }, + want: []byte{102, 53, 50, 49, 51, 97, 102, 57, 51, 57, 55, 52, 97, 101, 51, 102, 48, 97, 51, 101, 101, 100, 98, 97, 55, 98, 48, 101, 102, 50, 55, 55, 57, 54, 98, 99, 100, 52, 100, 56, 52, 98, 98, 55, 51, 55, 100, 57, 99, 51, 53, 102, 99, 101, 53, 52, 102, 99, 102, 101, 50, 56, 102, 100}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := RFC5802Algorithm(tt.args.password, tt.args.random64code, tt.args.token, tt.args.serverSignature, tt.args.serverIteration, tt.args.method); !reflect.DeepEqual(got, tt.want) { + t.Errorf("RFC5802Algorithm() = %v, want %v", got, tt.want) + } + }) + } +}