Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix warnings #454

Merged
merged 2 commits into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Sources/PostgresNIO/New/PSQLRowStream.swift
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ final class PSQLRowStream: @unchecked Sendable {
elementType: DataRow.self,
failureType: Error.self,
backPressureStrategy: AdaptiveRowBuffer(),
finishOnDeinit: false,
delegate: self
)

Expand Down
111 changes: 54 additions & 57 deletions Sources/PostgresNIO/Utilities/SASLAuthentication+SCRAM-SHA256.swift
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
import Crypto
import Foundation

extension UInt8: ExpressibleByUnicodeScalarLiteral {
extension UInt8 {
fileprivate static var NUL: UInt8 { return 0x00 /* yeah, just U+0000 man */ }
fileprivate static var comma: UInt8 { return 0x2c /* .init(ascii: ",") */ }
fileprivate static var equals: UInt8 { return 0x3d /* .init(ascii: "=") */ }
public init(unicodeScalarLiteral value: Unicode.Scalar) {
self.init(ascii: value)
}
}

fileprivate extension String {
Expand Down Expand Up @@ -87,7 +84,7 @@ fileprivate extension Array where Element == UInt8 {
*/
var isValidScramValue: Bool {
// TODO: FInd a better way than doing a whole construction of String...
return self.count > 0 && !(String(bytes: self, encoding: .utf8)?.contains(",") ?? true)
return self.count > 0 && !(String(decoding: self, as: Unicode.UTF8.self).contains(","))
}

}
Expand Down Expand Up @@ -171,40 +168,40 @@ fileprivate struct SCRAMMessageParser {
static func parseAttributePair(name: [UInt8], value: [UInt8], isGS2Header: Bool = false) -> SCRAMAttribute? {
guard name.count == 1 || isGS2Header else { return nil }
switch name.first {
case "m" where !isGS2Header: return .m(value)
case "r" where !isGS2Header: return String(printableAscii: value).map { .r($0) }
case "c" where !isGS2Header:
guard let parsedAttrs = value.decodingBase64().flatMap({ parse(raw: $0, isGS2Header: true) }) else { return nil }
guard (1...3).contains(parsedAttrs.count) else { return nil }
switch (parsedAttrs.first, parsedAttrs.dropFirst(1).first, parsedAttrs.dropFirst(2).first) {
case let (.gp(.bind(name, .none)), .a(ident), .gm(data)): return .c(binding: .bind(name, data), authIdentity: ident)
case let (.gp(.bind(name, .none)), .gm(data), .none): return .c(binding: .bind(name, data))
case let (.gp(bind), .a(ident), .none): return .c(binding: bind, authIdentity: ident)
case let (.gp(bind), .none, .none): return .c(binding: bind)
default: return nil
}
case "n" where !isGS2Header: return String(bytes: value, encoding: .utf8)?.decodedAsSaslName.map { .n($0) }
case "s" where !isGS2Header: return value.decodingBase64().map { .s($0) }
case "i" where !isGS2Header: return String(printableAscii: value).flatMap { UInt32.init($0) }.map { .i($0) }
case "p" where !isGS2Header: return value.decodingBase64().map { .p($0) }
case "v" where !isGS2Header: return value.decodingBase64().map { .v($0) }
case "e" where !isGS2Header: // TODO: actually map the specific enum string values
guard value.isValidScramValue else { return nil }
return String(bytes: value, encoding: .utf8).flatMap { SCRAMServerError(rawValue: $0) }.map { .e($0) }

case "y" where isGS2Header && value.count == 0: return .gp(.unused)
case "n" where isGS2Header && value.count == 0: return .gp(.unsupported)
case "p" where isGS2Header: return String(asciiAlphanumericMorse: value).map { .gp(.bind($0, nil)) }
case "a" where isGS2Header: return String(bytes: value, encoding: .utf8)?.decodedAsSaslName.map { .a($0) }
case .none where isGS2Header: return .a(nil)
case UInt8(ascii: "m") where !isGS2Header: return .m(value)
case UInt8(ascii: "r") where !isGS2Header: return String(printableAscii: value).map { .r($0) }
case UInt8(ascii: "c") where !isGS2Header:
guard let parsedAttrs = value.decodingBase64().flatMap({ parse(raw: $0, isGS2Header: true) }) else { return nil }
guard (1...3).contains(parsedAttrs.count) else { return nil }
switch (parsedAttrs.first, parsedAttrs.dropFirst(1).first, parsedAttrs.dropFirst(2).first) {
case let (.gp(.bind(name, .none)), .a(ident), .gm(data)): return .c(binding: .bind(name, data), authIdentity: ident)
case let (.gp(.bind(name, .none)), .gm(data), .none): return .c(binding: .bind(name, data))
case let (.gp(bind), .a(ident), .none): return .c(binding: bind, authIdentity: ident)
case let (.gp(bind), .none, .none): return .c(binding: bind)
default: return nil
}
case UInt8(ascii: "n") where !isGS2Header: return String(decoding: value, as: Unicode.UTF8.self).decodedAsSaslName.map { .n($0) }
case UInt8(ascii: "s") where !isGS2Header: return value.decodingBase64().map { .s($0) }
case UInt8(ascii: "i") where !isGS2Header: return String(printableAscii: value).flatMap { UInt32.init($0) }.map { .i($0) }
case UInt8(ascii: "p") where !isGS2Header: return value.decodingBase64().map { .p($0) }
case UInt8(ascii: "v") where !isGS2Header: return value.decodingBase64().map { .v($0) }
case UInt8(ascii: "e") where !isGS2Header: // TODO: actually map the specific enum string values
guard value.isValidScramValue else { return nil }
return SCRAMServerError(rawValue: String(decoding: value, as: Unicode.UTF8.self)).flatMap { .e($0) }

default:
if isGS2Header {
return .gm(name + value)
} else {
guard value.count > 0, value.isValidScramValue else { return nil }
return .optional(name: CChar(name[0]), value: value)
}
case UInt8(ascii: "y") where isGS2Header && value.count == 0: return .gp(.unused)
case UInt8(ascii: "n") where isGS2Header && value.count == 0: return .gp(.unsupported)
case UInt8(ascii: "p") where isGS2Header: return String(asciiAlphanumericMorse: value).map { .gp(.bind($0, nil)) }
case UInt8(ascii: "a") where isGS2Header: return String(decoding: value, as: Unicode.UTF8.self).decodedAsSaslName.map { .a($0) }
case .none where isGS2Header: return .a(nil)

default:
if isGS2Header {
return .gm(name + value)
} else {
guard value.count > 0, value.isValidScramValue else { return nil }
return .optional(name: CChar(name[0]), value: value)
}
}
}

Expand All @@ -230,45 +227,45 @@ fileprivate struct SCRAMMessageParser {
for attribute in attributes {
switch attribute {
case .m(let value):
result.append("m"); result.append("="); result.append(contentsOf: value)
result.append(UInt8(ascii: "m")); result.append(.equals); result.append(contentsOf: value)
case .r(let nonce):
result.append("r"); result.append("="); result.append(contentsOf: nonce.utf8.map { UInt8($0) })
result.append(UInt8(ascii: "r")); result.append(.equals); result.append(contentsOf: nonce.utf8.map { UInt8($0) })
case .n(let name):
result.append("n"); result.append("="); result.append(contentsOf: name.encodedAsSaslName.utf8.map { UInt8($0) })
result.append(UInt8(ascii: "n")); result.append(.equals); result.append(contentsOf: name.encodedAsSaslName.utf8.map { UInt8($0) })
case .s(let salt):
result.append("s"); result.append("="); result.append(contentsOf: salt.encodingBase64())
result.append(UInt8(ascii: "s")); result.append(.equals); result.append(contentsOf: salt.encodingBase64())
case .i(let count):
result.append("i"); result.append("="); result.append(contentsOf: "\(count)".utf8.map { UInt8($0) })
result.append(UInt8(ascii: "i")); result.append(.equals); result.append(contentsOf: "\(count)".utf8.map { UInt8($0) })
case .p(let proof):
result.append("p"); result.append("="); result.append(contentsOf: proof.encodingBase64())
result.append(UInt8(ascii: "p")); result.append(.equals); result.append(contentsOf: proof.encodingBase64())
case .v(let signature):
result.append("v"); result.append("="); result.append(contentsOf: signature.encodingBase64())
result.append(UInt8(ascii: "v")); result.append(.equals); result.append(contentsOf: signature.encodingBase64())
case .e(let error):
result.append("e"); result.append("="); result.append(contentsOf: error.rawValue.utf8.map { UInt8($0) })
result.append(UInt8(ascii: "e")); result.append(.equals); result.append(contentsOf: error.rawValue.utf8.map { UInt8($0) })
case .c(let binding, let identity):
if isInitialGS2Header {
switch binding {
case .unsupported: result.append("n")
case .unused: result.append("y")
case .bind(let name, _): result.append("p"); result.append("="); result.append(contentsOf: name.utf8.map { UInt8($0) })
case .unsupported: result.append(UInt8(ascii: "n"))
case .unused: result.append(UInt8(ascii: "y"))
case .bind(let name, _): result.append(UInt8(ascii: "p")); result.append(.equals); result.append(contentsOf: name.utf8.map { UInt8($0) })
}
result.append(",")
result.append(.comma)
if let identity = identity {
result.append("a"); result.append("="); result.append(contentsOf: identity.encodedAsSaslName.utf8.map { UInt8($0) })
result.append(UInt8(ascii: "a")); result.append(.equals); result.append(contentsOf: identity.encodedAsSaslName.utf8.map { UInt8($0) })
}
result.append(",")
result.append(.comma)
} else {
guard var partial = serialize([attribute], isInitialGS2Header: true) else { return nil }
if case let .bind(_, data) = binding {
guard let data = data else { return nil }
partial.append(contentsOf: data)
}
result.append("c"); result.append("="); result.append(contentsOf: partial.encodingBase64())
result.append(UInt8(ascii: "c")); result.append(.equals); result.append(contentsOf: partial.encodingBase64())
}
default:
return nil
}
result.append(",")
result.append(.comma)
}
return result.dropLast()
}
Expand Down Expand Up @@ -472,7 +469,7 @@ fileprivate final class SASLMechanism_SCRAM_SHA256_Common {
let saltedPassword = Hi(string: password, salt: serverSalt, iterations: serverIterations)
let clientKey = HMAC<SHA256>.authenticationCode(for: "Client Key".data(using: .utf8)!, using: .init(data: saltedPassword))
let storedKey = SHA256.hash(data: Data(clientKey))
var authMessage = firstMessageBare; authMessage.append(","); authMessage.append(contentsOf: message); authMessage.append(","); authMessage.append(contentsOf: clientFinalNoProof)
var authMessage = firstMessageBare; authMessage.append(.comma); authMessage.append(contentsOf: message); authMessage.append(.comma); authMessage.append(contentsOf: clientFinalNoProof)
let clientSignature = HMAC<SHA256>.authenticationCode(for: authMessage, using: .init(data: storedKey))
var clientProof = Array(clientKey)

Expand All @@ -485,7 +482,7 @@ fileprivate final class SASLMechanism_SCRAM_SHA256_Common {
}

// Generate a `client-final-message`
var clientFinalMessage = clientFinalNoProof; clientFinalMessage.append(",")
var clientFinalMessage = clientFinalNoProof; clientFinalMessage.append(.comma)
guard let proofPart = SCRAMMessageParser.serialize([.p(Array(clientProof))]) else { throw SASLAuthenticationError.genericAuthenticationFailure }
clientFinalMessage.append(contentsOf: proofPart)

Expand Down Expand Up @@ -590,7 +587,7 @@ fileprivate final class SASLMechanism_SCRAM_SHA256_Common {
// Compute client signature
let clientKey = HMAC<SHA256>.authenticationCode(for: "Client Key".data(using: .utf8)!, using: .init(data: saltedPassword))
let storedKey = SHA256.hash(data: Data(clientKey))
var authMessage = clientBareFirstMessage; authMessage.append(","); authMessage.append(contentsOf: serverFirstMessage); authMessage.append(","); authMessage.append(contentsOf: message.dropLast(proof.count + 3))
var authMessage = clientBareFirstMessage; authMessage.append(.comma); authMessage.append(contentsOf: serverFirstMessage); authMessage.append(.comma); authMessage.append(contentsOf: message.dropLast(proof.count + 3))
let clientSignature = HMAC<SHA256>.authenticationCode(for: authMessage, using: .init(data: storedKey))

// Recompute client key from signature and proof, verify match
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -803,7 +803,7 @@ final class ConnectionPoolTests: XCTestCase {

pool.connectionReceivedNewMaxStreamSetting(connection, newMaxStreamSetting: 21)

for (index, request) in requests.enumerated() {
for (_, request) in requests.enumerated() {
let connection = try await request.future.success
connections.append(connection)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ import class Foundation.JSONEncoder
import NIOCore
@testable import PostgresNIO

extension ConnectionStateMachine.ConnectionAction: Equatable {
// fully-qualifying all types in the extension has the same effect as adding a `@retroactive` before the protocol
extension PostgresNIO.ConnectionStateMachine.ConnectionAction: Swift.Equatable {
public static func == (lhs: Self, rhs: Self) -> Bool {
switch (lhs, rhs) {
case (.read, read):
Expand Down Expand Up @@ -47,7 +48,8 @@ extension ConnectionStateMachine.ConnectionAction: Equatable {
}
}

extension ConnectionStateMachine.ConnectionAction.CleanUpContext: Equatable {
// fully-qualifying all types in the extension has the same effect as adding a `@retroactive` before the protocol'
extension PostgresNIO.ConnectionStateMachine.ConnectionAction.CleanUpContext: Swift.Equatable {
public static func == (lhs: Self, rhs: Self) -> Bool {
guard lhs.closePromise?.futureResult === rhs.closePromise?.futureResult else {
return false
Expand Down Expand Up @@ -96,13 +98,15 @@ extension ConnectionStateMachine {
}
}

extension PSQLError: Equatable {
// fully-qualifying all types in the extension has the same effect as adding a `@retroactive` before the protocol
extension PostgresNIO.PSQLError: Swift.Equatable {
public static func == (lhs: PSQLError, rhs: PSQLError) -> Bool {
return true
}
}

extension PSQLTask: Equatable {
// fully-qualifying all types in the extension has the same effect as adding a `@retroactive` before the protocol
extension PostgresNIO.PSQLTask: Swift.Equatable {
public static func == (lhs: PSQLTask, rhs: PSQLTask) -> Bool {
switch (lhs, rhs) {
case (.extendedQuery(let lhs), .extendedQuery(let rhs)):
Expand Down
2 changes: 1 addition & 1 deletion Tests/PostgresNIOTests/New/Messages/DataRowTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ class DataRowTests: XCTestCase {
}
}

extension DataRow: ExpressibleByArrayLiteral {
extension PostgresNIO.DataRow: Swift.ExpressibleByArrayLiteral {
public typealias ArrayLiteralElement = PostgresEncodable

public init(arrayLiteral elements: PostgresEncodable...) {
Expand Down
Loading