From 5bfac65d16423e8eec17233ea6911702f2903698 Mon Sep 17 00:00:00 2001 From: ZhangCheng Date: Fri, 8 Dec 2023 18:26:06 +0800 Subject: [PATCH] fix: round-robin statement-level read-write splitting issue --- .../hostchooser/MultiHostChooser.java | 8 + .../PgConnectionManager.java | 9 + .../ReadWriteSplittingHostSpec.java | 9 + .../ReadWriteSplittingPgStatement.java | 10 +- .../readwritesplitting/SqlRouteEngine.java | 44 +- ...iteSplittingConnectionMultiThreadTest.java | 53 +++ .../ReadWriteSplittingConnectionTest.java | 389 +++++++++++++----- 7 files changed, 406 insertions(+), 116 deletions(-) create mode 100644 pgjdbc/src/test/java/org/postgresql/test/readwritesplitting/ReadWriteSplittingConnectionMultiThreadTest.java diff --git a/pgjdbc/src/main/java/org/postgresql/hostchooser/MultiHostChooser.java b/pgjdbc/src/main/java/org/postgresql/hostchooser/MultiHostChooser.java index 2947297..e3109a7 100644 --- a/pgjdbc/src/main/java/org/postgresql/hostchooser/MultiHostChooser.java +++ b/pgjdbc/src/main/java/org/postgresql/hostchooser/MultiHostChooser.java @@ -26,6 +26,7 @@ import java.util.List; import java.util.Properties; import java.util.*; +import java.util.concurrent.atomic.AtomicInteger; /** @@ -49,12 +50,16 @@ public class MultiHostChooser implements HostChooser { private static Map roundRobinCounter = new HashMap<>(); + private final boolean isEnableStatementLoadBalance; + + private final AtomicInteger statementLoadBalanceCount = new AtomicInteger(0); MultiHostChooser(HostSpec[] hostSpecs, HostRequirement targetServerType, Properties info) { this.hostSpecs = hostSpecs; this.targetServerType = targetServerType; this.loadBalanceType = initLoadBalanceType(info); + this.isEnableStatementLoadBalance = PGProperty.ENABLE_STATEMENT_LOAD_BALANCE.getBoolean(info); this.URLIdentifier = QueryCNListUtils.keyFromURL(info); this.info = info; try { @@ -118,6 +123,9 @@ public class MultiHostChooser implements HostChooser { // Returns a counter and increments it by one. // Because it is possible to use it in multiple instances, use synchronized (MultiHostChooser.class). private int getRRIndex() { + if (isEnableStatementLoadBalance) { + return Math.abs(statementLoadBalanceCount.getAndIncrement()); + } synchronized (roundRobinCounter) { int value = roundRobinCounter.getOrDefault(URLIdentifier, 0); value = (value + 1) % MAX_CONNECT_NUM; diff --git a/pgjdbc/src/main/java/org/postgresql/readwritesplitting/PgConnectionManager.java b/pgjdbc/src/main/java/org/postgresql/readwritesplitting/PgConnectionManager.java index e500afa..97483dc 100644 --- a/pgjdbc/src/main/java/org/postgresql/readwritesplitting/PgConnectionManager.java +++ b/pgjdbc/src/main/java/org/postgresql/readwritesplitting/PgConnectionManager.java @@ -199,4 +199,13 @@ public class PgConnectionManager implements AutoCloseable { } return true; } + + /** + * Get properties. + * + * @return properties. + */ + public Properties getProps() { + return props; + } } diff --git a/pgjdbc/src/main/java/org/postgresql/readwritesplitting/ReadWriteSplittingHostSpec.java b/pgjdbc/src/main/java/org/postgresql/readwritesplitting/ReadWriteSplittingHostSpec.java index e851346..36a3fbc 100644 --- a/pgjdbc/src/main/java/org/postgresql/readwritesplitting/ReadWriteSplittingHostSpec.java +++ b/pgjdbc/src/main/java/org/postgresql/readwritesplitting/ReadWriteSplittingHostSpec.java @@ -87,4 +87,13 @@ public class ReadWriteSplittingHostSpec { public HostSpec readLoadBalance() { return readChooser.iterator().next().hostSpec; } + + /** + * Get read chooser. + * + * @return read host chooser + */ + public HostChooser getReadChooser() { + return readChooser; + } } diff --git a/pgjdbc/src/main/java/org/postgresql/readwritesplitting/ReadWriteSplittingPgStatement.java b/pgjdbc/src/main/java/org/postgresql/readwritesplitting/ReadWriteSplittingPgStatement.java index 4d5f7a6..bf26149 100644 --- a/pgjdbc/src/main/java/org/postgresql/readwritesplitting/ReadWriteSplittingPgStatement.java +++ b/pgjdbc/src/main/java/org/postgresql/readwritesplitting/ReadWriteSplittingPgStatement.java @@ -4,6 +4,9 @@ package org.postgresql.readwritesplitting; +import org.postgresql.hostchooser.HostChooser; +import org.postgresql.hostchooser.HostChooserFactory; + import java.sql.Connection; import java.sql.ResultSet; import java.sql.SQLException; @@ -31,6 +34,8 @@ public class ReadWriteSplittingPgStatement implements Statement { private final Integer resultSetHoldability; + private final HostChooser readChooser; + private Statement currentStatement; private ResultSet currentResultSet; @@ -51,6 +56,9 @@ public class ReadWriteSplittingPgStatement implements Statement { this.resultSetType = resultSetType; this.resultSetConcurrency = resultSetConcurrency; this.resultSetHoldability = resultSetHoldability; + ReadWriteSplittingHostSpec hostSpec = readWriteSplittingPgConnection.getReadWriteSplittingHostSpec(); + this.readChooser = HostChooserFactory.createHostChooser(hostSpec.getReadHostSpecs(), + hostSpec.getTargetServerType(), readWriteSplittingPgConnection.getConnectionManager().getProps()); } @Override @@ -62,7 +70,7 @@ public class ReadWriteSplittingPgStatement implements Statement { } private Statement createPgStatement(String sql) throws SQLException { - Connection connection = SqlRouteEngine.getRoutedConnection(readWriteSplittingPgConnection, sql); + Connection connection = SqlRouteEngine.getRoutedConnection(readWriteSplittingPgConnection, sql, readChooser); Statement statement = connection.createStatement(resultSetType, resultSetConcurrency, resultSetHoldability); statements.add(statement); currentStatement = statement; diff --git a/pgjdbc/src/main/java/org/postgresql/readwritesplitting/SqlRouteEngine.java b/pgjdbc/src/main/java/org/postgresql/readwritesplitting/SqlRouteEngine.java index 2f3bba0..9d0d803 100644 --- a/pgjdbc/src/main/java/org/postgresql/readwritesplitting/SqlRouteEngine.java +++ b/pgjdbc/src/main/java/org/postgresql/readwritesplitting/SqlRouteEngine.java @@ -11,6 +11,7 @@ import org.apache.shardingsphere.sql.parser.core.ParseASTNode; import org.apache.shardingsphere.sql.parser.sql.common.statement.SQLStatement; import org.apache.shardingsphere.sql.parser.sql.common.statement.dml.SelectStatement; import org.apache.shardingsphere.sql.parser.sql.dialect.handler.dml.SelectStatementHandler; +import org.postgresql.hostchooser.HostChooser; import org.postgresql.hostchooser.HostRequirement; import org.postgresql.log.Log; import org.postgresql.log.Logger; @@ -34,32 +35,38 @@ public class SqlRouteEngine { /** * Route SQL. * - * @param readWriteSplittingPgConnection read write splitting PG Connection + * @param conn read write splitting PG Connection * @param sql SQL * @return routed connection * @throws SQLException SQL exception */ - public static Connection getRoutedConnection(ReadWriteSplittingPgConnection readWriteSplittingPgConnection, + public static Connection getRoutedConnection(ReadWriteSplittingPgConnection conn, String sql) throws SQLException { - HostSpec hostSpec = SqlRouteEngine.route(sql, readWriteSplittingPgConnection); - if (LOGGER.isDebugEnabled()) { - LOGGER.debug("Routed connection host spec: " + hostSpec); - } - return readWriteSplittingPgConnection.getConnectionManager().getConnection(hostSpec); + return getRoutedConnection(conn, sql, conn.getReadWriteSplittingHostSpec().getReadChooser()); } /** * Route SQL. * - * @param sql SQL * @param readWriteSplittingPgConnection read write splitting PG Connection - * @return host spec - * @throws SQLException sql exception + * @param sql SQL + * @param readHostChooser read host chooser + * @return routed connection + * @throws SQLException SQL exception */ - public static HostSpec route(String sql, ReadWriteSplittingPgConnection readWriteSplittingPgConnection) + public static Connection getRoutedConnection(ReadWriteSplittingPgConnection readWriteSplittingPgConnection, + String sql, HostChooser readHostChooser) throws SQLException { + HostSpec hostSpec = SqlRouteEngine.route(sql, readWriteSplittingPgConnection, readHostChooser); + if (LOGGER.isDebugEnabled()) { + LOGGER.debug("Routed connection host spec: " + hostSpec); + } + return readWriteSplittingPgConnection.getConnectionManager().getConnection(hostSpec); + } + + private static HostSpec route(String sql, ReadWriteSplittingPgConnection connection, HostChooser readHostChooser) throws SQLException { - ReadWriteSplittingHostSpec hostSpec = readWriteSplittingPgConnection.getReadWriteSplittingHostSpec(); - if (!readWriteSplittingPgConnection.getAutoCommit()) { + ReadWriteSplittingHostSpec hostSpec = connection.getReadWriteSplittingHostSpec(); + if (!connection.getAutoCommit()) { return hostSpec.getWriteHostSpec(); } try { @@ -67,7 +74,10 @@ public class SqlRouteEngine { return hostSpec.getWriteHostSpec(); } if (HostRequirement.secondary == hostSpec.getTargetServerType()) { - return hostSpec.readLoadBalance(); + return readLoadBalance(readHostChooser); + } + if (hostSpec.getReadHostSpecs().length == 0) { + return hostSpec.getWriteHostSpec(); } ParseASTNode parseASTNode = PARSE_ENGINE.parse(sql, true); SQLStatement sqlStatement = new SQLStatementVisitorEngine(DATABASE_TYPE, false).visit(parseASTNode); @@ -77,7 +87,11 @@ public class SqlRouteEngine { } catch (final Exception ignored) { return hostSpec.getWriteHostSpec(); } - return hostSpec.readLoadBalance(); + return readLoadBalance(readHostChooser); + } + + private static HostSpec readLoadBalance(HostChooser readChooser) { + return readChooser.iterator().next().hostSpec; } private static boolean isWriteRouteStatement(final SQLStatement sqlStatement) { diff --git a/pgjdbc/src/test/java/org/postgresql/test/readwritesplitting/ReadWriteSplittingConnectionMultiThreadTest.java b/pgjdbc/src/test/java/org/postgresql/test/readwritesplitting/ReadWriteSplittingConnectionMultiThreadTest.java new file mode 100644 index 0000000..354a788 --- /dev/null +++ b/pgjdbc/src/test/java/org/postgresql/test/readwritesplitting/ReadWriteSplittingConnectionMultiThreadTest.java @@ -0,0 +1,53 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + */ + +package org.postgresql.test.readwritesplitting; + +import org.junit.Test; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.Future; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; + +/** + * Read write splitting connection multi thread test. + * + * @since 2023-12-08 + */ +public class ReadWriteSplittingConnectionMultiThreadTest { + private static final int THREAD_COUNT = 100; + + @Test + public void test() throws Exception { + ThreadPoolExecutor executor = new ThreadPoolExecutor(THREAD_COUNT, THREAD_COUNT, 0, + TimeUnit.SECONDS, new ArrayBlockingQueue<>(100)); + List> futures = new ArrayList<>(THREAD_COUNT); + ReadWriteSplittingConnectionTest.setUp(); + for (int i = 0; i < THREAD_COUNT; i++) { + Future future = executor.submit(() -> { + ReadWriteSplittingConnectionTest tester = new ReadWriteSplittingConnectionTest(); + tester.roundRobinLoadBalanceTest(); + tester.roundRobinLoadBalanceWithPreparedStatementTest(); + tester.shuffleLoadBalanceTest(); + tester.shuffleLoadBalanceWithPreparedStatementTest(); + tester.priorityLoadBalanceTest(); + tester.targetServerTypeOfMasterTest(); + tester.targetServerTypeOfSlaveTest(); + tester.roundRobinLoadBalanceByMultiStatementsTest(); + tester.shuffleLoadBalanceByMultiStatementsTest(); + tester.priorityLoadBalanceByMultiStatementsTest(); + tester.priorityLoadBalanceByMultiConnectionsTest(); + return null; + }); + futures.add(future); + } + for (Future future : futures) { + future.get(); + } + ReadWriteSplittingConnectionTest.tearDown(); + } +} diff --git a/pgjdbc/src/test/java/org/postgresql/test/readwritesplitting/ReadWriteSplittingConnectionTest.java b/pgjdbc/src/test/java/org/postgresql/test/readwritesplitting/ReadWriteSplittingConnectionTest.java index 0cfab69..d8ca9b8 100644 --- a/pgjdbc/src/test/java/org/postgresql/test/readwritesplitting/ReadWriteSplittingConnectionTest.java +++ b/pgjdbc/src/test/java/org/postgresql/test/readwritesplitting/ReadWriteSplittingConnectionTest.java @@ -8,6 +8,7 @@ import org.junit.AfterClass; import org.junit.Assert; import org.junit.BeforeClass; import org.junit.Test; +import org.postgresql.jdbc.PgConnection; import org.postgresql.readwritesplitting.ReadWriteSplittingHostSpec; import org.postgresql.readwritesplitting.ReadWriteSplittingPgConnection; import org.postgresql.test.TestUtil; @@ -20,6 +21,7 @@ import java.sql.ResultSet; import java.sql.SQLException; import java.sql.Statement; import java.util.Properties; +import java.util.concurrent.atomic.AtomicInteger; import static org.hamcrest.CoreMatchers.instanceOf; @@ -39,8 +41,6 @@ public class ReadWriteSplittingConnectionTest { private static HostSpec[] readHostSpecs; - private int currentIndex; - private static HostSpec[] initHostSpecs() { HostSpec[] result = new HostSpec[DN_NUM]; result[0] = getMasterHostSpec(); @@ -100,43 +100,59 @@ public class ReadWriteSplittingConnectionTest { String urlParams = String.format("?enableStatementLoadBalance=true&autoBalance=roundrobin" + "&writeDataSourceAddress=%s", getMasterHostSpec()); try (Connection connection = getConnection(urlParams)) { - ReadWriteSplittingPgConnection readWriteSplittingPgConnection = - getReadWriteSplittingPgConnection(connection); - ReadWriteSplittingHostSpec readWriteSplittingHostSpec = - readWriteSplittingPgConnection.getReadWriteSplittingHostSpec(); + ReadWriteSplittingPgConnection conn = getReadWriteSplittingPgConnection(connection); + ReadWriteSplittingHostSpec readWriteSplittingHostSpec = conn.getReadWriteSplittingHostSpec(); Assert.assertEquals(readWriteSplittingHostSpec.getWriteHostSpec(), getMasterHostSpec()); Assert.assertEquals(readWriteSplittingHostSpec.getReadHostSpecs(), getReadHostSpecs()); try (Statement statement = connection.createStatement()) { Assert.assertTrue(statement.execute("SELECT * FROM account WHERE id = 1")); - HostSpec actual = getRoutedReadHostSpec(readWriteSplittingPgConnection); + HostSpec actual = getRoutedReadHostSpec(conn); + AtomicInteger count = new AtomicInteger(0); for (int i = 0; i < readHostSpecs.length; i++) { - HostSpec firstExpected = getNextExpectedRoundRobinSpec(); + HostSpec firstExpected = getNextExpectedRoundRobinSpec(count); if (firstExpected.equals(actual)) { break; } } - Assert.assertTrue(isRoutedToReadHostSpecs(readWriteSplittingPgConnection)); + Assert.assertTrue(isRoutedToReadHostSpecs(conn)); } - for (int i = 0; i < 10; i++) { - try (Statement statement = connection.createStatement()) { + try (Statement statement = connection.createStatement()) { + AtomicInteger count = new AtomicInteger(0); + for (int i = 0; i < 10; i++) { Assert.assertTrue(statement.execute("SELECT * FROM account WHERE id = 1")); - HostSpec actual = getRoutedReadHostSpec(readWriteSplittingPgConnection); - Assert.assertEquals(getNextExpectedRoundRobinSpec(), actual); - } - } - for (int i = 0; i < 10; i++) { - String sql = "SELECT * FROM account WHERE id = ?"; - try (PreparedStatement statement = connection.prepareStatement(sql)) { - statement.setString(1, "1"); - Assert.assertTrue(statement.execute()); - HostSpec actual = getRoutedReadHostSpec(readWriteSplittingPgConnection); - Assert.assertEquals(getNextExpectedRoundRobinSpec(), actual); + HostSpec actual = getRoutedReadHostSpec(conn); + Assert.assertEquals(getNextExpectedRoundRobinSpec(count), actual); } } for (int i = 0; i < 3; i++) { try (Statement statement = connection.createStatement()) { statement.execute("UPDATE account SET balance = 11 WHERE id = 1"); - Assert.assertTrue(isRoutedToWriteHostSpecs(readWriteSplittingPgConnection)); + Assert.assertTrue(isRoutedToWriteHostSpecs(conn)); + } + } + } + } + + @Test + public void roundRobinLoadBalanceWithPreparedStatementTest() throws SQLException { + String urlParams = String.format("?enableStatementLoadBalance=true&autoBalance=roundrobin" + + "&writeDataSourceAddress=%s", getMasterHostSpec()); + try (Connection conn = getConnection(urlParams)) { + ReadWriteSplittingPgConnection readWriteConn = getReadWriteSplittingPgConnection(conn); + ReadWriteSplittingHostSpec readHosts = readWriteConn.getReadWriteSplittingHostSpec(); + Assert.assertEquals(readHosts.getWriteHostSpec(), getMasterHostSpec()); + Assert.assertEquals(readHosts.getReadHostSpecs(), getReadHostSpecs()); + AtomicInteger count = new AtomicInteger(); + for (int i = 0; i < 10; i++) { + try (PreparedStatement statement = conn.prepareStatement("SELECT * FROM account WHERE id = ?")) { + statement.setObject(1, 1); + Assert.assertTrue(statement.execute()); + HostSpec actual = getRoutedReadHostSpec(readWriteConn); + statement.setObject(1, 1); + Assert.assertTrue(statement.execute()); + HostSpec actual2 = getRoutedReadHostSpec(readWriteConn); + Assert.assertEquals(actual, actual2); + Assert.assertEquals(getNextExpectedRoundRobinSpec(count), getRoutedReadHostSpec(readWriteConn)); } } } @@ -155,16 +171,14 @@ public class ReadWriteSplittingConnectionTest { String urlParams = String.format("?enableStatementLoadBalance=true&autoBalance=shuffle&writeDataSourceAddress" + "=%s", getMasterHostSpec()); try (Connection connection = getConnection(urlParams)) { - ReadWriteSplittingPgConnection readWriteSplittingPgConnection = - getReadWriteSplittingPgConnection(connection); - ReadWriteSplittingHostSpec readWriteSplittingHostSpec = - readWriteSplittingPgConnection.getReadWriteSplittingHostSpec(); + ReadWriteSplittingPgConnection readWriteConn = getReadWriteSplittingPgConnection(connection); + ReadWriteSplittingHostSpec readWriteSplittingHostSpec = readWriteConn.getReadWriteSplittingHostSpec(); Assert.assertEquals(readWriteSplittingHostSpec.getWriteHostSpec(), getMasterHostSpec()); Assert.assertEquals(readWriteSplittingHostSpec.getReadHostSpecs(), getReadHostSpecs()); for (int i = 0; i < 10; i++) { try (Statement statement = connection.createStatement()) { Assert.assertTrue(statement.execute("SELECT * FROM account WHERE id = 1")); - Assert.assertTrue(isRoutedToReadHostSpecs(readWriteSplittingPgConnection)); + Assert.assertTrue(isRoutedToReadHostSpecs(readWriteConn)); } } for (int i = 0; i < 10; i++) { @@ -172,13 +186,44 @@ public class ReadWriteSplittingConnectionTest { try (PreparedStatement statement = connection.prepareStatement(sql)) { statement.setString(1, "1"); Assert.assertTrue(statement.execute()); - Assert.assertTrue(isRoutedToReadHostSpecs(readWriteSplittingPgConnection)); + Assert.assertTrue(isRoutedToReadHostSpecs(readWriteConn)); } } for (int i = 0; i < 3; i++) { try (Statement statement = connection.createStatement()) { statement.execute("UPDATE account SET balance = 11 WHERE id = 1"); - Assert.assertTrue(isRoutedToWriteHostSpecs(readWriteSplittingPgConnection)); + Assert.assertTrue(isRoutedToWriteHostSpecs(readWriteConn)); + } + } + } + } + + @Test + public void shuffleLoadBalanceWithPreparedStatementTest() throws SQLException { + String urlParams = String.format("?enableStatementLoadBalance=true&autoBalance=shuffle" + + "&writeDataSourceAddress=%s", getMasterHostSpec()); + try (Connection connection = getConnection(urlParams)) { + ReadWriteSplittingPgConnection conn = getReadWriteSplittingPgConnection(connection); + ReadWriteSplittingHostSpec readHosts = conn.getReadWriteSplittingHostSpec(); + Assert.assertEquals(readHosts.getWriteHostSpec(), getMasterHostSpec()); + Assert.assertEquals(readHosts.getReadHostSpecs(), getReadHostSpecs()); + for (int i = 0; i < 10; i++) { + try (PreparedStatement stmt = connection.prepareStatement("SELECT * FROM account WHERE id = ?")) { + stmt.setObject(1, 1); + Assert.assertTrue(stmt.execute()); + HostSpec actual = getRoutedReadHostSpec(conn); + Assert.assertTrue(isRoutedToReadHostSpecs(conn)); + stmt.setObject(1, 1); + Assert.assertTrue(stmt.execute()); + HostSpec actual2 = getRoutedReadHostSpec(conn); + Assert.assertEquals(actual, actual2); + Assert.assertTrue(isRoutedToReadHostSpecs(conn)); + } + String updateSQL = "UPDATE account SET balance = 11 WHERE id = ?"; + try (PreparedStatement stmt = connection.prepareStatement(updateSQL)) { + stmt.setObject(1, 1); + stmt.execute(); + Assert.assertTrue(isRoutedToWriteHostSpecs(conn)); } } } @@ -188,38 +233,36 @@ public class ReadWriteSplittingConnectionTest { public void priorityLoadBalanceTest() throws SQLException { String urlParams = String.format("?enableStatementLoadBalance=true&autoBalance=priority2" + "&writeDataSourceAddress=%s", getMasterHostSpec()); - try (Connection connection = getConnection(urlParams)) { - ReadWriteSplittingPgConnection readWriteSplittingPgConnection = - getReadWriteSplittingPgConnection(connection); - ReadWriteSplittingHostSpec readWriteSplittingHostSpec = - readWriteSplittingPgConnection.getReadWriteSplittingHostSpec(); + try (Connection conn = getConnection(urlParams)) { + ReadWriteSplittingPgConnection readWriteConn = getReadWriteSplittingPgConnection(conn); + ReadWriteSplittingHostSpec readWriteSplittingHostSpec = readWriteConn.getReadWriteSplittingHostSpec(); Assert.assertEquals(readWriteSplittingHostSpec.getWriteHostSpec(), getMasterHostSpec()); Assert.assertEquals(readWriteSplittingHostSpec.getReadHostSpecs(), getReadHostSpecs()); HostSpec firstActual; - try (Statement statement = connection.createStatement()) { + try (Statement statement = conn.createStatement()) { Assert.assertTrue(statement.execute("SELECT * FROM account WHERE id = 1")); - firstActual = getRoutedReadHostSpec(readWriteSplittingPgConnection); + firstActual = getRoutedReadHostSpec(readWriteConn); } for (int i = 0; i < 10; i++) { - try (Statement statement = connection.createStatement()) { + try (Statement statement = conn.createStatement()) { Assert.assertTrue(statement.execute("SELECT * FROM account WHERE id = 1")); - HostSpec actual = getRoutedReadHostSpec(readWriteSplittingPgConnection); + HostSpec actual = getRoutedReadHostSpec(readWriteConn); Assert.assertEquals(firstActual, actual); } } for (int i = 0; i < 10; i++) { String sql = "SELECT * FROM account WHERE id = ?"; - try (PreparedStatement statement = connection.prepareStatement(sql)) { + try (PreparedStatement statement = conn.prepareStatement(sql)) { statement.setString(1, "1"); Assert.assertTrue(statement.execute()); - HostSpec actual = getRoutedReadHostSpec(readWriteSplittingPgConnection); + HostSpec actual = getRoutedReadHostSpec(readWriteConn); Assert.assertEquals(firstActual, actual); } } for (int i = 0; i < 3; i++) { - try (Statement statement = connection.createStatement()) { + try (Statement statement = conn.createStatement()) { statement.execute("UPDATE account SET balance = 11 WHERE id = 1"); - Assert.assertTrue(isRoutedToWriteHostSpecs(readWriteSplittingPgConnection)); + Assert.assertTrue(isRoutedToWriteHostSpecs(readWriteConn)); } } } @@ -229,31 +272,29 @@ public class ReadWriteSplittingConnectionTest { public void targetServerTypeOfMasterTest() throws SQLException { String urlParams = String.format("?enableStatementLoadBalance=true&autoBalance=shuffle&writeDataSourceAddress" + "=%s&targetServerType=master", getMasterHostSpec()); - try (Connection connection = getConnection(urlParams)) { - ReadWriteSplittingPgConnection readWriteSplittingPgConnection = - getReadWriteSplittingPgConnection(connection); - ReadWriteSplittingHostSpec readWriteSplittingHostSpec = - readWriteSplittingPgConnection.getReadWriteSplittingHostSpec(); - Assert.assertEquals(readWriteSplittingHostSpec.getWriteHostSpec(), getMasterHostSpec()); - Assert.assertEquals(readWriteSplittingHostSpec.getReadHostSpecs(), getReadHostSpecs()); + try (Connection conn = getConnection(urlParams)) { + ReadWriteSplittingPgConnection readWriteConn = getReadWriteSplittingPgConnection(conn); + ReadWriteSplittingHostSpec readWriteHostSpec = readWriteConn.getReadWriteSplittingHostSpec(); + Assert.assertEquals(readWriteHostSpec.getWriteHostSpec(), getMasterHostSpec()); + Assert.assertEquals(readWriteHostSpec.getReadHostSpecs(), getReadHostSpecs()); for (int i = 0; i < 10; i++) { - try (Statement statement = connection.createStatement()) { + try (Statement statement = conn.createStatement()) { Assert.assertTrue(statement.execute("SELECT * FROM account WHERE id = 1")); - Assert.assertTrue(isRoutedToWriteHostSpecs(readWriteSplittingPgConnection)); + Assert.assertTrue(isRoutedToWriteHostSpecs(readWriteConn)); } } for (int i = 0; i < 10; i++) { String sql = "SELECT * FROM account WHERE id = ?"; - try (PreparedStatement statement = connection.prepareStatement(sql)) { + try (PreparedStatement statement = conn.prepareStatement(sql)) { statement.setString(1, "1"); Assert.assertTrue(statement.execute()); - Assert.assertTrue(isRoutedToWriteHostSpecs(readWriteSplittingPgConnection)); + Assert.assertTrue(isRoutedToWriteHostSpecs(readWriteConn)); } } for (int i = 0; i < 3; i++) { - try (Statement statement = connection.createStatement()) { + try (Statement statement = conn.createStatement()) { statement.execute("UPDATE account SET balance = 11 WHERE id = 1"); - Assert.assertTrue(isRoutedToWriteHostSpecs(readWriteSplittingPgConnection)); + Assert.assertTrue(isRoutedToWriteHostSpecs(readWriteConn)); } } } @@ -263,29 +304,27 @@ public class ReadWriteSplittingConnectionTest { public void targetServerTypeOfSlaveTest() throws SQLException { String urlParams = String.format("?enableStatementLoadBalance=true&autoBalance=shuffle&writeDataSourceAddress" + "=%s&targetServerType=slave", getMasterHostSpec()); - try (Connection connection = getConnection(urlParams)) { - ReadWriteSplittingPgConnection readWriteSplittingPgConnection = - getReadWriteSplittingPgConnection(connection); - ReadWriteSplittingHostSpec readWriteSplittingHostSpec = - readWriteSplittingPgConnection.getReadWriteSplittingHostSpec(); - Assert.assertEquals(readWriteSplittingHostSpec.getWriteHostSpec(), getMasterHostSpec()); - Assert.assertEquals(readWriteSplittingHostSpec.getReadHostSpecs(), getReadHostSpecs()); + try (Connection conn = getConnection(urlParams)) { + ReadWriteSplittingPgConnection readWriteConn = getReadWriteSplittingPgConnection(conn); + ReadWriteSplittingHostSpec readWriteHostSpec = readWriteConn.getReadWriteSplittingHostSpec(); + Assert.assertEquals(readWriteHostSpec.getWriteHostSpec(), getMasterHostSpec()); + Assert.assertEquals(readWriteHostSpec.getReadHostSpecs(), getReadHostSpecs()); for (int i = 0; i < 10; i++) { - try (Statement statement = connection.createStatement()) { + try (Statement statement = conn.createStatement()) { Assert.assertTrue(statement.execute("SELECT * FROM account WHERE id = 1")); - Assert.assertTrue(isRoutedToReadHostSpecs(readWriteSplittingPgConnection)); + Assert.assertTrue(isRoutedToReadHostSpecs(readWriteConn)); } } for (int i = 0; i < 10; i++) { String sql = "SELECT * FROM account WHERE id = ?"; - try (PreparedStatement statement = connection.prepareStatement(sql)) { + try (PreparedStatement statement = conn.prepareStatement(sql)) { statement.setString(1, "1"); Assert.assertTrue(statement.execute()); - Assert.assertTrue(isRoutedToReadHostSpecs(readWriteSplittingPgConnection)); + Assert.assertTrue(isRoutedToReadHostSpecs(readWriteConn)); } } for (int i = 0; i < 3; i++) { - try (Statement statement = connection.createStatement()) { + try (Statement statement = conn.createStatement()) { try { statement.execute("UPDATE account SET balance = 11 WHERE id = 1"); } catch (SQLException e) { @@ -297,6 +336,118 @@ public class ReadWriteSplittingConnectionTest { } } + @Test + public void roundRobinLoadBalanceByMultiStatementsTest() throws Exception { + String urlParams = String.format("?enableStatementLoadBalance=true&autoBalance=roundrobin" + + "&writeDataSourceAddress=%s", getMasterHostSpec()); + try (Connection conn = getConnection(urlParams)) { + ReadWriteSplittingPgConnection readWriteConn = getReadWriteSplittingPgConnection(conn); + ReadWriteSplittingHostSpec readWriteSplittingHostSpec = readWriteConn.getReadWriteSplittingHostSpec(); + Assert.assertEquals(readWriteSplittingHostSpec.getWriteHostSpec(), getMasterHostSpec()); + Assert.assertEquals(readWriteSplittingHostSpec.getReadHostSpecs(), getReadHostSpecs()); + AtomicInteger count = new AtomicInteger(); + AtomicInteger count2 = new AtomicInteger(); + AtomicInteger count3 = new AtomicInteger(); + try (Statement statement = conn.createStatement(); + Statement statement2 = conn.createStatement(); + Statement statement3 = conn.createStatement()) { + for (int i = 0; i < 10; i++) { + Assert.assertTrue(statement.execute("SELECT * FROM account WHERE id = 1")); + Assert.assertEquals(getNextExpectedRoundRobinSpec(count), getRoutedReadHostSpec(readWriteConn)); + Assert.assertTrue(statement2.execute("SELECT * FROM account WHERE id = 1")); + Assert.assertEquals(getNextExpectedRoundRobinSpec(count2), getRoutedReadHostSpec(readWriteConn)); + Assert.assertTrue(statement3.execute("SELECT * FROM account WHERE id = 1")); + Assert.assertEquals(getNextExpectedRoundRobinSpec(count3), getRoutedReadHostSpec(readWriteConn)); + } + } + } + } + + @Test + public void shuffleLoadBalanceByMultiStatementsTest() throws Exception { + String urlParams = String.format("?enableStatementLoadBalance=true&autoBalance=shuffle&writeDataSourceAddress" + + "=%s", getMasterHostSpec()); + try (Connection conn = getConnection(urlParams)) { + ReadWriteSplittingPgConnection readWriteConn = getReadWriteSplittingPgConnection(conn); + ReadWriteSplittingHostSpec readWriteHostSpec = readWriteConn.getReadWriteSplittingHostSpec(); + Assert.assertEquals(readWriteHostSpec.getWriteHostSpec(), getMasterHostSpec()); + Assert.assertEquals(readWriteHostSpec.getReadHostSpecs(), getReadHostSpecs()); + try (Statement statement = conn.createStatement(); + Statement statement2 = conn.createStatement(); + Statement statement3 = conn.createStatement()) { + for (int i = 0; i < 10; i++) { + Assert.assertTrue(statement.execute("SELECT * FROM account WHERE id = 1")); + Assert.assertTrue(isRoutedToReadHostSpecs(readWriteConn)); + Assert.assertTrue(statement2.execute("SELECT * FROM account WHERE id = 1")); + Assert.assertTrue(isRoutedToReadHostSpecs(readWriteConn)); + Assert.assertTrue(statement3.execute("SELECT * FROM account WHERE id = 1")); + Assert.assertTrue(isRoutedToReadHostSpecs(readWriteConn)); + } + } + } + } + + @Test + public void priorityLoadBalanceByMultiStatementsTest() throws Exception { + String urlParams = String.format("?enableStatementLoadBalance=true&autoBalance=priority2" + + "&writeDataSourceAddress=%s", getMasterHostSpec()); + try (Connection conn = getConnection(urlParams)) { + ReadWriteSplittingPgConnection readWriteConn = getReadWriteSplittingPgConnection(conn); + ReadWriteSplittingHostSpec readWriteHostSpec = readWriteConn.getReadWriteSplittingHostSpec(); + Assert.assertEquals(readWriteHostSpec.getWriteHostSpec(), getMasterHostSpec()); + Assert.assertEquals(readWriteHostSpec.getReadHostSpecs(), getReadHostSpecs()); + try (Statement statement = conn.createStatement(); + Statement statement2 = conn.createStatement(); + Statement statement3 = conn.createStatement()) { + for (int i = 0; i < 10; i++) { + Assert.assertTrue(statement.execute("SELECT * FROM account WHERE id = 1")); + Assert.assertTrue(isRoutedToReadHostSpecs(readWriteConn)); + Assert.assertTrue(statement2.execute("SELECT * FROM account WHERE id = 1")); + Assert.assertTrue(isRoutedToReadHostSpecs(readWriteConn)); + Assert.assertTrue(statement3.execute("SELECT * FROM account WHERE id = 1")); + Assert.assertTrue(isRoutedToReadHostSpecs(readWriteConn)); + } + } + } + } + + @Test + public void priorityLoadBalanceByMultiConnectionsTest() throws Exception { + String urlParams = String.format("?enableStatementLoadBalance=true&autoBalance=roundrobin" + + "&writeDataSourceAddress=%s", getMasterHostSpec()); + try (Connection connection = getConnection(urlParams); Connection connection2 = getConnection(urlParams)) { + ReadWriteSplittingPgConnection conn = getReadWriteSplittingPgConnection(connection); + ReadWriteSplittingPgConnection conn2 = getReadWriteSplittingPgConnection(connection2); + ReadWriteSplittingHostSpec readWriteHostSpecs = conn.getReadWriteSplittingHostSpec(); + Assert.assertEquals(readWriteHostSpecs.getWriteHostSpec(), getMasterHostSpec()); + Assert.assertEquals(readWriteHostSpecs.getReadHostSpecs(), getReadHostSpecs()); + AtomicInteger count = new AtomicInteger(); + try (Statement conn1Statement = connection.createStatement(); + Statement conn2Statement = connection2.createStatement(); + Statement conn1Statement2 = connection.createStatement(); + Statement conn2Statement2 = connection2.createStatement(); + Statement conn1Statement3 = connection.createStatement(); + Statement conn2Statement3 = connection2.createStatement()) { + for (int i = 0; i < 10; i++) { + Assert.assertTrue(conn1Statement.execute("SELECT * FROM account WHERE id = 1")); + Assert.assertEquals(getCurrentExpectedRoundRobinSpec(count), getRoutedReadHostSpec(conn)); + Assert.assertTrue(conn2Statement.execute("SELECT * FROM account WHERE id = 1")); + Assert.assertEquals(getCurrentExpectedRoundRobinSpec(count), getRoutedReadHostSpec(conn2)); + Assert.assertTrue(conn1Statement2.execute("SELECT * FROM account WHERE id = 1")); + Assert.assertEquals(getCurrentExpectedRoundRobinSpec(count), getRoutedReadHostSpec(conn)); + Assert.assertTrue(conn2Statement2.execute("SELECT * FROM account WHERE id = 1")); + Assert.assertEquals(getCurrentExpectedRoundRobinSpec(count), getRoutedReadHostSpec(conn2)); + Assert.assertTrue(conn1Statement3.execute("SELECT * FROM account WHERE id = 1")); + Assert.assertEquals(getCurrentExpectedRoundRobinSpec(count), getRoutedReadHostSpec(conn)); + Assert.assertTrue(conn2Statement3.execute("SELECT * FROM account WHERE id = 1")); + Assert.assertEquals(getCurrentExpectedRoundRobinSpec(count), getRoutedReadHostSpec(conn2)); + getCurrentExpectedRoundRobinSpec(count); + count.getAndIncrement(); + } + } + } + } + private static Properties getProperties() { Properties properties = new Properties(); properties.setProperty("user", TestUtil.getUser()); @@ -309,37 +460,36 @@ public class ReadWriteSplittingConnectionTest { String params = "?enableStatementLoadBalance=true&autoBalance=shuffle&writeDataSourceAddress=%s"; String urlParams = String.format(params, getMasterHostSpec()); try (Connection connection = getConnection(urlParams)) { - ReadWriteSplittingPgConnection readWriteSplittingPgConnection = - getReadWriteSplittingPgConnection(connection); + ReadWriteSplittingPgConnection readWriteConn = getReadWriteSplittingPgConnection(connection); connection.setAutoCommit(false); for (int i = 0; i < 3; i++) { try (Statement statement = connection.createStatement()) { Assert.assertTrue(statement.execute("SELECT * FROM account WHERE id = 1")); - Assert.assertTrue(isRoutedToWriteHostSpecs(readWriteSplittingPgConnection)); + Assert.assertTrue(isRoutedToWriteHostSpecs(readWriteConn)); } } try (Statement statement = connection.createStatement()) { statement.execute("UPDATE account SET balance = 11 WHERE id = 1"); - Assert.assertTrue(isRoutedToWriteHostSpecs(readWriteSplittingPgConnection)); + Assert.assertTrue(isRoutedToWriteHostSpecs(readWriteConn)); } connection.commit(); try (Statement statement = connection.createStatement()) { Assert.assertTrue(statement.execute("SELECT * FROM account WHERE id = 1")); - Assert.assertTrue(isRoutedToWriteHostSpecs(readWriteSplittingPgConnection)); + Assert.assertTrue(isRoutedToWriteHostSpecs(readWriteConn)); } connection.rollback(); try (Statement statement = connection.createStatement()) { Assert.assertTrue(statement.execute("SELECT * FROM account WHERE id = 1")); - Assert.assertTrue(isRoutedToWriteHostSpecs(readWriteSplittingPgConnection)); + Assert.assertTrue(isRoutedToWriteHostSpecs(readWriteConn)); } connection.setAutoCommit(true); try (Statement statement = connection.createStatement()) { Assert.assertTrue(statement.execute("SELECT * FROM account WHERE id = 1")); - Assert.assertTrue(isRoutedToReadHostSpecs(readWriteSplittingPgConnection)); + Assert.assertTrue(isRoutedToReadHostSpecs(readWriteConn)); } try (Statement statement = connection.createStatement()) { statement.execute("UPDATE account SET balance = 11 WHERE id = 1"); - Assert.assertTrue(isRoutedToWriteHostSpecs(readWriteSplittingPgConnection)); + Assert.assertTrue(isRoutedToWriteHostSpecs(readWriteConn)); } } } @@ -349,8 +499,7 @@ public class ReadWriteSplittingConnectionTest { String params = "?enableStatementLoadBalance=true&autoBalance=shuffle&writeDataSourceAddress=%s"; String urlParams = String.format(params, getMasterHostSpec()); try (Connection connection = getConnection(urlParams)) { - ReadWriteSplittingPgConnection readWriteSplittingPgConnection = - getReadWriteSplittingPgConnection(connection); + ReadWriteSplittingPgConnection readWriteConn = getReadWriteSplittingPgConnection(connection); try (Statement statement = connection.createStatement()) { for (int i = 0; i < 3; i++) { ResultSet resultSet = statement.executeQuery("SELECT * FROM account WHERE id = 1"); @@ -359,7 +508,7 @@ public class ReadWriteSplittingConnectionTest { Assert.assertTrue(resultSet.next()); Assert.assertTrue(resultSet2.next()); Assert.assertTrue(resultSet3.next()); - Assert.assertTrue(isRoutedToReadHostSpecs(readWriteSplittingPgConnection)); + Assert.assertTrue(isRoutedToReadHostSpecs(readWriteConn)); } } } @@ -369,26 +518,54 @@ public class ReadWriteSplittingConnectionTest { public void executeQueryWithoutWriteDataSourceAddressParamTest() throws Exception { String urlParams = "?enableStatementLoadBalance=true&autoBalance=shuffle"; try (Connection connection = getConnection(urlParams)) { - ReadWriteSplittingPgConnection readWriteSplittingPgConnection = - getReadWriteSplittingPgConnection(connection); + ReadWriteSplittingPgConnection readWriteConn = getReadWriteSplittingPgConnection(connection); for (int i = 0; i < 3; i++) { try (Statement statement = connection.createStatement()) { ResultSet resultSet = statement.executeQuery("SELECT * FROM account WHERE id = 1"); Assert.assertTrue(resultSet.next()); - Assert.assertTrue(isRoutedToReadHostSpecs(readWriteSplittingPgConnection)); + Assert.assertTrue(isRoutedToReadHostSpecs(readWriteConn)); } } } } - private HostSpec getNextExpectedRoundRobinSpec() { - return readHostSpecs[(currentIndex++) % readHostSpecs.length]; + @Test + public void executeWithOnlyWriteDataSource() throws Exception { + String writeHost = hostSpecs[0].getHost() + ":" + hostSpecs[0].getPort(); + String writeDatabase = TestUtil.getDatabase(); + String url = String.format("jdbc:postgresql://%s/%s?enableStatementLoadBalance=true" + + "&autoBalance=roundrobin&writeDataSourceAddress=%s", writeHost, writeDatabase, writeHost); + Properties props = getProperties(); + try (Connection conn = DriverManager.getConnection(url, props)) { + ReadWriteSplittingPgConnection readWriteConn = getReadWriteSplittingPgConnection(conn); + ReadWriteSplittingHostSpec readWriteSplittingHostSpec = readWriteConn.getReadWriteSplittingHostSpec(); + Assert.assertEquals(readWriteSplittingHostSpec.getWriteHostSpec(), getMasterHostSpec()); + Assert.assertTrue(readWriteSplittingHostSpec.getReadHostSpecs().length == 0); + try (Statement statement = conn.createStatement()) { + Assert.assertTrue(statement.execute("SELECT * FROM account WHERE id = 1")); + for (int i = 0; i < readHostSpecs.length; i++) { + Assert.assertTrue(isRoutedToWriteHostSpecs(readWriteConn)); + } + } + for (int i = 0; i < 3; i++) { + try (Statement statement = conn.createStatement()) { + statement.execute("UPDATE account SET balance = 11 WHERE id = 1"); + Assert.assertTrue(isRoutedToWriteHostSpecs(readWriteConn)); + } + } + } } - private static boolean isRoutedToReadHostSpecs(ReadWriteSplittingPgConnection readWriteSplittingPgConnection) - throws SQLException { - String socketAddress = - readWriteSplittingPgConnection.getConnectionManager().getCurrentConnection().getSocketAddress(); + private HostSpec getCurrentExpectedRoundRobinSpec(AtomicInteger count) { + return readHostSpecs[(count.get()) % readHostSpecs.length]; + } + + private HostSpec getNextExpectedRoundRobinSpec(AtomicInteger count) { + return readHostSpecs[(count.getAndIncrement()) % readHostSpecs.length]; + } + + private static boolean isRoutedToReadHostSpecs(ReadWriteSplittingPgConnection conn) throws SQLException { + String socketAddress = conn.getConnectionManager().getCurrentConnection().getSocketAddress(); for (HostSpec readHostSpec : getReadHostSpecs()) { if (socketAddress.endsWith(getHostOrAlias(readHostSpec) + ":" + readHostSpec.getPort())) { return true; @@ -397,22 +574,34 @@ public class ReadWriteSplittingConnectionTest { return false; } - private static HostSpec getRoutedReadHostSpec(ReadWriteSplittingPgConnection readWriteSplittingPgConnection) - throws SQLException { - String socketAddress = - readWriteSplittingPgConnection.getConnectionManager().getCurrentConnection().getSocketAddress(); - for (HostSpec readHostSpec : getReadHostSpecs()) { - if (socketAddress.endsWith(getHostOrAlias(readHostSpec) + ":" + readHostSpec.getPort())) { - return readHostSpec; + private static HostSpec getRoutedReadHostSpec(Connection connection) throws SQLException { + if (connection instanceof ReadWriteSplittingPgConnection) { + String socketAddress = ((ReadWriteSplittingPgConnection) connection).getConnectionManager() + .getCurrentConnection().getSocketAddress(); + for (HostSpec readHostSpec : getReadHostSpecs()) { + if (socketAddress.endsWith(getHostOrAlias(readHostSpec) + ":" + readHostSpec.getPort())) { + return readHostSpec; + } + } + throw new IllegalStateException("Must routed to one read host spec"); + } + if (connection instanceof PgConnection) { + String socketAddress = ((PgConnection) connection).getSocketAddress(); + for (HostSpec readHostSpec : getReadHostSpecs()) { + if (socketAddress.endsWith(getHostOrAlias(readHostSpec) + ":" + readHostSpec.getPort())) { + return readHostSpec; + } + } + if (socketAddress.endsWith(getHostOrAlias(getMasterHostSpec()) + ":" + getMasterHostSpec().getPort())) { + return getMasterHostSpec(); } + throw new IllegalStateException("Must routed to one host spec"); } - throw new IllegalStateException("Must routed to one read host spec"); + throw new IllegalStateException("Unexpected connection type"); } - private static boolean isRoutedToWriteHostSpecs(ReadWriteSplittingPgConnection readWriteSplittingPgConnection) - throws SQLException { - String socketAddress = - readWriteSplittingPgConnection.getConnectionManager().getCurrentConnection().getSocketAddress(); + private static boolean isRoutedToWriteHostSpecs(ReadWriteSplittingPgConnection conn) throws SQLException { + String socketAddress = conn.getConnectionManager().getCurrentConnection().getSocketAddress(); return socketAddress.endsWith(getHostOrAlias(writeHostSpec) + ":" + writeHostSpec.getPort()); } -- Gitee