Skip to content

Commit

Permalink
Add client support for SASL extensions (#231)
Browse files Browse the repository at this point in the history
Signed-off-by: Marko Strukelj <marko.strukelj@gmail.com>
  • Loading branch information
mstruk authored Mar 1, 2024
1 parent 10d3c0e commit 229daee
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 6 deletions.
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -1065,6 +1065,16 @@ You may also specify a pause time between requests in order not to flood the aut
The default value is '0', meaning 'no pause'. Provide the value greater than '0' to set the pause time between attempts in milliseconds:
- `oauth.http.retry.pause.millis` (e.g.: "500" - if a retry is attempted, there will first be a half-a-second pause)

### Configuring the SASL extensions

If your Kafka Broker uses some other custom `OAUTHBEARER` implementation, you may need to pass it SASL extensions options.
These are key:value pairs representing a client context, that are sent to the Kafka Broker when the new session is started.

You can pass SASL extensions options by using `oauth.sasl.extension.` as a key prefix:
- `oauth.sasl.extension.KEY` (e.g.: "VALUE" - replace KEY with the actual SASL extension key name, and VALUE with the actual value)

For example, you could add multiple sasl extensions options: `oauth.sasl.extension.key1="value1" oauth.sasl.extension.key2="value2"`


### Configuring the re-authentication on the client

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@ public class ClientConfig extends Config {
*/
public static final String OAUTH_CLIENT_ASSERTION_TYPE = "oauth.client.assertion.type";

/**
* A prefix to use to pass SASL extensions options
*/
public static final String OAUTH_SASL_EXTENSION_PREFIX = "oauth.sasl.extension.";

/**
* Create a new instance
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import io.strimzi.kafka.oauth.services.OAuthMetrics;
import io.strimzi.kafka.oauth.services.Services;
import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler;
import org.apache.kafka.common.security.auth.SaslExtensions;
import org.apache.kafka.common.security.auth.SaslExtensionsCallback;
import org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule;
import org.apache.kafka.common.security.oauthbearer.OAuthBearerToken;
import org.apache.kafka.common.security.oauthbearer.OAuthBearerTokenCallback;
Expand All @@ -33,10 +35,13 @@
import java.io.IOException;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.Set;
import java.util.regex.Pattern;

import static io.strimzi.kafka.oauth.common.ConfigUtil.getConnectTimeout;
import static io.strimzi.kafka.oauth.common.ConfigUtil.getReadTimeout;
Expand Down Expand Up @@ -89,15 +94,24 @@ public class JaasClientOauthLoginCallbackHandler implements AuthenticateCallback
private final ClientMetricsHandler authenticatorMetrics = new ClientMetricsHandler();
private boolean includeAcceptHeader;

// Using ordered map helps with predictable logging output which can be tested
private final Map<String, String> saslExtensions = new LinkedHashMap<>();
private static final Pattern SASL_KEY_VALIDATION_PATTERN = Pattern.compile("[A-Za-z]+");
private static final Pattern SASL_VALUE_VALIDATION_PATTERN = Pattern.compile("[\\x21-\\x7E \t\r\n]+");


@Override
public void configure(Map<String, ?> configs, String saslMechanism, List<AppConfigurationEntry> jaasConfigEntries) {
if (!OAuthBearerLoginModule.OAUTHBEARER_MECHANISM.equals(saslMechanism)) {
throw new IllegalArgumentException("Unexpected SASL mechanism: " + saslMechanism);
}

for (AppConfigurationEntry e: jaasConfigEntries) {
Map<String, ?> options = Collections.emptyMap();

if (!jaasConfigEntries.isEmpty()) {
options = jaasConfigEntries.get(0).getOptions();
Properties p = new Properties();
p.putAll(e.getOptions());
p.putAll(options);
config = new ClientConfig(p);
}

Expand Down Expand Up @@ -166,6 +180,17 @@ public void configure(Map<String, ?> configs, String saslMechanism, List<AppConf

String configId = configureMetrics(configs);

// Process extensions configuration
for (String key: options.keySet()) {
if (key.startsWith(ClientConfig.OAUTH_SASL_EXTENSION_PREFIX)) {
String value = config.getValue(key, "");
key = key.substring(ClientConfig.OAUTH_SASL_EXTENSION_PREFIX.length());

validateSaslExtension(key, value);
saslExtensions.put(key, value);
}
}

if (LOG.isDebugEnabled()) {
LOG.debug("Configured JaasClientOauthLoginCallbackHandler:"
+ "\n configId: " + configId
Expand All @@ -191,7 +216,17 @@ public void configure(Map<String, ?> configs, String saslMechanism, List<AppConf
+ "\n retries: " + retries
+ "\n retryPauseMillis: " + retryPauseMillis
+ "\n enableMetrics: " + enableMetrics
+ "\n includeAcceptHeader: " + includeAcceptHeader);
+ "\n includeAcceptHeader: " + includeAcceptHeader
+ "\n saslExtensions: " + saslExtensions);
}
}

private void validateSaslExtension(String key, String value) {
if (!SASL_KEY_VALIDATION_PATTERN.matcher(key).matches() || "auth".equals(key)) {
throw new ConfigException("Invalid sasl extension key: '" + key + "' ('" + ClientConfig.OAUTH_SASL_EXTENSION_PREFIX + key + "')");
}
if (!SASL_VALUE_VALIDATION_PATTERN.matcher(value).matches()) {
throw new ConfigException("Invalid sasl extension value for key: '" + key + "' ('" + ClientConfig.OAUTH_SASL_EXTENSION_PREFIX + key + "')");
}
}

Expand Down Expand Up @@ -332,12 +367,19 @@ public void handle(Callback[] callbacks) throws IOException, UnsupportedCallback
for (Callback callback : callbacks) {
if (callback instanceof OAuthBearerTokenCallback) {
handleCallback((OAuthBearerTokenCallback) callback);
} else if (callback instanceof SaslExtensionsCallback) {
handleExtensionsCallback((SaslExtensionsCallback) callback);
} else {
throw new UnsupportedCallbackException(callback);
}
}
}

private void handleExtensionsCallback(SaslExtensionsCallback callback) {
SaslExtensions extensions = new SaslExtensions(saslExtensions);
callback.extensions(extensions);
}

private void handleCallback(OAuthBearerTokenCallback callback) throws IOException {
if (callback.token() != null) {
throw new IllegalArgumentException("Callback had a token already");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ public void doTest() throws Exception {

testAllConfigOptions();

testSaslExtensions();

testAccessTokenLocation();

testRefreshTokenLocation();
Expand All @@ -93,7 +95,7 @@ private void testAllConfigOptions() throws IOException {
attrs.put(ClientConfig.OAUTH_PASSWORD_GRANT_PASSWORD, "password");
attrs.put(ClientConfig.OAUTH_USERNAME_CLAIM, "username-claim");
attrs.put(ClientConfig.OAUTH_FALLBACK_USERNAME_CLAIM, "fallback-username-claim");
attrs.put(ClientConfig.OAUTH_FALLBACK_USERNAME_PREFIX, "username-prefix");
attrs.put(ClientConfig.OAUTH_FALLBACK_USERNAME_PREFIX, "fallback-username-prefix");
attrs.put(ClientConfig.OAUTH_SCOPE, "scope");
attrs.put(ClientConfig.OAUTH_AUDIENCE, "audience");
attrs.put(ClientConfig.OAUTH_ACCESS_TOKEN_IS_JWT, "false");
Expand All @@ -104,6 +106,8 @@ private void testAllConfigOptions() throws IOException {
attrs.put(ClientConfig.OAUTH_HTTP_RETRY_PAUSE_MILLIS, "500");
attrs.put(ClientConfig.OAUTH_ENABLE_METRICS, "true");
attrs.put(ClientConfig.OAUTH_INCLUDE_ACCEPT_HEADER, "false");
attrs.put(ClientConfig.OAUTH_SASL_EXTENSION_PREFIX + "poolid", "poolid-value");
attrs.put(ClientConfig.OAUTH_SASL_EXTENSION_PREFIX + "group.ref", "group-ref-value");


AppConfigurationEntry jaasConfig = new AppConfigurationEntry("org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule", AppConfigurationEntry.LoginModuleControlFlag.REQUIRED, attrs);
Expand All @@ -116,6 +120,18 @@ private void testAllConfigOptions() throws IOException {
LogLineReader logReader = new LogLineReader(Common.LOG_PATH);
logReader.readNext();

try {
loginHandler.configure(clientProps, "OAUTHBEARER", Collections.singletonList(jaasConfig));
} catch (Exception e) {
Assert.assertTrue("Is a ConfigException", e instanceof ConfigException);
Assert.assertTrue("Invalid sasl extension key: " + e.getMessage(), e.getMessage().contains("Invalid sasl extension key: 'group.ref'"));
}

logReader.readNext();

attrs.remove(ClientConfig.OAUTH_SASL_EXTENSION_PREFIX + "group.ref");
attrs.put(ClientConfig.OAUTH_SASL_EXTENSION_PREFIX + "group", "group-ref-value");

loginHandler.configure(clientProps, "OAUTHBEARER", Collections.singletonList(jaasConfig));

Common.checkLog(logReader, "configId", "config-id",
Expand All @@ -139,7 +155,8 @@ private void testAllConfigOptions() throws IOException {
"retries", "3",
"retryPauseMillis", "500",
"enableMetrics", "true",
"includeAcceptHeader", "false");
"includeAcceptHeader", "false",
"saslExtensions", "\\{poolid=poolid-value, group=group-ref-value\\}");


// we could not check tokenEndpointUri and token in the same run
Expand Down Expand Up @@ -357,6 +374,34 @@ private void testValidConfigurations() {
}
}

private void testSaslExtensions() throws Exception {
String testClient = "testclient";
String testSecret = "testsecret";

changeAuthServerMode("jwks", "mode_200");
changeAuthServerMode("token", "mode_200");
createOAuthClient(testClient, testSecret);

Map<String, String> oauthConfig = new HashMap<>();
oauthConfig.put(ClientConfig.OAUTH_TOKEN_ENDPOINT_URI, TOKEN_ENDPOINT_URI);
oauthConfig.put(ClientConfig.OAUTH_CLIENT_ID, testClient);
oauthConfig.put(ClientConfig.OAUTH_CLIENT_SECRET, testSecret);
oauthConfig.put(ClientConfig.OAUTH_SSL_TRUSTSTORE_LOCATION, "../docker/target/kafka/certs/ca-truststore.p12");
oauthConfig.put(ClientConfig.OAUTH_SSL_TRUSTSTORE_PASSWORD, "changeit");
oauthConfig.put(ClientConfig.OAUTH_SASL_EXTENSION_PREFIX + "extoption", "optionvalue");

LogLineReader logReader = new LogLineReader(Common.LOG_PATH);
logReader.readNext();

// If it fails with 'Unknown signing key' it means that Kafka has not managed to load JWKS keys yet
// due to jwks endpoint returning status 404 initially
initJaasWithRetry(oauthConfig);

List<String> lines = logReader.readNext();
// Check in the log that SASL extensions have been properly set
checkLogForRegex(lines, ".*LoginManager.*extensionsMap=\\{extoption=optionvalue\\}.*");
}

private void testAccessTokenLocation() throws Exception {

String testClient = "testclient";
Expand All @@ -369,7 +414,7 @@ private void testAccessTokenLocation() throws Exception {
String accessToken = loginWithClientSecret(TOKEN_ENDPOINT_URI, testClient, testSecret, "../docker/target/kafka/certs/ca-truststore.p12", "changeit");

Path accessTokenFilePath = Paths.get("target/access_token_file");
Files.write(accessTokenFilePath, accessToken.getBytes(StandardCharsets.UTF_8), StandardOpenOption.CREATE_NEW);
Files.write(accessTokenFilePath, accessToken.getBytes(StandardCharsets.UTF_8), StandardOpenOption.CREATE, StandardOpenOption.TRUNCATE_EXISTING);
try {
LogLineReader logReader = new LogLineReader(Common.LOG_PATH);
logReader.readNext();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
org.slf4j.simpleLogger.logFile=target/test.log
org.slf4j.simpleLogger.showDateTime=true
org.slf4j.simpleLogger.dateTimeFormat=yyyy-MM-dd HH:mm:ss:SSS
org.slf4j.simpleLogger.log.org.apache.kafka.common.security=TRACE
org.slf4j.simpleLogger.log.org.apache.kafka=OFF
org.slf4j.simpleLogger.log.io.strimzi=TRACE

0 comments on commit 229daee

Please sign in to comment.