From 19fdae9bca9dda5a364af00155360edb3ad98328 Mon Sep 17 00:00:00 2001 From: Mirro Mutth Date: Wed, 27 Mar 2024 19:10:31 +0900 Subject: [PATCH] Add InitFlow and move session states to context --- .../r2dbc/mysql/ConnectionContext.java | 151 +++- .../asyncer/r2dbc/mysql/ConnectionState.java | 70 -- .../java/io/asyncer/r2dbc/mysql/InitFlow.java | 747 ++++++++++++++++++ ...ava => MySqlClientConnectionMetadata.java} | 24 +- .../r2dbc/mysql/MySqlConnectionFactory.java | 191 ++--- .../r2dbc/mysql/MySqlSimpleConnection.java | 431 +--------- .../mysql/PrepareParameterizedStatement.java | 8 +- .../r2dbc/mysql/PrepareSimpleStatement.java | 7 +- .../io/asyncer/r2dbc/mysql/QueryFlow.java | 483 +++-------- .../mysql/internal/util/StringUtils.java | 43 +- .../r2dbc/mysql/ConnectionContextTest.java | 32 +- .../mysql/ConnectionIntegrationTest.java | 85 +- .../mysql/MySqlSimpleConnectionTest.java | 168 +++- .../PrepareParameterizedStatementTest.java | 3 +- .../mysql/PrepareSimpleStatementTest.java | 8 +- 15 files changed, 1372 insertions(+), 1079 deletions(-) delete mode 100644 r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/ConnectionState.java create mode 100644 r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/InitFlow.java rename r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/{MySqlSimpleConnectionMetadata.java => MySqlClientConnectionMetadata.java} (60%) diff --git a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/ConnectionContext.java b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/ConnectionContext.java index acf55812b..6e8b31810 100644 --- a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/ConnectionContext.java +++ b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/ConnectionContext.java @@ -16,13 +16,16 @@ package io.asyncer.r2dbc.mysql; +import io.asyncer.r2dbc.mysql.cache.PrepareCache; import io.asyncer.r2dbc.mysql.codec.CodecContext; import io.asyncer.r2dbc.mysql.collation.CharCollation; import io.asyncer.r2dbc.mysql.constant.ServerStatuses; import io.asyncer.r2dbc.mysql.constant.ZeroDateOption; +import io.r2dbc.spi.IsolationLevel; import org.jetbrains.annotations.Nullable; import java.nio.file.Path; +import java.time.Duration; import java.time.ZoneId; import static io.asyncer.r2dbc.mysql.internal.util.AssertUtils.requireNonNull; @@ -37,6 +40,10 @@ public final class ConnectionContext implements CodecContext { private static final ServerVersion NONE_VERSION = ServerVersion.create(0, 0, 0); + private static final ServerVersion MYSQL_5_7_4 = ServerVersion.create(5, 7, 4); + + private static final ServerVersion MARIA_10_1_1 = ServerVersion.create(10, 1, 1, true); + private final ZeroDateOption zeroDateOption; @Nullable @@ -52,16 +59,47 @@ public final class ConnectionContext implements CodecContext { private Capability capability = Capability.DEFAULT; + private PrepareCache prepareCache; + @Nullable private ZoneId timeZone; + private String product = "Unknown"; + + /** + * Current isolation level inferred by past statements. + *

+ * Inference rules: + *

  1. In the beginning, it is also {@link #sessionIsolationLevel}.
  2. + *
  3. A transaction has began with a {@link IsolationLevel}, it will be changed to the value
  4. + *
  5. The transaction end (commit or rollback), it will recover to {@link #sessionIsolationLevel}.
+ */ + private volatile IsolationLevel currentIsolationLevel; + + /** + * Session isolation level. + * + *
  1. It is applied to all subsequent transactions performed within the current session.
  2. + *
  3. Calls {@link io.r2dbc.spi.Connection#setTransactionIsolationLevel}, it will change to the value.
  4. + *
  5. It can be changed within transactions, but does not affect the current ongoing transaction.
+ */ + private volatile IsolationLevel sessionIsolationLevel; + private boolean lockWaitTimeoutSupported = false; + /** + * Current lock wait timeout in seconds. + */ + private volatile Duration currentLockWaitTimeout; + + /** + * Session lock wait timeout in seconds. + */ + private volatile Duration sessionLockWaitTimeout; + /** * Assume that the auto commit is always turned on, it will be set after handshake V10 request message, or OK * message which means handshake V9 completed. - *

- * It would be updated multiple times, so {@code volatile} is required. */ private volatile short serverStatuses = ServerStatuses.AUTO_COMMIT; @@ -80,18 +118,50 @@ public final class ConnectionContext implements CodecContext { } /** - * Initializes this context. + * Initializes handshake information after connection is established. * * @param connectionId the connection identifier that is specified by server. * @param version the server version. * @param capability the connection capabilities. */ - void init(int connectionId, ServerVersion version, Capability capability) { + void initHandshake(int connectionId, ServerVersion version, Capability capability) { this.connectionId = connectionId; this.serverVersion = version; this.capability = capability; } + /** + * Initializes session information after logged-in. + * + * @param prepareCache the prepare cache. + * @param isolationLevel the session isolation level. + * @param lockWaitTimeoutSupported if the server supports lock wait timeout. + * @param lockWaitTimeout the lock wait timeout. + * @param product the server product name. + * @param timeZone the server timezone. + */ + void initSession( + PrepareCache prepareCache, + IsolationLevel isolationLevel, + boolean lockWaitTimeoutSupported, + Duration lockWaitTimeout, + @Nullable String product, + @Nullable ZoneId timeZone + ) { + this.prepareCache = prepareCache; + this.currentIsolationLevel = this.sessionIsolationLevel = isolationLevel; + this.lockWaitTimeoutSupported = lockWaitTimeoutSupported; + this.currentLockWaitTimeout = this.sessionLockWaitTimeout = lockWaitTimeout; + this.product = product == null ? "Unknown" : product; + + if (timeZone != null) { + if (isTimeZoneInitialized()) { + throw new IllegalStateException("Connection timezone have been initialized"); + } + this.timeZone = timeZone; + } + } + /** * Get the connection identifier that is specified by server. * @@ -128,6 +198,14 @@ public ZoneId getTimeZone() { return timeZone; } + String getProduct() { + return product; + } + + PrepareCache getPrepareCache() { + return prepareCache; + } + boolean isTimeZoneInitialized() { return timeZone != null; } @@ -138,13 +216,6 @@ public boolean isMariaDb() { return (capability != null && capability.isMariaDb()) || serverVersion.isMariaDb(); } - void initTimeZone(ZoneId timeZone) { - if (isTimeZoneInitialized()) { - throw new IllegalStateException("Connection timezone have been initialized"); - } - this.timeZone = timeZone; - } - @Override public ZeroDateOption getZeroDateOption() { return zeroDateOption; @@ -170,19 +241,23 @@ public int getLocalInfileBufferSize() { } /** - * Checks if the server supports lock wait timeout. + * Checks if the server supports InnoDB lock wait timeout. * - * @return if the server supports lock wait timeout. + * @return if the server supports InnoDB lock wait timeout. */ public boolean isLockWaitTimeoutSupported() { return lockWaitTimeoutSupported; } /** - * Enables lock wait timeout supported when loading session variables. + * Checks if the server supports statement timeout. + * + * @return if the server supports statement timeout. */ - void enableLockWaitTimeoutSupported() { - this.lockWaitTimeoutSupported = true; + public boolean isStatementTimeoutSupported() { + boolean isMariaDb = isMariaDb(); + return (isMariaDb && serverVersion.isGreaterThanOrEqualTo(MARIA_10_1_1)) || + (!isMariaDb && serverVersion.isGreaterThanOrEqualTo(MYSQL_5_7_4)); } /** @@ -202,4 +277,48 @@ public short getServerStatuses() { public void setServerStatuses(short serverStatuses) { this.serverStatuses = serverStatuses; } + + IsolationLevel getCurrentIsolationLevel() { + return currentIsolationLevel; + } + + void setCurrentIsolationLevel(IsolationLevel isolationLevel) { + this.currentIsolationLevel = isolationLevel; + } + + void resetCurrentIsolationLevel() { + this.currentIsolationLevel = this.sessionIsolationLevel; + } + + IsolationLevel getSessionIsolationLevel() { + return sessionIsolationLevel; + } + + void setSessionIsolationLevel(IsolationLevel isolationLevel) { + this.sessionIsolationLevel = isolationLevel; + } + + void setCurrentLockWaitTimeout(Duration timeoutSeconds) { + this.currentLockWaitTimeout = timeoutSeconds; + } + + void resetCurrentLockWaitTimeout() { + this.currentLockWaitTimeout = this.sessionLockWaitTimeout; + } + + boolean isLockWaitTimeoutChanged() { + return currentLockWaitTimeout != sessionLockWaitTimeout; + } + + Duration getSessionLockWaitTimeout() { + return sessionLockWaitTimeout; + } + + void setAllLockWaitTimeout(Duration timeoutSeconds) { + this.currentLockWaitTimeout = this.sessionLockWaitTimeout = timeoutSeconds; + } + + boolean isInTransaction() { + return (serverStatuses & ServerStatuses.IN_TRANSACTION) != 0; + } } diff --git a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/ConnectionState.java b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/ConnectionState.java deleted file mode 100644 index 73a9caf09..000000000 --- a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/ConnectionState.java +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Copyright 2023 asyncer.io projects - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.asyncer.r2dbc.mysql; - -import io.r2dbc.spi.IsolationLevel; - -/** - * An internal interface for check, set and reset connection states. - */ -interface ConnectionState { - - /** - * Sets current isolation level. - * - * @param level current level. - */ - void setIsolationLevel(IsolationLevel level); - - /** - * Returns session lock wait timeout. - * - * @return Session lock wait timeout. - */ - long getSessionLockWaitTimeout(); - - /** - * Sets current lock wait timeout. - * - * @param timeoutSeconds seconds of current lock wait timeout. - */ - void setCurrentLockWaitTimeout(long timeoutSeconds); - - /** - * Checks if lock wait timeout has been changed by {@link #setCurrentLockWaitTimeout(long)}. - * - * @return if lock wait timeout changed. - */ - boolean isLockWaitTimeoutChanged(); - - /** - * Resets current isolation level in initial state. - */ - void resetIsolationLevel(); - - /** - * Resets current isolation level in initial state. - */ - void resetCurrentLockWaitTimeout(); - - /** - * Checks if connection is processing a transaction. - * - * @return if in a transaction. - */ - boolean isInTransaction(); -} diff --git a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/InitFlow.java b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/InitFlow.java new file mode 100644 index 000000000..32dcc1c8a --- /dev/null +++ b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/InitFlow.java @@ -0,0 +1,747 @@ +/* + * Copyright 2024 asyncer.io projects + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.asyncer.r2dbc.mysql; + +import io.asyncer.r2dbc.mysql.api.MySqlResult; +import io.asyncer.r2dbc.mysql.authentication.MySqlAuthProvider; +import io.asyncer.r2dbc.mysql.cache.Caches; +import io.asyncer.r2dbc.mysql.cache.PrepareCache; +import io.asyncer.r2dbc.mysql.client.Client; +import io.asyncer.r2dbc.mysql.client.FluxExchangeable; +import io.asyncer.r2dbc.mysql.codec.Codecs; +import io.asyncer.r2dbc.mysql.codec.CodecsBuilder; +import io.asyncer.r2dbc.mysql.constant.CompressionAlgorithm; +import io.asyncer.r2dbc.mysql.constant.SslMode; +import io.asyncer.r2dbc.mysql.extension.CodecRegistrar; +import io.asyncer.r2dbc.mysql.internal.util.StringUtils; +import io.asyncer.r2dbc.mysql.message.client.AuthResponse; +import io.asyncer.r2dbc.mysql.message.client.ClientMessage; +import io.asyncer.r2dbc.mysql.message.client.HandshakeResponse; +import io.asyncer.r2dbc.mysql.message.client.InitDbMessage; +import io.asyncer.r2dbc.mysql.message.client.SslRequest; +import io.asyncer.r2dbc.mysql.message.client.SubsequenceClientMessage; +import io.asyncer.r2dbc.mysql.message.server.AuthMoreDataMessage; +import io.asyncer.r2dbc.mysql.message.server.ChangeAuthMessage; +import io.asyncer.r2dbc.mysql.message.server.CompleteMessage; +import io.asyncer.r2dbc.mysql.message.server.ErrorMessage; +import io.asyncer.r2dbc.mysql.message.server.HandshakeHeader; +import io.asyncer.r2dbc.mysql.message.server.HandshakeRequest; +import io.asyncer.r2dbc.mysql.message.server.OkMessage; +import io.asyncer.r2dbc.mysql.message.server.ServerMessage; +import io.asyncer.r2dbc.mysql.message.server.SyntheticSslResponseMessage; +import io.netty.buffer.ByteBufAllocator; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; +import io.r2dbc.spi.IsolationLevel; +import io.r2dbc.spi.R2dbcNonTransientResourceException; +import io.r2dbc.spi.R2dbcPermissionDeniedException; +import io.r2dbc.spi.Readable; +import org.jetbrains.annotations.Nullable; +import reactor.core.CoreSubscriber; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; +import reactor.core.publisher.SynchronousSink; +import reactor.util.concurrent.Queues; + +import java.security.AccessController; +import java.security.PrivilegedAction; +import java.time.DateTimeException; +import java.time.Duration; +import java.time.ZoneId; +import java.time.ZoneOffset; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.function.BiConsumer; +import java.util.function.Function; + +/** + * A message flow utility that can initializes the session of {@link Client}. + *

+ * It should not use server-side prepared statements, because {@link PrepareCache} will be initialized after the session + * is initialized. + */ +final class InitFlow { + + private static final InternalLogger logger = InternalLoggerFactory.getInstance(InitFlow.class); + + private static final ServerVersion MARIA_11_1_1 = ServerVersion.create(11, 1, 1, true); + + private static final ServerVersion MYSQL_8_0_3 = ServerVersion.create(8, 0, 3); + + private static final ServerVersion MYSQL_5_7_20 = ServerVersion.create(5, 7, 20); + + private static final ServerVersion MYSQL_8 = ServerVersion.create(8, 0, 0); + + private static final BiConsumer> INIT_DB = (message, sink) -> { + if (message instanceof ErrorMessage) { + ErrorMessage msg = (ErrorMessage) message; + logger.debug("Use database failed: [{}] [{}] {}", msg.getCode(), msg.getSqlState(), msg.getMessage()); + sink.next(false); + sink.complete(); + } else if (message instanceof CompleteMessage && ((CompleteMessage) message).isDone()) { + sink.next(true); + sink.complete(); + } else { + ReferenceCountUtil.safeRelease(message); + } + }; + + private static final BiConsumer> INIT_DB_AFTER = (message, sink) -> { + if (message instanceof ErrorMessage) { + sink.error(((ErrorMessage) message).toException()); + } else if (message instanceof CompleteMessage && ((CompleteMessage) message).isDone()) { + sink.complete(); + } else { + ReferenceCountUtil.safeRelease(message); + } + }; + + /** + * Initializes handshake and login a {@link Client}. + * + * @param client the {@link Client} to exchange messages with. + * @param sslMode the {@link SslMode} defines SSL capability and behavior. + * @param database the database that will be connected. + * @param user the user that will be login. + * @param password the password of the {@code user}. + * @param compressionAlgorithms the list of compression algorithms. + * @param zstdCompressionLevel the zstd compression level. + * @return a {@link Flux} that indicates the initialization is done, or an error if the initialization failed. + */ + static Flux initHandshake(Client client, SslMode sslMode, String database, String user, + @Nullable CharSequence password, Set compressionAlgorithms, int zstdCompressionLevel) { + return client.exchange(new HandshakeExchangeable(client, sslMode, database, user, password, + compressionAlgorithms, zstdCompressionLevel)); + } + + /** + * Initializes the session and {@link Codecs} of a {@link Client}. + * + * @param client the client + * @param database the database to use after session initialization + * @param prepareCacheSize the size of prepare cache + * @param sessionVariables the session variables to set + * @param forceTimeZone if the timezone should be set to session + * @param lockWaitTimeout the lock wait timeout that should be set to session + * @param statementTimeout the statement timeout that should be set to session + * @return a {@link Mono} that indicates the {@link Codecs}, or an error if the initialization failed + */ + static Mono initSession( + Client client, + String database, + int prepareCacheSize, + List sessionVariables, + boolean forceTimeZone, + @Nullable Duration lockWaitTimeout, + @Nullable Duration statementTimeout, + Extensions extensions + ) { + return Mono.defer(() -> { + ByteBufAllocator allocator = client.getByteBufAllocator(); + CodecsBuilder builder = Codecs.builder(); + + extensions.forEach(CodecRegistrar.class, registrar -> + registrar.register(allocator, builder)); + + Codecs codecs = builder.build(); + + List variables = mergeSessionVariables(client, sessionVariables, forceTimeZone, statementTimeout); + + logger.debug("Initializing client session: {}", variables); + + return QueryFlow.setSessionVariables(client, variables) + .then(loadSessionVariables(client, codecs)) + .flatMap(data -> loadAndInitInnoDbEngineStatus(data, client, codecs, lockWaitTimeout)) + .flatMap(data -> { + ConnectionContext context = client.getContext(); + + logger.debug("Initializing connection {} context: {}", context.getConnectionId(), data); + context.initSession( + Caches.createPrepareCache(prepareCacheSize), + data.level, + data.lockWaitTimeoutSupported, + data.lockWaitTimeout, + data.product, + data.timeZone + ); + + if (!data.lockWaitTimeoutSupported) { + logger.info( + "Lock wait timeout is not supported by server, all related operations will be ignored"); + } + + return database.isEmpty() ? Mono.just(codecs) : + initDatabase(client, database).then(Mono.just(codecs)); + }); + }); + } + + private static Mono loadAndInitInnoDbEngineStatus( + SessionState data, + Client client, + Codecs codecs, + @Nullable Duration lockWaitTimeout + ) { + return new TextSimpleStatement(client, codecs, "SHOW VARIABLES LIKE 'innodb\\\\_lock\\\\_wait\\\\_timeout'") + .execute() + .flatMap(r -> r.map(readable -> { + String value = readable.get(1, String.class); + + if (value == null || value.isEmpty()) { + return data; + } else { + return data.lockWaitTimeout(Duration.ofSeconds(Long.parseLong(value))); + } + })) + .single(data) + .flatMap(d -> { + if (lockWaitTimeout != null) { + // Do not use context.isLockWaitTimeoutSupported() here, because its session variable is not set + if (d.lockWaitTimeoutSupported) { + return QueryFlow.executeVoid(client, StringUtils.lockWaitTimeoutStatement(lockWaitTimeout)) + .then(Mono.fromSupplier(() -> d.lockWaitTimeout(lockWaitTimeout))); + } + + logger.warn("Lock wait timeout is not supported by server, ignore initial setting"); + return Mono.just(d); + } + return Mono.just(d); + }); + } + + private static Mono loadSessionVariables(Client client, Codecs codecs) { + ConnectionContext context = client.getContext(); + StringBuilder query = new StringBuilder(128) + .append("SELECT ") + .append(transactionIsolationColumn(context)) + .append(",@@version_comment AS v"); + + Function> handler; + + if (context.isTimeZoneInitialized()) { + handler = r -> convertSessionData(r, false); + } else { + query.append(",@@system_time_zone AS s,@@time_zone AS t"); + handler = r -> convertSessionData(r, true); + } + + return new TextSimpleStatement(client, codecs, query.toString()) + .execute() + .flatMap(handler) + .last(); + } + + private static Mono initDatabase(Client client, String database) { + return client.exchange(new InitDbMessage(database), INIT_DB) + .last() + .flatMap(success -> { + if (success) { + return Mono.empty(); + } + + String sql = "CREATE DATABASE IF NOT EXISTS " + StringUtils.quoteIdentifier(database); + + return QueryFlow.executeVoid(client, sql) + .then(client.exchange(new InitDbMessage(database), INIT_DB_AFTER).then()); + }); + } + + private static List mergeSessionVariables( + Client client, + List sessionVariables, + boolean forceTimeZone, + @Nullable Duration statementTimeout + ) { + ConnectionContext context = client.getContext(); + + if ((!forceTimeZone || !context.isTimeZoneInitialized()) && statementTimeout == null) { + return sessionVariables; + } + + List variables = new ArrayList<>(sessionVariables.size() + 2); + + variables.addAll(sessionVariables); + + if (forceTimeZone && context.isTimeZoneInitialized()) { + variables.add(timeZoneVariable(context.getTimeZone())); + } + + if (statementTimeout != null) { + if (context.isStatementTimeoutSupported()) { + variables.add(StringUtils.statementTimeoutVariable(statementTimeout, context.isMariaDb())); + } else { + logger.warn("Statement timeout is not supported in {}, ignore initial setting", + context.getServerVersion()); + } + } + + return variables; + } + + private static String timeZoneVariable(ZoneId timeZone) { + String offerStr = timeZone instanceof ZoneOffset && "Z".equalsIgnoreCase(timeZone.getId()) ? + "+00:00" : timeZone.getId(); + + return "time_zone='" + offerStr + "'"; + } + + private static Flux convertSessionData(MySqlResult r, boolean timeZone) { + return r.map(readable -> { + IsolationLevel level = convertIsolationLevel(readable.get(0, String.class)); + String product = readable.get(1, String.class); + + return new SessionState(level, product, timeZone ? readZoneId(readable) : null); + }); + } + + /** + * Resolves the column of session isolation level, the {@literal @@tx_isolation} has been marked as deprecated. + *

+ * If server is MariaDB, {@literal @@transaction_isolation} is used starting from {@literal 11.1.1}. + *

+ * If the server is MySQL, use {@literal @@transaction_isolation} starting from {@literal 8.0.3}, or between + * {@literal 5.7.20} and {@literal 8.0.0} (exclusive). + */ + private static String transactionIsolationColumn(ConnectionContext context) { + ServerVersion version = context.getServerVersion(); + + if (context.isMariaDb()) { + return version.isGreaterThanOrEqualTo(MARIA_11_1_1) ? "@@transaction_isolation AS i" : + "@@tx_isolation AS i"; + } + + return version.isGreaterThanOrEqualTo(MYSQL_8_0_3) || + (version.isGreaterThanOrEqualTo(MYSQL_5_7_20) && version.isLessThan(MYSQL_8)) ? + "@@transaction_isolation AS i" : "@@tx_isolation AS i"; + } + + private static ZoneId readZoneId(Readable readable) { + String systemTimeZone = readable.get(2, String.class); + String timeZone = readable.get(3, String.class); + + if (timeZone == null || timeZone.isEmpty() || "SYSTEM".equalsIgnoreCase(timeZone)) { + if (systemTimeZone == null || systemTimeZone.isEmpty()) { + logger.warn("MySQL does not return any timezone, trying to use system default timezone"); + return ZoneId.systemDefault().normalized(); + } else { + return convertZoneId(systemTimeZone); + } + } else { + return convertZoneId(timeZone); + } + } + + private static ZoneId convertZoneId(String id) { + try { + return StringUtils.parseZoneId(id); + } catch (DateTimeException e) { + logger.warn("The server timezone is unknown <{}>, trying to use system default timezone", id, e); + + return ZoneId.systemDefault().normalized(); + } + } + + private static IsolationLevel convertIsolationLevel(@Nullable String name) { + if (name == null) { + logger.warn("Isolation level is null in current session, fallback to repeatable read"); + + return IsolationLevel.REPEATABLE_READ; + } + + switch (name) { + case "READ-UNCOMMITTED": + return IsolationLevel.READ_UNCOMMITTED; + case "READ-COMMITTED": + return IsolationLevel.READ_COMMITTED; + case "REPEATABLE-READ": + return IsolationLevel.REPEATABLE_READ; + case "SERIALIZABLE": + return IsolationLevel.SERIALIZABLE; + } + + logger.warn("Unknown isolation level {} in current session, fallback to repeatable read", name); + + return IsolationLevel.REPEATABLE_READ; + } + + private InitFlow() { + } + + private static final class SessionState { + + private final IsolationLevel level; + + @Nullable + private final String product; + + @Nullable + private final ZoneId timeZone; + + private final Duration lockWaitTimeout; + + private final boolean lockWaitTimeoutSupported; + + SessionState(IsolationLevel level, @Nullable String product, @Nullable ZoneId timeZone) { + this(level, product, timeZone, Duration.ZERO, false); + } + + private SessionState( + IsolationLevel level, + @Nullable String product, + @Nullable ZoneId timeZone, + Duration lockWaitTimeout, + boolean lockWaitTimeoutSupported + ) { + this.level = level; + this.product = product; + this.timeZone = timeZone; + this.lockWaitTimeout = lockWaitTimeout; + this.lockWaitTimeoutSupported = lockWaitTimeoutSupported; + } + + SessionState lockWaitTimeout(Duration timeout) { + return new SessionState(level, product, timeZone, timeout, true); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof SessionState)) { + return false; + } + + SessionState that = (SessionState) o; + + return lockWaitTimeoutSupported == that.lockWaitTimeoutSupported && + level.equals(that.level) && + Objects.equals(product, that.product) && + Objects.equals(timeZone, that.timeZone) && + lockWaitTimeout.equals(that.lockWaitTimeout); + } + + @Override + public int hashCode() { + int result = level.hashCode(); + result = 31 * result + (product != null ? product.hashCode() : 0); + result = 31 * result + (timeZone != null ? timeZone.hashCode() : 0); + result = 31 * result + lockWaitTimeout.hashCode(); + return 31 * result + (lockWaitTimeoutSupported ? 1 : 0); + } + + @Override + public String toString() { + return "SessionState{level=" + level + + ", product='" + product + + "', timeZone=" + timeZone + + ", lockWaitTimeout=" + lockWaitTimeout + + ", lockWaitTimeoutSupported=" + lockWaitTimeoutSupported + + '}'; + } + } +} + +/** + * An implementation of {@link FluxExchangeable} that considers login to the database. + *

+ * Not like other {@link FluxExchangeable}s, it is started by a server-side message, which should be an implementation + * of {@link HandshakeRequest}. + */ +final class HandshakeExchangeable extends FluxExchangeable { + + private static final InternalLogger logger = InternalLoggerFactory.getInstance(HandshakeExchangeable.class); + + private static final Map ATTRIBUTES = Collections.emptyMap(); + + private static final String CLI_SPECIFIC = "HY000"; + + private static final int HANDSHAKE_VERSION = 10; + + private final Sinks.Many requests = Sinks.many().unicast() + .onBackpressureBuffer(Queues.one().get()); + + private final Client client; + + private final SslMode sslMode; + + private final String database; + + private final String user; + + @Nullable + private final CharSequence password; + + private final Set compressions; + + private final int zstdCompressionLevel; + + private boolean handshake = true; + + private MySqlAuthProvider authProvider; + + private byte[] salt; + + private boolean sslCompleted; + + HandshakeExchangeable(Client client, SslMode sslMode, String database, String user, + @Nullable CharSequence password, Set compressions, + int zstdCompressionLevel) { + this.client = client; + this.sslMode = sslMode; + this.database = database; + this.user = user; + this.password = password; + this.compressions = compressions; + this.zstdCompressionLevel = zstdCompressionLevel; + this.sslCompleted = sslMode == SslMode.TUNNEL; + } + + @Override + public void subscribe(CoreSubscriber actual) { + requests.asFlux().subscribe(actual); + } + + @Override + public void accept(ServerMessage message, SynchronousSink sink) { + if (message instanceof ErrorMessage) { + sink.error(((ErrorMessage) message).toException()); + return; + } + + // Ensures it will be initialized only once. + if (handshake) { + handshake = false; + if (message instanceof HandshakeRequest) { + HandshakeRequest request = (HandshakeRequest) message; + Capability capability = initHandshake(request); + + if (capability.isSslEnabled()) { + emitNext(SslRequest.from(capability, client.getContext().getClientCollation().getId()), sink); + } else { + emitNext(createHandshakeResponse(capability), sink); + } + } else { + sink.error(new R2dbcPermissionDeniedException("Unexpected message type '" + + message.getClass().getSimpleName() + "' in init phase")); + } + + return; + } + + if (message instanceof OkMessage) { + logger.trace("Connection (id {}) login success", client.getContext().getConnectionId()); + client.loginSuccess(); + sink.complete(); + } else if (message instanceof SyntheticSslResponseMessage) { + sslCompleted = true; + emitNext(createHandshakeResponse(client.getContext().getCapability()), sink); + } else if (message instanceof AuthMoreDataMessage) { + AuthMoreDataMessage msg = (AuthMoreDataMessage) message; + + if (msg.isFailed()) { + if (logger.isDebugEnabled()) { + logger.debug("Connection (id {}) fast authentication failed, use full authentication", + client.getContext().getConnectionId()); + } + + emitNext(createAuthResponse("full authentication"), sink); + } + // Otherwise success, wait until OK message or Error message. + } else if (message instanceof ChangeAuthMessage) { + ChangeAuthMessage msg = (ChangeAuthMessage) message; + + authProvider = MySqlAuthProvider.build(msg.getAuthType()); + salt = msg.getSalt(); + emitNext(createAuthResponse("change authentication"), sink); + } else { + sink.error(new R2dbcPermissionDeniedException("Unexpected message type '" + + message.getClass().getSimpleName() + "' in login phase")); + } + } + + @Override + public void dispose() { + // No particular error condition handling for complete signal. + this.requests.tryEmitComplete(); + } + + private void emitNext(SubsequenceClientMessage message, SynchronousSink sink) { + Sinks.EmitResult result = requests.tryEmitNext(message); + + if (result != Sinks.EmitResult.OK) { + sink.error(new IllegalStateException("Fail to emit a login request due to " + result)); + } + } + + private AuthResponse createAuthResponse(String phase) { + MySqlAuthProvider authProvider = getAndNextProvider(); + + if (authProvider.isSslNecessary() && !sslCompleted) { + throw new R2dbcPermissionDeniedException(authFails(authProvider.getType(), phase), CLI_SPECIFIC); + } + + return new AuthResponse(authProvider.authentication(password, salt, client.getContext().getClientCollation())); + } + + private Capability clientCapability(Capability serverCapability) { + Capability.Builder builder = serverCapability.mutate(); + + builder.disableSessionTrack(); + builder.disableDatabasePinned(); + builder.disableIgnoreAmbiguitySpace(); + builder.disableInteractiveTimeout(); + + if (sslMode == SslMode.TUNNEL) { + // Tunnel does not use MySQL SSL protocol, disable it. + builder.disableSsl(); + } else if (!serverCapability.isSslEnabled()) { + // Server unsupported SSL. + if (sslMode.requireSsl()) { + // Before handshake, Client.context does not be initialized + throw new R2dbcPermissionDeniedException("Server does not support SSL but mode '" + sslMode + + "' requires SSL", CLI_SPECIFIC); + } else if (sslMode.startSsl()) { + // SSL has start yet, and client can disable SSL, disable now. + client.sslUnsupported(); + } + } else { + // The server supports SSL, but the user does not want to use SSL, disable it. + if (!sslMode.startSsl()) { + builder.disableSsl(); + } + } + + if (isZstdAllowed(serverCapability)) { + if (isZstdSupported()) { + builder.disableZlibCompression(); + } else { + logger.warn("Server supports zstd, but zstd-jni dependency is missing"); + + if (isZlibAllowed(serverCapability)) { + builder.disableZstdCompression(); + } else if (compressions.contains(CompressionAlgorithm.UNCOMPRESSED)) { + builder.disableCompression(); + } else { + throw new R2dbcNonTransientResourceException( + "Environment does not support a compression algorithm in " + compressions + + ", config does not allow uncompressed mode", CLI_SPECIFIC); + } + } + } else if (isZlibAllowed(serverCapability)) { + builder.disableZstdCompression(); + } else if (compressions.contains(CompressionAlgorithm.UNCOMPRESSED)) { + builder.disableCompression(); + } else { + throw new R2dbcPermissionDeniedException( + "Environment does not support a compression algorithm in " + compressions + + ", config does not allow uncompressed mode", CLI_SPECIFIC); + } + + if (database.isEmpty()) { + builder.disableConnectWithDatabase(); + } + + if (client.getContext().getLocalInfilePath() == null) { + builder.disableLoadDataLocalInfile(); + } + + if (ATTRIBUTES.isEmpty()) { + builder.disableConnectAttributes(); + } + + return builder.build(); + } + + private Capability initHandshake(HandshakeRequest message) { + HandshakeHeader header = message.getHeader(); + int handshakeVersion = header.getProtocolVersion(); + ServerVersion serverVersion = header.getServerVersion(); + + if (handshakeVersion < HANDSHAKE_VERSION) { + logger.warn("MySQL use handshake V{}, server version is {}, maybe most features are unavailable", + handshakeVersion, serverVersion); + } + + Capability capability = clientCapability(message.getServerCapability()); + + // No need initialize server statuses because it has initialized by read filter. + this.client.getContext().initHandshake(header.getConnectionId(), serverVersion, capability); + this.authProvider = MySqlAuthProvider.build(message.getAuthType()); + this.salt = message.getSalt(); + + return capability; + } + + private MySqlAuthProvider getAndNextProvider() { + MySqlAuthProvider authProvider = this.authProvider; + this.authProvider = authProvider.next(); + return authProvider; + } + + private HandshakeResponse createHandshakeResponse(Capability capability) { + MySqlAuthProvider authProvider = getAndNextProvider(); + + if (authProvider.isSslNecessary() && !sslCompleted) { + throw new R2dbcPermissionDeniedException(authFails(authProvider.getType(), "handshake"), + CLI_SPECIFIC); + } + + byte[] authorization = authProvider.authentication(password, salt, client.getContext().getClientCollation()); + String authType = authProvider.getType(); + + if (MySqlAuthProvider.NO_AUTH_PROVIDER.equals(authType)) { + // Authentication type is not matter because of it has no authentication type. + // Server need send a Change Authentication Message after handshake response. + authType = MySqlAuthProvider.CACHING_SHA2_PASSWORD; + } + + return HandshakeResponse.from(capability, client.getContext().getClientCollation().getId(), user, authorization, + authType, database, ATTRIBUTES, zstdCompressionLevel); + } + + private boolean isZstdAllowed(Capability capability) { + return capability.isZstdCompression() && compressions.contains(CompressionAlgorithm.ZSTD); + } + + private boolean isZlibAllowed(Capability capability) { + return capability.isZlibCompression() && compressions.contains(CompressionAlgorithm.ZLIB); + } + + private static String authFails(String authType, String phase) { + return "Authentication type '" + authType + "' must require SSL in " + phase + " phase"; + } + + private static boolean isZstdSupported() { + try { + ClassLoader loader = AccessController.doPrivileged((PrivilegedAction) () -> { + ClassLoader cl = Thread.currentThread().getContextClassLoader(); + return cl == null ? ClassLoader.getSystemClassLoader() : cl; + }); + Class.forName("com.github.luben.zstd.Zstd", false, loader); + return true; + } catch (ClassNotFoundException e) { + return false; + } + } +} diff --git a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/MySqlSimpleConnectionMetadata.java b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/MySqlClientConnectionMetadata.java similarity index 60% rename from r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/MySqlSimpleConnectionMetadata.java rename to r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/MySqlClientConnectionMetadata.java index ee7faf42d..61cb1d0b8 100644 --- a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/MySqlSimpleConnectionMetadata.java +++ b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/MySqlClientConnectionMetadata.java @@ -17,39 +17,31 @@ package io.asyncer.r2dbc.mysql; import io.asyncer.r2dbc.mysql.api.MySqlConnectionMetadata; -import org.jetbrains.annotations.Nullable; - -import static io.asyncer.r2dbc.mysql.internal.util.AssertUtils.requireNonNull; +import io.asyncer.r2dbc.mysql.client.Client; /** * Connection metadata for a connection connected to MySQL database. */ -final class MySqlSimpleConnectionMetadata implements MySqlConnectionMetadata { - - private final String version; - - private final String product; +final class MySqlClientConnectionMetadata implements MySqlConnectionMetadata { - private final boolean isMariaDb; + private final Client client; - MySqlSimpleConnectionMetadata(String version, @Nullable String product, boolean isMariaDb) { - this.version = requireNonNull(version, "version must not be null"); - this.product = product == null ? "Unknown" : product; - this.isMariaDb = isMariaDb; + MySqlClientConnectionMetadata(Client client) { + this.client = client; } @Override public String getDatabaseVersion() { - return version; + return client.getContext().getServerVersion().toString(); } @Override public boolean isMariaDb() { - return isMariaDb; + return client.getContext().isMariaDb(); } @Override public String getDatabaseProductName() { - return product; + return client.getContext().getProduct(); } } diff --git a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/MySqlConnectionFactory.java b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/MySqlConnectionFactory.java index 6d76a8bed..d003db2b0 100644 --- a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/MySqlConnectionFactory.java +++ b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/MySqlConnectionFactory.java @@ -18,16 +18,9 @@ import io.asyncer.r2dbc.mysql.api.MySqlConnection; import io.asyncer.r2dbc.mysql.cache.Caches; -import io.asyncer.r2dbc.mysql.cache.PrepareCache; import io.asyncer.r2dbc.mysql.cache.QueryCache; import io.asyncer.r2dbc.mysql.client.Client; -import io.asyncer.r2dbc.mysql.codec.Codecs; -import io.asyncer.r2dbc.mysql.codec.CodecsBuilder; -import io.asyncer.r2dbc.mysql.constant.CompressionAlgorithm; -import io.asyncer.r2dbc.mysql.constant.SslMode; -import io.asyncer.r2dbc.mysql.extension.CodecRegistrar; import io.asyncer.r2dbc.mysql.internal.util.StringUtils; -import io.netty.buffer.ByteBufAllocator; import io.netty.channel.unix.DomainSocketAddress; import io.r2dbc.spi.ConnectionFactory; import io.r2dbc.spi.ConnectionFactoryMetadata; @@ -38,13 +31,9 @@ import java.net.InetSocketAddress; import java.net.SocketAddress; import java.time.ZoneId; -import java.time.ZoneOffset; -import java.util.ArrayList; -import java.util.List; import java.util.Objects; -import java.util.Set; import java.util.concurrent.locks.ReentrantLock; -import java.util.function.Predicate; +import java.util.function.Supplier; import static io.asyncer.r2dbc.mysql.internal.util.AssertUtils.requireNonNull; @@ -93,102 +82,103 @@ public static MySqlConnectionFactory from(MySqlConnectionConfiguration configura address = new DomainSocketAddress(configuration.getDomain()); } - String database = configuration.getDatabase(); - boolean createDbIfNotExist = configuration.isCreateDatabaseIfNotExist(); String user = configuration.getUser(); CharSequence password = configuration.getPassword(); - SslMode sslMode = ssl.getSslMode(); - int zstdCompressionLevel = configuration.getZstdCompressionLevel(); - ZoneId connectionTimeZone = retrieveZoneId(configuration.getConnectionTimeZone()); - ConnectionContext context = new ConnectionContext( - configuration.getZeroDateOption(), - configuration.getLoadLocalInfilePath(), - configuration.getLocalInfileBufferSize(), - configuration.isPreserveInstants(), - connectionTimeZone - ); - Set compressionAlgorithms = configuration.getCompressionAlgorithms(); - Extensions extensions = configuration.getExtensions(); - Predicate prepare = configuration.getPreferPrepareStatement(); - int prepareCacheSize = configuration.getPrepareCacheSize(); Publisher passwordPublisher = configuration.getPasswordPublisher(); - boolean forceTimeZone = configuration.isForceConnectionTimeZoneToSession(); - List sessionVariables = forceTimeZone && connectionTimeZone != null ? - mergeSessionVariables(configuration.getSessionVariables(), connectionTimeZone) : - configuration.getSessionVariables(); if (Objects.nonNull(passwordPublisher)) { return Mono.from(passwordPublisher).flatMap(token -> getMySqlConnection( - configuration, queryCache, - ssl, address, - database, createDbIfNotExist, - user, sslMode, - compressionAlgorithms, zstdCompressionLevel, - context, extensions, sessionVariables, prepare, - prepareCacheSize, token + configuration, ssl, + queryCache, + address, + user, + token )); } return getMySqlConnection( - configuration, queryCache, - ssl, address, - database, createDbIfNotExist, - user, sslMode, - compressionAlgorithms, zstdCompressionLevel, - context, extensions, sessionVariables, prepare, - prepareCacheSize, password + configuration, ssl, + queryCache, + address, + user, + password ); })); } + /** + * Gets an initialized {@link MySqlConnection} from authentication credential and configurations. + *

+ * It contains following steps: + *

  1. Create connection context
  2. + *
  3. Connect to MySQL server with TCP or Unix Domain Socket
  4. + *
  5. Handshake/login and init handshake states
  6. + *
  7. Init session states
+ * + * @param configuration the connection configuration. + * @param ssl the SSL configuration. + * @param queryCache lazy-init query cache, it is shared among all connections from the same factory. + * @param address TCP or Unix Domain Socket address. + * @param user the user of the authentication. + * @param password the password of the authentication. + * @return a {@link MySqlConnection}. + */ private static Mono getMySqlConnection( - final MySqlConnectionConfiguration configuration, - final LazyQueryCache queryCache, - final MySqlSslConfiguration ssl, - final SocketAddress address, - final String database, - final boolean createDbIfNotExist, - final String user, - final SslMode sslMode, - final Set compressionAlgorithms, - final int zstdLevel, - final ConnectionContext context, - final Extensions extensions, - final List sessionVariables, - @Nullable final Predicate prepare, - final int prepareCacheSize, - @Nullable final CharSequence password) { - return Client.connect(ssl, address, configuration.isTcpKeepAlive(), configuration.isTcpNoDelay(), - context, configuration.getConnectTimeout(), configuration.getLoopResources()) - .flatMap(client -> { - // Lazy init database after handshake/login - String db = createDbIfNotExist ? "" : database; - return QueryFlow.login(client, sslMode, db, user, password, compressionAlgorithms, zstdLevel); - }) - .flatMap(client -> { - ByteBufAllocator allocator = client.getByteBufAllocator(); - CodecsBuilder builder = Codecs.builder(); - PrepareCache prepareCache = Caches.createPrepareCache(prepareCacheSize); - String db = createDbIfNotExist ? database : ""; - - extensions.forEach(CodecRegistrar.class, registrar -> - registrar.register(allocator, builder)); - - Mono c = MySqlSimpleConnection.init(client, builder.build(), db, queryCache.get(), - prepareCache, sessionVariables, prepare); - - if (configuration.getLockWaitTimeout() != null) { - c = c.flatMap(connection -> connection.setLockWaitTimeout(configuration.getLockWaitTimeout()) - .thenReturn(connection)); - } - - if (configuration.getStatementTimeout() != null) { - c = c.flatMap(connection -> connection.setStatementTimeout(configuration.getStatementTimeout()) - .thenReturn(connection)); - } - - return c; - }); + final MySqlConnectionConfiguration configuration, + final MySqlSslConfiguration ssl, + final LazyQueryCache queryCache, + final SocketAddress address, + final String user, + @Nullable final CharSequence password + ) { + return Mono.fromSupplier(() -> { + ZoneId connectionTimeZone = retrieveZoneId(configuration.getConnectionTimeZone()); + return new ConnectionContext( + configuration.getZeroDateOption(), + configuration.getLoadLocalInfilePath(), + configuration.getLocalInfileBufferSize(), + configuration.isPreserveInstants(), + connectionTimeZone + ); + }).flatMap(context -> Client.connect( + ssl, + address, + configuration.isTcpKeepAlive(), + configuration.isTcpNoDelay(), + context, + configuration.getConnectTimeout(), + configuration.getLoopResources() + )).flatMap(client -> { + // Lazy init database after handshake/login + boolean deferDatabase = configuration.isCreateDatabaseIfNotExist(); + String database = configuration.getDatabase(); + String loginDb = deferDatabase ? "" : database; + String sessionDb = deferDatabase ? database : ""; + + return InitFlow.initHandshake( + client, + ssl.getSslMode(), + loginDb, + user, + password, + configuration.getCompressionAlgorithms(), + configuration.getZstdCompressionLevel() + ).then(InitFlow.initSession( + client, + sessionDb, + configuration.getPrepareCacheSize(), + configuration.getSessionVariables(), + configuration.isForceConnectionTimeZoneToSession(), + configuration.getLockWaitTimeout(), + configuration.getStatementTimeout(), + configuration.getExtensions() + )).map(codecs -> new MySqlSimpleConnection( + client, + codecs, + queryCache.get(), + configuration.getPreferPrepareStatement() + )).onErrorResume(e -> client.forceClose().then(Mono.error(e))); + }); } @Nullable @@ -202,19 +192,7 @@ private static ZoneId retrieveZoneId(String timeZone) { return StringUtils.parseZoneId(timeZone); } - private static List mergeSessionVariables(List sessionVariables, ZoneId timeZone) { - List res = new ArrayList<>(sessionVariables.size() + 1); - - String offerStr = timeZone instanceof ZoneOffset && "Z".equalsIgnoreCase(timeZone.getId()) ? - "+00:00" : timeZone.getId(); - - res.addAll(sessionVariables); - res.add("time_zone='" + offerStr + "'"); - - return res; - } - - private static final class LazyQueryCache { + private static final class LazyQueryCache implements Supplier { private final int capacity; @@ -227,6 +205,7 @@ private LazyQueryCache(int capacity) { this.capacity = capacity; } + @Override public QueryCache get() { QueryCache cache = this.cache; if (cache == null) { diff --git a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/MySqlSimpleConnection.java b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/MySqlSimpleConnection.java index 660e25e06..ce5ba41e4 100644 --- a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/MySqlSimpleConnection.java +++ b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/MySqlSimpleConnection.java @@ -19,16 +19,13 @@ import io.asyncer.r2dbc.mysql.api.MySqlBatch; import io.asyncer.r2dbc.mysql.api.MySqlConnection; import io.asyncer.r2dbc.mysql.api.MySqlConnectionMetadata; -import io.asyncer.r2dbc.mysql.api.MySqlResult; import io.asyncer.r2dbc.mysql.api.MySqlStatement; import io.asyncer.r2dbc.mysql.api.MySqlTransactionDefinition; -import io.asyncer.r2dbc.mysql.cache.PrepareCache; import io.asyncer.r2dbc.mysql.cache.QueryCache; import io.asyncer.r2dbc.mysql.client.Client; import io.asyncer.r2dbc.mysql.codec.Codecs; import io.asyncer.r2dbc.mysql.constant.ServerStatuses; import io.asyncer.r2dbc.mysql.internal.util.StringUtils; -import io.asyncer.r2dbc.mysql.message.client.InitDbMessage; import io.asyncer.r2dbc.mysql.message.client.PingMessage; import io.asyncer.r2dbc.mysql.message.server.CompleteMessage; import io.asyncer.r2dbc.mysql.message.server.ErrorMessage; @@ -38,18 +35,15 @@ import io.netty.util.internal.logging.InternalLoggerFactory; import io.r2dbc.spi.IsolationLevel; import io.r2dbc.spi.R2dbcNonTransientResourceException; -import io.r2dbc.spi.Readable; import io.r2dbc.spi.TransactionDefinition; import io.r2dbc.spi.ValidationDepth; import org.jetbrains.annotations.Nullable; +import org.jetbrains.annotations.TestOnly; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.core.publisher.SynchronousSink; -import java.time.DateTimeException; import java.time.Duration; -import java.time.ZoneId; -import java.util.List; import java.util.function.BiConsumer; import java.util.function.Function; import java.util.function.Predicate; @@ -60,24 +54,12 @@ /** * An implementation of {@link MySqlConnection} for connecting to the MySQL database. */ -final class MySqlSimpleConnection implements MySqlConnection, ConnectionState { +final class MySqlSimpleConnection implements MySqlConnection { private static final InternalLogger logger = InternalLoggerFactory.getInstance(MySqlSimpleConnection.class); private static final String PING_MARKER = "/* ping */"; - private static final ServerVersion MARIA_11_1_1 = ServerVersion.create(11, 1, 1, true); - - private static final ServerVersion MYSQL_8_0_3 = ServerVersion.create(8, 0, 3); - - private static final ServerVersion MYSQL_5_7_20 = ServerVersion.create(5, 7, 20); - - private static final ServerVersion MYSQL_8 = ServerVersion.create(8, 0, 0); - - private static final ServerVersion MYSQL_5_7_4 = ServerVersion.create(5, 7, 4); - - private static final ServerVersion MARIA_10_1_1 = ServerVersion.create(10, 1, 1, true); - private static final Function VALIDATE = message -> { if (message instanceof CompleteMessage && ((CompleteMessage) message).isDone()) { return true; @@ -106,87 +88,29 @@ final class MySqlSimpleConnection implements MySqlConnection, ConnectionState { } }; - private static final BiConsumer> INIT_DB = (message, sink) -> { - if (message instanceof ErrorMessage) { - ErrorMessage msg = (ErrorMessage) message; - logger.debug("Use database failed: [{}] [{}] {}", msg.getCode(), msg.getSqlState(), - msg.getMessage()); - sink.next(false); - sink.complete(); - } else if (message instanceof CompleteMessage && ((CompleteMessage) message).isDone()) { - sink.next(true); - sink.complete(); - } else { - ReferenceCountUtil.safeRelease(message); - } - }; - - private static final BiConsumer> INIT_DB_AFTER = (message, sink) -> { - if (message instanceof ErrorMessage) { - sink.error(((ErrorMessage) message).toException()); - } else if (message instanceof CompleteMessage && ((CompleteMessage) message).isDone()) { - sink.complete(); - } else { - ReferenceCountUtil.safeRelease(message); - } - }; - private final Client client; private final Codecs codecs; - private final boolean batchSupported; - private final MySqlConnectionMetadata metadata; - private volatile IsolationLevel sessionLevel; - private final QueryCache queryCache; - private final PrepareCache prepareCache; - @Nullable private final Predicate prepare; - /** - * Current isolation level inferred by past statements. - *

- * Inference rules: - *

  1. In the beginning, it is also {@link #sessionLevel}.
  2. - *
  3. After the user calls {@link #setTransactionIsolationLevel(IsolationLevel)}, it will change to - * the user-specified value.
  4. - *
  5. After the end of a transaction (commit or rollback), it will recover to {@link #sessionLevel}.
  6. - *
- */ - private volatile IsolationLevel currentLevel; - - /** - * Session lock wait timeout. - */ - private volatile long lockWaitTimeout; - - /** - * Current transaction lock wait timeout. - */ - private volatile long currentLockWaitTimeout; + // TODO: Check it when executing + private final boolean batchSupported; - MySqlSimpleConnection(Client client, Codecs codecs, IsolationLevel level, - long lockWaitTimeout, QueryCache queryCache, PrepareCache prepareCache, @Nullable String product, - @Nullable Predicate prepare) { + MySqlSimpleConnection(Client client, Codecs codecs, QueryCache queryCache, @Nullable Predicate prepare) { ConnectionContext context = client.getContext(); this.client = client; - this.sessionLevel = level; - this.currentLevel = level; this.codecs = codecs; - this.lockWaitTimeout = lockWaitTimeout; - this.currentLockWaitTimeout = lockWaitTimeout; + this.metadata = new MySqlClientConnectionMetadata(client); this.queryCache = queryCache; - this.prepareCache = prepareCache; - this.metadata = new MySqlSimpleConnectionMetadata(context.getServerVersion().toString(), product, - context.isMariaDb()); - this.batchSupported = context.getCapability().isMultiStatementsAllowed(); this.prepare = prepare; + this.batchSupported = context.getCapability().isMultiStatementsAllowed(); if (this.batchSupported) { logger.debug("Batch is supported by server"); @@ -202,7 +126,7 @@ public Mono beginTransaction() { @Override public Mono beginTransaction(TransactionDefinition definition) { - return Mono.defer(() -> QueryFlow.beginTransaction(client, this, batchSupported, definition)); + return Mono.defer(() -> QueryFlow.beginTransaction(client, batchSupported, definition)); } @Override @@ -219,7 +143,7 @@ public Mono close() { @Override public Mono commitTransaction() { - return Mono.defer(() -> QueryFlow.doneTransaction(client, this, true, batchSupported)); + return Mono.defer(() -> QueryFlow.doneTransaction(client, true, batchSupported)); } @Override @@ -231,7 +155,7 @@ public MySqlBatch createBatch() { public Mono createSavepoint(String name) { requireNonEmpty(name, "Savepoint name must not be empty"); - return QueryFlow.createSavepoint(client, this, name, batchSupported); + return QueryFlow.createSavepoint(client, name, batchSupported); } @Override @@ -247,7 +171,7 @@ public MySqlStatement createStatement(String sql) { if (query.isSimple()) { if (prepare != null && prepare.test(sql)) { logger.debug("Create a simple statement provided by prepare query"); - return new PrepareSimpleStatement(client, codecs, sql, prepareCache); + return new PrepareSimpleStatement(client, codecs, sql); } logger.debug("Create a simple statement provided by text query"); @@ -262,7 +186,7 @@ public MySqlStatement createStatement(String sql) { logger.debug("Create a parameterized statement provided by prepare query"); - return new PrepareParameterizedStatement(client, codecs, query, prepareCache); + return new PrepareParameterizedStatement(client, codecs, query); } @Override @@ -285,7 +209,7 @@ public Mono releaseSavepoint(String name) { @Override public Mono rollbackTransaction() { - return Mono.defer(() -> QueryFlow.doneTransaction(client, this, false, batchSupported)); + return Mono.defer(() -> QueryFlow.doneTransaction(client, false, batchSupported)); } @Override @@ -301,7 +225,7 @@ public MySqlConnectionMetadata getMetadata() { } /** - * MySQL does not have any way to query the isolation level of the current transaction, only inferred from past + * MySQL does not have a way to query the isolation level of the current transaction, only inferred from past * statements, so driver can not make sure the result is right. *

* See MySQL Bug 53341 @@ -310,16 +234,7 @@ public MySqlConnectionMetadata getMetadata() { */ @Override public IsolationLevel getTransactionIsolationLevel() { - return currentLevel; - } - - /** - * Gets session transaction isolation level(Only for testing). - * - * @return session transaction isolation level. - */ - IsolationLevel getSessionTransactionIsolationLevel() { - return sessionLevel; + return client.getContext().getCurrentIsolationLevel(); } @Override @@ -330,9 +245,11 @@ public Mono setTransactionIsolationLevel(IsolationLevel isolationLevel) { return QueryFlow.executeVoid(client, "SET SESSION TRANSACTION ISOLATION LEVEL " + isolationLevel.asSql()) .doOnSuccess(ignored -> { - this.sessionLevel = isolationLevel; - if (!this.isInTransaction()) { - this.currentLevel = isolationLevel; + ConnectionContext context = client.getContext(); + + context.setSessionIsolationLevel(isolationLevel); + if (!context.isInTransaction()) { + context.setCurrentIsolationLevel(isolationLevel); } }); } @@ -366,12 +283,13 @@ public Mono validate(ValidationDepth depth) { public boolean isAutoCommit() { // Within transaction, autocommit remains disabled until end the transaction with COMMIT or ROLLBACK. // The autocommit mode then reverts to its previous state. - return !isInTransaction() && isSessionAutoCommit(); + return !client.getContext().isInTransaction() && isSessionAutoCommit(); } @Override public Mono setAutoCommit(boolean autoCommit) { return Mono.defer(() -> { + // TODO: remove the check or checking when executing if (autoCommit == isSessionAutoCommit()) { return Mono.empty(); } @@ -380,321 +298,58 @@ public Mono setAutoCommit(boolean autoCommit) { }); } - @Override - public void setIsolationLevel(IsolationLevel level) { - this.currentLevel = level; - } - - @Override - public long getSessionLockWaitTimeout() { - return lockWaitTimeout; - } - - @Override - public void setCurrentLockWaitTimeout(long timeoutSeconds) { - this.currentLockWaitTimeout = timeoutSeconds; - } - - @Override - public void resetIsolationLevel() { - this.currentLevel = this.sessionLevel; - } - - @Override - public boolean isLockWaitTimeoutChanged() { - return currentLockWaitTimeout != lockWaitTimeout; - } - - @Override - public void resetCurrentLockWaitTimeout() { - this.currentLockWaitTimeout = this.lockWaitTimeout; - } - - @Override - public boolean isInTransaction() { - return (client.getContext().getServerStatuses() & ServerStatuses.IN_TRANSACTION) != 0; - } - @Override public Mono setLockWaitTimeout(Duration timeout) { requireNonNull(timeout, "timeout must not be null"); - if (!client.getContext().isLockWaitTimeoutSupported()) { - logger.warn("Lock wait timeout is not supported by server, setLockWaitTimeout operation is ignored"); - return Mono.empty(); + if (client.getContext().isLockWaitTimeoutSupported()) { + return QueryFlow.executeVoid(client, StringUtils.lockWaitTimeoutStatement(timeout)) + .doOnSuccess(ignored -> client.getContext().setAllLockWaitTimeout(timeout)); } - long timeoutSeconds = timeout.getSeconds(); - return QueryFlow.executeVoid(client, "SET innodb_lock_wait_timeout=" + timeoutSeconds) - .doOnSuccess(ignored -> this.lockWaitTimeout = this.currentLockWaitTimeout = timeoutSeconds); + logger.warn("Lock wait timeout is not supported by server, setLockWaitTimeout operation is ignored"); + return Mono.empty(); + } @Override public Mono setStatementTimeout(Duration timeout) { requireNonNull(timeout, "timeout must not be null"); - final ConnectionContext context = client.getContext(); - final boolean isMariaDb = context.isMariaDb(); - final ServerVersion serverVersion = context.getServerVersion(); - final long timeoutMs = timeout.toMillis(); - final String sql = isMariaDb ? "SET max_statement_time=" + timeoutMs / 1000.0 - : "SET SESSION MAX_EXECUTION_TIME=" + timeoutMs; + ConnectionContext context = client.getContext(); // mariadb: https://mariadb.com/kb/en/aborting-statements/ // mysql: https://dev.mysql.com/blog-archive/server-side-select-statement-timeouts/ // ref: https://github.com/mariadb-corporation/mariadb-connector-r2dbc - if (isMariaDb && serverVersion.isGreaterThanOrEqualTo(MARIA_10_1_1) - || !isMariaDb && serverVersion.isGreaterThanOrEqualTo(MYSQL_5_7_4)) { - return QueryFlow.executeVoid(client, sql); + if (context.isStatementTimeoutSupported()) { + String variable = StringUtils.statementTimeoutVariable(timeout, context.isMariaDb()); + return QueryFlow.setSessionVariable(client, variable); } return Mono.error( new R2dbcNonTransientResourceException( - "Statement timeout is not supported by server version " + serverVersion, + "Statement timeout is not supported by server version " + context.getServerVersion(), "HY000", - -1, - sql + -1 ) ); } - private boolean isSessionAutoCommit() { - return (client.getContext().getServerStatuses() & ServerStatuses.AUTO_COMMIT) != 0; - } - - static Flux doPingInternal(Client client) { - return client.exchange(PingMessage.INSTANCE, PING); - } - /** - * Initialize a {@link MySqlConnection} after login. + * Visible only for testing. * - * @param client must be logged-in. - * @param codecs the {@link Codecs}. - * @param database the database that should be lazy init. - * @param queryCache the cache of {@link Query}. - * @param prepareCache the cache of server-preparing result. - * @param sessionVariables the session variables to set. - * @param prepare judging for prefer use prepare statement to execute simple query. - * @return a {@link Mono} will emit an initialized {@link MySqlConnection}. + * @return current connection context */ - static Mono init( - Client client, Codecs codecs, String database, - QueryCache queryCache, PrepareCache prepareCache, - List sessionVariables, @Nullable Predicate prepare - ) { - Mono connection = initSessionVariables(client, sessionVariables) - .then(loadSessionVariables(client, codecs)) - .flatMap(data -> loadInnoDbEngineStatus(data, client, codecs)) - .map(data -> { - ConnectionContext context = client.getContext(); - ZoneId timeZone = data.timeZone; - if (timeZone != null) { - logger.debug("Got server time zone {} from loading session variables", timeZone); - context.initTimeZone(timeZone); - } - - if (data.lockWaitTimeoutSupported) { - context.enableLockWaitTimeoutSupported(); - } else { - logger.info("Lock wait timeout is not supported by server, all related operations will be ignored"); - } - - return new MySqlSimpleConnection(client, codecs, data.level, data.lockWaitTimeout, - queryCache, prepareCache, data.product, prepare); - }); - - if (database.isEmpty()) { - return connection; - } - - return connection.flatMap(c -> initDatabase(client, database).thenReturn(c)); - } - - private static Mono initSessionVariables(Client client, List sessionVariables) { - if (sessionVariables.isEmpty()) { - return Mono.empty(); - } - - StringBuilder query = new StringBuilder(sessionVariables.size() * 32 + 16).append("SET "); - boolean comma = false; - - for (String variable : sessionVariables) { - if (variable.isEmpty()) { - continue; - } - - if (comma) { - query.append(','); - } else { - comma = true; - } - - if (variable.startsWith("@")) { - query.append(variable); - } else { - query.append("SESSION ").append(variable); - } - } - - return QueryFlow.executeVoid(client, query.toString()); - } - - private static Mono loadSessionVariables(Client client, Codecs codecs) { - ConnectionContext context = client.getContext(); - StringBuilder query = new StringBuilder(128) - .append("SELECT ") - .append(transactionIsolationColumn(context)) - .append(",@@version_comment AS v"); - - Function> handler; - - if (context.isTimeZoneInitialized()) { - handler = r -> convertSessionData(r, false); - } else { - query.append(",@@system_time_zone AS s,@@time_zone AS t"); - handler = r -> convertSessionData(r, true); - } - - return new TextSimpleStatement(client, codecs, query.toString()) - .execute() - .flatMap(handler) - .last(); - } - - private static Mono loadInnoDbEngineStatus(SessionData data, Client client, Codecs codecs) { - return new TextSimpleStatement(client, codecs, "SHOW VARIABLES LIKE 'innodb\\\\_lock\\\\_wait\\\\_timeout'") - .execute() - .flatMap(r -> r.map(readable -> { - String value = readable.get(1, String.class); - - if (value == null || value.isEmpty()) { - return data; - } else { - return data.lockWaitTimeout(Long.parseLong(value)); - } - })) - .single(data); - } - - private static Mono initDatabase(Client client, String database) { - return client.exchange(new InitDbMessage(database), INIT_DB) - .last() - .flatMap(success -> { - if (success) { - return Mono.empty(); - } - - String sql = "CREATE DATABASE IF NOT EXISTS " + StringUtils.quoteIdentifier(database); - - return QueryFlow.executeVoid(client, sql) - .then(client.exchange(new InitDbMessage(database), INIT_DB_AFTER).then()); - }); - } - - private static Flux convertSessionData(MySqlResult r, boolean timeZone) { - return r.map(readable -> { - IsolationLevel level = convertIsolationLevel(readable.get(0, String.class)); - String product = readable.get(1, String.class); - - return new SessionData(level, product, timeZone ? readZoneId(readable) : null); - }); - } - - private static ZoneId readZoneId(Readable readable) { - String systemTimeZone = readable.get(2, String.class); - String timeZone = readable.get(3, String.class); - - if (timeZone == null || timeZone.isEmpty() || "SYSTEM".equalsIgnoreCase(timeZone)) { - if (systemTimeZone == null || systemTimeZone.isEmpty()) { - logger.warn("MySQL does not return any timezone, trying to use system default timezone"); - return ZoneId.systemDefault().normalized(); - } else { - return convertZoneId(systemTimeZone); - } - } else { - return convertZoneId(timeZone); - } - } - - private static ZoneId convertZoneId(String id) { - try { - return StringUtils.parseZoneId(id); - } catch (DateTimeException e) { - logger.warn("The server timezone is unknown <{}>, trying to use system default timezone", id, e); - - return ZoneId.systemDefault().normalized(); - } - } - - private static IsolationLevel convertIsolationLevel(@Nullable String name) { - if (name == null) { - logger.warn("Isolation level is null in current session, fallback to repeatable read"); - - return IsolationLevel.REPEATABLE_READ; - } - - switch (name) { - case "READ-UNCOMMITTED": - return IsolationLevel.READ_UNCOMMITTED; - case "READ-COMMITTED": - return IsolationLevel.READ_COMMITTED; - case "REPEATABLE-READ": - return IsolationLevel.REPEATABLE_READ; - case "SERIALIZABLE": - return IsolationLevel.SERIALIZABLE; - } - - logger.warn("Unknown isolation level {} in current session, fallback to repeatable read", name); - - return IsolationLevel.REPEATABLE_READ; + @TestOnly + ConnectionContext context() { + return client.getContext(); } - /** - * Resolves the column of session isolation level, the {@literal @@tx_isolation} has been marked as deprecated. - *

- * If server is MariaDB, {@literal @@transaction_isolation} is used starting from {@literal 11.1.1}. - *

- * If the server is MySQL, use {@literal @@transaction_isolation} starting from {@literal 8.0.3}, or between - * {@literal 5.7.20} and {@literal 8.0.0} (exclusive). - */ - private static String transactionIsolationColumn(ConnectionContext context) { - ServerVersion version = context.getServerVersion(); - - if (context.isMariaDb()) { - return version.isGreaterThanOrEqualTo(MARIA_11_1_1) ? "@@transaction_isolation AS i" : - "@@tx_isolation AS i"; - } - - return version.isGreaterThanOrEqualTo(MYSQL_8_0_3) || - (version.isGreaterThanOrEqualTo(MYSQL_5_7_20) && version.isLessThan(MYSQL_8)) ? - "@@transaction_isolation AS i" : "@@tx_isolation AS i"; + private boolean isSessionAutoCommit() { + return (client.getContext().getServerStatuses() & ServerStatuses.AUTO_COMMIT) != 0; } - private static final class SessionData { - - private final IsolationLevel level; - - @Nullable - private final String product; - - @Nullable - private final ZoneId timeZone; - - private long lockWaitTimeout = -1; - - private boolean lockWaitTimeoutSupported; - - private SessionData(IsolationLevel level, @Nullable String product, @Nullable ZoneId timeZone) { - this.level = level; - this.product = product; - this.timeZone = timeZone; - } - - SessionData lockWaitTimeout(long timeout) { - this.lockWaitTimeoutSupported = true; - this.lockWaitTimeout = timeout; - return this; - } + static Flux doPingInternal(Client client) { + return client.exchange(PingMessage.INSTANCE, PING); } } diff --git a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/PrepareParameterizedStatement.java b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/PrepareParameterizedStatement.java index d9e290811..44edd9509 100644 --- a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/PrepareParameterizedStatement.java +++ b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/PrepareParameterizedStatement.java @@ -18,7 +18,6 @@ import io.asyncer.r2dbc.mysql.api.MySqlResult; import io.asyncer.r2dbc.mysql.api.MySqlStatement; -import io.asyncer.r2dbc.mysql.cache.PrepareCache; import io.asyncer.r2dbc.mysql.client.Client; import io.asyncer.r2dbc.mysql.codec.Codecs; import io.asyncer.r2dbc.mysql.internal.util.StringUtils; @@ -33,20 +32,17 @@ */ final class PrepareParameterizedStatement extends ParameterizedStatementSupport { - private final PrepareCache prepareCache; - private int fetchSize = 0; - PrepareParameterizedStatement(Client client, Codecs codecs, Query query, PrepareCache prepareCache) { + PrepareParameterizedStatement(Client client, Codecs codecs, Query query) { super(client, codecs, query); - this.prepareCache = prepareCache; } @Override public Flux execute(List bindings) { return Flux.defer(() -> QueryFlow.execute(client, StringUtils.extendReturning(query.getFormattedSql(), returningIdentifiers()), - bindings, fetchSize, prepareCache + bindings, fetchSize )) .map(messages -> MySqlSegmentResult.toResult(true, client, codecs, syntheticKeyName(), messages)); } diff --git a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/PrepareSimpleStatement.java b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/PrepareSimpleStatement.java index d78bb3488..7ff6b06f6 100644 --- a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/PrepareSimpleStatement.java +++ b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/PrepareSimpleStatement.java @@ -36,19 +36,16 @@ final class PrepareSimpleStatement extends SimpleStatementSupport { private static final List BINDINGS = Collections.singletonList(new Binding(0)); - private final PrepareCache prepareCache; - private int fetchSize = 0; - PrepareSimpleStatement(Client client, Codecs codecs, String sql, PrepareCache prepareCache) { + PrepareSimpleStatement(Client client, Codecs codecs, String sql) { super(client, codecs, sql); - this.prepareCache = prepareCache; } @Override public Flux execute() { return Flux.defer(() -> QueryFlow.execute(client, - StringUtils.extendReturning(sql, returningIdentifiers()), BINDINGS, fetchSize, prepareCache)) + StringUtils.extendReturning(sql, returningIdentifiers()), BINDINGS, fetchSize)) .map(messages -> MySqlSegmentResult.toResult(true, client, codecs, syntheticKeyName(), messages)); } diff --git a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/QueryFlow.java b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/QueryFlow.java index e7a5de4bc..23ce5e806 100644 --- a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/QueryFlow.java +++ b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/QueryFlow.java @@ -18,19 +18,12 @@ import io.asyncer.r2dbc.mysql.api.MySqlBatch; import io.asyncer.r2dbc.mysql.api.MySqlTransactionDefinition; -import io.asyncer.r2dbc.mysql.authentication.MySqlAuthProvider; -import io.asyncer.r2dbc.mysql.cache.PrepareCache; import io.asyncer.r2dbc.mysql.client.Client; import io.asyncer.r2dbc.mysql.client.FluxExchangeable; -import io.asyncer.r2dbc.mysql.constant.CompressionAlgorithm; import io.asyncer.r2dbc.mysql.constant.ServerStatuses; -import io.asyncer.r2dbc.mysql.constant.SslMode; import io.asyncer.r2dbc.mysql.internal.util.StringUtils; -import io.asyncer.r2dbc.mysql.message.client.AuthResponse; import io.asyncer.r2dbc.mysql.message.client.ClientMessage; -import io.asyncer.r2dbc.mysql.message.client.HandshakeResponse; import io.asyncer.r2dbc.mysql.message.client.LocalInfileResponse; -import io.asyncer.r2dbc.mysql.message.client.SubsequenceClientMessage; import io.asyncer.r2dbc.mysql.message.client.PingMessage; import io.asyncer.r2dbc.mysql.message.client.PrepareQueryMessage; import io.asyncer.r2dbc.mysql.message.client.PreparedCloseMessage; @@ -38,29 +31,21 @@ import io.asyncer.r2dbc.mysql.message.client.PreparedFetchMessage; import io.asyncer.r2dbc.mysql.message.client.PreparedResetMessage; import io.asyncer.r2dbc.mysql.message.client.PreparedTextQueryMessage; -import io.asyncer.r2dbc.mysql.message.client.SslRequest; import io.asyncer.r2dbc.mysql.message.client.TextQueryMessage; -import io.asyncer.r2dbc.mysql.message.server.AuthMoreDataMessage; -import io.asyncer.r2dbc.mysql.message.server.ChangeAuthMessage; import io.asyncer.r2dbc.mysql.message.server.CompleteMessage; import io.asyncer.r2dbc.mysql.message.server.EofMessage; import io.asyncer.r2dbc.mysql.message.server.ErrorMessage; -import io.asyncer.r2dbc.mysql.message.server.HandshakeHeader; -import io.asyncer.r2dbc.mysql.message.server.HandshakeRequest; import io.asyncer.r2dbc.mysql.message.server.LocalInfileRequest; import io.asyncer.r2dbc.mysql.message.server.OkMessage; import io.asyncer.r2dbc.mysql.message.server.PreparedOkMessage; import io.asyncer.r2dbc.mysql.message.server.ServerMessage; import io.asyncer.r2dbc.mysql.message.server.ServerStatusMessage; import io.asyncer.r2dbc.mysql.message.server.SyntheticMetadataMessage; -import io.asyncer.r2dbc.mysql.message.server.SyntheticSslResponseMessage; import io.netty.util.ReferenceCountUtil; import io.netty.util.ReferenceCounted; import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLoggerFactory; import io.r2dbc.spi.IsolationLevel; -import io.r2dbc.spi.R2dbcNonTransientResourceException; -import io.r2dbc.spi.R2dbcPermissionDeniedException; import io.r2dbc.spi.TransactionDefinition; import org.jetbrains.annotations.Nullable; import reactor.core.CoreSubscriber; @@ -72,15 +57,10 @@ import reactor.core.publisher.SynchronousSink; import reactor.util.concurrent.Queues; -import java.security.AccessController; -import java.security.PrivilegedAction; import java.time.Duration; import java.util.ArrayList; -import java.util.Collections; import java.util.Iterator; import java.util.List; -import java.util.Map; -import java.util.Set; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Consumer; @@ -116,18 +96,16 @@ final class QueryFlow { * @param sql the statement for exception tracing. * @param bindings the data of bindings. * @param fetchSize the size of fetching, if it less than or equal to {@literal 0} means fetch all rows. - * @param cache the cache of server-preparing result. * @return the messages received in response to this exchange. */ - static Flux> execute(Client client, String sql, List bindings, int fetchSize, - PrepareCache cache) { + static Flux> execute(Client client, String sql, List bindings, int fetchSize) { return Flux.defer(() -> { if (bindings.isEmpty()) { return Flux.empty(); } // Note: the prepared SQL may not be sent when the cache matches. - return client.exchange(new PrepareExchangeable(cache, sql, bindings.iterator(), fetchSize)) + return client.exchange(new PrepareExchangeable(client, sql, bindings.iterator(), fetchSize)) .windowUntil(RESULT_DONE); }); } @@ -194,29 +172,6 @@ static Flux> execute(Client client, List statements) }); } - /** - * Login a {@link Client} and receive the {@code client} after logon. It will emit an exception when client receives - * a {@link ErrorMessage}. - * - * @param client the {@link Client} to exchange messages with. - * @param sslMode the {@link SslMode} defines SSL capability and behavior. - * @param database the database that will be connected. - * @param user the user that will be login. - * @param password the password of the {@code user}. - * @param compressionAlgorithms the list of compression algorithms. - * @param zstdCompressionLevel the zstd compression level. - * @param context the {@link ConnectionContext} for initialization. - * @return the messages received in response to the login exchange. - */ - static Mono login(Client client, SslMode sslMode, String database, String user, - @Nullable CharSequence password, - Set compressionAlgorithms, int zstdCompressionLevel) { - return client.exchange(new LoginExchangeable(client, sslMode, database, user, password, - compressionAlgorithms, zstdCompressionLevel)) - .onErrorResume(e -> client.forceClose().then(Mono.error(e))) - .then(Mono.just(client)); - } - /** * Execute a simple query and return a {@link Mono} for the complete signal or error. Query execution terminates * with the last {@link CompleteMessage} or a {@link ErrorMessage}. The {@link ErrorMessage} will emit an exception. @@ -245,17 +200,15 @@ static Mono executeVoid(Client client, String sql) { /** * Begins a new transaction with a {@link TransactionDefinition}. It will change current transaction statuses of - * the {@link ConnectionState}. + * the {@link ConnectionContext}. * * @param client the {@link Client} to exchange messages with. - * @param state the connection state for checks and sets transaction statuses. * @param batchSupported if connection supports batch query. * @param definition the {@link TransactionDefinition}. * @return receives complete signal. */ - static Mono beginTransaction(Client client, ConnectionState state, boolean batchSupported, - TransactionDefinition definition) { - final StartTransactionState startState = new StartTransactionState(state, definition, client); + static Mono beginTransaction(Client client, boolean batchSupported, TransactionDefinition definition) { + final StartTransactionState startState = new StartTransactionState(client, definition); if (batchSupported) { return client.exchange(new TransactionBatchExchangeable(startState)).then(); @@ -265,18 +218,15 @@ static Mono beginTransaction(Client client, ConnectionState state, boolean } /** - * Commits or rollbacks current transaction. It will recover statuses of the {@link ConnectionState} in the initial - * connection state. + * Commits or rollbacks current transaction. It will recover statuses of the {@link ConnectionContext}. * * @param client the {@link Client} to exchange messages with. - * @param state the connection state for checks and resets transaction statuses. * @param commit if it is commit, otherwise rollback. * @param batchSupported if connection supports batch query. * @return receives complete signal. */ - static Mono doneTransaction(Client client, ConnectionState state, boolean commit, - boolean batchSupported) { - final CommitRollbackState commitState = new CommitRollbackState(state, commit); + static Mono doneTransaction(Client client, boolean commit, boolean batchSupported) { + final CommitRollbackState commitState = new CommitRollbackState(client, commit); if (batchSupported) { return client.exchange(new TransactionBatchExchangeable(commitState)).then(); @@ -285,15 +235,80 @@ static Mono doneTransaction(Client client, ConnectionState state, boolean return client.exchange(new TransactionMultiExchangeable(commitState)).then(); } - static Mono createSavepoint(Client client, ConnectionState state, String name, - boolean batchSupported) { - final CreateSavepointState savepointState = new CreateSavepointState(state, name); + /** + * Creates a savepoint with a name. It will begin a new transaction before creating a savepoint if the connection is + * not in a transaction. + * + * @param client the {@link Client} to exchange messages with. + * @param name the name of the savepoint. + * @param batchSupported if connection supports batch query. + * @return a {@link Mono} receives complete signal. + */ + static Mono createSavepoint(Client client, String name, boolean batchSupported) { + final CreateSavepointState savepointState = new CreateSavepointState(client, name); if (batchSupported) { return client.exchange(new TransactionBatchExchangeable(savepointState)).then(); } return client.exchange(new TransactionMultiExchangeable(savepointState)).then(); } + /** + * Sets a session variable to the server. + * + * @param client the {@link Client} to exchange messages with. + * @param variable the session variable to set, e.g. {@code "sql_mode='ANSI'"}. + * @return a {@link Mono} receives complete signal. + */ + static Mono setSessionVariable(Client client, String variable) { + if (variable.isEmpty()) { + return Mono.empty(); + } else if (variable.startsWith("@")) { + return executeVoid(client, "SET " + variable); + } + + return executeVoid(client, "SET SESSION " + variable); + } + + /** + * Sets multiple session variables to the server. + * + * @param client the {@link Client} to exchange messages with. + * @param sessionVariables the session variables to set, e.g. {@code ["sql_mode='ANSI'", "time_zone='+09:00'"]}. + * @return a {@link Mono} receives complete signal. + */ + static Mono setSessionVariables(Client client, List sessionVariables) { + switch (sessionVariables.size()) { + case 0: + return Mono.empty(); + case 1: + return setSessionVariable(client, sessionVariables.get(0)); + default: { + StringBuilder query = new StringBuilder(sessionVariables.size() * 32 + 16).append("SET "); + boolean comma = false; + + for (String variable : sessionVariables) { + if (variable.isEmpty()) { + continue; + } + + if (comma) { + query.append(','); + } else { + comma = true; + } + + if (variable.startsWith("@")) { + query.append(variable); + } else { + query.append("SESSION ").append(variable); + } + } + + return executeVoid(client, query.toString()); + } + } + } + /** * Execute a simple query statement. Query execution terminates with the last {@link CompleteMessage} or a * {@link ErrorMessage}. The {@link ErrorMessage} will emit an exception. The exchange will be completed by @@ -544,7 +559,7 @@ final class PrepareExchangeable extends FluxExchangeable { private final Sinks.Many requests = Sinks.many().unicast() .onBackpressureBuffer(Queues.one().get()); - private final PrepareCache cache; + private final Client client; private final String sql; @@ -559,8 +574,8 @@ final class PrepareExchangeable extends FluxExchangeable { private boolean shouldClose; - PrepareExchangeable(PrepareCache cache, String sql, Iterator bindings, int fetchSize) { - this.cache = cache; + PrepareExchangeable(Client client, String sql, Iterator bindings, int fetchSize) { + this.client = client; this.sql = sql; this.bindings = bindings; this.fetchSize = fetchSize; @@ -572,7 +587,7 @@ public void subscribe(CoreSubscriber actual) { requests.asFlux().subscribe(actual); // After subscribe. - Integer statementId = cache.getIfPresent(sql); + Integer statementId = client.getContext().getPrepareCache().getIfPresent(sql); if (statementId == null) { logger.debug("Prepare cache mismatch, try to preparing"); this.shouldClose = true; @@ -713,7 +728,7 @@ private void putToCache(Integer statementId) { boolean putSucceed; try { - putSucceed = cache.putIfAbsent(sql, statementId, evictId -> { + putSucceed = client.getContext().getPrepareCache().putIfAbsent(sql, statementId, evictId -> { logger.debug("Prepare cache evicts statement {} when putting", evictId); Sinks.EmitResult result = requests.tryEmitNext(new PreparedCloseMessage(evictId)); @@ -809,292 +824,9 @@ private void onCompleteMessage(CompleteMessage message, SynchronousSink - * Not like other {@link FluxExchangeable}s, it is started by a server-side message, which should be an implementation - * of {@link HandshakeRequest}. - */ -final class LoginExchangeable extends FluxExchangeable { - - private static final InternalLogger logger = InternalLoggerFactory.getInstance(LoginExchangeable.class); - - private static final Map ATTRIBUTES = Collections.emptyMap(); - - private static final String CLI_SPECIFIC = "HY000"; - - private static final int HANDSHAKE_VERSION = 10; - - private final Sinks.Many requests = Sinks.many().unicast() - .onBackpressureBuffer(Queues.one().get()); - - private final Client client; - - private final SslMode sslMode; - - private final String database; - - private final String user; - - @Nullable - private final CharSequence password; - - private final Set compressions; - - private final int zstdCompressionLevel; - - private boolean handshake = true; - - private MySqlAuthProvider authProvider; - - private byte[] salt; - - private boolean sslCompleted; - - LoginExchangeable(Client client, SslMode sslMode, String database, String user, - @Nullable CharSequence password, Set compressions, - int zstdCompressionLevel) { - this.client = client; - this.sslMode = sslMode; - this.database = database; - this.user = user; - this.password = password; - this.compressions = compressions; - this.zstdCompressionLevel = zstdCompressionLevel; - this.sslCompleted = sslMode == SslMode.TUNNEL; - } - - @Override - public void subscribe(CoreSubscriber actual) { - requests.asFlux().subscribe(actual); - } - - @Override - public void accept(ServerMessage message, SynchronousSink sink) { - if (message instanceof ErrorMessage) { - sink.error(((ErrorMessage) message).toException()); - return; - } - - // Ensures it will be initialized only once. - if (handshake) { - handshake = false; - if (message instanceof HandshakeRequest) { - HandshakeRequest request = (HandshakeRequest) message; - Capability capability = initHandshake(request); - - if (capability.isSslEnabled()) { - emitNext(SslRequest.from(capability, client.getContext().getClientCollation().getId()), sink); - } else { - emitNext(createHandshakeResponse(capability), sink); - } - } else { - sink.error(new R2dbcPermissionDeniedException("Unexpected message type '" + - message.getClass().getSimpleName() + "' in init phase")); - } - - return; - } - - if (message instanceof OkMessage) { - client.loginSuccess(); - sink.complete(); - } else if (message instanceof SyntheticSslResponseMessage) { - sslCompleted = true; - emitNext(createHandshakeResponse(client.getContext().getCapability()), sink); - } else if (message instanceof AuthMoreDataMessage) { - AuthMoreDataMessage msg = (AuthMoreDataMessage) message; - - if (msg.isFailed()) { - if (logger.isDebugEnabled()) { - logger.debug("Connection (id {}) fast authentication failed, use full authentication", - client.getContext().getConnectionId()); - } - - emitNext(createAuthResponse("full authentication"), sink); - } - // Otherwise success, wait until OK message or Error message. - } else if (message instanceof ChangeAuthMessage) { - ChangeAuthMessage msg = (ChangeAuthMessage) message; - - authProvider = MySqlAuthProvider.build(msg.getAuthType()); - salt = msg.getSalt(); - emitNext(createAuthResponse("change authentication"), sink); - } else { - sink.error(new R2dbcPermissionDeniedException("Unexpected message type '" + - message.getClass().getSimpleName() + "' in login phase")); - } - } - - @Override - public void dispose() { - // No particular error condition handling for complete signal. - this.requests.tryEmitComplete(); - } - - private void emitNext(SubsequenceClientMessage message, SynchronousSink sink) { - Sinks.EmitResult result = requests.tryEmitNext(message); - - if (result != Sinks.EmitResult.OK) { - sink.error(new IllegalStateException("Fail to emit a login request due to " + result)); - } - } - - private AuthResponse createAuthResponse(String phase) { - MySqlAuthProvider authProvider = getAndNextProvider(); - - if (authProvider.isSslNecessary() && !sslCompleted) { - throw new R2dbcPermissionDeniedException(authFails(authProvider.getType(), phase), CLI_SPECIFIC); - } - - return new AuthResponse(authProvider.authentication(password, salt, client.getContext().getClientCollation())); - } - - private Capability clientCapability(Capability serverCapability) { - Capability.Builder builder = serverCapability.mutate(); - - builder.disableSessionTrack(); - builder.disableDatabasePinned(); - builder.disableIgnoreAmbiguitySpace(); - builder.disableInteractiveTimeout(); - - if (sslMode == SslMode.TUNNEL) { - // Tunnel does not use MySQL SSL protocol, disable it. - builder.disableSsl(); - } else if (!serverCapability.isSslEnabled()) { - // Server unsupported SSL. - if (sslMode.requireSsl()) { - // Before handshake, Client.context does not be initialized - throw new R2dbcPermissionDeniedException("Server does not support SSL but mode '" + sslMode + - "' requires SSL", CLI_SPECIFIC); - } else if (sslMode.startSsl()) { - // SSL has start yet, and client can disable SSL, disable now. - client.sslUnsupported(); - } - } else { - // The server supports SSL, but the user does not want to use SSL, disable it. - if (!sslMode.startSsl()) { - builder.disableSsl(); - } - } - - if (isZstdAllowed(serverCapability)) { - if (isZstdSupported()) { - builder.disableZlibCompression(); - } else { - logger.warn("Server supports zstd, but zstd-jni dependency is missing"); - - if (isZlibAllowed(serverCapability)) { - builder.disableZstdCompression(); - } else if (compressions.contains(CompressionAlgorithm.UNCOMPRESSED)) { - builder.disableCompression(); - } else { - throw new R2dbcNonTransientResourceException( - "Environment does not support a compression algorithm in " + compressions + - ", config does not allow uncompressed mode", CLI_SPECIFIC); - } - } - } else if (isZlibAllowed(serverCapability)) { - builder.disableZstdCompression(); - } else if (compressions.contains(CompressionAlgorithm.UNCOMPRESSED)) { - builder.disableCompression(); - } else { - throw new R2dbcPermissionDeniedException( - "Environment does not support a compression algorithm in " + compressions + - ", config does not allow uncompressed mode", CLI_SPECIFIC); - } - - if (database.isEmpty()) { - builder.disableConnectWithDatabase(); - } - - if (client.getContext().getLocalInfilePath() == null) { - builder.disableLoadDataLocalInfile(); - } - - if (ATTRIBUTES.isEmpty()) { - builder.disableConnectAttributes(); - } - - return builder.build(); - } - - private Capability initHandshake(HandshakeRequest message) { - HandshakeHeader header = message.getHeader(); - int handshakeVersion = header.getProtocolVersion(); - ServerVersion serverVersion = header.getServerVersion(); - - if (handshakeVersion < HANDSHAKE_VERSION) { - logger.warn("MySQL use handshake V{}, server version is {}, maybe most features are unavailable", - handshakeVersion, serverVersion); - } - - Capability capability = clientCapability(message.getServerCapability()); - - // No need initialize server statuses because it has initialized by read filter. - this.client.getContext().init(header.getConnectionId(), serverVersion, capability); - this.authProvider = MySqlAuthProvider.build(message.getAuthType()); - this.salt = message.getSalt(); - - return capability; - } - - private MySqlAuthProvider getAndNextProvider() { - MySqlAuthProvider authProvider = this.authProvider; - this.authProvider = authProvider.next(); - return authProvider; - } - - private HandshakeResponse createHandshakeResponse(Capability capability) { - MySqlAuthProvider authProvider = getAndNextProvider(); - - if (authProvider.isSslNecessary() && !sslCompleted) { - throw new R2dbcPermissionDeniedException(authFails(authProvider.getType(), "handshake"), - CLI_SPECIFIC); - } - - byte[] authorization = authProvider.authentication(password, salt, client.getContext().getClientCollation()); - String authType = authProvider.getType(); - - if (MySqlAuthProvider.NO_AUTH_PROVIDER.equals(authType)) { - // Authentication type is not matter because of it has no authentication type. - // Server need send a Change Authentication Message after handshake response. - authType = MySqlAuthProvider.CACHING_SHA2_PASSWORD; - } - - return HandshakeResponse.from(capability, client.getContext().getClientCollation().getId(), user, authorization, - authType, database, ATTRIBUTES, zstdCompressionLevel); - } - - private boolean isZstdAllowed(Capability capability) { - return capability.isZstdCompression() && compressions.contains(CompressionAlgorithm.ZSTD); - } - - private boolean isZlibAllowed(Capability capability) { - return capability.isZlibCompression() && compressions.contains(CompressionAlgorithm.ZLIB); - } - - private static String authFails(String authType, String phase) { - return "Authentication type '" + authType + "' must require SSL in " + phase + " phase"; - } - - private static boolean isZstdSupported() { - try { - ClassLoader loader = AccessController.doPrivileged((PrivilegedAction) () -> { - ClassLoader cl = Thread.currentThread().getContextClassLoader(); - return cl == null ? ClassLoader.getSystemClassLoader() : cl; - }); - Class.forName("com.github.luben.zstd.Zstd", false, loader); - return true; - } catch (ClassNotFoundException e) { - return false; - } - } -} - abstract class AbstractTransactionState { - final ConnectionState state; + final Client client; final List statements = new ArrayList<>(5); @@ -1106,8 +838,8 @@ abstract class AbstractTransactionState { @Nullable private String sql; - protected AbstractTransactionState(ConnectionState state) { - this.state = state; + protected AbstractTransactionState(Client client) { + this.client = client; } final void setSql(String sql) { @@ -1165,22 +897,24 @@ final class CommitRollbackState extends AbstractTransactionState { private final boolean commit; - CommitRollbackState(ConnectionState state, boolean commit) { - super(state); + CommitRollbackState(Client client, boolean commit) { + super(client); this.commit = commit; } @Override boolean cancelTasks() { - if (!state.isInTransaction()) { + ConnectionContext context = client.getContext(); + + if (!context.isInTransaction()) { tasks |= CANCEL; return true; } - if (state.isLockWaitTimeoutChanged()) { + if (context.isLockWaitTimeoutChanged()) { // If server does not support lock wait timeout, the state will not be changed, so it is safe. tasks |= LOCK_WAIT_TIMEOUT; - statements.add("SET innodb_lock_wait_timeout=" + state.getSessionLockWaitTimeout()); + statements.add(StringUtils.lockWaitTimeoutStatement(context.getSessionLockWaitTimeout())); } tasks |= COMMIT_OR_ROLLBACK; @@ -1193,10 +927,10 @@ boolean cancelTasks() { protected boolean process(int task, SynchronousSink sink) { switch (task) { case LOCK_WAIT_TIMEOUT: - state.resetCurrentLockWaitTimeout(); + client.getContext().resetCurrentLockWaitTimeout(); return true; case COMMIT_OR_ROLLBACK: - state.resetIsolationLevel(); + client.getContext().resetCurrentIsolationLevel(); sink.complete(); return false; case CANCEL: @@ -1222,26 +956,24 @@ final class StartTransactionState extends AbstractTransactionState { private final TransactionDefinition definition; - private final Client client; - - StartTransactionState(ConnectionState state, TransactionDefinition definition, Client client) { - super(state); + StartTransactionState(Client client, TransactionDefinition definition) { + super(client); this.definition = definition; - this.client = client; } @Override boolean cancelTasks() { - if (state.isInTransaction()) { + final ConnectionContext context = client.getContext(); + if (context.isInTransaction()) { tasks |= CANCEL; return true; } + final Duration timeout = definition.getAttribute(TransactionDefinition.LOCK_WAIT_TIMEOUT); if (timeout != null) { - if (client.getContext().isLockWaitTimeoutSupported()) { - long lockWaitTimeout = timeout.getSeconds(); + if (context.isLockWaitTimeoutSupported()) { tasks |= LOCK_WAIT_TIMEOUT; - statements.add("SET innodb_lock_wait_timeout=" + lockWaitTimeout); + statements.add(StringUtils.lockWaitTimeoutStatement(timeout)); } else { QueryFlow.logger.warn( "Lock wait timeout is not supported by server, transaction definition lockWaitTimeout is ignored"); @@ -1267,22 +999,19 @@ protected boolean process(int task, SynchronousSink sink) { case LOCK_WAIT_TIMEOUT: final Duration timeout = definition.getAttribute(TransactionDefinition.LOCK_WAIT_TIMEOUT); if (timeout != null) { - final long lockWaitTimeout = timeout.getSeconds(); - state.setCurrentLockWaitTimeout(lockWaitTimeout); + client.getContext().setCurrentLockWaitTimeout(timeout); } return true; case ISOLATION_LEVEL: - final IsolationLevel isolationLevel = - definition.getAttribute(TransactionDefinition.ISOLATION_LEVEL); + final IsolationLevel isolationLevel = definition.getAttribute(TransactionDefinition.ISOLATION_LEVEL); if (isolationLevel != null) { - state.setIsolationLevel(isolationLevel); + client.getContext().setCurrentIsolationLevel(isolationLevel); } return true; case START_TRANSACTION: case CANCEL: sink.complete(); return false; - } sink.error(new IllegalStateException("Undefined transaction task: " + task + ", remain: " + tasks)); @@ -1352,14 +1081,14 @@ final class CreateSavepointState extends AbstractTransactionState { private final String name; - CreateSavepointState(final ConnectionState state, final String name) { - super(state); + CreateSavepointState(final Client client, final String name) { + super(client); this.name = name; } @Override boolean cancelTasks() { - if (!state.isInTransaction()) { + if (!client.getContext().isInTransaction()) { tasks |= START_TRANSACTION; statements.add("BEGIN"); } diff --git a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/internal/util/StringUtils.java b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/internal/util/StringUtils.java index e5c3596b6..1a96e2d79 100644 --- a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/internal/util/StringUtils.java +++ b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/internal/util/StringUtils.java @@ -16,13 +16,14 @@ package io.asyncer.r2dbc.mysql.internal.util; +import java.time.Duration; import java.time.ZoneId; import java.time.ZoneOffset; import static io.asyncer.r2dbc.mysql.internal.util.AssertUtils.requireNonEmpty; /** - * A utility for processing {@link String} in MySQL/MariaDB. + * A utility for processing {@link String} and simple statements in MySQL/MariaDB. */ public final class StringUtils { @@ -79,16 +80,48 @@ public static String extendReturning(String sql, String returning) { return returning.isEmpty() ? sql : sql + " RETURNING " + returning; } + /** + * Generates a {@link String} indicating the statement timeout variable. e.g. {@code "max_statement_time=1.5"} for + * MariaDB or {@code "max_execution_time=1500"} for MySQL. + * + * @param timeout the statement timeout + * @param isMariaDb whether the current server is MariaDB + * @return the statement timeout variable + */ + public static String statementTimeoutVariable(Duration timeout, boolean isMariaDb) { + // mariadb: https://mariadb.com/kb/en/aborting-statements/ + // mysql: https://dev.mysql.com/blog-archive/server-side-select-statement-timeouts/ + // ref: https://github.com/mariadb-corporation/mariadb-connector-r2dbc + if (isMariaDb) { + // MariaDB supports fractional seconds with microsecond precision + double seconds = (timeout.getSeconds() + timeout.getNano() / 1_000_000_000.0); + return "max_statement_time=" + seconds; + } + + return "max_execution_time=" + timeout.toMillis(); + } + + /** + * Generates a statement to set the lock wait timeout for the current session. It is using InnoDB-specific session + * variable {@code innodb_lock_wait_timeout}. + * + * @param timeout the lock wait timeout + * @return the lock wait timeout statement + */ + public static String lockWaitTimeoutStatement(Duration timeout) { + return "SET innodb_lock_wait_timeout=" + timeout.getSeconds(); + } + /** * Parses a normalized {@link ZoneId} from a time zone string of MySQL. *

- * Note: since java 14.0.2, 11.0.8, 8u261 and 7u271, America/Nuuk is already renamed from America/Godthab. - * See also tzdata2020a + * Note: since java 14.0.2, 11.0.8, 8u261 and 7u271, America/Nuuk is already renamed from America/Godthab. See also + * tzdata2020a * * @param zoneId the time zone string * @return the normalized {@link ZoneId} - * @throws IllegalArgumentException if the time zone string is {@code null} or empty - * @throws java.time.DateTimeException if the time zone string has an invalid format + * @throws IllegalArgumentException if the time zone string is {@code null} or empty + * @throws java.time.DateTimeException if the time zone string has an invalid format * @throws java.time.zone.ZoneRulesException if the time zone string cannot be found */ public static ZoneId parseZoneId(String zoneId) { diff --git a/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/ConnectionContextTest.java b/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/ConnectionContextTest.java index 5e2be6114..5d0635412 100644 --- a/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/ConnectionContextTest.java +++ b/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/ConnectionContextTest.java @@ -16,10 +16,13 @@ package io.asyncer.r2dbc.mysql; +import io.asyncer.r2dbc.mysql.cache.Caches; import io.asyncer.r2dbc.mysql.constant.ServerStatuses; import io.asyncer.r2dbc.mysql.constant.ZeroDateOption; +import io.r2dbc.spi.IsolationLevel; import org.junit.jupiter.api.Test; +import java.time.Duration; import java.time.ZoneId; import static org.assertj.core.api.Assertions.assertThat; @@ -46,15 +49,36 @@ void getTimeZone() { void setTwiceTimeZone() { ConnectionContext context = new ConnectionContext(ZeroDateOption.USE_NULL, null, 8192, true, null); - context.initTimeZone(ZoneId.systemDefault()); - assertThatIllegalStateException().isThrownBy(() -> context.initTimeZone(ZoneId.systemDefault())); + + context.initSession( + Caches.createPrepareCache(0), + IsolationLevel.REPEATABLE_READ, + false, Duration.ZERO, + null, + ZoneId.systemDefault() + ); + assertThatIllegalStateException().isThrownBy(() -> context.initSession( + Caches.createPrepareCache(0), + IsolationLevel.REPEATABLE_READ, + false, + Duration.ZERO, + null, + ZoneId.systemDefault() + )); } @Test void badSetTimeZone() { ConnectionContext context = new ConnectionContext(ZeroDateOption.USE_NULL, null, 8192, true, ZoneId.systemDefault()); - assertThatIllegalStateException().isThrownBy(() -> context.initTimeZone(ZoneId.systemDefault())); + assertThatIllegalStateException().isThrownBy(() -> context.initSession( + Caches.createPrepareCache(0), + IsolationLevel.REPEATABLE_READ, + false, + Duration.ZERO, + null, + ZoneId.systemDefault() + )); } public static ConnectionContext mock() { @@ -69,7 +93,7 @@ public static ConnectionContext mock(boolean isMariaDB, ZoneId zoneId) { ConnectionContext context = new ConnectionContext(ZeroDateOption.USE_NULL, null, 8192, true, zoneId); - context.init(1, ServerVersion.parse(isMariaDB ? "11.2.22.MOCKED" : "8.0.11.MOCKED"), + context.initHandshake(1, ServerVersion.parse(isMariaDB ? "11.2.22.MOCKED" : "8.0.11.MOCKED"), Capability.of(~(isMariaDB ? 1 : 0))); context.setServerStatuses(ServerStatuses.AUTO_COMMIT); diff --git a/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/ConnectionIntegrationTest.java b/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/ConnectionIntegrationTest.java index b45d7f91c..8fa06f1f9 100644 --- a/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/ConnectionIntegrationTest.java +++ b/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/ConnectionIntegrationTest.java @@ -68,16 +68,16 @@ class ConnectionIntegrationTest extends IntegrationTestSupport { @Test void isInTransaction() { - castedComplete(connection -> Mono.fromRunnable(() -> assertThat(connection.isInTransaction()) + castedComplete(connection -> Mono.fromRunnable(() -> assertThat(connection.context().isInTransaction()) .isFalse()) .then(connection.beginTransaction()) - .doOnSuccess(ignored -> assertThat(connection.isInTransaction()).isTrue()) + .doOnSuccess(ignored -> assertThat(connection.context().isInTransaction()).isTrue()) .then(connection.commitTransaction()) - .doOnSuccess(ignored -> assertThat(connection.isInTransaction()).isFalse()) + .doOnSuccess(ignored -> assertThat(connection.context().isInTransaction()).isFalse()) .then(connection.beginTransaction()) - .doOnSuccess(ignored -> assertThat(connection.isInTransaction()).isTrue()) + .doOnSuccess(ignored -> assertThat(connection.context().isInTransaction()).isTrue()) .then(connection.rollbackTransaction()) - .doOnSuccess(ignored -> assertThat(connection.isInTransaction()).isFalse())); + .doOnSuccess(ignored -> assertThat(connection.context().isInTransaction()).isFalse())); } @DisabledIf("envIsLessThanMySql56") @@ -88,16 +88,16 @@ void startTransaction() { TransactionDefinition readWriteConsistent = MySqlTransactionDefinition.mutability(true) .consistent(); - castedComplete(connection -> Mono.fromRunnable(() -> assertThat(connection.isInTransaction()) + castedComplete(connection -> Mono.fromRunnable(() -> assertThat(connection.context().isInTransaction()) .isFalse()) .then(connection.beginTransaction(readOnlyConsistent)) - .doOnSuccess(ignored -> assertThat(connection.isInTransaction()).isTrue()) + .doOnSuccess(ignored -> assertThat(connection.context().isInTransaction()).isTrue()) .then(connection.rollbackTransaction()) - .doOnSuccess(ignored -> assertThat(connection.isInTransaction()).isFalse()) + .doOnSuccess(ignored -> assertThat(connection.context().isInTransaction()).isFalse()) .then(connection.beginTransaction(readWriteConsistent)) - .doOnSuccess(ignored -> assertThat(connection.isInTransaction()).isTrue()) + .doOnSuccess(ignored -> assertThat(connection.context().isInTransaction()).isTrue()) .then(connection.rollbackTransaction()) - .doOnSuccess(ignored -> assertThat(connection.isInTransaction()).isFalse())); + .doOnSuccess(ignored -> assertThat(connection.context().isInTransaction()).isFalse())); } @Test @@ -115,9 +115,9 @@ void autoRollbackPreRelease() { .flatMap(MySqlResult::getRowsUpdated) .single() .doOnNext(it -> assertThat(it).isEqualTo(1)) - .doOnSuccess(ignored -> assertThat(conn.isInTransaction()).isTrue()) + .doOnSuccess(ignored -> assertThat(conn.context().isInTransaction()).isTrue()) .then(conn.preRelease()) - .doOnSuccess(ignored -> assertThat(conn.isInTransaction()).isFalse()) + .doOnSuccess(ignored -> assertThat(conn.context().isInTransaction()).isFalse()) .then(conn.postAllocate()) .thenMany(conn.createStatement("SELECT * FROM test") .execute()) @@ -143,7 +143,7 @@ void shouldNotRollbackCommittedPreRelease() { .doOnNext(it -> assertThat(it).isEqualTo(1)) .then(conn.commitTransaction()) .then(conn.preRelease()) - .doOnSuccess(ignored -> assertThat(conn.isInTransaction()).isFalse()) + .doOnSuccess(ignored -> assertThat(conn.context().isInTransaction()).isFalse()) .then(conn.postAllocate()) .thenMany(conn.createStatement("SELECT * FROM test") .execute()) @@ -158,15 +158,15 @@ void transactionDefinitionLockWaitTimeout() { .beginTransaction(MySqlTransactionDefinition.empty() .lockWaitTimeout(Duration.ofSeconds(345))) .doOnSuccess(ignored -> { - assertThat(connection.isInTransaction()).isTrue(); + assertThat(connection.context().isInTransaction()).isTrue(); assertThat(connection.getTransactionIsolationLevel()).isEqualTo(REPEATABLE_READ); - assertThat(connection.isLockWaitTimeoutChanged()).isTrue(); + assertThat(connection.context().isLockWaitTimeoutChanged()).isTrue(); }) .then(connection.rollbackTransaction()) .doOnSuccess(ignored -> { - assertThat(connection.isInTransaction()).isFalse(); + assertThat(connection.context().isInTransaction()).isFalse(); assertThat(connection.getTransactionIsolationLevel()).isEqualTo(REPEATABLE_READ); - assertThat(connection.isLockWaitTimeoutChanged()).isFalse(); + assertThat(connection.context().isLockWaitTimeoutChanged()).isFalse(); })); } @@ -175,15 +175,15 @@ void transactionDefinitionIsolationLevel() { castedComplete(connection -> connection .beginTransaction(MySqlTransactionDefinition.from(READ_COMMITTED)) .doOnSuccess(ignored -> { - assertThat(connection.isInTransaction()).isTrue(); + assertThat(connection.context().isInTransaction()).isTrue(); assertThat(connection.getTransactionIsolationLevel()).isEqualTo(READ_COMMITTED); - assertThat(connection.isLockWaitTimeoutChanged()).isFalse(); + assertThat(connection.context().isLockWaitTimeoutChanged()).isFalse(); }) .then(connection.rollbackTransaction()) .doOnSuccess(ignored -> { - assertThat(connection.isInTransaction()).isFalse(); + assertThat(connection.context().isInTransaction()).isFalse(); assertThat(connection.getTransactionIsolationLevel()).isEqualTo(REPEATABLE_READ); - assertThat(connection.isLockWaitTimeoutChanged()).isFalse(); + assertThat(connection.context().isLockWaitTimeoutChanged()).isFalse(); })); } @@ -194,7 +194,7 @@ void setTransactionLevelNotInTransaction() { Mono.fromSupplier(connection::getTransactionIsolationLevel) .doOnSuccess(it -> assertThat(it).isEqualTo(REPEATABLE_READ)) .then(connection.beginTransaction()) - .doOnSuccess(ignored -> assertThat(connection.isInTransaction()).isTrue()) + .doOnSuccess(ignored -> assertThat(connection.context().isInTransaction()).isTrue()) .then(Mono.fromSupplier(connection::getTransactionIsolationLevel)) .doOnSuccess(it -> assertThat(it).isEqualTo(REPEATABLE_READ)) .then(connection.rollbackTransaction()) @@ -203,7 +203,7 @@ void setTransactionLevelNotInTransaction() { .then(Mono.fromSupplier(connection::getTransactionIsolationLevel)) .doOnSuccess(it -> assertThat(it).isEqualTo(READ_COMMITTED)) .then(connection.beginTransaction()) - .doOnSuccess(ignored -> assertThat(connection.isInTransaction()).isTrue()) + .doOnSuccess(ignored -> assertThat(connection.context().isInTransaction()).isTrue()) // ensure transaction isolation level applies to subsequent transactions .then(Mono.fromSupplier(connection::getTransactionIsolationLevel)) .doOnSuccess(it -> assertThat(it).isEqualTo(READ_COMMITTED)) @@ -222,13 +222,13 @@ void setTransactionLevelInTransaction() { .then(Mono.fromSupplier(connection::getTransactionIsolationLevel)) .doOnSuccess(it -> assertThat(it).isNotEqualTo(READ_COMMITTED)) .then(connection.rollbackTransaction()) - .doOnSuccess(ignored -> assertThat(connection.isInTransaction()).isFalse()) + .doOnSuccess(ignored -> assertThat(connection.context().isInTransaction()).isFalse()) // ensure that session isolation level is changed after rollback .then(Mono.fromSupplier(connection::getTransactionIsolationLevel)) .doOnSuccess(it -> assertThat(it).isEqualTo(READ_COMMITTED)) // ensure transaction isolation level applies to subsequent transactions .then(connection.beginTransaction()) - .doOnSuccess(ignored -> assertThat(connection.isInTransaction()).isTrue()) + .doOnSuccess(ignored -> assertThat(connection.context().isInTransaction()).isTrue()) ); } @@ -240,15 +240,15 @@ void transactionDefinition() { .lockWaitTimeout(Duration.ofSeconds(112)) .consistent()) .doOnSuccess(ignored -> { - assertThat(connection.isInTransaction()).isTrue(); + assertThat(connection.context().isInTransaction()).isTrue(); assertThat(connection.getTransactionIsolationLevel()).isEqualTo(REPEATABLE_READ); - assertThat(connection.isLockWaitTimeoutChanged()).isTrue(); + assertThat(connection.context().isLockWaitTimeoutChanged()).isTrue(); }) .then(connection.rollbackTransaction()) .doOnSuccess(ignored -> { - assertThat(connection.isInTransaction()).isFalse(); + assertThat(connection.context().isInTransaction()).isFalse(); assertThat(connection.getTransactionIsolationLevel()).isEqualTo(REPEATABLE_READ); - assertThat(connection.isLockWaitTimeoutChanged()).isFalse(); + assertThat(connection.context().isLockWaitTimeoutChanged()).isFalse(); })); } @@ -290,7 +290,7 @@ void createSavepointAndRollbackToSavepoint(String savepoint) { "CREATE TEMPORARY TABLE test (id INT NOT NULL PRIMARY KEY, name VARCHAR(50))").execute()) .flatMap(IntegrationTestSupport::extractRowsUpdated) .then(connection.beginTransaction()) - .doOnSuccess(ignored -> assertThat(connection.isInTransaction()).isTrue()) + .doOnSuccess(ignored -> assertThat(connection.context().isInTransaction()).isTrue()) .then(Mono.from(connection.createStatement("INSERT INTO test VALUES (1, 'test1')") .execute())) .flatMap(IntegrationTestSupport::extractRowsUpdated) @@ -301,7 +301,7 @@ void createSavepointAndRollbackToSavepoint(String savepoint) { .flatMap(result -> Mono.from(result.map((row, metadata) -> row.get(0, Long.class)))) .doOnSuccess(count -> assertThat(count).isEqualTo(2)) .then(connection.createSavepoint(savepoint)) - .doOnSuccess(ignored -> assertThat(connection.isInTransaction()).isTrue()) + .doOnSuccess(ignored -> assertThat(connection.context().isInTransaction()).isTrue()) .then(Mono.from(connection.createStatement("INSERT INTO test VALUES (3, 'test3')") .execute())) .flatMap(IntegrationTestSupport::extractRowsUpdated) @@ -312,12 +312,12 @@ void createSavepointAndRollbackToSavepoint(String savepoint) { .flatMap(result -> Mono.from(result.map((row, metadata) -> row.get(0, Long.class)))) .doOnSuccess(count -> assertThat(count).isEqualTo(4)) .then(connection.rollbackTransactionToSavepoint(savepoint)) - .doOnSuccess(ignored -> assertThat(connection.isInTransaction()).isTrue()) + .doOnSuccess(ignored -> assertThat(connection.context().isInTransaction()).isTrue()) .then(Mono.from(connection.createStatement("SELECT COUNT(*) FROM test").execute())) .flatMap(result -> Mono.from(result.map((row, metadata) -> row.get(0, Long.class)))) .doOnSuccess(count -> assertThat(count).isEqualTo(2)) .then(connection.rollbackTransaction()) - .doOnSuccess(ignored -> assertThat(connection.isInTransaction()).isFalse()) + .doOnSuccess(ignored -> assertThat(connection.context().isInTransaction()).isFalse()) .then(Mono.from(connection.createStatement("SELECT COUNT(*) FROM test").execute())) .flatMap(result -> Mono.from(result.map((row, metadata) -> row.get(0, Long.class)))) .doOnSuccess(count -> assertThat(count).isEqualTo(0)) @@ -331,7 +331,7 @@ void createSavepointAndRollbackEntireTransaction(String savepoint) { "CREATE TEMPORARY TABLE test (id INT NOT NULL PRIMARY KEY, name VARCHAR(50))").execute()) .flatMap(IntegrationTestSupport::extractRowsUpdated) .then(connection.beginTransaction()) - .doOnSuccess(ignored -> assertThat(connection.isInTransaction()).isTrue()) + .doOnSuccess(ignored -> assertThat(connection.context().isInTransaction()).isTrue()) .then(Mono.from(connection.createStatement("INSERT INTO test VALUES (1, 'test1')") .execute())) .flatMap(IntegrationTestSupport::extractRowsUpdated) @@ -342,7 +342,7 @@ void createSavepointAndRollbackEntireTransaction(String savepoint) { .flatMap(result -> Mono.from(result.map((row, metadata) -> row.get(0, Long.class)))) .doOnSuccess(count -> assertThat(count).isEqualTo(2)) .then(connection.createSavepoint(savepoint)) - .doOnSuccess(ignored -> assertThat(connection.isInTransaction()).isTrue()) + .doOnSuccess(ignored -> assertThat(connection.context().isInTransaction()).isTrue()) .then(Mono.from(connection.createStatement("INSERT INTO test VALUES (3, 'test3')") .execute())) .flatMap(IntegrationTestSupport::extractRowsUpdated) @@ -353,7 +353,7 @@ void createSavepointAndRollbackEntireTransaction(String savepoint) { .flatMap(result -> Mono.from(result.map((row, metadata) -> row.get(0, Long.class)))) .doOnSuccess(count -> assertThat(count).isEqualTo(4)) .then(connection.rollbackTransaction()) - .doOnSuccess(ignored -> assertThat(connection.isInTransaction()).isFalse()) + .doOnSuccess(ignored -> assertThat(connection.context().isInTransaction()).isFalse()) .then(Mono.from(connection.createStatement("SELECT COUNT(*) FROM test").execute())) .flatMap(result -> Mono.from(result.map((row, metadata) -> row.get(0, Long.class)))) .doOnSuccess(count -> assertThat(count).isEqualTo(0)) @@ -374,8 +374,7 @@ void rollbackTransactionWithoutBegin() { void setTransactionIsolationLevel() { complete(connection -> Flux.just(READ_UNCOMMITTED, READ_COMMITTED, REPEATABLE_READ, SERIALIZABLE) .concatMap(level -> connection.setTransactionIsolationLevel(level) - .map(ignored -> assertThat(level)) - .doOnNext(a -> a.isEqualTo(connection.getTransactionIsolationLevel())))); + .doOnSuccess(ignored -> assertThat(level).isEqualTo(connection.getTransactionIsolationLevel())))); } @Test @@ -400,7 +399,7 @@ void commitTransactionShouldRespectQueuedMessages() { .execute(), connection.commitTransaction() )) - .doOnComplete(() -> assertThat(connection.isInTransaction()).isFalse()) + .doOnComplete(() -> assertThat(connection.context().isInTransaction()).isFalse()) .thenMany(connection.createStatement("SELECT COUNT(*) FROM test").execute()) .flatMap(result -> Mono.from(result.map((row, metadata) -> row.get(0, Long.class))) @@ -421,7 +420,7 @@ void rollbackTransactionShouldRespectQueuedMessages() { .execute(), connection.rollbackTransaction() )) - .doOnComplete(() -> assertThat(connection.isInTransaction()).isFalse()) + .doOnComplete(() -> assertThat(connection.context().isInTransaction()).isFalse()) .thenMany(connection.createStatement("SELECT COUNT(*) FROM test").execute()) .flatMap(result -> Mono.from(result.map((row, metadata) -> row.get(0, Long.class))) .doOnNext(count -> assertThat(count).isEqualTo(0L))) @@ -435,15 +434,15 @@ void beginTransactionShouldRespectQueuedMessages() { Mono.from(connection.createStatement(tdl).execute()) .flatMap(IntegrationTestSupport::extractRowsUpdated) .then(Mono.from(connection.beginTransaction())) - .doOnSuccess(ignored -> assertThat(connection.isInTransaction()).isTrue()) + .doOnSuccess(ignored -> assertThat(connection.context().isInTransaction()).isTrue()) .thenMany(Flux.merge( connection.createStatement("INSERT INTO test VALUES (1, 'test1')").execute(), connection.commitTransaction(), connection.beginTransaction() )) - .doOnComplete(() -> assertThat(connection.isInTransaction()).isTrue()) + .doOnComplete(() -> assertThat(connection.context().isInTransaction()).isTrue()) .then(Mono.from(connection.rollbackTransaction())) - .doOnSuccess(ignored -> assertThat(connection.isInTransaction()).isFalse()) + .doOnSuccess(ignored -> assertThat(connection.context().isInTransaction()).isFalse()) .thenMany(connection.createStatement("SELECT COUNT(*) FROM test").execute()) .flatMap(result -> Mono.from(result.map((row, metadata) -> row.get(0, Long.class))) .doOnNext(count -> assertThat(count).isEqualTo(1L))) diff --git a/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/MySqlSimpleConnectionTest.java b/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/MySqlSimpleConnectionTest.java index c8d50c633..b2847c20d 100644 --- a/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/MySqlSimpleConnectionTest.java +++ b/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/MySqlSimpleConnectionTest.java @@ -16,21 +16,36 @@ package io.asyncer.r2dbc.mysql; +import io.asyncer.r2dbc.mysql.api.MySqlTransactionDefinition; import io.asyncer.r2dbc.mysql.cache.Caches; +import io.asyncer.r2dbc.mysql.cache.PrepareCache; import io.asyncer.r2dbc.mysql.client.Client; +import io.asyncer.r2dbc.mysql.client.FluxExchangeable; import io.asyncer.r2dbc.mysql.codec.Codecs; +import io.asyncer.r2dbc.mysql.constant.ServerStatuses; import io.asyncer.r2dbc.mysql.message.client.ClientMessage; import io.asyncer.r2dbc.mysql.message.client.TextQueryMessage; +import io.asyncer.r2dbc.mysql.message.server.CompleteMessage; import io.r2dbc.spi.IsolationLevel; import org.assertj.core.api.ThrowableTypeAssert; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import reactor.core.CoreSubscriber; import reactor.core.publisher.Flux; +import reactor.core.publisher.SynchronousSink; import reactor.test.StepVerifier; +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; + import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -39,39 +54,24 @@ */ class MySqlSimpleConnectionTest { - private final Client client; - - private final Codecs codecs = mock(Codecs.class); - - private final IsolationLevel level = IsolationLevel.REPEATABLE_READ; - - private final String product = "MockConnection"; - - private final MySqlSimpleConnection noPrepare; - - MySqlSimpleConnectionTest() { - Client client = mock(Client.class); - - when(client.getContext()).thenReturn(ConnectionContextTest.mock()); - - this.client = client; - this.noPrepare = new MySqlSimpleConnection(client, - codecs, level, 50, Caches.createQueryCache(0), - Caches.createPrepareCache(0), product, null); - } + private static final Codecs CODECS = mock(Codecs.class); @Test void createStatement() { String condition = "SELECT * FROM test"; - MySqlSimpleConnection allPrepare = new MySqlSimpleConnection(client, - codecs, level, 50, Caches.createQueryCache(0), - Caches.createPrepareCache(0), product, sql -> true); - MySqlSimpleConnection halfPrepare = new MySqlSimpleConnection(client, - codecs, level, 50, Caches.createQueryCache(0), - Caches.createPrepareCache(0), product, sql -> false); - MySqlSimpleConnection conditionPrepare = new MySqlSimpleConnection(client, - codecs, level, 50, Caches.createQueryCache(0), - Caches.createPrepareCache(0), product, sql -> sql.equals(condition)); + MySqlSimpleConnection allPrepare = new MySqlSimpleConnection( + mockClient(), + CODECS, + Caches.createQueryCache(0), sql -> true); + MySqlSimpleConnection halfPrepare = new MySqlSimpleConnection( + mockClient(), + CODECS, + Caches.createQueryCache(0), sql -> false); + MySqlSimpleConnection conditionPrepare = new MySqlSimpleConnection( + mockClient(), + CODECS, + Caches.createQueryCache(0), sql -> sql.equals(condition)); + MySqlSimpleConnection noPrepare = newNoPrepare(mockClient()); assertThat(noPrepare.createStatement("SELECT * FROM test WHERE id=1")) .isExactlyInstanceOf(TextSimpleStatement.class); @@ -105,12 +105,14 @@ void createStatement() { @SuppressWarnings("ConstantConditions") @Test void badCreateStatement() { + MySqlSimpleConnection noPrepare = newNoPrepare(mockClient()); assertThatIllegalArgumentException().isThrownBy(() -> noPrepare.createStatement(null)); } @SuppressWarnings("ConstantConditions") @Test void badCreateSavepoint() { + MySqlSimpleConnection noPrepare = newNoPrepare(mockClient()); ThrowableTypeAssert asserted = assertThatIllegalArgumentException(); asserted.isThrownBy(() -> noPrepare.createSavepoint("")); @@ -120,6 +122,7 @@ void badCreateSavepoint() { @SuppressWarnings("ConstantConditions") @Test void badReleaseSavepoint() { + MySqlSimpleConnection noPrepare = newNoPrepare(mockClient()); ThrowableTypeAssert asserted = assertThatIllegalArgumentException(); asserted.isThrownBy(() -> noPrepare.releaseSavepoint("")); @@ -129,6 +132,7 @@ void badReleaseSavepoint() { @SuppressWarnings("ConstantConditions") @Test void badRollbackTransactionToSavepoint() { + MySqlSimpleConnection noPrepare = newNoPrepare(mockClient()); ThrowableTypeAssert asserted = assertThatIllegalArgumentException(); asserted.isThrownBy(() -> noPrepare.rollbackTransactionToSavepoint("")); @@ -138,24 +142,120 @@ void badRollbackTransactionToSavepoint() { @SuppressWarnings("ConstantConditions") @Test void badSetTransactionIsolationLevel() { + MySqlSimpleConnection noPrepare = newNoPrepare(mockClient()); assertThatIllegalArgumentException().isThrownBy(() -> noPrepare.setTransactionIsolationLevel(null)); } - @Test - void shouldSetTransactionIsolationLevelSuccessfully() { - ClientMessage message = new TextQueryMessage("SET SESSION TRANSACTION ISOLATION LEVEL SERIALIZABLE"); + @ParameterizedTest + @ValueSource(strings = { "READ UNCOMMITTED", "READ COMMITTED", "REPEATABLE READ", "SERIALIZABLE" }) + void shouldSetTransactionIsolationLevelSuccessfully(String levelSql) { + Client client = mockClient(); + IsolationLevel level = IsolationLevel.valueOf(levelSql); + ClientMessage message = new TextQueryMessage("SET SESSION TRANSACTION ISOLATION LEVEL " + levelSql); + when(client.exchange(eq(message), any())).thenReturn(Flux.empty()); - noPrepare.setTransactionIsolationLevel(IsolationLevel.SERIALIZABLE) + MySqlSimpleConnection noPrepare = newNoPrepare(client); + noPrepare.setTransactionIsolationLevel(level) .as(StepVerifier::create) .verifyComplete(); - assertThat(noPrepare.getSessionTransactionIsolationLevel()).isEqualTo(IsolationLevel.SERIALIZABLE); + assertThat(client.getContext().getCurrentIsolationLevel()).isEqualTo(level); + assertThat(client.getContext().getSessionIsolationLevel()).isEqualTo(level); + } + + @ParameterizedTest + @ValueSource(strings = { + "READ UNCOMMITTED,SERIALIZABLE", + "READ COMMITTED,REPEATABLE READ", + "REPEATABLE READ,READ UNCOMMITTED" + }) + void shouldSetTransactionIsolationLevelInTransaction(String levels) { + String[] levelStatements = levels.split(","); + IsolationLevel currentLevel = IsolationLevel.valueOf(levelStatements[0]); + IsolationLevel sessionLevel = IsolationLevel.valueOf(levelStatements[1]); + Client client = mockClient(); + ClientMessage session = new TextQueryMessage("SET SESSION TRANSACTION ISOLATION LEVEL " + sessionLevel.asSql()); + CompleteMessage mockDone = mock(CompleteMessage.class); + @SuppressWarnings("unchecked") + SynchronousSink sink = (SynchronousSink) mock(SynchronousSink.class); + AtomicBoolean completed = new AtomicBoolean(false); + + doAnswer(it -> { + throw it.getArgument(0, Exception.class); + }).when(sink).error(any()); + doAnswer(it -> { + completed.set(true); + return null; + }).when(sink).complete(); + when(mockDone.isDone()).thenReturn(true); + when(client.exchange(eq(session), any())).thenReturn(Flux.empty()); + when(client.exchange(any())).thenAnswer(it -> { + FluxExchangeable exchangeable = it.getArgument(0); + @SuppressWarnings("unchecked") + CoreSubscriber subscriber = mock(CoreSubscriber.class); + exchangeable.subscribe(subscriber); + + while (!completed.get()) { + exchangeable.accept(mockDone, sink); + } + + // Mock server status to be in transaction + client.getContext().setServerStatuses(ServerStatuses.IN_TRANSACTION); + + return Flux.empty(); + }); + + IsolationLevel mockLevel = IsolationLevel.valueOf("DEFAULT"); + client.getContext().initSession( + mock(PrepareCache.class), + mockLevel, + false, + Duration.ZERO, + null, + null + ); + MySqlSimpleConnection noPrepare = newNoPrepare(client); + + assertThat(client.getContext().getCurrentIsolationLevel()).isEqualTo(mockLevel); + assertThat(client.getContext().getSessionIsolationLevel()).isEqualTo(mockLevel); + + noPrepare.beginTransaction(MySqlTransactionDefinition.from(currentLevel)) + .as(StepVerifier::create) + .verifyComplete(); + + assertThat(client.getContext().getCurrentIsolationLevel()).isEqualTo(currentLevel); + assertThat(client.getContext().getSessionIsolationLevel()).isEqualTo(mockLevel); + + noPrepare.setTransactionIsolationLevel(sessionLevel) + .as(StepVerifier::create) + .verifyComplete(); + + assertThat(client.getContext().getCurrentIsolationLevel()).isEqualTo(currentLevel); + assertThat(client.getContext().getSessionIsolationLevel()).isEqualTo(sessionLevel); } @SuppressWarnings("ConstantConditions") @Test void badValidate() { + MySqlSimpleConnection noPrepare = newNoPrepare(mockClient()); assertThatIllegalArgumentException().isThrownBy(() -> noPrepare.validate(null)); } + + private static Client mockClient() { + Client client = mock(Client.class); + + when(client.getContext()).thenReturn(ConnectionContextTest.mock()); + + return client; + } + + private static MySqlSimpleConnection newNoPrepare(Client client) { + return new MySqlSimpleConnection( + client, + CODECS, + Caches.createQueryCache(0), + null + ); + } } diff --git a/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/PrepareParameterizedStatementTest.java b/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/PrepareParameterizedStatementTest.java index 345704af5..94e1591f4 100644 --- a/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/PrepareParameterizedStatementTest.java +++ b/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/PrepareParameterizedStatementTest.java @@ -52,8 +52,7 @@ public PrepareParameterizedStatement makeInstance(boolean isMariaDB, String sql, return new PrepareParameterizedStatement( client, codecs, - Query.parse(sql), - Caches.createPrepareCache(0) + Query.parse(sql) ); } diff --git a/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/PrepareSimpleStatementTest.java b/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/PrepareSimpleStatementTest.java index 0e18e7233..56d5ac907 100644 --- a/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/PrepareSimpleStatementTest.java +++ b/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/PrepareSimpleStatementTest.java @@ -16,7 +16,6 @@ package io.asyncer.r2dbc.mysql; -import io.asyncer.r2dbc.mysql.cache.Caches; import io.asyncer.r2dbc.mysql.client.Client; import io.asyncer.r2dbc.mysql.codec.Codecs; @@ -64,12 +63,7 @@ public PrepareSimpleStatement makeInstance(boolean isMariaDB, String ignored, St when(client.getContext()).thenReturn(ConnectionContextTest.mock(isMariaDB)); - return new PrepareSimpleStatement( - client, - codecs, - sql, - Caches.createPrepareCache(0) - ); + return new PrepareSimpleStatement(client, codecs, sql); } @Override