Skip to content

Commit

Permalink
[shortfin] Implement async alloc/dealloc of buffers. (#507)
Browse files Browse the repository at this point in the history
* Device allocations are now async, queue ordered alloc/dealloc.
* Program invocations asynchronously deallocate function call results if
it can. If it ever cannot, then a small tracy zone
`SyncImportTimelineResource` will be emitted per result that cannot be
async deallocated.
* Adds `ProgramInvocation.assume_no_alias` instance boolean to disable
the assumption which allows async deallocation to work.
* Adds global `ProgramIncovation.global_no_alias` property to control
process-wide.

This is a very fiddly optimization which requires (esp in multi-device
cases) a number of things to line up. Tested on amdgpu and CPU with a
number of sample workloads (with logging enabled and visually
confirmed).

See #980 for detailed analysis and further work required.
  • Loading branch information
stellaraccident authored Feb 19, 2025
1 parent 888a98a commit b299af3
Show file tree
Hide file tree
Showing 14 changed files with 569 additions and 134 deletions.
176 changes: 115 additions & 61 deletions shortfin/python/lib_ext.cc
Original file line number Diff line number Diff line change
Expand Up @@ -316,32 +316,97 @@ local::ProgramInvocation::Future PyFunctionCall(
return local::ProgramInvocation::Invoke(std::move(inv));
}

py::object PyRehydrateRef(local::ProgramInvocation *inv,
iree::vm_opaque_ref ref) {
auto type = ref.get()->type;
// Note that these accessors are dangerous as they assert/abort if
// process-wide registration is not done properly. We assume here that
// since we got a ref out that the basics are set up soundly, but if actually
// doing this on user/dynamic types, we would want to be more defensive.
// TODO: Don't just do a linear scan if we have more than a couple.
// TODO: Find a reliable way to statically cache the type id.
if (local::ProgramInvocationMarshalableFactory::invocation_marshalable_type<
array::device_array>() == type) {
// device_array
return py::cast(local::ProgramInvocationMarshalableFactory::
CreateFromInvocationResultRef<array::device_array>(
inv, std::move(ref)));
} else if (local::ProgramInvocationMarshalableFactory::
invocation_marshalable_type<array::storage>() == type) {
// storage
return py::cast(
local::ProgramInvocationMarshalableFactory::
CreateFromInvocationResultRef<array::storage>(inv, std::move(ref)));
// Wraps a ProgramInvocation::Ptr representing a completed (awaited) invocation.
// Holds some additional accounting for marshaling results back to Python.
class PyProgramInvocation {
public:
PyProgramInvocation(local::ProgramInvocation::Ptr inv)
: inv_(std::move(inv)) {}
PyProgramInvocation(const PyProgramInvocation &) = delete;
PyProgramInvocation(PyProgramInvocation &&other)
: inv_(std::move(other.inv_)),
cached_results_(std::move(other.cached_results_)),
results_failure_(other.results_failure_) {}

// Fields that can be bound.
bool assume_no_alias = true;
static std::optional<bool> global_assume_no_alias;

void CheckValid() {
if (!inv_) throw std::invalid_argument("Deallocated invocation");
}
throw std::invalid_argument(
fmt::format("Cannot marshal ref type {} to Python",
to_string_view(iree_vm_ref_type_name(type))));
}
local::ProgramInvocation::Ptr &inv() { return inv_; }

py::object results() {
if (results_failure_) {
throw std::logic_error("Prior attempt to marshal IREE results failed");
}
if (cached_results_) {
return cached_results_;
}

// Cache results.
CheckValid();
results_failure_ = true;

local::CoarseInvocationTimelineImporter::Options options;
options.assume_no_alias = assume_no_alias;
if (global_assume_no_alias) {
options.assume_no_alias = *global_assume_no_alias;
}
local::CoarseInvocationTimelineImporter timeline_importer(inv().get(),
options);
size_t size = inv_->results_size();
py::object tp = py::steal(PyTuple_New(size));
for (size_t i = 0; i < size; ++i) {
iree::vm_opaque_ref ref = inv_->result_ref(i);
if (!ref) {
throw new std::logic_error("Program returned unsupported Python type");
}
py::object item = RehydrateRef(std::move(ref), &timeline_importer);
PyTuple_SET_ITEM(tp.ptr(), i, item.release().ptr());
}

cached_results_ = std::move(tp);
results_failure_ = false;
return cached_results_;
}

private:
py::object RehydrateRef(
iree::vm_opaque_ref ref,
local::CoarseInvocationTimelineImporter *timeline_importer) {
auto type = ref.get()->type;
// Note that these accessors are dangerous as they assert/abort if
// process-wide registration is not done properly. We assume here that
// since we got a ref out that the basics are set up soundly, but if
// actually doing this on user/dynamic types, we would want to be more
// defensive.
// TODO: Don't just do a linear scan if we have more than a couple.
// TODO: Find a reliable way to statically cache the type id.
if (local::ProgramInvocationMarshalableFactory::invocation_marshalable_type<
array::device_array>() == type) {
// device_array
return py::cast(local::ProgramInvocationMarshalableFactory::
CreateFromInvocationResultRef<array::device_array>(
inv().get(), timeline_importer, std::move(ref)));
} else if (local::ProgramInvocationMarshalableFactory::
invocation_marshalable_type<array::storage>() == type) {
// storage
return py::cast(local::ProgramInvocationMarshalableFactory::
CreateFromInvocationResultRef<array::storage>(
inv().get(), timeline_importer, std::move(ref)));
}
throw std::invalid_argument(
fmt::format("Cannot marshal ref type {} to Python",
to_string_view(iree_vm_ref_type_name(type))));
}

local::ProgramInvocation::Ptr inv_;
py::object cached_results_;
bool results_failure_ = false;
};
std::optional<bool> PyProgramInvocation::global_assume_no_alias;

py::object RunInForeground(std::shared_ptr<Refs> refs, local::System &self,
py::object coro) {
Expand Down Expand Up @@ -743,56 +808,45 @@ void BindLocal(py::module_ &m) {
return local::ProgramModule::ParameterProvider(system, c_params);
},
py::arg("system"), py::arg("params"));
py::class_<local::ProgramInvocation::Ptr>(m, "ProgramInvocation")
py::class_<PyProgramInvocation>(m, "ProgramInvocation")
.def_rw("assume_no_alias", &PyProgramInvocation::assume_no_alias,
"Assumes that no results alias inputs or other buffers")
.def_rw_static(
"global_assume_no_alias",
&PyProgramInvocation::global_assume_no_alias,
"Globally changes the assume_no_alias flag for all invocations")
.def("invoke",
[](local::ProgramInvocation::Ptr &self) {
if (!self) throw std::invalid_argument("Deallocated invocation");
return local::ProgramInvocation::Invoke(std::move(self));
[](PyProgramInvocation &self) {
self.CheckValid();
return local::ProgramInvocation::Invoke(std::move(self.inv()));
})
.def("add_arg",
[](local::ProgramInvocation::Ptr &self, py::handle arg) {
if (!self) throw std::invalid_argument("Deallocated invocation");
py::capsule inv_capsule(self.get());
[](PyProgramInvocation &self, py::handle arg) {
self.CheckValid();
py::capsule inv_capsule(&self.inv());
PyAddProgramInvocationArg(inv_capsule, arg);
})
.def("__iter__",
[](local::ProgramInvocation::Ptr &self) {
if (!self) throw std::invalid_argument("Deallocated invocation");
size_t size = self->results_size();
py::object tp = py::steal(PyTuple_New(size));
for (size_t i = 0; i < size; ++i) {
iree::vm_opaque_ref ref = self->result_ref(i);
if (!ref) {
throw new std::logic_error(
"Program returned unsupported Python type");
}
py::object item = PyRehydrateRef(self.get(), std::move(ref));
PyTuple_SET_ITEM(tp.ptr(), i, item.release().ptr());
}
return tp.attr("__iter__")();
[](PyProgramInvocation &self) {
return self.results().attr("__iter__")();
})
.def(
"__len__",
[](local::ProgramInvocation::Ptr &self) {
if (!self) throw std::invalid_argument("Deallocated invocation");
return self->results_size();
[](PyProgramInvocation &self) {
self.CheckValid();
return self.inv()->results_size();
},
"The number of results in this invocation")
.def(
"__getitem__",
[](local::ProgramInvocation::Ptr &self, iree_host_size_t i) {
if (!self) throw std::invalid_argument("Deallocated invocation");
iree::vm_opaque_ref ref = self->result_ref(i);
if (!ref) {
throw new std::logic_error(
"Program returned unsupported Python type");
}
return PyRehydrateRef(self.get(), std::move(ref));
[](PyProgramInvocation &self, iree_host_size_t i) {
self.CheckValid();
return self.results().attr("__getitem__")(i);
},
"Gets the i'th result")
.def("__repr__", [](local::ProgramInvocation::Ptr &self) {
if (!self) return std::string("ProgramInvocation(INVALID)");
return self->to_s();
.def("__repr__", [](PyProgramInvocation &self) {
if (!self.inv()) return std::string("ProgramInvocation(INVALID)");
return self.inv()->to_s();
});

py::class_<local::BaseProgramParameters>(m, "BaseProgramParameters");
Expand Down Expand Up @@ -1207,7 +1261,7 @@ void BindLocal(py::module_ &m) {
// expensive in the C++ API: essentially, ProgramInvocations flow
// through the system precisely one way. As a low level facility, this
// is deemed acceptable.
return py::cast(std::move(result));
return py::cast(PyProgramInvocation(std::move(result)));
});
py::class_<local::MessageFuture, local::Future>(m, "MessageFuture")
.def("result", [](local::MessageFuture &self) {
Expand Down
10 changes: 6 additions & 4 deletions shortfin/src/shortfin/array/array.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ void device_array::AddAsInvocationArgument(

iree::vm_opaque_ref ref;
*(&ref) = iree_hal_buffer_view_move_ref(buffer_view);
inv->AddArg(std::move(ref));
inv->AddArg(std::move(ref), storage().timeline_resource_.get());

storage().AddInvocationArgBarrier(inv, barrier);
}
Expand All @@ -119,16 +119,18 @@ iree_vm_ref_type_t device_array::invocation_marshalable_type() {
}

device_array device_array::CreateFromInvocationResultRef(
local::ProgramInvocation *inv, iree::vm_opaque_ref ref) {
local::ProgramInvocation *inv,
local::CoarseInvocationTimelineImporter *timeline_importer,
iree::vm_opaque_ref ref) {
SHORTFIN_TRACE_SCOPE_NAMED("PyDeviceArray::CreateFromInvocationResultRef");
// We don't retain the buffer view in the device array, so just deref it
// vs stealing the ref.
iree_hal_buffer_view_t *bv = iree_hal_buffer_view_deref(*ref.get());
iree::hal_buffer_ptr buffer =
iree::hal_buffer_ptr::borrow_reference(iree_hal_buffer_view_buffer(bv));

auto imported_storage =
storage::ImportInvocationResultStorage(inv, std::move(buffer));
auto imported_storage = storage::ImportInvocationResultStorage(
inv, timeline_importer, std::move(buffer));
std::span<const iree_hal_dim_t> shape(iree_hal_buffer_view_shape_dims(bv),
iree_hal_buffer_view_shape_rank(bv));
return device_array(
Expand Down
4 changes: 3 additions & 1 deletion shortfin/src/shortfin/array/array.h
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,9 @@ class SHORTFIN_API device_array
void AddAsInvocationArgument(local::ProgramInvocation *inv,
local::ProgramResourceBarrier barrier) override;
static device_array CreateFromInvocationResultRef(
local::ProgramInvocation *inv, iree::vm_opaque_ref ref);
local::ProgramInvocation *inv,
local::CoarseInvocationTimelineImporter *timeline_importer,
iree::vm_opaque_ref ref);
static iree_vm_ref_type_t invocation_marshalable_type();
friend class shortfin::local::ProgramInvocationMarshalableFactory;
};
Expand Down
Loading

0 comments on commit b299af3

Please sign in to comment.