Skip to content

Commit

Permalink
Refactored databricks_cluster(s) data sources to Go SDK (#3685)
Browse files Browse the repository at this point in the history
* relax cluster check

* fix

* fix

* refactor `databricks_cluster` data source to Go SDK

* refactor `databricks_clusters` data source to Go SDK
  • Loading branch information
nkvuong authored Jul 1, 2024
1 parent 1ba1772 commit df210b2
Show file tree
Hide file tree
Showing 4 changed files with 141 additions and 185 deletions.
29 changes: 14 additions & 15 deletions clusters/data_cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,24 @@ import (
"context"
"fmt"

"github.com/databricks/databricks-sdk-go"
"github.com/databricks/databricks-sdk-go/service/compute"
"github.com/databricks/terraform-provider-databricks/common"
)

func DataSourceCluster() common.Resource {
type clusterData struct {
Id string `json:"id,omitempty" tf:"computed"`
ClusterId string `json:"cluster_id,omitempty" tf:"computed"`
Name string `json:"cluster_name,omitempty" tf:"computed"`
ClusterInfo *ClusterInfo `json:"cluster_info,omitempty" tf:"computed"`
}
return common.DataResource(clusterData{}, func(ctx context.Context, e interface{}, c *common.DatabricksClient) error {
data := e.(*clusterData)
clusterAPI := NewClustersAPI(ctx, c)
return common.WorkspaceData(func(ctx context.Context, data *struct {
Id string `json:"id,omitempty" tf:"computed"`
ClusterId string `json:"cluster_id,omitempty" tf:"computed"`
Name string `json:"cluster_name,omitempty" tf:"computed"`
ClusterInfo *compute.ClusterDetails `json:"cluster_info,omitempty" tf:"computed"`
}, w *databricks.WorkspaceClient) error {
if data.Name != "" {
clusters, err := clusterAPI.List()
clusters, err := w.Clusters.ListAll(ctx, compute.ListClustersRequest{})
if err != nil {
return err
}
namedClusters := []ClusterInfo{}
namedClusters := []compute.ClusterDetails{}
for _, clst := range clusters {
cluster := clst
if cluster.ClusterName == data.Name {
Expand All @@ -37,16 +36,16 @@ func DataSourceCluster() common.Resource {
}
data.ClusterInfo = &namedClusters[0]
} else if data.ClusterId != "" {
cls, err := clusterAPI.Get(data.ClusterId)
cls, err := w.Clusters.GetByClusterId(ctx, data.ClusterId)
if err != nil {
return err
}
data.ClusterInfo = &cls
data.ClusterInfo = cls
} else {
return fmt.Errorf("you need to specify either `cluster_name` or `cluster_id`")
}
data.Id = data.ClusterInfo.ClusterID
data.ClusterId = data.ClusterInfo.ClusterID
data.Id = data.ClusterInfo.ClusterId
data.ClusterId = data.ClusterInfo.ClusterId

return nil
})
Expand Down
171 changes: 74 additions & 97 deletions clusters/data_cluster_test.go
Original file line number Diff line number Diff line change
@@ -1,104 +1,81 @@
package clusters

import (
"fmt"
"testing"

"github.com/databricks/databricks-sdk-go/experimental/mocks"
"github.com/databricks/databricks-sdk-go/service/compute"
"github.com/databricks/terraform-provider-databricks/qa"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/mock"
)

func TestClusterDataByID(t *testing.T) {
d, err := qa.ResourceFixture{
Fixtures: []qa.HTTPFixture{
{
Method: "GET",
Resource: "/api/2.0/clusters/get?cluster_id=abc",
Response: ClusterInfo{
ClusterID: "abc",
NumWorkers: 100,
ClusterName: "Shared Autoscaling",
SparkVersion: "7.1-scala12",
NodeTypeID: "i3.xlarge",
AutoterminationMinutes: 15,
State: ClusterStateRunning,
AutoScale: &AutoScale{
MaxWorkers: 4,
},
qa.ResourceFixture{
MockWorkspaceClientFunc: func(m *mocks.MockWorkspaceClient) {
e := m.GetMockClustersAPI().EXPECT()
e.GetByClusterId(mock.Anything, "abc").Return(&compute.ClusterDetails{
ClusterId: "abc",
NumWorkers: 100,
ClusterName: "Shared Autoscaling",
SparkVersion: "7.1-scala12",
NodeTypeId: "i3.xlarge",
AutoterminationMinutes: 15,
State: ClusterStateRunning,
Autoscale: &compute.AutoScale{
MaxWorkers: 4,
},
},
}, nil)
},
Resource: DataSourceCluster(),
HCL: `cluster_id = "abc"`,
Read: true,
NonWritable: true,
ID: "abc",
}.Apply(t)
require.NoError(t, err)
assert.Equal(t, 15, d.Get("cluster_info.0.autotermination_minutes"))
assert.Equal(t, "Shared Autoscaling", d.Get("cluster_info.0.cluster_name"))
assert.Equal(t, "i3.xlarge", d.Get("cluster_info.0.node_type_id"))
assert.Equal(t, 4, d.Get("cluster_info.0.autoscale.0.max_workers"))
assert.Equal(t, "RUNNING", d.Get("cluster_info.0.state"))

for k, v := range d.State().Attributes {
fmt.Printf("assert.Equal(t, %#v, d.Get(%#v))\n", v, k)
}
}.ApplyAndExpectData(t, map[string]any{
"cluster_info.0.autotermination_minutes": 15,
"cluster_info.0.cluster_name": "Shared Autoscaling",
"cluster_info.0.node_type_id": "i3.xlarge",
"cluster_info.0.autoscale.0.max_workers": 4,
"cluster_info.0.state": "RUNNING",
})
}

func TestClusterDataByName(t *testing.T) {
d, err := qa.ResourceFixture{
Fixtures: []qa.HTTPFixture{
{
Method: "GET",
Resource: "/api/2.0/clusters/list",

Response: ClusterList{
Clusters: []ClusterInfo{{
ClusterID: "abc",
NumWorkers: 100,
ClusterName: "Shared Autoscaling",
SparkVersion: "7.1-scala12",
NodeTypeID: "i3.xlarge",
AutoterminationMinutes: 15,
State: ClusterStateRunning,
AutoScale: &AutoScale{
MaxWorkers: 4,
},
}},
qa.ResourceFixture{
MockWorkspaceClientFunc: func(m *mocks.MockWorkspaceClient) {
e := m.GetMockClustersAPI().EXPECT()
e.ListAll(mock.Anything, compute.ListClustersRequest{}).Return([]compute.ClusterDetails{{
ClusterId: "abc",
NumWorkers: 100,
ClusterName: "Shared Autoscaling",
SparkVersion: "7.1-scala12",
NodeTypeId: "i3.xlarge",
AutoterminationMinutes: 15,
State: ClusterStateRunning,
Autoscale: &compute.AutoScale{
MaxWorkers: 4,
},
},
}}, nil)
},
Resource: DataSourceCluster(),
HCL: `cluster_name = "Shared Autoscaling"`,
Read: true,
NonWritable: true,
ID: "_",
}.Apply(t)
require.NoError(t, err)
assert.Equal(t, 15, d.Get("cluster_info.0.autotermination_minutes"))
assert.Equal(t, "Shared Autoscaling", d.Get("cluster_info.0.cluster_name"))
assert.Equal(t, "i3.xlarge", d.Get("cluster_info.0.node_type_id"))
assert.Equal(t, 4, d.Get("cluster_info.0.autoscale.0.max_workers"))
assert.Equal(t, "RUNNING", d.Get("cluster_info.0.state"))

for k, v := range d.State().Attributes {
fmt.Printf("assert.Equal(t, %#v, d.Get(%#v))\n", v, k)
}
}.ApplyAndExpectData(t, map[string]any{
"cluster_info.0.autotermination_minutes": 15,
"cluster_info.0.cluster_name": "Shared Autoscaling",
"cluster_info.0.node_type_id": "i3.xlarge",
"cluster_info.0.autoscale.0.max_workers": 4,
"cluster_info.0.state": "RUNNING",
})
}

func TestClusterDataByName_NotFound(t *testing.T) {
qa.ResourceFixture{
Fixtures: []qa.HTTPFixture{
{
Method: "GET",
Resource: "/api/2.0/clusters/list",

Response: ClusterList{
Clusters: []ClusterInfo{},
},
},
MockWorkspaceClientFunc: func(m *mocks.MockWorkspaceClient) {
e := m.GetMockClustersAPI().EXPECT()
e.ListAll(mock.Anything, compute.ListClustersRequest{}).Return([]compute.ClusterDetails{}, nil)
},
Resource: DataSourceCluster(),
HCL: `cluster_name = "Unknown"`,
Expand All @@ -110,34 +87,34 @@ func TestClusterDataByName_NotFound(t *testing.T) {

func TestClusterDataByName_DuplicateNames(t *testing.T) {
qa.ResourceFixture{
Fixtures: []qa.HTTPFixture{
{
Method: "GET",
Resource: "/api/2.0/clusters/list",

Response: ClusterList{
Clusters: []ClusterInfo{
{
ClusterID: "abc",
NumWorkers: 100,
ClusterName: "Shared Autoscaling",
SparkVersion: "7.1-scala12",
NodeTypeID: "i3.xlarge",
AutoterminationMinutes: 15,
State: ClusterStateRunning,
},
{
ClusterID: "def",
NumWorkers: 100,
ClusterName: "Shared Autoscaling",
SparkVersion: "7.1-scala12",
NodeTypeID: "i3.xlarge",
AutoterminationMinutes: 15,
State: ClusterStateRunning,
},
MockWorkspaceClientFunc: func(m *mocks.MockWorkspaceClient) {
e := m.GetMockClustersAPI().EXPECT()
e.ListAll(mock.Anything, compute.ListClustersRequest{}).Return([]compute.ClusterDetails{
{
ClusterId: "abc",
NumWorkers: 100,
ClusterName: "Shared Autoscaling",
SparkVersion: "7.1-scala12",
NodeTypeId: "i3.xlarge",
AutoterminationMinutes: 15,
State: ClusterStateRunning,
Autoscale: &compute.AutoScale{
MaxWorkers: 4,
},
},
{
ClusterId: "def",
NumWorkers: 100,
ClusterName: "Shared Autoscaling",
SparkVersion: "7.1-scala12",
NodeTypeId: "i3.xlarge",
AutoterminationMinutes: 15,
State: ClusterStateRunning,
Autoscale: &compute.AutoScale{
MaxWorkers: 4,
},
},
},
}, nil)
},
Resource: DataSourceCluster(),
HCL: `cluster_name = "Shared Autoscaling"`,
Expand Down
56 changes: 23 additions & 33 deletions clusters/data_clusters.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,42 +4,32 @@ import (
"context"
"strings"

"github.com/databricks/databricks-sdk-go"
"github.com/databricks/databricks-sdk-go/service/compute"
"github.com/databricks/terraform-provider-databricks/common"
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema"
)

func DataSourceClusters() common.Resource {
return common.Resource{
Read: func(ctx context.Context, d *schema.ResourceData, i *common.DatabricksClient) error {
clusters, err := NewClustersAPI(ctx, i).List()
if err != nil {
return err
return common.WorkspaceData(func(ctx context.Context, data *struct {
Id string `json:"id,omitempty" tf:"computed"`
Ids []string `json:"ids,omitempty" tf:"computed,slice_set"`
ClusterNameContains string `json:"cluster_name_contains"`
}, w *databricks.WorkspaceClient) error {
clusters, err := w.Clusters.ListAll(ctx, compute.ListClustersRequest{})
if err != nil {
return err
}
ids := make([]string, 0, len(clusters))
name_contains := strings.ToLower(data.ClusterNameContains)
for _, v := range clusters {
match_name := strings.Contains(strings.ToLower(v.ClusterName), name_contains)
if name_contains != "" && !match_name {
continue
}
ids := schema.NewSet(schema.HashString, []any{})
name_contains := strings.ToLower(d.Get("cluster_name_contains").(string))
for _, v := range clusters {
match_name := strings.Contains(strings.ToLower(v.ClusterName), name_contains)
if name_contains != "" && !match_name {
continue
}
ids.Add(v.ClusterID)
}
d.Set("ids", ids)
d.SetId("_")
return nil
},
Schema: map[string]*schema.Schema{
"ids": {
Computed: true,
Type: schema.TypeSet,
Elem: &schema.Schema{
Type: schema.TypeString,
},
},
"cluster_name_contains": {
Optional: true,
Type: schema.TypeString,
},
},
}
ids = append(ids, v.ClusterId)
}
data.Ids = ids
data.Id = "_"
return nil
})
}
Loading

0 comments on commit df210b2

Please sign in to comment.