1 Star 0 Fork 0

chuang / gorm-shard

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
sharding.go 15.21 KB
一键复制 编辑 原始数据 按行查看 历史
chuang 提交于 2024-02-29 10:29 . 搬运工
package sharding
import (
var (
ErrMissingShardingKey = errors.New("sharding key or id required, and use operator =")
ErrInvalidID = errors.New("invalid id format")
ErrInsertDiffSuffix = errors.New("can not insert different suffix table in one query ")
var (
ShardingIgnoreStoreKey = "sharding_ignore"
type Sharding struct {
ConnPool *ConnPool
configs map[string]Config
querys sync.Map
snowflakeNodes []*snowflake.Node
_config Config
_tables []any
mutex sync.RWMutex
// Config specifies the configuration for sharding.
type Config struct {
// 是否自动填充id
FillId bool
// When DoubleWrite enabled, data will double write to both main table and sharding table.
DoubleWrite bool
// ShardingKey specifies the table column you want to used for sharding the table rows.
// For example, for a product order table, you may want to split the rows by `user_id`.
ShardingKey string
// NumberOfShards specifies how many tables you want to sharding.
NumberOfShards uint
// tableFormat specifies the sharding table suffix format.
tableFormat string
// ShardingAlgorithm specifies a function to generate the sharding
// table's suffix by the column value.
// For example, this function implements a mod sharding algorithm.
// func(value any) (suffix string, err error) {
// if uid, ok := value.(int64);ok {
// return fmt.Sprintf("_%02d", user_id % 64), nil
// }
// return "", errors.New("invalid user_id")
// }
ShardingAlgorithm func(columnValue any) (suffix string, err error)
// ShardingSuffixs specifies a function to generate all table's suffix.
// Used to support Migrator and generate PrimaryKey.
// For example, this function get a mod all sharding suffixs.
// func () (suffixs []string) {
// numberOfShards := 5
// for i := 0; i < numberOfShards; i++ {
// suffixs = append(suffixs, fmt.Sprintf("_%02d", i%numberOfShards))
// }
// return
// }
ShardingSuffixs func() (suffixs []string)
// ShardingAlgorithmByPrimaryKey specifies a function to generate the sharding
// table's suffix by the primary key. Used when no sharding key specified.
// For example, this function use the Snowflake library to generate the suffix.
// func(id int64) (suffix string) {
// return fmt.Sprintf("_%02d", snowflake.ParseInt64(id).Node())
// }
ShardingAlgorithmByPrimaryKey func(id int64) (suffix string)
// PrimaryKeyGenerator specifies the primary key generate algorithm.
// Used only when insert and the record does not contains an id field.
// Options are PKSnowflake, PKPGSequence and PKCustom.
// When use PKCustom, you should also specify PrimaryKeyGeneratorFn.
PrimaryKeyGenerator int
// PrimaryKeyGeneratorFn specifies a function to generate the primary key.
// When use auto-increment like generator, the tableIdx argument could ignored.
// For example, this function use the Snowflake library to generate the primary key.
// If you don't want to auto-fill the `id` or use a primary key that isn't called `id`, just return 0.
// func(tableIdx int64) int64 {
// return nodes[tableIdx].Generate().Int64()
// }
PrimaryKeyGeneratorFn func(tableIdx int64) int64
func Register(config Config, tables ...any) *Sharding {
return &Sharding{
_config: config,
_tables: tables,
func (s *Sharding) compile() error {
if s.configs == nil {
s.configs = make(map[string]Config)
for _, table := range s._tables {
if t, ok := table.(string); ok {
s.configs[t] = s._config
} else {
stmt := &gorm.Statement{DB: s.DB}
if err := stmt.Parse(table); err == nil {
s.configs[stmt.Table] = s._config
} else {
return err
for t, c := range s.configs {
if c.NumberOfShards > 1024 && c.PrimaryKeyGenerator == PKSnowflake {
panic("Snowflake NumberOfShards should less than 1024")
if c.PrimaryKeyGenerator == PKSnowflake {
c.PrimaryKeyGeneratorFn = s.genSnowflakeKey
} else if c.PrimaryKeyGenerator == PKPGSequence {
// Execute SQL to CREATE SEQUENCE for this table if not exist
err := s.createPostgreSQLSequenceKeyIfNotExist(t)
if err != nil {
return err
c.PrimaryKeyGeneratorFn = func(index int64) int64 {
return s.genPostgreSQLSequenceKey(t, index)
} else if c.PrimaryKeyGenerator == PKMySQLSequence {
err := s.createMySQLSequenceKeyIfNotExist(t)
if err != nil {
return err
c.PrimaryKeyGeneratorFn = func(index int64) int64 {
return s.genMySQLSequenceKey(t, index)
} else if c.PrimaryKeyGenerator == PKCustom {
if c.PrimaryKeyGeneratorFn == nil {
return errors.New("PrimaryKeyGeneratorFn is required when use PKCustom")
} else {
return errors.New("PrimaryKeyGenerator can only be one of PKSnowflake, PKPGSequence, PKMySQLSequence and PKCustom")
if c.ShardingAlgorithm == nil {
if c.NumberOfShards == 0 {
return errors.New("specify NumberOfShards or ShardingAlgorithm")
if c.NumberOfShards < 10 {
c.tableFormat = "_%01d"
} else if c.NumberOfShards < 100 {
c.tableFormat = "_%02d"
} else if c.NumberOfShards < 1000 {
c.tableFormat = "_%03d"
} else if c.NumberOfShards < 10000 {
c.tableFormat = "_%04d"
c.ShardingAlgorithm = func(value any) (suffix string, err error) {
id := 0
switch value := value.(type) {
case int:
id = value
case int64:
id = int(value)
case string:
id, err = strconv.Atoi(value)
if err != nil {
id = int(crc32.ChecksumIEEE([]byte(value)))
return "", fmt.Errorf("default algorithm only support integer and string column," +
"if you use other type, specify you own ShardingAlgorithm")
return fmt.Sprintf(c.tableFormat, id%int(c.NumberOfShards)), nil
if c.ShardingSuffixs == nil {
c.ShardingSuffixs = func() (suffixs []string) {
for i := 0; i < int(c.NumberOfShards); i++ {
suffix, err := c.ShardingAlgorithm(i)
if err != nil {
return nil
suffixs = append(suffixs, suffix)
if c.ShardingAlgorithmByPrimaryKey == nil {
if c.PrimaryKeyGenerator == PKSnowflake {
c.ShardingAlgorithmByPrimaryKey = func(id int64) (suffix string) {
return fmt.Sprintf(c.tableFormat, snowflake.ParseInt64(id).Node())
s.configs[t] = c
return nil
// Name plugin name for Gorm plugin interface
func (s *Sharding) Name() string {
return "gorm:sharding"
// LastQuery get last SQL query
func (s *Sharding) LastQuery() string {
if query, ok := s.querys.Load("last_query"); ok {
return query.(string)
return ""
// Initialize implement for Gorm plugin interface
func (s *Sharding) Initialize(db *gorm.DB) error {
db.Dialector = NewShardingDialector(db.Dialector, s)
s.DB = db
for t, c := range s.configs {
if c.PrimaryKeyGenerator == PKPGSequence {
err := s.DB.Exec("CREATE SEQUENCE IF NOT EXISTS " + pgSeqName(t)).Error
if err != nil {
return fmt.Errorf("init postgresql sequence error, %w", err)
if c.PrimaryKeyGenerator == PKMySQLSequence {
err := s.DB.Exec("CREATE TABLE IF NOT EXISTS " + mySQLSeqName(t) + " (id INT NOT NULL)").Error
if err != nil {
return fmt.Errorf("init mysql create sequence error, %w", err)
err = s.DB.Exec("INSERT INTO " + mySQLSeqName(t) + " VALUES (0)").Error
if err != nil {
return fmt.Errorf("init mysql insert sequence error, %w", err)
s.snowflakeNodes = make([]*snowflake.Node, 1024)
for i := int64(0); i < 1024; i++ {
n, err := snowflake.NewNode(i)
if err != nil {
return fmt.Errorf("init snowflake node error, %w", err)
s.snowflakeNodes[i] = n
return s.compile()
func (s *Sharding) registerCallbacks(db *gorm.DB) {
s.Callback().Create().Before("*").Register("gorm:sharding", s.switchConn)
s.Callback().Query().Before("*").Register("gorm:sharding", s.switchConn)
s.Callback().Update().Before("*").Register("gorm:sharding", s.switchConn)
s.Callback().Delete().Before("*").Register("gorm:sharding", s.switchConn)
s.Callback().Row().Before("*").Register("gorm:sharding", s.switchConn)
s.Callback().Raw().Before("*").Register("gorm:sharding", s.switchConn)
func (s *Sharding) switchConn(db *gorm.DB) {
// Support ignore sharding in some case, like:
// When DoubleWrite is enabled, we need to query database schema
// information by table name during the migration.
if _, ok := db.Get(ShardingIgnoreStoreKey); !ok {
if db.Statement.ConnPool != nil {
s.ConnPool = &ConnPool{ConnPool: db.Statement.ConnPool, sharding: s}
db.Statement.ConnPool = s.ConnPool
// resolve split the old query to full table query and sharding table query
func (s *Sharding) resolve(query string, args ...any) (ftQuery, stQuery, tableName string, err error) {
ftQuery = query
stQuery = query
if len(s.configs) == 0 {
expr, err := sqlparser.NewParser(strings.NewReader(query)).ParseStatement()
if err != nil {
return ftQuery, stQuery, tableName, nil
var table *sqlparser.TableName
var condition sqlparser.Expr
var isInsert bool
var insertNames []*sqlparser.Ident
var insertExpressions []*sqlparser.Exprs
var insertStmt *sqlparser.InsertStatement
switch stmt := expr.(type) {
case *sqlparser.SelectStatement:
tbl, ok := stmt.FromItems.(*sqlparser.TableName)
if !ok {
if stmt.Hint != nil && stmt.Hint.Value == "nosharding" {
table = tbl
condition = stmt.Condition
case *sqlparser.InsertStatement:
table = stmt.TableName
isInsert = true
insertNames = stmt.ColumnNames
insertExpressions = stmt.Expressions
insertStmt = stmt
case *sqlparser.UpdateStatement:
condition = stmt.Condition
table = stmt.TableName
case *sqlparser.DeleteStatement:
condition = stmt.Condition
table = stmt.TableName
return ftQuery, stQuery, "", sqlparser.ErrNotImplemented
tableName = table.Name.Name
r, ok := s.configs[tableName]
if !ok {
var suffix string
if isInsert {
var newTable *sqlparser.TableName
for _, insertExpression := range insertExpressions {
var value any
var id int64
var keyFind bool
columnNames := insertNames
insertValues := insertExpression.Exprs
value, id, keyFind, err = s.insertValue(r.ShardingKey, insertNames, insertValues, args...)
if err != nil {
var subSuffix string
subSuffix, err = getSuffix(value, id, keyFind, r)
if err != nil {
if suffix != "" && suffix != subSuffix {
err = ErrInsertDiffSuffix
suffix = subSuffix
newTable = &sqlparser.TableName{Name: &sqlparser.Ident{Name: tableName + suffix}}
fillID := s._config.FillId
if isInsert && fillID {
for _, name := range insertNames {
if name.Name == "id" {
fillID = false
suffixWord := strings.Replace(suffix, "_", "", 1)
tblIdx, err := strconv.Atoi(suffixWord)
if err != nil {
tblIdx = slices.Index(r.ShardingSuffixs(), suffixWord)
if tblIdx == -1 {
return ftQuery, stQuery, tableName, errors.New("table suffix '" + suffixWord + "' is not in ShardingSuffixs. In order to generate the primary key, ShardingSuffixs should include all table suffixes")
//return ftQuery, stQuery, tableName, err
id := r.PrimaryKeyGeneratorFn(int64(tblIdx))
if id == 0 {
fillID = false
if fillID {
columnNames = append(insertNames, &sqlparser.Ident{Name: "id"})
insertValues = append(insertValues, &sqlparser.NumberLit{Value: strconv.FormatInt(id, 10)})
if fillID {
insertStmt.ColumnNames = columnNames
insertExpression.Exprs = insertValues
ftQuery = insertStmt.String()
insertStmt.TableName = newTable
stQuery = insertStmt.String()
} else {
var value any
var id int64
var keyFind bool
value, id, keyFind, err = s.nonInsertValue(r.ShardingKey, condition, args...)
if err != nil {
suffix, err = getSuffix(value, id, keyFind, r)
if err != nil {
newTable := &sqlparser.TableName{Name: &sqlparser.Ident{Name: tableName + suffix}}
switch stmt := expr.(type) {
case *sqlparser.SelectStatement:
ftQuery = stmt.String()
stmt.FromItems = newTable
stmt.OrderBy = replaceOrderByTableName(stmt.OrderBy, tableName, newTable.Name.Name)
stQuery = stmt.String()
case *sqlparser.UpdateStatement:
ftQuery = stmt.String()
stmt.TableName = newTable
stQuery = stmt.String()
case *sqlparser.DeleteStatement:
ftQuery = stmt.String()
stmt.TableName = newTable
stQuery = stmt.String()
func getSuffix(value any, id int64, keyFind bool, r Config) (suffix string, err error) {
if keyFind {
suffix, err = r.ShardingAlgorithm(value)
if err != nil {
} else {
if r.ShardingAlgorithmByPrimaryKey == nil {
err = fmt.Errorf("there is not sharding key and ShardingAlgorithmByPrimaryKey is not configured")
suffix = r.ShardingAlgorithmByPrimaryKey(id)
func (s *Sharding) insertValue(key string, names []*sqlparser.Ident, exprs []sqlparser.Expr, args ...any) (value any, id int64, keyFind bool, err error) {
if len(names) != len(exprs) {
return nil, 0, keyFind, errors.New("column names and expressions mismatch")
for i, name := range names {
if name.Name == key {
switch expr := exprs[i].(type) {
case *sqlparser.BindExpr:
value = args[expr.Pos]
case *sqlparser.StringLit:
value = expr.Value
case *sqlparser.NumberLit:
value = expr.Value
return nil, 0, keyFind, sqlparser.ErrNotImplemented
keyFind = true
if !keyFind {
return nil, 0, keyFind, ErrMissingShardingKey
func (s *Sharding) nonInsertValue(key string, condition sqlparser.Expr, args ...any) (value any, id int64, keyFind bool, err error) {
err = sqlparser.Walk(sqlparser.VisitFunc(func(node sqlparser.Node) error {
if n, ok := node.(*sqlparser.BinaryExpr); ok {
if x, ok := n.X.(*sqlparser.Ident); ok {
if x.Name == key && n.Op == sqlparser.EQ {
keyFind = true
switch expr := n.Y.(type) {
case *sqlparser.BindExpr:
value = args[expr.Pos]
case *sqlparser.StringLit:
value = expr.Value
case *sqlparser.NumberLit:
value = expr.Value
return sqlparser.ErrNotImplemented
return nil
} else if x.Name == "id" && n.Op == sqlparser.EQ {
switch expr := n.Y.(type) {
case *sqlparser.BindExpr:
v := args[expr.Pos]
var ok bool
if id, ok = v.(int64); !ok {
return fmt.Errorf("ID should be int64 type")
case *sqlparser.NumberLit:
id, err = strconv.ParseInt(expr.Value, 10, 64)
if err != nil {
return err
return ErrInvalidID
return nil
return nil
}), condition)
if err != nil {
if !keyFind && id == 0 {
return nil, 0, keyFind, ErrMissingShardingKey
func replaceOrderByTableName(orderBy []*sqlparser.OrderingTerm, oldName, newName string) []*sqlparser.OrderingTerm {
for i, term := range orderBy {
if x, ok := term.X.(*sqlparser.QualifiedRef); ok {
if x.Table.Name == oldName {
x.Table.Name = newName
orderBy[i].X = x
return orderBy


53164aa7 5694891 3bd8fe86 5694891