diff --git a/clusters/clusters_api.go b/clusters/clusters_api.go index 6a08a4a60..308016bda 100644 --- a/clusters/clusters_api.go +++ b/clusters/clusters_api.go @@ -434,22 +434,6 @@ type Cluster struct { ClusterMounts []MountInfo `json:"cluster_mount_infos,omitempty" tf:"alias:cluster_mount_info"` } -// TODO: Remove this once all the resources using clusters are migrated to Go SDK. -// They would then be using Validate(cluster compute.CreateCluster) defined in resource_cluster.go that is a duplicate of this method but uses Go SDK. -func (cluster Cluster) Validate() error { - // TODO: rewrite with CustomizeDiff - if cluster.NumWorkers > 0 || cluster.Autoscale != nil { - return nil - } - profile := cluster.SparkConf["spark.databricks.cluster.profile"] - master := cluster.SparkConf["spark.master"] - resourceClass := cluster.CustomTags["ResourceClass"] - if profile == "singleNode" && strings.HasPrefix(master, "local") && resourceClass == "SingleNode" { - return nil - } - return errors.New(numWorkerErr) -} - // TODO: Remove this once all the resources using clusters are migrated to Go SDK. // They would then be using ModifyRequestOnInstancePool(cluster *compute.CreateCluster) defined in resource_cluster.go that is a duplicate of this method but uses Go SDK. // ModifyRequestOnInstancePool helps remove all request fields that should not be submitted when instance pool is selected. diff --git a/clusters/resource_cluster.go b/clusters/resource_cluster.go index 3c0350202..ae2240dee 100644 --- a/clusters/resource_cluster.go +++ b/clusters/resource_cluster.go @@ -130,9 +130,19 @@ func ZoneDiffSuppress(k, old, new string, d *schema.ResourceData) bool { return false } -// This method is a duplicate of Validate() in clusters/clusters_api.go that uses Go SDK. -// Long term, Validate() in clusters_api.go will be removed once all the resources using clusters are migrated to Go SDK. -func Validate(cluster any) error { +// The clusters API does not provide a great way to create a single node cluster. +// This function validates that a user has their cluster configured correctly IF +// they were trying to create a single node cluster. It does this by: +// +// 1. Asserting the correct cluster tags and spark conf are set when num_workers is 0 +// and autoscaling is not enabled. +// 2. Skip the validation if a policy is configured on the cluster. This is done to allow +// users to configure spark_conf, custom_tags, num_workers, etc. in the policy itself. +// +// TODO: Once the clusters resource is migrated to the TF plugin framework, we should +// make this a warning instead of an error. +func ValidateIfSingleNode(cluster any) error { + var hasPolicyConfigured bool var profile, master, resourceClass string switch c := cluster.(type) { case compute.CreateCluster: @@ -142,6 +152,7 @@ func Validate(cluster any) error { profile = c.SparkConf["spark.databricks.cluster.profile"] master = c.SparkConf["spark.master"] resourceClass = c.CustomTags["ResourceClass"] + hasPolicyConfigured = c.PolicyId != "" case compute.EditCluster: if c.NumWorkers > 0 || c.Autoscale != nil { return nil @@ -149,6 +160,7 @@ func Validate(cluster any) error { profile = c.SparkConf["spark.databricks.cluster.profile"] master = c.SparkConf["spark.master"] resourceClass = c.CustomTags["ResourceClass"] + hasPolicyConfigured = c.PolicyId != "" case compute.ClusterSpec: if c.NumWorkers > 0 || c.Autoscale != nil { return nil @@ -156,9 +168,16 @@ func Validate(cluster any) error { profile = c.SparkConf["spark.databricks.cluster.profile"] master = c.SparkConf["spark.master"] resourceClass = c.CustomTags["ResourceClass"] + hasPolicyConfigured = c.PolicyId != "" default: return fmt.Errorf(unsupportedExceptCreateEditClusterSpecErr, cluster, "", "", "") } + // If a cluster has a policy configured then we skip validation regarding whether + // the single node cluster configuration is valid or not. This is done to allow + // users to configure spark_conf, custom_tags, num_workers, etc. in the policy itself. + if hasPolicyConfigured { + return nil + } if profile == "singleNode" && strings.HasPrefix(master, "local") && resourceClass == "SingleNode" { return nil } @@ -445,7 +464,7 @@ func resourceClusterCreate(ctx context.Context, d *schema.ResourceData, c *commo clusters := w.Clusters var createClusterRequest compute.CreateCluster common.DataToStructPointer(d, clusterSchema, &createClusterRequest) - if err := Validate(createClusterRequest); err != nil { + if err := ValidateIfSingleNode(createClusterRequest); err != nil { return err } if err = ModifyRequestOnInstancePool(&createClusterRequest); err != nil { @@ -595,7 +614,7 @@ func resourceClusterUpdate(ctx context.Context, d *schema.ResourceData, c *commo if hasClusterConfigChanged(d) { log.Printf("[DEBUG] Cluster state has changed!") - if err := Validate(cluster); err != nil { + if err := ValidateIfSingleNode(cluster); err != nil { return err } if err = ModifyRequestOnInstancePool(&cluster); err != nil { diff --git a/clusters/resource_cluster_test.go b/clusters/resource_cluster_test.go index 240b62cb4..b8004bab3 100644 --- a/clusters/resource_cluster_test.go +++ b/clusters/resource_cluster_test.go @@ -1863,6 +1863,77 @@ func TestResourceClusterCreate_SingleNodeFail(t *testing.T) { assert.EqualError(t, err, numWorkerErr) } +func TestResourceClusterCreate_SingleNodeWithPolicy(t *testing.T) { + d, err := qa.ResourceFixture{ + Fixtures: []qa.HTTPFixture{ + { + Method: "POST", + Resource: "/api/2.1/clusters/create", + ExpectedRequest: compute.CreateCluster{ + NumWorkers: 0, + ClusterName: "Single Node Cluster", + SparkVersion: "7.3.x-scala12", + NodeTypeId: "Standard_F4s", + AutoterminationMinutes: 120, + ForceSendFields: []string{"NumWorkers"}, + PolicyId: "policy-123", + }, + Response: compute.ClusterDetails{ + ClusterId: "abc", + State: compute.StateRunning, + }, + }, + { + Method: "POST", + Resource: "/api/2.1/clusters/events", + ExpectedRequest: compute.GetEvents{ + ClusterId: "abc", + Limit: 1, + Order: compute.GetEventsOrderDesc, + EventTypes: []compute.EventType{compute.EventTypePinned, compute.EventTypeUnpinned}, + }, + Response: compute.GetEventsResponse{ + Events: []compute.ClusterEvent{}, + TotalCount: 0, + }, + }, + { + Method: "GET", + ReuseRequest: true, + Resource: "/api/2.1/clusters/get?cluster_id=abc", + Response: compute.ClusterDetails{ + ClusterId: "abc", + ClusterName: "Single Node Cluster", + SparkVersion: "7.3.x-scala12", + NodeTypeId: "Standard_F4s", + AutoterminationMinutes: 120, + State: compute.StateRunning, + PolicyId: "policy-123", + }, + }, + { + Method: "GET", + Resource: "/api/2.0/libraries/cluster-status?cluster_id=abc", + Response: compute.ClusterLibraryStatuses{ + LibraryStatuses: []compute.LibraryFullStatus{}, + }, + }, + }, + Create: true, + Resource: ResourceCluster(), + State: map[string]any{ + "autotermination_minutes": 120, + "cluster_name": "Single Node Cluster", + "spark_version": "7.3.x-scala12", + "node_type_id": "Standard_F4s", + "is_pinned": false, + "policy_id": "policy-123", + }, + }.Apply(t) + assert.NoError(t, err) + assert.Equal(t, 0, d.Get("num_workers")) +} + func TestResourceClusterCreate_NegativeNumWorkers(t *testing.T) { _, err := qa.ResourceFixture{ Create: true, @@ -1902,6 +1973,76 @@ func TestResourceClusterUpdate_FailNumWorkersZero(t *testing.T) { assert.EqualError(t, err, numWorkerErr) } +func TestResourceClusterUpdate_NumWorkersZeroWithPolicy(t *testing.T) { + _, err := qa.ResourceFixture{ + Fixtures: []qa.HTTPFixture{ + { + Method: "GET", + Resource: "/api/2.1/clusters/get?cluster_id=abc", + ReuseRequest: true, + Response: compute.ClusterDetails{ + ClusterId: "abc", + NumWorkers: 0, + ClusterName: "Shared Autoscaling", + SparkVersion: "7.1-scala12", + NodeTypeId: "i3.xlarge", + AutoterminationMinutes: 15, + State: compute.StateTerminated, + PolicyId: "policy-123", + }, + }, + { + Method: "POST", + Resource: "/api/2.1/clusters/events", + ExpectedRequest: compute.GetEvents{ + ClusterId: "abc", + Limit: 1, + Order: compute.GetEventsOrderDesc, + EventTypes: []compute.EventType{compute.EventTypePinned, compute.EventTypeUnpinned}, + }, + Response: compute.GetEventsResponse{ + Events: []compute.ClusterEvent{}, + TotalCount: 0, + }, + }, + { + Method: "POST", + Resource: "/api/2.1/clusters/edit", + ExpectedRequest: compute.ClusterDetails{ + AutoterminationMinutes: 15, + ClusterId: "abc", + NumWorkers: 0, + ClusterName: "Shared Autoscaling", + SparkVersion: "7.1-scala12", + NodeTypeId: "i3.xlarge", + PolicyId: "policy-123", + ForceSendFields: []string{"NumWorkers"}, + }, + }, + }, + ID: "abc", + Update: true, + Resource: ResourceCluster(), + InstanceState: map[string]string{ + "autotermination_minutes": "15", + "cluster_name": "Shared Autoscaling", + "spark_version": "7.1-scala12", + "node_type_id": "i3.xlarge", + "num_workers": "100", + "policy_id": "policy-123", + }, + State: map[string]any{ + "autotermination_minutes": 15, + "cluster_name": "Shared Autoscaling", + "spark_version": "7.1-scala12", + "node_type_id": "i3.xlarge", + "num_workers": 0, + "policy_id": "policy-123", + }, + }.Apply(t) + assert.NoError(t, err) +} + func TestModifyClusterRequestAws(t *testing.T) { c := compute.CreateCluster{ InstancePoolId: "a", diff --git a/jobs/jobs_api_go_sdk.go b/jobs/jobs_api_go_sdk.go index 6051bafae..915c48054 100644 --- a/jobs/jobs_api_go_sdk.go +++ b/jobs/jobs_api_go_sdk.go @@ -157,7 +157,7 @@ func (c controlRunStateLifecycleManagerGoSdk) OnUpdate(ctx context.Context) erro } func updateAndValidateJobClusterSpec(clusterSpec *compute.ClusterSpec, d *schema.ResourceData) error { - err := clusters.Validate(*clusterSpec) + err := clusters.ValidateIfSingleNode(*clusterSpec) if err != nil { return err } diff --git a/jobs/resource_job.go b/jobs/resource_job.go index be2b982a7..6626e4bc3 100644 --- a/jobs/resource_job.go +++ b/jobs/resource_job.go @@ -1072,12 +1072,12 @@ func ResourceJob() common.Resource { if task.NewCluster == nil { continue } - if err := clusters.Validate(*task.NewCluster); err != nil { + if err := clusters.ValidateIfSingleNode(*task.NewCluster); err != nil { return fmt.Errorf("task %s invalid: %w", task.TaskKey, err) } } if js.NewCluster != nil { - if err := clusters.Validate(*js.NewCluster); err != nil { + if err := clusters.ValidateIfSingleNode(*js.NewCluster); err != nil { return fmt.Errorf("invalid job cluster: %w", err) } } diff --git a/jobs/resource_job_test.go b/jobs/resource_job_test.go index 75a780c00..12f257663 100644 --- a/jobs/resource_job_test.go +++ b/jobs/resource_job_test.go @@ -2073,6 +2073,55 @@ is defined in a policy used by the cluster. Please define this in the cluster co itself to create a single node cluster.`) } +func TestResourceJobCreate_SingleNodeJobClustersWithPolicy(t *testing.T) { + d, err := qa.ResourceFixture{ + Fixtures: []qa.HTTPFixture{ + { + Method: "POST", + Resource: "/api/2.0/jobs/create", + ExpectedRequest: JobSettings{ + Name: "single node cluster", + MaxConcurrentRuns: 1, + Libraries: []compute.Library{ + { + Jar: "dbfs://ff/gg/hh.jar", + }, + }, + NewCluster: &clusters.Cluster{ + NumWorkers: 0, + PolicyID: "policy-123", + SparkVersion: "7.3.x-scala2.12", + }, + }, + Response: Job{ + JobID: 17, + }, + }, + { + Method: "GET", + Resource: "/api/2.0/jobs/get?job_id=17", + Response: Job{ + Settings: &JobSettings{}, + }, + }, + }, + Create: true, + Resource: ResourceJob(), + HCL: ` + name = "single node cluster" + new_cluster { + spark_version = "7.3.x-scala2.12" + policy_id = "policy-123" + } + max_concurrent_runs = 1 + library { + jar = "dbfs://ff/gg/hh.jar" + }`, + }.Apply(t) + assert.NoError(t, err) + assert.Equal(t, "17", d.Id()) +} + func TestResourceJobRead(t *testing.T) { d, err := qa.ResourceFixture{ Fixtures: []qa.HTTPFixture{ @@ -2976,6 +3025,59 @@ is defined in a policy used by the cluster. Please define this in the cluster co itself to create a single node cluster.`) } +func TestResourceJobUpdate_SingleNodeJobClustersWithPolicy(t *testing.T) { + d, err := qa.ResourceFixture{ + ID: "17", + Fixtures: []qa.HTTPFixture{ + { + Method: "POST", + Resource: "/api/2.0/jobs/reset", + ExpectedRequest: UpdateJobRequest{ + JobID: 17, + NewSettings: &JobSettings{ + Name: "single node cluster", + MaxConcurrentRuns: 1, + Libraries: []compute.Library{ + { + Jar: "dbfs://ff/gg/hh.jar", + }, + }, + NewCluster: &clusters.Cluster{ + NumWorkers: 0, + PolicyID: "policy-123", + SparkVersion: "7.3.x-scala2.12", + }, + }, + }, + Response: Job{ + JobID: 17, + }, + }, + { + Method: "GET", + Resource: "/api/2.0/jobs/get?job_id=17", + Response: Job{ + Settings: &JobSettings{}, + }, + }, + }, + Update: true, + Resource: ResourceJob(), + HCL: ` + name = "single node cluster" + new_cluster { + spark_version = "7.3.x-scala2.12" + policy_id = "policy-123" + } + max_concurrent_runs = 1 + library { + jar = "dbfs://ff/gg/hh.jar" + }`, + }.Apply(t) + assert.NoError(t, err) + assert.Equal(t, "17", d.Id()) +} + func TestJobsAPIList(t *testing.T) { qa.HTTPFixturesApply(t, []qa.HTTPFixture{ {