From 732bcceae239f7a78ffc37562d59990db40a0345 Mon Sep 17 00:00:00 2001 From: "U. Lasiotus" Date: Sun, 8 Dec 2024 11:18:31 -0800 Subject: [PATCH] Start porting tokio/mio. Things look promising: a simple test (sys/mio-test) works. Still a lot of work to do. --- Makefile.toml | 20 + src/imager/src/main.rs | 4 +- src/sys/Cargo.lock | 51 +- src/sys/Cargo.toml | 1 + src/sys/lib/moto-ipc/src/io_channel.rs | 21 + src/sys/lib/moto-rt/src/lib.rs | 10 + src/sys/lib/moto-rt/src/net.rs | 33 +- src/sys/lib/moto-rt/src/poll.rs | 101 ++++ src/sys/lib/rt.vdso/src/main.rs | 24 + src/sys/lib/rt.vdso/src/posix.rs | 15 + src/sys/lib/rt.vdso/src/rt_net.rs | 651 +++++++++++++++++++++---- src/sys/lib/rt.vdso/src/rt_poll.rs | 72 +++ src/sys/lib/rt.vdso/src/runtime.rs | 253 ++++++++++ src/sys/tests/mio-test/Cargo.toml | 7 + src/sys/tests/mio-test/src/main.rs | 5 + src/sys/tests/mio-test/src/simple.rs | 234 +++++++++ 16 files changed, 1384 insertions(+), 118 deletions(-) create mode 100644 src/sys/lib/moto-rt/src/poll.rs create mode 100644 src/sys/lib/rt.vdso/src/rt_poll.rs create mode 100644 src/sys/lib/rt.vdso/src/runtime.rs create mode 100644 src/sys/tests/mio-test/Cargo.toml create mode 100644 src/sys/tests/mio-test/src/main.rs create mode 100644 src/sys/tests/mio-test/src/simple.rs diff --git a/Makefile.toml b/Makefile.toml index bb60741..daa9e64 100644 --- a/Makefile.toml +++ b/Makefile.toml @@ -56,6 +56,7 @@ dependencies = [ "mdbg_debug", "rnetbench_debug", "systest_debug", + "mio_test_debug", "make_img_debug", ] @@ -78,6 +79,7 @@ dependencies = [ "mdbg_release", "rnetbench_release", "systest_release", + "mio_test_release", "make_img_release", ] @@ -99,6 +101,7 @@ dependencies = [ "mdbg_debug", "rnetbench_debug", "systest_debug", + "mio_test_debug", "make_img_debug", ] @@ -120,6 +123,7 @@ dependencies = [ "mdbg_release", "rnetbench_release", "systest_release", + "mio_test_release", "make_img_release", ] @@ -294,6 +298,22 @@ cargo +dev-x86_64-unknown-moturus clippy --release --target x86_64-unknown-motur cp "${CARGO_TARGET_DIR}/x86_64-unknown-moturus/release/systest" "${MOTO_BIN}/systest" ''' +[tasks.mio_test_debug] +cwd = "./src/sys/tests/mio-test" +script = ''' +cargo +dev-x86_64-unknown-moturus build --target x86_64-unknown-moturus +cargo +dev-x86_64-unknown-moturus clippy --target x86_64-unknown-moturus +strip -o "${MOTO_BIN}/mio-test" "${CARGO_TARGET_DIR}/x86_64-unknown-moturus/debug/mio-test" +''' + +[tasks.mio_test_release] +cwd = "./src/sys/tests/mio-test" +script = ''' +cargo +dev-x86_64-unknown-moturus build --release --target x86_64-unknown-moturus +cargo +dev-x86_64-unknown-moturus clippy --release --target x86_64-unknown-moturus +cp "${CARGO_TARGET_DIR}/x86_64-unknown-moturus/release/mio-test" "${MOTO_BIN}/mio-test" +''' + [tasks.rush_debug] cwd = "./src/bin/rush" script = ''' diff --git a/src/imager/src/main.rs b/src/imager/src/main.rs index fc27c93..84b6e4a 100644 --- a/src/imager/src/main.rs +++ b/src/imager/src/main.rs @@ -20,7 +20,7 @@ use std::io::{self, Seek, SeekFrom}; const SECTOR_SIZE: u32 = 512; // For the "full" image. -static BIN_FULL: [&str; 10] = [ +static BIN_FULL: [&str; 11] = [ "bin/httpd", "bin/kibim", "bin/rush", @@ -31,7 +31,7 @@ static BIN_FULL: [&str; 10] = [ "sys/sys-tty", "sys/sysbox", "sys/systest", -// "sys/mio-test", + "sys/mio-test", ]; // For the "web" image. diff --git a/src/sys/Cargo.lock b/src/sys/Cargo.lock index ea58624..7ee860f 100644 --- a/src/sys/Cargo.lock +++ b/src/sys/Cargo.lock @@ -472,7 +472,7 @@ dependencies = [ "frusa", "intrusive-collections", "log", - "moto-rt", + "moto-rt 0.1.0", "moto-sys", "x86", "x86_64", @@ -518,7 +518,7 @@ name = "mdbg" version = "0.1.0" dependencies = [ "clap", - "moto-rt", + "moto-rt 0.1.0", "moto-sys", ] @@ -528,12 +528,31 @@ version = "2.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" +[[package]] +name = "mio" +version = "1.0.2" +source = "git+https://github.com/moturus/mio.git?branch=motor-os_20241121#9fc61237be4ca3f61fc2aec7793d990bc2eaee05" +dependencies = [ + "libc", + "log", + "moto-rt 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)", + "wasi", + "windows-sys", +] + +[[package]] +name = "mio-test" +version = "0.1.0" +dependencies = [ + "mio", +] + [[package]] name = "moto-ipc" version = "0.2.5" dependencies = [ "compiler_builtins", - "moto-rt", + "moto-rt 0.1.0", "moto-sys", "rustc-std-workspace-alloc", "rustc-std-workspace-core", @@ -545,7 +564,7 @@ version = "0.1.0" dependencies = [ "log", "moto-ipc", - "moto-rt", + "moto-rt 0.1.0", "moto-sys", "spin 0.5.2", ] @@ -555,7 +574,7 @@ name = "moto-mpmc" version = "0.1.0" dependencies = [ "crossbeam-utils", - "moto-rt", + "moto-rt 0.1.0", "moto-sys", ] @@ -568,12 +587,18 @@ dependencies = [ "rustc-std-workspace-core", ] +[[package]] +name = "moto-rt" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "64a5ea411b361b70e6f124111875f03b13faa1b6f1b2e0af93e68c51ca4120c7" + [[package]] name = "moto-sys" version = "0.2.4" dependencies = [ "compiler_builtins", - "moto-rt", + "moto-rt 0.1.0", "rustc-std-workspace-alloc", "rustc-std-workspace-core", ] @@ -583,7 +608,7 @@ name = "moto-sys-io" version = "0.2.4" dependencies = [ "moto-ipc", - "moto-rt", + "moto-rt 0.1.0", "moto-sys", "smoltcp", ] @@ -734,7 +759,7 @@ dependencies = [ "elfloader", "frusa", "moto-ipc", - "moto-rt", + "moto-rt 0.1.0", "moto-sys", "moto-sys-io", ] @@ -903,7 +928,7 @@ dependencies = [ "ipnetwork", "log", "moto-ipc", - "moto-rt", + "moto-rt 0.1.0", "moto-sys", "moto-sys-io", "moto-virtio", @@ -920,7 +945,7 @@ version = "0.1.0" dependencies = [ "moto-ipc", "moto-log", - "moto-rt", + "moto-rt 0.1.0", "moto-sys", ] @@ -931,7 +956,7 @@ dependencies = [ "log", "moto-ipc", "moto-log", - "moto-rt", + "moto-rt 0.1.0", "moto-sys", "x86_64", ] @@ -941,7 +966,7 @@ name = "sysbox" version = "0.1.0" dependencies = [ "moto-ipc", - "moto-rt", + "moto-rt 0.1.0", "moto-sys", "moto-sys-io", "time", @@ -955,7 +980,7 @@ dependencies = [ "futures", "moto-ipc", "moto-mpmc", - "moto-rt", + "moto-rt 0.1.0", "moto-sys", ] diff --git a/src/sys/Cargo.toml b/src/sys/Cargo.toml index 5ff96dd..b7b1592 100644 --- a/src/sys/Cargo.toml +++ b/src/sys/Cargo.toml @@ -26,6 +26,7 @@ members = [ "tools/sysbox", # tests + "tests/mio-test", "tests/systest", ] resolver = "2" diff --git a/src/sys/lib/moto-ipc/src/io_channel.rs b/src/sys/lib/moto-ipc/src/io_channel.rs index 89eae15..2f47895 100644 --- a/src/sys/lib/moto-ipc/src/io_channel.rs +++ b/src/sys/lib/moto-ipc/src/io_channel.rs @@ -233,6 +233,17 @@ impl RawChannel { } } + fn may_alloc_page(&self, subchannel: SubChannel) -> bool { + let (bitmap_ref, subchannel_mask) = match subchannel { + SubChannel::Client(mask) => (&self.client_pages_in_use, mask), + SubChannel::Server(mask) => (&self.server_pages_in_use, mask), + }; + + let bitmap = bitmap_ref.load(Ordering::Relaxed); + let ones = (bitmap | !subchannel_mask).trailing_ones(); + ones != 64 + } + fn alloc_page(&self, subchannel: SubChannel) -> Result { let (bitmap_ref, subchannel_mask) = match subchannel { SubChannel::Client(mask) => (&self.client_pages_in_use, mask), @@ -535,6 +546,11 @@ impl ClientConnection { } } + pub fn may_alloc_page(&self, subchannel_mask: u64) -> bool { + self.raw_channel() + .may_alloc_page(SubChannel::Client(subchannel_mask)) + } + pub fn get_page(&self, page_idx: u16) -> Result { if page_idx & !IoPage::SERVER_FLAG > (CHANNEL_PAGE_COUNT as u16) { Err(moto_rt::E_INVALID_ARGUMENT) @@ -744,6 +760,11 @@ impl ServerConnection { } } + pub fn may_alloc_page(&self, subchannel_mask: u64) -> bool { + self.raw_channel() + .may_alloc_page(SubChannel::Server(subchannel_mask)) + } + pub fn get_page(&self, page_idx: u16) -> Result { if page_idx & !IoPage::SERVER_FLAG > (CHANNEL_PAGE_COUNT as u16) { Err(moto_rt::E_INVALID_ARGUMENT) diff --git a/src/sys/lib/moto-rt/src/lib.rs b/src/sys/lib/moto-rt/src/lib.rs index e431e38..6c80821 100644 --- a/src/sys/lib/moto-rt/src/lib.rs +++ b/src/sys/lib/moto-rt/src/lib.rs @@ -78,6 +78,8 @@ pub mod net; #[allow(nonstandard_style)] pub mod netc; +#[cfg(not(feature = "base"))] +pub mod poll; #[cfg(not(feature = "base"))] pub mod process; #[cfg(not(feature = "base"))] @@ -199,6 +201,7 @@ pub struct RtVdsoVtableV1 { // Networking. pub dns_lookup: AtomicU64, pub net_bind: AtomicU64, + pub net_listen: AtomicU64, pub net_accept: AtomicU64, pub net_tcp_connect: AtomicU64, pub net_udp_connect: AtomicU64, @@ -210,6 +213,13 @@ pub struct RtVdsoVtableV1 { pub net_udp_recv_from: AtomicU64, pub net_udp_peek_from: AtomicU64, pub net_udp_send_to: AtomicU64, + + // Polling. + pub poll_new: AtomicU64, + pub poll_add: AtomicU64, + pub poll_set: AtomicU64, + pub poll_del: AtomicU64, + pub poll_wait: AtomicU64, } #[cfg(not(feature = "base"))] diff --git a/src/sys/lib/moto-rt/src/net.rs b/src/sys/lib/moto-rt/src/net.rs index 908a390..130a2ee 100644 --- a/src/sys/lib/moto-rt/src/net.rs +++ b/src/sys/lib/moto-rt/src/net.rs @@ -21,6 +21,7 @@ pub const SO_SNDTIMEO: u64 = 2; pub const SO_SHUTDOWN: u64 = 3; pub const SO_NODELAY: u64 = 4; pub const SO_TTL: u64 = 5; +pub const SO_NONBLOCKING: u64 = 6; fn setsockopt(rt_fd: RtFd, opt: u64, ptr: usize, len: usize) -> Result<(), ErrorCode> { let vdso_setsockopt: extern "C" fn(RtFd, u64, usize, usize) -> ErrorCode = unsafe { @@ -52,6 +53,16 @@ pub fn bind(proto: u8, addr: &netc::sockaddr) -> Result { to_result!(vdso_bind(proto, addr)) } +pub fn listen(rt_fd: RtFd, max_backlog: u32) -> Result<(), ErrorCode> { + let vdso_listen: extern "C" fn(RtFd, u32) -> ErrorCode = unsafe { + core::mem::transmute( + RtVdsoVtableV1::get().net_listen.load(Ordering::Relaxed) as usize as *const (), + ) + }; + + ok_or_error(vdso_listen(rt_fd, max_backlog)) +} + pub fn accept(rt_fd: RtFd) -> Result<(RtFd, netc::sockaddr), ErrorCode> { let vdso_accept: extern "C" fn(RtFd, *mut netc::sockaddr) -> RtFd = unsafe { core::mem::transmute( @@ -68,6 +79,9 @@ pub fn accept(rt_fd: RtFd) -> Result<(RtFd, netc::sockaddr), ErrorCode> { Ok((res, addr)) } +/// Create a TCP stream by connecting to a remote addr. +/// +/// If timeout.is_zero(), the connect is nonblocking. pub fn tcp_connect(addr: &netc::sockaddr, timeout: Duration) -> Result { let vdso_tcp_connect: extern "C" fn(*const netc::sockaddr, u64) -> RtFd = unsafe { core::mem::transmute( @@ -100,9 +114,7 @@ pub fn socket_addr(_rt_fd: RtFd) -> Result { pub fn peer_addr(rt_fd: RtFd) -> Result { let vdso_peer_addr: extern "C" fn(RtFd, *mut netc::sockaddr) -> ErrorCode = unsafe { core::mem::transmute( - RtVdsoVtableV1::get() - .net_peer_addr - .load(Ordering::Relaxed) as usize as *const (), + RtVdsoVtableV1::get().net_peer_addr.load(Ordering::Relaxed) as usize as *const (), ) }; @@ -135,11 +147,12 @@ pub fn only_v6(_rt_fd: RtFd) -> Result { pub fn take_error(_rt_fd: RtFd) -> Result { // getsockopt - Err(crate::E_NOT_IMPLEMENTED) + todo!() } -pub fn set_nonblocking(_rt_fd: RtFd, _nonblocking: bool) -> Result<(), ErrorCode> { - todo!() +pub fn set_nonblocking(rt_fd: RtFd, nonblocking: bool) -> Result<(), ErrorCode> { + let nonblocking: u8 = if nonblocking { 1 } else { 0 }; + setsockopt(rt_fd, SO_NONBLOCKING, &nonblocking as *const _ as usize, 1) } pub fn peek(_rt_fd: RtFd, _buf: &mut [u8]) -> Result { @@ -153,6 +166,7 @@ pub fn set_read_timeout(rt_fd: RtFd, timeout: Option) -> Result<(), Er }; if timeout == 0 { + // See TcpStream::set_read_timeout() doc in Rust stdlib. return Err(crate::E_INVALID_ARGUMENT); } @@ -187,6 +201,11 @@ pub fn set_write_timeout(rt_fd: RtFd, timeout: Option) -> Result<(), E None => u64::MAX, }; + if timeout == 0 { + // See TcpStream::set_write_timeout() doc in Rust stdlib. + return Err(crate::E_INVALID_ARGUMENT); + } + setsockopt( rt_fd, SO_SNDTIMEO, @@ -229,7 +248,7 @@ pub fn linger(_rt_fd: RtFd) -> Result, ErrorCode> { } pub fn set_nodelay(rt_fd: RtFd, nodelay: bool) -> Result<(), ErrorCode> { - let nodelay = if nodelay { 1 } else { 0 }; + let nodelay: u8 = if nodelay { 1 } else { 0 }; setsockopt(rt_fd, SO_NODELAY, &nodelay as *const _ as usize, 1) } diff --git a/src/sys/lib/moto-rt/src/poll.rs b/src/sys/lib/moto-rt/src/poll.rs new file mode 100644 index 0000000..f429df9 --- /dev/null +++ b/src/sys/lib/moto-rt/src/poll.rs @@ -0,0 +1,101 @@ +use crate::ok_or_error; +use crate::to_result; +use crate::ErrorCode; +use crate::RtFd; +use crate::RtVdsoVtableV1; +use core::sync::atomic::Ordering; + +#[cfg(not(feature = "rustc-dep-of-std"))] +extern crate alloc; + +pub const POLL_READABLE: u64 = 1; +pub const POLL_WRITABLE: u64 = 2; +pub const POLL_READ_CLOSED: u64 = 4; +pub const POLL_WRITE_CLOSED: u64 = 8; +pub const POLL_ERROR: u64 = 16; + +pub type Token = u64; +pub type Interests = u64; +pub type EventBits = u64; + +#[derive(Clone, Copy, Debug)] +pub struct Event { + pub token: Token, + pub events: EventBits, +} + +pub fn new() -> Result { + let vdso_poll_new: extern "C" fn() -> RtFd = unsafe { + core::mem::transmute( + RtVdsoVtableV1::get().poll_new.load(Ordering::Relaxed) as usize as *const (), + ) + }; + + to_result!(vdso_poll_new()) +} + +pub fn add( + poll_fd: RtFd, + source_fd: RtFd, + token: Token, + interests: Interests, +) -> Result<(), ErrorCode> { + let vdso_poll_add: extern "C" fn(RtFd, RtFd, u64, u64) -> ErrorCode = unsafe { + core::mem::transmute( + RtVdsoVtableV1::get().poll_add.load(Ordering::Relaxed) as usize as *const (), + ) + }; + + ok_or_error(vdso_poll_add(poll_fd, source_fd, token, interests)) +} + +pub fn set( + poll_fd: RtFd, + source_fd: RtFd, + token: Token, + interests: Interests, +) -> Result<(), ErrorCode> { + let vdso_poll_set: extern "C" fn(RtFd, RtFd, u64, u64) -> ErrorCode = unsafe { + core::mem::transmute( + RtVdsoVtableV1::get().poll_set.load(Ordering::Relaxed) as usize as *const (), + ) + }; + + ok_or_error(vdso_poll_set(poll_fd, source_fd, token, interests)) +} + +pub fn del(poll_fd: RtFd, source_fd: RtFd) -> Result<(), ErrorCode> { + let vdso_poll_del: extern "C" fn(RtFd, RtFd) -> ErrorCode = unsafe { + core::mem::transmute( + RtVdsoVtableV1::get().poll_del.load(Ordering::Relaxed) as usize as *const (), + ) + }; + + ok_or_error(vdso_poll_del(poll_fd, source_fd)) +} + +pub fn wait( + poll_fd: RtFd, + events: *mut Event, + events_num: usize, + timeout: Option, +) -> Result { + let vdso_poll_wait: extern "C" fn(RtFd, u64, *mut Event, usize) -> i32 = unsafe { + core::mem::transmute( + RtVdsoVtableV1::get().poll_wait.load(Ordering::Relaxed) as usize as *const (), + ) + }; + + let timeout = if let Some(timo) = timeout { + timo.as_u64() + } else { + u64::MAX + }; + + let res = vdso_poll_wait(poll_fd, timeout, events, events_num); + if res < 0 { + return Err((-res) as ErrorCode); + } + + Ok(res as usize) +} diff --git a/src/sys/lib/rt.vdso/src/main.rs b/src/sys/lib/rt.vdso/src/main.rs index d344685..d926666 100644 --- a/src/sys/lib/rt.vdso/src/main.rs +++ b/src/sys/lib/rt.vdso/src/main.rs @@ -10,10 +10,12 @@ mod rt_alloc; mod rt_fs; mod rt_futex; mod rt_net; +mod rt_poll; mod rt_process; mod rt_thread; mod rt_time; mod rt_tls; +mod runtime; mod stdio; #[macro_use] @@ -300,6 +302,10 @@ pub extern "C" fn _rt_entry(version: u64) { vtable .net_bind .store(rt_net::bind as *const () as usize as u64, Ordering::Relaxed); + vtable.net_listen.store( + rt_net::listen as *const () as usize as u64, + Ordering::Relaxed, + ); vtable.net_accept.store( rt_net::accept as *const () as usize as u64, Ordering::Relaxed, @@ -325,6 +331,24 @@ pub extern "C" fn _rt_entry(version: u64) { Ordering::Relaxed, ); + // Poll. + vtable + .poll_new + .store(rt_poll::new as *const () as usize as u64, Ordering::Relaxed); + vtable + .poll_add + .store(rt_poll::add as *const () as usize as u64, Ordering::Relaxed); + vtable + .poll_set + .store(rt_poll::set as *const () as usize as u64, Ordering::Relaxed); + vtable + .poll_del + .store(rt_poll::del as *const () as usize as u64, Ordering::Relaxed); + vtable.poll_wait.store( + rt_poll::wait as *const () as usize as u64, + Ordering::Relaxed, + ); + // The final fence. core::sync::atomic::fence(core::sync::atomic::Ordering::Release); diff --git a/src/sys/lib/rt.vdso/src/posix.rs b/src/sys/lib/rt.vdso/src/posix.rs index 50ded95..3d24077 100644 --- a/src/sys/lib/rt.vdso/src/posix.rs +++ b/src/sys/lib/rt.vdso/src/posix.rs @@ -12,9 +12,12 @@ use crate::stdio::Stdio; use alloc::collections::VecDeque; use alloc::sync::Arc; use alloc::vec::Vec; +use moto_rt::poll::Interests; +use moto_rt::poll::Token; use moto_rt::ErrorCode; use moto_rt::RtFd; use moto_rt::E_BAD_HANDLE; +use moto_rt::E_INVALID_ARGUMENT; use moto_rt::E_OK; pub trait PosixFile: Any + Send + Sync { @@ -30,6 +33,18 @@ pub trait PosixFile: Any + Send + Sync { fn close(&self) -> Result<(), ErrorCode> { Err(E_BAD_HANDLE) } + fn poll_add(&self, poll_fd: RtFd, token: Token, interests: Interests) -> Result<(), ErrorCode> { + todo!() + // Err(E_INVALID_ARGUMENT) + } + fn poll_set(&self, poll_fd: RtFd, token: Token, interests: Interests) -> Result<(), ErrorCode> { + todo!() + // Err(E_INVALID_ARGUMENT) + } + fn poll_del(&self, poll_fd: RtFd) -> Result<(), ErrorCode> { + todo!() + // Err(E_INVALID_ARGUMENT) + } } pub extern "C" fn posix_read(rt_fd: i32, buf: *mut u8, buf_sz: usize) -> i64 { diff --git a/src/sys/lib/rt.vdso/src/rt_net.rs b/src/sys/lib/rt.vdso/src/rt_net.rs index 47478c2..2e7eded 100644 --- a/src/sys/lib/rt.vdso/src/rt_net.rs +++ b/src/sys/lib/rt.vdso/src/rt_net.rs @@ -1,10 +1,14 @@ use crate::posix; use crate::posix::PosixFile; +use crate::runtime::ResponseHandler; +use crate::runtime::WaitObject; use core::any::Any; use moto_rt::error::*; use moto_rt::moto_log; use moto_rt::mutex::Mutex; use moto_rt::netc; +use moto_rt::poll::Interests; +use moto_rt::poll::Token; use moto_rt::RtFd; use moto_sys_io::api_net; @@ -54,6 +58,20 @@ pub extern "C" fn bind(proto: u8, addr: *const netc::sockaddr) -> RtFd { posix::push_file(listener) } +pub extern "C" fn listen(rt_fd: RtFd, max_backlog: u32) -> ErrorCode { + let Some(posix_file) = posix::get_file(rt_fd) else { + return E_BAD_HANDLE; + }; + let Some(listener) = (posix_file.as_ref() as &dyn Any).downcast_ref::() else { + return E_BAD_HANDLE; + }; + + match listener.listen(max_backlog) { + Ok(()) => E_OK, + Err(err) => err, + } +} + pub extern "C" fn accept(rt_fd: RtFd, peer_addr: *mut netc::sockaddr) -> RtFd { let Some(posix_file) = posix::get_file(rt_fd) else { return -(E_BAD_HANDLE as RtFd); @@ -91,41 +109,15 @@ pub unsafe extern "C" fn setsockopt(rt_fd: RtFd, option: u64, ptr: usize, len: u let Some(posix_file) = posix::get_file(rt_fd) else { return E_BAD_HANDLE; }; - let Some(tcp_stream) = (posix_file.as_ref() as &dyn Any).downcast_ref::() else { - return E_BAD_HANDLE; - }; - match option { - moto_rt::net::SO_RCVTIMEO => { - assert_eq!(len, core::mem::size_of::()); - let timeout = *(ptr as *const u64); - tcp_stream.set_read_timeout(timeout); - moto_rt::E_OK - } - moto_rt::net::SO_SNDTIMEO => { - assert_eq!(len, core::mem::size_of::()); - let timeout = *(ptr as *const u64); - tcp_stream.set_write_timeout(timeout); - moto_rt::E_OK - } - moto_rt::net::SO_SHUTDOWN => { - assert_eq!(len, 1); - let val = *(ptr as *const u8); - let read = val & moto_rt::net::SHUTDOWN_READ != 0; - let write = val & moto_rt::net::SHUTDOWN_WRITE != 0; - tcp_stream.shutdown(read, write) - } - moto_rt::net::SO_NODELAY => { - assert_eq!(len, 1); - let nodelay = *(ptr as *const u8); - tcp_stream.set_nodelay(nodelay) - } - moto_rt::net::SO_TTL => { - assert_eq!(len, 4); - let ttl = *(ptr as *const u32); - tcp_stream.set_ttl(ttl) - } - _ => panic!("unrecognized option {option}"), + if let Some(tcp_stream) = (posix_file.as_ref() as &dyn Any).downcast_ref::() { + tcp_stream.setsockopt(option, ptr, len) + } else if let Some(tcp_listener) = + (posix_file.as_ref() as &dyn Any).downcast_ref::() + { + tcp_listener.setsockopt(option, ptr, len) + } else { + E_BAD_HANDLE } } @@ -253,7 +245,7 @@ impl NetRuntime { } } - fn release_channel(&mut self, channel: Arc) { + fn release_channel(&mut self, channel: &NetChannel) { channel.reservations.fetch_sub(1, Ordering::Relaxed); if let Some(channel) = self.full_channels.remove(&channel.id()) { // TODO: maybe clear empty channels? @@ -284,8 +276,12 @@ struct NetChannel { // Threads waiting to add their msg to send_queue. send_waiters: Mutex>, + // Streams waiting for "can write" notification. + write_waiters: Mutex>>, + // Threads waiting for specific resp_id: map resp_id => (thread handle, resp). - resp_waiters: Mutex)>>, + legacy_resp_waiters: Mutex)>>, + response_handlers: Mutex>>, io_thread_join_handle: AtomicU64, io_thread_wake_handle: AtomicU64, @@ -316,6 +312,20 @@ impl NetChannel { maybe_msg = msg; should_sleep &= sleep; + if !self.send_queue.is_full() { + loop { + // Cannot use `while let Some(...) = ` because the lock + // won't be released... + let maybe_waiter = self.write_waiters.lock().pop_front(); + let Some(waiter) = maybe_waiter else { + break; + }; + if let Some(waiter) = waiter.upgrade() { + waiter.maybe_can_write(); + } + } + } + if should_sleep { assert!(maybe_msg.is_none()); @@ -366,7 +376,7 @@ impl NetChannel { // msg.command // ); - let wait_handle: Option = if msg.id == 0 { + let wait_handle: SysHandle = if msg.id == 0 { // This is an incoming packet, or similar, without a dedicated waiter. let stream_handle = msg.handle; let stream = { @@ -382,25 +392,32 @@ impl NetChannel { // and we will lose the wakeup. Sad story, don't ask... let mut rx_lock = stream.rx_waiter.lock(); stream.process_incoming_msg(msg); - rx_lock.take() + rx_lock.take().unwrap_or(SysHandle::NONE) } else { self.on_orphan_message(msg); - None + SysHandle::NONE } } else { - let mut resp_waiters = self.resp_waiters.lock(); + let mut resp_waiters = self.legacy_resp_waiters.lock(); if let Some((handle, resp)) = resp_waiters.get_mut(&msg.id) { *resp = Some(msg); - Some(*handle) + *handle } else { - panic!("unexpected msg"); + core::mem::drop(resp_waiters); + let Some(resp_handler) = self.response_handlers.lock().remove(&msg.id) else { + panic!("unexpected msg"); + }; + if let Some(handler) = resp_handler.upgrade() { + handler.on_response(msg); + } + SysHandle::NONE } }; - if let Some(wait_handle) = wait_handle { - if wait_handle.as_u64() != moto_sys::UserThreadControlBlock::get().self_handle { - let _ = moto_sys::SysCpu::wake(wait_handle); - } + if wait_handle != SysHandle::NONE + && wait_handle.as_u64() != moto_sys::UserThreadControlBlock::get().self_handle + { + let _ = moto_sys::SysCpu::wake(wait_handle); } if received_messages > 32 { @@ -489,7 +506,9 @@ impl NetChannel { next_msg_id: CachePadded::new(AtomicU64::new(1)), send_queue: crossbeam_queue::ArrayQueue::new(io_channel::CHANNEL_PAGE_COUNT), send_waiters: Mutex::new(VecDeque::new()), - resp_waiters: Mutex::new(BTreeMap::new()), + write_waiters: Mutex::new(VecDeque::new()), + legacy_resp_waiters: Mutex::new(BTreeMap::new()), + response_handlers: Mutex::new(BTreeMap::new()), io_thread_join_handle: AtomicU64::new(SysHandle::NONE.into()), io_thread_wake_handle: AtomicU64::new(SysHandle::NONE.into()), io_thread_running: CachePadded::new(AtomicBool::new(false)), @@ -532,7 +551,7 @@ impl NetChannel { assert!(self.subchannels_in_use[idx].swap(false, Ordering::AcqRel)); } - fn tcp_stream_created(self: &Arc, stream: &Arc) { + fn tcp_stream_created(&self, stream: &Arc) { assert!(self .tcp_streams .lock() @@ -540,21 +559,21 @@ impl NetChannel { .is_none()); } - fn tcp_stream_dropped(self: &Arc, handle: u64, subchannel_idx: usize) { + fn tcp_stream_dropped(&self, handle: u64, subchannel_idx: usize) { let stream = self.tcp_streams.lock().remove(&handle).unwrap(); assert_eq!(0, stream.strong_count()); self.release_subchannel(subchannel_idx); - NET.lock().release_channel(self.clone()); + NET.lock().release_channel(self); } - fn tcp_listener_created(self: &Arc, listener: &Arc) { + fn tcp_listener_created(&self, listener: &Arc) { self.tcp_listeners .lock() .insert(listener.handle, Arc::downgrade(listener)); } - fn tcp_listener_dropped(self: &Arc, handle: u64) { + fn tcp_listener_dropped(&self, handle: u64) { assert_eq!( 0, self.tcp_listeners @@ -564,10 +583,10 @@ impl NetChannel { .strong_count() ); - NET.lock().release_channel(self.clone()); + NET.lock().release_channel(self); } - fn send_msg(self: &Arc, msg: io_channel::Msg) { + fn send_msg(&self, msg: io_channel::Msg) { loop { if self.send_queue.push(msg).is_ok() { self.maybe_wake_io_thread(); @@ -582,10 +601,10 @@ impl NetChannel { } } - fn wait_for_resp(self: &Arc, resp_id: u64) -> io_channel::Msg { + fn wait_for_resp(&self, resp_id: u64) -> io_channel::Msg { loop { { - let mut recv_waiters = self.resp_waiters.lock(); + let mut recv_waiters = self.legacy_resp_waiters.lock(); if let Some(resp) = recv_waiters.get_mut(&resp_id).unwrap().1.take() { recv_waiters.remove(&resp_id); return resp; @@ -598,12 +617,12 @@ impl NetChannel { } // Send message and wait for response. - fn send_receive(self: &Arc, mut req: io_channel::Msg) -> io_channel::Msg { + fn send_receive(&self, mut req: io_channel::Msg) -> io_channel::Msg { let req_id = self.next_msg_id.fetch_add(1, Ordering::Relaxed); // Add to waiters before sending the message, otherwise the response may // arive too quickly and the receiving code will panic due to a missing waiter. - self.resp_waiters.lock().insert( + self.legacy_resp_waiters.lock().insert( req_id, ( moto_sys::UserThreadControlBlock::get().self_handle.into(), @@ -616,6 +635,42 @@ impl NetChannel { self.wait_for_resp(req_id) } + fn new_req_id(&self) -> u64 { + self.next_msg_id.fetch_add(1, Ordering::Relaxed) + } + + fn post_msg(&self, req: io_channel::Msg) -> Result<(), io_channel::Msg> { + if self.send_queue.push(req).is_ok() { + self.maybe_wake_io_thread(); + Ok(()) + } else { + Err(req) + } + } + + fn post_msg_with_response_waiter( + &self, + req: io_channel::Msg, + handler: Weak, + ) -> Result<(), ErrorCode> { + assert_ne!(0, req.id); + + // Add to response handlers before sending the message, otherwise the response may + // arive too quickly and the receiving code will panic due to a missing waiter. + assert!(self + .response_handlers + .lock() + .insert(req.id, handler) + .is_none()); + + if self.send_queue.push(req).is_ok() { + self.maybe_wake_io_thread(); + Ok(()) + } else { + Err(E_NOT_READY) + } + } + // Note: this is called from the IO thread, so must not sleep/block. fn on_orphan_message(&self, msg: io_channel::Msg) { match msg.command { @@ -623,6 +678,7 @@ impl NetChannel { // RX raced with the client dropping the sream. Need to get page to free it. let sz_read = msg.payload.args_64()[1]; if sz_read > 0 { + crate::moto_log!("orphan RX"); let _ = self.conn.get_page(msg.payload.shared_pages()[0]); } } @@ -679,6 +735,9 @@ pub struct TcpStream { local_addr: SocketAddr, remote_addr: SocketAddr, handle: u64, + wait_object: WaitObject, + nonblocking: AtomicBool, + me: Weak, // This is, most of the time, a single-producer, single-consumer queue. recv_queue: Mutex>, @@ -687,6 +746,9 @@ pub struct TcpStream { // A partially consumed incoming RX. rx_buf: Mutex>, + // A pending tx message. + tx_msg: Mutex>, + rx_waiter: Mutex>, tcp_state: AtomicU32, // rt_api::TcpState @@ -746,12 +808,45 @@ impl PosixFile for TcpStream { fn flush(&self) -> Result<(), ErrorCode> { Ok(()) } + fn close(&self) -> Result<(), ErrorCode> { Ok(()) } + + fn poll_add(&self, poll_fd: RtFd, token: Token, interests: Interests) -> Result<(), ErrorCode> { + self.wait_object.add_interests(poll_fd, token, interests)?; + self.raise_events(interests, token); + Ok(()) + } + + fn poll_set(&self, poll_fd: RtFd, token: Token, interests: Interests) -> Result<(), ErrorCode> { + self.wait_object.set_interests(poll_fd, token, interests)?; + self.raise_events(interests, token); + Ok(()) + } + + fn poll_del(&self, poll_fd: RtFd) -> Result<(), ErrorCode> { + self.wait_object.del_interests(poll_fd) + } } impl TcpStream { + fn raise_events(&self, interests: Interests, token: Token) { + let mut events = 0; + + if (interests & moto_rt::poll::POLL_WRITABLE != 0) && self.have_write_buffer_space() { + events |= moto_rt::poll::POLL_WRITABLE; + } + if ((interests & moto_rt::poll::POLL_READABLE) != 0) + && (self.rx_buf.lock().is_some() || !self.recv_queue.lock().is_empty()) + { + events |= moto_rt::poll::POLL_READABLE; + } + if events != 0 { + self.wait_object.on_event(events); + } + } + fn ack_rx(&self) { let mut req = io_channel::Msg::new(); req.command = api_net::CMD_TCP_STREAM_RX_ACK; @@ -769,6 +864,7 @@ impl TcpStream { match msg.command { api_net::CMD_TCP_STREAM_RX => { self.recv_queue.lock().push_back(msg); + self.wait_object.on_event(moto_rt::poll::POLL_READABLE); } api_net::EVT_TCP_STREAM_STATE_CHANGED => { self.tcp_state @@ -792,16 +888,23 @@ impl TcpStream { let subchannel_idx = channel.reserve_subchannel(); let subchannel_mask = api_net::io_subchannel_mask(subchannel_idx); + let mut nonblocking = false; let req = if let Some(timo) = timeout { - api_net::tcp_stream_connect_timeout_request( - socket_addr, - subchannel_mask, - Instant::now() + timo, - ) + let abs_timeout = if timo.is_zero() { + nonblocking = true; + Instant::from_u64(1) // Nonblocking + } else { + Instant::now() + timo + }; + api_net::tcp_stream_connect_timeout_request(socket_addr, subchannel_mask, abs_timeout) } else { api_net::tcp_stream_connect_request(socket_addr, subchannel_mask) }; + if nonblocking { + todo!() + } + let resp = channel.send_receive(req); if resp.status() != moto_rt::E_OK { #[cfg(debug_assertions)] @@ -813,18 +916,24 @@ impl TcpStream { ); channel.release_subchannel(subchannel_idx); - NET.lock().release_channel(channel.clone()); + NET.lock().release_channel(&channel); return Err(resp.status()); } - let inner = Arc::new(TcpStream { + let inner = Arc::new_cyclic(|me| TcpStream { local_addr: api_net::get_socket_addr(&resp.payload).unwrap(), remote_addr: *socket_addr, handle: resp.handle, + wait_object: WaitObject::new( + moto_rt::poll::POLL_READABLE | moto_rt::poll::POLL_WRITABLE, + ), + me: me.clone(), + nonblocking: AtomicBool::new(nonblocking), channel: channel.clone(), recv_queue: Mutex::new(VecDeque::new()), next_rx_seq: AtomicU64::new(1), rx_buf: Mutex::new(None), + tx_msg: Mutex::new(None), rx_waiter: Mutex::new(None), tcp_state: AtomicU32::new(api_net::TcpState::ReadWrite.into()), rx_done: AtomicBool::new(false), @@ -852,6 +961,87 @@ impl TcpStream { Ok(inner) } + unsafe fn setsockopt(&self, option: u64, ptr: usize, len: usize) -> ErrorCode { + match option { + moto_rt::net::SO_NONBLOCKING => { + assert_eq!(len, 1); + let nonblocking = *(ptr as *const u8); + if nonblocking > 1 { + return E_INVALID_ARGUMENT; + } + self.set_nonblocking(nonblocking == 1) + } + moto_rt::net::SO_RCVTIMEO => { + assert_eq!(len, core::mem::size_of::()); + let timeout = *(ptr as *const u64); + self.set_read_timeout(timeout); + moto_rt::E_OK + } + moto_rt::net::SO_SNDTIMEO => { + assert_eq!(len, core::mem::size_of::()); + let timeout = *(ptr as *const u64); + self.set_write_timeout(timeout); + moto_rt::E_OK + } + moto_rt::net::SO_SHUTDOWN => { + assert_eq!(len, 1); + let val = *(ptr as *const u8); + let read = val & moto_rt::net::SHUTDOWN_READ != 0; + let write = val & moto_rt::net::SHUTDOWN_WRITE != 0; + self.shutdown(read, write) + } + moto_rt::net::SO_NODELAY => { + assert_eq!(len, 1); + let nodelay = *(ptr as *const u8); + self.set_nodelay(nodelay) + } + moto_rt::net::SO_TTL => { + assert_eq!(len, 4); + let ttl = *(ptr as *const u32); + self.set_ttl(ttl) + } + _ => panic!("unrecognized option {option}"), + } + } + + unsafe fn getsockopt(&self, option: u64, ptr: usize, len: usize) -> ErrorCode { + match option { + moto_rt::net::SO_RCVTIMEO => { + assert_eq!(len, core::mem::size_of::()); + let timeout = self.read_timeout(); + *(ptr as *mut u64) = timeout; + moto_rt::E_OK + } + moto_rt::net::SO_SNDTIMEO => { + assert_eq!(len, core::mem::size_of::()); + let timeout = self.write_timeout(); + *(ptr as *mut u64) = timeout; + moto_rt::E_OK + } + moto_rt::net::SO_NODELAY => { + assert_eq!(len, 1); + match self.nodelay() { + Ok(nodelay) => { + *(ptr as *mut u8) = nodelay; + moto_rt::E_OK + } + Err(err) => err, + } + } + moto_rt::net::SO_TTL => { + assert_eq!(len, 4); + match self.ttl() { + Ok(ttl) => { + *(ptr as *mut u32) = ttl; + moto_rt::E_OK + } + Err(err) => err, + } + } + _ => panic!("unrecognized option {option}"), + } + } + fn set_read_timeout(&self, timeout_ns: u64) { self.rx_timeout_ns.store(timeout_ns, Ordering::Relaxed); } @@ -986,12 +1176,16 @@ impl TcpStream { Err(moto_rt::E_NOT_READY) } - pub fn read(&self, buf: &mut [u8]) -> Result { + fn read(&self, buf: &mut [u8]) -> Result { match self.poll_rx(buf) { Ok(sz) => return Ok(sz), Err(err) => assert_eq!(err, moto_rt::E_NOT_READY), } + if self.nonblocking.load(Ordering::Relaxed) { + return Err(moto_rt::E_NOT_READY); + } + let rx_timeout_ns = self.rx_timeout_ns.load(Ordering::Relaxed); let rx_timeout = if rx_timeout_ns == u64::MAX { None @@ -1070,7 +1264,33 @@ impl TcpStream { } } - pub fn write(&self, buf: &[u8]) -> Result { + fn maybe_can_write(&self) { + if self.have_write_buffer_space() { + self.wait_object.on_event(moto_rt::poll::POLL_WRITABLE); + } else { + self.channel.write_waiters.lock().push_back(self.me.clone()); + } + } + + fn have_write_buffer_space(&self) -> bool { + { + let mut tx_lock = self.tx_msg.lock(); + if let Some((msg, write_sz)) = tx_lock.take() { + if let Err(msg) = self.try_tx(msg, write_sz) { + *tx_lock = Some((msg, write_sz)); + return false; + } + } + } + + self.channel.conn.may_alloc_page(self.subchannel_mask) + } + + fn write(&self, buf: &[u8]) -> Result { + if self.nonblocking.load(Ordering::Relaxed) { + return self.write_nonblocking(buf); + } + if buf.is_empty() || !self.tcp_state().can_write() { return Ok(0); } @@ -1148,7 +1368,8 @@ impl TcpStream { ); } - let msg = api_net::tcp_stream_tx_msg(self.handle, io_page, write_sz, timestamp.as_u64()); + let msg = + api_net::tcp_stream_tx_msg(self.handle, io_page, write_sz, Instant::now().as_u64()); self.channel.send_msg(msg); self.stats_tx_bytes .fetch_add(write_sz as u64, Ordering::Relaxed); @@ -1163,6 +1384,58 @@ impl TcpStream { Ok(write_sz) } + fn write_nonblocking(&self, buf: &[u8]) -> Result { + if buf.is_empty() || !self.tcp_state().can_write() { + return Ok(0); + } + + let write_sz = buf.len().min(io_channel::PAGE_SIZE); + + // Serialize writes, as we have only one self.tx_msg to store into. + let mut tx_lock = self.tx_msg.lock(); + if tx_lock.is_some() { + return Err(moto_rt::E_NOT_READY); + } + + let Ok(io_page) = self.channel.conn.alloc_page(self.subchannel_mask) else { + return Err(moto_rt::E_NOT_READY); + }; + unsafe { + core::ptr::copy_nonoverlapping( + buf.as_ptr(), + io_page.bytes_mut().as_mut_ptr(), + write_sz, + ); + } + + let msg = + api_net::tcp_stream_tx_msg(self.handle, io_page, write_sz, Instant::now().as_u64()); + if let Err(msg) = self.try_tx(msg, write_sz) { + *tx_lock = Some((msg, write_sz)); + self.channel.write_waiters.lock().push_back(self.me.clone()); + } + + // We copied write_sz bytes out, so must return Ok(write_sz). + Ok(write_sz) + } + + fn try_tx(&self, msg: io_channel::Msg, write_sz: usize) -> Result<(), io_channel::Msg> { + self.channel.post_msg(msg)?; + + self.stats_tx_bytes + .fetch_add(write_sz as u64, Ordering::Relaxed); + #[cfg(debug_assertions)] + moto_log!( + "{}:{} stream 0x{:x} TX bytes {}", + file!(), + line!(), + self.handle, + self.stats_tx_bytes.load(Ordering::Relaxed) + ); + + Ok(()) + } + fn peer_addr(&self) -> &SocketAddr { &self.remote_addr } @@ -1266,16 +1539,41 @@ impl TcpStream { Ok(None) } - fn set_nonblocking(&self, _nonblocking: bool) -> Result<(), ErrorCode> { + fn set_nonblocking(&self, _nonblocking: bool) -> ErrorCode { todo!() } } +struct AcceptRequest { + channel: Arc, + subchannel_idx: usize, + req: moto_ipc::io_channel::Msg, +} + +struct PendingAccept { + req: AcceptRequest, + resp: moto_ipc::io_channel::Msg, +} + pub struct TcpListener { socket_addr: SocketAddr, channel: Arc, handle: u64, nonblocking: AtomicBool, + wait_object: WaitObject, + + // All outgoing accept requests are stored here: req_id => req. + accept_requests: Mutex>, + + // Incoming async accepts are stored here. Better processed + // in arrival order. + async_accepts: Mutex>, + + // Incoming sync accepts are stored here: req_id => acc; + // have to be processed by id. + sync_accepts: Mutex>, + max_backlog: AtomicU32, + me: Weak, } impl Drop for TcpListener { @@ -1292,6 +1590,50 @@ impl PosixFile for TcpListener { fn close(&self) -> Result<(), ErrorCode> { Ok(()) } + + fn poll_add(&self, poll_fd: RtFd, token: Token, interests: Interests) -> Result<(), ErrorCode> { + self.wait_object.add_interests(poll_fd, token, interests)?; + + crate::moto_log!("TODO: TcpListener: add pending accepts"); + + Ok(()) + } + + fn poll_set(&self, poll_fd: RtFd, token: Token, interests: Interests) -> Result<(), ErrorCode> { + todo!() + // Err(E_INVALID_ARGUMENT) + } + fn poll_del(&self, poll_fd: RtFd) -> Result<(), ErrorCode> { + todo!() + // Err(E_INVALID_ARGUMENT) + } +} + +impl ResponseHandler for TcpListener { + fn on_response(&self, resp: io_channel::Msg) { + let req = self.accept_requests.lock().remove(&resp.id).unwrap(); + let wake_handle = SysHandle::from_u64(req.req.wake_handle); + + if wake_handle != SysHandle::NONE { + // The accept was blocking; a thread is waiting. + assert!(self + .sync_accepts + .lock() + .insert(req.req.id, PendingAccept { req, resp }) + .is_none()); + let _ = moto_sys::SysCpu::wake(wake_handle); + return; + } + + self.async_accepts + .lock() + .push_back(PendingAccept { req, resp }); + if self.async_accepts.lock().len() < (self.max_backlog.load(Ordering::Relaxed) as usize) { + self.post_accept(false).unwrap(); // TODO: how to post an accept later? + } + + self.wait_object.on_event(moto_rt::poll::POLL_READABLE); + } } impl TcpListener { @@ -1300,15 +1642,21 @@ impl TcpListener { let channel = NET.lock().reserve_channel(); let resp = channel.send_receive(req); if resp.status() != moto_rt::E_OK { - NET.lock().release_channel(channel); + NET.lock().release_channel(&channel); return Err(resp.status()); } - let inner = Arc::new(TcpListener { + let inner = Arc::new_cyclic(|me| TcpListener { socket_addr: *socket_addr, channel: channel.clone(), handle: resp.handle, nonblocking: AtomicBool::new(false), + wait_object: WaitObject::new(moto_rt::poll::POLL_READABLE), + accept_requests: Mutex::new(BTreeMap::new()), + async_accepts: Mutex::new(VecDeque::new()), + sync_accepts: Mutex::new(BTreeMap::new()), + max_backlog: AtomicU32::new(32), + me: me.clone(), }); channel.tcp_listener_created(&inner); @@ -1323,52 +1671,92 @@ impl TcpListener { Ok(inner) } + fn listen(&self, max_backlog: u32) -> Result<(), ErrorCode> { + if !self.nonblocking.load(Ordering::Relaxed) { + return Err(E_INVALID_ARGUMENT); + } + + if max_backlog == 0 { + return Err(E_INVALID_ARGUMENT); + } + self.max_backlog.store(max_backlog, Ordering::Relaxed); + if !self.accept_requests.lock().is_empty() { + return Ok(()); // Already listening. + } + + if self.async_accepts.lock().len() >= (max_backlog as usize) { + return Ok(()); // The backlog is too large. + } + + self.post_accept(false) + .map(|_| ()) + .inspect_err(|_| panic!("TODO: what can we do here?")) + } + fn socket_addr(&self) -> Result { Ok(self.socket_addr) } - fn accept(&self) -> Result<(Arc, SocketAddr), ErrorCode> { - // Because a listener can spawn thousands, millions of sockets - // (think a long-running web server), we cannot use the listener's - // channel for incoming connections. + fn get_pending_accept(&self) -> Result { + if let Some(pending_accept) = self.async_accepts.lock().pop_front() { + return Ok(pending_accept); + } if self.nonblocking.load(Ordering::Relaxed) { - todo!() + return Err(E_NOT_READY); + }; + + let req_id = self.post_accept(true).unwrap(); // TODO: wait for channel to become ready. + loop { + { + if let Some(pending_accept) = self.sync_accepts.lock().remove(&req_id) { + return Ok(pending_accept); + } + } + + let _ = moto_sys::SysCpu::wait(&mut [], SysHandle::NONE, SysHandle::NONE, None); } - let channel = NET.lock().reserve_channel(); - let subchannel_idx = channel.reserve_subchannel(); - let subchannel_mask = api_net::io_subchannel_mask(subchannel_idx); + } - let req = api_net::accept_tcp_listener_request(self.handle, subchannel_mask); - let resp = channel.send_receive(req); - if resp.status() != moto_rt::E_OK { - channel.release_subchannel(subchannel_idx); - NET.lock().release_channel(channel); - return Err(resp.status()); + fn accept(&self) -> Result<(Arc, SocketAddr), ErrorCode> { + let pending_accept = self.get_pending_accept()?; + if pending_accept.resp.status() != moto_rt::E_OK { + pending_accept + .req + .channel + .release_subchannel(pending_accept.req.subchannel_idx); + NET.lock().release_channel(&pending_accept.req.channel); + return Err(pending_accept.resp.status()); } - let remote_addr = api_net::get_socket_addr(&resp.payload).unwrap(); + let remote_addr = api_net::get_socket_addr(&pending_accept.resp.payload).unwrap(); - let inner = Arc::new(TcpStream { + let inner = Arc::new_cyclic(|me| TcpStream { local_addr: self.socket_addr, remote_addr, - handle: resp.handle, - channel: channel.clone(), + handle: pending_accept.resp.handle, + wait_object: WaitObject::new( + moto_rt::poll::POLL_READABLE | moto_rt::poll::POLL_WRITABLE, + ), + me: me.clone(), + nonblocking: AtomicBool::new(self.nonblocking.load(Ordering::Relaxed)), + channel: pending_accept.req.channel.clone(), recv_queue: Mutex::new(VecDeque::new()), next_rx_seq: AtomicU64::new(1), rx_buf: Mutex::new(None), + tx_msg: Mutex::new(None), rx_waiter: Mutex::new(None), tcp_state: AtomicU32::new(api_net::TcpState::ReadWrite.into()), rx_done: AtomicBool::new(false), rx_timeout_ns: AtomicU64::new(u64::MAX), tx_timeout_ns: AtomicU64::new(u64::MAX), - subchannel_idx, - subchannel_mask, + subchannel_idx: pending_accept.req.subchannel_idx, + subchannel_mask: api_net::io_subchannel_mask(pending_accept.req.subchannel_idx), stats_rx_bytes: AtomicU64::new(0), stats_tx_bytes: AtomicU64::new(0), }); - channel.tcp_stream_created(&inner); + pending_accept.req.channel.tcp_stream_created(&inner); inner.ack_rx(); #[cfg(debug_assertions)] @@ -1384,7 +1772,78 @@ impl TcpListener { Ok((inner, remote_addr)) } - fn set_ttl(&self, _ttl: u32) -> Result<(), ErrorCode> { + fn post_accept(&self, blocking: bool) -> Result { + // Because a listener can spawn thousands, millions of sockets + // (think a long-running web server), we cannot use the listener's + // channel for incoming connections. + let channel = NET.lock().reserve_channel(); + let subchannel_idx = channel.reserve_subchannel(); + let subchannel_mask = api_net::io_subchannel_mask(subchannel_idx); + + let mut req = api_net::accept_tcp_listener_request(self.handle, subchannel_mask); + let req_id = channel.new_req_id(); + req.id = req_id; + if blocking { + req.wake_handle = moto_sys::UserThreadControlBlock::get().self_handle; + } + let accept_request = AcceptRequest { + channel: channel.clone(), + subchannel_idx, + req, + }; + + assert!(self + .accept_requests + .lock() + .insert(req.id, accept_request) + .is_none()); + + channel + .post_msg_with_response_waiter(req, self.me.clone()) + .inspect_err(|_| { + assert!(self.accept_requests.lock().remove(&req.id).is_some()); + channel.release_subchannel(subchannel_idx); + NET.lock().release_channel(&channel); + }) + .map(|_| req_id) + } + + unsafe fn setsockopt(&self, option: u64, ptr: usize, len: usize) -> ErrorCode { + match option { + moto_rt::net::SO_NONBLOCKING => { + assert_eq!(len, 1); + let nonblocking = *(ptr as *const u8); + if nonblocking > 1 { + return E_INVALID_ARGUMENT; + } + self.set_nonblocking(nonblocking == 1) + } + moto_rt::net::SO_TTL => { + assert_eq!(len, 4); + let ttl = *(ptr as *const u32); + self.set_ttl(ttl) + } + _ => panic!("unrecognized option {option}"), + } + } + + unsafe fn getsockopt(&self, option: u64, ptr: usize, len: usize) -> ErrorCode { + match option { + moto_rt::net::SO_TTL => { + assert_eq!(len, 4); + match self.ttl() { + Ok(ttl) => { + *(ptr as *mut u32) = ttl; + moto_rt::E_OK + } + Err(err) => err, + } + } + _ => panic!("unrecognized option {option}"), + } + } + + fn set_ttl(&self, _ttl: u32) -> ErrorCode { todo!() } @@ -1405,11 +1864,11 @@ impl TcpListener { Ok(None) } - fn set_nonblocking(&self, nonblocking: bool) -> Result<(), ErrorCode> { + fn set_nonblocking(&self, nonblocking: bool) -> ErrorCode { self.nonblocking.store(nonblocking, Ordering::Relaxed); if nonblocking { - todo!("Kick existing blocking accept()s."); + crate::moto_log!("Kick existing blocking accept()s."); } - Ok(()) + E_OK } } diff --git a/src/sys/lib/rt.vdso/src/rt_poll.rs b/src/sys/lib/rt.vdso/src/rt_poll.rs new file mode 100644 index 0000000..28ad973 --- /dev/null +++ b/src/sys/lib/rt.vdso/src/rt_poll.rs @@ -0,0 +1,72 @@ +use crate::posix; +use crate::posix::PosixFile; +use crate::runtime::Registry; +use alloc::collections::btree_map::BTreeMap; +use alloc::sync::Arc; +use core::any::Any; +use moto_rt::poll::Event; +use moto_rt::ErrorCode; +use moto_rt::RtFd; +use moto_rt::E_BAD_HANDLE; + +pub extern "C" fn new() -> RtFd { + posix::new_file(|fd| Arc::new(Registry::new(fd))) +} + +pub extern "C" fn add(poll_fd: RtFd, source_fd: RtFd, token: u64, events: u64) -> ErrorCode { + let Some(posix_file) = posix::get_file(poll_fd) else { + return E_BAD_HANDLE; + }; + let Some(registry) = (posix_file.as_ref() as &dyn Any).downcast_ref::() else { + return E_BAD_HANDLE; + }; + + registry.add(source_fd, token, events) +} + +pub extern "C" fn set(poll_fd: RtFd, source_fd: RtFd, token: u64, events: u64) -> ErrorCode { + let Some(posix_file) = posix::get_file(poll_fd) else { + return E_BAD_HANDLE; + }; + let Some(registry) = (posix_file.as_ref() as &dyn Any).downcast_ref::() else { + return E_BAD_HANDLE; + }; + + registry.set(source_fd, token, events) +} + +pub extern "C" fn del(poll_fd: RtFd, source_fd: RtFd) -> ErrorCode { + let Some(posix_file) = posix::get_file(poll_fd) else { + return E_BAD_HANDLE; + }; + let Some(registry) = (posix_file.as_ref() as &dyn Any).downcast_ref::() else { + return E_BAD_HANDLE; + }; + + registry.del(source_fd) +} + +// Returns the number of events or minus error code. +pub unsafe extern "C" fn wait( + poll_fd: RtFd, + timeout: u64, + events_ptr: *mut Event, + events_num: usize, +) -> i32 { + assert!(events_num < (i32::MAX as usize)); + + let Some(posix_file) = posix::get_file(poll_fd) else { + return -(E_BAD_HANDLE as i32); + }; + let Some(registry) = (posix_file.as_ref() as &dyn Any).downcast_ref::() else { + return -(E_BAD_HANDLE as i32); + }; + + let events = core::slice::from_raw_parts_mut(events_ptr, events_num); + let deadline = if timeout == u64::MAX { + None + } else { + Some(moto_rt::time::Instant::from_u64(timeout)) + }; + registry.wait(events, deadline) +} diff --git a/src/sys/lib/rt.vdso/src/runtime.rs b/src/sys/lib/rt.vdso/src/runtime.rs new file mode 100644 index 0000000..ed9c332 --- /dev/null +++ b/src/sys/lib/rt.vdso/src/runtime.rs @@ -0,0 +1,253 @@ +//! Runtime to support I/O and polling mechanisms. +//! +//! Somewhat similar to Linux's epoll, but supports only edge-triggered events. + +use core::any::Any; +use core::sync::atomic::AtomicU64; +use core::sync::atomic::Ordering; + +use crate::posix; +use crate::posix::PosixFile; +use crate::spin::Mutex; +use alloc::collections::btree_map::BTreeMap; +use alloc::sync::Arc; +use alloc::sync::Weak; +use moto_ipc::io_channel; +use moto_rt::poll::Event; +use moto_rt::poll::EventBits; +use moto_rt::poll::Interests; +use moto_rt::poll::Token; +use moto_rt::ErrorCode; +use moto_rt::RtFd; +use moto_rt::E_BAD_HANDLE; +use moto_rt::E_INVALID_ARGUMENT; +use moto_rt::E_OK; +use moto_rt::E_TIMED_OUT; +use moto_sys::SysHandle; + +pub trait ResponseHandler { + fn on_response(&self, resp: io_channel::Msg); +} + +/// A leaf object that can be waited on. +/// +/// Wait objects are flat, they either represent sockets (and, later, files) +/// directly, or a user-managed EventObject, which is similar to eventfd in Linux. +/// Wait objects are owned by their parent objects (e.g. sockets); +/// but a wait object can be added to multiple "Registries" with different tokens. +pub struct WaitObject { + supported_interests: Interests, + + // Registry FD -> (Registry, Token). + // TODO: is there a way to go from Arc to Arc? + // If so, then we can have Weak below. + #[allow(clippy::type_complexity)] + registries: Mutex, Token, Interests)>>, +} + +impl Drop for WaitObject { + fn drop(&mut self) { + // todo!("notify registries") + crate::moto_log!("WaitObject::drop(): notify registries"); + } +} + +impl WaitObject { + pub fn new(supported_interests: Interests) -> Self { + Self { + supported_interests, + registries: Mutex::new(BTreeMap::new()), + } + } + + pub fn add_interests( + &self, + registry_fd: RtFd, + token: Token, + interests: Interests, + ) -> Result<(), ErrorCode> { + let Some(posix_file) = posix::get_file(registry_fd) else { + return Err(E_BAD_HANDLE); + }; + if (posix_file.as_ref() as &dyn Any) + .downcast_ref::() + .is_none() + { + return Err(E_BAD_HANDLE); + } + let registry = Arc::downgrade(&posix_file); + + let mut registries = self.registries.lock(); + match registries.entry(registry_fd) { + alloc::collections::btree_map::Entry::Vacant(entry) => { + entry.insert((registry, token, interests)); + } + alloc::collections::btree_map::Entry::Occupied(_) => return Err(E_INVALID_ARGUMENT), + } + + Ok(()) + } + + pub fn set_interests( + &self, + registry_fd: RtFd, + token: Token, + interests: Interests, + ) -> Result<(), ErrorCode> { + let Some(posix_file) = posix::get_file(registry_fd) else { + return Err(E_BAD_HANDLE); + }; + if (posix_file.as_ref() as &dyn Any) + .downcast_ref::() + .is_none() + { + return Err(E_BAD_HANDLE); + } + let registry = Arc::downgrade(&posix_file); + + let mut registries = self.registries.lock(); + if let Some(val) = registries.get_mut(®istry_fd) { + *val = (registry, token, interests); + Ok(()) + } else { + Err(E_INVALID_ARGUMENT) + } + } + + pub fn del_interests(&self, registry_fd: RtFd) -> Result<(), ErrorCode> { + self.registries + .lock() + .remove(®istry_fd) + .map(|_| ()) + .ok_or(E_INVALID_ARGUMENT) + } + + pub fn on_event(&self, events: EventBits) { + let mut dropped_registries = alloc::vec::Vec::new(); + let mut registries = self.registries.lock(); + for entry in &*registries { + let (registry_id, (registry, token, interests)) = entry; + if interests & events != 0 { + if let Some(registry) = registry.upgrade() { + let registry = (registry.as_ref() as &dyn Any) + .downcast_ref::() + .unwrap(); + registry.on_event(*token, interests & events); + } else { + dropped_registries.push(*registry_id); + } + } + } + + for id in dropped_registries { + registries.remove(&id); + } + } +} + +pub struct Registry { + fd: RtFd, + events: Mutex>, + wait_handle: AtomicU64, +} + +impl PosixFile for Registry {} + +impl Registry { + pub fn new(fd: RtFd) -> Self { + Self { + fd, + events: Mutex::new(BTreeMap::new()), + wait_handle: AtomicU64::new(SysHandle::NONE.as_u64()), + } + } + + pub fn add(&self, source_fd: RtFd, token: Token, interests: Interests) -> ErrorCode { + let Some(posix_file) = posix::get_file(source_fd) else { + return E_BAD_HANDLE; + }; + + if let Err(err) = posix_file.poll_add(self.fd, token, interests) { + err + } else { + E_OK + } + } + + pub fn set(&self, source_fd: RtFd, token: Token, interests: Interests) -> ErrorCode { + let Some(posix_file) = posix::get_file(source_fd) else { + return E_BAD_HANDLE; + }; + + if let Err(err) = posix_file.poll_set(self.fd, token, interests) { + err + } else { + E_OK + } + } + + pub fn del(&self, source_fd: RtFd) -> ErrorCode { + let Some(posix_file) = posix::get_file(source_fd) else { + return E_BAD_HANDLE; + }; + + if let Err(err) = posix_file.poll_del(self.fd) { + err + } else { + E_OK + } + } + + pub fn wait(&self, events_buf: &mut [Event], deadline: Option) -> i32 { + self.wait_handle.store( + moto_sys::UserThreadControlBlock::get().self_handle, + Ordering::Release, + ); + + loop { + if !self.events.lock().is_empty() { + break; + } + let _ = moto_sys::SysCpu::wait(&mut [], SysHandle::NONE, SysHandle::NONE, deadline); + if !self.events.lock().is_empty() { + break; + } + if let Some(deadline) = deadline { + if deadline <= moto_rt::time::Instant::now() { + self.wait_handle + .store(SysHandle::NONE.as_u64(), Ordering::Release); + return -(E_TIMED_OUT as i32); + } + } + } + self.wait_handle + .store(SysHandle::NONE.as_u64(), Ordering::Release); + + let mut events = self.events.lock(); + let mut idx = 0; + while idx < events_buf.len() { + let Some((token, bits)) = events.pop_first() else { + break; + }; + let entry = &mut events_buf[idx]; + entry.token = token; + entry.events = bits; + idx += 1; + } + + (idx + 1) as i32 + } + + pub fn on_event(&self, token: Token, event_bits: EventBits) { + self.events + .lock() + .entry(token) + .and_modify(|curr| *curr |= event_bits) + .or_insert(event_bits); + + let handle = SysHandle::from_u64(self.wait_handle.load(Ordering::Acquire)); + if handle != SysHandle::NONE { + let _ = moto_sys::SysCpu::wake(handle); + } + } +} diff --git a/src/sys/tests/mio-test/Cargo.toml b/src/sys/tests/mio-test/Cargo.toml new file mode 100644 index 0000000..56b41f2 --- /dev/null +++ b/src/sys/tests/mio-test/Cargo.toml @@ -0,0 +1,7 @@ +[package] +name = "mio-test" +version = "0.1.0" +edition = "2021" + +[dependencies] +mio = { git = "https://github.com/moturus/mio.git", branch = "motor-os_20241121", features = ["net", "os-poll"]} \ No newline at end of file diff --git a/src/sys/tests/mio-test/src/main.rs b/src/sys/tests/mio-test/src/main.rs new file mode 100644 index 0000000..7d5603e --- /dev/null +++ b/src/sys/tests/mio-test/src/main.rs @@ -0,0 +1,5 @@ +mod simple; + +fn main() { + simple::test() +} diff --git a/src/sys/tests/mio-test/src/simple.rs b/src/sys/tests/mio-test/src/simple.rs new file mode 100644 index 0000000..b4905e4 --- /dev/null +++ b/src/sys/tests/mio-test/src/simple.rs @@ -0,0 +1,234 @@ +// A simple test: a single TCP server/listener listens; +// NUM_CLIENTS connect to the server, write "ping"; +// the server responds "pong". +const NUM_CLIENTS: usize = 4; + +// This is a modified tcp_server.rs example from toko-rs/mio. +use mio::event::Event; +use mio::net::TcpListener; +use mio::{Events, Interest, Poll, Registry, Token}; +use std::collections::HashMap; +use std::io::{self, Read, Write}; +use std::sync::atomic::AtomicBool; + +// Setup some tokens to allow us to identify which event is for which socket. +const SERVER: Token = Token(0); + +// Some data we'll send over the connection. +const PING: &[u8] = b"ping"; +const PONG: &[u8] = b"pong"; + +const ADDR: &str = "127.0.0.1:9000"; + +struct ClientConnection { + stream: mio::net::TcpStream, + ping: bool, +} + +fn server_thread(ready: &AtomicBool) -> io::Result<()> { + // Create a poll instance. + let mut poll = Poll::new()?; + // Create storage for events. + let mut events = Events::with_capacity(128); + + // Setup the TCP server socket. + let addr = ADDR.parse().unwrap(); + let mut server = TcpListener::bind(addr)?; + + // Register the server with poll we can receive events for it. + poll.registry() + .register(&mut server, SERVER, Interest::READABLE)?; + + // Map of `Token` -> `TcpStream`. + let mut connections = HashMap::new(); + // Unique token for each incoming connection. + let mut connection_token = Token(SERVER.0 + 1); + + ready.store(true, std::sync::atomic::Ordering::Release); + + let mut num_clients = 0; + + loop { + if let Err(err) = poll.poll(&mut events, None) { + if interrupted(&err) { + continue; + } + return Err(err); + } + + for event in events.iter() { + match event.token() { + SERVER => loop { + // Received an event for the TCP server socket, which + // indicates we can accept an connection. + let (mut connection, _address) = match server.accept() { + Ok((connection, address)) => (connection, address), + Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + // If we get a `WouldBlock` error we know our + // listener has no more incoming connections queued, + // so we can return to polling and wait for some + // more. + break; + } + Err(e) => { + // If it was any other kind of error, something went + // wrong and we terminate with an error. + return Err(e); + } + }; + + let token = next(&mut connection_token); + poll.registry() + .register(&mut connection, token, Interest::READABLE)?; + + connections.insert( + token, + ClientConnection { + stream: connection, + ping: false, + }, + ); + }, + token => { + // Maybe received an event for a TCP connection. + let done = if let Some(connection) = connections.get_mut(&token) { + handle_connection_event(poll.registry(), connection, event)? + } else { + // Sporadic events happen, we can safely ignore them. + false + }; + if done { + if let Some(mut connection) = connections.remove(&token) { + poll.registry().deregister(&mut connection.stream)?; + num_clients += 1; + if num_clients == NUM_CLIENTS { + return Ok(()); + } + } + } + } + } + } + } +} + +fn next(current: &mut Token) -> Token { + let next = current.0; + current.0 += 1; + Token(next) +} + +/// Returns `true` if the connection is done. +fn handle_connection_event( + registry: &Registry, + connection: &mut ClientConnection, + event: &Event, +) -> io::Result { + if event.is_writable() { + assert!(connection.ping); + // We can (maybe) write to the connection. + match connection.stream.write(PONG) { + // We want to write the entire `DATA` buffer in a single go. If we + // write less we'll return a short write error (same as + // `io::Write::write_all` does). + Ok(n) if n < PONG.len() => return Err(io::ErrorKind::WriteZero.into()), + Ok(_) => { + // After we've written something we'll reregister the connection + // to only respond to readable events. + registry.reregister(&mut connection.stream, event.token(), Interest::READABLE)? + } + // Would block "errors" are the OS's way of saying that the + // connection is not actually ready to perform this I/O operation. + Err(ref err) if would_block(err) => {} + // Got interrupted (how rude!), we'll try again. + Err(ref err) if interrupted(err) => { + return handle_connection_event(registry, connection, event) + } + // Other errors we'll consider fatal. + Err(err) => return Err(err), + } + } + + if event.is_readable() { + let mut connection_closed = false; + let mut received_data = vec![0; 4096]; + let mut bytes_read = 0; + // We can (maybe) read from the connection. + loop { + match connection.stream.read(&mut received_data[bytes_read..]) { + Ok(0) => { + // Reading 0 bytes means the other side has closed the + // connection or is done writing, then so are we. + connection_closed = true; + break; + } + Ok(n) => { + bytes_read += n; + if bytes_read == received_data.len() { + received_data.resize(received_data.len() + 1024, 0); + } + } + // Would block "errors" are the OS's way of saying that the + // connection is not actually ready to perform this I/O operation. + Err(ref err) if would_block(err) => break, + Err(ref err) if interrupted(err) => continue, + // Other errors we'll consider fatal. + Err(err) => return Err(err), + } + } + + if !connection.ping && bytes_read > 0 { + assert_eq!(bytes_read, PING.len()); + assert_eq!(&received_data[..bytes_read], PING); + connection.ping = true; + registry.reregister(&mut connection.stream, event.token(), Interest::WRITABLE)? + } + + if connection_closed { + return Ok(true); + } + } + + Ok(false) +} + +fn would_block(err: &io::Error) -> bool { + err.kind() == io::ErrorKind::WouldBlock +} + +fn interrupted(err: &io::Error) -> bool { + err.kind() == io::ErrorKind::Interrupted +} + +fn client() -> io::Result<()> { + let mut conn = std::net::TcpStream::connect(ADDR)?; + conn.write_all(PING)?; + + let mut buf = [0_u8; PONG.len()]; + conn.read_exact(&mut buf)?; + assert_eq!(buf, PONG); + + Ok(()) +} + +pub fn test() { + let server_ready = AtomicBool::new(false); + std::thread::scope(|s| { + let server = s.spawn(|| server_thread(&server_ready).unwrap()); + while !server_ready.load(std::sync::atomic::Ordering::Relaxed) { + core::hint::spin_loop(); + } + + let mut clients = Vec::with_capacity(NUM_CLIENTS); + for _ in 0..NUM_CLIENTS { + clients.push(std::thread::spawn(|| client().unwrap())); + } + + for client in clients { + client.join().unwrap(); + } + server.join().unwrap(); + }); + + println!("simple PASS"); +}