-
Notifications
You must be signed in to change notification settings - Fork 82
/
Copy pathHTTPReceiveDiscardHandler.swift
90 lines (77 loc) · 3.2 KB
/
HTTPReceiveDiscardHandler.swift
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
//===----------------------------------------------------------------------===//
//
// This source file is part of the SwiftNIO open source project
//
// Copyright (c) 2024 Apple Inc. and the SwiftNIO project authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//
import HTTPTypes
import NIOCore
import NIOHTTPTypes
/// HTTP request handler that receives arbitrary bytes and discards them
public final class HTTPReceiveDiscardHandler: ChannelInboundHandler {
public typealias InboundIn = HTTPRequestPart
public typealias OutboundOut = HTTPResponsePart
private let expectation: Int?
private var expectationViolated = false
private var received = 0
/// Initializes `HTTPReceiveDiscardHandler`
/// - Parameter expectation: how many bytes should be expected. If more
/// bytes are received than expected, an error status code will
/// be sent to the client
public init(expectation: Int?) {
self.expectation = expectation
}
public func channelRead(context: ChannelHandlerContext, data: NIOAny) {
let part = self.unwrapInboundIn(data)
switch part {
case .head:
return
case .body(let buffer):
self.received += buffer.readableBytes
// If the expectation is violated, send 4xx
if let expectation = self.expectation, self.received > expectation {
self.onExpectationViolated(context: context, expectation: expectation)
}
case .end:
if self.expectationViolated {
// Already flushed a response, nothing else to do
return
}
if let expectation = self.expectation, self.received != expectation {
self.onExpectationViolated(context: context, expectation: expectation)
return
}
let responseBody = ByteBuffer(string: "Received \(self.received) bytes")
self.writeSimpleResponse(context: context, status: .ok, body: responseBody)
}
}
private func onExpectationViolated(context: ChannelHandlerContext, expectation: Int) {
self.expectationViolated = true
let body = ByteBuffer(
string:
"Received in excess of expectation; expected(\(expectation)) received(\(self.received))"
)
self.writeSimpleResponse(context: context, status: .badRequest, body: body)
}
private func writeSimpleResponse(
context: ChannelHandlerContext,
status: HTTPResponse.Status,
body: ByteBuffer
) {
let bodyLen = body.readableBytes
let responseHead = HTTPResponse(
status: status,
headerFields: HTTPFields(dictionaryLiteral: (.contentLength, "\(bodyLen)"))
)
context.write(self.wrapOutboundOut(.head(responseHead)), promise: nil)
context.write(self.wrapOutboundOut(.body(body)), promise: nil)
context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil)
}
}