diff --git a/Sources/PostgresNIO/New/PSQLRowStream.swift b/Sources/PostgresNIO/New/PSQLRowStream.swift index b3dfea30..0255e462 100644 --- a/Sources/PostgresNIO/New/PSQLRowStream.swift +++ b/Sources/PostgresNIO/New/PSQLRowStream.swift @@ -86,6 +86,7 @@ final class PSQLRowStream: @unchecked Sendable { elementType: DataRow.self, failureType: Error.self, backPressureStrategy: AdaptiveRowBuffer(), + finishOnDeinit: false, delegate: self ) diff --git a/Sources/PostgresNIO/Utilities/SASLAuthentication+SCRAM-SHA256.swift b/Sources/PostgresNIO/Utilities/SASLAuthentication+SCRAM-SHA256.swift index ac1d9ead..2a717b6b 100644 --- a/Sources/PostgresNIO/Utilities/SASLAuthentication+SCRAM-SHA256.swift +++ b/Sources/PostgresNIO/Utilities/SASLAuthentication+SCRAM-SHA256.swift @@ -1,13 +1,10 @@ import Crypto import Foundation -extension UInt8: ExpressibleByUnicodeScalarLiteral { +extension UInt8 { fileprivate static var NUL: UInt8 { return 0x00 /* yeah, just U+0000 man */ } fileprivate static var comma: UInt8 { return 0x2c /* .init(ascii: ",") */ } fileprivate static var equals: UInt8 { return 0x3d /* .init(ascii: "=") */ } - public init(unicodeScalarLiteral value: Unicode.Scalar) { - self.init(ascii: value) - } } fileprivate extension String { @@ -87,7 +84,7 @@ fileprivate extension Array where Element == UInt8 { */ var isValidScramValue: Bool { // TODO: FInd a better way than doing a whole construction of String... - return self.count > 0 && !(String(bytes: self, encoding: .utf8)?.contains(",") ?? true) + return self.count > 0 && !(String(decoding: self, as: Unicode.UTF8.self).contains(",")) } } @@ -171,40 +168,40 @@ fileprivate struct SCRAMMessageParser { static func parseAttributePair(name: [UInt8], value: [UInt8], isGS2Header: Bool = false) -> SCRAMAttribute? { guard name.count == 1 || isGS2Header else { return nil } switch name.first { - case "m" where !isGS2Header: return .m(value) - case "r" where !isGS2Header: return String(printableAscii: value).map { .r($0) } - case "c" where !isGS2Header: - guard let parsedAttrs = value.decodingBase64().flatMap({ parse(raw: $0, isGS2Header: true) }) else { return nil } - guard (1...3).contains(parsedAttrs.count) else { return nil } - switch (parsedAttrs.first, parsedAttrs.dropFirst(1).first, parsedAttrs.dropFirst(2).first) { - case let (.gp(.bind(name, .none)), .a(ident), .gm(data)): return .c(binding: .bind(name, data), authIdentity: ident) - case let (.gp(.bind(name, .none)), .gm(data), .none): return .c(binding: .bind(name, data)) - case let (.gp(bind), .a(ident), .none): return .c(binding: bind, authIdentity: ident) - case let (.gp(bind), .none, .none): return .c(binding: bind) - default: return nil - } - case "n" where !isGS2Header: return String(bytes: value, encoding: .utf8)?.decodedAsSaslName.map { .n($0) } - case "s" where !isGS2Header: return value.decodingBase64().map { .s($0) } - case "i" where !isGS2Header: return String(printableAscii: value).flatMap { UInt32.init($0) }.map { .i($0) } - case "p" where !isGS2Header: return value.decodingBase64().map { .p($0) } - case "v" where !isGS2Header: return value.decodingBase64().map { .v($0) } - case "e" where !isGS2Header: // TODO: actually map the specific enum string values - guard value.isValidScramValue else { return nil } - return String(bytes: value, encoding: .utf8).flatMap { SCRAMServerError(rawValue: $0) }.map { .e($0) } - - case "y" where isGS2Header && value.count == 0: return .gp(.unused) - case "n" where isGS2Header && value.count == 0: return .gp(.unsupported) - case "p" where isGS2Header: return String(asciiAlphanumericMorse: value).map { .gp(.bind($0, nil)) } - case "a" where isGS2Header: return String(bytes: value, encoding: .utf8)?.decodedAsSaslName.map { .a($0) } - case .none where isGS2Header: return .a(nil) + case UInt8(ascii: "m") where !isGS2Header: return .m(value) + case UInt8(ascii: "r") where !isGS2Header: return String(printableAscii: value).map { .r($0) } + case UInt8(ascii: "c") where !isGS2Header: + guard let parsedAttrs = value.decodingBase64().flatMap({ parse(raw: $0, isGS2Header: true) }) else { return nil } + guard (1...3).contains(parsedAttrs.count) else { return nil } + switch (parsedAttrs.first, parsedAttrs.dropFirst(1).first, parsedAttrs.dropFirst(2).first) { + case let (.gp(.bind(name, .none)), .a(ident), .gm(data)): return .c(binding: .bind(name, data), authIdentity: ident) + case let (.gp(.bind(name, .none)), .gm(data), .none): return .c(binding: .bind(name, data)) + case let (.gp(bind), .a(ident), .none): return .c(binding: bind, authIdentity: ident) + case let (.gp(bind), .none, .none): return .c(binding: bind) + default: return nil + } + case UInt8(ascii: "n") where !isGS2Header: return String(decoding: value, as: Unicode.UTF8.self).decodedAsSaslName.map { .n($0) } + case UInt8(ascii: "s") where !isGS2Header: return value.decodingBase64().map { .s($0) } + case UInt8(ascii: "i") where !isGS2Header: return String(printableAscii: value).flatMap { UInt32.init($0) }.map { .i($0) } + case UInt8(ascii: "p") where !isGS2Header: return value.decodingBase64().map { .p($0) } + case UInt8(ascii: "v") where !isGS2Header: return value.decodingBase64().map { .v($0) } + case UInt8(ascii: "e") where !isGS2Header: // TODO: actually map the specific enum string values + guard value.isValidScramValue else { return nil } + return SCRAMServerError(rawValue: String(decoding: value, as: Unicode.UTF8.self)).flatMap { .e($0) } - default: - if isGS2Header { - return .gm(name + value) - } else { - guard value.count > 0, value.isValidScramValue else { return nil } - return .optional(name: CChar(name[0]), value: value) - } + case UInt8(ascii: "y") where isGS2Header && value.count == 0: return .gp(.unused) + case UInt8(ascii: "n") where isGS2Header && value.count == 0: return .gp(.unsupported) + case UInt8(ascii: "p") where isGS2Header: return String(asciiAlphanumericMorse: value).map { .gp(.bind($0, nil)) } + case UInt8(ascii: "a") where isGS2Header: return String(decoding: value, as: Unicode.UTF8.self).decodedAsSaslName.map { .a($0) } + case .none where isGS2Header: return .a(nil) + + default: + if isGS2Header { + return .gm(name + value) + } else { + guard value.count > 0, value.isValidScramValue else { return nil } + return .optional(name: CChar(name[0]), value: value) + } } } @@ -230,45 +227,45 @@ fileprivate struct SCRAMMessageParser { for attribute in attributes { switch attribute { case .m(let value): - result.append("m"); result.append("="); result.append(contentsOf: value) + result.append(UInt8(ascii: "m")); result.append(.equals); result.append(contentsOf: value) case .r(let nonce): - result.append("r"); result.append("="); result.append(contentsOf: nonce.utf8.map { UInt8($0) }) + result.append(UInt8(ascii: "r")); result.append(.equals); result.append(contentsOf: nonce.utf8.map { UInt8($0) }) case .n(let name): - result.append("n"); result.append("="); result.append(contentsOf: name.encodedAsSaslName.utf8.map { UInt8($0) }) + result.append(UInt8(ascii: "n")); result.append(.equals); result.append(contentsOf: name.encodedAsSaslName.utf8.map { UInt8($0) }) case .s(let salt): - result.append("s"); result.append("="); result.append(contentsOf: salt.encodingBase64()) + result.append(UInt8(ascii: "s")); result.append(.equals); result.append(contentsOf: salt.encodingBase64()) case .i(let count): - result.append("i"); result.append("="); result.append(contentsOf: "\(count)".utf8.map { UInt8($0) }) + result.append(UInt8(ascii: "i")); result.append(.equals); result.append(contentsOf: "\(count)".utf8.map { UInt8($0) }) case .p(let proof): - result.append("p"); result.append("="); result.append(contentsOf: proof.encodingBase64()) + result.append(UInt8(ascii: "p")); result.append(.equals); result.append(contentsOf: proof.encodingBase64()) case .v(let signature): - result.append("v"); result.append("="); result.append(contentsOf: signature.encodingBase64()) + result.append(UInt8(ascii: "v")); result.append(.equals); result.append(contentsOf: signature.encodingBase64()) case .e(let error): - result.append("e"); result.append("="); result.append(contentsOf: error.rawValue.utf8.map { UInt8($0) }) + result.append(UInt8(ascii: "e")); result.append(.equals); result.append(contentsOf: error.rawValue.utf8.map { UInt8($0) }) case .c(let binding, let identity): if isInitialGS2Header { switch binding { - case .unsupported: result.append("n") - case .unused: result.append("y") - case .bind(let name, _): result.append("p"); result.append("="); result.append(contentsOf: name.utf8.map { UInt8($0) }) + case .unsupported: result.append(UInt8(ascii: "n")) + case .unused: result.append(UInt8(ascii: "y")) + case .bind(let name, _): result.append(UInt8(ascii: "p")); result.append(.equals); result.append(contentsOf: name.utf8.map { UInt8($0) }) } - result.append(",") + result.append(.comma) if let identity = identity { - result.append("a"); result.append("="); result.append(contentsOf: identity.encodedAsSaslName.utf8.map { UInt8($0) }) + result.append(UInt8(ascii: "a")); result.append(.equals); result.append(contentsOf: identity.encodedAsSaslName.utf8.map { UInt8($0) }) } - result.append(",") + result.append(.comma) } else { guard var partial = serialize([attribute], isInitialGS2Header: true) else { return nil } if case let .bind(_, data) = binding { guard let data = data else { return nil } partial.append(contentsOf: data) } - result.append("c"); result.append("="); result.append(contentsOf: partial.encodingBase64()) + result.append(UInt8(ascii: "c")); result.append(.equals); result.append(contentsOf: partial.encodingBase64()) } default: return nil } - result.append(",") + result.append(.comma) } return result.dropLast() } @@ -472,7 +469,7 @@ fileprivate final class SASLMechanism_SCRAM_SHA256_Common { let saltedPassword = Hi(string: password, salt: serverSalt, iterations: serverIterations) let clientKey = HMAC.authenticationCode(for: "Client Key".data(using: .utf8)!, using: .init(data: saltedPassword)) let storedKey = SHA256.hash(data: Data(clientKey)) - var authMessage = firstMessageBare; authMessage.append(","); authMessage.append(contentsOf: message); authMessage.append(","); authMessage.append(contentsOf: clientFinalNoProof) + var authMessage = firstMessageBare; authMessage.append(.comma); authMessage.append(contentsOf: message); authMessage.append(.comma); authMessage.append(contentsOf: clientFinalNoProof) let clientSignature = HMAC.authenticationCode(for: authMessage, using: .init(data: storedKey)) var clientProof = Array(clientKey) @@ -485,7 +482,7 @@ fileprivate final class SASLMechanism_SCRAM_SHA256_Common { } // Generate a `client-final-message` - var clientFinalMessage = clientFinalNoProof; clientFinalMessage.append(",") + var clientFinalMessage = clientFinalNoProof; clientFinalMessage.append(.comma) guard let proofPart = SCRAMMessageParser.serialize([.p(Array(clientProof))]) else { throw SASLAuthenticationError.genericAuthenticationFailure } clientFinalMessage.append(contentsOf: proofPart) @@ -590,7 +587,7 @@ fileprivate final class SASLMechanism_SCRAM_SHA256_Common { // Compute client signature let clientKey = HMAC.authenticationCode(for: "Client Key".data(using: .utf8)!, using: .init(data: saltedPassword)) let storedKey = SHA256.hash(data: Data(clientKey)) - var authMessage = clientBareFirstMessage; authMessage.append(","); authMessage.append(contentsOf: serverFirstMessage); authMessage.append(","); authMessage.append(contentsOf: message.dropLast(proof.count + 3)) + var authMessage = clientBareFirstMessage; authMessage.append(.comma); authMessage.append(contentsOf: serverFirstMessage); authMessage.append(.comma); authMessage.append(contentsOf: message.dropLast(proof.count + 3)) let clientSignature = HMAC.authenticationCode(for: authMessage, using: .init(data: storedKey)) // Recompute client key from signature and proof, verify match diff --git a/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift b/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift index ba3c6a3f..3e3c9d65 100644 --- a/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift +++ b/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift @@ -803,7 +803,7 @@ final class ConnectionPoolTests: XCTestCase { pool.connectionReceivedNewMaxStreamSetting(connection, newMaxStreamSetting: 21) - for (index, request) in requests.enumerated() { + for (_, request) in requests.enumerated() { let connection = try await request.future.success connections.append(connection) } diff --git a/Tests/PostgresNIOTests/New/Extensions/ConnectionAction+TestUtils.swift b/Tests/PostgresNIOTests/New/Extensions/ConnectionAction+TestUtils.swift index febeee37..d20032a8 100644 --- a/Tests/PostgresNIOTests/New/Extensions/ConnectionAction+TestUtils.swift +++ b/Tests/PostgresNIOTests/New/Extensions/ConnectionAction+TestUtils.swift @@ -2,7 +2,8 @@ import class Foundation.JSONEncoder import NIOCore @testable import PostgresNIO -extension ConnectionStateMachine.ConnectionAction: Equatable { +// fully-qualifying all types in the extension has the same effect as adding a `@retroactive` before the protocol +extension PostgresNIO.ConnectionStateMachine.ConnectionAction: Swift.Equatable { public static func == (lhs: Self, rhs: Self) -> Bool { switch (lhs, rhs) { case (.read, read): @@ -47,7 +48,8 @@ extension ConnectionStateMachine.ConnectionAction: Equatable { } } -extension ConnectionStateMachine.ConnectionAction.CleanUpContext: Equatable { +// fully-qualifying all types in the extension has the same effect as adding a `@retroactive` before the protocol' +extension PostgresNIO.ConnectionStateMachine.ConnectionAction.CleanUpContext: Swift.Equatable { public static func == (lhs: Self, rhs: Self) -> Bool { guard lhs.closePromise?.futureResult === rhs.closePromise?.futureResult else { return false @@ -96,13 +98,15 @@ extension ConnectionStateMachine { } } -extension PSQLError: Equatable { +// fully-qualifying all types in the extension has the same effect as adding a `@retroactive` before the protocol +extension PostgresNIO.PSQLError: Swift.Equatable { public static func == (lhs: PSQLError, rhs: PSQLError) -> Bool { return true } } -extension PSQLTask: Equatable { +// fully-qualifying all types in the extension has the same effect as adding a `@retroactive` before the protocol +extension PostgresNIO.PSQLTask: Swift.Equatable { public static func == (lhs: PSQLTask, rhs: PSQLTask) -> Bool { switch (lhs, rhs) { case (.extendedQuery(let lhs), .extendedQuery(let rhs)): diff --git a/Tests/PostgresNIOTests/New/Messages/DataRowTests.swift b/Tests/PostgresNIOTests/New/Messages/DataRowTests.swift index db31b98a..a90d1e93 100644 --- a/Tests/PostgresNIOTests/New/Messages/DataRowTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/DataRowTests.swift @@ -113,7 +113,7 @@ class DataRowTests: XCTestCase { } } -extension DataRow: ExpressibleByArrayLiteral { +extension PostgresNIO.DataRow: Swift.ExpressibleByArrayLiteral { public typealias ArrayLiteralElement = PostgresEncodable public init(arrayLiteral elements: PostgresEncodable...) {