1 Star 0 Fork 0

ltotal/seata-go

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
basic_undo_log_builder.go 7.89 KB
一键复制 编辑 原始数据 按行查看 历史
ltotal 提交于 2024-05-30 10:15 . 初始化提交
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package builder
import (
"bytes"
"database/sql"
"database/sql/driver"
"fmt"
"io"
"strings"
"github.com/arana-db/parser/ast"
"github.com/arana-db/parser/test_driver"
gxsort "github.com/dubbogo/gost/sort"
"gitee.com/ltotal/seata-go/pkg/datasource/sql/types"
)
// todo the executor should be stateful
type BasicUndoLogBuilder struct{}
// GetScanSlice get the column type for scann
// todo to use ColumnInfo get slice
func (*BasicUndoLogBuilder) GetScanSlice(columnNames []string, tableMeta *types.TableMeta) []driver.Value {
scanSlice := make([]driver.Value, 0, len(columnNames))
for _, columnNmae := range columnNames {
var (
scanVal interface{}
// 从metData获取该列的元信息
columnMeta = tableMeta.Columns[columnNmae]
)
switch strings.ToUpper(columnMeta.DatabaseTypeString) {
case "VARCHAR", "NVARCHAR", "VARCHAR2", "CHAR", "TEXT", "JSON", "TINYTEXT":
scanVal = sql.RawBytes{}
case "BIT", "INT", "LONGBLOB", "SMALLINT", "TINYINT", "BIGINT", "MEDIUMINT":
if columnMeta.IsNullable == 0 {
scanVal = int64(0)
} else {
scanVal = sql.NullInt64{}
}
case "DATE", "DATETIME", "TIME", "TIMESTAMP", "YEAR":
scanVal = sql.NullTime{}
case "DECIMAL", "DOUBLE", "FLOAT":
if columnMeta.IsNullable == 0 {
scanVal = float64(0)
} else {
scanVal = sql.NullFloat64{}
}
default:
scanVal = sql.RawBytes{}
}
scanSlice = append(scanSlice, &scanVal)
}
return scanSlice
}
func (b *BasicUndoLogBuilder) buildSelectArgs(stmt *ast.SelectStmt, args []driver.Value) []driver.Value {
var (
selectArgsIndexs = make([]int32, 0)
selectArgs = make([]driver.Value, 0)
)
b.traversalArgs(stmt.Where, &selectArgsIndexs)
if stmt.OrderBy != nil {
for _, item := range stmt.OrderBy.Items {
b.traversalArgs(item, &selectArgsIndexs)
}
}
if stmt.Limit != nil {
if stmt.Limit.Offset != nil {
b.traversalArgs(stmt.Limit.Offset, &selectArgsIndexs)
}
if stmt.Limit.Count != nil {
b.traversalArgs(stmt.Limit.Count, &selectArgsIndexs)
}
}
// sort selectArgs index array
gxsort.Int32(selectArgsIndexs)
for _, index := range selectArgsIndexs {
selectArgs = append(selectArgs, args[index])
}
return selectArgs
}
// todo perfect all sql operation
func (b *BasicUndoLogBuilder) traversalArgs(node ast.Node, argsIndex *[]int32) {
if node == nil {
return
}
switch node.(type) {
case *ast.BinaryOperationExpr:
expr := node.(*ast.BinaryOperationExpr)
b.traversalArgs(expr.L, argsIndex)
b.traversalArgs(expr.R, argsIndex)
break
case *ast.BetweenExpr:
expr := node.(*ast.BetweenExpr)
b.traversalArgs(expr.Left, argsIndex)
b.traversalArgs(expr.Right, argsIndex)
break
case *ast.PatternInExpr:
exprs := node.(*ast.PatternInExpr).List
for i := 0; i < len(exprs); i++ {
b.traversalArgs(exprs[i], argsIndex)
}
break
case *test_driver.ParamMarkerExpr:
*argsIndex = append(*argsIndex, int32(node.(*test_driver.ParamMarkerExpr).Order))
break
}
}
func (b *BasicUndoLogBuilder) buildRecordImages(rowsi driver.Rows, tableMetaData *types.TableMeta) (*types.RecordImage, error) {
// select column names
columnNames := rowsi.Columns()
rowImages := make([]types.RowImage, 0)
ss := b.GetScanSlice(columnNames, tableMetaData)
for {
err := rowsi.Next(ss)
if err == io.EOF {
break
}
if err != nil {
return nil, err
}
columns := make([]types.ColumnImage, 0)
// build record image
for i, name := range columnNames {
columnMeta := tableMetaData.Columns[name]
keyType := types.IndexTypeNull
if _, ok := tableMetaData.GetPrimaryKeyMap()[name]; ok {
keyType = types.IndexTypePrimaryKey
}
jdbcType := types.MySQLStrToJavaType(columnMeta.DatabaseTypeString)
columns = append(columns, types.ColumnImage{
KeyType: keyType,
ColumnName: name,
ColumnType: jdbcType,
Value: ss[i],
})
}
rowImages = append(rowImages, types.RowImage{Columns: columns})
}
return &types.RecordImage{TableName: tableMetaData.TableName, Rows: rowImages}, nil
}
// buildWhereConditionByPKs build where condition by primary keys
// each pk is a condition.the result will like :" (id,userCode) in ((?,?),(?,?)) or (id,userCode) in ((?,?),(?,?) ) or (id,userCode) in ((?,?))"
func (b *BasicUndoLogBuilder) buildWhereConditionByPKs(pkNameList []string, rowSize int, dbType string, maxInSize int) string {
var (
whereStr = &strings.Builder{}
batchSize = rowSize/maxInSize + 1
)
if rowSize%maxInSize == 0 {
batchSize = rowSize / maxInSize
}
for batch := 0; batch < batchSize; batch++ {
if batch > 0 {
whereStr.WriteString(" OR ")
}
whereStr.WriteString("(")
for i := 0; i < len(pkNameList); i++ {
if i > 0 {
whereStr.WriteString(",")
}
// todo add escape
whereStr.WriteString(fmt.Sprintf("`%s`", pkNameList[i]))
}
whereStr.WriteString(") IN (")
var eachSize int
if batch == batchSize-1 {
if rowSize%maxInSize == 0 {
eachSize = maxInSize
} else {
eachSize = rowSize % maxInSize
}
} else {
eachSize = maxInSize
}
for i := 0; i < eachSize; i++ {
if i > 0 {
whereStr.WriteString(",")
}
whereStr.WriteString("(")
for j := 0; j < len(pkNameList); j++ {
if j > 0 {
whereStr.WriteString(",")
}
whereStr.WriteString("?")
}
whereStr.WriteString(")")
}
whereStr.WriteString(")")
}
return whereStr.String()
}
func (b *BasicUndoLogBuilder) buildPKParams(rows []types.RowImage, pkNameList []string) []driver.Value {
params := make([]driver.Value, 0)
for _, row := range rows {
coumnMap := row.GetColumnMap()
for _, pk := range pkNameList {
col := coumnMap[pk]
if col != nil {
params = append(params, col.Value)
}
}
}
return params
}
// the string as local key. the local key example(multi pk): "t_user:1_a,2_b"
func (b *BasicUndoLogBuilder) buildLockKey(rows driver.Rows, meta types.TableMeta) string {
var (
lockKeys bytes.Buffer
filedSequence int
)
lockKeys.WriteString(meta.TableName)
lockKeys.WriteString(":")
pks := b.GetScanSlice(meta.GetPrimaryKeyOnlyName(), &meta)
for {
err := rows.Next(pks)
if err == io.EOF {
break
}
if filedSequence > 0 {
lockKeys.WriteString(",")
}
pkSplitIndex := 0
for _, value := range pks {
if pkSplitIndex > 0 {
lockKeys.WriteString("_")
}
lockKeys.WriteString(fmt.Sprintf("%v", value))
pkSplitIndex++
}
filedSequence++
}
return lockKeys.String()
}
// the string as local key. the local key example(multi pk): "t_user:1_a,2_b"
func (b *BasicUndoLogBuilder) buildLockKey2(records *types.RecordImage, meta types.TableMeta) string {
var (
lockKeys bytes.Buffer
filedSequence int
)
lockKeys.WriteString(meta.TableName)
lockKeys.WriteString(":")
keys := meta.GetPrimaryKeyOnlyName()
for _, row := range records.Rows {
if filedSequence > 0 {
lockKeys.WriteString(",")
}
pkSplitIndex := 0
for _, column := range row.Columns {
var hasKeyColumn bool
for _, key := range keys {
if column.ColumnName == key {
hasKeyColumn = true
if pkSplitIndex > 0 {
lockKeys.WriteString("_")
}
lockKeys.WriteString(fmt.Sprintf("%v", column.Value))
pkSplitIndex++
}
}
if hasKeyColumn {
filedSequence++
}
}
}
return lockKeys.String()
}
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/ltotal/seata-go.git
git@gitee.com:ltotal/seata-go.git
ltotal
seata-go
seata-go
v1.2.1

搜索帮助