Skip to content

Commit

Permalink
Added support for warehouse_id, options, partitions and liquid …
Browse files Browse the repository at this point in the history
…clustering for `databricks_sql_table` resource (#2789)

* additional parameters for `databricks_sql_table`

* fix timeout & tests

* fix test

* fix properties drift

* more supress diff

* refactor

* feedback

* feedback

* fix test
  • Loading branch information
nkvuong authored Oct 23, 2023
1 parent 7d8d851 commit cb4a6f0
Show file tree
Hide file tree
Showing 4 changed files with 419 additions and 21 deletions.
115 changes: 100 additions & 15 deletions catalog/resource_sql_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,19 @@ import (
"log"
"reflect"
"strings"
"time"

"github.com/databricks/databricks-sdk-go/apierr"
"github.com/databricks/databricks-sdk-go/service/compute"
"github.com/databricks/databricks-sdk-go/service/sql"
"github.com/databricks/terraform-provider-databricks/clusters"
"github.com/databricks/terraform-provider-databricks/common"

"github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema"
)

var MaxSqlExecWaitTimeout = 50

type SqlColumnInfo struct {
Name string `json:"name"`
Type string `json:"type_text,omitempty" tf:"suppress_diff,alias:type"`
Expand All @@ -29,14 +33,19 @@ type SqlTableInfo struct {
TableType string `json:"table_type" tf:"force_new"`
DataSourceFormat string `json:"data_source_format,omitempty" tf:"force_new"`
ColumnInfos []SqlColumnInfo `json:"columns,omitempty" tf:"alias:column,computed,force_new"`
Partitions []string `json:"partitions,omitempty" tf:"force_new"`
ClusterKeys []string `json:"cluster_keys,omitempty" tf:"force_new"`
StorageLocation string `json:"storage_location,omitempty" tf:"suppress_diff"`
StorageCredentialName string `json:"storage_credential_name,omitempty" tf:"force_new"`
ViewDefinition string `json:"view_definition,omitempty"`
Comment string `json:"comment,omitempty"`
Properties map[string]string `json:"properties,omitempty" tf:"computed"`
Options map[string]string `json:"options,omitempty" tf:"force_new"`
ClusterID string `json:"cluster_id,omitempty" tf:"computed"`
WarehouseID string `json:"warehouse_id,omitempty"`

exec common.CommandExecutor
exec common.CommandExecutor
sqlExec *sql.StatementExecutionAPI
}

type SqlTablesAPI struct {
Expand Down Expand Up @@ -75,6 +84,17 @@ func sqlTableIsManagedProperty(key string) bool {
"delta.lastUpdateVersion": true,
"delta.minReaderVersion": true,
"delta.minWriterVersion": true,
"delta.enableDeletionVectors": true,
"delta.enableRowTracking": true,
"delta.feature.deletionVectors": true,
"delta.feature.domainMetadata": true,
"delta.feature.liquid": true,
"delta.feature.rowTracking": true,
"delta.liquid.clusteringColumns": true,
"delta.rowTracking.materializedRowCommitVersionColumnName": true,
"delta.rowTracking.materializedRowIdColumnName": true,
"delta.checkpoint.writeStatsAsJson": true,
"delta.checkpoint.writeStatsAsStruct": true,
"view.catalogAndNamespace.numParts": true,
"view.catalogAndNamespace.part.0": true,
"view.catalogAndNamespace.part.1": true,
Expand All @@ -96,27 +116,37 @@ func sqlTableIsManagedProperty(key string) bool {
func (ti *SqlTableInfo) initCluster(ctx context.Context, d *schema.ResourceData, c *common.DatabricksClient) (err error) {
defaultClusterName := "terraform-sql-table"
clustersAPI := clusters.NewClustersAPI(ctx, c)
// if a cluster id is specified, start the cluster
if ci, ok := d.GetOk("cluster_id"); ok {
ti.ClusterID = ci.(string)
} else {
ti.ClusterID, err = ti.getOrCreateCluster(defaultClusterName, clustersAPI)
_, err = clustersAPI.StartAndGetInfo(ti.ClusterID)
if apierr.IsMissing(err) {
// cluster that was previously in a tfstate was deleted
ti.ClusterID, err = ti.getOrCreateCluster(defaultClusterName, clustersAPI)
if err != nil {
return
}
_, err = clustersAPI.StartAndGetInfo(ti.ClusterID)
}
if err != nil {
return
}
}
_, err = clustersAPI.StartAndGetInfo(ti.ClusterID)
if apierr.IsMissing(err) {
// cluster that was previously in a tfstate was deleted
// if a warehouse id is specified, use the warehouse
} else if wi, ok := d.GetOk("warehouse_id"); ok {
ti.WarehouseID = wi.(string)
// else, create a default cluster
} else {
ti.ClusterID, err = ti.getOrCreateCluster(defaultClusterName, clustersAPI)
if err != nil {
return
}
_, err = clustersAPI.StartAndGetInfo(ti.ClusterID)
}
ti.exec = c.CommandExecutor(ctx)
w, err := c.WorkspaceClient()
if err != nil {
return
return err
}
ti.exec = c.CommandExecutor(ctx)
ti.sqlExec = w.StatementExecution
return nil
}

Expand Down Expand Up @@ -177,6 +207,16 @@ func (ti *SqlTableInfo) serializeProperties() string {
return strings.Join(propsMap[:], ", ") // 'foo'='bar', 'this'='that'
}

func (ti *SqlTableInfo) serializeOptions() string {
optionsMap := make([]string, 0, len(ti.Options))
for key, value := range ti.Options {
if !sqlTableIsManagedProperty(key) {
optionsMap = append(optionsMap, fmt.Sprintf("'%s'='%s'", key, value))
}
}
return strings.Join(optionsMap[:], ", ") // 'foo'='bar', 'this'='that'
}

func (ti *SqlTableInfo) buildLocationStatement() string {
statements := make([]string, 0, 10)
statements = append(statements, fmt.Sprintf("LOCATION '%s'", ti.StorageLocation)) // LOCATION '/mnt/csv_files'
Expand Down Expand Up @@ -218,6 +258,14 @@ func (ti *SqlTableInfo) buildTableCreateStatement() string {
}
}

if len(ti.Partitions) > 0 {
statements = append(statements, fmt.Sprintf("\nPARTITIONED BY (%s)", strings.Join(ti.Partitions, ", "))) // PARTITIONED BY (university, major)
}

if len(ti.ClusterKeys) > 0 {
statements = append(statements, fmt.Sprintf("\nCLUSTER BY (%s)", strings.Join(ti.ClusterKeys, ", "))) // CLUSTER BY (university, major)
}

if ti.Comment != "" {
statements = append(statements, fmt.Sprintf("\nCOMMENT '%s'", parseComment(ti.Comment))) // COMMENT 'this is a comment'
}
Expand All @@ -226,6 +274,10 @@ func (ti *SqlTableInfo) buildTableCreateStatement() string {
statements = append(statements, fmt.Sprintf("\nTBLPROPERTIES (%s)", ti.serializeProperties())) // TBLPROPERTIES ('foo'='bar')
}

if len(ti.Options) > 0 {
statements = append(statements, fmt.Sprintf("\nOPTIONS (%s)", ti.serializeOptions())) // OPTIONS ('foo'='bar')
}

if !isView {
if ti.StorageLocation != "" {
statements = append(statements, "\n"+ti.buildLocationStatement())
Expand Down Expand Up @@ -253,6 +305,9 @@ func (ti *SqlTableInfo) diff(oldti *SqlTableInfo) ([]string, error) {
if ti.StorageLocation != oldti.StorageLocation {
statements = append(statements, fmt.Sprintf("ALTER TABLE %s SET %s", ti.SQLFullName(), ti.buildLocationStatement()))
}
if !reflect.DeepEqual(ti.ClusterKeys, oldti.ClusterKeys) {
statements = append(statements, fmt.Sprintf("ALTER TABLE %s CLUSTER BY (%s)", ti.SQLFullName(), strings.Join(ti.ClusterKeys, ", ")))
}
}

// Attributes common to both views and tables
Expand Down Expand Up @@ -302,12 +357,29 @@ func (ti *SqlTableInfo) deleteTable() error {

func (ti *SqlTableInfo) applySql(sqlQuery string) error {
log.Printf("[INFO] Executing Sql: %s", sqlQuery)
r := ti.exec.Execute(ti.ClusterID, "sql", sqlQuery)

if !r.Failed() {
if ti.WarehouseID != "" {
execCtx, cancel := context.WithTimeout(context.Background(), time.Duration(MaxSqlExecWaitTimeout)*time.Second)
defer cancel()
sqlRes, err := ti.sqlExec.ExecuteStatement(execCtx, sql.ExecuteStatementRequest{
Statement: sqlQuery,
WaitTimeout: fmt.Sprintf("%ds", MaxSqlExecWaitTimeout), //max allowed by sql exec
WarehouseId: ti.WarehouseID,
OnWaitTimeout: sql.ExecuteStatementRequestOnWaitTimeoutCancel,
})
if err != nil {
return err
}
if sqlRes.Status.State != "SUCCEEDED" {
return fmt.Errorf("statement failed to execute: %s", sqlRes.Status.State)
}
return nil
}
return fmt.Errorf("cannot execute %s: %s", sqlQuery, r.Error())

r := ti.exec.Execute(ti.ClusterID, "sql", sqlQuery)
if r.Failed() {
return fmt.Errorf("cannot execute %s: %s", sqlQuery, r.Error())
}
return nil
}

func ResourceSqlTable() *schema.Resource {
Expand All @@ -320,6 +392,12 @@ func ResourceSqlTable() *schema.Resource {
return strings.EqualFold(strings.ToLower(old), strings.ToLower(new))
}
s["storage_location"].DiffSuppressFunc = ucDirectoryPathSlashAndEmptySuppressDiff

s["cluster_id"].ConflictsWith = []string{"warehouse_id"}
s["warehouse_id"].ConflictsWith = []string{"cluster_id"}

s["partitions"].ConflictsWith = []string{"cluster_keys"}
s["cluster_keys"].ConflictsWith = []string{"partitions"}
return s
})
return common.Resource{
Expand All @@ -329,9 +407,16 @@ func ResourceSqlTable() *schema.Resource {
old, new := d.GetChange("properties")
oldProps := old.(map[string]any)
newProps := new.(map[string]any)
old, _ = d.GetChange("options")
options := old.(map[string]any)
for key := range oldProps {
if _, ok := newProps[key]; !ok {
if sqlTableIsManagedProperty(key) {
//options also gets exposed as properties
if _, ok := options[key]; ok {
newProps[key] = oldProps[key]
}
//some options are exposed as option.[...] properties
if sqlTableIsManagedProperty(key) || strings.HasPrefix(key, "option.") {
newProps[key] = oldProps[key]
}
}
Expand Down
Loading

0 comments on commit cb4a6f0

Please sign in to comment.