1 Star 0 Fork 0

寻根 / goweb

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
dbx.go 9.28 KB
一键复制 编辑 原始数据 按行查看 历史
lemonzheng(郑刚) 提交于 2024-04-12 17:41 . 代码优化
package dbx
import (
"database/sql"
"fmt"
"gitee.com/xungen/goweb/logx"
"gitee.com/xungen/goweb/utils"
"reflect"
"strings"
"sync"
"time"
)
type Config struct {
Port int
Host string
Type string
Name string
User string
Charset string
Password string
}
type DBConnect struct {
logx.Logger
db *sql.DB
tx *sql.Tx
}
type Keys = []string
type Exclude = []string
var poolmap sync.Map
var placeholderlist = strings.Repeat("?,", 1000)
var excludetimefields = Exclude{"AddAt", "CreateAt", "UpdateAt", "AddTime", "CreateTime", "UpdateTime"}
func Clear(name string) {
if name == "" {
poolmap.Range(func(key, value interface{}) bool {
Delete(key.(string))
return true
})
} else {
Delete(name)
}
}
func Delete(name string) {
if res, ok := poolmap.LoadAndDelete(name); ok {
go func() {
time.Sleep(time.Minute)
res.(*DBConnect).Close()
}()
}
}
func Name(data interface{}) string {
name := reflect.TypeOf(data).String()
if pos := strings.LastIndex(name, "."); pos >= 0 {
name = name[pos+1:]
} else {
name = strings.TrimLeft(name, "*")
}
return format(name)
}
func NewConnect(db *sql.DB) *DBConnect {
return &DBConnect{db: db, Logger: logx.Instance()}
}
func Open(cfg *Config) (*DBConnect, error) {
var source string
if strings.HasPrefix(cfg.Type, "sqlite") {
source = "sqlite3"
} else {
source = fmt.Sprintf("%s:%s@tcp(%s:%d)/%s", cfg.User, cfg.Password, cfg.Host, cfg.Port, cfg.Name)
if len(cfg.Charset) > 0 {
source += "?charset=" + cfg.Charset
}
}
db, err := sql.Open(cfg.Type, source)
if err != nil {
return nil, err
}
return NewConnect(db), nil
}
func ExcludeTimeFields(args ...string) Exclude {
res := make(Exclude, len(excludetimefields))
copy(res, excludetimefields)
return append(res, args...)
}
func Get(logger logx.Logger, name ...interface{}) *DBConnect {
var key = utils.Sprint(name...)
if res, ok := poolmap.Load(key); ok {
var pool = res.(*DBConnect)
if logger == nil {
if logger = pool.Logger; logger == nil {
logger = logx.Instance()
}
}
return &DBConnect{Logger: logger, db: pool.db}
}
return nil
}
func Set(name string, pool *DBConnect, logger logx.Logger) *DBConnect {
pool.Logger = logger
poolmap.Store(name, pool)
return Get(logger, name)
}
func (db *DBConnect) Close() {
if db.db == nil {
return
}
db.db.Close()
db.db = nil
db.tx = nil
}
func (db DBConnect) DB() *sql.DB {
return db.db
}
func EmptyResultError(err error) bool {
return err == sql.ErrNoRows
}
func (db DBConnect) Query(sqlcmd string, args ...interface{}) (map[string]string, error) {
res, err := db.QueryList(sqlcmd, args...)
if err != nil {
return nil, err
}
if len(res) == 0 {
return nil, nil
}
return res[0], nil
}
func (db DBConnect) QueryList(sqlcmd string, args ...interface{}) ([]map[string]string, error) {
var err error
var cols []string
var rows *sql.Rows
var list []map[string]string
defer func() {
db.trace(err, len(list), sqlcmd, args...)
}()
if rows, err = db.db.Query(sqlcmd, args...); err != nil {
return nil, err
}
defer rows.Close()
if cols, err = rows.Columns(); err != nil {
return nil, err
}
var num = len(cols)
var arr = make([]interface{}, num)
var vals = make([]sql.NullString, num)
for i := 0; i < num; i++ {
arr[i] = &vals[i]
}
for rows.Next() {
if err = rows.Scan(arr...); err != nil {
return nil, err
}
data := make(map[string]string, num)
for i, v := range vals {
data[cols[i]] = v.String
}
list = append(list, data)
}
return list, nil
}
func (db DBConnect) Select(result interface{}, sqlcmd string, args ...interface{}) error {
dest := reflect.ValueOf(result).Elem()
if dest.Kind() == reflect.Slice {
return db.query(&db, dest, sqlcmd, args...)
}
res := reflect.New(reflect.SliceOf(dest.Type())).Elem()
if err := db.query(&db, res, sqlcmd, args...); err != nil {
return err
}
if res.Len() == 0 {
return sql.ErrNoRows
}
dest.Set(res.Index(0))
return nil
}
func (db DBConnect) SelectWhere(result interface{}, cond string, args ...interface{}) error {
return db.SelectFrom(result, Name(result), cond, args...)
}
func (db DBConnect) SelectFrom(result interface{}, table string, cond string, args ...interface{}) error {
if cond == "" {
cond = "select ${fields} from " + table
} else {
cond = "select ${fields} from " + table + " where " + cond
}
return db.Select(result, cond, args...)
}
func (db DBConnect) ExecBatch(callback func(*DBConnect) error) error {
tx, err := db.db.Begin()
if err != nil {
return err
}
defer func() {
if e := recover(); e == nil {
if err == nil {
tx.Commit()
} else {
tx.Rollback()
}
} else {
tx.Rollback()
panic(e)
}
}()
err = callback(&DBConnect{tx: tx, Logger: db.Logger})
return err
}
func (db DBConnect) Exec(sqlcmd string, args ...interface{}) (sql.Result, error) {
var err error
var res sql.Result
if db.tx == nil {
res, err = db.db.Exec(sqlcmd, args...)
} else {
res, err = db.tx.Exec(sqlcmd, args...)
}
if err == nil {
if num, err := res.RowsAffected(); err == nil {
db.trace(err, int(num), sqlcmd, args...)
} else {
db.Error(utils.String(err))
}
} else {
db.trace(err, 0, sqlcmd, args...)
}
return res, err
}
func (db DBConnect) Insert(data interface{}, exclude ...Exclude) (sql.Result, error) {
cols := ""
args := make([]interface{}, 0)
exlist := combineSlice(exclude...)
fields := utils.GetFieldList(data)
result := utils.GetFieldValues(data)
modifyNameList(fields)
modifyNameList(exlist)
for i, v := range fields {
for _, k := range exlist {
if v == k {
v = ""
break
}
}
if len(v) > 0 {
args = append(args, result[i].Interface())
cols += "," + v
}
}
sqlcmd := fmt.Sprintf("insert into %s(%s) value(%s)", Name(data), cols[1:], placeholder(len(args)))
return db.Exec(sqlcmd, args...)
}
func (db DBConnect) Update(data interface{}, keys Keys, exclude ...Exclude) (sql.Result, error) {
vals := ""
cond := ""
vargs := make([]interface{}, 0)
cargs := make([]interface{}, 0)
exlist := combineSlice(exclude...)
fields := utils.GetFieldList(data)
result := utils.GetFieldValues(data)
modifyNameList(keys)
modifyNameList(fields)
modifyNameList(exlist)
for i, v := range fields {
for _, k := range exlist {
if v == k {
v = ""
break
}
}
if len(v) > 0 {
match := false
for _, k := range keys {
if v == k {
match = true
break
}
}
if match {
cargs = append(cargs, result[i].Interface())
cond += " and " + v + "=?"
} else {
vargs = append(vargs, result[i].Interface())
vals += "," + v + "=?"
}
}
}
if cond == "" {
cond = fmt.Sprintf("update %s set %s", Name(data), vals[1:])
} else {
cond = fmt.Sprintf("update %s set %s where %s", Name(data), vals[1:], cond[5:])
}
args := append(vargs, cargs...)
return db.Exec(cond, args...)
}
func format(str string) string {
return utils.UnderScoreCase(str)
}
func placeholder(num int) string {
return placeholderlist[:num+num-1]
}
func modifyNameList(fields []string) []string {
for i, v := range fields {
fields[i] = format(v)
}
return fields
}
func combineSlice(fields ...[]string) []string {
var res []string
for _, v := range fields {
res = append(res, v...)
}
return res
}
func (log DBConnect) trace(err error, rows int, sqlcmd string, args ...interface{}) {
var params strings.Builder
if len(args) > 0 {
params.WriteString(" with param")
for _, v := range args {
params.WriteRune('[')
params.WriteString(utils.String(v))
params.WriteRune(']')
}
}
if err == nil {
log.Debug("execute sqlcmd[%s] success[%d]%s", sqlcmd, rows, params.String())
} else {
log.Error("execute sqlcmd[%s] failed[%s]%s", sqlcmd, err.Error(), params.String())
}
}
func (log DBConnect) query(db *DBConnect, result reflect.Value, sqlcmd string, args ...interface{}) error {
var res = 0
var err error
var cols []string
var rows *sql.Rows
var dest = reflect.New(result.Type().Elem())
var item = reflect.ValueOf(dest.Interface())
if pos := strings.Index(sqlcmd, "${name}"); pos > 0 {
end := pos + 7
sqlcmd = sqlcmd[0:pos] + Name(dest.Interface()) + sqlcmd[end:]
}
if pos := strings.Index(sqlcmd, "${fields}"); pos > 0 {
end := pos + 9
fields := modifyNameList(utils.GetFieldList(dest.Interface()))
sqlcmd = sqlcmd[0:pos] + strings.Join(fields, ",") + sqlcmd[end:]
}
defer func() {
log.trace(err, res, sqlcmd, args...)
}()
if rows, err = db.db.Query(sqlcmd, args...); err != nil {
return err
}
defer rows.Close()
if cols, err = rows.Columns(); err != nil {
return err
}
if item.Kind() == reflect.Ptr {
item = item.Elem()
}
vec := result
num := len(cols)
typ := item.Type()
arr := make([]interface{}, num)
vals := make([]sql.NullString, num)
destmap := make(map[int]reflect.Value, num)
if utils.IsBaseValue(item.Kind()) {
destmap[0] = item
arr[0] = &vals[0]
} else {
for i := 0; i < num; i++ {
name := cols[i]
if _, ok := typ.FieldByName(name); ok {
destmap[i] = item.FieldByName(name)
} else {
name := utils.CamelCase(name)
if _, ok := typ.FieldByName(name); ok {
destmap[i] = item.FieldByName(name)
}
}
arr[i] = &vals[i]
}
}
for rows.Next() {
if err = rows.Scan(arr...); err != nil {
return err
}
for idx, field := range destmap {
utils.SetFieldValue(field, vals[idx].String)
}
vec = reflect.Append(vec, dest.Elem())
res++
}
result.Set(vec)
return err
}
Go
1
https://gitee.com/xungen/goweb.git
git@gitee.com:xungen/goweb.git
xungen
goweb
goweb
v0.0.8

搜索帮助

53164aa7 5694891 3bd8fe86 5694891