diff --git a/exporter/exporter_test.go b/exporter/exporter_test.go index e07d4c881e..ca07f3c6db 100644 --- a/exporter/exporter_test.go +++ b/exporter/exporter_test.go @@ -2309,18 +2309,36 @@ func TestImportingModelServing(t *testing.T) { Name: "abc", Id: "1234", Config: &serving.EndpointCoreConfigOutput{ - ServedModels: []serving.ServedModelOutput{ - { - ModelName: "def", - ModelVersion: "1", - Name: "def", - }, + AutoCaptureConfig: &serving.AutoCaptureConfigOutput{ + Enabled: true, + CatalogName: "main", + SchemaName: "tmp", + TableNamePrefix: "test", }, ServedEntities: []serving.ServedEntityOutput{ { - EntityName: "def", - EntityVersion: "1", - Name: "def", + EntityName: "main.tmp.model", + EntityVersion: "1", + Name: "def", + ScaleToZeroEnabled: true, + }, + { + EntityName: "def", + EntityVersion: "1", + Name: "def", + ScaleToZeroEnabled: false, + InstanceProfileArn: "arn:aws:iam::123456789012:instance-profile/MyInstanceProfile", + }, + { + ExternalModel: &serving.ExternalModel{ + Provider: "databricks", + Task: "llm/v1/embeddings", + Name: "e5_small_v2", + DatabricksModelServingConfig: &serving.DatabricksModelServingConfig{ + DatabricksApiToken: "dapi", + DatabricksWorkspaceUrl: "https://adb-1234.azuredatabricks.net", + }, + }, }, }, }, @@ -2334,9 +2352,25 @@ func TestImportingModelServing(t *testing.T) { ic := newImportContext(client) ic.Directory = tmpDir ic.enableListing("model-serving") + ic.enableServices("model-serving") err := ic.Run() assert.NoError(t, err) + + content, err := os.ReadFile(tmpDir + "/model-serving.tf") + assert.NoError(t, err) + contentStr := string(content) + assert.True(t, strings.Contains(contentStr, `resource "databricks_model_serving" "abc_90015098"`)) + assert.True(t, strings.Contains(contentStr, `scale_to_zero_enabled = false`)) + assert.True(t, strings.Contains(contentStr, `instance_profile_arn = "arn:aws:iam::123456789012:instance-profile/MyInstanceProfile"`)) + assert.True(t, strings.Contains(contentStr, `databricks_api_token = "dapi"`)) + assert.True(t, strings.Contains(contentStr, `databricks_workspace_url = "https://adb-1234.azuredatabricks.net"`)) + assert.True(t, strings.Contains(contentStr, `served_entities { + scale_to_zero_enabled = true + name = "def" + entity_version = "1" + entity_name = "main.tmp.model" + }`)) }) } diff --git a/exporter/importables.go b/exporter/importables.go index 0948f9a601..a32c84e9d3 100644 --- a/exporter/importables.go +++ b/exporter/importables.go @@ -20,6 +20,7 @@ import ( "github.com/databricks/databricks-sdk-go/service/iam" sdk_jobs "github.com/databricks/databricks-sdk-go/service/jobs" "github.com/databricks/databricks-sdk-go/service/ml" + "github.com/databricks/databricks-sdk-go/service/serving" "github.com/databricks/databricks-sdk-go/service/settings" "github.com/databricks/databricks-sdk-go/service/sharing" "github.com/databricks/databricks-sdk-go/service/sql" @@ -46,21 +47,23 @@ import ( ) var ( - adlsGen2Regex = regexp.MustCompile(`^(abfss?)://([^@]+)@([^.]+)\.(?:[^/]+)(/.*)?$`) - adlsGen1Regex = regexp.MustCompile(`^(adls?)://([^.]+)\.(?:[^/]+)(/.*)?$`) - wasbsRegex = regexp.MustCompile(`^(wasbs?)://([^@]+)@([^.]+)\.(?:[^/]+)(/.*)?$`) - s3Regex = regexp.MustCompile(`^(s3a?)://([^/]+)(/.*)?$`) - gsRegex = regexp.MustCompile(`^gs://([^/]+)(/.*)?$`) - globalWorkspaceConfName = "global_workspace_conf" - nameNormalizationRegex = regexp.MustCompile(`\W+`) - fileNameNormalizationRegex = regexp.MustCompile(`[^-_\w/.@]`) - jobClustersRegex = regexp.MustCompile(`^((job_cluster|task)\.\d+\.new_cluster\.\d+\.)`) - dltClusterRegex = regexp.MustCompile(`^(cluster\.\d+\.)`) - secretPathRegex = regexp.MustCompile(`^\{\{secrets\/([^\/]+)\/([^}]+)\}\}$`) - sqlParentRegexp = regexp.MustCompile(`^folders/(\d+)$`) - dltDefaultStorageRegex = regexp.MustCompile(`^dbfs:/pipelines/[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$`) - ignoreIdeFolderRegex = regexp.MustCompile(`^/Users/[^/]+/\.ide/.*$`) - fileExtensionLanguageMapping = map[string]string{ + adlsGen2Regex = regexp.MustCompile(`^(abfss?)://([^@]+)@([^.]+)\.(?:[^/]+)(/.*)?$`) + adlsGen1Regex = regexp.MustCompile(`^(adls?)://([^.]+)\.(?:[^/]+)(/.*)?$`) + wasbsRegex = regexp.MustCompile(`^(wasbs?)://([^@]+)@([^.]+)\.(?:[^/]+)(/.*)?$`) + s3Regex = regexp.MustCompile(`^(s3a?)://([^/]+)(/.*)?$`) + gsRegex = regexp.MustCompile(`^gs://([^/]+)(/.*)?$`) + globalWorkspaceConfName = "global_workspace_conf" + nameNormalizationRegex = regexp.MustCompile(`\W+`) + fileNameNormalizationRegex = regexp.MustCompile(`[^-_\w/.@]`) + jobClustersRegex = regexp.MustCompile(`^((job_cluster|task)\.\d+\.new_cluster\.\d+\.)`) + dltClusterRegex = regexp.MustCompile(`^(cluster\.\d+\.)`) + secretPathRegex = regexp.MustCompile(`^\{\{secrets\/([^\/]+)\/([^}]+)\}\}$`) + sqlParentRegexp = regexp.MustCompile(`^folders/(\d+)$`) + dltDefaultStorageRegex = regexp.MustCompile(`^dbfs:/pipelines/[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$`) + ignoreIdeFolderRegex = regexp.MustCompile(`^/Users/[^/]+/\.ide/.*$`) + servedEntityFieldExtractionRegex = regexp.MustCompile(`^config\.[0-9]+\.served_entities\.([0-9]+)\.(.*)$`) + uc3LevelIdRegex = regexp.MustCompile(`^([^.]+\.[^.]+\.[^.]+)$`) + fileExtensionLanguageMapping = map[string]string{ "SCALA": ".scala", "PYTHON": ".py", "SQL": ".sql", @@ -2012,11 +2015,11 @@ var resourcesMap map[string]importable = map[string]importable{ }) } ic.emitInitScriptsLegacy(cluster.InitScripts) - ic.emitSecretsFromSecretsPath(cluster.SparkConf) - ic.emitSecretsFromSecretsPath(cluster.SparkEnvVars) + ic.emitSecretsFromSecretsPathMap(cluster.SparkConf) + ic.emitSecretsFromSecretsPathMap(cluster.SparkEnvVars) } ic.emitFilesFromMap(pipeline.Configuration) - ic.emitSecretsFromSecretsPath(pipeline.Configuration) + ic.emitSecretsFromSecretsPathMap(pipeline.Configuration) ic.emitPermissionsIfNotIgnored(r, fmt.Sprintf("/pipelines/%s", r.ID), "pipeline_"+ic.Importables["databricks_pipeline"].Name(ic, r.Data)) return nil @@ -2145,7 +2148,6 @@ var resourcesMap map[string]importable = map[string]importable{ if err != nil { return err } - for offset, endpoint := range endpointsList { ic.EmitIfUpdatedAfterMillis(&resource{ Resource: "databricks_model_serving", @@ -2160,16 +2162,96 @@ var resourcesMap map[string]importable = map[string]importable{ Import: func(ic *importContext, r *resource) error { ic.emitPermissionsIfNotIgnored(r, fmt.Sprintf("/serving-endpoints/%s", r.Data.Get("serving_endpoint_id").(string)), "serving_endpoint_"+ic.Importables["databricks_model_serving"].Name(ic, r.Data)) + s := ic.Resources["databricks_model_serving"].Schema + var mse serving.CreateServingEndpoint + common.DataToStructPointer(r.Data, s, &mse) + if mse.Config.ServedEntities != nil { + for _, se := range mse.Config.ServedEntities { + if se.EntityName != "" { + if se.EntityVersion != "" { // we have an UC model or model from model registry + if uc3LevelIdRegex.MatchString(se.EntityName) { + ic.Emit(&resource{ + Resource: "databricks_registered_model", + ID: se.EntityName, + }) + } + // TODO: add else branch to emit databricks_model when we have support for it + } + // TODO: add else branch to emit UC function when we add support for them... + } + if se.InstanceProfileArn != "" { + ic.Emit(&resource{ + Resource: "databricks_instance_profile", + ID: se.InstanceProfileArn, + }) + } + ic.emitSecretsFromSecretsPathMap(se.EnvironmentVars) + if se.ExternalModel != nil { + if se.ExternalModel.DatabricksModelServingConfig != nil { + ic.emitSecretsFromSecretPathString(se.ExternalModel.DatabricksModelServingConfig.DatabricksApiToken) + } + if se.ExternalModel.Ai21labsConfig != nil { + ic.emitSecretsFromSecretPathString(se.ExternalModel.Ai21labsConfig.Ai21labsApiKey) + } + if se.ExternalModel.AnthropicConfig != nil { + ic.emitSecretsFromSecretPathString(se.ExternalModel.AnthropicConfig.AnthropicApiKey) + } + if se.ExternalModel.AmazonBedrockConfig != nil { + ic.emitSecretsFromSecretPathString(se.ExternalModel.AmazonBedrockConfig.AwsAccessKeyId) + ic.emitSecretsFromSecretPathString(se.ExternalModel.AmazonBedrockConfig.AwsSecretAccessKey) + } + if se.ExternalModel.CohereConfig != nil { + ic.emitSecretsFromSecretPathString(se.ExternalModel.CohereConfig.CohereApiKey) + } + if se.ExternalModel.OpenaiConfig != nil { + ic.emitSecretsFromSecretPathString(se.ExternalModel.OpenaiConfig.OpenaiApiKey) + } + if se.ExternalModel.PalmConfig != nil { + ic.emitSecretsFromSecretPathString(se.ExternalModel.PalmConfig.PalmApiKey) + } + } + } + } + if mse.Config.AutoCaptureConfig != nil && mse.Config.AutoCaptureConfig.CatalogName != "" && + mse.Config.AutoCaptureConfig.SchemaName != "" { + ic.Emit(&resource{ + Resource: "databricks_schema", + ID: mse.Config.AutoCaptureConfig.CatalogName + "." + mse.Config.AutoCaptureConfig.SchemaName, + }) + } return nil }, ShouldOmitField: func(ic *importContext, pathString string, as *schema.Schema, d *schema.ResourceData) bool { - if pathString == "config.0.traffic_config" || - (strings.HasPrefix(pathString, "config.0.served_models.") && - strings.HasSuffix(pathString, ".scale_to_zero_enabled")) { + if pathString == "config.0.traffic_config" || pathString == "config.0.auto_capture_config.0.enabled" || + (pathString == "config.0.auto_capture_config.0.table_name_prefix" && d.Get(pathString).(string) != "") { return false } + if res := servedEntityFieldExtractionRegex.FindStringSubmatch(pathString); res != nil { + field := res[2] + log.Printf("[DEBUG] ShouldOmitField: extracted field from %s: '%s'", pathString, field) + switch field { + case "scale_to_zero_enabled", "name": + return false + case "workload_size", "workload_type": + return d.Get(pathString).(string) == "" + } + } return defaultShouldOmitFieldFunc(ic, pathString, as, d) }, + ShouldGenerateField: func(ic *importContext, pathString string, as *schema.Schema, d *schema.ResourceData) bool { + // We need to generate some fields even if they have zero value... + if strings.HasSuffix(pathString, ".scale_to_zero_enabled") { + extModelBlockCoordinate := strings.Replace(pathString, ".scale_to_zero_enabled", ".external_model", 1) + return d.Get(extModelBlockCoordinate+".#").(int) == 0 + } + return pathString == "config.0.auto_capture_config.0.enabled" + }, + Depends: []reference{ + {Path: "config.served_entities.entity_name", Resource: "databricks_registered_model"}, + {Path: "config.auto_capture_config.catalog_name", Resource: "databricks_catalog"}, + {Path: "config.auto_capture_config.schema_name", Resource: "databricks_schema", Match: "name", + IsValidApproximation: isMatchingCatalogAndSchemaInModelServing, SkipDirectLookup: true}, + }, }, "databricks_mlflow_webhook": { WorkspaceLevel: true, diff --git a/exporter/util.go b/exporter/util.go index e8ffacea8c..269fbf6199 100644 --- a/exporter/util.go +++ b/exporter/util.go @@ -118,8 +118,8 @@ func (ic *importContext) importClusterLegacy(c *clusters.Cluster) { }) } ic.emitInitScriptsLegacy(c.InitScripts) - ic.emitSecretsFromSecretsPath(c.SparkConf) - ic.emitSecretsFromSecretsPath(c.SparkEnvVars) + ic.emitSecretsFromSecretsPathMap(c.SparkConf) + ic.emitSecretsFromSecretsPathMap(c.SparkEnvVars) ic.emitUserOrServicePrincipal(c.SingleUserName) } @@ -153,19 +153,23 @@ func (ic *importContext) importCluster(c *compute.ClusterSpec) { }) } ic.emitInitScripts(c.InitScripts) - ic.emitSecretsFromSecretsPath(c.SparkConf) - ic.emitSecretsFromSecretsPath(c.SparkEnvVars) + ic.emitSecretsFromSecretsPathMap(c.SparkConf) + ic.emitSecretsFromSecretsPathMap(c.SparkEnvVars) ic.emitUserOrServicePrincipal(c.SingleUserName) } -func (ic *importContext) emitSecretsFromSecretsPath(m map[string]string) { +func (ic *importContext) emitSecretsFromSecretPathString(v string) { + if res := secretPathRegex.FindStringSubmatch(v); res != nil { + ic.Emit(&resource{ + Resource: "databricks_secret_scope", + ID: res[1], + }) + } +} + +func (ic *importContext) emitSecretsFromSecretsPathMap(m map[string]string) { for _, v := range m { - if res := secretPathRegex.FindStringSubmatch(v); res != nil { - ic.Emit(&resource{ - Resource: "databricks_secret_scope", - ID: res[1], - }) - } + ic.emitSecretsFromSecretPathString(v) } } @@ -1436,6 +1440,21 @@ func isMatchingCatalogAndSchema(ic *importContext, res *resource, ra *resourceAp return result } +func isMatchingCatalogAndSchemaInModelServing(ic *importContext, res *resource, ra *resourceApproximation, origPath string) bool { + res_catalog_name := res.Data.Get("config.0.auto_capture_config.0.catalog_name").(string) + res_schema_name := res.Data.Get("config.0.auto_capture_config.0.schema_name").(string) + ra_catalog_name, cat_found := ra.Get("catalog_name") + ra_schema_name, schema_found := ra.Get("name") + if !cat_found || !schema_found { + log.Printf("[WARN] Can't find attributes in approximation: %s %s, catalog='%v' (found? %v) schema='%v' (found? %v). Resource: %s, catalog='%s', schema='%s'", + ra.Type, ra.Name, ra_catalog_name, cat_found, ra_schema_name, schema_found, res.Resource, res_catalog_name, res_schema_name) + return true + } + + result := ra_catalog_name.(string) == res_catalog_name && ra_schema_name.(string) == res_schema_name + return result +} + func isMatchingShareRecipient(ic *importContext, res *resource, ra *resourceApproximation, origPath string) bool { shareName, ok := res.Data.GetOk("share") // principal := res.Data.Get(origPath)