代码拉取完成,页面将自动刷新
package gweb
import (
"database/sql"
"fmt"
"reflect"
"slices"
"strings"
"gitee.com/makitdone/gweb/v2/conf"
"gitee.com/makitdone/gx/maps"
"gitee.com/makitdone/gx/refl"
"github.com/gookit/goutil/strutil"
"gorm.io/driver/mysql"
"gorm.io/gorm"
"gorm.io/gorm/schema"
)
// Connect 连接数据库
func Connect(config conf.DbConfig) *DB {
username := config.UserName
password := config.Password
host := config.Host
port := config.Port
dbname := config.DbName
// timeout := 10
var err error
dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8&parseTime=True&loc=Local", username, password, host, port, dbname)
db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{
NamingStrategy: schema.NamingStrategy{
SingularTable: true, // 禁用表名加s
},
})
db = db.Session(&gorm.Session{NewDB: false})
if config.Debug {
db = db.Debug()
}
if err != nil {
panic("连接数据库失败," + err.Error())
}
return newDB(db)
}
const (
JoinTypeDB = 1
JoinTypeStr = 2
)
type joinArg struct {
Name string
Type int
Query string
Selects []string // 所需字段,为空时表示全部
Args []interface{} // 原参数
}
type Pagination[T any] struct {
PageIndex int `json:"pi"`
PageSize int `json:"ps"`
List T `json:"list"`
Total int64 `json:"total"`
}
func CloneDB(db *gorm.DB) *gorm.DB {
return db.Session(&gorm.Session{
NewDB: false,
PrepareStmt: true,
})
}
// func (db *DB) Model(value any) *DB {
// return &DB{
// DB: db.DB.Model(value),
// joins: []joinArg{},
// PanicIfError: false,
// }
// }
type DB struct {
*gorm.DB
tx *gorm.DB
joins []joinArg
groupFields []string
PanicIfError bool
}
func newDB(base *gorm.DB) *DB {
return &DB{DB: base, tx: base, joins: []joinArg{}, PanicIfError: false}
}
// 克隆对象
func (db *DB) Clone() *DB {
joins := db.joins
return &DB{DB: CloneDB(db.DB), tx: db.tx, joins: joins, PanicIfError: db.PanicIfError}
}
func (db *DB) Model(value any) *DB {
return &DB{
DB: db.tx.Model(value),
tx: db.tx,
joins: []joinArg{},
PanicIfError: false,
}
}
// 遇到 error 直接 panic,不是底层开发就没必要不厌其烦得判断 error 返回,除了增加代码量,没啥意义
func (db *DB) Panic() *DB {
db = db.Clone()
db.PanicIfError = true
return db
}
func (db *DB) Table(name string) *DB {
db = db.Clone()
db.DB = db.DB.Table(name)
return db
}
// WHERE 查询,在原版的基础上支持 map,比如:
//
// db.Where(map[string]any{
// "age": 20
// })
//
// db.Where(map[string]any{
// "age > ?": 20,
// })
//
// 注意 "age > ?" 不能去掉空格写成 "age>?"
//
// 有些情况下,updates的map也需要作为判断条件查询数据,此时可以直接用,nil 会处理为 IS NULL
func (db *DB) Where(query any, args ...any) *DB {
db = db.Clone()
if m, ok := query.(map[string]interface{}); ok {
for k, v := range m {
vStr := fmt.Sprintf("%v", v)
if v == "" {
db.DB = db.DB.Where(k)
} else if vStr == "<nil>" {
db.DB = db.DB.Where(fmt.Sprintf("`%s` IS NULL", k))
} else if strings.ContainsAny(k, " .") {
db.DB = db.DB.Where(k, v)
continue
} else {
db.DB = db.DB.Where(fmt.Sprintf("`%s`=?", k), v)
}
}
} else {
db.DB = db.DB.Where(query, args...)
}
return db
}
func (db *DB) Or(query interface{}, args ...interface{}) *DB {
db = db.Clone()
db.DB = db.DB.Or(query, args...)
return db
}
// 加强版的Joins,支持Select字段追加,比如:
//
// db.Joins("Project", DB.Select("id", "name"))
//
// db.Joins("Project", DB.Select("fullname"))
//
// 相当于:db.Joins("Project", DB.Select("id", "name", "fullname"))
//
// 不至于使后写的冲掉前面的,有些逻辑场景下,查询但不需要Select字段,有些则需要,
// 但两个逻辑同时出现在一个复杂逻辑中, 就不允许相互冲掉
func (db *DB) Joins(query string, args ...interface{}) *DB {
db = db.Clone()
joins := []joinArg{}
_args := []any{}
for _, arg := range args {
t := refl.StructType(arg)
if t.Kind() == reflect.Struct {
structTypeStr := t.String()
if structTypeStr == "gweb.DB" {
argDB, _ := refl.GetFieldValue(arg, "DB")
_args = append(_args, argDB)
continue
}
}
_args = append(_args, arg)
}
join := db.tx.Model(db.DB.Statement.Model).Joins(query, _args...).Statement.Joins[0]
found := false
joinType := JoinTypeDB
if strings.Contains(strings.ToUpper(query), "JOIN ") {
joinType = JoinTypeStr
}
for _, item := range db.joins {
j := item
if j.Name == join.Name {
found = true
if joinType == JoinTypeDB {
// 合并selects
selects := append(j.Selects, join.Selects...)
selects = slices.Compact(selects)
joins = append(joins, joinArg{Name: j.Name, Query: query, Type: JoinTypeDB, Selects: selects})
} else {
joins = append(joins, joinArg{Name: join.Name, Type: JoinTypeStr, Query: query, Args: _args})
}
} else {
if j.Type == JoinTypeDB {
joins = append(joins, joinArg{Name: j.Name, Type: JoinTypeDB, Query: j.Query, Selects: j.Selects})
} else {
joins = append(joins, joinArg{Name: j.Name, Type: JoinTypeStr, Query: j.Query, Args: j.Args})
}
}
}
if !found {
if joinType == JoinTypeDB {
joins = append(joins, joinArg{Name: join.Name, Query: query, Type: JoinTypeDB, Selects: join.Selects})
} else {
joins = append(joins, joinArg{Name: join.Name, Query: query, Type: JoinTypeStr, Args: _args})
}
}
db.DB.Statement.Joins = nil
db.joins = joins
for _, item := range joins {
j := item
if j.Type == JoinTypeDB {
db.DB = db.DB.Joins(j.Query, db.tx.Select(j.Selects))
} else {
db.DB = db.DB.Joins(j.Query, j.Args...)
}
}
return db
}
// Select 查询字段,在原版的基础上支持散落于不同逻辑,这在复杂的逻辑场景下有用,比如动态统计API的字段不是固定的
func (db *DB) Select(query interface{}, args ...interface{}) *DB {
db = db.Clone()
selects := db.DB.Select(query, args...).Statement.Selects
db.DB.Statement.Selects = append(db.DB.Statement.Selects, selects...)
return db
}
// 不继承之前的Select字段,重新开始
func (db *DB) NewSelect(query interface{}, args ...interface{}) *DB {
db = db.Clone()
db.DB = db.DB.Select(query, args...)
return db
}
// 增改查时要忽略的字段
func (db *DB) Omit(columns ...string) *DB {
db = db.Clone()
db.DB = db.DB.Omit(columns...)
return db
}
// Create inserts value
func (db *DB) Create(value ...any) error {
db = db.Clone()
if len(value) == 0 {
return nil
}
for _, v := range value {
err := db.DB.Create(v).Error
if err != nil {
if db.PanicIfError {
panic(err)
} else {
return err
}
}
}
return nil
}
func (db *DB) CreateInBatches(value any, batchSize int) error {
db = db.Clone()
if err := db.DB.CreateInBatches(value, batchSize).Error; err != nil {
if db.PanicIfError {
panic(err)
}
return err
}
return nil
}
func (db *DB) Update(column string, value interface{}) error {
db = db.Clone()
if err := db.DB.Update(column, value).Error; err != nil {
if db.PanicIfError {
panic(err)
}
return err
}
return nil
}
func (db *DB) Updates(data any) error {
db = db.Clone()
if err := db.DB.Updates(data).Error; err != nil {
if db.PanicIfError {
panic(err)
}
return err
}
return nil
}
func (db *DB) Delete(value interface{}, conds ...interface{}) error {
db = db.Clone()
err := db.DB.Delete(value, conds...).Error
if err != nil && db.PanicIfError {
panic(err)
}
return err
}
// 在原版的基础上支持直接忽略外键字段,而不需要明确调用Omit来忽略某个字段
func (db *DB) Save(value any, omitFK bool) error {
db = db.Clone()
refl.IterFields(value, func(f reflect.StructField) error {
field := AnalyzeGormTags(f)
if field.Column == "" && omitFK {
db = db.Omit(field.FieldName)
}
return nil
})
if err := db.DB.Save(value).Error; err != nil {
if db.PanicIfError {
panic(err)
}
return err
}
return nil
}
func (db *DB) First(dest any, conds ...any) error {
db = db.Clone()
if err := db.DB.First(dest, conds...).Error; err != nil {
if db.PanicIfError {
panic(err)
}
return err
}
return nil
}
func (db *DB) Last(dest any, conds ...any) error {
db = db.Clone()
if err := db.DB.Last(dest, conds...).Error; err != nil {
if db.PanicIfError {
panic(err)
}
return err
}
return nil
}
func (db *DB) Find(dest any, conds ...any) error {
db = db.Clone()
if err := db.DB.Find(dest, conds...).Error; err != nil {
if db.PanicIfError {
panic(err)
}
return err
}
return nil
}
// 注意SUM不要直接Scan,避免意料之外的异常
func (db *DB) Scan(dest any) error {
db = db.Clone()
if err := db.DB.Scan(dest).Error; err != nil {
if db.PanicIfError {
panic(err)
}
return err
}
return nil
}
func (db *DB) Preload(query string, args ...interface{}) *DB {
db = db.Clone()
db.DB = db.DB.Preload(query, args...)
return db
}
// group by 分组查询,在原有的基础上,允许将group字段分散在不同的地方
func (db *DB) Group(fields ...string) *DB {
db = db.Clone()
if db.groupFields == nil {
db.groupFields = make([]string, 0)
}
db.groupFields = append(db.groupFields, fields...)
db.groupFields = slices.Compact(db.groupFields)
db.DB = db.DB.Group(strings.Join(db.groupFields, ", "))
return db
}
// 放弃之前的group,开始组织新的group语句
func (db *DB) NewGroup(fields ...string) *DB {
db = db.Clone()
db.groupFields = make([]string, 0)
return db.Group(fields...)
}
func (db *DB) Distinct(args ...interface{}) *DB {
db = db.Clone()
db.DB = db.DB.Distinct(args...)
return db
}
func (db *DB) Limit(limit int) *DB {
db = db.Clone()
db.DB = db.DB.Limit(limit)
return db
}
func (db *DB) Offset(offset int) *DB {
db = db.Clone()
db.DB = db.DB.Offset(offset)
return db
}
// Count计数
func (db *DB) Count() (int64, error) {
db = db.Clone()
var count int64
if err := db.DB.Count(&count).Error; err != nil {
if db.PanicIfError {
panic(err)
}
return 0, err
}
return count, nil
}
func (db *DB) Order(fields ...string) *DB {
db = db.Clone()
db.DB = db.DB.Order(strings.Join(fields, ", "))
return db
}
func (db *DB) Raw(sql string, values ...interface{}) *DB {
db = db.Clone()
db.DB = db.DB.Raw(sql, values...)
return db
}
// 执行SQL
func (db *DB) Exec(sql string, value ...any) error {
db = db.Clone()
if err := db.DB.Exec(sql, value...).Error; err != nil {
if db.PanicIfError {
panic(err)
}
return err
}
return nil
}
// 获取影响的记录数
func (db *DB) RowsAffected() int64 {
return db.DB.RowsAffected
}
func (db *DB) Paginate(dest any, pi int, ps int) (*Pagination[any], error) {
db = db.Clone()
count, err := db.Count()
if err != nil {
if db.PanicIfError {
panic(err)
}
return nil, err
}
offset := (pi - 1) * ps
if err := CloneDB(db.DB).Limit(ps).Offset(offset).Find(dest).Error; err != nil {
if db.PanicIfError {
panic(err)
}
return nil, err
}
return &Pagination[any]{
PageIndex: pi,
PageSize: ps,
Total: count,
List: dest,
}, nil
}
// 根据需要选择分页或者不分页,用于接口中分页或者不分页都要支持的情况
func (db *DB) FindOrPaginate(dest any, pi int, ps int) (any, error) {
db = db.Clone()
if pi == 0 {
if err := db.Find(dest); err != nil {
if db.PanicIfError {
panic(err)
}
return nil, err
} else {
return dest, nil
}
} else {
if pagination, err := db.Paginate(dest, pi, ps); err != nil {
if db.PanicIfError {
panic(err)
}
return nil, err
} else {
return pagination, nil
}
}
}
func (db *DB) PaginateScan(dest any, pi int, ps int) (*Pagination[any], error) {
db = db.Clone()
count, err := db.Count()
if err != nil {
if db.PanicIfError {
panic(err)
}
return nil, err
}
offset := (pi - 1) * ps
if err := CloneDB(db.DB).Limit(ps).Offset(offset).Scan(dest).Error; err != nil {
if db.PanicIfError {
panic(err)
}
return nil, err
}
return &Pagination[any]{
PageIndex: pi,
PageSize: ps,
Total: count,
List: dest,
}, nil
}
func (db *DB) ScanOrPaginate(dest any, pi int, ps int) (any, error) {
db = db.Clone()
if pi == 0 {
if err := db.Scan(dest); err != nil {
if db.PanicIfError {
panic(err)
}
return nil, err
} else {
return dest, nil
}
} else {
if pagination, err := db.PaginateScan(dest, pi, ps); err != nil {
if db.PanicIfError {
panic(err)
}
return nil, err
} else {
return pagination, nil
}
}
}
func (db *DB) FirstOrInit(dest any, attrs ...any) error {
db = db.Clone()
if err := db.DB.Attrs(attrs...).FirstOrInit(dest).Error; err != nil {
if db.PanicIfError {
panic(err)
}
return err
}
return nil
}
func (db *DB) FirstAndInit(dest any, attrs ...any) error {
db = db.Clone()
if err := db.DB.Assign(attrs...).FirstOrInit(dest).Error; err != nil {
if db.PanicIfError {
panic(err)
}
return err
}
return nil
}
func (db *DB) FirstOrCreate(dest any, attrs ...any) error {
db = db.Clone()
if err := db.DB.Attrs(attrs...).FirstOrCreate(dest).Error; err != nil {
if db.PanicIfError {
panic(err)
}
return err
}
return nil
}
// func (db *DB) UpdateOrCreate(dest any, attrs ...any) error {
// db = db.Clone()
// if err := db.DB.Assign(attrs...).FirstOrCreate(dest).Error; err != nil {
// if db.PanicIfError {
// panic(err)
// }
// return err
// }
// return nil
// }
func (db *DB) UpdateOrCreate(dest any, conds map[string]any, data map[string]any) error {
db = db.Where(conds)
if err := CloneDB(db.DB).First(&dest).Error; err != nil {
if err2 := db.tx.Model(dest).Create(maps.Merge(data, conds)).Error; err2 != nil {
if db.PanicIfError {
panic(err2)
}
return err2
}
} else {
if err := CloneDB(db.DB).Updates(data).Error; err != nil {
if db.PanicIfError {
panic(err)
}
return err
}
}
if err := db.DB.First(dest).Error; err != nil {
if db.PanicIfError {
panic(err)
}
return nil
}
return nil
}
// 封装执行逻辑块,方便变量起名困难症患者
func (db *DB) Run(fc func(db *DB)) *DB {
fc(db)
return db
}
func (db *DB) Sum(query string) (float64, error) {
db = db.Clone()
var sum sql.NullFloat64
db = db.Select(query)
if err := db.Scan(&sum); err != nil {
if db.PanicIfError {
panic(err)
}
return 0, err
} else {
if sum.Valid {
return sum.Float64, nil
} else {
return 0, nil
}
}
}
// 事务
func (db *DB) Transaction(fc func(tx *DB) error) error {
return db.DB.Transaction(func(tx *gorm.DB) error {
return fc(newDB(tx))
})
}
type GormColumn struct {
IsPrimaryKey bool `json:"primary_key" form:"primary_key"`
Column string `json:"column" form:"column"` // 数据库字段名称
FieldName string `json:"field_name" form:"field_name"` // 字段名称
FieldType string `json:"field_type" form:"field_type"`
Size int `json:"size" form:"size"`
IsFKField bool `json:"is_fk_field"`
DefaultValue string `json:"default_value" form:"default_value"`
}
// 解析GORM标签中的单个选项
func parseGormOption(option string) (key string, value string) {
parts := strings.SplitN(option, ":", 2)
if len(parts) == 1 {
return parts[0], "true"
}
return parts[0], strings.Trim(parts[1], `"'`)
}
// 分析GORM标签的函数
func AnalyzeGormTags(f reflect.StructField) *GormColumn {
fieldType := fmt.Sprintf("%v", f.Type)
if gormTag := f.Tag.Get("gorm"); gormTag != "" {
// 分割GORM标签中的选项
options := strings.Split(gormTag, ";")
var config GormColumn
config.FieldName = f.Name
config.FieldType = fieldType
for _, option := range options {
if strings.ToLower(option) == "primarykey" {
config.IsPrimaryKey = true
}
// 解析每个选项
key, value := parseGormOption(option)
if key == "column" {
config.Column = value
}
if strings.ToLower(key) == "foreignkey" {
config.IsFKField = true
}
}
if config.Column == "" && !config.IsFKField {
config.Column = strutil.SnakeCase(f.Name, "_")
}
return &config
}
return nil
}
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。