diff --git a/libssh-rs/src/channel.rs b/libssh-rs/src/channel.rs index 82313c6..1d358ff 100644 --- a/libssh-rs/src/channel.rs +++ b/libssh-rs/src/channel.rs @@ -46,6 +46,9 @@ impl Drop for Channel { // Prevent any callbacks firing as part the remainder of this drop operation sys::ssh_remove_channel_callbacks(self.chan_inner, self._callbacks.as_mut()); } + if self.chan_inner.is_null() { + return; + } let (_sess, chan) = self.lock_session(); unsafe { sys::ssh_channel_free(chan); diff --git a/libssh-rs/src/error.rs b/libssh-rs/src/error.rs index 23c1d53..ab26fbc 100644 --- a/libssh-rs/src/error.rs +++ b/libssh-rs/src/error.rs @@ -17,6 +17,10 @@ pub enum Error { Sftp(crate::sftp::SftpError), } +#[derive(Error)] +#[error("{0}")] +pub struct RequestAuthAgentError(pub Error, pub crate::Channel); + /// Represents the result of a fallible operation pub type SshResult = Result; @@ -53,3 +57,9 @@ impl From for Error { Error::Fatal(err.to_string()) } } + +impl std::fmt::Debug for RequestAuthAgentError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "RequestAuthAgentError({})", self.0) + } +} diff --git a/libssh-rs/src/lib.rs b/libssh-rs/src/lib.rs index 4352ea5..a988c99 100644 --- a/libssh-rs/src/lib.rs +++ b/libssh-rs/src/lib.rs @@ -17,8 +17,8 @@ use std::os::unix::io::RawFd as RawSocket; #[cfg(windows)] use std::os::windows::io::RawSocket; use std::ptr::null_mut; -use std::sync::Once; use std::sync::{Arc, Mutex, MutexGuard}; +use std::sync::{Once, Weak}; use std::time::Duration; mod channel; @@ -72,9 +72,12 @@ fn initialize() -> SshResult<()> { } pub(crate) struct SessionHolder { + outer: Weak>, sess: sys::ssh_session, callbacks: sys::ssh_callbacks_struct, auth_callback: Option) -> SshResult>>, + channel_open_request_auth_agent_callback: + Option Result<(), RequestAuthAgentError>>>, } unsafe impl Send for SessionHolder {} @@ -197,11 +200,16 @@ impl Session { channel_open_request_x11_function: None, channel_open_request_auth_agent_function: None, }; - let sess = Arc::new(Mutex::new(SessionHolder { - sess, - callbacks, - auth_callback: None, - })); + let sess = Arc::new_cyclic(|outer| { + let outer = outer.clone(); + Mutex::new(SessionHolder { + outer, + sess, + callbacks, + auth_callback: None, + channel_open_request_auth_agent_callback: None, + }) + }); { let mut sess = sess.lock().unwrap(); @@ -274,6 +282,55 @@ impl Session { } } + unsafe extern "C" fn bridge_channel_open_request_auth_agent_callback( + session: sys::ssh_session, + userdata: *mut ::std::os::raw::c_void, + ) -> sys::ssh_channel { + let result = std::panic::catch_unwind(|| -> SshResult { + let sess: &mut SessionHolder = &mut *(userdata as *mut SessionHolder); + assert!( + std::ptr::eq(session, sess.sess), + "invalid callback invocation: session mismatch" + ); + let cb = sess + .channel_open_request_auth_agent_callback + .as_mut() + .unwrap(); + let chan = unsafe { sys::ssh_channel_new(session) }; + if chan.is_null() { + return Err(sess + .last_error() + .unwrap_or_else(|| Error::fatal("ssh_channel_new failed"))); + } + match cb(Channel::new(&sess.outer.upgrade().unwrap(), chan)) { + // SAFETY: We steal a *mut sys::ssh_channel_struct here and let libssh + // temporarily "borrows" it for an unspecified amount of time. + // libssh is guaranteed to finish using it before returning from the outermost + // libssh function call that triggered this callback. As such function call + // always happens with Session locked and dropping Channel needs to lock the + // session first, we can be sure that this *mut sys::ssh_channel_struct will not + // be freed while libssh is still using it. + Ok(_) => Ok(chan), + Err(RequestAuthAgentError(err, mut chan_obj)) => { + unsafe { sys::ssh_channel_free(chan_obj.chan_inner) }; + chan_obj.chan_inner = std::ptr::null_mut(); + Err(err) + } + } + }); + match result { + Err(err) => { + eprintln!("Panic in request auth agent callback: {:?}", err); + std::ptr::null_mut() + } + Ok(Err(err)) => { + eprintln!("Error in request auth agent callback: {:#}", err); + std::ptr::null_mut() + } + Ok(Ok(chan)) => chan, + } + } + /// Sets a callback that is used by libssh when it needs to prompt /// for the passphrase during public key authentication. /// This is NOT used for password or keyboard interactive authentication. @@ -326,6 +383,32 @@ impl Session { sess.callbacks.auth_function = Some(Self::bridge_auth_callback); } + /// Sets a callback that is used by libssh when the remote side requests a new channel + /// for SSH agent forwarding. + /// The callback has the signature: + /// + /// ```no_run + /// use libssh_rs::RequestAuthAgentResult; + /// fn callback(channel: Channel) -> RequestAuthAgentResult { + /// unimplemented!() + /// } + /// ``` + /// + /// The callback should decide whether to allow the agent forward and if so, take ownership of + /// the channel (and further move it elsewhere to handle agent protocol within). Otherwise or + /// in case of an error, the callback should return the channel back as it is not possible to + /// drop it in the callback. + pub fn set_channel_open_request_auth_agent_callback(&self, callback: F) + where + F: FnMut(Channel) -> Result<(), RequestAuthAgentError> + 'static, + { + let mut sess = self.lock_session(); + sess.channel_open_request_auth_agent_callback + .replace(Box::new(callback)); + sess.callbacks.channel_open_request_auth_agent_function = + Some(Self::bridge_channel_open_request_auth_agent_callback); + } + /// Create a new channel. /// Channels are used to handle I/O for commands and forwarded streams. pub fn new_channel(&self) -> SshResult {