Skip to content

Commit

Permalink
Tweaks to promises (#364)
Browse files Browse the repository at this point in the history
Refactor promises a bit so that they hide more internal details and will later support boxed callbacks
  • Loading branch information
jadamcrain authored May 23, 2024
1 parent 509e381 commit 62e35d7
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 66 deletions.
4 changes: 2 additions & 2 deletions dnp3/src/master/association.rs
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ impl TaskStates {
if self.time_sync.is_pending() {
if let Some(procedure) = config.auto_time_sync {
return self.time_sync.create_next_task(|| {
TimeSync(TimeSyncTask::get_procedure(procedure, Promise::None)).wrap()
TimeSync(TimeSyncTask::get_procedure(procedure, None)).wrap()
});
}
}
Expand Down Expand Up @@ -680,7 +680,7 @@ impl Association {
None => Next::None,
Some(next) => {
if now >= next {
Next::Now(Task::LinkStatus(Promise::None))
Next::Now(Task::LinkStatus(Promise::null()))
} else {
Next::NotBefore(next)
}
Expand Down
56 changes: 25 additions & 31 deletions dnp3/src/master/handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,8 @@ impl MasterChannel {

/// Get the current decoding level used by this master
pub async fn get_decode_level(&mut self) -> Result<DecodeLevel, Shutdown> {
let (tx, rx) = tokio::sync::oneshot::channel::<Result<DecodeLevel, Shutdown>>();
self.send_master_message(MasterMsg::GetDecodeLevel(Promise::OneShot(tx)))
let (promise, rx) = Promise::one_shot();
self.send_master_message(MasterMsg::GetDecodeLevel(promise))
.await?;
rx.await?
}
Expand Down Expand Up @@ -161,7 +161,7 @@ impl MasterChannel {
) -> Result<AssociationHandle, AssociationError> {
self.assert_channel_type(MasterChannelType::Stream)?;

let (tx, rx) = tokio::sync::oneshot::channel::<Result<(), AssociationError>>();
let (promise, rx) = Promise::one_shot();
let addr = FragmentAddr {
link: address,
phys: PhysAddr::None,
Expand All @@ -172,7 +172,7 @@ impl MasterChannel {
read_handler,
assoc_handler,
assoc_information,
Promise::OneShot(tx),
promise,
))
.await?;
rx.await?
Expand All @@ -196,7 +196,7 @@ impl MasterChannel {
) -> Result<AssociationHandle, AssociationError> {
self.assert_channel_type(MasterChannelType::Udp)?;

let (tx, rx) = tokio::sync::oneshot::channel::<Result<(), AssociationError>>();
let (promise, rx) = Promise::one_shot();
let addr = FragmentAddr {
link: address,
phys: PhysAddr::Udp(destination),
Expand All @@ -207,7 +207,7 @@ impl MasterChannel {
read_handler,
assoc_handler,
assoc_information,
Promise::OneShot(tx),
promise,
))
.await?;
rx.await?
Expand Down Expand Up @@ -266,14 +266,9 @@ impl AssociationHandle {
request: ReadRequest,
period: Duration,
) -> Result<PollHandle, PollError> {
let (tx, rx) = tokio::sync::oneshot::channel::<Result<PollHandle, PollError>>();
self.send_poll_message(PollMsg::AddPoll(
self.clone(),
request,
period,
Promise::OneShot(tx),
))
.await?;
let (promise, rx) = Promise::one_shot();
self.send_poll_message(PollMsg::AddPoll(self.clone(), request, period, promise))
.await?;
rx.await?
}

Expand All @@ -289,8 +284,8 @@ impl AssociationHandle {
///
/// If successful, the [ReadHandler](ReadHandler) will process the received measurement data
pub async fn read(&mut self, request: ReadRequest) -> Result<(), TaskError> {
let (tx, rx) = tokio::sync::oneshot::channel::<Result<(), TaskError>>();
let task = SingleReadTask::new(request, Promise::OneShot(tx));
let (promise, rx) = Promise::one_shot();
let task = SingleReadTask::new(request, promise);
self.send_task(task).await?;
rx.await?
}
Expand All @@ -305,8 +300,8 @@ impl AssociationHandle {
function: FunctionCode,
headers: Headers,
) -> Result<(), WriteError> {
let (tx, rx) = tokio::sync::oneshot::channel::<Result<(), WriteError>>();
let task = EmptyResponseTask::new(function, headers, Promise::OneShot(tx));
let (promise, rx) = Promise::one_shot();
let task = EmptyResponseTask::new(function, headers, promise);
self.send_task(task).await?;
rx.await?
}
Expand All @@ -319,8 +314,8 @@ impl AssociationHandle {
request: ReadRequest,
handler: Box<dyn ReadHandler>,
) -> Result<(), TaskError> {
let (tx, rx) = tokio::sync::oneshot::channel::<Result<(), TaskError>>();
let task = SingleReadTask::new_with_custom_handler(request, handler, Promise::OneShot(tx));
let (promise, rx) = Promise::one_shot();
let task = SingleReadTask::new_with_custom_handler(request, handler, promise);
self.send_task(task).await?;
rx.await?
}
Expand All @@ -333,8 +328,8 @@ impl AssociationHandle {
mode: CommandMode,
headers: CommandHeaders,
) -> Result<(), CommandError> {
let (tx, rx) = tokio::sync::oneshot::channel::<Result<(), CommandError>>();
let task = CommandTask::from_mode(mode, headers, Promise::OneShot(tx));
let (promise, rx) = Promise::one_shot();
let task = CommandTask::from_mode(mode, headers, promise);
self.send_task(task).await?;
rx.await?
}
Expand All @@ -354,8 +349,8 @@ impl AssociationHandle {
}

async fn restart(&mut self, restart_type: RestartType) -> Result<Duration, TaskError> {
let (tx, rx) = tokio::sync::oneshot::channel::<Result<Duration, TaskError>>();
let task = RestartTask::new(restart_type, Promise::OneShot(tx));
let (promise, rx) = Promise::one_shot();
let task = RestartTask::new(restart_type, promise);
self.send_task(task).await?;
rx.await?
}
Expand All @@ -365,8 +360,8 @@ impl AssociationHandle {
&mut self,
procedure: TimeSyncProcedure,
) -> Result<(), TimeSyncError> {
let (tx, rx) = tokio::sync::oneshot::channel::<Result<(), TimeSyncError>>();
let task = TimeSyncTask::get_procedure(procedure, Promise::OneShot(tx));
let (promise, rx) = Promise::one_shot();
let task = TimeSyncTask::get_procedure(procedure, Some(promise));
self.send_task(task).await?;
rx.await?
}
Expand All @@ -376,8 +371,8 @@ impl AssociationHandle {
&mut self,
headers: Vec<DeadBandHeader>,
) -> Result<(), WriteError> {
let (tx, rx) = tokio::sync::oneshot::channel::<Result<(), WriteError>>();
let task = WriteDeadBandsTask::new(headers, Promise::OneShot(tx));
let (promise, rx) = Promise::one_shot();
let task = WriteDeadBandsTask::new(headers, promise);
self.send_task(task).await?;
rx.await?
}
Expand All @@ -390,9 +385,8 @@ impl AssociationHandle {
/// If a [`TaskError::UnexpectedResponseHeaders`] is returned, the link might be alive
/// but it didn't answer with the expected `LINK_STATUS`.
pub async fn check_link_status(&mut self) -> Result<(), TaskError> {
let (tx, rx) = tokio::sync::oneshot::channel::<Result<(), TaskError>>();
self.send_task(Task::LinkStatus(Promise::OneShot(tx)))
.await?;
let (promise, rx) = Promise::one_shot();
self.send_task(Task::LinkStatus(promise)).await?;
rx.await?
}

Expand Down
48 changes: 39 additions & 9 deletions dnp3/src/master/promise.rs
Original file line number Diff line number Diff line change
@@ -1,24 +1,54 @@
pub(crate) type CallbackType<T> = Box<dyn FnOnce(T) + Send + Sync + 'static>;

/// A generic callback type that must be invoked once and only once.
/// The user can select to implement it using FnOnce or a
/// one-shot reply channel
pub(crate) enum Promise<T> {
/// nothing happens when the promise is completed
None,
enum Inner<T> {
/// one-shot reply channel is consumed when the promise is completed
OneShot(tokio::sync::oneshot::Sender<T>),
/// Boxed FnOnce
#[allow(dead_code)]
CallBack(CallbackType<T>, T),
}

pub(crate) struct Promise<T> {
inner: Option<Inner<T>>,
}

impl<T> Promise<T> {
pub(crate) fn null() -> Self {
Self { inner: None }
}

fn new(inner: Inner<T>) -> Self {
Self { inner: Some(inner) }
}

pub(crate) fn one_shot() -> (Self, tokio::sync::oneshot::Receiver<T>) {
let (tx, rx) = tokio::sync::oneshot::channel();
(Self::OneShot(tx), rx)
(Self::new(Inner::OneShot(tx)), rx)
}

pub(crate) fn complete(mut self, value: T) {
if let Some(x) = self.inner.take() {
match x {
Inner::OneShot(s) => {
s.send(value).ok();
}
Inner::CallBack(cb, _) => cb(value),
}
}
}
}

pub(crate) fn complete(self, value: T) {
match self {
Promise::None => {}
Promise::OneShot(s) => {
s.send(value).ok();
impl<T> Drop for Promise<T> {
fn drop(&mut self) {
if let Some(x) = self.inner.take() {
match x {
Inner::OneShot(_) => {}
Inner::CallBack(cb, default) => {
cb(default);
}
}
}
}
Expand Down
16 changes: 4 additions & 12 deletions dnp3/src/master/task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -318,19 +318,11 @@ impl MasterSession {
res.map(|_| ())
}
Task::LinkStatus(promise) => {
match self
let res = self
.run_link_status_task(io, task.dest, writer, reader)
.await
{
Ok(result) => {
promise.complete(Ok(result));
Ok(())
}
Err(err) => {
promise.complete(Err(err));
Err(err)
}
}
.await;
promise.complete(res);
res
}
};

Expand Down
23 changes: 11 additions & 12 deletions dnp3/src/master/tasks/time.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ enum State {

pub(crate) struct TimeSyncTask {
state: State,
promise: Promise<Result<(), TimeSyncError>>,
promise: Option<Promise<Result<(), TimeSyncError>>>,
}

impl From<TimeSyncTask> for Task {
Expand All @@ -42,7 +42,7 @@ impl TimeSyncProcedure {
}

impl TimeSyncTask {
fn new(state: State, promise: Promise<Result<(), TimeSyncError>>) -> Self {
fn new(state: State, promise: Option<Promise<Result<(), TimeSyncError>>>) -> Self {
Self { state, promise }
}

Expand All @@ -52,7 +52,7 @@ impl TimeSyncTask {

pub(crate) fn get_procedure(
procedure: TimeSyncProcedure,
promise: Promise<Result<(), TimeSyncError>>,
promise: Option<Promise<Result<(), TimeSyncError>>>,
) -> Self {
Self::new(procedure.get_start_state(), promise)
}
Expand Down Expand Up @@ -110,12 +110,12 @@ impl TimeSyncTask {

pub(crate) fn on_task_error(self, association: Option<&mut Association>, err: TaskError) {
match self.promise {
Promise::None => {
None => {
if let Some(association) = association {
association.on_time_sync_failure(err.into());
}
}
_ => self.promise.complete(Err(err.into())),
Some(x) => x.complete(Err(err.into())),
}
}

Expand Down Expand Up @@ -292,15 +292,15 @@ impl TimeSyncTask {

fn report_success(self, association: &mut Association) {
match self.promise {
Promise::None => association.on_time_sync_success(),
_ => self.promise.complete(Ok(())),
None => association.on_time_sync_success(),
Some(x) => x.complete(Ok(())),
}
}

fn report_error(self, association: &mut Association, error: TimeSyncError) {
match self.promise {
Promise::None => association.on_time_sync_failure(error),
_ => self.promise.complete(Err(error)),
None => association.on_time_sync_failure(error),
Some(x) => x.complete(Err(error)),
}
}
}
Expand Down Expand Up @@ -395,9 +395,8 @@ mod tests {
handler(system_time),
Box::new(NullAssociationInformation),
);
let (tx, rx) = tokio::sync::oneshot::channel();
let task =
NonReadTask::TimeSync(TimeSyncTask::get_procedure(procedure, Promise::OneShot(tx)));
let (promise, rx) = Promise::one_shot();
let task = NonReadTask::TimeSync(TimeSyncTask::get_procedure(procedure, Some(promise)));

(task, system_time, association, rx)
}
Expand Down

0 comments on commit 62e35d7

Please sign in to comment.