diff --git a/.changes/fb00b4ae-ffdb-4137-baa8-574848296da1.json b/.changes/fb00b4ae-ffdb-4137-baa8-574848296da1.json new file mode 100644 index 000000000..431e3b97f --- /dev/null +++ b/.changes/fb00b4ae-ffdb-4137-baa8-574848296da1.json @@ -0,0 +1,8 @@ +{ + "id": "fb00b4ae-ffdb-4137-baa8-574848296da1", + "type": "bugfix", + "description": "Refactor XML deserialization to handle flat collections", + "issues": [ + "awslabs/aws-sdk-kotlin#1220" + ] +} \ No newline at end of file diff --git a/build.gradle.kts b/build.gradle.kts index cc54aef71..6e1e0c77a 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -130,7 +130,8 @@ apiValidation { "channel-benchmarks", "http-benchmarks", "serde-benchmarks", - "serde-benchmarks-codegen", + "serde-codegen-support", + "serde-tests", "nullability-tests", "paginator-tests", "waiter-tests", diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/core/AbstractCodeWriterExt.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/core/AbstractCodeWriterExt.kt index 54bded579..2b943da04 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/core/AbstractCodeWriterExt.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/core/AbstractCodeWriterExt.kt @@ -187,3 +187,7 @@ fun > T.callIf(test: Boolean, runnable: Runnable): T { } return this } + +/** Escape the [expressionStart] character to avoid problems during formatting */ +fun > T.escape(text: String): String = + text.replace("$expressionStart", "$expressionStart$expressionStart") diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/core/RuntimeTypes.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/core/RuntimeTypes.kt index 36bc0cec8..b382f3de2 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/core/RuntimeTypes.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/core/RuntimeTypes.kt @@ -103,6 +103,7 @@ object RuntimeTypes { val Attributes = symbol("Attributes") val attributesOf = symbol("attributesOf") val AttributeKey = symbol("AttributeKey") + val createOrAppend = symbol("createOrAppend") val get = symbol("get") val mutableMultiMapOf = symbol("mutableMultiMapOf") val putIfAbsent = symbol("putIfAbsent") @@ -230,6 +231,7 @@ object RuntimeTypes { val SerialKind = symbol("SerialKind") val SerializationException = symbol("SerializationException") val DeserializationException = symbol("DeserializationException") + val getOrDeserializeErr = symbol("getOrDeserializeErr") val serializeStruct = symbol("serializeStruct") val serializeList = symbol("serializeList") @@ -241,6 +243,18 @@ object RuntimeTypes { val asSdkSerializable = symbol("asSdkSerializable") val field = symbol("field") + val parse = symbol("parse") + val parseInt = symbol("parseInt") + val parseShort = symbol("parseShort") + val parseLong = symbol("parseLong") + val parseFloat = symbol("parseFloat") + val parseDouble = symbol("parseDouble") + val parseByte = symbol("parseByte") + val parseBoolean = symbol("parseBoolean") + val parseTimestamp = symbol("parseTimestamp") + val parseBigInteger = symbol("parseBigInteger") + val parseBigDecimal = symbol("parseBigDecimal") + object SerdeJson : RuntimeTypePackage(KotlinDependency.SERDE_JSON) { val JsonSerialName = symbol("JsonSerialName") val JsonSerializer = symbol("JsonSerializer") @@ -260,8 +274,13 @@ object RuntimeTypes { val XmlMapName = symbol("XmlMapName") val XmlError = symbol("XmlError") val XmlSerializer = symbol("XmlSerializer") - val XmlDeserializer = symbol("XmlDeserializer") val XmlUnwrappedOutput = symbol("XmlUnwrappedOutput") + + val XmlTagReader = symbol("XmlTagReader") + val xmlStreamReader = symbol("xmlStreamReader") + val xmlRootTagReader = symbol("xmlTagReader") + val data = symbol("data") + val tryData = symbol("tryData") } object SerdeFormUrl : RuntimeTypePackage(KotlinDependency.SERDE_FORM_URL) { diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/lang/KotlinTypes.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/lang/KotlinTypes.kt index 3f60e404c..fb967ccf7 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/lang/KotlinTypes.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/lang/KotlinTypes.kt @@ -41,6 +41,7 @@ object KotlinTypes { val List: Symbol = stdlibSymbol("List") val listOf: Symbol = stdlibSymbol("listOf") val MutableList: Symbol = stdlibSymbol("MutableList") + val MutableMap: Symbol = stdlibSymbol("MutableMap") val Map: Symbol = stdlibSymbol("Map") val mutableListOf: Symbol = stdlibSymbol("mutableListOf") val mutableMapOf: Symbol = stdlibSymbol("mutableMapOf") diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/protocol/HttpBindingProtocolGenerator.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/protocol/HttpBindingProtocolGenerator.kt index 9eed98597..e74474ba3 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/protocol/HttpBindingProtocolGenerator.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/protocol/HttpBindingProtocolGenerator.kt @@ -14,7 +14,7 @@ import software.amazon.smithy.kotlin.codegen.lang.toEscapedLiteral import software.amazon.smithy.kotlin.codegen.model.* import software.amazon.smithy.kotlin.codegen.rendering.serde.deserializerName import software.amazon.smithy.kotlin.codegen.rendering.serde.formatInstant -import software.amazon.smithy.kotlin.codegen.rendering.serde.parseInstant +import software.amazon.smithy.kotlin.codegen.rendering.serde.parseInstantExpr import software.amazon.smithy.kotlin.codegen.rendering.serde.serializerName import software.amazon.smithy.kotlin.codegen.utils.getOrNull import software.amazon.smithy.model.Model @@ -813,14 +813,12 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator { HttpBinding.Location.HEADER, defaultTimestampFormat, ) - writer - .addImport(RuntimeTypes.Core.Instant) - .write( - "builder.#L = response.headers[#S]?.let { #L }", - memberName, - headerName, - parseInstant("it", tsFormat), - ) + writer.write( + "builder.#L = response.headers[#S]?.let { #L }", + memberName, + headerName, + writer.parseInstantExpr("it", tsFormat), + ) } is ListShape -> { // member > boolean, number, string, or timestamp @@ -849,8 +847,7 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator { if (tsFormat == TimestampFormatTrait.Format.HTTP_DATE) { splitFn = "splitHttpDateHeaderListValues" } - writer.addImport(RuntimeTypes.Core.Instant) - parseInstant("it", tsFormat) + writer.parseInstantExpr("it", tsFormat) } is StringShape -> { when { diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/DeserializeStructGenerator.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/DeserializeStructGenerator.kt index 5680af9f5..2f8e8de51 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/DeserializeStructGenerator.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/DeserializeStructGenerator.kt @@ -605,5 +605,3 @@ open class DeserializeStructGenerator( } } } - -private fun nullabilitySuffix(isSparse: Boolean): String = if (isSparse) "?" else "" diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/SerdeExt.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/SerdeExt.kt index 56ab30cd5..071d8ba36 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/SerdeExt.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/SerdeExt.kt @@ -9,8 +9,7 @@ import software.amazon.smithy.codegen.core.CodegenException import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.codegen.core.SymbolReference import software.amazon.smithy.kotlin.codegen.KotlinSettings -import software.amazon.smithy.kotlin.codegen.core.SymbolRenderer -import software.amazon.smithy.kotlin.codegen.core.defaultName +import software.amazon.smithy.kotlin.codegen.core.* import software.amazon.smithy.kotlin.codegen.core.mangledSuffix import software.amazon.smithy.kotlin.codegen.model.buildSymbol import software.amazon.smithy.model.Model @@ -216,11 +215,31 @@ fun formatInstant(paramName: String, tsFmt: TimestampFormatTrait.Format, forceSt * @param paramName The name of the local identifier to convert to an `Instant` * @param tsFmt The timestamp format [paramName] is expected to be converted from */ -fun parseInstant(paramName: String, tsFmt: TimestampFormatTrait.Format): String = when (tsFmt) { - TimestampFormatTrait.Format.EPOCH_SECONDS -> "Instant.fromEpochSeconds($paramName)" - TimestampFormatTrait.Format.DATE_TIME -> "Instant.fromIso8601($paramName)" - TimestampFormatTrait.Format.HTTP_DATE -> "Instant.fromRfc5322($paramName)" - else -> throw CodegenException("unknown timestamp format: $tsFmt") +fun KotlinWriter.parseInstantExpr(paramName: String, tsFmt: TimestampFormatTrait.Format): String { + val fn = when (tsFmt) { + TimestampFormatTrait.Format.EPOCH_SECONDS -> "fromEpochSeconds" + TimestampFormatTrait.Format.DATE_TIME -> "fromIso8601" + TimestampFormatTrait.Format.HTTP_DATE -> "fromRfc5322" + else -> throw CodegenException("unknown timestamp format: $tsFmt") + } + return format("#T.#L(#L)", RuntimeTypes.Core.Instant, fn, paramName) +} + +fun TimestampFormatTrait.Format.toRuntimeEnum(): String = when (this) { + TimestampFormatTrait.Format.EPOCH_SECONDS -> "TimestampFormat.EPOCH_SECONDS" + TimestampFormatTrait.Format.DATE_TIME -> "TimestampFormat.ISO_8601" + TimestampFormatTrait.Format.HTTP_DATE -> "TimestampFormat.RFC_5322" + else -> throw CodegenException("unknown timestamp format: $this") +} + +fun TimestampFormatTrait.Format.toRuntimeEnum(writer: KotlinWriter): String { + val enum = when (this) { + TimestampFormatTrait.Format.EPOCH_SECONDS -> "EPOCH_SECONDS" + TimestampFormatTrait.Format.DATE_TIME -> "ISO_8601" + TimestampFormatTrait.Format.HTTP_DATE -> "RFC_5322" + else -> throw CodegenException("unknown timestamp format: $this") + } + return writer.format("#T.#L", RuntimeTypes.Core.TimestampFormat, enum) } /** @@ -289,3 +308,5 @@ internal fun Shape.childShape(model: Model): Shape? = when (this) { is MapShape -> model.expectShape(this.value.target) else -> null } + +internal fun nullabilitySuffix(isSparse: Boolean): String = if (isSparse) "?" else "" diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/SerializeStructGenerator.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/SerializeStructGenerator.kt index 017f612bd..ff099b0c9 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/SerializeStructGenerator.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/SerializeStructGenerator.kt @@ -680,10 +680,3 @@ open class SerializeStructGenerator( return "serialize$suffix" } } - -fun TimestampFormatTrait.Format.toRuntimeEnum(): String = when (this) { - TimestampFormatTrait.Format.EPOCH_SECONDS -> "TimestampFormat.EPOCH_SECONDS" - TimestampFormatTrait.Format.DATE_TIME -> "TimestampFormat.ISO_8601" - TimestampFormatTrait.Format.HTTP_DATE -> "TimestampFormat.RFC_5322" - else -> throw CodegenException("unknown timestamp format: $this") -} diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/XmlParserGenerator.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/XmlParserGenerator.kt index da3d78448..4c4e59894 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/XmlParserGenerator.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/XmlParserGenerator.kt @@ -6,34 +6,38 @@ package software.amazon.smithy.kotlin.codegen.rendering.serde import software.amazon.smithy.codegen.core.Symbol -import software.amazon.smithy.kotlin.codegen.core.KotlinWriter +import software.amazon.smithy.codegen.core.SymbolReference +import software.amazon.smithy.kotlin.codegen.core.* import software.amazon.smithy.kotlin.codegen.core.RuntimeTypes -import software.amazon.smithy.kotlin.codegen.core.withBlock +import software.amazon.smithy.kotlin.codegen.lang.KotlinTypes +import software.amazon.smithy.kotlin.codegen.model.* import software.amazon.smithy.kotlin.codegen.model.knowledge.SerdeIndex -import software.amazon.smithy.kotlin.codegen.model.targetOrSelf +import software.amazon.smithy.kotlin.codegen.model.traits.SyntheticClone +import software.amazon.smithy.kotlin.codegen.model.traits.UnwrappedXmlOutput import software.amazon.smithy.kotlin.codegen.rendering.protocol.ProtocolGenerator -import software.amazon.smithy.kotlin.codegen.rendering.protocol.toRenderingContext -import software.amazon.smithy.model.shapes.MemberShape -import software.amazon.smithy.model.shapes.OperationShape -import software.amazon.smithy.model.shapes.Shape -import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.model.shapes.* +import software.amazon.smithy.model.traits.SparseTrait import software.amazon.smithy.model.traits.TimestampFormatTrait +import software.amazon.smithy.model.traits.XmlAttributeTrait +import software.amazon.smithy.model.traits.XmlFlattenedTrait +import software.amazon.smithy.model.traits.XmlNameTrait +import software.amazon.smithy.utils.StringUtils +import kotlin.jvm.optionals.getOrDefault /** * XML parser generator based on common deserializer interface and XML serde descriptors */ open class XmlParserGenerator( - // FIXME - shouldn't be necessary but XML serde descriptor generator needs it for rendering context - private val protocolGenerator: ProtocolGenerator, private val defaultTimestampFormat: TimestampFormatTrait.Format, ) : StructuredDataParserGenerator { - open fun descriptorGenerator( - ctx: ProtocolGenerator.GenerationContext, - shape: Shape, - members: List, - writer: KotlinWriter, - ): XmlSerdeDescriptorGenerator = XmlSerdeDescriptorGenerator(ctx.toRenderingContext(protocolGenerator, shape, writer), members) + /** + * Deserialization context that holds current state + * @param tagReader the name of the current tag reader to operate on + */ + data class SerdeCtx( + val tagReader: String, + ) override fun operationDeserializer( ctx: ProtocolGenerator.GenerationContext, @@ -73,42 +77,59 @@ open class XmlParserGenerator( documentMembers: List, writer: KotlinWriter, ) { - writer.write("val deserializer = #T(payload)", RuntimeTypes.Serde.SerdeXml.XmlDeserializer) + writer.write("val root = #T(payload)", RuntimeTypes.Serde.SerdeXml.xmlRootTagReader) val shape = ctx.model.expectShape(op.output.get()) - renderDeserializerBody(ctx, shape, documentMembers, writer) + val serdeCtx = unwrapOperationBody(ctx, SerdeCtx("root"), op, writer) + + if (op.hasTrait()) { + renderDeserializerUnwrappedXmlBody(ctx, serdeCtx, shape, writer) + } else { + renderDeserializerBody(ctx, serdeCtx, shape, documentMembers, writer) + } } + /** + * Hook for protocols to perform logic prior to deserializing the operation output. + * Implementations must return the [SerdeCtx] to use for further deserialization. + */ + protected open fun unwrapOperationBody( + ctx: ProtocolGenerator.GenerationContext, + serdeCtx: SerdeCtx, + op: OperationShape, + writer: KotlinWriter, + ): SerdeCtx = serdeCtx + protected fun renderDeserializerBody( ctx: ProtocolGenerator.GenerationContext, + serdeCtx: SerdeCtx, shape: Shape, members: List, writer: KotlinWriter, ) { - descriptorGenerator(ctx, shape, members, writer).render() if (shape.isUnionShape) { - val name = ctx.symbolProvider.toSymbol(shape).name - DeserializeUnionGenerator(ctx, name, members, writer, defaultTimestampFormat).render() + deserializeUnion(ctx, serdeCtx, members, writer) } else { - DeserializeStructGenerator(ctx, members, writer, defaultTimestampFormat).render() + deserializeStruct(ctx, serdeCtx, members, writer) } } - protected fun documentDeserializer( + private fun documentDeserializer( ctx: ProtocolGenerator.GenerationContext, shape: Shape, members: Collection = shape.members(), ): Symbol { val symbol = ctx.symbolProvider.toSymbol(shape) return shape.documentDeserializer(ctx.settings, symbol, members) { writer -> - writer.openBlock("internal fun #identifier.name:L(deserializer: #T): #T {", RuntimeTypes.Serde.Deserializer, symbol) + writer.openBlock("internal fun #identifier.name:L(reader: #T): #T {", RuntimeTypes.Serde.SerdeXml.XmlTagReader, symbol) .call { + val serdeCtx = SerdeCtx("reader") if (shape.isUnionShape) { writer.write("var value: #T? = null", symbol) - renderDeserializerBody(ctx, shape, members.toList(), writer) + renderDeserializerBody(ctx, serdeCtx, shape, members.toList(), writer) writer.write("return value ?: throw #T(#S)", RuntimeTypes.Serde.DeserializationException, "Deserialized union value unexpectedly null: ${symbol.name}") } else { writer.write("val builder = #T.Builder()", symbol) - renderDeserializerBody(ctx, shape, members.toList(), writer) + renderDeserializerBody(ctx, serdeCtx, shape, members.toList(), writer) writer.write("builder.correctErrors()") writer.write("return builder.build()") } @@ -128,13 +149,25 @@ open class XmlParserGenerator( val fnName = symbol.errorDeserializerName() writer.openBlock("internal fun #L(builder: #T.Builder, payload: ByteArray) {", fnName, symbol) .call { - writer.write("val deserializer = #T(payload)", RuntimeTypes.Serde.SerdeXml.XmlDeserializer) - renderDeserializerBody(ctx, errorShape, members, writer) + writer.write("val root = #T(payload)", RuntimeTypes.Serde.SerdeXml.xmlRootTagReader) + val serdeCtx = unwrapOperationError(ctx, SerdeCtx("root"), errorShape, writer) + renderDeserializerBody(ctx, serdeCtx, errorShape, members, writer) } .closeBlock("}") } } + /** + * Hook for protocols to perform logic prior to deserializing an operation error. + * Implementations must return the [SerdeCtx] to use for further deserialization. + */ + protected open fun unwrapOperationError( + ctx: ProtocolGenerator.GenerationContext, + serdeCtx: SerdeCtx, + errorShape: StructureShape, + writer: KotlinWriter, + ): SerdeCtx = serdeCtx + override fun payloadDeserializer( ctx: ProtocolGenerator.GenerationContext, shape: Shape, @@ -152,10 +185,469 @@ open class XmlParserGenerator( // short circuit when the shape has no modeled members to deserialize write("return #T.Builder().build()", symbol) } else { - write("val deserializer = #T(payload)", RuntimeTypes.Serde.SerdeXml.XmlDeserializer) - write("return #T(deserializer)", deserializeFn) + writer.write("val root = #T(payload)", RuntimeTypes.Serde.SerdeXml.xmlRootTagReader) + write("return #T(root)", deserializeFn) + } + } + } + } + + private fun KotlinWriter.deserializeLoop( + serdeCtx: SerdeCtx, + ignoreUnexpected: Boolean = true, + block: KotlinWriter.(SerdeCtx) -> Unit, + ) { + withBlock("loop@while (true) {", "}") { + write("val curr = #L.nextTag() ?: break@loop", serdeCtx.tagReader) + withBlock("when (curr.tagName) {", "}") { + block(this, serdeCtx.copy(tagReader = "curr")) + if (ignoreUnexpected) { + write("else -> {}") + } else { + write("else -> throw #T(#S)", RuntimeTypes.Serde.DeserializationException, "Unexpected tag \${curr.tag}") + } + } + // maintain stream reader state by dropping the current element and all it's children + // this ensures nested elements with potentially the same name as a higher level element + // are not erroneously returned and matched by `nextTag()` + write("curr.drop()") + } + } + + private fun originalMember(ctx: ProtocolGenerator.GenerationContext, member: MemberShape): MemberShape { + val containerShapeId = ctx.model.expectShape(member.container).getTrait()?.archetype ?: member.container + val container = ctx.model.expectShape(containerShapeId) + return container.getMember(member.memberName).getOrDefault(member) + } + + private fun KotlinWriter.writeMemberDebugComment( + ctx: ProtocolGenerator.GenerationContext, + member: MemberShape, + ) { + val originalMember = originalMember(ctx, member) + write("// ${originalMember.memberName} ${escape(originalMember.id.toString())}") + } + + private fun deserializeUnion( + ctx: ProtocolGenerator.GenerationContext, + serdeCtx: SerdeCtx, + members: List, + writer: KotlinWriter, + ) { + writer.deserializeLoop(serdeCtx) { innerCtx -> + members.forEach { member -> + val name = member.getTrait()?.value ?: member.memberName + writeMemberDebugComment(ctx, member) + val unionTypeName = member.unionTypeName(ctx) + withBlock("#S -> value = #L(", ")", name, unionTypeName) { + deserializeMember(ctx, innerCtx, member, writer) + } + } + } + } + + private fun deserializeStruct( + ctx: ProtocolGenerator.GenerationContext, + serdeCtx: SerdeCtx, + members: List, + writer: KotlinWriter, + ) { + // split attribute members and non attribute members + val (attributeMembers, payloadMembers) = members.partition { + it.hasTrait() + } + + attributeMembers.forEach { member -> + deserializeAttributeMember(ctx, serdeCtx, member, writer) + } + + // don't generate a parse loop if no attribute members + if (payloadMembers.isEmpty()) return + writer.write("") + writer.deserializeLoop(serdeCtx) { innerCtx -> + payloadMembers.forEach { member -> + val name = member.getTrait()?.value ?: member.memberName + + writeMemberDebugComment(ctx, member) + writeInline("#S -> builder.#L = ", name, ctx.symbolProvider.toMemberName(member)) + deserializeMember(ctx, innerCtx, member, writer) + } + } + } + + private fun deserializeAttributeMember( + ctx: ProtocolGenerator.GenerationContext, + serdeCtx: SerdeCtx, + member: MemberShape, + writer: KotlinWriter, + ) { + val memberName = member.getTrait()?.value ?: member.memberName + writer.withBlock( + "#L.tag.getAttr(#S)?.let {", + "}", + serdeCtx.tagReader, + memberName, + ) { + writeInline("builder.#L = ", ctx.symbolProvider.toMemberName(member)) + deserializePrimitiveMember(ctx, member, "it", textExprIsResult = false, this) + } + } + + private fun deserializeMember( + ctx: ProtocolGenerator.GenerationContext, + serdeCtx: SerdeCtx, + member: MemberShape, + writer: KotlinWriter, + ) { + val target = ctx.model.expectShape(member.target) + when (target.type) { + ShapeType.LIST, ShapeType.SET -> { + if (member.hasTrait()) { + deserializeFlatList(ctx, serdeCtx, member, writer) + } else { + deserializeList(ctx, serdeCtx, member, writer) + } + } + ShapeType.MAP -> { + if (member.hasTrait()) { + deserializeFlatMap(ctx, serdeCtx, member, writer) + } else { + deserializeMap(ctx, serdeCtx, member, writer) + } + } + ShapeType.STRUCTURE, ShapeType.UNION -> { + val deserializeFn = documentDeserializer(ctx, target) + writer.write("#T(#L)", deserializeFn, serdeCtx.tagReader) + } + else -> deserializePrimitiveMember( + ctx, + member, + writer.format("#L.#T()", serdeCtx.tagReader, RuntimeTypes.Serde.SerdeXml.tryData), + textExprIsResult = true, + writer, + ) + } + } + + // TODO - this could probably be moved to SerdeExt and commonized + + private fun Shape.shapeDeserializerDefinitionFile( + ctx: ProtocolGenerator.GenerationContext, + ): String { + val target = targetOrSelf(ctx.model) + val shapeName = StringUtils.capitalize(target.id.getName(ctx.service)) + return "${shapeName}ShapeDeserializer.kt" + } + + private fun Shape.shapeDeserializer( + ctx: ProtocolGenerator.GenerationContext, + block: (fnName: String, writer: KotlinWriter) -> Unit, + ): Symbol { + val target = targetOrSelf(ctx.model) + val shapeName = StringUtils.capitalize(target.id.getName(ctx.service)) + val symbol = ctx.symbolProvider.toSymbol(this) + + val fnName = "deserialize${shapeName}Shape" + return buildSymbol { + name = fnName + namespace = ctx.settings.pkg.serde + definitionFile = shapeDeserializerDefinitionFile(ctx) + reference(symbol, SymbolReference.ContextOption.DECLARE) + renderBy = { + block(fnName, it) + } + } + } + + private fun deserializeShape( + ctx: ProtocolGenerator.GenerationContext, + shape: Shape, + block: KotlinWriter.() -> Unit, + ): Symbol { + val symbol = ctx.symbolProvider.toSymbol(shape) + val deserializeFn = shape.shapeDeserializer(ctx) { fnName, writer -> + writer.withBlock( + "internal fun #L(reader: #T): #T {", + "}", + fnName, + RuntimeTypes.Serde.SerdeXml.XmlTagReader, + symbol, + ) { + block(this) + } + } + return deserializeFn + } + + private fun deserializeList( + ctx: ProtocolGenerator.GenerationContext, + serdeCtx: SerdeCtx, + member: MemberShape, + writer: KotlinWriter, + ) { + val target = ctx.model.expectShape(member.target) + val targetMember = target.member + val isSparse = target.hasTrait() + val deserializeFn = deserializeShape(ctx, target) { + write("val result = mutableListOf<#T#L>()", ctx.symbolProvider.toSymbol(targetMember), nullabilitySuffix(isSparse)) + deserializeLoop(SerdeCtx(tagReader = "reader")) { innerCtx -> + val memberName = targetMember.getTrait()?.value ?: targetMember.memberName + withBlock("#S -> {", "}", memberName) { + deserializeListInner(ctx, innerCtx, target, this) + write("result.add(el)") + } + } + write("return result") + } + writer.write("#T(#L)", deserializeFn, serdeCtx.tagReader) + } + + private fun flatCollectionAccumulatorExpr( + ctx: ProtocolGenerator.GenerationContext, + member: MemberShape, + ): String { + val escapedMemberName = ctx.symbolProvider.toMemberName(member) + return when (val container = ctx.model.expectShape(member.container)) { + is StructureShape -> "builder.$escapedMemberName" + is UnionShape -> { + val unionVariantName = member.unionVariantName() + "value?.as${unionVariantName}OrNull()" + } + else -> error("unexpected container shape $container for member $member") + } + } + + private fun deserializeFlatList( + ctx: ProtocolGenerator.GenerationContext, + serdeCtx: SerdeCtx, + member: MemberShape, + writer: KotlinWriter, + ) { + val target = ctx.model.expectShape(member.target) + writer.withBlock("run {", "}") { + deserializeListInner(ctx, serdeCtx, target, this) + val accum = flatCollectionAccumulatorExpr(ctx, member) + write("#T(#L, el)", RuntimeTypes.Core.Collections.createOrAppend, accum) + } + } + + private fun deserializeListInner( + ctx: ProtocolGenerator.GenerationContext, + serdeCtx: SerdeCtx, + target: CollectionShape, + writer: KotlinWriter, + ) { + // <- sparse + // CDATA || TAG(s) <- not sparse + val isSparse = target.hasTrait() + with(writer) { + if (isSparse) { + openBlock("val el = if (#L.nextHasValue()) {", serdeCtx.tagReader) + .call { + deserializeMember(ctx, serdeCtx, target.member, this) + } + .closeAndOpenBlock("} else {") + .write("null") + .closeBlock("}") + } else { + writeInline("val el = ") + deserializeMember(ctx, serdeCtx, target.member, this) + } + } + } + + private fun deserializeMap( + ctx: ProtocolGenerator.GenerationContext, + serdeCtx: SerdeCtx, + member: MemberShape, + writer: KotlinWriter, + ) { + val target = ctx.model.expectShape(member.target) + val keySymbol = KotlinTypes.String + val valueSymbol = ctx.symbolProvider.toSymbol(target.value) + writer.addImportReferences(valueSymbol, SymbolReference.ContextOption.USE) + val isSparse = target.hasTrait() + + val deserializeFn = deserializeShape(ctx, target) { + write("val result = mutableMapOf<#T, #T#L>()", keySymbol, valueSymbol, nullabilitySuffix(isSparse)) + deserializeLoop(SerdeCtx("reader")) { innerCtx -> + withBlock("#S -> {", "}", "entry") { + val deserializeEntryFn = deserializeMapEntry(ctx, target) + write("#T(result, ${innerCtx.tagReader})", deserializeEntryFn) + } + } + write("return result") + } + writer.write("#T(#L)", deserializeFn, serdeCtx.tagReader) + } + + private fun deserializeFlatMap( + ctx: ProtocolGenerator.GenerationContext, + serdeCtx: SerdeCtx, + member: MemberShape, + writer: KotlinWriter, + ) { + val target = ctx.model.expectShape(member.target) + val keySymbol = KotlinTypes.String + val valueSymbol = ctx.symbolProvider.toSymbol(target.value) + val isSparse = target.hasTrait() + writer.addImportReferences(valueSymbol, SymbolReference.ContextOption.USE) + writer.withBlock("run {", "}") { + val accum = flatCollectionAccumulatorExpr(ctx, member) + write( + "val dest = #L?.toMutableMap() ?: mutableMapOf<#T, #T#L>()", + accum, + keySymbol, + valueSymbol, + nullabilitySuffix(isSparse), + ) + val deserializeEntryFn = deserializeMapEntry(ctx, target) + write("#T(dest, #L)", deserializeEntryFn, serdeCtx.tagReader) + write("dest") + } + } + + private fun deserializeMapEntry( + ctx: ProtocolGenerator.GenerationContext, + map: MapShape, + ): Symbol { + val shapeName = StringUtils.capitalize(map.id.getName(ctx.service)) + val keySymbol = KotlinTypes.String + val valueSymbol = ctx.symbolProvider.toSymbol(map.value) + val isSparse = map.hasTrait() + val serdeCtx = SerdeCtx("reader") + + return buildSymbol { + name = "deserialize${shapeName}Entry" + namespace = ctx.settings.pkg.serde + definitionFile = map.shapeDeserializerDefinitionFile(ctx) + renderBy = { writer -> + // NOTE: we make this internal rather than private because flat maps don't generate a + // dedicated map deserializer, they inline the entry deserialization since the map + // being built up is not processed all at once + writer.withBlock( + "internal fun $name(dest: #T<#T, #T#L>, reader: #T) {", + "}", + KotlinTypes.Collections.MutableMap, + keySymbol, + valueSymbol, + nullabilitySuffix(isSparse), + RuntimeTypes.Serde.SerdeXml.XmlTagReader, + ) { + write("var key: #T? = null", keySymbol) + write("var value: #T? = null", valueSymbol) + writer.addImportReferences(valueSymbol, SymbolReference.ContextOption.USE) + deserializeLoop(serdeCtx) { innerCtx -> + val keyName = map.key.getTrait()?.value ?: map.key.memberName + writeInline("#S -> key = ", keyName) + deserializeMember(ctx, innerCtx, map.key, this) + // FIXME - We re-use deserializeMember here but key types targeting enums + // have to pull the raw string value back out because of + // https://github.com/awslabs/smithy-kotlin/issues/1045 + val targetValueShape = ctx.model.expectShape(map.key.target) + if (targetValueShape.type == ShapeType.ENUM) { + writer.indent() + .write(".value") + } + + val valueName = map.value.getTrait()?.value ?: map.value.memberName + if (isSparse) { + openBlock("#S -> value = if (${innerCtx.tagReader}.nextHasValue()) {", valueName) + .call { + deserializeMember(ctx, innerCtx, map.value, this) + } + .closeAndOpenBlock("} else {") + .write("null") + .closeBlock("}") + } else { + writeInline("#S -> value = ", valueName) + deserializeMember(ctx, innerCtx, map.value, this) + } + } + write("if (key == null) throw #T(#S)", RuntimeTypes.Serde.DeserializationException, "missing key map entry") + if (!isSparse) { + write("if (value == null) throw #T(#S)", RuntimeTypes.Serde.DeserializationException, "missing value map entry") + } + write("dest[key] = value") + } + } + } + } + + private fun deserializePrimitiveMember( + ctx: ProtocolGenerator.GenerationContext, + member: MemberShape, + textExpr: String, + textExprIsResult: Boolean, + writer: KotlinWriter, + ) { + val target = ctx.model.expectShape(member.target) + + val parseFn = when (target.type) { + ShapeType.BLOB -> writer.format("#T { it.#T() } ", RuntimeTypes.Serde.parse, RuntimeTypes.Core.Text.Encoding.decodeBase64Bytes) + ShapeType.BOOLEAN -> writer.format("#T()", RuntimeTypes.Serde.parseBoolean) + ShapeType.STRING -> { + if (!textExprIsResult) { + writer.write(textExpr) + return + } else { + null } } + ShapeType.TIMESTAMP -> { + val trait = member.getTrait() ?: target.getTrait() + val tsFormat = trait?.format ?: defaultTimestampFormat + val runtimeEnum = tsFormat.toRuntimeEnum(writer) + writer.format("#T(#L)", RuntimeTypes.Serde.parseTimestamp, runtimeEnum) + } + ShapeType.BYTE -> writer.format("#T()", RuntimeTypes.Serde.parseByte) + ShapeType.SHORT -> writer.format("#T()", RuntimeTypes.Serde.parseShort) + ShapeType.INTEGER -> writer.format("#T()", RuntimeTypes.Serde.parseInt) + ShapeType.LONG -> writer.format("#T()", RuntimeTypes.Serde.parseLong) + ShapeType.FLOAT -> writer.format("#T()", RuntimeTypes.Serde.parseFloat) + ShapeType.DOUBLE -> writer.format("#T()", RuntimeTypes.Serde.parseDouble) + ShapeType.BIG_DECIMAL -> writer.format("#T()", RuntimeTypes.Serde.parseBigDecimal) + ShapeType.BIG_INTEGER -> writer.format("#T()", RuntimeTypes.Serde.parseBigInteger) + ShapeType.ENUM -> { + if (!textExprIsResult) { + writer.write("#T.fromValue(#L)", ctx.symbolProvider.toSymbol(target), textExpr) + return + } + writer.format("#T { #T.fromValue(it) } ", RuntimeTypes.Serde.parse, ctx.symbolProvider.toSymbol(target)) + } + ShapeType.INT_ENUM -> { + writer.format("#T { #T.fromValue(it.toInt()) } ", RuntimeTypes.Serde.parse, ctx.symbolProvider.toSymbol(target)) + } + else -> error("unknown primitive member shape $member") + } + + val escapedErrMessage = "expected $target".replace("$", "$$") + writer.write(textExpr) + .indent() + .callIf(parseFn != null) { + writer.write(".#L", parseFn) + } + .write(".#T { #S }", RuntimeTypes.Serde.getOrDeserializeErr, escapedErrMessage) + .dedent() + } + + private fun renderDeserializerUnwrappedXmlBody( + ctx: ProtocolGenerator.GenerationContext, + serdeCtx: SerdeCtx, + shape: Shape, + writer: KotlinWriter, + ) { + val members = shape.members() + check(members.size == 1) { + "unwrapped XML output trait is only allowed on operation output structs with exactly one member" + } + + val member = members.first() + writer.withBlock("when (#L.tagName) {", "}", serdeCtx.tagReader) { + val name = member.getTrait()?.value ?: member.memberName + writeMemberDebugComment(ctx, member) + writeInline("#S -> builder.#L = ", name, ctx.symbolProvider.toMemberName(member)) + deserializeMember(ctx, serdeCtx, member, writer) } } } diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/XmlSerdeDescriptorGenerator.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/XmlSerdeDescriptorGenerator.kt index b5bbf4de7..a5f47cf3c 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/XmlSerdeDescriptorGenerator.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/XmlSerdeDescriptorGenerator.kt @@ -8,7 +8,6 @@ package software.amazon.smithy.kotlin.codegen.rendering.serde import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.kotlin.codegen.core.RenderingContext import software.amazon.smithy.kotlin.codegen.core.RuntimeTypes -import software.amazon.smithy.kotlin.codegen.core.addImport import software.amazon.smithy.kotlin.codegen.core.defaultName import software.amazon.smithy.kotlin.codegen.model.expectShape import software.amazon.smithy.kotlin.codegen.model.expectTrait @@ -78,7 +77,6 @@ open class XmlSerdeDescriptorGenerator( nameSuffix: String, ): List { ctx.writer.addImport( - RuntimeTypes.Serde.SerdeXml.XmlDeserializer, RuntimeTypes.Serde.SerdeXml.XmlSerialName, ) diff --git a/codegen/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/XmlSerdeDescriptorGeneratorTest.kt b/codegen/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/XmlSerdeDescriptorGeneratorTest.kt index edb4981aa..3da6ce551 100644 --- a/codegen/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/XmlSerdeDescriptorGeneratorTest.kt +++ b/codegen/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/XmlSerdeDescriptorGeneratorTest.kt @@ -5,7 +5,6 @@ package software.amazon.smithy.kotlin.codegen.rendering.serde -import software.amazon.smithy.kotlin.codegen.core.RUNTIME_ROOT_NS import software.amazon.smithy.kotlin.codegen.test.* import software.amazon.smithy.model.shapes.ShapeId import kotlin.test.Test @@ -141,50 +140,6 @@ class XmlSerdeDescriptorGeneratorTest { contents.shouldContainOnlyOnceWithDiff(expectedDescriptors) } - @Test - fun `it generates expected import declarations`() { - val snippet = """ - @http(method: "POST", uri: "/foo") - operation Foo { - input: FooRequest, - output: FooRequest - } - - @xmlName("CustomFooRequest") - structure FooRequest { - @xmlAttribute - payload: String, - @xmlFlattened - listVal: ListOfString - } - - list ListOfString { - member: String - } - """ - - val expected = """ - import $RUNTIME_ROOT_NS.serde.SdkFieldDescriptor - import $RUNTIME_ROOT_NS.serde.SdkObjectDescriptor - import $RUNTIME_ROOT_NS.serde.SerialKind - import $RUNTIME_ROOT_NS.serde.asSdkSerializable - import $RUNTIME_ROOT_NS.serde.deserializeList - import $RUNTIME_ROOT_NS.serde.deserializeMap - import $RUNTIME_ROOT_NS.serde.deserializeStruct - import $RUNTIME_ROOT_NS.serde.field - import $RUNTIME_ROOT_NS.serde.serializeList - import $RUNTIME_ROOT_NS.serde.serializeMap - import $RUNTIME_ROOT_NS.serde.serializeStruct - import $RUNTIME_ROOT_NS.serde.xml.Flattened - import $RUNTIME_ROOT_NS.serde.xml.XmlAttribute - import $RUNTIME_ROOT_NS.serde.xml.XmlDeserializer - import $RUNTIME_ROOT_NS.serde.xml.XmlSerialName - """.formatForTest("") - - val contents = getContents(snippet, "FooRequest") - contents.shouldContainOnlyOnceWithDiff(expected) - } - @Test fun `it generates field descriptors for flattened xml trait and object descriptor for XmlName trait`() { val snippet = """ diff --git a/gradle.properties b/gradle.properties index 11f5dc5b1..ebe9af9e9 100644 --- a/gradle.properties +++ b/gradle.properties @@ -16,4 +16,4 @@ org.gradle.jvmargs=-Xmx2G -XX:MaxMetaspaceSize=1G sdkVersion=1.0.16-SNAPSHOT # codegen -codegenVersion=0.30.17-SNAPSHOT \ No newline at end of file +codegenVersion=0.30.17-SNAPSHOT diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 87e9b9cc3..eea8dd4af 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -1,5 +1,5 @@ [versions] -kotlin-version = "1.9.21" +kotlin-version = "1.9.22" dokka-version = "1.9.10" aws-kotlin-repo-tools-version = "0.4.0" diff --git a/runtime/protocol/aws-xml-protocols/common/src/aws/smithy/kotlin/runtime/awsprotocol/xml/Ec2QueryErrorDeserializer.kt b/runtime/protocol/aws-xml-protocols/common/src/aws/smithy/kotlin/runtime/awsprotocol/xml/Ec2QueryErrorDeserializer.kt index c01e062e5..4ce0e7c98 100644 --- a/runtime/protocol/aws-xml-protocols/common/src/aws/smithy/kotlin/runtime/awsprotocol/xml/Ec2QueryErrorDeserializer.kt +++ b/runtime/protocol/aws-xml-protocols/common/src/aws/smithy/kotlin/runtime/awsprotocol/xml/Ec2QueryErrorDeserializer.kt @@ -6,10 +6,8 @@ package aws.smithy.kotlin.runtime.awsprotocol.xml import aws.smithy.kotlin.runtime.InternalApi import aws.smithy.kotlin.runtime.awsprotocol.ErrorDetails -import aws.smithy.kotlin.runtime.serde.* -import aws.smithy.kotlin.runtime.serde.xml.XmlCollectionName -import aws.smithy.kotlin.runtime.serde.xml.XmlDeserializer -import aws.smithy.kotlin.runtime.serde.xml.XmlSerialName +import aws.smithy.kotlin.runtime.serde.getOrDeserializeErr +import aws.smithy.kotlin.runtime.serde.xml.* internal data class Ec2QueryErrorResponse(val errors: List, val requestId: String?) @@ -17,83 +15,65 @@ internal data class Ec2QueryError(val code: String?, val message: String?) @InternalApi public suspend fun parseEc2QueryErrorResponse(payload: ByteArray): ErrorDetails { - val response = Ec2QueryErrorResponseDeserializer.deserialize(XmlDeserializer(payload, true)) + val response = Ec2QueryErrorResponseDeserializer.deserialize(xmlTagReader(payload)) val firstError = response.errors.firstOrNull() return ErrorDetails(firstError?.code, firstError?.message, response.requestId) } /** * Deserializes EC2 Query protocol errors as specified by - * https://awslabs.github.io/smithy/1.0/spec/aws/aws-ec2-query-protocol.html#operation-error-serialization + * https://smithy.io/2.0/aws/protocols/aws-ec2-query-protocol.html#operation-error-serialization */ internal object Ec2QueryErrorResponseDeserializer { - private val ERRORS_DESCRIPTOR = SdkFieldDescriptor( - SerialKind.List, - XmlSerialName("Errors"), - XmlCollectionName("Error"), - ) - private val REQUESTID_DESCRIPTOR = SdkFieldDescriptor(SerialKind.String, XmlSerialName("RequestId")) - private val OBJ_DESCRIPTOR = SdkObjectDescriptor.build { - trait(XmlSerialName("Response")) - field(ERRORS_DESCRIPTOR) - field(REQUESTID_DESCRIPTOR) - } - - suspend fun deserialize(deserializer: Deserializer): Ec2QueryErrorResponse { - var errors = listOf() + fun deserialize(root: XmlTagReader): Ec2QueryErrorResponse = runCatching { + var errors: List? = null var requestId: String? = null + if (root.tagName != "Response") error("expected found ${root.tag}") - deserializer.deserializeStruct(OBJ_DESCRIPTOR) { - loop@ while (true) { - when (findNextFieldIndex()) { - ERRORS_DESCRIPTOR.index -> errors = deserializer.deserializeList(ERRORS_DESCRIPTOR) { - val collection = mutableListOf() - while (hasNextElement()) { - if (nextHasValue()) { - val element = Ec2QueryErrorDeserializer.deserialize(deserializer) - collection.add(element) - } else { - deserializeNull() - continue - } - } - collection - } - REQUESTID_DESCRIPTOR.index -> requestId = deserializeString() - null -> break@loop - else -> skipValue() - } + loop@while (true) { + val curr = root.nextTag() ?: break@loop + when (curr.tagName) { + "Errors" -> errors = Ec2QueryErrorListDeserializer.deserialize(curr) + "RequestId" -> requestId = curr.data() } + curr.drop() } - return Ec2QueryErrorResponse(errors, requestId) + Ec2QueryErrorResponse(errors ?: emptyList(), requestId) + }.getOrDeserializeErr { "Unable to deserialize EC2Query error" } +} + +internal object Ec2QueryErrorListDeserializer { + fun deserialize(root: XmlTagReader): List { + val errors = mutableListOf() + loop@while (true) { + val curr = root.nextTag() ?: break@loop + when (curr.tagName) { + "Error" -> { + val el = Ec2QueryErrorDeserializer.deserialize(curr) + errors.add(el) + } + } + curr.drop() + } + return errors } } internal object Ec2QueryErrorDeserializer { - private val CODE_DESCRIPTOR = SdkFieldDescriptor(SerialKind.String, XmlSerialName("Code")) - private val MESSAGE_DESCRIPTOR = SdkFieldDescriptor(SerialKind.String, XmlSerialName("Message")) - private val OBJ_DESCRIPTOR = SdkObjectDescriptor.build { - trait(XmlSerialName("Error")) - field(CODE_DESCRIPTOR) - field(MESSAGE_DESCRIPTOR) - } - suspend fun deserialize(deserializer: Deserializer): Ec2QueryError { + fun deserialize(root: XmlTagReader): Ec2QueryError { var code: String? = null var message: String? = null - deserializer.deserializeStruct(OBJ_DESCRIPTOR) { - loop@ while (true) { - when (findNextFieldIndex()) { - CODE_DESCRIPTOR.index -> code = deserializeString() - MESSAGE_DESCRIPTOR.index -> message = deserializeString() - null -> break@loop - else -> skipValue() - } + loop@while (true) { + val curr = root.nextTag() ?: break@loop + when (curr.tagName) { + "Code" -> code = curr.data() + "Message", "message" -> message = curr.data() } + curr.drop() } - return Ec2QueryError(code, message) } } diff --git a/runtime/protocol/aws-xml-protocols/common/src/aws/smithy/kotlin/runtime/awsprotocol/xml/RestXmlErrorDeserializer.kt b/runtime/protocol/aws-xml-protocols/common/src/aws/smithy/kotlin/runtime/awsprotocol/xml/RestXmlErrorDeserializer.kt index 6eef5f463..60e494425 100644 --- a/runtime/protocol/aws-xml-protocols/common/src/aws/smithy/kotlin/runtime/awsprotocol/xml/RestXmlErrorDeserializer.kt +++ b/runtime/protocol/aws-xml-protocols/common/src/aws/smithy/kotlin/runtime/awsprotocol/xml/RestXmlErrorDeserializer.kt @@ -7,8 +7,9 @@ package aws.smithy.kotlin.runtime.awsprotocol.xml import aws.smithy.kotlin.runtime.InternalApi import aws.smithy.kotlin.runtime.awsprotocol.ErrorDetails import aws.smithy.kotlin.runtime.serde.* -import aws.smithy.kotlin.runtime.serde.xml.XmlDeserializer -import aws.smithy.kotlin.runtime.serde.xml.XmlSerialName +import aws.smithy.kotlin.runtime.serde.xml.XmlTagReader +import aws.smithy.kotlin.runtime.serde.xml.data +import aws.smithy.kotlin.runtime.serde.xml.xmlTagReader /** * Provides access to specific values regardless of message form @@ -19,16 +20,6 @@ internal interface RestXmlErrorDetails { val message: String? } -// Models "ErrorResponse" type in https://awslabs.github.io/smithy/1.0/spec/aws/aws-restxml-protocol.html#operation-error-serialization -internal data class XmlErrorResponse( - val error: XmlError?, - override val requestId: String? = error?.requestId, -) : RestXmlErrorDetails { - override val code: String? = error?.code - override val message: String? = error?.message -} - -// Models "Error" type in https://awslabs.github.io/smithy/1.0/spec/aws/aws-restxml-protocol.html#operation-error-serialization internal data class XmlError( override val requestId: String?, override val code: String?, @@ -39,96 +30,56 @@ internal data class XmlError( * Deserializes rest XML protocol errors as specified by: * https://awslabs.github.io/smithy/1.0/spec/aws/aws-restxml-protocol.html#error-response-serialization * - * Returns parsed data in normalized form or throws IllegalArgumentException if response cannot be parsed. - * NOTE: we use an explicit XML deserializer here because we rely on validating the root element name - * for dealing with the alternate error response forms + * Returns parsed data in normalized form or throws [DeserializationException] if response cannot be parsed. */ @InternalApi public suspend fun parseRestXmlErrorResponse(payload: ByteArray): ErrorDetails { - val details = ErrorResponseDeserializer.deserialize(XmlDeserializer(payload, true)) - ?: XmlErrorDeserializer.deserialize(XmlDeserializer(payload, true)) - ?: throw DeserializationException("Unable to deserialize RestXml error.") + val details = XmlErrorDeserializer.deserialize(xmlTagReader(payload)) return ErrorDetails(details.code, details.message, details.requestId) } -/* - * The deserializers in this file were initially generated by the SDK and then - * adapted to fit this use case of deserializing well-known error structures from - * restXml-based services. - */ - /** - * Deserializes rest Xml protocol errors as specified by: - * - Smithy spec: https://awslabs.github.io/smithy/1.0/spec/aws/aws-restxml-protocol.html#operation-error-serialization - */ -internal object ErrorResponseDeserializer { - private val ERROR_DESCRIPTOR = SdkFieldDescriptor(SerialKind.Struct, XmlSerialName("Error")) - private val REQUESTID_DESCRIPTOR = SdkFieldDescriptor(SerialKind.String, XmlSerialName("RequestId")) - private val OBJ_DESCRIPTOR = SdkObjectDescriptor.build { - trait(XmlSerialName("ErrorResponse")) - field(ERROR_DESCRIPTOR) - field(REQUESTID_DESCRIPTOR) - } - - suspend fun deserialize(deserializer: Deserializer): XmlErrorResponse? { - var requestId: String? = null - var xmlError: XmlError? = null - - return try { - deserializer.deserializeStruct(OBJ_DESCRIPTOR) { - loop@ while (true) { - when (findNextFieldIndex()) { - ERROR_DESCRIPTOR.index -> xmlError = XmlErrorDeserializer.deserialize(deserializer) - REQUESTID_DESCRIPTOR.index -> requestId = deserializeString() - null -> break@loop - else -> skipValue() - } - } - } - - XmlErrorResponse(xmlError, requestId ?: xmlError?.requestId) - } catch (e: DeserializationException) { - null // return so an appropriate exception type can be instantiated above here. - } - } -} - -/** - * This deserializer is used for both the nested Error node from ErrorResponse as well as the top-level - * Error node as described in https://awslabs.github.io/smithy/1.0/spec/aws/aws-restxml-protocol.html#operation-error-serialization + * This deserializer is used for both wrapped and unwrapped restXml errors. */ internal object XmlErrorDeserializer { - private val MESSAGE_DESCRIPTOR = SdkFieldDescriptor(SerialKind.String, XmlSerialName("Message")) - private val CODE_DESCRIPTOR = SdkFieldDescriptor(SerialKind.String, XmlSerialName("Code")) - private val REQUESTID_DESCRIPTOR = SdkFieldDescriptor(SerialKind.String, XmlSerialName("RequestId")) - private val OBJ_DESCRIPTOR = SdkObjectDescriptor.build { - trait(XmlSerialName("Error")) - field(MESSAGE_DESCRIPTOR) - field(CODE_DESCRIPTOR) - field(REQUESTID_DESCRIPTOR) - } - - suspend fun deserialize(deserializer: Deserializer): XmlError? { + fun deserialize(root: XmlTagReader): XmlError = runCatching { var message: String? = null var code: String? = null var requestId: String? = null - return try { - deserializer.deserializeStruct(OBJ_DESCRIPTOR) { - loop@ while (true) { - when (findNextFieldIndex()) { - MESSAGE_DESCRIPTOR.index -> message = deserializeString() - CODE_DESCRIPTOR.index -> code = deserializeString() - REQUESTID_DESCRIPTOR.index -> requestId = deserializeString() - null -> break@loop - else -> skipValue() - } + val rootTagName = root.tagName + check(rootTagName == "ErrorResponse" || rootTagName == "Error") { + "expected restXml error response with root tag of or " + } + + // wrapped error, unwrap it + var errTag = root + if (root.tagName == "ErrorResponse") { + errTag = root.nextTag() ?: error("expected more tags after ") + } + + if (errTag.tagName == "Error") { + loop@while (true) { + val curr = errTag.nextTag() ?: break@loop + when (curr.tagName) { + "Code" -> code = curr.data() + "Message", "message" -> message = curr.data() + "RequestId" -> requestId = curr.data() } + curr.drop() } + } - XmlError(requestId, code, message) - } catch (e: DeserializationException) { - null // return so an appropriate exception type can be instantiated above here. + // wrapped responses + if (requestId == null) { + loop@while (true) { + val curr = root.nextTag() ?: break@loop + when (curr.tagName) { + "RequestId" -> requestId = curr.data() + } + } } - } + + XmlError(requestId, code, message) + }.getOrDeserializeErr { "Unable to deserialize RestXml error" } } diff --git a/runtime/runtime-core/api/runtime-core.api b/runtime/runtime-core/api/runtime-core.api index c4f8afeef..0e37794c0 100644 --- a/runtime/runtime-core/api/runtime-core.api +++ b/runtime/runtime-core/api/runtime-core.api @@ -112,6 +112,10 @@ public final class aws/smithy/kotlin/runtime/collections/AttributesKt { public static final fun toMutableAttributes (Laws/smithy/kotlin/runtime/collections/Attributes;)Laws/smithy/kotlin/runtime/collections/MutableAttributes; } +public final class aws/smithy/kotlin/runtime/collections/CollectionExtKt { + public static final fun createOrAppend (Ljava/util/List;Ljava/lang/Object;)Ljava/util/List; +} + public abstract interface class aws/smithy/kotlin/runtime/collections/MultiMap : java/util/Map, kotlin/jvm/internal/markers/KMappedMarker { public abstract fun contains (Ljava/lang/Object;Ljava/lang/Object;)Z public abstract fun getEntryValues ()Lkotlin/sequences/Sequence; @@ -2252,6 +2256,10 @@ public abstract interface class aws/smithy/kotlin/runtime/util/PropertyProvider public abstract fun getProperty (Ljava/lang/String;)Ljava/lang/String; } +public final class aws/smithy/kotlin/runtime/util/ResultExtKt { + public static final fun mapErr (Ljava/lang/Object;Lkotlin/jvm/functions/Function1;)Ljava/lang/Object; +} + public final class aws/smithy/kotlin/runtime/util/SingleFlightGroup { public fun ()V public final fun singleFlight (Lkotlin/jvm/functions/Function1;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; diff --git a/runtime/runtime-core/common/src/aws/smithy/kotlin/runtime/collections/CollectionExt.kt b/runtime/runtime-core/common/src/aws/smithy/kotlin/runtime/collections/CollectionExt.kt new file mode 100644 index 000000000..c355bb720 --- /dev/null +++ b/runtime/runtime-core/common/src/aws/smithy/kotlin/runtime/collections/CollectionExt.kt @@ -0,0 +1,18 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package aws.smithy.kotlin.runtime.collections + +/** + * Creates a new list or appends to an existing one if not null. + * + * If [dest] is null this function creates a new list with element [x] and returns it. + * Otherwise, it appends [x] to [dest] and returns the mutated list. + */ +public fun createOrAppend(dest: List?, x: T): List { + if (dest == null) return listOf(x) + val mut = dest.toMutableList() + mut.add(x) + return mut +} diff --git a/runtime/runtime-core/common/src/aws/smithy/kotlin/runtime/util/ResultExt.kt b/runtime/runtime-core/common/src/aws/smithy/kotlin/runtime/util/ResultExt.kt new file mode 100644 index 000000000..b2ed33b0d --- /dev/null +++ b/runtime/runtime-core/common/src/aws/smithy/kotlin/runtime/util/ResultExt.kt @@ -0,0 +1,18 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package aws.smithy.kotlin.runtime.util + +import aws.smithy.kotlin.runtime.InternalApi + +/** + * Maps the exception to a new error if this instance represents [failure][Result.isFailure], leaving + * a [success][Result.isSuccess] value untouched. + */ +@InternalApi +public inline fun Result.mapErr(onFailure: (Throwable) -> Throwable): Result = + when (val ex = exceptionOrNull()) { + null -> this + else -> Result.failure(onFailure(ex)) + } diff --git a/runtime/serde/api/serde.api b/runtime/serde/api/serde.api index c9d0d90b6..8be9fc0e9 100644 --- a/runtime/serde/api/serde.api +++ b/runtime/serde/api/serde.api @@ -39,6 +39,10 @@ public final class aws/smithy/kotlin/runtime/serde/DeserializerKt { public static final fun deserializeStruct (Laws/smithy/kotlin/runtime/serde/Deserializer;Laws/smithy/kotlin/runtime/serde/SdkObjectDescriptor;Lkotlin/jvm/functions/Function1;)V } +public final class aws/smithy/kotlin/runtime/serde/ExceptionsKt { + public static final fun getOrDeserializeErr (Ljava/lang/Object;Lkotlin/jvm/functions/Function0;)Ljava/lang/Object; +} + public abstract interface class aws/smithy/kotlin/runtime/serde/FieldTrait { } @@ -64,6 +68,31 @@ public abstract interface class aws/smithy/kotlin/runtime/serde/MapSerializer : public abstract fun mapEntry (Ljava/lang/String;Laws/smithy/kotlin/runtime/serde/SdkFieldDescriptor;Lkotlin/jvm/functions/Function1;)V } +public final class aws/smithy/kotlin/runtime/serde/ParsersKt { + public static final fun parse (Ljava/lang/Object;Lkotlin/jvm/functions/Function1;)Ljava/lang/Object; + public static final fun parse (Ljava/lang/String;Lkotlin/jvm/functions/Function1;)Ljava/lang/Object; + public static final fun parseBigDecimal (Ljava/lang/Object;)Ljava/lang/Object; + public static final fun parseBigDecimal (Ljava/lang/String;)Ljava/lang/Object; + public static final fun parseBigInteger (Ljava/lang/Object;)Ljava/lang/Object; + public static final fun parseBigInteger (Ljava/lang/String;)Ljava/lang/Object; + public static final fun parseBoolean (Ljava/lang/Object;)Ljava/lang/Object; + public static final fun parseBoolean (Ljava/lang/String;)Ljava/lang/Object; + public static final fun parseByte (Ljava/lang/Object;)Ljava/lang/Object; + public static final fun parseByte (Ljava/lang/String;)Ljava/lang/Object; + public static final fun parseDouble (Ljava/lang/Object;)Ljava/lang/Object; + public static final fun parseDouble (Ljava/lang/String;)Ljava/lang/Object; + public static final fun parseFloat (Ljava/lang/Object;)Ljava/lang/Object; + public static final fun parseFloat (Ljava/lang/String;)Ljava/lang/Object; + public static final fun parseInt (Ljava/lang/Object;)Ljava/lang/Object; + public static final fun parseInt (Ljava/lang/String;)Ljava/lang/Object; + public static final fun parseLong (Ljava/lang/Object;)Ljava/lang/Object; + public static final fun parseLong (Ljava/lang/String;)Ljava/lang/Object; + public static final fun parseShort (Ljava/lang/Object;)Ljava/lang/Object; + public static final fun parseShort (Ljava/lang/String;)Ljava/lang/Object; + public static final fun parseTimestamp (Ljava/lang/Object;Laws/smithy/kotlin/runtime/time/TimestampFormat;)Ljava/lang/Object; + public static final fun parseTimestamp (Ljava/lang/String;Laws/smithy/kotlin/runtime/time/TimestampFormat;)Ljava/lang/Object; +} + public abstract interface class aws/smithy/kotlin/runtime/serde/PrimitiveDeserializer { public abstract fun deserializeBigDecimal ()Ljava/math/BigDecimal; public abstract fun deserializeBigInteger ()Ljava/math/BigInteger; diff --git a/runtime/serde/common/src/aws/smithy/kotlin/runtime/serde/Exceptions.kt b/runtime/serde/common/src/aws/smithy/kotlin/runtime/serde/Exceptions.kt index e774a49e6..18f908727 100644 --- a/runtime/serde/common/src/aws/smithy/kotlin/runtime/serde/Exceptions.kt +++ b/runtime/serde/common/src/aws/smithy/kotlin/runtime/serde/Exceptions.kt @@ -5,6 +5,8 @@ package aws.smithy.kotlin.runtime.serde import aws.smithy.kotlin.runtime.ClientException +import aws.smithy.kotlin.runtime.InternalApi +import aws.smithy.kotlin.runtime.util.mapErr /** * Exception class for all serialization errors @@ -33,3 +35,12 @@ public class DeserializationException : ClientException { public constructor(cause: Throwable?) : super(cause) } + +/** + * Get the underlying [success][Result.isSuccess] value or wrap the failure in a [DeserializationException] + * and throw it. + */ +@InternalApi +public inline fun Result.getOrDeserializeErr(errorMessage: () -> String): T = + mapErr { DeserializationException(errorMessage(), it) } + .getOrThrow() diff --git a/runtime/serde/common/src/aws/smithy/kotlin/runtime/serde/Parsers.kt b/runtime/serde/common/src/aws/smithy/kotlin/runtime/serde/Parsers.kt new file mode 100644 index 000000000..3991f0f1a --- /dev/null +++ b/runtime/serde/common/src/aws/smithy/kotlin/runtime/serde/Parsers.kt @@ -0,0 +1,87 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package aws.smithy.kotlin.runtime.serde + +import aws.smithy.kotlin.runtime.InternalApi +import aws.smithy.kotlin.runtime.content.BigDecimal +import aws.smithy.kotlin.runtime.content.BigInteger +import aws.smithy.kotlin.runtime.time.Instant +import aws.smithy.kotlin.runtime.time.TimestampFormat + +@InternalApi +public inline fun String.parse(transform: (String) -> T): Result = runCatching { transform(this) } + +@InternalApi +public fun String.parseBoolean(): Result = parse(String::toBoolean) + +@InternalApi +public fun String.parseInt(): Result = parse(String::toInt) + +@InternalApi +public fun String.parseShort(): Result = parse(String::toShort) + +@InternalApi +public fun String.parseLong(): Result = parse(String::toLong) + +@InternalApi +public fun String.parseFloat(): Result = parse(String::toFloat) + +@InternalApi +public fun String.parseDouble(): Result = parse(String::toDouble) + +@InternalApi +public fun String.parseByte(): Result = parse { it.toInt().toByte() } + +@InternalApi +public fun String.parseBigInteger(): Result = parse(::BigInteger) + +@InternalApi +public fun String.parseBigDecimal(): Result = parse(::BigDecimal) + +private fun String.toTimestamp(fmt: TimestampFormat): Instant = when (fmt) { + TimestampFormat.ISO_8601_CONDENSED, + TimestampFormat.ISO_8601_CONDENSED_DATE, + TimestampFormat.ISO_8601, + -> Instant.fromIso8601(this) + + TimestampFormat.RFC_5322 -> Instant.fromRfc5322(this) + TimestampFormat.EPOCH_SECONDS -> Instant.fromEpochSeconds(this) +} + +@InternalApi +public fun String.parseTimestamp(fmt: TimestampFormat): Result = parse { it.toTimestamp(fmt) } + +@InternalApi +public inline fun Result.parse(transform: (String) -> T): Result = mapCatching(transform) + +@InternalApi +public fun Result.parseBoolean(): Result = parse(String::toBoolean) + +@InternalApi +public fun Result.parseInt(): Result = parse(String::toInt) + +@InternalApi +public fun Result.parseShort(): Result = parse(String::toShort) + +@InternalApi +public fun Result.parseLong(): Result = parse(String::toLong) + +@InternalApi +public fun Result.parseFloat(): Result = parse(String::toFloat) + +@InternalApi +public fun Result.parseDouble(): Result = parse(String::toDouble) + +@InternalApi +public fun Result.parseByte(): Result = parse { it.toInt().toByte() } + +@InternalApi +public fun Result.parseBigInteger(): Result = parse(::BigInteger) + +@InternalApi +public fun Result.parseBigDecimal(): Result = parse(::BigDecimal) + +@InternalApi +public fun Result.parseTimestamp(fmt: TimestampFormat): Result = parse { it.toTimestamp(fmt) } diff --git a/runtime/serde/serde-xml/api/serde-xml.api b/runtime/serde/serde-xml/api/serde-xml.api index 846e3d41f..7ed0bcdca 100644 --- a/runtime/serde/serde-xml/api/serde-xml.api +++ b/runtime/serde/serde-xml/api/serde-xml.api @@ -199,6 +199,23 @@ public final class aws/smithy/kotlin/runtime/serde/xml/XmlStreamWriterKt { public static synthetic fun xmlStreamWriter$default (ZILjava/lang/Object;)Laws/smithy/kotlin/runtime/serde/xml/XmlStreamWriter; } +public final class aws/smithy/kotlin/runtime/serde/xml/XmlTagReader { + public fun (Laws/smithy/kotlin/runtime/serde/xml/XmlToken$BeginElement;Laws/smithy/kotlin/runtime/serde/xml/XmlStreamReader;)V + public final fun drop ()V + public final fun getTag ()Laws/smithy/kotlin/runtime/serde/xml/XmlToken$BeginElement; + public final fun getTagName ()Ljava/lang/String; + public final fun nextHasValue ()Z + public final fun nextTag ()Laws/smithy/kotlin/runtime/serde/xml/XmlTagReader; + public final fun nextToken ()Laws/smithy/kotlin/runtime/serde/xml/XmlToken; +} + +public final class aws/smithy/kotlin/runtime/serde/xml/XmlTagReaderKt { + public static final fun data (Laws/smithy/kotlin/runtime/serde/xml/XmlTagReader;)Ljava/lang/String; + public static final fun tagReader (Laws/smithy/kotlin/runtime/serde/xml/XmlToken$BeginElement;Laws/smithy/kotlin/runtime/serde/xml/XmlStreamReader;)Laws/smithy/kotlin/runtime/serde/xml/XmlTagReader; + public static final fun tryData (Laws/smithy/kotlin/runtime/serde/xml/XmlTagReader;)Ljava/lang/Object; + public static final fun xmlTagReader ([B)Laws/smithy/kotlin/runtime/serde/xml/XmlTagReader; +} + public abstract class aws/smithy/kotlin/runtime/serde/xml/XmlToken { public abstract fun getDepth ()I public fun toString ()Ljava/lang/String; @@ -216,6 +233,7 @@ public final class aws/smithy/kotlin/runtime/serde/xml/XmlToken$BeginElement : a public final fun copy (ILaws/smithy/kotlin/runtime/serde/xml/XmlToken$QualifiedName;Ljava/util/Map;Ljava/util/List;)Laws/smithy/kotlin/runtime/serde/xml/XmlToken$BeginElement; public static synthetic fun copy$default (Laws/smithy/kotlin/runtime/serde/xml/XmlToken$BeginElement;ILaws/smithy/kotlin/runtime/serde/xml/XmlToken$QualifiedName;Ljava/util/Map;Ljava/util/List;ILjava/lang/Object;)Laws/smithy/kotlin/runtime/serde/xml/XmlToken$BeginElement; public fun equals (Ljava/lang/Object;)Z + public final fun getAttr (Ljava/lang/String;)Ljava/lang/String; public final fun getAttributes ()Ljava/util/Map; public fun getDepth ()I public final fun getName ()Laws/smithy/kotlin/runtime/serde/xml/XmlToken$QualifiedName; @@ -258,6 +276,7 @@ public final class aws/smithy/kotlin/runtime/serde/xml/XmlToken$Namespace { } public final class aws/smithy/kotlin/runtime/serde/xml/XmlToken$QualifiedName { + public static final field Companion Laws/smithy/kotlin/runtime/serde/xml/XmlToken$QualifiedName$Companion; public fun (Ljava/lang/String;Ljava/lang/String;)V public synthetic fun (Ljava/lang/String;Ljava/lang/String;ILkotlin/jvm/internal/DefaultConstructorMarker;)V public final fun component1 ()Ljava/lang/String; @@ -272,6 +291,10 @@ public final class aws/smithy/kotlin/runtime/serde/xml/XmlToken$QualifiedName { public fun toString ()Ljava/lang/String; } +public final class aws/smithy/kotlin/runtime/serde/xml/XmlToken$QualifiedName$Companion { + public final fun from (Ljava/lang/String;)Laws/smithy/kotlin/runtime/serde/xml/XmlToken$QualifiedName; +} + public final class aws/smithy/kotlin/runtime/serde/xml/XmlToken$StartDocument : aws/smithy/kotlin/runtime/serde/xml/XmlToken { public static final field INSTANCE Laws/smithy/kotlin/runtime/serde/xml/XmlToken$StartDocument; public fun getDepth ()I diff --git a/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlDeserializer.kt b/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlDeserializer.kt index e9b58efe2..e6dd81942 100644 --- a/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlDeserializer.kt +++ b/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlDeserializer.kt @@ -31,6 +31,7 @@ internal sealed class FieldLocation { * restXml based services DO NOT always send documents with a root element name that matches the shape ID name * (S3 in particular). This means there is nothing in the model that gives you enough information to validate the tag. */ +@Deprecated("XmlDeserializer is deprecated and will be removed in a future release") @InternalApi public class XmlDeserializer( private val reader: XmlStreamReader, diff --git a/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlFieldTraits.kt b/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlFieldTraits.kt index f14271d22..24c0d1acc 100644 --- a/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlFieldTraits.kt +++ b/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlFieldTraits.kt @@ -168,7 +168,7 @@ internal fun SdkFieldDescriptor.toQualifiedNames( /** * Determines if the qualified name of this field descriptor matches the given name. */ -internal fun SdkFieldDescriptor.nameMatches(other: String): Boolean = toQualifiedNames().any { it.tag == other } +internal fun SdkFieldDescriptor.nameMatches(other: String): Boolean = toQualifiedNames().any { it.toString() == other } /** * Requires that the given name matches one of this field descriptor's qualified names. diff --git a/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlPrimitiveDeserializer.kt b/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlPrimitiveDeserializer.kt index 1707fdf32..00f124d19 100644 --- a/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlPrimitiveDeserializer.kt +++ b/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlPrimitiveDeserializer.kt @@ -2,7 +2,6 @@ * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. * SPDX-License-Identifier: Apache-2.0 */ - package aws.smithy.kotlin.runtime.serde.xml import aws.smithy.kotlin.runtime.content.BigDecimal diff --git a/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlTagReader.kt b/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlTagReader.kt new file mode 100644 index 000000000..338e2944f --- /dev/null +++ b/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlTagReader.kt @@ -0,0 +1,123 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package aws.smithy.kotlin.runtime.serde.xml + +import aws.smithy.kotlin.runtime.InternalApi +import aws.smithy.kotlin.runtime.serde.DeserializationException + +/** + * An [XmlStreamReader] scoped to reading a single XML element [tag] + * XmlTagReader provides a "tag" scoped view into an XML document. Methods return + * `null` when the current tag has been exhausted. + */ +@InternalApi +public class XmlTagReader( + public val tag: XmlToken.BeginElement, + private val reader: XmlStreamReader, +) { + // last tag we emitted and returned to the caller + private var lastEmitted: XmlTagReader? = null + private var closed = false + + /** + * Get the fully qualified tag name of [tag] + */ + public val tagName: String + get() = tag.name.toString() + + /** + * Return the next actionable token or null if stream is exhausted. + */ + public fun nextToken(): XmlToken? { + if (closed) return null + val peek = reader.peek() + if (peek.terminates(tag)) { + // consume it and close the tag reader + reader.nextToken() + closed = true + return null + } + return reader.nextToken() + } + + /** + * Check if the next token has a value, returns false if [XmlToken.EndElement] + * would be returned. + */ + public fun nextHasValue(): Boolean { + if (closed) return false + return reader.peek() !is XmlToken.EndElement + } + + /** + * Exhaust this [XmlTagReader] to completion. This should always + * be invoked to maintain deserialization state. + */ + public fun drop() { + do { + val tok = nextToken() + } while (tok != null) + } + + /** + * Return an [XmlTagReader] for the next [XmlToken.BeginElement]. The returned reader + * is only valid until [nextTag] is called or [drop] is invoked on it, whichever comes first. + */ + public fun nextTag(): XmlTagReader? { + lastEmitted?.drop() + + var cand = nextToken() + while (cand != null && cand !is XmlToken.BeginElement) { + cand = nextToken() + } + + val nextTok = cand as? XmlToken.BeginElement + + return nextTok?.tagReader(reader).also { newScope -> + lastEmitted = newScope + } + } +} + +/** + * Get a [XmlTagReader] for the root tag. This is the entry point for beginning + * deserialization. + */ +@InternalApi +public fun xmlTagReader(payload: ByteArray): XmlTagReader = + xmlStreamReader(payload).root() + +private fun XmlStreamReader.root(): XmlTagReader { + val start = seek() ?: error("expected start tag: last = $lastToken") + return start.tagReader(this) +} + +/** + * Create a new reader scoped to this element. + */ +@InternalApi +public fun XmlToken.BeginElement.tagReader(reader: XmlStreamReader): XmlTagReader { + val start = reader.lastToken as? XmlToken.BeginElement ?: error("expected start tag found ${reader.lastToken}") + check(name == start.name) { "expected start tag $name but current reader state is on ${start.name}" } + return XmlTagReader(this, reader) +} + +/** + * Unwrap the next token as [XmlToken.Text] and return its value or throw a [DeserializationException] + */ +@InternalApi +public fun XmlTagReader.data(): String = + when (val next = nextToken()) { + is XmlToken.Text -> next.value ?: "" + null, is XmlToken.EndElement -> "" + else -> throw DeserializationException("expected XmlToken.Text element, found $next") + } + +/** + * Attempt to get the text token as [XmlToken.Text] and return a result containing its' value on success + * or the exception thrown on failure. + */ +@InternalApi +public fun XmlTagReader.tryData(): Result = runCatching { data() } diff --git a/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlToken.kt b/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlToken.kt index ca1bfd8c8..8358ab37e 100644 --- a/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlToken.kt +++ b/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlToken.kt @@ -28,12 +28,29 @@ public sealed class XmlToken { */ @InternalApi public data class QualifiedName(public val local: String, public val prefix: String? = null) { - override fun toString(): String = tag - - public val tag: String get() = when (prefix) { + override fun toString(): String = when (prefix) { null -> local else -> "$prefix:$local" } + + @InternalApi + public companion object { + + /** + * Construct a [QualifiedName] from a raw string representation + */ + public fun from(qualified: String): QualifiedName { + val split = qualified.split(":", limit = 2) + val (local, prefix) = when (split.size == 2) { + true -> split[1] to split[0] + false -> split[0] to null + } + return QualifiedName(local, prefix) + } + } + + val tag: String + get() = toString() } /** @@ -46,13 +63,17 @@ public sealed class XmlToken { public val attributes: Map = emptyMap(), public val nsDeclarations: List = emptyList(), ) : XmlToken() { + // Convenience constructor for name-only nodes. public constructor(depth: Int, name: String) : this(depth, QualifiedName(name)) // Convenience constructor for name-only nodes with attributes. public constructor(depth: Int, name: String, attributes: Map) : this(depth, QualifiedName(name), attributes) - override fun toString(): String = "<${this.name} (${this.depth})>" + override fun toString(): String = "<$name ($depth)>" + + // convenience function for codegen + public fun getAttr(qualified: String): String? = attributes[QualifiedName.from(qualified)] } /** @@ -63,7 +84,7 @@ public sealed class XmlToken { // Convenience constructor for name-only nodes. public constructor(depth: Int, name: String) : this(depth, QualifiedName(name)) - override fun toString(): String = " (${this.depth})" + override fun toString(): String = " ($depth)" } /** @@ -71,7 +92,7 @@ public sealed class XmlToken { */ @InternalApi public data class Text(override val depth: Int, public val value: String?) : XmlToken() { - override fun toString(): String = "${this.value} (${this.depth})" + override fun toString(): String = "$value ($depth)" } @InternalApi @@ -90,9 +111,9 @@ public sealed class XmlToken { } override fun toString(): String = when (this) { - is BeginElement -> "<${this.name}>" - is EndElement -> "" - is Text -> "${this.value}" + is BeginElement -> "<$name>" + is EndElement -> "" + is Text -> "$value" StartDocument -> "[StartDocument]" EndDocument -> "[EndDocument]" } diff --git a/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/deserialization/LexingXmlStreamReader.kt b/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/deserialization/LexingXmlStreamReader.kt index 1c2d648f6..a241d9578 100644 --- a/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/deserialization/LexingXmlStreamReader.kt +++ b/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/deserialization/LexingXmlStreamReader.kt @@ -42,16 +42,15 @@ public class LexingXmlStreamReader(private val source: XmlLexer) : XmlStreamRead override fun skipNext() { val peekToken = peek(1) ?: return val startDepth = peekToken.depth + scanUntilDepth(startDepth, nextToken()) + } - tailrec fun scanUntilDepth(from: XmlToken?) { - when { - from == null || from is XmlToken.EndDocument -> return // End of document - from is XmlToken.EndElement && from.depth == startDepth -> return // Returned to original start depth - else -> scanUntilDepth(nextToken()) // Keep scannin'! - } + private tailrec fun scanUntilDepth(startDepth: Int, from: XmlToken?) { + when { + from == null || from is XmlToken.EndDocument -> return // End of document + from is XmlToken.EndElement && from.depth == startDepth -> return // Returned to original start depth + else -> scanUntilDepth(startDepth, nextToken()) // Keep scannin'! } - - scanUntilDepth(nextToken()) } override fun subTreeReader(subtreeStartDepth: XmlStreamReader.SubtreeStartDepth): XmlStreamReader = @@ -118,16 +117,15 @@ private class ChildXmlStreamReader( * An empty XML stream reader that trivially returns `null` for all [nextToken] and [peek] invocations. * @param parent The [LexingXmlStreamReader] on which this child reader is based. */ -private class EmptyXmlStreamReader(private val parent: XmlStreamReader) : XmlStreamReader { +private class EmptyXmlStreamReader(private val parent: XmlStreamReader?) : XmlStreamReader { override val lastToken: XmlToken? - get() = parent.lastToken + get() = parent?.lastToken override fun nextToken(): XmlToken? = null override fun peek(index: Int): XmlToken? = null override fun skipNext() = Unit - override fun subTreeReader(subtreeStartDepth: XmlStreamReader.SubtreeStartDepth): XmlStreamReader = this } diff --git a/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/dom/XmlNode.kt b/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/dom/XmlNode.kt index 0d1db20e7..7959ae783 100644 --- a/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/dom/XmlNode.kt +++ b/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/dom/XmlNode.kt @@ -121,7 +121,7 @@ internal fun formatXmlNode(curr: XmlNode, depth: Int, sb: StringBuilder, pretty: // open tag append("$indent<") - append(curr.name.tag) + append(curr.name.toString()) curr.namespaces.forEach { // namespaces declared by this node append(" xmlns") @@ -134,7 +134,7 @@ internal fun formatXmlNode(curr: XmlNode, depth: Int, sb: StringBuilder, pretty: // attributes if (curr.attributes.isNotEmpty()) append(" ") curr.attributes.forEach { - append("${it.key.tag}=\"${it.value}\"") + append("${it.key}=\"${it.value}\"") } append(">") @@ -155,7 +155,7 @@ internal fun formatXmlNode(curr: XmlNode, depth: Int, sb: StringBuilder, pretty: } append("") if (pretty && depth > 0) appendLine() diff --git a/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/SharedTestData.kt b/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/SharedTestData.kt deleted file mode 100644 index 94a905df3..000000000 --- a/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/SharedTestData.kt +++ /dev/null @@ -1,358 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ - -package aws.smithy.kotlin.runtime.serde.xml - -import aws.smithy.kotlin.runtime.serde.* - -class SimpleStructClass { - var x: Int? = null - var y: Int? = null - var z: String? = null - - // Only for testing, not serialization - var unknownFieldCount: Int = 0 - - companion object { - val X_DESCRIPTOR = SdkFieldDescriptor(SerialKind.Integer, XmlSerialName("x")) - val Y_DESCRIPTOR = SdkFieldDescriptor(SerialKind.Integer, XmlSerialName("y")) - val Z_DESCRIPTOR = SdkFieldDescriptor(SerialKind.String, XmlSerialName("z"), XmlAttribute) - val OBJ_DESCRIPTOR = SdkObjectDescriptor.build { - trait(XmlSerialName("payload")) - field(X_DESCRIPTOR) - field(Y_DESCRIPTOR) - field(Z_DESCRIPTOR) - } - - fun deserialize(deserializer: Deserializer): SimpleStructClass { - val result = SimpleStructClass() - deserializer.deserializeStruct(OBJ_DESCRIPTOR) { - loop@ while (true) { - when (findNextFieldIndex()) { - X_DESCRIPTOR.index -> result.x = deserializeInt() - Y_DESCRIPTOR.index -> result.y = deserializeInt() - Z_DESCRIPTOR.index -> result.z = deserializeString() - null -> break@loop - else -> throw DeserializationException(IllegalStateException("unexpected field in BasicStructTest deserializer")) - } - } - } - return result - } - } -} - -class SimpleStructOfStringsClass { - var x: String? = null - var y: String? = null - - // Only for testing, not serialization - var unknownFieldCount: Int = 0 - - companion object { - val X_DESCRIPTOR = SdkFieldDescriptor(SerialKind.Integer, XmlSerialName("x")) - val Y_DESCRIPTOR = SdkFieldDescriptor(SerialKind.Integer, XmlSerialName("y")) - val OBJ_DESCRIPTOR = SdkObjectDescriptor.build { - trait(XmlSerialName("payload")) - field(X_DESCRIPTOR) - field(Y_DESCRIPTOR) - } - - fun deserialize(deserializer: Deserializer): SimpleStructOfStringsClass { - val result = SimpleStructOfStringsClass() - deserializer.deserializeStruct(OBJ_DESCRIPTOR) { - loop@ while (true) { - when (findNextFieldIndex()) { - X_DESCRIPTOR.index -> result.x = deserializeString() - Y_DESCRIPTOR.index -> result.y = deserializeString() - null -> break@loop - else -> throw DeserializationException(IllegalStateException("unexpected field in BasicStructTest deserializer")) - } - } - } - return result - } - } -} - -class StructWithAttribsClass { - var x: Int? = null - var y: Int? = null - var unknownFieldCount: Int = 0 - - companion object { - val X_DESCRIPTOR = SdkFieldDescriptor(SerialKind.Integer, XmlSerialName("x"), XmlAttribute) - val Y_DESCRIPTOR = SdkFieldDescriptor(SerialKind.Integer, XmlSerialName("y"), XmlAttribute) - val OBJ_DESCRIPTOR = SdkObjectDescriptor.build { - trait(XmlSerialName("payload")) - field(X_DESCRIPTOR) - field(Y_DESCRIPTOR) - } - - fun deserialize(deserializer: Deserializer): StructWithAttribsClass { - val result = StructWithAttribsClass() - deserializer.deserializeStruct(OBJ_DESCRIPTOR) { - loop@ while (true) { - when (findNextFieldIndex()) { - X_DESCRIPTOR.index -> result.x = deserializeInt() - Y_DESCRIPTOR.index -> result.y = deserializeInt() - null -> break@loop - Deserializer.FieldIterator.UNKNOWN_FIELD -> { - result.unknownFieldCount++ - skipValue() - } - else -> throw DeserializationException(IllegalStateException("unexpected field in BasicStructTest deserializer")) - } - } - } - return result - } - } -} - -class StructWithMultiAttribsAndTextValClass { - var x: Int? = null - var y: Int? = null - var txt: String? = null - var unknownFieldCount: Int = 0 - - companion object { - val X_DESCRIPTOR = SdkFieldDescriptor(SerialKind.Integer, XmlSerialName("xval"), XmlAttribute) - val Y_DESCRIPTOR = SdkFieldDescriptor(SerialKind.Integer, XmlSerialName("yval"), XmlAttribute) - val TXT_DESCRIPTOR = SdkFieldDescriptor(SerialKind.Integer, XmlSerialName("x")) - val OBJ_DESCRIPTOR = SdkObjectDescriptor.build { - trait(XmlSerialName("payload")) - field(TXT_DESCRIPTOR) - field(X_DESCRIPTOR) - field(Y_DESCRIPTOR) - } - - fun deserialize(deserializer: Deserializer): StructWithMultiAttribsAndTextValClass { - val result = StructWithMultiAttribsAndTextValClass() - deserializer.deserializeStruct(OBJ_DESCRIPTOR) { - loop@ while (true) { - when (findNextFieldIndex()) { - X_DESCRIPTOR.index -> result.x = deserializeInt() - Y_DESCRIPTOR.index -> result.y = deserializeInt() - TXT_DESCRIPTOR.index -> result.txt = deserializeString() - null -> break@loop - Deserializer.FieldIterator.UNKNOWN_FIELD -> { - result.unknownFieldCount++ - skipValue() - } - else -> throw DeserializationException(IllegalStateException("unexpected field in BasicStructTest deserializer")) - } - } - } - return result - } - } -} - -class RecursiveShapesInputOutput private constructor(builder: Builder) { - val nested: RecursiveShapesInputOutputNested1? = builder.nested - - companion object { - operator fun invoke(block: Builder.() -> kotlin.Unit): RecursiveShapesInputOutput = Builder().apply(block).build() - } - - override fun toString(): kotlin.String = buildString { - append("RecursiveShapesInputOutput(") - append("nested=$nested)") - } - - override fun hashCode(): kotlin.Int { - var result = nested?.hashCode() ?: 0 - return result - } - - override fun equals(other: kotlin.Any?): kotlin.Boolean { - if (this === other) return true - - other as RecursiveShapesInputOutput - - if (nested != other.nested) return false - - return true - } - - fun copy(block: Builder.() -> kotlin.Unit = {}): RecursiveShapesInputOutput = Builder(this).apply(block).build() - - public class Builder() { - var nested: RecursiveShapesInputOutputNested1? = null - - constructor(x: RecursiveShapesInputOutput) : this() { - this.nested = x.nested - } - - fun build(): RecursiveShapesInputOutput = RecursiveShapesInputOutput(this) - } -} - -class RecursiveShapesInputOutputNested1 private constructor(builder: Builder) { - val foo: String? = builder.foo - val nested: RecursiveShapesInputOutputNested2? = builder.nested - - companion object { - fun dslBuilder(): Builder = Builder() - - operator fun invoke(block: Builder.() -> kotlin.Unit): RecursiveShapesInputOutputNested1 = Builder().apply(block).build() - } - - override fun toString(): kotlin.String = buildString { - append("RecursiveShapesInputOutputNested1(") - append("foo=$foo,") - append("nested=$nested)") - } - - override fun hashCode(): kotlin.Int { - var result = foo?.hashCode() ?: 0 - result = 31 * result + (nested?.hashCode() ?: 0) - return result - } - - override fun equals(other: kotlin.Any?): kotlin.Boolean { - if (this === other) return true - - other as RecursiveShapesInputOutputNested1 - - if (foo != other.foo) return false - if (nested != other.nested) return false - - return true - } - - fun copy(block: Builder.() -> kotlin.Unit = {}): RecursiveShapesInputOutputNested1 = Builder(this).apply(block).build() - - public class Builder() { - var foo: String? = null - var nested: RecursiveShapesInputOutputNested2? = null - - constructor(x: RecursiveShapesInputOutputNested1) : this() { - this.foo = x.foo - this.nested = x.nested - } - - fun build(): RecursiveShapesInputOutputNested1 = RecursiveShapesInputOutputNested1(this) - } -} - -class RecursiveShapesInputOutputNested2 private constructor(builder: Builder) { - val bar: String? = builder.bar - val recursiveMember: RecursiveShapesInputOutputNested1? = builder.recursiveMember - - companion object { - fun dslBuilder(): Builder = Builder() - - operator fun invoke(block: Builder.() -> kotlin.Unit): RecursiveShapesInputOutputNested2 = Builder().apply(block).build() - } - - override fun toString(): kotlin.String = buildString { - append("RecursiveShapesInputOutputNested2(") - append("bar=$bar,") - append("recursiveMember=$recursiveMember)") - } - - override fun hashCode(): kotlin.Int { - var result = bar?.hashCode() ?: 0 - result = 31 * result + (recursiveMember?.hashCode() ?: 0) - return result - } - - override fun equals(other: kotlin.Any?): kotlin.Boolean { - if (this === other) return true - - other as RecursiveShapesInputOutputNested2 - - if (bar != other.bar) return false - if (recursiveMember != other.recursiveMember) return false - - return true - } - - fun copy(block: Builder.() -> kotlin.Unit = {}): RecursiveShapesInputOutputNested2 = Builder(this).apply(block).build() - - public class Builder() { - var bar: String? = null - var recursiveMember: RecursiveShapesInputOutputNested1? = null - - constructor(x: RecursiveShapesInputOutputNested2) : this() { - this.bar = x.bar - this.recursiveMember = x.recursiveMember - } - - fun build(): RecursiveShapesInputOutputNested2 = RecursiveShapesInputOutputNested2(this) - } -} - -/* - @xmlNamespace(uri: "http://foo.com") - structure XmlNamespacesInputOutput { - nested: XmlNamespaceNested - } - - // Ignored since it's not at the top-level - @xmlNamespace(uri: "http://foo.com") - structure XmlNamespaceNested { - @xmlNamespace(uri: "http://baz.com", prefix: "baz") - foo: String, - - @xmlNamespace(uri: "http://qux.com") - values: XmlNamespacedList - } - - list XmlNamespacedList { - @xmlNamespace(uri: "http://bux.com") - member: String, - } -*/ -class XmlNamespacesRequest(val nested: XmlNamespaceNested?) { - companion object { - private val NESTED_DESCRIPTOR = SdkFieldDescriptor(SerialKind.Struct, XmlSerialName("nested")) - private val OBJ_DESCRIPTOR = SdkObjectDescriptor.build { - trait(XmlSerialName("XmlNamespacesInputOutput")) - trait(XmlNamespace("http://foo.com")) - field(NESTED_DESCRIPTOR) - } - } - - fun serialize(serializer: Serializer) { - serializer.serializeStruct(OBJ_DESCRIPTOR) { - nested?.let { field(NESTED_DESCRIPTOR, XmlNamespaceNestedDocumentSerializer(it)) } - } - } -} - -data class XmlNamespaceNested( - val foo: String? = null, - val values: List? = null, -) - -internal class XmlNamespaceNestedDocumentSerializer(val input: XmlNamespaceNested) : SdkSerializable { - - companion object { - private val FOO_DESCRIPTOR = SdkFieldDescriptor(SerialKind.String, XmlSerialName("foo"), XmlNamespace("http://baz.com", "baz")) - private val VALUES_DESCRIPTOR = SdkFieldDescriptor(SerialKind.List, XmlSerialName("values"), XmlNamespace("http://qux.com"), XmlCollectionValueNamespace("http://bux.com")) - private val OBJ_DESCRIPTOR = SdkObjectDescriptor.build { - trait(XmlSerialName("XmlNamespaceNested")) - trait(XmlNamespace("http://foo.com")) - field(FOO_DESCRIPTOR) - field(VALUES_DESCRIPTOR) - } - } - - override fun serialize(serializer: Serializer) { - serializer.serializeStruct(OBJ_DESCRIPTOR) { - input.foo?.let { field(FOO_DESCRIPTOR, it) } - if (input.values != null) { - listField(VALUES_DESCRIPTOR) { - for (el0 in input.values) { - serializeString(el0) - } - } - } - } - } -} diff --git a/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/XmlDeserializerAWSTest.kt b/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/XmlDeserializerAWSTest.kt deleted file mode 100644 index bf7b1ee36..000000000 --- a/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/XmlDeserializerAWSTest.kt +++ /dev/null @@ -1,126 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ -package aws.smithy.kotlin.runtime.serde.xml - -import aws.smithy.kotlin.runtime.serde.* -import kotlin.test.Test -import kotlin.test.assertNotNull -import kotlin.test.assertTrue - -class XmlDeserializerAWSTest { - - class HostedZoneConfig private constructor(builder: Builder) { - val comment: String? = builder.comment - - companion object { - val COMMENT_DESCRIPTOR = SdkFieldDescriptor(SerialKind.String, XmlSerialName("Comment")) - val OBJ_DESCRIPTOR = SdkObjectDescriptor.build { - trait(XmlSerialName("HostedZoneConfig")) - trait(XmlNamespace("https://route53.amazonaws.com/doc/2013-04-01/")) - field(COMMENT_DESCRIPTOR) - } - - fun deserialize(deserializer: Deserializer): HostedZoneConfig { - val builder = Builder() - - deserializer.deserializeStruct(OBJ_DESCRIPTOR) { - loop@ while (true) { - when (findNextFieldIndex()) { - COMMENT_DESCRIPTOR.index -> builder.comment = deserializeString() - null -> break@loop - Deserializer.FieldIterator.UNKNOWN_FIELD -> { - } - else -> throw DeserializationException(IllegalStateException("unexpected field index in HostedZoneConfig deserializer")) - } - } - } - return HostedZoneConfig(builder) - } - - operator fun invoke(block: Builder.() -> Unit) = Builder().apply(block).build() - } - - public class Builder { - var comment: String? = null - - fun build(): HostedZoneConfig = HostedZoneConfig(this) - } - } - - class CreateHostedZoneRequest private constructor(builder: Builder) { - val name: String? = builder.name - val callerReference: String? = builder.callerReference - val hostedZoneConfig: HostedZoneConfig? = builder.hostedZoneConfig - - companion object { - val NAME_DESCRIPTOR = SdkFieldDescriptor(SerialKind.String, XmlSerialName("Name")) - val CALLER_REFERENCE_DESCRIPTOR = SdkFieldDescriptor(SerialKind.String, XmlSerialName("CallerReference")) - val HOSTED_ZONE_DESCRIPTOR = SdkFieldDescriptor(SerialKind.Struct, XmlSerialName("HostedZoneConfig")) - - val OBJ_DESCRIPTOR = SdkObjectDescriptor.build { - trait(XmlSerialName("CreateHostedZoneRequest")) - trait(XmlNamespace("https://route53.amazonaws.com/doc/2013-04-01/")) - field(NAME_DESCRIPTOR) - field(CALLER_REFERENCE_DESCRIPTOR) - field(HOSTED_ZONE_DESCRIPTOR) - } - - fun deserialize(deserializer: Deserializer): CreateHostedZoneRequest { - val builder = Builder() - deserializer.deserializeStruct(OBJ_DESCRIPTOR) { - loop@ while (true) { - when (findNextFieldIndex()) { - NAME_DESCRIPTOR.index -> builder.name = deserializeString() - CALLER_REFERENCE_DESCRIPTOR.index -> builder.callerReference = deserializeString() - HOSTED_ZONE_DESCRIPTOR.index -> - builder.hostedZoneConfig = HostedZoneConfig.deserialize(deserializer) - null -> break@loop - Deserializer.FieldIterator.UNKNOWN_FIELD -> skipValue() - else -> throw DeserializationException(IllegalStateException("unexpected field index in CreateHostedZoneRequest deserializer")) - } - } - } - return builder.build() - } - - operator fun invoke(block: Builder.() -> Unit) = Builder().apply(block).build() - } - - public class Builder { - var name: String? = null - var callerReference: String? = null - var hostedZoneConfig: HostedZoneConfig? = null - - fun build(): CreateHostedZoneRequest = CreateHostedZoneRequest(this) - } - } - - @Test - fun itHandlesRoute53XML() { - val testXml = """ - - - - java.sdk.com. - a322f752-8156-4746-8c04-e174ca1f51ce - - comment - - - """.trimIndent() - - val unit = XmlDeserializer(testXml.encodeToByteArray()) - - val createHostedZoneRequest = CreateHostedZoneRequest.deserialize(unit) - - assertTrue(createHostedZoneRequest.name == "java.sdk.com.") - assertTrue(createHostedZoneRequest.callerReference == "a322f752-8156-4746-8c04-e174ca1f51ce") - assertNotNull(createHostedZoneRequest.hostedZoneConfig) - assertTrue(createHostedZoneRequest.hostedZoneConfig.comment == "comment") - } -} diff --git a/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/XmlDeserializerListTest.kt b/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/XmlDeserializerListTest.kt deleted file mode 100644 index 2ba7052e6..000000000 --- a/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/XmlDeserializerListTest.kt +++ /dev/null @@ -1,712 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ -package aws.smithy.kotlin.runtime.serde.xml - -import aws.smithy.kotlin.runtime.serde.* -import io.kotest.matchers.collections.shouldContainExactly -import kotlin.test.Test -import kotlin.test.assertEquals -import kotlin.test.assertNotNull -import kotlin.test.assertTrue - -class XmlDeserializerListTest { - - class ListDeserializer private constructor(builder: BuilderImpl) { - val list: List? = builder.list - - companion object { - operator fun invoke(block: DslBuilder.() -> Unit) = BuilderImpl().apply(block).build() - fun dslBuilder(): DslBuilder = BuilderImpl() - - fun deserialize( - deserializer: Deserializer, - OBJ_DESCRIPTOR: SdkObjectDescriptor, - ELEMENT_LIST_FIELD_DESCRIPTOR: SdkFieldDescriptor, - ): ListDeserializer { - val builder = dslBuilder() - - deserializer.deserializeStruct(OBJ_DESCRIPTOR) { - loop@ while (true) { - when (findNextFieldIndex()) { - ELEMENT_LIST_FIELD_DESCRIPTOR.index -> - builder.list = - deserializer.deserializeList(ELEMENT_LIST_FIELD_DESCRIPTOR) { - val list = mutableListOf() - while (hasNextElement()) { - list.add(deserializeInt()) - } - return@deserializeList list - } - null -> break@loop - else -> skipValue() - } - } - } - - return builder.build() - } - } - - interface Builder { - fun build(): ListDeserializer - // TODO - Java fill in Java builder - } - - interface DslBuilder { - var list: List? - - fun build(): ListDeserializer - } - - private class BuilderImpl : Builder, DslBuilder { - override var list: List? = null - - override fun build(): ListDeserializer = ListDeserializer(this) - } - } - - @Test - fun itHandlesListSingleElement() { - val payload = """ - - - 1 - - - """.encodeToByteArray() - val ELEMENT_LIST_FIELD_DESCRIPTOR = SdkFieldDescriptor(SerialKind.List, XmlSerialName("list")) - val OBJ_DESCRIPTOR = SdkObjectDescriptor.build { - trait(XmlSerialName("object")) - field(ELEMENT_LIST_FIELD_DESCRIPTOR) - } - - val deserializer = XmlDeserializer(payload) - val actual = ListDeserializer.deserialize(deserializer, OBJ_DESCRIPTOR, ELEMENT_LIST_FIELD_DESCRIPTOR).list - val expected = listOf(1) - - actual.shouldContainExactly(expected) - } - - @Test - fun itHandlesListMultipleElementsAndCustomMemberName() { - val payload = """ - - - 1 - 2 - 3 - - - """.encodeToByteArray() - val ELEMENT_LIST_FIELD_DESCRIPTOR = SdkFieldDescriptor(SerialKind.List, XmlSerialName("list"), XmlCollectionName("element")) - val OBJ_DESCRIPTOR = SdkObjectDescriptor.build { - trait(XmlSerialName("object")) - field(ELEMENT_LIST_FIELD_DESCRIPTOR) - } - - val deserializer = XmlDeserializer(payload) - val actual = ListDeserializer.deserialize(deserializer, OBJ_DESCRIPTOR, ELEMENT_LIST_FIELD_DESCRIPTOR).list - val expected = listOf(1, 2, 3) - - actual.shouldContainExactly(expected) - } - - class SparseListDeserializer private constructor(builder: BuilderImpl) { - val list: List? = builder.list - - companion object { - operator fun invoke(block: DslBuilder.() -> Unit) = BuilderImpl().apply(block).build() - fun dslBuilder(): DslBuilder = BuilderImpl() - - fun deserialize( - deserializer: Deserializer, - OBJ_DESCRIPTOR: SdkObjectDescriptor, - ELEMENT_LIST_FIELD_DESCRIPTOR: SdkFieldDescriptor, - ): SparseListDeserializer { - val builder = dslBuilder() - - deserializer.deserializeStruct(OBJ_DESCRIPTOR) { - loop@ while (true) { - when (findNextFieldIndex()) { - ELEMENT_LIST_FIELD_DESCRIPTOR.index -> - builder.list = - deserializer.deserializeList(ELEMENT_LIST_FIELD_DESCRIPTOR) { - val col0 = mutableListOf() - while (hasNextElement()) { - val el0 = if (nextHasValue()) { - deserializeInt() - } else { - deserializeNull() - } - col0.add(el0) - } - col0 - } - null -> break@loop - else -> skipValue() - } - } - } - - return builder.build() - } - } - - interface Builder { - fun build(): SparseListDeserializer - // TODO - Java fill in Java builder - } - - interface DslBuilder { - var list: List? - - fun build(): SparseListDeserializer - } - - private class BuilderImpl : Builder, DslBuilder { - override var list: List? = null - - override fun build(): SparseListDeserializer = SparseListDeserializer(this) - } - } - - @Test - fun itHandlesSparseLists() { - val payload = """ - - - 1 - - 3 - - - """.encodeToByteArray() - val ELEMENT_LIST_FIELD_DESCRIPTOR = SdkFieldDescriptor(SerialKind.List, XmlSerialName("list"), SparseValues) - val OBJ_DESCRIPTOR = SdkObjectDescriptor.build { - trait(XmlSerialName("object")) - field(ELEMENT_LIST_FIELD_DESCRIPTOR) - } - - val deserializer = XmlDeserializer(payload) - val actual = - SparseListDeserializer.deserialize(deserializer, OBJ_DESCRIPTOR, ELEMENT_LIST_FIELD_DESCRIPTOR).list - val expected = listOf(1, null, 3) - - actual.shouldContainExactly(expected) - } - - @Test - fun itHandlesEmptyLists() { - val payload = """ - - - - - """.encodeToByteArray() - val ELEMENT_LIST_FIELD_DESCRIPTOR = SdkFieldDescriptor(SerialKind.List, XmlSerialName("list")) - val OBJ_DESCRIPTOR = SdkObjectDescriptor.build { - trait(XmlSerialName("object")) - field(ELEMENT_LIST_FIELD_DESCRIPTOR) - } - - val deserializer = XmlDeserializer(payload) - val actual = ListDeserializer.deserialize(deserializer, OBJ_DESCRIPTOR, ELEMENT_LIST_FIELD_DESCRIPTOR).list - val expected = emptyList() - - assertEquals(expected, actual) - } - - @Test - fun itHandlesFlatLists() { - val payload = """ - - 1 - 2 - 3 - - """.encodeToByteArray() - val elementFieldDescriptor = SdkFieldDescriptor(SerialKind.List, XmlSerialName("element"), Flattened) - val objectDescriptor = SdkObjectDescriptor.build { - trait(XmlSerialName("object")) - field(elementFieldDescriptor) - } - val deserializer = XmlDeserializer(payload) - val actual = ListDeserializer.deserialize(deserializer, objectDescriptor, elementFieldDescriptor).list - val expected = listOf(1, 2, 3) - - actual.shouldContainExactly(expected) - } - - @Test - fun itHandlesListOfObjectsWithMissingFields() { - val payload = """ - - - - a - b - - - - d - - - - """.encodeToByteArray() - val listWrapperFieldDescriptor = - SdkFieldDescriptor(SerialKind.List, XmlSerialName("list"), XmlCollectionName(element = "payload")) - val objectDescriptor = SdkObjectDescriptor.build { - trait(XmlSerialName("object")) - field(listWrapperFieldDescriptor) - } - - val deserializer = XmlDeserializer(payload) - var actual: MutableList? = null - deserializer.deserializeStruct(objectDescriptor) { - loop@ while (true) { - when (findNextFieldIndex()) { - listWrapperFieldDescriptor.index -> - actual = - deserializer.deserializeList(listWrapperFieldDescriptor) { - val list = mutableListOf() - while (hasNextElement()) { - list.add(SimpleStructOfStringsClass.deserialize(deserializer)) - } - return@deserializeList list - } - null -> break@loop - else -> skipValue() - } - } - } - - assertEquals(2, actual!!.size) - - assertEquals("a", actual!![0].x) - assertEquals("b", actual!![0].y) - assertEquals("", actual!![1].x) - assertEquals("d", actual!![1].y) - } - - @Test - fun itHandlesListOfObjectsWithEmptyValues() { - val payload = """ - - - - 1 - 2 - - - - - """.encodeToByteArray() - val listWrapperFieldDescriptor = - SdkFieldDescriptor(SerialKind.List, XmlSerialName("list"), XmlCollectionName(element = "payload")) - val objectDescriptor = SdkObjectDescriptor.build { - trait(XmlSerialName("object")) - field(listWrapperFieldDescriptor) - } - - val deserializer = XmlDeserializer(payload) - var actual: MutableList? = null - deserializer.deserializeStruct(objectDescriptor) { - loop@ while (true) { - when (findNextFieldIndex()) { - listWrapperFieldDescriptor.index -> - actual = - deserializer.deserializeList(listWrapperFieldDescriptor) { - val list = mutableListOf() - while (hasNextElement()) { - list.add(SimpleStructClass.deserialize(deserializer)) - } - return@deserializeList list - } - null -> break@loop - else -> skipValue() - } - } - } - assertEquals(2, actual!!.size) - assertEquals(1, actual!![0].x) - assertEquals(2, actual!![0].y) - assertEquals(null, actual!![1].x) - assertEquals(null, actual!![1].y) - } - - @Test - fun itHandlesNestedLists() { - val payload = """ - - - - - a - 3 - - - a - 3 - - - - - b - 4 - - - c - 5 - - - - - d - 8 - - - e - 9 - - - - - """.encodeToByteArray() - - val deserializer = XmlDeserializer(payload) - val actual = NestedListOperationOperationDeserializer().deserialize(deserializer) - - assertTrue(actual.parentList?.size == 3) - } - - @Test - fun itHandlesListsOfStructs() { - val payload = """ - - - - a - 3 - - - b - 4 - - - c - 6 - - - - """.encodeToByteArray() - - val listDescriptor = SdkFieldDescriptor(SerialKind.List, XmlSerialName("parentList")) - val deserializer = XmlDeserializer(payload) - val actual = FooOperationDeserializer().deserialize(deserializer, listDescriptor) - - assertTrue(actual.parentList?.size == 3) - } - - @Test - fun itHandlesFlatListsOfStructs() { - val payload = """ - - - a - 3 - - - b - 4 - - - c - 6 - - - """.encodeToByteArray() - - val deserializer = XmlDeserializer(payload) - val listDescriptor = SdkFieldDescriptor(SerialKind.List, XmlSerialName("flatList"), Flattened) - val actual = FooOperationDeserializer().deserialize(deserializer, listDescriptor) - - val parentList = assertNotNull(actual.parentList) - assertEquals(3, parentList.size) - assertEquals(parentList[0].fooMember, "a") - assertEquals(parentList[0].someInt, 3) - assertEquals(parentList[2].fooMember, "c") - assertEquals(parentList[2].someInt, 6) - } -} - -internal class FooOperationDeserializer { - - fun deserialize( - deserializer: Deserializer, - LIST_DESCRIPTOR: SdkFieldDescriptor, - ): FooResponse { - val OBJ_DESCRIPTOR = SdkObjectDescriptor.build { - trait(XmlSerialName("FooResponse")) - field(LIST_DESCRIPTOR) - } - - val builder = FooResponse.Builder() - - deserializer.deserializeStruct(OBJ_DESCRIPTOR) { - loop@while (true) { - when (findNextFieldIndex()) { - LIST_DESCRIPTOR.index -> - builder.parentList = - deserializer.deserializeList(LIST_DESCRIPTOR) { - val col0 = mutableListOf() - while (hasNextElement()) { - val el0 = if (nextHasValue()) { PayloadStructDocumentDeserializer().deserialize(deserializer) } else { deserializeNull(); continue } - col0.add(el0) - } - col0 - } - null -> break@loop - else -> skipValue() - } - } - } - - return builder.build() - } -} - -internal class PayloadStructDocumentDeserializer { - - companion object { - private val FOOMEMBER_DESCRIPTOR = SdkFieldDescriptor(SerialKind.String, XmlSerialName("fooMember")) - private val SOMEINT_DESCRIPTOR = SdkFieldDescriptor(SerialKind.Integer, XmlSerialName("someInt")) - private val OBJ_DESCRIPTOR = SdkObjectDescriptor.build { - field(FOOMEMBER_DESCRIPTOR) - field(SOMEINT_DESCRIPTOR) - } - } - - fun deserialize(deserializer: Deserializer): PayloadStruct { - val builder = PayloadStruct.dslBuilder() - deserializer.deserializeStruct(OBJ_DESCRIPTOR) { - loop@while (true) { - when (findNextFieldIndex()) { - FOOMEMBER_DESCRIPTOR.index -> builder.fooMember = deserializeString() - SOMEINT_DESCRIPTOR.index -> builder.someInt = deserializeInt() - null -> break@loop - else -> skipValue() - } - } - } - return builder.build() - } -} - -class FooResponse private constructor(builder: Builder) { - val parentList: List? = builder.parentList - - companion object { - fun builder(): Builder = Builder() - - operator fun invoke(block: Builder.() -> kotlin.Unit): FooResponse = Builder().apply(block).build() - } - - override fun toString(): kotlin.String = buildString { - append("FooResponse(") - append("parentList=$parentList)") - } - - override fun hashCode(): kotlin.Int { - var result = parentList?.hashCode() ?: 0 - return result - } - - override fun equals(other: kotlin.Any?): kotlin.Boolean { - if (this === other) return true - - other as FooResponse - - if (parentList != other.parentList) return false - - return true - } - - fun copy(block: Builder.() -> kotlin.Unit = {}): FooResponse = Builder(this).apply(block).build() - - public class Builder() { - var parentList: List? = null - - constructor(x: FooResponse) : this() { - this.parentList = x.parentList - } - - fun build(): FooResponse = FooResponse(this) - } -} - -class PayloadStruct private constructor(builder: BuilderImpl) { - val fooMember: String? = builder.fooMember - val someInt: Int? = builder.someInt - - companion object { - fun builder(): Builder = BuilderImpl() - - fun dslBuilder(): DslBuilder = BuilderImpl() - - operator fun invoke(block: DslBuilder.() -> kotlin.Unit): PayloadStruct = BuilderImpl().apply(block).build() - } - - override fun toString(): kotlin.String = buildString { - append("PayloadStruct(") - append("fooMember=$fooMember,") - append("someInt=$someInt)") - } - - override fun hashCode(): kotlin.Int { - var result = fooMember?.hashCode() ?: 0 - result = 31 * result + (someInt ?: 0) - return result - } - - override fun equals(other: kotlin.Any?): kotlin.Boolean { - if (this === other) return true - - other as PayloadStruct - - if (fooMember != other.fooMember) return false - if (someInt != other.someInt) return false - - return true - } - - fun copy(block: DslBuilder.() -> kotlin.Unit = {}): PayloadStruct = BuilderImpl(this).apply(block).build() - - interface Builder { - fun build(): PayloadStruct - fun fooMember(fooMember: String): Builder - fun someInt(someInt: Int): Builder - } - - interface DslBuilder { - var fooMember: String? - var someInt: Int? - - fun build(): PayloadStruct - } - - private class BuilderImpl() : Builder, DslBuilder { - override var fooMember: String? = null - override var someInt: Int? = null - - constructor(x: PayloadStruct) : this() { - this.fooMember = x.fooMember - this.someInt = x.someInt - } - - override fun build(): PayloadStruct = PayloadStruct(this) - override fun fooMember(fooMember: String): Builder = apply { this.fooMember = fooMember } - override fun someInt(someInt: Int): Builder = apply { this.someInt = someInt } - } -} - -class NestedListResponse private constructor(builder: BuilderImpl) { - val parentList: List>? = builder.parentList - - companion object { - fun builder(): Builder = BuilderImpl() - - fun dslBuilder(): DslBuilder = BuilderImpl() - - operator fun invoke(block: DslBuilder.() -> kotlin.Unit): NestedListResponse = BuilderImpl().apply(block).build() - } - - override fun toString(): kotlin.String = buildString { - append("NestedListResponse(") - append("parentList=$parentList)") - } - - override fun hashCode(): kotlin.Int { - var result = parentList?.hashCode() ?: 0 - return result - } - - override fun equals(other: kotlin.Any?): kotlin.Boolean { - if (this === other) return true - - other as NestedListResponse - - if (parentList != other.parentList) return false - - return true - } - - fun copy(block: DslBuilder.() -> kotlin.Unit = {}): NestedListResponse = BuilderImpl(this).apply(block).build() - - interface Builder { - fun build(): NestedListResponse - fun parentList(parentList: List>): Builder - } - - interface DslBuilder { - var parentList: List>? - - fun build(): NestedListResponse - } - - private class BuilderImpl() : Builder, DslBuilder { - override var parentList: List>? = null - - constructor(x: NestedListResponse) : this() { - this.parentList = x.parentList - } - - override fun build(): NestedListResponse = NestedListResponse(this) - override fun parentList(parentList: List>): Builder = apply { this.parentList = parentList } - } -} - -internal class NestedListOperationOperationDeserializer { - - companion object { - private val PARENTLIST_DESCRIPTOR = SdkFieldDescriptor(SerialKind.List, XmlSerialName("parentList")) - private val PARENTLIST_C0_DESCRIPTOR = SdkFieldDescriptor(SerialKind.List, XmlSerialName("parentListC0")) - private val OBJ_DESCRIPTOR = SdkObjectDescriptor.build { - trait(XmlSerialName("NestedListResponse")) - field(PARENTLIST_DESCRIPTOR) - } - } - - fun deserialize(deserializer: Deserializer): NestedListResponse { - val builder = NestedListResponse.dslBuilder() - - deserializer.deserializeStruct(OBJ_DESCRIPTOR) { - loop@while (true) { - when (findNextFieldIndex()) { - PARENTLIST_DESCRIPTOR.index -> - builder.parentList = - deserializer.deserializeList(PARENTLIST_DESCRIPTOR) { - val col0 = mutableListOf>() - while (hasNextElement()) { - val el0 = deserializer.deserializeList(PARENTLIST_C0_DESCRIPTOR) { - val col1 = mutableListOf() - while (hasNextElement()) { - val el1 = if (nextHasValue()) { PayloadStructDocumentDeserializer().deserialize(deserializer) } else { deserializeNull(); continue } - col1.add(el1) - } - col1 - } - col0.add(el0) - } - col0 - } - null -> break@loop - else -> skipValue() - } - } - } - - return builder.build() - } -} diff --git a/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/XmlDeserializerMapTest.kt b/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/XmlDeserializerMapTest.kt deleted file mode 100644 index d93a75b41..000000000 --- a/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/XmlDeserializerMapTest.kt +++ /dev/null @@ -1,795 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ -package aws.smithy.kotlin.runtime.serde.xml - -import aws.smithy.kotlin.runtime.serde.* -import io.kotest.matchers.maps.shouldContainExactly -import kotlin.test.Test - -class XmlDeserializerMapTest { - - @Test - fun itHandlesMapsWithDefaultNodeNames() { - val payload = """ - - - - key1 - 1 - - - key2 - 2 - - - - """.encodeToByteArray() - val fieldDescriptor = SdkFieldDescriptor(SerialKind.Map, XmlSerialName("values")) - val objDescriptor = SdkObjectDescriptor.build { - trait(XmlSerialName("object")) - field(fieldDescriptor) - } - - var actual = mutableMapOf() - val deserializer = XmlDeserializer(payload) - deserializer.deserializeStruct(objDescriptor) { - loop@while (true) { - when (findNextFieldIndex()) { - fieldDescriptor.index -> - actual = - deserializer.deserializeMap(fieldDescriptor) { - val map0 = mutableMapOf() - while (hasNextEntry()) { - val k0 = key() - val v0 = if (nextHasValue()) { deserializeInt() } else { deserializeNull(); continue } - map0[k0] = v0 - } - map0 - } - null -> break@loop - else -> skipValue() - } - } - } - val expected = mapOf("key1" to 1, "key2" to 2) - actual.shouldContainExactly(expected) - } - - @Test - fun itHandlesMapsWithCustomNodeNames() { - val payload = """ - - - - key1 - 1 - - - key2 - 2 - - - - """.encodeToByteArray() - val fieldDescriptor = - SdkFieldDescriptor(SerialKind.Map, XmlSerialName("mymap"), XmlMapName("myentry", "mykey", "myvalue")) - val objDescriptor = SdkObjectDescriptor.build { - trait(XmlSerialName("object")) - field(fieldDescriptor) - } - var actual = mutableMapOf() - val deserializer = XmlDeserializer(payload) - deserializer.deserializeStruct(objDescriptor) { - loop@while (true) { - when (findNextFieldIndex()) { - fieldDescriptor.index -> - actual = - deserializer.deserializeMap(fieldDescriptor) { - val map0 = mutableMapOf() - while (hasNextEntry()) { - val k0 = key() - val v0 = if (nextHasValue()) { deserializeInt() } else { deserializeNull(); continue } - map0[k0] = v0 - } - map0 - } - null -> break@loop - else -> skipValue() - } - } - } - val expected = mapOf("key1" to 1, "key2" to 2) - actual.shouldContainExactly(expected) - } - - // https://awslabs.github.io/smithy/1.0/spec/core/xml-traits.html#flattened-map-serialization - @Test - fun itHandlesFlatMaps() { - val payload = """ - - - key1 - 1 - - - key2 - 2 - - - key3 - 3 - - - """.encodeToByteArray() - val containerFieldDescriptor = - SdkFieldDescriptor(SerialKind.Map, XmlSerialName("flatMap"), XmlMapName(null, "key", "value"), Flattened) - val objDescriptor = SdkObjectDescriptor.build { - trait(XmlSerialName("object")) - field(containerFieldDescriptor) - } - var actual = mutableMapOf() - val deserializer = XmlDeserializer(payload) - deserializer.deserializeStruct(objDescriptor) { - loop@while (true) { - when (findNextFieldIndex()) { - containerFieldDescriptor.index -> - actual = - deserializer.deserializeMap(containerFieldDescriptor) { - val map0 = mutableMapOf() - while (hasNextEntry()) { - val k0 = key() - val v0 = if (nextHasValue()) { deserializeInt() } else { deserializeNull(); continue } - map0[k0] = v0 - } - map0 - } - null -> break@loop - else -> skipValue() - } - } - } - val expected = mapOf("key1" to 1, "key2" to 2, "key3" to 3) - actual.shouldContainExactly(expected) - } - - @Test - fun itHandlesEmptyMaps() { - val payload = """ - - - - """.encodeToByteArray() - val containerFieldDescriptor = - SdkFieldDescriptor(SerialKind.Map, XmlSerialName("Map")) - val objDescriptor = SdkObjectDescriptor.build { - trait(XmlSerialName("object")) - field(containerFieldDescriptor) - } - - val deserializer = XmlDeserializer(payload) - var actual = mutableMapOf() - deserializer.deserializeStruct(objDescriptor) { - loop@while (true) { - when (findNextFieldIndex()) { - containerFieldDescriptor.index -> - actual = - deserializer.deserializeMap(containerFieldDescriptor) { - val map0 = mutableMapOf() - while (hasNextEntry()) { - val k0 = key() - val v0 = if (nextHasValue()) { deserializeInt() } else { deserializeNull(); continue } - map0[k0] = v0 - } - map0 - } - null -> break@loop - else -> skipValue() - } - } - } - - val expected = emptyMap() - actual.shouldContainExactly(expected) - } - - @Test - fun itHandlesSparseMaps() { - val payload = """ - - - - key1 - 1 - - - key2 - - - - - """.encodeToByteArray() - val fieldDescriptor = SdkFieldDescriptor(SerialKind.Map, XmlSerialName("values")) - val objDescriptor = SdkObjectDescriptor.build { - trait(XmlSerialName("object")) - field(fieldDescriptor) - } - - val deserializer = XmlDeserializer(payload) - var actual = mutableMapOf() - deserializer.deserializeStruct(objDescriptor) { - loop@while (true) { - when (findNextFieldIndex()) { - fieldDescriptor.index -> - actual = - deserializer.deserializeMap(fieldDescriptor) { - val map = mutableMapOf() - while (hasNextEntry()) { - val key = key() - val value = when (nextHasValue()) { - true -> deserializeInt() - false -> deserializeNull() - } - - map[key] = value - } - return@deserializeMap map - } - null -> break@loop - else -> skipValue() - } - } - } - - val expected = mapOf("key1" to 1, "key2" to null) - actual.shouldContainExactly(expected) - } - - @Test - fun itHandlesCheckingMapValuesForNull() { - val payload = """ - - - - key1 - 1 - - - key2 - - - - - """.encodeToByteArray() - val fieldDescriptor = SdkFieldDescriptor(SerialKind.Map, XmlSerialName("values")) - val objDescriptor = SdkObjectDescriptor.build { - trait(XmlSerialName("object")) - field(fieldDescriptor) - } - - val deserializer = XmlDeserializer(payload) - var actual = mutableMapOf() - deserializer.deserializeStruct(objDescriptor) { - loop@while (true) { - when (findNextFieldIndex()) { - fieldDescriptor.index -> - actual = - deserializer.deserializeMap(fieldDescriptor) { - val map0 = mutableMapOf() - while (hasNextEntry()) { - val k0 = key() - val v0 = if (nextHasValue()) { deserializeInt() } else { deserializeNull(); continue } - map0[k0] = v0 - } - map0 - } - null -> break@loop - else -> skipValue() - } - } - } - - val expected = mapOf("key1" to 1) - actual.shouldContainExactly(expected) - } - - @Test - fun itHandlesNestedMap() { - val payload = """ - - - - outer1 - - - - inner1 - innerValue1 - - - inner2 - innerValue2 - - - - - - outer2 - - - - inner3 - innerValue3 - - - inner4 - innerValue4 - - - - - - - """.encodeToByteArray() - val ELEMENT_MAP_FIELD_DESCRIPTOR = SdkFieldDescriptor(SerialKind.Map, XmlSerialName("map"), XmlMapName(entry = "outerEntry", value = "outerValue")) - val nestedMapDescriptor = SdkFieldDescriptor(SerialKind.Map, XmlSerialName("nestedMap"), XmlMapName(entry = "innerEntry", value = "innerValue")) - val OBJ_DESCRIPTOR = SdkObjectDescriptor.build { - trait(XmlSerialName("object")) - field(ELEMENT_MAP_FIELD_DESCRIPTOR) - } - - val deserializer = XmlDeserializer(payload) - var actual = mutableMapOf>() - deserializer.deserializeStruct(OBJ_DESCRIPTOR) { - loop@while (true) { - when (findNextFieldIndex()) { - ELEMENT_MAP_FIELD_DESCRIPTOR.index -> - actual = - deserializer.deserializeMap(ELEMENT_MAP_FIELD_DESCRIPTOR) { - val map0 = mutableMapOf>() - while (hasNextEntry()) { - val k0 = key() - val v0 = deserializer.deserializeMap(nestedMapDescriptor) { - val map1 = mutableMapOf() - while (hasNextEntry()) { - val k1 = key() - val v1 = if (nextHasValue()) { deserializeString() } else { deserializeNull(); continue } - map1[k1] = v1 - } - map1 - } - map0[k0] = v0 - } - map0 - } - null -> break@loop - else -> skipValue() - } - } - } - - val expected = mapOf( - "outer1" to mapOf("inner1" to "innerValue1", "inner2" to "innerValue2"), - "outer2" to mapOf("inner3" to "innerValue3", "inner4" to "innerValue4"), - ) - - actual.shouldContainExactly(expected) - } - - @Test - fun itHandlesNestedStructAsValue() { - val payload = """ - - - - foo - - there - - - - baz - - bye - - - - - """.encodeToByteArray() - - val deserializer = XmlDeserializer(payload) - val resp = XmlMapsOperationDeserializer().deserialize(deserializer) - - println(resp) - } - - // https://github.com/awslabs/aws-sdk-kotlin/issues/962 - @Test - fun itHandlesConsecutiveFlatMaps() { - val payload = """ - - - key1 - 1 - - - key2 - 2 - - - key3 - 3 - - - key4 - 4 - - - key5 - 5 - - - """.encodeToByteArray() - val firstMapDescriptor = SdkFieldDescriptor(SerialKind.Map, XmlSerialName("firstMap"), XmlMapName(null, "key", "value"), Flattened) - val secondMapDescriptor = SdkFieldDescriptor(SerialKind.Map, XmlSerialName("secondMap"), XmlMapName(null, "key", "value"), Flattened) - - val objDescriptor = SdkObjectDescriptor.build { - trait(XmlSerialName("object")) - field(firstMapDescriptor) - field(secondMapDescriptor) - } - var firstMap = mutableMapOf() - var secondMap = mutableMapOf() - val deserializer = XmlDeserializer(payload) - deserializer.deserializeStruct(objDescriptor) { - loop@while (true) { - when (findNextFieldIndex()) { - firstMapDescriptor.index -> - firstMap = - deserializer.deserializeMap(firstMapDescriptor) { - val map0 = mutableMapOf() - while (hasNextEntry()) { - val k0 = key() - val v0 = if (nextHasValue()) { deserializeInt() } else { deserializeNull(); continue } - map0[k0] = v0 - } - map0 - } - secondMapDescriptor.index -> - secondMap = - deserializer.deserializeMap(secondMapDescriptor) { - val map0 = mutableMapOf() - while (hasNextEntry()) { - val k0 = key() - val v0 = if (nextHasValue()) { deserializeInt() } else { deserializeNull(); continue } - map0[k0] = v0 - } - map0 - } - null -> break@loop - else -> skipValue() - } - } - } - - val expectedFirstMap = mapOf("key1" to 1, "key2" to 2, "key3" to 3) - firstMap.shouldContainExactly(expectedFirstMap) - val expectedSecondMap = mapOf("key4" to 4, "key5" to 5) - secondMap.shouldContainExactly(expectedSecondMap) - } - - @Test - fun itHandlesMapsFollowedByFlatMaps() { - val payload = """ - - - - key1 - 1 - - - key2 - 2 - - - - key3 - 3 - - - key4 - 4 - - - """.encodeToByteArray() - val mapDescriptor = SdkFieldDescriptor(SerialKind.Map, XmlSerialName("map")) - val flatMapDescriptor = SdkFieldDescriptor(SerialKind.Map, XmlSerialName("flatMap"), Flattened) - val objDescriptor = SdkObjectDescriptor.build { - trait(XmlSerialName("object")) - field(mapDescriptor) - field(flatMapDescriptor) - } - - var map = mutableMapOf() - var flatMap = mutableMapOf() - - val deserializer = XmlDeserializer(payload) - deserializer.deserializeStruct(objDescriptor) { - loop@while (true) { - when (findNextFieldIndex()) { - mapDescriptor.index -> - map = - deserializer.deserializeMap(mapDescriptor) { - val map0 = mutableMapOf() - while (hasNextEntry()) { - val k0 = key() - val v0 = if (nextHasValue()) { deserializeInt() } else { deserializeNull(); continue } - map0[k0] = v0 - } - map0 - } - flatMapDescriptor.index -> - flatMap = - deserializer.deserializeMap(flatMapDescriptor) { - val map0 = mutableMapOf() - while (hasNextEntry()) { - val k0 = key() - val v0 = if (nextHasValue()) { deserializeInt() } else { deserializeNull(); continue } - map0[k0] = v0 - } - map0 - } - null -> break@loop - else -> skipValue() - } - } - } - map.shouldContainExactly(mapOf("key1" to 1, "key2" to 2)) - flatMap.shouldContainExactly(mapOf("key3" to 3, "key4" to 4)) - } - - @Test - fun itHandlesFlatMapsFollowedByMaps() { - val payload = """ - - - key3 - 3 - - - key4 - 4 - - - - key1 - 1 - - - key2 - 2 - - - - """.encodeToByteArray() - val mapDescriptor = SdkFieldDescriptor(SerialKind.Map, XmlSerialName("map")) - val flatMapDescriptor = SdkFieldDescriptor(SerialKind.Map, XmlSerialName("flatMap"), Flattened) - val objDescriptor = SdkObjectDescriptor.build { - trait(XmlSerialName("object")) - field(mapDescriptor) - field(flatMapDescriptor) - } - - var map = mutableMapOf() - var flatMap = mutableMapOf() - - val deserializer = XmlDeserializer(payload) - deserializer.deserializeStruct(objDescriptor) { - loop@while (true) { - when (findNextFieldIndex()) { - mapDescriptor.index -> - map = - deserializer.deserializeMap(mapDescriptor) { - val map0 = mutableMapOf() - while (hasNextEntry()) { - val k0 = key() - val v0 = if (nextHasValue()) { deserializeInt() } else { deserializeNull(); continue } - map0[k0] = v0 - } - map0 - } - flatMapDescriptor.index -> - flatMap = - deserializer.deserializeMap(flatMapDescriptor) { - val map0 = mutableMapOf() - while (hasNextEntry()) { - val k0 = key() - val v0 = if (nextHasValue()) { deserializeInt() } else { deserializeNull(); continue } - map0[k0] = v0 - } - map0 - } - null -> break@loop - else -> skipValue() - } - } - } - map.shouldContainExactly(mapOf("key1" to 1, "key2" to 2)) - flatMap.shouldContainExactly(mapOf("key3" to 3, "key4" to 4)) - } -} - -internal class XmlMapsOperationDeserializer() { - - companion object { - private val MYMAP_DESCRIPTOR = SdkFieldDescriptor(SerialKind.Map, XmlSerialName("myMap")) - private val OBJ_DESCRIPTOR = SdkObjectDescriptor.build { - trait(XmlSerialName("XmlMapsInputOutput")) - field(MYMAP_DESCRIPTOR) - } - } - - fun deserialize(deserializer: XmlDeserializer): XmlMapsInputOutput { - val builder = XmlMapsInputOutput.dslBuilder() - - deserializer.deserializeStruct(OBJ_DESCRIPTOR) { - loop@while (true) { - when (findNextFieldIndex()) { - MYMAP_DESCRIPTOR.index -> - builder.myMap = - deserializer.deserializeMap(MYMAP_DESCRIPTOR) { - val map0 = mutableMapOf() - while (hasNextEntry()) { - val k0 = key() - val v0 = if (nextHasValue()) { GreetingStructDocumentDeserializer().deserialize(deserializer) } else { deserializeNull(); continue } - map0[k0] = v0 - } - map0 - } - null -> break@loop - else -> skipValue() - } - } - } - - return builder.build() - } -} - -internal class GreetingStructDocumentDeserializer { - - companion object { - private val HI_DESCRIPTOR = SdkFieldDescriptor(SerialKind.String, XmlSerialName("hi")) - private val OBJ_DESCRIPTOR = SdkObjectDescriptor.build { - trait(XmlSerialName("GreetingStruct")) - field(HI_DESCRIPTOR) - } - } - - fun deserialize(deserializer: Deserializer): GreetingStruct { - val builder = GreetingStruct.dslBuilder() - deserializer.deserializeStruct(OBJ_DESCRIPTOR) { - loop@while (true) { - when (findNextFieldIndex()) { - HI_DESCRIPTOR.index -> builder.hi = deserializeString() - null -> break@loop - else -> skipValue() - } - } - } - return builder.build() - } -} - -class XmlMapsInputOutput private constructor(builder: BuilderImpl) { - val myMap: Map? = builder.myMap - - companion object { - fun builder(): Builder = BuilderImpl() - - fun dslBuilder(): DslBuilder = BuilderImpl() - - operator fun invoke(block: DslBuilder.() -> kotlin.Unit): XmlMapsInputOutput = BuilderImpl().apply(block).build() - } - - override fun toString(): kotlin.String = buildString { - append("XmlMapsInputOutput(") - append("myMap=$myMap)") - } - - override fun hashCode(): kotlin.Int { - var result = myMap?.hashCode() ?: 0 - return result - } - - override fun equals(other: kotlin.Any?): kotlin.Boolean { - if (this === other) return true - - other as XmlMapsInputOutput - - if (myMap != other.myMap) return false - - return true - } - - fun copy(block: DslBuilder.() -> kotlin.Unit = {}): XmlMapsInputOutput = BuilderImpl(this).apply(block).build() - - interface Builder { - fun build(): XmlMapsInputOutput - fun myMap(myMap: Map): Builder - } - - interface DslBuilder { - var myMap: Map? - - fun build(): XmlMapsInputOutput - } - - private class BuilderImpl() : Builder, DslBuilder { - override var myMap: Map? = null - - constructor(x: XmlMapsInputOutput) : this() { - this.myMap = x.myMap - } - - override fun build(): XmlMapsInputOutput = XmlMapsInputOutput(this) - override fun myMap(myMap: Map): Builder = apply { this.myMap = myMap } - } -} - -class GreetingStruct private constructor(builder: BuilderImpl) { - val hi: String? = builder.hi - - companion object { - fun builder(): Builder = BuilderImpl() - - fun dslBuilder(): DslBuilder = BuilderImpl() - - operator fun invoke(block: DslBuilder.() -> kotlin.Unit): GreetingStruct = BuilderImpl().apply(block).build() - } - - override fun toString(): kotlin.String = buildString { - append("GreetingStruct(") - append("hi=$hi)") - } - - override fun hashCode(): kotlin.Int { - var result = hi?.hashCode() ?: 0 - return result - } - - override fun equals(other: kotlin.Any?): kotlin.Boolean { - if (this === other) return true - - other as GreetingStruct - - if (hi != other.hi) return false - - return true - } - - fun copy(block: DslBuilder.() -> kotlin.Unit = {}): GreetingStruct = BuilderImpl(this).apply(block).build() - - interface Builder { - fun build(): GreetingStruct - fun hi(hi: String): Builder - } - - interface DslBuilder { - var hi: String? - - fun build(): GreetingStruct - } - - private class BuilderImpl() : Builder, DslBuilder { - override var hi: String? = null - - constructor(x: GreetingStruct) : this() { - this.hi = x.hi - } - - override fun build(): GreetingStruct = GreetingStruct(this) - override fun hi(hi: String): Builder = apply { this.hi = hi } - } -} diff --git a/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/XmlDeserializerNamespaceTest.kt b/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/XmlDeserializerNamespaceTest.kt deleted file mode 100644 index becb7e347..000000000 --- a/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/XmlDeserializerNamespaceTest.kt +++ /dev/null @@ -1,107 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ -package aws.smithy.kotlin.runtime.serde.xml - -import aws.smithy.kotlin.runtime.serde.* -import kotlin.test.Test -import kotlin.test.assertEquals - -// See https://awslabs.github.io/smithy/spec/xml.html#xmlname-trait -class XmlDeserializerNamespaceTest { - - @Test - fun `it handles struct with namespace declarations but default tags`() { - val payload = """ - - example1 - example2 - - """.trimIndent().encodeToByteArray() - - val deserializer = XmlDeserializer(payload) - val bst = NamespaceStructTest.deserialize(deserializer) - - assertEquals("example1", bst.foo) - assertEquals("example2", bst.bar) - } - - class NamespaceStructTest { - var foo: String? = null - var bar: String? = null - - companion object { - val FOO_DESCRIPTOR = SdkFieldDescriptor(SerialKind.Integer, XmlSerialName("foo")) - val BAR_DESCRIPTOR = SdkFieldDescriptor(SerialKind.Integer, XmlSerialName("bar")) - val OBJ_DESCRIPTOR = SdkObjectDescriptor.build { - trait(XmlSerialName("MyStructure")) - trait(XmlNamespace("http://foo.com")) - field(FOO_DESCRIPTOR) - field(BAR_DESCRIPTOR) - } - - fun deserialize(deserializer: Deserializer): NamespaceStructTest { - val result = NamespaceStructTest() - deserializer.deserializeStruct(OBJ_DESCRIPTOR) { - loop@ while (true) { - when (findNextFieldIndex()) { - FOO_DESCRIPTOR.index -> result.foo = deserializeString() - BAR_DESCRIPTOR.index -> result.bar = deserializeString() - null -> break@loop - else -> throw DeserializationException(IllegalStateException("unexpected field in BasicStructTest deserializer")) - } - } - } - return result - } - } - } - - @Test - fun `it handles struct with node namespace`() { - val payload = """ - - example1 - example2 - - """.trimIndent().encodeToByteArray() - - val deserializer = XmlDeserializer(payload) - val bst = NodeNamespaceStructTest.deserialize(deserializer) - - assertEquals("example1", bst.foo) - assertEquals("example2", bst.bar) - } - - class NodeNamespaceStructTest { - var foo: String? = null - var bar: String? = null - - companion object { - val FOO_DESCRIPTOR = SdkFieldDescriptor(SerialKind.Integer, XmlSerialName("foo")) - val BAR_DESCRIPTOR = SdkFieldDescriptor(SerialKind.Integer, XmlSerialName("baz:bar")) - val OBJ_DESCRIPTOR = SdkObjectDescriptor.build { - trait(XmlSerialName("MyStructure")) - trait(XmlNamespace("http://foo.com", "baz")) - field(FOO_DESCRIPTOR) - field(BAR_DESCRIPTOR) - } - - fun deserialize(deserializer: Deserializer): NodeNamespaceStructTest { - val result = NodeNamespaceStructTest() - deserializer.deserializeStruct(OBJ_DESCRIPTOR) { - loop@ while (true) { - when (findNextFieldIndex()) { - FOO_DESCRIPTOR.index -> result.foo = deserializeString() - BAR_DESCRIPTOR.index -> result.bar = deserializeString() - null -> break@loop - else -> throw DeserializationException(IllegalStateException("unexpected field in BasicStructTest deserializer")) - } - } - } - return result - } - } - } -} diff --git a/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/XmlDeserializerPrimitiveTest.kt b/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/XmlDeserializerPrimitiveTest.kt deleted file mode 100644 index 26e02418f..000000000 --- a/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/XmlDeserializerPrimitiveTest.kt +++ /dev/null @@ -1,97 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ -package aws.smithy.kotlin.runtime.serde.xml - -import aws.smithy.kotlin.runtime.serde.* -import kotlin.math.abs -import kotlin.test.Test -import kotlin.test.assertEquals -import kotlin.test.assertFailsWith -import kotlin.test.assertTrue - -class XmlDeserializerPrimitiveTest { - @Test - fun itHandlesDoubles() { - val deserializer = XmlPrimitiveDeserializer("1.2".wrapInStruct(), SdkFieldDescriptor(SerialKind.Double, XmlSerialName("node"))) - val actual = deserializer.deserializeDouble() - val expected = 1.2 - assertTrue(abs(actual - expected) <= 0.0001) - } - - @Test - fun itHandlesFloats() { - val deserializer = XmlPrimitiveDeserializer("1.2".wrapInStruct(), SdkFieldDescriptor(SerialKind.Float, XmlSerialName("node"))) - val actual = deserializer.deserializeFloat() - val expected = 1.2f - assertTrue(abs(actual - expected) <= 0.0001f) - } - - @Test - fun itHandlesInt() { - val deserializer = XmlPrimitiveDeserializer("${Int.MAX_VALUE}".wrapInStruct(), SdkFieldDescriptor(SerialKind.Integer, XmlSerialName("node"))) - val actual = deserializer.deserializeInt() - val expected = 2147483647 - assertEquals(expected, actual) - } - - @Test - fun itHandlesByteAsNumber() { - val deserializer = XmlPrimitiveDeserializer("1".wrapInStruct(), SdkFieldDescriptor(SerialKind.Byte, XmlSerialName("node"))) - val actual = deserializer.deserializeByte() - val expected: Byte = 1 - assertEquals(expected, actual) - } - - @Test - fun itHandlesShort() { - val deserializer = XmlPrimitiveDeserializer("${Short.MAX_VALUE}".wrapInStruct(), SdkFieldDescriptor(SerialKind.Short, XmlSerialName("node"))) - val actual = deserializer.deserializeShort() - val expected: Short = 32767 - assertEquals(expected, actual) - } - - @Test - fun itHandlesLong() { - val deserializer = XmlPrimitiveDeserializer("${Long.MAX_VALUE}".wrapInStruct(), SdkFieldDescriptor(SerialKind.Long, XmlSerialName("node"))) - val actual = deserializer.deserializeLong() - val expected = 9223372036854775807L - assertEquals(expected, actual) - } - - @Test - fun itHandlesBool() { - val deserializer = XmlPrimitiveDeserializer("true".wrapInStruct(), SdkFieldDescriptor(SerialKind.Boolean, XmlSerialName("node"))) - val actual = deserializer.deserializeBoolean() - assertTrue(actual) - } - - @Test - fun itFailsInvalidTypeSpecificationForInt() { - val deserializer = XmlPrimitiveDeserializer("1.2".wrapInStruct(), SdkFieldDescriptor(SerialKind.Integer, XmlSerialName("node"))) - assertFailsWith(DeserializationException::class) { - deserializer.deserializeInt() - } - } - - // TODO: It's unclear if this test should result in an exception or null value. - @Test - fun itFailsMissingTypeSpecificationForInt() { - val deserializer = XmlPrimitiveDeserializer("".wrapInStruct(), SdkFieldDescriptor(SerialKind.Integer, XmlSerialName("node"))) - assertFailsWith(DeserializationException::class) { - deserializer.deserializeInt() - } - } - - // TODO: It's unclear if this test should result in an exception or null value. - @Test - fun itFailsWhitespaceTypeSpecificationForInt() { - val deserializer = XmlPrimitiveDeserializer(" ".wrapInStruct(), SdkFieldDescriptor(SerialKind.Integer, XmlSerialName("node"))) - assertFailsWith(DeserializationException::class) { - deserializer.deserializeInt() - } - } - - private fun String.wrapInStruct(): ByteArray = "$this".encodeToByteArray() -} diff --git a/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/XmlDeserializerStructTest.kt b/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/XmlDeserializerStructTest.kt deleted file mode 100644 index a500f0d66..000000000 --- a/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/XmlDeserializerStructTest.kt +++ /dev/null @@ -1,404 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ -package aws.smithy.kotlin.runtime.serde.xml - -import aws.smithy.kotlin.runtime.serde.* -import kotlin.test.Test -import kotlin.test.assertEquals - -class XmlDeserializerStructTest { - @Test - fun `it handles basic structs with attribs`() { - val payload = """ - - - - - """.trimIndent().encodeToByteArray() - - val deserializer = XmlDeserializer(payload) - val bst = StructWithAttribsClass.deserialize(deserializer) - - assertEquals(1, bst.x) - assertEquals(2, bst.y) - } - - @Test - fun `it handles basic structs with multi attribs and text`() { - val payload = """ - - - - - nodeval - - """.trimIndent().encodeToByteArray() - - val deserializer = XmlDeserializer(payload) - val bst = StructWithMultiAttribsAndTextValClass.deserialize(deserializer) - - assertEquals(1, bst.x) - assertEquals(2, bst.y) - assertEquals("nodeval", bst.txt) - } - - @Test - fun itHandlesBasicStructsWithAttribsAndText() { - val payload = """ - - x1 - - true - - """.encodeToByteArray() - - val deserializer = XmlDeserializer(payload) - val bst = BasicAttribTextStructTest.deserialize(deserializer) - - assertEquals(1, bst.xa) - assertEquals("x1", bst.xt) - assertEquals(2, bst.y) - assertEquals(1, bst.unknownFieldCount) - } - - class BasicAttribTextStructTest { - var xa: Int? = null - var xt: String? = null - var y: Int? = null - var z: Boolean? = null - var unknownFieldCount: Int = 0 - - companion object { - val X_ATTRIB_DESCRIPTOR = SdkFieldDescriptor(SerialKind.Integer, XmlSerialName("xa"), XmlAttribute) - val X_VALUE_DESCRIPTOR = SdkFieldDescriptor(SerialKind.Integer, XmlSerialName("x")) - val Y_DESCRIPTOR = SdkFieldDescriptor(SerialKind.Integer, XmlSerialName("ya"), XmlAttribute) - val Z_DESCRIPTOR = SdkFieldDescriptor(SerialKind.Boolean, XmlSerialName("z")) - val OBJ_DESCRIPTOR = SdkObjectDescriptor.build { - trait(XmlSerialName("payload")) - field(X_ATTRIB_DESCRIPTOR) - field(X_VALUE_DESCRIPTOR) - field(Y_DESCRIPTOR) - field(Z_DESCRIPTOR) - } - - fun deserialize(deserializer: Deserializer): BasicAttribTextStructTest { - val result = BasicAttribTextStructTest() - deserializer.deserializeStruct(OBJ_DESCRIPTOR) { - loop@ while (true) { - when (findNextFieldIndex()) { - X_ATTRIB_DESCRIPTOR.index -> result.xa = deserializeInt() - X_VALUE_DESCRIPTOR.index -> result.xt = deserializeString() - Y_DESCRIPTOR.index -> result.y = deserializeInt() - Z_DESCRIPTOR.index -> result.z = deserializeBoolean() - null -> break@loop - Deserializer.FieldIterator.UNKNOWN_FIELD -> { - result.unknownFieldCount++ - skipValue() - } - else -> throw DeserializationException(IllegalStateException("unexpected field in BasicStructTest deserializer")) - } - } - } - return result - } - } - } - - @Test - fun itHandlesBasicStructs() { - val payload = """ - - 1 - 2 - - """.encodeToByteArray() - - val deserializer = XmlDeserializer(payload) - val bst = SimpleStructClass.deserialize(deserializer) - - assertEquals(1, bst.x) - assertEquals(2, bst.y) - } - - @Test - fun itHandlesBasicStructsWithNullValues() { - val payload1 = """ - - a - - - """.encodeToByteArray() - - val deserializer = XmlDeserializer(payload1) - val bst = SimpleStructOfStringsClass.deserialize(deserializer) - - assertEquals("a", bst.x) - assertEquals("", bst.y) - - val payload2 = """ - - - 2 - - """.encodeToByteArray() - - val deserializer2 = XmlDeserializer(payload2) - val bst2 = SimpleStructOfStringsClass.deserialize(deserializer2) - - assertEquals("", bst2.x) - assertEquals("2", bst2.y) - } - - @Test - fun itEnumeratesUnknownStructFields() { - val payload = """ - - 1 - 2 - - """.encodeToByteArray() - - val deserializer = XmlDeserializer(payload) - val bst = SimpleStructClass.deserialize(deserializer) - - assertEquals(1, bst.x) - assertEquals(2, bst.y) - assertEquals("strval", bst.z) - } - - @Test - fun itHandlesNestedXmlStructures() { - val payload = """ - - - Foo1 - - Bar1 - - Foo2 - - Bar2 - - - - - - """.encodeToByteArray() - - val deserializer = XmlDeserializer(payload) - val bst = RecursiveShapesOperationDeserializer().deserialize(deserializer) - - println(bst.nested?.nested) - } - - class BasicUnwrappedStructTest { - var x: String? = null - - companion object { - val X_DESCRIPTOR = SdkFieldDescriptor(SerialKind.String, XmlSerialName("x")) - private val OBJ_DESCRIPTOR = SdkObjectDescriptor.build { - trait(XmlSerialName("payload")) - trait(XmlUnwrappedOutput) - field(X_DESCRIPTOR) - } - - fun deserialize(deserializer: Deserializer): BasicUnwrappedStructTest { - val result = BasicUnwrappedStructTest() - deserializer.deserializeStruct(OBJ_DESCRIPTOR) { - loop@ while (true) { - when (findNextFieldIndex()) { - X_DESCRIPTOR.index -> result.x = deserializeString() - null -> break@loop - else -> throw DeserializationException(IllegalStateException("unexpected field in BasicUnwrappedStructTest deserializer")) - } - } - } - return result - } - } - } - - @Test - fun itHandlesBasicUnwrappedStructs() { - val payload = """ - text - """.encodeToByteArray() - - val deserializer = XmlDeserializer(payload) - val bst = BasicUnwrappedStructTest.deserialize(deserializer) - - assertEquals("text", bst.x) - } - - @Test - fun itHandlesBasicUnwrappedStructsWithNullValues() { - val payload = """ - - """.encodeToByteArray() - - val deserializer = XmlDeserializer(payload) - val bst = BasicUnwrappedStructTest.deserialize(deserializer) - - assertEquals(null, bst.x) - } - - class AliasStruct { - var message: String? = null - var attribute: String? = null - - companion object { - val MESSAGE_DESCRIPTOR = SdkFieldDescriptor( - SerialKind.String, - XmlSerialName("Message"), - XmlAliasName("message"), - XmlAliasName("msg"), - ) - val ATTRIBUTE_DESCRIPTOR = SdkFieldDescriptor( - SerialKind.String, - XmlAttribute, - XmlSerialName("Attribute"), - XmlAliasName("attribute"), - XmlAliasName("attr"), - ) - val OBJ_DESCRIPTOR = SdkObjectDescriptor.build { - trait(XmlSerialName("Struct")) - trait(XmlAliasName("struct")) - field(MESSAGE_DESCRIPTOR) - field(ATTRIBUTE_DESCRIPTOR) - } - - fun deserialize(deserializer: Deserializer): AliasStruct { - val result = AliasStruct() - deserializer.deserializeStruct(OBJ_DESCRIPTOR) { - loop@ while (true) { - when (findNextFieldIndex()) { - MESSAGE_DESCRIPTOR.index -> result.message = deserializeString() - ATTRIBUTE_DESCRIPTOR.index -> result.attribute = deserializeString() - null -> break@loop - else -> throw DeserializationException(IllegalStateException("unexpected field in AliasStruct deserializer")) - } - } - } - return result - } - } - } - - @Test - fun itHandlesAliasMatchingOnElements() { - val tests = listOf( - "Hi there", - "Hi there", - "Hi there", - "Hi there", - ) - tests.forEach { payload -> - val deserializer = XmlDeserializer(payload.encodeToByteArray()) - val bst = AliasStruct.deserialize(deserializer) - assertEquals("Hi there", bst.message, "Can't find 'Hi there' in $payload") - } - } - - @Test - fun itHandlesAliasMatchingOnAttributes() { - val tests = listOf( - """""", - """""", - """""", - ) - tests.forEach { payload -> - val deserializer = XmlDeserializer(payload.encodeToByteArray()) - val bst = AliasStruct.deserialize(deserializer) - assertEquals("Hi there", bst.attribute, "Can't find 'Hi there' in $payload") - } - } -} - -internal class RecursiveShapesOperationDeserializer { - - companion object { - private val NESTED_DESCRIPTOR = SdkFieldDescriptor(SerialKind.Struct, XmlSerialName("nested")) - private val OBJ_DESCRIPTOR = SdkObjectDescriptor.build { - trait(XmlSerialName("RecursiveShapesInputOutput")) - field(NESTED_DESCRIPTOR) - } - } - - fun deserialize(deserializer: Deserializer): RecursiveShapesInputOutput { - val builder = RecursiveShapesInputOutput.Builder() - - deserializer.deserializeStruct(OBJ_DESCRIPTOR) { - loop@while (true) { - when (findNextFieldIndex()) { - NESTED_DESCRIPTOR.index -> builder.nested = RecursiveShapesInputOutputNested1DocumentDeserializer().deserialize(deserializer) - null -> break@loop - else -> skipValue() - } - } - } - - return builder.build() - } -} - -internal class RecursiveShapesInputOutputNested1DocumentDeserializer { - - companion object { - private val FOO_DESCRIPTOR = SdkFieldDescriptor(SerialKind.String, XmlSerialName("foo")) - private val NESTED_DESCRIPTOR = SdkFieldDescriptor(SerialKind.Struct, XmlSerialName("nested")) - private val OBJ_DESCRIPTOR = SdkObjectDescriptor.build { - field(FOO_DESCRIPTOR) - field(NESTED_DESCRIPTOR) - } - } - - fun deserialize(deserializer: Deserializer): RecursiveShapesInputOutputNested1 { - val builder = RecursiveShapesInputOutputNested1.dslBuilder() - deserializer.deserializeStruct(OBJ_DESCRIPTOR) { - loop@while (true) { - when (findNextFieldIndex()) { - FOO_DESCRIPTOR.index -> builder.foo = deserializeString() - NESTED_DESCRIPTOR.index -> builder.nested = RecursiveShapesInputOutputNested2DocumentDeserializer().deserialize(deserializer) - null -> break@loop - else -> skipValue() - } - } - } - return builder.build() - } -} - -internal class RecursiveShapesInputOutputNested2DocumentDeserializer { - - companion object { - private val BAR_DESCRIPTOR = SdkFieldDescriptor(SerialKind.String, XmlSerialName("bar")) - private val RECURSIVEMEMBER_DESCRIPTOR = SdkFieldDescriptor(SerialKind.Struct, XmlSerialName("recursiveMember")) - private val OBJ_DESCRIPTOR = SdkObjectDescriptor.build { - field(BAR_DESCRIPTOR) - field(RECURSIVEMEMBER_DESCRIPTOR) - } - } - - fun deserialize(deserializer: Deserializer): RecursiveShapesInputOutputNested2 { - val builder = RecursiveShapesInputOutputNested2.dslBuilder() - deserializer.deserializeStruct(OBJ_DESCRIPTOR) { - loop@while (true) { - when (findNextFieldIndex()) { - BAR_DESCRIPTOR.index -> builder.bar = deserializeString() - RECURSIVEMEMBER_DESCRIPTOR.index -> builder.recursiveMember = RecursiveShapesInputOutputNested1DocumentDeserializer().deserialize(deserializer) - null -> break@loop - else -> skipValue() - } - } - } - return builder.build() - } -} diff --git a/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/XmlSerializerTest.kt b/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/XmlSerializerTest.kt deleted file mode 100644 index ad60d281c..000000000 --- a/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/XmlSerializerTest.kt +++ /dev/null @@ -1,963 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ -package aws.smithy.kotlin.runtime.serde.xml - -import aws.smithy.kotlin.runtime.serde.* -import kotlin.test.Test -import kotlin.test.assertEquals - -/* -Remove all whitespace and newline chars from XML string and return the compact form -e.g. - -``` - - - 1 - - -``` - -becomes: `1` - */ -private fun String.toXmlCompactString(): String = - trimIndent() - .replace("\n", "") - .replace(Regex(">\\s+"), ">") - -class XmlSerializerTest { - - @Test - fun canSerializeClassWithClassField() { - val a = A( - B(2), - ) - val xml = XmlSerializer() - a.serialize(xml) - assertEquals("""2""", xml.toByteArray().decodeToString()) - } - - class A(private val b: B) : SdkSerializable { - companion object { - val descriptorB: SdkFieldDescriptor = SdkFieldDescriptor(SerialKind.Struct, XmlSerialName("b")) - - val objectDescriptor: SdkObjectDescriptor = SdkObjectDescriptor.build { - trait(XmlSerialName("a")) - field(descriptorB) - } - } - - override fun serialize(serializer: Serializer) { - serializer.serializeStruct(objectDescriptor) { - field(descriptorB, b) - } - } - } - - data class B(private val value: Int) : SdkSerializable { - companion object { - val descriptorValue = SdkFieldDescriptor(SerialKind.Integer, XmlSerialName("v")) - - val objectDescriptor: SdkObjectDescriptor = SdkObjectDescriptor.build { - trait(XmlSerialName("b")) - field(descriptorValue) - } - } - - override fun serialize(serializer: Serializer) { - serializer.serializeStruct(objectDescriptor) { - field(descriptorValue, value) - } - } - } - - @Test - fun canSerializePrimitiveList() { - // https://awslabs.github.io/smithy/spec/xml.html#wrapped-list-serialization - val list = listOf("example1", "example2", "example3") - val xml = XmlSerializer() - val listDescriptor = SdkFieldDescriptor(SerialKind.List, XmlSerialName("values")) - val objDescriptor = SdkObjectDescriptor.build { - trait(XmlSerialName("Foo")) - field(listDescriptor) - } - - xml.serializeStruct(objDescriptor) { - listField(listDescriptor) { - for (value in list) { - serializeString(value) - } - } - } - - val expected = """ - - - example1 - example2 - example3 - - - """.toXmlCompactString() - assertEquals(expected, xml.toByteArray().decodeToString()) - } - - @Test - fun canSerializeRenamedList() { - val list = listOf("example1", "example2", "example3") - val xml = XmlSerializer() - val listDescriptor = SdkFieldDescriptor(SerialKind.List, XmlSerialName("values"), XmlCollectionName("Item")) - val objDescriptor = SdkObjectDescriptor.build { - trait(XmlSerialName("Foo")) - field(listDescriptor) - } - - xml.serializeStruct(objDescriptor) { - listField(listDescriptor) { - for (value in list) { - serializeString(value) - } - } - } - - val expected = """ - - - example1 - example2 - example3 - - - """.toXmlCompactString() - assertEquals(expected, xml.toByteArray().decodeToString()) - } - - @Test - fun canSerializeFlattenedList() { - // https://awslabs.github.io/smithy/spec/xml.html#flattened-list-serialization - val list = listOf("example1", "example2", "example3") - val xml = XmlSerializer() - val listDescriptor = SdkFieldDescriptor(SerialKind.List, XmlSerialName("flat"), Flattened) - val objDescriptor = SdkObjectDescriptor.build { - trait(XmlSerialName("Foo")) - field(listDescriptor) - } - - xml.serializeStruct(objDescriptor) { - listField(listDescriptor) { - for (value in list) { - serializeString(value) - } - } - } - - val expected = """ - - example1 - example2 - example3 - - """.toXmlCompactString() - assertEquals(expected, xml.toByteArray().decodeToString()) - } - - @Test - fun canSerializeListOfClasses() { - val obj = listOf( - B(1), - B(2), - B(3), - ) - val xml = XmlSerializer() - xml.serializeList(SdkFieldDescriptor(SerialKind.List, XmlSerialName("list"))) { - for (value in obj) { - serializeSdkSerializable(value) - } - } - - val expected = """ - - - 1 - - - 2 - - - 3 - - - """.toXmlCompactString() - assertEquals(expected, xml.toByteArray().decodeToString()) - } - - @Test - fun canSerializeFlatListOfClasses() { - val obj = listOf( - B(1), - B(2), - B(3), - ) - val xml = XmlSerializer() - xml.serializeList(SdkFieldDescriptor(SerialKind.List, XmlSerialName("list"), Flattened)) { - for (value in obj) { - serializeSdkSerializable(value) - } - } - val expected = """ - - 1 - - - 2 - - - 3 - - """.toXmlCompactString() - assertEquals(expected, xml.toByteArray().decodeToString()) - } - - @Test - fun canSerializeMap() { - // See https://awslabs.github.io/smithy/spec/xml.html#wrapped-map-serialization - val foo = Foo( - mapOf( - "example-key1" to "example1", - "example-key2" to "example2", - ), - ) - val xml = XmlSerializer() - foo.serialize(xml) - - val expected = """ - - - - example-key1 - example1 - - - example-key2 - example2 - - - - """.toXmlCompactString() - - assertEquals(expected, xml.toByteArray().decodeToString()) - } - - @Test - fun canSerializeFlattenedMap() { - // See https://awslabs.github.io/smithy/spec/xml.html#flattened-map-serialization - val bar = Bar( - mapOf( - "example-key1" to "example1", - "example-key2" to "example2", - "example-key3" to "example3", - ), - ) - val serializer = XmlSerializer() - bar.serialize(serializer) - - val expected = """ - - - example-key1 - example1 - - - - example-key2 - example2 - - - - example-key3 - example3 - - - """.toXmlCompactString() - - assertEquals(expected, serializer.toByteArray().decodeToString()) - } - - @Test - fun canSerializeMapOfLists() { - val objs = mapOf( - "A1" to listOf("a", "b", "c"), - "A2" to listOf("d", "e", "f"), - "A3" to listOf("g", "h", "i"), - ) - val xml = XmlSerializer() - xml.serializeMap(SdkFieldDescriptor(SerialKind.Map, XmlSerialName("objs"))) { - for (obj in objs) { - listEntry(obj.key, SdkFieldDescriptor(SerialKind.List, XmlSerialName("elements"))) { - for (v in obj.value) { - serializeString(v) - } - } - } - } - - val expected = """ - - - A1 - - - a - b - c - - - - - A2 - - - d - e - f - - - - - A3 - - - g - h - i - - - - - """.toXmlCompactString() - assertEquals(expected, xml.toByteArray().decodeToString()) - } - - @Test - fun canSerializeListOfLists() { - val objs = listOf( - listOf("a", "b", "c"), - listOf("d", "e", "f"), - listOf("g", "h", "i"), - ) - val xml = XmlSerializer() - xml.serializeList(SdkFieldDescriptor(SerialKind.List, XmlSerialName("objs"))) { - for (obj in objs) { - xml.serializeList(SdkFieldDescriptor(SerialKind.List, XmlSerialName("elements"))) { - for (v in obj) { - serializeString(v) - } - } - } - } - - val expected = """ - - - a - b - c - - - d - e - f - - - g - h - i - - - """.toXmlCompactString() - assertEquals(expected, xml.toByteArray().decodeToString()) - } - - @Test - fun canSerializeListOfMaps() { - val objs = listOf( - mapOf("a" to "b", "c" to "d"), - mapOf("e" to "f", "g" to "h"), - mapOf("i" to "j", "k" to "l"), - ) - val xml = XmlSerializer() - xml.serializeList(SdkFieldDescriptor(SerialKind.List, XmlSerialName("elements"))) { - for (obj in objs) { - xml.serializeMap(SdkFieldDescriptor(SerialKind.Map, XmlSerialName("entries"))) { - for (v in obj) { - entry(v.key, v.value) - } - } - } - } - val expected = """ - - - - a - b - - - c - d - - - - - e - f - - - g - h - - - - - i - j - - - k - l - - - - """.toXmlCompactString() - assertEquals(expected, xml.toByteArray().decodeToString()) - } - - @Test - fun canSerializeMapOfMaps() { - val objs = mapOf( - "A1" to mapOf("a" to "b", "c" to "d"), - "A2" to mapOf("e" to "f", "g" to "h"), - "A3" to mapOf("i" to "j", "k" to "l"), - ) - val serializer = XmlSerializer() - serializer.serializeMap(SdkFieldDescriptor(SerialKind.Map, XmlSerialName("objs"))) { - for (obj in objs) { - mapEntry(obj.key, SdkFieldDescriptor(SerialKind.Map)) { - for (v in obj.value) { - entry(v.key, v.value) - } - } - } - } - - // NOTE the child map entries do not have a surrounding tag around them, much like a map of structs omit the - // structure tag - val expected = """ - - - A1 - - - a - b - - - c - d - - - - - A2 - - - e - f - - - g - h - - - - - A3 - - - i - j - - - k - l - - - - - """.toXmlCompactString() - assertEquals(expected, serializer.toByteArray().decodeToString()) - } - - @Test - fun canSerializeMapOfStructs() { - val objs = mapOf( - "foo" to B(1), - "bar" to B(2), - ) - - val serializer = XmlSerializer() - - serializer.serializeMap(SdkFieldDescriptor(SerialKind.Map, XmlSerialName("myMap"))) { - objs.entries.forEach { (key, value) -> entry(key, value) } - } - - val expected = """ - - - foo - - 1 - - - - bar - - 2 - - - - """.toXmlCompactString() - assertEquals(expected, serializer.toByteArray().decodeToString()) - } - - class Bar(var flatMap: Map? = null) : SdkSerializable { - companion object { - val FLAT_MAP_DESCRIPTOR = SdkFieldDescriptor(SerialKind.Map, XmlSerialName("flatMap"), XmlMapName(entry = "flatMap"), Flattened) - val OBJ_DESCRIPTOR = SdkObjectDescriptor.build { - trait(XmlSerialName("Bar")) - field(FLAT_MAP_DESCRIPTOR) - } - } - - override fun serialize(serializer: Serializer) { - serializer.serializeStruct(OBJ_DESCRIPTOR) { - mapField(FLAT_MAP_DESCRIPTOR) { - for (value in flatMap!!) { - entry(value.key, value.value) - } - } - } - } - } - - class Foo(var values: Map? = null) : SdkSerializable { - companion object { - val FLAT_MAP_DESCRIPTOR = SdkFieldDescriptor(SerialKind.Map, XmlSerialName("values"), XmlMapName(entry = "entry")) - val OBJ_DESCRIPTOR = SdkObjectDescriptor.build { - trait(XmlSerialName("Foo")) - field(FLAT_MAP_DESCRIPTOR) - } - } - - override fun serialize(serializer: Serializer) { - serializer.serializeStruct(OBJ_DESCRIPTOR) { - mapField(FLAT_MAP_DESCRIPTOR) { - for (value in values!!) { - entry(value.key, value.value) - } - } - } - } - } - - @Test - fun canSerializeAllPrimitives() { - val xml = XmlSerializer() - val data = Primitives( - true, 10, 20, 30, 40, 50f, 60.0, 'A', "Str0", - listOf(1, 2, 3), - ) - data.serialize(xml) - - assertEquals("""true1020304050.060.0AStr0123""", xml.toByteArray().decodeToString()) - } - - @Test - fun canSerializeNamespaces() { - // See https://awslabs.github.io/smithy/spec/xml.html#xmlnamespace-trait - val myStructure = MyStructure1("example", "example") - val xml = XmlSerializer() - myStructure.serialize(xml) - val expected1 = """ - - example - example - - """.toXmlCompactString() - assertEquals(expected1, xml.toByteArray().decodeToString()) - - val myStructure2 = MyStructure2("example", "example") - val xml2 = XmlSerializer() - myStructure2.serialize(xml2) - val expected2 = """ - - example - example - - """.toXmlCompactString() - assertEquals(expected2, xml2.toByteArray().decodeToString()) - } - - @Test - fun canSerializeNestedNamespaces() { - val input = XmlNamespacesRequest( - nested = XmlNamespaceNested( - foo = "Foo", - values = listOf("Bar", "Baz"), - ), - ) - - val serializer = XmlSerializer() - input.serialize(serializer) - - val expected = """ - - - Foo - - Bar - Baz - - - - """.toXmlCompactString() - - assertEquals(expected, serializer.toByteArray().decodeToString()) - } - - class MyStructure1(private val foo: String, private val bar: String) : SdkSerializable { - companion object { - val fooDescriptor: SdkFieldDescriptor = SdkFieldDescriptor(SerialKind.Struct, XmlSerialName("foo")) - val barDescriptor: SdkFieldDescriptor = SdkFieldDescriptor(SerialKind.Struct, XmlSerialName("bar")) - - val objectDescriptor: SdkObjectDescriptor = SdkObjectDescriptor.build { - trait(XmlSerialName("MyStructure")) - trait(XmlNamespace("http://foo.com")) - field(fooDescriptor) - field(barDescriptor) - } - } - - override fun serialize(serializer: Serializer) { - serializer.serializeStruct(objectDescriptor) { - field(fooDescriptor, foo) - field(barDescriptor, bar) - } - } - } - - class MyStructure2(private val foo: String, private val bar: String) : SdkSerializable { - companion object { - val fooDescriptor: SdkFieldDescriptor = SdkFieldDescriptor(SerialKind.Struct, XmlSerialName("foo")) - val barDescriptor: SdkFieldDescriptor = SdkFieldDescriptor(SerialKind.Struct, XmlSerialName("baz:bar")) - - val objectDescriptor: SdkObjectDescriptor = SdkObjectDescriptor.build { - trait(XmlSerialName("MyStructure")) - trait(XmlNamespace("http://foo.com", "baz")) - - field(fooDescriptor) - field(barDescriptor) - } - } - - override fun serialize(serializer: Serializer) { - serializer.serializeStruct(objectDescriptor) { - field(fooDescriptor, foo) - field(barDescriptor, bar) - } - } - } - - @Test - fun canIgnoresNestedStructNamespaces() { - /* - @xmlNamespace(uri: "http://foo.com") - structure Foo { - nested: Bar, - } - - // Ignored - not at top level - // TODO - nothing in the spec defines this...only the protocol tests - @xmlNamespace(uri: "http://bar.com") - structure Bar { - x: String - } - */ - - val serializer = XmlSerializer() - val nestedDescriptor = SdkFieldDescriptor(SerialKind.Struct, XmlSerialName("nested")) - val objDescriptor = SdkObjectDescriptor.build { - trait(XmlSerialName("Foo")) - trait(XmlNamespace("http://foo.com")) - field(nestedDescriptor) - } - - val nested = object : SdkSerializable { - override fun serialize(serializer: Serializer) { - val xDescriptor = SdkFieldDescriptor(SerialKind.String, XmlSerialName("x")) - val obj2Descriptor = SdkObjectDescriptor.build { - trait(XmlSerialName("Bar")) - trait(XmlNamespace("http://bar.com")) - field(xDescriptor) - } - serializer.serializeStruct(obj2Descriptor) { - field(xDescriptor, "blerg") - } - } - } - - serializer.serializeStruct(objDescriptor) { - field(nestedDescriptor, nested) - } - - val expected = """ - - - blerg - - - """.toXmlCompactString() - - assertEquals(expected, serializer.toByteArray().decodeToString()) - } - - @Test - fun itSerializesRecursiveShapes() { - val expected = """ - - - Foo1 - - Bar1 - - Foo2 - - Bar2 - - - - - - """.toXmlCompactString() - - val input = RecursiveShapesInputOutput { - nested = RecursiveShapesInputOutputNested1 { - foo = "Foo1" - nested = RecursiveShapesInputOutputNested2 { - bar = "Bar1" - recursiveMember = RecursiveShapesInputOutputNested1 { - foo = "Foo2" - nested = RecursiveShapesInputOutputNested2 { - bar = "Bar2" - } - } - } - } - } - - val serializer = XmlSerializer() - RecursiveShapesInputOutputSerializer().serialize(serializer, input) - val actual = serializer.toByteArray().decodeToString() - println(actual) - assertEquals(expected, actual) - } - - @Test - fun itCanSerializeAttributes() { - val boolDescriptor = SdkFieldDescriptor(SerialKind.Boolean, XmlSerialName("bool"), XmlAttribute) - val strDescriptor = SdkFieldDescriptor(SerialKind.Boolean, XmlSerialName("str"), XmlAttribute) - val intDescriptor = SdkFieldDescriptor(SerialKind.Boolean, XmlSerialName("number"), XmlAttribute) - // timestamps are ignored as they aren't special cased (as of right now) but rather serialized through string/raw - - val objDescriptor = SdkObjectDescriptor.build { - trait(XmlSerialName("Foo")) - field(boolDescriptor) - field(strDescriptor) - field(intDescriptor) - } - - // NOTE: attribute fields MUST be generated as the first fields after serializeStruct() to work properly - val serializer = XmlSerializer() - serializer.serializeStruct(objDescriptor) { - field(boolDescriptor, true) - field(strDescriptor, "bar") - field(intDescriptor, 2) - } - - val expected = """ - - """.toXmlCompactString() - - assertEquals(expected, serializer.toByteArray().decodeToString()) - } - - @Test - fun itCanSerializeAttributesWithNamespaces() { - val nestedDescriptor = SdkFieldDescriptor(SerialKind.Struct, XmlSerialName("nestedField"), XmlNamespace("https://example.com", "xsi")) - - val objDescriptor = SdkObjectDescriptor.build { - trait(XmlSerialName("Foo")) - field(nestedDescriptor) - } - - val nestedSerializer = object : SdkSerializable { - override fun serialize(serializer: Serializer) { - val attrDescriptor = SdkFieldDescriptor(SerialKind.String, XmlSerialName("xsi:myAttr"), XmlAttribute) - val nestedObjDescriptor = SdkObjectDescriptor.build { - trait(XmlSerialName("Nested")) - field(attrDescriptor) - } - serializer.serializeStruct(nestedObjDescriptor) { - field(attrDescriptor, "nestedAttrValue") - } - } - } - - // NOTE: attribute fields MUST be generated as the first fields after serializeStruct() to work properly - val serializer = XmlSerializer() - serializer.serializeStruct(objDescriptor) { - field(nestedDescriptor, nestedSerializer) - } - - // The order these attributes come out in exactly the order they're put in (as defined by XmlSerializer). - val expected = """ - - - - """.toXmlCompactString() - - assertEquals(expected, serializer.toByteArray().decodeToString()) - } -} - -data class Primitives( - // val unit: Unit, - val boolean: Boolean, - val byte: Byte, - val short: Short, - val int: Int, - val long: Long, - val float: Float, - val double: Double, - val char: Char, - val string: String, - // val unitNullable: Unit?, - val listInt: List, -) : SdkSerializable { - companion object { - val descriptorBoolean = SdkFieldDescriptor(SerialKind.Boolean, XmlSerialName("boolean")) - val descriptorByte = SdkFieldDescriptor(SerialKind.Byte, XmlSerialName("byte")) - val descriptorShort = SdkFieldDescriptor(SerialKind.Short, XmlSerialName("short")) - val descriptorInt = SdkFieldDescriptor(SerialKind.Integer, XmlSerialName("int")) - val descriptorLong = SdkFieldDescriptor(SerialKind.Long, XmlSerialName("long")) - val descriptorFloat = SdkFieldDescriptor(SerialKind.Float, XmlSerialName("float")) - val descriptorDouble = SdkFieldDescriptor(SerialKind.Double, XmlSerialName("double")) - val descriptorChar = SdkFieldDescriptor(SerialKind.Char, XmlSerialName("char")) - val descriptorString = SdkFieldDescriptor(SerialKind.String, XmlSerialName("string")) - - // val descriptorUnitNullable = SdkFieldDescriptor("unitNullable") - val descriptorListInt = SdkFieldDescriptor(SerialKind.List, XmlSerialName("listInt"), XmlCollectionName(element = "number")) - } - - override fun serialize(serializer: Serializer) { - serializer.serializeStruct(SdkFieldDescriptor(SerialKind.Struct, XmlSerialName("struct"))) { - serializeNull() - field(descriptorBoolean, boolean) - field(descriptorByte, byte) - field(descriptorShort, short) - field(descriptorInt, int) - field(descriptorLong, long) - field(descriptorFloat, float) - field(descriptorDouble, double) - field(descriptorChar, char) - field(descriptorString, string) - // serializeNull(descriptorUnitNullable) - listField(descriptorListInt) { - for (value in listInt) { - serializeInt(value) - } - } - } - } -} - -// structure RecursiveShapesInputOutput { -// nested: RecursiveShapesInputOutputNested1 -// } -// -// structure RecursiveShapesInputOutputNested1 { -// foo: String, -// nested: RecursiveShapesInputOutputNested2 -// } -// -// structure RecursiveShapesInputOutputNested2 { -// bar: String, -// recursiveMember: RecursiveShapesInputOutputNested1, -// } -internal class RecursiveShapesInputOutputSerializer { - companion object { - private val NESTED_DESCRIPTOR = SdkFieldDescriptor(SerialKind.Struct, XmlSerialName("nested")) - private val OBJ_DESCRIPTOR = SdkObjectDescriptor.build { - trait(XmlSerialName("RecursiveShapesInputOutput")) - field(NESTED_DESCRIPTOR) - } - } - - fun serialize(serializer: Serializer, input: RecursiveShapesInputOutput) { - serializer.serializeStruct(OBJ_DESCRIPTOR) { - input.nested?.let { field(NESTED_DESCRIPTOR, RecursiveShapesInputOutputNested1DocumentSerializer(it)) } - } - } -} - -internal class RecursiveShapesInputOutputNested1DocumentSerializer(val input: RecursiveShapesInputOutputNested1) : SdkSerializable { - - companion object { - private val FOO_DESCRIPTOR = SdkFieldDescriptor(SerialKind.String, XmlSerialName("foo")) - private val NESTED_DESCRIPTOR = SdkFieldDescriptor(SerialKind.Struct, XmlSerialName("nested")) - private val OBJ_DESCRIPTOR = SdkObjectDescriptor.build { - trait(XmlSerialName("RecursiveShapesInputOutputNested1")) - field(FOO_DESCRIPTOR) - field(NESTED_DESCRIPTOR) - } - } - - override fun serialize(serializer: Serializer) { - serializer.serializeStruct(OBJ_DESCRIPTOR) { - input.foo?.let { field(FOO_DESCRIPTOR, it) } - input.nested?.let { field(NESTED_DESCRIPTOR, RecursiveShapesInputOutputNested2DocumentSerializer(it)) } - } - } -} - -internal class RecursiveShapesInputOutputNested2DocumentSerializer(val input: RecursiveShapesInputOutputNested2) : SdkSerializable { - - companion object { - private val BAR_DESCRIPTOR = SdkFieldDescriptor(SerialKind.String, XmlSerialName("bar")) - private val RECURSIVEMEMBER_DESCRIPTOR = SdkFieldDescriptor(SerialKind.Struct, XmlSerialName("recursiveMember")) - private val OBJ_DESCRIPTOR = SdkObjectDescriptor.build { - trait(XmlSerialName("RecursiveShapesInputOutputNested2")) - field(BAR_DESCRIPTOR) - field(RECURSIVEMEMBER_DESCRIPTOR) - } - } - - override fun serialize(serializer: Serializer) { - serializer.serializeStruct(OBJ_DESCRIPTOR) { - input.bar?.let { field(BAR_DESCRIPTOR, it) } - input.recursiveMember?.let { field(RECURSIVEMEMBER_DESCRIPTOR, RecursiveShapesInputOutputNested1DocumentSerializer(it)) } - } - } -} diff --git a/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/XmlStreamReaderTest.kt b/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/XmlStreamReaderTest.kt index fc1ef5dca..26f129c8e 100644 --- a/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/XmlStreamReaderTest.kt +++ b/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/XmlStreamReaderTest.kt @@ -193,8 +193,7 @@ class XmlStreamReaderTest { assertEquals(expected, actual) } - @Test - fun itSkipsValuesRecursively() { + private fun skipTest() { val payload = """ 1> @@ -226,15 +225,19 @@ class XmlStreamReaderTest { } val nt = reader.peek() - assertIs(nt) + assertIs(nt) assertEquals("unknown", nt.name.local) + reader.skipNext() val y = reader.nextToken() as XmlToken.BeginElement assertEquals("y", y.name.local) } + @Test + fun itSkipsNextValuesRecursively() = skipTest() + @Test fun itSkipsSimpleValues() { val payload = """ diff --git a/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/XmlTagReaderTest.kt b/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/XmlTagReaderTest.kt new file mode 100644 index 000000000..2e0c07b29 --- /dev/null +++ b/runtime/serde/serde-xml/common/test/aws/smithy/kotlin/runtime/serde/xml/XmlTagReaderTest.kt @@ -0,0 +1,148 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package aws.smithy.kotlin.runtime.serde.xml + +import aws.smithy.kotlin.runtime.serde.parseInt +import kotlin.test.* + +class XmlTagReaderTest { + + @Test + fun testNextTag() { + // inner b could be confused as closing the outer b if depth isn't tracked properly + val payload = """ + + + + + + + + + more + + """.encodeToByteArray() + val scoped = xmlTagReader(payload) + val expected = listOf("a", "b", "c", "d") + .map { XmlToken.BeginElement(2, it) } + + expected.forEach { expectedStartTag -> + val tagReader = assertNotNull(scoped.nextTag()) + assertEquals(expectedStartTag, tagReader.tag) + tagReader.drop() + } + } + + @Test + fun testNextTagScope() { + // test scope of each tag reader + val payload = """ + + + 1 + 2 + + + 3 + 4 + + + + abc + + + """.encodeToByteArray() + val scoped = xmlTagReader(payload) + assertEquals(XmlToken.BeginElement(1, "Root"), scoped.tag) + + val s1 = assertNotNull(scoped.nextTag()) + assertEquals(XmlToken.BeginElement(2, "Child1"), s1.tag) + val s1Elements = listOf( + XmlToken.BeginElement(3, "x"), + XmlToken.Text(3, "1"), + XmlToken.EndElement(3, "x"), + XmlToken.BeginElement(3, "y"), + XmlToken.Text(3, "2"), + XmlToken.EndElement(3, "y"), + ) + assertEquals(s1Elements, s1.allTokens()) + + val s2 = assertNotNull(scoped.nextTag()) + assertEquals(XmlToken.BeginElement(2, "Child2"), s2.tag) + + val aReader = assertNotNull(s2.nextTag()) + assertEquals(XmlToken.BeginElement(3, "a"), aReader.tag) + assertNull(aReader.nextTag()) + + val bReader = assertNotNull(s2.nextTag()) + assertEquals(XmlToken.BeginElement(3, "b"), bReader.tag) + assertEquals(XmlToken.Text(3, "4"), bReader.nextToken()) + assertNull(bReader.nextToken()) + bReader.drop() + + // self close token behavior + val selfCloseReader = assertNotNull(scoped.nextTag()) + assertEquals(emptyList(), selfCloseReader.allTokens()) + selfCloseReader.drop() + + val s4 = assertNotNull(scoped.nextTag()) + assertEquals(XmlToken.BeginElement(2, "Child4"), s4.tag) + } + + @Test + fun testData() { + val payload = """ + + + 1 + 2 + + + this is an a + decoder should skip + + ignored a + ignored b + ignored c + + + + + + + + """.encodeToByteArray() + + val decoder = xmlTagReader(payload) + loop@while (true) { + val curr = decoder.nextTag() ?: break@loop + when (curr.tagName) { + "Child1" -> { + assertEquals(1, curr.nextTag()?.data()?.parseInt()?.getOrNull()) + assertEquals(2, curr.nextTag()?.data()?.parseInt()?.getOrNull()) + } + "Child2" -> { + assertEquals("this is an a", curr.nextTag()?.data()) + // intentionally ignore the next tag and don't consume the entire child subtree + } + "Child4" -> assertEquals(" ", curr.nextTag()?.data()) + else -> {} + } + // consume the current tag entirely before trying to process the next + curr.drop() + } + } +} + +fun XmlTagReader.allTokens(): List { + val tokenList = mutableListOf() + var nextToken: XmlToken? + do { + nextToken = this.nextToken() + if (nextToken != null) tokenList.add(nextToken) + } while (nextToken != null) + + return tokenList +} diff --git a/runtime/smithy-test/common/src/aws/smithy/kotlin/runtime/smithy/test/XmlAssertions.kt b/runtime/smithy-test/common/src/aws/smithy/kotlin/runtime/smithy/test/XmlAssertions.kt index f0c964a90..5d75333bb 100644 --- a/runtime/smithy-test/common/src/aws/smithy/kotlin/runtime/smithy/test/XmlAssertions.kt +++ b/runtime/smithy-test/common/src/aws/smithy/kotlin/runtime/smithy/test/XmlAssertions.kt @@ -14,7 +14,7 @@ import kotlin.test.assertEquals /** * Assert XML strings for equality ignoring key order */ -public suspend fun assertXmlStringsEqual(expected: String, actual: String) { +public fun assertXmlStringsEqual(expected: String, actual: String) { // parse into a dom representation and sort the dom into a canonical form for comparison val expectedNode = XmlNode.parse(expected.encodeToByteArray()).apply { toCanonicalForm() } val actualNode = XmlNode.parse(actual.encodeToByteArray()).apply { toCanonicalForm() } diff --git a/settings.gradle.kts b/settings.gradle.kts index 4608e6c65..e9e3d91d2 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -73,11 +73,12 @@ include(":tests") include(":tests:benchmarks:aws-signing-benchmarks") include(":tests:benchmarks:channel-benchmarks") include(":tests:benchmarks:http-benchmarks") -include(":tests:benchmarks:serde-benchmarks-codegen") include(":tests:benchmarks:serde-benchmarks") include(":tests:compile") include(":tests:codegen:nullability-tests") include(":tests:codegen:paginator-tests") +include(":tests:codegen:serde-tests") +include(":tests:codegen:serde-codegen-support") include(":tests:codegen:waiter-tests") include(":tests:integration:slf4j-1x-consumer") include(":tests:integration:slf4j-2x-consumer") diff --git a/tests/benchmarks/serde-benchmarks-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/protocols/json/SerdeBenchmarkJsonProtocol.kt b/tests/benchmarks/serde-benchmarks-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/protocols/json/SerdeBenchmarkJsonProtocol.kt deleted file mode 100644 index 58dd3abb7..000000000 --- a/tests/benchmarks/serde-benchmarks-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/protocols/json/SerdeBenchmarkJsonProtocol.kt +++ /dev/null @@ -1,20 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ -package software.amazon.smithy.kotlin.codegen.protocols.json - -import software.amazon.smithy.kotlin.codegen.integration.KotlinIntegration -import software.amazon.smithy.kotlin.codegen.rendering.protocol.ProtocolGenerator -import software.amazon.smithy.model.shapes.ShapeId - -/** - * Dummy protocol for use in serde-benchmark project models. Generates JSON based serializers/deserializers - */ -class SerdeBenchmarkJsonProtocol : KotlinIntegration { - companion object { - val ID: ShapeId = ShapeId.from("aws.benchmarks.protocols#serdeBenchmarkJson") - } - - override val protocolGenerators: List = listOf(SerdeBenchmarkJsonProtocolGenerator) -} diff --git a/tests/benchmarks/serde-benchmarks-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/protocols/xml/SerdeBenchmarkXmlProtocol.kt b/tests/benchmarks/serde-benchmarks-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/protocols/xml/SerdeBenchmarkXmlProtocol.kt deleted file mode 100644 index c2d940daf..000000000 --- a/tests/benchmarks/serde-benchmarks-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/protocols/xml/SerdeBenchmarkXmlProtocol.kt +++ /dev/null @@ -1,20 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ -package software.amazon.smithy.kotlin.codegen.protocols.xml - -import software.amazon.smithy.kotlin.codegen.integration.KotlinIntegration -import software.amazon.smithy.kotlin.codegen.rendering.protocol.ProtocolGenerator -import software.amazon.smithy.model.shapes.ShapeId - -/** - * Dummy protocol for use in serde-benchmark project models. Generates XML-based serializers/deserializers. - */ -class SerdeBenchmarkXmlProtocol : KotlinIntegration { - companion object { - val ID: ShapeId = ShapeId.from("aws.benchmarks.protocols#serdeBenchmarkXml") - } - - override val protocolGenerators: List = listOf(SerdeBenchmarkXmlProtocolGenerator) -} diff --git a/tests/benchmarks/serde-benchmarks-codegen/src/main/resources/META-INF/services/software.amazon.smithy.kotlin.codegen.integration.KotlinIntegration b/tests/benchmarks/serde-benchmarks-codegen/src/main/resources/META-INF/services/software.amazon.smithy.kotlin.codegen.integration.KotlinIntegration deleted file mode 100644 index 88d7b7750..000000000 --- a/tests/benchmarks/serde-benchmarks-codegen/src/main/resources/META-INF/services/software.amazon.smithy.kotlin.codegen.integration.KotlinIntegration +++ /dev/null @@ -1,2 +0,0 @@ -software.amazon.smithy.kotlin.codegen.protocols.json.SerdeBenchmarkJsonProtocol -software.amazon.smithy.kotlin.codegen.protocols.xml.SerdeBenchmarkXmlProtocol diff --git a/tests/benchmarks/serde-benchmarks/README.md b/tests/benchmarks/serde-benchmarks/README.md index a618a0e35..6cca8e91f 100644 --- a/tests/benchmarks/serde-benchmarks/README.md +++ b/tests/benchmarks/serde-benchmarks/README.md @@ -8,20 +8,20 @@ This project contains micro benchmarks for the serialization implementation(s). ./gradlew :runtime:serde:serde-benchmarks:jvmBenchmark ``` -Baseline `0.7.8-beta` on EC2 **[m5.4xlarge](https://aws.amazon.com/ec2/instance-types/m5/)** in **OpenJK 1.8.0_312**: +Baseline on EC2 **[m5.4xlarge](https://aws.amazon.com/ec2/instance-types/m5/)** in **Corretto-17.0.10.8.1**: ``` jvm summary: -Benchmark (sourceFilename) Mode Cnt Score Error Units -a.s.k.b.s.json.CitmBenchmark.tokensBenchmark N/A avgt 5 12.530 ± 0.611 ms/op -a.s.k.b.s.json.TwitterBenchmark.deserializeBenchmark N/A avgt 5 10.148 ± 7.515 ms/op -a.s.k.b.s.json.TwitterBenchmark.serializeBenchmark N/A avgt 5 1.534 ± 1.608 ms/op -a.s.k.b.s.json.TwitterBenchmark.tokensBenchmark N/A avgt 5 6.381 ± 3.615 ms/op -a.s.k.b.s.xml.BufferStreamWriterBenchmark.serializeBenchmark N/A avgt 5 11.746 ± 0.262 ms/op -a.s.k.b.s.xml.XmlDeserializerBenchmark.deserializeBenchmark N/A avgt 5 90.697 ± 1.178 ms/op -a.s.k.b.s.xml.XmlLexerBenchmark.deserializeBenchmark countries-states.xml avgt 5 22.665 ± 0.473 ms/op -a.s.k.b.s.xml.XmlLexerBenchmark.deserializeBenchmark kotlin-article.xml avgt 5 0.734 ± 0.017 ms/op -a.s.k.b.s.xml.XmlSerializerBenchmark.serializeBenchmark N/A avgt 5 27.324 ± 31.331 ms/op +Benchmark (sourceFilename) Mode Cnt Score Error Units +a.s.k.b.s.json.CitmBenchmark.tokensBenchmark N/A avgt 5 10.066 ± 0.033 ms/op +a.s.k.b.s.json.TwitterBenchmark.deserializeBenchmark N/A avgt 5 7.295 ± 0.033 ms/op +a.s.k.b.s.json.TwitterBenchmark.serializeBenchmark N/A avgt 5 1.498 ± 0.026 ms/op +a.s.k.b.s.json.TwitterBenchmark.tokensBenchmark N/A avgt 5 4.431 ± 0.029 ms/op +a.s.k.b.s.xml.BufferStreamWriterBenchmark.serializeBenchmark N/A avgt 5 10.540 ± 0.134 ms/op +a.s.k.b.s.xml.XmlDeserializerBenchmark.deserializeBenchmark N/A avgt 5 33.566 ± 0.074 ms/op +a.s.k.b.s.xml.XmlLexerBenchmark.deserializeBenchmark countries-states.xml avgt 5 25.200 ± 0.079 ms/op +a.s.k.b.s.xml.XmlLexerBenchmark.deserializeBenchmark kotlin-article.xml avgt 5 0.846 ± 0.003 ms/op +a.s.k.b.s.xml.XmlSerializerBenchmark.serializeBenchmark N/A avgt 5 21.714 ± 0.385 ms/op ``` ## JSON Data @@ -44,7 +44,7 @@ Raw data was imported from multiple sources: ## Benchmarks -The `model` folder contains hand rolled Smithy models for some of the benchmarks. The `smithy-benchmarks-codegen` project -contains the codegen support to generate these models. +The `model` folder contains hand rolled Smithy models for some of the benchmarks. +The `tests/codegen/serde-codegen-support` module contains the codegen support to generate these models. These models are generated as part of the build. Until you run `assemble` you may see errors in your IDE. \ No newline at end of file diff --git a/tests/benchmarks/serde-benchmarks/build.gradle.kts b/tests/benchmarks/serde-benchmarks/build.gradle.kts index 61e89b8ea..a604a67ec 100644 --- a/tests/benchmarks/serde-benchmarks/build.gradle.kts +++ b/tests/benchmarks/serde-benchmarks/build.gradle.kts @@ -75,13 +75,14 @@ afterEvaluate { val codegen by configurations.getting dependencies { - codegen(project(":tests:benchmarks:serde-benchmarks-codegen")) + codegen(project(":tests:codegen:serde-codegen-support")) codegen(libs.smithy.cli) codegen(libs.smithy.model) } tasks.generateSmithyProjections { smithyBuildConfigs.set(files("smithy-build.json")) + buildClasspath.set(codegen) } data class BenchmarkModel(val name: String) { diff --git a/tests/benchmarks/serde-benchmarks/jvm/src/aws/smithy/kotlin/benchmarks/serde/xml/XmlDeserializerBenchmark.kt b/tests/benchmarks/serde-benchmarks/jvm/src/aws/smithy/kotlin/benchmarks/serde/xml/XmlDeserializerBenchmark.kt index 90d4a6c5e..4b9b8ee97 100644 --- a/tests/benchmarks/serde-benchmarks/jvm/src/aws/smithy/kotlin/benchmarks/serde/xml/XmlDeserializerBenchmark.kt +++ b/tests/benchmarks/serde-benchmarks/jvm/src/aws/smithy/kotlin/benchmarks/serde/xml/XmlDeserializerBenchmark.kt @@ -7,7 +7,7 @@ package aws.smithy.kotlin.benchmarks.serde.xml import aws.smithy.kotlin.benchmarks.serde.BenchmarkBase import aws.smithy.kotlin.benchmarks.serde.xml.countriesstates.model.CountriesAndStates import aws.smithy.kotlin.benchmarks.serde.xml.countriesstates.serde.deserializeCountriesAndStatesDocument -import aws.smithy.kotlin.runtime.serde.xml.XmlDeserializer +import aws.smithy.kotlin.runtime.serde.xml.xmlTagReader import kotlinx.benchmark.* import kotlinx.coroutines.runBlocking @@ -18,7 +18,7 @@ open class XmlDeserializerBenchmark : BenchmarkBase() { private fun deserialize(): CountriesAndStates = runBlocking { - val deserializer = XmlDeserializer(source) + val deserializer = xmlTagReader(source) deserializeCountriesAndStatesDocument(deserializer) } diff --git a/tests/benchmarks/serde-benchmarks/jvm/src/aws/smithy/kotlin/benchmarks/serde/xml/XmlSerializerBenchmark.kt b/tests/benchmarks/serde-benchmarks/jvm/src/aws/smithy/kotlin/benchmarks/serde/xml/XmlSerializerBenchmark.kt index e9a819446..6d46e43d7 100644 --- a/tests/benchmarks/serde-benchmarks/jvm/src/aws/smithy/kotlin/benchmarks/serde/xml/XmlSerializerBenchmark.kt +++ b/tests/benchmarks/serde-benchmarks/jvm/src/aws/smithy/kotlin/benchmarks/serde/xml/XmlSerializerBenchmark.kt @@ -8,8 +8,8 @@ import aws.smithy.kotlin.benchmarks.serde.BenchmarkBase import aws.smithy.kotlin.benchmarks.serde.xml.countriesstates.model.CountriesAndStates import aws.smithy.kotlin.benchmarks.serde.xml.countriesstates.serde.deserializeCountriesAndStatesDocument import aws.smithy.kotlin.benchmarks.serde.xml.countriesstates.serde.serializeCountriesAndStatesDocument -import aws.smithy.kotlin.runtime.serde.xml.XmlDeserializer import aws.smithy.kotlin.runtime.serde.xml.XmlSerializer +import aws.smithy.kotlin.runtime.serde.xml.xmlTagReader import kotlinx.benchmark.* import kotlinx.coroutines.runBlocking @@ -21,7 +21,7 @@ open class XmlSerializerBenchmark : BenchmarkBase() { @Setup fun init() { dataSet = runBlocking { - val deserializer = XmlDeserializer(source) + val deserializer = xmlTagReader(source) deserializeCountriesAndStatesDocument(deserializer) } } diff --git a/tests/benchmarks/serde-benchmarks/model/countriesstates.smithy b/tests/benchmarks/serde-benchmarks/model/countriesstates.smithy index 25dcd16d4..88e5c06e9 100644 --- a/tests/benchmarks/serde-benchmarks/model/countriesstates.smithy +++ b/tests/benchmarks/serde-benchmarks/model/countriesstates.smithy @@ -2,9 +2,9 @@ $version: "1.0" namespace aws.benchmarks.countries_states -use aws.benchmarks.protocols#serdeBenchmarkXml +use aws.serde.protocols#serdeXml -@serdeBenchmarkXml +@serdeXml service CountriesStatesService { version: "2019-12-16", operations: [GetCountriesAndStates] diff --git a/tests/benchmarks/serde-benchmarks/model/serde-protocols.smithy b/tests/benchmarks/serde-benchmarks/model/serde-protocols.smithy deleted file mode 100644 index b14ea9bf9..000000000 --- a/tests/benchmarks/serde-benchmarks/model/serde-protocols.smithy +++ /dev/null @@ -1,13 +0,0 @@ -$version: "1.0" - -namespace aws.benchmarks.protocols - -// dummy protocols just for benchmarking purposes - -@protocolDefinition -@trait -structure serdeBenchmarkJson{} - -@protocolDefinition -@trait -structure serdeBenchmarkXml{} diff --git a/tests/benchmarks/serde-benchmarks/model/twitter.smithy b/tests/benchmarks/serde-benchmarks/model/twitter.smithy index b11264123..d57f03115 100644 --- a/tests/benchmarks/serde-benchmarks/model/twitter.smithy +++ b/tests/benchmarks/serde-benchmarks/model/twitter.smithy @@ -2,9 +2,9 @@ $version: "1.0" namespace aws.benchmarks.twitter -use aws.benchmarks.protocols#serdeBenchmarkJson +use aws.serde.protocols#serdeJson -@serdeBenchmarkJson +@serdeJson service Twitter { version: "2019-12-16", operations: [GetFeed] diff --git a/tests/benchmarks/serde-benchmarks-codegen/build.gradle.kts b/tests/codegen/serde-codegen-support/build.gradle.kts similarity index 81% rename from tests/benchmarks/serde-benchmarks-codegen/build.gradle.kts rename to tests/codegen/serde-codegen-support/build.gradle.kts index 3a21d75ba..f5551afa4 100644 --- a/tests/benchmarks/serde-benchmarks-codegen/build.gradle.kts +++ b/tests/codegen/serde-codegen-support/build.gradle.kts @@ -9,7 +9,7 @@ plugins { skipPublishing() -description = "Codegen support for serde-benchmarks project" +description = "Codegen support for serde related integration tests" dependencies { implementation(project(":codegen:smithy-kotlin-codegen")) diff --git a/tests/codegen/serde-codegen-support/src/main/kotlin/software/amazon/smithy/kotlin/codegen/protocols/ProtocolSupplier.kt b/tests/codegen/serde-codegen-support/src/main/kotlin/software/amazon/smithy/kotlin/codegen/protocols/ProtocolSupplier.kt new file mode 100644 index 000000000..f08c06908 --- /dev/null +++ b/tests/codegen/serde-codegen-support/src/main/kotlin/software/amazon/smithy/kotlin/codegen/protocols/ProtocolSupplier.kt @@ -0,0 +1,15 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.kotlin.codegen.protocols + +import software.amazon.smithy.kotlin.codegen.integration.KotlinIntegration +import software.amazon.smithy.kotlin.codegen.protocols.json.SerdeJsonProtocolGenerator +import software.amazon.smithy.kotlin.codegen.protocols.xml.SerdeXmlProtocolGenerator +import software.amazon.smithy.kotlin.codegen.rendering.protocol.ProtocolGenerator + +class ProtocolSupplier : KotlinIntegration { + override val protocolGenerators: List + get() = listOf(SerdeJsonProtocolGenerator, SerdeXmlProtocolGenerator) +} diff --git a/tests/benchmarks/serde-benchmarks-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/protocols/BenchmarkProtocolGenerator.kt b/tests/codegen/serde-codegen-support/src/main/kotlin/software/amazon/smithy/kotlin/codegen/protocols/SerdeProtocolGenerator.kt similarity index 92% rename from tests/benchmarks/serde-benchmarks-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/protocols/BenchmarkProtocolGenerator.kt rename to tests/codegen/serde-codegen-support/src/main/kotlin/software/amazon/smithy/kotlin/codegen/protocols/SerdeProtocolGenerator.kt index c8d3a1765..a1d85f729 100644 --- a/tests/benchmarks/serde-benchmarks-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/protocols/BenchmarkProtocolGenerator.kt +++ b/tests/codegen/serde-codegen-support/src/main/kotlin/software/amazon/smithy/kotlin/codegen/protocols/SerdeProtocolGenerator.kt @@ -13,7 +13,7 @@ import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.ServiceShape import software.amazon.smithy.model.traits.TimestampFormatTrait -abstract class BenchmarkProtocolGenerator : HttpBindingProtocolGenerator() { +abstract class SerdeProtocolGenerator : HttpBindingProtocolGenerator() { abstract val contentTypes: ProtocolContentTypes override val defaultTimestampFormat: TimestampFormatTrait.Format = TimestampFormatTrait.Format.EPOCH_SECONDS @@ -38,7 +38,7 @@ abstract class BenchmarkProtocolGenerator : HttpBindingProtocolGenerator() { RuntimeTypes.Core.ExecutionContext, RuntimeTypes.Http.HttpCall, ) { - write("error(\"not needed for benchmark tests\")") + write("error(\"not needed for codegen related tests\")") } } } diff --git a/tests/codegen/serde-codegen-support/src/main/kotlin/software/amazon/smithy/kotlin/codegen/protocols/json/SerdeJsonProtocol.kt b/tests/codegen/serde-codegen-support/src/main/kotlin/software/amazon/smithy/kotlin/codegen/protocols/json/SerdeJsonProtocol.kt new file mode 100644 index 000000000..8f91830d9 --- /dev/null +++ b/tests/codegen/serde-codegen-support/src/main/kotlin/software/amazon/smithy/kotlin/codegen/protocols/json/SerdeJsonProtocol.kt @@ -0,0 +1,22 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.kotlin.codegen.protocols.json + +import software.amazon.smithy.model.node.Node +import software.amazon.smithy.model.node.ObjectNode +import software.amazon.smithy.model.shapes.ShapeId +import software.amazon.smithy.model.traits.AnnotationTrait + +/** + * Dummy protocol for use in serde-benchmark project models. Generates JSON based serializers/deserializers + */ +class SerdeJsonProtocol : AnnotationTrait { + companion object { + val ID: ShapeId = ShapeId.from("aws.serde.protocols#serdeJson") + class Provider : AnnotationTrait.Provider(ID, ::SerdeJsonProtocol) + } + constructor(node: ObjectNode) : super(ID, node) + constructor() : this(Node.objectNode()) +} diff --git a/tests/benchmarks/serde-benchmarks-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/protocols/json/SerdeBenchmarkJsonProtocolGenerator.kt b/tests/codegen/serde-codegen-support/src/main/kotlin/software/amazon/smithy/kotlin/codegen/protocols/json/SerdeJsonProtocolGenerator.kt similarity index 71% rename from tests/benchmarks/serde-benchmarks-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/protocols/json/SerdeBenchmarkJsonProtocolGenerator.kt rename to tests/codegen/serde-codegen-support/src/main/kotlin/software/amazon/smithy/kotlin/codegen/protocols/json/SerdeJsonProtocolGenerator.kt index fe6ae04e4..30543899b 100644 --- a/tests/benchmarks/serde-benchmarks-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/protocols/json/SerdeBenchmarkJsonProtocolGenerator.kt +++ b/tests/codegen/serde-codegen-support/src/main/kotlin/software/amazon/smithy/kotlin/codegen/protocols/json/SerdeJsonProtocolGenerator.kt @@ -4,17 +4,17 @@ */ package software.amazon.smithy.kotlin.codegen.protocols.json -import software.amazon.smithy.kotlin.codegen.protocols.BenchmarkProtocolGenerator +import software.amazon.smithy.kotlin.codegen.protocols.SerdeProtocolGenerator import software.amazon.smithy.kotlin.codegen.rendering.protocol.* import software.amazon.smithy.kotlin.codegen.rendering.serde.* import software.amazon.smithy.model.shapes.ShapeId /** - * Protocol generator for benchmark protocol [SerdeBenchmarkJsonProtocol] + * Protocol generator for benchmark protocol [SerdeJsonProtocol] */ -object SerdeBenchmarkJsonProtocolGenerator : BenchmarkProtocolGenerator() { +object SerdeJsonProtocolGenerator : SerdeProtocolGenerator() { override val contentTypes = ProtocolContentTypes.consistent("application/json") - override val protocol: ShapeId = SerdeBenchmarkJsonProtocol.ID + override val protocol: ShapeId = SerdeJsonProtocol.ID override fun structuredDataSerializer(ctx: ProtocolGenerator.GenerationContext): StructuredDataSerializerGenerator = JsonSerializerGenerator(this) diff --git a/tests/codegen/serde-codegen-support/src/main/kotlin/software/amazon/smithy/kotlin/codegen/protocols/xml/SerdeXmlProtocol.kt b/tests/codegen/serde-codegen-support/src/main/kotlin/software/amazon/smithy/kotlin/codegen/protocols/xml/SerdeXmlProtocol.kt new file mode 100644 index 000000000..87e209dd6 --- /dev/null +++ b/tests/codegen/serde-codegen-support/src/main/kotlin/software/amazon/smithy/kotlin/codegen/protocols/xml/SerdeXmlProtocol.kt @@ -0,0 +1,23 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.kotlin.codegen.protocols.xml + +import software.amazon.smithy.model.node.Node +import software.amazon.smithy.model.node.ObjectNode +import software.amazon.smithy.model.shapes.ShapeId +import software.amazon.smithy.model.traits.AnnotationTrait + +/** + * Dummy protocol for use in testing projects that need to test XML codegen. Generates XML-based serializers/deserializers. + */ +class SerdeXmlProtocol : AnnotationTrait { + companion object { + val ID: ShapeId = ShapeId.from("aws.serde.protocols#serdeXml") + class Provider : AnnotationTrait.Provider(ID, ::SerdeXmlProtocol) + } + + constructor(node: ObjectNode) : super(ID, node) + constructor() : this(Node.objectNode()) +} diff --git a/tests/benchmarks/serde-benchmarks-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/protocols/xml/SerdeBenchmarkXmlProtocolGenerator.kt b/tests/codegen/serde-codegen-support/src/main/kotlin/software/amazon/smithy/kotlin/codegen/protocols/xml/SerdeXmlProtocolGenerator.kt similarity index 74% rename from tests/benchmarks/serde-benchmarks-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/protocols/xml/SerdeBenchmarkXmlProtocolGenerator.kt rename to tests/codegen/serde-codegen-support/src/main/kotlin/software/amazon/smithy/kotlin/codegen/protocols/xml/SerdeXmlProtocolGenerator.kt index 9b421b6eb..8a30a7979 100644 --- a/tests/benchmarks/serde-benchmarks-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/protocols/xml/SerdeBenchmarkXmlProtocolGenerator.kt +++ b/tests/codegen/serde-codegen-support/src/main/kotlin/software/amazon/smithy/kotlin/codegen/protocols/xml/SerdeXmlProtocolGenerator.kt @@ -4,7 +4,7 @@ */ package software.amazon.smithy.kotlin.codegen.protocols.xml -import software.amazon.smithy.kotlin.codegen.protocols.BenchmarkProtocolGenerator +import software.amazon.smithy.kotlin.codegen.protocols.SerdeProtocolGenerator import software.amazon.smithy.kotlin.codegen.rendering.protocol.* import software.amazon.smithy.kotlin.codegen.rendering.serde.StructuredDataParserGenerator import software.amazon.smithy.kotlin.codegen.rendering.serde.StructuredDataSerializerGenerator @@ -13,14 +13,14 @@ import software.amazon.smithy.kotlin.codegen.rendering.serde.XmlSerializerGenera import software.amazon.smithy.model.shapes.ShapeId /** - * Protocol generator for benchmark protocol [SerdeBenchmarkXmlProtocol]. + * Protocol generator for testing [SerdeXmlProtocol]. */ -object SerdeBenchmarkXmlProtocolGenerator : BenchmarkProtocolGenerator() { +object SerdeXmlProtocolGenerator : SerdeProtocolGenerator() { override val contentTypes = ProtocolContentTypes.consistent("application/xml") - override val protocol: ShapeId = SerdeBenchmarkXmlProtocol.ID + override val protocol: ShapeId = SerdeXmlProtocol.ID override fun structuredDataParser(ctx: ProtocolGenerator.GenerationContext): StructuredDataParserGenerator = - XmlParserGenerator(this, defaultTimestampFormat) + XmlParserGenerator(defaultTimestampFormat) override fun structuredDataSerializer(ctx: ProtocolGenerator.GenerationContext): StructuredDataSerializerGenerator = XmlSerializerGenerator(this, defaultTimestampFormat) diff --git a/tests/codegen/serde-codegen-support/src/main/resources/META-INF/services/software.amazon.smithy.kotlin.codegen.integration.KotlinIntegration b/tests/codegen/serde-codegen-support/src/main/resources/META-INF/services/software.amazon.smithy.kotlin.codegen.integration.KotlinIntegration new file mode 100644 index 000000000..8d5752fb7 --- /dev/null +++ b/tests/codegen/serde-codegen-support/src/main/resources/META-INF/services/software.amazon.smithy.kotlin.codegen.integration.KotlinIntegration @@ -0,0 +1 @@ +software.amazon.smithy.kotlin.codegen.protocols.ProtocolSupplier diff --git a/tests/codegen/serde-codegen-support/src/main/resources/META-INF/services/software.amazon.smithy.model.traits.TraitService b/tests/codegen/serde-codegen-support/src/main/resources/META-INF/services/software.amazon.smithy.model.traits.TraitService new file mode 100644 index 000000000..895c9d9c8 --- /dev/null +++ b/tests/codegen/serde-codegen-support/src/main/resources/META-INF/services/software.amazon.smithy.model.traits.TraitService @@ -0,0 +1,2 @@ +software.amazon.smithy.kotlin.codegen.protocols.json.SerdeJsonProtocol$Companion$Provider +software.amazon.smithy.kotlin.codegen.protocols.xml.SerdeXmlProtocol$Companion$Provider \ No newline at end of file diff --git a/tests/codegen/serde-codegen-support/src/main/resources/META-INF/smithy/manifest b/tests/codegen/serde-codegen-support/src/main/resources/META-INF/smithy/manifest new file mode 100644 index 000000000..31b96587d --- /dev/null +++ b/tests/codegen/serde-codegen-support/src/main/resources/META-INF/smithy/manifest @@ -0,0 +1 @@ +protocols.smithy \ No newline at end of file diff --git a/tests/codegen/serde-codegen-support/src/main/resources/META-INF/smithy/protocols.smithy b/tests/codegen/serde-codegen-support/src/main/resources/META-INF/smithy/protocols.smithy new file mode 100644 index 000000000..2f4413170 --- /dev/null +++ b/tests/codegen/serde-codegen-support/src/main/resources/META-INF/smithy/protocols.smithy @@ -0,0 +1,13 @@ +$version: "2.0" + +namespace aws.serde.protocols + +// dummy protocols just for testing/benchmarking purposes + +@protocolDefinition +@trait +structure serdeJson{} + +@protocolDefinition +@trait +structure serdeXml{} diff --git a/tests/codegen/serde-tests/.gitignore b/tests/codegen/serde-tests/.gitignore new file mode 100644 index 000000000..fc706388c --- /dev/null +++ b/tests/codegen/serde-tests/.gitignore @@ -0,0 +1 @@ +generated-src \ No newline at end of file diff --git a/tests/codegen/serde-tests/build.gradle.kts b/tests/codegen/serde-tests/build.gradle.kts new file mode 100644 index 000000000..ee2f94ab5 --- /dev/null +++ b/tests/codegen/serde-tests/build.gradle.kts @@ -0,0 +1,98 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +import aws.sdk.kotlin.gradle.codegen.smithyKotlinProjectionSrcDir +import aws.sdk.kotlin.gradle.dsl.skipPublishing + +plugins { + alias(libs.plugins.kotlin.jvm) + alias(libs.plugins.aws.kotlin.repo.tools.smithybuild) +} + +skipPublishing() + +val codegen by configurations.getting +dependencies { + codegen(project(":codegen:smithy-kotlin-codegen")) + codegen(project(":tests:codegen:serde-codegen-support")) + codegen(libs.smithy.cli) + codegen(libs.smithy.model) +} + +tasks.generateSmithyProjections { + smithyBuildConfigs.set(files("smithy-build.json")) + inputs.dir(project.layout.projectDirectory.dir("model")) + listOf("xml", "json").forEach { projectionName -> + val fromDir = smithyBuild.smithyKotlinProjectionSrcDir(projectionName) + outputs.dir(fromDir) + } + buildClasspath.set(codegen) +} + +val optinAnnotations = listOf("kotlin.RequiresOptIn", "aws.smithy.kotlin.runtime.InternalApi") +kotlin.sourceSets.all { + optinAnnotations.forEach { languageSettings.optIn(it) } +} + +tasks.test { + useJUnitPlatform() + testLogging { + events("passed", "skipped", "failed") + showStandardStreams = true + } +} + +dependencies { + compileOnly(project(":codegen:smithy-kotlin-codegen")) + + implementation(libs.kotlinx.coroutines.core) + implementation(project(":runtime:runtime-core")) + implementation(project(":runtime:serde")) + implementation(project(":runtime:serde:serde-json")) + implementation(project(":runtime:serde:serde-xml")) + implementation(project(":runtime:smithy-test")) + + testImplementation(libs.kotlin.test.junit5) +} + +val generatedSrcDir = project.layout.projectDirectory.dir("generated-src/main/kotlin") + +val stageGeneratedSources = tasks.register("stageGeneratedSources") { + group = "codegen" + dependsOn(tasks.generateSmithyProjections) + outputs.dir(generatedSrcDir) + // FIXME - this task up-to-date checks are wrong, likely something is not setup right with inputs/outputs somewhere + // for now just always run it + outputs.upToDateWhen { false } + doLast { + listOf("xml", "json").forEach { projectionName -> + val fromDir = smithyBuild.smithyKotlinProjectionSrcDir(projectionName) + logger.info("copying from ${fromDir.get()} to $generatedSrcDir") + copy { + from(fromDir) + into(generatedSrcDir) + include("**/model/*.kt") + include("**/serde/*.kt") + exclude("**/auth/*.kt") + exclude("**/endpoints/**.kt") + exclude("**/serde/*OperationSerializer.kt") + exclude("**/serde/*OperationDeserializer.kt") + } + } + } +} + +kotlin.sourceSets.getByName("main") { + kotlin.srcDir(generatedSrcDir) +} + +tasks.withType { + dependsOn(stageGeneratedSources) + // generated code has warnings unfortunately, see https://github.com/awslabs/aws-sdk-kotlin/issues/1169 + kotlinOptions.allWarningsAsErrors = false +} + +tasks.clean.configure { + delete(project.layout.projectDirectory.dir("generated-src")) +} diff --git a/tests/codegen/serde-tests/model/shared.smithy b/tests/codegen/serde-tests/model/shared.smithy new file mode 100644 index 000000000..042ce8c09 --- /dev/null +++ b/tests/codegen/serde-tests/model/shared.smithy @@ -0,0 +1,168 @@ +$version: "2.0" + +namespace aws.tests.serde.shared + +list StringList { + member: String, +} + +@sparse +list SparseStringList { + member: String +} + +map StringMap { + key: String, + value: String, +} + +map StringListMap { + key: String, + value: StringList +} + +map NestedStringMap { + key: String, + value: StringMap +} + +@sparse +map SparseStringMap { + key: String, + value: String, +} + +list NestedStringList { + member: StringList, +} + +list IntegerList { + member: Integer, +} + +@uniqueItems +list IntegerSet { + member: Integer, +} + +enum FooEnum { + FOO = "Foo" + BAZ = "Baz" + BAR = "Bar" + ONE = "1" + ZERO = "0" +} + +list FooEnumList { + member: FooEnum, +} + +map FooEnumMap { + key: String, + value: FooEnum, +} + +map FooEnumKeyMap { + key: FooEnum, + value: Integer +} + +@timestampFormat("date-time") +timestamp DateTime + +@timestampFormat("epoch-seconds") +timestamp EpochSeconds + +@timestampFormat("http-date") +timestamp HttpDate + +intEnum IntegerEnum { + A = 1 + B = 2 + C = 3 +} + +list IntegerEnumList { + member: IntegerEnum +} + +map IntegerEnumMap { + key: String, + value: IntegerEnum +} + + +@mixin +structure PrimitiveTypesMixin { + strField: String, + byteField: Byte, + intField: Integer, + shortField: Short, + longField: Long, + floatField: Float, + doubleField: Double, + bigIntegerField: BigInteger, + bigDecimalField: BigDecimal, + boolField: Boolean, + blobField: Blob, + enumField: FooEnum, + intEnumField: IntegerEnum, + dateTimeField: DateTime, + epochTimeField: EpochSeconds, + httpTimeField: HttpDate, +} + +@mixin +union PrimitiveTypesUnionMixin { + strField: String, + byteField: Byte, + intField: Integer, + shortField: Short, + longField: Long, + floatField: Float, + doubleField: Double, + bigIntegerField: BigInteger, + bigDecimalField: BigDecimal, + boolField: Boolean, + blobField: Blob, + enumField: FooEnum, + intEnumField: IntegerEnum, + dateTimeField: DateTime, + epochTimeField: EpochSeconds, + httpTimeField: HttpDate, + unitField: Unit +} + +@mixin +structure MapTypesMixin { + normalMap: StringMap, + sparseMap: SparseStringMap, + nestedMap: NestedStringMap, + listMap: StringListMap, + enumValueMap: FooEnumMap, + enumKeyMap: FooEnumKeyMap, +} + +@mixin +union MapTypesUnionMixin { + normalMap: StringMap, + sparseMap: SparseStringMap, + nestedMap: NestedStringMap, +} + +@mixin +structure ListTypesMixin { + normalList: StringList, + sparseList: SparseStringList, + nestedList: NestedStringList, +} + +@mixin +union ListTypesUnionMixin { + normalList: StringList, + + sparseList: SparseStringList, + + nestedList: NestedStringList, +} + diff --git a/tests/codegen/serde-tests/model/xml.smithy b/tests/codegen/serde-tests/model/xml.smithy new file mode 100644 index 000000000..c06fd0ac2 --- /dev/null +++ b/tests/codegen/serde-tests/model/xml.smithy @@ -0,0 +1,86 @@ +$version: "2.0" + +namespace aws.tests.serde.xml + +use aws.serde.protocols#serdeXml +use aws.tests.serde.shared#PrimitiveTypesMixin +use aws.tests.serde.shared#ListTypesMixin +use aws.tests.serde.shared#MapTypesMixin +use aws.tests.serde.shared#PrimitiveTypesUnionMixin +use aws.tests.serde.shared#ListTypesUnionMixin +use aws.tests.serde.shared#MapTypesUnionMixin +use aws.tests.serde.shared#StringMap +use aws.tests.serde.shared#StringListMap +use aws.tests.serde.shared#NestedStringMap +use aws.tests.serde.shared#FooEnumMap +use aws.tests.serde.shared#IntegerList +use aws.tests.serde.shared#StringList +use aws.tests.serde.shared#NestedStringList + +@serdeXml +service XmlService { + version: "2022-07-07", + operations: [TestOp] +} + +@http(uri: "/top", method: "POST") +operation TestOp { + input: StructType, + output: StructType, +} + +structure StructType with [PrimitiveTypesMixin, ListTypesMixin, MapTypesMixin] { + unionField: UnionType, + + recursive: StructType, + + @xmlAttribute + extra: Long, + + @xmlName("prefix:local") + renamedWithPrefix: String, + + @xmlFlattened + @xmlName("flatlist1") + flatList: StringList, + + @xmlFlattened + @xmlName("flatlist2") + secondFlatList: IntegerList + + @xmlFlattened + @xmlName("flatenummap") + flatEnumMap: FooEnumMap, + + renamedMemberList: RenamedMemberIntList + + renamedMemberMap: RenamedMap +} + +list RenamedMemberIntList { + @xmlName("item") + member: String +} + +map RenamedMap { + @xmlName("aKey") + key: String + + @xmlName("aValue") + value: String +} + +union UnionType with [PrimitiveTypesUnionMixin, ListTypesUnionMixin, MapTypesUnionMixin] { + @xmlFlattened + @xmlName("flatmap") + flatMap: StringMap, + + @xmlFlattened + @xmlName("flatlist") + flatList: StringList, + + @xmlName("double") + fpDouble: Double, + + struct: StructType, +} diff --git a/tests/codegen/serde-tests/smithy-build.json b/tests/codegen/serde-tests/smithy-build.json new file mode 100644 index 000000000..1ea49d608 --- /dev/null +++ b/tests/codegen/serde-tests/smithy-build.json @@ -0,0 +1,31 @@ +{ + "version": "1.0", + "sources": ["model"], + "projections": { + "xml": { + "transforms": [ + { + "name": "includeServices", + "args": { + "services": [ + "aws.tests.serde.xml#XmlService" + ] + } + } + ], + "plugins": { + "kotlin-codegen": { + "service": "aws.tests.serde.xml#XmlService", + "package": { + "name": "aws.smithy.kotlin.tests.serde.xml", + "version": "0.0.1" + }, + "build": { + "rootProject": false, + "generateDefaultBuildFiles": false + } + } + } + } + } +} diff --git a/tests/codegen/serde-tests/src/test/kotlin/aws/smithy/kotlin/tests/serde/AbstractXmlTest.kt b/tests/codegen/serde-tests/src/test/kotlin/aws/smithy/kotlin/tests/serde/AbstractXmlTest.kt new file mode 100644 index 000000000..2e65cb4e5 --- /dev/null +++ b/tests/codegen/serde-tests/src/test/kotlin/aws/smithy/kotlin/tests/serde/AbstractXmlTest.kt @@ -0,0 +1,27 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package aws.smithy.kotlin.tests.serde + +import aws.smithy.kotlin.runtime.serde.xml.* +import aws.smithy.kotlin.runtime.smithy.test.assertXmlStringsEqual +import kotlin.test.assertEquals + +abstract class AbstractXmlTest { + fun testRoundTrip( + expected: T, + payload: String, + serializerFn: (XmlSerializer, T) -> Unit, + deserializerFn: (XmlTagReader) -> T, + ) { + val serializer = XmlSerializer() + serializerFn(serializer, expected) + val actualPayload = serializer.toByteArray().decodeToString() + assertXmlStringsEqual(payload, actualPayload) + + val reader = xmlTagReader(payload.encodeToByteArray()) + val actualDeserialized = deserializerFn(reader) + assertEquals(expected, actualDeserialized) + } +} diff --git a/tests/codegen/serde-tests/src/test/kotlin/aws/smithy/kotlin/tests/serde/XmlListTest.kt b/tests/codegen/serde-tests/src/test/kotlin/aws/smithy/kotlin/tests/serde/XmlListTest.kt new file mode 100644 index 000000000..ca083d198 --- /dev/null +++ b/tests/codegen/serde-tests/src/test/kotlin/aws/smithy/kotlin/tests/serde/XmlListTest.kt @@ -0,0 +1,126 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package aws.smithy.kotlin.tests.serde + +import aws.smithy.kotlin.runtime.serde.xml.xmlTagReader +import aws.smithy.kotlin.tests.serde.xml.model.StructType +import aws.smithy.kotlin.tests.serde.xml.serde.deserializeStructTypeDocument +import aws.smithy.kotlin.tests.serde.xml.serde.serializeStructTypeDocument +import kotlin.test.Test +import kotlin.test.assertEquals + +class XmlListTest : AbstractXmlTest() { + @Test + fun testNormalList() { + val expected = StructType { + normalList = listOf("bar", "baz") + } + val payload = """ + + + bar + baz + + + """.trimIndent() + testRoundTrip(expected, payload, ::serializeStructTypeDocument, ::deserializeStructTypeDocument) + } + + @Test + fun testSparseList() { + val expected = StructType { + sparseList = listOf("bar", null, "baz") + } + val payload = """ + + + bar + + baz + + + """.trimIndent() + testRoundTrip(expected, payload, ::serializeStructTypeDocument, ::deserializeStructTypeDocument) + } + + @Test + fun testNestedList() { + val expected = StructType { + nestedList = listOf( + listOf("a", "b", "c"), + listOf("x", "y", "z"), + ) + } + val payload = """ + + + + a + b + c + + + x + y + z + + + + """.trimIndent() + testRoundTrip(expected, payload, ::serializeStructTypeDocument, ::deserializeStructTypeDocument) + } + + @Test + fun testListWithRenamedMember() { + val expected = StructType { + renamedMemberList = listOf("bar", "baz") + } + val payload = """ + + + bar + baz + + + """.trimIndent() + testRoundTrip(expected, payload, ::serializeStructTypeDocument, ::deserializeStructTypeDocument) + } + + @Test + fun testFlatList() { + val expected = StructType { + flatList = listOf("foo", "bar") + } + val payload = """ + + foo + bar + + """.trimIndent() + testRoundTrip(expected, payload, ::serializeStructTypeDocument, ::deserializeStructTypeDocument) + } + + @Test + fun testInterspersedFlatLists() { + // see https://github.com/awslabs/aws-sdk-kotlin/issues/1220 + val expected = StructType { + flatList = listOf("foo", "bar") + secondFlatList = listOf(1, 2) + } + val payload = """ + + foo + 1 + bar + 2 + + """.trimIndent() + + // we don't round trip this because the format isn't going to match + val reader = xmlTagReader(payload.encodeToByteArray()) + val actualDeserialized = deserializeStructTypeDocument(reader) + assertEquals(expected, actualDeserialized) + } +} diff --git a/tests/codegen/serde-tests/src/test/kotlin/aws/smithy/kotlin/tests/serde/XmlMapTest.kt b/tests/codegen/serde-tests/src/test/kotlin/aws/smithy/kotlin/tests/serde/XmlMapTest.kt new file mode 100644 index 000000000..530df9186 --- /dev/null +++ b/tests/codegen/serde-tests/src/test/kotlin/aws/smithy/kotlin/tests/serde/XmlMapTest.kt @@ -0,0 +1,277 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package aws.smithy.kotlin.tests.serde + +import aws.smithy.kotlin.runtime.serde.xml.xmlTagReader +import aws.smithy.kotlin.tests.serde.xml.model.FooEnum +import aws.smithy.kotlin.tests.serde.xml.model.StructType +import aws.smithy.kotlin.tests.serde.xml.model.UnionType +import aws.smithy.kotlin.tests.serde.xml.serde.deserializeStructTypeDocument +import aws.smithy.kotlin.tests.serde.xml.serde.serializeStructTypeDocument +import kotlin.test.Test +import kotlin.test.assertEquals + +class XmlMapTest : AbstractXmlTest() { + @Test + fun testNormalMap() { + val expected = StructType { + normalMap = mapOf( + "foo" to "bar", + "baz" to "quux", + ) + } + val payload = """ + + + + foo + bar + + + baz + quux + + + + """.trimIndent() + testRoundTrip(expected, payload, ::serializeStructTypeDocument, ::deserializeStructTypeDocument) + } + + @Test + fun testSparseMap() { + val expected = StructType { + sparseMap = mapOf( + "foo" to "bar", + "null" to null, + "baz" to "quux", + ) + } + val payload = """ + + + + foo + bar + + + null + + + + baz + quux + + + + """.trimIndent() + testRoundTrip(expected, payload, ::serializeStructTypeDocument, ::deserializeStructTypeDocument) + } + + @Test + fun testNestedMap() { + val expected = StructType { + nestedMap = mapOf( + "foo" to mapOf( + "k1" to "v1", + "k2" to "v2", + ), + "bar" to mapOf( + "k3" to "v3", + "k4" to "v4", + ), + ) + } + val payload = """ + + + + foo + + + k1 + v1 + + + k2 + v2 + + + + + bar + + + k3 + v3 + + + k4 + v4 + + + + + + """.trimIndent() + testRoundTrip(expected, payload, ::serializeStructTypeDocument, ::deserializeStructTypeDocument) + } + + @Test + fun testMapWithRenamedMember() { + val expected = StructType { + renamedMemberMap = mapOf( + "foo" to "bar", + "baz" to "quux", + ) + } + val payload = """ + + + + foo + bar + + + baz + quux + + + + """.trimIndent() + testRoundTrip(expected, payload, ::serializeStructTypeDocument, ::deserializeStructTypeDocument) + } + + @Test + fun testFlatMap() { + val expected = StructType { + flatEnumMap = mapOf( + "foo" to FooEnum.Foo, + "bar" to FooEnum.Bar, + ) + } + val payload = """ + + + foo + Foo + + + bar + Bar + + + """.trimIndent() + testRoundTrip(expected, payload, ::serializeStructTypeDocument, ::deserializeStructTypeDocument) + } + + @Test + fun testInterspersedFlatMaps() { + // see https://github.com/awslabs/aws-sdk-kotlin/issues/1220 + val expected = StructType { + flatEnumMap = mapOf( + "foo" to FooEnum.Foo, + "bar" to FooEnum.Bar, + ) + unionField = UnionType.Struct( + StructType { + normalMap = mapOf("k1" to "v1", "k2" to "v2") + flatEnumMap = mapOf("inner" to FooEnum.Baz) + }, + ) + } + val payload = """ + + + foo + Foo + + + + + + + k1 + v1 + + + k2 + v2 + + + + inner + Baz + + + + + bar + Bar + + + """.trimIndent() + + // we don't round trip this because the format isn't going to match + val reader = xmlTagReader(payload.encodeToByteArray()) + val actualDeserialized = deserializeStructTypeDocument(reader) + assertEquals(expected, actualDeserialized) + } + + @Test + fun testEnumValueMap() { + val expected = StructType { + enumValueMap = mapOf( + "foo" to FooEnum.Foo, + "bar" to FooEnum.Bar, + ) + } + val payload = """ + + + + foo + Foo + + + bar + Bar + + + + """.trimIndent() + testRoundTrip(expected, payload, ::serializeStructTypeDocument, ::deserializeStructTypeDocument) + } + + @Test + fun testEnumKeyMap() { + // see also https://github.com/awslabs/smithy-kotlin/issues/1045 + val expected = StructType { + enumKeyMap = mapOf( + FooEnum.Foo.value to 1, + "Bar" to 2, + "Unknown" to 3, + ) + } + val payload = """ + + + + Foo + 1 + + + Bar + 2 + + + Unknown + 3 + + + + """.trimIndent() + testRoundTrip(expected, payload, ::serializeStructTypeDocument, ::deserializeStructTypeDocument) + } +} diff --git a/tests/codegen/serde-tests/src/test/kotlin/aws/smithy/kotlin/tests/serde/XmlStructTest.kt b/tests/codegen/serde-tests/src/test/kotlin/aws/smithy/kotlin/tests/serde/XmlStructTest.kt new file mode 100644 index 000000000..8fa03295c --- /dev/null +++ b/tests/codegen/serde-tests/src/test/kotlin/aws/smithy/kotlin/tests/serde/XmlStructTest.kt @@ -0,0 +1,114 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package aws.smithy.kotlin.tests.serde + +import aws.smithy.kotlin.runtime.content.BigInteger +import aws.smithy.kotlin.runtime.text.encoding.encodeBase64String +import aws.smithy.kotlin.runtime.time.Instant +import aws.smithy.kotlin.tests.serde.xml.model.FooEnum +import aws.smithy.kotlin.tests.serde.xml.model.IntegerEnum +import aws.smithy.kotlin.tests.serde.xml.model.StructType +import aws.smithy.kotlin.tests.serde.xml.serde.deserializeStructTypeDocument +import aws.smithy.kotlin.tests.serde.xml.serde.serializeStructTypeDocument +import java.math.BigDecimal +import kotlin.test.Test + +class XmlStructTest : AbstractXmlTest() { + @Test + fun testStructPrimitives() { + val expected = StructType { + strField = "a string" + byteField = 2.toByte() + intField = 3 + shortField = 4 + longField = 5L + floatField = 6.0f + doubleField = 7.1 + bigIntegerField = BigInteger("1234") + bigDecimalField = BigDecimal("1.234") + boolField = true + blobField = "blob field".encodeToByteArray() + enumField = FooEnum.Bar + intEnumField = IntegerEnum.C + dateTimeField = Instant.fromIso8601("2020-10-16T15:46:24.982Z") + epochTimeField = Instant.fromEpochSeconds(1657204347) + httpTimeField = Instant.fromRfc5322("Sat, 22 Jul 2017 19:30:00 GMT") + extra = 42 + } + + val base64BlobField = expected.blobField!!.encodeBase64String() + + val payload = """ + + a string + 2 + 3 + 4 + 5 + 6.0 + 7.1 + 1234 + 1.234 + true + $base64BlobField + Bar + 3 + 2020-10-16T15:46:24.982Z + 1657204347 + Sat, 22 Jul 2017 19:30:00 GMT + + """.trimIndent() + + testRoundTrip(expected, payload, ::serializeStructTypeDocument, ::deserializeStructTypeDocument) + } + + @Test + fun testRenamedMembers() { + val expected = StructType { + renamedWithPrefix = "foo" + flatList = listOf("bar", "baz") + } + val payload = """ + + foo + bar + baz + + """.trimIndent() + testRoundTrip(expected, payload, ::serializeStructTypeDocument, ::deserializeStructTypeDocument) + } + + @Test + fun testRecursiveType() { + val expected = StructType { + strField = "first" + recursive { + strField = "second" + extra = 42 + recursive { + strField = "third" + normalList = listOf("foo", "bar") + } + } + } + val payload = """ + + first + + second + + third + + foo + bar + + + + + """.trimIndent() + + testRoundTrip(expected, payload, ::serializeStructTypeDocument, ::deserializeStructTypeDocument) + } +} diff --git a/tests/codegen/serde-tests/src/test/kotlin/aws/smithy/kotlin/tests/serde/XmlUnionTest.kt b/tests/codegen/serde-tests/src/test/kotlin/aws/smithy/kotlin/tests/serde/XmlUnionTest.kt new file mode 100644 index 000000000..5ec3f5446 --- /dev/null +++ b/tests/codegen/serde-tests/src/test/kotlin/aws/smithy/kotlin/tests/serde/XmlUnionTest.kt @@ -0,0 +1,304 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package aws.smithy.kotlin.tests.serde + +import aws.smithy.kotlin.runtime.time.Instant +import aws.smithy.kotlin.tests.serde.xml.model.StructType +import aws.smithy.kotlin.tests.serde.xml.model.UnionType +import aws.smithy.kotlin.tests.serde.xml.serde.deserializeStructTypeDocument +import aws.smithy.kotlin.tests.serde.xml.serde.serializeStructTypeDocument +import kotlin.test.Test + +class XmlUnionTest : AbstractXmlTest() { + @Test + fun testString() { + val expected = StructType { + unionField = UnionType.StrField("a string") + } + val payload = """ + + + a string + + + """.trimIndent() + testRoundTrip(expected, payload, ::serializeStructTypeDocument, ::deserializeStructTypeDocument) + } + + @Test + fun testByte() { + val expected = StructType { + unionField = UnionType.ByteField(1) + } + val payload = """ + + + 1 + + + """.trimIndent() + testRoundTrip(expected, payload, ::serializeStructTypeDocument, ::deserializeStructTypeDocument) + } + + @Test + fun testInt() { + val expected = StructType { + unionField = UnionType.IntField(1) + } + val payload = """ + + + 1 + + + """.trimIndent() + testRoundTrip(expected, payload, ::serializeStructTypeDocument, ::deserializeStructTypeDocument) + } + + @Test + fun testLong() { + val expected = StructType { + unionField = UnionType.LongField(1) + } + val payload = """ + + + 1 + + + """.trimIndent() + testRoundTrip(expected, payload, ::serializeStructTypeDocument, ::deserializeStructTypeDocument) + } + + @Test + fun testTimestamp() { + val expected = StructType { + unionField = UnionType.DateTimeField( + Instant.fromIso8601("2020-10-16T15:46:24.982Z"), + ) + } + val payload = """ + + + 2020-10-16T15:46:24.982Z + + + """.trimIndent() + testRoundTrip(expected, payload, ::serializeStructTypeDocument, ::deserializeStructTypeDocument) + } + + @Test + fun testNormalList() { + val expected = StructType { + unionField = UnionType.NormalList(listOf("foo", "bar")) + } + + val payload = """ + + + + foo + bar + + + + """.trimIndent() + testRoundTrip(expected, payload, ::serializeStructTypeDocument, ::deserializeStructTypeDocument) + } + + @Test + fun testFlatList() { + val expected = StructType { + unionField = UnionType.FlatList(listOf("foo", "bar")) + } + + val payload = """ + + + foo + bar + + + """.trimIndent() + testRoundTrip(expected, payload, ::serializeStructTypeDocument, ::deserializeStructTypeDocument) + } + + @Test + fun testNestedList() { + val expected = StructType { + unionField = UnionType.NestedList( + listOf( + listOf("a", "b", "c"), + listOf("x", "y", "z"), + ), + ) + } + + val payload = """ + + + + + a + b + c + + + x + y + z + + + + + """.trimIndent() + testRoundTrip(expected, payload, ::serializeStructTypeDocument, ::deserializeStructTypeDocument) + } + + @Test + fun testNormalMap() { + val expected = StructType { + unionField = UnionType.NormalMap( + mapOf( + "k1" to "v1", + "k2" to "v2", + ), + ) + } + val payload = """ + + + + + k1 + v1 + + + k2 + v2 + + + + + """.trimIndent() + testRoundTrip(expected, payload, ::serializeStructTypeDocument, ::deserializeStructTypeDocument) + } + + @Test + fun testNestedMap() { + val expected = StructType { + unionField = UnionType.NestedMap( + mapOf( + "foo" to mapOf( + "k1" to "v1", + "k2" to "v2", + ), + "bar" to mapOf( + "k3" to "v3", + "k4" to "v4", + ), + ), + ) + } + val payload = """ + + + + + foo + + + k1 + v1 + + + k2 + v2 + + + + + bar + + + k3 + v3 + + + k4 + v4 + + + + + + + """.trimIndent() + testRoundTrip(expected, payload, ::serializeStructTypeDocument, ::deserializeStructTypeDocument) + } + + @Test + fun testFlatMap() { + val expected = StructType { + unionField = UnionType.FlatMap( + mapOf( + "foo" to "bar", + "bar" to "baz", + ), + ) + } + val payload = """ + + + + foo + bar + + + bar + baz + + + + """.trimIndent() + testRoundTrip(expected, payload, ::serializeStructTypeDocument, ::deserializeStructTypeDocument) + } + + // FIXME - https://github.com/awslabs/smithy-kotlin/issues/1040 + // @Test + // fun testUnitField() { } + + @Test + fun testStruct() { + val expected = StructType { + unionField = UnionType.Struct( + StructType { + normalMap = mapOf("k1" to "v1", "k2" to "v2") + strField = "a string" + }, + ) + } + val payload = """ + + + + + + k1 + v1 + + + k2 + v2 + + + a string + + + + """.trimIndent() + testRoundTrip(expected, payload, ::serializeStructTypeDocument, ::deserializeStructTypeDocument) + } +}