Skip to content

Commit

Permalink
Support additional connection parameters (#361)
Browse files Browse the repository at this point in the history
  • Loading branch information
fabianfett authored Dec 12, 2023
1 parent e60e495 commit fa3137d
Show file tree
Hide file tree
Showing 9 changed files with 117 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,19 @@ extension PostgresConnection {
/// This property is provided for compatibility with Amazon RDS Proxy, which requires it to be `false`.
/// If you are not using Amazon RDS Proxy, you should leave this set to `true` (the default).
public var requireBackendKeyData: Bool


/// Additional parameters to send to the server on startup. The name value pairs are added to the initial
/// startup message that the client sends to the server.
public var additionalStartupParameters: [(String, String)]

/// Create an options structure with default values.
///
/// Most users should not need to adjust the defaults.
public init() {
self.connectTimeout = .seconds(10)
self.tlsServerName = nil
self.requireBackendKeyData = true
self.additionalStartupParameters = []
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1113,11 +1113,19 @@ struct SendPrepareStatement {
let query: String
}

struct AuthContext: Equatable, CustomDebugStringConvertible {
let username: String
let password: String?
let database: String?

struct AuthContext: CustomDebugStringConvertible {
var username: String
var password: String?
var database: String?
var additionalParameters: [(String, String)]

init(username: String, password: String? = nil, database: String? = nil, additionalParameters: [(String, String)] = []) {
self.username = username
self.password = password
self.database = database
self.additionalParameters = additionalParameters
}

var debugDescription: String {
"""
AuthContext(username: \(String(reflecting: self.username)), \
Expand All @@ -1127,6 +1135,22 @@ struct AuthContext: Equatable, CustomDebugStringConvertible {
}
}

extension AuthContext: Equatable {
static func ==(lhs: Self, rhs: Self) -> Bool {
guard lhs.username == rhs.username
&& lhs.password == rhs.password
&& lhs.database == rhs.database
&& lhs.additionalParameters.count == rhs.additionalParameters.count
else {
return false
}

return lhs.additionalParameters.elementsEqual(rhs.additionalParameters) { lhs, rhs in
lhs.0 == rhs.0 && lhs.1 == rhs.1
}
}
}

enum PasswordAuthencationMode: Equatable {
case cleartext
case md5(salt: UInt32)
Expand Down
2 changes: 1 addition & 1 deletion Sources/PostgresNIO/New/PostgresChannelHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
case .wait:
break
case .sendStartupMessage(let authContext):
self.encoder.startup(user: authContext.username, database: authContext.database)
self.encoder.startup(user: authContext.username, database: authContext.database, options: authContext.additionalParameters)
context.writeAndFlush(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: nil)
case .sendSSLRequest:
self.encoder.ssl()
Expand Down
9 changes: 8 additions & 1 deletion Sources/PostgresNIO/New/PostgresFrontendMessageEncoder.swift
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ struct PostgresFrontendMessageEncoder {
self.buffer = buffer
}

mutating func startup(user: String, database: String?) {
mutating func startup(user: String, database: String?, options: [(String, String)]) {
self.clearIfNeeded()
self.buffer.psqlLengthPrefixed { buffer in
buffer.writeInteger(Self.startupVersionThree)
Expand All @@ -37,6 +37,13 @@ struct PostgresFrontendMessageEncoder {
buffer.writeNullTerminatedString(database)
}

// we don't send replication parameters, as the default is false and this is what we
// need for a client
for (key, value) in options {
buffer.writeNullTerminatedString(key)
buffer.writeNullTerminatedString(value)
}

buffer.writeInteger(UInt8(0))
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ struct PSQLFrontendMessageDecoder: NIOSingleStepByteToMessageDecoder {
case 196608:
var user: String?
var database: String?
var options: String?
var options = [(String, String)]()

while let name = messageSlice.readNullTerminatedString(), messageSlice.readerIndex < finalIndex {
let value = messageSlice.readNullTerminatedString()

Expand All @@ -51,11 +51,10 @@ struct PSQLFrontendMessageDecoder: NIOSingleStepByteToMessageDecoder {
case "database":
database = value

case "options":
options = value

default:
break
if let value = value {
options.append((name, value))
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ enum PostgresFrontendMessage: Equatable {
static let requestCode: Int32 = 80877103
}

struct Startup: Hashable {
struct Startup: Equatable {
static let versionThree: Int32 = 0x00_03_00_00

/// Creates a `Startup` with "3.0" as the protocol version.
Expand All @@ -119,7 +119,7 @@ enum PostgresFrontendMessage: Equatable {
/// The protocol version number is followed by one or more pairs of parameter
/// name and value strings. A zero byte is required as a terminator after
/// the last name/value pair. `user` is required, others are optional.
struct Parameters: Hashable {
struct Parameters: Equatable {
enum Replication {
case `true`
case `false`
Expand All @@ -136,12 +136,33 @@ enum PostgresFrontendMessage: Equatable {
/// of setting individual run-time parameters.) Spaces within this string are
/// considered to separate arguments, unless escaped with a
/// backslash (\); write \\ to represent a literal backslash.
var options: String?
var options: [(String, String)]

/// Used to connect in streaming replication mode, where a small set of
/// replication commands can be issued instead of SQL statements. Value
/// can be true, false, or database, and the default is false.
var replication: Replication

static func ==(lhs: Self, rhs: Self) -> Bool {
guard lhs.user == rhs.user
&& lhs.database == rhs.database
&& lhs.replication == rhs.replication
&& lhs.options.count == rhs.options.count
else {
return false
}

var lhsIterator = lhs.options.makeIterator()
var rhsIterator = rhs.options.makeIterator()

while let lhsNext = lhsIterator.next(), let rhsNext = rhsIterator.next() {
guard lhsNext.0 == rhsNext.0 && lhsNext.1 == rhsNext.1 else {
return false
}
}
return true
}

}

var parameters: Parameters
Expand Down
41 changes: 39 additions & 2 deletions Tests/PostgresNIOTests/New/Messages/StartupTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class StartupTests: XCTestCase {
let user = "test"
let database = "abc123"

encoder.startup(user: user, database: database)
encoder.startup(user: user, database: database, options: [])
byteBuffer = encoder.flushBuffer()

let byteBufferLength = Int32(byteBuffer.readableBytes)
Expand All @@ -32,7 +32,7 @@ class StartupTests: XCTestCase {

let user = "test"

encoder.startup(user: user, database: nil)
encoder.startup(user: user, database: nil, options: [])
byteBuffer = encoder.flushBuffer()

let byteBufferLength = Int32(byteBuffer.readableBytes)
Expand All @@ -44,4 +44,41 @@ class StartupTests: XCTestCase {

XCTAssertEqual(byteBuffer.readableBytes, 0)
}

func testStartupMessageWithAdditionalOptions() {
var encoder = PostgresFrontendMessageEncoder(buffer: .init())
var byteBuffer = ByteBuffer()

let user = "test"
let database = "abc123"

encoder.startup(user: user, database: database, options: [("some", "options")])
byteBuffer = encoder.flushBuffer()

let byteBufferLength = Int32(byteBuffer.readableBytes)
XCTAssertEqual(byteBufferLength, byteBuffer.readInteger())
XCTAssertEqual(PostgresFrontendMessage.Startup.versionThree, byteBuffer.readInteger())
XCTAssertEqual(byteBuffer.readNullTerminatedString(), "user")
XCTAssertEqual(byteBuffer.readNullTerminatedString(), "test")
XCTAssertEqual(byteBuffer.readNullTerminatedString(), "database")
XCTAssertEqual(byteBuffer.readNullTerminatedString(), "abc123")
XCTAssertEqual(byteBuffer.readNullTerminatedString(), "some")
XCTAssertEqual(byteBuffer.readNullTerminatedString(), "options")
XCTAssertEqual(byteBuffer.readInteger(), UInt8(0))

XCTAssertEqual(byteBuffer.readableBytes, 0)
}
}

extension PostgresFrontendMessage.Startup.Parameters.Replication {
var stringValue: String {
switch self {
case .true:
return "true"
case .false:
return "false"
case .database:
return "replication"
}
}
}
9 changes: 4 additions & 5 deletions Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,8 @@ class PostgresChannelHandlerTests: XCTestCase {

XCTAssertEqual(startup.parameters.user, config.username)
XCTAssertEqual(startup.parameters.database, config.database)
XCTAssertEqual(startup.parameters.options, nil)
XCTAssertEqual(startup.parameters.replication, .false)

XCTAssert(startup.parameters.options.isEmpty)

XCTAssertNoThrow(try embedded.writeInbound(PostgresBackendMessage.authentication(.ok)))
XCTAssertNoThrow(try embedded.writeInbound(PostgresBackendMessage.backendKeyData(.init(processID: 1234, secretKey: 5678))))
XCTAssertNoThrow(try embedded.writeInbound(PostgresBackendMessage.readyForQuery(.idle)))
Expand Down Expand Up @@ -209,7 +208,7 @@ class PostgresChannelHandlerTests: XCTestCase {

XCTAssertEqual(startup.parameters.user, config.username)
XCTAssertEqual(startup.parameters.database, config.database)
XCTAssertEqual(startup.parameters.options, nil)
XCTAssert(startup.parameters.options.isEmpty)
XCTAssertEqual(startup.parameters.replication, .false)

var buffer = ByteBuffer()
Expand Down Expand Up @@ -282,7 +281,7 @@ extension AuthContext {
PostgresFrontendMessage.Startup.Parameters(
user: self.username,
database: self.database,
options: nil,
options: self.additionalParameters,
replication: .false
)
}
Expand Down
2 changes: 1 addition & 1 deletion Tests/PostgresNIOTests/New/PostgresConnectionTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,7 @@ class PostgresConnectionTests: XCTestCase {

async let connectionPromise = PostgresConnection.connect(on: eventLoop, configuration: configuration, id: 1, logger: self.logger)
let message = try await channel.waitForOutboundWrite(as: PostgresFrontendMessage.self)
XCTAssertEqual(message, .startup(.versionThree(parameters: .init(user: "username", database: "database", replication: .false))))
XCTAssertEqual(message, .startup(.versionThree(parameters: .init(user: "username", database: "database", options: [], replication: .false))))
try await channel.writeInbound(PostgresBackendMessage.authentication(.ok))
try await channel.writeInbound(PostgresBackendMessage.backendKeyData(.init(processID: 1234, secretKey: 5678)))
try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle))
Expand Down

0 comments on commit fa3137d

Please sign in to comment.