-
Notifications
You must be signed in to change notification settings - Fork 83
/
Copy pathHTTPDrippingDownloadHandler.swift
255 lines (223 loc) · 9.22 KB
/
HTTPDrippingDownloadHandler.swift
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
//===----------------------------------------------------------------------===//
//
// This source file is part of the SwiftNIO open source project
//
// Copyright (c) 2024 Apple Inc. and the SwiftNIO project authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//
import HTTPTypes
import NIOCore
import NIOHTTPTypes
/// HTTP request handler sending a configurable stream of zeroes. Uses HTTPTypes request/response parts.
public final class HTTPDrippingDownloadHandler: ChannelDuplexHandler {
public typealias InboundIn = HTTPRequestPart
public typealias OutboundOut = HTTPResponsePart
public typealias OutboundIn = Never
// Predefine buffer to reuse over and over again when sending chunks to requester. NIO allows
// us to give it reference counted buffers. Reusing like this allows us to avoid allocations.
static let downloadBodyChunk = ByteBuffer(repeating: 0, count: 65536)
private var frequency: TimeAmount
private var size: Int
private var count: Int
private var delay: TimeAmount
private var code: HTTPResponse.Status
private enum Phase {
/// We haven't gotten the request head - nothing to respond to
case waitingOnHead
/// We got the request head and are delaying the response
case delayingResponse
/// We're dripping response chunks to the peer, tracking how many chunks we have left
case dripping(DrippingState)
/// We either sent everything to the client or the request ended short
case done
}
private struct DrippingState {
var chunksLeft: Int
var currentChunkBytesLeft: Int
}
private var phase = Phase.waitingOnHead
private var scheduled: Scheduled<Void>?
private var scheduledCallbackHandler: HTTPDrippingDownloadHandlerScheduledCallbackHandler?
private var pendingRead = false
private var pendingWrite = false
private var activelyWritingChunk = false
/// Initializes an `HTTPDrippingDownloadHandler`.
/// - Parameters:
/// - count: How many chunks should be sent. Note that the underlying HTTP
/// stack may split or combine these chunks into data frames as
/// they see fit.
/// - size: How large each chunk should be
/// - frequency: How much time to wait between sending each chunk
/// - delay: How much time to wait before sending the first chunk
/// - code: What HTTP status code to send
public init(
count: Int = 0,
size: Int = 0,
frequency: TimeAmount = .zero,
delay: TimeAmount = .zero,
code: HTTPResponse.Status = .ok
) {
self.frequency = frequency
self.size = size
self.count = count
self.delay = delay
self.code = code
}
public convenience init?(queryArgsString: Substring.UTF8View) {
self.init()
let pairs = queryArgsString.split(separator: UInt8(ascii: "&"))
for pair in pairs {
var pairParts = pair.split(separator: UInt8(ascii: "="), maxSplits: 1).makeIterator()
guard let first = pairParts.next(), let second = pairParts.next() else {
continue
}
guard let secondNum = Int(Substring(second)) else {
return nil
}
switch Substring(first) {
case "frequency":
self.frequency = .seconds(Int64(secondNum))
case "size":
self.size = secondNum
case "count":
self.count = secondNum
case "delay":
self.delay = .seconds(Int64(secondNum))
case "code":
self.code = HTTPResponse.Status(code: secondNum)
default:
continue
}
}
}
public func channelRead(context: ChannelHandlerContext, data: NIOAny) {
let part = self.unwrapInboundIn(data)
switch part {
case .head:
self.phase = .delayingResponse
if self.delay == TimeAmount.zero {
// If no delay, we might as well start responding immediately
self.onResponseDelayCompleted(context: context)
} else {
let this = NIOLoopBound(self, eventLoop: context.eventLoop)
let loopBoundContext = NIOLoopBound(context, eventLoop: context.eventLoop)
self.scheduled = context.eventLoop.scheduleTask(in: self.delay) {
this.value.onResponseDelayCompleted(context: loopBoundContext.value)
}
}
case .body, .end:
return
}
}
private func onResponseDelayCompleted(context: ChannelHandlerContext) {
guard case .delayingResponse = self.phase else {
return
}
var head = HTTPResponse(status: self.code)
// If the length isn't too big, let's include a content length header
if case (let contentLength, false) = self.size.multipliedReportingOverflow(by: self.count) {
head.headerFields = HTTPFields(dictionaryLiteral: (.contentLength, "\(contentLength)"))
}
context.writeAndFlush(self.wrapOutboundOut(.head(head)), promise: nil)
self.phase = .dripping(
DrippingState(
chunksLeft: self.count,
currentChunkBytesLeft: self.size
)
)
self.writeChunk(context: context)
}
public func channelInactive(context: ChannelHandlerContext) {
self.phase = .done
self.scheduled?.cancel()
context.fireChannelInactive()
}
public func channelWritabilityChanged(context: ChannelHandlerContext) {
if case .dripping = self.phase, self.pendingWrite && context.channel.isWritable {
self.writeChunk(context: context)
}
}
private func writeChunk(context: ChannelHandlerContext) {
// Make sure we don't accidentally reenter
if self.activelyWritingChunk {
return
}
self.activelyWritingChunk = true
defer {
self.activelyWritingChunk = false
}
// If we're not dripping the response body, there's nothing to do
guard case .dripping(var drippingState) = self.phase else {
return
}
// If we've sent all chunks, send end and be done
if drippingState.chunksLeft < 1 {
self.phase = .done
context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil)
return
}
var dataWritten = false
while drippingState.currentChunkBytesLeft > 0, context.channel.isWritable {
let toSend = min(
drippingState.currentChunkBytesLeft,
HTTPDrippingDownloadHandler.downloadBodyChunk.readableBytes
)
let buffer = HTTPDrippingDownloadHandler.downloadBodyChunk.getSlice(
at: HTTPDrippingDownloadHandler.downloadBodyChunk.readerIndex,
length: toSend
)!
context.write(self.wrapOutboundOut(.body(buffer)), promise: nil)
drippingState.currentChunkBytesLeft -= toSend
dataWritten = true
}
// If we weren't able to send the full chunk, it's because the channel isn't writable. We yield until it is
if drippingState.currentChunkBytesLeft > 0 {
self.pendingWrite = true
self.phase = .dripping(drippingState)
if dataWritten {
context.flush()
}
return
}
// We sent the full chunk. If we have no more chunks to write, we're done!
drippingState.chunksLeft -= 1
if drippingState.chunksLeft == 0 {
self.phase = .done
context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil)
return
}
if dataWritten {
context.flush()
}
// More chunks to write.. Kick off timer
drippingState.currentChunkBytesLeft = self.size
self.phase = .dripping(drippingState)
if self.scheduledCallbackHandler == nil {
let this = NIOLoopBound(self, eventLoop: context.eventLoop)
let loopBoundContext = NIOLoopBound(context, eventLoop: context.eventLoop)
self.scheduledCallbackHandler = HTTPDrippingDownloadHandlerScheduledCallbackHandler(
handler: this,
context: loopBoundContext
)
}
// SAFTEY: scheduling the callback only potentially throws when invoked off eventloop
do {
try context.eventLoop.scheduleCallback(in: self.frequency, handler: self.scheduledCallbackHandler!)
} catch {
context.fireErrorCaught(error)
}
}
private struct HTTPDrippingDownloadHandlerScheduledCallbackHandler: NIOScheduledCallbackHandler & Sendable {
var handler: NIOLoopBound<HTTPDrippingDownloadHandler>
var context: NIOLoopBound<ChannelHandlerContext>
func handleScheduledCallback(eventLoop: some EventLoop) {
self.handler.value.writeChunk(context: self.context.value)
}
}
}