From 1ec64927c82377938f212fb8310352f205769cb8 Mon Sep 17 00:00:00 2001 From: zhangting Date: Tue, 5 Mar 2024 16:35:20 +0800 Subject: [PATCH] =?UTF-8?q?insert=E8=AF=AD=E5=8F=A5=E7=94=9F=E6=88=90?= =?UTF-8?q?=E7=9A=84returning=20=E2=80=9CX=E2=80=9D=EF=BC=8CX=E8=A2=AB?= =?UTF-8?q?=E8=AF=AF=E8=AF=86=E5=88=AB=E4=B8=BA=E5=AD=97=E7=AC=A6=E4=B8=B2?= =?UTF-8?q?=EF=BC=8C=E6=95=85=E5=9C=A8jdbc=E4=B8=AD=E5=8E=BB=E6=8E=89?= =?UTF-8?q?=E5=8F=8C=E5=BC=95=E5=8F=B7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../main/java/org/postgresql/core/Parser.java | 30 ++++- .../org/postgresql/test/jdbc2/ParserTest.java | 112 ++++++++++++++++++ 2 files changed, 141 insertions(+), 1 deletion(-) create mode 100644 pgjdbc/src/test/java/org/postgresql/test/jdbc2/ParserTest.java diff --git a/pgjdbc/src/main/java/org/postgresql/core/Parser.java b/pgjdbc/src/main/java/org/postgresql/core/Parser.java index b31e1f2..a3ea6b3 100644 --- a/pgjdbc/src/main/java/org/postgresql/core/Parser.java +++ b/pgjdbc/src/main/java/org/postgresql/core/Parser.java @@ -532,9 +532,17 @@ public class Parser { private static boolean addReturning(StringBuilder nativeSql, SqlCommandType currentCommandType, String[] returningColumnNames, boolean isReturningPresent, boolean isQuotedReturningIdentifiers) throws SQLException { - if (isReturningPresent || returningColumnNames.length == 0) { + if (isReturningPresent) { + if (!isQuotedReturningIdentifiers) { + removeQuotation(nativeSql); + } + return false; + } + + if (returningColumnNames.length == 0) { return false; } + if (currentCommandType != SqlCommandType.INSERT && currentCommandType != SqlCommandType.UPDATE && currentCommandType != SqlCommandType.DELETE @@ -561,6 +569,26 @@ public class Parser { return true; } + private static void removeQuotation(StringBuilder nativeSql) { + String[] queryArr = nativeSql.toString().split(" "); + nativeSql.setLength(0); + if (queryArr.length > 0) { + nativeSql.append(queryArr[0]); + } + int flag = 0; + for (int k = 1; k < queryArr.length; k++) { + nativeSql.append(" "); + String queryUpper = queryArr[k-1].toUpperCase(); + if (queryUpper.equals("RETURNING") || flag == 1) { + flag = 1; + String q = queryArr[k].replaceAll("\"", ""); + nativeSql.append(q); + } else { + nativeSql.append(queryArr[k]); + } + } + } + /** * Converts {@code List} to {@code int[]}. Empty and {@code null} lists are converted to * empty array. diff --git a/pgjdbc/src/test/java/org/postgresql/test/jdbc2/ParserTest.java b/pgjdbc/src/test/java/org/postgresql/test/jdbc2/ParserTest.java new file mode 100644 index 0000000..06abda8 --- /dev/null +++ b/pgjdbc/src/test/java/org/postgresql/test/jdbc2/ParserTest.java @@ -0,0 +1,112 @@ +package org.postgresql.test.jdbc2; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.junit.runners.Parameterized; +import org.postgresql.test.TestUtil; + +import java.sql.*; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +@RunWith(JUnit4.class) +public class ParserTest extends BaseTest4{ + + @Test + public void tesCreateTable() throws Exception { + String sql = "create table parserTab(id int not null auto_increment,class_id int,name varchar(32),primary key(id,class_id));"; + PreparedStatement ps = con.prepareStatement(sql,1); + ps.execute(); + ps.close(); + } + + @Test + public void testReturningAll() throws Exception { + String sql1 = "set dolphin.sql_mode='sql_mode_full_group,no_zero_date';"; + execute(sql1, con); + String sql = "insert into parserTab(class_id,name) values(?,?);"; + int class_id = 40; + String name = "feng"; + PreparedStatement ps = con.prepareStatement(sql,1); + ps.setObject(1, class_id, Types.INTEGER); + ps.setObject(2, name, Types.VARCHAR); + ResultSet rs = ps.executeQuery(); + assertTrue(rs.next()); + assertEquals(class_id, rs.getInt(2)); + assertEquals(name, rs.getString(3)); + System.out.println(rs.getInt(1)); + System.out.println(rs.getInt(2)); + System.out.println(rs.getString(3)); + + rs.close(); + ps.close(); + } + + @Test + public void testReturning() throws Exception { + String sql1 = "set dolphin.sql_mode='sql_mode_full_group,no_zero_date';"; + execute(sql1, con); + String sql = "insert into parserTab(class_id,name) values(?,?) returning id, class_id;"; + int class_id = 40; + String name = "feng"; + PreparedStatement ps = con.prepareStatement(sql); + ps.setObject(1, class_id, Types.INTEGER); + ps.setObject(2, name, Types.VARCHAR); + ResultSet rs = ps.executeQuery(); + assertTrue(rs.next()); + assertEquals(class_id, rs.getInt(2)); + rs.close(); + ps.close(); + } + + @Test + public void testReturningQuote() throws Exception { + String sql1 = "set dolphin.sql_mode='sql_mode_full_group,no_zero_date';"; + execute(sql1, con); + String sql = "insert into parserTab(class_id,name) values(?,?) returning \"id\", \"class_id\";"; + int class_id = 40; + String name = "feng"; + PreparedStatement ps = con.prepareStatement(sql); + ps.setObject(1, class_id, Types.INTEGER); + ps.setObject(2, name, Types.VARCHAR); + ResultSet rs = ps.executeQuery(); + assertTrue(rs.next()); + assertEquals(class_id, rs.getInt(2)); + rs.close(); + ps.close(); + } + + @Test + public void testReturningColumnNames() throws Exception { + //Connection conn = DriverManager.getConnection("jdbc:opengauss://192.168.0.127:5435/target_db?quoteReturningIdentifiers=false", "zt2", "@zt7567628"); + Connection conn = TestUtil.openDB(); + String sql1 = "set dolphin.sql_mode='sql_mode_full_group,no_zero_date';"; + execute(sql1, conn); + String sql = "insert into parserTab(class_id,name) values(?,?);"; + int class_id = 40; + String name = "feng"; + String[] columnNames = {"class_id", "name"}; + PreparedStatement ps = conn.prepareStatement(sql, columnNames); + ps.setObject(1, class_id, Types.INTEGER); + ps.setObject(2, name, Types.VARCHAR); + ResultSet rs = ps.executeQuery(); + assertTrue(rs.next()); + assertEquals(class_id, rs.getInt(1)); + assertEquals(name, rs.getString(2)); + rs.close(); + ps.close(); + } + + public static void execute(String sql, Connection conn) { + try { + Statement stmt = conn.createStatement(); + Boolean rb = stmt.execute(sql); + } catch (SQLException e) { + System.out.println("ERROR:" + e.getMessage()); + e.printStackTrace(); + } + } + +} -- Gitee