1 Star 0 Fork 0

tomatomeatman/GolangRepository

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
SqlFactory.go 18.44 KB
一键复制 编辑 原始数据 按行查看 历史
laowei 提交于 2024-08-30 17:29 +08:00 . 1
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720
package gorm
import (
"database/sql"
"errors"
"fmt"
"reflect"
"strconv"
"strings"
"time"
"unsafe"
Log "github.com/cihub/seelog"
"github.com/shopspring/decimal"
)
// // 将查询结果转换成map数组,常用于原生sql查询
// func ScanRows2map(rows *sql.Rows) []map[string]string {
// if nil == rows {
// return nil
// }
// res := make([]map[string]string, 0) // 定义结果 map
// colTypes, _ := rows.ColumnTypes() // 列信息
// var rowParam = make([]interface{}, len(colTypes)) // 传入到 rows.Scan 的参数 数组
// var rowValue = make([]interface{}, len(colTypes)) // 接收数据一行列的数组
// for i, colType := range colTypes {
// rowValue[i] = reflect.New(colType.ScanType()) // 跟据数据库参数类型,创建默认值 和类型
// rowParam[i] = reflect.ValueOf(&rowValue[i]).Interface() // 跟据接收的数据的类型反射出值的地址
// }
// // 遍历每行
// for rows.Next() {
// //rows.Scan(rowParam) // 赋值到 rowValue 中 go 1.20
// rows.Scan(rowParam...) // 赋值到 rowValue 中 go 1.16
// record := make(map[string]string)
// for i, colType := range colTypes {
// if rowValue[i] == nil {
// record[colType.Name()] = ""
// continue
// }
// //如果字段类型为int,则需要进一步判断
// //并且1.如果获得值类型为int64,则需要按int64处理,常用于类似以下的查询:
// //rows, _ := SqlFactory{}.GetDB().Raw("select * from table where uId=@GuId and sName=@GsName", &where).Rows()
// //并且2.如果数据库类型虽然为INT,但获取的值被以string进行接收, 则要按字符串的方式进行,常用于类似以下的查询:
// //rows, _ := SqlFactory{}.GetDB().Raw("select * from Student").Rows()
// if colType.DatabaseTypeName() == "INT" { //
// switch rowValue[i].(type) {
// case int64: //
// record[colType.Name()] = strconv.FormatInt(int64(rowValue[i].(int64)), 10)
// continue
// }
// }
// if colType.DatabaseTypeName() == "BIGINT" {
// record[colType.Name()] = toStr(rowValue[i])
// continue
// }
// if colType.DatabaseTypeName() == "FLOAT" {
// record[colType.Name()] = toStr(rowValue[i])
// continue
// }
// if colType.DatabaseTypeName() == "DOUBLE" {
// record[colType.Name()] = toStr(rowValue[i])
// continue
// }
// if colType.DatabaseTypeName() == "DECIMAL" {
// record[colType.Name()] = toStr(rowValue[i])
// continue
// }
// if colType.DatabaseTypeName() == "DATETIME" {
// record[colType.Name()] = rowValue[i].(time.Time).Format("2006-01-02 15:04:05")
// continue
// }
// record[colType.Name()] = byte2Str(rowValue[i].([]byte))
// }
// res = append(res, record)
// }
// return res
// }
// 将查询结果转换成map数组,常用于原生sql查询
func ScanRows2mapI(rows *sql.Rows) []map[string]interface{} {
defer rows.Close()
columns, err := rows.Columns()
if err != nil {
return nil
}
values := make([]interface{}, len(columns))
valuePtrs := make([]interface{}, len(columns))
for i := range columns {
valuePtrs[i] = &values[i]
}
colNames := make(map[string]string)
temp, _ := rows.ColumnTypes() // 列信息
for _, colType := range temp {
colNames[colType.Name()] = strings.ToUpper(colType.DatabaseTypeName()) // 跟据数据库参数类型,创建默认值 和类型
}
maps := []map[string]interface{}{} //就算没数据也不会返回nil
for rows.Next() {
err := rows.Scan(valuePtrs...)
if err != nil {
return nil
}
m := make(map[string]interface{})
for i, col := range columns {
m[col] = dbValueCorrect(colNames[col], values[i])
// val := values[i]
// m[col] = val //先设置为从数据库获取的值
// if val == nil {
// continue
// }
// //数据调整,因从数据库获取的值并不一定符合数据类型,例如Decimal字段返回的数据类型不是decimal.Decimal而是string
// switch colNames[col] {
// case "INT":
// value := toInt(val, -99999)
// if value != -99999 {
// m[col] = value
// }
// case "NUMERIC":
// value := toInt(val, -99999)
// if value != -99999 {
// m[col] = value
// }
// case "TINYINT":
// value := toInt(val, -99999)
// if value != -99999 {
// m[col] = value
// }
// case "BIGINT":
// value := toInt(val, -99999)
// if value != -99999 {
// m[col] = value
// }
// case "FLOAT":
// value := toFloat(val, 64, -99999999.99999999)
// if value != -99999999.99999999 {
// m[col] = value
// }
// case "DOUBLE":
// value := toFloat(val, 64, -99999999.99999999)
// if value != -99999999.99999999 {
// m[col] = value
// }
// case "DECIMAL":
// if reflect.TypeOf(val).Elem().Name() == "Decimal" {
// m[col] = val.(*decimal.Decimal)
// continue
// }
// if reflect.TypeOf(val).Elem().Name() == "uint8" {
// m[col] = toStr(val)
// continue
// }
// m[col] = val
// case "DATETIME":
// m[col] = val.(time.Time).Format("2006-01-02 15:04:05")
// case "DATE":
// m[col] = val.(time.Time).Format("2006-01-02 15:04:05")
// case "TIME":
// m[col] = val.(time.Time).Format("2006-01-02 15:04:05")
// case "TIMESTAMP":
// m[col] = toInt(val, -99999)
// default:
// b, ok := val.([]byte)
// if ok {
// m[col] = string(b)
// continue
// }
// m[col] = val
// }
}
maps = append(maps, m)
}
return maps
}
/**
* 将查询结果修正成符合字段类型的数据
* 比如DECIMAL类型的值如果不进行修正返回会是字符串
* @param colType 数据库字段类型
* @param val 数据库字段值
* @return 修正后的数据
*/
func dbValueCorrect(colType string, val interface{}) interface{} {
if val == nil {
return nil
}
switch colType {
case "INT":
value := toInt(val, -99999)
if value != -99999 {
return value
}
case "NUMERIC":
value := toInt(val, -99999)
if value != -99999 {
return value
}
case "TINYINT":
value := toInt(val, -99999)
if value != -99999 {
return value
}
case "BIGINT":
value := toInt(val, -99999)
if value != -99999 {
return value
}
case "FLOAT":
value := toFloat(val, 64, -99999999.99999999)
if value != -99999999.99999999 {
return value
}
case "DOUBLE":
value := toFloat(val, 64, -99999999.99999999)
if value != -99999999.99999999 {
return value
}
case "DECIMAL":
switch reflect.TypeOf(val).Elem().Name() {
case "Decimal":
return val.(*decimal.Decimal)
case "uint8":
temp, _ := decimal.NewFromString(toStr(val))
return temp
case "string":
temp, _ := decimal.NewFromString(val.(string))
return temp
default:
return val
}
case "DATETIME":
return val.(time.Time).Format("2006-01-02 15:04:05")
case "DATE":
return val.(time.Time).Format("2006-01-02 15:04:05")
case "TIME":
return val.(time.Time).Format("2006-01-02 15:04:05")
case "TIMESTAMP":
return toInt(val, -99999)
default:
b, ok := val.([]byte)
if ok {
return string(b)
}
return val
}
return val //都不在范围内,不转换
}
// func ScanRows2mapI(rows *sql.Rows) []map[string]interface{} {
// if nil == rows {
// return nil
// }
// res := make([]map[string]interface{}, 0) // 定义结果 map
// colTypes, _ := rows.ColumnTypes() // 列信息
// var rowParam = make([]interface{}, len(colTypes)) // 传入到 rows.Scan 的参数 数组
// var rowValue = make([]interface{}, len(colTypes)) // 接收数据一行列的数组
// for i, colType := range colTypes {
// rowValue[i] = reflect.New(colType.ScanType()) // 跟据数据库参数类型,创建默认值 和类型
// rowParam[i] = reflect.ValueOf(&rowValue[i]).Interface() // 跟据接收的数据的类型反射出值的地址
// }
// // 遍历
// for rows.Next() {
// rows.Scan(rowParam) // 赋值到 rowValue 中 go 1.20
// //rows.Scan(rowParam...) // 赋值到 rowValue 中 go 1.16
// record := make(map[string]interface{})
// for i, colType := range colTypes {
// if rowValue[i] == nil {
// record[colType.Name()] = ""
// continue
// }
// //如果字段类型为int,则需要进一步判断
// //并且1.如果获得值类型为int64,则需要按int64处理,常用于类似以下的查询:
// //rows, _ := SqlFactory{}.GetDB().Raw("select * from table where uId=@GuId and sName=@GsName", &where).Rows()
// //并且2.如果数据库类型虽然为INT,但获取的值被以string进行接收, 则要按字符串的方式进行,常用于类似以下的查询:
// //rows, _ := SqlFactory{}.GetDB().Raw("select * from Student").Rows()
// if colType.DatabaseTypeName() == "INT" { //
// switch value := rowValue[i].(type) {
// case int64:
// record[colType.Name()] = rowValue[i]
// //record[colType.Name()] = strconv.FormatInt(int64(rowValue[i].(int64)), 10)
// continue
// case string:
// record[colType.Name()] = toInt64(rowValue[i], -99999)
// continue
// case []uint8:
// record[colType.Name()] = toInt64(rowValue[i], -99999)
// continue
// default:
// fmt.Println(value)
// }
// }
// if colType.DatabaseTypeName() == "BIGINT" {
// switch value := rowValue[i].(type) {
// case int64:
// record[colType.Name()] = rowValue[i]
// //record[colType.Name()] = strconv.FormatInt(int64(rowValue[i].(int64)), 10)
// continue
// case string:
// record[colType.Name()] = toInt64(rowValue[i], -99999)
// continue
// case []uint8:
// record[colType.Name()] = toInt64(rowValue[i], -99999)
// continue
// default:
// fmt.Println(value)
// }
// }
// if colType.DatabaseTypeName() == "DATETIME" {
// record[colType.Name()] = rowValue[i]
// continue
// }
// record[colType.Name()] = byte2Str(rowValue[i].([]byte))
// }
// res = append(res, record)
// }
// return res
// }
// Byte转Str
func byte2Str(b []byte) string {
return *(*string)(unsafe.Pointer(&b))
}
// 转换字符串
func toStr(data interface{}) string {
switch obj := data.(type) {
case []uint8:
return byte2Str(obj)
default:
return fmt.Sprintf("%v", data)
}
}
// // 对象(字符串)转64整型
// func toInt64(data interface{}, iDefault int64) int64 {
// var str string
// switch obj := data.(type) {
// case []uint8:
// str = byte2Str(obj)
// default:
// str = fmt.Sprintf("%v", obj)
// }
// if str == "" { //字符串不能判断nil
// return iDefault
// }
// result, err := strconv.ParseInt(str, 10, 64)
// if err != nil {
// return iDefault
// }
// return result
// }
// 对象(字符串)转整型
func toInt(data interface{}, iDefault int) int {
var str string
switch obj := data.(type) {
case []uint8:
str = byte2Str(obj)
default:
str = fmt.Sprintf("%v", obj)
}
if str == "" { //字符串不能判断nil
return iDefault
}
result, err := strconv.ParseInt(str, 10, 64)
if err != nil {
return iDefault
}
return int(result)
}
// 对象(字符串)转64整型
func toFloat(data interface{}, bitSize int, iDefault float64) float64 {
var str string
switch obj := data.(type) {
case []uint8:
str = byte2Str(obj)
default:
str = fmt.Sprintf("%v", obj)
}
if str == "" { //字符串不能判断nil
return iDefault
}
result, err := strconv.ParseFloat(str, bitSize)
if err != nil {
return iDefault
}
return result
}
// 取数据库名称
func GetDbName(name string) string {
return GetVariable(name)
}
// 取数据库全局变量
func GetVariable(name string) string {
if name == "" {
return ""
}
for key := range dbVariables {
if name == key {
return dbVariables[key]
}
}
return ""
}
// 替换字符串中的所有全局变量
func ReplaceVariable(sqlstr string) string {
if sqlstr == "" {
return ""
}
result := sqlstr
for key, val := range dbVariables {
if !strings.Contains(result, "${") {
return result
}
if !strings.Contains(result, "${"+key+"}") {
continue
}
result = strings.Replace(result, "${"+key+"}", val, -1)
}
return result
}
// 添加记录,返回影响行数及错误信息
func Add(entity interface{}) (int64, error) {
result := GetDB().Create(entity)
return result.RowsAffected, result.Error
}
// 调用数据查询数量
func Count(sql string, params ...interface{}) (int, error) {
var iCount int
dbResult := doDb(sql, params, globGormDB.Raw).Scan(&iCount)
if dbResult.Error != nil {
return 0, dbResult.Error
}
return iCount, nil
}
// 调用数据查询值
// 警告:只能是明确知道只有一条记录且只返回一个值的时候使用
func Value(sql string, params ...interface{}) (interface{}, error) {
rows, iCode, err := FindToMap(sql, params...)
if err != nil {
return nil, err
}
if iCode == 0 {
return nil, nil
}
for _, v := range rows[0] { //取第一行第一列的值
return v, nil
}
return nil, nil //没有数据
}
// 调用数据查询
func Query(sql string, dest interface{}, where ...interface{}) (interface{}, error) {
var dbResult GormDB
if len(where) < 1 {
dbResult = GetDB().Raw(sql).Scan(dest)
} else {
dbResult = GetDB().Raw(sql, where...).Scan(dest)
}
if dbResult.Error != nil {
Log.Error("查询发生异常:", dbResult.Error)
return nil, dbResult.Error
}
return &dest, dbResult.Error
}
// 调用数据查询
func Find(sql string, params ...interface{}) (tx GormDB) {
return doDb(sql, params, globGormDB.Raw)
}
// 调用数据查询
// 返回规则: 发生错误时: {数据为nil,错误码值,错误信息}
// 正确时: {数据,数据数量,nil}
func FindToMap(text string, params ...interface{}) ([]map[string]interface{}, int, error) {
rows, err := doDb(text, params, globGormDB.Raw).Rows()
if err != nil {
Log.Error("查询发生异常:", err)
return nil, 1002, err
}
defer rows.Close()
res := ScanRows2mapI(rows)
if res == nil {
Log.Error("查询成功后进行数据转换时发生异常,:无法正确转换")
return nil, 1003, errors.New("查询发生异常")
}
rowCount := len(res)
if rowCount < 1 {
return res, rowCount, nil //没有数据
}
return res, rowCount, nil
}
// 调用数据查询一条记录
// 返回规则: 发生错误时: {数据为nil,错误码值,错误信息}
// 正确时: {数据,数据数量,nil}
func FindOneMap(text string, params ...interface{}) (map[string]interface{}, int, error) {
rows, err := doDb(text, params, globGormDB.Raw).Rows()
if err != nil {
Log.Error("查询发生异常:", err)
return nil, 1001, err
}
defer rows.Close()
res := ScanRows2mapI(rows)
if res == nil {
Log.Error("查询成功后进行数据转换时发生异常,:无法正确转换")
return nil, 1002, errors.New("查询发生异常")
}
rowCount := len(res)
if rowCount < 1 {
return nil, 1003, errors.New("没有数据")
}
return res[0], rowCount, nil
}
// 格式化结果集(备用代码)
// func FormatScan(results *[]map[string]interface{}, colTypes map[string]*sql.ColumnType) {
// for _, row := range *results {
// for key, value := range row {
// if value == nil {
// continue
// }
// switch colTypes[key].DatabaseTypeName() {
// case "INT":
// if reflect.TypeOf(value).String() == "int32" {
// row[key] = int(value.(int32))
// continue
// }
// if reflect.TypeOf(value).String() == "int64" {
// row[key] = int(value.(int64))
// continue
// }
// tmp := toInt(value, -99999)
// if tmp != -99999 {
// row[key] = tmp
// }
// continue
// case "NUMERIC":
// tmp := toInt64(value, -99999)
// if tmp != -99999 {
// row[key] = tmp
// }
// continue
// case "TINYINT":
// tmp := toInt64(value, -99999)
// if tmp != -99999 {
// row[key] = tmp
// }
// continue
// case "BIGINT":
// tmp := toInt64(value, -99999)
// if tmp != -99999 {
// row[key] = tmp
// }
// continue
// case "FLOAT":
// tmp := toFloat(value, 64, -99999999.99999999)
// if tmp != -99999999.99999999 {
// row[key] = tmp
// }
// continue
// case "DOUBLE":
// tmp := toFloat(value, 64, -99999999.99999999)
// if tmp != -99999999.99999999 {
// row[key] = tmp
// }
// continue
// case "DECIMAL":
// if reflect.TypeOf(value).Elem().Name() == "Decimal" {
// row[key] = value.(*decimal.Decimal)
// continue
// }
// if reflect.TypeOf(value).Elem().Name() == "uint8" {
// row[key] = toStr(value)
// continue
// }
// continue
// case "DATETIME":
// row[key] = value.(time.Time).Format("2006-01-02 15:04:05")
// continue
// case "DATE":
// row[key] = value.(time.Time).Format("2006-01-02 15:04:05")
// continue
// case "TIME":
// row[key] = value.(time.Time).Format("15:04:05")
// continue
// case "TIMESTAMP":
// row[key] = toInt64(value, -99999)
// continue
// }
// }
// }
// }
// 调用数据查询
func Raw(sql string, params ...interface{}) (tx GormDB) {
return doDb(sql, params, globGormDB.Raw)
}
// 调用数据查询
func RawRows(sql string, params ...interface{}) (*sql.Rows, error) {
return doDb(sql, params, globGormDB.Raw).Rows()
}
// 调用数据更新
func Exec(sql string, params ...interface{}) (tx GormDB) {
return doDb(sql, params, globGormDB.Exec)
}
// 调用数据库操作
func doDb(sql string, param []interface{}, dbFunc func(sql string, values ...interface{}) (tx GormDB)) (tx GormDB) {
if (nil == param) || (len(param) < 1) {
return dbFunc(sql)
}
iCount := len(param)
if iCount > 1 {
return dbFunc(sql, param...)
}
rtk := reflect.TypeOf(param[0]).Kind()
if rtk == reflect.Map {
s := reflect.ValueOf(param[0])
if s.Len() < 1 {
return dbFunc(sql)
}
return dbFunc(sql, param[0])
}
if (rtk != reflect.Slice) && (rtk != reflect.Array) {
return dbFunc(sql, param[0])
}
params := []interface{}{}
s := reflect.ValueOf(param[0])
for i := 0; i < s.Len(); i++ {
params = append(params, s.Index(i).Interface())
}
return dbFunc(sql, params...)
}
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Go
1
https://gitee.com/tomatomeatman/golang-repository.git
git@gitee.com:tomatomeatman/golang-repository.git
tomatomeatman
golang-repository
GolangRepository
61e401b0d628

搜索帮助