Skip to content

Commit

Permalink
Implement different db users to seperate the tenants from each other
Browse files Browse the repository at this point in the history
  • Loading branch information
RandomTannenbaum committed Jan 14, 2025
1 parent f723ffd commit e4b87cd
Show file tree
Hide file tree
Showing 18 changed files with 836 additions and 200 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import jakarta.persistence.EntityNotFoundException;
import org.flywaydb.core.Flyway;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;

Expand All @@ -10,6 +12,8 @@ public class FlywayMultitenantMigrationInitializer {
private final TenantConfigProviderInterface tenantConfigProvider;
private final String[] scriptLocations;

private static final Logger logger = LoggerFactory.getLogger(FlywayMultitenantMigrationInitializer.class);

public FlywayMultitenantMigrationInitializer(TenantConfigProviderInterface tenantConfigProvider,
final @Value("${spring.flyway.locations}")
String[] scriptLocations) {
Expand All @@ -19,21 +23,27 @@ public FlywayMultitenantMigrationInitializer(TenantConfigProviderInterface tenan

public void migrateFlyway() {
this.tenantConfigProvider.getTenantConfigs().forEach(tenantConfig -> {
TenantConfigProvider.DataSourceConfig dataSourceConfig = this.tenantConfigProvider
TenantConfigProvider.DataSourceConfig dataSourceConfigFlyway = this.tenantConfigProvider
.getTenantConfigById(tenantConfig.tenantId())
.map(TenantConfigProvider.TenantConfig::dataSourceConfig)
.map(TenantConfigProvider.TenantConfig::dataSourceConfigFlyway)
.orElseThrow(() -> new EntityNotFoundException("Cannot find tenant for configuring flyway migration"));

logUsedHibernateConfig(dataSourceConfigFlyway);

Flyway tenantSchemaFlyway = Flyway
.configure() //
.dataSource(dataSourceConfig.url(), dataSourceConfig.name(), dataSourceConfig.password()) //
.dataSource(dataSourceConfigFlyway.url(), dataSourceConfigFlyway.name(),
dataSourceConfigFlyway.password()) //
.locations(scriptLocations) //
.baselineOnMigrate(Boolean.TRUE) //
.schemas(dataSourceConfig.schema()) //
.schemas(dataSourceConfigFlyway.schema()) //
.load();

tenantSchemaFlyway.migrate();
});
}

private void logUsedHibernateConfig(TenantConfigProvider.DataSourceConfig dataSourceConfig) {
logger.info("use DbConfig: user={}", dataSourceConfig.name());
}
}
Original file line number Diff line number Diff line change
@@ -1,9 +1,25 @@
package ch.puzzle.okr.multitenancy;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import ch.puzzle.okr.exception.HibernateContextException;
import java.util.Properties;
import org.springframework.core.env.ConfigurableEnvironment;

/**
* Reads the (not tenant specific) hibernate configuration form the "hibernate.x" properties in the
* applicationX.properties file. It then caches the configuration as DbConfig object. The data from the DbConfig object
* is used by the SchemaMultiTenantConnectionProvider via getHibernateConfig() and getHibernateConfig(tenantId).
*
* <pre>
* getHibernateConfig() returns the cached DbConfig as properties.
* </pre>
*
* <pre>
* getHibernateConfig(tenantId) patches the DbConfig data with tenant specific data (from
* TenantConfigProvider) and returns the patched data as properties
* </pre>
*/
public class HibernateContext {
public static final String HIBERNATE_CONNECTION_URL = "hibernate.connection.url";
public static final String HIBERNATE_CONNECTION_USERNAME = "hibernate.connection.username";
Expand All @@ -14,6 +30,8 @@ public class HibernateContext {
public static final String SPRING_DATASOURCE_USERNAME = "spring.datasource.username";
public static final String SPRING_DATASOURCE_PASSWORD = "spring.datasource.password";

private static final Logger logger = LoggerFactory.getLogger(HibernateContext.class);

public record DbConfig(String url, String username, String password, String multiTenancy) {

public boolean isValid() {
Expand All @@ -29,20 +47,22 @@ private boolean hasEmptyValues() {
}
}

// general (not tenant specific) hibernate config
private static DbConfig cachedHibernateConfig;

public static void extractAndSetHibernateConfig(ConfigurableEnvironment environment) {
DbConfig dbConfig = extractHibernateConfig(environment);
setHibernateConfig(dbConfig);
logUsedHibernateConfig(dbConfig);
}

public static void setHibernateConfig(DbConfig dbConfig) {
if (dbConfig == null || !dbConfig.isValid()) {
throw new HibernateContextException("Invalid hibernate configuration " + dbConfig);
}
cachedHibernateConfig = dbConfig;
}

public static void extractAndSetHibernateConfig(ConfigurableEnvironment environment) {
DbConfig dbConfig = extractHibernateConfig(environment);
HibernateContext.setHibernateConfig(dbConfig);
}

private static DbConfig extractHibernateConfig(ConfigurableEnvironment environment) {
String url = environment.getProperty(HibernateContext.HIBERNATE_CONNECTION_URL);
String username = environment.getProperty(HibernateContext.HIBERNATE_CONNECTION_USERNAME);
Expand All @@ -60,7 +80,9 @@ public static Properties getHibernateConfig() {
if (cachedHibernateConfig == null) {
throw new HibernateContextException("No cached hibernate configuration found");
}
return getConfigAsProperties(cachedHibernateConfig);
var config = getConfigAsProperties(cachedHibernateConfig);
logUsedHibernateConfig(config);
return config;
}

private static Properties getConfigAsProperties(DbConfig dbConfig) {
Expand All @@ -74,4 +96,48 @@ private static Properties getConfigAsProperties(DbConfig dbConfig) {
properties.put(HibernateContext.SPRING_DATASOURCE_PASSWORD, dbConfig.password());
return properties;
}

public static Properties getHibernateConfig(String tenantIdentifier) {
if (cachedHibernateConfig == null) {
throw new RuntimeException("No cached hibernate configuration found (for tenant " + tenantIdentifier + ")");
}
var config = getConfigAsPropertiesAndPatch(cachedHibernateConfig, tenantIdentifier);
logUsedHibernateConfig(tenantIdentifier, config);
return config;
}

private static Properties getConfigAsPropertiesAndPatch(DbConfig dbConfig, String tenantIdentifier) {
Properties properties = getConfigAsProperties(dbConfig);
return patchConfigAppForTenant(properties, tenantIdentifier);
}

private static Properties patchConfigAppForTenant(Properties properties, String tenantIdentifier) {
TenantConfigProvider.TenantConfig cachedTenantConfig = TenantConfigProvider
.getCachedTenantConfig(tenantIdentifier);
if (cachedTenantConfig == null) {
throw new RuntimeException("No cached tenant configuration found (for tenant " + tenantIdentifier + ")");
}

TenantConfigProvider.DataSourceConfig dataSourceConfigApp = cachedTenantConfig.dataSourceConfigApp();
properties.put(HibernateContext.HIBERNATE_CONNECTION_USERNAME, dataSourceConfigApp.name());
properties.put(HibernateContext.HIBERNATE_CONNECTION_PASSWORD, dataSourceConfigApp.password());
properties.put(HibernateContext.SPRING_DATASOURCE_USERNAME, dataSourceConfigApp.name());
properties.put(HibernateContext.SPRING_DATASOURCE_PASSWORD, dataSourceConfigApp.password());
return properties;
}

private static void logUsedHibernateConfig(DbConfig hibernateConfig) {
logger.info("set DbConfig: user={}", hibernateConfig.username());
}

private static void logUsedHibernateConfig(Properties hibernateConfig) {
logger.info("use DbConfig: user={}",
hibernateConfig.getProperty(HibernateContext.HIBERNATE_CONNECTION_USERNAME)); //
}

private static void logUsedHibernateConfig(String tenantId, Properties hibernateConfig) {
logger.info("use DbConfig: tenant={} user={}", tenantId,
hibernateConfig.getProperty(HibernateContext.HIBERNATE_CONNECTION_USERNAME));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,42 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.sql.Connection;
import java.sql.SQLException;
import java.text.MessageFormat;
import java.util.*;

import static ch.puzzle.okr.multitenancy.TenantContext.DEFAULT_TENANT_ID;

/**
* The central piece of code of multitenancy.
*
* <pre>
* getConnection(tenantId) sets in each tenant request the specific db schema for the
* tenant. This guarantees that each tenant always works in its own DB schema.
*
* getConnection(tenantId) -> Connection calls in the abstract super class the
* getConnection(tenantId) -> Connection which calls the abstract
* selectConnectionProvider(tenantIdentifier) -> ConnectionProvider which is implemented
* in SchemaMultiTenantConnectionProvider.
* </pre>
*
* <pre>
* Some coding details:
*
* selectConnectionProvider(tenantId) -> ConnectionProvider returns for a tenant a
* ConnectionProvider. It first checks if the ConnectionProvider for the tenant is already
* cached (in connectionProviderMap). If the ConnectionProvider is cached, it returns it.
* Otherwise it creates a ConnectionProvider for the tenant, cache it and return it.
*
* To create a ConnectionProvider for the tenant, it tries to load the configuration from
* the hibernate properties. For this it uses 2 methods of HibernateContext:
* getHibernateConfig() if the tenant is the DEFAULT_TENANT_ID (public) and
* getHibernateConfig(tenantId) for all other tenants. With this information its then
* possible to create and cache a ConnectionProvider for the tenant. If no matching
* hibernate properties are found, then an exception is thrown.
* </pre>
*/
public class SchemaMultiTenantConnectionProvider extends AbstractMultiTenantConnectionProvider<String> {

private static final Logger logger = LoggerFactory.getLogger(SchemaMultiTenantConnectionProvider.class);
Expand All @@ -31,7 +67,7 @@ public Connection getConnection(String tenantIdentifier) throws SQLException {
return getConnection(tenantIdentifier, connection);
}

protected Connection getConnection(String tenantIdentifier, Connection connection) throws SQLException {
Connection getConnection(String tenantIdentifier, Connection connection) throws SQLException {
String schema = convertTenantIdToSchemaName(tenantIdentifier);
logger.debug("Setting schema to {}", schema);

Expand All @@ -42,7 +78,7 @@ protected Connection getConnection(String tenantIdentifier, Connection connectio
return connection;
}

private String convertTenantIdToSchemaName(String tenantIdentifier) {
String convertTenantIdToSchemaName(String tenantIdentifier) {
return Objects.equals(tenantIdentifier, DEFAULT_TENANT_ID) ? tenantIdentifier
: MessageFormat.format("okr_{0}", tenantIdentifier);
}
Expand All @@ -57,14 +93,14 @@ protected ConnectionProvider selectConnectionProvider(String tenantIdentifier) {
return getConnectionProvider(tenantIdentifier);
}

protected ConnectionProvider getConnectionProvider(String tenantIdentifier) {
ConnectionProvider getConnectionProvider(String tenantIdentifier) {
return Optional
.ofNullable(tenantIdentifier) //
.map(connectionProviderMap::get) //
.orElseGet(() -> createNewConnectionProvider(tenantIdentifier));
.orElseGet(() -> createAndCacheNewConnectionProvider(tenantIdentifier));
}

private ConnectionProvider createNewConnectionProvider(String tenantIdentifier) {
ConnectionProvider createAndCacheNewConnectionProvider(String tenantIdentifier) {
return Optional
.ofNullable(tenantIdentifier) //
.map(this::createConnectionProvider) //
Expand All @@ -84,29 +120,25 @@ private ConnectionProvider createConnectionProvider(String tenantIdentifier) {
.orElse(null);
}

protected Properties getHibernatePropertiesForTenantIdentifier(String tenantIdentifier) {
Properties properties = getHibernateProperties();
if (properties == null || properties.isEmpty()) {
throw new ConnectionProviderException("Cannot load hibernate properties from application.properties)");
Properties getHibernatePropertiesForTenantIdentifier(String tenantIdentifier) {
Properties properties = getHibernateProperties(tenantIdentifier);
if (properties.isEmpty()) {
throw new ConnectionProviderException("Cannot load hibernate properties from application.properties");
}
if (!Objects.equals(tenantIdentifier, DEFAULT_TENANT_ID)) {
properties.put(MappingSettings.DEFAULT_SCHEMA, MessageFormat.format("okr_{0}", tenantIdentifier));
}
return properties;
}

private ConnectionProvider initConnectionProvider(Properties hibernateProperties) {
ConnectionProvider initConnectionProvider(Properties hibernateProperties) {
Map<String, Object> configProperties = convertPropertiesToMap(hibernateProperties);
DriverManagerConnectionProviderImpl connectionProvider = getDriverManagerConnectionProviderImpl();
DriverManagerConnectionProviderImpl connectionProvider = new DriverManagerConnectionProviderImpl();
connectionProvider.configure(configProperties);
return connectionProvider;
}

protected DriverManagerConnectionProviderImpl getDriverManagerConnectionProviderImpl() {
return new DriverManagerConnectionProviderImpl();
}

private Map<String, Object> convertPropertiesToMap(Properties properties) {
Map<String, Object> convertPropertiesToMap(Properties properties) {
Map<String, Object> configProperties = new HashMap<>();
for (String key : properties.stringPropertyNames()) {
String value = properties.getProperty(key);
Expand All @@ -115,7 +147,17 @@ private Map<String, Object> convertPropertiesToMap(Properties properties) {
return configProperties;
}

protected Properties getHibernateProperties() {
return HibernateContext.getHibernateConfig();
private Properties getHibernateProperties(String tenantIdentifier) {
if (tenantIdentifier == null) {
throw new ConnectionProviderException("No hibernate configuration found for tenant: " + tenantIdentifier);
}
try {
if (tenantIdentifier.equals(DEFAULT_TENANT_ID)) {
return HibernateContext.getHibernateConfig();
}
return HibernateContext.getHibernateConfig(tenantIdentifier);
} catch (RuntimeException e) {
throw new ConnectionProviderException(e.getMessage());
}
}
}
Loading

0 comments on commit e4b87cd

Please sign in to comment.