1 Star 0 Fork 0

Wsage/go-framework

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
db_mysql.go 6.90 KB
一键复制 编辑 原始数据 按行查看 历史
sage 提交于 2022-02-19 21:38 . modify rows close
package dbs
import (
"database/sql"
"fmt"
v1config "gitee.com/scottq/go-framework/src/v1/config"
"github.com/go-sql-driver/mysql"
"log"
"net"
"strings"
"time"
)
type IDBMysql interface {
DB() *sql.DB
TX() *sql.Tx
ExecuteSearch(tableName string, fields []string, whereArr []string, whereArgs []interface{}, orderBy []string, pageNum, pageSize int64, rowsHandler SearchRowsHandler) (int64, int64, error)
ExecuteQuery(tableName string, fields map[string]interface{}, whereArr []string, whereArgs []interface{}, orderBy []string) (bool, error)
ExecuteCreate(tableName string, fields map[string]interface{}) (int64, error)
ExecuteUpdate(tableName string, fields map[string]interface{}, whereArr []string, whereArgs []interface{}) (int64, error)
BeginTrans() (IDBMysql, error)
CommitTrans() error
RollbackTrans() error
}
type DBMysql struct {
Edb *sql.DB
Etx *sql.Tx
}
func (d *DBMysql) DB() *sql.DB {
return d.Edb
}
func (d *DBMysql) TX() *sql.Tx {
return d.Etx
}
func NewDBMysql(c v1config.DBConfig) (*DBMysql, error) {
dbConfig := mysql.NewConfig()
dbConfig.User = c.DbUser
dbConfig.Passwd = c.DbPassword
dbConfig.Net = "tcp"
dbConfig.Addr = net.JoinHostPort(c.DbHost, c.DbPort)
dbConfig.DBName = c.DbName
dbConfig.MultiStatements = true
dbConfig.RejectReadOnly = false
extParam := make(map[string]string)
if c.MaxConcatLen != "" {
extParam["group_concat_max_len"] = c.MaxConcatLen
}
dbConfig.Params = extParam
newDb, err := sql.Open("mysql", dbConfig.FormatDSN())
if err != nil {
log.Fatalf("connect to db %s failed", dbConfig.FormatDSN())
return nil, err
}
if c.MaxIdleConns > 0 {
//预留并发链接数
newDb.SetMaxIdleConns(c.MaxIdleConns)
}
if c.MaxOpenConns > 0 {
//最大支持链接
newDb.SetMaxOpenConns(c.MaxOpenConns)
}
if c.MaxLifetime > 0 {
//每个链接最大生存时间
newDb.SetConnMaxLifetime(time.Duration(c.MaxLifetime))
}
if c.MaxIdleTime > 0 {
//每个链接最大空闲时间
newDb.SetConnMaxIdleTime(time.Duration(c.MaxIdleTime))
}
return &DBMysql{
Edb: newDb,
}, nil
}
type SearchRowsHandler func(*sql.Rows) error
func (d *DBMysql) ExecuteSearch(tableName string, fields []string, whereArr []string, whereArgs []interface{}, orderBy []string, pageNum, pageSize int64, rowsHandler SearchRowsHandler) (int64, int64, error) {
var err error
var total int64
if pageNum <= 0 {
pageNum = 1
}
if pageSize <= 0 || pageSize > 1000 {
pageSize = 1000
}
whereStr := ""
if len(whereArr) > 0 {
whereStr = "WHERE " + strings.Join(whereArr, " AND ")
}
countSql := fmt.Sprintf(
"SELECT COUNT(id) AS total FROM `%s` %s LIMIT 1",
tableName, whereStr)
db := d.DB()
tx := d.TX()
//total
var stmt *sql.Stmt
if tx != nil {
stmt, err = tx.Prepare(countSql)
} else {
stmt, err = db.Prepare(countSql)
}
if err != nil {
return 0, total, err
}
defer stmt.Close()
row := stmt.QueryRow(whereArgs...)
err = row.Scan(&total)
if err == sql.ErrNoRows {
return 0, 0, nil
} else if err != nil {
return 0, total, err
}
if len(fields) <= 0 {
fields = append(fields, "*")
}
fieldsStr := strings.Join(fields, ",")
orderByStr := strings.Join(orderBy, ",")
if orderByStr != "" {
orderByStr = "ORDER BY " + orderByStr
}
searchSql := fmt.Sprintf(
"SELECT %s FROM `%s` %s %s LIMIT ? OFFSET ?",
fieldsStr, tableName, whereStr, orderByStr)
var stmt1 *sql.Stmt
if tx != nil {
stmt1, err = tx.Prepare(searchSql)
} else {
stmt1, err = db.Prepare(searchSql)
}
if err != nil {
return 0, total, err
}
defer stmt1.Close()
whereArgs = append(whereArgs, pageSize)
whereArgs = append(whereArgs, pageSize*(pageNum-1))
rows, err := stmt1.Query(whereArgs...)
if err != nil {
return 0, total, err
}
defer rows.Close()
//处理rows
err = rowsHandler(rows)
if err != nil {
return 0, total, err
}
return 0, total, nil
}
func (d *DBMysql) ExecuteQuery(tableName string, fields map[string]interface{}, whereArr []string, whereArgs []interface{}, orderBy []string) (bool, error) {
var err error
var fieldArr = []string{}
var scanArr = []interface{}{}
for k, v := range fields {
fieldArr = append(fieldArr, fmt.Sprintf("`%s`", k))
scanArr = append(scanArr, v)
}
whereStr := ""
if len(whereArr) > 0 {
whereStr = "WHERE " + strings.Join(whereArr, " AND ")
}
fieldStr := strings.Join(fieldArr, ",")
orderByStr := strings.Join(orderBy, ",")
if orderByStr != "" {
orderByStr = "ORDER BY " + orderByStr
}
selectSql := fmt.Sprintf("SELECT %s FROM %s %s %s LIMIT 1",
fieldStr, tableName, whereStr, orderByStr)
db := d.DB()
tx := d.TX()
var stmt *sql.Stmt
if tx != nil {
stmt, err = tx.Prepare(selectSql)
} else {
stmt, err = db.Prepare(selectSql)
}
if err != nil {
return false, err
}
defer stmt.Close()
row := stmt.QueryRow(whereArgs...)
err = row.Scan(scanArr...)
if err == sql.ErrNoRows {
return false, nil
} else if err != nil {
return false, err
}
return true, nil
}
func (d *DBMysql) ExecuteCreate(tableName string, fields map[string]interface{}) (int64, error) {
var err error
var fieldArr = []string{}
var valueArr = []interface{}{}
for k, v := range fields {
fieldArr = append(fieldArr, fmt.Sprintf("`%s`=?", k))
valueArr = append(valueArr, v)
}
fieldStr := strings.Join(fieldArr, ",")
db := d.DB()
tx := d.TX()
insertSql := fmt.Sprintf("INSERT INTO %s SET %s", tableName, fieldStr)
var stmt *sql.Stmt
if tx != nil {
stmt, err = tx.Prepare(insertSql)
} else {
stmt, err = db.Prepare(insertSql)
}
if err != nil {
return 0, err
}
defer stmt.Close()
ret, err := stmt.Exec(valueArr...)
if err != nil {
return 0, err
}
return ret.LastInsertId()
}
func (d *DBMysql) ExecuteUpdate(tableName string, fields map[string]interface{}, whereArr []string, whereArgs []interface{}) (int64, error) {
var err error
var fieldArr = []string{}
var valueArr = []interface{}{}
for k, v := range fields {
fieldArr = append(fieldArr, fmt.Sprintf("`%s`=?", k))
valueArr = append(valueArr, v)
}
fieldStr := strings.Join(fieldArr, ",")
whereStr := ""
if len(whereArr) > 0 {
whereStr = "WHERE " + strings.Join(whereArr, " AND ")
}
if len(whereArgs) > 0 {
for _, v := range whereArgs {
valueArr = append(valueArr, v)
}
}
db := d.DB()
tx := d.TX()
updateSql := fmt.Sprintf("UPDATE %s SET %s %s", tableName, fieldStr, whereStr)
var stmt *sql.Stmt
if tx != nil {
stmt, err = tx.Prepare(updateSql)
} else {
stmt, err = db.Prepare(updateSql)
}
if err != nil {
return 0, err
}
defer stmt.Close()
ret, err := stmt.Exec(valueArr...)
if err != nil {
return 0, err
}
return ret.RowsAffected()
}
func (d *DBMysql) BeginTrans() (IDBMysql, error) {
tx, err := d.Edb.Begin()
if err != nil {
return nil, err
}
return &DBMysql{
Edb: nil,
Etx: tx,
}, nil
}
func (d *DBMysql) CommitTrans() error {
if d.Etx == nil {
return fmt.Errorf("not begin trans")
}
return d.Etx.Commit()
}
func (d *DBMysql) RollbackTrans() error {
if d.Etx == nil {
return fmt.Errorf("not begin trans")
}
return d.Etx.Rollback()
}
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Go
1
https://gitee.com/scottq/go-framework.git
git@gitee.com:scottq/go-framework.git
scottq
go-framework
go-framework
v1.1.25

搜索帮助