diff --git a/src/client.rs b/src/client.rs index 2c7240f..4cf443a 100644 --- a/src/client.rs +++ b/src/client.rs @@ -2,7 +2,7 @@ use std::io::ErrorKind; use std::sync::Arc; use anyhow::{bail, Context, Result}; -use serde_json::{json, Map, Value}; +use serde_json::Value; use tokio::io::BufReader; use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf}; use tokio::net::TcpStream; @@ -13,7 +13,8 @@ use tracing::{debug, error, info, trace, Instrument}; use crate::instance::{ InitializeCache, InstanceKey, InstanceRegistry, RaInstance, INIT_REQUEST_ID, }; -use crate::lsp::transport::{LspReader, LspWriter, Message}; +use crate::lsp::jsonrpc::{Message, ResponseSuccess, Version}; +use crate::lsp::transport::{LspReader, LspWriter}; use crate::proto; pub struct Client { @@ -52,9 +53,7 @@ impl Client { client.spawn_input_task(client_rx, close_rx, socket_write); client.spawn_output_task(socket_read, close_tx); - client - .wait_for_initialize_response(client_tx, &mut buffer) - .await?; + client.wait_for_initialize_response(client_tx).await?; Ok(()) } @@ -63,30 +62,34 @@ impl Client { socket_read: &mut BufReader, ) -> Result<()> { let mut reader = LspReader::new(socket_read); - let mut buffer = Vec::new(); - let (mut json, _bytes) = reader.read_message().await?.context("channel closed")?; - if !matches!(json.get("method"), Some(Value::String(method)) if method == "initialize") { - bail!("first client message was not InitializeRequest"); - } - debug!("recv InitializeRequest"); + + let message = reader.read_message().await?.context("channel closed")?; + trace!(?message, "<- client"); + + let mut req = match message { + Message::Request(req) if req.method == "initialize" => { + debug!(message = ?Message::from(req.clone()), "recv InitializeRequest"); + req + } + _ => bail!("first client message was not InitializeRequest"), + }; + // this is an initialize request, it's special because it cannot be sent twice or // rust-analyzer will crash. // we save the request id so we can later use it for the response - self.initialize_request_id = Some( - json.remove("id") - .context("InitializeRequest is missing an `id`")?, - ); + self.initialize_request_id = Some(req.id.clone()); + if self.instance.init_cache.attempt_send_request() { // it haven't been sent yet, we can send it. // // instead of tagging the original id we replace it with a custom id that only // the `initialize` uses - json.insert("id".to_owned(), Value::String(INIT_REQUEST_ID.to_owned())); + req.id = Value::String(INIT_REQUEST_ID.to_owned()); self.instance .message_writer - .send(Message::from_json(&json, &mut buffer)) + .send(req.into()) .await .context("forward client request")?; } else { @@ -95,24 +98,17 @@ impl Client { Ok(()) } - async fn wait_for_initialize_response( - &self, - tx: mpsc::Sender, - buffer: &mut Vec, - ) -> Result<()> { + async fn wait_for_initialize_response(&self, tx: mpsc::Sender) -> Result<()> { // parse the cached message and restore the `id` to the value this client expects - let response = self.instance.init_cache.response.get().await; - let mut json: Map = serde_json::from_slice(response.as_bytes()) - .expect("BUG: cached initialize response was invalid"); - json.insert( - "id".to_owned(), - self.initialize_request_id - .clone() - .expect("BUG: need to wait_for_initialize_request first"), - ); - let message = Message::from_json(&json, buffer); - debug!("send response to InitializeRequest"); - tx.send(message).await.context("send initialize response")?; + let mut res = self.instance.init_cache.response.get().await; + res.id = self + .initialize_request_id + .clone() + .expect("BUG: need to wait_for_initialize_request first"); + debug!(message = ?Message::from(res.clone()), "send response to InitializeRequest"); + tx.send(res.into()) + .await + .context("send initialize response")?; Ok(()) } @@ -151,7 +147,8 @@ impl Client { message = close_rx.recv() => message, message = rx.recv() => message, } { - if let Err(err) = writer.write_message(message).await { + trace!(?message, "-> client"); + if let Err(err) = writer.write_message(&message).await { match err.kind() { // ignore benign errors, treat as socket close ErrorKind::BrokenPipe => {} @@ -198,8 +195,8 @@ impl Client { fn tag_id(port: u16, id: &Value) -> Result { match id { - Value::Number(number) => Ok(format!("{port:04x}:n:{number}")), - Value::String(string) => Ok(format!("{port:04x}:s:{string}")), + Value::Number(number) => Ok(format!("{port}:n:{number}")), + Value::String(string) => Ok(format!("{port}:s:{string}")), _ => bail!("unexpected message id type {id:?}"), } } @@ -214,56 +211,58 @@ async fn read_client_socket( init_cache: &InitializeCache, ) -> Result<()> { let mut reader = LspReader::new(socket_read); - let mut buffer = Vec::new(); - while let Some((mut json, bytes)) = reader.read_message().await? { - trace!(message = serde_json::to_string(&json).unwrap(), "client"); - if matches!(json.get("method"), Some(Value::String(method)) if method == "initialized") { - // initialized notification can only be sent once per server - if init_cache.attempt_send_notif() { - debug!("send InitializedNotification"); - } else { - // we're not the first, skip processing the message further - debug!("skip InitializedNotification"); - continue; + while let Some(message) = reader.read_message().await? { + trace!(?message, "<- client"); + + match message { + Message::Request(mut req) if req.method == "initialized" => { + // initialized notification can only be sent once per server + if init_cache.attempt_send_notif() { + debug!("send InitializedNotification"); + + req.id = tag_id(port, &req.id)?.into(); + if tx.send(req.into()).await.is_err() { + break; + } + } else { + // we're not the first, skip processing the message further + debug!("skip InitializedNotification"); + continue; + } } - } - if matches!(json.get("method"), Some(Value::String(method)) if method == "shutdown") { - // client requested the server to shut down but other clients might still be connected. - // instead we disconnect this client to prevent the editor hanging - // see - if let Some(shutdown_request_id) = json.get("id") { + + Message::Request(req) if req.method == "shutdown" => { + // client requested the server to shut down but other clients might still be connected. + // instead we disconnect this client to prevent the editor hanging + // see info!("client sent shutdown request, sending a response and closing connection"); // - let message = Message::from_json( - &json!({ - "id": shutdown_request_id, - "jsonrpc": "2.0", - "result": null, - }), - &mut buffer, - ); + let message = Message::ResponseSuccess(ResponseSuccess { + jsonrpc: Version, + result: Value::Null, + id: req.id, + }); // ignoring error because we would've closed the connection regardless let _ = close_tx.send(message).await; + break; } - break; - } - if let Some(id) = json.get("id") { - // messages containing an id need the id modified so we can discern which client to send - // the response to - let tagged_id = tag_id(port, id)?; - json.insert("id".to_owned(), Value::String(tagged_id)); - let message = Message::from_json(&json, &mut buffer); - if tx.send(message).await.is_err() { - break; + Message::Request(mut req) => { + req.id = tag_id(port, &req.id)?.into(); + if tx.send(req.into()).await.is_err() { + break; + } } - } else { - // notification messages without an id don't need any modification and can be forwarded - // to rust-analyzer as is - let message = Message::from_bytes(bytes); - if tx.send(message).await.is_err() { - break; + + Message::ResponseSuccess(_) | Message::ResponseError(_) => { + debug!(?message, "client response"); + } + + Message::Notification(notif) => { + if tx.send(notif.into()).await.is_err() { + break; + } } } } diff --git a/src/instance.rs b/src/instance.rs index a3951d0..9bc66a1 100644 --- a/src/instance.rs +++ b/src/instance.rs @@ -18,7 +18,8 @@ use tracing::{debug, error, info, info_span, instrument, trace, warn, Instrument use crate::async_once_cell::AsyncOnceCell; use crate::config::Config; -use crate::lsp::transport::{LspReader, LspWriter, Message}; +use crate::lsp::jsonrpc::{Message, ResponseSuccess}; +use crate::lsp::transport::{LspReader, LspWriter}; use crate::proto; /// keeps track of the initialize/initialized handshake for an instance @@ -26,7 +27,7 @@ use crate::proto; pub struct InitializeCache { request_sent: AtomicBool, notif_sent: AtomicBool, - pub response: AsyncOnceCell, + pub response: AsyncOnceCell, } impl InitializeCache { @@ -356,7 +357,8 @@ impl RaInstance { // child closes and all the clients disconnect including the sender and this receiver // will not keep blocking (unlike in client input task) while let Some(message) = receiver.recv().await { - if let Err(err) = writer.write_message(message).await { + trace!(?message, "-> server"); + if let Err(err) = writer.write_message(&message).await { match err.kind() { // stdin is closed, no need to log an error ErrorKind::BrokenPipe => {} @@ -426,9 +428,13 @@ impl RaInstance { } } -fn parse_tagged_id(tagged: &str) -> Result<(u16, Value)> { +fn parse_tagged_id(tagged: &Value) -> Result<(u16, Value)> { + let Value::String(tagged) = tagged else { + bail!("tagged id must be a String found `{tagged:?}`"); + }; + let (port, rest) = tagged.split_once(':').context("missing first `:`")?; - let port = u16::from_str_radix(port, 16)?; + let port = u16::from_str(port)?; let (value_type, old_id) = rest.split_once(':').context("missing second `:`")?; let old_id = match value_type { "n" => Value::Number(Number::from_str(old_id)?), @@ -444,81 +450,92 @@ async fn read_server_socket( init_cache: &InitializeCache, ) -> Result<()> { let mut reader = LspReader::new(reader); - let mut buffer = Vec::new(); - - while let Some((mut json, bytes)) = reader.read_message().await? { - trace!(message = serde_json::to_string(&json).unwrap(), "server"); - - if let Some(id) = json.get("id") { - // we tagged the request id so we expect to only receive tagged responses - let tagged_id = match id { - Value::String(string) if string == INIT_REQUEST_ID => { - // this is a response to the InitializeRequest, we need to process it - // separately - debug!("recv InitializeRequest response"); - init_cache - .response - .set(Message::from_bytes(bytes)) - .await - .ok() // throw away the Err(message), we don't need it and it doesn't implement std::error::Error - .context("received multiple InitializeRequest responses from instance")?; - continue; - } - Value::String(string) => string, - _ => { - debug!( - message = serde_json::to_string(&json).unwrap(), - "response to no request", - ); - // FIXME uncommenting this crashes rust-analyzer, presumably because the client - // then sends a confusing response or something? i'm guessing that the response - // id gets tagged with port but rust-analyzer expects to know the id because - // it's actually a response and not a request. i'm not sure if we can handle - // these at all with multiple clients attached - // - // ideally we could send these to all clients, but what if there is a matching - // response from each client? rust-analyzer only expects one (this might - // actually be why it's crashing) - // - // ignoring these might end up being the safest option, they don't seem to - // matter to neovim anyway - // ```rust - // let message = Message::from_bytes(bytes); - // let senders = senders.read().await; - // for sender in senders.values() { - // sender - // .send(message.clone()) - // .await - // .context("forward server notification")?; - // } - // ``` - continue; - } - }; - let (port, old_id) = match parse_tagged_id(tagged_id) { - Ok(ok) => ok, - Err(err) => { - warn!(?err, "invalid tagged id"); - continue; - } - }; + while let Some(message) = reader.read_message().await? { + trace!(?message, "<- server"); + + match message { + Message::ResponseSuccess(res) + if res.id == Value::String(INIT_REQUEST_ID.to_owned()) => + { + // this is a response to the InitializeRequest, we need to process it + // separately + debug!(message = ?Message::from(res.clone()), "recv InitializeRequest response"); + init_cache + .response + .set(res.into()) + .await + .ok() // throw away the Err(message), we don't need it and it doesn't implement std::error::Error + .context("received multiple InitializeRequest responses from instance")?; + } + + Message::ResponseSuccess(mut res) => { + match parse_tagged_id(&res.id) { + Ok((port, id)) => { + res.id = id; + if let Some(sender) = senders.read().await.get(&port) { + // ignore closed channels + let _ignore = sender.send(res.into()).await; + } else { + warn!("no client"); + } + } + Err(err) => { + warn!(?err, "invalid tagged id"); + } + }; + } - json.insert("id".to_owned(), old_id); + Message::ResponseError(mut res) => { + match parse_tagged_id(&res.id) { + Ok((port, id)) => { + res.id = id; + if let Some(sender) = senders.read().await.get(&port) { + // ignore closed channels + let _ignore = sender.send(res.into()).await; + } else { + warn!("no client"); + } + } + Err(err) => { + warn!(?err, "invalid tagged id"); + continue; + } + }; + } - if let Some(sender) = senders.read().await.get(&port) { - let message = Message::from_json(&json, &mut buffer); - // ignore closed channels - let _ignore = sender.send(message).await; - } else { - warn!("no client"); + Message::Request(_) => { + debug!(?message, "server request"); + // FIXME uncommenting this crashes rust-analyzer, presumably because the client + // then sends a confusing response or something? i'm guessing that the response + // id gets tagged with port but rust-analyzer expects to know the id because + // it's actually a response and not a request. i'm not sure if we can handle + // these at all with multiple clients attached + // + // ideally we could send these to all clients, but what if there is a matching + // response from each client? rust-analyzer only expects one (this might + // actually be why it's crashing) + // + // ignoring these might end up being the safest option, they don't seem to + // matter to neovim anyway + // ```rust + // let message = Message::from_bytes(bytes); + // let senders = senders.read().await; + // for sender in senders.values() { + // sender + // .send(message.clone()) + // .await + // .context("forward server notification")?; + // } + // ``` } - } else { - // notification messages without an id are sent to all clients - let message = Message::from_bytes(bytes); - for sender in senders.read().await.values() { - // ignore closed channels - let _ignore = sender.send(message.clone()).await; + + Message::Notification(notif) => { + // notification messages without an id are sent to all clients + for sender in senders.read().await.values() { + // ignore closed channels + let _ignore = sender.send(notif.clone().into()).await; + } } } } diff --git a/src/lsp.rs b/src/lsp.rs index a27f6c0..6a2c5bc 100644 --- a/src/lsp.rs +++ b/src/lsp.rs @@ -28,4 +28,5 @@ //! - Progress notifications - contains a `token` property which could be used to identify the //! client but the specification also says it has nothing to do with the request IDs +pub mod jsonrpc; pub mod transport; diff --git a/src/lsp/jsonrpc.rs b/src/lsp/jsonrpc.rs new file mode 100644 index 0000000..e8421b2 --- /dev/null +++ b/src/lsp/jsonrpc.rs @@ -0,0 +1,179 @@ +//! JSON-RPC 2.0 support +//! +//! Support for the JSON-RPC 2.0 protocol as definied here +//! . With the exception of batching which +//! is handled in [`read_message`](super::transport::LspReader::read_message). + +use std::fmt; + +use serde_derive::{Deserialize, Serialize}; + +#[derive(Serialize, Deserialize, Clone)] +#[serde(untagged)] +pub enum Message { + Request(Request), + Notification(Notification), + ResponseError(ResponseError), + ResponseSuccess(ResponseSuccess), +} + +#[derive(Serialize, Deserialize, Clone)] +#[serde(deny_unknown_fields)] +pub struct Request { + pub jsonrpc: Version, + pub method: String, + #[serde(default)] + pub params: serde_json::Value, + pub id: serde_json::Value, +} + +#[derive(Serialize, Deserialize, Clone)] +#[serde(deny_unknown_fields)] +pub struct Notification { + pub jsonrpc: Version, + pub method: String, + #[serde(default)] + pub params: serde_json::Value, +} + +#[derive(Serialize, Deserialize, Clone)] +#[serde(deny_unknown_fields)] +pub struct ResponseError { + pub jsonrpc: Version, + pub error: Error, + pub id: serde_json::Value, +} + +#[derive(Serialize, Deserialize, Clone)] +#[serde(deny_unknown_fields)] +pub struct ResponseSuccess { + pub jsonrpc: Version, + #[serde(default)] + pub result: serde_json::Value, + pub id: serde_json::Value, +} + +#[derive(Serialize, Deserialize, Clone)] +#[serde(deny_unknown_fields)] +pub struct Error { + pub code: i64, + pub message: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub data: Option, +} + +#[derive(Serialize, Deserialize, Clone)] +#[serde(untagged)] +pub enum Params { + ByPosition(Vec), + ByName(serde_json::Map), +} + +/// ZST representation of the `"2.0"` version string +#[derive(Clone, Copy)] +pub struct Version; + +impl serde::ser::Serialize for Version { + fn serialize(&self, serializer: S) -> Result + where + S: serde::ser::Serializer, + { + serializer.serialize_str("2.0") + } +} + +impl<'de> serde::de::Visitor<'de> for Version { + type Value = Version; + + fn expecting(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result { + fmt.write_str(r#"string value "2.0""#) + } + + fn visit_str(self, v: &str) -> Result + where + E: serde::de::Error, + { + match v { + "2.0" => Ok(Version), + _ => Err(E::custom("unsupported JSON-RPC version")), + } + } +} + +impl<'de> serde::de::Deserialize<'de> for Version { + fn deserialize(deserializer: D) -> Result + where + D: serde::de::Deserializer<'de>, + { + deserializer.deserialize_str(Version) + } +} + +impl From for Message { + fn from(value: Request) -> Self { + Message::Request(value) + } +} + +impl From for Message { + fn from(value: Notification) -> Self { + Message::Notification(value) + } +} + +impl From for Message { + fn from(value: ResponseSuccess) -> Self { + Message::ResponseSuccess(value) + } +} + +impl From for Message { + fn from(value: ResponseError) -> Self { + Message::ResponseError(value) + } +} + +impl fmt::Debug for Message { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let json = serde_json::to_string(self).expect("BUG: invalid message"); + f.write_str(&json) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn test(input: &str) { + let deserialized = serde_json::from_str::(input).expect("failed to deserialize"); + let serialized = serde_json::to_string(&deserialized).expect("failed to serialize"); + assert_eq!(input, serialized); + } + + #[test] + fn call_with_positional_parameters() { + test(r#"{"jsonrpc":"2.0","method":"subtract","params":[42,23],"id":1}"#) + } + + #[test] + fn call_with_named_parameters() { + test( + r#"{"jsonrpc":"2.0","method":"subtract","params":{"minuend":42,"subtrahend":23},"id":3}"#, + ); + } + + #[test] + fn response() { + test(r#"{"jsonrpc":"2.0","result":19,"id":1}"#) + } + + #[test] + fn notification() { + test(r#"{"jsonrpc":"2.0","method":"update","params":[1,2,3,4,5]}"#) + } + + #[test] + fn error() { + test(r#"{"jsonrpc":"2.0","error":{"code":-32601,"message":"Method not found"},"id":"1"}"#) + } +} diff --git a/src/lsp/transport.rs b/src/lsp/transport.rs index 9168130..85b4191 100644 --- a/src/lsp/transport.rs +++ b/src/lsp/transport.rs @@ -1,15 +1,14 @@ -use std::fmt::{self, Debug}; use std::io::{self, ErrorKind}; use std::str; -use std::sync::Arc; use anyhow::{bail, ensure, Context, Result}; -use serde::Serialize; -use serde_json::{Map, Value}; use tokio::io::{AsyncBufRead, AsyncBufReadExt, AsyncReadExt, AsyncWrite, AsyncWriteExt}; +use crate::lsp::jsonrpc::Message; + pub struct LspReader { reader: R, + batch: Vec, buffer: Vec, } @@ -19,6 +18,11 @@ pub struct LspReader { /// The currently recognized headers are `content-type` which is optional and contains a `string` /// (something like a MIME-type) and `content-length` which contains the length of the message body /// after the final `\r\n` of the header. Header names and values are separated by `: `. +/// +/// While we parse the `content-type` header ignore it completely and we don't forward it, +/// expecting both the server and client to assume the default. +/// +/// For mor details see . pub struct Header { pub content_length: usize, pub content_type: Option, @@ -31,11 +35,12 @@ where pub fn new(reader: R) -> Self { LspReader { reader, + batch: Vec::new(), buffer: Vec::with_capacity(1024), } } - async fn read_header(&mut self) -> Result> { + pub async fn read_header(&mut self) -> Result> { let mut content_type = None; let mut content_length = None; @@ -55,8 +60,9 @@ where let header_text = self .buffer .strip_suffix(b"\r\n") - .context("malformed header, missing \\r\\n")?; - let header_text = str::from_utf8(header_text).context("malformed header")?; + .context(r"malformed header, missing `\r\n` terminator")?; + let header_text = str::from_utf8(header_text) + .context("malformed header, ascii encoding is a subset of utf-8")?; if header_text.is_empty() { // headers are separated by an empty line from the body @@ -76,7 +82,7 @@ where ensure!(content_length.is_none(), "repeated header content-length"); content_length = Some(value.parse::().context("content-length header")?); } - _ => bail!("unknown header: {name}"), + _ => bail!("unknown header name: {name:?}"), } } @@ -87,9 +93,19 @@ where })) } - /// reads one LSP message from a reader, deserializes it and leaves the serialized body of the - /// message in `buffer` - pub async fn read_message(&mut self) -> Result, &[u8])>> { + /// Read one message + /// + /// Returns `None` if the reader was closed and it'll never return another + /// message after the first `None`. + /// + /// Batch messages are transparently split into individual messages and + /// delivered in order. + pub async fn read_message(&mut self) -> Result> { + // return pending messages until the last batch is drained + if let Some(pending) = self.batch.pop() { + return Ok(Some(pending)); + } + let header = self.read_header().await.context("parsing header")?; let header = match header { Some(header) => header, @@ -110,25 +126,34 @@ where } let bytes = self.buffer.as_slice(); - let json = serde_json::from_slice(bytes).context("invalid body")?; - Ok(Some((json, bytes))) + let body = str::from_utf8(bytes) + .with_context(|| { + let lossy_utf8 = String::from_utf8_lossy(bytes); + format!("parsing body `{lossy_utf8}`") + }) + .context("parsing LSP message")?; + + // handle batches + if body.starts_with('[') { + self.batch = serde_json::from_str(body) + .with_context(|| format!("parsing body `{body}`")) + .context("parsing LSP message")?; + // we're popping the messages from the end of the vec + self.batch.reverse(); + let message = self.batch.pop().context("received an empty batch")?; + Ok(Some(message)) + } else { + let message = serde_json::from_str(body) + .with_context(|| format!("parsing body `{body}`")) + .context("parsing LSP message")?; + Ok(Some(message)) + } } } pub struct LspWriter { writer: W, -} - -/// LSP messages -#[derive(Clone)] -pub struct Message { - bytes: Arc<[u8]>, -} - -impl Debug for Message { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str("Message") - } + buffer: Vec, } impl LspWriter @@ -136,35 +161,21 @@ where W: AsyncWrite + Unpin, { pub fn new(writer: W) -> Self { - LspWriter { writer } + LspWriter { + writer, + buffer: Vec::with_capacity(1024), + } } /// serialize LSP message into a writer, prepending the appropriate content-length header - pub async fn write_message(&mut self, message: Message) -> io::Result<()> { + pub async fn write_message(&mut self, message: &Message) -> io::Result<()> { + self.buffer.clear(); + serde_json::to_writer(&mut self.buffer, message).expect("BUG: invalid message"); + self.writer - .write_all(format!("Content-Length: {}\r\n\r\n", message.bytes.len()).as_bytes()) + .write_all(format!("Content-Length: {}\r\n\r\n", self.buffer.len()).as_bytes()) .await?; - self.writer.write_all(&message.bytes).await?; + self.writer.write_all(&self.buffer).await?; self.writer.flush().await } } - -impl Message { - /// construct a message from a byte buffer, should only contain the message body - no headers - pub fn from_bytes(bytes: &[u8]) -> Self { - Self { - bytes: Arc::from(bytes), - } - } - - pub fn as_bytes(&self) -> &[u8] { - &self.bytes - } - - /// construct a message from a serializable value, like JSON - pub fn from_json(json: &impl Serialize, buffer: &mut Vec) -> Self { - buffer.clear(); - serde_json::to_writer(&mut *buffer, json).expect("invalid json"); - Self::from_bytes(&*buffer) - } -}