Skip to content

Commit

Permalink
Enable StrictConcurrency checking (vapor#483)
Browse files Browse the repository at this point in the history
  • Loading branch information
fabianfett authored Jun 14, 2024
1 parent 6c3d0a9 commit 7b621c1
Show file tree
Hide file tree
Showing 9 changed files with 63 additions and 50 deletions.
19 changes: 14 additions & 5 deletions Package.swift
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
// swift-tools-version:5.8
import PackageDescription

let swiftSettings: [SwiftSetting] = [
.enableUpcomingFeature("StrictConcurrency")
]

let package = Package(
name: "postgres-nio",
platforms: [
Expand Down Expand Up @@ -41,23 +45,26 @@ let package = Package(
.product(name: "NIOSSL", package: "swift-nio-ssl"),
.product(name: "NIOFoundationCompat", package: "swift-nio"),
.product(name: "ServiceLifecycle", package: "swift-service-lifecycle"),
]
],
swiftSettings: swiftSettings
),
.target(
name: "_ConnectionPoolModule",
dependencies: [
.product(name: "Atomics", package: "swift-atomics"),
.product(name: "DequeModule", package: "swift-collections"),
],
path: "Sources/ConnectionPoolModule"
path: "Sources/ConnectionPoolModule",
swiftSettings: swiftSettings
),
.testTarget(
name: "PostgresNIOTests",
dependencies: [
.target(name: "PostgresNIO"),
.product(name: "NIOEmbedded", package: "swift-nio"),
.product(name: "NIOTestUtils", package: "swift-nio"),
]
],
swiftSettings: swiftSettings
),
.testTarget(
name: "ConnectionPoolModuleTests",
Expand All @@ -67,14 +74,16 @@ let package = Package(
.product(name: "NIOCore", package: "swift-nio"),
.product(name: "NIOConcurrencyHelpers", package: "swift-nio"),
.product(name: "NIOEmbedded", package: "swift-nio"),
]
],
swiftSettings: swiftSettings
),
.testTarget(
name: "IntegrationTests",
dependencies: [
.target(name: "PostgresNIO"),
.product(name: "NIOTestUtils", package: "swift-nio"),
]
],
swiftSettings: swiftSettings
),
]
)
4 changes: 2 additions & 2 deletions Sources/ConnectionPoolModule/ConnectionPool.swift
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@

