Skip to content

Commit

Permalink
feat: Adds IntEnum support (#484)
Browse files Browse the repository at this point in the history
* feat: Adds IntEnum support

- Generate IntEnum shapes
- Moves shared enum casing logic to Utils so it can be easily shared
- Updates test generator to sort members by lower camel case

* Adds IntEnumGenerator

* ktlintformat

* Addresses swiftlint violations

* Adds proper import and removes wildcard import
  • Loading branch information
epau authored Nov 28, 2022
1 parent 749c379 commit 35d41e8
Show file tree
Hide file tree
Showing 12 changed files with 274 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ public struct OrderedGroup<Input, Output, Context: MiddlewareContext> {
order.insert(relativeTo: relativeTo, position: position, ids: middleware.id)
}

func get(id: String)-> AnyMiddleware<Input, Output, Context>? {
func get(id: String) -> AnyMiddleware<Input, Output, Context>? {
return _items[id]
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
// SPDX-License-Identifier: Apache-2.0
//

#if os(iOS) || os (watchOS) || os(macOS) || os(tvOS)
#if os(iOS) || os(watchOS) || os(macOS) || os(tvOS)
import Foundation.NSProcessInfo

public struct PlatformOperationSystemVersion {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import software.amazon.smithy.codegen.core.SymbolProvider
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.knowledge.ServiceIndex
import software.amazon.smithy.model.neighbor.Walker
import software.amazon.smithy.model.shapes.IntegerShape
import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.shapes.ShapeVisitor
Expand Down Expand Up @@ -167,6 +168,13 @@ class CodegenVisitor(context: PluginContext) : ShapeVisitor.Default<Void>() {
return null
}

override fun integerShape(shape: IntegerShape): Void? {
if (shape.isIntEnumShape()) {
writers.useShapeWriter(shape) { writer: SwiftWriter -> IntEnumGenerator(model, symbolProvider, writer, shape.asIntEnumShape().get(), settings).render() }
}
return null
}

override fun unionShape(shape: UnionShape): Void? {
writers.useShapeWriter(shape) { writer: SwiftWriter -> UnionGenerator(model, symbolProvider, writer, shape, settings).render() }
return null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,9 @@ import software.amazon.smithy.model.shapes.StringShape
import software.amazon.smithy.model.traits.EnumDefinition
import software.amazon.smithy.model.traits.EnumTrait
import software.amazon.smithy.swift.codegen.customtraits.NestedTrait
import software.amazon.smithy.swift.codegen.lang.reservedWords
import software.amazon.smithy.swift.codegen.model.expectShape
import software.amazon.smithy.swift.codegen.model.hasTrait
import software.amazon.smithy.swift.codegen.model.nestedNamespaceType
import software.amazon.smithy.utils.CaseUtils

/**
* Generates an appropriate Swift type for a Smithy enum string.
Expand Down Expand Up @@ -227,19 +225,6 @@ class EnumGenerator(
* them to camelCase after removing chars except alphanumeric, space and underscore.
*/
fun EnumDefinition.swiftEnumCaseName(shouldBeEscaped: Boolean = true): String {
var enumCaseName = CaseUtils.toCamelCase(
name.orElseGet {
value
}.replace(Regex("[^a-zA-Z0-9_ ]"), "")
)
if (!SymbolVisitor.isValidSwiftIdentifier(enumCaseName)) {
enumCaseName = "_$enumCaseName"
}

if (shouldBeEscaped && reservedWords.contains(enumCaseName)) {
enumCaseName = SymbolVisitor.escapeReservedWords(enumCaseName)
}

return enumCaseName
return swiftEnumCaseName(name, value, shouldBeEscaped)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
package software.amazon.smithy.swift.codegen

import software.amazon.smithy.codegen.core.SymbolProvider
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.IntEnumShape
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.model.traits.EnumValueTrait
import software.amazon.smithy.swift.codegen.customtraits.NestedTrait
import software.amazon.smithy.swift.codegen.model.expectShape
import software.amazon.smithy.swift.codegen.model.hasTrait
import software.amazon.smithy.swift.codegen.model.nestedNamespaceType
import java.util.Optional

class IntEnumGenerator(
private val model: Model,
private val symbolProvider: SymbolProvider,
private val writer: SwiftWriter,
private val shape: IntEnumShape,
private val settings: SwiftSettings
) {
private var allCasesBuilder: MutableList<String> = mutableListOf()
private var rawValuesBuilder: MutableList<String> = mutableListOf()

fun render() {
val symbol = symbolProvider.toSymbol(shape)
writer.putContext("enum.name", symbol.name)
val isNestedType = shape.hasTrait<NestedTrait>()
if (isNestedType) {
val service = model.expectShape<ServiceShape>(settings.service)
writer.openBlock("extension ${service.nestedNamespaceType(symbolProvider)} {", "}") {
renderEnum()
}
} else {
renderEnum()
}
writer.removeContext("enum.name")
}

private fun renderEnum() {
writer.writeShapeDocs(shape)
writer.writeAvailableAttribute(null, shape)
writer.openBlock("public enum \$enum.name:L: \$N, \$N, \$N, \$N, \$N {", "}", SwiftTypes.Protocols.Equatable, SwiftTypes.Protocols.RawRepresentable, SwiftTypes.Protocols.CaseIterable, SwiftTypes.Protocols.Codable, SwiftTypes.Protocols.Hashable) {
createEnumWriterContexts()
// add the sdkUnknown case which will always be last
writer.write("case sdkUnknown(\$N)", SwiftTypes.Int)

writer.write("")

// Generate allCases static array
generateAllCasesBlock()

// Generate initializer from rawValue
generateInitFromRawValueBlock()

// Generate rawValue internal enum
generateRawValueEnumBlock()
}
}

fun addEnumCaseToEnum(caseShape: MemberShape) {
writer.writeMemberDocs(model, caseShape)
writer.write("case ${caseShape.swiftEnumCaseName()}")
}

fun addEnumCaseToAllCases(caseShape: MemberShape) {
allCasesBuilder.add(".${caseShape.swiftEnumCaseName(false)}")
}

fun addEnumCaseToRawValuesEnum(caseShape: MemberShape) {
rawValuesBuilder.add("case .${caseShape.swiftEnumCaseName(false)}: return ${caseShape.swiftEnumCaseValue()}")
}

fun createEnumWriterContexts() {
shape
.getCaseMembers()
.sortedBy { it.memberName }
.forEach {
// Add all given enum cases to generated enum definition
addEnumCaseToEnum(it)
addEnumCaseToAllCases(it)
addEnumCaseToRawValuesEnum(it)
}
}

fun generateAllCasesBlock() {
allCasesBuilder.add(".sdkUnknown(0)")
writer.openBlock("public static var allCases: [\$enum.name:L] {", "}") {
writer.openBlock("return [", "]") {
writer.write(allCasesBuilder.joinToString(",\n"))
}
}
}

fun generateInitFromRawValueBlock() {
writer.openBlock("public init(rawValue: \$N) {", "}", SwiftTypes.Int) {
writer.write("let value = Self.allCases.first(where: { \$\$0.rawValue == rawValue })")
writer.write("self = value ?? Self.sdkUnknown(rawValue)")
}
}

fun generateRawValueEnumBlock() {
rawValuesBuilder.add("case let .sdkUnknown(s): return s")
writer.openBlock("public var rawValue: \$N {", "}", SwiftTypes.Int) {
writer.write("switch self {")
writer.write(rawValuesBuilder.joinToString("\n"))
writer.write("}")
}
}

fun IntEnumShape.getCaseMembers(): List<MemberShape> {
return members().filter {
it.hasTrait<EnumValueTrait>()
}
}

fun MemberShape.swiftEnumCaseName(shouldBeEscaped: Boolean = true): String {
return swiftEnumCaseName(
Optional.of(memberName),
"${swiftEnumCaseValue()}",
shouldBeEscaped
)
}

fun MemberShape.swiftEnumCaseValue(): Int {
return expectTrait<EnumValueTrait>(EnumValueTrait::class.java).expectIntValue()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import software.amazon.smithy.swift.codegen.customtraits.SwiftBoxTrait
import software.amazon.smithy.swift.codegen.model.hasTrait
import software.amazon.smithy.swift.codegen.model.recursiveSymbol
import software.amazon.smithy.swift.codegen.model.toMemberNames
import software.amazon.smithy.swift.codegen.utils.toLowerCamelCase
import software.amazon.smithy.utils.StringUtils.lowerCase

/**
Expand Down Expand Up @@ -228,7 +229,7 @@ class ShapeValueGenerator(
// this is important because when a struct is generated in swift it is generated with its members sorted by name.
// when you instantiate that struct you have to call params in order with their param names. if you don't it won't compile
// so we sort here before we write any of the members with their values
val sortedMembers = node.members.toSortedMap(compareBy<StringNode> { it.value.lowercase() })
val sortedMembers = node.members.toSortedMap(compareBy<StringNode> { it.value.toLowerCamelCase() })
sortedMembers.forEach { (keyNode, valueNode) ->
val memberShape: Shape
when (currShape) {
Expand Down Expand Up @@ -324,6 +325,14 @@ class ShapeValueGenerator(
writer.writeInline("Date(timeIntervalSince1970: \$L)", node.value)
}

ShapeType.INT_ENUM -> {
val enumSymbol = generator.symbolProvider.toSymbol(currShape)
writer.writeInline(
"\$L(rawValue: \$L)",
enumSymbol, node.value
)
}

ShapeType.BYTE, ShapeType.SHORT, ShapeType.INTEGER,
ShapeType.LONG, ShapeType.DOUBLE, ShapeType.FLOAT -> writer.writeInline("\$L", node.value)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,9 @@ class SymbolVisitor(private val model: Model, swiftSettings: SwiftSettings) :
}

private fun numberShape(shape: Shape?, typeName: String, defaultValue: String = "0"): Symbol {
if (shape != null && shape.isIntEnumShape()) {
return createEnumSymbol(shape)
}
return createSymbolBuilder(shape, typeName, "Swift").putProperty(SymbolProperty.DEFAULT_VALUE_KEY, defaultValue).build()
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,32 @@
*/

package software.amazon.smithy.swift.codegen
import software.amazon.smithy.swift.codegen.lang.reservedWords
import software.amazon.smithy.utils.CaseUtils
import java.util.Optional

fun <T> Optional<T>.getOrNull(): T? = if (isPresent) get() else null

fun String.removeSurroundingBackticks() = removeSurrounding("`", "`")

/**
* Creates an idiomatic name for swift enum cases given an optional name and value.
* Uses either name or value attributes of EnumDefinition in that order and formats
* them to camelCase after removing chars except alphanumeric, space and underscore.
*/
fun swiftEnumCaseName(name: Optional<String>, value: String, shouldBeEscaped: Boolean = true): String {
var enumCaseName = CaseUtils.toCamelCase(
name
.orElseGet { value }
.replace(Regex("[^a-zA-Z0-9_ ]"), "")
)
if (!SymbolVisitor.isValidSwiftIdentifier(enumCaseName)) {
enumCaseName = "_$enumCaseName"
}

if (shouldBeEscaped && reservedWords.contains(enumCaseName)) {
enumCaseName = SymbolVisitor.escapeReservedWords(enumCaseName)
}

return enumCaseName
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import software.amazon.smithy.model.neighbor.RelationshipType
import software.amazon.smithy.model.neighbor.Walker
import software.amazon.smithy.model.shapes.BlobShape
import software.amazon.smithy.model.shapes.CollectionShape
import software.amazon.smithy.model.shapes.IntEnumShape
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.Shape
Expand Down Expand Up @@ -103,6 +104,7 @@ fun formatHeaderOrQueryValue(
}
Pair(formattedItemValue, requiresDoCatch)
}
is IntEnumShape -> Pair("$memberName.rawValue", false)
else -> Pair(memberName, false)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,16 @@ class HttpResponseHeaders(
writer.indent()
when (memberTarget) {
is NumberShape -> {
val memberValue = stringToNumber(memberTarget, headerDeclaration, true)
writer.write("self.\$L = $memberValue", memberName)
if (memberTarget.isIntEnumShape) {
val enumSymbol = ctx.symbolProvider.toSymbol(memberTarget)
writer.write(
"self.\$L = \$L(rawValue: \$L(\$L) ?? 0)",
memberName, enumSymbol, SwiftTypes.Int, headerDeclaration
)
} else {
val memberValue = stringToNumber(memberTarget, headerDeclaration, true)
writer.write("self.\$L = \$L", memberName, memberValue)
}
}
is BlobShape -> {
val memberValue = "$headerDeclaration.data(using: .utf8)"
Expand Down Expand Up @@ -94,7 +102,14 @@ class HttpResponseHeaders(
invalidHeaderListErrorName = "invalidBooleanHeaderList"
"${SwiftTypes.Bool}(\$0)"
}
is NumberShape -> "${stringToNumber(collectionMemberTarget, "\$0", false)}"
is NumberShape -> {
if (collectionMemberTarget.isIntEnumShape) {
val enumSymbol = ctx.symbolProvider.toSymbol(collectionMemberTarget)
"${SwiftTypes.Int}(\$0).map({ intValue in $enumSymbol(rawValue: intValue) })"
} else {
"${stringToNumber(collectionMemberTarget, "\$0", false)}"
}
}
is TimestampShape -> {
val bindingIndex = HttpBindingIndex.of(ctx.model)
val tsFormat = bindingIndex.determineTimestampFormat(
Expand Down Expand Up @@ -133,7 +148,7 @@ class HttpResponseHeaders(
// render map function
val collectionMemberTargetShape = ctx.model.expectShape(memberTarget.member.target)
val collectionMemberTargetSymbol = ctx.symbolProvider.toSymbol(collectionMemberTargetShape)
if (!collectionMemberTargetSymbol.isBoxed()) {
if (!collectionMemberTargetSymbol.isBoxed() || collectionMemberTargetShape.isIntEnumShape()) {
writer.openBlock("self.\$L = try \$LHeaderValues.map {", "}", memberName, memberName) {
val transformedHeaderDeclaration = "${memberName}Transformed"
writer.openBlock("guard let \$L = \$L else {", "}", transformedHeaderDeclaration, conversion) {
Expand Down
52 changes: 52 additions & 0 deletions smithy-swift-codegen/src/test/kotlin/IntEnumGeneratorTests.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import io.kotest.matchers.string.shouldContain
import org.junit.jupiter.api.Assertions
import org.junit.jupiter.api.Test
import software.amazon.smithy.build.MockManifest
import software.amazon.smithy.swift.codegen.SwiftCodegenPlugin

class IntEnumGeneratorTests {

@Test
fun `generates int enum`() {
val model = javaClass.getResource("int-enum-shape-test.smithy").asSmithy()
val manifest = MockManifest()
val context = buildMockPluginContext(model, manifest, "smithy.example#Example")
SwiftCodegenPlugin().execute(context)
val enumShape = manifest
.getFileString("example/models/Abcs.swift").get()
Assertions.assertNotNull(enumShape)

var expectedGeneratedEnum =
"""
public enum Abcs: Swift.Equatable, Swift.RawRepresentable, Swift.CaseIterable, Swift.Codable, Swift.Hashable {
case a
case b
case c
case sdkUnknown(Swift.Int)
public static var allCases: [Abcs] {
return [
.a,
.b,
.c,
.sdkUnknown(0)
]
}
public init(rawValue: Swift.Int) {
let value = Self.allCases.first(where: { ${'$'}0.rawValue == rawValue })
self = value ?? Self.sdkUnknown(rawValue)
}
public var rawValue: Swift.Int {
switch self {
case .a: return 1
case .b: return 2
case .c: return 3
case let .sdkUnknown(s): return s
}
}
}
""".trimIndent()

enumShape.shouldContain(expectedGeneratedEnum)
}
}
Loading

0 comments on commit 35d41e8

Please sign in to comment.