1 Star 0 Fork 0

DaMeng/Atlas

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
driver.go 14.98 KB
一键复制 编辑 原始数据 按行查看 历史
DaMeng 提交于 2024-10-24 15:32 +08:00 . :art:修改mod名称
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537
// Copyright 2021-present The Atlas Authors. All rights reserved.
// This source code is licensed under the Apache 2.0 license found
// in the LICENSE file in the root directory of this source tree.
package postgres
import (
"context"
"database/sql"
"errors"
"fmt"
"hash/fnv"
"net/url"
"strconv"
"time"
"gitee.com/damengde/atlas/sql/internal/sqlx"
"gitee.com/damengde/atlas/sql/migrate"
"gitee.com/damengde/atlas/sql/schema"
"gitee.com/damengde/atlas/sql/sqlclient"
)
type (
// Driver represents a PostgreSQL driver for introspecting database schemas,
// generating diff between schema elements and apply migrations changes.
Driver struct {
*conn
schema.Differ
schema.Inspector
migrate.PlanApplier
}
// database connection and its information.
conn struct {
schema.ExecQuerier
// The schema in the `search_path` parameter (if given).
schema string
// System variables that are set on `Open`.
version int
crdb bool
}
)
var _ interface {
migrate.Snapshoter
migrate.StmtScanner
migrate.CleanChecker
schema.TypeParseFormatter
} = (*Driver)(nil)
// DriverName holds the name used for registration.
const DriverName = "postgres"
func init() {
sqlclient.Register(
DriverName,
sqlclient.OpenerFunc(opener),
sqlclient.RegisterDriverOpener(Open),
sqlclient.RegisterFlavours("postgresql"),
sqlclient.RegisterCodec(MarshalHCL, EvalHCL),
sqlclient.RegisterURLParser(parser{}),
)
}
func opener(_ context.Context, u *url.URL) (*sqlclient.Client, error) {
ur := parser{}.ParseURL(u)
db, err := sql.Open(DriverName, ur.DSN)
if err != nil {
return nil, err
}
drv, err := Open(db)
if err != nil {
if cerr := db.Close(); cerr != nil {
err = fmt.Errorf("%w: %v", err, cerr)
}
return nil, err
}
switch drv := drv.(type) {
case *Driver:
drv.schema = ur.Schema
case noLockDriver:
drv.noLocker.(*Driver).schema = ur.Schema
}
return &sqlclient.Client{
Name: DriverName,
DB: db,
URL: ur,
Driver: drv,
}, nil
}
// Open opens a new PostgreSQL driver.
func Open(db schema.ExecQuerier) (migrate.Driver, error) {
c := &conn{ExecQuerier: db}
rows, err := db.QueryContext(context.Background(), paramsQuery)
if err != nil {
return nil, fmt.Errorf("postgres: scanning system variables: %w", err)
}
params, err := sqlx.ScanStrings(rows)
if err != nil {
return nil, fmt.Errorf("postgres: failed scanning rows: %w", err)
}
if len(params) != 1 && len(params) != 2 {
return nil, fmt.Errorf("postgres: unexpected number of rows: %d", len(params))
}
if c.version, err = strconv.Atoi(params[0]); err != nil {
return nil, fmt.Errorf("postgres: malformed version: %s: %w", params[0], err)
}
if c.version < 10_00_00 {
return nil, fmt.Errorf("postgres: unsupported postgres version: %d", c.version)
}
// Means we are connected to CockroachDB because we have a result for name='crdb_version'. see `paramsQuery`.
if c.crdb = len(params) == 2; c.crdb {
return noLockDriver{
&Driver{
conn: c,
Differ: &sqlx.Diff{DiffDriver: &crdbDiff{diff{c}}},
Inspector: &crdbInspect{inspect{c}},
PlanApplier: &planApply{c},
},
}, nil
}
return &Driver{
conn: c,
Differ: &sqlx.Diff{DiffDriver: &diff{c}},
Inspector: &inspect{c},
PlanApplier: &planApply{c},
}, nil
}
func (d *Driver) dev() *sqlx.DevDriver {
return &sqlx.DevDriver{
Driver: d,
PatchObject: func(s *schema.Schema, o schema.Object) {
if e, ok := o.(*schema.EnumType); ok {
e.Schema = s
}
},
}
}
// NormalizeRealm returns the normal representation of the given database.
func (d *Driver) NormalizeRealm(ctx context.Context, r *schema.Realm) (*schema.Realm, error) {
return d.dev().NormalizeRealm(ctx, r)
}
// NormalizeSchema returns the normal representation of the given database.
func (d *Driver) NormalizeSchema(ctx context.Context, s *schema.Schema) (*schema.Schema, error) {
return d.dev().NormalizeSchema(ctx, s)
}
// Lock implements the schema.Locker interface.
func (d *Driver) Lock(ctx context.Context, name string, timeout time.Duration) (schema.UnlockFunc, error) {
conn, err := sqlx.SingleConn(ctx, d.ExecQuerier)
if err != nil {
return nil, err
}
h := fnv.New32()
h.Write([]byte(name))
id := h.Sum32()
if err := acquire(ctx, conn, id, timeout); err != nil {
conn.Close()
return nil, err
}
return func() error {
defer conn.Close()
rows, err := conn.QueryContext(ctx, "SELECT pg_advisory_unlock($1)", id)
if err != nil {
return err
}
switch released, err := sqlx.ScanNullBool(rows); {
case err != nil:
return err
case !released.Valid || !released.Bool:
return fmt.Errorf("sql/postgres: failed releasing lock %d", id)
}
return nil
}, nil
}
// Snapshot implements migrate.Snapshoter.
func (d *Driver) Snapshot(ctx context.Context) (migrate.RestoreFunc, error) {
// Postgres will only then be considered bound to a schema if the `search_path` was given.
// In all other cases, the connection is considered bound to the realm.
if d.schema != "" {
s, err := d.InspectSchema(ctx, d.schema, nil)
if err != nil {
return nil, err
}
if len(s.Tables) > 0 {
return nil, &migrate.NotCleanError{
State: schema.NewRealm(s),
Reason: fmt.Sprintf("found table %q in connected schema", s.Tables[0].Name),
}
}
return d.SchemaRestoreFunc(s), nil
}
// Not bound to a schema.
realm, err := d.InspectRealm(ctx, nil)
if err != nil {
return nil, err
}
restore := d.RealmRestoreFunc(realm)
// Postgres is considered clean, if there are no schemas or the public schema has no tables.
if len(realm.Schemas) == 0 {
return restore, nil
}
if s, ok := realm.Schema("public"); len(realm.Schemas) == 1 && ok {
if len(s.Tables) > 0 {
return nil, &migrate.NotCleanError{
State: realm,
Reason: fmt.Sprintf("found table %q in schema %q", s.Tables[0].Name, s.Name),
}
}
return restore, nil
}
return nil, &migrate.NotCleanError{
State: realm,
Reason: fmt.Sprintf("found schema %q", realm.Schemas[0].Name),
}
}
// SchemaRestoreFunc returns a function that restores the given schema to its desired state.
func (d *Driver) SchemaRestoreFunc(desired *schema.Schema) migrate.RestoreFunc {
return func(ctx context.Context) error {
current, err := d.InspectSchema(ctx, desired.Name, nil)
if err != nil {
return err
}
changes, err := d.SchemaDiff(current, desired)
if err != nil {
return err
}
return d.ApplyChanges(ctx, withCascade(changes))
}
}
// RealmRestoreFunc returns a function that restores the given realm to its desired state.
func (d *Driver) RealmRestoreFunc(desired *schema.Realm) migrate.RestoreFunc {
return func(ctx context.Context) error {
current, err := d.InspectRealm(ctx, nil)
if err != nil {
return err
}
changes, err := d.RealmDiff(current, desired)
if err != nil {
return err
}
return d.ApplyChanges(ctx, withCascade(changes))
}
}
func withCascade(changes schema.Changes) schema.Changes {
for _, c := range changes {
switch c := c.(type) {
case *schema.DropTable:
c.Extra = append(c.Extra, &schema.IfExists{}, &Cascade{})
case *schema.DropView:
c.Extra = append(c.Extra, &schema.IfExists{}, &Cascade{})
case *schema.DropProc:
c.Extra = append(c.Extra, &schema.IfExists{}, &Cascade{})
case *schema.DropFunc:
c.Extra = append(c.Extra, &schema.IfExists{}, &Cascade{})
case *schema.DropObject:
c.Extra = append(c.Extra, &schema.IfExists{}, &Cascade{})
}
}
return changes
}
// CheckClean implements migrate.CleanChecker.
func (d *Driver) CheckClean(ctx context.Context, revT *migrate.TableIdent) error {
if revT == nil { // accept nil values
revT = &migrate.TableIdent{}
}
if d.schema != "" {
switch s, err := d.InspectSchema(ctx, d.schema, nil); {
case err != nil:
return err
case len(s.Tables) == 0, (revT.Schema == "" || s.Name == revT.Schema) && len(s.Tables) == 1 && s.Tables[0].Name == revT.Name:
return nil
default:
return &migrate.NotCleanError{State: schema.NewRealm(s), Reason: fmt.Sprintf("found table %q in schema %q", s.Tables[0].Name, s.Name)}
}
}
r, err := d.InspectRealm(ctx, nil)
if err != nil {
return err
}
for _, s := range r.Schemas {
switch {
case len(s.Tables) == 0 && s.Name == "public":
case len(s.Tables) == 0 || s.Name != revT.Schema:
return &migrate.NotCleanError{State: r, Reason: fmt.Sprintf("found schema %q", s.Name)}
case len(s.Tables) > 1:
return &migrate.NotCleanError{State: r, Reason: fmt.Sprintf("found %d tables in schema %q", len(s.Tables), s.Name)}
case len(s.Tables) == 1 && s.Tables[0].Name != revT.Name:
return &migrate.NotCleanError{State: r, Reason: fmt.Sprintf("found table %q in schema %q", s.Tables[0].Name, s.Name)}
}
}
return nil
}
// Version returns the version of the connected database.
func (d *Driver) Version() string {
return strconv.Itoa(d.conn.version)
}
// FormatType converts schema type to its column form in the database.
func (*Driver) FormatType(t schema.Type) (string, error) {
return FormatType(t)
}
// ParseType returns the schema.Type value represented by the given string.
func (*Driver) ParseType(s string) (schema.Type, error) {
return ParseType(s)
}
// StmtBuilder is a helper method used to build statements with PostgreSQL formatting.
func (*Driver) StmtBuilder(opts migrate.PlanOptions) *sqlx.Builder {
return &sqlx.Builder{
QuoteOpening: '"',
QuoteClosing: '"',
Schema: opts.SchemaQualifier,
Indent: opts.Indent,
}
}
// ScanStmts implements migrate.StmtScanner.
func (*Driver) ScanStmts(input string) ([]*migrate.Stmt, error) {
return (&migrate.Scanner{
ScannerOptions: migrate.ScannerOptions{
MatchBegin: true,
MatchBeginAtomic: true,
MatchDollarQuote: true,
},
}).Scan(input)
}
func acquire(ctx context.Context, conn schema.ExecQuerier, id uint32, timeout time.Duration) error {
switch {
// With timeout (context-based).
case timeout > 0:
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, timeout)
defer cancel()
fallthrough
// Infinite timeout.
case timeout < 0:
rows, err := conn.QueryContext(ctx, "SELECT pg_advisory_lock($1)", id)
if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) {
err = schema.ErrLocked
}
if err != nil {
return err
}
return rows.Close()
// No timeout.
default:
rows, err := conn.QueryContext(ctx, "SELECT pg_try_advisory_lock($1)", id)
if err != nil {
return err
}
acquired, err := sqlx.ScanNullBool(rows)
if err != nil {
return err
}
if !acquired.Bool {
return schema.ErrLocked
}
return nil
}
}
// supportsIndexInclude reports if the server supports the INCLUDE clause.
func (c *conn) supportsIndexInclude() bool {
return c.version >= 11_00_00
}
// supportsIndexNullsDistinct reports if the server supports the NULLS [NOT] DISTINCT clause.
func (c *conn) supportsIndexNullsDistinct() bool {
return c.version >= 15_00_00
}
type parser struct{}
// ParseURL implements the sqlclient.URLParser interface.
func (parser) ParseURL(u *url.URL) *sqlclient.URL {
return &sqlclient.URL{URL: u, DSN: u.String(), Schema: u.Query().Get("search_path")}
}
// ChangeSchema implements the sqlclient.SchemaChanger interface.
func (parser) ChangeSchema(u *url.URL, s string) *url.URL {
nu := *u
q := nu.Query()
q.Set("search_path", s)
nu.RawQuery = q.Encode()
return &nu
}
// Standard column types (and their aliases) as defined in
// PostgreSQL codebase/website.
const (
TypeBit = "bit"
TypeBitVar = "bit varying"
TypeBoolean = "boolean"
TypeBool = "bool" // boolean.
TypeBytea = "bytea"
TypeCharacter = "character"
TypeChar = "char" // character
TypeCharVar = "character varying"
TypeVarChar = "varchar" // character varying
TypeText = "text"
typeName = "name" // internal type for object names
TypeSmallInt = "smallint"
TypeInteger = "integer"
TypeBigInt = "bigint"
TypeInt = "int" // integer.
TypeInt2 = "int2" // smallint.
TypeInt4 = "int4" // integer.
TypeInt8 = "int8" // bigint.
TypeCIDR = "cidr"
TypeInet = "inet"
TypeMACAddr = "macaddr"
TypeMACAddr8 = "macaddr8"
TypeCircle = "circle"
TypeLine = "line"
TypeLseg = "lseg"
TypeBox = "box"
TypePath = "path"
TypePolygon = "polygon"
TypePoint = "point"
TypeDate = "date"
TypeTime = "time" // time without time zone
TypeTimeTZ = "timetz" // time with time zone
TypeTimeWTZ = "time with time zone"
TypeTimeWOTZ = "time without time zone"
TypeTimestamp = "timestamp" // timestamp without time zone
TypeTimestampTZ = "timestamptz"
TypeTimestampWTZ = "timestamp with time zone"
TypeTimestampWOTZ = "timestamp without time zone"
TypeDouble = "double precision"
TypeReal = "real"
TypeFloat8 = "float8" // double precision
TypeFloat4 = "float4" // real
TypeFloat = "float" // float(p).
TypeNumeric = "numeric"
TypeDecimal = "decimal" // numeric
TypeSmallSerial = "smallserial" // smallint with auto_increment.
TypeSerial = "serial" // integer with auto_increment.
TypeBigSerial = "bigserial" // bigint with auto_increment.
TypeSerial2 = "serial2" // smallserial
TypeSerial4 = "serial4" // serial
TypeSerial8 = "serial8" // bigserial
TypeArray = "array"
TypeXML = "xml"
TypeJSON = "json"
TypeJSONB = "jsonb"
TypeUUID = "uuid"
TypeMoney = "money"
TypeInterval = "interval"
TypeTSQuery = "tsquery"
TypeTSVector = "tsvector"
TypeUserDefined = "user-defined"
TypeInt4Range = "int4range"
TypeInt4MultiRange = "int4multirange"
TypeInt8Range = "int8range"
TypeInt8MultiRange = "int8multirange"
TypeNumRange = "numrange"
TypeNumMultiRange = "nummultirange"
TypeTSRange = "tsrange"
TypeTSMultiRange = "tsmultirange"
TypeTSTZRange = "tstzrange"
TypeTSTZMultiRange = "tstzmultirange"
TypeDateRange = "daterange"
TypeDateMultiRange = "datemultirange"
// PostgreSQL internal object types and their aliases.
typeOID = "oid"
typeRegClass = "regclass"
typeRegCollation = "regcollation"
typeRegConfig = "regconfig"
typeRegDictionary = "regdictionary"
typeRegNamespace = "regnamespace"
typeRegOper = "regoper"
typeRegOperator = "regoperator"
typeRegProc = "regproc"
typeRegProcedure = "regprocedure"
typeRegRole = "regrole"
typeRegType = "regtype"
// PostgreSQL of supported pseudo-types.
typeAny = "any"
typeAnyElement = "anyelement"
typeAnyArray = "anyarray"
typeAnyNonArray = "anynonarray"
typeAnyEnum = "anyenum"
typeInternal = "internal"
typeRecord = "record"
typeTrigger = "trigger"
typeVoid = "void"
typeUnknown = "unknown"
)
// List of supported index types.
const (
IndexTypeBTree = "BTREE"
IndexTypeBRIN = "BRIN"
IndexTypeHash = "HASH"
IndexTypeGIN = "GIN"
IndexTypeGiST = "GIST"
IndexTypeSPGiST = "SPGIST"
defaultPagePerRange = 128
)
// List of "GENERATED" types.
const (
GeneratedTypeAlways = "ALWAYS"
GeneratedTypeByDefault = "BY_DEFAULT" // BY DEFAULT.
)
// List of PARTITION KEY types.
const (
PartitionTypeRange = "RANGE"
PartitionTypeList = "LIST"
PartitionTypeHash = "HASH"
)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/damengde/atlas.git
git@gitee.com:damengde/atlas.git
damengde
atlas
Atlas
v0.0.2

搜索帮助