@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *)
public struct ConnectionAndMetadata<Connection: PooledConnection> {
public struct ConnectionAndMetadata<Connection: PooledConnection>: Sendable {

public var connection: Connection

Expand Down Expand Up @@ -495,7 +495,7 @@ public final class ConnectionPool<
}

@usableFromInline
enum TimerRunResult {
enum TimerRunResult: Sendable {
case timerTriggered
case timerCancelled
case cancellationContinuationFinished
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ public protocol ConnectionPoolObservabilityDelegate: Sendable {
func requestQueueDepthChanged(_ newDepth: Int)
}

public struct NoOpConnectionPoolMetrics<ConnectionID: Hashable>: ConnectionPoolObservabilityDelegate {
public struct NoOpConnectionPoolMetrics<ConnectionID: Hashable & Sendable>: ConnectionPoolObservabilityDelegate {
public init(connectionIDType: ConnectionID.Type) {}

public func startedConnecting(id: ConnectionID) {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ extension PostgresMessage {
/// Identifies an incoming or outgoing postgres message. Sent as the first byte, before the message size.
/// Values are not unique across all identifiers, meaning some messages will require keeping state to identify.
@available(*, deprecated, message: "Will be removed from public API.")
public struct Identifier: ExpressibleByIntegerLiteral, Equatable, CustomStringConvertible {
public struct Identifier: Sendable, ExpressibleByIntegerLiteral, Equatable, CustomStringConvertible {
// special
public static let none: Identifier = 0x00
// special
Expand Down
2 changes: 1 addition & 1 deletion Sources/PostgresNIO/Pool/PostgresClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ extension PostgresConnection: PooledConnection {
self.channel.close(mode: .all, promise: nil)
}

public func onClose(_ closure: @escaping ((any Error)?) -> ()) {
public func onClose(_ closure: @escaping @Sendable ((any Error)?) -> ()) {
self.closeFuture.whenComplete { _ in closure(nil) }
}
}
Expand Down
2 changes: 1 addition & 1 deletion Sources/PostgresNIO/Utilities/PostgresError+Code.swift
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
extension PostgresError {
public struct Code: ExpressibleByStringLiteral, Equatable {
public struct Code: Sendable, ExpressibleByStringLiteral, Equatable {
// Class 00 — Successful Completion
public static let successfulCompletion: Code = "00000"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import DequeModule

@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *)
final class MockConnectionFactory<Clock: _Concurrency.Clock> where Clock.Duration == Duration {
final class MockConnectionFactory<Clock: _Concurrency.Clock>: Sendable where Clock.Duration == Duration {
typealias ConnectionIDGenerator = _ConnectionPoolModule.ConnectionIDGenerator
typealias Request = ConnectionRequest<MockConnection>
typealias KeepAliveBehavior = MockPingPongBehavior
Expand Down
61 changes: 31 additions & 30 deletions Tests/IntegrationTests/PostgresNIOTests.swift
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import Logging
@testable import PostgresNIO
import Atomics
import XCTest
import NIOCore
import NIOPosix
Expand Down Expand Up @@ -112,59 +113,59 @@ final class PostgresNIOTests: XCTestCase {
XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait())
defer { XCTAssertNoThrow( try conn?.close().wait() ) }

var receivedNotifications: [PostgresMessage.NotificationResponse] = []
let receivedNotifications = ManagedAtomic<Int>(0)
conn?.addListener(channel: "example") { context, notification in
receivedNotifications.append(notification)
receivedNotifications.wrappingIncrement(ordering: .relaxed)
XCTAssertEqual(notification.channel, "example")
XCTAssertEqual(notification.payload, "")
}
XCTAssertNoThrow(_ = try conn?.simpleQuery("LISTEN example").wait())
XCTAssertNoThrow(_ = try conn?.simpleQuery("NOTIFY example").wait())
// Notifications are asynchronous, so we should run at least one more query to make sure we'll have received the notification response by then
XCTAssertNoThrow(_ = try conn?.simpleQuery("SELECT 1").wait())
XCTAssertEqual(receivedNotifications.count, 1)
XCTAssertEqual(receivedNotifications.first?.channel, "example")
XCTAssertEqual(receivedNotifications.first?.payload, "")
XCTAssertEqual(receivedNotifications.load(ordering: .relaxed), 1)
}

func testNotificationsNonEmptyPayload() {
var conn: PostgresConnection?
XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait())
defer { XCTAssertNoThrow( try conn?.close().wait() ) }
var receivedNotifications: [PostgresMessage.NotificationResponse] = []
let receivedNotifications = ManagedAtomic<Int>(0)
conn?.addListener(channel: "example") { context, notification in
receivedNotifications.append(notification)
receivedNotifications.wrappingIncrement(ordering: .relaxed)
XCTAssertEqual(notification.channel, "example")
XCTAssertEqual(notification.payload, "Notification payload example")
}
XCTAssertNoThrow(_ = try conn?.simpleQuery("LISTEN example").wait())
XCTAssertNoThrow(_ = try conn?.simpleQuery("NOTIFY example, 'Notification payload example'").wait())
// Notifications are asynchronous, so we should run at least one more query to make sure we'll have received the notification response by then
XCTAssertNoThrow(_ = try conn?.simpleQuery("SELECT 1").wait())
XCTAssertEqual(receivedNotifications.count, 1)
XCTAssertEqual(receivedNotifications.first?.channel, "example")
XCTAssertEqual(receivedNotifications.first?.payload, "Notification payload example")
XCTAssertEqual(receivedNotifications.load(ordering: .relaxed), 1)
}

func testNotificationsRemoveHandlerWithinHandler() {
var conn: PostgresConnection?
XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait())
defer { XCTAssertNoThrow( try conn?.close().wait() ) }
var receivedNotifications = 0
let receivedNotifications = ManagedAtomic<Int>(0)
conn?.addListener(channel: "example") { context, notification in
receivedNotifications += 1
receivedNotifications.wrappingIncrement(ordering: .relaxed)
context.stop()
}
XCTAssertNoThrow(_ = try conn?.simpleQuery("LISTEN example").wait())
XCTAssertNoThrow(_ = try conn?.simpleQuery("NOTIFY example").wait())
XCTAssertNoThrow(_ = try conn?.simpleQuery("NOTIFY example").wait())
XCTAssertNoThrow(_ = try conn?.simpleQuery("SELECT 1").wait())
XCTAssertEqual(receivedNotifications, 1)
XCTAssertEqual(receivedNotifications.load(ordering: .relaxed), 1)
}

func testNotificationsRemoveHandlerOutsideHandler() {
var conn: PostgresConnection?
XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait())
defer { XCTAssertNoThrow( try conn?.close().wait() ) }
var receivedNotifications = 0
let receivedNotifications = ManagedAtomic<Int>(0)
let context = conn?.addListener(channel: "example") { context, notification in
receivedNotifications += 1
receivedNotifications.wrappingIncrement(ordering: .relaxed)
}
XCTAssertNotNil(context)
XCTAssertNoThrow(_ = try conn?.simpleQuery("LISTEN example").wait())
Expand All @@ -173,47 +174,47 @@ final class PostgresNIOTests: XCTestCase {
context?.stop()
XCTAssertNoThrow(_ = try conn?.simpleQuery("NOTIFY example").wait())
XCTAssertNoThrow(_ = try conn?.simpleQuery("SELECT 1").wait())
XCTAssertEqual(receivedNotifications, 1)
XCTAssertEqual(receivedNotifications.load(ordering: .relaxed), 1)
}

func testNotificationsMultipleRegisteredHandlers() {
var conn: PostgresConnection?
XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait())
defer { XCTAssertNoThrow( try conn?.close().wait() ) }
var receivedNotifications1 = 0
let receivedNotifications1 = ManagedAtomic<Int>(0)
conn?.addListener(channel: "example") { context, notification in
receivedNotifications1 += 1
receivedNotifications1.wrappingIncrement(ordering: .relaxed)
}
var receivedNotifications2 = 0
let receivedNotifications2 = ManagedAtomic<Int>(0)
conn?.addListener(channel: "example") { context, notification in
receivedNotifications2 += 1
receivedNotifications2.wrappingIncrement(ordering: .relaxed)
}
XCTAssertNoThrow(_ = try conn?.simpleQuery("LISTEN example").wait())
XCTAssertNoThrow(_ = try conn?.simpleQuery("NOTIFY example").wait())
XCTAssertNoThrow(_ = try conn?.simpleQuery("SELECT 1").wait())
XCTAssertEqual(receivedNotifications1, 1)
XCTAssertEqual(receivedNotifications2, 1)
XCTAssertEqual(receivedNotifications1.load(ordering: .relaxed), 1)
XCTAssertEqual(receivedNotifications2.load(ordering: .relaxed), 1)
}

func testNotificationsMultipleRegisteredHandlersRemoval() throws {
var conn: PostgresConnection?
XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait())
defer { XCTAssertNoThrow( try conn?.close().wait() ) }
var receivedNotifications1 = 0
let receivedNotifications1 = ManagedAtomic<Int>(0)
XCTAssertNotNil(conn?.addListener(channel: "example") { context, notification in
receivedNotifications1 += 1
receivedNotifications1.wrappingIncrement(ordering: .relaxed)
context.stop()
})
var receivedNotifications2 = 0
let receivedNotifications2 = ManagedAtomic<Int>(0)
XCTAssertNotNil(conn?.addListener(channel: "example") { context, notification in
receivedNotifications2 += 1
receivedNotifications2.wrappingIncrement(ordering: .relaxed)
})
XCTAssertNoThrow(_ = try conn?.simpleQuery("LISTEN example").wait())
XCTAssertNoThrow(_ = try conn?.simpleQuery("NOTIFY example").wait())
XCTAssertNoThrow(_ = try conn?.simpleQuery("NOTIFY example").wait())
XCTAssertNoThrow(_ = try conn?.simpleQuery("SELECT 1").wait())
XCTAssertEqual(receivedNotifications1, 1)
XCTAssertEqual(receivedNotifications2, 2)
XCTAssertEqual(receivedNotifications1.load(ordering: .relaxed), 1)
XCTAssertEqual(receivedNotifications2.load(ordering: .relaxed), 2)
}

func testNotificationHandlerFiltersOnChannel() {
Expand Down Expand Up @@ -1283,11 +1284,11 @@ final class PostgresNIOTests: XCTestCase {
XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait())
defer { XCTAssertNoThrow( try conn?.close().wait() ) }
var queries: [[PostgresRow]]?
XCTAssertNoThrow(queries = try conn?.prepare(query: "SELECT $1::text as foo;", handler: { query in
XCTAssertNoThrow(queries = try conn?.prepare(query: "SELECT $1::text as foo;", handler: { [eventLoop] query in
let a = query.execute(["a"])
let b = query.execute(["b"])
let c = query.execute(["c"])
return EventLoopFuture.whenAllSucceed([a, b, c], on: self.eventLoop)
return EventLoopFuture.whenAllSucceed([a, b, c], on: eventLoop)
}).wait())
XCTAssertEqual(queries?.count, 3)
var resultIterator = queries?.makeIterator()
Expand Down
19 changes: 11 additions & 8 deletions Tests/PostgresNIOTests/New/PostgresConnectionTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ class PostgresConnectionTests: XCTestCase {
func testSimpleListenConnectionDrops() async throws {
let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel()

try await withThrowingTaskGroup(of: Void.self) { taskGroup in
try await withThrowingTaskGroup(of: Void.self) { [logger] taskGroup in
taskGroup.addTask {
let events = try await connection.listen("foo")
var iterator = events.makeAsyncIterator()
Expand All @@ -197,7 +197,7 @@ class PostgresConnectionTests: XCTestCase {
_ = try await iterator.next()
XCTFail("Did not expect to not throw")
} catch {
self.logger.error("error", metadata: ["error": "\(error)"])
logger.error("error", metadata: ["error": "\(error)"])
}
}

Expand Down Expand Up @@ -226,10 +226,10 @@ class PostgresConnectionTests: XCTestCase {

func testCloseGracefullyClosesWhenInternalQueueIsEmpty() async throws {
let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel()
try await withThrowingTaskGroup(of: Void.self) { taskGroup async throws -> () in
try await withThrowingTaskGroup(of: Void.self) { [logger] taskGroup async throws -> () in
for _ in 1...2 {
taskGroup.addTask {
let rows = try await connection.query("SELECT 1;", logger: self.logger)
let rows = try await connection.query("SELECT 1;", logger: logger)
var iterator = rows.decode(Int.self).makeAsyncIterator()
let first = try await iterator.next()
XCTAssertEqual(first, 1)
Expand Down Expand Up @@ -286,10 +286,10 @@ class PostgresConnectionTests: XCTestCase {
func testCloseClosesImmediatly() async throws {
let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel()

try await withThrowingTaskGroup(of: Void.self) { taskGroup async throws -> () in
try await withThrowingTaskGroup(of: Void.self) { [logger] taskGroup async throws -> () in
for _ in 1...2 {
taskGroup.addTask {
try await connection.query("SELECT 1;", logger: self.logger)
try await connection.query("SELECT 1;", logger: logger)
}
}

Expand Down Expand Up @@ -319,8 +319,9 @@ class PostgresConnectionTests: XCTestCase {

func testIfServerJustClosesTheErrorReflectsThat() async throws {
let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel()
let logger = self.logger

async let response = try await connection.query("SELECT 1;", logger: self.logger)
async let response = try await connection.query("SELECT 1;", logger: logger)

let listenMessage = try await channel.waitForUnpreparedRequest()
XCTAssertEqual(listenMessage.parse.query, "SELECT 1;")
Expand Down Expand Up @@ -423,6 +424,7 @@ class PostgresConnectionTests: XCTestCase {
case pleaseDontCrash
}
channel.pipeline.fireUserInboundEventTriggered(MyEvent.pleaseDontCrash)
try await connection.close()
}

func testSerialExecutionOfSamePreparedStatement() async throws {
Expand Down Expand Up @@ -651,7 +653,8 @@ class PostgresConnectionTests: XCTestCase {
database: "database"
)

async let connectionPromise = PostgresConnection.connect(on: eventLoop, configuration: configuration, id: 1, logger: self.logger)
let logger = self.logger
async let connectionPromise = PostgresConnection.connect(on: eventLoop, configuration: configuration, id: 1, logger: logger)
let message = try await channel.waitForOutboundWrite(as: PostgresFrontendMessage.self)
XCTAssertEqual(message, .startup(.versionThree(parameters: .init(user: "username", database: "database", options: [], replication: .false))))
try await channel.writeInbound(PostgresBackendMessage.authentication(.ok))
Expand Down

0 comments on commit 7b621c1

Please sign in to comment.