1 Star 0 Fork 0

Wsage/go-framework

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
db_version_mag.go 4.97 KB
一键复制 编辑 原始数据 按行查看 历史
王少奇 提交于 2021-08-22 21:37 . 新增 乐观锁和版本管理
package handler
import (
"database/sql"
"fmt"
"gitee.com/scottq/go-framework/src/utils"
"io/ioutil"
"strings"
)
const VERSION_MAG_TABLE = "migrate_version"
var CREATE_VERSION_TABLE_SQL = fmt.Sprintf(`
CREATE TABLE IF NOT EXISTS %s
(
id bigint unsigned NOT NULL AUTO_INCREMENT,
created_at timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP,
label varchar(20) COLLATE utf8mb4_unicode_ci NOT NULL DEFAULT '' COMMENT 'label',
version varchar(20) COLLATE utf8mb4_unicode_ci NOT NULL DEFAULT '' COMMENT '版本号',
upgrade_file varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL DEFAULT '' COMMENT '升级文件',
PRIMARY KEY (id),
UNIQUE KEY idx_version (label,version)
) ENGINE = InnoDB AUTO_INCREMENT = 1 DEFAULT CHARSET = utf8mb4 COLLATE = utf8mb4_unicode_ci COMMENT ='升级表'
`, VERSION_MAG_TABLE)
type logFunc = func(info string, err error)
type ILogger interface {
AddLogger(logger logFunc)
}
type IVersionMag interface {
Upgrade(version string) error
UpgradeOne(version string, sqlFile string) error
Version() (string, error)
SetLabel(label string)
Label() string
ILogger
}
type VersionMag struct {
versionDir string
db *sql.DB
label string
optLock IOptimisticLock
_logger logFunc
}
func NewDBVersionMag(sqlDir string, db *sql.DB) (IVersionMag, error) {
lock, err := NewOptimisticLock(db, 60)
if err != nil {
return nil, err
}
mag := &VersionMag{
db: db,
versionDir: sqlDir,
optLock: lock,
}
if err := mag.init(); err != nil {
return nil, err
}
return mag, nil
}
func (this *VersionMag) AddLogger(logger logFunc) {
this._logger = logger
}
func (this *VersionMag) SetLabel(label string) {
this.label = label
}
func (this *VersionMag) Label() string {
return this.label
}
func (this *VersionMag) lockName() string {
return "sql_upgrade." + this.label
}
func (this *VersionMag) logInfo(info string) {
if this._logger == nil {
fmt.Println(info)
return
}
this._logger(info, nil)
}
func (this *VersionMag) logError(err error) {
if this._logger == nil {
fmt.Println(err.Error())
return
}
this._logger("", err)
}
//升级版本
func (this VersionMag) Upgrade(version string) error {
this.logInfo("version upgrade to " + version)
lockName := this.lockName()
if !this.optLock.Lock(lockName) {
this.logInfo("upgrade job has locked: " + lockName)
return nil
}
defer this.optLock.UnLock(lockName)
sqlDir := this.versionDir
files, err := ioutil.ReadDir(sqlDir)
if err != nil {
this.logError(err)
panic(err.Error())
}
//
for _, x := range files {
if x.IsDir() {
continue
}
sqlFile := x.Name()
sqlFileFull := sqlDir + "/" + sqlFile
upVersion := strings.ReplaceAll(sqlFile, ".sql", "")
upVersion = strings.TrimLeft(upVersion, "v")
if utils.CompareVersion(upVersion, version) > 0 {
continue
}
err := this.UpgradeOne(upVersion, sqlFileFull)
if err != nil {
panic(err.Error())
return err
}
}
return nil
}
//升级一次版本
func (this *VersionMag) UpgradeOne(upVersion string, sqlFile string) error {
var err error
version, err := this.Version()
if err != nil {
return err
}
//无需处理
if utils.CompareVersion(upVersion, version) <= 0 {
this.logInfo("no need upgrade: v" + upVersion)
return nil
}
//执行升级文件
bytes, err := ioutil.ReadFile(sqlFile)
if err != nil {
this.logError(err)
return err
}
execSql := string(bytes)
err = this.upgradeSql(upVersion, execSql, sqlFile)
if err != nil {
return err
}
this.logInfo(fmt.Sprintf("upgrade success v%s => v%s", version, upVersion))
return nil
}
func (this *VersionMag) upgradeSql(upVersion string, sqlContent string, comment string) error {
var err error
_, err = this.db.Exec(sqlContent)
if err != nil {
this.logError(err)
return err
}
//保存version信息
err = this.addVersionRecord(this.label, upVersion, comment)
if err != nil {
this.logError(err)
return err
}
return nil
}
func (this *VersionMag) addVersionRecord(label, version, sqlFile string) error {
//保存version信息
insertSql := fmt.Sprintf("INSERT INTO %s SET `label`=?,`version`=?,`upgrade_file`=?", VERSION_MAG_TABLE)
stmt, err := this.db.Prepare(insertSql)
if err != nil {
this.logError(err)
return err
}
_, err = stmt.Exec(label, version, sqlFile)
if err != nil {
this.logError(err)
return err
}
return nil
}
func (this VersionMag) init() error {
var err error
_, err = this.db.Exec(CREATE_VERSION_TABLE_SQL)
if err != nil {
this.logError(err)
return err
}
return nil
}
//获取当前版本信息
func (this VersionMag) Version() (string, error) {
var err error
querySql := fmt.Sprintf("SELECT version FROM %s WHERE label=? ORDER BY id DESC LIMIT 1", VERSION_MAG_TABLE)
row := this.db.QueryRow(querySql, this.label)
var version string
err = row.Scan(&version)
switch err {
case nil:
case sql.ErrNoRows:
default:
this.logError(err)
return "", err
}
if version == "" {
version = "0.0.0"
}
return version, nil
}
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Go
1
https://gitee.com/scottq/go-framework.git
git@gitee.com:scottq/go-framework.git
scottq
go-framework
go-framework
v1.1.25

搜索帮助

0d507c66 1850385 C8b1a773 1850385