diff --git a/ttnn/cpp/ttnn/tensor/storage.hpp b/ttnn/cpp/ttnn/tensor/storage.hpp index 6436916528d..4161a3bd9dd 100644 --- a/ttnn/cpp/ttnn/tensor/storage.hpp +++ b/ttnn/cpp/ttnn/tensor/storage.hpp @@ -328,18 +328,19 @@ struct MultiDeviceStorage { auto buffer_it = buffers.find(device->id()); TT_FATAL(buffer_it != buffers.end(), "Buffer not found for device {}", device->id()); TT_ASSERT( - buffer_it->device() == device, + buffer_it->second->device() == device, "Mismatch between device derived from buffer and device derived from MultiDeviceStorage."); return buffer_it->second; } inline std::shared_ptr& get_buffer_for_device(IDevice* device) { std::lock_guard lock(buffer_mtx); - TT_FATAL(buffers.find(device->id()) != buffers.end(), "Buffer not found for device {}", device->id()); + auto buffer_it = buffers.find(device->id()); + TT_FATAL(buffer_it != buffers.end(), "Buffer not found for device {}", device->id()); TT_ASSERT( - buffers.at(device->id())->device() == device, + buffer_it->second->device() == device, "Mismatch between device derived from buffer and device derived from MultiDeviceStorage."); - return buffers.at(device->id()); + return buffer_it->second; } inline std::shared_ptr get_buffer_for_device_id(uint32_t device_id) const { @@ -349,8 +350,9 @@ struct MultiDeviceStorage { inline TensorSpec get_tensor_spec_for_device(IDevice* device) const { std::lock_guard lock(shape_mtx); - TT_FATAL(specs.find(device->id()) != specs.end(), "Shape not found for device {}", device->id()); - return specs.at(device->id()); + auto spec_it = specs.find(device->id()); + TT_FATAL(spec_it != specs.end(), "Shape not found for device {}", device->id()); + return spec_it->second; } inline uint32_t num_buffers() const {