diff --git a/NOTICES.txt b/NOTICES.txt index 52e13d63..3a9cc733 100644 --- a/NOTICES.txt +++ b/NOTICES.txt @@ -17,3 +17,12 @@ This product contains a derivation of `NIOSSLTestHelpers.swift` from SwiftNIO SS * https://www.apache.org/licenses/LICENSE-2.0 * HOMEPAGE: * https://github.com/apple/swift-nio-ssl + +--- + +This product contains derivations of "HTTPProxySimulator" and "HTTPBin" test utils from AsyncHTTPClient. + + * LICENSE (Apache License 2.0): + * https://www.apache.org/licenses/LICENSE-2.0 + * HOMEPAGE: + * https://github.com/swift-server/async-http-client diff --git a/Package.swift b/Package.swift index 74dcbd6b..1fd56471 100644 --- a/Package.swift +++ b/Package.swift @@ -13,6 +13,7 @@ let package = Package( ], dependencies: [ .package(url: "https://github.com/apple/swift-nio.git", from: "2.33.0"), + .package(url: "https://github.com/apple/swift-nio-extras.git", from: "1.16.0"), .package(url: "https://github.com/apple/swift-nio-ssl.git", from: "2.14.0"), .package(url: "https://github.com/apple/swift-nio-transport-services.git", from: "1.11.4"), .package(url: "https://github.com/apple/swift-atomics.git", from: "1.0.2"), @@ -22,6 +23,7 @@ let package = Package( .product(name: "NIO", package: "swift-nio"), .product(name: "NIOCore", package: "swift-nio"), .product(name: "NIOConcurrencyHelpers", package: "swift-nio"), + .product(name: "NIOExtras", package: "swift-nio-extras"), .product(name: "NIOFoundationCompat", package: "swift-nio"), .product(name: "NIOHTTP1", package: "swift-nio"), .product(name: "NIOSSL", package: "swift-nio-ssl"), diff --git a/Sources/WebSocketKit/HTTPInitialRequestHandler.swift b/Sources/WebSocketKit/HTTPUpgradeRequestHandler.swift similarity index 60% rename from Sources/WebSocketKit/HTTPInitialRequestHandler.swift rename to Sources/WebSocketKit/HTTPUpgradeRequestHandler.swift index e74cbf2e..84af52d1 100644 --- a/Sources/WebSocketKit/HTTPInitialRequestHandler.swift +++ b/Sources/WebSocketKit/HTTPUpgradeRequestHandler.swift @@ -1,7 +1,7 @@ import NIO import NIOHTTP1 -final class HTTPInitialRequestHandler: ChannelInboundHandler, RemovableChannelHandler { +final class HTTPUpgradeRequestHandler: ChannelInboundHandler, RemovableChannelHandler { typealias InboundIn = HTTPClientResponsePart typealias OutboundOut = HTTPClientRequestPart @@ -11,6 +11,8 @@ final class HTTPInitialRequestHandler: ChannelInboundHandler, RemovableChannelHa let headers: HTTPHeaders let upgradePromise: EventLoopPromise + private var requestSent = false + init(host: String, path: String, query: String?, headers: HTTPHeaders, upgradePromise: EventLoopPromise) { self.host = host self.path = path @@ -20,10 +22,33 @@ final class HTTPInitialRequestHandler: ChannelInboundHandler, RemovableChannelHa } func channelActive(context: ChannelHandlerContext) { + self.sendRequest(context: context) + context.fireChannelActive() + } + + func handlerAdded(context: ChannelHandlerContext) { + if context.channel.isActive { + self.sendRequest(context: context) + } + } + + private func sendRequest(context: ChannelHandlerContext) { + if self.requestSent { + // we might run into this handler twice, once in handlerAdded and once in channelActive. + return + } + self.requestSent = true + var headers = self.headers headers.add(name: "Host", value: self.host) - var uri = self.path.hasPrefix("/") ? self.path : "/" + self.path + var uri: String + if self.path.hasPrefix("/") || self.path.hasPrefix("ws://") || self.path.hasPrefix("wss://") { + uri = self.path + } else { + uri = "/" + self.path + } + if let query = self.query { uri += "?\(query)" } @@ -43,10 +68,13 @@ final class HTTPInitialRequestHandler: ChannelInboundHandler, RemovableChannelHa } func channelRead(context: ChannelHandlerContext, data: NIOAny) { + // `NIOHTTPClientUpgradeHandler` should consume the first response in the success case, + // any response we see here indicates a failure. Report the failure and tidy up at the end of the response. let clientResponse = self.unwrapInboundIn(data) switch clientResponse { case .head(let responseHead): - self.upgradePromise.fail(WebSocketClient.Error.invalidResponseStatus(responseHead)) + let error = WebSocketClient.Error.invalidResponseStatus(responseHead) + self.upgradePromise.fail(error) case .body: break case .end: context.close(promise: nil) diff --git a/Sources/WebSocketKit/WebSocket+Connect.swift b/Sources/WebSocketKit/WebSocket+Connect.swift index b996c798..643fcb6c 100644 --- a/Sources/WebSocketKit/WebSocket+Connect.swift +++ b/Sources/WebSocketKit/WebSocket+Connect.swift @@ -3,6 +3,15 @@ import NIOHTTP1 import Foundation extension WebSocket { + /// Establish a WebSocket connection. + /// + /// - Parameters: + /// - url: URL for the WebSocket server. + /// - headers: Headers to send to the WebSocket server. + /// - configuration: Configuration for the WebSocket client. + /// - eventLoopGroup: Event loop group to be used by the WebSocket client. + /// - onUpgrade: An escaping closure to be executed after the upgrade is completed by `NIOWebSocketClientUpgrader`. + /// - Returns: An future which completes when the connection to the WebSocket server is established. public static func connect( to url: String, headers: HTTPHeaders = [:], @@ -22,6 +31,15 @@ extension WebSocket { ) } + /// Establish a WebSocket connection. + /// + /// - Parameters: + /// - url: URL for the WebSocket server. + /// - headers: Headers to send to the WebSocket server. + /// - configuration: Configuration for the WebSocket client. + /// - eventLoopGroup: Event loop group to be used by the WebSocket client. + /// - onUpgrade: An escaping closure to be executed after the upgrade is completed by `NIOWebSocketClientUpgrader`. + /// - Returns: An future which completes when the connection to the WebSocket server is established. public static func connect( to url: URL, headers: HTTPHeaders = [:], @@ -43,6 +61,19 @@ extension WebSocket { ) } + /// Establish a WebSocket connection. + /// + /// - Parameters: + /// - scheme: Scheme component of the URI for the WebSocket server. + /// - host: Host component of the URI for the WebSocket server. + /// - port: Port on which to connect to the WebSocket server. + /// - path: Path component of the URI for the WebSocket server. + /// - query: Query component of the URI for the WebSocket server. + /// - headers: Headers to send to the WebSocket server. + /// - configuration: Configuration for the WebSocket client. + /// - eventLoopGroup: Event loop group to be used by the WebSocket client. + /// - onUpgrade: An escaping closure to be executed after the upgrade is completed by `NIOWebSocketClientUpgrader`. + /// - Returns: An future which completes when the connection to the WebSocket server is established. public static func connect( scheme: String = "ws", host: String, @@ -67,4 +98,98 @@ extension WebSocket { onUpgrade: onUpgrade ) } + + /// Establish a WebSocket connection via a proxy server. + /// + /// - Parameters: + /// - scheme: Scheme component of the URI for the origin server. + /// - host: Host component of the URI for the origin server. + /// - port: Port on which to connect to the origin server. + /// - path: Path component of the URI for the origin server. + /// - query: Query component of the URI for the origin server. + /// - headers: Headers to send to the origin server. + /// - proxy: Host component of the URI for the proxy server. + /// - proxyPort: Port on which to connect to the proxy server. + /// - proxyHeaders: Headers to send to the proxy server. + /// - proxyConnectDeadline: Deadline for establishing the proxy connection. + /// - configuration: Configuration for the WebSocket client. + /// - eventLoopGroup: Event loop group to be used by the WebSocket client. + /// - onUpgrade: An escaping closure to be executed after the upgrade is completed by `NIOWebSocketClientUpgrader`. + /// - Returns: An future which completes when the connection to the origin server is established. + public static func connect( + scheme: String = "ws", + host: String, + port: Int = 80, + path: String = "/", + query: String? = nil, + headers: HTTPHeaders = [:], + proxy: String?, + proxyPort: Int? = nil, + proxyHeaders: HTTPHeaders = [:], + proxyConnectDeadline: NIODeadline = NIODeadline.distantFuture, + configuration: WebSocketClient.Configuration = .init(), + on eventLoopGroup: EventLoopGroup, + onUpgrade: @escaping (WebSocket) -> () + ) -> EventLoopFuture { + return WebSocketClient( + eventLoopGroupProvider: .shared(eventLoopGroup), + configuration: configuration + ).connect( + scheme: scheme, + host: host, + port: port, + path: path, + query: query, + headers: headers, + proxy: proxy, + proxyPort: proxyPort, + proxyHeaders: proxyHeaders, + proxyConnectDeadline: proxyConnectDeadline, + onUpgrade: onUpgrade + ) + } + + + /// Description + /// - Parameters: + /// - url: URL for the origin server. + /// - headers: Headers to send to the origin server. + /// - proxy: Host component of the URI for the proxy server. + /// - proxyPort: Port on which to connect to the proxy server. + /// - proxyHeaders: Headers to send to the proxy server. + /// - proxyConnectDeadline: Deadline for establishing the proxy connection. + /// - configuration: Configuration for the WebSocket client. + /// - eventLoopGroup: Event loop group to be used by the WebSocket client. + /// - onUpgrade: An escaping closure to be executed after the upgrade is completed by `NIOWebSocketClientUpgrader`. + /// - Returns: An future which completes when the connection to the origin server is established. + public static func connect( + to url: String, + headers: HTTPHeaders = [:], + proxy: String?, + proxyPort: Int? = nil, + proxyHeaders: HTTPHeaders = [:], + proxyConnectDeadline: NIODeadline = NIODeadline.distantFuture, + configuration: WebSocketClient.Configuration = .init(), + on eventLoopGroup: EventLoopGroup, + onUpgrade: @escaping (WebSocket) -> () + ) -> EventLoopFuture { + guard let url = URL(string: url) else { + return eventLoopGroup.next().makeFailedFuture(WebSocketClient.Error.invalidURL) + } + let scheme = url.scheme ?? "ws" + return self.connect( + scheme: scheme, + host: url.host ?? "localhost", + port: url.port ?? (scheme == "wss" ? 443 : 80), + path: url.path, + query: url.query, + headers: headers, + proxy: proxy, + proxyPort: proxyPort, + proxyHeaders: proxyHeaders, + proxyConnectDeadline: proxyConnectDeadline, + on: eventLoopGroup, + onUpgrade: onUpgrade + ) + } } diff --git a/Sources/WebSocketKit/WebSocket.swift b/Sources/WebSocketKit/WebSocket.swift index 26c7f6d6..ce915c6d 100644 --- a/Sources/WebSocketKit/WebSocket.swift +++ b/Sources/WebSocketKit/WebSocket.swift @@ -24,7 +24,9 @@ public final class WebSocket { self.channel.closeFuture } - private let channel: Channel + @usableFromInline + /* private but @usableFromInline */ + internal let channel: Channel private var onTextCallback: (WebSocket, String) -> () private var onBinaryCallback: (WebSocket, ByteBuffer) -> () private var onPongCallback: (WebSocket) -> () @@ -64,10 +66,10 @@ public final class WebSocket { } /// If set, this will trigger automatic pings on the connection. If ping is not answered before - /// the next ping is sent, then the WebSocket will be presumed innactive and will be closed + /// the next ping is sent, then the WebSocket will be presumed inactive and will be closed /// automatically. /// These pings can also be used to keep the WebSocket alive if there is some other timeout - /// mechanism shutting down innactive connections, such as a Load Balancer deployed in + /// mechanism shutting down inactive connections, such as a Load Balancer deployed in /// front of the server. public var pingInterval: TimeAmount? { didSet { @@ -82,13 +84,13 @@ public final class WebSocket { } } + @inlinable public func send(_ text: S, promise: EventLoopPromise? = nil) where S: Collection, S.Element == Character { let string = String(text) - var buffer = channel.allocator.buffer(capacity: text.count) - buffer.writeString(string) - self.send(raw: buffer.readableBytesView, opcode: .text, fin: true, promise: promise) + let buffer = channel.allocator.buffer(string: string) + self.send(buffer, opcode: .text, fin: true, promise: promise) } @@ -105,6 +107,7 @@ public final class WebSocket { ) } + @inlinable public func send( raw data: Data, opcode: WebSocketOpcode, @@ -113,13 +116,32 @@ public final class WebSocket { ) where Data: DataProtocol { - var buffer = channel.allocator.buffer(capacity: data.count) - buffer.writeBytes(data) + if let byteBufferView = data as? ByteBufferView { + // optimisation: converting from `ByteBufferView` to `ByteBuffer` doesn't allocate or copy any data + send(ByteBuffer(byteBufferView), opcode: opcode, fin: fin, promise: promise) + } else { + let buffer = channel.allocator.buffer(bytes: data) + send(buffer, opcode: opcode, fin: fin, promise: promise) + } + } + + /// Send the provided data in a WebSocket frame. + /// - Parameters: + /// - data: Data to be sent. + /// - opcode: Frame opcode. + /// - fin: The value of the fin bit. + /// - promise: A promise to be completed when the write is complete. + public func send( + _ data: ByteBuffer, + opcode: WebSocketOpcode = .binary, + fin: Bool = true, + promise: EventLoopPromise? = nil + ) { let frame = WebSocketFrame( fin: fin, opcode: opcode, maskKey: self.makeMaskKey(), - data: buffer + data: data ) self.channel.writeAndFlush(frame, promise: promise) } @@ -164,11 +186,7 @@ public final class WebSocket { func makeMaskKey() -> WebSocketMaskingKey? { switch type { case .client: - var bytes: [UInt8] = [] - for _ in 0..<4 { - bytes.append(.random(in: .min ..< .max)) - } - return WebSocketMaskingKey(bytes) + return WebSocketMaskingKey.random() case .server: return nil } @@ -237,14 +255,8 @@ public final class WebSocket { frameSequence.append(frame) self.frameSequence = frameSequence case .continuation: - // we must have an existing sequence - if var frameSequence = self.frameSequence { - // append this frame and update - frameSequence.append(frame) - self.frameSequence = frameSequence - } else { - self.close(code: .protocolError, promise: nil) - } + /// continuations are filtered by ``NIOWebSocketFrameAggregator`` + preconditionFailure("We will never receive a continuation frame") default: // We ignore all other frames. break diff --git a/Sources/WebSocketKit/WebSocketClient.swift b/Sources/WebSocketKit/WebSocketClient.swift index 2f13cfc7..b5cd072d 100644 --- a/Sources/WebSocketKit/WebSocketClient.swift +++ b/Sources/WebSocketKit/WebSocketClient.swift @@ -1,6 +1,7 @@ import Foundation import NIO import NIOConcurrencyHelpers +import NIOExtras import NIOHTTP1 import NIOWebSocket import NIOSSL @@ -26,12 +27,25 @@ public final class WebSocketClient { public var tlsConfiguration: TLSConfiguration? public var maxFrameSize: Int + /// Defends against small payloads in frame aggregation. + /// See `NIOWebSocketFrameAggregator` for details. + public var minNonFinalFragmentSize: Int + /// Max number of fragments in an aggregated frame. + /// See `NIOWebSocketFrameAggregator` for details. + public var maxAccumulatedFrameCount: Int + /// Maximum frame size after aggregation. + /// See `NIOWebSocketFrameAggregator` for details. + public var maxAccumulatedFrameSize: Int + public init( tlsConfiguration: TLSConfiguration? = nil, maxFrameSize: Int = 1 << 14 ) { self.tlsConfiguration = tlsConfiguration self.maxFrameSize = maxFrameSize + self.minNonFinalFragmentSize = 0 + self.maxAccumulatedFrameCount = Int.max + self.maxAccumulatedFrameSize = Int.max } } @@ -59,30 +73,71 @@ public final class WebSocketClient { query: String? = nil, headers: HTTPHeaders = [:], onUpgrade: @escaping (WebSocket) -> () + ) -> EventLoopFuture { + self.connect(scheme: scheme, host: host, port: port, path: path, query: query, headers: headers, proxy: nil, onUpgrade: onUpgrade) + } + + /// Establish a WebSocket connection via a proxy server. + /// + /// - Parameters: + /// - scheme: Scheme component of the URI for the origin server. + /// - host: Host component of the URI for the origin server. + /// - port: Port on which to connect to the origin server. + /// - path: Path component of the URI for the origin server. + /// - query: Query component of the URI for the origin server. + /// - headers: Headers to send to the origin server. + /// - proxy: Host component of the URI for the proxy server. + /// - proxyPort: Port on which to connect to the proxy server. + /// - proxyHeaders: Headers to send to the proxy server. + /// - proxyConnectDeadline: Deadline for establishing the proxy connection. + /// - onUpgrade: An escaping closure to be executed after the upgrade is completed by `NIOWebSocketClientUpgrader`. + /// - Returns: An future which completes when the connection to the origin server is established. + public func connect( + scheme: String, + host: String, + port: Int, + path: String = "/", + query: String? = nil, + headers: HTTPHeaders = [:], + proxy: String?, + proxyPort: Int? = nil, + proxyHeaders: HTTPHeaders = [:], + proxyConnectDeadline: NIODeadline = NIODeadline.distantFuture, + onUpgrade: @escaping (WebSocket) -> () ) -> EventLoopFuture { assert(["ws", "wss"].contains(scheme)) let upgradePromise = self.group.next().makePromise(of: Void.self) let bootstrap = WebSocketClient.makeBootstrap(on: self.group) .channelOption(ChannelOptions.socket(SocketOptionLevel(IPPROTO_TCP), TCP_NODELAY), value: 1) - .channelInitializer { channel in - let httpHandler = HTTPInitialRequestHandler( + .channelInitializer { channel -> EventLoopFuture in + + let uri: String + var upgradeRequestHeaders = headers + if proxy == nil { + uri = path + } else { + let relativePath = path.hasPrefix("/") ? path : "/" + path + let port = proxyPort.map { ":\($0)" } ?? "" + uri = "\(scheme)://\(host)\(relativePath)\(port)" + + if scheme == "ws" { + upgradeRequestHeaders.add(contentsOf: proxyHeaders) + } + } + + let httpUpgradeRequestHandler = HTTPUpgradeRequestHandler( host: host, - path: path, + path: uri, query: query, - headers: headers, + headers: upgradeRequestHeaders, upgradePromise: upgradePromise ) - var key: [UInt8] = [] - for _ in 0..<16 { - key.append(.random(in: .min ..< .max)) - } let websocketUpgrader = NIOWebSocketClientUpgrader( - requestKey: Data(key).base64EncodedString(), maxFrameSize: self.configuration.maxFrameSize, automaticErrorHandling: true, upgradePipelineHandler: { channel, req in - return WebSocket.client(on: channel, onUpgrade: onUpgrade) + return WebSocket.client(on: channel, config: .init(clientConfig: self.configuration), onUpgrade: onUpgrade) } ) @@ -90,46 +145,105 @@ public final class WebSocketClient { upgraders: [websocketUpgrader], completionHandler: { context in upgradePromise.succeed(()) - channel.pipeline.removeHandler(httpHandler, promise: nil) + channel.pipeline.removeHandler(httpUpgradeRequestHandler, promise: nil) } ) - if scheme == "wss" { - do { - let context = try NIOSSLContext( - configuration: self.configuration.tlsConfiguration ?? .makeClientConfiguration() - ) - let tlsHandler: NIOSSLClientHandler + if proxy == nil || scheme == "ws" { + if scheme == "wss" { do { - tlsHandler = try NIOSSLClientHandler(context: context, serverHostname: host) - } catch let error as NIOSSLExtraError where error == .cannotUseIPAddressInSNI { - tlsHandler = try NIOSSLClientHandler(context: context, serverHostname: nil) + let tlsHandler = try self.makeTLSHandler(tlsConfiguration: self.configuration.tlsConfiguration, host: host) + // The sync methods here are safe because we're on the channel event loop + // due to the promise originating on the event loop of the channel. + try channel.pipeline.syncOperations.addHandler(tlsHandler) + } catch { + return channel.pipeline.close(mode: .all) } - return channel.pipeline.addHandler(tlsHandler).flatMap { - channel.pipeline.addHTTPClientHandlers(leftOverBytesStrategy: .forwardBytes, withClientUpgrade: config) - }.flatMap { - channel.pipeline.addHandler(httpHandler) - } - } catch { - return channel.pipeline.close(mode: .all) } - } else { + return channel.pipeline.addHTTPClientHandlers( leftOverBytesStrategy: .forwardBytes, withClientUpgrade: config ).flatMap { - channel.pipeline.addHandler(httpHandler) + channel.pipeline.addHandler(httpUpgradeRequestHandler) + } + } + + // TLS + proxy + // we need to handle connecting with an additional CONNECT request + let proxyEstablishedPromise = channel.eventLoop.makePromise(of: Void.self) + let encoder = HTTPRequestEncoder() + let decoder = ByteToMessageHandler(HTTPResponseDecoder(leftOverBytesStrategy: .dropBytes)) + + var connectHeaders = proxyHeaders + connectHeaders.add(name: "Host", value: host) + + let proxyRequestHandler = NIOHTTP1ProxyConnectHandler( + targetHost: host, + targetPort: port, + headers: connectHeaders, + deadline: proxyConnectDeadline, + promise: proxyEstablishedPromise + ) + + // This code block adds HTTP handlers to allow the proxy request handler to function. + // They are then removed upon completion only to be re-added in `addHTTPClientHandlers`. + // This is done because the HTTP decoder is not valid after an upgrade, the CONNECT request being counted as one. + do { + try channel.pipeline.syncOperations.addHandler(encoder) + try channel.pipeline.syncOperations.addHandler(decoder) + try channel.pipeline.syncOperations.addHandler(proxyRequestHandler) + } catch { + return channel.eventLoop.makeFailedFuture(error) + } + + proxyEstablishedPromise.futureResult.flatMap { + channel.pipeline.removeHandler(decoder) + }.flatMap { + channel.pipeline.removeHandler(encoder) + }.whenComplete { result in + switch result { + case .success: + do { + let tlsHandler = try self.makeTLSHandler(tlsConfiguration: self.configuration.tlsConfiguration, host: host) + // The sync methods here are safe because we're on the channel event loop + // due to the promise originating on the event loop of the channel. + try channel.pipeline.syncOperations.addHandler(tlsHandler) + try channel.pipeline.syncOperations.addHTTPClientHandlers( + leftOverBytesStrategy: .forwardBytes, + withClientUpgrade: config + ) + try channel.pipeline.syncOperations.addHandler(httpUpgradeRequestHandler) + } catch { + channel.pipeline.close(mode: .all, promise: nil) + } + case .failure: + channel.pipeline.close(mode: .all, promise: nil) } } + + return channel.eventLoop.makeSucceededVoidFuture() } - let connect = bootstrap.connect(host: host, port: port) + let connect = bootstrap.connect(host: proxy ?? host, port: proxyPort ?? port) connect.cascadeFailure(to: upgradePromise) return connect.flatMap { channel in return upgradePromise.futureResult } } + private func makeTLSHandler(tlsConfiguration: TLSConfiguration?, host: String) throws -> NIOSSLClientHandler { + let context = try NIOSSLContext( + configuration: self.configuration.tlsConfiguration ?? .makeClientConfiguration() + ) + let tlsHandler: NIOSSLClientHandler + do { + tlsHandler = try NIOSSLClientHandler(context: context, serverHostname: host) + } catch let error as NIOSSLExtraError where error == .cannotUseIPAddressInSNI { + tlsHandler = try NIOSSLClientHandler(context: context, serverHostname: nil) + } + return tlsHandler + } public func syncShutdown() throws { switch self.eventLoopGroupProvider { @@ -153,13 +267,13 @@ public final class WebSocketClient { if let tsBootstrap = NIOTSConnectionBootstrap(validatingGroup: eventLoop) { return tsBootstrap } - #endif + #endif - if let nioBootstrap = ClientBootstrap(validatingGroup: eventLoop) { - return nioBootstrap - } + if let nioBootstrap = ClientBootstrap(validatingGroup: eventLoop) { + return nioBootstrap + } - fatalError("No matching bootstrap found") + fatalError("No matching bootstrap found") } deinit { diff --git a/Sources/WebSocketKit/WebSocketHandler.swift b/Sources/WebSocketKit/WebSocketHandler.swift index b54f9fb7..1af0f76a 100644 --- a/Sources/WebSocketKit/WebSocketHandler.swift +++ b/Sources/WebSocketKit/WebSocketHandler.swift @@ -2,27 +2,100 @@ import NIO import NIOWebSocket extension WebSocket { + + /// Stores configuration for a WebSocket client/server instance + public struct Configuration { + /// Defends against small payloads in frame aggregation. + /// See `NIOWebSocketFrameAggregator` for details. + public var minNonFinalFragmentSize: Int + /// Max number of fragments in an aggregated frame. + /// See `NIOWebSocketFrameAggregator` for details. + public var maxAccumulatedFrameCount: Int + /// Maximum frame size after aggregation. + /// See `NIOWebSocketFrameAggregator` for details. + public var maxAccumulatedFrameSize: Int + + public init() { + self.minNonFinalFragmentSize = 0 + self.maxAccumulatedFrameCount = Int.max + self.maxAccumulatedFrameSize = Int.max + } + + internal init(clientConfig: WebSocketClient.Configuration) { + self.minNonFinalFragmentSize = clientConfig.minNonFinalFragmentSize + self.maxAccumulatedFrameCount = clientConfig.maxAccumulatedFrameCount + self.maxAccumulatedFrameSize = clientConfig.maxAccumulatedFrameSize + } + } + + /// Sets up a channel to operate as a WebSocket client. + /// - Parameters: + /// - channel: NIO channel which the client will use to communicate. + /// - onUpgrade: An escaping closure to be executed the channel is configured with the WebSocket handlers. + /// - Returns: An future which completes when the WebSocket connection to the server is established. + public static func client( + on channel: Channel, + onUpgrade: @escaping (WebSocket) -> () + ) -> EventLoopFuture { + return self.configure(on: channel, as: .client, with: Configuration(), onUpgrade: onUpgrade) + } + + /// Sets up a channel to operate as a WebSocket client. + /// - Parameters: + /// - channel: NIO channel which the client/server will use to communicate. + /// - config: Configuration for the client channel handlers. + /// - onUpgrade: An escaping closure to be executed the channel is configured with the WebSocket handlers. + /// - Returns: An future which completes when the WebSocket connection to the server is established. public static func client( + on channel: Channel, + config: Configuration, + onUpgrade: @escaping (WebSocket) -> () + ) -> EventLoopFuture { + return self.configure(on: channel, as: .client, with: config, onUpgrade: onUpgrade) + } + + /// Sets up a channel to operate as a WebSocket server. + /// - Parameters: + /// - channel: NIO channel which the server will use to communicate. + /// - onUpgrade: An escaping closure to be executed the channel is configured with the WebSocket handlers. + /// - Returns: An future which completes when the WebSocket connection to the server is established. + public static func server( on channel: Channel, onUpgrade: @escaping (WebSocket) -> () ) -> EventLoopFuture { - return self.handle(on: channel, as: .client, onUpgrade: onUpgrade) + return self.configure(on: channel, as: .server, with: Configuration(), onUpgrade: onUpgrade) } + /// Sets up a channel to operate as a WebSocket server. + /// - Parameters: + /// - channel: NIO channel which the server will use to communicate. + /// - config: Configuration for the server channel handlers. + /// - onUpgrade: An escaping closure to be executed the channel is configured with the WebSocket handlers. + /// - Returns: An future which completes when the WebSocket connection to the server is established. public static func server( on channel: Channel, + config: Configuration, onUpgrade: @escaping (WebSocket) -> () ) -> EventLoopFuture { - return self.handle(on: channel, as: .server, onUpgrade: onUpgrade) + return self.configure(on: channel, as: .server, with: config, onUpgrade: onUpgrade) } - private static func handle( + private static func configure( on channel: Channel, as type: PeerType, + with config: Configuration, onUpgrade: @escaping (WebSocket) -> () ) -> EventLoopFuture { let webSocket = WebSocket(channel: channel, type: type) - return channel.pipeline.addHandler(WebSocketHandler(webSocket: webSocket)).map { _ in + + return channel.pipeline.addHandlers([ + NIOWebSocketFrameAggregator( + minNonFinalFragmentSize: config.minNonFinalFragmentSize, + maxAccumulatedFrameCount: config.maxAccumulatedFrameCount, + maxAccumulatedFrameSize: config.maxAccumulatedFrameSize + ), + WebSocketHandler(webSocket: webSocket) + ]).map { _ in onUpgrade(webSocket) } } diff --git a/Tests/WebSocketKitTests/WebSocketKitTests.swift b/Tests/WebSocketKitTests/WebSocketKitTests.swift index 19bef5a1..7cc33a92 100644 --- a/Tests/WebSocketKitTests/WebSocketKitTests.swift +++ b/Tests/WebSocketKitTests/WebSocketKitTests.swift @@ -1,5 +1,7 @@ import XCTest +import Atomics import NIO +import NIOExtras import NIOHTTP1 import NIOSSL import NIOWebSocket @@ -125,7 +127,7 @@ final class WebSocketKitTests: XCTestCase { let pingPromise = self.elg.next().makePromise(of: String.self) let pongPromise = self.elg.next().makePromise(of: String.self) let pingPongData = ByteBuffer(bytes: "Vapor rules".utf8) - + let server = try ServerBootstrap.webSocket(on: self.elg) { req, ws in ws.onPing { ws in pingPromise.succeed("ping") @@ -150,6 +152,41 @@ final class WebSocketKitTests: XCTestCase { try server.close(mode: .all).wait() } + func testWebSocketAggregateFrames() throws { + func byteBuffView(_ str: String) -> ByteBufferView { + ByteBuffer(string: str).readableBytesView + } + + let server = try ServerBootstrap.webSocket(on: self.elg) { req, ws in + ws.onText { ws, text in + ws.send(text, opcode: .text, fin: false) + ws.send(" th", opcode: .continuation, fin: false) + ws.send("e mo", opcode: .continuation, fin: false) + ws.send("st", opcode: .continuation, fin: true) + } + }.bind(host: "localhost", port: 0).wait() + + guard let port = server.localAddress?.port else { + XCTFail("couldn't get port from \(server.localAddress.debugDescription)") + return + } + + let promise = elg.next().makePromise(of: String.self) + let closePromise = elg.next().makePromise(of: Void.self) + WebSocket.connect(to: "ws://localhost:\(port)", on: elg) { ws in + ws.send("Hel", opcode: .text, fin: false) + ws.send("lo! Vapor r", opcode: .continuation, fin: false) + ws.send("ules", opcode: .continuation, fin: true) + ws.onText { ws, string in + promise.succeed(string) + ws.close(promise: closePromise) + } + }.cascadeFailure(to: promise) + try XCTAssertEqual(promise.futureResult.wait(), "Hello! Vapor rules the most") + XCTAssertNoThrow(try closePromise.futureResult.wait()) + try server.close(mode: .all).wait() + } + func testErrorCode() throws { let promise = self.elg.next().makePromise(of: WebSocketErrorCode.self) @@ -299,6 +336,122 @@ final class WebSocketKitTests: XCTestCase { try server.close(mode: .all).wait() } + func testProxy() throws { + let promise = elg.next().makePromise(of: String.self) + + let localWebsocketBin: WebsocketBin + let verifyProxyHead = { (ctx: ChannelHandlerContext, requestHead: HTTPRequestHead) in + XCTAssertEqual(requestHead.uri, "ws://apple.com/:\(ctx.localAddress!.port!)") + XCTAssertEqual(requestHead.headers.first(name: "Host"), "apple.com") + } + localWebsocketBin = WebsocketBin( + .http1_1(ssl: false), + proxy: .simulate( + config: WebsocketBin.ProxyConfig(tls: false, headVerification: verifyProxyHead), + authorization: "token amFwcGxlc2VlZDpwYXNzMTIz" + ), + sslContext: nil + ) { req, ws in + ws.onText { ws, text in + ws.send(text) + } + } + + defer { + XCTAssertNoThrow(try localWebsocketBin.shutdown()) + } + + let closePromise = elg.next().makePromise(of: Void.self) + + let client = WebSocketClient( + eventLoopGroupProvider: .shared(self.elg), + configuration: .init() + ) + + client.connect( + scheme: "ws", + host: "apple.com", + port: localWebsocketBin.port, + proxy: "localhost", + proxyPort: localWebsocketBin.port, + proxyHeaders: HTTPHeaders([("proxy-authorization", "token amFwcGxlc2VlZDpwYXNzMTIz")]) + ) { ws in + ws.send("hello") + ws.onText { ws, string in + promise.succeed(string) + ws.close(promise: closePromise) + } + }.cascadeFailure(to: promise) + + XCTAssertEqual(try promise.futureResult.wait(), "hello") + XCTAssertNoThrow(try closePromise.futureResult.wait()) + } + + func testProxyTLS() throws { + let promise = elg.next().makePromise(of: String.self) + + let (cert, key) = generateSelfSignedCert() + let configuration = TLSConfiguration.makeServerConfiguration( + certificateChain: [.certificate(cert)], + privateKey: .privateKey(key) + ) + let sslContext = try! NIOSSLContext(configuration: configuration) + + let verifyProxyHead = { (ctx: ChannelHandlerContext, requestHead: HTTPRequestHead) in + // CONNECT uses a special form of request target, unique to this method, consisting of + // only the host and port number of the tunnel destination, separated by a colon. + // https://httpwg.org/specs/rfc9110.html#CONNECT + XCTAssertEqual(requestHead.uri, "apple.com:\(ctx.localAddress!.port!)") + XCTAssertEqual(requestHead.headers.first(name: "Host"), "apple.com") + } + let localWebsocketBin = WebsocketBin( + .http1_1(ssl: true), + proxy: .simulate( + config: WebsocketBin.ProxyConfig(tls: true, headVerification: verifyProxyHead), + authorization: "token amFwcGxlc2VlZDpwYXNzMTIz" + ), + sslContext: sslContext + ) { req, ws in + ws.onText { ws, text in + ws.send(text) + } + } + + defer { + XCTAssertNoThrow(try localWebsocketBin.shutdown()) + } + + let closePromise = elg.next().makePromise(of: Void.self) + var tlsConfiguration = TLSConfiguration.makeClientConfiguration() + tlsConfiguration.certificateVerification = .none + + let client = WebSocketClient( + eventLoopGroupProvider: .shared(self.elg), + configuration: .init( + tlsConfiguration: tlsConfiguration + ) + ) + + client.connect( + scheme: "wss", + host: "apple.com", + port: localWebsocketBin.port, + proxy: "localhost", + proxyPort: localWebsocketBin.port, + proxyHeaders: HTTPHeaders([("proxy-authorization", "token amFwcGxlc2VlZDpwYXNzMTIz")]) + ) { ws in + ws.send("hello") + ws.onText { ws, string in + promise.succeed(string) + ws.close(promise: closePromise) + } + }.cascadeFailure(to: promise) + + XCTAssertEqual(try promise.futureResult.wait(), "hello") + XCTAssertNoThrow(try closePromise.futureResult.wait()) + } + + var elg: EventLoopGroup! override func setUp() { // needs to be at least two to avoid client / server on same EL timing issues @@ -347,3 +500,333 @@ extension ServerBootstrap { } } } + +fileprivate extension WebSocket { + func send( + _ data: String, + opcode: WebSocketOpcode, + fin: Bool = true, + promise: EventLoopPromise? = nil + ) { + self.send(raw: ByteBuffer(string: data).readableBytesView, opcode: opcode, fin: fin, promise: promise) + } +} + + + +internal final class WebsocketBin { + enum BindTarget { + case unixDomainSocket(String) + case localhostIPv4RandomPort + case localhostIPv6RandomPort + } + + enum Mode { + // refuses all connections + case refuse + // supports http1.1 connections only, which can be either plain text or encrypted + case http1_1(ssl: Bool = false) + } + + enum Proxy { + case none + case simulate(config: ProxyConfig, authorization: String?) + } + + struct ProxyConfig { + var tls: Bool + let headVerification: (ChannelHandlerContext, HTTPRequestHead) -> Void + } + + let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) + + var port: Int { + return Int(self.serverChannel.localAddress!.port!) + } + + private let mode: Mode + private let sslContext: NIOSSLContext? + private var serverChannel: Channel! + private let isShutdown = ManagedAtomic(false) + + init( + _ mode: Mode = .http1_1(ssl: false), + proxy: Proxy = .none, + bindTarget: BindTarget = .localhostIPv4RandomPort, + sslContext: NIOSSLContext?, + onUpgrade: @escaping (HTTPRequestHead, WebSocket) -> () + ) { + self.mode = mode + self.sslContext = sslContext + + let socketAddress: SocketAddress + switch bindTarget { + case .localhostIPv4RandomPort: + socketAddress = try! SocketAddress(ipAddress: "127.0.0.1", port: 0) + case .localhostIPv6RandomPort: + socketAddress = try! SocketAddress(ipAddress: "::1", port: 0) + case .unixDomainSocket(let path): + socketAddress = try! SocketAddress(unixDomainSocketPath: path) + } + + self.serverChannel = try! ServerBootstrap(group: self.group) + .serverChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1) + .childChannelInitializer { channel in + do { + + if case .refuse = mode { + throw HTTPBinError.refusedConnection + } + + let webSocket = NIOWebSocketServerUpgrader( + shouldUpgrade: { channel, req in + return channel.eventLoop.makeSucceededFuture([:]) + }, + upgradePipelineHandler: { channel, req in + return WebSocket.server(on: channel) { ws in + onUpgrade(req, ws) + } + } + ) + + // if we need to simulate a proxy, we need to add those handlers first + if case .simulate(config: let config, authorization: let expectedAuthorization) = proxy { + if config.tls { + try self.syncAddTLSHTTPProxyHandlers( + to: channel, + proxyConfig: config, + expectedAuthorization: expectedAuthorization, + upgraders: [webSocket] + ) + } else { + try self.syncAddHTTPProxyHandlers( + to: channel, + proxyConfig: config, + expectedAuthorization: expectedAuthorization, + upgraders: [webSocket] + ) + } + return channel.eventLoop.makeSucceededVoidFuture() + } + + // if a connection has been established, we need to negotiate TLS before + // anything else. Depending on the negotiation, the HTTPHandlers will be added. + if let sslContext = self.sslContext { + try channel.pipeline.syncOperations.addHandler(NIOSSLServerHandler(context: sslContext)) + } + + // if neither HTTP Proxy nor TLS are wanted, we can add HTTP1 handlers directly + try channel.pipeline.syncOperations.configureHTTPServerPipeline( + withPipeliningAssistance: true, + withServerUpgrade: ( + upgraders: [webSocket], + completionHandler: { ctx in + // complete + } + ), + withErrorHandling: true + ) + return channel.eventLoop.makeSucceededVoidFuture() + } catch { + return channel.eventLoop.makeFailedFuture(error) + } + }.bind(to: socketAddress).wait() + } + + + // In the TLS case we must set up the 'proxy' and the 'server' handlers sequentially + // rather than re-using parts because the requestDecoder stops parsing after a CONNECT request + private func syncAddTLSHTTPProxyHandlers( + to channel: Channel, + proxyConfig: ProxyConfig, + expectedAuthorization: String?, + upgraders: [HTTPServerProtocolUpgrader] + ) throws { + let sync = channel.pipeline.syncOperations + let promise = channel.eventLoop.makePromise(of: Void.self) + + let responseEncoder = HTTPResponseEncoder() + let requestDecoder = ByteToMessageHandler(HTTPRequestDecoder(leftOverBytesStrategy: .forwardBytes)) + let proxySimulator = HTTPProxySimulator(promise: promise, config: proxyConfig, expectedAuthorization: expectedAuthorization) + + try sync.addHandler(responseEncoder) + try sync.addHandler(requestDecoder) + + try sync.addHandler(proxySimulator) + + promise.futureResult.flatMap { _ in + channel.pipeline.removeHandler(proxySimulator) + }.flatMap { _ in + channel.pipeline.removeHandler(responseEncoder) + }.flatMap { _ in + channel.pipeline.removeHandler(requestDecoder) + }.whenComplete { result in + switch result { + case .failure: + channel.close(mode: .all, promise: nil) + case .success: + self.httpProxyEstablished(channel, upgraders: upgraders) + break + } + } + } + + + // In the plain-text case we must set up the 'proxy' and the 'server' handlers simultaneously + // so that the combined proxy/upgrade request can be processed by the separate proxy and upgrade handlers + private func syncAddHTTPProxyHandlers( + to channel: Channel, + proxyConfig: ProxyConfig, + expectedAuthorization: String?, + upgraders: [HTTPServerProtocolUpgrader] + ) throws { + let sync = channel.pipeline.syncOperations + let promise = channel.eventLoop.makePromise(of: Void.self) + + let responseEncoder = HTTPResponseEncoder() + let requestDecoder = ByteToMessageHandler(HTTPRequestDecoder(leftOverBytesStrategy: .forwardBytes)) + let proxySimulator = HTTPProxySimulator(promise: promise, config: proxyConfig, expectedAuthorization: expectedAuthorization) + + let serverPipelineHandler = HTTPServerPipelineHandler() + let serverProtocolErrorHandler = HTTPServerProtocolErrorHandler() + + let extraHTTPHandlers: [RemovableChannelHandler] = [ + requestDecoder, + serverPipelineHandler, + serverProtocolErrorHandler + ] + + try sync.addHandler(responseEncoder) + try sync.addHandler(requestDecoder) + + try sync.addHandler(proxySimulator) + + try sync.addHandler(serverPipelineHandler) + try sync.addHandler(serverProtocolErrorHandler) + + + let upgrader = HTTPServerUpgradeHandler(upgraders: upgraders, + httpEncoder: responseEncoder, + extraHTTPHandlers: extraHTTPHandlers, + upgradeCompletionHandler: { ctx in + // complete + }) + + + try sync.addHandler(upgrader) + + promise.futureResult.flatMap { () -> EventLoopFuture in + channel.pipeline.removeHandler(proxySimulator) + }.whenComplete { result in + switch result { + case .failure: + channel.close(mode: .all, promise: nil) + case .success: + break + } + } + } + + private func httpProxyEstablished(_ channel: Channel, upgraders: [HTTPServerProtocolUpgrader]) { + do { + // if a connection has been established, we need to negotiate TLS before + // anything else. Depending on the negotiation, the HTTPHandlers will be added. + if let sslContext = self.sslContext { + try channel.pipeline.syncOperations.addHandler(NIOSSLServerHandler(context: sslContext)) + } + + try channel.pipeline.syncOperations.configureHTTPServerPipeline( + withPipeliningAssistance: true, + withServerUpgrade: ( + upgraders: upgraders, + completionHandler: { ctx in + // complete + } + ), + withErrorHandling: true + ) + } catch { + // in case of an while modifying the pipeline we should close the connection + channel.close(mode: .all, promise: nil) + } + } + + func shutdown() throws { + self.isShutdown.store(true, ordering: .relaxed) + try self.group.syncShutdownGracefully() + } +} + +enum HTTPBinError: Error { + case refusedConnection + case invalidProxyRequest +} + +final class HTTPProxySimulator: ChannelInboundHandler, RemovableChannelHandler { + typealias InboundIn = HTTPServerRequestPart + typealias InboundOut = HTTPServerResponsePart + typealias OutboundOut = HTTPServerResponsePart + + + // the promise to succeed, once the proxy connection is setup + let promise: EventLoopPromise + let config: WebsocketBin.ProxyConfig + let expectedAuthorization: String? + + var head: HTTPResponseHead + + init(promise: EventLoopPromise, config: WebsocketBin.ProxyConfig, expectedAuthorization: String?) { + self.promise = promise + self.config = config + self.expectedAuthorization = expectedAuthorization + self.head = HTTPResponseHead(version: .init(major: 1, minor: 1), status: .ok, headers: .init([("Content-Length", "0")])) + } + + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + let request = self.unwrapInboundIn(data) + switch request { + case .head(let head): + if self.config.tls { + guard head.method == .CONNECT else { + self.head.status = .badRequest + return + } + } else { + guard head.method == .GET else { + self.head.status = .badRequest + return + } + } + + self.config.headVerification(context, head) + + if let expectedAuthorization = self.expectedAuthorization { + guard let authorization = head.headers["proxy-authorization"].first, + expectedAuthorization == authorization else { + self.head.status = .proxyAuthenticationRequired + return + } + } + if !self.config.tls { + context.fireChannelRead(data) + } + + case .body: + () + case .end: + if self.self.config.tls { + context.write(self.wrapOutboundOut(.head(self.head)), promise: nil) + context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil) + } + if self.head.status == .ok { + if !self.config.tls { + context.fireChannelRead(data) + } + self.promise.succeed(()) + } else { + self.promise.fail(HTTPBinError.invalidProxyRequest) + } + } + } +} +