From 57b73d8f2a27b2adeb2f502a37530bd872bbaaaa Mon Sep 17 00:00:00 2001 From: Vuong Date: Thu, 20 Jun 2024 18:02:29 +0100 Subject: [PATCH] clean up spark versions methods --- access/resource_sql_permissions.go | 2 +- access/resource_sql_permissions_test.go | 16 +-- catalog/resource_sql_table.go | 2 +- catalog/resource_sql_table_test.go | 12 +-- clusters/clusters_api.go | 28 +++-- clusters/clusters_api_sdk.go | 9 ++ clusters/clusters_api_test.go | 135 ++---------------------- clusters/data_spark_version.go | 116 -------------------- exporter/exporter_test.go | 6 +- storage/mounts.go | 9 +- storage/mounts_test.go | 8 +- storage/resource_mount_test.go | 8 +- 12 files changed, 71 insertions(+), 280 deletions(-) diff --git a/access/resource_sql_permissions.go b/access/resource_sql_permissions.go index 360220219a..432fdf0bdd 100644 --- a/access/resource_sql_permissions.go +++ b/access/resource_sql_permissions.go @@ -272,7 +272,7 @@ func (ta *SqlPermissions) initCluster(ctx context.Context, d *schema.ResourceDat } func (ta *SqlPermissions) getOrCreateCluster(clustersAPI clusters.ClustersAPI) (string, error) { - sparkVersion := clustersAPI.LatestSparkVersionOrDefault(clusters.SparkVersionRequest{ + sparkVersion := clusters.LatestSparkVersionOrDefault(clustersAPI.Context(), clustersAPI.WorkspaceClient(), compute.SparkVersionRequest{ Latest: true, }) nodeType := clustersAPI.GetSmallestNodeType(compute.NodeTypeRequest{LocalDisk: true}) diff --git a/access/resource_sql_permissions_test.go b/access/resource_sql_permissions_test.go index 6e75fd109e..17a864d7e2 100644 --- a/access/resource_sql_permissions_test.go +++ b/access/resource_sql_permissions_test.go @@ -185,11 +185,11 @@ var createHighConcurrencyCluster = []qa.HTTPFixture{ Method: "GET", ReuseRequest: true, Resource: "/api/2.0/clusters/spark-versions", - Response: clusters.SparkVersionsList{ - SparkVersions: []clusters.SparkVersion{ + Response: compute.GetSparkVersionsResponse{ + Versions: []compute.SparkVersion{ { - Version: "7.1.x-cpu-ml-scala2.12", - Description: "7.1 ML (includes Apache Spark 3.0.0, Scala 2.12)", + Key: "7.1.x-cpu-ml-scala2.12", + Name: "7.1 ML (includes Apache Spark 3.0.0, Scala 2.12)", }, }, }, @@ -262,11 +262,11 @@ var createSharedCluster = []qa.HTTPFixture{ Method: "GET", ReuseRequest: true, Resource: "/api/2.0/clusters/spark-versions", - Response: clusters.SparkVersionsList{ - SparkVersions: []clusters.SparkVersion{ + Response: compute.GetSparkVersionsResponse{ + Versions: []compute.SparkVersion{ { - Version: "7.1.x-cpu-ml-scala2.12", - Description: "7.1 ML (includes Apache Spark 3.0.0, Scala 2.12)", + Key: "7.1.x-cpu-ml-scala2.12", + Name: "7.1 ML (includes Apache Spark 3.0.0, Scala 2.12)", }, }, }, diff --git a/catalog/resource_sql_table.go b/catalog/resource_sql_table.go index db4ee7ff46..0076181d7b 100644 --- a/catalog/resource_sql_table.go +++ b/catalog/resource_sql_table.go @@ -159,7 +159,7 @@ func (ti *SqlTableInfo) initCluster(ctx context.Context, d *schema.ResourceData, } func (ti *SqlTableInfo) getOrCreateCluster(clusterName string, clustersAPI clusters.ClustersAPI) (string, error) { - sparkVersion := clustersAPI.LatestSparkVersionOrDefault(clusters.SparkVersionRequest{ + sparkVersion := clusters.LatestSparkVersionOrDefault(clustersAPI.Context(), clustersAPI.WorkspaceClient(), compute.SparkVersionRequest{ Latest: true, }) nodeType := clustersAPI.GetSmallestNodeType(compute.NodeTypeRequest{LocalDisk: true}) diff --git a/catalog/resource_sql_table_test.go b/catalog/resource_sql_table_test.go index db26ccab51..58d4b0ddd6 100644 --- a/catalog/resource_sql_table_test.go +++ b/catalog/resource_sql_table_test.go @@ -1248,15 +1248,11 @@ var baseClusterFixture = []qa.HTTPFixture{ Method: "GET", ReuseRequest: true, Resource: "/api/2.0/clusters/spark-versions", - Response: clusters.SparkVersionsList{ - SparkVersions: []clusters.SparkVersion{ + Response: compute.GetSparkVersionsResponse{ + Versions: []compute.SparkVersion{ { - Version: "7.1.x-cpu-ml-scala2.12", - Description: "7.1 ML (includes Apache Spark 3.0.0, Scala 2.12)", - }, - { - Version: "7.3.x-scala2.12", - Description: "7.3 LTS (includes Apache Spark 3.0.1, Scala 2.12)", + Key: "7.1.x-cpu-ml-scala2.12", + Name: "7.1 ML (includes Apache Spark 3.0.0, Scala 2.12)", }, }, }, diff --git a/clusters/clusters_api.go b/clusters/clusters_api.go index 3434d3fdae..dd4708aec1 100644 --- a/clusters/clusters_api.go +++ b/clusters/clusters_api.go @@ -9,6 +9,7 @@ import ( "sync" "time" + "github.com/databricks/databricks-sdk-go" "github.com/databricks/databricks-sdk-go/apierr" "github.com/databricks/databricks-sdk-go/service/compute" @@ -574,6 +575,19 @@ type ClustersAPI struct { context context.Context } +// Temporary function to be used until all resources are migrated to Go SDK +// Create a workspace client +func (a ClustersAPI) WorkspaceClient() *databricks.WorkspaceClient { + client, _ := a.client.WorkspaceClient() + return client +} + +// Temporary function to be used until all resources are migrated to Go SDK +// Return a context +func (a ClustersAPI) Context() context.Context { + return a.context +} + // Create creates a new Spark cluster and waits till it's running func (a ClustersAPI) Create(cluster Cluster) (info ClusterInfo, err error) { var ci ClusterID @@ -867,6 +881,7 @@ var getOrCreateClusterMutex sync.Mutex // GetOrCreateRunningCluster creates an autoterminating cluster if it doesn't exist func (a ClustersAPI) GetOrCreateRunningCluster(name string, custom ...Cluster) (c ClusterInfo, err error) { + w, err := a.client.WorkspaceClient() getOrCreateClusterMutex.Lock() defer getOrCreateClusterMutex.Unlock() @@ -900,13 +915,14 @@ func (a ClustersAPI) GetOrCreateRunningCluster(name string, custom ...Cluster) ( LocalDisk: true, }) log.Printf("[INFO] Creating an autoterminating cluster with node type %s", smallestNodeType) + latestVersion, _ := w.Clusters.SelectSparkVersion(a.context, compute.SparkVersionRequest{ + Latest: true, + LongTermSupport: true, + }) r := Cluster{ - NumWorkers: 1, - ClusterName: name, - SparkVersion: a.LatestSparkVersionOrDefault(SparkVersionRequest{ - Latest: true, - LongTermSupport: true, - }), + NumWorkers: 1, + ClusterName: name, + SparkVersion: latestVersion, NodeTypeID: smallestNodeType, AutoterminationMinutes: 10, } diff --git a/clusters/clusters_api_sdk.go b/clusters/clusters_api_sdk.go index 388f1b80b5..a1c4b91f2a 100644 --- a/clusters/clusters_api_sdk.go +++ b/clusters/clusters_api_sdk.go @@ -35,3 +35,12 @@ func StartClusterAndGetInfo(ctx context.Context, w *databricks.WorkspaceClient, } return w.Clusters.StartByClusterIdAndWait(ctx, clusterID) } + +// LatestSparkVersionOrDefault returns Spark version matching the definition, or default in case of error +func LatestSparkVersionOrDefault(ctx context.Context, w *databricks.WorkspaceClient, svr compute.SparkVersionRequest) string { + version, err := w.Clusters.SelectSparkVersion(ctx, svr) + if err != nil { + return "7.3.x-scala2.12" + } + return version +} diff --git a/clusters/clusters_api_test.go b/clusters/clusters_api_test.go index 441d7397d5..32f8d3f407 100644 --- a/clusters/clusters_api_test.go +++ b/clusters/clusters_api_test.go @@ -6,7 +6,7 @@ import ( "fmt" // "reflect" - "strings" + "testing" "github.com/databricks/databricks-sdk-go/apierr" @@ -28,23 +28,23 @@ func TestGetOrCreateRunningCluster_AzureAuth(t *testing.T) { Method: "GET", ReuseRequest: true, Resource: "/api/2.0/clusters/spark-versions", - Response: SparkVersionsList{ - SparkVersions: []SparkVersion{ + Response: compute.GetSparkVersionsResponse{ + Versions: []compute.SparkVersion{ { - Version: "7.1.x-cpu-ml-scala2.12", - Description: "7.1 ML (includes Apache Spark 3.0.0, Scala 2.12)", + Key: "7.1.x-cpu-ml-scala2.12", + Name: "7.1 ML (includes Apache Spark 3.0.0, Scala 2.12)", }, { - Version: "apache-spark-2.4.x-scala2.11", - Description: "Light 2.4 (includes Apache Spark 2.4, Scala 2.11)", + Key: "apache-spark-2.4.x-scala2.11", + Name: "Light 2.4 (includes Apache Spark 2.4, Scala 2.11)", }, { - Version: "7.3.x-scala2.12", - Description: "7.3 LTS (includes Apache Spark 3.0.1, Scala 2.12)", + Key: "7.3.x-scala2.12", + Name: "7.3 LTS (includes Apache Spark 3.0.1, Scala 2.12)", }, { - Version: "6.4.x-scala2.11", - Description: "6.4 (includes Apache Spark 2.4.5, Scala 2.11)", + Key: "6.4.x-scala2.11", + Name: "6.4 (includes Apache Spark 2.4.5, Scala 2.11)", }, }, }, @@ -1016,119 +1016,6 @@ func TestEventsEmptyResult(t *testing.T) { assert.Equal(t, len(clusterEvents), 0) } -func TestListSparkVersions(t *testing.T) { - client, server, err := qa.HttpFixtureClient(t, []qa.HTTPFixture{ - { - Method: "GET", - Resource: "/api/2.0/clusters/spark-versions", - Response: SparkVersionsList{ - SparkVersions: []SparkVersion{ - { - Version: "7.1.x-cpu-ml-scala2.12", - Description: "7.1 ML (includes Apache Spark 3.0.0, Scala 2.12)", - }, - { - Version: "apache-spark-2.4.x-scala2.11", - Description: "Light 2.4 (includes Apache Spark 2.4, Scala 2.11)", - }, - { - Version: "7.3.x-hls-scala2.12", - Description: "7.3 LTS Genomics (includes Apache Spark 3.0.1, Scala 2.12)", - }, - { - Version: "6.4.x-scala2.11", - Description: "6.4 (includes Apache Spark 2.4.5, Scala 2.11)", - }, - }, - }, - }, - }) - defer server.Close() - require.NoError(t, err) - - ctx := context.Background() - sparkVersions, err := NewClustersAPI(ctx, client).ListSparkVersions() - require.NoError(t, err) - require.Equal(t, 4, len(sparkVersions.SparkVersions)) - require.Equal(t, "6.4.x-scala2.11", sparkVersions.SparkVersions[3].Version) -} - -func TestListSparkVersionsWithError(t *testing.T) { - client, server, err := qa.HttpFixtureClient(t, []qa.HTTPFixture{ - { - Method: "GET", - Resource: "/api/2.0/clusters/spark-versions", - Response: "{garbage....", - }, - }) - defer server.Close() - require.NoError(t, err) - - ctx := context.Background() - _, err = NewClustersAPI(ctx, client).ListSparkVersions() - require.Error(t, err) - require.Equal(t, true, strings.Contains(err.Error(), "invalid character 'g' looking")) -} - -func TestGetLatestSparkVersion(t *testing.T) { - versions := SparkVersionsList{ - SparkVersions: []SparkVersion{ - { - Version: "7.1.x-cpu-ml-scala2.12", - Description: "7.1 ML (includes Apache Spark 3.0.0, Scala 2.12)", - }, - { - Version: "apache-spark-2.4.x-scala2.11", - Description: "Light 2.4 (includes Apache Spark 2.4, Scala 2.11)", - }, - { - Version: "7.3.x-hls-scala2.12", - Description: "7.3 LTS Genomics (includes Apache Spark 3.0.1, Scala 2.12)", - }, - { - Version: "6.4.x-scala2.11", - Description: "6.4 (includes Apache Spark 2.4.5, Scala 2.11)", - }, - { - Version: "7.3.x-scala2.12", - Description: "7.3 LTS (includes Apache Spark 3.0.1, Scala 2.12)", - }, - { - Version: "7.4.x-scala2.12", - Description: "7.4 (includes Apache Spark 3.0.1, Scala 2.12)", - }, - { - Version: "7.1.x-scala2.12", - Description: "7.1 (includes Apache Spark 3.0.0, Scala 2.12)", - }, - }, - } - - version, err := versions.LatestSparkVersion(SparkVersionRequest{Scala: "2.12", Latest: true}) - require.NoError(t, err) - require.Equal(t, "7.4.x-scala2.12", version) - - version, err = versions.LatestSparkVersion(SparkVersionRequest{Scala: "2.12", LongTermSupport: true, Latest: true}) - require.NoError(t, err) - require.Equal(t, "7.3.x-scala2.12", version) - - version, err = versions.LatestSparkVersion(SparkVersionRequest{Scala: "2.12", Latest: true, SparkVersion: "3.0.0"}) - require.NoError(t, err) - require.Equal(t, "7.1.x-scala2.12", version) - - _, err = versions.LatestSparkVersion(SparkVersionRequest{Scala: "2.12"}) - require.Error(t, err) - require.Equal(t, true, strings.Contains(err.Error(), "query returned multiple results")) - - _, err = versions.LatestSparkVersion(SparkVersionRequest{Scala: "2.12", ML: true, Genomics: true}) - require.Error(t, err) - require.Equal(t, true, strings.Contains(err.Error(), "query returned no results")) - - _, err = versions.LatestSparkVersion(SparkVersionRequest{Scala: "2.12", SparkVersion: "3.10"}) - require.Error(t, err) - require.Equal(t, true, strings.Contains(err.Error(), "query returned no results")) -} - func TestClusterState_CanReach(t *testing.T) { tests := []struct { from ClusterState diff --git a/clusters/data_spark_version.go b/clusters/data_spark_version.go index 2b5e3c2fe1..dfe1541795 100644 --- a/clusters/data_spark_version.go +++ b/clusters/data_spark_version.go @@ -2,129 +2,13 @@ package clusters import ( "context" - "fmt" - "regexp" - "sort" - "strings" "github.com/databricks/databricks-sdk-go" "github.com/databricks/databricks-sdk-go/service/compute" "github.com/databricks/terraform-provider-databricks/common" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" - "golang.org/x/mod/semver" ) -// SparkVersion - contains information about specific version -type SparkVersion struct { - Version string `json:"key"` - Description string `json:"name"` -} - -// SparkVersionsList - returns a list of all currently supported Spark Versions -// https://docs.databricks.com/dev-tools/api/latest/clusters.html#runtime-versions -type SparkVersionsList struct { - SparkVersions []SparkVersion `json:"versions"` -} - -// SparkVersionRequest - filtering request -type SparkVersionRequest struct { - LongTermSupport bool `json:"long_term_support,omitempty"` - Beta bool `json:"beta,omitempty" tf:"conflicts:long_term_support"` - Latest bool `json:"latest,omitempty" tf:"default:true"` - ML bool `json:"ml,omitempty"` - Genomics bool `json:"genomics,omitempty"` - GPU bool `json:"gpu,omitempty"` - Scala string `json:"scala,omitempty" tf:"default:2.12"` - SparkVersion string `json:"spark_version,omitempty"` - Photon bool `json:"photon,omitempty"` - Graviton bool `json:"graviton,omitempty"` -} - -// ListSparkVersions returns smallest (or default) node type id given the criteria -func (a ClustersAPI) ListSparkVersions() (SparkVersionsList, error) { - var sparkVersions SparkVersionsList - err := a.client.Get(a.context, "/clusters/spark-versions", nil, &sparkVersions) - return sparkVersions, err -} - -type sparkVersionsType []string - -func (s sparkVersionsType) Len() int { - return len(s) -} -func (s sparkVersionsType) Swap(i, j int) { - s[i], s[j] = s[j], s[i] -} - -var dbrVersionRegex = regexp.MustCompile(`^(\d+\.\d+)\.x-.*`) - -func extractDbrVersions(s string) string { - m := dbrVersionRegex.FindStringSubmatch(s) - if len(m) > 1 { - return m[1] - } - return s -} - -func (s sparkVersionsType) Less(i, j int) bool { - return semver.Compare("v"+extractDbrVersions(s[i]), "v"+extractDbrVersions(s[j])) > 0 -} - -// LatestSparkVersion returns latest version matching the request parameters -func (sparkVersions SparkVersionsList) LatestSparkVersion(req SparkVersionRequest) (string, error) { - var versions []string - - for _, version := range sparkVersions.SparkVersions { - if strings.Contains(version.Version, "-scala"+req.Scala) { - matches := ((!strings.Contains(version.Version, "apache-spark-")) && - (strings.Contains(version.Version, "-ml-") == req.ML) && - (strings.Contains(version.Version, "-hls-") == req.Genomics) && - (strings.Contains(version.Version, "-gpu-") == req.GPU) && - (strings.Contains(version.Version, "-photon-") == req.Photon) && - (strings.Contains(version.Version, "-aarch64-") == req.Graviton) && - (strings.Contains(version.Description, "Beta") == req.Beta)) - if matches && req.LongTermSupport { - matches = (matches && (strings.Contains(version.Description, "LTS") || strings.Contains(version.Version, "-esr-"))) - } - if matches && len(req.SparkVersion) > 0 { - matches = (matches && strings.Contains(version.Description, "Apache Spark "+req.SparkVersion)) - } - if matches { - versions = append(versions, version.Version) - } - } - } - if len(versions) < 1 { - return "", fmt.Errorf("spark versions query returned no results. Please change your search criteria and try again") - } else if len(versions) > 1 { - if req.Latest { - sort.Sort(sparkVersionsType(versions)) - } else { - return "", fmt.Errorf("spark versions query returned multiple results. Please change your search criteria and try again") - } - } - - return versions[0], nil -} - -// LatestSparkVersion returns latest version matching the request parameters -func (a ClustersAPI) LatestSparkVersion(svr SparkVersionRequest) (string, error) { - sparkVersions, err := a.ListSparkVersions() - if err != nil { - return "", err - } - return sparkVersions.LatestSparkVersion(svr) -} - -// LatestSparkVersionOrDefault returns Spark version matching the definition, or default in case of error -func (a ClustersAPI) LatestSparkVersionOrDefault(svr SparkVersionRequest) string { - version, err := a.LatestSparkVersion(svr) - if err != nil { - return "7.3.x-scala2.12" - } - return version -} - // DataSourceSparkVersion returns DBR version matching to the specification func DataSourceSparkVersion() common.Resource { return common.WorkspaceDataWithCustomizeFunc(func(ctx context.Context, data *compute.SparkVersionRequest, w *databricks.WorkspaceClient) error { diff --git a/exporter/exporter_test.go b/exporter/exporter_test.go index 174a766409..f605bd762f 100644 --- a/exporter/exporter_test.go +++ b/exporter/exporter_test.go @@ -180,10 +180,10 @@ func TestImportingMounts(t *testing.T) { Method: "GET", ReuseRequest: true, Resource: "/api/2.0/clusters/spark-versions", - Response: clusters.SparkVersionsList{ - SparkVersions: []clusters.SparkVersion{ + Response: compute.GetSparkVersionsResponse{ + Versions: []compute.SparkVersion{ { - Version: "Foo LTS", + Key: "Foo LTS", }, }, }, diff --git a/storage/mounts.go b/storage/mounts.go index 8531eacc03..1dcf66dd60 100644 --- a/storage/mounts.go +++ b/storage/mounts.go @@ -137,11 +137,10 @@ func getCommonClusterObject(clustersAPI clusters.ClustersAPI, clusterName string return clusters.Cluster{ NumWorkers: 0, ClusterName: clusterName, - SparkVersion: clustersAPI.LatestSparkVersionOrDefault( - clusters.SparkVersionRequest{ - Latest: true, - LongTermSupport: true, - }), + SparkVersion: clusters.LatestSparkVersionOrDefault(clustersAPI.Context(), clustersAPI.WorkspaceClient(), compute.SparkVersionRequest{ + Latest: true, + LongTermSupport: true, + }), NodeTypeID: clustersAPI.GetSmallestNodeType( compute.NodeTypeRequest{ LocalDisk: true, diff --git a/storage/mounts_test.go b/storage/mounts_test.go index 401fed07ca..9d8a4a2521 100644 --- a/storage/mounts_test.go +++ b/storage/mounts_test.go @@ -174,11 +174,11 @@ func TestDeletedMountClusterRecreates(t *testing.T) { Method: "GET", ReuseRequest: true, Resource: "/api/2.0/clusters/spark-versions", - Response: clusters.SparkVersionsList{ - SparkVersions: []clusters.SparkVersion{ + Response: compute.GetSparkVersionsResponse{ + Versions: []compute.SparkVersion{ { - Version: "7.1.x-cpu-ml-scala2.12", - Description: "7.1 ML (includes Apache Spark 3.0.0, Scala 2.12)", + Key: "7.1.x-cpu-ml-scala2.12", + Name: "7.1 ML (includes Apache Spark 3.0.0, Scala 2.12)", }, }, }, diff --git a/storage/resource_mount_test.go b/storage/resource_mount_test.go index 58318fd01c..24bf5ca37c 100644 --- a/storage/resource_mount_test.go +++ b/storage/resource_mount_test.go @@ -21,11 +21,11 @@ import ( // Test interface compliance via compile time error var _ Mount = (*S3IamMount)(nil) -var sparkVersionsResponse = clusters.SparkVersionsList{ - SparkVersions: []clusters.SparkVersion{ +var sparkVersionsResponse = compute.GetSparkVersionsResponse{ + Versions: []compute.SparkVersion{ { - Version: "7.3.x-scala2.12", - Description: "7.3 LTS (includes Apache Spark 3.0.1, Scala 2.12)", + Key: "7.3.x-scala2.12", + Name: "7.3 LTS (includes Apache Spark 3.0.1, Scala 2.12)", }, }, }