Skip to content

Commit

Permalink
Fix warnings
Browse files Browse the repository at this point in the history
  • Loading branch information
fabianfett committed Feb 21, 2024
1 parent 69ccfdf commit 65a76ee
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 62 deletions.
1 change: 1 addition & 0 deletions Sources/PostgresNIO/New/PSQLRowStream.swift
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ final class PSQLRowStream: @unchecked Sendable {
elementType: DataRow.self,
failureType: Error.self,
backPressureStrategy: AdaptiveRowBuffer(),
finishOnDeinit: false,
delegate: self
)

Expand Down
109 changes: 53 additions & 56 deletions Sources/PostgresNIO/Utilities/SASLAuthentication+SCRAM-SHA256.swift
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
import Crypto
import Foundation

extension UInt8: ExpressibleByUnicodeScalarLiteral {
extension UInt8 {
fileprivate static var NUL: UInt8 { return 0x00 /* yeah, just U+0000 man */ }
fileprivate static var comma: UInt8 { return 0x2c /* .init(ascii: ",") */ }
fileprivate static var equals: UInt8 { return 0x3d /* .init(ascii: "=") */ }
public init(unicodeScalarLiteral value: Unicode.Scalar) {
self.init(ascii: value)
}
}

fileprivate extension String {
Expand Down Expand Up @@ -171,40 +168,40 @@ fileprivate struct SCRAMMessageParser {
static func parseAttributePair(name: [UInt8], value: [UInt8], isGS2Header: Bool = false) -> SCRAMAttribute? {
guard name.count == 1 || isGS2Header else { return nil }
switch name.first {
case "m" where !isGS2Header: return .m(value)
case "r" where !isGS2Header: return String(printableAscii: value).map { .r($0) }
case "c" where !isGS2Header:
guard let parsedAttrs = value.decodingBase64().flatMap({ parse(raw: $0, isGS2Header: true) }) else { return nil }
guard (1...3).contains(parsedAttrs.count) else { return nil }
switch (parsedAttrs.first, parsedAttrs.dropFirst(1).first, parsedAttrs.dropFirst(2).first) {
case let (.gp(.bind(name, .none)), .a(ident), .gm(data)): return .c(binding: .bind(name, data), authIdentity: ident)
case let (.gp(.bind(name, .none)), .gm(data), .none): return .c(binding: .bind(name, data))
case let (.gp(bind), .a(ident), .none): return .c(binding: bind, authIdentity: ident)
case let (.gp(bind), .none, .none): return .c(binding: bind)
default: return nil
}
case "n" where !isGS2Header: return String(bytes: value, encoding: .utf8)?.decodedAsSaslName.map { .n($0) }
case "s" where !isGS2Header: return value.decodingBase64().map { .s($0) }
case "i" where !isGS2Header: return String(printableAscii: value).flatMap { UInt32.init($0) }.map { .i($0) }
case "p" where !isGS2Header: return value.decodingBase64().map { .p($0) }
case "v" where !isGS2Header: return value.decodingBase64().map { .v($0) }
case "e" where !isGS2Header: // TODO: actually map the specific enum string values
guard value.isValidScramValue else { return nil }
return String(bytes: value, encoding: .utf8).flatMap { SCRAMServerError(rawValue: $0) }.map { .e($0) }

case "y" where isGS2Header && value.count == 0: return .gp(.unused)
case "n" where isGS2Header && value.count == 0: return .gp(.unsupported)
case "p" where isGS2Header: return String(asciiAlphanumericMorse: value).map { .gp(.bind($0, nil)) }
case "a" where isGS2Header: return String(bytes: value, encoding: .utf8)?.decodedAsSaslName.map { .a($0) }
case .none where isGS2Header: return .a(nil)
case UInt8(ascii: "m") where !isGS2Header: return .m(value)
case UInt8(ascii: "r") where !isGS2Header: return String(printableAscii: value).map { .r($0) }
case UInt8(ascii: "c") where !isGS2Header:
guard let parsedAttrs = value.decodingBase64().flatMap({ parse(raw: $0, isGS2Header: true) }) else { return nil }
guard (1...3).contains(parsedAttrs.count) else { return nil }
switch (parsedAttrs.first, parsedAttrs.dropFirst(1).first, parsedAttrs.dropFirst(2).first) {
case let (.gp(.bind(name, .none)), .a(ident), .gm(data)): return .c(binding: .bind(name, data), authIdentity: ident)
case let (.gp(.bind(name, .none)), .gm(data), .none): return .c(binding: .bind(name, data))
case let (.gp(bind), .a(ident), .none): return .c(binding: bind, authIdentity: ident)
case let (.gp(bind), .none, .none): return .c(binding: bind)
default: return nil
}
case UInt8(ascii: "n") where !isGS2Header: return String(bytes: value, encoding: .utf8)?.decodedAsSaslName.map { .n($0) }
case UInt8(ascii: "s") where !isGS2Header: return value.decodingBase64().map { .s($0) }
case UInt8(ascii: "i") where !isGS2Header: return String(printableAscii: value).flatMap { UInt32.init($0) }.map { .i($0) }
case UInt8(ascii: "p") where !isGS2Header: return value.decodingBase64().map { .p($0) }
case UInt8(ascii: "v") where !isGS2Header: return value.decodingBase64().map { .v($0) }
case UInt8(ascii: "e") where !isGS2Header: // TODO: actually map the specific enum string values
guard value.isValidScramValue else { return nil }
return String(bytes: value, encoding: .utf8).flatMap { SCRAMServerError(rawValue: $0) }.map { .e($0) }

default:
if isGS2Header {
return .gm(name + value)
} else {
guard value.count > 0, value.isValidScramValue else { return nil }
return .optional(name: CChar(name[0]), value: value)
}
case UInt8(ascii: "y") where isGS2Header && value.count == 0: return .gp(.unused)
case UInt8(ascii: "n") where isGS2Header && value.count == 0: return .gp(.unsupported)
case UInt8(ascii: "p") where isGS2Header: return String(asciiAlphanumericMorse: value).map { .gp(.bind($0, nil)) }
case UInt8(ascii: "a") where isGS2Header: return String(bytes: value, encoding: .utf8)?.decodedAsSaslName.map { .a($0) }
case .none where isGS2Header: return .a(nil)

default:
if isGS2Header {
return .gm(name + value)
} else {
guard value.count > 0, value.isValidScramValue else { return nil }
return .optional(name: CChar(name[0]), value: value)
}
}
}

Expand All @@ -230,45 +227,45 @@ fileprivate struct SCRAMMessageParser {
for attribute in attributes {
switch attribute {
case .m(let value):
result.append("m"); result.append("="); result.append(contentsOf: value)
result.append(UInt8(ascii: "m")); result.append(.equals); result.append(contentsOf: value)
case .r(let nonce):
result.append("r"); result.append("="); result.append(contentsOf: nonce.utf8.map { UInt8($0) })
result.append(UInt8(ascii: "r")); result.append(.equals); result.append(contentsOf: nonce.utf8.map { UInt8($0) })
case .n(let name):
result.append("n"); result.append("="); result.append(contentsOf: name.encodedAsSaslName.utf8.map { UInt8($0) })
result.append(UInt8(ascii: "n")); result.append(.equals); result.append(contentsOf: name.encodedAsSaslName.utf8.map { UInt8($0) })
case .s(let salt):
result.append("s"); result.append("="); result.append(contentsOf: salt.encodingBase64())
result.append(UInt8(ascii: "s")); result.append(.equals); result.append(contentsOf: salt.encodingBase64())
case .i(let count):
result.append("i"); result.append("="); result.append(contentsOf: "\(count)".utf8.map { UInt8($0) })
result.append(UInt8(ascii: "i")); result.append(.equals); result.append(contentsOf: "\(count)".utf8.map { UInt8($0) })
case .p(let proof):
result.append("p"); result.append("="); result.append(contentsOf: proof.encodingBase64())
result.append(UInt8(ascii: "p")); result.append(.equals); result.append(contentsOf: proof.encodingBase64())
case .v(let signature):
result.append("v"); result.append("="); result.append(contentsOf: signature.encodingBase64())
result.append(UInt8(ascii: "v")); result.append(.equals); result.append(contentsOf: signature.encodingBase64())
case .e(let error):
result.append("e"); result.append("="); result.append(contentsOf: error.rawValue.utf8.map { UInt8($0) })
result.append(UInt8(ascii: "e")); result.append(.equals); result.append(contentsOf: error.rawValue.utf8.map { UInt8($0) })
case .c(let binding, let identity):
if isInitialGS2Header {
switch binding {
case .unsupported: result.append("n")
case .unused: result.append("y")
case .bind(let name, _): result.append("p"); result.append("="); result.append(contentsOf: name.utf8.map { UInt8($0) })
case .unsupported: result.append(UInt8(ascii: "n"))
case .unused: result.append(UInt8(ascii: "y"))
case .bind(let name, _): result.append(UInt8(ascii: "p")); result.append(.equals); result.append(contentsOf: name.utf8.map { UInt8($0) })
}
result.append(",")
result.append(.comma)
if let identity = identity {
result.append("a"); result.append("="); result.append(contentsOf: identity.encodedAsSaslName.utf8.map { UInt8($0) })
result.append(UInt8(ascii: "a")); result.append(.equals); result.append(contentsOf: identity.encodedAsSaslName.utf8.map { UInt8($0) })
}
result.append(",")
result.append(.comma)
} else {
guard var partial = serialize([attribute], isInitialGS2Header: true) else { return nil }
if case let .bind(_, data) = binding {
guard let data = data else { return nil }
partial.append(contentsOf: data)
}
result.append("c"); result.append("="); result.append(contentsOf: partial.encodingBase64())
result.append(UInt8(ascii: "c")); result.append(.equals); result.append(contentsOf: partial.encodingBase64())
}
default:
return nil
}
result.append(",")
result.append(.comma)
}
return result.dropLast()
}
Expand Down Expand Up @@ -472,7 +469,7 @@ fileprivate final class SASLMechanism_SCRAM_SHA256_Common {
let saltedPassword = Hi(string: password, salt: serverSalt, iterations: serverIterations)
let clientKey = HMAC<SHA256>.authenticationCode(for: "Client Key".data(using: .utf8)!, using: .init(data: saltedPassword))
let storedKey = SHA256.hash(data: Data(clientKey))
var authMessage = firstMessageBare; authMessage.append(","); authMessage.append(contentsOf: message); authMessage.append(","); authMessage.append(contentsOf: clientFinalNoProof)
var authMessage = firstMessageBare; authMessage.append(.comma); authMessage.append(contentsOf: message); authMessage.append(.comma); authMessage.append(contentsOf: clientFinalNoProof)
let clientSignature = HMAC<SHA256>.authenticationCode(for: authMessage, using: .init(data: storedKey))
var clientProof = Array(clientKey)

Expand All @@ -485,7 +482,7 @@ fileprivate final class SASLMechanism_SCRAM_SHA256_Common {
}

// Generate a `client-final-message`
var clientFinalMessage = clientFinalNoProof; clientFinalMessage.append(",")
var clientFinalMessage = clientFinalNoProof; clientFinalMessage.append(.comma)
guard let proofPart = SCRAMMessageParser.serialize([.p(Array(clientProof))]) else { throw SASLAuthenticationError.genericAuthenticationFailure }
clientFinalMessage.append(contentsOf: proofPart)

Expand Down Expand Up @@ -590,7 +587,7 @@ fileprivate final class SASLMechanism_SCRAM_SHA256_Common {
// Compute client signature
let clientKey = HMAC<SHA256>.authenticationCode(for: "Client Key".data(using: .utf8)!, using: .init(data: saltedPassword))
let storedKey = SHA256.hash(data: Data(clientKey))
var authMessage = clientBareFirstMessage; authMessage.append(","); authMessage.append(contentsOf: serverFirstMessage); authMessage.append(","); authMessage.append(contentsOf: message.dropLast(proof.count + 3))
var authMessage = clientBareFirstMessage; authMessage.append(.comma); authMessage.append(contentsOf: serverFirstMessage); authMessage.append(.comma); authMessage.append(contentsOf: message.dropLast(proof.count + 3))
let clientSignature = HMAC<SHA256>.authenticationCode(for: authMessage, using: .init(data: storedKey))

// Recompute client key from signature and proof, verify match
Expand Down
2 changes: 1 addition & 1 deletion Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -803,7 +803,7 @@ final class ConnectionPoolTests: XCTestCase {

pool.connectionReceivedNewMaxStreamSetting(connection, newMaxStreamSetting: 21)

for (index, request) in requests.enumerated() {
for (_, request) in requests.enumerated() {
let connection = try await request.future.success
connections.append(connection)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ import class Foundation.JSONEncoder
import NIOCore
@testable import PostgresNIO

extension ConnectionStateMachine.ConnectionAction: Equatable {
// fully-qualifying all types in the extension has the same effect as adding a `@retroactive` before the protocol
extension PostgresNIO.ConnectionStateMachine.ConnectionAction: Swift.Equatable {
public static func == (lhs: Self, rhs: Self) -> Bool {
switch (lhs, rhs) {
case (.read, read):
Expand Down Expand Up @@ -47,7 +48,8 @@ extension ConnectionStateMachine.ConnectionAction: Equatable {
}
}

extension ConnectionStateMachine.ConnectionAction.CleanUpContext: Equatable {
// fully-qualifying all types in the extension has the same effect as adding a `@retroactive` before the protocol'
extension PostgresNIO.ConnectionStateMachine.ConnectionAction.CleanUpContext: Swift.Equatable {
public static func == (lhs: Self, rhs: Self) -> Bool {
guard lhs.closePromise?.futureResult === rhs.closePromise?.futureResult else {
return false
Expand Down Expand Up @@ -96,13 +98,15 @@ extension ConnectionStateMachine {
}
}

extension PSQLError: Equatable {
// fully-qualifying all types in the extension has the same effect as adding a `@retroactive` before the protocol
extension PostgresNIO.PSQLError: Swift.Equatable {
public static func == (lhs: PSQLError, rhs: PSQLError) -> Bool {
return true
}
}

extension PSQLTask: Equatable {
// fully-qualifying all types in the extension has the same effect as adding a `@retroactive` before the protocol
extension PostgresNIO.PSQLTask: Swift.Equatable {
public static func == (lhs: PSQLTask, rhs: PSQLTask) -> Bool {
switch (lhs, rhs) {
case (.extendedQuery(let lhs), .extendedQuery(let rhs)):
Expand Down
2 changes: 1 addition & 1 deletion Tests/PostgresNIOTests/New/Messages/DataRowTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ class DataRowTests: XCTestCase {
}
}

extension DataRow: ExpressibleByArrayLiteral {
extension DataRow: @retroactive ExpressibleByArrayLiteral {
public typealias ArrayLiteralElement = PostgresEncodable

public init(arrayLiteral elements: PostgresEncodable...) {
Expand Down

0 comments on commit 65a76ee

Please sign in to comment.