diff --git a/inc/coro/mutex.hpp b/inc/coro/mutex.hpp index 6e5646e..de27f52 100644 --- a/inc/coro/mutex.hpp +++ b/inc/coro/mutex.hpp @@ -61,11 +61,7 @@ class scoped_lock class mutex { public: - explicit mutex() noexcept - : m_unlocked_value(&m_state), // This address is guaranteed to be unique and un-used elsewhere. - m_state(const_cast(m_unlocked_value)) - { - } + explicit mutex() noexcept : m_state(const_cast(unlocked_value())) {} ~mutex() = default; mutex(const mutex&) = delete; @@ -110,18 +106,18 @@ class mutex private: friend class lock_operation; - /// Inactive value, this cannot be nullptr since we want nullptr to signify that the mutex - /// is locked but there are zero waiters, this makes it easy to CAS new waiters into the - /// m_state linked list. - const void* m_unlocked_value; - - /// unlocked -> state == m_unlocked_value + /// unlocked -> state == unlocked_value() /// locked but empty waiter list == nullptr /// locked with waiters == lock_operation* std::atomic m_state; /// A list of grabbed internal waiters that are only accessed by the unlock()'er. lock_operation* m_internal_waiters{nullptr}; + + /// Inactive value, this cannot be nullptr since we want nullptr to signify that the mutex + /// is locked but there are zero waiters, this makes it easy to CAS new waiters into the + /// m_state linked list. + auto unlocked_value() const noexcept -> const void* { return &m_state; } }; } // namespace coro diff --git a/src/mutex.cpp b/src/mutex.cpp index 1385b8d..3d8563f 100644 --- a/src/mutex.cpp +++ b/src/mutex.cpp @@ -13,6 +13,7 @@ auto scoped_lock::unlock() -> void { if (m_mutex != nullptr) { + std::atomic_thread_fence(std::memory_order::release); m_mutex->unlock(); // Only allow a scoped lock to unlock the mutex a single time. m_mutex = nullptr; @@ -35,9 +36,10 @@ auto mutex::lock_operation::await_suspend(std::coroutine_handle<> awaiting_corou void* current = m_mutex.m_state.load(std::memory_order::acquire); void* new_value; + const void* unlocked_value = m_mutex.unlocked_value(); do { - if (current == m_mutex.m_unlocked_value) + if (current == unlocked_value) { // If the current value is 'unlocked' then attempt to lock it. new_value = nullptr; @@ -52,8 +54,9 @@ auto mutex::lock_operation::await_suspend(std::coroutine_handle<> awaiting_corou } while (!m_mutex.m_state.compare_exchange_weak(current, new_value, std::memory_order::acq_rel)); // Don't suspend if the state went from unlocked -> locked with zero waiters. - if (current == m_mutex.m_unlocked_value) + if (current == unlocked_value) { + std::atomic_thread_fence(std::memory_order::acquire); return false; } @@ -63,7 +66,7 @@ auto mutex::lock_operation::await_suspend(std::coroutine_handle<> awaiting_corou auto mutex::try_lock() -> bool { - void* expected = const_cast(m_unlocked_value); + void* expected = const_cast(unlocked_value()); return m_state.compare_exchange_strong(expected, nullptr, std::memory_order::acq_rel, std::memory_order::relaxed); } @@ -78,7 +81,7 @@ auto mutex::unlock() -> void // mutex as unlocked. if (m_state.compare_exchange_strong( current, - const_cast(m_unlocked_value), + const_cast(unlocked_value()), std::memory_order::release, std::memory_order::relaxed)) {