diff --git a/rust/helpers.c b/rust/helpers.c index 58a194042c86db..96441744030e4b 100644 --- a/rust/helpers.c +++ b/rust/helpers.c @@ -100,6 +100,12 @@ bool rust_helper_refcount_dec_and_test(refcount_t *r) } EXPORT_SYMBOL_GPL(rust_helper_refcount_dec_and_test); +struct task_struct *rust_helper_get_current(void) +{ + return current; +} +EXPORT_SYMBOL_GPL(rust_helper_get_current); + void rust_helper_get_task_struct(struct task_struct *t) { get_task_struct(t); diff --git a/rust/kernel/task.rs b/rust/kernel/task.rs index 8d7a8222990fa2..8b2b56ba9c6d36 100644 --- a/rust/kernel/task.rs +++ b/rust/kernel/task.rs @@ -5,7 +5,7 @@ //! C header: [`include/linux/sched.h`](../../../../include/linux/sched.h). use crate::bindings; -use core::{cell::UnsafeCell, ptr}; +use core::{cell::UnsafeCell, marker::PhantomData, ops::Deref, ptr}; /// Wraps the kernel's `struct task_struct`. /// @@ -13,6 +13,46 @@ use core::{cell::UnsafeCell, ptr}; /// /// Instances of this type are always ref-counted, that is, a call to `get_task_struct` ensures /// that the allocation remains valid at least until the matching call to `put_task_struct`. +/// +/// # Examples +/// +/// The following is an example of getting the PID of the current thread with zero additional cost +/// when compared to the C version: +/// +/// ``` +/// use kernel::task::Task; +/// +/// let pid = Task::current().pid(); +/// ``` +/// +/// Getting the PID of the current process, also zero additional cost: +/// +/// ``` +/// use kernel::task::Task; +/// +/// let pid = Task::current().group_leader().pid(); +/// ``` +/// +/// Getting the current task and storing it in some struct. The reference count is automatically +/// incremented when creating `State` and decremented when it is dropped: +/// +/// ``` +/// use kernel::{task::Task, ARef}; +/// +/// struct State { +/// creator: ARef, +/// index: u32, +/// } +/// +/// impl State { +/// fn new() -> Self { +/// Self { +/// creator: Task::current().into(), +/// index: 0, +/// } +/// } +/// } +/// ``` #[repr(transparent)] pub struct Task(pub(crate) UnsafeCell); @@ -25,6 +65,20 @@ unsafe impl Sync for Task {} type Pid = bindings::pid_t; impl Task { + /// Returns a task reference for the currently executing task/thread. + pub fn current<'a>() -> TaskRef<'a> { + // SAFETY: Just an FFI call with no additional safety requirements. + let ptr = unsafe { bindings::get_current() }; + + TaskRef { + // SAFETY: If the current thread is still running, the current task is valid. Given + // that `TaskRef` is not `Send`, we know it cannot be transferred to another thread + // (where it could potentially outlive the caller). + task: unsafe { &*ptr.cast() }, + _not_send: PhantomData, + } + } + /// Returns the group leader of the given task. pub fn group_leader(&self) -> &Task { // SAFETY: By the type invariant, we know that `self.0` is valid. @@ -69,3 +123,30 @@ unsafe impl crate::types::AlwaysRefCounted for Task { unsafe { bindings::put_task_struct(obj.cast().as_ptr()) } } } + +/// A wrapper for a shared reference to [`Task`] that isn't [`Send`]. +/// +/// We make this explicitly not [`Send`] so that we can use it to represent the current thread +/// without having to increment/decrement the task's reference count. +/// +/// # Invariants +/// +/// The wrapped [`Task`] remains valid for the lifetime of the object. +pub struct TaskRef<'a> { + task: &'a Task, + _not_send: PhantomData<*mut ()>, +} + +impl Deref for TaskRef<'_> { + type Target = Task; + + fn deref(&self) -> &Self::Target { + self.task + } +} + +impl From> for crate::types::ARef { + fn from(t: TaskRef<'_>) -> Self { + t.deref().into() + } +}