diff --git a/pgjdbc/pom.xml b/pgjdbc/pom.xml index db745e3ce9de5be10ac06b2d6274b4dc3a10b50f..f1d4f9dc0706643f58117c584ffbae730b44f8e4 100644 --- a/pgjdbc/pom.xml +++ b/pgjdbc/pom.xml @@ -32,6 +32,7 @@ 42 false 8.5 + 5.4.0 @@ -47,6 +48,18 @@ 1.7.30 provided + + org.apache.shardingsphere + shardingsphere-parser-sql-engine + ${shardingsphere.version} + provided + + + org.apache.shardingsphere + shardingsphere-parser-sql-opengauss + ${shardingsphere.version} + provided + diff --git a/pgjdbc/src/main/java/org/postgresql/Driver.java b/pgjdbc/src/main/java/org/postgresql/Driver.java index c324cc5d18fae623c8889fe2a6fcb703379930a1..e5c4cf32e2b563a13c470a5170ad779c3da70b40 100755 --- a/pgjdbc/src/main/java/org/postgresql/Driver.java +++ b/pgjdbc/src/main/java/org/postgresql/Driver.java @@ -13,6 +13,7 @@ import org.postgresql.log.Log; import org.postgresql.log.Tracer; import org.postgresql.quickautobalance.ConnectionManager; import org.postgresql.quickautobalance.LoadBalanceHeartBeating; +import org.postgresql.readwritesplitting.ReadWriteSplittingPgConnection; import org.postgresql.util.DriverInfo; import org.postgresql.util.GT; import org.postgresql.util.HostSpec; @@ -561,6 +562,9 @@ public class Driver implements java.sql.Driver { */ private static Connection makeConnection(String url, Properties props) throws SQLException { ConnectionManager.getInstance().setCluster(props); + if (PGProperty.ENABLE_STATEMENT_LOAD_BALANCE.getBoolean(props)) { + return new ReadWriteSplittingPgConnection(hostSpecs(props), props, user(props), database(props), url); + } PgConnection pgConnection = new PgConnection(hostSpecs(props), user(props), database(props), props, url); GlobalConnectionTracker.possessConnectionReference(pgConnection.getQueryExecutor(), props); LoadBalanceHeartBeating.setConnection(pgConnection, props); diff --git a/pgjdbc/src/main/java/org/postgresql/PGProperty.java b/pgjdbc/src/main/java/org/postgresql/PGProperty.java index 6a11f0c51fca8e2e382fed7628e0aead37f378dc..22eedf91123e569a2fc69ad4997b6689c5094178 100644 --- a/pgjdbc/src/main/java/org/postgresql/PGProperty.java +++ b/pgjdbc/src/main/java/org/postgresql/PGProperty.java @@ -480,7 +480,7 @@ public enum PGProperty { ENABLE_QUICK_AUTO_BALANCE("enableQuickAutoBalance", "false", "If the connection enable quickAutoBalance, this parameter only takes effect when autoBalance=leastconn." + "value: true or false.", - false, "true", "false"), + false, "true", "false"), /** * Idle time threshold of connections when quick auto balancing filters connections. @@ -508,6 +508,20 @@ public enum PGProperty { + "jdbc will retain minReservedConPerCluster percent of the connections pre data node that meet the closing conditions during quick auto balancing." + "Value range: int && [0, 100]." + "This parameter only takes effect when autoBalance=leastconn and enableQuickAutoBalance=true"), + + /** + * Enable statement load balance. + */ + ENABLE_STATEMENT_LOAD_BALANCE("enableStatementLoadBalance", "false", + "Enable statement-level load balancing configuration, " + + "so that load balancing routing will be performed when each SQL statement is executed." + + "Optional values: true or false.", + false, "true", "false"), + + /** + * Write data source address. + */ + WRITE_DATA_SOURCE_ADDRESS("writeDataSourceAddress", "", "Specify the host and port for write database", false), /** * Supported TLS cipher suites diff --git a/pgjdbc/src/main/java/org/postgresql/jdbc/ReadWriteSplittingPgPreparedStatement.java b/pgjdbc/src/main/java/org/postgresql/jdbc/ReadWriteSplittingPgPreparedStatement.java new file mode 100644 index 0000000000000000000000000000000000000000..f4ac019491310604a812efa2a4111a9f900ad925 --- /dev/null +++ b/pgjdbc/src/main/java/org/postgresql/jdbc/ReadWriteSplittingPgPreparedStatement.java @@ -0,0 +1,624 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + */ + +package org.postgresql.jdbc; + +import org.postgresql.readwritesplitting.ReadWriteSplittingPgConnection; +import org.postgresql.readwritesplitting.SqlRouteEngine; + +import java.io.InputStream; +import java.io.Reader; +import java.math.BigDecimal; +import java.net.URL; +import java.sql.Array; +import java.sql.Blob; +import java.sql.Clob; +import java.sql.Connection; +import java.sql.Date; +import java.sql.NClob; +import java.sql.ParameterMetaData; +import java.sql.PreparedStatement; +import java.sql.Ref; +import java.sql.ResultSet; +import java.sql.ResultSetMetaData; +import java.sql.RowId; +import java.sql.SQLException; +import java.sql.SQLWarning; +import java.sql.SQLXML; +import java.sql.Time; +import java.sql.Timestamp; +import java.util.Calendar; + +/** + * Read write splitting PG prepared statement + * + * @since 2023-11-20 + */ +public class ReadWriteSplittingPgPreparedStatement implements PreparedStatement { + private final PreparedStatement pgPreparedStatement; + + /** + * Constructor. + * + * @param readWriteSplittingPgConnection read write splitting pg connection + * @param sql SQL + * @throws SQLException SQL exception + */ + public ReadWriteSplittingPgPreparedStatement(ReadWriteSplittingPgConnection readWriteSplittingPgConnection, + String sql) throws SQLException { + Connection pgConnection = SqlRouteEngine.getRoutedConnection(readWriteSplittingPgConnection, sql); + pgPreparedStatement = pgConnection.prepareStatement(sql); + } + + /** + * Constructor. + * + * @param readWriteSplittingPgConnection read write splitting pg connection + * @param sql SQL + * @param autoGeneratedKeys auto generated keys + * @throws SQLException SQL exception + */ + public ReadWriteSplittingPgPreparedStatement(ReadWriteSplittingPgConnection readWriteSplittingPgConnection, + String sql, int autoGeneratedKeys) throws SQLException { + Connection pgConnection = SqlRouteEngine.getRoutedConnection(readWriteSplittingPgConnection, sql); + pgPreparedStatement = pgConnection.prepareStatement(sql, autoGeneratedKeys); + } + + /** + * Constructor. + * + * @param readWriteSplittingPgConnection read write splitting pg connection + * @param sql SQL + * @param resultSetType result set type + * @param resultSetConcurrency result set concurrency + * @throws SQLException SQL exception + */ + public ReadWriteSplittingPgPreparedStatement(ReadWriteSplittingPgConnection readWriteSplittingPgConnection, + String sql, int resultSetType, int resultSetConcurrency) throws SQLException { + Connection pgConnection = SqlRouteEngine.getRoutedConnection(readWriteSplittingPgConnection, sql); + pgPreparedStatement = pgConnection.prepareStatement(sql, resultSetType, resultSetConcurrency); + } + + /** + * Constructor. + * + * @param readWriteSplittingPgConnection read write splitting pg connection + * @param sql SQL + * @param columnIndexes column indexes + * @throws SQLException SQL exception + */ + public ReadWriteSplittingPgPreparedStatement(ReadWriteSplittingPgConnection readWriteSplittingPgConnection, + String sql, int[] columnIndexes) throws SQLException { + Connection pgConnection = SqlRouteEngine.getRoutedConnection(readWriteSplittingPgConnection, sql); + pgPreparedStatement = pgConnection.prepareStatement(sql, columnIndexes); + } + + /** + * Constructor. + * + * @param readWriteSplittingPgConnection read write splitting pg connection + * @param sql SQL + * @param columnNames column names + * @throws SQLException SQL exception + */ + public ReadWriteSplittingPgPreparedStatement(ReadWriteSplittingPgConnection readWriteSplittingPgConnection, + String sql, String[] columnNames) throws SQLException { + Connection pgConnection = SqlRouteEngine.getRoutedConnection(readWriteSplittingPgConnection, sql); + pgPreparedStatement = pgConnection.prepareStatement(sql, columnNames); + } + + /** + * Constructor. + * + * @param readWriteSplittingPgConnection read write splitting pg connection + * @param sql SQL + * @param resultSetType result set type + * @param resultSetConcurrency result set concurrency + * @param resultSetHoldability result set holdability + * @throws SQLException SQL exception + */ + public ReadWriteSplittingPgPreparedStatement(ReadWriteSplittingPgConnection readWriteSplittingPgConnection, + String sql, int resultSetType, int resultSetConcurrency, + int resultSetHoldability) throws SQLException { + Connection pgConnection = SqlRouteEngine.getRoutedConnection(readWriteSplittingPgConnection, sql); + pgPreparedStatement = pgConnection.prepareStatement(sql, resultSetType, resultSetConcurrency, + resultSetHoldability); + } + + @Override + public ResultSet executeQuery() throws SQLException { + return pgPreparedStatement.executeQuery(); + } + + @Override + public int executeUpdate() throws SQLException { + return pgPreparedStatement.executeUpdate(); + } + + @Override + public void setNull(int parameterIndex, int sqlType) throws SQLException { + pgPreparedStatement.setNull(parameterIndex, sqlType); + } + + @Override + public void setBoolean(int parameterIndex, boolean isTrue) throws SQLException { + pgPreparedStatement.setBoolean(parameterIndex, isTrue); + } + + @Override + public void setByte(int parameterIndex, byte x) throws SQLException { + pgPreparedStatement.setByte(parameterIndex, x); + } + + @Override + public void setShort(int parameterIndex, short x) throws SQLException { + pgPreparedStatement.setShort(parameterIndex, x); + } + + @Override + public void setInt(int parameterIndex, int x) throws SQLException { + pgPreparedStatement.setInt(parameterIndex, x); + } + + @Override + public void setLong(int parameterIndex, long x) throws SQLException { + pgPreparedStatement.setLong(parameterIndex, x); + } + + @Override + public void setFloat(int parameterIndex, float x) throws SQLException { + pgPreparedStatement.setFloat(parameterIndex, x); + } + + @Override + public void setDouble(int parameterIndex, double x) throws SQLException { + pgPreparedStatement.setDouble(parameterIndex, x); + } + + @Override + public void setBigDecimal(int parameterIndex, BigDecimal x) throws SQLException { + pgPreparedStatement.setBigDecimal(parameterIndex, x); + } + + @Override + public void setString(int parameterIndex, String x) throws SQLException { + pgPreparedStatement.setString(parameterIndex, x); + } + + @Override + public void setBytes(int parameterIndex, byte[] x) throws SQLException { + pgPreparedStatement.setBytes(parameterIndex, x); + } + + @Override + public void setDate(int parameterIndex, Date x) throws SQLException { + pgPreparedStatement.setDate(parameterIndex, x); + } + + @Override + public void setTime(int parameterIndex, Time x) throws SQLException { + pgPreparedStatement.setTime(parameterIndex, x); + } + + @Override + public void setTimestamp(int parameterIndex, Timestamp x) throws SQLException { + pgPreparedStatement.setTimestamp(parameterIndex, x); + } + + @Override + public void setAsciiStream(int parameterIndex, InputStream x, int length) throws SQLException { + pgPreparedStatement.setAsciiStream(parameterIndex, x, length); + } + + @Override + public void setUnicodeStream(int parameterIndex, InputStream x, int length) throws SQLException { + pgPreparedStatement.setUnicodeStream(parameterIndex, x, length); + } + + @Override + public void setBinaryStream(int parameterIndex, InputStream x, int length) throws SQLException { + pgPreparedStatement.setBinaryStream(parameterIndex, x, length); + } + + @Override + public void clearParameters() throws SQLException { + pgPreparedStatement.clearParameters(); + } + + @Override + public void setObject(int parameterIndex, Object x, int targetSqlType) throws SQLException { + pgPreparedStatement.setObject(parameterIndex, x, targetSqlType); + } + + @Override + public void setObject(int parameterIndex, Object x) throws SQLException { + pgPreparedStatement.setObject(parameterIndex, x); + } + + @Override + public boolean execute() throws SQLException { + return pgPreparedStatement.execute(); + } + + + @Override + public void addBatch() throws SQLException { + pgPreparedStatement.addBatch(); + } + + @Override + public void setCharacterStream(int parameterIndex, Reader reader, int length) throws SQLException { + pgPreparedStatement.setCharacterStream(parameterIndex, reader, length); + } + + @Override + public void setRef(int parameterIndex, Ref x) throws SQLException { + pgPreparedStatement.setRef(parameterIndex, x); + } + + @Override + public void setBlob(int parameterIndex, Blob x) throws SQLException { + pgPreparedStatement.setBlob(parameterIndex, x); + } + + @Override + public void setClob(int parameterIndex, Clob x) throws SQLException { + pgPreparedStatement.setClob(parameterIndex, x); + } + + @Override + public void setArray(int parameterIndex, Array x) throws SQLException { + pgPreparedStatement.setArray(parameterIndex, x); + } + + @Override + public ResultSetMetaData getMetaData() throws SQLException { + return pgPreparedStatement.getMetaData(); + } + + @Override + public void setDate(int parameterIndex, Date x, Calendar cal) throws SQLException { + pgPreparedStatement.setDate(parameterIndex, x, cal); + } + + @Override + public void setTime(int parameterIndex, Time x, Calendar cal) throws SQLException { + pgPreparedStatement.setTime(parameterIndex, x, cal); + } + + @Override + public void setTimestamp(int parameterIndex, Timestamp x, Calendar cal) throws SQLException { + pgPreparedStatement.setTimestamp(parameterIndex, x, cal); + } + + @Override + public void setNull(int parameterIndex, int sqlType, String typeName) throws SQLException { + pgPreparedStatement.setNull(parameterIndex, sqlType, typeName); + } + + @Override + public void setURL(int parameterIndex, URL x) throws SQLException { + pgPreparedStatement.setURL(parameterIndex, x); + } + + @Override + public ParameterMetaData getParameterMetaData() throws SQLException { + return pgPreparedStatement.getParameterMetaData(); + } + + @Override + public void setRowId(int parameterIndex, RowId x) throws SQLException { + pgPreparedStatement.setRowId(parameterIndex, x); + } + + @Override + public void setNString(int parameterIndex, String value) throws SQLException { + pgPreparedStatement.setNString(parameterIndex, value); + } + + @Override + public void setNCharacterStream(int parameterIndex, Reader value, long length) throws SQLException { + pgPreparedStatement.setNCharacterStream(parameterIndex, value, length); + } + + @Override + public void setNClob(int parameterIndex, NClob value) throws SQLException { + pgPreparedStatement.setNClob(parameterIndex, value); + } + + @Override + public void setClob(int parameterIndex, Reader reader, long length) throws SQLException { + pgPreparedStatement.setClob(parameterIndex, reader, length); + } + + @Override + public void setBlob(int parameterIndex, InputStream inputStream, long length) throws SQLException { + pgPreparedStatement.setBlob(parameterIndex, inputStream, length); + } + + @Override + public void setNClob(int parameterIndex, Reader reader, long length) throws SQLException { + pgPreparedStatement.setNClob(parameterIndex, reader, length); + } + + @Override + public void setSQLXML(int parameterIndex, SQLXML xmlObject) throws SQLException { + pgPreparedStatement.setSQLXML(parameterIndex, xmlObject); + } + + @Override + public void setObject(int parameterIndex, Object x, int targetSqlType, int scaleOrLength) throws SQLException { + pgPreparedStatement.setObject(parameterIndex, x, targetSqlType, scaleOrLength); + } + + @Override + public void setAsciiStream(int parameterIndex, InputStream x, long length) throws SQLException { + pgPreparedStatement.setAsciiStream(parameterIndex, x, length); + } + + @Override + public void setBinaryStream(int parameterIndex, InputStream x, long length) throws SQLException { + pgPreparedStatement.setBinaryStream(parameterIndex, x, length); + } + + @Override + public void setCharacterStream(int parameterIndex, Reader reader, long length) throws SQLException { + pgPreparedStatement.setCharacterStream(parameterIndex, reader, length); + } + + @Override + public void setAsciiStream(int parameterIndex, InputStream x) throws SQLException { + pgPreparedStatement.setAsciiStream(parameterIndex, x); + } + + @Override + public void setBinaryStream(int parameterIndex, InputStream x) throws SQLException { + pgPreparedStatement.setBinaryStream(parameterIndex, x); + } + + @Override + public void setCharacterStream(int parameterIndex, Reader reader) throws SQLException { + pgPreparedStatement.setCharacterStream(parameterIndex, reader); + } + + @Override + public void setNCharacterStream(int parameterIndex, Reader value) throws SQLException { + pgPreparedStatement.setNCharacterStream(parameterIndex, value); + } + + @Override + public void setClob(int parameterIndex, Reader reader) throws SQLException { + pgPreparedStatement.setClob(parameterIndex, reader); + } + + @Override + public void setBlob(int parameterIndex, InputStream inputStream) throws SQLException { + pgPreparedStatement.setBlob(parameterIndex, inputStream); + } + + @Override + public void setNClob(int parameterIndex, Reader reader) throws SQLException { + pgPreparedStatement.setNClob(parameterIndex, reader); + } + + @Override + public ResultSet executeQuery(String sql) throws SQLException { + return pgPreparedStatement.executeQuery(sql); + } + + @Override + public int executeUpdate(String sql) throws SQLException { + return pgPreparedStatement.executeUpdate(sql); + } + + @Override + public void close() throws SQLException { + pgPreparedStatement.close(); + } + + @Override + public int getMaxFieldSize() throws SQLException { + return pgPreparedStatement.getMaxFieldSize(); + } + + @Override + public void setMaxFieldSize(int max) throws SQLException { + pgPreparedStatement.setMaxFieldSize(max); + } + + @Override + public int getMaxRows() throws SQLException { + return pgPreparedStatement.getMaxRows(); + } + + @Override + public void setMaxRows(int max) throws SQLException { + pgPreparedStatement.setMaxRows(max); + } + + @Override + public void setEscapeProcessing(boolean isEnable) throws SQLException { + pgPreparedStatement.setEscapeProcessing(isEnable); + } + + @Override + public int getQueryTimeout() throws SQLException { + return pgPreparedStatement.getQueryTimeout(); + } + + @Override + public void setQueryTimeout(int seconds) throws SQLException { + pgPreparedStatement.setQueryTimeout(seconds); + } + + @Override + public void cancel() throws SQLException { + pgPreparedStatement.cancel(); + } + + @Override + public SQLWarning getWarnings() throws SQLException { + return pgPreparedStatement.getWarnings(); + } + + @Override + public void clearWarnings() throws SQLException { + pgPreparedStatement.clearWarnings(); + } + + @Override + public void setCursorName(String name) throws SQLException { + pgPreparedStatement.setCursorName(name); + } + + @Override + public boolean execute(String sql) throws SQLException { + return pgPreparedStatement.execute(sql); + } + + @Override + public ResultSet getResultSet() throws SQLException { + return pgPreparedStatement.getResultSet(); + } + + @Override + public int getUpdateCount() throws SQLException { + return pgPreparedStatement.getUpdateCount(); + } + + @Override + public boolean getMoreResults() throws SQLException { + return pgPreparedStatement.getMoreResults(); + } + + @Override + public int getFetchDirection() throws SQLException { + return pgPreparedStatement.getFetchDirection(); + } + + @Override + public void setFetchDirection(int direction) throws SQLException { + pgPreparedStatement.setFetchDirection(direction); + } + + @Override + public int getFetchSize() throws SQLException { + return pgPreparedStatement.getFetchSize(); + } + + @Override + public void setFetchSize(int rows) throws SQLException { + pgPreparedStatement.setFetchSize(rows); + } + + @Override + public int getResultSetConcurrency() throws SQLException { + return pgPreparedStatement.getResultSetConcurrency(); + } + + @Override + public int getResultSetType() throws SQLException { + return pgPreparedStatement.getResultSetType(); + } + + @Override + public void addBatch(String sql) throws SQLException { + pgPreparedStatement.addBatch(sql); + } + + @Override + public void clearBatch() throws SQLException { + pgPreparedStatement.clearBatch(); + } + + @Override + public int[] executeBatch() throws SQLException { + return pgPreparedStatement.executeBatch(); + } + + @Override + public Connection getConnection() throws SQLException { + return pgPreparedStatement.getConnection(); + } + + @Override + public boolean getMoreResults(int current) throws SQLException { + return pgPreparedStatement.getMoreResults(current); + } + + @Override + public ResultSet getGeneratedKeys() throws SQLException { + return pgPreparedStatement.getGeneratedKeys(); + } + + @Override + public int executeUpdate(String sql, int autoGeneratedKeys) throws SQLException { + return pgPreparedStatement.executeUpdate(sql, autoGeneratedKeys); + } + + @Override + public int executeUpdate(String sql, int[] columnIndexes) throws SQLException { + return pgPreparedStatement.executeUpdate(sql, columnIndexes); + } + + @Override + public int executeUpdate(String sql, String[] columnNames) throws SQLException { + return pgPreparedStatement.executeUpdate(sql, columnNames); + } + + @Override + public boolean execute(String sql, int autoGeneratedKeys) throws SQLException { + return pgPreparedStatement.execute(sql, autoGeneratedKeys); + } + + @Override + public boolean execute(String sql, int[] columnIndexes) throws SQLException { + return pgPreparedStatement.execute(sql, columnIndexes); + } + + @Override + public boolean execute(String sql, String[] columnNames) throws SQLException { + return pgPreparedStatement.execute(sql, columnNames); + } + + @Override + public int getResultSetHoldability() throws SQLException { + return pgPreparedStatement.getResultSetHoldability(); + } + + @Override + public boolean isClosed() throws SQLException { + return pgPreparedStatement.isClosed(); + } + + @Override + public boolean isPoolable() throws SQLException { + return pgPreparedStatement.isPoolable(); + } + + @Override + public void setPoolable(boolean isPoolable) throws SQLException { + pgPreparedStatement.setPoolable(isPoolable); + } + + @Override + public void closeOnCompletion() throws SQLException { + pgPreparedStatement.closeOnCompletion(); + } + + @Override + public boolean isCloseOnCompletion() throws SQLException { + return pgPreparedStatement.isCloseOnCompletion(); + } + + @Override + public T unwrap(Class iface) throws SQLException { + return pgPreparedStatement.unwrap(iface); + } + + @Override + public boolean isWrapperFor(Class iface) throws SQLException { + return pgPreparedStatement.isWrapperFor(iface); + } +} diff --git a/pgjdbc/src/main/java/org/postgresql/readwritesplitting/ForceExecuteCallback.java b/pgjdbc/src/main/java/org/postgresql/readwritesplitting/ForceExecuteCallback.java new file mode 100644 index 0000000000000000000000000000000000000000..2a42b60fd24f868429ea770ce1c41037f70ef7d7 --- /dev/null +++ b/pgjdbc/src/main/java/org/postgresql/readwritesplitting/ForceExecuteCallback.java @@ -0,0 +1,23 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + */ + +package org.postgresql.readwritesplitting; + +import java.sql.SQLException; + +/** + * Force execute callback. + * + * @since 2023-11-20 + * @param type of target to be executed + */ +public interface ForceExecuteCallback { + /** + * Execute. + * + * @param target target to be executed + * @throws SQLException SQL exception + */ + void execute(T target) throws SQLException; +} diff --git a/pgjdbc/src/main/java/org/postgresql/readwritesplitting/ForceExecuteTemplate.java b/pgjdbc/src/main/java/org/postgresql/readwritesplitting/ForceExecuteTemplate.java new file mode 100644 index 0000000000000000000000000000000000000000..c6ddc6f130c24d1ca4ef971086ad40ccf550949f --- /dev/null +++ b/pgjdbc/src/main/java/org/postgresql/readwritesplitting/ForceExecuteTemplate.java @@ -0,0 +1,45 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + */ + +package org.postgresql.readwritesplitting; + +import java.sql.SQLException; +import java.util.Collection; +import java.util.LinkedList; + +/** + * Force execute template. + * + * @since 2023-11-20 + * @param type of targets to be executed + */ +public final class ForceExecuteTemplate { + /** + * Force execute. + * + * @param targets targets to be executed + * @param callback force execute callback + * @throws SQLException throw SQL exception after all targets are executed + */ + public void execute(final Collection targets, final ForceExecuteCallback callback) throws SQLException { + Collection exceptions = new LinkedList<>(); + for (T each : targets) { + try { + callback.execute(each); + } catch (final SQLException ex) { + exceptions.add(ex); + } + } + throwSQLExceptionIfNecessary(exceptions); + } + + private void throwSQLExceptionIfNecessary(final Collection exceptions) throws SQLException { + if (exceptions.isEmpty()) { + return; + } + SQLException ex = new SQLException(""); + exceptions.forEach(ex::setNextException); + throw ex; + } +} diff --git a/pgjdbc/src/main/java/org/postgresql/readwritesplitting/MethodInvocationRecorder.java b/pgjdbc/src/main/java/org/postgresql/readwritesplitting/MethodInvocationRecorder.java new file mode 100644 index 0000000000000000000000000000000000000000..9efbc61d9ade0f40853dc9c686e4332b6c9e4705 --- /dev/null +++ b/pgjdbc/src/main/java/org/postgresql/readwritesplitting/MethodInvocationRecorder.java @@ -0,0 +1,41 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + */ + +package org.postgresql.readwritesplitting; + +import java.sql.SQLException; +import java.util.LinkedHashMap; +import java.util.Map; + +/** + * Method invocation recorder. + * + * @since 2023-11-20 + * @param type of target + */ +public final class MethodInvocationRecorder { + private final Map> methodInvocations = new LinkedHashMap<>(); + + /** + * Record method invocation. + * + * @param methodName method name + * @param callback callback + */ + public void record(final String methodName, final ForceExecuteCallback callback) { + methodInvocations.put(methodName, callback); + } + + /** + * Replay methods invocation. + * + * @param target target object + * @throws SQLException SQL Exception + */ + public void replay(final T target) throws SQLException { + for (ForceExecuteCallback each : methodInvocations.values()) { + each.execute(target); + } + } +} diff --git a/pgjdbc/src/main/java/org/postgresql/readwritesplitting/PgConnectionManager.java b/pgjdbc/src/main/java/org/postgresql/readwritesplitting/PgConnectionManager.java new file mode 100644 index 0000000000000000000000000000000000000000..e500afaef997d9a60161c6bce9cda557cb95409c --- /dev/null +++ b/pgjdbc/src/main/java/org/postgresql/readwritesplitting/PgConnectionManager.java @@ -0,0 +1,202 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + */ + +package org.postgresql.readwritesplitting; + +import org.postgresql.hostchooser.HostRequirement; +import org.postgresql.jdbc.PgConnection; +import org.postgresql.util.HostSpec; + +import java.sql.Connection; +import java.sql.SQLException; +import java.util.Map; +import java.util.Properties; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicReference; + +/** + * PG connection manager. + * + * @since 2023-11-20 + */ +public class PgConnectionManager implements AutoCloseable { + private final MethodInvocationRecorder methodInvocationRecorder = new MethodInvocationRecorder<>(); + + private final ForceExecuteTemplate forceExecuteTemplate = new ForceExecuteTemplate<>(); + + private final Map cachedConnections = new ConcurrentHashMap<>(); + + private final AtomicReference currentConnection = new AtomicReference<>(); + + private final Properties props; + + private final String user; + + private final String database; + + private final String url; + + private final ReadWriteSplittingPgConnection readWriteSplittingPgConnection; + + /** + * Constructor. + * + * @param props props + * @param user user + * @param database database + * @param url url + * @param connection read write splitting pg connection + */ + public PgConnectionManager(Properties props, String user, String database, String url, + ReadWriteSplittingPgConnection connection) { + this.props = props; + this.user = user; + this.database = database; + this.url = url; + this.readWriteSplittingPgConnection = connection; + } + + /** + * Get connection. + * + * @param hostSpec host spec + * @return connection + * @throws SQLException SQL exception + */ + public synchronized PgConnection getConnection(HostSpec hostSpec) throws SQLException { + String cacheKey = getCacheKey(hostSpec); + PgConnection result = cachedConnections.get(cacheKey); + if (result == null) { + result = createConnection(hostSpec, cacheKey); + } + setCurrentConnection(result); + return result; + } + + private PgConnection createConnection(HostSpec hostSpec, String cacheKey) throws SQLException { + PgConnection result = new PgConnection(new HostSpec[]{hostSpec}, user, database, props, url); + methodInvocationRecorder.replay(result); + cachedConnections.put(cacheKey, result); + return result; + } + + private void setCurrentConnection(PgConnection result) { + currentConnection.set(result); + } + + /** + * Get current connection. + * + * @return current connection + * @throws SQLException SQL exception + */ + public PgConnection getCurrentConnection() throws SQLException { + PgConnection result = currentConnection.get(); + return result == null ? getConnection(selectCurrentHostSpec()) : result; + } + + private HostSpec selectCurrentHostSpec() { + ReadWriteSplittingHostSpec readWriteHostSpec = readWriteSplittingPgConnection.getReadWriteSplittingHostSpec(); + if (HostRequirement.master == readWriteHostSpec.getTargetServerType()) { + return readWriteHostSpec.getWriteHostSpec(); + } + if (HostRequirement.secondary == readWriteHostSpec.getTargetServerType()) { + return readWriteHostSpec.readLoadBalance(); + } + return readWriteHostSpec.getWriteHostSpec(); + } + + private String getCacheKey(HostSpec hostSpec) { + return hostSpec.getHost() + ":" + hostSpec.getPort(); + } + + @Override + public void close() throws SQLException { + try { + forceExecuteTemplate.execute(cachedConnections.values(), PgConnection::close); + } finally { + cachedConnections.clear(); + } + } + + /** + * Set auto commit. + * + * @param isAutoCommit auto commit + * @throws SQLException SQL exception + */ + public void setAutoCommit(final boolean isAutoCommit) throws SQLException { + methodInvocationRecorder.record("setAutoCommit", target -> target.setAutoCommit(isAutoCommit)); + forceExecuteTemplate.execute(cachedConnections.values(), connection -> connection.setAutoCommit(isAutoCommit)); + } + + /** + * Set transaction isolation. + * + * @param level transaction isolation level + * @throws SQLException SQL exception + */ + public void setTransactionIsolation(int level) throws SQLException { + methodInvocationRecorder.record("setTransactionIsolation", + connection -> connection.setTransactionIsolation(level)); + forceExecuteTemplate.execute(cachedConnections.values(), + connection -> connection.setTransactionIsolation(level)); + } + + /** + * Set schema. + * + * @param schema schema + * @throws SQLException SQL exception + */ + public void setSchema(String schema) throws SQLException { + methodInvocationRecorder.record("setSchema", connection -> connection.setSchema(schema)); + forceExecuteTemplate.execute(cachedConnections.values(), connection -> connection.setSchema(schema)); + } + + /** + * Commit. + * + * @throws SQLException SQL exception + */ + public void commit() throws SQLException { + forceExecuteTemplate.execute(cachedConnections.values(), Connection::commit); + } + + /** + * Rollback. + * + * @throws SQLException SQL exception + */ + public void rollback() throws SQLException { + forceExecuteTemplate.execute(cachedConnections.values(), Connection::rollback); + } + + /** + * Set read only. + * + * @param isReadOnly read only + * @throws SQLException SQL exception + */ + public void setReadOnly(final boolean isReadOnly) throws SQLException { + methodInvocationRecorder.record("setReadOnly", connection -> connection.setReadOnly(isReadOnly)); + forceExecuteTemplate.execute(cachedConnections.values(), connection -> connection.setReadOnly(isReadOnly)); + } + + /** + * Whether connection valid. + * + * @param timeout timeout + * @return connection valid or not + * @throws SQLException SQL exception + */ + public boolean isValid(final int timeout) throws SQLException { + for (Connection each : cachedConnections.values()) { + if (!each.isValid(timeout)) { + return false; + } + } + return true; + } +} diff --git a/pgjdbc/src/main/java/org/postgresql/readwritesplitting/ReadWriteSplittingHostSpec.java b/pgjdbc/src/main/java/org/postgresql/readwritesplitting/ReadWriteSplittingHostSpec.java new file mode 100644 index 0000000000000000000000000000000000000000..e851346e759c415d15934792bd5fadb1f12a10e0 --- /dev/null +++ b/pgjdbc/src/main/java/org/postgresql/readwritesplitting/ReadWriteSplittingHostSpec.java @@ -0,0 +1,90 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + */ + +package org.postgresql.readwritesplitting; + +import org.postgresql.hostchooser.HostChooser; +import org.postgresql.hostchooser.HostChooserFactory; +import org.postgresql.hostchooser.HostRequirement; +import org.postgresql.util.HostSpec; + +import java.util.Properties; + +/** + * Read write splitting host spec. + * + * @since 2023-11-20 + */ +public class ReadWriteSplittingHostSpec { + private final HostSpec writeHostSpec; + + private final HostSpec[] readHostSpecs; + + private final HostRequirement targetServerType; + + private final HostChooser readChooser; + + /** + * Constructor. + * + * @param writeHostSpec write host spec + * @param hostSpecs host specs + * @param targetServerType target server type + * @param props props + */ + public ReadWriteSplittingHostSpec(HostSpec writeHostSpec, HostSpec[] hostSpecs, HostRequirement targetServerType, + Properties props) { + this.writeHostSpec = writeHostSpec; + this.readHostSpecs = createReadHostSpecs(hostSpecs, writeHostSpec); + this.targetServerType = targetServerType; + readChooser = HostChooserFactory.createHostChooser(readHostSpecs, targetServerType, props); + } + + private HostSpec[] createReadHostSpecs(HostSpec[] hostSpecs, HostSpec writeHostSpec) { + int index = 0; + HostSpec[] result = new HostSpec[hostSpecs.length - 1]; + for (HostSpec each : hostSpecs) { + if (!each.equals(writeHostSpec)) { + result[index++] = each; + } + } + return result; + } + + /** + * Get write host spec. + * + * @return write host spec + */ + public HostSpec getWriteHostSpec() { + return writeHostSpec; + } + + /** + * Get read host specs. + * + * @return read host specs + */ + public HostSpec[] getReadHostSpecs() { + return readHostSpecs; + } + + /** + * Get target server type. + * + * @return target server type + */ + public HostRequirement getTargetServerType() { + return targetServerType; + } + + /** + * Read load balance. + * + * @return routed host spec + */ + public HostSpec readLoadBalance() { + return readChooser.iterator().next().hostSpec; + } +} diff --git a/pgjdbc/src/main/java/org/postgresql/readwritesplitting/ReadWriteSplittingPgConnection.java b/pgjdbc/src/main/java/org/postgresql/readwritesplitting/ReadWriteSplittingPgConnection.java new file mode 100644 index 0000000000000000000000000000000000000000..af6327d80ad12fa605c46643845e389ec51753ac --- /dev/null +++ b/pgjdbc/src/main/java/org/postgresql/readwritesplitting/ReadWriteSplittingPgConnection.java @@ -0,0 +1,449 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + */ + +package org.postgresql.readwritesplitting; + +import org.postgresql.PGProperty; +import org.postgresql.core.v3.ConnectionFactoryImpl; +import org.postgresql.hostchooser.HostRequirement; +import org.postgresql.jdbc.PgConnection; +import org.postgresql.jdbc.ReadWriteSplittingPgPreparedStatement; +import org.postgresql.log.Log; +import org.postgresql.log.Logger; +import org.postgresql.util.GT; +import org.postgresql.util.HostSpec; +import org.postgresql.util.PSQLException; +import org.postgresql.util.PSQLState; + +import java.io.IOException; +import java.sql.Array; +import java.sql.Blob; +import java.sql.CallableStatement; +import java.sql.Clob; +import java.sql.Connection; +import java.sql.DatabaseMetaData; +import java.sql.NClob; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLClientInfoException; +import java.sql.SQLException; +import java.sql.SQLWarning; +import java.sql.SQLXML; +import java.sql.Savepoint; +import java.sql.Statement; +import java.sql.Struct; +import java.util.Collections; +import java.util.Map; +import java.util.Properties; +import java.util.concurrent.Executor; + +/** + * Read write splitting pg connection. + * + * @since 2023-11-20 + */ +public class ReadWriteSplittingPgConnection implements Connection { + private final ReadWriteSplittingHostSpec readWriteSplittingHostSpec; + + private final PgConnectionManager connectionManager; + + private final Log LOGGER = Logger.getLogger(ReadWriteSplittingPgConnection.class.getName()); + + private volatile boolean isClosed; + + private boolean isAutoCommit = true; + + /** + * Constructor. + * + * @param hostSpecs host specs + * @param props props + * @param user user + * @param database database + * @param url url + * @throws SQLException SQL exception + */ + public ReadWriteSplittingPgConnection(HostSpec[] hostSpecs, Properties props, String user, String database, + String url) throws SQLException { + checkRequiredDependencies(); + connectionManager = new PgConnectionManager(props, user, database, url, this); + readWriteSplittingHostSpec = new ReadWriteSplittingHostSpec(getWriteDataSourceAddress(props, hostSpecs), + hostSpecs, getTargetServerTypeParam(props), props); + } + + private static void checkRequiredDependencies() throws PSQLException { + if (!isClassPresent("org.apache.shardingsphere.sql.parser.api.SQLParserEngine")) { + throw new PSQLException("When enableStatementLoadBalance=true, the dependency " + + "shardingsphere-parser-sql-engine does not exist and this function cannot be used.", + PSQLState.UNEXPECTED_ERROR); + } + if (!isClassPresent("org.apache.shardingsphere.sql.parser.opengauss.parser.OpenGaussParserFacade")) { + throw new PSQLException("When enableStatementLoadBalance=true, the dependency " + + "shardingsphere-parser-sql-opengauss does not exist and this function cannot be used.", + PSQLState.UNEXPECTED_ERROR); + } + } + + /** + * Get target server type param. + * + * @param className Class name + * @return Whether class is present + */ + public static boolean isClassPresent(String className) { + try { + Class.forName(className); + return true; + } catch (ClassNotFoundException ignored) { + // Class or one of its dependencies is not present + return false; + } + } + + private HostSpec getWriteDataSourceAddress(Properties props, HostSpec[] hostSpecs) throws SQLException { + String writeDataSourceAddress = PGProperty.WRITE_DATA_SOURCE_ADDRESS.get(props); + if (writeDataSourceAddress.trim().isEmpty()) { + return getWriteAddressByEstablishingConnections(hostSpecs); + } + String[] hostSpec = writeDataSourceAddress.split(":"); + return new HostSpec(hostSpec[0], Integer.parseInt(hostSpec[1])); + } + + private HostSpec getWriteAddressByEstablishingConnections(HostSpec[] hostSpecs) throws SQLException { + for (HostSpec each : hostSpecs) { + PgConnection connection = getConnectionManager().getConnection(each); + ConnectionFactoryImpl connectionFactory = new ConnectionFactoryImpl(); + try { + if (connectionFactory.isMaster(connection.getQueryExecutor())) { + return each; + } + } catch (IOException ex) { + if (LOGGER.isDebugEnabled()) { + LOGGER.debug("Error obtaining node role " + ex.getMessage()); + LOGGER.debug(ex.getStackTrace()); + } + } + } + throw new PSQLException(GT.tr("No write address found"), PSQLState.CONNECTION_UNABLE_TO_CONNECT); + } + + private HostRequirement getTargetServerTypeParam(Properties info) throws PSQLException { + HostRequirement targetServerType; + String targetServerTypeStr = PGProperty.TARGET_SERVER_TYPE.get(info); + try { + targetServerType = HostRequirement.getTargetServerType(targetServerTypeStr); + } catch (IllegalArgumentException ex) { + throw new PSQLException( + GT.tr("Invalid targetServerType value: {0}", targetServerTypeStr), + PSQLState.CONNECTION_UNABLE_TO_CONNECT); + } + return targetServerType; + } + + /** + * Get read write splitting host spec. + * + * @return read write splitting host spec + */ + public ReadWriteSplittingHostSpec getReadWriteSplittingHostSpec() { + return readWriteSplittingHostSpec; + } + + /** + * Get connection manager. + * + * @return the connectionManager + */ + public PgConnectionManager getConnectionManager() { + return connectionManager; + } + + @Override + public Statement createStatement() throws SQLException { + return createStatement(ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY); + } + + @Override + public Statement createStatement(int resultSetType, int resultSetConcurrency) throws SQLException { + return createStatement(resultSetType, resultSetConcurrency, getHoldability()); + } + + @Override + public Statement createStatement(int resultSetType, int resultSetConcurrency, int resultSetHoldability) + throws SQLException { + return new ReadWriteSplittingPgStatement(this, resultSetType, + resultSetConcurrency, resultSetHoldability); + } + + @Override + public PreparedStatement prepareStatement(String sql) throws SQLException { + return new ReadWriteSplittingPgPreparedStatement(this, sql); + } + + @Override + public PreparedStatement prepareStatement(String sql, int autoGeneratedKeys) throws SQLException { + return new ReadWriteSplittingPgPreparedStatement(this, sql, autoGeneratedKeys); + } + + @Override + public PreparedStatement prepareStatement(String sql, int resultSetType, int resultSetConcurrency) + throws SQLException { + return new ReadWriteSplittingPgPreparedStatement(this, sql, resultSetType, + resultSetConcurrency); + } + + @Override + public PreparedStatement prepareStatement(String sql, int[] columnIndexes) throws SQLException { + return new ReadWriteSplittingPgPreparedStatement(this, sql, columnIndexes); + } + + @Override + public PreparedStatement prepareStatement(String sql, String[] columnNames) throws SQLException { + return new ReadWriteSplittingPgPreparedStatement(this, sql, columnNames); + } + + @Override + public PreparedStatement prepareStatement(String sql, int resultSetType, int resultSetConcurrency, + int resultSetHoldability) throws SQLException { + return new ReadWriteSplittingPgPreparedStatement(this, sql, resultSetType, + resultSetConcurrency, resultSetHoldability); + } + + @Override + public CallableStatement prepareCall(String sql) throws SQLException { + return connectionManager.getCurrentConnection().prepareCall(sql); + } + + @Override + public CallableStatement prepareCall(String sql, int resultSetType, int resultSetConcurrency) throws SQLException { + return connectionManager.getCurrentConnection().prepareCall(sql, resultSetType, resultSetConcurrency); + } + + @Override + public CallableStatement prepareCall(String sql, int resultSetType, int resultSetConcurrency, + int resultSetHoldability) throws SQLException { + return connectionManager.getCurrentConnection().prepareCall(sql, resultSetType, resultSetConcurrency, + resultSetHoldability); + } + + @Override + public void setAutoCommit(boolean isAutoCommit) throws SQLException { + this.isAutoCommit = isAutoCommit; + connectionManager.setAutoCommit(isAutoCommit); + } + + @Override + public boolean getAutoCommit() throws SQLException { + return isAutoCommit; + } + + @Override + public void commit() throws SQLException { + connectionManager.commit(); + } + + @Override + public void rollback() throws SQLException { + connectionManager.rollback(); + } + + @Override + public Savepoint setSavepoint() throws SQLException { + return connectionManager.getCurrentConnection().setSavepoint(); + } + + @Override + public Savepoint setSavepoint(String name) throws SQLException { + return connectionManager.getCurrentConnection().setSavepoint(name); + } + + @Override + public void rollback(Savepoint savepoint) throws SQLException { + connectionManager.getCurrentConnection().rollback(savepoint); + } + + @Override + public void releaseSavepoint(Savepoint savepoint) throws SQLException { + connectionManager.getCurrentConnection().releaseSavepoint(savepoint); + } + + @Override + public void close() throws SQLException { + isClosed = true; + connectionManager.close(); + } + + @Override + public boolean isClosed() throws SQLException { + return isClosed; + } + + @Override + public boolean isValid(int timeout) throws SQLException { + return connectionManager.isValid(timeout); + } + + @Override + public void setSchema(String schema) throws SQLException { + connectionManager.setSchema(schema); + } + + @Override + public void setReadOnly(boolean isReadOnly) throws SQLException { + connectionManager.setReadOnly(isReadOnly); + } + + @Override + public boolean isReadOnly() throws SQLException { + return connectionManager.getCurrentConnection().isReadOnly(); + } + + @Override + public void setTransactionIsolation(int level) throws SQLException { + connectionManager.setTransactionIsolation(level); + } + + @Override + public int getTransactionIsolation() throws SQLException { + return connectionManager.getCurrentConnection().getTransactionIsolation(); + } + + @Override + public String nativeSQL(String sql) throws SQLException { + return connectionManager.getCurrentConnection().nativeSQL(sql); + } + + @Override + public DatabaseMetaData getMetaData() throws SQLException { + return connectionManager.getCurrentConnection().getMetaData(); + } + + @Override + public void setCatalog(String catalog) throws SQLException { + connectionManager.getCurrentConnection().setCatalog(catalog); + } + + @Override + public String getCatalog() throws SQLException { + return connectionManager.getCurrentConnection().getCatalog(); + } + + @Override + public SQLWarning getWarnings() throws SQLException { + return connectionManager.getCurrentConnection().getWarnings(); + } + + @Override + public void clearWarnings() throws SQLException { + connectionManager.getCurrentConnection().clearWarnings(); + } + + @Override + public Map> getTypeMap() throws SQLException { + return connectionManager.getCurrentConnection().getTypeMap(); + } + + @Override + public void setTypeMap(Map> map) throws SQLException { + connectionManager.getCurrentConnection().setTypeMap(map); + } + + @Override + public void setHoldability(int holdability) throws SQLException { + connectionManager.getCurrentConnection().setHoldability(holdability); + } + + @Override + public int getHoldability() throws SQLException { + return connectionManager.getCurrentConnection().getHoldability(); + } + + @Override + public Clob createClob() throws SQLException { + return connectionManager.getCurrentConnection().createClob(); + } + + @Override + public Blob createBlob() throws SQLException { + return connectionManager.getCurrentConnection().createBlob(); + } + + @Override + public NClob createNClob() throws SQLException { + return connectionManager.getCurrentConnection().createNClob(); + } + + @Override + public SQLXML createSQLXML() throws SQLException { + return connectionManager.getCurrentConnection().createSQLXML(); + } + + @Override + public void setClientInfo(String name, String value) throws SQLClientInfoException { + try { + connectionManager.getCurrentConnection().setClientInfo(name, value); + } catch (SQLException e) { + throw new SQLClientInfoException(Collections.emptyMap(), e); + } + } + + @Override + public void setClientInfo(Properties properties) throws SQLClientInfoException { + try { + connectionManager.getCurrentConnection().setClientInfo(properties); + } catch (SQLException e) { + throw new SQLClientInfoException(Collections.emptyMap(), e); + } + } + + @Override + public String getClientInfo(String name) throws SQLException { + return connectionManager.getCurrentConnection().getClientInfo(name); + } + + @Override + public Properties getClientInfo() throws SQLException { + return connectionManager.getCurrentConnection().getClientInfo(); + } + + @Override + public Array createArrayOf(String typeName, Object[] elements) throws SQLException { + return connectionManager.getCurrentConnection().createArrayOf(typeName, elements); + } + + @Override + public Struct createStruct(String typeName, Object[] attributes) throws SQLException { + return connectionManager.getCurrentConnection().createStruct(typeName, attributes); + } + + @Override + public String getSchema() throws SQLException { + return connectionManager.getCurrentConnection().getSchema(); + } + + @Override + public void abort(Executor executor) throws SQLException { + connectionManager.getCurrentConnection().abort(executor); + } + + @Override + public void setNetworkTimeout(Executor executor, int milliseconds) throws SQLException { + connectionManager.getCurrentConnection().setNetworkTimeout(executor, milliseconds); + } + + @Override + public int getNetworkTimeout() throws SQLException { + return connectionManager.getCurrentConnection().getNetworkTimeout(); + } + + @Override + public T unwrap(Class iface) throws SQLException { + return connectionManager.getCurrentConnection().unwrap(iface); + } + + @Override + public boolean isWrapperFor(Class iface) throws SQLException { + return connectionManager.getCurrentConnection().isWrapperFor(iface); + } +} diff --git a/pgjdbc/src/main/java/org/postgresql/readwritesplitting/ReadWriteSplittingPgStatement.java b/pgjdbc/src/main/java/org/postgresql/readwritesplitting/ReadWriteSplittingPgStatement.java new file mode 100644 index 0000000000000000000000000000000000000000..4d5f7a6fd441d1320873504d76d76216d2f48b4c --- /dev/null +++ b/pgjdbc/src/main/java/org/postgresql/readwritesplitting/ReadWriteSplittingPgStatement.java @@ -0,0 +1,321 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + */ + +package org.postgresql.readwritesplitting; + +import java.sql.Connection; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.SQLWarning; +import java.sql.Statement; +import java.util.Collection; +import java.util.LinkedList; +import java.util.List; + +/** + * Read write splitting statement. + * + * @since 2023-11-20 + */ +public class ReadWriteSplittingPgStatement implements Statement { + private final List statements = new LinkedList<>(); + + private final ForceExecuteTemplate forceExecuteTemplate = new ForceExecuteTemplate<>(); + + private final ReadWriteSplittingPgConnection readWriteSplittingPgConnection; + + private final Integer resultSetType; + + private final Integer resultSetConcurrency; + + private final Integer resultSetHoldability; + + private Statement currentStatement; + + private ResultSet currentResultSet; + + private boolean isClosed; + + /** + * Constructor. + * + * @param readWriteSplittingPgConnection read write splitting connection + * @param resultSetType result set type + * @param resultSetConcurrency result set concurrency + * @param resultSetHoldability result set holdability + */ + public ReadWriteSplittingPgStatement(ReadWriteSplittingPgConnection readWriteSplittingPgConnection, + int resultSetType, int resultSetConcurrency, int resultSetHoldability) { + this.readWriteSplittingPgConnection = readWriteSplittingPgConnection; + this.resultSetType = resultSetType; + this.resultSetConcurrency = resultSetConcurrency; + this.resultSetHoldability = resultSetHoldability; + } + + @Override + public ResultSet executeQuery(String sql) throws SQLException { + Statement pgStatement = createPgStatement(sql); + ResultSet result = pgStatement.executeQuery(sql); + currentResultSet = result; + return result; + } + + private Statement createPgStatement(String sql) throws SQLException { + Connection connection = SqlRouteEngine.getRoutedConnection(readWriteSplittingPgConnection, sql); + Statement statement = connection.createStatement(resultSetType, resultSetConcurrency, resultSetHoldability); + statements.add(statement); + currentStatement = statement; + return statement; + } + + /** + * Get current result set. + * + * @return current result set + * @throws SQLException SQL exception + */ + public Statement getCurrentStatement() throws SQLException { + if (currentStatement == null) { + Statement statement = + readWriteSplittingPgConnection.getConnectionManager().getCurrentConnection().createStatement(); + statements.add(statement); + currentStatement = statement; + return statement; + } else { + return currentStatement; + } + } + + @Override + public boolean execute(String sql) throws SQLException { + Statement pgStatement = createPgStatement(sql); + return pgStatement.execute(sql); + } + + @Override + public boolean execute(String sql, int autoGeneratedKeys) throws SQLException { + Statement pgStatement = createPgStatement(sql); + return pgStatement.execute(sql, autoGeneratedKeys); + } + + @Override + public boolean execute(String sql, int[] columnIndexes) throws SQLException { + Statement pgStatement = createPgStatement(sql); + return pgStatement.execute(sql, columnIndexes); + } + + @Override + public boolean execute(String sql, String[] columnNames) throws SQLException { + Statement pgStatement = createPgStatement(sql); + return pgStatement.execute(sql, columnNames); + } + + @Override + public int executeUpdate(String sql) throws SQLException { + Statement pgStatement = createPgStatement(sql); + return pgStatement.executeUpdate(sql); + } + + @Override + public int executeUpdate(String sql, int autoGeneratedKeys) throws SQLException { + Statement pgStatement = createPgStatement(sql); + return pgStatement.executeUpdate(sql, autoGeneratedKeys); + } + + @Override + public int executeUpdate(String sql, int[] columnIndexes) throws SQLException { + Statement pgStatement = createPgStatement(sql); + return pgStatement.executeUpdate(sql, columnIndexes); + } + + @Override + public int executeUpdate(String sql, String[] columnNames) throws SQLException { + Statement pgStatement = createPgStatement(sql); + return pgStatement.executeUpdate(sql, columnNames); + } + + public Collection getRoutedStatements() { + return statements; + } + + @Override + public void close() throws SQLException { + isClosed = true; + try { + forceExecuteTemplate.execute(getRoutedStatements(), Statement::close); + } finally { + getRoutedStatements().clear(); + } + } + + @Override + public boolean isClosed() throws SQLException { + return isClosed; + } + + @Override + public ResultSet getResultSet() throws SQLException { + return currentResultSet; + } + + @Override + public int[] executeBatch() throws SQLException { + return getCurrentStatement().executeBatch(); + } + + @Override + public int getMaxFieldSize() throws SQLException { + return getCurrentStatement().getMaxFieldSize(); + } + + @Override + public void setMaxFieldSize(int max) throws SQLException { + getCurrentStatement().setMaxFieldSize(max); + } + + @Override + public int getMaxRows() throws SQLException { + return getCurrentStatement().getMaxRows(); + } + + @Override + public void setMaxRows(int max) throws SQLException { + getCurrentStatement().setMaxRows(max); + } + + @Override + public void setEscapeProcessing(boolean isEnabled) throws SQLException { + getCurrentStatement().setEscapeProcessing(isEnabled); + } + + @Override + public int getQueryTimeout() throws SQLException { + return getCurrentStatement().getQueryTimeout(); + } + + @Override + public void setQueryTimeout(int seconds) throws SQLException { + getCurrentStatement().setQueryTimeout(seconds); + } + + @Override + public void cancel() throws SQLException { + forceExecuteTemplate.execute(getRoutedStatements(), Statement::cancel); + } + + @Override + public SQLWarning getWarnings() throws SQLException { + return getCurrentStatement().getWarnings(); + } + + @Override + public void clearWarnings() throws SQLException { + getCurrentStatement().clearWarnings(); + } + + @Override + public void setCursorName(String name) throws SQLException { + getCurrentStatement().setCursorName(name); + } + + @Override + public int getUpdateCount() throws SQLException { + return getCurrentStatement().getUpdateCount(); + } + + @Override + public boolean getMoreResults() throws SQLException { + return getCurrentStatement().getMoreResults(); + } + + @Override + public void setFetchDirection(int direction) throws SQLException { + getCurrentStatement().setFetchDirection(direction); + } + + @Override + public int getFetchDirection() throws SQLException { + return getCurrentStatement().getFetchDirection(); + } + + @Override + public void setFetchSize(int rows) throws SQLException { + getCurrentStatement().setFetchSize(rows); + } + + @Override + public int getFetchSize() throws SQLException { + return getCurrentStatement().getFetchSize(); + } + + @Override + public int getResultSetConcurrency() throws SQLException { + return getCurrentStatement().getResultSetConcurrency(); + } + + @Override + public int getResultSetType() throws SQLException { + return getCurrentStatement().getResultSetType(); + } + + @Override + public void addBatch(String sql) throws SQLException { + getCurrentStatement().addBatch(sql); + } + + @Override + public void clearBatch() throws SQLException { + getCurrentStatement().clearBatch(); + } + + @Override + public Connection getConnection() throws SQLException { + return readWriteSplittingPgConnection; + } + + @Override + public boolean getMoreResults(int current) throws SQLException { + return getCurrentStatement().getMoreResults(); + } + + @Override + public ResultSet getGeneratedKeys() throws SQLException { + return getCurrentStatement().getGeneratedKeys(); + } + + @Override + public int getResultSetHoldability() throws SQLException { + return getCurrentStatement().getResultSetHoldability(); + } + + @Override + public void setPoolable(boolean isPoolable) throws SQLException { + getCurrentStatement().setPoolable(isPoolable); + } + + @Override + public boolean isPoolable() throws SQLException { + return getCurrentStatement().isPoolable(); + } + + @Override + public void closeOnCompletion() throws SQLException { + getCurrentStatement().closeOnCompletion(); + } + + @Override + public boolean isCloseOnCompletion() throws SQLException { + return getCurrentStatement().isCloseOnCompletion(); + } + + @Override + public T unwrap(Class iface) throws SQLException { + return getCurrentStatement().unwrap(iface); + } + + @Override + public boolean isWrapperFor(Class iface) throws SQLException { + return getCurrentStatement().isWrapperFor(iface); + } +} diff --git a/pgjdbc/src/main/java/org/postgresql/readwritesplitting/SqlRouteEngine.java b/pgjdbc/src/main/java/org/postgresql/readwritesplitting/SqlRouteEngine.java new file mode 100644 index 0000000000000000000000000000000000000000..2f3bba02f455c8ae42bd815b7c3f03fe52e8110b --- /dev/null +++ b/pgjdbc/src/main/java/org/postgresql/readwritesplitting/SqlRouteEngine.java @@ -0,0 +1,91 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + */ + +package org.postgresql.readwritesplitting; + +import org.apache.shardingsphere.sql.parser.api.CacheOption; +import org.apache.shardingsphere.sql.parser.api.SQLParserEngine; +import org.apache.shardingsphere.sql.parser.api.SQLStatementVisitorEngine; +import org.apache.shardingsphere.sql.parser.core.ParseASTNode; +import org.apache.shardingsphere.sql.parser.sql.common.statement.SQLStatement; +import org.apache.shardingsphere.sql.parser.sql.common.statement.dml.SelectStatement; +import org.apache.shardingsphere.sql.parser.sql.dialect.handler.dml.SelectStatementHandler; +import org.postgresql.hostchooser.HostRequirement; +import org.postgresql.log.Log; +import org.postgresql.log.Logger; +import org.postgresql.util.HostSpec; + +import java.sql.Connection; +import java.sql.SQLException; + +/** + * SQL route engine. + * + * @since 2023-11-20 + */ +public class SqlRouteEngine { + private static final String DATABASE_TYPE = "openGauss"; + + private static final SQLParserEngine PARSE_ENGINE = new SQLParserEngine(DATABASE_TYPE, new CacheOption(128, 1024L)); + + private static Log LOGGER = Logger.getLogger(SqlRouteEngine.class.getName()); + + /** + * Route SQL. + * + * @param readWriteSplittingPgConnection read write splitting PG Connection + * @param sql SQL + * @return routed connection + * @throws SQLException SQL exception + */ + public static Connection getRoutedConnection(ReadWriteSplittingPgConnection readWriteSplittingPgConnection, + String sql) throws SQLException { + HostSpec hostSpec = SqlRouteEngine.route(sql, readWriteSplittingPgConnection); + if (LOGGER.isDebugEnabled()) { + LOGGER.debug("Routed connection host spec: " + hostSpec); + } + return readWriteSplittingPgConnection.getConnectionManager().getConnection(hostSpec); + } + + /** + * Route SQL. + * + * @param sql SQL + * @param readWriteSplittingPgConnection read write splitting PG Connection + * @return host spec + * @throws SQLException sql exception + */ + public static HostSpec route(String sql, ReadWriteSplittingPgConnection readWriteSplittingPgConnection) + throws SQLException { + ReadWriteSplittingHostSpec hostSpec = readWriteSplittingPgConnection.getReadWriteSplittingHostSpec(); + if (!readWriteSplittingPgConnection.getAutoCommit()) { + return hostSpec.getWriteHostSpec(); + } + try { + if (HostRequirement.master == hostSpec.getTargetServerType()) { + return hostSpec.getWriteHostSpec(); + } + if (HostRequirement.secondary == hostSpec.getTargetServerType()) { + return hostSpec.readLoadBalance(); + } + ParseASTNode parseASTNode = PARSE_ENGINE.parse(sql, true); + SQLStatement sqlStatement = new SQLStatementVisitorEngine(DATABASE_TYPE, false).visit(parseASTNode); + if (isWriteRouteStatement(sqlStatement)) { + return hostSpec.getWriteHostSpec(); + } + } catch (final Exception ignored) { + return hostSpec.getWriteHostSpec(); + } + return hostSpec.readLoadBalance(); + } + + private static boolean isWriteRouteStatement(final SQLStatement sqlStatement) { + return containsLockSegment(sqlStatement) || !(sqlStatement instanceof SelectStatement); + } + + private static boolean containsLockSegment(final SQLStatement sqlStatement) { + return sqlStatement instanceof SelectStatement + && SelectStatementHandler.getLockSegment((SelectStatement) sqlStatement).isPresent(); + } +} diff --git a/pgjdbc/src/test/java/org/postgresql/test/readwritesplitting/ReadWriteSplittingConnectionTest.java b/pgjdbc/src/test/java/org/postgresql/test/readwritesplitting/ReadWriteSplittingConnectionTest.java new file mode 100644 index 0000000000000000000000000000000000000000..0cfab690d3409563f50661d965606ba8b8c535c7 --- /dev/null +++ b/pgjdbc/src/test/java/org/postgresql/test/readwritesplitting/ReadWriteSplittingConnectionTest.java @@ -0,0 +1,423 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + */ + +package org.postgresql.test.readwritesplitting; + +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Test; +import org.postgresql.readwritesplitting.ReadWriteSplittingHostSpec; +import org.postgresql.readwritesplitting.ReadWriteSplittingPgConnection; +import org.postgresql.test.TestUtil; +import org.postgresql.util.HostSpec; + +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.Properties; + +import static org.hamcrest.CoreMatchers.instanceOf; + +/** + * Read write splitting connection test. + * + * @since 2023-11-20 + */ +public class ReadWriteSplittingConnectionTest { + private static final int DN_NUM = 3; + + private static final String ACCOUNT_TABLE = "account"; + + private static HostSpec[] hostSpecs; + + private static HostSpec writeHostSpec; + + private static HostSpec[] readHostSpecs; + + private int currentIndex; + + private static HostSpec[] initHostSpecs() { + HostSpec[] result = new HostSpec[DN_NUM]; + result[0] = getMasterHostSpec(); + result[1] = new HostSpec(TestUtil.getSecondaryServer(), TestUtil.getSecondaryPort()); + result[2] = new HostSpec(TestUtil.getSecondaryServer2(), TestUtil.getSecondaryServerPort2()); + return result; + } + + private static HostSpec[] initReadSpecs() { + HostSpec[] result = new HostSpec[DN_NUM - 1]; + result[0] = new HostSpec(TestUtil.getSecondaryServer(), TestUtil.getSecondaryPort()); + result[1] = new HostSpec(TestUtil.getSecondaryServer2(), TestUtil.getSecondaryServerPort2()); + return result; + } + + private static HostSpec[] getReadHostSpecs() { + return readHostSpecs; + } + + @BeforeClass + public static void setUp() throws Exception { + hostSpecs = initHostSpecs(); + readHostSpecs = initReadSpecs(); + writeHostSpec = hostSpecs[0]; + try (Connection connection = TestUtil.openDB()) { + TestUtil.createTable(connection, ACCOUNT_TABLE, "id int, balance float, transaction_id int"); + TestUtil.execute("insert into account values(1, 1, 1)", connection); + } + } + + @AfterClass + public static void tearDown() throws Exception { + try (Connection connection = TestUtil.openDB()) { + TestUtil.dropTable(connection, ACCOUNT_TABLE); + } + } + + private static HostSpec getMasterHostSpec() { + return new HostSpec(TestUtil.getServer(), TestUtil.getPort()); + } + + private String initURL(HostSpec[] hostSpecs) { + String host1 = hostSpecs[0].getHost() + ":" + hostSpecs[0].getPort(); + String host2 = hostSpecs[1].getHost() + ":" + hostSpecs[1].getPort(); + String host3 = hostSpecs[2].getHost() + ":" + hostSpecs[2].getPort(); + return "jdbc:postgresql://" + host1 + "," + host2 + "," + host3 + "/" + TestUtil.getDatabase(); + } + + private Connection getConnection(String urlParams) throws SQLException { + String url = initURL(hostSpecs) + urlParams; + Properties props = getProperties(); + return DriverManager.getConnection(url, props); + } + + @Test + public void roundRobinLoadBalanceTest() throws SQLException { + String urlParams = String.format("?enableStatementLoadBalance=true&autoBalance=roundrobin" + + "&writeDataSourceAddress=%s", getMasterHostSpec()); + try (Connection connection = getConnection(urlParams)) { + ReadWriteSplittingPgConnection readWriteSplittingPgConnection = + getReadWriteSplittingPgConnection(connection); + ReadWriteSplittingHostSpec readWriteSplittingHostSpec = + readWriteSplittingPgConnection.getReadWriteSplittingHostSpec(); + Assert.assertEquals(readWriteSplittingHostSpec.getWriteHostSpec(), getMasterHostSpec()); + Assert.assertEquals(readWriteSplittingHostSpec.getReadHostSpecs(), getReadHostSpecs()); + try (Statement statement = connection.createStatement()) { + Assert.assertTrue(statement.execute("SELECT * FROM account WHERE id = 1")); + HostSpec actual = getRoutedReadHostSpec(readWriteSplittingPgConnection); + for (int i = 0; i < readHostSpecs.length; i++) { + HostSpec firstExpected = getNextExpectedRoundRobinSpec(); + if (firstExpected.equals(actual)) { + break; + } + } + Assert.assertTrue(isRoutedToReadHostSpecs(readWriteSplittingPgConnection)); + } + for (int i = 0; i < 10; i++) { + try (Statement statement = connection.createStatement()) { + Assert.assertTrue(statement.execute("SELECT * FROM account WHERE id = 1")); + HostSpec actual = getRoutedReadHostSpec(readWriteSplittingPgConnection); + Assert.assertEquals(getNextExpectedRoundRobinSpec(), actual); + } + } + for (int i = 0; i < 10; i++) { + String sql = "SELECT * FROM account WHERE id = ?"; + try (PreparedStatement statement = connection.prepareStatement(sql)) { + statement.setString(1, "1"); + Assert.assertTrue(statement.execute()); + HostSpec actual = getRoutedReadHostSpec(readWriteSplittingPgConnection); + Assert.assertEquals(getNextExpectedRoundRobinSpec(), actual); + } + } + for (int i = 0; i < 3; i++) { + try (Statement statement = connection.createStatement()) { + statement.execute("UPDATE account SET balance = 11 WHERE id = 1"); + Assert.assertTrue(isRoutedToWriteHostSpecs(readWriteSplittingPgConnection)); + } + } + } + } + + private static ReadWriteSplittingPgConnection getReadWriteSplittingPgConnection(Connection connection) { + Assert.assertThat(connection, instanceOf(ReadWriteSplittingPgConnection.class)); + if (connection instanceof ReadWriteSplittingPgConnection) { + return (ReadWriteSplittingPgConnection) connection; + } + throw new IllegalStateException("Unexpected connection type"); + } + + @Test + public void shuffleLoadBalanceTest() throws SQLException { + String urlParams = String.format("?enableStatementLoadBalance=true&autoBalance=shuffle&writeDataSourceAddress" + + "=%s", getMasterHostSpec()); + try (Connection connection = getConnection(urlParams)) { + ReadWriteSplittingPgConnection readWriteSplittingPgConnection = + getReadWriteSplittingPgConnection(connection); + ReadWriteSplittingHostSpec readWriteSplittingHostSpec = + readWriteSplittingPgConnection.getReadWriteSplittingHostSpec(); + Assert.assertEquals(readWriteSplittingHostSpec.getWriteHostSpec(), getMasterHostSpec()); + Assert.assertEquals(readWriteSplittingHostSpec.getReadHostSpecs(), getReadHostSpecs()); + for (int i = 0; i < 10; i++) { + try (Statement statement = connection.createStatement()) { + Assert.assertTrue(statement.execute("SELECT * FROM account WHERE id = 1")); + Assert.assertTrue(isRoutedToReadHostSpecs(readWriteSplittingPgConnection)); + } + } + for (int i = 0; i < 10; i++) { + String sql = "SELECT * FROM account WHERE id = ?"; + try (PreparedStatement statement = connection.prepareStatement(sql)) { + statement.setString(1, "1"); + Assert.assertTrue(statement.execute()); + Assert.assertTrue(isRoutedToReadHostSpecs(readWriteSplittingPgConnection)); + } + } + for (int i = 0; i < 3; i++) { + try (Statement statement = connection.createStatement()) { + statement.execute("UPDATE account SET balance = 11 WHERE id = 1"); + Assert.assertTrue(isRoutedToWriteHostSpecs(readWriteSplittingPgConnection)); + } + } + } + } + + @Test + public void priorityLoadBalanceTest() throws SQLException { + String urlParams = String.format("?enableStatementLoadBalance=true&autoBalance=priority2" + + "&writeDataSourceAddress=%s", getMasterHostSpec()); + try (Connection connection = getConnection(urlParams)) { + ReadWriteSplittingPgConnection readWriteSplittingPgConnection = + getReadWriteSplittingPgConnection(connection); + ReadWriteSplittingHostSpec readWriteSplittingHostSpec = + readWriteSplittingPgConnection.getReadWriteSplittingHostSpec(); + Assert.assertEquals(readWriteSplittingHostSpec.getWriteHostSpec(), getMasterHostSpec()); + Assert.assertEquals(readWriteSplittingHostSpec.getReadHostSpecs(), getReadHostSpecs()); + HostSpec firstActual; + try (Statement statement = connection.createStatement()) { + Assert.assertTrue(statement.execute("SELECT * FROM account WHERE id = 1")); + firstActual = getRoutedReadHostSpec(readWriteSplittingPgConnection); + } + for (int i = 0; i < 10; i++) { + try (Statement statement = connection.createStatement()) { + Assert.assertTrue(statement.execute("SELECT * FROM account WHERE id = 1")); + HostSpec actual = getRoutedReadHostSpec(readWriteSplittingPgConnection); + Assert.assertEquals(firstActual, actual); + } + } + for (int i = 0; i < 10; i++) { + String sql = "SELECT * FROM account WHERE id = ?"; + try (PreparedStatement statement = connection.prepareStatement(sql)) { + statement.setString(1, "1"); + Assert.assertTrue(statement.execute()); + HostSpec actual = getRoutedReadHostSpec(readWriteSplittingPgConnection); + Assert.assertEquals(firstActual, actual); + } + } + for (int i = 0; i < 3; i++) { + try (Statement statement = connection.createStatement()) { + statement.execute("UPDATE account SET balance = 11 WHERE id = 1"); + Assert.assertTrue(isRoutedToWriteHostSpecs(readWriteSplittingPgConnection)); + } + } + } + } + + @Test + public void targetServerTypeOfMasterTest() throws SQLException { + String urlParams = String.format("?enableStatementLoadBalance=true&autoBalance=shuffle&writeDataSourceAddress" + + "=%s&targetServerType=master", getMasterHostSpec()); + try (Connection connection = getConnection(urlParams)) { + ReadWriteSplittingPgConnection readWriteSplittingPgConnection = + getReadWriteSplittingPgConnection(connection); + ReadWriteSplittingHostSpec readWriteSplittingHostSpec = + readWriteSplittingPgConnection.getReadWriteSplittingHostSpec(); + Assert.assertEquals(readWriteSplittingHostSpec.getWriteHostSpec(), getMasterHostSpec()); + Assert.assertEquals(readWriteSplittingHostSpec.getReadHostSpecs(), getReadHostSpecs()); + for (int i = 0; i < 10; i++) { + try (Statement statement = connection.createStatement()) { + Assert.assertTrue(statement.execute("SELECT * FROM account WHERE id = 1")); + Assert.assertTrue(isRoutedToWriteHostSpecs(readWriteSplittingPgConnection)); + } + } + for (int i = 0; i < 10; i++) { + String sql = "SELECT * FROM account WHERE id = ?"; + try (PreparedStatement statement = connection.prepareStatement(sql)) { + statement.setString(1, "1"); + Assert.assertTrue(statement.execute()); + Assert.assertTrue(isRoutedToWriteHostSpecs(readWriteSplittingPgConnection)); + } + } + for (int i = 0; i < 3; i++) { + try (Statement statement = connection.createStatement()) { + statement.execute("UPDATE account SET balance = 11 WHERE id = 1"); + Assert.assertTrue(isRoutedToWriteHostSpecs(readWriteSplittingPgConnection)); + } + } + } + } + + @Test + public void targetServerTypeOfSlaveTest() throws SQLException { + String urlParams = String.format("?enableStatementLoadBalance=true&autoBalance=shuffle&writeDataSourceAddress" + + "=%s&targetServerType=slave", getMasterHostSpec()); + try (Connection connection = getConnection(urlParams)) { + ReadWriteSplittingPgConnection readWriteSplittingPgConnection = + getReadWriteSplittingPgConnection(connection); + ReadWriteSplittingHostSpec readWriteSplittingHostSpec = + readWriteSplittingPgConnection.getReadWriteSplittingHostSpec(); + Assert.assertEquals(readWriteSplittingHostSpec.getWriteHostSpec(), getMasterHostSpec()); + Assert.assertEquals(readWriteSplittingHostSpec.getReadHostSpecs(), getReadHostSpecs()); + for (int i = 0; i < 10; i++) { + try (Statement statement = connection.createStatement()) { + Assert.assertTrue(statement.execute("SELECT * FROM account WHERE id = 1")); + Assert.assertTrue(isRoutedToReadHostSpecs(readWriteSplittingPgConnection)); + } + } + for (int i = 0; i < 10; i++) { + String sql = "SELECT * FROM account WHERE id = ?"; + try (PreparedStatement statement = connection.prepareStatement(sql)) { + statement.setString(1, "1"); + Assert.assertTrue(statement.execute()); + Assert.assertTrue(isRoutedToReadHostSpecs(readWriteSplittingPgConnection)); + } + } + for (int i = 0; i < 3; i++) { + try (Statement statement = connection.createStatement()) { + try { + statement.execute("UPDATE account SET balance = 11 WHERE id = 1"); + } catch (SQLException e) { + Assert.assertTrue(e.getMessage().contains("ERROR: cannot execute UPDATE in a read-only " + + "transaction")); + } + } + } + } + } + + private static Properties getProperties() { + Properties properties = new Properties(); + properties.setProperty("user", TestUtil.getUser()); + properties.setProperty("password", TestUtil.getPassword()); + return properties; + } + + @Test + public void transactionRouteTest() throws SQLException { + String params = "?enableStatementLoadBalance=true&autoBalance=shuffle&writeDataSourceAddress=%s"; + String urlParams = String.format(params, getMasterHostSpec()); + try (Connection connection = getConnection(urlParams)) { + ReadWriteSplittingPgConnection readWriteSplittingPgConnection = + getReadWriteSplittingPgConnection(connection); + connection.setAutoCommit(false); + for (int i = 0; i < 3; i++) { + try (Statement statement = connection.createStatement()) { + Assert.assertTrue(statement.execute("SELECT * FROM account WHERE id = 1")); + Assert.assertTrue(isRoutedToWriteHostSpecs(readWriteSplittingPgConnection)); + } + } + try (Statement statement = connection.createStatement()) { + statement.execute("UPDATE account SET balance = 11 WHERE id = 1"); + Assert.assertTrue(isRoutedToWriteHostSpecs(readWriteSplittingPgConnection)); + } + connection.commit(); + try (Statement statement = connection.createStatement()) { + Assert.assertTrue(statement.execute("SELECT * FROM account WHERE id = 1")); + Assert.assertTrue(isRoutedToWriteHostSpecs(readWriteSplittingPgConnection)); + } + connection.rollback(); + try (Statement statement = connection.createStatement()) { + Assert.assertTrue(statement.execute("SELECT * FROM account WHERE id = 1")); + Assert.assertTrue(isRoutedToWriteHostSpecs(readWriteSplittingPgConnection)); + } + connection.setAutoCommit(true); + try (Statement statement = connection.createStatement()) { + Assert.assertTrue(statement.execute("SELECT * FROM account WHERE id = 1")); + Assert.assertTrue(isRoutedToReadHostSpecs(readWriteSplittingPgConnection)); + } + try (Statement statement = connection.createStatement()) { + statement.execute("UPDATE account SET balance = 11 WHERE id = 1"); + Assert.assertTrue(isRoutedToWriteHostSpecs(readWriteSplittingPgConnection)); + } + } + } + + @Test + public void executeMultiQueryByOneStatementTest() throws Exception { + String params = "?enableStatementLoadBalance=true&autoBalance=shuffle&writeDataSourceAddress=%s"; + String urlParams = String.format(params, getMasterHostSpec()); + try (Connection connection = getConnection(urlParams)) { + ReadWriteSplittingPgConnection readWriteSplittingPgConnection = + getReadWriteSplittingPgConnection(connection); + try (Statement statement = connection.createStatement()) { + for (int i = 0; i < 3; i++) { + ResultSet resultSet = statement.executeQuery("SELECT * FROM account WHERE id = 1"); + ResultSet resultSet2 = statement.executeQuery("SELECT * FROM account WHERE id = 1"); + ResultSet resultSet3 = statement.executeQuery("SELECT * FROM account WHERE id = 1"); + Assert.assertTrue(resultSet.next()); + Assert.assertTrue(resultSet2.next()); + Assert.assertTrue(resultSet3.next()); + Assert.assertTrue(isRoutedToReadHostSpecs(readWriteSplittingPgConnection)); + } + } + } + } + + @Test + public void executeQueryWithoutWriteDataSourceAddressParamTest() throws Exception { + String urlParams = "?enableStatementLoadBalance=true&autoBalance=shuffle"; + try (Connection connection = getConnection(urlParams)) { + ReadWriteSplittingPgConnection readWriteSplittingPgConnection = + getReadWriteSplittingPgConnection(connection); + for (int i = 0; i < 3; i++) { + try (Statement statement = connection.createStatement()) { + ResultSet resultSet = statement.executeQuery("SELECT * FROM account WHERE id = 1"); + Assert.assertTrue(resultSet.next()); + Assert.assertTrue(isRoutedToReadHostSpecs(readWriteSplittingPgConnection)); + } + } + } + } + + private HostSpec getNextExpectedRoundRobinSpec() { + return readHostSpecs[(currentIndex++) % readHostSpecs.length]; + } + + private static boolean isRoutedToReadHostSpecs(ReadWriteSplittingPgConnection readWriteSplittingPgConnection) + throws SQLException { + String socketAddress = + readWriteSplittingPgConnection.getConnectionManager().getCurrentConnection().getSocketAddress(); + for (HostSpec readHostSpec : getReadHostSpecs()) { + if (socketAddress.endsWith(getHostOrAlias(readHostSpec) + ":" + readHostSpec.getPort())) { + return true; + } + } + return false; + } + + private static HostSpec getRoutedReadHostSpec(ReadWriteSplittingPgConnection readWriteSplittingPgConnection) + throws SQLException { + String socketAddress = + readWriteSplittingPgConnection.getConnectionManager().getCurrentConnection().getSocketAddress(); + for (HostSpec readHostSpec : getReadHostSpecs()) { + if (socketAddress.endsWith(getHostOrAlias(readHostSpec) + ":" + readHostSpec.getPort())) { + return readHostSpec; + } + } + throw new IllegalStateException("Must routed to one read host spec"); + } + + private static boolean isRoutedToWriteHostSpecs(ReadWriteSplittingPgConnection readWriteSplittingPgConnection) + throws SQLException { + String socketAddress = + readWriteSplittingPgConnection.getConnectionManager().getCurrentConnection().getSocketAddress(); + return socketAddress.endsWith(getHostOrAlias(writeHostSpec) + ":" + writeHostSpec.getPort()); + } + + private static String getHostOrAlias(HostSpec readHostSpec) { + String host = readHostSpec.getHost(); + return host.equalsIgnoreCase("localhost") ? "127.0.0.1" : host; + } +}