Skip to content

Commit

Permalink
Squashed commit of the following:
Browse files Browse the repository at this point in the history
commit 3902491d7789cd7a6f0ef2bb1572abefff1073cf
Author: Stella Laurenzo <stellaraccident@gmail.com>
Date:   Tue Feb 18 20:56:02 2025 -0800

    Flip flag off.

commit 1880f4962b14071c7897a1c18770c246243784b7
Author: Stella Laurenzo <stellaraccident@gmail.com>
Date:   Tue Feb 18 20:48:51 2025 -0800

    Disable logging.

commit b29b3c5
Author: Stella Laurenzo <stellaraccident@gmail.com>
Date:   Tue Feb 18 20:30:40 2025 -0800

    Allow no alias override.

commit e5514c3
Author: Stella Laurenzo <stellaraccident@gmail.com>
Date:   Tue Feb 18 19:04:14 2025 -0800

    Do whole inv result tracking.

commit 878ed0c
Author: Stella Laurenzo <stellaraccident@gmail.com>
Date:   Tue Feb 18 12:41:43 2025 -0800

    Fix header

commit 4f2f6b8
Merge: c527839 8b806bf
Author: Stella Laurenzo <stellaraccident@gmail.com>
Date:   Tue Feb 18 12:35:35 2025 -0800

    Merge branch 'main' of github.com:nod-ai/sharktank into shortfin_async_alloc

commit c527839
Merge: e3f1eac 51cf2f4
Author: Stella Laurenzo <stellaraccident@gmail.com>
Date:   Wed Nov 13 17:41:51 2024 -0800

    Merge branch 'main' into shortfin_async_alloc

commit e3f1eac
Author: Stella Laurenzo <stellaraccident@gmail.com>
Date:   Wed Nov 13 17:09:43 2024 -0800

    [shortfin] Implement async alloc/dealloc of buffers.

    This has been a todo since day one. For device buffers, this now properly stream orders the alloc/dealloc.

Fix npe
  • Loading branch information
stellaraccident committed Feb 19, 2025
1 parent 8b806bf commit 172f914
Show file tree
Hide file tree
Showing 14 changed files with 540 additions and 132 deletions.
176 changes: 115 additions & 61 deletions shortfin/python/lib_ext.cc
Original file line number Diff line number Diff line change
Expand Up @@ -316,32 +316,97 @@ local::ProgramInvocation::Future PyFunctionCall(
return local::ProgramInvocation::Invoke(std::move(inv));
}

