2 Star 0 Fork 4

Ryan / MysqlGo

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
sql_builder.go 7.16 KB
AI 代码解读
一键复制 编辑 原始数据 按行查看 历史
zhengmingming 提交于 2019-11-06 01:44 . fix: insert sql build bug
package mysqlgo
import (
"fmt"
"strings"
)
var (
selectSQL = "SELECT%DISTINCT% %FIELD% FROM %TABLE%%JOIN%%WHERE%%GROUP%%HAVING%%ORDER%%LIMIT% %UNION%%COMMENT%"
insertSQL = "INSERT INTO %TABLE%(%FIELD%) VALUE(%MARK%)"
updateSQL = "UPDATE %TABLE% SET %FIELD% WHERE %ARGS%"
deleteSQL = "DELETE FORM %TABLE% WHERE %AGRS%"
)
//SQLBuilder SQL生成工具类
type SQLBuilder struct {
table Table
distinct bool
field string
join []Join
where string
whereArgs []interface{}
group []string
having string
order map[string]bool
limit Limit
union Union
comment string
page string
args []interface{}
lastSQL string
priviewSQL bool
}
//SQLTYPE SQL类型
type SQLTYPE int
const (
query SQLTYPE = 0 //查询
insert SQLTYPE = 1 //插入
update SQLTYPE = 2 //更新
delete SQLTYPE = 3 //删除
)
func (sql *SQLBuilder) printSQL() {
if sql.priviewSQL {
fmt.Printf("[SQLBuilder Preview]: %s \n", sql.lastSQL)
}
}
func (sql *SQLBuilder) tableFormat() string {
if sql.table.Alias == "" {
return fmt.Sprintf(" %s ", sql.table.Name)
}
return fmt.Sprintf(" %s as %s ", sql.table.Name, sql.table.Alias)
}
func (sql *SQLBuilder) distinctFormat() string {
if sql.distinct {
return "DISTINCT"
}
return ""
}
func (sql *SQLBuilder) fieldFormat() string {
if sql.field != "" {
return sql.field
}
return "*"
}
func (sql *SQLBuilder) joinFormat() string {
var join []string
for _, value := range sql.join {
if !(strings.Index(value.Statement, "JOIN") > -1 && strings.Index(value.Statement, "join") > -1) {
value.Statement = fmt.Sprintf(" JOIN %s ", value.Statement)
}
switch value.Type {
case JOININNER:
join = append(join, fmt.Sprintf(" INNER %s ", value.Statement))
break
case JOINLEFT:
join = append(join, fmt.Sprintf(" LEFT %s ", value.Statement))
break
case JOINRIGHT:
join = append(join, fmt.Sprintf(" RIGHT %s ", value.Statement))
break
case JOINFULL:
join = append(join, fmt.Sprintf(" FULL %s ", value.Statement))
break
default:
join = append(join, fmt.Sprintf(" INNER %s ", value.Statement))
break
}
}
return strings.Join(join, ",")
}
func (sql *SQLBuilder) whereFormat() string {
if sql.where != "" {
return fmt.Sprintf(" WHERE %s ", sql.where)
}
return ""
}
func (sql *SQLBuilder) groupFormat() string {
if len(sql.group) > 0 {
return fmt.Sprintf(" GROUP BY %s", strings.Join(sql.group, ","))
}
return ""
}
func (sql *SQLBuilder) havingFormat() string {
if sql.having != "" {
return fmt.Sprintf(" HAVING %s", sql.having)
}
return ""
}
func (sql *SQLBuilder) commentFormat() string {
if sql.comment != "" {
return fmt.Sprintf(" /* %s */", sql.comment)
}
return ""
}
func (sql *SQLBuilder) orderFormat() string {
if len(sql.order) > 0 {
var orderStr []string
for key, order := range sql.order {
var str string
if order {
str = fmt.Sprintf(" %s desc", key)
} else {
str = fmt.Sprintf(" %s asc", key)
}
orderStr = append(orderStr, str)
}
return fmt.Sprintf(" ORDER BY %s ", strings.Join(orderStr, ","))
}
return ""
}
func (sql *SQLBuilder) limitFormat() string {
if sql.limit.RowCount > 0 {
if sql.limit.Offset > 0 {
return fmt.Sprintf(" LIMIT %d OFFSET %d ", sql.limit.RowCount, sql.limit.Offset)
}
return fmt.Sprintf(" LIMIT %d ", sql.limit.RowCount)
}
return ""
}
func (sql *SQLBuilder) unionFormat() string {
if len(sql.union.SelectSQL) > 0 {
if sql.union.All {
return fmt.Sprintf(" UNION ALL %s ", strings.Join(sql.union.SelectSQL, ","))
}
return fmt.Sprintf(" UNION %s ", strings.Join(sql.union.SelectSQL, ","))
}
return ""
}
func (sql *SQLBuilder) fieldMarkFormat() string {
if sql.field == "" || sql.field == "*" {
return ""
}
var count = strings.Count(sql.field, ",")
var fieldMark []string
for i := 0; i <= count; i++ {
fieldMark = append(fieldMark, "?")
}
if len(fieldMark) == 0 {
return ""
}
return strings.Join(fieldMark, ",")
}
func (sql *SQLBuilder) parseQuerySQL() string {
var sqlFormat = selectSQL
sqlFormat = strings.Replace(sqlFormat, "%TABLE%", sql.tableFormat(), -1)
sqlFormat = strings.Replace(sqlFormat, "%DISTINCT%", sql.distinctFormat(), -1)
sqlFormat = strings.Replace(sqlFormat, "%FIELD%", sql.fieldFormat(), -1)
sqlFormat = strings.Replace(sqlFormat, "%JOIN%", sql.joinFormat(), -1)
sqlFormat = strings.Replace(sqlFormat, "%WHERE%", sql.whereFormat(), -1)
sqlFormat = strings.Replace(sqlFormat, "%GROUP%", sql.groupFormat(), -1)
sqlFormat = strings.Replace(sqlFormat, "%HAVING%", sql.havingFormat(), -1)
sqlFormat = strings.Replace(sqlFormat, "%ORDER%", sql.orderFormat(), -1)
sqlFormat = strings.Replace(sqlFormat, "%LIMIT%", sql.limitFormat(), -1)
sqlFormat = strings.Replace(sqlFormat, "%UNION%", sql.unionFormat(), -1)
sqlFormat = strings.Replace(sqlFormat, "%COMMENT%", sql.commentFormat(), -1)
return sqlFormat
}
func (sql *SQLBuilder) parseInsertSQL() string {
var sqlFormat = insertSQL
sqlFormat = strings.Replace(sqlFormat, "%TABLE%", sql.tableFormat(), -1)
sqlFormat = strings.Replace(sqlFormat, "%FIELD%", sql.fieldFormat(), -1)
sqlFormat = strings.Replace(sqlFormat, "%DISTINCT%", sql.distinctFormat(), -1)
fieldMark := sql.fieldMarkFormat()
if fieldMark == "" {
return ""
}
sqlFormat = strings.Replace(sqlFormat, "%MARK%", fieldMark, -1)
return sqlFormat
}
func (sql *SQLBuilder) parseUpdateSQL(data map[string]interface{}) string {
var sqlFormat = updateSQL
sqlFormat = strings.Replace(sqlFormat, "%TABLE%", sql.tableFormat(), -1)
var keys []string
var values []interface{}
for key, value := range data {
keys = append(keys, fmt.Sprintf("%s = ?", key))
values = append(values, value)
}
sqlFormat = strings.Replace(sqlFormat, "%FIELD%", strings.Join(keys, ","), -1)
sqlFormat = strings.Replace(sqlFormat, "%ARGS%", sql.whereFormat(), -1)
sql.args = append(sql.args, values...)
sql.args = append(sql.args, sql.whereArgs...)
return sqlFormat
}
func (sql *SQLBuilder) parseDeleteSQL() string {
var sqlFormat = deleteSQL
sqlFormat = strings.Replace(sqlFormat, "%TABLE%", sql.tableFormat(), -1)
if len(sql.whereArgs) == 0 {
return ""
}
sqlFormat = strings.Replace(sqlFormat, "%ARGS%", sql.whereFormat(), -1)
return sqlFormat
}
//Where add where
func (sql *SQLBuilder) Where(condition string, args ...interface{}) {
if sql.where != "" {
sql.where = fmt.Sprintf("%s AND (%s)", sql.where, condition)
} else {
sql.where = fmt.Sprintf("(%s)", condition)
}
if args != nil {
if sql.whereArgs != nil {
sql.whereArgs = append(sql.whereArgs, args...)
} else {
sql.whereArgs = args
}
}
}
//Field 指定字段名称
func (sql *SQLBuilder) Field(fields ...string) {
for _, field := range fields {
if strings.Index(sql.field, field) < 0 {
if sql.field == "" {
sql.field = fmt.Sprintf("%s", field)
continue
}
sql.field = fmt.Sprintf("%s, %s", sql.field, field)
}
}
}
//BuildSQL 构造SQL
func (sql *SQLBuilder) BuildSQL(t SQLTYPE, data map[string]interface{}) string {
switch t {
case query:
sql.lastSQL = sql.parseQuerySQL()
break
case insert:
sql.lastSQL = sql.parseInsertSQL()
break
case update:
sql.lastSQL = sql.parseUpdateSQL(data)
break
case delete:
sql.lastSQL = sql.parseDeleteSQL()
break
}
sql.printSQL()
return sql.lastSQL
}
Go
1
https://gitee.com/ironCoffers/MysqlGo.git
git@gitee.com:ironCoffers/MysqlGo.git
ironCoffers
MysqlGo
MysqlGo
master

搜索帮助