代码拉取完成,页面将自动刷新
// Copyright 2021-present The Atlas Authors. All rights reserved.
// This source code is licensed under the Apache 2.0 license found
// in the LICENSE file in the root directory of this source tree.
package migrate
import (
"errors"
"fmt"
"io"
"regexp"
"strings"
"unicode"
"unicode/utf8"
)
// Stmt represents a scanned statement text along with its
// position in the file and associated comments group.
type Stmt struct {
Pos int // statement position
Text string // statement text
Comments []string // associated comments
}
// Directive returns all directive comments with the given name.
// See: pkg.go.dev/cmd/compile#hdr-Compiler_Directives.
func (s *Stmt) Directive(name string) (ds []string) {
for _, c := range s.Comments {
switch {
case strings.HasPrefix(c, "/*") && !strings.Contains(c, "\n"):
if d, ok := directive(strings.TrimSuffix(c, "*/"), name, "/*"); ok {
ds = append(ds, d)
}
default:
for _, p := range []string{"#", "--", "-- "} {
if d, ok := directive(c, name, p); ok {
ds = append(ds, d)
}
}
}
}
return
}
// Stmts provides a generic implementation for extracting SQL statements from the given file contents.
func Stmts(input string) ([]*Stmt, error) {
return (&Scanner{
ScannerOptions: ScannerOptions{
// Default options for backward compatibility.
MatchBegin: false,
MatchBeginAtomic: true,
MatchDollarQuote: true,
},
}).Scan(input)
}
// FileStmtDecls scans atlas-format file statements using
// the Driver implementation, if implemented.
func FileStmtDecls(drv Driver, f File) ([]*Stmt, error) {
s, ok1 := drv.(StmtScanner)
_, ok2 := f.(*LocalFile)
if !ok1 || !ok2 {
return f.StmtDecls()
}
return s.ScanStmts(string(f.Bytes()))
}
// FileStmts is like FileStmtDecls but returns only the
// statement text without the extra info.
func FileStmts(drv Driver, f File) ([]string, error) {
s, err := FileStmtDecls(drv, f)
if err != nil {
return nil, err
}
stmts := make([]string, len(s))
for i := range s {
stmts[i] = s[i].Text
}
return stmts, nil
}
type (
// StmtScanner interface for scanning SQL statements from migration
// and schema files and can be optionally implemented by drivers.
StmtScanner interface {
ScanStmts(input string) ([]*Stmt, error)
}
// Scanner scanning SQL statements from migration and schema files.
Scanner struct {
ScannerOptions
// scanner state.
src, input string // src and current input text
pos int // current phase position
total int // total bytes scanned so far
width int // size of latest rune
delim string // configured delimiter
comments []string // collected comments
}
// ScannerOptions controls the behavior of the scanner.
ScannerOptions struct {
// MatchBegin enables matching for BEGIN ... END statements block.
MatchBegin bool
// MatchBeginAtomic enables matching for BEGIN ATOMIC ... END statements block.
MatchBeginAtomic bool
// MatchDollarQuote enables the PostgreSQL dollar-quoted string syntax.
MatchDollarQuote bool
}
)
// Scan scans the statement in the given input.
func (s *Scanner) Scan(input string) ([]*Stmt, error) {
var stmts []*Stmt
if err := s.init(input); err != nil {
return nil, err
}
for {
s, err := s.stmt()
if err == io.EOF {
return stmts, nil
}
if err != nil {
return nil, err
}
stmts = append(stmts, s)
}
}
// init initializes the scanner state.
func (s *Scanner) init(input string) error {
s.comments = nil
s.pos, s.total, s.width = 0, 0, 0
s.src, s.input, s.delim = input, input, delimiter
if d, ok := directive(input, directiveDelimiter, directivePrefixSQL); ok {
if err := s.setDelim(d); err != nil {
return err
}
parts := strings.SplitN(input, "\n", 2)
if len(parts) == 1 {
return s.error(s.pos, "no input found after delimiter %q", d)
}
s.input = parts[1]
}
return nil
}
const (
eos = -1
delimiter = ";"
delimiterCmd = "delimiter"
)
var (
// Dollar-quoted string as defined by the PostgreSQL scanner.
reDollarQuote = regexp.MustCompile(`^\$([A-Za-zÈ-ÿ_][\wÈ-ÿ]*)*\$`)
// The 'BEGIN ATOMIC' syntax as specified in the SQL 2003 standard.
reBeginAtomic = regexp.MustCompile(`(?i)^\s*BEGIN\s+ATOMIC\s+`)
reBegin = regexp.MustCompile(`(?i)^\s*BEGIN\s+`)
reEnd = regexp.MustCompile(`(?i)^\s*END\s*`)
)
func (s *Scanner) stmt() (*Stmt, error) {
var (
depth, openingPos int
text string
)
s.skipSpaces()
Scan:
for {
switch r := s.next(); {
case r == eos:
switch {
case depth > 0:
return nil, s.error(openingPos, "unclosed '('")
case s.pos > 0:
text = s.input
break Scan
default:
return nil, io.EOF
}
case r == '(':
if depth == 0 {
openingPos = s.pos
}
depth++
case r == ')':
if depth == 0 {
return nil, s.error(s.pos, "unexpected ')'")
}
depth--
case r == '\'', r == '"', r == '`':
if err := s.skipQuote(r); err != nil {
return nil, err
}
// Check if the start of the statement is the MySQL DELIMITER command.
// See https://dev.mysql.com/doc/refman/8.0/en/mysql-commands.html.
case s.pos == 1 && len(s.input) > len(delimiterCmd) && strings.EqualFold(s.input[:len(delimiterCmd)], delimiterCmd):
s.addPos(len(delimiterCmd) - 1)
if err := s.delimCmd(); err != nil {
return nil, err
}
s.skipSpaces()
// Delimiters take precedence over comments.
case depth == 0 && strings.HasPrefix(s.input[s.pos-s.width:], s.delim):
s.addPos(len(s.delim) - s.width)
text = s.input[:s.pos]
break Scan
case s.MatchDollarQuote && r == '$' && reDollarQuote.MatchString(s.input[s.pos-1:]):
if err := s.skipDollarQuote(); err != nil {
return nil, err
}
// Skip non-standard MySQL comments if they are inside
// expressions until we make the lexer driver-aware.
case depth == 0 && r == '#':
s.comment("#", "\n")
case r == '-' && s.next() == '-':
s.comment("--", "\n")
case r == '/' && s.next() == '*':
s.comment("/*", "*/")
case s.delim == delimiter && s.MatchBeginAtomic && reBeginAtomic.MatchString(s.input[s.pos-1:]):
if err := s.skipBeginAtomic(); err == nil {
text = s.input[:s.pos]
break Scan
}
// Not a "BEGIN ATOMIC" block.
case s.delim == delimiter && s.MatchBegin &&
// Either the current scanned statement starts with BEGIN, or we inside a statement and expects at least one ~space before).
(s.pos == 1 && reBegin.MatchString(s.input[s.pos-1:]) || s.pos > 1 && reBegin.MatchString(s.input[s.pos-2:])):
if err := s.skipBegin(); err == nil {
text = s.input[:s.pos]
break Scan
}
// Not a "BEGIN" block.
}
}
return s.emit(text), nil
}
func (s *Scanner) next() rune {
if s.pos >= len(s.input) {
return eos
}
r, w := utf8.DecodeRuneInString(s.input[s.pos:])
s.width = w
s.addPos(w)
return r
}
func (s *Scanner) pick() rune {
p, w := s.pos, s.width
r := s.next()
s.pos, s.width = p, w
return r
}
func (s *Scanner) addPos(p int) {
s.pos += p
s.total += p
}
func (s *Scanner) skipQuote(quote rune) error {
pos := s.pos
for {
switch r := s.next(); {
case r == eos:
return s.error(pos, "unclosed quote %q", quote)
case r == '\\':
s.next()
case r == quote:
return nil
}
}
}
func (s *Scanner) skipDollarQuote() error {
m := reDollarQuote.FindString(s.input[s.pos-1:])
if m == "" {
return s.error(s.pos, "unexpected dollar quote")
}
s.addPos(len(m) - 1)
for {
switch r := s.next(); {
case r == eos:
// Fail only if a delimiter was not set.
if s.delim == "" {
return s.error(s.pos, "unclosed dollar-quoted string")
}
return nil
case r == '$' && strings.HasPrefix(s.input[s.pos-1:], m):
s.addPos(len(m) - 1)
return nil
}
}
}
func (s *Scanner) skipBeginAtomic() error {
m := reBeginAtomic.FindString(s.input[s.pos-1:])
if m == "" {
return s.error(s.pos, "unexpected missing BEGIN ATOMIC block")
}
s.addPos(len(m) - 1)
body := &Scanner{ScannerOptions: s.ScannerOptions}
if err := body.init(s.input[s.pos:]); err != nil {
return err
}
for {
stmt, err := body.stmt()
if err == io.EOF {
return s.error(s.pos, "unexpected eof when scanning sql body")
}
if err != nil {
return s.error(s.pos, "scan sql body: %v", err)
}
if reEnd.MatchString(stmt.Text) {
break
}
}
s.addPos(body.total)
return nil
}
func (s *Scanner) skipBegin() error {
m := reBegin.FindString(s.input[s.pos-1:])
if m == "" {
return s.error(s.pos, "unexpected missing BEGIN block")
}
s.addPos(len(m) - 1)
group := &Scanner{ScannerOptions: s.ScannerOptions}
if err := group.init(s.input[s.pos:]); err != nil {
return err
}
for depth := 1; depth > 0; {
switch stmt, err := group.stmt(); {
case err == io.EOF:
return s.error(s.pos, "unexpected eof when scanning compound statements")
case err != nil:
return s.error(s.pos, "scan compound statements: %v", err)
case reEnd.MatchString(stmt.Text):
if m := reEnd.FindString(stmt.Text); len(m) == len(stmt.Text) || strings.TrimPrefix(stmt.Text, m) == s.delim {
depth--
}
}
}
s.addPos(group.total)
return nil
}
func (s *Scanner) comment(left, right string) {
i := strings.Index(s.input[s.pos:], right)
// Not a comment.
if i == -1 {
return
}
// If the comment reside inside a statement, collect it.
if s.pos != len(left) {
s.addPos(i + len(right))
return
}
s.addPos(i + len(right))
// If we did not scan any statement characters, it
// can be skipped and stored in the comments group.
s.comments = append(s.comments, s.input[:s.pos])
s.input = s.input[s.pos:]
s.pos = 0
// Double \n separate the comments group from the statement.
if strings.HasPrefix(s.input, "\n\n") || right == "\n" && strings.HasPrefix(s.input, "\n") {
s.comments = nil
}
s.skipSpaces()
}
func (s *Scanner) skipSpaces() {
n := len(s.input)
s.input = strings.TrimLeftFunc(s.input, unicode.IsSpace)
s.total += n - len(s.input)
}
func (s *Scanner) emit(text string) *Stmt {
stmt := &Stmt{Pos: s.total - len(text), Text: text, Comments: s.comments}
s.input = s.input[s.pos:]
s.pos = 0
s.comments = nil
// Trim custom delimiter.
if s.delim != delimiter {
stmt.Text = strings.TrimSuffix(stmt.Text, s.delim)
}
stmt.Text = strings.TrimSpace(stmt.Text)
return stmt
}
// delimCmd checks if the scanned "DELIMITER"
// text represents an actual delimiter command.
func (s *Scanner) delimCmd() error {
// A space must come after the delimiter.
if s.pick() != ' ' {
return nil
}
// Scan delimiter.
for r := s.pick(); r != eos && r != '\n'; r = s.next() {
}
delim := strings.TrimSpace(s.input[len(delimiterCmd):s.pos])
// MySQL client allows quoting delimiters.
if strings.HasPrefix(delim, "'") && strings.HasSuffix(delim, "'") {
delim = strings.ReplaceAll(delim[1:len(delim)-1], "''", "'")
}
if err := s.setDelim(delim); err != nil {
return err
}
// Skip all we saw until now.
s.emit(s.input[:s.pos])
return nil
}
func (s *Scanner) setDelim(d string) error {
if d == "" {
return errors.New("empty delimiter")
}
// Unescape delimiters. e.g. "\\n" => "\n".
s.delim = strings.NewReplacer(`\n`, "\n", `\r`, "\r", `\t`, "\t").Replace(d)
return nil
}
func (s *Scanner) error(pos int, format string, args ...any) error {
format = "%d:%d: " + format
var (
p = len(s.src) - len(s.input) + pos
src = s.src[:p]
col = strings.LastIndex(src, "\n")
line = 1 + strings.Count(src, "\n")
)
if line == 1 {
col = p
} else {
col = p - col - 1
}
return fmt.Errorf(format, append([]any{line, col}, args...)...)
}
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。