1 Star 0 Fork 0

qq51529210/golang封装的一些开发包

加入 Gitee
与超过 1400万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
gorm.go 8.81 KB
一键复制 编辑 原始数据 按行查看 历史
qq51529210 提交于 2026-02-02 11:33 +08:00 . u
package ggorm
import (
"database/sql"
"database/sql/driver"
"encoding/json"
"errors"
"fmt"
"strings"
"github.com/glebarez/sqlite"
"github.com/go-sql-driver/mysql"
gormmysql "gorm.io/driver/mysql"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
)
// 初始化,自动创建数据库
// mysql dsn: mysql://root:1234@tcp(127.0.0.1:3306)/test?charset=utf8mb4&parseTime=True&loc=Local
// sqlite dsn: ./test.db
func Open(dsn string) (*gorm.DB, error) {
var cfg gorm.Config
cfg.NamingStrategy = schema.NamingStrategy{
SingularTable: true,
NoLowerCase: true,
}
// mysql
_dsn := strings.TrimPrefix(dsn, "mysql://")
if _dsn != dsn {
return openMysql(_dsn, &cfg)
}
// sqlite
return openSqlite(_dsn, &cfg)
}
// 初始化 mysql
func openMysql(dsn string, cfg *gorm.Config) (*gorm.DB, error) {
// 解析出 schema
mysqlCfg, err := mysql.ParseDSN(dsn)
if err != nil {
return nil, err
}
// 打开连接,不要数据库
_dsn := strings.Replace(dsn, mysqlCfg.DBName, "", 1)
db, err := sql.Open("mysql", _dsn)
if err != nil {
return nil, err
}
defer db.Close()
// 如果没有就创建数据库
_, err = db.Exec(fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS `%s` DEFAULT CHARACTER SET utf8mb4;", mysqlCfg.DBName))
if err != nil {
return nil, err
}
return gorm.Open(gormmysql.Open(dsn), cfg)
}
// 初始化 sqlite
func openSqlite(dsn string, cfg *gorm.Config) (*gorm.DB, error) {
// sqlite
db, err := gorm.Open(sqlite.Open(dsn), cfg)
if err != nil {
return nil, err
}
err = db.Exec("PRAGMA foreign_keys = ON;").Error
if err != nil {
return nil, err
}
//
return db, nil
}
// 是否没有数据错误
func IsDataNotFound(err error) bool {
return err == gorm.ErrRecordNotFound
}
// 判断是否唯一性错误,mysql 使用
func IsMySqlDuplicateError(err error) bool {
if e, ok := err.(*mysql.MySQLError); ok {
return e.Number == 1062
}
return false
}
// First 封装,ptr 是接收数据的结构体指针,query 是条件struct,fields 是查询的字段
func First(tx *gorm.DB, ptr any, query any, fields ...string) error {
db := tx
if len(fields) > 0 {
db = db.Select(fields)
}
return WhereQuery(db, query).First(ptr).Error
}
// Add 封装,fields 是添加的字段
func Add(tx *gorm.DB, v any, fields ...string) *gorm.DB {
db := tx
if len(fields) > 0 {
db = db.Select(fields)
}
return db.Create(v)
}
// AddOmit 封装,fields 是忽略的字段
func AddOmit(tx *gorm.DB, v any, fields ...string) *gorm.DB {
db := tx
if len(fields) > 0 {
db = db.Omit(fields...)
}
return db.Create(v)
}
// AddIgnore 封装,fields 是修改的字段
func AddIgnore(tx *gorm.DB, v any, fields ...string) *gorm.DB {
db := tx.Clauses(clause.OnConflict{DoNothing: true})
if len(fields) > 0 {
db = db.Select(fields)
}
return db.Create(v)
}
// AddIgnoreOmit 封装,fields 是修改的字段
func AddIgnoreOmit(tx *gorm.DB, v any, fields ...string) *gorm.DB {
db := tx.Clauses(clause.OnConflict{DoNothing: true})
if len(fields) > 0 {
db = db.Omit(fields...)
}
return db.Create(v)
}
// AddOrUpdate 封装,fields 是添加/修改的字段
func AddOrUpdate(tx *gorm.DB, v any, fields ...string) *gorm.DB {
return tx.Clauses(clause.OnConflict{
UpdateAll: len(fields) == 0,
DoUpdates: clause.AssignmentColumns(fields),
}).Create(v)
}
// Update 封装,fields 是添加的字段,query 是条件struct,fields 是修改的字段
func Update(tx *gorm.DB, v any, query any, fields ...string) *gorm.DB {
db := tx
if len(fields) > 0 {
db = db.Select(fields)
}
return WhereQuery(db, query).Updates(v)
}
// PageQuery 分页查询参数
type PageQuery struct {
// 偏移,小于 0 不匹配
Offset int `json:"offset,omitempty" form:"offset" binding:"omitempty,min=0"`
// 条数,小于 1 不匹配
Count int `json:"count,omitempty" form:"count" binding:"omitempty,min=1"`
// 排序,"column1 [desc], column2..."
Order string `json:"order,omitempty" form:"order"`
// 是否需要返回总数
Total string `json:"total,omitempty" form:"total" binding:"omitempty,oneof=0 1"`
}
// InitDB 根据分页参数初始化数据库查询
// 设置 offset、limit 和排序规则
// 参数:
// - db: GORM 数据库连接
//
// 返回:
// - 设置了分页参数的数据库连接
func (m *PageQuery) InitDB(db *gorm.DB) *gorm.DB {
// 分页
if m.Offset > 0 {
db = db.Offset(m.Offset)
}
// 数量
if m.Count > 0 {
db = db.Limit(m.Count)
}
// 排序
if m.Order != "" {
db = db.Order(fmt.Sprintf("`%s`", m.Order))
}
return db
}
// PageResult 是 Page 的返回值
type PageResult[M any] struct {
// 总数
Total int64 `json:"total"`
// 列表
Data []M `json:"data"`
}
// Page 用于分页查询
func Page[M any](db *gorm.DB, page *PageQuery, res *PageResult[M], fields ...string) error {
if page != nil {
// 总数
if page.Total == "1" {
if err := db.Count(&res.Total).Error; err != nil {
return err
}
}
// 分页
db = page.InitDB(db)
}
// 查询
if len(fields) > 0 {
db = db.Select(fields)
}
if err := db.Find(&res.Data).Error; err != nil {
return err
}
//
return nil
}
// Page 用于分页查询
func PageOmit[M any](db *gorm.DB, page *PageQuery, res *PageResult[M], fields ...string) error {
if page != nil {
// 总数
if page.Total == "1" {
if err := db.Count(&res.Total).Error; err != nil {
return err
}
}
// 分页
db = page.InitDB(db)
}
// 查询
if len(fields) > 0 {
db = db.Omit(fields...)
}
if err := db.Find(&res.Data).Error; err != nil {
return err
}
//
return nil
}
// All 用于查询全部,query 是条件,fields 是查询的字段
func All[M any](db *gorm.DB, query any, fields ...string) (ms []M, err error) {
// 查询的字段
if len(fields) > 0 {
db = db.Select(fields)
}
// 查询条件
if query != nil {
db = WhereQuery(db, query)
}
// 查询
err = db.Find(&ms).Error
return
}
// AllOmit 用于查询全部,query 是条件,fields 是忽略的字段
func AllOmit[M any](db *gorm.DB, query any, fields ...string) (ms []M, err error) {
// 查询的字段
if len(fields) > 0 {
db = db.Omit(fields...)
}
// 查询条件
if query != nil {
db = WhereQuery(db, query)
}
// 查询
err = db.Find(&ms).Error
return
}
// AllOrder 用于查询全部,query 是条件,order 是排序sql,fields 是查询的字段
func AllOrder[M any](db *gorm.DB, query any, order string, fields ...string) (ms []M, err error) {
// 查询条件
if query != nil {
db = WhereQuery(db, query)
}
// 排序
if order != "" {
db = db.Order(order)
}
// 查询
err = db.Find(&ms).Error
return
}
// AllOrderOmit 用于查询全部,query 是条件,order 是排序sql,fields 是忽略的字段
func AllOrderOmit[M any](db *gorm.DB, query any, order string, fields ...string) (ms []M, err error) {
// 查询的字段
if len(fields) > 0 {
db = db.Omit(fields...)
}
// 查询条件
if query != nil {
db = WhereQuery(db, query)
}
// 查询
err = db.Find(&ms).Error
return
}
// Field 用于查询单个字段,query 是条件,field 是字段名
func Field[M any](db *gorm.DB, query any, field string) (ms []M, err error) {
// 查询条件
if query != nil {
db = WhereQuery(db, query)
}
// 查询
err = db.Pluck(field, &ms).Error
return
}
// 查询锁
func LockForUpdate(db *gorm.DB) *gorm.DB {
return db.Clauses(clause.Locking{Strength: "UPDATE"})
}
// 查询锁
func GetLockForUpdate(db *gorm.DB, m any) error {
return LockForUpdate(db).Take(m).Error
}
// 查询锁
func AllLockForUpdate[M any](db *gorm.DB, query any, fields ...string) ([]M, error) {
return All[M](LockForUpdate(db), query, fields...)
}
// 使用一个新 session 来生成sql
// 在 db 被全局锁表,比如 casbin 的事务使用
func SQL(tx, clean *gorm.DB, init func(db *gorm.DB) *gorm.DB) (string, []any) {
sess := clean.Session(&gorm.Session{DryRun: true, Logger: new(DropLogger)})
db := init(sess)
return db.Statement.SQL.String(), db.Statement.Vars
}
// 执行sql
// 在 db 被全局锁表,比如 casbin 的事务使用
func Exec(tx, clean *gorm.DB, init func(db *gorm.DB) *gorm.DB) (int64, error) {
sess := clean.Session(&gorm.Session{DryRun: true, Logger: new(DropLogger)})
db := init(sess)
db = tx.Exec(db.Statement.SQL.String(), db.Statement.Vars...)
return db.RowsAffected, db.Error
}
// 执行sql
// 在 db 被全局锁表,比如 casbin 的事务使用
func Scan(tx, clean *gorm.DB, init func(db *gorm.DB) *gorm.DB, destPtr any) error {
sess := clean.Session(&gorm.Session{DryRun: true, Logger: new(DropLogger)})
db := init(sess)
return tx.Raw(db.Statement.SQL.String(), db.Statement.Vars...).Scan(destPtr).Error
}
func ScanValue(value any, ptr any) error {
switch v := value.(type) {
case []byte:
return json.Unmarshal(v, ptr)
case string:
return json.Unmarshal([]byte(v), ptr)
default:
return errors.New("unsupport scan type")
}
}
func ValueData(ptr any) (driver.Value, error) {
if ptr == nil {
return "{}", nil
}
return json.Marshal(ptr)
}
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Go
1
https://gitee.com/qq51529210/gutil.git
git@gitee.com:qq51529210/gutil.git
qq51529210
gutil
golang封装的一些开发包
b65ca894fc34

搜索帮助