Ai
1 Star 0 Fork 0

h79/goutils

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
adapter.go 7.81 KB
一键复制 编辑 原始数据 按行查看 历史
huqiuyun 提交于 2025-11-18 22:40 +08:00 . es 支持tls
package db
import (
"crypto/rsa"
"errors"
"fmt"
"net/url"
"time"
commonconfig "gitee.com/h79/goutils/common/config"
"gitee.com/h79/goutils/common/data"
"gitee.com/h79/goutils/common/logger"
commonoption "gitee.com/h79/goutils/common/option"
"gitee.com/h79/goutils/common/result"
commontls "gitee.com/h79/goutils/common/tls"
"gitee.com/h79/goutils/dao/config"
"gitee.com/h79/goutils/dao/log"
"gitee.com/h79/goutils/dao/option"
drivermysql "github.com/go-sql-driver/mysql"
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
"gorm.io/driver/sqlserver"
"gorm.io/plugin/dbresolver"
"runtime"
"gorm.io/gorm"
)
type DialFunc func(string) gorm.Dialector
var openFuncs = map[string]DialFunc{
"mysql": mysql.Open,
"postgres": postgres.Open,
"sqlserver": sqlserver.Open,
}
var _ Sql = (*Adapter)(nil)
// Adapter represents the Gorm adapter for policy storage.
type Adapter struct {
driverName string
databaseName string
dsn string
version string
db *gorm.DB
statsEnabled bool
}
type ScopesFunc func(db *gorm.DB) *gorm.DB
// finalizer is the destructor for Adapter.
func finalizer(a *Adapter) {
sqlDB, err := a.db.DB()
if err != nil {
panic(err)
}
err = sqlDB.Close()
if err != nil {
panic(err)
}
}
var DefaultDnsFunc = func(cnf *config.Database, tls, serverPubKey bool) string {
if cnf == nil {
return ""
}
if cnf.DriverType == "mysql" {
if cnf.Charset == "" {
cnf.Charset = "utf8mb4"
}
dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=%s&parseTime=True&loc=Local", cnf.User, cnf.Pwd, cnf.Host, cnf.Port, cnf.Name, cnf.Charset)
if tls {
dsn += "&tls=" + url.QueryEscape(cnf.Tls.Key)
}
if serverPubKey {
dsn += "&serverPubKey=" + url.QueryEscape(cnf.ServerPubKey.Key)
}
return dsn
} else if cnf.DriverType == "postgres" {
return fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s", cnf.Host, cnf.Port, cnf.User, cnf.Pwd, cnf.Name)
} else if cnf.DriverType == "sqlite3" {
return cnf.Name
} else if cnf.DriverType == "sql" {
return fmt.Sprintf("sqlserver://%s:%s@%s:%d?database=%s", cnf.User, cnf.Pwd, cnf.Host, cnf.Port, cnf.Name)
}
return ""
}
func WithDnsOption(f func(cnf *config.Database, tls, pubKey bool) string) commonoption.Option {
return dnsFunc(f)
}
type dnsFunc func(cnf *config.Database, tls bool, pubKey bool) string
func (t dnsFunc) String() string {
return "sql:dns"
}
func (t dnsFunc) Type() int { return option.SqlDnsOpt }
func (t dnsFunc) Value() interface{} { return t }
func dnsFuncExist(opts ...commonoption.Option) dnsFunc {
if r, ok := commonoption.Exist(option.SqlDnsOpt, opts...); ok {
return r.Value().(dnsFunc)
}
return nil
}
func getDns(cnf *config.Database, tls, serverPubKey bool, opts ...commonoption.Option) string {
fn := dnsFuncExist(opts...)
if fn == nil {
fn = DefaultDnsFunc
}
return fn(cnf, tls, serverPubKey)
}
func mySqlTls(cnf *config.Database, opts ...commonoption.Option) error {
if cnf.DriverType != "mysql" {
return result.RErrNotSupport
}
fn := tlsFuncExist(opts...)
if fn == nil {
fn = DefaultTlsFunc
}
tlsCfg, err := fn(&cnf.Tls)
if err != nil {
return err
}
if tlsCfg != nil {
return drivermysql.RegisterTLSConfig(cnf.Tls.Key, tlsCfg)
}
return nil
}
func mySqlServerPubKey(cnf *config.Database, opts ...commonoption.Option) error {
if cnf.DriverType != "mysql" {
return result.RErrNotSupport
}
fn := serverPubKeyFuncExist(opts...)
if fn == nil {
fn = func(cnf *commontls.ServerPubKey) (*rsa.PublicKey, error) {
return commontls.GetServerPubKey(cnf)
}
}
pk, err := fn(&cnf.ServerPubKey)
if err != nil {
return err
}
drivermysql.RegisterServerPubKey(cnf.ServerPubKey.Key, pk)
return nil
}
func databaseSecret(cfg *config.Database, master int, opts ...commonoption.Option) {
sec := option.UseSecret(cfg.Name, master, opts...)
if !sec.HasValid() {
return
}
if sec.User != "" {
cfg.User = sec.User
}
if sec.Pwd != "" {
cfg.Pwd = sec.Pwd
}
}
// NewAdapter is the constructor for Adapter.
func NewAdapter(cfg *config.Sql, opts ...commonoption.Option) (*Adapter, error) {
a := &Adapter{version: cfg.Version}
databaseSecret(&cfg.Master, 0, opts...)
tlsIf := false
err := mySqlTls(&cfg.Master, opts...)
if err == nil {
tlsIf = true
}
pubKeyIf := false
err = mySqlServerPubKey(&cfg.Master, opts...)
if err == nil {
pubKeyIf = true
}
a.driverName = cfg.Master.DriverType
a.databaseName = cfg.Master.Name
a.dsn = getDns(&cfg.Master, tlsIf, pubKeyIf, opts...)
// Open the DB
db, err := openDB(a.driverName, a.dsn)
if err != nil {
return nil, err
}
var sources []gorm.Dialector
var replicas []gorm.Dialector
for _, source := range cfg.Sources {
databaseSecret(&source, 1, opts...)
tlsIf = false
if err = mySqlTls(&source, opts...); err == nil {
tlsIf = true
}
pubKeyIf = false
err = mySqlServerPubKey(&source, opts...)
if err == nil {
pubKeyIf = true
}
dr, er := getDriver(source.DriverType, getDns(&source, tlsIf, pubKeyIf, opts...))
if er != nil {
return nil, er
}
sources = append(sources, dr)
}
for _, replica := range cfg.Replicas {
databaseSecret(&replica, 2, opts...)
tlsIf = false
if err = mySqlTls(&replica, opts...); err == nil {
tlsIf = true
}
pubKeyIf = false
err = mySqlServerPubKey(&replica, opts...)
if err == nil {
pubKeyIf = true
}
dr, er := getDriver(replica.DriverType, getDns(&replica, tlsIf, pubKeyIf, opts...))
if er != nil {
return nil, er
}
replicas = append(replicas, dr)
}
resolver := dbresolver.Register(dbresolver.Config{
Sources: sources,
Replicas: replicas,
// sources/replicas load balancing policy
Policy: dbresolver.RandomPolicy{},
})
if cfg.MaxOpenConns > 0 {
resolver.SetMaxOpenConns(cfg.MaxOpenConns)
}
if cfg.MaxIdleConns > 0 {
resolver.SetMaxIdleConns(cfg.MaxIdleConns)
}
if cfg.MaxLifetime > 0 {
resolver.SetConnMaxLifetime(time.Minute * cfg.MaxLifetime)
}
if cfg.MaxIdleTime > 0 {
resolver.SetConnMaxIdleTime(time.Minute * cfg.MaxIdleTime)
}
if err = db.Use(resolver); err != nil {
return nil, err
}
if cfg.Logger.LogLevel > 1 {
if cfg.Logger.SlowThreshold <= 0 {
cfg.Logger.SlowThreshold = 200
}
cfg.Logger.SlowThreshold = cfg.Logger.SlowThreshold * time.Millisecond
lg := &log.DbLogger{SqlLogger: cfg.Logger}
log.UseDbLogger(lg, opts...)
db.Logger = lg
if commonconfig.RegisterConfig != nil {
commonconfig.RegisterConfig("DB|"+cfg.Name, lg.HandlerConfig)
}
}
a.db = db
// Call the destructor when the object is released.
runtime.SetFinalizer(a, finalizer)
return a, nil
}
func (a *Adapter) Db() *gorm.DB {
return a.db
}
func (a *Adapter) Name() string {
return a.databaseName
}
func (a *Adapter) CollectData() (data.Model, error) {
if a.db == nil {
return nil, fmt.Errorf("obj is nil")
}
sql, err := a.db.DB()
if err != nil {
return nil, fmt.Errorf("get db err: %v", err)
}
logger.D(" Stats", "DB start...")
stats := sql.Stats()
return &Stats{
MaxOpenConnections: stats.MaxOpenConnections,
OpenConnections: stats.OpenConnections,
InUse: stats.InUse,
Idle: stats.Idle,
WaitCount: stats.WaitCount,
WaitDuration: stats.WaitDuration,
MaxIdleClosed: stats.MaxIdleClosed,
MaxIdleTimeClosed: stats.MaxIdleTimeClosed,
MaxLifetimeClosed: stats.MaxLifetimeClosed,
}, nil
}
func (a *Adapter) Close() {
a.db = nil
}
func (a *Adapter) Version() string {
return a.version
}
func AddDriver(driverName string, dial DialFunc) {
openFuncs[driverName] = dial
}
func getDriver(driverName, dataSourceName string) (gorm.Dialector, error) {
driver, ok := openFuncs[driverName]
if !ok {
return nil, errors.New("database dialect is not supported")
}
return driver(dataSourceName), nil
}
func openDB(driverName, dataSourceName string) (*gorm.DB, error) {
dr, err := getDriver(driverName, dataSourceName)
if err != nil {
return nil, err
}
return gorm.Open(dr, &gorm.Config{})
}
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Go
1
https://gitee.com/h79/goutils.git
git@gitee.com:h79/goutils.git
h79
goutils
goutils
v1.32.69

搜索帮助