diff --git a/src/server/async_once_cell.rs b/src/async_once_cell.rs similarity index 100% rename from src/server/async_once_cell.rs rename to src/async_once_cell.rs diff --git a/src/client.rs b/src/client.rs index 1c586b5..4cf443a 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,72 +1,270 @@ -use std::pin::Pin; -use std::task::{Context, Poll}; - -use anyhow::{Context as _, Result}; -use pin_project_lite::pin_project; -use ra_multiplex::config::Config; -use ra_multiplex::proto; -use tokio::io::{self, AsyncRead, AsyncWrite, AsyncWriteExt}; +use std::io::ErrorKind; +use std::sync::Arc; + +use anyhow::{bail, Context, Result}; +use serde_json::Value; +use tokio::io::BufReader; +use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf}; use tokio::net::TcpStream; +use tokio::sync::mpsc; +use tokio::{select, task}; +use tracing::{debug, error, info, trace, Instrument}; -pin_project! { - struct Stdio { - #[pin] - stdin: io::Stdin, - #[pin] - stdout: io::Stdout, - } +use crate::instance::{ + InitializeCache, InstanceKey, InstanceRegistry, RaInstance, INIT_REQUEST_ID, +}; +use crate::lsp::jsonrpc::{Message, ResponseSuccess, Version}; +use crate::lsp::transport::{LspReader, LspWriter}; +use crate::proto; + +pub struct Client { + port: u16, + initialize_request_id: Option, + instance: Arc, } -fn stdio() -> Stdio { - Stdio { - stdin: io::stdin(), - stdout: io::stdout(), +impl Client { + /// finds or spawns a rust-analyzer instance and connects the client + pub async fn process(socket: TcpStream, port: u16, registry: InstanceRegistry) -> Result<()> { + let (socket_read, socket_write) = socket.into_split(); + let mut socket_read = BufReader::new(socket_read); + + let mut buffer = Vec::new(); + let proto_init = proto::Init::from_reader(&mut buffer, &mut socket_read).await?; + + let key = InstanceKey::from_proto_init(&proto_init).await; + debug!( + path = ?key.workspace_root(), + server = ?key.server(), + args = ?key.args(), + "client configured", + ); + + let mut client = Client { + port, + initialize_request_id: None, + instance: registry.get(&key).await?, + }; + + client.wait_for_initialize_request(&mut socket_read).await?; + + let (client_tx, client_rx) = client.register_client_with_instance().await; + let (close_tx, close_rx) = mpsc::channel(1); + client.spawn_input_task(client_rx, close_rx, socket_write); + client.spawn_output_task(socket_read, close_tx); + + client.wait_for_initialize_response(client_tx).await?; + Ok(()) } -} -impl AsyncRead for Stdio { - fn poll_read( - self: Pin<&mut Self>, - cx: &mut Context, - buf: &mut io::ReadBuf, - ) -> Poll> { - self.project().stdin.poll_read(cx, buf) + async fn wait_for_initialize_request( + &mut self, + socket_read: &mut BufReader, + ) -> Result<()> { + let mut reader = LspReader::new(socket_read); + + let message = reader.read_message().await?.context("channel closed")?; + trace!(?message, "<- client"); + + let mut req = match message { + Message::Request(req) if req.method == "initialize" => { + debug!(message = ?Message::from(req.clone()), "recv InitializeRequest"); + req + } + _ => bail!("first client message was not InitializeRequest"), + }; + + // this is an initialize request, it's special because it cannot be sent twice or + // rust-analyzer will crash. + + // we save the request id so we can later use it for the response + self.initialize_request_id = Some(req.id.clone()); + + if self.instance.init_cache.attempt_send_request() { + // it haven't been sent yet, we can send it. + // + // instead of tagging the original id we replace it with a custom id that only + // the `initialize` uses + req.id = Value::String(INIT_REQUEST_ID.to_owned()); + + self.instance + .message_writer + .send(req.into()) + .await + .context("forward client request")?; + } else { + // initialize request was already sent for this instance, no need to send it again + } + Ok(()) } -} -impl AsyncWrite for Stdio { - fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll> { - self.project().stdout.poll_write(cx, buf) + async fn wait_for_initialize_response(&self, tx: mpsc::Sender) -> Result<()> { + // parse the cached message and restore the `id` to the value this client expects + let mut res = self.instance.init_cache.response.get().await; + res.id = self + .initialize_request_id + .clone() + .expect("BUG: need to wait_for_initialize_request first"); + debug!(message = ?Message::from(res.clone()), "send response to InitializeRequest"); + tx.send(res.into()) + .await + .context("send initialize response")?; + Ok(()) } - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - self.project().stdout.poll_flush(cx) + async fn register_client_with_instance( + &self, + ) -> (mpsc::Sender, mpsc::Receiver) { + let (client_tx, client_rx) = mpsc::channel(64); + self.instance + .message_readers + .write() + .await + .insert(self.port, client_tx.clone()); + (client_tx, client_rx) } - fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - self.project().stdout.poll_shutdown(cx) + fn spawn_input_task( + &self, + mut rx: mpsc::Receiver, + mut close_rx: mpsc::Receiver, + socket_write: OwnedWriteHalf, + ) { + let mut writer = LspWriter::new(socket_write); + task::spawn( + async move { + // unlike the output task, here we first wait on the channel which is going to + // block until the rust-analyzer server sends a notification, however if we're the last + // client and have just closed the server is unlikely to send any. this results in the + // last client often falsely hanging while the gc task depends on the input channels being + // closed to detect a disconnected client. + // + // when a client sends a shutdown request we receive a message on the close_rx, send + // the reply and close the connection. if no shutdown request was received but the + // client closed close_rx channel will be dropped (unlike the normal rx channel which + // is shared) and the connection will close without sending any response. + while let Some(message) = select! { + message = close_rx.recv() => message, + message = rx.recv() => message, + } { + trace!(?message, "-> client"); + if let Err(err) = writer.write_message(&message).await { + match err.kind() { + // ignore benign errors, treat as socket close + ErrorKind::BrokenPipe => {} + // report fatal errors + _ => error!(?err, "error writing client input: {err}"), + } + break; // break on any error + } + } + debug!("client input closed"); + info!("client disconnected"); + } + .in_current_span(), + ); + } + + fn spawn_output_task( + &self, + socket_read: BufReader, + close_tx: mpsc::Sender, + ) { + let port = self.port; + let instance = Arc::clone(&self.instance); + let instance_tx = self.instance.message_writer.clone(); + task::spawn( + async move { + match read_client_socket( + socket_read, + instance_tx, + close_tx, + port, + &instance.init_cache, + ) + .await + { + Ok(_) => debug!("client output closed"), + Err(err) => error!(?err, "error reading client output"), + } + } + .in_current_span(), + ); } } -pub async fn main(server_path: String, server_args: Vec) -> Result<()> { - let config = Config::load_or_default().await; +fn tag_id(port: u16, id: &Value) -> Result { + match id { + Value::Number(number) => Ok(format!("{port}:n:{number}")), + Value::String(string) => Ok(format!("{port}:s:{string}")), + _ => bail!("unexpected message id type {id:?}"), + } +} + +/// reads from client socket and tags the id for requests, forwards the messages into a mpsc queue +/// to the writer +async fn read_client_socket( + socket_read: BufReader, + tx: mpsc::Sender, + close_tx: mpsc::Sender, + port: u16, + init_cache: &InitializeCache, +) -> Result<()> { + let mut reader = LspReader::new(socket_read); + + while let Some(message) = reader.read_message().await? { + trace!(?message, "<- client"); - let proto_init = proto::Init::new(server_path, server_args); - let mut proto_init = serde_json::to_vec(&proto_init).context("sending proto init")?; - proto_init.push(b'\0'); + match message { + Message::Request(mut req) if req.method == "initialized" => { + // initialized notification can only be sent once per server + if init_cache.attempt_send_notif() { + debug!("send InitializedNotification"); - let mut stream = TcpStream::connect(config.connect) - .await - .context("connect")?; + req.id = tag_id(port, &req.id)?.into(); + if tx.send(req.into()).await.is_err() { + break; + } + } else { + // we're not the first, skip processing the message further + debug!("skip InitializedNotification"); + continue; + } + } - stream - .write_all(&proto_init) - .await - .context("sending proto init")?; - drop(proto_init); + Message::Request(req) if req.method == "shutdown" => { + // client requested the server to shut down but other clients might still be connected. + // instead we disconnect this client to prevent the editor hanging + // see + info!("client sent shutdown request, sending a response and closing connection"); + // + let message = Message::ResponseSuccess(ResponseSuccess { + jsonrpc: Version, + result: Value::Null, + id: req.id, + }); + // ignoring error because we would've closed the connection regardless + let _ = close_tx.send(message).await; + break; + } - io::copy_bidirectional(&mut stream, &mut stdio()) - .await - .context("io error")?; + Message::Request(mut req) => { + req.id = tag_id(port, &req.id)?.into(); + if tx.send(req.into()).await.is_err() { + break; + } + } + + Message::ResponseSuccess(_) | Message::ResponseError(_) => { + debug!(?message, "client response"); + } + + Message::Notification(notif) => { + if tx.send(notif.into()).await.is_err() { + break; + } + } + } + } Ok(()) } diff --git a/src/server/instance.rs b/src/instance.rs similarity index 77% rename from src/server/instance.rs rename to src/instance.rs index 1bcea0e..9bc66a1 100644 --- a/src/server/instance.rs +++ b/src/instance.rs @@ -1,8 +1,6 @@ use std::collections::hash_map::Entry; use std::collections::HashMap; use std::io::ErrorKind; -#[cfg(unix)] -use std::os::unix::process::ExitStatusExt; use std::path::{Path, PathBuf}; use std::process::Stdio; use std::str::{self, FromStr}; @@ -11,9 +9,6 @@ use std::sync::Arc; use std::time::Duration; use anyhow::{bail, Context, Result}; -use ra_multiplex::config::Config; -use ra_multiplex::lsp::{self, Message}; -use ra_multiplex::proto; use serde_json::{Number, Value}; use tokio::io::{AsyncBufReadExt, BufReader}; use tokio::process::{Child, ChildStderr, ChildStdin, ChildStdout, Command}; @@ -21,14 +16,18 @@ use tokio::sync::{mpsc, Mutex, Notify, RwLock}; use tokio::{select, task, time}; use tracing::{debug, error, info, info_span, instrument, trace, warn, Instrument}; -use super::async_once_cell::AsyncOnceCell; +use crate::async_once_cell::AsyncOnceCell; +use crate::config::Config; +use crate::lsp::jsonrpc::{Message, ResponseSuccess}; +use crate::lsp::transport::{LspReader, LspWriter}; +use crate::proto; /// keeps track of the initialize/initialized handshake for an instance #[derive(Default)] pub struct InitializeCache { request_sent: AtomicBool, notif_sent: AtomicBool, - pub response: AsyncOnceCell, + pub response: AsyncOnceCell, } impl InitializeCache { @@ -350,7 +349,7 @@ impl RaInstance { /// read messages sent by clients from a channel and write them into server stdin fn spawn_stdin_task(self: &Arc, rx: mpsc::Receiver, stdin: ChildStdin) { let mut receiver = rx; - let mut stdin = stdin; + let mut writer = LspWriter::new(stdin); task::spawn( async move { @@ -358,7 +357,8 @@ impl RaInstance { // child closes and all the clients disconnect including the sender and this receiver // will not keep blocking (unlike in client input task) while let Some(message) = receiver.recv().await { - if let Err(err) = message.to_writer(&mut stdin).await { + trace!(?message, "-> server"); + if let Err(err) = writer.write_message(&message).await { match err.kind() { // stdin is closed, no need to log an error ErrorKind::BrokenPipe => {} @@ -405,7 +405,7 @@ impl RaInstance { match exit { Ok(status) => { #[cfg(unix)] - let signal = status.signal(); + let signal = std::os::unix::process::ExitStatusExt::signal(&status); #[cfg(not(unix))] let signal = tracing::field::Empty; @@ -428,9 +428,13 @@ impl RaInstance { } } -fn parse_tagged_id(tagged: &str) -> Result<(u16, Value)> { +fn parse_tagged_id(tagged: &Value) -> Result<(u16, Value)> { + let Value::String(tagged) = tagged else { + bail!("tagged id must be a String found `{tagged:?}`"); + }; + let (port, rest) = tagged.split_once(':').context("missing first `:`")?; - let port = u16::from_str_radix(port, 16)?; + let port = u16::from_str(port)?; let (value_type, old_id) = rest.split_once(':').context("missing second `:`")?; let old_id = match value_type { "n" => Value::Number(Number::from_str(old_id)?), @@ -441,85 +445,97 @@ fn parse_tagged_id(tagged: &str) -> Result<(u16, Value)> { } async fn read_server_socket( - mut reader: BufReader, + reader: BufReader, senders: &MessageReaders, init_cache: &InitializeCache, ) -> Result<()> { - let mut buffer = Vec::new(); - - while let Some((mut json, bytes)) = lsp::read_message(&mut reader, &mut buffer).await? { - trace!(message = serde_json::to_string(&json).unwrap(), "server"); - - if let Some(id) = json.get("id") { - // we tagged the request id so we expect to only receive tagged responses - let tagged_id = match id { - Value::String(string) if string == INIT_REQUEST_ID => { - // this is a response to the InitializeRequest, we need to process it - // separately - debug!("recv InitializeRequest response"); - init_cache - .response - .set(Message::from_bytes(bytes)) - .await - .ok() // throw away the Err(message), we don't need it and it doesn't implement std::error::Error - .context("received multiple InitializeRequest responses from instance")?; - continue; - } - Value::String(string) => string, - _ => { - debug!( - message = serde_json::to_string(&json).unwrap(), - "response to no request", - ); - // FIXME uncommenting this crashes rust-analyzer, presumably because the client - // then sends a confusing response or something? i'm guessing that the response - // id gets tagged with port but rust-analyzer expects to know the id because - // it's actually a response and not a request. i'm not sure if we can handle - // these at all with multiple clients attached - // - // ideally we could send these to all clients, but what if there is a matching - // response from each client? rust-analyzer only expects one (this might - // actually be why it's crashing) - // - // ignoring these might end up being the safest option, they don't seem to - // matter to neovim anyway - // ```rust - // let message = Message::from_bytes(bytes); - // let senders = senders.read().await; - // for sender in senders.values() { - // sender - // .send(message.clone()) - // .await - // .context("forward server notification")?; - // } - // ``` - continue; - } - }; + let mut reader = LspReader::new(reader); + + while let Some(message) = reader.read_message().await? { + trace!(?message, "<- server"); + + match message { + Message::ResponseSuccess(res) + if res.id == Value::String(INIT_REQUEST_ID.to_owned()) => + { + // this is a response to the InitializeRequest, we need to process it + // separately + debug!(message = ?Message::from(res.clone()), "recv InitializeRequest response"); + init_cache + .response + .set(res.into()) + .await + .ok() // throw away the Err(message), we don't need it and it doesn't implement std::error::Error + .context("received multiple InitializeRequest responses from instance")?; + } - let (port, old_id) = match parse_tagged_id(tagged_id) { - Ok(ok) => ok, - Err(err) => { - warn!(?err, "invalid tagged id"); - continue; - } - }; + Message::ResponseSuccess(mut res) => { + match parse_tagged_id(&res.id) { + Ok((port, id)) => { + res.id = id; + if let Some(sender) = senders.read().await.get(&port) { + // ignore closed channels + let _ignore = sender.send(res.into()).await; + } else { + warn!("no client"); + } + } + Err(err) => { + warn!(?err, "invalid tagged id"); + } + }; + } - json.insert("id".to_owned(), old_id); + Message::ResponseError(mut res) => { + match parse_tagged_id(&res.id) { + Ok((port, id)) => { + res.id = id; + if let Some(sender) = senders.read().await.get(&port) { + // ignore closed channels + let _ignore = sender.send(res.into()).await; + } else { + warn!("no client"); + } + } + Err(err) => { + warn!(?err, "invalid tagged id"); + continue; + } + }; + } - if let Some(sender) = senders.read().await.get(&port) { - let message = Message::from_json(&json, &mut buffer); - // ignore closed channels - let _ignore = sender.send(message).await; - } else { - warn!("no client"); + Message::Request(_) => { + debug!(?message, "server request"); + // FIXME uncommenting this crashes rust-analyzer, presumably because the client + // then sends a confusing response or something? i'm guessing that the response + // id gets tagged with port but rust-analyzer expects to know the id because + // it's actually a response and not a request. i'm not sure if we can handle + // these at all with multiple clients attached + // + // ideally we could send these to all clients, but what if there is a matching + // response from each client? rust-analyzer only expects one (this might + // actually be why it's crashing) + // + // ignoring these might end up being the safest option, they don't seem to + // matter to neovim anyway + // ```rust + // let message = Message::from_bytes(bytes); + // let senders = senders.read().await; + // for sender in senders.values() { + // sender + // .send(message.clone()) + // .await + // .context("forward server notification")?; + // } + // ``` } - } else { - // notification messages without an id are sent to all clients - let message = Message::from_bytes(bytes); - for sender in senders.read().await.values() { - // ignore closed channels - let _ignore = sender.send(message.clone()).await; + + Message::Notification(notif) => { + // notification messages without an id are sent to all clients + for sender in senders.read().await.values() { + // ignore closed channels + let _ignore = sender.send(notif.clone().into()).await; + } } } } diff --git a/src/lib.rs b/src/lib.rs index e731881..1a746cd 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,9 @@ -pub mod config; -pub mod lsp; -pub mod proto; +mod async_once_cell; +mod client; +mod config; +mod instance; +mod lsp; +mod proto; + +pub mod proxy; +pub mod server; diff --git a/src/lsp.rs b/src/lsp.rs index 96afb56..6a2c5bc 100644 --- a/src/lsp.rs +++ b/src/lsp.rs @@ -28,158 +28,5 @@ //! - Progress notifications - contains a `token` property which could be used to identify the //! client but the specification also says it has nothing to do with the request IDs -use std::fmt::{self, Debug}; -use std::io::{self, ErrorKind}; -use std::str; -use std::sync::Arc; - -use anyhow::{bail, ensure, Context, Result}; -use serde::Serialize; -use serde_json::{Map, Value}; -use tokio::io::{AsyncBufRead, AsyncBufReadExt, AsyncReadExt, AsyncWrite, AsyncWriteExt}; - -/// Every message begins with a HTTP-style header -/// -/// Headers are terminated by `\r\n` sequence and the final header is followed by another `\r\n`. -/// The currently recognized headers are `content-type` which is optional and contains a `string` -/// (something like a MIME-type) and `content-length` which contains the length of the message body -/// after the final `\r\n` of the header. Header names and values are separated by `: `. -pub struct Header { - pub content_length: usize, - pub content_type: Option, -} - -impl Header { - pub async fn from_reader( - buffer: &mut Vec, - mut reader: R, - ) -> Result> { - let mut content_type = None; - let mut content_length = None; - - loop { - buffer.clear(); - match reader.read_until(b'\n', &mut *buffer).await { - Ok(0) => return Ok(None), // EOF - Ok(_) => {} - Err(err) => match err.kind() { - // reader is closed for some reason, no need to log an error about it - ErrorKind::ConnectionReset - | ErrorKind::ConnectionAborted - | ErrorKind::BrokenPipe => return Ok(None), - _ => bail!(err), - }, - } - let header_text = buffer - .strip_suffix(b"\r\n") - .context("malformed header, missing \\r\\n")?; - let header_text = str::from_utf8(header_text).context("malformed header")?; - - if header_text.is_empty() { - // headers are separated by an empty line from the body - break; - } - let (name, value) = match header_text.split_once(": ") { - Some(split) => split, - None => bail!("malformed header, missing value separator: {}", header_text), - }; - - match name.to_ascii_lowercase().as_str() { - "content-type" => { - ensure!(content_type.is_none(), "repeated header content-type"); - content_type = Some(value.to_owned()); - } - "content-length" => { - ensure!(content_length.is_none(), "repeated header content-length"); - content_length = Some(value.parse::().context("content-length header")?); - } - _ => bail!("unknown header: {name}"), - } - } - - let content_length = content_length.context("missing required header content-length")?; - Ok(Some(Header { - content_length, - content_type, - })) - } -} - -/// reads one LSP message from a reader, deserializes it and leaves the serialized body of the -/// message in `buffer` -pub async fn read_message( - mut reader: R, - buffer: &mut Vec, -) -> Result, &[u8])>> -where - R: AsyncBufRead + Unpin, -{ - let header = Header::from_reader(&mut *buffer, &mut reader) - .await - .context("parsing header")?; - let header = match header { - Some(header) => header, - None => return Ok(None), - }; - - buffer.clear(); - buffer.resize(header.content_length, 0); - if let Err(err) = reader.read_exact(&mut *buffer).await { - match err.kind() { - // reader is closed for some reason, no need to log an error about it - ErrorKind::UnexpectedEof - | ErrorKind::ConnectionReset - | ErrorKind::ConnectionAborted - | ErrorKind::BrokenPipe => return Ok(None), - _ => bail!(err), - } - } - - let bytes = buffer.as_slice(); - let json = serde_json::from_slice(bytes).context("invalid body")?; - Ok(Some((json, bytes))) -} - -/// LSP messages -#[derive(Clone)] -pub struct Message { - bytes: Arc<[u8]>, -} - -impl Debug for Message { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str("Message") - } -} - -impl Message { - /// construct a message from a byte buffer, should only contain the message body - no headers - pub fn from_bytes(bytes: &[u8]) -> Self { - Self { - bytes: Arc::from(bytes), - } - } - - pub fn as_bytes(&self) -> &[u8] { - &self.bytes - } - - /// construct a message from a serializable value, like JSON - pub fn from_json(json: &impl Serialize, buffer: &mut Vec) -> Self { - buffer.clear(); - serde_json::to_writer(&mut *buffer, json).expect("invalid json"); - Self::from_bytes(&*buffer) - } - - /// serialize LSP message into a writer, prepending the appropriate content-length header - pub async fn to_writer(&self, mut writer: W) -> io::Result<()> - where - W: AsyncWrite + Unpin, - { - writer - .write_all(format!("Content-Length: {}\r\n\r\n", self.bytes.len()).as_bytes()) - .await?; - writer.write_all(&self.bytes).await?; - writer.flush().await - } -} +pub mod jsonrpc; +pub mod transport; diff --git a/src/lsp/jsonrpc.rs b/src/lsp/jsonrpc.rs new file mode 100644 index 0000000..e8421b2 --- /dev/null +++ b/src/lsp/jsonrpc.rs @@ -0,0 +1,179 @@ +//! JSON-RPC 2.0 support +//! +//! Support for the JSON-RPC 2.0 protocol as definied here +//! . With the exception of batching which +//! is handled in [`read_message`](super::transport::LspReader::read_message). + +use std::fmt; + +use serde_derive::{Deserialize, Serialize}; + +#[derive(Serialize, Deserialize, Clone)] +#[serde(untagged)] +pub enum Message { + Request(Request), + Notification(Notification), + ResponseError(ResponseError), + ResponseSuccess(ResponseSuccess), +} + +#[derive(Serialize, Deserialize, Clone)] +#[serde(deny_unknown_fields)] +pub struct Request { + pub jsonrpc: Version, + pub method: String, + #[serde(default)] + pub params: serde_json::Value, + pub id: serde_json::Value, +} + +#[derive(Serialize, Deserialize, Clone)] +#[serde(deny_unknown_fields)] +pub struct Notification { + pub jsonrpc: Version, + pub method: String, + #[serde(default)] + pub params: serde_json::Value, +} + +#[derive(Serialize, Deserialize, Clone)] +#[serde(deny_unknown_fields)] +pub struct ResponseError { + pub jsonrpc: Version, + pub error: Error, + pub id: serde_json::Value, +} + +#[derive(Serialize, Deserialize, Clone)] +#[serde(deny_unknown_fields)] +pub struct ResponseSuccess { + pub jsonrpc: Version, + #[serde(default)] + pub result: serde_json::Value, + pub id: serde_json::Value, +} + +#[derive(Serialize, Deserialize, Clone)] +#[serde(deny_unknown_fields)] +pub struct Error { + pub code: i64, + pub message: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub data: Option, +} + +#[derive(Serialize, Deserialize, Clone)] +#[serde(untagged)] +pub enum Params { + ByPosition(Vec), + ByName(serde_json::Map), +} + +/// ZST representation of the `"2.0"` version string +#[derive(Clone, Copy)] +pub struct Version; + +impl serde::ser::Serialize for Version { + fn serialize(&self, serializer: S) -> Result + where + S: serde::ser::Serializer, + { + serializer.serialize_str("2.0") + } +} + +impl<'de> serde::de::Visitor<'de> for Version { + type Value = Version; + + fn expecting(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result { + fmt.write_str(r#"string value "2.0""#) + } + + fn visit_str(self, v: &str) -> Result + where + E: serde::de::Error, + { + match v { + "2.0" => Ok(Version), + _ => Err(E::custom("unsupported JSON-RPC version")), + } + } +} + +impl<'de> serde::de::Deserialize<'de> for Version { + fn deserialize(deserializer: D) -> Result + where + D: serde::de::Deserializer<'de>, + { + deserializer.deserialize_str(Version) + } +} + +impl From for Message { + fn from(value: Request) -> Self { + Message::Request(value) + } +} + +impl From for Message { + fn from(value: Notification) -> Self { + Message::Notification(value) + } +} + +impl From for Message { + fn from(value: ResponseSuccess) -> Self { + Message::ResponseSuccess(value) + } +} + +impl From for Message { + fn from(value: ResponseError) -> Self { + Message::ResponseError(value) + } +} + +impl fmt::Debug for Message { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let json = serde_json::to_string(self).expect("BUG: invalid message"); + f.write_str(&json) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn test(input: &str) { + let deserialized = serde_json::from_str::(input).expect("failed to deserialize"); + let serialized = serde_json::to_string(&deserialized).expect("failed to serialize"); + assert_eq!(input, serialized); + } + + #[test] + fn call_with_positional_parameters() { + test(r#"{"jsonrpc":"2.0","method":"subtract","params":[42,23],"id":1}"#) + } + + #[test] + fn call_with_named_parameters() { + test( + r#"{"jsonrpc":"2.0","method":"subtract","params":{"minuend":42,"subtrahend":23},"id":3}"#, + ); + } + + #[test] + fn response() { + test(r#"{"jsonrpc":"2.0","result":19,"id":1}"#) + } + + #[test] + fn notification() { + test(r#"{"jsonrpc":"2.0","method":"update","params":[1,2,3,4,5]}"#) + } + + #[test] + fn error() { + test(r#"{"jsonrpc":"2.0","error":{"code":-32601,"message":"Method not found"},"id":"1"}"#) + } +} diff --git a/src/lsp/transport.rs b/src/lsp/transport.rs new file mode 100644 index 0000000..85b4191 --- /dev/null +++ b/src/lsp/transport.rs @@ -0,0 +1,181 @@ +use std::io::{self, ErrorKind}; +use std::str; + +use anyhow::{bail, ensure, Context, Result}; +use tokio::io::{AsyncBufRead, AsyncBufReadExt, AsyncReadExt, AsyncWrite, AsyncWriteExt}; + +use crate::lsp::jsonrpc::Message; + +pub struct LspReader { + reader: R, + batch: Vec, + buffer: Vec, +} + +/// Every message begins with a HTTP-style header +/// +/// Headers are terminated by `\r\n` sequence and the final header is followed by another `\r\n`. +/// The currently recognized headers are `content-type` which is optional and contains a `string` +/// (something like a MIME-type) and `content-length` which contains the length of the message body +/// after the final `\r\n` of the header. Header names and values are separated by `: `. +/// +/// While we parse the `content-type` header ignore it completely and we don't forward it, +/// expecting both the server and client to assume the default. +/// +/// For mor details see . +pub struct Header { + pub content_length: usize, + pub content_type: Option, +} + +impl LspReader +where + R: AsyncBufRead + Unpin, +{ + pub fn new(reader: R) -> Self { + LspReader { + reader, + batch: Vec::new(), + buffer: Vec::with_capacity(1024), + } + } + + pub async fn read_header(&mut self) -> Result> { + let mut content_type = None; + let mut content_length = None; + + loop { + self.buffer.clear(); + match self.reader.read_until(b'\n', &mut self.buffer).await { + Ok(0) => return Ok(None), // EOF + Ok(_) => {} + Err(err) => match err.kind() { + // reader is closed for some reason, no need to log an error about it + ErrorKind::ConnectionReset + | ErrorKind::ConnectionAborted + | ErrorKind::BrokenPipe => return Ok(None), + _ => bail!(err), + }, + } + let header_text = self + .buffer + .strip_suffix(b"\r\n") + .context(r"malformed header, missing `\r\n` terminator")?; + let header_text = str::from_utf8(header_text) + .context("malformed header, ascii encoding is a subset of utf-8")?; + + if header_text.is_empty() { + // headers are separated by an empty line from the body + break; + } + let (name, value) = match header_text.split_once(": ") { + Some(split) => split, + None => bail!("malformed header, missing value separator: {}", header_text), + }; + + match name.to_ascii_lowercase().as_str() { + "content-type" => { + ensure!(content_type.is_none(), "repeated header content-type"); + content_type = Some(value.to_owned()); + } + "content-length" => { + ensure!(content_length.is_none(), "repeated header content-length"); + content_length = Some(value.parse::().context("content-length header")?); + } + _ => bail!("unknown header name: {name:?}"), + } + } + + let content_length = content_length.context("missing required header content-length")?; + Ok(Some(Header { + content_length, + content_type, + })) + } + + /// Read one message + /// + /// Returns `None` if the reader was closed and it'll never return another + /// message after the first `None`. + /// + /// Batch messages are transparently split into individual messages and + /// delivered in order. + pub async fn read_message(&mut self) -> Result> { + // return pending messages until the last batch is drained + if let Some(pending) = self.batch.pop() { + return Ok(Some(pending)); + } + + let header = self.read_header().await.context("parsing header")?; + let header = match header { + Some(header) => header, + None => return Ok(None), + }; + + self.buffer.clear(); + self.buffer.resize(header.content_length, 0); + if let Err(err) = self.reader.read_exact(&mut self.buffer).await { + match err.kind() { + // reader is closed for some reason, no need to log an error about it + ErrorKind::UnexpectedEof + | ErrorKind::ConnectionReset + | ErrorKind::ConnectionAborted + | ErrorKind::BrokenPipe => return Ok(None), + _ => bail!(err), + } + } + + let bytes = self.buffer.as_slice(); + let body = str::from_utf8(bytes) + .with_context(|| { + let lossy_utf8 = String::from_utf8_lossy(bytes); + format!("parsing body `{lossy_utf8}`") + }) + .context("parsing LSP message")?; + + // handle batches + if body.starts_with('[') { + self.batch = serde_json::from_str(body) + .with_context(|| format!("parsing body `{body}`")) + .context("parsing LSP message")?; + // we're popping the messages from the end of the vec + self.batch.reverse(); + let message = self.batch.pop().context("received an empty batch")?; + Ok(Some(message)) + } else { + let message = serde_json::from_str(body) + .with_context(|| format!("parsing body `{body}`")) + .context("parsing LSP message")?; + Ok(Some(message)) + } + } +} + +pub struct LspWriter { + writer: W, + buffer: Vec, +} + +impl LspWriter +where + W: AsyncWrite + Unpin, +{ + pub fn new(writer: W) -> Self { + LspWriter { + writer, + buffer: Vec::with_capacity(1024), + } + } + + /// serialize LSP message into a writer, prepending the appropriate content-length header + pub async fn write_message(&mut self, message: &Message) -> io::Result<()> { + self.buffer.clear(); + serde_json::to_writer(&mut self.buffer, message).expect("BUG: invalid message"); + + self.writer + .write_all(format!("Content-Length: {}\r\n\r\n", self.buffer.len()).as_bytes()) + .await?; + self.writer.write_all(&self.buffer).await?; + self.writer.flush().await + } +} diff --git a/src/main.rs b/src/main.rs index d9a6437..8a297b4 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,9 +1,7 @@ use std::env; use clap::{Args, Parser, Subcommand}; - -mod client; -mod server; +use ra_multiplex::{proxy, server}; #[derive(Parser, Debug)] #[command(author, version, about, long_about = None)] @@ -45,11 +43,11 @@ async fn main() -> anyhow::Result<()> { let cli = Cli::parse(); match cli.command { - Some(Cmd::Server(_args)) => server::main().await, - Some(Cmd::Client(args)) => client::main(args.server_path, args.server_args).await, + Some(Cmd::Server(_args)) => server::run().await, + Some(Cmd::Client(args)) => proxy::run(args.server_path, args.server_args).await, None => { let server_path = env::var("RA_MUX_SERVER").unwrap_or_else(|_| "rust-analyzer".into()); - client::main(server_path, vec![]).await + proxy::run(server_path, vec![]).await } } } diff --git a/src/proxy.rs b/src/proxy.rs new file mode 100644 index 0000000..83271e5 --- /dev/null +++ b/src/proxy.rs @@ -0,0 +1,73 @@ +use std::pin::Pin; +use std::task::{Context, Poll}; + +use anyhow::{Context as _, Result}; +use pin_project_lite::pin_project; +use tokio::io::{self, AsyncRead, AsyncWrite, AsyncWriteExt}; +use tokio::net::TcpStream; + +use crate::config::Config; +use crate::proto; + +pin_project! { + struct Stdio { + #[pin] + stdin: io::Stdin, + #[pin] + stdout: io::Stdout, + } +} + +fn stdio() -> Stdio { + Stdio { + stdin: io::stdin(), + stdout: io::stdout(), + } +} + +impl AsyncRead for Stdio { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context, + buf: &mut io::ReadBuf, + ) -> Poll> { + self.project().stdin.poll_read(cx, buf) + } +} + +impl AsyncWrite for Stdio { + fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll> { + self.project().stdout.poll_write(cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + self.project().stdout.poll_flush(cx) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + self.project().stdout.poll_shutdown(cx) + } +} + +pub async fn run(server_path: String, server_args: Vec) -> Result<()> { + let config = Config::load_or_default().await; + + let proto_init = proto::Init::new(server_path, server_args); + let mut proto_init = serde_json::to_vec(&proto_init).context("sending proto init")?; + proto_init.push(b'\0'); + + let mut stream = TcpStream::connect(config.connect) + .await + .context("connect")?; + + stream + .write_all(&proto_init) + .await + .context("sending proto init")?; + drop(proto_init); + + io::copy_bidirectional(&mut stream, &mut stdio()) + .await + .context("io error")?; + Ok(()) +} diff --git a/src/server.rs b/src/server.rs index fdbc30b..05ff6ed 100644 --- a/src/server.rs +++ b/src/server.rs @@ -7,19 +7,15 @@ //! cargo workspace and routing the messages through TCP to multiple clients. use anyhow::{Context, Result}; -use ra_multiplex::config::Config; use tokio::net::TcpListener; use tokio::task; use tracing::{debug, error, info, info_span, warn, Instrument}; -use crate::server::client::Client; -use crate::server::instance::InstanceRegistry; +use crate::client::Client; +use crate::config::Config; +use crate::instance::InstanceRegistry; -mod async_once_cell; -mod client; -mod instance; - -pub async fn main() -> Result<()> { +pub async fn run() -> Result<()> { let config = Config::load_or_default().await; let registry = InstanceRegistry::new().await; diff --git a/src/server/client.rs b/src/server/client.rs deleted file mode 100644 index a201dc7..0000000 --- a/src/server/client.rs +++ /dev/null @@ -1,270 +0,0 @@ -use std::io::ErrorKind; -use std::sync::Arc; - -use anyhow::{bail, Context, Result}; -use ra_multiplex::lsp::{self, Message}; -use ra_multiplex::proto; -use serde_json::{json, Map, Value}; -use tokio::io::BufReader; -use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf}; -use tokio::net::TcpStream; -use tokio::sync::mpsc; -use tokio::{select, task}; -use tracing::{debug, error, info, trace, Instrument}; - -use super::instance::{ - InitializeCache, InstanceKey, InstanceRegistry, RaInstance, INIT_REQUEST_ID, -}; - -pub struct Client { - port: u16, - initialize_request_id: Option, - instance: Arc, -} - -impl Client { - /// finds or spawns a rust-analyzer instance and connects the client - pub async fn process(socket: TcpStream, port: u16, registry: InstanceRegistry) -> Result<()> { - let (socket_read, socket_write) = socket.into_split(); - let mut socket_read = BufReader::new(socket_read); - - let mut buffer = Vec::new(); - let proto_init = proto::Init::from_reader(&mut buffer, &mut socket_read).await?; - - let key = InstanceKey::from_proto_init(&proto_init).await; - debug!( - path = ?key.workspace_root(), - server = ?key.server(), - args = ?key.args(), - "client configured", - ); - - let mut client = Client { - port, - initialize_request_id: None, - instance: registry.get(&key).await?, - }; - - client.wait_for_initialize_request(&mut socket_read).await?; - - let (client_tx, client_rx) = client.register_client_with_instance().await; - let (close_tx, close_rx) = mpsc::channel(1); - client.spawn_input_task(client_rx, close_rx, socket_write); - client.spawn_output_task(socket_read, close_tx); - - client - .wait_for_initialize_response(client_tx, &mut buffer) - .await?; - Ok(()) - } - - async fn wait_for_initialize_request( - &mut self, - socket_read: &mut BufReader, - ) -> Result<()> { - let mut buffer = Vec::new(); - let (mut json, _bytes) = lsp::read_message(&mut *socket_read, &mut buffer) - .await? - .context("channel closed")?; - if !matches!(json.get("method"), Some(Value::String(method)) if method == "initialize") { - bail!("first client message was not InitializeRequest"); - } - debug!("recv InitializeRequest"); - // this is an initialize request, it's special because it cannot be sent twice or - // rust-analyzer will crash. - - // we save the request id so we can later use it for the response - self.initialize_request_id = Some( - json.remove("id") - .context("InitializeRequest is missing an `id`")?, - ); - if self.instance.init_cache.attempt_send_request() { - // it haven't been sent yet, we can send it. - // - // instead of tagging the original id we replace it with a custom id that only - // the `initialize` uses - json.insert("id".to_owned(), Value::String(INIT_REQUEST_ID.to_owned())); - - self.instance - .message_writer - .send(Message::from_json(&json, &mut buffer)) - .await - .context("forward client request")?; - } else { - // initialize request was already sent for this instance, no need to send it again - } - Ok(()) - } - - async fn wait_for_initialize_response( - &self, - tx: mpsc::Sender, - buffer: &mut Vec, - ) -> Result<()> { - // parse the cached message and restore the `id` to the value this client expects - let response = self.instance.init_cache.response.get().await; - let mut json: Map = serde_json::from_slice(response.as_bytes()) - .expect("BUG: cached initialize response was invalid"); - json.insert( - "id".to_owned(), - self.initialize_request_id - .clone() - .expect("BUG: need to wait_for_initialize_request first"), - ); - let message = Message::from_json(&json, buffer); - debug!("send response to InitializeRequest"); - tx.send(message).await.context("send initialize response")?; - Ok(()) - } - - async fn register_client_with_instance( - &self, - ) -> (mpsc::Sender, mpsc::Receiver) { - let (client_tx, client_rx) = mpsc::channel(64); - self.instance - .message_readers - .write() - .await - .insert(self.port, client_tx.clone()); - (client_tx, client_rx) - } - - fn spawn_input_task( - &self, - mut rx: mpsc::Receiver, - mut close_rx: mpsc::Receiver, - mut socket_write: OwnedWriteHalf, - ) { - task::spawn( - async move { - // unlike the output task, here we first wait on the channel which is going to - // block until the rust-analyzer server sends a notification, however if we're the last - // client and have just closed the server is unlikely to send any. this results in the - // last client often falsely hanging while the gc task depends on the input channels being - // closed to detect a disconnected client. - // - // when a client sends a shutdown request we receive a message on the close_rx, send - // the reply and close the connection. if no shutdown request was received but the - // client closed close_rx channel will be dropped (unlike the normal rx channel which - // is shared) and the connection will close without sending any response. - while let Some(message) = select! { - message = close_rx.recv() => message, - message = rx.recv() => message, - } { - if let Err(err) = message.to_writer(&mut socket_write).await { - match err.kind() { - // ignore benign errors, treat as socket close - ErrorKind::BrokenPipe => {} - // report fatal errors - _ => error!(?err, "error writing client input: {err}"), - } - break; // break on any error - } - } - debug!("client input closed"); - info!("client disconnected"); - } - .in_current_span(), - ); - } - - fn spawn_output_task( - &self, - socket_read: BufReader, - close_tx: mpsc::Sender, - ) { - let port = self.port; - let instance = Arc::clone(&self.instance); - let instance_tx = self.instance.message_writer.clone(); - task::spawn( - async move { - match read_client_socket( - socket_read, - instance_tx, - close_tx, - port, - &instance.init_cache, - ) - .await - { - Ok(_) => debug!("client output closed"), - Err(err) => error!(?err, "error reading client output"), - } - } - .in_current_span(), - ); - } -} - -fn tag_id(port: u16, id: &Value) -> Result { - match id { - Value::Number(number) => Ok(format!("{port:04x}:n:{number}")), - Value::String(string) => Ok(format!("{port:04x}:s:{string}")), - _ => bail!("unexpected message id type {id:?}"), - } -} - -/// reads from client socket and tags the id for requests, forwards the messages into a mpsc queue -/// to the writer -async fn read_client_socket( - mut socket_read: BufReader, - tx: mpsc::Sender, - close_tx: mpsc::Sender, - port: u16, - init_cache: &InitializeCache, -) -> Result<()> { - let mut buffer = Vec::new(); - - while let Some((mut json, bytes)) = lsp::read_message(&mut socket_read, &mut buffer).await? { - trace!(message = serde_json::to_string(&json).unwrap(), "client"); - if matches!(json.get("method"), Some(Value::String(method)) if method == "initialized") { - // initialized notification can only be sent once per server - if init_cache.attempt_send_notif() { - debug!("send InitializedNotification"); - } else { - // we're not the first, skip processing the message further - debug!("skip InitializedNotification"); - continue; - } - } - if matches!(json.get("method"), Some(Value::String(method)) if method == "shutdown") { - // client requested the server to shut down but other clients might still be connected. - // instead we disconnect this client to prevent the editor hanging - // see - if let Some(shutdown_request_id) = json.get("id") { - info!("client sent shutdown request, sending a response and closing connection"); - // - let message = Message::from_json( - &json!({ - "id": shutdown_request_id, - "jsonrpc": "2.0", - "result": null, - }), - &mut buffer, - ); - // ignoring error because we would've closed the connection regardless - let _ = close_tx.send(message).await; - } - break; - } - if let Some(id) = json.get("id") { - // messages containing an id need the id modified so we can discern which client to send - // the response to - let tagged_id = tag_id(port, id)?; - json.insert("id".to_owned(), Value::String(tagged_id)); - - let message = Message::from_json(&json, &mut buffer); - if tx.send(message).await.is_err() { - break; - } - } else { - // notification messages without an id don't need any modification and can be forwarded - // to rust-analyzer as is - let message = Message::from_bytes(bytes); - if tx.send(message).await.is_err() { - break; - } - } - } - Ok(()) -}