Skip to content

Commit

Permalink
Make new workspace client from account client (#218)
Browse files Browse the repository at this point in the history
## Changes
Ported to the Java SDK from
https://github.com/databricks/databricks-sdk-go/pull/792/files and
databricks/databricks-sdk-go#700.

## Tests
<!-- How is this tested? -->
  • Loading branch information
mgyucht authored Feb 2, 2024
1 parent 2032f61 commit dc07492
Show file tree
Hide file tree
Showing 12 changed files with 292 additions and 125 deletions.
9 changes: 9 additions & 0 deletions .codegen/account.java.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ package com.databricks.sdk;
import com.databricks.sdk.core.ApiClient;
import com.databricks.sdk.core.ConfigLoader;
import com.databricks.sdk.core.DatabricksConfig;
import com.databricks.sdk.core.utils.AzureUtils;
import com.databricks.sdk.service.provisioning.*;
{{range .Services}}{{if .IsAccounts}}
import com.databricks.sdk.service.{{.Package.Name}}.{{.PascalName}}API;
import com.databricks.sdk.service.{{.Package.Name}}.{{.PascalName}}Service;{{end}}{{end}}
Expand Down Expand Up @@ -63,4 +65,11 @@ public class AccountClient {
public DatabricksConfig config() {
return config;
}

public WorkspaceClient getWorkspaceClient(Workspace workspace) {
String host = this.config.getDatabricksEnvironment().getDeploymentUrl(workspace.getDeploymentName());
DatabricksConfig config = this.config.newWithWorkspaceHost(host);
AzureUtils.getAzureWorkspaceResourceId(workspace).map(config::setAzureWorkspaceResourceId);
return new WorkspaceClient(config);
}
}
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
fmt:
mvn spotless:apply

test:
mvn test

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class AzureCliCredentialsProvider implements CredentialsProvider, AzureUtils {
public class AzureCliCredentialsProvider implements CredentialsProvider {
private final ObjectMapper mapper = new ObjectMapper();
private static final Logger LOG = LoggerFactory.getLogger(AzureCliCredentialsProvider.class);

Expand All @@ -18,7 +18,6 @@ public String authType() {
return AZURE_CLI;
}

@Override
public CliTokenSource tokenSourceFor(DatabricksConfig config, String resource) {
List<String> cmd =
new ArrayList<>(
Expand Down Expand Up @@ -72,7 +71,7 @@ public HeaderFactory configure(DatabricksConfig config) {
}

try {
ensureHostPresent(config, mapper);
AzureUtils.ensureHostPresent(config, mapper, this::tokenSourceFor);
String resource = config.getEffectiveAzureLoginAppId();
CliTokenSource tokenSource = tokenSourceFor(config, resource);
CliTokenSource mgmtTokenSource;
Expand All @@ -89,9 +88,9 @@ public HeaderFactory configure(DatabricksConfig config) {
Map<String, String> headers = new HashMap<>();
headers.put("Authorization", token.getTokenType() + " " + token.getAccessToken());
if (finalMgmtTokenSource != null) {
addSpManagementToken(finalMgmtTokenSource, headers);
AzureUtils.addSpManagementToken(finalMgmtTokenSource, headers);
}
return addWorkspaceResourceId(config, headers);
return AzureUtils.addWorkspaceResourceId(config, headers);
};
} catch (DatabricksException e) {
String stderr = e.getMessage();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,28 +53,28 @@ public String getActiveDirectoryEndpoint() {
ENVIRONMENTS.put(
"PUBLIC",
new AzureEnvironment(
"AzurePublicCloud",
"PUBLIC",
"https://management.core.windows.net/",
"https://management.azure.com/",
"https://login.microsoftonline.com/"));
ENVIRONMENTS.put(
"GERMAN",
new AzureEnvironment(
"AzureGermanCloud",
"GERMAN",
"https://management.core.cloudapi.de/",
"https://management.microsoftazure.de/",
"https://login.microsoftonline.de/"));
ENVIRONMENTS.put(
"USGOVERNMENT",
new AzureEnvironment(
"AzureUSGovernmentCloud",
"USGOVERNMENT",
"https://management.core.usgovcloudapi.net/",
"https://management.usgovcloudapi.net/",
"https://login.microsoftonline.us/"));
ENVIRONMENTS.put(
"CHINA",
new AzureEnvironment(
"AzureChinaCloud",
"CHINA",
"https://management.core.chinacloudapi.cn/",
"https://management.chinacloudapi.cn/",
"https://login.chinacloudapi.cn/"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
import com.databricks.sdk.core.http.Request;
import com.databricks.sdk.core.http.Response;
import com.databricks.sdk.core.oauth.OpenIDConnectEndpoints;
import com.databricks.sdk.core.utils.Cloud;
import com.databricks.sdk.core.utils.Environment;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.io.File;
import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.lang.reflect.Field;
import java.util.*;
import org.apache.http.HttpMessage;

public class DatabricksConfig {
Expand Down Expand Up @@ -83,9 +83,6 @@ public class DatabricksConfig {
@ConfigAttribute(env = "ARM_ENVIRONMENT")
private String azureEnvironment;

@ConfigAttribute(env = "DATABRICKS_AZURE_LOGIN_APP_ID", auth = "azure")
private String azureLoginAppId;

@ConfigAttribute(env = "DATABRICKS_CLI_PATH")
private String databricksCliPath;

Expand Down Expand Up @@ -124,6 +121,8 @@ public class DatabricksConfig {

private Environment env;

private DatabricksEnvironment databricksEnvironment;

public Environment getEnv() {
return env;
}
Expand Down Expand Up @@ -397,11 +396,7 @@ public DatabricksConfig setAzureEnvironment(String azureEnvironment) {
}

public String getEffectiveAzureLoginAppId() {
if (azureLoginAppId != null) {
return azureLoginAppId;
}

return AzureEnvironment.ARM_DATABRICKS_RESOURCE_ID;
return getDatabricksEnvironment().getAzureApplicationId();
}

public String getAuthType() {
Expand Down Expand Up @@ -468,15 +463,7 @@ public DatabricksConfig setHttpClient(HttpClient httpClient) {
}

public boolean isAzure() {
if (azureWorkspaceResourceId != null) {
return true;
}
if (host == null) {
return false;
}
return host.contains(".azuredatabricks.net")
|| host.contains("databricks.azure.cn")
|| host.contains(".databricks.azure.us");
return this.getDatabricksEnvironment().getCloud() == Cloud.AZURE;
}

public synchronized void authenticate(HttpMessage request) {
Expand All @@ -487,17 +474,11 @@ public synchronized void authenticate(HttpMessage request) {
}

public boolean isGcp() {
if (host == null) {
return false;
}
return host.contains(".gcp.databricks.com");
return this.getDatabricksEnvironment().getCloud() == Cloud.GCP;
}

public boolean isAws() {
if (host == null) {
return false;
}
return (!isAzure() && !isGcp());
return this.getDatabricksEnvironment().getCloud() == Cloud.AWS;
}

public boolean isAccountClient() {
Expand Down Expand Up @@ -540,4 +521,64 @@ public OpenIDConnectEndpoints getOidcEndpoints() throws IOException {
public String toString() {
return ConfigLoader.debugString(this);
}

public DatabricksConfig setDatabricksEnvironment(DatabricksEnvironment databricksEnvironment) {
this.databricksEnvironment = databricksEnvironment;
return this;
}

public DatabricksEnvironment getDatabricksEnvironment() {
ConfigLoader.fixHostIfNeeded(this);

if (this.databricksEnvironment != null) {
return this.databricksEnvironment;
}

if (this.host != null) {
for (DatabricksEnvironment env : DatabricksEnvironment.ALL_ENVIRONMENTS) {
if (this.host.endsWith(env.getDnsZone())) {
return env;
}
}
}

if (this.azureWorkspaceResourceId != null) {
String azureEnv = "PUBLIC";
if (this.azureEnvironment != null) {
azureEnv = this.azureEnvironment;
}
for (DatabricksEnvironment env : DatabricksEnvironment.ALL_ENVIRONMENTS) {
if (env.getCloud() != Cloud.AZURE) {
continue;
}
if (!env.getAzureEnvironment().getName().equals(azureEnv)) {
continue;
}
if (env.getDnsZone().startsWith(".dev") || env.getDnsZone().startsWith(".staging")) {
continue;
}
return env;
}
}

return DatabricksEnvironment.DEFAULT_ENVIRONMENT;
}

public DatabricksConfig newWithWorkspaceHost(String host) {
Set<String> fieldsToSkip =
new HashSet<>(Arrays.asList("host", "accountId", "azureWorkspaceResourceId"));
DatabricksConfig newConfig = new DatabricksConfig();
for (Field f : DatabricksConfig.class.getDeclaredFields()) {
if (fieldsToSkip.contains(f.getName())) {
continue;
}
try {
f.set(newConfig, f.get(this));
} catch (IllegalAccessException e) {
throw new RuntimeException(e);
}
}
newConfig.setHost(host);
return newConfig;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package com.databricks.sdk.core;

import com.databricks.sdk.core.utils.Cloud;
import java.util.Arrays;
import java.util.List;

public class DatabricksEnvironment {
private Cloud cloud;
private String dnsZone;
private String azureApplicationId;
private AzureEnvironment azureEnvironment;

private DatabricksEnvironment(Cloud cloud, String dnsZone) {
this(cloud, dnsZone, null, null);
}

private DatabricksEnvironment(
Cloud cloud, String dnsZone, String azureApplicationId, AzureEnvironment azureEnvironment) {
this.cloud = cloud;
this.dnsZone = dnsZone;
this.azureApplicationId = azureApplicationId;
this.azureEnvironment = azureEnvironment;
}

public Cloud getCloud() {
return cloud;
}

public String getDnsZone() {
return dnsZone;
}

public String getAzureApplicationId() {
return azureApplicationId;
}

public AzureEnvironment getAzureEnvironment() {
return azureEnvironment;
}

public String getDeploymentUrl(String name) {
return String.format("https://%s%s", name, dnsZone);
}

public static final DatabricksEnvironment DEFAULT_ENVIRONMENT =
new DatabricksEnvironment(Cloud.AWS, ".cloud.databricks.com");

public static final List<DatabricksEnvironment> ALL_ENVIRONMENTS =
Arrays.asList(
new DatabricksEnvironment(Cloud.AWS, ".dev.databricks.com"),
new DatabricksEnvironment(Cloud.AWS, ".staging.cloud.databricks.com"),
new DatabricksEnvironment(Cloud.AWS, ".cloud.databricks.us"),
DEFAULT_ENVIRONMENT,
new DatabricksEnvironment(
Cloud.AZURE,
".dev.azuredatabricks.net",
"62a912ac-b58e-4c1d-89ea-b2dbfc7358fc",
AzureEnvironment.getEnvironment("PUBLIC")),
new DatabricksEnvironment(
Cloud.AZURE,
".staging.azuredatabricks.net",
"4a67d088-db5c-48f1-9ff2-0aace800ae68",
AzureEnvironment.getEnvironment("PUBLIC")),
new DatabricksEnvironment(
Cloud.AZURE,
".azuredatabricks.net",
"2ff814a6-3304-4ab8-85cb-cd0e6f879c1d",
AzureEnvironment.getEnvironment("PUBLIC")),
new DatabricksEnvironment(
Cloud.AZURE,
".databricks.azure.us",
"2ff814a6-3304-4ab8-85cb-cd0e6f879c1d",
AzureEnvironment.getEnvironment("USGOVERNMENT")),
new DatabricksEnvironment(
Cloud.AZURE,
".databricks.azure.cn",
"2ff814a6-3304-4ab8-85cb-cd0e6f879c1d",
AzureEnvironment.getEnvironment("CHINA")),
new DatabricksEnvironment(Cloud.GCP, ".dev.gcp.databricks.com"),
new DatabricksEnvironment(Cloud.GCP, ".staging.gcp.databricks.com"),
new DatabricksEnvironment(Cloud.GCP, ".gcp.databricks.com"));
}
Loading

0 comments on commit dc07492

Please sign in to comment.