代码拉取完成,页面将自动刷新
package util
import (
"context"
"database/sql"
"fmt"
"math"
"strings"
"sync"
"time"
)
type MysqlClient struct {
Listener func(ctx context.Context, query string, err error, args ...any)
dbMap map[string]*sql.DB
version string
ConnOptions []*MysqlConnOption
TimeoutInsert time.Duration
TimeoutSelect time.Duration
TimeoutUpdate time.Duration
TimeoutTrans time.Duration
mutex sync.RWMutex
}
func NewMysqlClient(connOptions ...*MysqlConnOption) (*MysqlClient, error) {
if len(connOptions) == 0 {
return nil, ErrMysqlWrongOption
}
return &MysqlClient{
TimeoutInsert: 3 * time.Second,
TimeoutSelect: 3 * time.Second,
TimeoutUpdate: 3 * time.Second,
TimeoutTrans: 5 * time.Second,
ConnOptions: connOptions,
dbMap: make(map[string]*sql.DB),
}, nil
}
func NewMysqlClientEasy(dsn ...string) (*MysqlClient, error) {
if len(dsn) == 0 {
return nil, ErrMysqlWrongDsn
}
options := make([]*MysqlConnOption, 0, len(dsn))
for _, d := range dsn {
if option, err := NewMysqlConnOption(d, 0, 0, -1, -1); err != nil {
return nil, err
} else {
options = append(options, option)
}
}
return NewMysqlClient(options...)
}
func (c *MysqlClient) GetDB(ctx context.Context, useSlave bool) (*sql.DB, error) {
if useSlave && len(c.ConnOptions) == 1 {
useSlave = false
}
var option *MysqlConnOption
var err error
if !useSlave {
option = c.ConnOptions[0]
} else if len(c.ConnOptions) == 2 {
option = c.ConnOptions[1]
} else {
option = c.ConnOptions[RandRange(1, len(c.ConnOptions))]
}
if option == nil || option.Addr == "" || option.DbName == "" || option.Dsn == "" {
return nil, ErrMysqlWrongOption
}
k := option.Addr + ":" + option.DbName
c.mutex.RLock()
db, ok := c.dbMap[k]
c.mutex.RUnlock()
if ok && db != nil {
return db, nil
}
c.mutex.Lock()
defer c.mutex.Unlock()
if db, ok = c.dbMap[k]; ok && db != nil {
return db, nil
}
if db, err = sql.Open("mysql", option.Dsn); err != nil {
return nil, err
}
ctx1, cancel := context.WithTimeout(ctx, c.TimeoutSelect)
defer cancel()
if err = db.PingContext(ctx1); err != nil {
return nil, err
}
db.SetMaxIdleConns(option.MaxIdle)
db.SetMaxOpenConns(option.MaxOpen)
db.SetConnMaxIdleTime(option.MaxIdleTime)
db.SetConnMaxLifetime(option.MaxLifetime)
if !ok {
c.dbMap[k] = db
}
return db, nil
}
func (c *MysqlClient) CloseDB() {
c.mutex.Lock()
defer c.mutex.Unlock()
for _, db := range c.dbMap {
_ = db.Close()
}
clear(c.dbMap)
}
func (c *MysqlClient) Truncate(ctx context.Context, table string) error {
db, err := c.GetDB(ctx, false)
if err != nil {
return err
}
query := fmt.Sprintf("TRUNCATE TABLE `%s`", table)
defer func() {
if c.Listener != nil {
go c.Listener(ctx, query, err, table)
}
}()
ctx1, cancel := context.WithTimeout(ctx, c.TimeoutUpdate)
defer cancel()
_, err = db.ExecContext(ctx1, query)
return err
}
func (c *MysqlClient) Version(ctx context.Context) (string, error) {
if c.version != "" {
return c.version, nil
}
if err := c.SelectWalk(ctx, func(ctx1 context.Context, row *MysqlRow) error {
if c.version == "" {
c.version = row.ToStr("ver")
}
return nil
}, "SELECT VERSION() AS `ver`"); err != nil {
return "", err
}
return c.version, nil
}
func (c *MysqlClient) doInsert(ctx context.Context, table string, row *MysqlRow, ignore, replace bool) (ret int64, err error) {
if row.IsEmpty() {
return 0, ErrMysqlEmptyData
}
cols := make([]string, 0)
phds := make([]string, 0)
params := make([]any, 0)
for col, param := range *row {
if strings.Contains(col, "`") {
return 0, ErrMysqlColWithBQuote
}
cols = append(cols, col)
params = append(params, param)
phds = append(phds, "?")
}
query := ""
if ignore {
query = "INSERT IGNORE "
} else if replace {
query = "REPLACE "
} else {
query = "INSERT "
}
query += "INTO `" + table + "` (`" + strings.Join(cols, "`, `") + "`) VALUES (" + strings.Join(phds, ", ") + ")"
defer func() {
if c.Listener != nil {
go c.Listener(ctx, query, err, ret, params)
}
}()
var db *sql.DB
db, err = c.GetDB(ctx, false)
if err != nil {
return
}
ctx1, cancel := context.WithTimeout(ctx, c.TimeoutInsert)
defer cancel()
var res sql.Result
if res, err = db.ExecContext(ctx1, query, params...); err != nil {
return
}
ret, err = res.LastInsertId()
return
}
func (c *MysqlClient) Insert(ctx context.Context, table string, row *MysqlRow) (ret int64, err error) {
return c.doInsert(ctx, table, row, false, false)
}
func (c *MysqlClient) InsertIgnore(ctx context.Context, table string, row *MysqlRow) (ret int64, err error) {
return c.doInsert(ctx, table, row, true, false)
}
func (c *MysqlClient) InsertReplace(ctx context.Context, table string, row *MysqlRow) (ret int64, err error) {
return c.doInsert(ctx, table, row, false, true)
}
func (c *MysqlClient) InsertDuplicate(ctx context.Context, table string, row *MysqlRow) (ret int64, err error) {
if row.IsEmpty() {
return 0, ErrMysqlEmptyData
}
cols := make([]string, 0)
phds := make([]string, 0)
sets := make([]string, 0)
params := make([]any, 0)
for col, param := range *row {
if strings.Contains(col, "`") {
return 0, ErrMysqlColWithBQuote
}
cols = append(cols, col)
params = append(params, param)
phds = append(phds, "?")
sets = append(sets, "`"+col+"` = ?")
}
params = append(params, params...)
query := "INSERT INTO `" + table + "` (`" + strings.Join(cols, "`, `") + "`) VALUES (" + strings.Join(phds, ", ") + ") "
query += "ON DUPLICATE KEY UPDATE " + strings.Join(sets, ", ")
defer func() {
if c.Listener != nil {
go c.Listener(ctx, query, err, ret, params)
}
}()
var db *sql.DB
db, err = c.GetDB(ctx, false)
if err != nil {
return
}
ctx1, cancel := context.WithTimeout(ctx, c.TimeoutInsert)
defer cancel()
var res sql.Result
if res, err = db.ExecContext(ctx1, query, params...); err != nil {
return
}
ret, err = res.LastInsertId()
return
}
func (c *MysqlClient) InsertBatch(ctx context.Context, table string, rows []*MysqlRow) (ret int64, err error) {
if len(rows) == 0 || rows[0].IsEmpty() {
return 0, ErrMysqlEmptyData
}
cols := make([]string, 0)
phds := make([]string, 0)
colsReal := make([]string, 0)
for col := range *(rows[0]) {
if strings.Contains(col, "`") {
return 0, ErrMysqlColWithBQuote
}
cols = append(cols, col)
colsReal = append(colsReal, col)
phds = append(phds, "?")
}
query := "INSERT IGNORE INTO `" + table + "` (`" + strings.Join(cols, "`, `") + "`) VALUES (" + strings.Join(phds, ", ") + ")"
var db *sql.DB
db, err = c.GetDB(ctx, false)
if err != nil {
return
}
ctx1, cancel := context.WithTimeout(ctx, 3*c.TimeoutInsert)
defer cancel()
var stmt *sql.Stmt
stmt, err = db.PrepareContext(ctx1, query)
if err != nil {
return
}
defer func() {
_ = stmt.Close()
}()
var res sql.Result
var affected, newId int64
for _, row := range rows {
params := make([]any, 0, len(colsReal))
for _, col := range colsReal {
params = append(params, row.Get(col))
}
if res, err = stmt.ExecContext(ctx1, params...); err != nil {
return
} else if affected, err = res.RowsAffected(); err != nil {
return
} else {
ret += affected
if c.Listener != nil {
newId, err = res.LastInsertId()
go c.Listener(ctx, query, err, newId, params)
}
}
}
return
}
func (c *MysqlClient) SelectWalk(
ctx context.Context,
fn func(ctx context.Context, row *MysqlRow) error,
query string,
params ...any,
) error {
if (strings.ToLower(query[0:7]) != "select " && strings.ToLower(query[0:5]) != "show ") || len(query) < 8 {
return ErrMysqlWrongSql
}
db, err := c.GetDB(ctx, true)
if err != nil {
return err
}
ctx1, cancel := context.WithTimeout(ctx, c.TimeoutSelect)
defer cancel()
var rows *sql.Rows
rows, err = db.QueryContext(ctx1, query, params...)
if err != nil {
return err
}
defer func() {
if c.Listener != nil {
go c.Listener(ctx, query, err, params)
}
}()
defer func() {
_ = rows.Close()
}()
var cols []string
cols, err = rows.Columns()
if err != nil {
return err
}
for rows.Next() {
rawBuffers := make([]sql.RawBytes, len(cols))
scanArgs := make([]any, len(cols))
for i := range rawBuffers {
scanArgs[i] = &rawBuffers[i]
}
if err = rows.Scan(scanArgs...); err != nil {
return err
}
row := NewMysqlRow()
for i, bs := range rawBuffers {
row.Set(cols[i], string(bs))
}
if err = fn(ctx, row); err != nil {
return err
}
}
return nil
}
func (c *MysqlClient) ShowColumns(ctx context.Context, table string) (map[string]string, error) {
ret := make(map[string]string)
if err := c.SelectWalk(ctx, func(_ context.Context, row *MysqlRow) error {
ret[row.ToStr("Field")] = row.ToStr("Type")
return nil
}, fmt.Sprintf("SHOW COLUMNS FROM `%s`", table)); err != nil {
return nil, err
}
return ret, nil
}
func (c *MysqlClient) Select(ctx context.Context, query string, params ...any) (rows []*MysqlRow, err error) {
err = c.SelectWalk(ctx, func(_ context.Context, row *MysqlRow) error {
rows = append(rows, row)
return nil
}, query, params...)
return
}
func (c *MysqlClient) Count(ctx context.Context, query string, params ...any) (cnt int64, err error) {
if strings.ToLower(query[0:13]) != "select count(" || len(query) < 14 {
return 0, ErrMysqlWrongSql
}
db, err := c.GetDB(ctx, true)
if err != nil {
return 0, err
}
defer func() {
if c.Listener != nil {
go c.Listener(ctx, query, err, cnt, params)
}
}()
ctx1, cancel := context.WithTimeout(ctx, c.TimeoutSelect)
defer cancel()
err = db.QueryRowContext(ctx1, query, params...).Scan(&cnt)
return
}
func (c *MysqlClient) SelectPage(
ctx context.Context,
fn func(ctx context.Context, row *MysqlRow) error,
page, size int64,
table, where, order, cols string,
params ...any,
) (totalRows, totalPages, currentPage int64, err error) {
if where == "" {
where = "1 = 1"
}
totalRows, err = c.Count(ctx, fmt.Sprintf("SELECT COUNT(*) FROM `%s` WHERE %s", table, where), params...)
if err != nil {
return
}
if totalRows == 0 {
return
}
if size < 1 {
size = 10
}
totalPages = int64(math.Ceil(float64(totalRows) / float64(size)))
if page < 1 {
page = 1
}
currentPage = page
if page > totalPages {
err = ErrMysqlPageTooLarge
return
}
if cols == "" {
cols = "*"
}
if order != "" {
order = "ORDER BY " + order
}
params = append(params, (page-1)*size, size)
query := fmt.Sprintf("SELECT %s FROM `%s` WHERE %s %s LIMIT ?, ?", cols, table, where, order)
err = c.SelectWalk(ctx, fn, query, params...)
return
}
func (c *MysqlClient) SelectByIds(
ctx context.Context,
fn func(ctx context.Context, row *MysqlRow) error,
table, cols string,
ids ...any,
) error {
if cols == "" {
cols = "*"
}
query := ""
switch len(ids) {
case 0:
return ErrMysqlEmptyData
case 1:
query = fmt.Sprintf("SELECT %s FROM `%s` WHERE `id` = ?", cols, table)
default:
phds := strings.TrimRight(strings.Repeat(" ?,", len(ids)), ",")
query = fmt.Sprintf("SELECT %s FROM `%s` WHERE `id` IN (%s)", cols, table, phds)
}
return c.SelectWalk(ctx, fn, query, ids...)
}
func (c *MysqlClient) Update(ctx context.Context, query string, params ...any) (ret int64, err error) {
if sqlStart7 := strings.ToLower(query[0:7]); (sqlStart7 != "update " && sqlStart7 != "delete ") || len(query) < 8 {
err = ErrMysqlWrongSql
return
}
var db *sql.DB
db, err = c.GetDB(ctx, false)
if err != nil {
return
}
defer func() {
if c.Listener != nil {
go c.Listener(ctx, query, err, ret, params)
}
}()
ctx1, cancel := context.WithTimeout(ctx, c.TimeoutUpdate)
defer cancel()
var res sql.Result
if res, err = db.ExecContext(ctx1, query, params...); err != nil {
return
}
ret, err = res.RowsAffected()
return
}
func (c *MysqlClient) UpdateById(ctx context.Context, table string, row *MysqlRow, id any) (int64, error) {
if row.IsEmpty() {
return 0, ErrMysqlEmptyData
}
row.Drop("id")
sets := make([]string, 0)
params := make([]any, 0)
for col, param := range *row {
if strings.Contains(col, "`") {
return 0, ErrMysqlColWithBQuote
}
sets = append(sets, "`"+col+"` = ?")
params = append(params, param)
}
if len(sets) == 0 {
return 0, ErrMysqlEmptyData
}
params = append(params, id)
return c.Update(ctx, fmt.Sprintf("UPDATE `%s` SET %s WHERE `id` = ?", table, strings.Join(sets, ", ")), params...)
}
func (c *MysqlClient) Delete(ctx context.Context, query string, params ...any) (int64, error) {
if strings.ToLower(query[0:7]) != "delete " || len(query) < 8 {
return 0, ErrMysqlWrongSql
}
return c.Update(ctx, query, params...)
}
func (c *MysqlClient) DeleteByIds(ctx context.Context, table string, ids ...any) (int64, error) {
query := ""
switch len(ids) {
case 0:
return 0, ErrMysqlEmptyData
case 1:
query = fmt.Sprintf("DELETE FROM `%s` WHERE `id` = ?", table)
default:
phds := strings.TrimRight(strings.Repeat(" ?,", len(ids)), ",")
query = fmt.Sprintf("DELETE FROM `%s` WHERE `id` IN (%s)", table, phds)
}
return c.Delete(ctx, query, ids...)
}
func (c *MysqlClient) Transaction(
ctx context.Context,
fn func(ctx context.Context, tx *sql.Tx) error,
opts *sql.TxOptions,
) error {
db, err := c.GetDB(ctx, false)
if err != nil {
return err
}
ctx1, cancel := context.WithTimeout(ctx, c.TimeoutTrans)
defer cancel()
tx, err := db.BeginTx(ctx1, opts)
if err != nil {
return err
}
defer func() {
if r := recover(); r != nil {
_ = tx.Rollback()
}
}()
if err = fn(ctx, tx); err != nil {
_ = tx.Rollback()
} else if err = tx.Commit(); err != nil {
_ = tx.Rollback()
}
return err
}
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。