Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Aliases function for ResourceProvider to handle recursive structures more gracefully #3338

Merged
merged 14 commits into from
Mar 7, 2024
4 changes: 2 additions & 2 deletions clusters/resource_cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ type ClusterSpec struct {
compute.ClusterSpec
}

func (ClusterSpec) Aliases() map[string]string {
return map[string]string{}
func (ClusterSpec) Aliases() map[string]map[string]string {
return map[string]map[string]string{}
}

func (ClusterSpec) CustomizeSchema(s map[string]*schema.Schema) map[string]*schema.Schema {
Expand Down
15 changes: 8 additions & 7 deletions common/recursion_tracking.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,16 @@ type recursionTrackingContext struct {
}

func (rt recursionTrackingContext) depthExceeded(typeField reflect.StructField) bool {
typeName := rt.getNameForTypeField(typeField)
typeName := getNameForType(typeField.Type)
if maxDepth, ok := rt.maxDepthForTypes[typeName]; ok {
println("in if!!")
edwardfeng-db marked this conversation as resolved.
Show resolved Hide resolved
return rt.timesVisited[typeName]+1 > maxDepth
}
return false
}

func (rt recursionTrackingContext) getNameForTypeField(typeField reflect.StructField) string {
return strings.TrimPrefix(typeField.Type.String(), "*")
}

func (rt recursionTrackingContext) getMaxDepthForTypeField(typeField reflect.StructField) int {
typeName := rt.getNameForTypeField(typeField)
typeName := getNameForType(typeField.Type)
return rt.maxDepthForTypes[typeName]
}

Expand All @@ -38,7 +35,7 @@ func (rt recursionTrackingContext) copy() recursionTrackingContext {
}

func (rt recursionTrackingContext) visit(v reflect.Value) {
rt.timesVisited[strings.TrimPrefix(v.Type().String(), "*")] += 1
rt.timesVisited[getNameForType(v.Type())] += 1
}

func getEmptyRecursionTrackingContext() recursionTrackingContext {
Expand All @@ -54,3 +51,7 @@ func getRecursionTrackingContext(rp RecursiveResourceProvider) recursionTracking
rp.MaxDepthForTypes(),
}
}

func getNameForType(t reflect.Type) string {
return strings.TrimPrefix(t.String(), "*")
}
88 changes: 39 additions & 49 deletions common/reflect_resource.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,18 @@ var kindMap = map[reflect.Kind]string{

// Generic interface for ResourceProvider. Using CustomizeSchema and Aliases functions to keep track of additional information
// on top of the generated go-sdk struct. This is used to replace manually maintained structs with `tf` tags.
//
// Aliases() returns a two dimensional map where the top level key is the name of the struct, the second level key is the name of the field,
// the values are the alias for the corresponding field under the specified struct.
// Example:
//
// {
// "compute.ClusterSpec": {
// "libraries": "library"
// }
// }
type ResourceProvider interface {
Aliases() map[string]string
Aliases() map[string]map[string]string
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add the doccomment here? What are the keys and values?

By the way, we may want to make Aliases optional for ResourceProvider. If a resource has no aliases, we shouldn't be forced to add an empty Aliases function, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added comments with an example.

For making Aliases optional, I played around with the idea of adding another interface without Aliases() but I think it's a bit more complicated since we also have RecursiveResourceProvider. I think I'm probably going to leave it as it is and try to resolve this later on. For now I think let's just merge this to unblock the next steps for the jobs migration

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed offline. We'll do this.

CustomizeSchema(map[string]*schema.Schema) map[string]*schema.Schema
}

Expand Down Expand Up @@ -79,21 +89,25 @@ func reflectKind(k reflect.Kind) string {
return n
}

func chooseFieldNameWithAliases(typeField reflect.StructField, aliases map[string]string) string {
func chooseFieldNameWithAliases(typeField reflect.StructField, parentType reflect.Type, aliases map[string]map[string]string) string {
parentTypeName := getNameForType(parentType)
fieldNameWithAliasTag := chooseFieldName(typeField)
// If nothing in the aliases map, return the field name from plain chooseFieldName method.
if len(aliases) == 0 {
return chooseFieldName(typeField)
return fieldNameWithAliasTag
}

jsonFieldName := getJsonFieldName(typeField)
if jsonFieldName == "-" {
return "-"
}

if value, ok := aliases[jsonFieldName]; ok {
return value
if parentMap, ok := aliases[parentTypeName]; ok {
if value, ok := parentMap[jsonFieldName]; ok {
return value
}
}
return jsonFieldName
return fieldNameWithAliasTag
}

func getJsonFieldName(typeField reflect.StructField) string {
Expand Down Expand Up @@ -144,7 +158,7 @@ func StructToSchema(v any, customize func(map[string]*schema.Schema) map[string]
return resourceProviderStructToSchema(rp)
}
rv := reflect.ValueOf(v)
scm := typeToSchema(rv, map[string]string{}, getEmptyRecursionTrackingContext())
scm := typeToSchema(rv, map[string]map[string]string{}, getEmptyRecursionTrackingContext())
if customize != nil {
scm = customize(scm)
}
Expand Down Expand Up @@ -295,7 +309,7 @@ func listAllFields(v reflect.Value) []field {
return fields
}

func typeToSchema(v reflect.Value, aliases map[string]string, rt recursionTrackingContext) map[string]*schema.Schema {
func typeToSchema(v reflect.Value, aliases map[string]map[string]string, rt recursionTrackingContext) map[string]*schema.Schema {
scm := map[string]*schema.Schema{}
rk := v.Kind()
if rk == reflect.Ptr {
Expand All @@ -312,13 +326,12 @@ func typeToSchema(v reflect.Value, aliases map[string]string, rt recursionTracki
typeField := field.sf
if rt.depthExceeded(typeField) {
// Skip the field if recursion depth is over the limit.
log.Printf("[TRACE] over recursion limit, skipping field: %s, max depth: %d", rt.getNameForTypeField(typeField), rt.getMaxDepthForTypeField(typeField))
log.Printf("[TRACE] over recursion limit, skipping field: %s, max depth: %d", getNameForType(typeField.Type), rt.getMaxDepthForTypeField(typeField))
continue
}
tfTag := typeField.Tag.Get("tf")

fieldName := chooseFieldNameWithAliases(typeField, aliases)
unwrappedAliases := unwrapAliasesMap(fieldName, aliases)
fieldName := chooseFieldNameWithAliases(typeField, v.Type(), aliases)
if fieldName == "-" {
continue
}
Expand Down Expand Up @@ -381,7 +394,7 @@ func typeToSchema(v reflect.Value, aliases map[string]string, rt recursionTracki
scm[fieldName].Type = schema.TypeList
elem := typeField.Type.Elem()
sv := reflect.New(elem).Elem()
nestedSchema := typeToSchema(sv, unwrappedAliases, rt)
nestedSchema := typeToSchema(sv, aliases, rt)
if strings.Contains(tfTag, "suppress_diff") {
scm[fieldName].DiffSuppressFunc = diffSuppressor(fieldName, scm[fieldName])
for k, v := range nestedSchema {
Expand All @@ -398,7 +411,7 @@ func typeToSchema(v reflect.Value, aliases map[string]string, rt recursionTracki
elem := typeField.Type // changed from ptr
sv := reflect.New(elem) // changed from ptr

nestedSchema := typeToSchema(sv, unwrappedAliases, rt)
nestedSchema := typeToSchema(sv, aliases, rt)
if strings.Contains(tfTag, "suppress_diff") {
scm[fieldName].DiffSuppressFunc = diffSuppressor(fieldName, scm[fieldName])
for k, v := range nestedSchema {
Expand Down Expand Up @@ -427,7 +440,7 @@ func typeToSchema(v reflect.Value, aliases map[string]string, rt recursionTracki
case reflect.Struct:
sv := reflect.New(elem).Elem()
scm[fieldName].Elem = &schema.Resource{
Schema: typeToSchema(sv, unwrappedAliases, rt),
Schema: typeToSchema(sv, aliases, rt),
}
}
default:
Expand All @@ -446,7 +459,7 @@ func IsRequestEmpty(v any) (bool, error) {
return false, fmt.Errorf("value of Struct is expected, but got %s: %#v", reflectKind(rv.Kind()), rv)
}
var isNotEmpty bool
err := iterFields(rv, []string{}, StructToSchema(v, nil), map[string]string{}, func(fieldSchema *schema.Schema, path []string, valueField *reflect.Value) error {
err := iterFields(rv, []string{}, StructToSchema(v, nil), map[string]map[string]string{}, func(fieldSchema *schema.Schema, path []string, valueField *reflect.Value) error {
if isNotEmpty {
return nil
}
Expand All @@ -472,29 +485,9 @@ func isGoSdk(v reflect.Value) bool {
return false
}

// Unwraps aliases map given a fieldname. Should be called everytime we recursively call iterFields.
//
// NOTE: If the target field has an alias, we expect `fieldname` argument to be the alias.
// For example
//
// fieldName = "cluster"
// aliases = {"cluster.clusterName": "name", "libraries": "library"}
// would return: {"clusterName": "name"}
func unwrapAliasesMap(fieldName string, aliases map[string]string) map[string]string {
result := make(map[string]string)
prefix := fieldName + "."
for key, value := range aliases {
// Only keep the keys that have the prefix.
if strings.HasPrefix(key, prefix) && key != prefix {
result[key] = value
}
}
return result
}

// Iterate through each field of the given reflect.Value object and execute a callback function with the corresponding
// terraform schema object as the input.
func iterFields(rv reflect.Value, path []string, s map[string]*schema.Schema, aliases map[string]string,
func iterFields(rv reflect.Value, path []string, s map[string]*schema.Schema, aliases map[string]map[string]string,
cb func(fieldSchema *schema.Schema, path []string, valueField *reflect.Value) error) error {
rk := rv.Kind()
if rk != reflect.Struct {
Expand All @@ -507,7 +500,7 @@ func iterFields(rv reflect.Value, path []string, s map[string]*schema.Schema, al
fields := listAllFields(rv)
for _, field := range fields {
typeField := field.sf
fieldName := chooseFieldNameWithAliases(typeField, aliases)
fieldName := chooseFieldNameWithAliases(typeField, rv.Type(), aliases)
if fieldName == "-" {
continue
}
Expand All @@ -533,7 +526,7 @@ func iterFields(rv reflect.Value, path []string, s map[string]*schema.Schema, al
return nil
}

func collectionToMaps(v any, s *schema.Schema, aliases map[string]string) ([]any, error) {
func collectionToMaps(v any, s *schema.Schema, aliases map[string]map[string]string) ([]any, error) {
resultList := []any{}
if sl, ok := v.([]string); ok {
// most likely list of parameters to job task
Expand Down Expand Up @@ -568,12 +561,11 @@ func collectionToMaps(v any, s *schema.Schema, aliases map[string]string) ([]any
err := iterFields(v, []string{}, r.Schema, aliases, func(fieldSchema *schema.Schema,
path []string, valueField *reflect.Value) error {
fieldName := path[len(path)-1]
newAliases := unwrapAliasesMap(fieldName, aliases)
fieldValue := valueField.Interface()
fieldPath := strings.Join(path, ".")
switch fieldSchema.Type {
case schema.TypeList, schema.TypeSet:
nv, err := collectionToMaps(fieldValue, fieldSchema, newAliases)
nv, err := collectionToMaps(fieldValue, fieldSchema, aliases)
if err != nil {
return fmt.Errorf("%s: %v", path, err)
}
Expand Down Expand Up @@ -704,25 +696,23 @@ func DataToStructPointer(d *schema.ResourceData, scm map[string]*schema.Schema,
// DataToReflectValue reads reflect value from data
func DataToReflectValue(d *schema.ResourceData, s map[string]*schema.Schema, rv reflect.Value) error {
// TODO: Pass in the right aliases map.
return readReflectValueFromData([]string{}, d, rv, s, map[string]string{})
return readReflectValueFromData([]string{}, d, rv, s, map[string]map[string]string{})
}

// Get the aliases map from the given struct if it is an instance of ResourceProvider.
// NOTE: This does not return aliases defined on `tf` tags.
func getAliasesMapFromStruct(s any) map[string]string {
func getAliasesMapFromStruct(s any) map[string]map[string]string {
if v, ok := s.(ResourceProvider); ok {
return v.Aliases()
}
return map[string]string{}
return map[string]map[string]string{}
}

func readReflectValueFromData(path []string, d attributeGetter,
rv reflect.Value, s map[string]*schema.Schema, aliases map[string]string) error {
rv reflect.Value, s map[string]*schema.Schema, aliases map[string]map[string]string) error {
return iterFields(rv, path, s, aliases, func(fieldSchema *schema.Schema,
path []string, valueField *reflect.Value) error {
fieldPath := strings.Join(path, ".")
fieldName := path[len(path)-1]
newAliases := unwrapAliasesMap(fieldName, aliases)
raw, ok := d.GetOk(fieldPath)
if !ok {
return nil
Expand Down Expand Up @@ -759,13 +749,13 @@ func readReflectValueFromData(path []string, d attributeGetter,
rawSet := raw.(*schema.Set)
rawList := rawSet.List()
return readListFromData(path, d, rawList, valueField,
fieldSchema, newAliases, func(i int) string {
fieldSchema, aliases, func(i int) string {
return strconv.Itoa(rawSet.F(rawList[i]))
})
case schema.TypeList:
// here we rely on Terraform SDK to perform validation, so we don't to it twice
rawList := raw.([]any)
return readListFromData(path, d, rawList, valueField, fieldSchema, newAliases, strconv.Itoa)
return readListFromData(path, d, rawList, valueField, fieldSchema, aliases, strconv.Itoa)
default:
return fmt.Errorf("%s[%v] unsupported field type", fieldPath, raw)
}
Expand Down Expand Up @@ -818,7 +808,7 @@ func primitiveReflectValueFromInterface(rk reflect.Kind,
}

func readListFromData(path []string, d attributeGetter,
rawList []any, valueField *reflect.Value, fieldSchema *schema.Schema, aliases map[string]string,
rawList []any, valueField *reflect.Value, fieldSchema *schema.Schema, aliases map[string]map[string]string,
offsetConverter func(i int) string) error {
if len(rawList) == 0 {
return nil
Expand Down
15 changes: 9 additions & 6 deletions common/reflect_resource_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,12 @@ func TestChooseFieldName(t *testing.T) {
}

func TestChooseFieldNameWithAliasesMap(t *testing.T) {
type Bar struct {
Foo string `json:"foo,omitempty"`
}
assert.Equal(t, "foo", chooseFieldNameWithAliases(reflect.StructField{
Tag: `json:"bar"`,
}, map[string]string{"bar": "foo"}))
}, reflect.ValueOf(Bar{}).Type(), map[string]map[string]string{"common.Bar": {"bar": "foo"}}))
}

type testSliceItem struct {
Expand Down Expand Up @@ -93,8 +96,8 @@ type testForEachTask struct {
Extra string `json:"extra,omitempty"`
}

func (testRecursiveStruct) Aliases() map[string]string {
return map[string]string{}
func (testRecursiveStruct) Aliases() map[string]map[string]string {
return map[string]map[string]string{}
}

func (testRecursiveStruct) CustomizeSchema(s map[string]*schema.Schema) map[string]*schema.Schema {
Expand Down Expand Up @@ -269,9 +272,9 @@ type DummyResourceProvider struct {
DummyNoTfTag
}

func (DummyResourceProvider) Aliases() map[string]string {
return map[string]string{"enabled": "enabled_alias",
"addresses.primary": "primary_alias"}
func (DummyResourceProvider) Aliases() map[string]map[string]string {
return map[string]map[string]string{"common.DummyResourceProvider": {"enabled": "enabled_alias"},
"common.AddressNoTfTag": {"primary": "primary_alias"}}
}

func (DummyResourceProvider) CustomizeSchema(s map[string]*schema.Schema) map[string]*schema.Schema {
Expand Down
Loading