代码拉取完成,页面将自动刷新
package dbs
import (
"context"
"database/sql"
"fmt"
v1config "gitee.com/scottq/go-framework/src/v1/config"
"github.com/go-sql-driver/mysql"
ormmysql "gorm.io/driver/mysql"
"gorm.io/gorm"
"strings"
)
type IGormMysql interface {
DB() *sql.DB
GormDB() *gorm.DB
GormTX() gorm.ConnPool
ExecuteSearch(tableName string, fields []string, whereArr []string, whereArgs []interface{}, orderBy []string, pageNum, pageSize int64, rowsHandler SearchRowsHandler) (int64, int64, error)
ExecuteQuery(tableName string, fields map[string]interface{}, whereArr []string, whereArgs []interface{}, orderBy []string) (bool, error)
ExecuteCreate(tableName string, fields map[string]interface{}) (int64, error)
ExecuteUpdate(tableName string, fields map[string]interface{}, whereArr []string, whereArgs []interface{}) (int64, error)
BeginTrans() (IGormMysql, error)
CommitTrans() error
RollbackTrans() error
}
type GormMysql struct {
//db,正常操作db,开始事务后为begin后的gorm
Edb *gorm.DB
//tx,初始为nil,开始事务后为对应事务上下文
//可判断是否为nil来判断是否开始了事务
Etx gorm.ConnPool
}
//NewGormMysql gorm的mysql操作db
func NewGormMysql(c *v1config.DBConfig, fs ...func(config *mysql.Config)) (*GormMysql, error) {
conn, err := NewMysqlConn(c, fs...)
if err != nil {
return nil, err
}
db, err := gorm.Open(ormmysql.New(ormmysql.Config{Conn: conn}), &gorm.Config{})
if err != nil {
return nil, err
}
return &GormMysql{
Edb: db,
}, nil
}
func (d *GormMysql) ExecuteSearch(tableName string, fields []string, whereArr []string, whereArgs []interface{}, orderBy []string, pageNum, pageSize int64, rowsHandler SearchRowsHandler) (int64, int64, error) {
var err error
var total int64
if pageNum <= 0 {
pageNum = 1
}
if pageSize <= 0 || pageSize > 1000 {
pageSize = 1000
}
whereStr := ""
if len(whereArr) > 0 {
whereStr = "WHERE " + strings.Join(whereArr, " AND ")
}
countSql := fmt.Sprintf(
"SELECT COUNT(id) AS total FROM `%s` %s LIMIT 1",
tableName, whereStr)
//total
stmt, err := d.sqlConn().PrepareContext(context.Background(), countSql)
if err != nil {
return 0, total, err
}
defer stmt.Close()
row := stmt.QueryRow(whereArgs...)
err = row.Scan(&total)
if err == sql.ErrNoRows {
return 0, 0, nil
} else if err != nil {
return 0, total, err
}
if len(fields) <= 0 {
fields = append(fields, "*")
}
fieldsStr := strings.Join(fields, ",")
orderByStr := strings.Join(orderBy, ",")
if orderByStr != "" {
orderByStr = "ORDER BY " + orderByStr
}
searchSql := fmt.Sprintf(
"SELECT %s FROM `%s` %s %s LIMIT ? OFFSET ?",
fieldsStr, tableName, whereStr, orderByStr)
stmt1, err := d.sqlConn().PrepareContext(context.Background(), searchSql)
if err != nil {
return 0, total, err
}
defer stmt1.Close()
whereArgs = append(whereArgs, pageSize)
whereArgs = append(whereArgs, pageSize*(pageNum-1))
rows, err := stmt1.Query(whereArgs...)
if err != nil {
return 0, total, err
}
defer rows.Close()
//处理rows
err = rowsHandler(rows)
if err != nil {
return 0, total, err
}
return 0, total, nil
}
func (d *GormMysql) ExecuteQuery(tableName string, fields map[string]interface{}, whereArr []string, whereArgs []interface{}, orderBy []string) (bool, error) {
var err error
var fieldArr = []string{}
var scanArr = []interface{}{}
for k, v := range fields {
fieldArr = append(fieldArr, fmt.Sprintf("`%s`", k))
scanArr = append(scanArr, v)
}
whereStr := ""
if len(whereArr) > 0 {
whereStr = "WHERE " + strings.Join(whereArr, " AND ")
}
fieldStr := strings.Join(fieldArr, ",")
var orderByStr string
//orderByStr = strings.Join(orderBy, ",")
//if orderByStr != "" {
// orderByStr = "ORDER BY " + orderByStr
//}
selectSql := fmt.Sprintf("SELECT %s FROM %s %s %s LIMIT 1",
fieldStr, tableName, whereStr, orderByStr)
stmt, err := d.sqlConn().PrepareContext(context.Background(), selectSql)
if err != nil {
return false, err
}
defer stmt.Close()
row := stmt.QueryRow(whereArgs...)
err = row.Scan(scanArr...)
if err == sql.ErrNoRows {
return false, nil
} else if err != nil {
return false, err
}
return true, nil
}
func (d *GormMysql) ExecuteCreate(tableName string, fields map[string]interface{}) (int64, error) {
var err error
var fieldArr = []string{}
var valueArr = []interface{}{}
for k, v := range fields {
fieldArr = append(fieldArr, fmt.Sprintf("`%s`=?", k))
valueArr = append(valueArr, v)
}
fieldStr := strings.Join(fieldArr, ",")
insertSql := fmt.Sprintf("INSERT INTO %s SET %s", tableName, fieldStr)
stmt, err := d.sqlConn().PrepareContext(context.Background(), insertSql)
if err != nil {
return 0, err
}
defer stmt.Close()
ret, err := stmt.Exec(valueArr...)
if err != nil {
return 0, err
}
return ret.LastInsertId()
}
func (d *GormMysql) ExecuteUpdate(tableName string, fields map[string]interface{}, whereArr []string, whereArgs []interface{}) (int64, error) {
var err error
var fieldArr = []string{}
var valueArr = []interface{}{}
for k, v := range fields {
fieldArr = append(fieldArr, fmt.Sprintf("`%s`=?", k))
valueArr = append(valueArr, v)
}
fieldStr := strings.Join(fieldArr, ",")
whereStr := ""
if len(whereArr) > 0 {
whereStr = "WHERE " + strings.Join(whereArr, " AND ")
}
if len(whereArgs) > 0 {
for _, v := range whereArgs {
valueArr = append(valueArr, v)
}
}
updateSql := fmt.Sprintf("UPDATE %s SET %s %s", tableName, fieldStr, whereStr)
stmt, err := d.sqlConn().PrepareContext(context.Background(), updateSql)
if err != nil {
return 0, err
}
defer stmt.Close()
ret, err := stmt.Exec(valueArr...)
if err != nil {
return 0, err
}
return ret.RowsAffected()
}
func (d *GormMysql) DB() *sql.DB {
db, _ := d.Edb.DB()
return db
}
func (d *GormMysql) GormDB() *gorm.DB {
return d.Edb
}
func (d *GormMysql) GormTX() gorm.ConnPool {
return d.Etx
}
func (d *GormMysql) BeginTrans() (IGormMysql, error) {
db := d.Edb.Begin()
return &GormMysql{
Edb: db,
Etx: db.Statement.ConnPool,
}, nil
}
func (d *GormMysql) CommitTrans() error {
if d.Etx == nil {
return fmt.Errorf("not begin trans")
}
return d.Edb.Commit().Error
}
func (d *GormMysql) RollbackTrans() error {
if d.Etx == nil {
return fmt.Errorf("not begin trans")
}
return d.Edb.Rollback().Error
}
func (d *GormMysql) WithDB(db *gorm.DB) *GormMysql {
return &GormMysql{
Edb: db,
}
}
//原生的conn操作
func (d *GormMysql) sqlConn() gorm.ConnPool {
if d.Etx != nil {
return d.Etx
}
db, _ := d.Edb.DB()
return db
}
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。