Skip to content

Commit

Permalink
feat: Add default trait support (#803)
Browse files Browse the repository at this point in the history
* Add default value handling to SwiftSymbolProvider.

* Add default value handling for generating expected body when generating expected HTTP response in protocol tests.

* Add more comprehensive default value handling for deserialization, also add error correction for server failing to send a required value in response by filling with zero or zero-equivalents.

* Modify ShapeExt::defaultValue logic to handle an optional value having a default value.

* Handle edgecase where floating point value has integer default value given in trait.

* Provide zero-equivalent error correction default value for enums by using .sdkUnknown case.

* Handle JSON number value equality with loosened restriction (JSON handles 2 and 2.0 the same, as numbers)

* Add enum trait handling

* Fix enum case name codegen for default value

* Fix int enum value handling & address future Swift 6 error warning by using Foundation.Data() to convert string to data. Use symbol's property bag to set flag for importing Data outside of SwiftSymbolProvider.

* Refactor blob and document shape type default value codegen to reduce duplication.

* Use flag for importing Foundation.Data set by SwiftSymbolProvider (when handling blob shape) and add import.

* Ktlint

* Refactor to use when{} as expression; return no default value instead of throwing error in MemberShapeDecodeGenerator for null node case.

* ktlint

* Address PR comments

* Address PR comments

* Address PR comment

* Use closure to handle dependency import for the default value resolved by SwiftSymbolProvider at a later time when the resolved symbol gets used by SwiftWriter. Also, fix timestamp value handling and add logic for date-time case as well.

* ktlint

---------

Co-authored-by: Sichan Yoo <chanyoo@amazon.com>
  • Loading branch information
sichanyoo and Sichan Yoo authored Sep 6, 2024
1 parent 6a3e98f commit e87dd41
Show file tree
Hide file tree
Showing 8 changed files with 351 additions and 34 deletions.
2 changes: 2 additions & 0 deletions Sources/SmithyTestUtil/JSONComparator.swift
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ fileprivate func anyValuesAreEqual(_ lhs: Any?, _ rhs: Any?) -> Bool {
return anyDictsAreEqual(lhsDict, rhsDict)
} else if let lhsArray = lhs as? [Any], let rhsArray = rhs as? [Any] {
return anyArraysAreEqual(lhsArray, rhsArray)
} else if let lhn = lhs as? NSNumber, let rhn = rhs as? NSNumber {
return lhn == rhn
} else {
return type(of: lhs) == type(of: rhs) && "\(lhs)" == "\(rhs)"
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import software.amazon.smithy.codegen.core.SymbolProvider
import software.amazon.smithy.codegen.core.SymbolReference
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.knowledge.NullableIndex
import software.amazon.smithy.model.node.Node
import software.amazon.smithy.model.shapes.BigDecimalShape
import software.amazon.smithy.model.shapes.BigIntegerShape
import software.amazon.smithy.model.shapes.BlobShape
Expand All @@ -23,6 +24,7 @@ import software.amazon.smithy.model.shapes.DocumentShape
import software.amazon.smithy.model.shapes.DoubleShape
import software.amazon.smithy.model.shapes.EnumShape
import software.amazon.smithy.model.shapes.FloatShape
import software.amazon.smithy.model.shapes.IntEnumShape
import software.amazon.smithy.model.shapes.IntegerShape
import software.amazon.smithy.model.shapes.ListShape
import software.amazon.smithy.model.shapes.LongShape
Expand All @@ -39,8 +41,11 @@ import software.amazon.smithy.model.shapes.StringShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.shapes.TimestampShape
import software.amazon.smithy.model.shapes.UnionShape
import software.amazon.smithy.model.traits.ClientOptionalTrait
import software.amazon.smithy.model.traits.DefaultTrait
import software.amazon.smithy.model.traits.EnumTrait
import software.amazon.smithy.model.traits.ErrorTrait
import software.amazon.smithy.model.traits.InputTrait
import software.amazon.smithy.model.traits.SparseTrait
import software.amazon.smithy.model.traits.StreamingTrait
import software.amazon.smithy.swift.codegen.customtraits.NestedTrait
Expand All @@ -49,8 +54,15 @@ import software.amazon.smithy.swift.codegen.model.SymbolProperty
import software.amazon.smithy.swift.codegen.model.boxed
import software.amazon.smithy.swift.codegen.model.buildSymbol
import software.amazon.smithy.swift.codegen.model.defaultName
import software.amazon.smithy.swift.codegen.model.defaultValue
import software.amazon.smithy.swift.codegen.model.defaultValueClosure
import software.amazon.smithy.swift.codegen.model.getTrait
import software.amazon.smithy.swift.codegen.model.hasTrait
import software.amazon.smithy.swift.codegen.model.nestedNamespaceType
import software.amazon.smithy.swift.codegen.swiftmodules.FoundationTypes
import software.amazon.smithy.swift.codegen.swiftmodules.SmithyReadWriteTypes
import software.amazon.smithy.swift.codegen.swiftmodules.SmithyTimestampsTypes
import software.amazon.smithy.swift.codegen.swiftmodules.SmithyTypes
import software.amazon.smithy.swift.codegen.swiftmodules.SwiftTypes
import software.amazon.smithy.swift.codegen.utils.ModelFileUtils
import software.amazon.smithy.swift.codegen.utils.clientName
Expand Down Expand Up @@ -104,21 +116,21 @@ class SwiftSymbolProvider(private val model: Model, val swiftSettings: SwiftSett
return escaper.escapeMemberName(shape.memberName.toLowerCamelCase())
}

override fun integerShape(shape: IntegerShape): Symbol = numberShape(shape, "Int", "0")
override fun integerShape(shape: IntegerShape): Symbol = numberShape(shape, "Int")

override fun floatShape(shape: FloatShape): Symbol = numberShape(shape, "Float", "0.0")
override fun floatShape(shape: FloatShape): Symbol = numberShape(shape, "Float")

override fun longShape(shape: LongShape): Symbol = numberShape(shape, "Int", "0")
override fun longShape(shape: LongShape): Symbol = numberShape(shape, "Int")

override fun doubleShape(shape: DoubleShape): Symbol = numberShape(shape, "Double", "0.0")
override fun doubleShape(shape: DoubleShape): Symbol = numberShape(shape, "Double")

override fun byteShape(shape: ByteShape): Symbol = numberShape(shape, "Int8", "0")
override fun byteShape(shape: ByteShape): Symbol = numberShape(shape, "Int8")

override fun shortShape(shape: ShortShape): Symbol = numberShape(shape, "Int16", "0")
override fun shortShape(shape: ShortShape): Symbol = numberShape(shape, "Int16")

override fun bigIntegerShape(shape: BigIntegerShape): Symbol = numberShape(shape, "Int", defaultValue = "0")
override fun bigIntegerShape(shape: BigIntegerShape): Symbol = numberShape(shape, "Int")

override fun bigDecimalShape(shape: BigDecimalShape): Symbol = numberShape(shape, "Double", "0.0")
override fun bigDecimalShape(shape: BigDecimalShape): Symbol = numberShape(shape, "Double")

override fun stringShape(shape: StringShape): Symbol {
val enumTrait = shape.getTrait(EnumTrait::class.java)
Expand Down Expand Up @@ -149,7 +161,7 @@ class SwiftSymbolProvider(private val model: Model, val swiftSettings: SwiftSett
}

override fun booleanShape(shape: BooleanShape): Symbol {
return createSymbolBuilder(shape, "Bool", namespace = "Swift", SwiftDeclaration.STRUCT).putProperty(SymbolProperty.DEFAULT_VALUE_KEY, "false").build()
return createSymbolBuilder(shape, "Bool", namespace = "Swift", SwiftDeclaration.STRUCT).build()
}

override fun structureShape(shape: StructureShape): Symbol {
Expand Down Expand Up @@ -205,7 +217,7 @@ class SwiftSymbolProvider(private val model: Model, val swiftSettings: SwiftSett
.putProperty(SymbolProperty.NESTED_SYMBOL, symbol)
.build()
}
return symbol
return handleDefaultValue(shape, symbol.toBuilder()).build()
}

override fun timestampShape(shape: TimestampShape): Symbol {
Expand Down Expand Up @@ -243,25 +255,30 @@ class SwiftSymbolProvider(private val model: Model, val swiftSettings: SwiftSett
.build()
}

private fun numberShape(shape: Shape?, typeName: String, defaultValue: String = "0"): Symbol {
if (shape != null && shape.isIntEnumShape()) {
private fun numberShape(shape: Shape, typeName: String): Symbol {
if (shape.isIntEnumShape()) {
return createEnumSymbol(shape)
}
return createSymbolBuilder(shape, typeName, "Swift", SwiftDeclaration.STRUCT).putProperty(SymbolProperty.DEFAULT_VALUE_KEY, defaultValue).build()
return createSymbolBuilder(shape, typeName, "Swift", SwiftDeclaration.STRUCT).build()
}

/**
* Creates a symbol builder for the shape with the given type name in the root namespace.
*/
private fun createSymbolBuilder(shape: Shape?, typeName: String, declaration: SwiftDeclaration, boxed: Boolean = false): Symbol.Builder {
private fun createSymbolBuilder(
shape: Shape,
typeName: String,
declaration: SwiftDeclaration,
boxed: Boolean = false
): Symbol.Builder {
val builder = Symbol.builder()
.putProperty("shape", shape)
.putProperty("decl", declaration.keyword)
.name(typeName)
if (boxed) {
builder.boxed()
}
return builder
return handleDefaultValue(shape, builder)
}

/**
Expand All @@ -270,7 +287,7 @@ class SwiftSymbolProvider(private val model: Model, val swiftSettings: SwiftSett
* the namespace (and ultimately the package name) to `foo.bar` for the symbol.
*/
private fun createSymbolBuilder(
shape: Shape?,
shape: Shape,
typeName: String,
namespace: String,
declaration: SwiftDeclaration,
Expand All @@ -284,6 +301,117 @@ class SwiftSymbolProvider(private val model: Model, val swiftSettings: SwiftSett
return ModelFileUtils.filename(swiftSettings, name)
}

/**
* Resolve default value for a given shape and save it as a property in symbol builder if needed.
*
* The default trait can be applied to list shape, map shape, and all simple types as per Smithy spec.
* Both the member shape and the target shape may have the default trait.
*
* There exist default value restrictions for the following shapes:
* - enum: can be set to any valid string value of the enum.
* - intEnum: can be set to any valid integer value of the enum.
* - document: can be set to null, true, false, string, numbers, an empty list, or an empty map.
* - list: can only be set to an empty list.
* - map: can only be set to an empty map.
*/
private fun handleDefaultValue(shape: Shape, builder: Symbol.Builder): Symbol.Builder {
// Skip if the current shape is a member shape with @clientOptional trait
if (shape.hasTrait<ClientOptionalTrait>()) return builder
// Skip if the current shape doesn't have default trait. Otherwise, get the default value as literal string
val defaultValueLiteral = shape.getTrait<DefaultTrait>()?.toNode()?.toString() ?: return builder
// If default value is "null", it is explicit notation for no default value. Return unmodified builder.
if (defaultValueLiteral == "null") return builder

// The current shape may be a member shape or a root level shape.
val targetShape = when (shape) {
is MemberShape -> {
// If containing shape is an input shape, return unmodified builder.
if (model.expectShape(shape.container).hasTrait<InputTrait>()) return builder
model.expectShape(shape.target)
}
else -> shape
}
val node = shape.getTrait<DefaultTrait>()!!.toNode()

return when (targetShape) {
is ListShape -> builder.defaultValue("[]")
is EnumShape -> builder.defaultValue(".${swiftEnumCaseName(null, defaultValueLiteral)}")
is IntEnumShape -> {
// Get the corresponding enum member name (enum case name) for the int value from default trait
val enumMemberName = targetShape.enumValues.entries.firstOrNull {
it.value == defaultValueLiteral.toInt()
}!!.key
builder.defaultValue(".${swiftEnumCaseName(enumMemberName, defaultValueLiteral)}")
}
is StringShape -> builder.defaultValue("\"$defaultValueLiteral\"")
is MapShape -> builder.defaultValue("[:]")
is BlobShape -> handleBlobDefaultValue(defaultValueLiteral, targetShape, builder)
is DocumentShape -> {
handleDocumentDefaultValue(defaultValueLiteral, node, builder)
}
is TimestampShape -> handleTimestampDefaultValue(defaultValueLiteral, node, builder)
is FloatShape, is DoubleShape -> {
val decimal = ".0".takeIf { !defaultValueLiteral.contains(".") } ?: ""
builder.defaultValue(defaultValueLiteral + decimal)
}
// For: boolean, byte, short, integer, long, bigInteger, bigDecimal,
// just take the literal string value from the trait.
else -> builder.defaultValue(defaultValueLiteral)
}
}

// Document: default value can be set to null, true, false, string, numbers, an empty list, or an empty map.
private fun handleDocumentDefaultValue(literal: String, node: Node, builder: Symbol.Builder): Symbol.Builder {
var formatString = when {
node.isObjectNode -> "\$N.object([:])"
node.isArrayNode -> "\$N.array([])"
node.isBooleanNode -> "\$N.boolean($literal)"
node.isStringNode -> "\$N.string(\"$literal\")"
node.isNumberNode -> "\$N.number($literal)"
else -> return builder // no-op
}
return builder.defaultValueClosure { writer ->
writer.format(formatString, SmithyReadWriteTypes.Document)
}
}

private fun handleBlobDefaultValue(literal: String, shape: Shape, builder: Symbol.Builder): Symbol.Builder {
return builder.defaultValueClosure(
if (shape.hasTrait<StreamingTrait>()) {
{ writer ->
writer.format(
"\$N.data(\$N(\"$literal\".utf8))",
SmithyTypes.ByteStream,
FoundationTypes.Data
)
}
} else {
{ writer ->
writer.format("\$N(\"$literal\".utf8)", FoundationTypes.Data)
}
}
)
}

private fun handleTimestampDefaultValue(literal: String, node: Node, builder: Symbol.Builder): Symbol.Builder {
// Smithy validates that default value given to timestamp shape must either be a
// number (for epoch-seconds) or a date-time string compliant with RFC3339.
return builder.defaultValueClosure(
if (node.isNumberNode) {
{ writer ->
writer.format("\$N(timeIntervalSince1970: $literal)", FoundationTypes.Date)
}
} else {
{ writer ->
writer.format(
"\$N(format: .dateTime).date(from: \"$literal\")",
SmithyTimestampsTypes.TimestampFormatter
)
}
}
)
}

/**
* Add all the [members] as references needed to declare the given symbol being built.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ import software.amazon.smithy.swift.cod.DocumentationConverter
import software.amazon.smithy.swift.codegen.integration.SectionId
import software.amazon.smithy.swift.codegen.integration.SectionWriter
import software.amazon.smithy.swift.codegen.integration.SwiftIntegration
import software.amazon.smithy.swift.codegen.model.SymbolProperty
import software.amazon.smithy.swift.codegen.model.defaultValue
import software.amazon.smithy.swift.codegen.model.defaultValueFromClosure
import software.amazon.smithy.swift.codegen.model.isBoxed
import software.amazon.smithy.swift.codegen.model.isBuiltIn
import software.amazon.smithy.swift.codegen.model.isGeneric
Expand Down Expand Up @@ -190,7 +192,7 @@ class SwiftWriter(
}

if (shouldSetDefault) {
type.defaultValue()?.let {
getDefaultValue(type)?.let {
formatted += " = $it"
}
}
Expand All @@ -200,6 +202,14 @@ class SwiftWriter(
else -> throw CodegenException("Invalid type provided for \$T. Expected a Symbol, but found `$type`")
}
}

private fun getDefaultValue(symbol: Symbol): String? {
return if (symbol.properties.containsKey(SymbolProperty.DEFAULT_VALUE_CLOSURE_KEY)) {
symbol.defaultValueFromClosure(writer)
} else {
symbol.defaultValue()
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ fun String.removeSurroundingBackticks() = removeSurrounding("`", "`")
*/
fun swiftEnumCaseName(name: String?, value: String, shouldBeEscaped: Boolean = true): String {
val resolvedName = name ?: value
var enumCaseName = CaseUtils.toCamelCase(resolvedName.replace(Regex("[^a-zA-Z0-9_ ]"), ""))
var enumCaseName = CaseUtils.toCamelCase(resolvedName.replace(Regex("[^a-zA-Z0-9_ -]"), ""))
if (!SwiftSymbolProvider.isValidSwiftIdentifier(enumCaseName)) {
enumCaseName = "_$enumCaseName"
}
Expand Down
Loading

0 comments on commit e87dd41

Please sign in to comment.