diff --git a/Sources/PostgresNIO/Connection/PostgresConnection+Configuration.swift b/Sources/PostgresNIO/Connection/PostgresConnection+Configuration.swift index 22c59d8a..dd0f5404 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection+Configuration.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection+Configuration.swift @@ -85,7 +85,11 @@ extension PostgresConnection { /// This property is provided for compatibility with Amazon RDS Proxy, which requires it to be `false`. /// If you are not using Amazon RDS Proxy, you should leave this set to `true` (the default). public var requireBackendKeyData: Bool - + + /// Additional parameters to send to the server on startup. The name value pairs are added to the initial + /// startup message that the client sends to the server. + public var additionalStartupParameters: [(String, String)] + /// Create an options structure with default values. /// /// Most users should not need to adjust the defaults. @@ -93,6 +97,7 @@ extension PostgresConnection { self.connectTimeout = .seconds(10) self.tlsServerName = nil self.requireBackendKeyData = true + self.additionalStartupParameters = [] } } diff --git a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift index 9cde0cf3..d7a609a6 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift @@ -1113,11 +1113,19 @@ struct SendPrepareStatement { let query: String } -struct AuthContext: Equatable, CustomDebugStringConvertible { - let username: String - let password: String? - let database: String? - +struct AuthContext: CustomDebugStringConvertible { + var username: String + var password: String? + var database: String? + var additionalParameters: [(String, String)] + + init(username: String, password: String? = nil, database: String? = nil, additionalParameters: [(String, String)] = []) { + self.username = username + self.password = password + self.database = database + self.additionalParameters = additionalParameters + } + var debugDescription: String { """ AuthContext(username: \(String(reflecting: self.username)), \ @@ -1127,6 +1135,22 @@ struct AuthContext: Equatable, CustomDebugStringConvertible { } } +extension AuthContext: Equatable { + static func ==(lhs: Self, rhs: Self) -> Bool { + guard lhs.username == rhs.username + && lhs.password == rhs.password + && lhs.database == rhs.database + && lhs.additionalParameters.count == rhs.additionalParameters.count + else { + return false + } + + return lhs.additionalParameters.elementsEqual(rhs.additionalParameters) { lhs, rhs in + lhs.0 == rhs.0 && lhs.1 == rhs.1 + } + } +} + enum PasswordAuthencationMode: Equatable { case cleartext case md5(salt: UInt32) diff --git a/Sources/PostgresNIO/New/PostgresChannelHandler.swift b/Sources/PostgresNIO/New/PostgresChannelHandler.swift index 9d0ef2a5..54ae0fc9 100644 --- a/Sources/PostgresNIO/New/PostgresChannelHandler.swift +++ b/Sources/PostgresNIO/New/PostgresChannelHandler.swift @@ -328,7 +328,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler { case .wait: break case .sendStartupMessage(let authContext): - self.encoder.startup(user: authContext.username, database: authContext.database) + self.encoder.startup(user: authContext.username, database: authContext.database, options: authContext.additionalParameters) context.writeAndFlush(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: nil) case .sendSSLRequest: self.encoder.ssl() diff --git a/Sources/PostgresNIO/New/PostgresFrontendMessageEncoder.swift b/Sources/PostgresNIO/New/PostgresFrontendMessageEncoder.swift index e98ab1f1..97805418 100644 --- a/Sources/PostgresNIO/New/PostgresFrontendMessageEncoder.swift +++ b/Sources/PostgresNIO/New/PostgresFrontendMessageEncoder.swift @@ -25,7 +25,7 @@ struct PostgresFrontendMessageEncoder { self.buffer = buffer } - mutating func startup(user: String, database: String?) { + mutating func startup(user: String, database: String?, options: [(String, String)]) { self.clearIfNeeded() self.buffer.psqlLengthPrefixed { buffer in buffer.writeInteger(Self.startupVersionThree) @@ -37,6 +37,13 @@ struct PostgresFrontendMessageEncoder { buffer.writeNullTerminatedString(database) } + // we don't send replication parameters, as the default is false and this is what we + // need for a client + for (key, value) in options { + buffer.writeNullTerminatedString(key) + buffer.writeNullTerminatedString(value) + } + buffer.writeInteger(UInt8(0)) } } diff --git a/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift b/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift index 46c043b1..55ccd0a9 100644 --- a/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift +++ b/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift @@ -39,8 +39,8 @@ struct PSQLFrontendMessageDecoder: NIOSingleStepByteToMessageDecoder { case 196608: var user: String? var database: String? - var options: String? - + var options = [(String, String)]() + while let name = messageSlice.readNullTerminatedString(), messageSlice.readerIndex < finalIndex { let value = messageSlice.readNullTerminatedString() @@ -51,11 +51,10 @@ struct PSQLFrontendMessageDecoder: NIOSingleStepByteToMessageDecoder { case "database": database = value - case "options": - options = value - default: - break + if let value = value { + options.append((name, value)) + } } } diff --git a/Tests/PostgresNIOTests/New/Extensions/PostgresFrontendMessage.swift b/Tests/PostgresNIOTests/New/Extensions/PostgresFrontendMessage.swift index 010667dc..2532959a 100644 --- a/Tests/PostgresNIOTests/New/Extensions/PostgresFrontendMessage.swift +++ b/Tests/PostgresNIOTests/New/Extensions/PostgresFrontendMessage.swift @@ -103,7 +103,7 @@ enum PostgresFrontendMessage: Equatable { static let requestCode: Int32 = 80877103 } - struct Startup: Hashable { + struct Startup: Equatable { static let versionThree: Int32 = 0x00_03_00_00 /// Creates a `Startup` with "3.0" as the protocol version. @@ -119,7 +119,7 @@ enum PostgresFrontendMessage: Equatable { /// The protocol version number is followed by one or more pairs of parameter /// name and value strings. A zero byte is required as a terminator after /// the last name/value pair. `user` is required, others are optional. - struct Parameters: Hashable { + struct Parameters: Equatable { enum Replication { case `true` case `false` @@ -136,12 +136,33 @@ enum PostgresFrontendMessage: Equatable { /// of setting individual run-time parameters.) Spaces within this string are /// considered to separate arguments, unless escaped with a /// backslash (\); write \\ to represent a literal backslash. - var options: String? + var options: [(String, String)] /// Used to connect in streaming replication mode, where a small set of /// replication commands can be issued instead of SQL statements. Value /// can be true, false, or database, and the default is false. var replication: Replication + + static func ==(lhs: Self, rhs: Self) -> Bool { + guard lhs.user == rhs.user + && lhs.database == rhs.database + && lhs.replication == rhs.replication + && lhs.options.count == rhs.options.count + else { + return false + } + + var lhsIterator = lhs.options.makeIterator() + var rhsIterator = rhs.options.makeIterator() + + while let lhsNext = lhsIterator.next(), let rhsNext = rhsIterator.next() { + guard lhsNext.0 == rhsNext.0 && lhsNext.1 == rhsNext.1 else { + return false + } + } + return true + } + } var parameters: Parameters diff --git a/Tests/PostgresNIOTests/New/Messages/StartupTests.swift b/Tests/PostgresNIOTests/New/Messages/StartupTests.swift index 39e9bb42..5af3bf34 100644 --- a/Tests/PostgresNIOTests/New/Messages/StartupTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/StartupTests.swift @@ -11,7 +11,7 @@ class StartupTests: XCTestCase { let user = "test" let database = "abc123" - encoder.startup(user: user, database: database) + encoder.startup(user: user, database: database, options: []) byteBuffer = encoder.flushBuffer() let byteBufferLength = Int32(byteBuffer.readableBytes) @@ -32,7 +32,7 @@ class StartupTests: XCTestCase { let user = "test" - encoder.startup(user: user, database: nil) + encoder.startup(user: user, database: nil, options: []) byteBuffer = encoder.flushBuffer() let byteBufferLength = Int32(byteBuffer.readableBytes) @@ -44,4 +44,41 @@ class StartupTests: XCTestCase { XCTAssertEqual(byteBuffer.readableBytes, 0) } + + func testStartupMessageWithAdditionalOptions() { + var encoder = PostgresFrontendMessageEncoder(buffer: .init()) + var byteBuffer = ByteBuffer() + + let user = "test" + let database = "abc123" + + encoder.startup(user: user, database: database, options: [("some", "options")]) + byteBuffer = encoder.flushBuffer() + + let byteBufferLength = Int32(byteBuffer.readableBytes) + XCTAssertEqual(byteBufferLength, byteBuffer.readInteger()) + XCTAssertEqual(PostgresFrontendMessage.Startup.versionThree, byteBuffer.readInteger()) + XCTAssertEqual(byteBuffer.readNullTerminatedString(), "user") + XCTAssertEqual(byteBuffer.readNullTerminatedString(), "test") + XCTAssertEqual(byteBuffer.readNullTerminatedString(), "database") + XCTAssertEqual(byteBuffer.readNullTerminatedString(), "abc123") + XCTAssertEqual(byteBuffer.readNullTerminatedString(), "some") + XCTAssertEqual(byteBuffer.readNullTerminatedString(), "options") + XCTAssertEqual(byteBuffer.readInteger(), UInt8(0)) + + XCTAssertEqual(byteBuffer.readableBytes, 0) + } +} + +extension PostgresFrontendMessage.Startup.Parameters.Replication { + var stringValue: String { + switch self { + case .true: + return "true" + case .false: + return "false" + case .database: + return "replication" + } + } } diff --git a/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift b/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift index b81d0899..dfdcc53e 100644 --- a/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift @@ -37,9 +37,8 @@ class PostgresChannelHandlerTests: XCTestCase { XCTAssertEqual(startup.parameters.user, config.username) XCTAssertEqual(startup.parameters.database, config.database) - XCTAssertEqual(startup.parameters.options, nil) - XCTAssertEqual(startup.parameters.replication, .false) - + XCTAssert(startup.parameters.options.isEmpty) + XCTAssertNoThrow(try embedded.writeInbound(PostgresBackendMessage.authentication(.ok))) XCTAssertNoThrow(try embedded.writeInbound(PostgresBackendMessage.backendKeyData(.init(processID: 1234, secretKey: 5678)))) XCTAssertNoThrow(try embedded.writeInbound(PostgresBackendMessage.readyForQuery(.idle))) @@ -209,7 +208,7 @@ class PostgresChannelHandlerTests: XCTestCase { XCTAssertEqual(startup.parameters.user, config.username) XCTAssertEqual(startup.parameters.database, config.database) - XCTAssertEqual(startup.parameters.options, nil) + XCTAssert(startup.parameters.options.isEmpty) XCTAssertEqual(startup.parameters.replication, .false) var buffer = ByteBuffer() @@ -282,7 +281,7 @@ extension AuthContext { PostgresFrontendMessage.Startup.Parameters( user: self.username, database: self.database, - options: nil, + options: self.additionalParameters, replication: .false ) } diff --git a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift index 3b1a8ca9..82baf914 100644 --- a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift @@ -602,7 +602,7 @@ class PostgresConnectionTests: XCTestCase { async let connectionPromise = PostgresConnection.connect(on: eventLoop, configuration: configuration, id: 1, logger: self.logger) let message = try await channel.waitForOutboundWrite(as: PostgresFrontendMessage.self) - XCTAssertEqual(message, .startup(.versionThree(parameters: .init(user: "username", database: "database", replication: .false)))) + XCTAssertEqual(message, .startup(.versionThree(parameters: .init(user: "username", database: "database", options: [], replication: .false)))) try await channel.writeInbound(PostgresBackendMessage.authentication(.ok)) try await channel.writeInbound(PostgresBackendMessage.backendKeyData(.init(processID: 1234, secretKey: 5678))) try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle))