diff --git a/clusters/data_cluster.go b/clusters/data_cluster.go index 8a45b7afdf..73ae4a1e19 100644 --- a/clusters/data_cluster.go +++ b/clusters/data_cluster.go @@ -4,25 +4,24 @@ import ( "context" "fmt" + "github.com/databricks/databricks-sdk-go" + "github.com/databricks/databricks-sdk-go/service/compute" "github.com/databricks/terraform-provider-databricks/common" ) func DataSourceCluster() common.Resource { - type clusterData struct { - Id string `json:"id,omitempty" tf:"computed"` - ClusterId string `json:"cluster_id,omitempty" tf:"computed"` - Name string `json:"cluster_name,omitempty" tf:"computed"` - ClusterInfo *ClusterInfo `json:"cluster_info,omitempty" tf:"computed"` - } - return common.DataResource(clusterData{}, func(ctx context.Context, e interface{}, c *common.DatabricksClient) error { - data := e.(*clusterData) - clusterAPI := NewClustersAPI(ctx, c) + return common.WorkspaceData(func(ctx context.Context, data *struct { + Id string `json:"id,omitempty" tf:"computed"` + ClusterId string `json:"cluster_id,omitempty" tf:"computed"` + Name string `json:"cluster_name,omitempty" tf:"computed"` + ClusterInfo *compute.ClusterDetails `json:"cluster_info,omitempty" tf:"computed"` + }, w *databricks.WorkspaceClient) error { if data.Name != "" { - clusters, err := clusterAPI.List() + clusters, err := w.Clusters.ListAll(ctx, compute.ListClustersRequest{}) if err != nil { return err } - namedClusters := []ClusterInfo{} + namedClusters := []compute.ClusterDetails{} for _, clst := range clusters { cluster := clst if cluster.ClusterName == data.Name { @@ -37,16 +36,16 @@ func DataSourceCluster() common.Resource { } data.ClusterInfo = &namedClusters[0] } else if data.ClusterId != "" { - cls, err := clusterAPI.Get(data.ClusterId) + cls, err := w.Clusters.GetByClusterId(ctx, data.ClusterId) if err != nil { return err } - data.ClusterInfo = &cls + data.ClusterInfo = cls } else { return fmt.Errorf("you need to specify either `cluster_name` or `cluster_id`") } - data.Id = data.ClusterInfo.ClusterID - data.ClusterId = data.ClusterInfo.ClusterID + data.Id = data.ClusterInfo.ClusterId + data.ClusterId = data.ClusterInfo.ClusterId return nil }) diff --git a/clusters/data_cluster_test.go b/clusters/data_cluster_test.go index 9945634fcc..cd20edec0d 100644 --- a/clusters/data_cluster_test.go +++ b/clusters/data_cluster_test.go @@ -1,104 +1,81 @@ package clusters import ( - "fmt" "testing" + "github.com/databricks/databricks-sdk-go/experimental/mocks" + "github.com/databricks/databricks-sdk-go/service/compute" "github.com/databricks/terraform-provider-databricks/qa" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/stretchr/testify/mock" ) func TestClusterDataByID(t *testing.T) { - d, err := qa.ResourceFixture{ - Fixtures: []qa.HTTPFixture{ - { - Method: "GET", - Resource: "/api/2.0/clusters/get?cluster_id=abc", - Response: ClusterInfo{ - ClusterID: "abc", - NumWorkers: 100, - ClusterName: "Shared Autoscaling", - SparkVersion: "7.1-scala12", - NodeTypeID: "i3.xlarge", - AutoterminationMinutes: 15, - State: ClusterStateRunning, - AutoScale: &AutoScale{ - MaxWorkers: 4, - }, + qa.ResourceFixture{ + MockWorkspaceClientFunc: func(m *mocks.MockWorkspaceClient) { + e := m.GetMockClustersAPI().EXPECT() + e.GetByClusterId(mock.Anything, "abc").Return(&compute.ClusterDetails{ + ClusterId: "abc", + NumWorkers: 100, + ClusterName: "Shared Autoscaling", + SparkVersion: "7.1-scala12", + NodeTypeId: "i3.xlarge", + AutoterminationMinutes: 15, + State: ClusterStateRunning, + Autoscale: &compute.AutoScale{ + MaxWorkers: 4, }, - }, + }, nil) }, Resource: DataSourceCluster(), HCL: `cluster_id = "abc"`, Read: true, NonWritable: true, ID: "abc", - }.Apply(t) - require.NoError(t, err) - assert.Equal(t, 15, d.Get("cluster_info.0.autotermination_minutes")) - assert.Equal(t, "Shared Autoscaling", d.Get("cluster_info.0.cluster_name")) - assert.Equal(t, "i3.xlarge", d.Get("cluster_info.0.node_type_id")) - assert.Equal(t, 4, d.Get("cluster_info.0.autoscale.0.max_workers")) - assert.Equal(t, "RUNNING", d.Get("cluster_info.0.state")) - - for k, v := range d.State().Attributes { - fmt.Printf("assert.Equal(t, %#v, d.Get(%#v))\n", v, k) - } + }.ApplyAndExpectData(t, map[string]any{ + "cluster_info.0.autotermination_minutes": 15, + "cluster_info.0.cluster_name": "Shared Autoscaling", + "cluster_info.0.node_type_id": "i3.xlarge", + "cluster_info.0.autoscale.0.max_workers": 4, + "cluster_info.0.state": "RUNNING", + }) } func TestClusterDataByName(t *testing.T) { - d, err := qa.ResourceFixture{ - Fixtures: []qa.HTTPFixture{ - { - Method: "GET", - Resource: "/api/2.0/clusters/list", - - Response: ClusterList{ - Clusters: []ClusterInfo{{ - ClusterID: "abc", - NumWorkers: 100, - ClusterName: "Shared Autoscaling", - SparkVersion: "7.1-scala12", - NodeTypeID: "i3.xlarge", - AutoterminationMinutes: 15, - State: ClusterStateRunning, - AutoScale: &AutoScale{ - MaxWorkers: 4, - }, - }}, + qa.ResourceFixture{ + MockWorkspaceClientFunc: func(m *mocks.MockWorkspaceClient) { + e := m.GetMockClustersAPI().EXPECT() + e.ListAll(mock.Anything, compute.ListClustersRequest{}).Return([]compute.ClusterDetails{{ + ClusterId: "abc", + NumWorkers: 100, + ClusterName: "Shared Autoscaling", + SparkVersion: "7.1-scala12", + NodeTypeId: "i3.xlarge", + AutoterminationMinutes: 15, + State: ClusterStateRunning, + Autoscale: &compute.AutoScale{ + MaxWorkers: 4, }, - }, + }}, nil) }, Resource: DataSourceCluster(), HCL: `cluster_name = "Shared Autoscaling"`, Read: true, NonWritable: true, ID: "_", - }.Apply(t) - require.NoError(t, err) - assert.Equal(t, 15, d.Get("cluster_info.0.autotermination_minutes")) - assert.Equal(t, "Shared Autoscaling", d.Get("cluster_info.0.cluster_name")) - assert.Equal(t, "i3.xlarge", d.Get("cluster_info.0.node_type_id")) - assert.Equal(t, 4, d.Get("cluster_info.0.autoscale.0.max_workers")) - assert.Equal(t, "RUNNING", d.Get("cluster_info.0.state")) - - for k, v := range d.State().Attributes { - fmt.Printf("assert.Equal(t, %#v, d.Get(%#v))\n", v, k) - } + }.ApplyAndExpectData(t, map[string]any{ + "cluster_info.0.autotermination_minutes": 15, + "cluster_info.0.cluster_name": "Shared Autoscaling", + "cluster_info.0.node_type_id": "i3.xlarge", + "cluster_info.0.autoscale.0.max_workers": 4, + "cluster_info.0.state": "RUNNING", + }) } func TestClusterDataByName_NotFound(t *testing.T) { qa.ResourceFixture{ - Fixtures: []qa.HTTPFixture{ - { - Method: "GET", - Resource: "/api/2.0/clusters/list", - - Response: ClusterList{ - Clusters: []ClusterInfo{}, - }, - }, + MockWorkspaceClientFunc: func(m *mocks.MockWorkspaceClient) { + e := m.GetMockClustersAPI().EXPECT() + e.ListAll(mock.Anything, compute.ListClustersRequest{}).Return([]compute.ClusterDetails{}, nil) }, Resource: DataSourceCluster(), HCL: `cluster_name = "Unknown"`, @@ -110,34 +87,34 @@ func TestClusterDataByName_NotFound(t *testing.T) { func TestClusterDataByName_DuplicateNames(t *testing.T) { qa.ResourceFixture{ - Fixtures: []qa.HTTPFixture{ - { - Method: "GET", - Resource: "/api/2.0/clusters/list", - - Response: ClusterList{ - Clusters: []ClusterInfo{ - { - ClusterID: "abc", - NumWorkers: 100, - ClusterName: "Shared Autoscaling", - SparkVersion: "7.1-scala12", - NodeTypeID: "i3.xlarge", - AutoterminationMinutes: 15, - State: ClusterStateRunning, - }, - { - ClusterID: "def", - NumWorkers: 100, - ClusterName: "Shared Autoscaling", - SparkVersion: "7.1-scala12", - NodeTypeID: "i3.xlarge", - AutoterminationMinutes: 15, - State: ClusterStateRunning, - }, + MockWorkspaceClientFunc: func(m *mocks.MockWorkspaceClient) { + e := m.GetMockClustersAPI().EXPECT() + e.ListAll(mock.Anything, compute.ListClustersRequest{}).Return([]compute.ClusterDetails{ + { + ClusterId: "abc", + NumWorkers: 100, + ClusterName: "Shared Autoscaling", + SparkVersion: "7.1-scala12", + NodeTypeId: "i3.xlarge", + AutoterminationMinutes: 15, + State: ClusterStateRunning, + Autoscale: &compute.AutoScale{ + MaxWorkers: 4, + }, + }, + { + ClusterId: "def", + NumWorkers: 100, + ClusterName: "Shared Autoscaling", + SparkVersion: "7.1-scala12", + NodeTypeId: "i3.xlarge", + AutoterminationMinutes: 15, + State: ClusterStateRunning, + Autoscale: &compute.AutoScale{ + MaxWorkers: 4, }, }, - }, + }, nil) }, Resource: DataSourceCluster(), HCL: `cluster_name = "Shared Autoscaling"`, diff --git a/clusters/data_clusters.go b/clusters/data_clusters.go index 2628c4968d..da637762b5 100644 --- a/clusters/data_clusters.go +++ b/clusters/data_clusters.go @@ -4,42 +4,32 @@ import ( "context" "strings" + "github.com/databricks/databricks-sdk-go" + "github.com/databricks/databricks-sdk-go/service/compute" "github.com/databricks/terraform-provider-databricks/common" - "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" ) func DataSourceClusters() common.Resource { - return common.Resource{ - Read: func(ctx context.Context, d *schema.ResourceData, i *common.DatabricksClient) error { - clusters, err := NewClustersAPI(ctx, i).List() - if err != nil { - return err + return common.WorkspaceData(func(ctx context.Context, data *struct { + Id string `json:"id,omitempty" tf:"computed"` + Ids []string `json:"ids,omitempty" tf:"computed,slice_set"` + ClusterNameContains string `json:"cluster_name_contains"` + }, w *databricks.WorkspaceClient) error { + clusters, err := w.Clusters.ListAll(ctx, compute.ListClustersRequest{}) + if err != nil { + return err + } + ids := make([]string, 0, len(clusters)) + name_contains := strings.ToLower(data.ClusterNameContains) + for _, v := range clusters { + match_name := strings.Contains(strings.ToLower(v.ClusterName), name_contains) + if name_contains != "" && !match_name { + continue } - ids := schema.NewSet(schema.HashString, []any{}) - name_contains := strings.ToLower(d.Get("cluster_name_contains").(string)) - for _, v := range clusters { - match_name := strings.Contains(strings.ToLower(v.ClusterName), name_contains) - if name_contains != "" && !match_name { - continue - } - ids.Add(v.ClusterID) - } - d.Set("ids", ids) - d.SetId("_") - return nil - }, - Schema: map[string]*schema.Schema{ - "ids": { - Computed: true, - Type: schema.TypeSet, - Elem: &schema.Schema{ - Type: schema.TypeString, - }, - }, - "cluster_name_contains": { - Optional: true, - Type: schema.TypeString, - }, - }, - } + ids = append(ids, v.ClusterId) + } + data.Ids = ids + data.Id = "_" + return nil + }) } diff --git a/clusters/data_clusters_test.go b/clusters/data_clusters_test.go index ddabc295fe..48d80afdfe 100644 --- a/clusters/data_clusters_test.go +++ b/clusters/data_clusters_test.go @@ -6,69 +6,59 @@ import ( "github.com/databricks/databricks-sdk-go/client" "github.com/databricks/databricks-sdk-go/config" + "github.com/databricks/databricks-sdk-go/experimental/mocks" + "github.com/databricks/databricks-sdk-go/service/compute" "github.com/databricks/terraform-provider-databricks/common" "github.com/databricks/terraform-provider-databricks/qa" - "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/stretchr/testify/mock" ) func TestClustersDataSource(t *testing.T) { qa.ResourceFixture{ - Fixtures: []qa.HTTPFixture{ - { - Method: "GET", - Resource: "/api/2.0/clusters/list", - - Response: ClusterList{ - Clusters: []ClusterInfo{ - { - ClusterID: "b", - }, - { - ClusterID: "a", - }, - }, + MockWorkspaceClientFunc: func(m *mocks.MockWorkspaceClient) { + e := m.GetMockClustersAPI().EXPECT() + e.ListAll(mock.Anything, compute.ListClustersRequest{}).Return([]compute.ClusterDetails{ + { + ClusterId: "b", + }, + { + ClusterId: "a", }, - }, + }, nil) }, Resource: DataSourceClusters(), NonWritable: true, Read: true, ID: "_", - }.ApplyNoError(t) + }.ApplyAndExpectData(t, map[string]any{ + "ids": []string{"a", "b"}, + }) } func TestClustersDataSourceContainsName(t *testing.T) { - d, err := qa.ResourceFixture{ - Fixtures: []qa.HTTPFixture{ - { - Method: "GET", - Resource: "/api/2.0/clusters/list", - Response: ClusterList{ - Clusters: []ClusterInfo{ - { - ClusterID: "b", - ClusterName: "THIS NAME", - }, - { - ClusterID: "a", - ClusterName: "that name", - }, - }, + qa.ResourceFixture{ + MockWorkspaceClientFunc: func(m *mocks.MockWorkspaceClient) { + e := m.GetMockClustersAPI().EXPECT() + e.ListAll(mock.Anything, compute.ListClustersRequest{}).Return([]compute.ClusterDetails{ + { + ClusterId: "b", + ClusterName: "THIS NAME", + }, + { + ClusterId: "a", + ClusterName: "that name", }, - }, + }, nil) }, Resource: DataSourceClusters(), NonWritable: true, Read: true, ID: "_", HCL: `cluster_name_contains = "this"`, - }.Apply(t) - require.NoError(t, err) - ids := d.Get("ids").(*schema.Set) - assert.True(t, ids.Contains("b")) - assert.Equal(t, 1, ids.Len()) + }.ApplyAndExpectData(t, map[string]any{ + "ids": []string{"b"}, + }) } func TestClustersDataSourceErrorsOut(t *testing.T) {