Skip to content

Commit

Permalink
clean up spark versions methods
Browse files Browse the repository at this point in the history
  • Loading branch information
nkvuong committed Jul 2, 2024
1 parent 8d6e92d commit 57b73d8
Show file tree
Hide file tree
Showing 12 changed files with 71 additions and 280 deletions.
2 changes: 1 addition & 1 deletion access/resource_sql_permissions.go
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ func (ta *SqlPermissions) initCluster(ctx context.Context, d *schema.ResourceDat
}

func (ta *SqlPermissions) getOrCreateCluster(clustersAPI clusters.ClustersAPI) (string, error) {
sparkVersion := clustersAPI.LatestSparkVersionOrDefault(clusters.SparkVersionRequest{
sparkVersion := clusters.LatestSparkVersionOrDefault(clustersAPI.Context(), clustersAPI.WorkspaceClient(), compute.SparkVersionRequest{
Latest: true,
})
nodeType := clustersAPI.GetSmallestNodeType(compute.NodeTypeRequest{LocalDisk: true})
Expand Down
16 changes: 8 additions & 8 deletions access/resource_sql_permissions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,11 +185,11 @@ var createHighConcurrencyCluster = []qa.HTTPFixture{
Method: "GET",
ReuseRequest: true,
Resource: "/api/2.0/clusters/spark-versions",
Response: clusters.SparkVersionsList{
SparkVersions: []clusters.SparkVersion{
Response: compute.GetSparkVersionsResponse{
Versions: []compute.SparkVersion{
{
Version: "7.1.x-cpu-ml-scala2.12",
Description: "7.1 ML (includes Apache Spark 3.0.0, Scala 2.12)",
Key: "7.1.x-cpu-ml-scala2.12",
Name: "7.1 ML (includes Apache Spark 3.0.0, Scala 2.12)",
},
},
},
Expand Down Expand Up @@ -262,11 +262,11 @@ var createSharedCluster = []qa.HTTPFixture{
Method: "GET",
ReuseRequest: true,
Resource: "/api/2.0/clusters/spark-versions",
Response: clusters.SparkVersionsList{
SparkVersions: []clusters.SparkVersion{
Response: compute.GetSparkVersionsResponse{
Versions: []compute.SparkVersion{
{
Version: "7.1.x-cpu-ml-scala2.12",
Description: "7.1 ML (includes Apache Spark 3.0.0, Scala 2.12)",
Key: "7.1.x-cpu-ml-scala2.12",
Name: "7.1 ML (includes Apache Spark 3.0.0, Scala 2.12)",
},
},
},
Expand Down
2 changes: 1 addition & 1 deletion catalog/resource_sql_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ func (ti *SqlTableInfo) initCluster(ctx context.Context, d *schema.ResourceData,
}

func (ti *SqlTableInfo) getOrCreateCluster(clusterName string, clustersAPI clusters.ClustersAPI) (string, error) {
sparkVersion := clustersAPI.LatestSparkVersionOrDefault(clusters.SparkVersionRequest{
sparkVersion := clusters.LatestSparkVersionOrDefault(clustersAPI.Context(), clustersAPI.WorkspaceClient(), compute.SparkVersionRequest{
Latest: true,
})
nodeType := clustersAPI.GetSmallestNodeType(compute.NodeTypeRequest{LocalDisk: true})
Expand Down
12 changes: 4 additions & 8 deletions catalog/resource_sql_table_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1248,15 +1248,11 @@ var baseClusterFixture = []qa.HTTPFixture{
Method: "GET",
ReuseRequest: true,
Resource: "/api/2.0/clusters/spark-versions",
Response: clusters.SparkVersionsList{
SparkVersions: []clusters.SparkVersion{
Response: compute.GetSparkVersionsResponse{
Versions: []compute.SparkVersion{
{
Version: "7.1.x-cpu-ml-scala2.12",
Description: "7.1 ML (includes Apache Spark 3.0.0, Scala 2.12)",
},
{
Version: "7.3.x-scala2.12",
Description: "7.3 LTS (includes Apache Spark 3.0.1, Scala 2.12)",
Key: "7.1.x-cpu-ml-scala2.12",
Name: "7.1 ML (includes Apache Spark 3.0.0, Scala 2.12)",
},
},
},
Expand Down
28 changes: 22 additions & 6 deletions clusters/clusters_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"sync"
"time"

"github.com/databricks/databricks-sdk-go"
"github.com/databricks/databricks-sdk-go/apierr"
"github.com/databricks/databricks-sdk-go/service/compute"

Expand Down Expand Up @@ -574,6 +575,19 @@ type ClustersAPI struct {
context context.Context
}

// Temporary function to be used until all resources are migrated to Go SDK
// Create a workspace client
func (a ClustersAPI) WorkspaceClient() *databricks.WorkspaceClient {
client, _ := a.client.WorkspaceClient()
return client
}

// Temporary function to be used until all resources are migrated to Go SDK
// Return a context
func (a ClustersAPI) Context() context.Context {
return a.context
}

// Create creates a new Spark cluster and waits till it's running
func (a ClustersAPI) Create(cluster Cluster) (info ClusterInfo, err error) {
var ci ClusterID
Expand Down Expand Up @@ -867,6 +881,7 @@ var getOrCreateClusterMutex sync.Mutex

// GetOrCreateRunningCluster creates an autoterminating cluster if it doesn't exist
func (a ClustersAPI) GetOrCreateRunningCluster(name string, custom ...Cluster) (c ClusterInfo, err error) {
w, err := a.client.WorkspaceClient()
getOrCreateClusterMutex.Lock()
defer getOrCreateClusterMutex.Unlock()

Expand Down Expand Up @@ -900,13 +915,14 @@ func (a ClustersAPI) GetOrCreateRunningCluster(name string, custom ...Cluster) (
LocalDisk: true,
})
log.Printf("[INFO] Creating an autoterminating cluster with node type %s", smallestNodeType)
latestVersion, _ := w.Clusters.SelectSparkVersion(a.context, compute.SparkVersionRequest{
Latest: true,
LongTermSupport: true,
})
r := Cluster{
NumWorkers: 1,
ClusterName: name,
SparkVersion: a.LatestSparkVersionOrDefault(SparkVersionRequest{
Latest: true,
LongTermSupport: true,
}),
NumWorkers: 1,
ClusterName: name,
SparkVersion: latestVersion,
NodeTypeID: smallestNodeType,
AutoterminationMinutes: 10,
}
Expand Down
9 changes: 9 additions & 0 deletions clusters/clusters_api_sdk.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,12 @@ func StartClusterAndGetInfo(ctx context.Context, w *databricks.WorkspaceClient,
}
return w.Clusters.StartByClusterIdAndWait(ctx, clusterID)
}

// LatestSparkVersionOrDefault returns Spark version matching the definition, or default in case of error
func LatestSparkVersionOrDefault(ctx context.Context, w *databricks.WorkspaceClient, svr compute.SparkVersionRequest) string {
version, err := w.Clusters.SelectSparkVersion(ctx, svr)
if err != nil {
return "7.3.x-scala2.12"
}
return version
}
135 changes: 11 additions & 124 deletions clusters/clusters_api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"fmt"

// "reflect"
"strings"

"testing"

"github.com/databricks/databricks-sdk-go/apierr"
Expand All @@ -28,23 +28,23 @@ func TestGetOrCreateRunningCluster_AzureAuth(t *testing.T) {
Method: "GET",
ReuseRequest: true,
Resource: "/api/2.0/clusters/spark-versions",
Response: SparkVersionsList{
SparkVersions: []SparkVersion{
Response: compute.GetSparkVersionsResponse{
Versions: []compute.SparkVersion{
{
Version: "7.1.x-cpu-ml-scala2.12",
Description: "7.1 ML (includes Apache Spark 3.0.0, Scala 2.12)",
Key: "7.1.x-cpu-ml-scala2.12",
Name: "7.1 ML (includes Apache Spark 3.0.0, Scala 2.12)",
},
{
Version: "apache-spark-2.4.x-scala2.11",
Description: "Light 2.4 (includes Apache Spark 2.4, Scala 2.11)",
Key: "apache-spark-2.4.x-scala2.11",
Name: "Light 2.4 (includes Apache Spark 2.4, Scala 2.11)",
},
{
Version: "7.3.x-scala2.12",
Description: "7.3 LTS (includes Apache Spark 3.0.1, Scala 2.12)",
Key: "7.3.x-scala2.12",
Name: "7.3 LTS (includes Apache Spark 3.0.1, Scala 2.12)",
},
{
Version: "6.4.x-scala2.11",
Description: "6.4 (includes Apache Spark 2.4.5, Scala 2.11)",
Key: "6.4.x-scala2.11",
Name: "6.4 (includes Apache Spark 2.4.5, Scala 2.11)",
},
},
},
Expand Down Expand Up @@ -1016,119 +1016,6 @@ func TestEventsEmptyResult(t *testing.T) {
assert.Equal(t, len(clusterEvents), 0)
}

func TestListSparkVersions(t *testing.T) {
client, server, err := qa.HttpFixtureClient(t, []qa.HTTPFixture{
{
Method: "GET",
Resource: "/api/2.0/clusters/spark-versions",
Response: SparkVersionsList{
SparkVersions: []SparkVersion{
{
Version: "7.1.x-cpu-ml-scala2.12",
Description: "7.1 ML (includes Apache Spark 3.0.0, Scala 2.12)",
},
{
Version: "apache-spark-2.4.x-scala2.11",
Description: "Light 2.4 (includes Apache Spark 2.4, Scala 2.11)",
},
{
Version: "7.3.x-hls-scala2.12",
Description: "7.3 LTS Genomics (includes Apache Spark 3.0.1, Scala 2.12)",
},
{
Version: "6.4.x-scala2.11",
Description: "6.4 (includes Apache Spark 2.4.5, Scala 2.11)",
},
},
},
},
})
defer server.Close()
require.NoError(t, err)

ctx := context.Background()
sparkVersions, err := NewClustersAPI(ctx, client).ListSparkVersions()
require.NoError(t, err)
require.Equal(t, 4, len(sparkVersions.SparkVersions))
require.Equal(t, "6.4.x-scala2.11", sparkVersions.SparkVersions[3].Version)
}

func TestListSparkVersionsWithError(t *testing.T) {
client, server, err := qa.HttpFixtureClient(t, []qa.HTTPFixture{
{
Method: "GET",
Resource: "/api/2.0/clusters/spark-versions",
Response: "{garbage....",
},
})
defer server.Close()
require.NoError(t, err)

ctx := context.Background()
_, err = NewClustersAPI(ctx, client).ListSparkVersions()
require.Error(t, err)
require.Equal(t, true, strings.Contains(err.Error(), "invalid character 'g' looking"))
}

func TestGetLatestSparkVersion(t *testing.T) {
versions := SparkVersionsList{
SparkVersions: []SparkVersion{
{
Version: "7.1.x-cpu-ml-scala2.12",
Description: "7.1 ML (includes Apache Spark 3.0.0, Scala 2.12)",
},
{
Version: "apache-spark-2.4.x-scala2.11",
Description: "Light 2.4 (includes Apache Spark 2.4, Scala 2.11)",
},
{
Version: "7.3.x-hls-scala2.12",
Description: "7.3 LTS Genomics (includes Apache Spark 3.0.1, Scala 2.12)",
},
{
Version: "6.4.x-scala2.11",
Description: "6.4 (includes Apache Spark 2.4.5, Scala 2.11)",
},
{
Version: "7.3.x-scala2.12",
Description: "7.3 LTS (includes Apache Spark 3.0.1, Scala 2.12)",
},
{
Version: "7.4.x-scala2.12",
Description: "7.4 (includes Apache Spark 3.0.1, Scala 2.12)",
},
{
Version: "7.1.x-scala2.12",
Description: "7.1 (includes Apache Spark 3.0.0, Scala 2.12)",
},
},
}

version, err := versions.LatestSparkVersion(SparkVersionRequest{Scala: "2.12", Latest: true})
require.NoError(t, err)
require.Equal(t, "7.4.x-scala2.12", version)

version, err = versions.LatestSparkVersion(SparkVersionRequest{Scala: "2.12", LongTermSupport: true, Latest: true})
require.NoError(t, err)
require.Equal(t, "7.3.x-scala2.12", version)

version, err = versions.LatestSparkVersion(SparkVersionRequest{Scala: "2.12", Latest: true, SparkVersion: "3.0.0"})
require.NoError(t, err)
require.Equal(t, "7.1.x-scala2.12", version)

_, err = versions.LatestSparkVersion(SparkVersionRequest{Scala: "2.12"})
require.Error(t, err)
require.Equal(t, true, strings.Contains(err.Error(), "query returned multiple results"))

_, err = versions.LatestSparkVersion(SparkVersionRequest{Scala: "2.12", ML: true, Genomics: true})
require.Error(t, err)
require.Equal(t, true, strings.Contains(err.Error(), "query returned no results"))

_, err = versions.LatestSparkVersion(SparkVersionRequest{Scala: "2.12", SparkVersion: "3.10"})
require.Error(t, err)
require.Equal(t, true, strings.Contains(err.Error(), "query returned no results"))
}

func TestClusterState_CanReach(t *testing.T) {
tests := []struct {
from ClusterState
Expand Down
Loading

0 comments on commit 57b73d8

Please sign in to comment.