From 19c54758409331e372bbc63fbd09f5cc65e73627 Mon Sep 17 00:00:00 2001 From: sdy Date: Tue, 10 Feb 2026 17:19:24 +0800 Subject: [PATCH 1/3] add pgwire extend --- pgx/v5/conn.go | 36 ++ pgx/v5/pgconn/pgconn.go | 162 ++++++++- pgx/v5/pgproto3/bind.go | 406 ++++++++++++++++++++++- pgx/v5/pgproto3/frontend.go | 19 ++ pgx/v5/pgproto3/parameter_description.go | 85 +++++ pgx/v5/pgproto3/parse.go | 59 ++++ pgx/v5/pgproto3/trace.go | 11 + pkg/targets/kwdb/process_prepare.go | 34 +- 8 files changed, 805 insertions(+), 7 deletions(-) diff --git a/pgx/v5/conn.go b/pgx/v5/conn.go index cc783f2..06cd13e 100644 --- a/pgx/v5/conn.go +++ b/pgx/v5/conn.go @@ -329,6 +329,42 @@ func (c *Conn) Prepare(ctx context.Context, name, sql string) (sd *pgconn.Statem return sd, nil } +func (c *Conn) PrepareEx(ctx context.Context, name, sql string) (sd *pgconn.StatementDescription, err error) { + if c.prepareTracer != nil { + ctx = c.prepareTracer.TracePrepareStart(ctx, c, TracePrepareStartData{Name: name, SQL: sql}) + } + + if name != "" { + var ok bool + if sd, ok = c.preparedStatements[name]; ok { + if sd.SQL == sql { + if c.prepareTracer != nil { + c.prepareTracer.TracePrepareEnd(ctx, c, TracePrepareEndData{AlreadyPrepared: true}) + } + return sd, nil + } + panic(fmt.Sprintf("prepare le repeat table %s", name)) + } + } + + if c.prepareTracer != nil { + defer func() { + c.prepareTracer.TracePrepareEnd(ctx, c, TracePrepareEndData{Err: err}) + }() + } + + sd, err = c.pgConn.PrepareEx(ctx, name, sql) + if err != nil { + return nil, err + } + + if name != "" { + c.preparedStatements[name] = sd + } + + return sd, nil +} + // Deallocate released a prepared statement func (c *Conn) Deallocate(ctx context.Context, name string) error { delete(c.preparedStatements, name) diff --git a/pgx/v5/pgconn/pgconn.go b/pgx/v5/pgconn/pgconn.go index f783cad..1608c54 100644 --- a/pgx/v5/pgconn/pgconn.go +++ b/pgx/v5/pgconn/pgconn.go @@ -13,6 +13,7 @@ import ( "net" "strconv" "strings" + "sync" "time" "github.com/jackc/pgx/v5/internal/iobufpool" @@ -86,6 +87,8 @@ type PgConn struct { fieldDescriptions [16]FieldDescription cleanupDone chan struct{} + + pool sync.Pool } // Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or DSN format) @@ -250,6 +253,9 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig pgConn := new(PgConn) pgConn.config = config pgConn.cleanupDone = make(chan struct{}) + pgConn.pool.New = func() interface{} { + return &pgproto3.PayloadBuffer{} + } var err error network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port) @@ -759,10 +765,15 @@ func (pgConn *PgConn) convertRowDescription(dst []FieldDescription, rd *pgproto3 } type StatementDescription struct { - Name string - SQL string - ParamOIDs []uint32 - Fields []FieldDescription + Name string + SQL string + ParamOIDs []uint32 + Fields []FieldDescription + TableName string + TagIndex int16 // index of tag start position + PtagIDs []uint16 // ptag index and sequence + OtherTagIDs []uint16 + StorageLen []uint32 // ts field storage length } // Prepare creates a prepared statement. If the name is empty, the anonymous prepared statement will be used. This @@ -823,6 +834,76 @@ readloop: return psd, nil } +func (pgConn *PgConn) PrepareEx(ctx context.Context, name, tablename string) (*StatementDescription, error) { + if err := pgConn.lock(); err != nil { + return nil, err + } + defer pgConn.unlock() + + if ctx != context.Background() { + select { + case <-ctx.Done(): + return nil, newContextAlreadyDoneError(ctx) + default: + } + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() + } + pgConn.frontend.SendParseEx(&pgproto3.ParseEx{Name: name, TableName: tablename}) + pgConn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'S', Name: name}) + pgConn.frontend.SendSync(&pgproto3.Sync{}) + err := pgConn.frontend.Flush() + if err != nil { + pgConn.asyncClose() + return nil, err + } + psd := &StatementDescription{Name: name, TableName: tablename} + + var parseErr error + +readloop: + for { + msg, err := pgConn.receiveMessage() + if err != nil { + pgConn.asyncClose() + return nil, normalizeTimeoutError(ctx, err) + } + + switch msg := msg.(type) { + case *pgproto3.ParameterDescription: + psd.ParamOIDs = make([]uint32, len(msg.ParameterOIDs)) + copy(psd.ParamOIDs, msg.ParameterOIDs) + case *pgproto3.RowDescription: + psd.Fields = pgConn.convertRowDescription(nil, msg) + case *pgproto3.ErrorResponse: + parseErr = ErrorResponseToPgError(msg) + case *pgproto3.ReadyForQuery: + break readloop + case *pgproto3.ParameterDescriptionEx: + psd.ParamOIDs = make([]uint32, len(msg.ParameterOIDs)) + copy(psd.ParamOIDs, msg.ParameterOIDs) + psd.TagIndex = msg.TagIndex + psd.PtagIDs = make([]uint16, len(msg.PtagIDs)) + copy(psd.PtagIDs, msg.PtagIDs) + psd.StorageLen = make([]uint32, len(msg.StorageLen)) + copy(psd.StorageLen, msg.StorageLen) + /*posPtag := 0 + for i := uint16(psd.TagIndex); i < uint16(len(psd.ParamOIDs)); i++ { + if i == psd.PtagIDs[posPtag] { + posPtag++ + continue + } + psd.OtherTagIDs = append(psd.OtherTagIDs, i) + }*/ + } + } + + if parseErr != nil { + return nil, parseErr + } + return psd, nil +} + // ErrorResponseToPgError converts a wire protocol error message to a *PgError. func ErrorResponseToPgError(msg *pgproto3.ErrorResponse) *PgError { return &PgError{ @@ -1065,6 +1146,79 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa return result } +func (pgConn *PgConn) ExecPreparedEx(ctx context.Context, stmtName string, sd *StatementDescription, args [][]byte, colCountPerRow int) *ResultReader { + result := pgConn.execExtendedPrefix(ctx, args) + if result.closed { + return result + } + + payloads := make(map[string]*pgproto3.PayloadBuffer) + var keyBuilder strings.Builder + keyBuilder.Grow(64) + // transform args to paylaod + rowColCount := colCountPerRow + maxRowlen := 43 + len(sd.StorageLen) + var payload *pgproto3.PayloadBuffer + for _, stlen := range sd.StorageLen { + maxRowlen += int(stlen) + } + + for pos := 0; pos < len(args); { + // get per row ptag + // by ptag find buffer + // buffer including head and tag and body + keyBuilder.Reset() + for _, ptag := range sd.PtagIDs { + keyBuilder.Write(args[pos+int(ptag)]) + } + key := keyBuilder.String() + payload = payloads[key] + if payload == nil { + payloads[key] = pgConn.pool.Get().(*pgproto3.PayloadBuffer) + payload = payloads[key] + } + if (payload.Cap - payload.Tail) < maxRowlen { + payload.Extend(4096) // temp 4k extend + } + + row := args[pos : pos+rowColCount] + payload.FillOneRow(row, sd.ParamOIDs, sd.PtagIDs, sd.TagIndex, sd.StorageLen, 0) + payload.RowNum++ + + pos += rowColCount + } + + // fill header rowNum + for _, pd := range payloads { + pd.WriteRowNum() + } + + pgConn.frontend.SendBindEx(&pgproto3.BindEx{PreparedStatement: stmtName, PtagToPayload: payloads}) + + // pgConn.execExtendedSuffix(result) + pgConn.frontend.SendSync(&pgproto3.Sync{}) + + err := pgConn.frontend.Flush() + if err != nil { + pgConn.asyncClose() + result.concludeCommand(CommandTag{}, err) + pgConn.contextWatcher.Unwatch() + result.closed = true + pgConn.unlock() + return nil + } + + // recycle + for _, pd := range payloads { + pd.Reset() + pgConn.pool.Put(pd) + } + + result.readUntilRowDescription() + + return result +} + func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]byte) *ResultReader { pgConn.resultReader = ResultReader{ pgConn: pgConn, diff --git a/pgx/v5/pgproto3/bind.go b/pgx/v5/pgproto3/bind.go index fdd2d3b..0a45d46 100644 --- a/pgx/v5/pgproto3/bind.go +++ b/pgx/v5/pgproto3/bind.go @@ -5,9 +5,13 @@ import ( "encoding/binary" "encoding/hex" "encoding/json" + "errors" "fmt" - "github.com/jackc/pgx/v5/internal/pgio" + "github.com/jackc/pgx/v5/pgtype" + "math" + "strconv" + "time" ) type Bind struct { @@ -214,3 +218,403 @@ func (dst *Bind) UnmarshalJSON(data []byte) error { } return nil } + +const ( + // OsnIDOffset offset of osn_id in the payload header + OsnIDOffset = 0 + // OsnIDSize length of osn_id in the payload header (Only 8 bytes have been used). + OsnIDSize = 16 + // RangeGroupIDOffset offset of range_group_id in the payload header + RangeGroupIDOffset = 16 + // RangeGroupIDSize length of range_group_id in the payload header + RangeGroupIDSize = 2 + // PayloadVersionOffset offset of payload_version in the payload header + PayloadVersionOffset = 18 + // PayloadVersionSize length of payload_version in the payload header + PayloadVersionSize = 4 + // DBIDOffset offset of db_id in the payload header + DBIDOffset = 22 + // DBIDSize length of db_id in the payload header + DBIDSize = 4 + // TableIDOffset offset of table_id in the payload header + TableIDOffset = 26 + // TableIDSize length of table_id in the payload header + TableIDSize = 8 + // TSVersionOffset offset of ts_version in the payload header + TSVersionOffset = 34 + // TSVersionSize length of ts_version in the payload header + TSVersionSize = 4 + // RowNumOffset offset of row_num in the payload header + RowNumOffset = 38 + // RowNumSize length of row_num in the payload header + RowNumSize = 4 + // RowTypeOffset offset of row_type in the payload header + RowTypeOffset = 42 + // RowTypeSize length of row_type in the payload header + RowTypeSize = 1 + // HeadSize is the payload fixed header length of insert ts table + HeadSize = RowTypeOffset + RowTypeSize + // PTagLenSize length of primary tag + PTagLenSize = 2 + // AllTagLenSize length of ordinary tag + AllTagLenSize = 4 + // DataLenSize length of datalen size + DataLenSize = 4 + // VarDataLenSize length of not fixed datalen + VarDataLenSize = 2 + // VarColumnSize is the fixed length memory taken by var-length data type + VarColumnSize = 8 +) + +type PayloadBuffer struct { + Tail int + Data []byte + Cap int + HeadTail int + rowTail int + RowNum uint32 +} + +func (pd *PayloadBuffer) Extend(size int) error { + if size <= 0 { + return nil + } + newSize := pd.Cap + size + dataNew := make([]byte, newSize) + if dataNew == nil { + return errors.New("extend memory failed.") + } + + if pd.Data == nil { + pd.Tail = 0 + pd.Data = dataNew + } else { + useData := pd.Data[:pd.Tail] + copy(dataNew, useData) + pd.Data = dataNew + } + + pd.Cap = newSize + return nil +} + +func (pd *PayloadBuffer) WriteRowNum() error { + if pd.Data != nil && len(pd.Data) > RowNumOffset { + binary.LittleEndian.PutUint32(pd.Data[RowNumOffset:], pd.RowNum) + binary.LittleEndian.PutUint32(pd.Data[pd.HeadTail:], uint32(pd.Tail-pd.HeadTail-4)) // data len + } + + return nil +} + +func (pd *PayloadBuffer) Reset() error { + pd.Tail = 0 + pd.HeadTail = 0 + pd.rowTail = 0 + pd.RowNum = 0 + + return nil +} + +const ( + minTimeDuration time.Duration = -1 << 63 + maxTimeDuration time.Duration = 1<<63 - 1 +) + +func AddMicros(t time.Time, d int64) time.Time { + negMult := time.Duration(1) + if d < 0 { + negMult = -1 + d = -d + } + const maxMicroDur = int64(maxTimeDuration / time.Microsecond) + for d > maxMicroDur { + const maxWholeNanoDur = time.Duration(maxMicroDur) * time.Microsecond + t = t.Add(negMult * maxWholeNanoDur) + d -= maxMicroDur + } + return t.Add(negMult * time.Duration(d) * time.Microsecond) +} + +func PgBinaryToTime(i int64) time.Time { + return AddMicros(PGEpochJDate, i) +} + +var ( + // PGEpochJDate represents the pg epoch. + PGEpochJDate = time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC) +) + +// fill tag and metric data +func FillColData(oid uint32, val []byte, dst []byte, storelen uint32, dstVarStart []byte, + offsetVar uint16, isPtag bool) (int, int, error) { + switch oid { // not all support + case pgtype.Int8OID: + // val, err := strconv.Atoi(string(val)) + //if err != nil { + // return -1, -1, err + //} + binary.LittleEndian.PutUint64(dst, binary.BigEndian.Uint64(val)) + return 8, 0, nil + case pgtype.TimestamptzOID: + //loc := time.FixedZone("UTC", 8*3600) // UTC+8 + // This will look for the name CEST in the Europe/Berlin time zone. + //const longForm = "2006-01-02 15:04:05" + //t, _ := time.ParseInLocation(longForm, string(val), loc) + //milliseconds := t.UnixNano() / int64(time.Millisecond) + i := binary.BigEndian.Uint64(val) + tm := PgBinaryToTime(int64(i)) + nanosecond := tm.Nanosecond() + second := tm.Unix() + tum := second*1000 + int64(nanosecond/1000000) + binary.LittleEndian.PutUint64(dst, uint64(tum)) + return 8, 0, nil + case pgtype.Int4OID: + num := binary.BigEndian.Uint64(val) + if num > uint64(math.MaxInt32) || int64(num) < int64(math.MinInt32) { + return -1, -1, strconv.ErrRange + } + binary.LittleEndian.PutUint32(dst, uint32(num)) + return 4, 0, nil + case pgtype.BPCharOID: + copy(dst, val) + return int(storelen), 0, nil + case pgtype.VarcharOID: + if len(val) > int(storelen) { + return -1, 0, strconv.ErrRange + } + if isPtag { + copy(dst, val) + return int(storelen), 0, nil + } + currentVarPos := offsetVar + binary.LittleEndian.PutUint64(dst, uint64(currentVarPos)) + // var part + binary.LittleEndian.PutUint16(dstVarStart[currentVarPos:], uint16(len(val)+1)) + currentVarPos += 2 + copy(dstVarStart[currentVarPos:], val) + return 8, int(offsetVar) + len(val) + 2 + 1, nil // 8 real 4 + default: + panic(0) + } +} + +func CalcVarPosHead(ptagDataLen int, posBaseHead int, idxTag int, + StorageLen []uint32) (int, int) { + posTagStart := 0 + posVarStart := 0 + posVarStart = posBaseHead + posVarStart += 2 + posVarStart += ptagDataLen + posVarStart += 4 + posTagStart = posVarStart + // bitmap + allColCount := len(StorageLen) + posVarStart += int((allColCount-idxTag)/8) + 1 + for i := idxTag; i < len(StorageLen); i++ { + posVarStart += int(StorageLen[i]) + } + return posTagStart, posVarStart +} + +// calc metric col row size +func computeRowSize(ParamOIDs []uint32, StorageLen []uint32) int { + rowSize := 0 + for i, oid := range ParamOIDs { + switch oid { + case pgtype.Int8OID: + rowSize += 8 + case pgtype.Int2OID: + rowSize += 2 + case pgtype.Int4OID: + rowSize += 4 + case pgtype.TimestamptzOID: + rowSize += 8 + case pgtype.BPCharOID: + rowSize += int(StorageLen[i]) + case pgtype.VarcharOID: + rowSize += VarColumnSize + default: + panic(0) + } + } + return rowSize +} + +const ( + OnlyData = 1 + OnlyTag = 2 + BothTagAndData = 0 +) + +func (pd *PayloadBuffer) FillOneRow(args [][]byte, ParamOIDs []uint32, + PtagIDs []uint16, TagIndex int16, StorageLen []uint32, rowNum uint32) error { + var err error + lenCol := 0 + pos := 0 + usedVarLen := 0 + usingVarLen := 0 + if pd.HeadTail == 0 { + pos = HeadSize + // add head + // header(43)+ptaglen(2)+ptag+taglen(4)+tag + // tag feild len with define coldes + ptaglen := 0 + // write ptag value and len + ptagDataStart := pos + PTagLenSize + for _, pIdx := range PtagIDs { + ptaglen += int(StorageLen[pIdx]) + tagLen, _, err := FillColData(ParamOIDs[pIdx], args[pIdx], pd.Data[ptagDataStart:], + StorageLen[pIdx], nil, 0, true) + if err != nil { + return err + } + ptagDataStart += tagLen + } + + taglen := 0 + + // rowNum + // binary.LittleEndian.PutUint32(pd.Data[RowNumOffset:], rowNum) + posTagStart, posVar := CalcVarPosHead(ptaglen, pos, int(TagIndex), StorageLen) + // ptag data + binary.LittleEndian.PutUint16(pd.Data[pos:], uint16(ptaglen)) + pos += 2 + // copy(pd.Data[pos:], ptag) + + // tag data + tagLenPos := posTagStart - 4 + allColCount := len(StorageLen) + pos = posTagStart + int((allColCount-int(TagIndex))/8) + 1 // bitmap + taglen += ((allColCount - int(TagIndex)) / 8) + 1 + payloadFlag := BothTagAndData + for i := int(TagIndex); i < len(StorageLen); i++ { + if (i+1) > len(args) || args[i] == nil { + pd.Data[posTagStart+((i-int(TagIndex))/8)] |= 1 << ((i - int(TagIndex)) % 8) + pos += int(StorageLen[int(i)]) + taglen += int(StorageLen[int(i)]) + continue + } + payloadFlag = BothTagAndData + lenCol, usingVarLen, err = FillColData(ParamOIDs[i], args[i], pd.Data[pos:], + StorageLen[i], pd.Data[posVar:], uint16(usedVarLen), false) + if err != nil { + return err + } + if usingVarLen > 0 { + usedVarLen += usingVarLen + } + taglen += lenCol + } + pd.Data[RowNumOffset+4] = byte(payloadFlag) + binary.LittleEndian.PutUint32(pd.Data[tagLenPos:], uint32(taglen)) + pd.HeadTail = posVar + usedVarLen + pd.Tail = pd.HeadTail + 4 // data len + } + // data part + rowStart := pd.Tail + 4 // row length + usedVarLen = 0 + rowDataStart := rowStart + int(TagIndex/8) + 1 // bitmap + lenTuple := computeRowSize(ParamOIDs[:TagIndex], StorageLen) + posRowVarStart := lenTuple + rowDataStart + pos = rowDataStart + for c := 0; c < int(TagIndex); c++ { + lenCol, usingVarLen, err = FillColData(ParamOIDs[c], args[c], + pd.Data[pos:], StorageLen[c], pd.Data[posRowVarStart:], uint16(usedVarLen), false) + if err != nil { + return err + } + // lenTuple += lenCol + usedVarLen += usingVarLen + pos += lenCol + } + + // write row len other this row invalid + rowLen := usedVarLen + (posRowVarStart - rowStart) // var len and data len and bitmap + binary.LittleEndian.PutUint32(pd.Data[pd.Tail:], uint32(rowLen)) + pd.Tail = posRowVarStart + usedVarLen + return nil +} + +type BindEx struct { + DestinationPortal string + PreparedStatement string + PtagToPayload map[string]*PayloadBuffer +} + +// Frontend identifies this message as sendable by a PostgreSQL frontend. +func (*BindEx) Frontend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *BindEx) Decode(src []byte) error { + *dst = BindEx{} + + idx := bytes.IndexByte(src, 0) + if idx < 0 { + return &invalidMessageFormatErr{messageType: "BindEx"} + } + dst.DestinationPortal = string(src[:idx]) + rp := idx + 1 + + idx = bytes.IndexByte(src[rp:], 0) + if idx < 0 { + return &invalidMessageFormatErr{messageType: "BindEx"} + } + dst.PreparedStatement = string(src[rp : rp+idx]) + rp += idx + 1 + + if len(src[rp:]) < 2 { + return &invalidMessageFormatErr{messageType: "BindEx"} + } + // payload count + // all payload + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *BindEx) Encode(dst []byte) []byte { + dst = append(dst, 'W') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + + dst = append(dst, src.DestinationPortal...) + dst = append(dst, 0) + dst = append(dst, src.PreparedStatement...) + dst = append(dst, 0) + + // all payloads + dst = pgio.AppendInt16(dst, int16(len(src.PtagToPayload))) // payload count + + for ptag, payload := range src.PtagToPayload { + dst = append(dst, ptag...) + dst = append(dst, 0) + + // data + dst = pgio.AppendInt32(dst, int32(payload.Tail)) + dst = append(dst, payload.Data[:payload.Tail]...) + } + + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + + return dst +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src BindEx) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + DestinationPortal string + PreparedStatement string + }{ + Type: "BindEx", + DestinationPortal: src.DestinationPortal, + PreparedStatement: src.PreparedStatement, + }) +} + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *BindEx) UnmarshalJSON(data []byte) error { + + return nil +} diff --git a/pgx/v5/pgproto3/frontend.go b/pgx/v5/pgproto3/frontend.go index 9f0aab2..f666c5d 100644 --- a/pgx/v5/pgproto3/frontend.go +++ b/pgx/v5/pgproto3/frontend.go @@ -52,6 +52,7 @@ type Frontend struct { rowDescription RowDescription portalSuspended PortalSuspended aEParameter AEParameter + parameterDescriptionEx ParameterDescriptionEx bodyLen int msgType byte @@ -132,6 +133,14 @@ func (f *Frontend) SendBind(msg *Bind) { } } +func (f *Frontend) SendBindEx(msg *BindEx) { + prevLen := len(f.wbuf) + f.wbuf = msg.Encode(f.wbuf) + if f.tracer != nil { + f.tracer.traceBindEx('F', int32(len(f.wbuf)-prevLen), msg) + } +} + // SendParse sends a Parse message to the backend (i.e. the server). The message is not guaranteed to be written until // Flush is called. func (f *Frontend) SendParse(msg *Parse) { @@ -142,6 +151,14 @@ func (f *Frontend) SendParse(msg *Parse) { } } +func (f *Frontend) SendParseEx(msg *ParseEx) { + prevLen := len(f.wbuf) + f.wbuf = msg.Encode(f.wbuf) + if f.tracer != nil { + f.tracer.traceParseEx('F', int32(len(f.wbuf)-prevLen), msg) + } +} + // SendClose sends a Close message to the backend (i.e. the server). The message is not guaranteed to be written until // Flush is called. func (f *Frontend) SendClose(msg *Close) { @@ -309,6 +326,8 @@ func (f *Frontend) Receive() (BackendMessage, error) { msg = &f.copyBothResponse case 'Z': msg = &f.readyForQuery + case 'X': + msg = &f.parameterDescriptionEx default: return nil, fmt.Errorf("unknown message type: %c", f.msgType) } diff --git a/pgx/v5/pgproto3/parameter_description.go b/pgx/v5/pgproto3/parameter_description.go index 374d38a..20713c0 100644 --- a/pgx/v5/pgproto3/parameter_description.go +++ b/pgx/v5/pgproto3/parameter_description.go @@ -64,3 +64,88 @@ func (src ParameterDescription) MarshalJSON() ([]byte, error) { ParameterOIDs: src.ParameterOIDs, }) } + +type ParameterDescriptionEx struct { + ParameterOIDs []uint32 // ts table create sequence + TagIndex int16 // index of tag start position + PtagIDs []uint16 // ptag index and sequence + StorageLen []uint32 +} + +// Backend identifies this message as sendable by the PostgreSQL backend. +func (*ParameterDescriptionEx) Backend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *ParameterDescriptionEx) Decode(src []byte) error { + buf := bytes.NewBuffer(src) + + if buf.Len() < 2 { + return &invalidMessageFormatErr{messageType: "ParameterDescriptionEx"} + } + + // Instead infer parameter count by remaining size of message + parameterCount := binary.BigEndian.Uint16(buf.Next(2)) + + *dst = ParameterDescriptionEx{ParameterOIDs: make([]uint32, parameterCount)} + + for i := uint16(0); i < parameterCount; i++ { + dst.ParameterOIDs[i] = binary.BigEndian.Uint32(buf.Next(4)) + } + + dst.TagIndex = int16(binary.BigEndian.Uint16(buf.Next(2))) + ptagCount := binary.BigEndian.Uint16(buf.Next(2)) + for i := uint16(0); i < ptagCount; i++ { + dst.PtagIDs = append(dst.PtagIDs, binary.BigEndian.Uint16(buf.Next(2))) + } + + lenCount := binary.BigEndian.Uint16(buf.Next(2)) + for i := uint16(0); i < lenCount; i++ { + dst.StorageLen = append(dst.StorageLen, binary.BigEndian.Uint32(buf.Next(4))) + } + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *ParameterDescriptionEx) Encode(dst []byte) []byte { + dst = append(dst, 'X') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + + dst = pgio.AppendUint16(dst, uint16(len(src.ParameterOIDs))) + for _, oid := range src.ParameterOIDs { + dst = pgio.AppendUint32(dst, oid) + } + + dst = pgio.AppendInt16(dst, src.TagIndex) + + dst = pgio.AppendUint16(dst, uint16(len(src.PtagIDs))) + for _, id := range src.PtagIDs { + dst = pgio.AppendUint16(dst, id) + } + + dst = pgio.AppendUint16(dst, uint16(len(src.StorageLen))) + for _, len := range src.StorageLen { + dst = pgio.AppendUint32(dst, len) + } + + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + + return dst +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src ParameterDescriptionEx) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + ParameterOIDs []uint32 + TagIndex int16 + PtagIDs []uint16 + }{ + Type: "ParameterDescriptionEx", + ParameterOIDs: src.ParameterOIDs, + TagIndex: src.TagIndex, + PtagIDs: src.PtagIDs, + }) +} diff --git a/pgx/v5/pgproto3/parse.go b/pgx/v5/pgproto3/parse.go index b53200d..ddae5c0 100644 --- a/pgx/v5/pgproto3/parse.go +++ b/pgx/v5/pgproto3/parse.go @@ -86,3 +86,62 @@ func (src Parse) MarshalJSON() ([]byte, error) { ParameterOIDs: src.ParameterOIDs, }) } + +type ParseEx struct { + Name string + TableName string +} + +// Frontend identifies this message as sendable by a PostgreSQL frontend. +func (*ParseEx) Frontend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *ParseEx) Decode(src []byte) error { + *dst = ParseEx{} + + buf := bytes.NewBuffer(src) + + b, err := buf.ReadBytes(0) + if err != nil { + return err + } + dst.Name = string(b[:len(b)-1]) + + b, err = buf.ReadBytes(0) + if err != nil { + return err + } + dst.TableName = string(b[:len(b)-1]) + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *ParseEx) Encode(dst []byte) []byte { + dst = append(dst, 'R') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + + dst = append(dst, src.Name...) + dst = append(dst, 0) + dst = append(dst, src.TableName...) + dst = append(dst, 0) + + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + + return dst +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src ParseEx) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Name string + TableName string + }{ + Type: "Parse", + Name: src.Name, + TableName: src.TableName, + }) +} diff --git a/pgx/v5/pgproto3/trace.go b/pgx/v5/pgproto3/trace.go index a73fce7..655e55a 100644 --- a/pgx/v5/pgproto3/trace.go +++ b/pgx/v5/pgproto3/trace.go @@ -191,6 +191,11 @@ func (t *tracer) traceBind(sender byte, encodedLen int32, msg *Bind) { t.finishTrace() } +func (t *tracer) traceBindEx(sender byte, encodedLen int32, msg *BindEx) { + t.beginTrace(sender, encodedLen, "BindEx") + t.finishTrace() +} + func (t *tracer) traceBindComplete(sender byte, encodedLen int32, msg *BindComplete) { t.beginTrace(sender, encodedLen, "BindComplete") t.finishTrace() @@ -339,6 +344,12 @@ func (t *tracer) traceParse(sender byte, encodedLen int32, msg *Parse) { t.finishTrace() } +func (t *tracer) traceParseEx(sender byte, encodedLen int32, msg *ParseEx) { + t.beginTrace(sender, encodedLen, "Parse") + fmt.Fprintf(t.buf, "\t %s %s", traceDoubleQuotedString([]byte(msg.Name)), traceDoubleQuotedString([]byte(msg.TableName))) + t.finishTrace() +} + func (t *tracer) traceParseComplete(sender byte, encodedLen int32, msg *ParseComplete) { t.beginTrace(sender, encodedLen, "ParseComplete") t.finishTrace() diff --git a/pkg/targets/kwdb/process_prepare.go b/pkg/targets/kwdb/process_prepare.go index a44acad..e8f11a3 100644 --- a/pkg/targets/kwdb/process_prepare.go +++ b/pkg/targets/kwdb/process_prepare.go @@ -4,6 +4,7 @@ import ( "context" "encoding/binary" "fmt" + "github.com/jackc/pgx/v5/pgconn" "strconv" "strings" @@ -70,6 +71,8 @@ type prepareProcessor struct { buffer map[string]*fixedArgList // tableName, fixedArgList buffInited bool formatBuf []int16 + + sd *pgconn.StatementDescription } func newProcessorPrepare(opts *LoadingOptions, dbName string) *prepareProcessor { @@ -178,6 +181,15 @@ func (p *prepareProcessor) ProcessBatch(b targets.Batch, doLoad bool) (metricCou } } + /*if p.sd != nil && p.sd.ParamOIDs != nil { + nvalCount := len(values) + rowNeedCount := len(p.sd.ParamOIDs) + for ; rowNeedCount > nvalCount; nvalCount++ { + // fill nill other tag for cpu all field + tableBuffer.Append(nil) + } + }*/ + // check buffer is full if tableBuffer.Length() == tableBuffer.Capacity() { // init prepareStmt @@ -185,9 +197,18 @@ func (p *prepareProcessor) ProcessBatch(b targets.Batch, doLoad bool) (metricCou if !ok { p.createPrepareSql("cpu") p.preparedSql["cpu"] = struct{}{} + /*if p.sd != nil && p.sd.ParamOIDs != nil { + nvalCount := len(values) + rowNeedCount := len(p.sd.ParamOIDs) + for ; rowNeedCount > nvalCount; nvalCount++ { + // fill nill other tag for cpu all field + tableBuffer.Append(nil) + } + }*/ } - p.execPrepareStmt("cpu", tableBuffer.args) + //p.execPrepareStmt("cpu", tableBuffer.args) + p.execPrepareStmtEx("cpu", tableBuffer.args, 12) // reuse buffer: reset tableBuffer's write position tableBuffer.Reset() } @@ -251,7 +272,9 @@ func (p *prepareProcessor) createPrepareSql(deviecName string) { query := fmt.Sprintf("insert into %s.cpu (k_timestamp,usage_user,usage_system,usage_idle,usage_nice,usage_iowait,usage_irq,usage_softirq,usage_steal,usage_guest,usage_guest_nice,hostname) values ", p.opts.DBName) insertsql.WriteString(query) sql := insertsql.String() + p.prepareStmt.String() - _, err1 := p._db.Connection.Prepare(context.Background(), "insertall"+deviecName, sql) + // _, err1 := p._db.Connection.Prepare(context.Background(), "insertall"+deviecName, sql) + var err1 error + p.sd, err1 = p._db.Connection.PrepareEx(context.Background(), "insertall"+deviecName, "benchmark.public.cpu") if err1 != nil { panic(fmt.Sprintf("kwdb Prepare failed,err :%s, sql :%s", err1, sql)) } @@ -263,3 +286,10 @@ func (p *prepareProcessor) execPrepareStmt(tableName string, args [][]byte) { panic(res.Err) } } + +func (p *prepareProcessor) execPrepareStmtEx(tableName string, args [][]byte, colCountPerRow int) { + res := p._db.Connection.PgConn().ExecPreparedEx(context.Background(), "insertall"+tableName, p.sd, args, colCountPerRow).Read() + if res.Err != nil { + panic(res.Err) + } +} -- Gitee From ac930dc8ff6088dfe7ce40c941d78bb896a3e0e5 Mon Sep 17 00:00:00 2001 From: mmhuge123 Date: Fri, 8 May 2026 00:52:26 +0000 Subject: [PATCH 2/3] fix pgwire extend --- pgx/v5/pgconn/pgconn.go | 8 +++++++- pgx/v5/pgproto3/bind.go | 11 +++++------ pgx/v5/pgproto3/parse.go | 2 +- 3 files changed, 13 insertions(+), 8 deletions(-) diff --git a/pgx/v5/pgconn/pgconn.go b/pgx/v5/pgconn/pgconn.go index 44be252..6244fca 100644 --- a/pgx/v5/pgconn/pgconn.go +++ b/pgx/v5/pgconn/pgconn.go @@ -1182,7 +1182,13 @@ func (pgConn *PgConn) ExecPreparedEx(ctx context.Context, stmtName string, sd *S } row := args[pos : pos+rowColCount] - payload.FillOneRow(row, sd.ParamOIDs, sd.PtagIDs, sd.TagIndex, sd.StorageLen, 0) + if err := payload.FillOneRow(row, sd.ParamOIDs, sd.PtagIDs, sd.TagIndex, sd.StorageLen, 0); err != nil { + result.concludeCommand(CommandTag{}, err) + result.closed = true + pgConn.contextWatcher.Unwatch() + pgConn.unlock() + return result + } payload.RowNum++ pos += rowColCount diff --git a/pgx/v5/pgproto3/bind.go b/pgx/v5/pgproto3/bind.go index 0a45d46..f4fe4da 100644 --- a/pgx/v5/pgproto3/bind.go +++ b/pgx/v5/pgproto3/bind.go @@ -9,7 +9,6 @@ import ( "fmt" "github.com/jackc/pgx/v5/internal/pgio" "github.com/jackc/pgx/v5/pgtype" - "math" "strconv" "time" ) @@ -370,11 +369,11 @@ func FillColData(oid uint32, val []byte, dst []byte, storelen uint32, dstVarStar binary.LittleEndian.PutUint64(dst, uint64(tum)) return 8, 0, nil case pgtype.Int4OID: - num := binary.BigEndian.Uint64(val) - if num > uint64(math.MaxInt32) || int64(num) < int64(math.MinInt32) { - return -1, -1, strconv.ErrRange - } - binary.LittleEndian.PutUint32(dst, uint32(num)) + num := binary.BigEndian.Uint32(val) + //if num > math.MaxInt32 || num < math.MinInt32 { + // return -1, -1, strconv.ErrRange + //} + binary.LittleEndian.PutUint32(dst, num) return 4, 0, nil case pgtype.BPCharOID: copy(dst, val) diff --git a/pgx/v5/pgproto3/parse.go b/pgx/v5/pgproto3/parse.go index ddae5c0..eb223f1 100644 --- a/pgx/v5/pgproto3/parse.go +++ b/pgx/v5/pgproto3/parse.go @@ -140,7 +140,7 @@ func (src ParseEx) MarshalJSON() ([]byte, error) { Name string TableName string }{ - Type: "Parse", + Type: "ParseEx", Name: src.Name, TableName: src.TableName, }) -- Gitee From 0d6ea2fd4166307a011c36825c710bba95b5fb18 Mon Sep 17 00:00:00 2001 From: mmhuge123 Date: Fri, 8 May 2026 06:48:40 +0000 Subject: [PATCH 3/3] add --insert-type=prepare-extend --- docs/kwdb.md | 3 ++- docs/kwdb_zh.md | 3 ++- internal/mcp_tools/tools.go | 15 ++++++++------- internal/service/execution_service.go | 2 +- pkg/targets/kwdb/benchmark.go | 9 ++++++--- pkg/targets/kwdb/process_prepare.go | 21 +++++++++++++++++---- scripts/tsbs_kwdb.sh | 3 ++- 7 files changed, 38 insertions(+), 18 deletions(-) diff --git a/docs/kwdb.md b/docs/kwdb.md index a2375d0..8bb3dc5 100644 --- a/docs/kwdb.md +++ b/docs/kwdb.md @@ -87,7 +87,7 @@ Hostname of the kwdb server. Port of the kwdb server. #### `-insert-type` (type: `string`) -Optional as `insert, prepare, prepareiot` +Optional as `insert, prepare, prepare-extend, prepareiot` Note: The correspondence between case and insert-type is as follows: @@ -95,6 +95,7 @@ Note: The correspondence between case and insert-type is as follows: |----------|-------------| | cpu-only | insert | | cpu-only | prepare | +| cpu-only | prepare-extend | | IoT | insert | | IoT | prepareiot | diff --git a/docs/kwdb_zh.md b/docs/kwdb_zh.md index c19bcd7..ef627fe 100644 --- a/docs/kwdb_zh.md +++ b/docs/kwdb_zh.md @@ -85,7 +85,7 @@ KWDB 服务器地址。 KWDB 服务器端口。 #### `-insert-type` (类型:`string`) -可选值:insert、prepare 或 prepareiot。 +可选值:insert、prepare、prepare-extend 或 prepareiot。 注:case 与 insert-type 的对应关系如下: @@ -93,6 +93,7 @@ KWDB 服务器端口。 |----------|-------------| | cpu-only | insert | | cpu-only | prepare | +| cpu-only | prepare-extend | | IoT | insert | | IoT | prepareiot | diff --git a/internal/mcp_tools/tools.go b/internal/mcp_tools/tools.go index d7c775c..2bfb7c3 100644 --- a/internal/mcp_tools/tools.go +++ b/internal/mcp_tools/tools.go @@ -107,8 +107,8 @@ func RegisterTools( }, "insert_type": map[string]interface{}{ "type": "string", - "description": "Insert type. Must be one of: 'insert' (regular insert), 'prepare' (prepared statement), 'prepareiot' (IoT prepared). Default: 'insert'", - "enum": []string{"insert", "prepare", "prepareiot"}, + "description": "Insert type. Must be one of: 'insert' (regular insert), 'prepare' (prepared statement), 'prepare-extend' (prepared statement with extended execution), 'prepareiot' (IoT prepared). Default: 'insert'", + "enum": []string{"insert", "prepare", "prepare-extend", "prepareiot"}, "default": "insert", }, "case": map[string]interface{}{ @@ -121,7 +121,7 @@ func RegisterTools( }, "preparesize": map[string]interface{}{ "type": "integer", - "description": "Prepared statement size (optional). Only used for prepare and prepareiot insert types.", + "description": "Prepared statement size (optional). Only used for prepare, prepare-extend, and prepareiot insert types.", }, "workers": map[string]interface{}{ "type": "integer", @@ -703,15 +703,16 @@ func handleLoadData( input.InsertType = "insert" } else { validTypes := map[string]bool{ - "insert": true, - "prepare": true, - "prepareiot": true, + "insert": true, + "prepare": true, + "prepare-extend": true, + "prepareiot": true, } if !validTypes[input.InsertType] { return nil, LoadDataOutput{ TaskID: "", Status: "error", - Message: fmt.Sprintf("参数验证失败: invalid insert_type: %s, must be one of: insert, prepare, prepareiot", input.InsertType), + Message: fmt.Sprintf("参数验证失败: invalid insert_type: %s, must be one of: insert, prepare, prepare-extend, prepareiot", input.InsertType), }, nil } } diff --git a/internal/service/execution_service.go b/internal/service/execution_service.go index 6acc49d..88e01dd 100644 --- a/internal/service/execution_service.go +++ b/internal/service/execution_service.go @@ -260,7 +260,7 @@ func (s *ExecutionService) ExecuteLoadData(ctx context.Context, taskID string, i "--case=" + input.Case, } - if input.InsertType == "prepare" || input.InsertType == "prepareiot" { + if input.InsertType == "prepare" || input.InsertType == "prepare-extend" || input.InsertType == "prepareiot" { if input.Preparesize != nil { args = append(args, fmt.Sprintf("--preparesize=%d", *input.Preparesize)) } else if input.BatchSize != nil { diff --git a/pkg/targets/kwdb/benchmark.go b/pkg/targets/kwdb/benchmark.go index 12217cf..84bc3d6 100644 --- a/pkg/targets/kwdb/benchmark.go +++ b/pkg/targets/kwdb/benchmark.go @@ -9,9 +9,10 @@ import ( ) var ( - KWDBINSERT = "insert" - KWDBPREPARE = "prepare" - KWDBPREPAREIOT = "prepareiot" + KWDBINSERT = "insert" + KWDBPREPARE = "prepare" + KWDBPREPAREEXTEND = "prepare-extend" + KWDBPREPAREIOT = "prepareiot" ) func NewBenchmark(dbName string, opts *LoadingOptions, dataSourceConfig *source.DataSourceConfig) (targets.Benchmark, error) { @@ -71,6 +72,8 @@ func (b *benchmark) GetProcessor() targets.Processor { return newProcessorInsert(b.opts, b.dbName) case KWDBPREPARE: return newProcessorPrepare(b.opts, b.dbName) + case KWDBPREPAREEXTEND: + return newProcessorPrepare(b.opts, b.dbName) case KWDBPREPAREIOT: return newProcessorPrepareiot(b.opts, b.dbName) default: diff --git a/pkg/targets/kwdb/process_prepare.go b/pkg/targets/kwdb/process_prepare.go index 95eb656..4cc7ae9 100644 --- a/pkg/targets/kwdb/process_prepare.go +++ b/pkg/targets/kwdb/process_prepare.go @@ -87,6 +87,10 @@ func newProcessorPrepare(opts *LoadingOptions, dbName string) *prepareProcessor } } +func (p *prepareProcessor) useExtend() bool { + return p.opts.Type == KWDBPREPAREEXTEND +} + func (p *prepareProcessor) Init(workerNum int, doLoad, _ bool) { if !doLoad { return @@ -187,8 +191,12 @@ func (p *prepareProcessor) ProcessBatch(b targets.Batch, doLoad bool) (metricCou cpuPrepared = true } - //p.execPrepareStmt("cpu", tableBuffer.args) - p.execPrepareStmtEx("cpu", tableBuffer.args, 12) + if p.useExtend() { + p.execPrepareStmtEx("cpu", tableBuffer.args, 12) + } else { + p.execPrepareStmt("cpu", tableBuffer.args) + } + // reuse buffer: reset tableBuffer's write position tableBuffer.Reset() } @@ -299,9 +307,14 @@ func (p *prepareProcessor) createPrepareSql(deviecName string) { query := fmt.Sprintf("insert into %s.cpu (k_timestamp,usage_user,usage_system,usage_idle,usage_nice,usage_iowait,usage_irq,usage_softirq,usage_steal,usage_guest,usage_guest_nice,hostname) values ", p.opts.DBName) insertsql.WriteString(query) sql := insertsql.String() + p.prepareStmt.String() - // _, err1 := p._db.Connection.Prepare(context.Background(), "insertall"+deviecName, sql) var err1 error - p.sd, err1 = p._db.Connection.PrepareEx(context.Background(), "insertall"+deviecName, "benchmark.public.cpu") + + if p.useExtend() { + p.sd, err1 = p._db.Connection.PrepareEx(context.Background(), "insertall"+deviecName, "benchmark.public.cpu") + } else { + _, err1 = p._db.Connection.Prepare(context.Background(), "insertall"+deviecName, sql) + } + if err1 != nil { panic(fmt.Sprintf("kwdb Prepare failed,err :%s, sql :%s", err1, sql)) } diff --git a/scripts/tsbs_kwdb.sh b/scripts/tsbs_kwdb.sh index a98d00a..38a0f62 100755 --- a/scripts/tsbs_kwdb.sh +++ b/scripts/tsbs_kwdb.sh @@ -25,7 +25,7 @@ enable_perf=${enable_perf:-false} parallel_degree=${parallel_degree:-8} # 查询并行性,默认为 8 ## 数据写入配置 -insert_type=${insert_type:-insert} # 写入方式,默认为 insert,可设置:insert、prepare、prepareiot +insert_type=${insert_type:-insert} # 写入方式,默认为 insert,可设置:insert、prepare、prepare-extend、prepareiot tsbs_case=${tsbs_case:-cpu-only} # case 类型,默认为 cpu-only, 可设置:iot 注:case 与 insert-type 的对应关系如下: @@ -33,6 +33,7 @@ tsbs_case=${tsbs_case:-cpu-only} # case 类型,默认为 cpu-only, 可设 |----------|-------------| | cpu-only | insert | | cpu-only | prepare | + | cpu-only | prepare-extend | | IoT | insert | | IoT | prepareiot | -- Gitee