py::object PyRehydrateRef(local::ProgramInvocation *inv,
iree::vm_opaque_ref ref) {
auto type = ref.get()->type;
// Note that these accessors are dangerous as they assert/abort if
// process-wide registration is not done properly. We assume here that
// since we got a ref out that the basics are set up soundly, but if actually
// doing this on user/dynamic types, we would want to be more defensive.
// TODO: Don't just do a linear scan if we have more than a couple.
// TODO: Find a reliable way to statically cache the type id.
if (local::ProgramInvocationMarshalableFactory::invocation_marshalable_type<
array::device_array>() == type) {
// device_array
return py::cast(local::ProgramInvocationMarshalableFactory::
CreateFromInvocationResultRef<array::device_array>(
inv, std::move(ref)));
} else if (local::ProgramInvocationMarshalableFactory::
invocation_marshalable_type<array::storage>() == type) {
// storage
return py::cast(
local::ProgramInvocationMarshalableFactory::
CreateFromInvocationResultRef<array::storage>(inv, std::move(ref)));
// Wraps a ProgramInvocation::Ptr representing a completed (awaited) invocation.
// Holds some additional accounting for marshaling results back to Python.
class PyProgramInvocation {
public:
PyProgramInvocation(local::ProgramInvocation::Ptr inv)
: inv_(std::move(inv)) {}
PyProgramInvocation(const PyProgramInvocation &) = delete;
PyProgramInvocation(PyProgramInvocation &&other)
: inv_(std::move(other.inv_)),
cached_results_(std::move(other.cached_results_)),
results_failure_(other.results_failure_) {}

// Fields that can be bound.
bool assume_no_alias = false;
static std::optional<bool> global_assume_no_alias;

void CheckValid() {
if (!inv_) throw std::invalid_argument("Deallocated invocation");
}
throw std::invalid_argument(
fmt::format("Cannot marshal ref type {} to Python",
to_string_view(iree_vm_ref_type_name(type))));
}
local::ProgramInvocation::Ptr &inv() { return inv_; }

py::object results() {
if (results_failure_) {
throw std::logic_error("Prior attempt to marshal IREE results failed");
}
if (cached_results_) {
return cached_results_;
}

// Cache results.
CheckValid();
results_failure_ = true;

local::CoarseInvocationTimelineImporter::Options options;
options.assume_no_alias = assume_no_alias;
if (global_assume_no_alias) {
options.assume_no_alias = *global_assume_no_alias;
}
local::CoarseInvocationTimelineImporter timeline_importer(inv().get(),
options);
size_t size = inv_->results_size();
py::object tp = py::steal(PyTuple_New(size));
for (size_t i = 0; i < size; ++i) {
iree::vm_opaque_ref ref = inv_->result_ref(i);
if (!ref) {
throw new std::logic_error("Program returned unsupported Python type");
}
py::object item = RehydrateRef(std::move(ref), &timeline_importer);
PyTuple_SET_ITEM(tp.ptr(), i, item.release().ptr());
}

cached_results_ = std::move(tp);
results_failure_ = false;
return cached_results_;
}

private:
py::object RehydrateRef(
iree::vm_opaque_ref ref,
local::CoarseInvocationTimelineImporter *timeline_importer) {
auto type = ref.get()->type;
// Note that these accessors are dangerous as they assert/abort if
// process-wide registration is not done properly. We assume here that
// since we got a ref out that the basics are set up soundly, but if
// actually doing this on user/dynamic types, we would want to be more
// defensive.
// TODO: Don't just do a linear scan if we have more than a couple.
// TODO: Find a reliable way to statically cache the type id.
if (local::ProgramInvocationMarshalableFactory::invocation_marshalable_type<
array::device_array>() == type) {
// device_array
return py::cast(local::ProgramInvocationMarshalableFactory::
CreateFromInvocationResultRef<array::device_array>(
inv().get(), timeline_importer, std::move(ref)));
} else if (local::ProgramInvocationMarshalableFactory::
invocation_marshalable_type<array::storage>() == type) {
// storage
return py::cast(local::ProgramInvocationMarshalableFactory::
CreateFromInvocationResultRef<array::storage>(
inv().get(), timeline_importer, std::move(ref)));
}
throw std::invalid_argument(
fmt::format("Cannot marshal ref type {} to Python",
to_string_view(iree_vm_ref_type_name(type))));
}

local::ProgramInvocation::Ptr inv_;
py::object cached_results_;
bool results_failure_ = false;
};
std::optional<bool> PyProgramInvocation::global_assume_no_alias;

py::object RunInForeground(std::shared_ptr<Refs> refs, local::System &self,
py::object coro) {
Expand Down Expand Up @@ -743,56 +808,45 @@ void BindLocal(py::module_ &m) {
return local::ProgramModule::ParameterProvider(system, c_params);
},
py::arg("system"), py::arg("params"));
py::class_<local::ProgramInvocation::Ptr>(m, "ProgramInvocation")
py::class_<PyProgramInvocation>(m, "ProgramInvocation")
.def_rw("assume_no_alias", &PyProgramInvocation::assume_no_alias,
"Assumes that no results alias inputs or other buffers")
.def_rw_static(
"global_assume_no_alias",
&PyProgramInvocation::global_assume_no_alias,
"Globally changes the assume_no_alias flag for all invocations")
.def("invoke",
[](local::ProgramInvocation::Ptr &self) {
if (!self) throw std::invalid_argument("Deallocated invocation");
return local::ProgramInvocation::Invoke(std::move(self));
[](PyProgramInvocation &self) {
self.CheckValid();
return local::ProgramInvocation::Invoke(std::move(self.inv()));
})
.def("add_arg",
[](local::ProgramInvocation::Ptr &self, py::handle arg) {
if (!self) throw std::invalid_argument("Deallocated invocation");
py::capsule inv_capsule(self.get());
[](PyProgramInvocation &self, py::handle arg) {
self.CheckValid();
py::capsule inv_capsule(&self.inv());
PyAddProgramInvocationArg(inv_capsule, arg);
})
.def("__iter__",
[](local::ProgramInvocation::Ptr &self) {
if (!self) throw std::invalid_argument("Deallocated invocation");
size_t size = self->results_size();
py::object tp = py::steal(PyTuple_New(size));
for (size_t i = 0; i < size; ++i) {
iree::vm_opaque_ref ref = self->result_ref(i);
if (!ref) {
throw new std::logic_error(
"Program returned unsupported Python type");
}
py::object item = PyRehydrateRef(self.get(), std::move(ref));
PyTuple_SET_ITEM(tp.ptr(), i, item.release().ptr());
}
return tp.attr("__iter__")();
[](PyProgramInvocation &self) {
return self.results().attr("__iter__")();
})
.def(
"__len__",
[](local::ProgramInvocation::Ptr &self) {
if (!self) throw std::invalid_argument("Deallocated invocation");
return self->results_size();
[](PyProgramInvocation &self) {
self.CheckValid();
return self.inv()->results_size();
},
"The number of results in this invocation")
.def(
"__getitem__",
[](local::ProgramInvocation::Ptr &self, iree_host_size_t i) {
if (!self) throw std::invalid_argument("Deallocated invocation");
iree::vm_opaque_ref ref = self->result_ref(i);
if (!ref) {
throw new std::logic_error(
"Program returned unsupported Python type");
}
return PyRehydrateRef(self.get(), std::move(ref));
[](PyProgramInvocation &self, iree_host_size_t i) {
self.CheckValid();
return self.results().attr("__getitem__")(i);
},
"Gets the i'th result")
.def("__repr__", [](local::ProgramInvocation::Ptr &self) {
if (!self) return std::string("ProgramInvocation(INVALID)");
return self->to_s();
.def("__repr__", [](PyProgramInvocation &self) {
if (!self.inv()) return std::string("ProgramInvocation(INVALID)");
return self.inv()->to_s();
});

py::class_<local::BaseProgramParameters>(m, "BaseProgramParameters");
Expand Down Expand Up @@ -1207,7 +1261,7 @@ void BindLocal(py::module_ &m) {
// expensive in the C++ API: essentially, ProgramInvocations flow
// through the system precisely one way. As a low level facility, this
// is deemed acceptable.
return py::cast(std::move(result));
return py::cast(PyProgramInvocation(std::move(result)));
});
py::class_<local::MessageFuture, local::Future>(m, "MessageFuture")
.def("result", [](local::MessageFuture &self) {
Expand Down
10 changes: 6 additions & 4 deletions shortfin/src/shortfin/array/array.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ void device_array::AddAsInvocationArgument(

iree::vm_opaque_ref ref;
*(&ref) = iree_hal_buffer_view_move_ref(buffer_view);
inv->AddArg(std::move(ref));
inv->AddArg(std::move(ref), storage().timeline_resource_.get());

storage().AddInvocationArgBarrier(inv, barrier);
}
Expand All @@ -119,16 +119,18 @@ iree_vm_ref_type_t device_array::invocation_marshalable_type() {
}

device_array device_array::CreateFromInvocationResultRef(
local::ProgramInvocation *inv, iree::vm_opaque_ref ref) {
local::ProgramInvocation *inv,
local::CoarseInvocationTimelineImporter *timeline_importer,
iree::vm_opaque_ref ref) {
SHORTFIN_TRACE_SCOPE_NAMED("PyDeviceArray::CreateFromInvocationResultRef");
// We don't retain the buffer view in the device array, so just deref it
// vs stealing the ref.
iree_hal_buffer_view_t *bv = iree_hal_buffer_view_deref(*ref.get());
iree::hal_buffer_ptr buffer =
iree::hal_buffer_ptr::borrow_reference(iree_hal_buffer_view_buffer(bv));

auto imported_storage =
storage::ImportInvocationResultStorage(inv, std::move(buffer));
auto imported_storage = storage::ImportInvocationResultStorage(
inv, timeline_importer, std::move(buffer));
std::span<const iree_hal_dim_t> shape(iree_hal_buffer_view_shape_dims(bv),
iree_hal_buffer_view_shape_rank(bv));
return device_array(
Expand Down
4 changes: 3 additions & 1 deletion shortfin/src/shortfin/array/array.h
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,9 @@ class SHORTFIN_API device_array
void AddAsInvocationArgument(local::ProgramInvocation *inv,
local::ProgramResourceBarrier barrier) override;
static device_array CreateFromInvocationResultRef(
local::ProgramInvocation *inv, iree::vm_opaque_ref ref);
local::ProgramInvocation *inv,
local::CoarseInvocationTimelineImporter *timeline_importer,
iree::vm_opaque_ref ref);
static iree_vm_ref_type_t invocation_marshalable_type();
friend class shortfin::local::ProgramInvocationMarshalableFactory;
};
Expand Down
Loading

0 comments on commit 172f914

Please sign in to comment.