1 Star 0 Fork 0

chuang / gorm-shard

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
sharding.go 15.21 KB
一键复制 编辑 原始数据 按行查看 历史
chuang 提交于 2024-02-29 10:29 . 搬运工
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563
package sharding
import (
"errors"
"fmt"
"hash/crc32"
"strconv"
"strings"
"sync"
"github.com/bwmarrin/snowflake"
"github.com/longbridgeapp/sqlparser"
"golang.org/x/exp/slices"
"gorm.io/gorm"
)
var (
ErrMissingShardingKey = errors.New("sharding key or id required, and use operator =")
ErrInvalidID = errors.New("invalid id format")
ErrInsertDiffSuffix = errors.New("can not insert different suffix table in one query ")
)
var (
ShardingIgnoreStoreKey = "sharding_ignore"
)
type Sharding struct {
*gorm.DB
ConnPool *ConnPool
configs map[string]Config
querys sync.Map
snowflakeNodes []*snowflake.Node
_config Config
_tables []any
mutex sync.RWMutex
}
// Config specifies the configuration for sharding.
type Config struct {
// 是否自动填充id
FillId bool
// When DoubleWrite enabled, data will double write to both main table and sharding table.
DoubleWrite bool
// ShardingKey specifies the table column you want to used for sharding the table rows.
// For example, for a product order table, you may want to split the rows by `user_id`.
ShardingKey string
// NumberOfShards specifies how many tables you want to sharding.
NumberOfShards uint
// tableFormat specifies the sharding table suffix format.
tableFormat string
// ShardingAlgorithm specifies a function to generate the sharding
// table's suffix by the column value.
// For example, this function implements a mod sharding algorithm.
//
// func(value any) (suffix string, err error) {
// if uid, ok := value.(int64);ok {
// return fmt.Sprintf("_%02d", user_id % 64), nil
// }
// return "", errors.New("invalid user_id")
// }
ShardingAlgorithm func(columnValue any) (suffix string, err error)
// ShardingSuffixs specifies a function to generate all table's suffix.
// Used to support Migrator and generate PrimaryKey.
// For example, this function get a mod all sharding suffixs.
//
// func () (suffixs []string) {
// numberOfShards := 5
// for i := 0; i < numberOfShards; i++ {
// suffixs = append(suffixs, fmt.Sprintf("_%02d", i%numberOfShards))
// }
// return
// }
ShardingSuffixs func() (suffixs []string)
// ShardingAlgorithmByPrimaryKey specifies a function to generate the sharding
// table's suffix by the primary key. Used when no sharding key specified.
// For example, this function use the Snowflake library to generate the suffix.
//
// func(id int64) (suffix string) {
// return fmt.Sprintf("_%02d", snowflake.ParseInt64(id).Node())
// }
ShardingAlgorithmByPrimaryKey func(id int64) (suffix string)
// PrimaryKeyGenerator specifies the primary key generate algorithm.
// Used only when insert and the record does not contains an id field.
// Options are PKSnowflake, PKPGSequence and PKCustom.
// When use PKCustom, you should also specify PrimaryKeyGeneratorFn.
PrimaryKeyGenerator int
// PrimaryKeyGeneratorFn specifies a function to generate the primary key.
// When use auto-increment like generator, the tableIdx argument could ignored.
// For example, this function use the Snowflake library to generate the primary key.
// If you don't want to auto-fill the `id` or use a primary key that isn't called `id`, just return 0.
//
// func(tableIdx int64) int64 {
// return nodes[tableIdx].Generate().Int64()
// }
PrimaryKeyGeneratorFn func(tableIdx int64) int64
}
func Register(config Config, tables ...any) *Sharding {
return &Sharding{
_config: config,
_tables: tables,
}
}
func (s *Sharding) compile() error {
if s.configs == nil {
s.configs = make(map[string]Config)
}
for _, table := range s._tables {
if t, ok := table.(string); ok {
s.configs[t] = s._config
} else {
stmt := &gorm.Statement{DB: s.DB}
if err := stmt.Parse(table); err == nil {
s.configs[stmt.Table] = s._config
} else {
return err
}
}
}
for t, c := range s.configs {
if c.NumberOfShards > 1024 && c.PrimaryKeyGenerator == PKSnowflake {
panic("Snowflake NumberOfShards should less than 1024")
}
if c.PrimaryKeyGenerator == PKSnowflake {
c.PrimaryKeyGeneratorFn = s.genSnowflakeKey
} else if c.PrimaryKeyGenerator == PKPGSequence {
// Execute SQL to CREATE SEQUENCE for this table if not exist
err := s.createPostgreSQLSequenceKeyIfNotExist(t)
if err != nil {
return err
}
c.PrimaryKeyGeneratorFn = func(index int64) int64 {
return s.genPostgreSQLSequenceKey(t, index)
}
} else if c.PrimaryKeyGenerator == PKMySQLSequence {
err := s.createMySQLSequenceKeyIfNotExist(t)
if err != nil {
return err
}
c.PrimaryKeyGeneratorFn = func(index int64) int64 {
return s.genMySQLSequenceKey(t, index)
}
} else if c.PrimaryKeyGenerator == PKCustom {
if c.PrimaryKeyGeneratorFn == nil {
return errors.New("PrimaryKeyGeneratorFn is required when use PKCustom")
}
} else {
return errors.New("PrimaryKeyGenerator can only be one of PKSnowflake, PKPGSequence, PKMySQLSequence and PKCustom")
}
if c.ShardingAlgorithm == nil {
if c.NumberOfShards == 0 {
return errors.New("specify NumberOfShards or ShardingAlgorithm")
}
if c.NumberOfShards < 10 {
c.tableFormat = "_%01d"
} else if c.NumberOfShards < 100 {
c.tableFormat = "_%02d"
} else if c.NumberOfShards < 1000 {
c.tableFormat = "_%03d"
} else if c.NumberOfShards < 10000 {
c.tableFormat = "_%04d"
}
c.ShardingAlgorithm = func(value any) (suffix string, err error) {
id := 0
switch value := value.(type) {
case int:
id = value
case int64:
id = int(value)
case string:
id, err = strconv.Atoi(value)
if err != nil {
id = int(crc32.ChecksumIEEE([]byte(value)))
}
default:
return "", fmt.Errorf("default algorithm only support integer and string column," +
"if you use other type, specify you own ShardingAlgorithm")
}
return fmt.Sprintf(c.tableFormat, id%int(c.NumberOfShards)), nil
}
}
if c.ShardingSuffixs == nil {
c.ShardingSuffixs = func() (suffixs []string) {
for i := 0; i < int(c.NumberOfShards); i++ {
suffix, err := c.ShardingAlgorithm(i)
if err != nil {
return nil
}
suffixs = append(suffixs, suffix)
}
return
}
}
if c.ShardingAlgorithmByPrimaryKey == nil {
if c.PrimaryKeyGenerator == PKSnowflake {
c.ShardingAlgorithmByPrimaryKey = func(id int64) (suffix string) {
return fmt.Sprintf(c.tableFormat, snowflake.ParseInt64(id).Node())
}
}
}
s.configs[t] = c
}
return nil
}
// Name plugin name for Gorm plugin interface
func (s *Sharding) Name() string {
return "gorm:sharding"
}
// LastQuery get last SQL query
func (s *Sharding) LastQuery() string {
if query, ok := s.querys.Load("last_query"); ok {
return query.(string)
}
return ""
}
// Initialize implement for Gorm plugin interface
func (s *Sharding) Initialize(db *gorm.DB) error {
db.Dialector = NewShardingDialector(db.Dialector, s)
s.DB = db
s.registerCallbacks(db)
for t, c := range s.configs {
if c.PrimaryKeyGenerator == PKPGSequence {
err := s.DB.Exec("CREATE SEQUENCE IF NOT EXISTS " + pgSeqName(t)).Error
if err != nil {
return fmt.Errorf("init postgresql sequence error, %w", err)
}
}
if c.PrimaryKeyGenerator == PKMySQLSequence {
err := s.DB.Exec("CREATE TABLE IF NOT EXISTS " + mySQLSeqName(t) + " (id INT NOT NULL)").Error
if err != nil {
return fmt.Errorf("init mysql create sequence error, %w", err)
}
err = s.DB.Exec("INSERT INTO " + mySQLSeqName(t) + " VALUES (0)").Error
if err != nil {
return fmt.Errorf("init mysql insert sequence error, %w", err)
}
}
}
s.snowflakeNodes = make([]*snowflake.Node, 1024)
for i := int64(0); i < 1024; i++ {
n, err := snowflake.NewNode(i)
if err != nil {
return fmt.Errorf("init snowflake node error, %w", err)
}
s.snowflakeNodes[i] = n
}
return s.compile()
}
func (s *Sharding) registerCallbacks(db *gorm.DB) {
s.Callback().Create().Before("*").Register("gorm:sharding", s.switchConn)
s.Callback().Query().Before("*").Register("gorm:sharding", s.switchConn)
s.Callback().Update().Before("*").Register("gorm:sharding", s.switchConn)
s.Callback().Delete().Before("*").Register("gorm:sharding", s.switchConn)
s.Callback().Row().Before("*").Register("gorm:sharding", s.switchConn)
s.Callback().Raw().Before("*").Register("gorm:sharding", s.switchConn)
}
func (s *Sharding) switchConn(db *gorm.DB) {
// Support ignore sharding in some case, like:
// When DoubleWrite is enabled, we need to query database schema
// information by table name during the migration.
if _, ok := db.Get(ShardingIgnoreStoreKey); !ok {
s.mutex.Lock()
if db.Statement.ConnPool != nil {
s.ConnPool = &ConnPool{ConnPool: db.Statement.ConnPool, sharding: s}
db.Statement.ConnPool = s.ConnPool
}
s.mutex.Unlock()
}
}
// resolve split the old query to full table query and sharding table query
func (s *Sharding) resolve(query string, args ...any) (ftQuery, stQuery, tableName string, err error) {
ftQuery = query
stQuery = query
if len(s.configs) == 0 {
return
}
expr, err := sqlparser.NewParser(strings.NewReader(query)).ParseStatement()
if err != nil {
return ftQuery, stQuery, tableName, nil
}
var table *sqlparser.TableName
var condition sqlparser.Expr
var isInsert bool
var insertNames []*sqlparser.Ident
var insertExpressions []*sqlparser.Exprs
var insertStmt *sqlparser.InsertStatement
switch stmt := expr.(type) {
case *sqlparser.SelectStatement:
tbl, ok := stmt.FromItems.(*sqlparser.TableName)
if !ok {
return
}
if stmt.Hint != nil && stmt.Hint.Value == "nosharding" {
return
}
table = tbl
condition = stmt.Condition
case *sqlparser.InsertStatement:
table = stmt.TableName
isInsert = true
insertNames = stmt.ColumnNames
insertExpressions = stmt.Expressions
insertStmt = stmt
case *sqlparser.UpdateStatement:
condition = stmt.Condition
table = stmt.TableName
case *sqlparser.DeleteStatement:
condition = stmt.Condition
table = stmt.TableName
default:
return ftQuery, stQuery, "", sqlparser.ErrNotImplemented
}
tableName = table.Name.Name
r, ok := s.configs[tableName]
if !ok {
return
}
var suffix string
if isInsert {
var newTable *sqlparser.TableName
for _, insertExpression := range insertExpressions {
var value any
var id int64
var keyFind bool
columnNames := insertNames
insertValues := insertExpression.Exprs
value, id, keyFind, err = s.insertValue(r.ShardingKey, insertNames, insertValues, args...)
if err != nil {
return
}
var subSuffix string
subSuffix, err = getSuffix(value, id, keyFind, r)
if err != nil {
return
}
if suffix != "" && suffix != subSuffix {
err = ErrInsertDiffSuffix
return
}
suffix = subSuffix
newTable = &sqlparser.TableName{Name: &sqlparser.Ident{Name: tableName + suffix}}
fillID := s._config.FillId
if isInsert && fillID {
for _, name := range insertNames {
if name.Name == "id" {
fillID = false
break
}
}
suffixWord := strings.Replace(suffix, "_", "", 1)
tblIdx, err := strconv.Atoi(suffixWord)
if err != nil {
tblIdx = slices.Index(r.ShardingSuffixs(), suffixWord)
if tblIdx == -1 {
return ftQuery, stQuery, tableName, errors.New("table suffix '" + suffixWord + "' is not in ShardingSuffixs. In order to generate the primary key, ShardingSuffixs should include all table suffixes")
}
//return ftQuery, stQuery, tableName, err
}
id := r.PrimaryKeyGeneratorFn(int64(tblIdx))
if id == 0 {
fillID = false
}
if fillID {
columnNames = append(insertNames, &sqlparser.Ident{Name: "id"})
insertValues = append(insertValues, &sqlparser.NumberLit{Value: strconv.FormatInt(id, 10)})
}
}
if fillID {
insertStmt.ColumnNames = columnNames
insertExpression.Exprs = insertValues
}
}
ftQuery = insertStmt.String()
insertStmt.TableName = newTable
stQuery = insertStmt.String()
} else {
var value any
var id int64
var keyFind bool
value, id, keyFind, err = s.nonInsertValue(r.ShardingKey, condition, args...)
if err != nil {
return
}
suffix, err = getSuffix(value, id, keyFind, r)
if err != nil {
return
}
newTable := &sqlparser.TableName{Name: &sqlparser.Ident{Name: tableName + suffix}}
switch stmt := expr.(type) {
case *sqlparser.SelectStatement:
ftQuery = stmt.String()
stmt.FromItems = newTable
stmt.OrderBy = replaceOrderByTableName(stmt.OrderBy, tableName, newTable.Name.Name)
stQuery = stmt.String()
case *sqlparser.UpdateStatement:
ftQuery = stmt.String()
stmt.TableName = newTable
stQuery = stmt.String()
case *sqlparser.DeleteStatement:
ftQuery = stmt.String()
stmt.TableName = newTable
stQuery = stmt.String()
}
}
return
}
func getSuffix(value any, id int64, keyFind bool, r Config) (suffix string, err error) {
if keyFind {
suffix, err = r.ShardingAlgorithm(value)
if err != nil {
return
}
} else {
if r.ShardingAlgorithmByPrimaryKey == nil {
err = fmt.Errorf("there is not sharding key and ShardingAlgorithmByPrimaryKey is not configured")
return
}
suffix = r.ShardingAlgorithmByPrimaryKey(id)
}
return
}
func (s *Sharding) insertValue(key string, names []*sqlparser.Ident, exprs []sqlparser.Expr, args ...any) (value any, id int64, keyFind bool, err error) {
if len(names) != len(exprs) {
return nil, 0, keyFind, errors.New("column names and expressions mismatch")
}
for i, name := range names {
if name.Name == key {
switch expr := exprs[i].(type) {
case *sqlparser.BindExpr:
value = args[expr.Pos]
case *sqlparser.StringLit:
value = expr.Value
case *sqlparser.NumberLit:
value = expr.Value
default:
return nil, 0, keyFind, sqlparser.ErrNotImplemented
}
keyFind = true
break
}
}
if !keyFind {
return nil, 0, keyFind, ErrMissingShardingKey
}
return
}
func (s *Sharding) nonInsertValue(key string, condition sqlparser.Expr, args ...any) (value any, id int64, keyFind bool, err error) {
err = sqlparser.Walk(sqlparser.VisitFunc(func(node sqlparser.Node) error {
if n, ok := node.(*sqlparser.BinaryExpr); ok {
if x, ok := n.X.(*sqlparser.Ident); ok {
if x.Name == key && n.Op == sqlparser.EQ {
keyFind = true
switch expr := n.Y.(type) {
case *sqlparser.BindExpr:
value = args[expr.Pos]
case *sqlparser.StringLit:
value = expr.Value
case *sqlparser.NumberLit:
value = expr.Value
default:
return sqlparser.ErrNotImplemented
}
return nil
} else if x.Name == "id" && n.Op == sqlparser.EQ {
switch expr := n.Y.(type) {
case *sqlparser.BindExpr:
v := args[expr.Pos]
var ok bool
if id, ok = v.(int64); !ok {
return fmt.Errorf("ID should be int64 type")
}
case *sqlparser.NumberLit:
id, err = strconv.ParseInt(expr.Value, 10, 64)
if err != nil {
return err
}
default:
return ErrInvalidID
}
return nil
}
}
}
return nil
}), condition)
if err != nil {
return
}
if !keyFind && id == 0 {
return nil, 0, keyFind, ErrMissingShardingKey
}
return
}
func replaceOrderByTableName(orderBy []*sqlparser.OrderingTerm, oldName, newName string) []*sqlparser.OrderingTerm {
for i, term := range orderBy {
if x, ok := term.X.(*sqlparser.QualifiedRef); ok {
if x.Table.Name == oldName {
x.Table.Name = newName
orderBy[i].X = x
}
}
}
return orderBy
}
Go
1
https://gitee.com/hellochuang/gorm-shard.git
git@gitee.com:hellochuang/gorm-shard.git
hellochuang
gorm-shard
gorm-shard
v1.0.0

搜索帮助

53164aa7 5694891 3bd8fe86 5694891