Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changes to AWS construct to support streaming #1496

Draft
wants to merge 5 commits into
base: series/0.19
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 41 additions & 11 deletions modules/aws-http4s/src/smithy4s/aws/AwsCall.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,30 +16,60 @@

package smithy4s.aws

import smithy4s.kinds._
import fs2.Stream
import cats.effect.Resource

// format: off
trait AwsCall[F[_], Input, Err, Output, StreamedInput, StreamedOutput] {
sealed trait AwsCall[F[_], Input, Err, Output, StreamedInput, StreamedOutput] {

/**
* Runs the call and exposes the output in an effect, provided it is proven that the
* call does not have a streamed component to it.
*/
def run(implicit ev: AwsOperationKind.Unary[StreamedInput, StreamedOutput]): F[Output]

// /**
// * Uploads a payload and returns an effect, provided it is proven that the call has a
// * streamed input of type Byte, and no streamed output.
// */
// def upload[P](payload: P)(implicit uploadable: AwsUploadable[F, P], ev: AwsOperationKind.ByteUpload[StreamedInput, StreamedOutput]): F[Output]
/**
* Uploads a payload and returns an effect, provided it is proven that the call has a
* streamed input of type Byte, and no streamed output.
*/
def upload(payload: Stream[F, StreamedInput])(implicit ev: AwsOperationKind.ByteUpload[StreamedInput, StreamedOutput]): F[Output]

def download(implicit ev: AwsOperationKind.ByteDownload[StreamedInput, StreamedOutput]) : Resource[F, AwsDownloadResult[F, Output, StreamedOutput]]

/** Utility to turn AwsCall[F, I, E, O, Nothing, SO] into AwsCall[F, I, E, O, SI, SO]
* TODO: ensure this is relevant
*/
def wideUpload[SI]: AwsCall[F, Input, Err, Output, SI, StreamedOutput] = this.asInstanceOf[AwsCall[F, Input, Err, Output, SI, StreamedOutput]]

/** Utility to turn AwsCall[F, I, E, O, SI, Nothing] into AwsCall[F, I, E, O, SI, SO].
* TODO: ensure this is relevant
* */
def wideDownload[SO]: AwsCall[F, Input, Err, Output, StreamedInput, SO] = this.asInstanceOf[AwsCall[F, Input, Err, Output, StreamedInput, SO]]
}

case class AwsDownloadResult[F[_], O, SO](metadata: O, payload: Stream[F, SO])

object AwsCall {

def liftEffect[F[_]] : PolyFunction5[Kind1[F]#toKind5, AwsCall[F, *, *, *, *, *]] = new PolyFunction5[Kind1[F]#toKind5, AwsCall[F, *, *, *, *, *]]{
def apply[I, E, O, SI, SO](fo: F[O]) : AwsCall[F, I, E, O, SI, SO] = new AwsCall[F, I, E, O, SI, SO]{
def run(implicit ev: AwsOperationKind.Unary[SI, SO]): F[O] = fo
}
private def uncallable: Nothing = sys.error("Impossible call")

private final case class UnaryAwsCall[F[_], Input, Err, Output](run : F[Output]) extends AwsCall[F, Input, Err, Output, Nothing, Nothing]{
def run(implicit ev: AwsOperationKind.Unary[Nothing,Nothing]): F[Output] = run
def upload(payload: Stream[F, Nothing])(implicit ev: AwsOperationKind.ByteUpload[Nothing,Nothing]): F[Output] = uncallable
def download(implicit ev: AwsOperationKind.ByteDownload[Nothing,Nothing]): Resource[F,AwsDownloadResult[F,Output,Nothing]] = sys.error("Impossible calls")
}

private final case class BlobUploadAwsCall[F[_], Input, Err, Output, StreamedInput](uploadFunction: Stream[F, StreamedInput] => F[Output]) extends AwsCall[F, Input, Err, Output, StreamedInput, Nothing]{
def run(implicit ev: AwsOperationKind.Unary[StreamedInput,Nothing]): F[Output] = uncallable
def upload(payload: Stream[F, StreamedInput])(implicit ev: AwsOperationKind.ByteUpload[StreamedInput,Nothing]): F[Output] = uploadFunction(payload)
def download(implicit ev: AwsOperationKind.ByteDownload[StreamedInput,Nothing]): Resource[F,AwsDownloadResult[F,Output,Nothing]] = uncallable
}

def download[F[_], Input, Err, Output, StreamedOutput](res: (Byte => StreamedOutput) => Resource[F, AwsDownloadResult[F, Output, StreamedOutput]]): AwsCall[F, Input, Err, Output, Nothing, StreamedOutput] = new BlobDownloadAwsCall(res)

private final case class BlobDownloadAwsCall[F[_], Input, Err, Output, StreamedOutput](downloadResult: (Byte => StreamedOutput) => Resource[F, AwsDownloadResult[F, Output, StreamedOutput]]) extends AwsCall[F, Input, Err, Output, Nothing, StreamedOutput]{
def run(implicit ev: AwsOperationKind.Unary[Nothing,StreamedOutput]): F[Output] = uncallable
def upload(payload: Stream[F, Nothing])(implicit ev: AwsOperationKind.ByteUpload[Nothing,StreamedOutput]): F[Output] = uncallable
def download(implicit ev: AwsOperationKind.ByteDownload[Nothing,StreamedOutput]): Resource[F,AwsDownloadResult[F,Output,StreamedOutput]] = downloadResult(ev.apply)
}
}
20 changes: 19 additions & 1 deletion modules/aws-http4s/src/smithy4s/aws/AwsOperationKind.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package smithy4s.aws

import scala.annotation.implicitNotFound
import smithy4s.Bijection

object AwsOperationKind {

Expand All @@ -29,12 +30,29 @@ object AwsOperationKind {
}

@implicitNotFound(
"Cannot prove that the operation is a blob upload: it's either meant to upload something else than bytes or has a streamed output"
"Cannot prove that the operation is a blob upload. No instance of ByteUpload[${StreamedInput}, ${StreamedOutput}]"
)
sealed trait ByteUpload[StreamedInput, StreamedOutput]
object ByteUpload {
implicit val ByteUpload: ByteUpload[Byte, Nothing] =
new ByteUpload[Byte, Nothing] {}

implicit def fromBijection[T: Bijection[Byte, *]]: ByteUpload[T, Nothing] =
new ByteUpload[T, Nothing]() {}
}

@implicitNotFound(
"Cannot prove that the operation is a blob download. No instance of ByteDownload[${StreamedInput}, ${StreamedOutput}"
)
sealed trait ByteDownload[StreamedInput, StreamedOutput] {
def apply(value: Byte): StreamedOutput = value.asInstanceOf[StreamedOutput]
}
object ByteDownload {
implicit val ByteDownload: ByteDownload[Nothing, Byte] =
new ByteDownload[Nothing, Byte] {}

implicit def fromBijection[T: Bijection[Byte, *]]
: ByteDownload[Nothing, T] =
new ByteDownload[Nothing, T]() {}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
* Copyright 2021-2024 Disney Streaming
*
* Licensed under the Tomorrow Open Source Technology License, Version 1.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://disneystreaming.github.io/TOST-1.0.txt
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package smithy4s.aws
package internals

import cats.effect.Concurrent
import cats.effect.Resource
import cats.syntax.all._
import fs2.Chunk
import org.http4s._
import org.http4s.client.Client
import org.typelevel.ci.CIString
import smithy4s._
import smithy4s.aws.kernel.AwsCrypto._

private[aws] sealed trait AwsPayloadSignature {
import AwsPayloadSignature._
val headerValue: String = this match {
case Sha256(v) => v
case UnsignedPayload => "UNSIGNED-PAYLOAD"
// case StreamingUnsignedPayload => "STREAMING-UNSIGNED-PAYLOAD-TRAILER"
}
}

/**
* This is a draft API. There are many other ways to include the payload in the signature.
* Some of which are complex: using trailers and/or multiple chunks
*/
private[aws] object AwsPayloadSignature {
case class Sha256(value: String) extends AwsPayloadSignature
case object UnsignedPayload extends AwsPayloadSignature
// case object StreamingUnsignedPayload extends AwsPayloadSignature

val `X-Amz-Content-SHA256` = CIString("X-Amz-Content-SHA256")

def makeHeader(value: AwsPayloadSignature): Header.Raw =
Header.Raw(`X-Amz-Content-SHA256`, value.headerValue)


def signSingleChunk[F[_]: Concurrent]: Endpoint.Middleware[Client[F]] =
new Endpoint.Middleware[Client[F]] {
def prepare[Alg[_[_, _, _, _, _]]](service: Service[Alg])(
endpoint: service.Endpoint[_, _, _, _, _]
): Client[F] => Client[F] = { client =>
Client { request =>
Resource.eval(hashSingleChunk(request)).flatMap { request =>
client.run(request)
}
}
}
}

private def hashSingleChunk[F[_]: Concurrent](
request: Request[F]
): F[Request[F]] = {
request.body.chunks.compile.to(Chunk).map(_.flatten).map { body =>
val payloadHash = sha256HexDigest(body.toArray)
val signature = AwsPayloadSignature.Sha256(payloadHash)
request.putHeaders(AwsPayloadSignature.makeHeader(signature))
}
}
}
37 changes: 25 additions & 12 deletions modules/aws-http4s/src/smithy4s/aws/internals/AwsSigning.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,18 @@ package internals
import cats.effect.Concurrent
import cats.effect.Resource
import cats.syntax.all._
import fs2.Chunk
import org.http4s._
import org.http4s.client.Client
import org.typelevel.ci.CIString
import smithy4s._
import smithy4s.aws.kernel.AwsCrypto._
import smithy4s.aws.internals.AwsPayloadSignature.`X-Amz-Content-SHA256`

import java.net.URLEncoder
import java.nio.charset.StandardCharsets

/**
* A Client middleware that signs http requests before they are sent to AWS.
* This works by compiling the body of the request in memory in a chunk before sending
* it back, which means it is not proper to use it in the context of streaming.
*/
private[aws] object AwsSigning {

Expand Down Expand Up @@ -108,8 +106,7 @@ private[aws] object AwsSigning {
// scalafmt: { align.preset = most, danglingParentheses.preset = false, maxColumn = 240, align.tokens = [{code = ":"}]}
(request: Request[F]) => {

val bodyF = request.body.chunks.compile.to(Chunk).map(_.flatten)
val awsHeadersF = (bodyF, timestamp, credentials, region).mapN { case (body, timestamp, credentials, region) =>
val awsHeadersF = (timestamp, credentials, region).mapN { case (timestamp, credentials, region) =>
val credentialsScope = s"${timestamp.conciseDate}/$region/$endpointPrefix/aws4_request"
val queryParams: Vector[(String, String)] =
request.uri.query.toVector.sorted.map { case (k, v) => k -> v.getOrElse("") }
Expand All @@ -122,23 +119,39 @@ private[aws] object AwsSigning {
}
.mkString("&")

// // !\ Important: these must remain in the same order
val baseHeadersList = List(
val amzHeaders: List[(CIString, String)] = request.headers.headers
.filter(_.name.toString.toLowerCase.startsWith("x-amz"))
.map(h => (h.name, h.value))
.filterNot(_._2 == null)

// It is assumed that the hash value is computed before this middleware run
// via another middleware. If it is not, we use a default value.
val contentSha = amzHeaders.find(_._1 == `X-Amz-Content-SHA256`)
val payloadHash = contentSha.map(_._2).getOrElse(AwsPayloadSignature.UnsignedPayload.headerValue)
val missingContentShaHeader =
if (contentSha.isEmpty) List(`X-Amz-Content-SHA256` -> AwsPayloadSignature.UnsignedPayload.headerValue)
else List.empty

val addedHeaders: List[(CIString, String)] = List(
`Content-Type` -> request.contentType.map(contentType.value(_)).orNull,
`Host` -> request.uri.host.map(_.renderString).orNull,
`X-Amz-Date` -> timestamp.conciseDateTime,
`X-Amz-Security-Token` -> credentials.sessionToken.orNull,
`X-Amz-Target` -> (serviceName + "." + operationName)
).filterNot(_._2 == null)
).filterNot(_._2 == null) ++
// we also include the header, if it was not because it is required
missingContentShaHeader

// Headers included in the signature needs to be sorted alphabetically
val allHeaders = (addedHeaders ++ amzHeaders).sortBy(_._1)

val canonicalHeadersString = baseHeadersList
val canonicalHeadersString = allHeaders
.map { case (key, value) =>
key.toString.toLowerCase + ":" + value.trim
}
.mkString(newline)
lazy val signedHeadersString = baseHeadersList.map(_._1).map(_.toString.toLowerCase()).mkString(";")
lazy val signedHeadersString = allHeaders.map(_._1).map(_.toString.toLowerCase()).mkString(";")

val payloadHash = sha256HexDigest(body.toArray)
val pathString = request.uri.path.toAbsolute.renderString
val canonicalRequest = new StringBuilder()
.append(request.method.name.toUpperCase())
Expand Down Expand Up @@ -171,7 +184,7 @@ private[aws] object AwsSigning {
val signature = toHexString(hmacSha256(stringToSign, signatureKey))
val authHeaderValue = s"${algorithm} Credential=${credentials.accessKeyId}/$credentialsScope, SignedHeaders=$signedHeadersString, Signature=$signature"
val authHeader = Headers("Authorization" -> authHeaderValue)
val baseHeaders = Headers(baseHeadersList.map { case (k, v) => Header.Raw(k, v) })
val baseHeaders = Headers(addedHeaders.map { case (k, v) => Header.Raw(k, v) })
authHeader ++ baseHeaders
}

Expand Down
Loading