Skip to content

Commit

Permalink
[Exporter] Improve exporting of databricks_model_serving (#3821)
Browse files Browse the repository at this point in the history
## Changes
<!-- Summary of your changes that are easy to understand -->

Changes include:

- emitting references to UC models
- emitting references to UC catalog and schema for auto-capture
- emitting secrets from environment variables and external models
configuration
- generation of some required fields
- expanding test coverage

## Tests
<!-- 
How is this tested? Please see the checklist below and also describe any
other relevant tests
-->

- [x] `make test` run locally
- [ ] relevant change in `docs/` folder
- [ ] covered with integration tests in `internal/acceptance`
- [ ] relevant acceptance tests are passing
- [ ] using Go SDK
  • Loading branch information
alexott authored Jul 26, 2024
1 parent a75696a commit 1df8285
Show file tree
Hide file tree
Showing 3 changed files with 177 additions and 42 deletions.
52 changes: 43 additions & 9 deletions exporter/exporter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2309,18 +2309,36 @@ func TestImportingModelServing(t *testing.T) {
Name: "abc",
Id: "1234",
Config: &serving.EndpointCoreConfigOutput{
ServedModels: []serving.ServedModelOutput{
{
ModelName: "def",
ModelVersion: "1",
Name: "def",
},
AutoCaptureConfig: &serving.AutoCaptureConfigOutput{
Enabled: true,
CatalogName: "main",
SchemaName: "tmp",
TableNamePrefix: "test",
},
ServedEntities: []serving.ServedEntityOutput{
{
EntityName: "def",
EntityVersion: "1",
Name: "def",
EntityName: "main.tmp.model",
EntityVersion: "1",
Name: "def",
ScaleToZeroEnabled: true,
},
{
EntityName: "def",
EntityVersion: "1",
Name: "def",
ScaleToZeroEnabled: false,
InstanceProfileArn: "arn:aws:iam::123456789012:instance-profile/MyInstanceProfile",
},
{
ExternalModel: &serving.ExternalModel{
Provider: "databricks",
Task: "llm/v1/embeddings",
Name: "e5_small_v2",
DatabricksModelServingConfig: &serving.DatabricksModelServingConfig{
DatabricksApiToken: "dapi",
DatabricksWorkspaceUrl: "https://adb-1234.azuredatabricks.net",
},
},
},
},
},
Expand All @@ -2334,9 +2352,25 @@ func TestImportingModelServing(t *testing.T) {
ic := newImportContext(client)
ic.Directory = tmpDir
ic.enableListing("model-serving")
ic.enableServices("model-serving")

err := ic.Run()
assert.NoError(t, err)

content, err := os.ReadFile(tmpDir + "/model-serving.tf")
assert.NoError(t, err)
contentStr := string(content)
assert.True(t, strings.Contains(contentStr, `resource "databricks_model_serving" "abc_90015098"`))
assert.True(t, strings.Contains(contentStr, `scale_to_zero_enabled = false`))
assert.True(t, strings.Contains(contentStr, `instance_profile_arn = "arn:aws:iam::123456789012:instance-profile/MyInstanceProfile"`))
assert.True(t, strings.Contains(contentStr, `databricks_api_token = "dapi"`))
assert.True(t, strings.Contains(contentStr, `databricks_workspace_url = "https://adb-1234.azuredatabricks.net"`))
assert.True(t, strings.Contains(contentStr, `served_entities {
scale_to_zero_enabled = true
name = "def"
entity_version = "1"
entity_name = "main.tmp.model"
}`))
})
}

Expand Down
126 changes: 104 additions & 22 deletions exporter/importables.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"github.com/databricks/databricks-sdk-go/service/iam"
sdk_jobs "github.com/databricks/databricks-sdk-go/service/jobs"
"github.com/databricks/databricks-sdk-go/service/ml"
"github.com/databricks/databricks-sdk-go/service/serving"
"github.com/databricks/databricks-sdk-go/service/settings"
"github.com/databricks/databricks-sdk-go/service/sharing"
"github.com/databricks/databricks-sdk-go/service/sql"
Expand All @@ -46,21 +47,23 @@ import (
)

var (
adlsGen2Regex = regexp.MustCompile(`^(abfss?)://([^@]+)@([^.]+)\.(?:[^/]+)(/.*)?$`)
adlsGen1Regex = regexp.MustCompile(`^(adls?)://([^.]+)\.(?:[^/]+)(/.*)?$`)
wasbsRegex = regexp.MustCompile(`^(wasbs?)://([^@]+)@([^.]+)\.(?:[^/]+)(/.*)?$`)
s3Regex = regexp.MustCompile(`^(s3a?)://([^/]+)(/.*)?$`)
gsRegex = regexp.MustCompile(`^gs://([^/]+)(/.*)?$`)
globalWorkspaceConfName = "global_workspace_conf"
nameNormalizationRegex = regexp.MustCompile(`\W+`)
fileNameNormalizationRegex = regexp.MustCompile(`[^-_\w/.@]`)
jobClustersRegex = regexp.MustCompile(`^((job_cluster|task)\.\d+\.new_cluster\.\d+\.)`)
dltClusterRegex = regexp.MustCompile(`^(cluster\.\d+\.)`)
secretPathRegex = regexp.MustCompile(`^\{\{secrets\/([^\/]+)\/([^}]+)\}\}$`)
sqlParentRegexp = regexp.MustCompile(`^folders/(\d+)$`)
dltDefaultStorageRegex = regexp.MustCompile(`^dbfs:/pipelines/[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$`)
ignoreIdeFolderRegex = regexp.MustCompile(`^/Users/[^/]+/\.ide/.*$`)
fileExtensionLanguageMapping = map[string]string{
adlsGen2Regex = regexp.MustCompile(`^(abfss?)://([^@]+)@([^.]+)\.(?:[^/]+)(/.*)?$`)
adlsGen1Regex = regexp.MustCompile(`^(adls?)://([^.]+)\.(?:[^/]+)(/.*)?$`)
wasbsRegex = regexp.MustCompile(`^(wasbs?)://([^@]+)@([^.]+)\.(?:[^/]+)(/.*)?$`)
s3Regex = regexp.MustCompile(`^(s3a?)://([^/]+)(/.*)?$`)
gsRegex = regexp.MustCompile(`^gs://([^/]+)(/.*)?$`)
globalWorkspaceConfName = "global_workspace_conf"
nameNormalizationRegex = regexp.MustCompile(`\W+`)
fileNameNormalizationRegex = regexp.MustCompile(`[^-_\w/.@]`)
jobClustersRegex = regexp.MustCompile(`^((job_cluster|task)\.\d+\.new_cluster\.\d+\.)`)
dltClusterRegex = regexp.MustCompile(`^(cluster\.\d+\.)`)
secretPathRegex = regexp.MustCompile(`^\{\{secrets\/([^\/]+)\/([^}]+)\}\}$`)
sqlParentRegexp = regexp.MustCompile(`^folders/(\d+)$`)
dltDefaultStorageRegex = regexp.MustCompile(`^dbfs:/pipelines/[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$`)
ignoreIdeFolderRegex = regexp.MustCompile(`^/Users/[^/]+/\.ide/.*$`)
servedEntityFieldExtractionRegex = regexp.MustCompile(`^config\.[0-9]+\.served_entities\.([0-9]+)\.(.*)$`)
uc3LevelIdRegex = regexp.MustCompile(`^([^.]+\.[^.]+\.[^.]+)$`)
fileExtensionLanguageMapping = map[string]string{
"SCALA": ".scala",
"PYTHON": ".py",
"SQL": ".sql",
Expand Down Expand Up @@ -2012,11 +2015,11 @@ var resourcesMap map[string]importable = map[string]importable{
})
}
ic.emitInitScriptsLegacy(cluster.InitScripts)
ic.emitSecretsFromSecretsPath(cluster.SparkConf)
ic.emitSecretsFromSecretsPath(cluster.SparkEnvVars)
ic.emitSecretsFromSecretsPathMap(cluster.SparkConf)
ic.emitSecretsFromSecretsPathMap(cluster.SparkEnvVars)
}
ic.emitFilesFromMap(pipeline.Configuration)
ic.emitSecretsFromSecretsPath(pipeline.Configuration)
ic.emitSecretsFromSecretsPathMap(pipeline.Configuration)
ic.emitPermissionsIfNotIgnored(r, fmt.Sprintf("/pipelines/%s", r.ID),
"pipeline_"+ic.Importables["databricks_pipeline"].Name(ic, r.Data))
return nil
Expand Down Expand Up @@ -2145,7 +2148,6 @@ var resourcesMap map[string]importable = map[string]importable{
if err != nil {
return err
}

for offset, endpoint := range endpointsList {
ic.EmitIfUpdatedAfterMillis(&resource{
Resource: "databricks_model_serving",
Expand All @@ -2160,16 +2162,96 @@ var resourcesMap map[string]importable = map[string]importable{
Import: func(ic *importContext, r *resource) error {
ic.emitPermissionsIfNotIgnored(r, fmt.Sprintf("/serving-endpoints/%s", r.Data.Get("serving_endpoint_id").(string)),
"serving_endpoint_"+ic.Importables["databricks_model_serving"].Name(ic, r.Data))
s := ic.Resources["databricks_model_serving"].Schema
var mse serving.CreateServingEndpoint
common.DataToStructPointer(r.Data, s, &mse)
if mse.Config.ServedEntities != nil {
for _, se := range mse.Config.ServedEntities {
if se.EntityName != "" {
if se.EntityVersion != "" { // we have an UC model or model from model registry
if uc3LevelIdRegex.MatchString(se.EntityName) {
ic.Emit(&resource{
Resource: "databricks_registered_model",
ID: se.EntityName,
})
}
// TODO: add else branch to emit databricks_model when we have support for it
}
// TODO: add else branch to emit UC function when we add support for them...
}
if se.InstanceProfileArn != "" {
ic.Emit(&resource{
Resource: "databricks_instance_profile",
ID: se.InstanceProfileArn,
})
}
ic.emitSecretsFromSecretsPathMap(se.EnvironmentVars)
if se.ExternalModel != nil {
if se.ExternalModel.DatabricksModelServingConfig != nil {
ic.emitSecretsFromSecretPathString(se.ExternalModel.DatabricksModelServingConfig.DatabricksApiToken)
}
if se.ExternalModel.Ai21labsConfig != nil {
ic.emitSecretsFromSecretPathString(se.ExternalModel.Ai21labsConfig.Ai21labsApiKey)
}
if se.ExternalModel.AnthropicConfig != nil {
ic.emitSecretsFromSecretPathString(se.ExternalModel.AnthropicConfig.AnthropicApiKey)
}
if se.ExternalModel.AmazonBedrockConfig != nil {
ic.emitSecretsFromSecretPathString(se.ExternalModel.AmazonBedrockConfig.AwsAccessKeyId)
ic.emitSecretsFromSecretPathString(se.ExternalModel.AmazonBedrockConfig.AwsSecretAccessKey)
}
if se.ExternalModel.CohereConfig != nil {
ic.emitSecretsFromSecretPathString(se.ExternalModel.CohereConfig.CohereApiKey)
}
if se.ExternalModel.OpenaiConfig != nil {
ic.emitSecretsFromSecretPathString(se.ExternalModel.OpenaiConfig.OpenaiApiKey)
}
if se.ExternalModel.PalmConfig != nil {
ic.emitSecretsFromSecretPathString(se.ExternalModel.PalmConfig.PalmApiKey)
}
}
}
}
if mse.Config.AutoCaptureConfig != nil && mse.Config.AutoCaptureConfig.CatalogName != "" &&
mse.Config.AutoCaptureConfig.SchemaName != "" {
ic.Emit(&resource{
Resource: "databricks_schema",
ID: mse.Config.AutoCaptureConfig.CatalogName + "." + mse.Config.AutoCaptureConfig.SchemaName,
})
}
return nil
},
ShouldOmitField: func(ic *importContext, pathString string, as *schema.Schema, d *schema.ResourceData) bool {
if pathString == "config.0.traffic_config" ||
(strings.HasPrefix(pathString, "config.0.served_models.") &&
strings.HasSuffix(pathString, ".scale_to_zero_enabled")) {
if pathString == "config.0.traffic_config" || pathString == "config.0.auto_capture_config.0.enabled" ||
(pathString == "config.0.auto_capture_config.0.table_name_prefix" && d.Get(pathString).(string) != "") {
return false
}
if res := servedEntityFieldExtractionRegex.FindStringSubmatch(pathString); res != nil {
field := res[2]
log.Printf("[DEBUG] ShouldOmitField: extracted field from %s: '%s'", pathString, field)
switch field {
case "scale_to_zero_enabled", "name":
return false
case "workload_size", "workload_type":
return d.Get(pathString).(string) == ""
}
}
return defaultShouldOmitFieldFunc(ic, pathString, as, d)
},
ShouldGenerateField: func(ic *importContext, pathString string, as *schema.Schema, d *schema.ResourceData) bool {
// We need to generate some fields even if they have zero value...
if strings.HasSuffix(pathString, ".scale_to_zero_enabled") {
extModelBlockCoordinate := strings.Replace(pathString, ".scale_to_zero_enabled", ".external_model", 1)
return d.Get(extModelBlockCoordinate+".#").(int) == 0
}
return pathString == "config.0.auto_capture_config.0.enabled"
},
Depends: []reference{
{Path: "config.served_entities.entity_name", Resource: "databricks_registered_model"},
{Path: "config.auto_capture_config.catalog_name", Resource: "databricks_catalog"},
{Path: "config.auto_capture_config.schema_name", Resource: "databricks_schema", Match: "name",
IsValidApproximation: isMatchingCatalogAndSchemaInModelServing, SkipDirectLookup: true},
},
},
"databricks_mlflow_webhook": {
WorkspaceLevel: true,
Expand Down
41 changes: 30 additions & 11 deletions exporter/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,8 @@ func (ic *importContext) importClusterLegacy(c *clusters.Cluster) {
})
}
ic.emitInitScriptsLegacy(c.InitScripts)
ic.emitSecretsFromSecretsPath(c.SparkConf)
ic.emitSecretsFromSecretsPath(c.SparkEnvVars)
ic.emitSecretsFromSecretsPathMap(c.SparkConf)
ic.emitSecretsFromSecretsPathMap(c.SparkEnvVars)
ic.emitUserOrServicePrincipal(c.SingleUserName)
}

Expand Down Expand Up @@ -153,19 +153,23 @@ func (ic *importContext) importCluster(c *compute.ClusterSpec) {
})
}
ic.emitInitScripts(c.InitScripts)
ic.emitSecretsFromSecretsPath(c.SparkConf)
ic.emitSecretsFromSecretsPath(c.SparkEnvVars)
ic.emitSecretsFromSecretsPathMap(c.SparkConf)
ic.emitSecretsFromSecretsPathMap(c.SparkEnvVars)
ic.emitUserOrServicePrincipal(c.SingleUserName)
}

func (ic *importContext) emitSecretsFromSecretsPath(m map[string]string) {
func (ic *importContext) emitSecretsFromSecretPathString(v string) {
if res := secretPathRegex.FindStringSubmatch(v); res != nil {
ic.Emit(&resource{
Resource: "databricks_secret_scope",
ID: res[1],
})
}
}

func (ic *importContext) emitSecretsFromSecretsPathMap(m map[string]string) {
for _, v := range m {
if res := secretPathRegex.FindStringSubmatch(v); res != nil {
ic.Emit(&resource{
Resource: "databricks_secret_scope",
ID: res[1],
})
}
ic.emitSecretsFromSecretPathString(v)
}
}

Expand Down Expand Up @@ -1436,6 +1440,21 @@ func isMatchingCatalogAndSchema(ic *importContext, res *resource, ra *resourceAp
return result
}

func isMatchingCatalogAndSchemaInModelServing(ic *importContext, res *resource, ra *resourceApproximation, origPath string) bool {
res_catalog_name := res.Data.Get("config.0.auto_capture_config.0.catalog_name").(string)
res_schema_name := res.Data.Get("config.0.auto_capture_config.0.schema_name").(string)
ra_catalog_name, cat_found := ra.Get("catalog_name")
ra_schema_name, schema_found := ra.Get("name")
if !cat_found || !schema_found {
log.Printf("[WARN] Can't find attributes in approximation: %s %s, catalog='%v' (found? %v) schema='%v' (found? %v). Resource: %s, catalog='%s', schema='%s'",
ra.Type, ra.Name, ra_catalog_name, cat_found, ra_schema_name, schema_found, res.Resource, res_catalog_name, res_schema_name)
return true
}

result := ra_catalog_name.(string) == res_catalog_name && ra_schema_name.(string) == res_schema_name
return result
}

func isMatchingShareRecipient(ic *importContext, res *resource, ra *resourceApproximation, origPath string) bool {
shareName, ok := res.Data.GetOk("share")
// principal := res.Data.Get(origPath)
Expand Down

0 comments on commit 1df8285

Please sign in to comment.