From e15abf173f553b7a92d46c2e3d691afbd18d63ef Mon Sep 17 00:00:00 2001 From: wuyujun <714166892@qq.com> Date: Sat, 13 May 2023 15:11:59 +0800 Subject: [PATCH] =?UTF-8?q?=E8=A7=A3=E5=86=B3update=E5=A4=9A=E8=A1=A8?= =?UTF-8?q?=E6=98=BE=E5=BC=8F=E5=85=B3=E8=81=94=E8=B6=85=E8=BF=87=E4=B8=89?= =?UTF-8?q?=E5=BC=A0=E8=A1=A8=E7=BB=93=E6=9E=9C=E9=94=99=E8=AF=AF=E7=9A=84?= =?UTF-8?q?=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/common/backend/parser/analyze.cpp | 42 +++++++++++++--------- src/test/regress/expected/multi_update.out | 3 ++ src/test/regress/sql/multi_update.sql | 3 ++ 3 files changed, 31 insertions(+), 17 deletions(-) diff --git a/src/common/backend/parser/analyze.cpp b/src/common/backend/parser/analyze.cpp index d8087112ef..3a2f950047 100644 --- a/src/common/backend/parser/analyze.cpp +++ b/src/common/backend/parser/analyze.cpp @@ -105,7 +105,7 @@ static const int MILLISECONDS_PER_SECONDS = 1000; static Query* transformDeleteStmt(ParseState* pstate, DeleteStmt* stmt); static Query* transformInsertStmt(ParseState* pstate, InsertStmt* stmt); static void checkUpsertTargetlist(Relation targetTable, List* updateTlist); -static UpsertExpr* transformUpsertClause(ParseState* pstate, UpsertClause* upsertClause); +static UpsertExpr* transformUpsertClause(ParseState* pstate, UpsertClause* upsertClause, List* resultRelations); static int count_rowexpr_columns(ParseState* pstate, Node* expr); static void transformVariableSetStmt(ParseState* pstate, VariableSetStmt* stmt); static Query* transformVariableMutiSetStmt(ParseState* pstate, VariableMultiSetStmt* muti_stmt); @@ -116,7 +116,7 @@ static Query* transformSetOperationStmt(ParseState* pstate, SelectStmt* stmt); static Node* transformSetOperationTree(ParseState* pstate, SelectStmt* stmt, bool isTopLevel, List** targetlist); static void determineRecursiveColTypes(ParseState* pstate, Node* larg, List* nrtargetlist); static Query* transformUpdateStmt(ParseState* pstate, UpdateStmt* stmt); -static List* transformUpdateTargetList(ParseState* pstate, List* qryTlist, List* origTlist); +static List* transformUpdateTargetList(ParseState* pstate, List* qryTlist, List* origTlist, List* resultRelations); static List* transformReturningList(ParseState* pstate, List* returningList); static Query* transformDeclareCursorStmt(ParseState* pstate, DeclareCursorStmt* stmt); static Query* transformExplainStmt(ParseState* pstate, ExplainStmt* stmt); @@ -2210,10 +2210,10 @@ static Query* transformInsertStmt(ParseState* pstate, InsertStmt* stmt) pstate->p_varnamespace = NIL; rightRefState->isUpsert = true; SetUpsertAttrnoState(pstate, stmt->upsertClause->targetList); - qry->upsertClause = transformUpsertClause(pstate, stmt->upsertClause); + qry->upsertClause = transformUpsertClause(pstate, stmt->upsertClause, qry->resultRelations); rightRefState->isUpsert = false; } else { - qry->upsertClause = transformUpsertClause(pstate, stmt->upsertClause); + qry->upsertClause = transformUpsertClause(pstate, stmt->upsertClause, qry->resultRelations); } } /* @@ -2395,7 +2395,7 @@ static bool ContainSubLink(Node* clause) } #endif /* ENABLE_MULTIPLE_NODES */ -static UpsertExpr* transformUpsertClause(ParseState* pstate, UpsertClause* upsertClause) +static UpsertExpr* transformUpsertClause(ParseState* pstate, UpsertClause* upsertClause, List* resultRelations) { UpsertExpr* result = NULL; List* updateTlist = NIL; @@ -2456,7 +2456,7 @@ static UpsertExpr* transformUpsertClause(ParseState* pstate, UpsertClause* upser updateTlist = transformTargetList(pstate, upsertClause->targetList, EXPR_KIND_UPDATE_TARGET); /* Done with select-like processing, move on transforming to match update set target column */ - updateTlist = transformUpdateTargetList(pstate, updateTlist, upsertClause->targetList); + updateTlist = transformUpdateTargetList(pstate, updateTlist, upsertClause->targetList, resultRelations); updateWhere = transformWhereClause(pstate, upsertClause->whereClause, EXPR_KIND_WHERE, "WHERE"); #ifdef ENABLE_MULTIPLE_NODES /* Do not support sublinks in update where clause for now */ @@ -4066,7 +4066,7 @@ static Query* transformUpdateStmt(ParseState* pstate, UpdateStmt* stmt) * Now we are done with SELECT-like processing, and can get on with * transforming the target list to match the UPDATE target columns. */ - qry->targetList = transformUpdateTargetList(pstate, qry->targetList, stmt->targetList); + qry->targetList = transformUpdateTargetList(pstate, qry->targetList, stmt->targetList, qry->resultRelations); transformLimitSortClause(pstate, stmt, qry, false); qry->resultRelations = remove_update_redundant_relation(qry->resultRelations, pstate->p_target_rangetblentry); @@ -4151,22 +4151,25 @@ char* checkUpdateResTargetName(Relation rd, RangeVar* rel, ResTarget* res, bool* } /* Find the attrno corresponding to ResTarget from target tables. */ -static int fixUpdateResTargetName(ParseState* pstate, ResTarget* res, int* rti, Relation* rd, RangeTblEntry** rte) +static int fixUpdateResTargetName(ParseState* pstate, List* resultRelations, ResTarget* res, int* rti, + Relation* rd, RangeTblEntry** rte) { ListCell* l1; ListCell* l2; ListCell* l3; + ListCell* l4; bool removeRelname = false, matchRelname = false; char* resname = NULL; char* resultResName = NULL; int attrno, resultAttrno = InvalidAttrNumber; bool isMatched = false; - int rtindex = 1; - forthree (l1, pstate->p_target_rangetblentry, l2, pstate->p_target_relation, l3, pstate->p_updateRangeVars) { + forfour (l1, pstate->p_target_rangetblentry, l2, pstate->p_target_relation, l3, pstate->p_updateRangeVars, + l4, resultRelations) { RangeTblEntry* target_rte = (RangeTblEntry*)lfirst(l1); Relation targetrel = (Relation)lfirst(l2); RangeVar* rangeVar = (RangeVar*)lfirst(l3); + int rtindex = lfirst_int(l4); resname = checkUpdateResTargetName(targetrel, rangeVar, res, &matchRelname); attrno = attnameAttNum(targetrel, resname, true); @@ -4195,7 +4198,6 @@ static int fixUpdateResTargetName(ParseState* pstate, ResTarget* res, int* rti, if (matchRelname == true) { removeRelname = true; } - rtindex++; } if (!isMatched) { @@ -4270,7 +4272,7 @@ static inline void checkSRFInMultiUpdate(Expr* expr, int targetRelationNum) * transformUpdateTargetList - * handle SET clause in UPDATE/INSERT ... DUPLICATE KEY UPDATE */ -static List* transformUpdateTargetList(ParseState* pstate, List* qryTlist, List* origTlist) +static List* transformUpdateTargetList(ParseState* pstate, List* qryTlist, List* origTlist, List* resultRelations) { List* tlist = NIL; RangeTblEntry* target_rte = NULL; @@ -4279,8 +4281,9 @@ static List* transformUpdateTargetList(ParseState* pstate, List* qryTlist, List* Relation targetrel = NULL; int rtindex = 0; int targetRelationNum = list_length(pstate->p_target_relation); + int rangeTableNum = list_length(pstate->p_rtable); - List** new_tle = (List**)palloc0(targetRelationNum * sizeof(List*)); + List** new_tle = (List**)palloc0(rangeTableNum * sizeof(List*)); /* Prepare to assign non-conflicting resnos to resjunk attributes */ pstate->p_next_resno = 1; @@ -4307,7 +4310,7 @@ static List* transformUpdateTargetList(ParseState* pstate, List* qryTlist, List* continue; } - attrno = fixUpdateResTargetName(pstate, origTarget, &rtindex, &targetrel, &target_rte); + attrno = fixUpdateResTargetName(pstate, resultRelations, origTarget, &rtindex, &targetrel, &target_rte); if (attrno == InvalidAttrNumber) { UndefinedColumnError(pstate, origTarget, targetRelationNum); } @@ -4323,9 +4326,14 @@ static List* transformUpdateTargetList(ParseState* pstate, List* qryTlist, List* */ transformMultiTargetList(pstate->p_target_rangetblentry, new_tle); - for (int i = 0; i < targetRelationNum; i++) { - if (new_tle[i]) { - tlist = list_concat(tlist, new_tle[i]); + if (targetRelationNum == 1) { + int i = linitial_int(resultRelations); + tlist = new_tle[i - 1]; + } else { + for (int i = 0; i < rangeTableNum; i++) { + if (new_tle[i]) { + tlist = list_concat(tlist, new_tle[i]); + } } } pfree(new_tle); diff --git a/src/test/regress/expected/multi_update.out b/src/test/regress/expected/multi_update.out index 680918b971..8486009166 100644 --- a/src/test/regress/expected/multi_update.out +++ b/src/test/regress/expected/multi_update.out @@ -1247,6 +1247,9 @@ rollback; begin; update t_t_mutil_t2 t1, t_t_mutil_t2 t2, t_t_mutil_t2 t3 set t1.col1 = 3, t3.col2 = 2, t2.col3 = 3; rollback; +begin; +update t_t_mutil_t1 t1 inner join t_t_mutil_t2 t2 on t1.col1=t2.col1 inner join t_t_mutil_t3 t3 on t1.col1=t3.col1 set t1.col2=3,t2.col2=4; +rollback; CREATE SYNONYM s_t_mutil_t1 FOR t_t_mutil_t1; CREATE SYNONYM s_t_mutil_t2 FOR t_t_mutil_t1; begin; diff --git a/src/test/regress/sql/multi_update.sql b/src/test/regress/sql/multi_update.sql index 2dc9795299..7e9f9b5839 100644 --- a/src/test/regress/sql/multi_update.sql +++ b/src/test/regress/sql/multi_update.sql @@ -563,6 +563,9 @@ rollback; begin; update t_t_mutil_t2 t1, t_t_mutil_t2 t2, t_t_mutil_t2 t3 set t1.col1 = 3, t3.col2 = 2, t2.col3 = 3; rollback; +begin; +update t_t_mutil_t1 t1 inner join t_t_mutil_t2 t2 on t1.col1=t2.col1 inner join t_t_mutil_t3 t3 on t1.col1=t3.col1 set t1.col2=3,t2.col2=4; +rollback; CREATE SYNONYM s_t_mutil_t1 FOR t_t_mutil_t1; CREATE SYNONYM s_t_mutil_t2 FOR t_t_mutil_t1; -- Gitee