Skip to content

Commit f3c1bdd

Browse files
committed
large payload tests and performance improvements
1 parent 4ac580e commit f3c1bdd

File tree

3 files changed

+192
-77
lines changed

3 files changed

+192
-77
lines changed

Sources/AMQPClient/AMQPChannel.swift

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,13 @@ public final class AMQPChannel: Sendable {
3232
private let isConfirmMode = ManagedAtomic(false)
3333
private let isTxMode = ManagedAtomic(false)
3434
private let deliveryTag = ManagedAtomic(UInt64(1))
35-
private let frameMax: UInt32
35+
private let frameMax: Int
3636

3737
init(channelID: Frame.ChannelID, eventLoop: EventLoop, channel: AMQPChannelHandler, frameMax: UInt32) {
3838
ID = channelID
3939
self.eventLoop = eventLoop
4040
self.channel = channel
41-
self.frameMax = frameMax
41+
self.frameMax = Int(frameMax)
4242
}
4343

4444
/// Close the channel
@@ -90,22 +90,19 @@ public final class AMQPChannel: Sendable {
9090

9191
let header = Frame.Payload.header(.init(classID: classID, weight: 0, bodySize: UInt64(body.readableBytes), properties: properties))
9292

93-
let payloads: [Frame.Payload]
93+
var payloads: [Frame.Payload]
9494

9595
if body.readableBytes <= frameMax {
9696
payloads = [publish, header, .body(body)]
9797
} else {
98-
var parts = [publish, header]
99-
var buffer = body
98+
payloads = [publish, header]
99+
var body = body
100100

101-
while(buffer.readableBytes > 0) {
102-
guard let bytes = buffer.readBytes(length: frameMax < buffer.readableBytes ? Int(frameMax) : buffer.readableBytes) else {
103-
preconditionFailure("invalid bytes read")
104-
}
105-
parts.append(.body(.init(bytes: bytes)))
101+
while body.readableBytes > 0 {
102+
// slice is always valid
103+
let slice = body.readSlice(length: min(frameMax, body.readableBytes))!
104+
payloads.append(Frame.Payload.body(slice))
106105
}
107-
108-
payloads = parts
109106
}
110107

111108
let result: EventLoopFuture<Void> = channel.send(payloads: payloads)

Sources/AMQPClient/ChannelHandlers/AMQPConnectionMultiplexHandler.swift

Lines changed: 87 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ internal final class AMQPConnectionMultiplexHandler: ChannelDuplexHandler {
4242
// NOTE: this can be extended to keep some state of the open request so a response can be verified against its request
4343
var pendingRequests: Deque<EventLoopPromise<AMQPResponse>>
4444
weak var eventHandler: AMQPChannelHandler?
45-
var nextMessage: (frame: Frame.Method.Basic, properties: Properties?, bodySize: UInt64?, prevBody: ByteBuffer?)?
45+
var nextMessage: PartialDelivery?
4646

4747
init(initialResponsePromise: EventLoopPromise<AMQPResponse>) {
4848
pendingRequests = .init([initialResponsePromise])
@@ -207,7 +207,7 @@ internal final class AMQPConnectionMultiplexHandler: ChannelDuplexHandler {
207207
channel.fulfilNextPendingRequest(with: .channel(.message(.get())))
208208
case .deliver, .getOk, .return:
209209
// TODO: wrap this away more nicely, assert message must be nil
210-
channel.nextMessage = (frame: basic, nil, nil, nil)
210+
channel.nextMessage = PartialDelivery(method: basic)
211211
case .recoverOk:
212212
channel.fulfilNextPendingRequest(with: .channel(.basic(.recovered)))
213213
case let .consumeOk(consumerTag):
@@ -266,72 +266,56 @@ internal final class AMQPConnectionMultiplexHandler: ChannelDuplexHandler {
266266
}
267267
}
268268
case let .header(header):
269-
channel.nextMessage?.properties = header.properties
270-
channel.nextMessage?.bodySize = header.bodySize
269+
channel.nextMessage?.setHeader(header)
271270
case let .body(body):
272-
guard let msg = channel.nextMessage, let properties = msg.properties, let bodySize = msg.bodySize else {
273-
// TODO: take down channel
274-
return
275-
}
276-
277-
let prevSize = msg.prevBody?.readableBytes ?? 0
278-
if (prevSize + body.readableBytes < bodySize) {
279-
if var prevBody = msg.prevBody {
280-
prevBody.writeImmutableBuffer(body)
281-
channel.nextMessage?.prevBody = prevBody
282-
} else {
283-
channel.nextMessage?.prevBody = body
284-
}
285-
286-
return
287-
}
288-
289-
let completeBody: ByteBuffer
290-
291-
if var prevBody = msg.prevBody {
292-
prevBody.writeImmutableBuffer(body)
293-
completeBody = prevBody
294-
} else {
295-
completeBody = body
296-
}
297-
298-
switch msg.frame {
299-
case let .getOk(getOk):
300-
channel.fulfilNextPendingRequest(with: .channel(.message(.get(.init(
301-
message: .init(
302-
exchange: getOk.exchange,
303-
routingKey: getOk.routingKey,
304-
deliveryTag: getOk.deliveryTag,
305-
properties: properties,
306-
redelivered: getOk.redelivered,
307-
body: completeBody
308-
),
309-
messageCount: getOk.messageCount
310-
)))))
311-
case let .deliver(deliver):
312-
channel.eventHandler?.receiveDelivery(
313-
.init(
314-
exchange: deliver.exchange,
315-
routingKey: deliver.routingKey,
316-
deliveryTag: deliver.deliveryTag,
271+
// TODO: take down channel
272+
guard channel.nextMessage != nil else { return }
273+
274+
// written like this to avoid unnecessary copies
275+
channel.nextMessage!.addBody(body)
276+
277+
if channel.nextMessage!.isComplete {
278+
let (method, properties, completeBody) = channel.nextMessage!.asCompletedMessage()
279+
channel.nextMessage = nil
280+
281+
switch method {
282+
case let .getOk(getOk):
283+
channel.fulfilNextPendingRequest(with: .channel(.message(.get(.init(
284+
message: .init(
285+
exchange: getOk.exchange,
286+
routingKey: getOk.routingKey,
287+
deliveryTag: getOk.deliveryTag,
288+
properties: properties,
289+
redelivered: getOk.redelivered,
290+
body: completeBody
291+
),
292+
messageCount: getOk.messageCount
293+
)))))
294+
case let .deliver(deliver):
295+
channel.eventHandler?.receiveDelivery(
296+
.init(
297+
exchange: deliver.exchange,
298+
routingKey: deliver.routingKey,
299+
deliveryTag: deliver.deliveryTag,
300+
properties: properties,
301+
redelivered: deliver.redelivered,
302+
body: completeBody
303+
),
304+
for: deliver.consumerTag
305+
)
306+
case let .return(`return`):
307+
channel.eventHandler?.receiveReturn(.init(
308+
replyCode: `return`.replyCode,
309+
replyText: `return`.replyText,
310+
exchange: `return`.exchange,
311+
routingKey: `return`.routingKey,
317312
properties: properties,
318-
redelivered: deliver.redelivered,
319313
body: completeBody
320-
),
321-
for: deliver.consumerTag
322-
)
323-
case let .return(`return`):
324-
channel.eventHandler?.receiveReturn(.init(
325-
replyCode: `return`.replyCode,
326-
replyText: `return`.replyText,
327-
exchange: `return`.exchange,
328-
routingKey: `return`.routingKey,
329-
properties: properties,
330-
body: completeBody
331-
))
332-
default:
333-
// TODO: take down channel
334-
preconditionUnexpectedFrame(frame)
314+
))
315+
default:
316+
// TODO: take down channel
317+
preconditionUnexpectedFrame(frame)
318+
}
335319
}
336320
case .heartbeat:
337321
let heartbeat = Frame(channelID: frame.channelID, payload: .heartbeat)
@@ -400,3 +384,41 @@ extension AMQPConnectionMultiplexHandler.ChannelState {
400384
eventHandler?.reportAsClosed(error: error)
401385
}
402386
}
387+
388+
private struct PartialDelivery {
389+
let method: Frame.Method.Basic
390+
private var header: Frame.Header?
391+
private var payload: ByteBuffer?
392+
393+
init(method: Frame.Method.Basic) {
394+
self.method = method
395+
}
396+
397+
var isComplete: Bool { header != nil && header!.bodySize <= (payload?.readableBytes ?? 0) }
398+
399+
// NOTE: should be made throwing with validation for a more restrictive protocol implementation
400+
mutating func setHeader(_ header: consuming Frame.Header) {
401+
// validate that self.header == nil
402+
self.header = header
403+
}
404+
405+
// NOTE: should be made throwing with validation for a more restrictive protocol implementation
406+
mutating func addBody(_ buffer: consuming ByteBuffer) {
407+
guard let header else { return } // probably should take channel down
408+
409+
if payload == nil {
410+
buffer.reserveCapacity(Int(header.bodySize))
411+
payload = consume buffer
412+
} else {
413+
payload!.writeBuffer(&buffer)
414+
}
415+
}
416+
417+
func asCompletedMessage() -> (Frame.Method.Basic, Properties, ByteBuffer) {
418+
// NOTE: this could be made a consuming func once partial is possible I think
419+
assert(isComplete)
420+
421+
// header and payloads are guaranteed to be non-nil after isComplete
422+
return (method, header!.properties, payload!)
423+
}
424+
}
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
2+
@testable import AMQPClient
3+
import AMQPProtocol
4+
import NIO
5+
import NIOPosix
6+
import XCTest
7+
8+
final class AMQPChannelLargePayloadsTest: XCTestCase {
9+
let testExchange = "lage_payloads"
10+
var connection: AMQPConnection!
11+
var channel: AMQPChannel!
12+
13+
override func setUp() async throws {
14+
connection = try await AMQPConnection.connect(use: MultiThreadedEventLoopGroup.singleton.next(), from: .init(connection: .plain, server: .init()))
15+
channel = try await connection.openChannel()
16+
try await channel.exchangeDeclare(name: testExchange, type: "fanout")
17+
}
18+
19+
override func tearDown() async throws {
20+
try await channel?.close()
21+
try await connection?.close()
22+
}
23+
24+
func testPublishLargePayloads() async throws {
25+
try await channel.basicPublish(from: ByteBuffer(megaBytes: 50), exchange: testExchange, routingKey: "")
26+
try await channel.basicPublish(from: ByteBuffer(megaBytes: 10), exchange: testExchange, routingKey: "")
27+
}
28+
29+
func testGetLargePayloads() async throws {
30+
let queueName = "lp-test-get"
31+
let payload1 = ByteBuffer(megaBytes: 20)
32+
let payload2 = ByteBuffer(megaBytes: 30, repeating: UInt8(ascii: "b"))
33+
34+
try await channel.queueDeclare(name: queueName, exclusive: true)
35+
try await channel.queueBind(queue: queueName, exchange: testExchange)
36+
37+
try await channel.basicPublish(from: payload1, exchange: testExchange, routingKey: "")
38+
try await channel.basicPublish(from: payload2, exchange: testExchange, routingKey: "")
39+
let getResponse1 = try await channel.basicGet(queue: queueName, noAck: true)
40+
let getResponse2 = try await channel.basicGet(queue: queueName, noAck: true)
41+
42+
XCTAssertEqual(getResponse1?.message.body, payload1)
43+
XCTAssertEqual(getResponse1?.messageCount, 1)
44+
XCTAssertEqual(getResponse2?.message.body, payload2)
45+
XCTAssertEqual(getResponse2?.messageCount, 0)
46+
}
47+
48+
func testConsumeLargePayloads() async throws {
49+
let queueName = "lp-test-consume"
50+
let payload1 = ByteBuffer(megaBytes: 20, repeating: UInt8(ascii: "a"))
51+
let payload2 = ByteBuffer(megaBytes: 10, repeating: UInt8(ascii: "b"))
52+
53+
try await channel.queueDeclare(name: queueName, exclusive: true)
54+
try await channel.queueBind(queue: queueName, exchange: testExchange)
55+
56+
var messages = try await channel.basicConsume(queue: queueName, noAck: true).prefix(2).makeAsyncIterator()
57+
58+
try await channel.basicPublish(from: payload1, exchange: testExchange, routingKey: "")
59+
try await channel.basicPublish(from: payload2, exchange: testExchange, routingKey: "")
60+
61+
let received1 = try await messages.next()
62+
let received2 = try await messages.next()
63+
let received3 = try await messages.next()
64+
65+
XCTAssertEqual(received1?.body, payload1)
66+
XCTAssertEqual(received2?.body, payload2)
67+
XCTAssertNil(received3)
68+
}
69+
70+
func testReturnLargePayloads() async throws {
71+
let payload1 = ByteBuffer(megaBytes: 20, repeating: UInt8(ascii: "a"))
72+
let payload2 = ByteBuffer(megaBytes: 10, repeating: UInt8(ascii: "b"))
73+
74+
try await channel.exchangeDeclare(name: "no-binding", type: "fanout")
75+
76+
var returns = try await channel.returnConsume(named: "test").prefix(2).makeAsyncIterator()
77+
78+
79+
try await channel.basicPublish(from: payload1, exchange: "no-binding", routingKey: "", mandatory: true)
80+
try await channel.basicPublish(from: payload2, exchange: "no-binding", routingKey: "", mandatory: true)
81+
82+
let received1 = try await returns.next()
83+
let received2 = try await returns.next()
84+
let received3 = try await returns.next()
85+
86+
XCTAssertEqual(received1?.body, payload1)
87+
XCTAssertEqual(received2?.body, payload2)
88+
XCTAssertNil(received3)
89+
}
90+
}
91+
92+
extension ByteBuffer {
93+
init(megaBytes: Int, repeating: UInt8 = .init(ascii: "a")) {
94+
self.init(repeating: repeating, count: megaBytes * 1024 * 1024)
95+
}
96+
}

0 commit comments

Comments
 (0)