From 281fded73565d8206062d737d1937dbc1a554d1c Mon Sep 17 00:00:00 2001 From: Riatre Foo Date: Thu, 11 Apr 2024 05:29:22 +0800 Subject: [PATCH] Add binding for channel_open_request_auth_agent_callback This callback is required for implementing ssh agent forward as unlike X11 forward, there is no other way to establish a forwarding channel. The API design looks slightly convoluted, it's because in libssh: 1. Callback is triggered while handling protocol packets in other libssh call. 2. The callback creates a new channel and prepare for bidirectional forwarding between it and ssh agent. 3. The callback then returns a borrow of the newly created channel for libssh to make reply to the remote side. To do 3 we have to somehow steal a `struct ssh_channel*` from the user-owned channel. We decided to do so by create channel in the binding code, keep a ref and move it to user. Due to locking issues we have to take the Channel back if the user decided to not accept forward request. See SATEFY comment in bridge_channel_open_request_auth_agent_callback for details. --- libssh-rs/src/channel.rs | 3 ++ libssh-rs/src/error.rs | 10 +++++ libssh-rs/src/lib.rs | 95 +++++++++++++++++++++++++++++++++++++--- 3 files changed, 102 insertions(+), 6 deletions(-) diff --git a/libssh-rs/src/channel.rs b/libssh-rs/src/channel.rs index 82313c6..1d358ff 100644 --- a/libssh-rs/src/channel.rs +++ b/libssh-rs/src/channel.rs @@ -46,6 +46,9 @@ impl Drop for Channel { // Prevent any callbacks firing as part the remainder of this drop operation sys::ssh_remove_channel_callbacks(self.chan_inner, self._callbacks.as_mut()); } + if self.chan_inner.is_null() { + return; + } let (_sess, chan) = self.lock_session(); unsafe { sys::ssh_channel_free(chan); diff --git a/libssh-rs/src/error.rs b/libssh-rs/src/error.rs index 23c1d53..ab26fbc 100644 --- a/libssh-rs/src/error.rs +++ b/libssh-rs/src/error.rs @@ -17,6 +17,10 @@ pub enum Error { Sftp(crate::sftp::SftpError), } +#[derive(Error)] +#[error("{0}")] +pub struct RequestAuthAgentError(pub Error, pub crate::Channel); + /// Represents the result of a fallible operation pub type SshResult = Result; @@ -53,3 +57,9 @@ impl From for Error { Error::Fatal(err.to_string()) } } + +impl std::fmt::Debug for RequestAuthAgentError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "RequestAuthAgentError({})", self.0) + } +} diff --git a/libssh-rs/src/lib.rs b/libssh-rs/src/lib.rs index 4352ea5..a988c99 100644 --- a/libssh-rs/src/lib.rs +++ b/libssh-rs/src/lib.rs @@ -17,8 +17,8 @@ use std::os::unix::io::RawFd as RawSocket; #[cfg(windows)] use std::os::windows::io::RawSocket; use std::ptr::null_mut; -use std::sync::Once; use std::sync::{Arc, Mutex, MutexGuard}; +use std::sync::{Once, Weak}; use std::time::Duration; mod channel; @@ -72,9 +72,12 @@ fn initialize() -> SshResult<()> { } pub(crate) struct SessionHolder { + outer: Weak>, sess: sys::ssh_session, callbacks: sys::ssh_callbacks_struct, auth_callback: Option) -> SshResult>>, + channel_open_request_auth_agent_callback: + Option Result<(), RequestAuthAgentError>>>, } unsafe impl Send for SessionHolder {} @@ -197,11 +200,16 @@ impl Session { channel_open_request_x11_function: None, channel_open_request_auth_agent_function: None, }; - let sess = Arc::new(Mutex::new(SessionHolder { - sess, - callbacks, - auth_callback: None, - })); + let sess = Arc::new_cyclic(|outer| { + let outer = outer.clone(); + Mutex::new(SessionHolder { + outer, + sess, + callbacks, + auth_callback: None, + channel_open_request_auth_agent_callback: None, + }) + }); { let mut sess = sess.lock().unwrap(); @@ -274,6 +282,55 @@ impl Session { } } + unsafe extern "C" fn bridge_channel_open_request_auth_agent_callback( + session: sys::ssh_session, + userdata: *mut ::std::os::raw::c_void, + ) -> sys::ssh_channel { + let result = std::panic::catch_unwind(|| -> SshResult { + let sess: &mut SessionHolder = &mut *(userdata as *mut SessionHolder); + assert!( + std::ptr::eq(session, sess.sess), + "invalid callback invocation: session mismatch" + ); + let cb = sess + .channel_open_request_auth_agent_callback + .as_mut() + .unwrap(); + let chan = unsafe { sys::ssh_channel_new(session) }; + if chan.is_null() { + return Err(sess + .last_error() + .unwrap_or_else(|| Error::fatal("ssh_channel_new failed"))); + } + match cb(Channel::new(&sess.outer.upgrade().unwrap(), chan)) { + // SAFETY: We steal a *mut sys::ssh_channel_struct here and let libssh + // temporarily "borrows" it for an unspecified amount of time. + // libssh is guaranteed to finish using it before returning from the outermost + // libssh function call that triggered this callback. As such function call + // always happens with Session locked and dropping Channel needs to lock the + // session first, we can be sure that this *mut sys::ssh_channel_struct will not + // be freed while libssh is still using it. + Ok(_) => Ok(chan), + Err(RequestAuthAgentError(err, mut chan_obj)) => { + unsafe { sys::ssh_channel_free(chan_obj.chan_inner) }; + chan_obj.chan_inner = std::ptr::null_mut(); + Err(err) + } + } + }); + match result { + Err(err) => { + eprintln!("Panic in request auth agent callback: {:?}", err); + std::ptr::null_mut() + } + Ok(Err(err)) => { + eprintln!("Error in request auth agent callback: {:#}", err); + std::ptr::null_mut() + } + Ok(Ok(chan)) => chan, + } + } + /// Sets a callback that is used by libssh when it needs to prompt /// for the passphrase during public key authentication. /// This is NOT used for password or keyboard interactive authentication. @@ -326,6 +383,32 @@ impl Session { sess.callbacks.auth_function = Some(Self::bridge_auth_callback); } + /// Sets a callback that is used by libssh when the remote side requests a new channel + /// for SSH agent forwarding. + /// The callback has the signature: + /// + /// ```no_run + /// use libssh_rs::RequestAuthAgentResult; + /// fn callback(channel: Channel) -> RequestAuthAgentResult { + /// unimplemented!() + /// } + /// ``` + /// + /// The callback should decide whether to allow the agent forward and if so, take ownership of + /// the channel (and further move it elsewhere to handle agent protocol within). Otherwise or + /// in case of an error, the callback should return the channel back as it is not possible to + /// drop it in the callback. + pub fn set_channel_open_request_auth_agent_callback(&self, callback: F) + where + F: FnMut(Channel) -> Result<(), RequestAuthAgentError> + 'static, + { + let mut sess = self.lock_session(); + sess.channel_open_request_auth_agent_callback + .replace(Box::new(callback)); + sess.callbacks.channel_open_request_auth_agent_function = + Some(Self::bridge_channel_open_request_auth_agent_callback); + } + /// Create a new channel. /// Channels are used to handle I/O for commands and forwarded streams. pub fn new_channel(&self) -> SshResult {