Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: decrease generated artifact size #1057

Merged
merged 4 commits into from
Apr 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changes/1332be89-09d8-4b30-9e42-6f7a353c4c72.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"id": "1332be89-09d8-4b30-9e42-6f7a353c4c72",
"type": "misc",
"description": "Decrease generated client artifact sizes by reducing the number of suspension points for operations and inlining commonly used HTTP builders"
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class AwsQuery : QueryHttpBindingProtocolGenerator() {
writer: KotlinWriter,
) {
writer.write("""checkNotNull(payload){ "unable to parse error from empty response" }""")
writer.write("#T(payload)", RuntimeTypes.AwsXmlProtocols.parseRestXmlErrorResponse)
writer.write("#T(payload)", RuntimeTypes.AwsXmlProtocols.parseRestXmlErrorResponseNoSuspend)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class Ec2Query : QueryHttpBindingProtocolGenerator() {
writer: KotlinWriter,
) {
writer.write("""checkNotNull(payload){ "unable to parse error from empty response" }""")
writer.write("#T(payload)", RuntimeTypes.AwsXmlProtocols.parseEc2QueryErrorResponse)
writer.write("#T(payload)", RuntimeTypes.AwsXmlProtocols.parseEc2QueryErrorResponseNoSuspend)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ open class RestXml : AwsHttpBindingProtocolGenerator() {
writer: KotlinWriter,
) {
writer.write("""checkNotNull(payload){ "unable to parse error from empty response" }""")
writer.write("#T(payload)", RuntimeTypes.AwsXmlProtocols.parseRestXmlErrorResponse)
writer.write("#T(payload)", RuntimeTypes.AwsXmlProtocols.parseRestXmlErrorResponseNoSuspend)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,11 @@ abstract class AwsHttpBindingProtocolGenerator : HttpBindingProtocolGenerator()
override fun operationErrorHandler(ctx: ProtocolGenerator.GenerationContext, op: OperationShape): Symbol =
op.errorHandler(ctx.settings) { writer ->
writer.withBlock(
"private suspend fun ${op.errorHandlerName()}(context: #T, call: #T): #Q {",
"private fun ${op.errorHandlerName()}(context: #T, call: #T, payload: #T?): #Q {",
"}",
RuntimeTypes.Core.ExecutionContext,
RuntimeTypes.Http.HttpCall,
KotlinTypes.ByteArray,
KotlinTypes.Nothing,
) {
renderThrowOperationError(ctx, op, writer)
Expand All @@ -107,8 +108,7 @@ abstract class AwsHttpBindingProtocolGenerator : HttpBindingProtocolGenerator()
),
) {
val exceptionBaseSymbol = ExceptionBaseClassGenerator.baseExceptionSymbol(ctx.settings)
writer.write("val payload = call.response.body.#T()", RuntimeTypes.Http.readAll)
.write("val wrappedResponse = call.response.#T(payload)", RuntimeTypes.AwsProtocolCore.withPayload)
writer.write("val wrappedResponse = call.response.#T(payload)", RuntimeTypes.AwsProtocolCore.withPayload)
.write("val wrappedCall = call.copy(response = wrappedResponse)")
.write("")
.declareSection(
Expand Down Expand Up @@ -151,7 +151,7 @@ abstract class AwsHttpBindingProtocolGenerator : HttpBindingProtocolGenerator()
name = "${errSymbol.name}Deserializer"
namespace = ctx.settings.pkg.serde
}
writer.write("#S -> #T().deserialize(context, wrappedCall)", getErrorCode(ctx, err), errDeserializerSymbol)
writer.write("#S -> #T().deserialize(context, wrappedCall, payload)", getErrorCode(ctx, err), errDeserializerSymbol)
}
write("else -> #T(errorDetails.message)", exceptionBaseSymbol)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ object RuntimeTypes {
val EndpointResolver = symbol("EndpointResolver")
val ResolveEndpointRequest = symbol("ResolveEndpointRequest")
val execute = symbol("execute")
val HttpDeserialize = symbol("HttpDeserialize")
val HttpDeserializer = symbol("HttpDeserializer")
val HttpOperationContext = symbol("HttpOperationContext")
val HttpSerialize = symbol("HttpSerialize")
val HttpSerializer = symbol("HttpSerializer")
val OperationAuthConfig = symbol("OperationAuthConfig")
val OperationMetrics = symbol("OperationMetrics")
val OperationRequest = symbol("OperationRequest")
Expand Down Expand Up @@ -407,8 +407,8 @@ object RuntimeTypes {
val RestJsonErrorDeserializer = symbol("RestJsonErrorDeserializer")
}
object AwsXmlProtocols : RuntimeTypePackage(KotlinDependency.AWS_XML_PROTOCOLS) {
val parseRestXmlErrorResponse = symbol("parseRestXmlErrorResponse")
val parseEc2QueryErrorResponse = symbol("parseEc2QueryErrorResponse")
val parseRestXmlErrorResponseNoSuspend = symbol("parseRestXmlErrorResponseNoSuspend")
val parseEc2QueryErrorResponseNoSuspend = symbol("parseEc2QueryErrorResponseNoSuspend")
}

object AwsEventStream : RuntimeTypePackage(KotlinDependency.AWS_EVENT_STREAM) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {
* The function should have the following signature:
*
* ```
* suspend fun throwFooOperationError(context: ExecutionContext, call: HttpCall): Nothing {
* fun throwFooOperationError(context: ExecutionContext, call: HttpCall, payload: ByteArray?): Nothing {
* <-- CURRENT WRITER CONTEXT -->
* }
* ```
Expand Down Expand Up @@ -169,20 +169,25 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {
val operationSerializerSymbols = setOf(
RuntimeTypes.Http.HttpBody,
RuntimeTypes.Http.HttpMethod,
RuntimeTypes.HttpClient.Operation.HttpSerialize,
RuntimeTypes.Http.Request.HttpRequestBuilder,
RuntimeTypes.Http.Request.url,
)

val serdeMeta = HttpSerdeMeta(op.isInputEventStream(ctx.model))

ctx.delegator.useSymbolWriter(serializerSymbol) { writer ->
// import all of http, http.request, and serde packages. All serializers requires one or more of the symbols
// and most require quite a few. Rather than try and figure out which specific ones are used just take them
// all to ensure all the various DSL builders are available, etc
writer
.addImport(operationSerializerSymbols)
.write("")
.openBlock("internal class #T: #T<#T> {", serializerSymbol, RuntimeTypes.HttpClient.Operation.HttpSerialize, inputSymbol)
.openBlock("internal class #T: #T.#L<#T> {", serializerSymbol, RuntimeTypes.HttpClient.Operation.HttpSerializer, serdeMeta.variantName, inputSymbol)
.call {
writer.openBlock("override suspend fun serialize(context: #T, input: #T): #T {", RuntimeTypes.Core.ExecutionContext, inputSymbol, RuntimeTypes.Http.Request.HttpRequestBuilder)
val modifier = if (serdeMeta.isStreaming) "suspend " else ""
writer.openBlock(
"override #Lfun serialize(context: #T, input: #T): #T {",
modifier,
RuntimeTypes.Core.ExecutionContext,
inputSymbol,
RuntimeTypes.Http.Request.HttpRequestBuilder,
)
.write("val builder = #T()", RuntimeTypes.Http.Request.HttpRequestBuilder)
.call {
renderHttpSerialize(ctx, op, writer)
Expand Down Expand Up @@ -546,18 +551,22 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {

val resolver = getProtocolHttpBindingResolver(ctx.model, ctx.service)
val responseBindings = resolver.responseBindings(op)

val serdeMeta = httpDeserializerInfo(ctx, op)

ctx.delegator.useSymbolWriter(deserializerSymbol) { writer ->
writer
.write("")
.openBlock(
"internal class #T: #T<#T> {",
"internal class #T: #T.#L<#T> {",
deserializerSymbol,
RuntimeTypes.HttpClient.Operation.HttpDeserialize,
RuntimeTypes.HttpClient.Operation.HttpDeserializer,
serdeMeta.variantName,
outputSymbol,
)
.write("")
.call {
renderHttpDeserialize(ctx, outputSymbol, responseBindings, op, writer)
renderHttpDeserialize(ctx, outputSymbol, responseBindings, serdeMeta, op, writer)
}
.closeBlock("}")
}
Expand All @@ -569,8 +578,12 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {
protected open fun renderIsHttpError(ctx: ProtocolGenerator.GenerationContext, op: OperationShape, writer: KotlinWriter) {
writer.addImport(RuntimeTypes.Http.isSuccess)
writer.withBlock("if (!response.status.#T()) {", "}", RuntimeTypes.Http.isSuccess) {
val serdeMeta = httpDeserializerInfo(ctx, op)
if (serdeMeta.isStreaming) {
writer.write("val payload = response.body.#T()", RuntimeTypes.Http.readAll)
}
val errorHandlerFn = operationErrorHandler(ctx, op)
write("#T(context, call)", errorHandlerFn)
write("#T(context, call, payload)", errorHandlerFn)
}
}

Expand All @@ -587,7 +600,6 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {
RuntimeTypes.Serde.SerialKind,
RuntimeTypes.Serde.deserializeStruct,
RuntimeTypes.Http.Response.HttpResponse,
RuntimeTypes.HttpClient.Operation.HttpDeserialize,
)

val deserializerSymbol = buildSymbol {
Expand All @@ -598,16 +610,19 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {
reference(outputSymbol, SymbolReference.ContextOption.DECLARE)
}

// exception deserializers are never streaming
val serdeMeta = HttpSerdeMeta(false)

ctx.delegator.useSymbolWriter(deserializerSymbol) { writer ->
val resolver = getProtocolHttpBindingResolver(ctx.model, ctx.service)
val responseBindings = resolver.responseBindings(shape)
writer
.addImport(exceptionDeserializerSymbols)
.write("")
.openBlock("internal class #T: #T<#T> {", deserializerSymbol, RuntimeTypes.HttpClient.Operation.HttpDeserialize, outputSymbol)
.openBlock("internal class #T: #T.NonStreaming<#T> {", deserializerSymbol, RuntimeTypes.HttpClient.Operation.HttpDeserializer, outputSymbol)
.write("")
.call {
renderHttpDeserialize(ctx, outputSymbol, responseBindings, null, writer)
renderHttpDeserialize(ctx, outputSymbol, responseBindings, serdeMeta, null, writer)
}
.closeBlock("}")
}
Expand All @@ -617,18 +632,31 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {
ctx: ProtocolGenerator.GenerationContext,
outputSymbol: Symbol,
responseBindings: List<HttpBindingDescriptor>,
serdeMeta: HttpSerdeMeta,
// this method is shared between operation and exception deserialization. In the case of operations this MUST be set
op: OperationShape?,
writer: KotlinWriter,
) {
writer
.openBlock(
"override suspend fun deserialize(context: #T, call: #T): #T {",
RuntimeTypes.Core.ExecutionContext,
RuntimeTypes.Http.HttpCall,
outputSymbol,
)
.write("val response = call.response")
if (serdeMeta.isStreaming) {
writer
.openBlock(
"override suspend fun deserialize(context: #T, call: #T): #T {",
RuntimeTypes.Core.ExecutionContext,
RuntimeTypes.Http.HttpCall,
outputSymbol,
)
} else {
writer
.openBlock(
"override fun deserialize(context: #T, call: #T, payload: #T?): #T {",
RuntimeTypes.Core.ExecutionContext,
RuntimeTypes.Http.HttpCall,
KotlinTypes.ByteArray,
outputSymbol,
)
}

writer.write("val response = call.response")
.call {
if (outputSymbol.shape?.isError == false && op != null) {
// handle operation errors
Expand Down Expand Up @@ -657,7 +685,7 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {
if (op != null && op.isOutputEventStream(ctx.model)) {
deserializeViaEventStream(ctx, op, writer)
} else {
deserializeViaPayload(ctx, outputSymbol, responseBindings, op, writer)
deserializeViaPayload(ctx, outputSymbol, responseBindings, serdeMeta, op, writer)
}
}
.call {
Expand All @@ -681,6 +709,7 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {
ctx: ProtocolGenerator.GenerationContext,
outputSymbol: Symbol,
responseBindings: List<HttpBindingDescriptor>,
serdeMeta: HttpSerdeMeta,
// this method is shared between operation and exception deserialization. In the case of operations this MUST be set
op: OperationShape?,
writer: KotlinWriter,
Expand All @@ -707,10 +736,11 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {
sdg.errorDeserializer(ctx, outputSymbol.shape as StructureShape, documentMembers)
}

writer.write("val payload = response.body.#T()", RuntimeTypes.Http.readAll)
.withBlock("if (payload != null) {", "}") {
if (!serdeMeta.isStreaming) {
writer.withBlock("if (payload != null) {", "}") {
write("#T(builder, payload)", bodyDeserializerFn)
}
}
}
}
}
Expand Down Expand Up @@ -872,7 +902,6 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {
""
}

// writer.addImport("${KotlinDependency.CLIENT_RT_HTTP.namespace}.util", splitFn)
writer
.addImport(splitFn, KotlinDependency.HTTP, subpackage = "util")
.write("builder.#L = response.headers.getAll(#S)?.flatMap(::$splitFn)$mapFn", memberName, headerName)
Expand Down Expand Up @@ -940,9 +969,12 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {
val memberName = binding.member.defaultName()
val target = ctx.model.expectShape(binding.member.target)
val targetSymbol = ctx.symbolProvider.toSymbol(target)

// NOTE: we don't need serde metadata to know what to do here. Everything is non-streaming except streaming
// blob payloads.
when (target.type) {
ShapeType.STRING -> {
writer.write("val contents = response.body.#T()?.decodeToString()", RuntimeTypes.Http.readAll)
writer.write("val contents = payload?.decodeToString()")
if (target.isEnum) {
writer.write("builder.$memberName = contents?.let { #T.fromValue(it) }", targetSymbol)
} else {
Expand All @@ -951,36 +983,32 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {
}

ShapeType.ENUM -> {
writer.write("val contents = response.body.#T()?.decodeToString()", RuntimeTypes.Http.readAll)
writer.write("val contents = payload?.decodeToString()")
writer.write("builder.#L = contents?.let { #T.fromValue(it) }", memberName, targetSymbol)
}

ShapeType.INT_ENUM -> {
writer.write("val contents = response.body.#T()?.decodeToString()", RuntimeTypes.Http.readAll)
writer.write("val contents = payload?.decodeToString()")
writer.write("builder.#L = contents?.let { #T.fromValue(it.toInt()) }", memberName, targetSymbol)
}

ShapeType.BLOB -> {
val isBinaryStream = target.hasTrait<StreamingTrait>()
val conversion = if (isBinaryStream) {
writer.addImport(RuntimeTypes.Http.toByteStream)
"toByteStream()"
if (isBinaryStream) {
writer.write("builder.#L = response.body.#T()", memberName, RuntimeTypes.Http.toByteStream)
} else {
writer.addImport(RuntimeTypes.Http.readAll)
"readAll()"
writer.write("builder.#L = payload", memberName)
}
writer.write("builder.$memberName = response.body.$conversion")
}

ShapeType.STRUCTURE, ShapeType.UNION, ShapeType.DOCUMENT -> {
// delegate to the payload deserializer
val sdg = structuredDataParser(ctx)
val payloadDeserializerFn = sdg.payloadDeserializer(ctx, binding.member)

writer.write("val payload = response.body.#T()", RuntimeTypes.Http.readAll)
.withBlock("if (payload != null) {", "}") {
write("builder.#L = #T(payload)", memberName, payloadDeserializerFn)
}
writer.withBlock("if (payload != null) {", "}") {
write("builder.#L = #T(payload)", memberName, payloadDeserializerFn)
}
}

else ->
Expand Down Expand Up @@ -1061,3 +1089,18 @@ private fun renderNonBlankGuard(ctx: ProtocolGenerator.GenerationContext, member
private fun MemberShape.isNonBlankInStruct(ctx: ProtocolGenerator.GenerationContext): Boolean =
ctx.model.expectShape(target).isStringShape &&
getTrait<LengthTrait>()?.min?.getOrNull()?.takeIf { it > 0 } != null

private data class HttpSerdeMeta(val isStreaming: Boolean) {
/**
* The name of the HttpSerializer<T>/HttpDeserializer<T> variant
*/
val variantName: String
get() = if (isStreaming) "Streaming" else "NonStreaming"
}

private fun httpDeserializerInfo(ctx: ProtocolGenerator.GenerationContext, op: OperationShape): HttpSerdeMeta {
val isStreaming = ctx.model.expectShape<StructureShape>(op.output.get()).hasStreamingMember(ctx.model) ||
op.isOutputEventStream(ctx.model)

return HttpSerdeMeta(isStreaming)
}
Loading
Loading