1 Star 0 Fork 0

h79/goutils

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
adapter.go 7.98 KB
一键复制 编辑 原始数据 按行查看 历史
huqiuyun 提交于 2024-05-27 15:33 . mysql Charset
package db
import (
"crypto/rsa"
"crypto/tls"
"errors"
"fmt"
commonconfig "gitee.com/h79/goutils/common/config"
"gitee.com/h79/goutils/common/result"
"gitee.com/h79/goutils/dao/config"
"gitee.com/h79/goutils/dao/option"
daotls "gitee.com/h79/goutils/dao/util"
drivermysql "github.com/go-sql-driver/mysql"
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
"gorm.io/driver/sqlserver"
"gorm.io/plugin/dbresolver"
"net/url"
"strings"
"time"
"gorm.io/gorm"
"runtime"
)
type DialFunc func(string) gorm.Dialector
var openFuncs = map[string]DialFunc{
"mysql": mysql.Open,
"postgres": postgres.Open,
"sqlserver": sqlserver.Open,
}
var _ Sql = (*Adapter)(nil)
// Adapter represents the Gorm adapter for policy storage.
type Adapter struct {
driverName string
databaseName string
dsn string
db *gorm.DB
}
type ScopesFunc func(db *gorm.DB) *gorm.DB
// finalizer is the destructor for Adapter.
func finalizer(a *Adapter) {
sqlDB, err := a.db.DB()
if err != nil {
panic(err)
}
err = sqlDB.Close()
if err != nil {
panic(err)
}
}
var DefaultDnsFunc = func(cnf *config.Database, tls, serverPubKey bool) string {
if cnf == nil {
return ""
}
if cnf.DriverType == "mysql" {
if cnf.Charset == "" {
cnf.Charset = "utf8mb4"
}
dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=%s&parseTime=True&loc=Local", cnf.User, cnf.Pwd, cnf.Host, cnf.Port, cnf.Name, cnf.Charset)
if tls {
dsn += "&tls=" + url.QueryEscape(cnf.Tls.Key)
}
if serverPubKey {
dsn += "&serverPubKey=" + url.QueryEscape(cnf.ServerPubKey.Key)
}
return dsn
} else if cnf.DriverType == "postgres" {
return fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s", cnf.Host, cnf.Port, cnf.User, cnf.Pwd, cnf.Name)
} else if cnf.DriverType == "sqlite3" {
return cnf.Name
} else if cnf.DriverType == "sql" {
return fmt.Sprintf("sqlserver://%s:%s@%s:%d?database=%s", cnf.User, cnf.Pwd, cnf.Host, cnf.Port, cnf.Name)
}
return ""
}
func WithDnsOption(f func(cnf *config.Database, tls, pubKey bool) string) option.Option {
return dnsFunc(f)
}
type dnsFunc func(cnf *config.Database, tls bool, pubKey bool) string
func (t dnsFunc) String() string {
return "sql:dns"
}
func (t dnsFunc) Type() int { return option.TypeSqlDns }
func (t dnsFunc) Value() interface{} { return t }
func dnsFuncExist(opts ...option.Option) dnsFunc {
if r, ok := option.Exist(option.TypeSqlDns, opts...); ok {
return r.Value().(dnsFunc)
}
return nil
}
var DefaultTlsFunc = func(cnf *config.Database) (*tls.Config, error) {
if strings.EqualFold(cnf.Tls.Key, "true") ||
strings.EqualFold(cnf.Tls.Key, "false") ||
strings.EqualFold(cnf.Tls.Key, "skip-verify") ||
strings.EqualFold(cnf.Tls.Key, "preferred") {
return nil, nil
}
cert, rootCertPool, err := daotls.GetCertificate(&cnf.Tls)
if err != nil {
return nil, err
}
return &tls.Config{
RootCAs: rootCertPool,
Certificates: []tls.Certificate{cert},
}, nil
}
func WithTlsOption(f func(cnf *config.Database) (*tls.Config, error)) option.Option {
return tlsFunc(f)
}
type tlsFunc func(cnf *config.Database) (*tls.Config, error)
func (t tlsFunc) String() string {
return "sql:tls"
}
func (t tlsFunc) Type() int { return option.TypeSqlTls }
func (t tlsFunc) Value() interface{} { return t }
func tlsFuncExist(opts ...option.Option) tlsFunc {
if r, ok := option.Exist(option.TypeSqlTls, opts...); ok {
return r.Value().(tlsFunc)
}
return nil
}
func WithServerPubKeyOption(f func(cnf *config.Database) (*rsa.PublicKey, error)) option.Option {
return ServerPubKeyFunc(f)
}
type ServerPubKeyFunc func(cnf *config.Database) (*rsa.PublicKey, error)
func (t ServerPubKeyFunc) String() string {
return "sql:serverPubKey"
}
func (t ServerPubKeyFunc) Type() int { return option.TypeSqlServerPubKey }
func (t ServerPubKeyFunc) Value() interface{} { return t }
func serverPubKeyFuncExist(opts ...option.Option) ServerPubKeyFunc {
if r, ok := option.Exist(option.TypeSqlServerPubKey, opts...); ok {
return r.Value().(ServerPubKeyFunc)
}
return nil
}
func getDns(cnf *config.Database, tls, serverPubKey bool, opts ...option.Option) string {
fn := dnsFuncExist(opts...)
if fn == nil {
fn = DefaultDnsFunc
}
return fn(cnf, tls, serverPubKey)
}
func UseTls(cnf *config.Database, opts ...option.Option) error {
if cnf.DriverType != "mysql" {
return result.RErrNotSupport
}
fn := tlsFuncExist(opts...)
if fn == nil {
fn = DefaultTlsFunc
}
tlsCfg, err := fn(cnf)
if err != nil {
return err
}
if tlsCfg != nil {
return drivermysql.RegisterTLSConfig(cnf.Tls.Key, tlsCfg)
}
return nil
}
func UseServerPubKey(cnf *config.Database, opts ...option.Option) error {
if cnf.DriverType != "mysql" {
return result.RErrNotSupport
}
fn := serverPubKeyFuncExist(opts...)
if fn == nil {
fn = func(cnf *config.Database) (*rsa.PublicKey, error) {
return daotls.GetServerPubKey(&cnf.ServerPubKey)
}
}
pk, err := fn(cnf)
if err != nil {
return err
}
drivermysql.RegisterServerPubKey(cnf.ServerPubKey.Key, pk)
return nil
}
// NewAdapter is the constructor for Adapter.
func NewAdapter(cfg *config.Sql, opts ...option.Option) (*Adapter, error) {
a := &Adapter{}
tlsIf := false
err := UseTls(&cfg.Master, opts...)
if err == nil {
tlsIf = true
}
pubKeyIf := false
err = UseServerPubKey(&cfg.Master, opts...)
if err == nil {
pubKeyIf = true
}
a.driverName = cfg.Master.DriverType
a.databaseName = cfg.Master.Name
a.dsn = getDns(&cfg.Master, tlsIf, pubKeyIf, opts...)
// Open the DB
db, err := openDB(a.driverName, a.dsn)
if err != nil {
return nil, err
}
var sources []gorm.Dialector
var replicas []gorm.Dialector
for _, source := range cfg.Sources {
tlsIf = false
if err = UseTls(&source, opts...); err == nil {
tlsIf = true
}
pubKeyIf = false
err = UseServerPubKey(&source, opts...)
if err == nil {
pubKeyIf = true
}
dr, er := getDriver(source.DriverType, getDns(&source, tlsIf, pubKeyIf, opts...))
if er != nil {
return nil, er
}
sources = append(sources, dr)
}
for _, replica := range cfg.Replicas {
tlsIf = false
if err = UseTls(&replica, opts...); err == nil {
tlsIf = true
}
pubKeyIf = false
err = UseServerPubKey(&replica, opts...)
if err == nil {
pubKeyIf = true
}
dr, er := getDriver(replica.DriverType, getDns(&replica, tlsIf, pubKeyIf, opts...))
if er != nil {
return nil, er
}
replicas = append(replicas, dr)
}
resolver := dbresolver.Register(dbresolver.Config{
Sources: sources,
Replicas: replicas,
// sources/replicas load balancing policy
Policy: dbresolver.RandomPolicy{},
})
if cfg.MaxOpenConns > 0 {
resolver.SetMaxOpenConns(cfg.MaxOpenConns)
}
if cfg.MaxIdleConns > 0 {
resolver.SetMaxIdleConns(cfg.MaxIdleConns)
}
if cfg.MaxLifetime > 0 {
resolver.SetConnMaxLifetime(cfg.MaxLifetime)
}
if cfg.MaxIdleTime > 0 {
resolver.SetConnMaxIdleTime(time.Minute * cfg.MaxIdleTime)
}
if err = db.Use(resolver); err != nil {
return nil, err
}
if cfg.Logger.LogLevel > 1 {
if cfg.Logger.SlowThreshold <= 0 {
cfg.Logger.SlowThreshold = 200
}
cfg.Logger.SlowThreshold = cfg.Logger.SlowThreshold * time.Millisecond
log := &Logger{
SqlLogger: cfg.Logger,
}
db.Logger = log
if commonconfig.RegisterConfig != nil {
commonconfig.RegisterConfig("DB|"+cfg.Name, log.handlerConfig)
}
}
a.db = db
// Call the destructor when the object is released.
runtime.SetFinalizer(a, finalizer)
return a, nil
}
func (a *Adapter) Db() *gorm.DB {
return a.db
}
func (a *Adapter) Name() string {
return a.databaseName
}
func (a *Adapter) Close() {
a.db = nil
}
func AddDriver(driverName string, dial DialFunc) {
openFuncs[driverName] = dial
}
func getDriver(driverName, dataSourceName string) (gorm.Dialector, error) {
driver, ok := openFuncs[driverName]
if !ok {
return nil, errors.New("database dialect is not supported")
}
return driver(dataSourceName), nil
}
func openDB(driverName, dataSourceName string) (*gorm.DB, error) {
dr, err := getDriver(driverName, dataSourceName)
if err != nil {
return nil, err
}
return gorm.Open(dr, &gorm.Config{})
}
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Go
1
https://gitee.com/h79/goutils.git
git@gitee.com:h79/goutils.git
h79
goutils
goutils
v1.20.70

搜索帮助