diff --git a/common/customizable_schema_plugin_framework.go b/common/customizable_schema_plugin_framework.go index 6d2376d8d7..b3e5e7c815 100644 --- a/common/customizable_schema_plugin_framework.go +++ b/common/customizable_schema_plugin_framework.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/hashicorp/terraform-plugin-framework/provider/schema" + "github.com/hashicorp/terraform-plugin-framework/schema/validator" ) type CustomizableSchemaPluginFramework struct { @@ -15,24 +16,24 @@ func ConstructCustomizableSchema(attributes map[string]schema.Attribute) *Custom return &CustomizableSchemaPluginFramework{attr: attr} } -func (s *CustomizableSchemaPluginFramework) SchemaPath(path ...string) *CustomizableSchemaPluginFramework { - attr, err := navigateSchema(s.attr, path...) - if err != nil { - panic(err) - } +// func (s *CustomizableSchemaPluginFramework) SchemaPath(path ...string) *CustomizableSchemaPluginFramework { +// // The attr returned here is not the original object, it's a copy. +// attr, err := navigateSchema(&s.attr, path...) +// if err != nil { +// panic(err) +// } - return &CustomizableSchemaPluginFramework{attr} -} +// return &CustomizableSchemaPluginFramework{attr} +// } // Converts CustomizableSchema into a map from string to Attribute. func (s *CustomizableSchemaPluginFramework) ToAttributeMap() map[string]schema.Attribute { - return attributeToMap(s.attr) + return attributeToMap(&s.attr) } -func attributeToMap(attr schema.Attribute) map[string]schema.Attribute { +func attributeToMap(attr *schema.Attribute) map[string]schema.Attribute { var m map[string]schema.Attribute - - switch attr := attr.(type) { + switch attr := (*attr).(type) { case schema.SingleNestedAttribute: m = attr.Attributes case schema.ListNestedAttribute: @@ -46,74 +47,271 @@ func attributeToMap(attr schema.Attribute) map[string]schema.Attribute { return m } -func (s *CustomizableSchemaPluginFramework) AddNewField(key string, newField schema.Attribute) *CustomizableSchemaPluginFramework { - switch attr := s.attr.(type) { - case schema.SingleNestedAttribute: - _, exists := attr.Attributes[key] - if exists { - panic("Cannot add new field, " + key + " already exists in the schema") +func (s *CustomizableSchemaPluginFramework) AddNewField(key string, newField schema.Attribute, path ...string) *CustomizableSchemaPluginFramework { + cb := func(a schema.Attribute) schema.Attribute { + switch attr := a.(type) { + case schema.SingleNestedAttribute: + _, exists := attr.Attributes[key] + if exists { + panic("Cannot add new field, " + key + " already exists in the schema") + } + attr.Attributes[key] = newField + return attr + case schema.ListNestedAttribute: + _, exists := attr.NestedObject.Attributes[key] + if exists { + panic("Cannot add new field, " + key + " already exists in the schema") + } + attr.NestedObject.Attributes[key] = newField + return attr + case schema.MapNestedAttribute: + _, exists := attr.NestedObject.Attributes[key] + if exists { + panic("Cannot add new field, " + key + " already exists in the schema") + } + attr.NestedObject.Attributes[key] = newField + return attr + default: + panic("attribute is not nested, cannot add field") } - attr.Attributes[key] = newField - attr.Required = true - s.attr = attr - case schema.ListNestedAttribute: - _, exists := attr.NestedObject.Attributes[key] - if exists { - panic("Cannot add new field, " + key + " already exists in the schema") + } + + if len(path) == 0 { + s.attr = cb(s.attr) + } else { + navigateSchemaWithCallback(&s.attr, cb, path...) + } + return s +} + +func (s *CustomizableSchemaPluginFramework) RemoveField(key string, path ...string) *CustomizableSchemaPluginFramework { + cb := func(a schema.Attribute) schema.Attribute { + switch attr := a.(type) { + case schema.SingleNestedAttribute: + _, exists := attr.Attributes[key] + if !exists { + panic("Cannot remove field, " + key + " does not exist in the schema") + } + delete(attr.Attributes, key) + return attr + case schema.ListNestedAttribute: + _, exists := attr.NestedObject.Attributes[key] + if !exists { + panic("Cannot remove field, " + key + " does not exist in the schema") + } + delete(attr.NestedObject.Attributes, key) + return attr + case schema.MapNestedAttribute: + _, exists := attr.NestedObject.Attributes[key] + if !exists { + panic("Cannot remove field, " + key + " does not exist in the schema") + } + delete(attr.NestedObject.Attributes, key) + return attr + default: + panic("attribute is not nested, cannot add field") } - attr.NestedObject.Attributes[key] = newField - s.attr = attr - case schema.MapNestedAttribute: - _, exists := attr.NestedObject.Attributes[key] - if exists { - panic("Cannot add new field, " + key + " already exists in the schema") + } + + if len(path) == 0 { + s.attr = cb(s.attr) + } else { + navigateSchemaWithCallback(&s.attr, cb, path...) + } + return s +} + +func (s *CustomizableSchemaPluginFramework) AddValidator(v any, path ...string) *CustomizableSchemaPluginFramework { + cb := func(a schema.Attribute) schema.Attribute { + switch attr := a.(type) { + case schema.SingleNestedAttribute: + attr.Validators = append(attr.Validators, v.(validator.Object)) + return attr + case schema.ListNestedAttribute: + attr.Validators = append(attr.Validators, v.(validator.List)) + return attr + case schema.MapNestedAttribute: + attr.Validators = append(attr.Validators, v.(validator.Map)) + return attr + case schema.BoolAttribute: + attr.Validators = append(attr.Validators, v.(validator.Bool)) + return attr + case schema.Float64Attribute: + attr.Validators = append(attr.Validators, v.(validator.Float64)) + return attr + case schema.StringAttribute: + attr.Validators = append(attr.Validators, v.(validator.String)) + return attr + case schema.Int64Attribute: + attr.Validators = append(attr.Validators, v.(validator.Int64)) + return attr + case schema.ListAttribute: + attr.Validators = append(attr.Validators, v.(validator.List)) + return attr + case schema.MapAttribute: + attr.Validators = append(attr.Validators, v.(validator.Map)) + return attr + default: + panic(fmt.Sprintf("Unsupported type %T", s.attr)) } - attr.NestedObject.Attributes[key] = newField - s.attr = attr - default: - panic("attribute is not nested, cannot add field") } + navigateSchemaWithCallback(&s.attr, cb, path...) + return s } -func (s *CustomizableSchemaPluginFramework) RemoveField(key string) *CustomizableSchemaPluginFramework { - switch attr := s.attr.(type) { - case schema.SingleNestedAttribute: - _, exists := attr.Attributes[key] - if !exists { - panic("Cannot remove field, " + key + " does not exist in the schema") +func (s *CustomizableSchemaPluginFramework) SetOptional(path ...string) *CustomizableSchemaPluginFramework { + cb := func(a schema.Attribute) schema.Attribute { + switch attr := a.(type) { + case schema.SingleNestedAttribute: + attr.Optional = true + attr.Required = false + return attr + case schema.ListNestedAttribute: + attr.Optional = true + attr.Required = false + return attr + case schema.MapNestedAttribute: + attr.Optional = true + attr.Required = false + return attr + case schema.BoolAttribute: + attr.Optional = true + attr.Required = false + return attr + case schema.Float64Attribute: + attr.Optional = true + attr.Required = false + return attr + case schema.StringAttribute: + attr.Optional = true + attr.Required = false + return attr + case schema.Int64Attribute: + attr.Optional = true + attr.Required = false + return attr + case schema.ListAttribute: + attr.Optional = true + attr.Required = false + return attr + case schema.MapAttribute: + attr.Optional = true + attr.Required = false + return attr + default: + panic(fmt.Sprintf("Unsupported type %T", s.attr)) } - delete(attr.Attributes, key) - s.attr = attr - case schema.ListNestedAttribute: - _, exists := attr.NestedObject.Attributes[key] - if !exists { - panic("Cannot remove field, " + key + " does not exist in the schema") + } + + navigateSchemaWithCallback(&s.attr, cb, path...) + + return s +} + +func (s *CustomizableSchemaPluginFramework) SetRequired(path ...string) *CustomizableSchemaPluginFramework { + cb := func(a schema.Attribute) schema.Attribute { + switch attr := a.(type) { + case schema.SingleNestedAttribute: + attr.Optional = false + attr.Required = true + return attr + case schema.ListNestedAttribute: + attr.Optional = false + attr.Required = true + return attr + case schema.MapNestedAttribute: + attr.Optional = false + attr.Required = true + return attr + case schema.BoolAttribute: + attr.Optional = false + attr.Required = true + return attr + case schema.Float64Attribute: + attr.Optional = false + attr.Required = true + return attr + case schema.StringAttribute: + attr.Optional = false + attr.Required = true + return attr + case schema.Int64Attribute: + attr.Optional = false + attr.Required = true + return attr + case schema.ListAttribute: + attr.Optional = false + attr.Required = true + return attr + case schema.MapAttribute: + attr.Optional = false + attr.Required = true + return attr + default: + panic(fmt.Sprintf("Unsupported type %T", s.attr)) } - delete(attr.NestedObject.Attributes, key) - s.attr = attr - case schema.MapNestedAttribute: - _, exists := attr.NestedObject.Attributes[key] - if !exists { - panic("Cannot remove field, " + key + " does not exist in the schema") + } + + navigateSchemaWithCallback(&s.attr, cb, path...) + + return s +} + +func (s *CustomizableSchemaPluginFramework) SetSensitive(path ...string) *CustomizableSchemaPluginFramework { + cb := func(a schema.Attribute) schema.Attribute { + switch attr := a.(type) { + case schema.SingleNestedAttribute: + attr.Sensitive = true + return attr + case schema.ListNestedAttribute: + attr.Sensitive = true + return attr + case schema.MapNestedAttribute: + attr.Sensitive = true + return attr + case schema.BoolAttribute: + attr.Sensitive = true + return attr + case schema.Float64Attribute: + attr.Sensitive = true + return attr + case schema.StringAttribute: + attr.Sensitive = true + return attr + case schema.Int64Attribute: + attr.Sensitive = true + return attr + case schema.ListAttribute: + attr.Sensitive = true + return attr + case schema.MapAttribute: + attr.Sensitive = true + return attr + default: + panic(fmt.Sprintf("Unsupported type %T", s.attr)) } - delete(attr.NestedObject.Attributes, key) - s.attr = attr - default: - panic("attribute is not nested, cannot add field") } + navigateSchemaWithCallback(&s.attr, cb, path...) return s } // Given a attribute map, navigate through the given path, panics if the path is not valid. func MustSchemaAttributePath(attrs map[string]schema.Attribute, path ...string) schema.Attribute { - return ConstructCustomizableSchema(attrs).SchemaPath(path...).attr + attr := ConstructCustomizableSchema(attrs).attr + + res, err := navigateSchema(&attr, path...) + if err != nil { + panic(err) + } + + return res } // Helper function for navigating through schema attributes, panics if path does not exist or invalid. -func navigateSchema(s schema.Attribute, path ...string) (schema.Attribute, error) { +func navigateSchema(s *schema.Attribute, path ...string) (schema.Attribute, error) { cs := s for i, p := range path { m := attributeToMap(cs) @@ -122,10 +320,31 @@ func navigateSchema(s schema.Attribute, path ...string) (schema.Attribute, error if !ok { return nil, fmt.Errorf("missing key %s", p) } + if i == len(path)-1 { return v, nil } - cs = v + cs = &v + } + return nil, fmt.Errorf("path %v is incomplete", path) +} + +// Helper function for navigating through schema attributes, panics if path does not exist or invalid. +func navigateSchemaWithCallback(s *schema.Attribute, cb func(schema.Attribute) schema.Attribute, path ...string) (schema.Attribute, error) { + cs := s + for i, p := range path { + m := attributeToMap(cs) + + v, ok := m[p] + if !ok { + return nil, fmt.Errorf("missing key %s", p) + } + + if i == len(path)-1 { + m[p] = cb(v) + return m[p], nil + } + cs = &v } return nil, fmt.Errorf("path %v is incomplete", path) } diff --git a/common/customizable_schema_plugin_framework_test.go b/common/customizable_schema_plugin_framework_test.go index 11327d152d..c764235b95 100644 --- a/common/customizable_schema_plugin_framework_test.go +++ b/common/customizable_schema_plugin_framework_test.go @@ -1,23 +1,67 @@ package common import ( + "context" + "fmt" "testing" "github.com/hashicorp/terraform-plugin-framework/provider/schema" + "github.com/hashicorp/terraform-plugin-framework/schema/validator" "github.com/stretchr/testify/assert" ) +type stringLengthBetweenValidator struct { + Max int + Min int +} + +// Description returns a plain text description of the validator's behavior, suitable for a practitioner to understand its impact. +func (v stringLengthBetweenValidator) Description(ctx context.Context) string { + return fmt.Sprintf("string length must be between %d and %d", v.Min, v.Max) +} + +// MarkdownDescription returns a markdown formatted description of the validator's behavior, suitable for a practitioner to understand its impact. +func (v stringLengthBetweenValidator) MarkdownDescription(ctx context.Context) string { + return fmt.Sprintf("string length must be between `%d` and `%d`", v.Min, v.Max) +} + +// Validate runs the main validation logic of the validator, reading configuration data out of `req` and updating `resp` with diagnostics. +func (v stringLengthBetweenValidator) ValidateString(ctx context.Context, req validator.StringRequest, resp *validator.StringResponse) { + // If the value is unknown or null, there is nothing to validate. + if req.ConfigValue.IsUnknown() || req.ConfigValue.IsNull() { + return + } + + strLen := len(req.ConfigValue.ValueString()) + + if strLen < v.Min || strLen > v.Max { + resp.Diagnostics.AddAttributeError( + req.Path, + "Invalid String Length", + fmt.Sprintf("String length must be between %d and %d, got: %d.", v.Min, v.Max, strLen), + ) + + return + } +} + func TestCustomizeSchema(t *testing.T) { scm := pluginFrameworkStructToSchema(DummyTfSdk{}, func(c CustomizableSchemaPluginFramework) CustomizableSchemaPluginFramework { c.AddNewField("new_field", schema.StringAttribute{Required: true}) - c.SchemaPath("nested").AddNewField("new_field", schema.StringAttribute{Required: true}) - c.SchemaPath("nested").AddNewField("to_be_removed", schema.StringAttribute{Required: true}) - c.SchemaPath("nested").RemoveField("to_be_removed") + c.AddNewField("new_field", schema.StringAttribute{Required: true}, "nested") + c.AddNewField("to_be_removed", schema.StringAttribute{Required: true}, "nested") + c.RemoveField("to_be_removed", "nested") + c.SetRequired("nested", "enabled") + c.SetSensitive("nested", "name") + c.AddValidator(stringLengthBetweenValidator{}, "description") return c }) assert.True(t, scm.Attributes["new_field"].IsRequired()) assert.True(t, MustSchemaAttributePath(scm.Attributes, "nested", "new_field").IsRequired()) + assert.True(t, MustSchemaAttributePath(scm.Attributes, "nested", "enabled").IsRequired()) + assert.True(t, MustSchemaAttributePath(scm.Attributes, "nested", "name").IsSensitive()) attr := MustSchemaAttributePath(scm.Attributes, "nested").(schema.SingleNestedAttribute).Attributes _, ok := attr["to_be_removed"] + assert.True(t, len(MustSchemaAttributePath(scm.Attributes, "description").(schema.StringAttribute).Validators) == 1) assert.True(t, !ok) }