diff --git a/Sources/PostgresNIO/New/PSQLRowStream.swift b/Sources/PostgresNIO/New/PSQLRowStream.swift index fb50a3fa..72900e0c 100644 --- a/Sources/PostgresNIO/New/PSQLRowStream.swift +++ b/Sources/PostgresNIO/New/PSQLRowStream.swift @@ -32,13 +32,17 @@ final class PSQLRowStream: @unchecked Sendable { case empty case failure(Error) } - + + private enum Consumed { + case tag(String) + case emptyResponse + } + private enum DownstreamState { case waitingForConsumer(BufferState) case iteratingRows(onRow: (PostgresRow) throws -> (), EventLoopPromise, PSQLRowsDataSource) case waitingForAll([PostgresRow], EventLoopPromise<[PostgresRow]>, PSQLRowsDataSource) - case consumed(Result) - case finished + case consumed(Result) case asyncSequence(AsyncSequenceSource, PSQLRowsDataSource, onFinish: @Sendable () -> ()) } @@ -108,13 +112,13 @@ final class PSQLRowStream: @unchecked Sendable { case .empty: source.finish() onFinish() - self.downstreamState = .finished + self.downstreamState = .consumed(.success(.emptyResponse)) case .finished(let buffer, let commandTag): _ = source.yield(contentsOf: buffer) source.finish() onFinish() - self.downstreamState = .consumed(.success(commandTag)) + self.downstreamState = .consumed(.success(.tag(commandTag))) case .failure(let error): source.finish(error) @@ -139,7 +143,7 @@ final class PSQLRowStream: @unchecked Sendable { case .waitingForConsumer, .iteratingRows, .waitingForAll: preconditionFailure("Invalid state: \(self.downstreamState)") - case .consumed, .finished: + case .consumed: break case .asyncSequence(_, let dataSource, _): @@ -164,7 +168,7 @@ final class PSQLRowStream: @unchecked Sendable { dataSource.cancel(for: self) onFinish() - case .consumed, .finished: + case .consumed: return case .waitingForConsumer, .iteratingRows, .waitingForAll: @@ -207,7 +211,7 @@ final class PSQLRowStream: @unchecked Sendable { PostgresRow(data: $0, lookupTable: self.lookupTable, columns: self.rowDescription) } - self.downstreamState = .consumed(.success(commandTag)) + self.downstreamState = .consumed(.success(.tag(commandTag))) return self.eventLoop.makeSucceededFuture(rows) case .failure(let error): @@ -215,7 +219,7 @@ final class PSQLRowStream: @unchecked Sendable { return self.eventLoop.makeFailedFuture(error) case .empty: - self.downstreamState = .finished + self.downstreamState = .consumed(.success(.emptyResponse)) return self.eventLoop.makeSucceededFuture([]) } } @@ -265,7 +269,7 @@ final class PSQLRowStream: @unchecked Sendable { return promise.futureResult case .empty: - self.downstreamState = .finished + self.downstreamState = .consumed(.success(.emptyResponse)) return self.eventLoop.makeSucceededVoidFuture() case .finished(let buffer, let commandTag): @@ -279,7 +283,7 @@ final class PSQLRowStream: @unchecked Sendable { try onRow(row) } - self.downstreamState = .consumed(.success(commandTag)) + self.downstreamState = .consumed(.success(.tag(commandTag))) return self.eventLoop.makeSucceededVoidFuture() } catch { self.downstreamState = .consumed(.failure(error)) @@ -350,9 +354,6 @@ final class PSQLRowStream: @unchecked Sendable { case .consumed(.failure): break - - case .finished: - preconditionFailure("How can we receive further rows, if we are supposed to be done") } } @@ -376,22 +377,22 @@ final class PSQLRowStream: @unchecked Sendable { preconditionFailure("How can we get another end, if an end was already signalled?") case .iteratingRows(_, let promise, _): - self.downstreamState = .consumed(.success(commandTag)) + self.downstreamState = .consumed(.success(.tag(commandTag))) promise.succeed(()) case .waitingForAll(let rows, let promise, _): - self.downstreamState = .consumed(.success(commandTag)) + self.downstreamState = .consumed(.success(.tag(commandTag))) promise.succeed(rows) case .asyncSequence(let source, _, let onFinish): - self.downstreamState = .consumed(.success(commandTag)) + self.downstreamState = .consumed(.success(.tag(commandTag))) source.finish() onFinish() - case .consumed: + case .consumed(.success(.tag)), .consumed(.failure): break - case .finished, .waitingForConsumer(.empty): + case .consumed(.success(.emptyResponse)), .waitingForConsumer(.empty): preconditionFailure("How can we get an end for empty query response?") } } @@ -417,10 +418,10 @@ final class PSQLRowStream: @unchecked Sendable { consumer.finish(error) onFinish() - case .consumed: + case .consumed(.success(.tag)), .consumed(.failure): break - case .finished: + case .consumed(.success(.emptyResponse)): preconditionFailure("How can we get an error for empty query response?") } } @@ -442,13 +443,14 @@ final class PSQLRowStream: @unchecked Sendable { } var commandTag: String { - switch self.downstreamState { - case .consumed(.success(let commandTag)): - return commandTag - case .finished: + guard case .consumed(.success(let consumed)) = self.downstreamState else { + preconditionFailure("commandTag may only be called if all rows have been consumed") + } + switch consumed { + case .tag(let tag): + return tag + case .emptyResponse: return "" - default: - preconditionFailure("commandTag may only be called if there are no more rows to be consumed") } } }