From 80ba7e262c090e2fcfa8341c40bf5fb530f24ac1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=87a=C4=9Fatay=20Yi=C4=9Fit=20=C5=9Eahin?= Date: Thu, 3 Oct 2024 18:39:57 +0300 Subject: [PATCH] virtq: use enum_dispatch Closes #989 --- Cargo.lock | 13 +++++++ Cargo.toml | 1 + src/drivers/fs/virtio_fs.rs | 24 ++++++------ src/drivers/net/virtio/mmio.rs | 7 ++-- src/drivers/net/virtio/mod.rs | 34 ++++++++--------- src/drivers/virtio/virtqueue/mod.rs | 29 +++++--------- src/drivers/virtio/virtqueue/packed.rs | 46 ++++++++++++----------- src/drivers/virtio/virtqueue/split.rs | 52 +++++++++++++------------- src/drivers/vsock/mod.rs | 25 +++++++------ 9 files changed, 121 insertions(+), 110 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 6ed94e3ca6..3dd24805ef 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -437,6 +437,18 @@ dependencies = [ "zerocopy-derive", ] +[[package]] +name = "enum_dispatch" +version = "0.3.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa18ce2bc66555b3218614519ac839ddb759a7d6720732f979ef8d13be147ecd" +dependencies = [ + "once_cell", + "proc-macro2", + "quote", + "syn 2.0.79", +] + [[package]] name = "event-listener" version = "5.3.1" @@ -622,6 +634,7 @@ dependencies = [ "cfg-if", "crossbeam-utils", "dyn-clone", + "enum_dispatch", "fdt", "float-cmp", "free-list", diff --git a/Cargo.toml b/Cargo.toml index 4f43cde124..8920b3d721 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -107,6 +107,7 @@ talc = { version = "4" } time = { version = "0.3", default-features = false } volatile = "0.6" zerocopy = { version = "0.7", default-features = false } +enum_dispatch = "0.3.13" [dependencies.smoltcp] version = "0.11" diff --git a/src/drivers/fs/virtio_fs.rs b/src/drivers/fs/virtio_fs.rs index 9a10dc4f5f..a90330049b 100644 --- a/src/drivers/fs/virtio_fs.rs +++ b/src/drivers/fs/virtio_fs.rs @@ -18,7 +18,7 @@ use crate::drivers::virtio::transport::pci::{ComCfg, IsrStatus, NotifCfg}; use crate::drivers::virtio::virtqueue::error::VirtqError; use crate::drivers::virtio::virtqueue::split::SplitVq; use crate::drivers::virtio::virtqueue::{ - AvailBufferToken, BufferElem, BufferType, Virtq, VqIndex, VqSize, + AvailBufferToken, BufferElem, BufferType, VirtQueue, Virtq, VqIndex, VqSize, }; use crate::fs::fuse::{self, FuseInterface, Rsp, RspHeader}; use crate::mm::device_alloc::DeviceAlloc; @@ -42,7 +42,7 @@ pub(crate) struct VirtioFsDriver { pub(super) com_cfg: ComCfg, pub(super) isr_stat: IsrStatus, pub(super) notif_cfg: NotifCfg, - pub(super) vqueues: Vec>, + pub(super) vqueues: Vec, pub(super) irq: InterruptLine, } @@ -130,15 +130,17 @@ impl VirtioFsDriver { // create the queues and tell device about them for i in 0..vqnum as u16 { - let vq = SplitVq::new( - &mut self.com_cfg, - &self.notif_cfg, - VqSize::from(VIRTIO_MAX_QUEUE_SIZE), - VqIndex::from(i), - self.dev_cfg.features.into(), - ) - .unwrap(); - self.vqueues.push(Box::new(vq)); + let vq = VirtQueue::Split( + SplitVq::new( + &mut self.com_cfg, + &self.notif_cfg, + VqSize::from(VIRTIO_MAX_QUEUE_SIZE), + VqIndex::from(i), + self.dev_cfg.features.into(), + ) + .unwrap(), + ); + self.vqueues.push(vq); } // At this point the device is "live" diff --git a/src/drivers/net/virtio/mmio.rs b/src/drivers/net/virtio/mmio.rs index ff73ef3b8e..ef09e314ba 100644 --- a/src/drivers/net/virtio/mmio.rs +++ b/src/drivers/net/virtio/mmio.rs @@ -2,7 +2,6 @@ //! //! The module contains ... -use alloc::boxed::Box; use alloc::vec::Vec; use core::str::FromStr; @@ -13,7 +12,7 @@ use volatile::VolatileRef; use crate::drivers::net::virtio::{CtrlQueue, NetDevCfg, RxQueues, TxQueues, VirtioNetDriver}; use crate::drivers::virtio::error::{VirtioError, VirtioNetError}; use crate::drivers::virtio::transport::mmio::{ComCfg, IsrStatus, NotifCfg}; -use crate::drivers::virtio::virtqueue::Virtq; +use crate::drivers::virtio::virtqueue::VirtQueue; // Backend-dependent interface for Virtio network driver impl VirtioNetDriver { @@ -46,8 +45,8 @@ impl VirtioNetDriver { 1514 }; - let send_vqs = TxQueues::new(Vec::>::new(), &dev_cfg); - let recv_vqs = RxQueues::new(Vec::>::new(), &dev_cfg); + let send_vqs = TxQueues::new(Vec::::new(), &dev_cfg); + let recv_vqs = RxQueues::new(Vec::::new(), &dev_cfg); Ok(VirtioNetDriver { dev_cfg, com_cfg: ComCfg::new(registers, 1), diff --git a/src/drivers/net/virtio/mod.rs b/src/drivers/net/virtio/mod.rs index 4c39d12b8b..ecbaa09d0c 100644 --- a/src/drivers/net/virtio/mod.rs +++ b/src/drivers/net/virtio/mod.rs @@ -32,7 +32,7 @@ use crate::drivers::virtio::transport::pci::{ComCfg, IsrStatus, NotifCfg}; use crate::drivers::virtio::virtqueue::packed::PackedVq; use crate::drivers::virtio::virtqueue::split::SplitVq; use crate::drivers::virtio::virtqueue::{ - AvailBufferToken, BufferElem, BufferType, UsedBufferToken, Virtq, VqIndex, VqSize, + AvailBufferToken, BufferElem, BufferType, UsedBufferToken, VirtQueue, Virtq, VqIndex, VqSize, }; use crate::executor::device::{RxToken, TxToken}; use crate::mm::device_alloc::DeviceAlloc; @@ -46,23 +46,23 @@ pub(crate) struct NetDevCfg { pub features: virtio::net::F, } -pub struct CtrlQueue(Option>); +pub struct CtrlQueue(Option); impl CtrlQueue { - pub fn new(vq: Option>) -> Self { + pub fn new(vq: Option) -> Self { CtrlQueue(vq) } } pub struct RxQueues { - vqs: Vec>, + vqs: Vec, poll_sender: async_channel::Sender, poll_receiver: async_channel::Receiver, packet_size: u32, } impl RxQueues { - pub fn new(vqs: Vec>, dev_cfg: &NetDevCfg) -> Self { + pub fn new(vqs: Vec, dev_cfg: &NetDevCfg) -> Self { let (poll_sender, poll_receiver) = async_channel::unbounded(); // See Virtio specification v1.1 - 5.1.6.3.1 @@ -92,11 +92,11 @@ impl RxQueues { /// Adds a given queue to the underlying vector and populates the queue with RecvBuffers. /// /// Queues are all populated according to Virtio specification v1.1. - 5.1.6.3.1 - fn add(&mut self, mut vq: Box) { + fn add(&mut self, mut vq: VirtQueue) { const BUFF_PER_PACKET: u16 = 2; let num_packets: u16 = u16::from(vq.size()) / BUFF_PER_PACKET; fill_queue( - vq.as_mut(), + &mut vq, num_packets, self.packet_size, self.poll_sender.clone(), @@ -185,14 +185,14 @@ fn fill_queue( /// Structure which handles transmission of packets and delegation /// to the respective queue structures. pub struct TxQueues { - vqs: Vec>, + vqs: Vec, /// Indicates, whether the Driver/Device are using multiple /// queues for communication. packet_length: u32, } impl TxQueues { - pub fn new(vqs: Vec>, dev_cfg: &NetDevCfg) -> Self { + pub fn new(vqs: Vec, dev_cfg: &NetDevCfg) -> Self { let packet_length = if dev_cfg.features.contains(virtio::net::F::GUEST_TSO4) | dev_cfg.features.contains(virtio::net::F::GUEST_TSO6) | dev_cfg.features.contains(virtio::net::F::GUEST_UFO) @@ -224,7 +224,7 @@ impl TxQueues { } } - fn add(&mut self, vq: Box) { + fn add(&mut self, vq: VirtQueue) { // Currently we are doing nothing with the additional queues. They are inactive and might be used in the // future self.vqs.push(vq); @@ -370,7 +370,7 @@ impl NetworkDriver for VirtioNetDriver { } fill_queue( - self.recv_vqs.vqs[0].as_mut(), + &mut self.recv_vqs.vqs[0], num_buffers, self.recv_vqs.packet_size, self.recv_vqs.poll_sender.clone(), @@ -672,7 +672,7 @@ impl VirtioNetDriver { // Add a control if feature is negotiated if self.dev_cfg.features.contains(virtio::net::F::CTRL_VQ) { if self.dev_cfg.features.contains(virtio::net::F::RING_PACKED) { - self.ctrl_vq = CtrlQueue(Some(Box::new( + self.ctrl_vq = CtrlQueue(Some(VirtQueue::Packed( PackedVq::new( &mut self.com_cfg, &self.notif_cfg, @@ -683,7 +683,7 @@ impl VirtioNetDriver { .unwrap(), ))); } else { - self.ctrl_vq = CtrlQueue(Some(Box::new( + self.ctrl_vq = CtrlQueue(Some(VirtQueue::Split( SplitVq::new( &mut self.com_cfg, &self.notif_cfg, @@ -759,7 +759,7 @@ impl VirtioNetDriver { // Interrupt for receiving packets is wanted vq.enable_notifs(); - self.recv_vqs.add(Box::from(vq)); + self.recv_vqs.add(VirtQueue::Packed(vq)); let mut vq = PackedVq::new( &mut self.com_cfg, @@ -772,7 +772,7 @@ impl VirtioNetDriver { // Interrupt for comunicating that a sended packet left, is not needed vq.disable_notifs(); - self.send_vqs.add(Box::from(vq)); + self.send_vqs.add(VirtQueue::Packed(vq)); } else { let mut vq = SplitVq::new( &mut self.com_cfg, @@ -785,7 +785,7 @@ impl VirtioNetDriver { // Interrupt for receiving packets is wanted vq.enable_notifs(); - self.recv_vqs.add(Box::from(vq)); + self.recv_vqs.add(VirtQueue::Split(vq)); let mut vq = SplitVq::new( &mut self.com_cfg, @@ -798,7 +798,7 @@ impl VirtioNetDriver { // Interrupt for comunicating that a sended packet left, is not needed vq.disable_notifs(); - self.send_vqs.add(Box::from(vq)); + self.send_vqs.add(VirtQueue::Split(vq)); } } diff --git a/src/drivers/virtio/virtqueue/mod.rs b/src/drivers/virtio/virtqueue/mod.rs index 2ae238091d..e3d90b3567 100644 --- a/src/drivers/virtio/virtqueue/mod.rs +++ b/src/drivers/virtio/virtqueue/mod.rs @@ -20,13 +20,12 @@ use core::mem::MaybeUninit; use core::{mem, ptr}; use async_channel::TryRecvError; +use enum_dispatch::enum_dispatch; +use packed::PackedVq; +use split::SplitVq; use virtio::{le32, le64, pvirtq, virtq}; use self::error::VirtqError; -#[cfg(not(feature = "pci"))] -use super::transport::mmio::{ComCfg, NotifCfg}; -#[cfg(feature = "pci")] -use super::transport::pci::{ComCfg, NotifCfg}; use crate::arch::mm::{paging, VirtAddr}; use crate::mm::device_alloc::DeviceAlloc; @@ -99,6 +98,7 @@ type UsedBufferTokenSender = async_channel::Sender; /// might not provide the complete feature set of each queue. Drivers who /// do need these features should refrain from providing support for both /// Virtqueue types and use the structs directly instead. +#[enum_dispatch] pub trait Virtq { /// The `notif` parameter indicates if the driver wants to have a notification for this specific /// transfer. This is only for performance optimization. As it is NOT ensured, that the device sees the @@ -193,21 +193,6 @@ pub trait Virtq { notif: bool, ) -> Result<(), VirtqError>; - /// Creates a new Virtq of the specified [VqSize] and the [VqIndex]. - /// The index represents the "ID" of the virtqueue. - /// Upon creation the virtqueue is "registered" at the device via the `ComCfg` struct. - /// - /// Be aware, that devices define a maximum number of queues and a maximal size they can handle. - fn new( - com_cfg: &mut ComCfg, - notif_cfg: &NotifCfg, - size: VqSize, - index: VqIndex, - features: virtio::F, - ) -> Result - where - Self: Sized; - /// Returns the size of a Virtqueue. This represents the overall size and not the capacity the /// queue currently has for new descriptors. fn size(&self) -> VqSize; @@ -292,6 +277,12 @@ trait VirtqPrivate { } } +#[enum_dispatch(Virtq)] +pub(crate) enum VirtQueue { + Split(SplitVq), + Packed(PackedVq), +} + trait VirtqDescriptor { fn flags_mut(&mut self) -> &mut virtq::DescF; diff --git a/src/drivers/virtio/virtqueue/packed.rs b/src/drivers/virtio/virtqueue/packed.rs index b1ea248eee..9d17b74e7b 100644 --- a/src/drivers/virtio/virtqueue/packed.rs +++ b/src/drivers/virtio/virtqueue/packed.rs @@ -649,7 +649,30 @@ impl Virtq for PackedVq { self.index } - fn new( + fn size(&self) -> VqSize { + self.size + } + + fn has_used_buffers(&self) -> bool { + let desc = &self.descr_ring.ring[usize::from(self.descr_ring.poll_index)]; + self.descr_ring.is_marked_used(desc.flags) + } +} + +impl VirtqPrivate for PackedVq { + type Descriptor = pvirtq::Desc; + + fn create_indirect_ctrl( + buffer_tkn: &AvailBufferToken, + ) -> Result, VirtqError> { + Ok(Self::descriptor_iter(buffer_tkn)? + .collect::>() + .into_boxed_slice()) + } +} + +impl PackedVq { + pub(crate) fn new( com_cfg: &mut ComCfg, notif_cfg: &NotifCfg, size: VqSize, @@ -740,25 +763,4 @@ impl Virtq for PackedVq { last_next: Default::default(), }) } - - fn size(&self) -> VqSize { - self.size - } - - fn has_used_buffers(&self) -> bool { - let desc = &self.descr_ring.ring[usize::from(self.descr_ring.poll_index)]; - self.descr_ring.is_marked_used(desc.flags) - } -} - -impl VirtqPrivate for PackedVq { - type Descriptor = pvirtq::Desc; - - fn create_indirect_ctrl( - buffer_tkn: &AvailBufferToken, - ) -> Result, VirtqError> { - Ok(Self::descriptor_iter(buffer_tkn)? - .collect::>() - .into_boxed_slice()) - } } diff --git a/src/drivers/virtio/virtqueue/split.rs b/src/drivers/virtio/virtqueue/split.rs index fb1d9c5908..d40073f823 100644 --- a/src/drivers/virtio/virtqueue/split.rs +++ b/src/drivers/virtio/virtqueue/split.rs @@ -230,7 +230,33 @@ impl Virtq for SplitVq { self.index } - fn new( + fn size(&self) -> VqSize { + self.size + } + + fn has_used_buffers(&self) -> bool { + self.ring.read_idx != self.ring.used_ring().idx.to_ne() + } +} + +impl VirtqPrivate for SplitVq { + type Descriptor = virtq::Desc; + fn create_indirect_ctrl( + buffer_tkn: &AvailBufferToken, + ) -> Result, VirtqError> { + Ok(Self::descriptor_iter(buffer_tkn)? + .zip(1..) + .map(|(descriptor, next_id)| Self::Descriptor { + next: next_id.into(), + ..descriptor + }) + .collect::>() + .into_boxed_slice()) + } +} + +impl SplitVq { + pub(crate) fn new( com_cfg: &mut ComCfg, notif_cfg: &NotifCfg, size: VqSize, @@ -318,28 +344,4 @@ impl Virtq for SplitVq { index, }) } - - fn size(&self) -> VqSize { - self.size - } - - fn has_used_buffers(&self) -> bool { - self.ring.read_idx != self.ring.used_ring().idx.to_ne() - } -} - -impl VirtqPrivate for SplitVq { - type Descriptor = virtq::Desc; - fn create_indirect_ctrl( - buffer_tkn: &AvailBufferToken, - ) -> Result, VirtqError> { - Ok(Self::descriptor_iter(buffer_tkn)? - .zip(1..) - .map(|(descriptor, next_id)| Self::Descriptor { - next: next_id.into(), - ..descriptor - }) - .collect::>() - .into_boxed_slice()) - } } diff --git a/src/drivers/vsock/mod.rs b/src/drivers/vsock/mod.rs index 7ab72e3008..5005da5565 100644 --- a/src/drivers/vsock/mod.rs +++ b/src/drivers/vsock/mod.rs @@ -12,6 +12,7 @@ use pci_types::InterruptLine; use virtio::vsock::Hdr; use virtio::FeatureBits; +use super::virtio::virtqueue::VirtQueue; use crate::config::VIRTIO_MAX_QUEUE_SIZE; use crate::drivers::virtio::error::VirtioVsockError; #[cfg(feature = "pci")] @@ -67,7 +68,7 @@ fn fill_queue( } pub(crate) struct RxQueue { - vq: Option>, + vq: Option, poll_sender: async_channel::Sender, poll_receiver: async_channel::Receiver, packet_size: u32, @@ -85,12 +86,12 @@ impl RxQueue { } } - pub fn add(&mut self, mut vq: Box) { + pub fn add(&mut self, mut vq: VirtQueue) { const BUFF_PER_PACKET: u16 = 2; let num_packets: u16 = u16::from(vq.size()) / BUFF_PER_PACKET; info!("num_packets {}", num_packets); fill_queue( - vq.as_mut(), + &mut vq, num_packets, self.packet_size, self.poll_sender.clone(), @@ -144,7 +145,7 @@ impl RxQueue { if let Some(ref mut vq) = self.vq { f(&header, &packet[..]); - fill_queue(vq.as_mut(), 1, self.packet_size, self.poll_sender.clone()); + fill_queue(vq, 1, self.packet_size, self.poll_sender.clone()); } else { panic!("Invalid length of receive queue"); } @@ -153,7 +154,7 @@ impl RxQueue { } pub(crate) struct TxQueue { - vq: Option>, + vq: Option, /// Indicates, whether the Driver/Device are using multiple /// queues for communication. packet_length: u32, @@ -167,7 +168,7 @@ impl TxQueue { } } - pub fn add(&mut self, vq: Box) { + pub fn add(&mut self, vq: VirtQueue) { self.vq = Some(vq); } @@ -224,7 +225,7 @@ impl TxQueue { } pub(crate) struct EventQueue { - vq: Option>, + vq: Option, poll_sender: async_channel::Sender, poll_receiver: async_channel::Receiver, packet_size: u32, @@ -245,11 +246,11 @@ impl EventQueue { /// Adds a given queue to the underlying vector and populates the queue with RecvBuffers. /// /// Queues are all populated according to Virtio specification v1.1. - 5.1.6.3.1 - fn add(&mut self, mut vq: Box) { + fn add(&mut self, mut vq: VirtQueue) { const BUFF_PER_PACKET: u16 = 2; let num_packets: u16 = u16::from(vq.size()) / BUFF_PER_PACKET; fill_queue( - vq.as_mut(), + &mut vq, num_packets, self.packet_size, self.poll_sender.clone(), @@ -398,7 +399,7 @@ impl VirtioVsockDriver { } // create the queues and tell device about them - self.recv_vq.add(Box::new( + self.recv_vq.add(VirtQueue::Split( SplitVq::new( &mut self.com_cfg, &self.notif_cfg, @@ -411,7 +412,7 @@ impl VirtioVsockDriver { // Interrupt for receiving packets is wanted self.recv_vq.enable_notifs(); - self.send_vq.add(Box::new( + self.send_vq.add(VirtQueue::Split( SplitVq::new( &mut self.com_cfg, &self.notif_cfg, @@ -425,7 +426,7 @@ impl VirtioVsockDriver { self.send_vq.disable_notifs(); // create the queues and tell device about them - self.event_vq.add(Box::new( + self.event_vq.add(VirtQueue::Split( SplitVq::new( &mut self.com_cfg, &self.notif_cfg,