diff --git a/Cargo.lock b/Cargo.lock index bee7cbe5f..4311bded8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1181,6 +1181,7 @@ dependencies = [ "pretty_assertions", "proptest", "proptest-derive", + "pyo3", "smol_str", "thiserror 2.0.12", ] @@ -1202,6 +1203,17 @@ dependencies = [ "thiserror 2.0.12", ] +[[package]] +name = "hugr-py" +version = "0.1.0" +dependencies = [ + "bumpalo", + "hugr-model", + "paste", + "pyo3", + "smol_str", +] + [[package]] name = "hyper" version = "1.6.0" @@ -1425,6 +1437,12 @@ dependencies = [ "hashbrown 0.15.2", ] +[[package]] +name = "indoc" +version = "2.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f4c7245a08504955605670dbf141fceab975f15ca21570696aebe9d2e71576bd" + [[package]] name = "inkwell" version = "0.5.0" @@ -1647,6 +1665,15 @@ version = "2.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" +[[package]] +name = "memoffset" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a" +dependencies = [ + "autocfg", +] + [[package]] name = "mime" version = "0.3.17" @@ -1960,6 +1987,12 @@ dependencies = [ "plotters-backend", ] +[[package]] +name = "portable-atomic" +version = "1.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "350e9b48cbc6b0e028b0473b114454c6316e57336ee184ceab6e53f72c178b3e" + [[package]] name = "portgraph" version = "0.13.3" @@ -2093,6 +2126,69 @@ dependencies = [ "proptest", ] +[[package]] +name = "pyo3" +version = "0.23.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7778bffd85cf38175ac1f545509665d0b9b92a198ca7941f131f85f7a4f9a872" +dependencies = [ + "cfg-if", + "indoc", + "libc", + "memoffset", + "once_cell", + "portable-atomic", + "pyo3-build-config", + "pyo3-ffi", + "pyo3-macros", + "unindent", +] + +[[package]] +name = "pyo3-build-config" +version = "0.23.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94f6cbe86ef3bf18998d9df6e0f3fc1050a8c5efa409bf712e661a4366e010fb" +dependencies = [ + "once_cell", + "target-lexicon", +] + +[[package]] +name = "pyo3-ffi" +version = "0.23.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e9f1b4c431c0bb1c8fb0a338709859eed0d030ff6daa34368d3b152a63dfdd8d" +dependencies = [ + "libc", + "pyo3-build-config", +] + +[[package]] +name = "pyo3-macros" +version = "0.23.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fbc2201328f63c4710f68abdf653c89d8dbc2858b88c5d88b0ff38a75288a9da" +dependencies = [ + "proc-macro2", + "pyo3-macros-backend", + "quote", + "syn", +] + +[[package]] +name = "pyo3-macros-backend" +version = "0.23.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fca6726ad0f3da9c9de093d6f116a93c1a38e417ed73bf138472cf4064f72028" +dependencies = [ + "heck", + "proc-macro2", + "pyo3-build-config", + "quote", + "syn", +] + [[package]] name = "quick-error" version = "1.2.3" @@ -2611,6 +2707,12 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" +[[package]] +name = "target-lexicon" +version = "0.12.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1" + [[package]] name = "tempfile" version = "3.17.1" @@ -2846,6 +2948,12 @@ version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" +[[package]] +name = "unindent" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7264e107f553ccae879d21fbea1d6724ac785e8c3bfc762137959b5802826ef3" + [[package]] name = "unsafe-libyaml" version = "0.2.11" diff --git a/Cargo.toml b/Cargo.toml index 72ae96b9f..8aaeaa5f3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,6 +10,7 @@ members = [ "hugr-cli", "hugr-model", "hugr-llvm", + "hugr-py", ] default-members = ["hugr", "hugr-core", "hugr-passes", "hugr-cli", "hugr-model"] @@ -84,6 +85,7 @@ pest_derive = "2.7.12" pretty = "0.12.4" pretty_assertions = "1.4.1" zstd = "0.13.2" +pyo3 = "0.23.4" [profile.dev.package] insta.opt-level = 3 diff --git a/hugr-core/src/std_extensions/arithmetic/float_types.rs b/hugr-core/src/std_extensions/arithmetic/float_types.rs index 579d89e6b..3122bf30f 100644 --- a/hugr-core/src/std_extensions/arithmetic/float_types.rs +++ b/hugr-core/src/std_extensions/arithmetic/float_types.rs @@ -66,7 +66,7 @@ impl std::ops::Deref for ConstF64 { impl ConstF64 { /// Name of the constructor for creating constant 64bit floats. #[cfg_attr(not(feature = "model_unstable"), allow(dead_code))] - pub(crate) const CTR_NAME: &'static str = "arithmetic.float.const-f64"; + pub(crate) const CTR_NAME: &'static str = "arithmetic.float.const_f64"; /// Create a new [`ConstF64`] pub fn new(value: f64) -> Self { diff --git a/hugr-core/tests/model.rs b/hugr-core/tests/model.rs index 2a35acf98..b5ba7597f 100644 --- a/hugr-core/tests/model.rs +++ b/hugr-core/tests/model.rs @@ -4,7 +4,7 @@ use std::str::FromStr; use hugr::std_extensions::std_reg; use hugr_core::{export::export_hugr, import::import_hugr}; -use hugr_model::v0 as model; +use hugr_model::v0::{self as model}; fn roundtrip(source: &str) -> String { let bump = model::bumpalo::Bump::new(); diff --git a/hugr-core/tests/snapshots/model__roundtrip_const.snap b/hugr-core/tests/snapshots/model__roundtrip_const.snap index 695d5d212..c38ee9a24 100644 --- a/hugr-core/tests/snapshots/model__roundtrip_const.snap +++ b/hugr-core/tests/snapshots/model__roundtrip_const.snap @@ -12,6 +12,8 @@ expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-cons (import arithmetic.int.const) +(import arithmetic.float.const_f64) + (import core.const.adt) (import arithmetic.int.types.int) @@ -24,8 +26,6 @@ expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-cons (import core.adt) -(import arithmetic.float.const-f64) - (define-func example.bools (core.fn [] [(core.adt [[] []]) (core.adt [[] []])] (ext)) @@ -69,7 +69,7 @@ expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-cons (arithmetic.int.const 6 3) (arithmetic.int.const 6 4) (arithmetic.int.const 6 5)]) - (arithmetic.float.const-f64 -3.0)])) + (arithmetic.float.const_f64 -3.0)])) [] [%0] (signature (core.fn @@ -88,7 +88,7 @@ expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-cons [] [arithmetic.float.types.float64 arithmetic.float.types.float64] (ext))) - ((core.load_const _ _ (arithmetic.float.const-f64 1.0)) [] [%0] + ((core.load_const _ _ (arithmetic.float.const_f64 1.0)) [] [%0] (signature (core.fn [] [arithmetic.float.types.float64] (ext)))) ((core.load_const _ diff --git a/hugr-model/Cargo.toml b/hugr-model/Cargo.toml index b60d6c9af..1b1b0fdce 100644 --- a/hugr-model/Cargo.toml +++ b/hugr-model/Cargo.toml @@ -29,6 +29,7 @@ pest_derive = { workspace = true } pretty = { workspace = true } smol_str = { workspace = true, features = ["serde"] } thiserror.workspace = true +pyo3 = { workspace = true, optional = true } [lints] workspace = true diff --git a/hugr-model/src/v0/ast/mod.rs b/hugr-model/src/v0/ast/mod.rs index f30b36ad9..d7dbb65f8 100644 --- a/hugr-model/src/v0/ast/mod.rs +++ b/hugr-model/src/v0/ast/mod.rs @@ -29,6 +29,8 @@ use super::{LinkName, Literal, RegionKind, SymbolName, VarName}; mod parse; mod print; +#[cfg(feature = "pyo3")] +mod python; mod resolve; mod view; diff --git a/hugr-model/src/v0/ast/parse.rs b/hugr-model/src/v0/ast/parse.rs index 8467a8026..de707749a 100644 --- a/hugr-model/src/v0/ast/parse.rs +++ b/hugr-model/src/v0/ast/parse.rs @@ -444,3 +444,4 @@ impl_from_str!(Param, param, parse_param); impl_from_str!(Module, module, parse_module); impl_from_str!(SeqPart, part, parse_seq_part); impl_from_str!(Literal, literal, parse_literal); +impl_from_str!(Symbol, symbol, parse_symbol); diff --git a/hugr-model/src/v0/ast/print.rs b/hugr-model/src/v0/ast/print.rs index 9cfcab79c..75a45ba7f 100644 --- a/hugr-model/src/v0/ast/print.rs +++ b/hugr-model/src/v0/ast/print.rs @@ -426,6 +426,7 @@ impl_display!(Param, print_param); impl_display!(Term, print_term); impl_display!(SeqPart, print_seq_part); impl_display!(Literal, print_literal); +impl_display!(Symbol, print_symbol); impl Display for VarName { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { diff --git a/hugr-model/src/v0/ast/python.rs b/hugr-model/src/v0/ast/python.rs new file mode 100644 index 000000000..93ed394d8 --- /dev/null +++ b/hugr-model/src/v0/ast/python.rs @@ -0,0 +1,414 @@ +use std::sync::Arc; + +use super::{Module, Node, Operation, Param, Region, SeqPart, Symbol, Term}; +use pyo3::{ + exceptions::PyTypeError, + types::{PyAnyMethods, PyStringMethods as _, PyTypeMethods as _}, + Bound, PyAny, PyResult, +}; + +impl<'py> pyo3::FromPyObject<'py> for Term { + fn extract_bound(term: &Bound<'py, PyAny>) -> PyResult { + let name = term.get_type().name()?; + + Ok(match name.to_str()? { + "Wildcard" => Self::Wildcard, + "ExtSet" => Self::ExtSet, + "Var" => { + let name = term.getattr("name")?.extract()?; + Self::Var(name) + } + "Apply" => { + let symbol = term.getattr("symbol")?.extract()?; + let args: Vec<_> = term.getattr("args")?.extract()?; + Self::Apply(symbol, args.into()) + } + "List" => { + let parts: Vec<_> = term.getattr("parts")?.extract()?; + Self::List(parts.into()) + } + "Tuple" => { + let parts: Vec<_> = term.getattr("parts")?.extract()?; + Self::Tuple(parts.into()) + } + "Literal" => { + let literal = term.getattr("value")?.extract()?; + Self::Literal(literal) + } + "Func" => { + let region = term.getattr("region")?.extract()?; + Self::Func(Arc::new(region)) + } + _ => { + return Err(PyTypeError::new_err(format!( + "Unknown Term type: {}.", + name.to_str()? + ))) + } + }) + } +} + +impl<'py> pyo3::IntoPyObject<'py> for &Term { + type Target = pyo3::PyAny; + type Output = pyo3::Bound<'py, Self::Target>; + type Error = pyo3::PyErr; + + fn into_pyobject(self, py: pyo3::Python<'py>) -> Result { + let py_module = py.import("hugr.model")?; + match self { + Term::Wildcard => { + let py_class = py_module.getattr("Wildcard")?; + py_class.call0() + } + Term::Var(var_name) => { + let py_class = py_module.getattr("Var")?; + py_class.call1((var_name.as_ref(),)) + } + Term::Apply(symbol_name, terms) => { + let py_class = py_module.getattr("Apply")?; + py_class.call1((symbol_name.as_ref(), terms.as_ref())) + } + Term::List(parts) => { + let py_class = py_module.getattr("List")?; + py_class.call1((parts.as_ref(),)) + } + Term::Literal(literal) => { + let py_class = py_module.getattr("Literal")?; + py_class.call1((literal,)) + } + Term::Tuple(parts) => { + let py_class = py_module.getattr("Tuple")?; + py_class.call1((parts.as_ref(),)) + } + Term::Func(region) => { + let py_class = py_module.getattr("Func")?; + py_class.call1((region.as_ref(),)) + } + Term::ExtSet => { + let py_class = py_module.getattr("ExtSet")?; + py_class.call0() + } + } + } +} + +impl<'py> pyo3::FromPyObject<'py> for SeqPart { + fn extract_bound(part: &Bound<'py, PyAny>) -> PyResult { + let name = part.get_type().name()?; + + if name.to_str()? == "Splice" { + let term = part.getattr("seq")?.extract()?; + Ok(Self::Splice(term)) + } else { + let term = part.extract()?; + Ok(Self::Item(term)) + } + } +} + +impl<'py> pyo3::IntoPyObject<'py> for &SeqPart { + type Target = pyo3::PyAny; + type Output = pyo3::Bound<'py, Self::Target>; + type Error = pyo3::PyErr; + + fn into_pyobject(self, py: pyo3::Python<'py>) -> Result { + let py_module = py.import("hugr.model")?; + match self { + SeqPart::Item(term) => term.into_pyobject(py), + SeqPart::Splice(term) => { + let py_class = py_module.getattr("Splice")?; + py_class.call1((term,)) + } + } + } +} + +impl<'py> pyo3::FromPyObject<'py> for Param { + fn extract_bound(symbol: &Bound<'py, PyAny>) -> PyResult { + let name = symbol.getattr("name")?.extract()?; + let r#type = symbol.getattr("type")?.extract()?; + Ok(Self { name, r#type }) + } +} + +impl<'py> pyo3::IntoPyObject<'py> for &Param { + type Target = pyo3::PyAny; + type Output = pyo3::Bound<'py, Self::Target>; + type Error = pyo3::PyErr; + + fn into_pyobject(self, py: pyo3::Python<'py>) -> Result { + let py_module = py.import("hugr.model")?; + let py_class = py_module.getattr("Param")?; + py_class.call1((self.name.as_ref(), &self.r#type)) + } +} + +impl<'py> pyo3::FromPyObject<'py> for Symbol { + fn extract_bound(symbol: &Bound<'py, PyAny>) -> PyResult { + let name = symbol.getattr("name")?.extract()?; + let params: Vec<_> = symbol.getattr("params")?.extract()?; + let constraints: Vec<_> = symbol.getattr("constraints")?.extract()?; + let signature = symbol.getattr("signature")?.extract()?; + Ok(Self { + name, + signature, + params: params.into(), + constraints: constraints.into(), + }) + } +} + +impl<'py> pyo3::IntoPyObject<'py> for &Symbol { + type Target = pyo3::PyAny; + type Output = pyo3::Bound<'py, Self::Target>; + type Error = pyo3::PyErr; + + fn into_pyobject(self, py: pyo3::Python<'py>) -> Result { + let py_module = py.import("hugr.model")?; + let py_class = py_module.getattr("Symbol")?; + py_class.call1(( + self.name.as_ref(), + self.params.as_ref(), + self.constraints.as_ref(), + &self.signature, + )) + } +} + +impl<'py> pyo3::FromPyObject<'py> for Node { + fn extract_bound(node: &Bound<'py, PyAny>) -> PyResult { + let operation = node.getattr("operation")?.extract()?; + let inputs: Vec<_> = node.getattr("inputs")?.extract()?; + let outputs: Vec<_> = node.getattr("outputs")?.extract()?; + let regions: Vec<_> = node.getattr("regions")?.extract()?; + let meta: Vec<_> = node.getattr("meta")?.extract()?; + let signature = node.getattr("signature")?.extract()?; + + Ok(Self { + operation, + inputs: inputs.into(), + outputs: outputs.into(), + regions: regions.into(), + meta: meta.into(), + signature, + }) + } +} + +impl<'py> pyo3::IntoPyObject<'py> for &Node { + type Target = pyo3::PyAny; + type Output = pyo3::Bound<'py, Self::Target>; + type Error = pyo3::PyErr; + + fn into_pyobject(self, py: pyo3::Python<'py>) -> Result { + let py_module = py.import("hugr.model")?; + let py_class = py_module.getattr("Node")?; + py_class.call1(( + &self.operation, + self.inputs.as_ref(), + self.outputs.as_ref(), + self.regions.as_ref(), + self.meta.as_ref(), + &self.signature, + )) + } +} + +impl<'py> pyo3::FromPyObject<'py> for Operation { + fn extract_bound(op: &Bound<'py, PyAny>) -> PyResult { + let name = op.get_type().name()?; + + Ok(match name.to_str()? { + "InvalidOp" => Self::Invalid, + "Dfg" => Self::Dfg, + "Cfg" => Self::Cfg, + "Block" => Self::Block, + "DefineFunc" => { + let symbol = op.getattr("symbol")?.extract()?; + Self::DefineFunc(Box::new(symbol)) + } + "DeclareFunc" => { + let symbol = op.getattr("symbol")?.extract()?; + Self::DeclareFunc(Box::new(symbol)) + } + "DeclareConstructor" => { + let symbol = op.getattr("symbol")?.extract()?; + Self::DeclareConstructor(Box::new(symbol)) + } + "DeclareOperation" => { + let symbol = op.getattr("symbol")?.extract()?; + Self::DeclareOperation(Box::new(symbol)) + } + "DeclareAlias" => { + let symbol = op.getattr("symbol")?.extract()?; + Self::DeclareAlias(Box::new(symbol)) + } + "DefineAlias" => { + let symbol = op.getattr("symbol")?.extract()?; + let value = op.getattr("value")?.extract()?; + Self::DefineAlias(Box::new(symbol), value) + } + "TailLoop" => Self::TailLoop, + "Conditional" => Self::Conditional, + "Import" => { + let name = op.getattr("name")?.extract()?; + Self::Import(name) + } + "CustomOp" => { + let operation = op.getattr("operation")?.extract()?; + Self::Custom(operation) + } + _ => return Err(PyTypeError::new_err("Unknown Operation type.")), + }) + } +} + +impl<'py> pyo3::IntoPyObject<'py> for &Operation { + type Target = pyo3::PyAny; + type Output = pyo3::Bound<'py, Self::Target>; + type Error = pyo3::PyErr; + + fn into_pyobject(self, py: pyo3::Python<'py>) -> Result { + let py_module = py.import("hugr.model")?; + match self { + Operation::Invalid => { + let py_class = py_module.getattr("InvalidOp")?; + py_class.call0() + } + Operation::Dfg => { + let py_class = py_module.getattr("Dfg")?; + py_class.call0() + } + Operation::Cfg => { + let py_class = py_module.getattr("Cfg")?; + py_class.call0() + } + Operation::Block => { + let py_class = py_module.getattr("Block")?; + py_class.call0() + } + Operation::DefineFunc(symbol) => { + let py_class = py_module.getattr("DefineFunc")?; + py_class.call1((symbol.as_ref(),)) + } + Operation::DeclareFunc(symbol) => { + let py_class = py_module.getattr("DeclareFunc")?; + py_class.call1((symbol.as_ref(),)) + } + Operation::DeclareConstructor(symbol) => { + let py_class = py_module.getattr("DeclareConstructor")?; + py_class.call1((symbol.as_ref(),)) + } + Operation::DeclareOperation(symbol) => { + let py_class = py_module.getattr("DeclareOperation")?; + py_class.call1((symbol.as_ref(),)) + } + Operation::DeclareAlias(symbol) => { + let py_class = py_module.getattr("DeclareAlias")?; + py_class.call1((symbol.as_ref(),)) + } + Operation::DefineAlias(symbol, value) => { + let py_class = py_module.getattr("DefineAlias")?; + py_class.call1((symbol.as_ref(), value)) + } + Operation::TailLoop => { + let py_class = py_module.getattr("TailLoop")?; + py_class.call0() + } + Operation::Conditional => { + let py_class = py_module.getattr("Conditional")?; + py_class.call0() + } + Operation::Import(name) => { + let py_class = py_module.getattr("Import")?; + py_class.call1((name.as_ref(),)) + } + Operation::Custom(term) => { + let py_class = py_module.getattr("CustomOp")?; + py_class.call1((term,)) + } + } + } +} + +impl<'py> pyo3::FromPyObject<'py> for Region { + fn extract_bound(region: &Bound<'py, PyAny>) -> PyResult { + let kind = region.getattr("kind")?.extract()?; + let sources: Vec<_> = region.getattr("sources")?.extract()?; + let targets: Vec<_> = region.getattr("targets")?.extract()?; + let children: Vec<_> = region.getattr("children")?.extract()?; + let meta: Vec<_> = region.getattr("meta")?.extract()?; + let signature = region.getattr("signature")?.extract()?; + + Ok(Self { + kind, + sources: sources.into(), + targets: targets.into(), + children: children.into(), + meta: meta.into(), + signature, + }) + } +} + +impl<'py> pyo3::IntoPyObject<'py> for &Region { + type Target = pyo3::PyAny; + type Output = pyo3::Bound<'py, Self::Target>; + type Error = pyo3::PyErr; + + fn into_pyobject(self, py: pyo3::Python<'py>) -> Result { + let py_module = py.import("hugr.model")?; + let py_class = py_module.getattr("Region")?; + py_class.call1(( + self.kind, + self.sources.as_ref(), + self.targets.as_ref(), + self.children.as_ref(), + self.meta.as_ref(), + &self.signature, + )) + } +} + +impl<'py> pyo3::FromPyObject<'py> for Module { + fn extract_bound(region: &Bound<'py, PyAny>) -> PyResult { + let root = region.getattr("root")?.extract()?; + Ok(Self { root }) + } +} + +impl<'py> pyo3::IntoPyObject<'py> for &Module { + type Target = pyo3::PyAny; + type Output = pyo3::Bound<'py, Self::Target>; + type Error = pyo3::PyErr; + + fn into_pyobject(self, py: pyo3::Python<'py>) -> Result { + let py_module = py.import("hugr.model")?; + let py_class = py_module.getattr("Module")?; + py_class.call1((&self.root,)) + } +} + +macro_rules! impl_into_pyobject_owned { + ($ident:ty) => { + impl<'py> pyo3::IntoPyObject<'py> for $ident { + type Target = pyo3::PyAny; + type Output = pyo3::Bound<'py, Self::Target>; + type Error = pyo3::PyErr; + + fn into_pyobject(self, py: pyo3::Python<'py>) -> Result { + (&self).into_pyobject(py) + } + } + }; +} + +impl_into_pyobject_owned!(Term); +impl_into_pyobject_owned!(SeqPart); +impl_into_pyobject_owned!(Param); +impl_into_pyobject_owned!(Symbol); +impl_into_pyobject_owned!(Module); +impl_into_pyobject_owned!(Node); +impl_into_pyobject_owned!(Region); +impl_into_pyobject_owned!(Operation); diff --git a/hugr-model/src/v0/ast/resolve.rs b/hugr-model/src/v0/ast/resolve.rs index 72a2a5070..a2b83c316 100644 --- a/hugr-model/src/v0/ast/resolve.rs +++ b/hugr-model/src/v0/ast/resolve.rs @@ -19,6 +19,7 @@ pub struct Context<'a> { links: LinkTable<&'a str>, symbols: SymbolTable<'a>, imports: FxHashMap, + terms: FxHashMap, TermId>, } impl<'a> Context<'a> { @@ -30,6 +31,7 @@ impl<'a> Context<'a> { links: LinkTable::new(), symbols: SymbolTable::new(), imports: FxHashMap::default(), + terms: FxHashMap::default(), } } @@ -91,7 +93,10 @@ impl<'a> Context<'a> { Term::ExtSet => table::Term::ExtSet(&[]), }; - Ok(self.module.insert_term(term)) + Ok(*self + .terms + .entry(term.clone()) + .or_insert_with(|| self.module.insert_term(term))) } fn resolve_seq_parts(&mut self, parts: &'a [SeqPart]) -> BuildResult<&'a [table::SeqPart]> { diff --git a/hugr-model/src/v0/mod.rs b/hugr-model/src/v0/mod.rs index 32846cac3..bdb6194ae 100644 --- a/hugr-model/src/v0/mod.rs +++ b/hugr-model/src/v0/mod.rs @@ -83,6 +83,10 @@ //! [Table]: crate::v0::table //! [AST]: crate::v0::ast use ordered_float::OrderedFloat; +#[cfg(feature = "pyo3")] +use pyo3::types::PyAnyMethods as _; +#[cfg(feature = "pyo3")] +use pyo3::PyTypeInfo as _; use smol_str::SmolStr; use std::sync::Arc; use table::LinkIndex; @@ -290,6 +294,37 @@ pub enum ScopeClosure { Closed, } +#[cfg(feature = "pyo3")] +impl<'py> pyo3::FromPyObject<'py> for ScopeClosure { + fn extract_bound(ob: &pyo3::Bound<'py, pyo3::PyAny>) -> pyo3::PyResult { + let value: usize = ob.getattr("value")?.extract()?; + match value { + 0 => Ok(Self::Open), + 1 => Ok(Self::Closed), + _ => Err(pyo3::exceptions::PyTypeError::new_err( + "Invalid ScopeClosure.", + )), + } + } +} + +#[cfg(feature = "pyo3")] +impl<'py> pyo3::IntoPyObject<'py> for ScopeClosure { + type Target = pyo3::PyAny; + type Output = pyo3::Bound<'py, Self::Target>; + type Error = pyo3::PyErr; + + fn into_pyobject(self, py: pyo3::Python<'py>) -> Result { + let py_module = py.import("hugr.model")?; + let py_class = py_module.getattr("ScopeClosure")?; + + match self { + ScopeClosure::Open => py_class.getattr("OPEN"), + ScopeClosure::Closed => py_class.getattr("CLOSED"), + } + } +} + /// The kind of a region. #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default)] pub enum RegionKind { @@ -302,6 +337,39 @@ pub enum RegionKind { Module = 2, } +#[cfg(feature = "pyo3")] +impl<'py> pyo3::FromPyObject<'py> for RegionKind { + fn extract_bound(ob: &pyo3::Bound<'py, pyo3::PyAny>) -> pyo3::PyResult { + let value: usize = ob.getattr("value")?.extract()?; + match value { + 0 => Ok(Self::DataFlow), + 1 => Ok(Self::ControlFlow), + 2 => Ok(Self::Module), + _ => Err(pyo3::exceptions::PyTypeError::new_err( + "Invalid RegionKind.", + )), + } + } +} + +#[cfg(feature = "pyo3")] +impl<'py> pyo3::IntoPyObject<'py> for RegionKind { + type Target = pyo3::PyAny; + type Output = pyo3::Bound<'py, Self::Target>; + type Error = pyo3::PyErr; + + fn into_pyobject(self, py: pyo3::Python<'py>) -> Result { + let py_module = py.import("hugr.model")?; + let py_class = py_module.getattr("RegionKind")?; + + match self { + RegionKind::DataFlow => py_class.getattr("DATA_FLOW"), + RegionKind::ControlFlow => py_class.getattr("CONTROL_FLOW"), + RegionKind::Module => py_class.getattr("MODULE"), + } + } +} + /// The name of a variable. #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct VarName(SmolStr); @@ -319,6 +387,25 @@ impl AsRef for VarName { } } +#[cfg(feature = "pyo3")] +impl<'py> pyo3::FromPyObject<'py> for VarName { + fn extract_bound(ob: &pyo3::Bound<'py, pyo3::PyAny>) -> pyo3::PyResult { + let name: String = ob.extract()?; + Ok(Self::new(name)) + } +} + +#[cfg(feature = "pyo3")] +impl<'py> pyo3::IntoPyObject<'py> for &VarName { + type Target = pyo3::types::PyString; + type Output = pyo3::Bound<'py, Self::Target>; + type Error = pyo3::PyErr; + + fn into_pyobject(self, py: pyo3::Python<'py>) -> Result { + Ok(self.as_ref().into_pyobject(py)?) + } +} + /// The name of a symbol. #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct SymbolName(SmolStr); @@ -336,6 +423,14 @@ impl AsRef for SymbolName { } } +#[cfg(feature = "pyo3")] +impl<'py> pyo3::FromPyObject<'py> for SymbolName { + fn extract_bound(ob: &pyo3::Bound<'py, pyo3::PyAny>) -> pyo3::PyResult { + let name: String = ob.extract()?; + Ok(Self::new(name)) + } +} + /// The name of a link. #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct LinkName(SmolStr); @@ -359,6 +454,25 @@ impl AsRef for LinkName { } } +#[cfg(feature = "pyo3")] +impl<'py> pyo3::FromPyObject<'py> for LinkName { + fn extract_bound(ob: &pyo3::Bound<'py, pyo3::PyAny>) -> pyo3::PyResult { + let name: String = ob.extract()?; + Ok(Self::new(name)) + } +} + +#[cfg(feature = "pyo3")] +impl<'py> pyo3::IntoPyObject<'py> for &LinkName { + type Target = pyo3::types::PyString; + type Output = pyo3::Bound<'py, Self::Target>; + type Error = pyo3::PyErr; + + fn into_pyobject(self, py: pyo3::Python<'py>) -> Result { + Ok(self.as_ref().into_pyobject(py)?) + } +} + /// A static literal value. /// /// Literal values may be large since they can include strings and byte @@ -376,6 +490,47 @@ pub enum Literal { Float(OrderedFloat), } +#[cfg(feature = "pyo3")] +impl<'py> pyo3::FromPyObject<'py> for Literal { + fn extract_bound(ob: &pyo3::Bound<'py, pyo3::PyAny>) -> pyo3::PyResult { + if pyo3::types::PyString::is_type_of(ob) { + let value: String = ob.extract()?; + Ok(Literal::Str(value.into())) + } else if pyo3::types::PyInt::is_type_of(ob) { + let value: u64 = ob.extract()?; + Ok(Literal::Nat(value)) + } else if pyo3::types::PyFloat::is_type_of(ob) { + let value: f64 = ob.extract()?; + Ok(Literal::Float(value.into())) + } else if pyo3::types::PyBytes::is_type_of(ob) { + let value: Vec = ob.extract()?; + Ok(Literal::Bytes(value.into())) + } else { + Err(pyo3::exceptions::PyTypeError::new_err( + "Invalid literal value.", + )) + } + } +} + +#[cfg(feature = "pyo3")] +impl<'py> pyo3::IntoPyObject<'py> for &Literal { + type Target = pyo3::PyAny; + type Output = pyo3::Bound<'py, Self::Target>; + type Error = pyo3::PyErr; + + fn into_pyobject(self, py: pyo3::Python<'py>) -> Result { + Ok(match self { + Literal::Str(s) => s.as_str().into_pyobject(py)?.into_any(), + Literal::Nat(n) => n.into_pyobject(py)?.into_any(), + Literal::Bytes(b) => pyo3::types::PyBytes::new(py, b) + .into_pyobject(py)? + .into_any(), + Literal::Float(f) => f.0.into_pyobject(py)?.into_any(), + }) + } +} + #[cfg(test)] mod test { use super::*; diff --git a/hugr-model/tests/fixtures/model-const.edn b/hugr-model/tests/fixtures/model-const.edn index 13a161b05..b15dd3c5d 100644 --- a/hugr-model/tests/fixtures/model-const.edn +++ b/hugr-model/tests/fixtures/model-const.edn @@ -40,7 +40,7 @@ (arithmetic.int.const 6 3) (arithmetic.int.const 6 4) (arithmetic.int.const 6 5)]) - (arithmetic.float.const-f64 -3.0)))) + (arithmetic.float.const_f64 -3.0)))) [] [%0] (signature (core.fn diff --git a/hugr-py/.gitignore b/hugr-py/.gitignore new file mode 100644 index 000000000..256f6cea7 --- /dev/null +++ b/hugr-py/.gitignore @@ -0,0 +1 @@ +hugr/*.so diff --git a/hugr-py/Cargo.toml b/hugr-py/Cargo.toml new file mode 100644 index 000000000..8990db3b2 --- /dev/null +++ b/hugr-py/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "hugr-py" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +[lib] +name = "hugr_py" +crate-type = ["cdylib", "rlib"] +path = "rust/lib.rs" + +[dependencies] +bumpalo = { workspace = true, features = ["collections"] } +hugr-model = { version = "0.18", path = "../hugr-model", features = ["pyo3"] } +paste.workspace = true +pyo3 = { workspace = true } +smol_str.workspace = true diff --git a/hugr-py/pyproject.toml b/hugr-py/pyproject.toml index 93b5707ff..f8c21ee4b 100644 --- a/hugr-py/pyproject.toml +++ b/hugr-py/pyproject.toml @@ -42,5 +42,10 @@ homepage = "https://github.com/CQCL/hugr/tree/main/hugr-py" repository = "https://github.com/CQCL/hugr/tree/main/hugr-py" [build-system] -requires = ["hatchling"] -build-backend = "hatchling.build" +requires = ["maturin>=1.4,<2.0"] +build-backend = "maturin" + +[tool.maturin] +features = ["pyo3/extension-module"] +python-source = "src/" +module-name = "hugr._hugr" diff --git a/hugr-py/rust/lib.rs b/hugr-py/rust/lib.rs new file mode 100644 index 000000000..000b31fa8 --- /dev/null +++ b/hugr-py/rust/lib.rs @@ -0,0 +1,67 @@ +use hugr_model::v0::ast; +use pyo3::{exceptions::PyValueError, prelude::*}; + +macro_rules! syntax_to_and_from_string { + ($name:ident, $ty:ty) => { + paste::paste! { + #[pyfunction] + fn [<$name _to_string>](ob: ast::$ty) -> PyResult { + Ok(format!("{}", ob)) + } + + #[pyfunction] + fn [](string: String) -> PyResult { + string + .parse::() + .map_err(|err| PyValueError::new_err(err.to_string())) + } + } + }; +} + +syntax_to_and_from_string!(term, Term); +syntax_to_and_from_string!(node, Node); +syntax_to_and_from_string!(region, Region); +syntax_to_and_from_string!(module, Module); +syntax_to_and_from_string!(param, Param); +syntax_to_and_from_string!(symbol, Symbol); + +#[pyfunction] +fn module_to_bytes(module: ast::Module) -> PyResult> { + let bump = bumpalo::Bump::new(); + let resolved = module + .resolve(&bump) + .map_err(|err| PyValueError::new_err(err.to_string()))?; + let bytes = hugr_model::v0::binary::write_to_vec(&resolved); + Ok(bytes) +} + +#[pyfunction] +fn bytes_to_module(bytes: &[u8]) -> PyResult { + let bump = bumpalo::Bump::new(); + let table = hugr_model::v0::binary::read_from_slice(bytes, &bump) + .map_err(|err| PyValueError::new_err(err.to_string()))?; + let module = table + .as_ast() + .ok_or_else(|| PyValueError::new_err("Malformed module"))?; + Ok(module) +} + +#[pymodule] +fn _hugr(m: &Bound<'_, PyModule>) -> PyResult<()> { + m.add_function(wrap_pyfunction!(term_to_string, m)?)?; + m.add_function(wrap_pyfunction!(string_to_term, m)?)?; + m.add_function(wrap_pyfunction!(node_to_string, m)?)?; + m.add_function(wrap_pyfunction!(string_to_node, m)?)?; + m.add_function(wrap_pyfunction!(region_to_string, m)?)?; + m.add_function(wrap_pyfunction!(string_to_region, m)?)?; + m.add_function(wrap_pyfunction!(module_to_string, m)?)?; + m.add_function(wrap_pyfunction!(module_to_bytes, m)?)?; + m.add_function(wrap_pyfunction!(string_to_module, m)?)?; + m.add_function(wrap_pyfunction!(bytes_to_module, m)?)?; + m.add_function(wrap_pyfunction!(param_to_string, m)?)?; + m.add_function(wrap_pyfunction!(string_to_param, m)?)?; + m.add_function(wrap_pyfunction!(symbol_to_string, m)?)?; + m.add_function(wrap_pyfunction!(string_to_symbol, m)?)?; + Ok(()) +} diff --git a/hugr-py/src/hugr/__init__.py b/hugr-py/src/hugr/__init__.py index c6ab6c7f1..2de9aefc5 100644 --- a/hugr-py/src/hugr/__init__.py +++ b/hugr-py/src/hugr/__init__.py @@ -2,6 +2,7 @@ representation. """ +from . import model from .hugr.base import Hugr from .hugr.node_port import Direction, InPort, Node, OutPort, Wire @@ -12,6 +13,7 @@ "InPort", "Direction", "Wire", + "model", ] # This is updated by our release-please workflow, triggered by this diff --git a/hugr-py/src/hugr/_hugr.cpython-312-darwin.so b/hugr-py/src/hugr/_hugr.cpython-312-darwin.so new file mode 100755 index 000000000..e1559b04f Binary files /dev/null and b/hugr-py/src/hugr/_hugr.cpython-312-darwin.so differ diff --git a/hugr-py/src/hugr/_hugr/__init__.pyi b/hugr-py/src/hugr/_hugr/__init__.pyi new file mode 100644 index 000000000..7cf2aa8ad --- /dev/null +++ b/hugr-py/src/hugr/_hugr/__init__.pyi @@ -0,0 +1,16 @@ +import hugr.model + +def term_to_string(term: hugr.model.Term) -> str: ... +def string_to_term(string: str) -> hugr.model.Term: ... +def node_to_string(node: hugr.model.Node) -> str: ... +def string_to_node(string: str) -> hugr.model.Node: ... +def region_to_string(region: hugr.model.Region) -> str: ... +def string_to_region(string: str) -> hugr.model.Region: ... +def param_to_string(region: hugr.model.Param) -> str: ... +def string_to_param(string: str) -> hugr.model.Param: ... +def symbol_to_string(region: hugr.model.Symbol) -> str: ... +def string_to_symbol(string: str) -> hugr.model.Symbol: ... +def module_to_string(module: hugr.model.Module) -> str: ... +def string_to_module(string: str) -> hugr.model.Module: ... +def module_to_bytes(module: hugr.model.Module) -> bytes: ... +def bytes_to_module(binary: bytes) -> hugr.model.Module: ... diff --git a/hugr-py/src/hugr/hugr/base.py b/hugr-py/src/hugr/hugr/base.py index 0364d2dfd..f86687b7f 100644 --- a/hugr-py/src/hugr/hugr/base.py +++ b/hugr-py/src/hugr/hugr/base.py @@ -14,6 +14,7 @@ overload, ) +import hugr.model as model from hugr._serialization.ops import OpType as SerialOp from hugr._serialization.serial_hugr import SerialHugr from hugr.exceptions import ParentBeforeChild @@ -730,6 +731,15 @@ def to_json(self) -> str: """ return self._to_serial().to_json() + def to_model(self) -> model.Module: + return model.Module(self.to_model_region()) + + def to_model_region(self) -> model.Region: + from hugr.model.export import ModelExport + + export = ModelExport(self) + return export.export_region_module(self.root) + @classmethod def load_json(cls, json_str: str) -> Hugr: """Deserialize a JSON string into a HUGR. diff --git a/hugr-py/src/hugr/model/__init__.py b/hugr-py/src/hugr/model/__init__.py new file mode 100644 index 000000000..ba6a20f41 --- /dev/null +++ b/hugr-py/src/hugr/model/__init__.py @@ -0,0 +1,268 @@ +"""HUGR model data structures.""" + +from collections.abc import Sequence +from dataclasses import dataclass, field +from enum import Enum +from typing import Protocol + +import hugr._hugr as rust + + +class Term(Protocol): + """A model term for static data such as types, constants and metadata.""" + + def __str__(self) -> str: + return rust.term_to_string(self) + + @staticmethod + def from_str(s: str) -> "Term": + return rust.string_to_term(s) + + +@dataclass(frozen=True) +class Wildcard(Term): + """Standin for any term.""" + + +@dataclass(frozen=True) +class Var(Term): + """Local variable, identified by its name.""" + + name: str + + +@dataclass(frozen=True) +class Apply(Term): + """Symbol application.""" + + symbol: str + args: Sequence[Term] = field(default_factory=list) + + +@dataclass(frozen=True) +class Splice: + """A sequence spliced into the parent sequence.""" + + seq: Term + + +SeqPart = Term | Splice + + +@dataclass(frozen=True) +class List(Term): + """List of static data.""" + + parts: Sequence[SeqPart] = field(default_factory=list) + + +@dataclass(frozen=True) +class Tuple(Term): + """Tuple of static data.""" + + parts: Sequence[SeqPart] = field(default_factory=list) + + +@dataclass(frozen=True) +class Literal(Term): + """Static literal value.""" + + value: str | float | int | bytes + + +@dataclass(frozen=True) +class Func(Term): + """Function constant.""" + + region: "Region" + + +@dataclass(frozen=True) +class ExtSet(Term): + """Extension set. (deprecated).""" + + +@dataclass +class Param: + """A parameter to a Symbol.""" + + name: str + type: Term + + def __str__(self): + return rust.param_to_string(self) + + @staticmethod + def from_str(s: str) -> "Param": + return rust.string_to_param(s) + + +@dataclass +class Symbol: + """A named symbol.""" + + name: str + params: Sequence[Param] = field(default_factory=list) + constraints: Sequence[Term] = field(default_factory=list) + signature: Term = field(default_factory=Wildcard) + + def __str__(self): + return rust.symbol_to_string(self) + + @staticmethod + def from_str(s: str) -> "Symbol": + return rust.string_to_symbol(s) + + +class Op(Protocol): + """The operation of a node.""" + + +@dataclass(frozen=True) +class InvalidOp(Op): + """Invalid operation intended to serve as a placeholder.""" + + +@dataclass(frozen=True) +class Dfg(Op): + """Dataflow graph.""" + + +@dataclass(frozen=True) +class Cfg(Op): + """Control flow graph.""" + + +@dataclass(frozen=True) +class Block(Op): + """Basic block in a control flow graph.""" + + +@dataclass(frozen=True) +class DefineFunc(Op): + """Function definiton.""" + + symbol: Symbol + + +@dataclass(frozen=True) +class DeclareFunc(Op): + """Function declaration.""" + + symbol: Symbol + + +@dataclass(frozen=True) +class CustomOp(Op): + """Custom operation.""" + + operation: Term + + +@dataclass(frozen=True) +class DefineAlias(Op): + """Alias definition.""" + + symbol: Symbol + value: Term + + +@dataclass(frozen=True) +class DeclareAlias(Op): + """Alias declaration.""" + + symbol: Symbol + + +@dataclass(frozen=True) +class TailLoop(Op): + """Tail-controlled loop operation.""" + + +@dataclass(frozen=True) +class Conditional(Op): + """Conditional branch operation.""" + + +@dataclass(frozen=True) +class DeclareConstructor(Op): + """Constructor declaration.""" + + symbol: Symbol + + +@dataclass(frozen=True) +class DeclareOperation(Op): + """Operation declaration.""" + + symbol: Symbol + + +@dataclass(frozen=True) +class Import(Op): + """Import operation.""" + + name: str + + +@dataclass +class Node: + """A node in a hugr graph.""" + + operation: Op = field(default_factory=lambda: InvalidOp()) + inputs: Sequence[str] = field(default_factory=list) + outputs: Sequence[str] = field(default_factory=list) + regions: Sequence["Region"] = field(default_factory=list) + meta: Sequence[Term] = field(default_factory=list) + signature: Term | None = None + + def __str__(self) -> str: + return rust.node_to_string(self) + + +class RegionKind(Enum): + """The kind of a hugr region.""" + + DATA_FLOW = 0 + CONTROL_FLOW = 1 + MODULE = 2 + + +@dataclass +class Region: + """A hugr region containing an unordered collection of nodes.""" + + kind: RegionKind = RegionKind.DATA_FLOW + sources: Sequence[str] = field(default_factory=list) + targets: Sequence[str] = field(default_factory=list) + children: Sequence[Node] = field(default_factory=list) + meta: Sequence[Term] = field(default_factory=list) + signature: Term | None = None + + def __str__(self): + return rust.region_to_string(self) + + @staticmethod + def from_str(s: str) -> "Region": + return rust.string_to_region(s) + + +@dataclass +class Module: + """A top level hugr module.""" + + root: Region + + def __str__(self): + return rust.module_to_string(self) + + def __bytes__(self): + return rust.module_to_bytes(self) + + @staticmethod + def from_str(s: str) -> "Module": + return rust.string_to_module(s) + + @staticmethod + def from_bytes(b: bytes) -> "Module": + return rust.bytes_to_module(b) diff --git a/hugr-py/src/hugr/model/export.py b/hugr-py/src/hugr/model/export.py new file mode 100644 index 000000000..7464fe7fe --- /dev/null +++ b/hugr-py/src/hugr/model/export.py @@ -0,0 +1,558 @@ +"""Helpers to export hugr graphs from their python representation to hugr model.""" + +from collections.abc import Sequence +from typing import Generic, TypeVar, cast + +import hugr.model as model +from hugr.hugr.base import Hugr, Node +from hugr.hugr.node_port import InPort, OutPort +from hugr.ops import ( + CFG, + DFG, + AliasDecl, + AliasDefn, + AsExtOp, + Call, + CallIndirect, + Conditional, + Const, + Custom, + DataflowBlock, + ExitBlock, + FuncDecl, + FuncDefn, + Input, + LoadConst, + LoadFunc, + Output, + Tag, + TailLoop, +) +from hugr.tys import ConstKind, FunctionKind, Type, TypeBound, TypeParam, TypeTypeParam + + +class ModelExport: + """Helper to export a Hugr.""" + + def __init__(self, hugr: Hugr): + self.hugr = hugr + self.link_ports: _UnionFind[InPort | OutPort] = _UnionFind() + self.link_names: dict[InPort | OutPort, str] = {} + + for a, b in self.hugr.links(): + self.link_ports.union(a, b) + + def link_name(self, port): + root = self.link_ports[port] + + if root in self.link_names: + return self.link_names[root] + else: + index = str(len(self.link_names)) + self.link_names[root] = index + return index + + def export_node(self, node: Node) -> model.Node | None: + node_data = self.hugr[node] + + inputs = [self.link_name(InPort(node, i)) for i in range(node_data._num_inps)] + outputs = [self.link_name(OutPort(node, i)) for i in range(node_data._num_outs)] + + match node_data.op: + case DFG() as op: + region = self.export_region_dfg(node) + + return model.Node( + operation=model.Dfg(), + regions=[region], + signature=op.outer_signature().to_model(), + inputs=inputs, + outputs=outputs, + ) + + case Custom() as op: + name = f"{op.extension}.{op.op_name}" + args = cast(list[model.Term], [arg.to_model() for arg in op.args]) + signature = op.signature.to_model() + + return model.Node( + operation=model.CustomOp(model.Apply(name, args)), + signature=signature, + inputs=inputs, + outputs=outputs, + ) + + case AsExtOp() as op: + name = op.op_def().qualified_name() + args = cast( + list[model.Term], [arg.to_model() for arg in op.type_args()] + ) + signature = op.outer_signature().to_model() + + return model.Node( + operation=model.CustomOp(model.Apply(name, args)), + signature=signature, + inputs=inputs, + outputs=outputs, + ) + + case Conditional() as op: + regions = [ + self.export_region_dfg(child) for child in node_data.children + ] + + signature = op.outer_signature().to_model() + + return model.Node( + operation=model.Conditional(), + regions=regions, + signature=signature, + inputs=inputs, + outputs=outputs, + ) + + case TailLoop() as op: + region = self.export_region_dfg(node) + signature = op.outer_signature().to_model() + return model.Node( + operation=model.TailLoop(), + regions=[region], + signature=signature, + inputs=inputs, + outputs=outputs, + ) + + case FuncDefn() as op: + name = _mangle_name(node, op.f_name) + symbol = self.export_symbol( + name, op.signature.params, op.signature.body + ) + region = self.export_region_dfg(node) + + return model.Node( + operation=model.DefineFunc(symbol), + regions=[region], + ) + + case FuncDecl() as op: + name = _mangle_name(node, op.f_name) + symbol = self.export_symbol( + name, op.signature.params, op.signature.body + ) + return model.Node( + operation=model.DeclareFunc(symbol), + ) + + case AliasDecl() as op: + symbol = model.Symbol(name=op.alias, signature=model.Apply("core.type")) + + return model.Node(operation=model.DeclareAlias(symbol)) + + case AliasDefn() as op: + symbol = model.Symbol(name=op.alias, signature=model.Apply("core.type")) + + alias_value = cast(model.Term, op.definition.to_model()) + + return model.Node(operation=model.DefineAlias(symbol, alias_value)) + + case Call() as op: + input_types = [type.to_model() for type in op.instantiation.input] + output_types = [type.to_model() for type in op.instantiation.output] + signature = op.instantiation.to_model() + func_args = cast( + list[model.Term], [type.to_model() for type in op.type_args] + ) + func_name = self.find_func_input(node) + + if func_name is None: + error = f"Call node {node} is not connected to a function." + raise ValueError(error) + + func = model.Apply(func_name, func_args) + + return model.Node( + operation=model.CustomOp( + model.Apply( + "core.call", + [ + model.List(input_types), + model.List(output_types), + model.ExtSet(), + func, + ], + ) + ), + signature=signature, + inputs=inputs, + outputs=outputs, + ) + + case LoadFunc() as op: + signature = op.instantiation.to_model() + func_args = cast( + list[model.Term], [type.to_model() for type in op.type_args] + ) + func_name = self.find_func_input(node) + + if func_name is None: + error = f"LoadFunc node {node} is not connected to a function." + raise ValueError(error) + + func = model.Apply(func_name, func_args) + + return model.Node( + operation=model.CustomOp( + model.Apply( + "core.load_const", [signature, model.ExtSet(), func] + ) + ), + signature=signature, + inputs=inputs, + outputs=outputs, + ) + + case CallIndirect() as op: + input_types = [type.to_model() for type in op.signature.input] + output_types = [type.to_model() for type in op.signature.output] + + func = model.Apply( + "core.fn", + [model.List(input_types), model.List(output_types), model.ExtSet()], + ) + + signature = model.Apply( + "core.fn", + [ + model.List([func, *input_types]), + model.List(output_types), + model.ExtSet(), + ], + ) + + return model.Node( + operation=model.CustomOp( + model.Apply( + "core.call_indirect", + [ + model.List(input_types), + model.List(output_types), + model.ExtSet(), + ], + ) + ), + signature=signature, + inputs=inputs, + outputs=outputs, + ) + + case LoadConst() as op: + value = self.find_const_input(node) + + if value is None: + error = f"LoadConst node {node} is not connected to a constant." + raise ValueError(error) + + type = cast(model.Term, op.type_.to_model()) + signature = op.outer_signature().to_model() + + return model.Node( + operation=model.CustomOp( + model.Apply("core.load_const", [type, model.ExtSet(), value]) + ), + signature=signature, + inputs=inputs, + outputs=outputs, + ) + + case Const() as op: + return None + + case CFG() as op: + signature = op.outer_signature().to_model() + region = self.export_region_cfg(node) + + # TODO: Export CFGs + return model.Node( + operation=model.Cfg(), + signature=signature, + inputs=inputs, + outputs=outputs, + regions=[region], + ) + + case DataflowBlock() as op: + region = self.export_region_dfg(node) + + input_types = [ + model.Apply( + "core.ctrl", + [model.List([type.to_model() for type in op.inputs])], + ) + ] + + other_output_types = [type.to_model() for type in op.other_outputs] + output_types = [ + model.Apply( + "core.ctrl", + [ + model.List( + [ + *[type.to_model() for type in row], + *other_output_types, + ] + ) + ], + ) + for row in op.sum_ty.variant_rows + ] + + signature = model.Apply( + "core.fn", + [model.List(input_types), model.List(output_types), model.ExtSet()], + ) + + return model.Node( + operation=model.Block(), + inputs=inputs, + outputs=outputs, + regions=[region], + signature=signature, + ) + + case Tag() as op: + variants = model.List( + [ + model.List([type.to_model() for type in row]) + for row in op.sum_ty.variant_rows + ] + ) + + types = model.List( + [type.to_model() for type in op.sum_ty.variant_rows[op.tag]] + ) + + tag = model.Literal(op.tag) + signature = op.outer_signature().to_model() + + return model.Node( + operation=model.CustomOp( + model.Apply("core.make_adt", [variants, types, tag]) + ), + inputs=inputs, + outputs=outputs, + signature=signature, + ) + + case op: + error = f"Unknown operation: {op}" + raise ValueError(error) + + def export_region_module(self, node: Node) -> model.Region: + node_data = self.hugr[node] + children = [] + + for child in node_data.children: + child_node = self.export_node(child) + + if child_node is not None: + children.append(child_node) + + return model.Region(kind=model.RegionKind.MODULE, children=children) + + def export_region_dfg(self, node: Node) -> model.Region: + node_data = self.hugr[node] + children: list[model.Node] = [] + source_types: model.Term = model.Wildcard() + target_types: model.Term = model.Wildcard() + sources = [] + targets = [] + + for child in node_data.children: + child_data = self.hugr[child] + + match child_data.op: + case Input() as op: + source_types = model.List([type.to_model() for type in op.types]) + sources = [ + self.link_name(OutPort(child, i)) + for i in range(child_data._num_outs) + ] + + case Output() as op: + target_types = model.List([type.to_model() for type in op.types]) + targets = [ + self.link_name(InPort(child, i)) + for i in range(child_data._num_inps) + ] + + case _: + child_node = self.export_node(child) + + if child_node is not None: + children.append(child_node) + + signature = model.Apply("core.fn", [source_types, target_types, model.ExtSet()]) + + return model.Region( + kind=model.RegionKind.DATA_FLOW, + signature=signature, + children=children, + sources=sources, + targets=targets, + ) + + def export_region_cfg(self, node: Node) -> model.Region: + node_data = self.hugr[node] + + source = None + targets = [] + source_types: model.Term = model.Wildcard() + target_types: model.Term = model.Wildcard() + children = [] + + for child in node_data.children: + child_data = self.hugr[child] + + match child_data.op: + case ExitBlock() as op: + target_types = model.List( + [type.to_model() for type in op.cfg_outputs] + ) + targets = [ + self.link_name(InPort(child, i)) + for i in range(child_data._num_inps) + ] + case DataflowBlock() as op: + if source is None: + source_types = model.List( + [type.to_model() for type in op.inputs] + ) + source = self.link_name(OutPort(child, 0)) + + child_node = self.export_node(child) + + if child_node is not None: + children.append(child_node) + case _: + error = f"Unexpected operation in CFG {node}" + raise ValueError(error) + + if source is None: + error = f"CFG {node} has no entry block." + raise ValueError(error) + + signature = model.Apply("core.fn", [source_types, target_types, model.ExtSet()]) + + return model.Region( + kind=model.RegionKind.CONTROL_FLOW, + targets=targets, + sources=[source], + signature=signature, + children=children, + ) + + def export_symbol( + self, name: str, param_types: Sequence[TypeParam], body: Type + ) -> model.Symbol: + constraints = [] + params = [] + + for i, param_type in enumerate(param_types): + param_name = str(i) + + params.append(model.Param(name=param_name, type=param_type.to_model())) + + match param_type: + case TypeTypeParam(bound=TypeBound.Copyable): + constraints.append( + model.Apply("core.nonlinear", [model.Var(param_name)]) + ) + case _: + pass + + return model.Symbol( + name=name, + params=params, + constraints=constraints, + signature=cast(model.Term, body.to_model()), + ) + + def find_func_input(self, node: Node) -> str | None: + try: + func_node = next( + out_port.node + for (in_port, out_ports) in self.hugr.incoming_links(node) + if isinstance(self.hugr.port_kind(in_port), FunctionKind) + for out_port in out_ports + ) + except StopIteration: + return None + + match self.hugr[func_node].op: + case FuncDecl() as func_op: + name = func_op.f_name + case FuncDefn() as func_op: + name = func_op.f_name + case _: + return None + + return _mangle_name(node, name) + + def find_const_input(self, node: Node) -> model.Term | None: + try: + const_node = next( + out_port.node + for (in_port, out_ports) in self.hugr.incoming_links(node) + if isinstance(self.hugr.port_kind(in_port), ConstKind) + for out_port in out_ports + ) + except StopIteration: + return None + + match self.hugr[const_node].op: + case Const() as op: + return op.val.to_model() + case op: + return None + + +def _mangle_name(node: Node, name: str) -> str: + # Until we come to an agreement on the uniqueness of names, we mangle the names + # by adding the node id. + return f"_{name}_{node.idx}" + + +T = TypeVar("T") + + +class _UnionFind(Generic[T]): + def __init__(self) -> None: + self.parents: dict[T, T] = {} + self.sizes: dict[T, int] = {} + + def __getitem__(self, item: T) -> T: + if item not in self.parents: + self.parents[item] = item + self.sizes[item] = 1 + return item + + # Path splitting + while self.parents[item] != item: + parent = self.parents[item] + self.parents[item] = self.parents[parent] + item = parent + + return item + + def union(self, a: T, b: T): + a = self[a] + b = self[b] + + if a == b: + return + + if self.sizes[a] < self.sizes[b]: + (a, b) = (b, a) + + self.parents[b] = a + self.sizes[a] += self.sizes[b] diff --git a/hugr-py/src/hugr/ops.py b/hugr-py/src/hugr/ops.py index 46571bdcb..1555dab4d 100644 --- a/hugr-py/src/hugr/ops.py +++ b/hugr-py/src/hugr/ops.py @@ -183,7 +183,7 @@ def _to_serial(self, parent: Node) -> sops.Input: def outer_signature(self) -> tys.FunctionType: return tys.FunctionType(input=[], output=self.types) - def __call__(self) -> Command: + def __call__(self, *args) -> Command: return super().__call__() def name(self) -> str: diff --git a/hugr-py/src/hugr/std/collections/array.py b/hugr-py/src/hugr/std/collections/array.py index 3e7f3bcfc..80dad72dc 100644 --- a/hugr-py/src/hugr/std/collections/array.py +++ b/hugr-py/src/hugr/std/collections/array.py @@ -3,7 +3,9 @@ from __future__ import annotations from dataclasses import dataclass +from typing import cast +import hugr.model as model from hugr import tys, val from hugr.std import _load_extension from hugr.utils import comma_sep_str @@ -79,3 +81,13 @@ def to_value(self) -> val.Extension: def __str__(self) -> str: return f"array({comma_sep_str(self.v)})" + + def to_model(self) -> model.Term: + return model.Apply( + "collections.array.const", + [ + model.Literal(len(self.v)), + cast(model.Term, self.ty.ty.to_model()), + model.List([value.to_model() for value in self.v]), + ], + ) diff --git a/hugr-py/src/hugr/std/float.py b/hugr-py/src/hugr/std/float.py index 7b7a0cf24..e155ef3c7 100644 --- a/hugr-py/src/hugr/std/float.py +++ b/hugr-py/src/hugr/std/float.py @@ -4,6 +4,7 @@ from dataclasses import dataclass +import hugr.model as model from hugr import val from hugr.std import _load_extension @@ -30,3 +31,6 @@ def to_value(self) -> val.Extension: def __str__(self) -> str: return f"{self.v}" + + def to_model(self) -> model.Term: + return model.Apply("arithmetic.float.const_f64", [model.Literal(self.v)]) diff --git a/hugr-py/src/hugr/std/int.py b/hugr-py/src/hugr/std/int.py index 8652cc135..4c1d0cdeb 100644 --- a/hugr-py/src/hugr/std/int.py +++ b/hugr-py/src/hugr/std/int.py @@ -7,6 +7,7 @@ from typing_extensions import Self +import hugr.model as model from hugr import ext, tys, val from hugr.ops import AsExtOp, DataflowOp, ExtOp, RegisteredOp from hugr.std import _load_extension @@ -71,6 +72,11 @@ def to_value(self) -> val.Extension: def __str__(self) -> str: return f"{self.v}" + def to_model(self) -> model.Term: + return model.Apply( + "arithmetic.int.const", [model.Literal(self.width), model.Literal(self.v)] + ) + INT_OPS_EXTENSION = _load_extension("arithmetic.int") diff --git a/hugr-py/src/hugr/tys.py b/hugr-py/src/hugr/tys.py index 7494f4b14..c92e37a0c 100644 --- a/hugr-py/src/hugr/tys.py +++ b/hugr-py/src/hugr/tys.py @@ -3,9 +3,10 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Protocol, runtime_checkable +from typing import TYPE_CHECKING, Protocol, cast, runtime_checkable import hugr._serialization.tys as stys +import hugr.model as model from hugr.utils import comma_sep_repr, comma_sep_str, ser_it if TYPE_CHECKING: @@ -30,6 +31,10 @@ def _to_serial(self) -> stys.BaseTypeParam: def _to_serial_root(self) -> stys.TypeParam: return stys.TypeParam(root=self._to_serial()) # type: ignore[arg-type] + def to_model(self) -> model.Term: + """Convert the type parameter to a model Term.""" + raise NotImplementedError(self) + @runtime_checkable class TypeArg(Protocol): @@ -46,6 +51,10 @@ def resolve(self, registry: ext.ExtensionRegistry) -> TypeArg: """Resolve types in the argument using the given registry.""" return self + def to_model(self) -> model.Term | model.Splice: + """Convert the type argument to a model Term.""" + raise NotImplementedError(self) + @runtime_checkable class Type(Protocol): @@ -82,6 +91,10 @@ def resolve(self, registry: ext.ExtensionRegistry) -> Type: """Resolve types in the type using the given registry.""" return self + def to_model(self) -> model.Term | model.Splice: + """Convert the type to a model Term.""" + raise NotImplementedError(self) + #: Row of types. TypeRow = list[Type] @@ -103,6 +116,10 @@ def _to_serial(self) -> stys.TypeTypeParam: def __str__(self) -> str: return str(self.bound) + def to_model(self) -> model.Term: + # Note that we drop the bound. + return model.Apply("core.type") + @dataclass(frozen=True) class BoundedNatParam(TypeParam): @@ -118,6 +135,10 @@ def __str__(self) -> str: return "Nat" return f"Nat({self.upper_bound})" + def to_model(self) -> model.Term: + # Note that we drop the bound. + return model.Apply("core.nat") + @dataclass(frozen=True) class StringParam(TypeParam): @@ -129,6 +150,9 @@ def _to_serial(self) -> stys.StringParam: def __str__(self) -> str: return "String" + def to_model(self) -> model.Term: + return model.Apply("core.str") + @dataclass(frozen=True) class ListParam(TypeParam): @@ -142,6 +166,10 @@ def _to_serial(self) -> stys.ListParam: def __str__(self) -> str: return f"[{self.param}]" + def to_model(self) -> model.Term: + item_type = self.param.to_model() + return model.Apply("core.list", [item_type]) + @dataclass(frozen=True) class TupleParam(TypeParam): @@ -155,6 +183,10 @@ def _to_serial(self) -> stys.TupleParam: def __str__(self) -> str: return f"({comma_sep_str(self.params)})" + def to_model(self) -> model.Term: + item_types = model.List([param.to_model() for param in self.params]) + return model.Apply("core.tuple", [item_types]) + @dataclass(frozen=True) class ExtensionsParam(TypeParam): @@ -166,6 +198,9 @@ def _to_serial(self) -> stys.ExtensionsParam: def __str__(self) -> str: return "Extensions" + def to_model(self) -> model.Term: + return model.Apply("core.ext_set") + # ------------------------------------------ # --------------- TypeArg ------------------ @@ -187,6 +222,9 @@ def resolve(self, registry: ext.ExtensionRegistry) -> TypeArg: def __str__(self) -> str: return f"Type({self.ty!s})" + def to_model(self) -> model.Term | model.Splice: + return self.ty.to_model() + @dataclass(frozen=True) class BoundedNatArg(TypeArg): @@ -200,6 +238,9 @@ def _to_serial(self) -> stys.BoundedNatArg: def __str__(self) -> str: return str(self.n) + def to_model(self) -> model.Term: + return model.Literal(self.n) + @dataclass(frozen=True) class StringArg(TypeArg): @@ -213,6 +254,9 @@ def _to_serial(self) -> stys.StringArg: def __str__(self) -> str: return f'"{self.value}"' + def to_model(self) -> model.Term: + return model.Literal(self.value) + @dataclass(frozen=True) class SequenceArg(TypeArg): @@ -229,6 +273,11 @@ def resolve(self, registry: ext.ExtensionRegistry) -> TypeArg: def __str__(self) -> str: return f"({comma_sep_str(self.elems)})" + def to_model(self) -> model.Term: + # TODO: We should separate lists and tuples. + # For now we assume that this is a list. + return model.List([elem.to_model() for elem in self.elems]) + @dataclass(frozen=True) class ExtensionsArg(TypeArg): @@ -242,6 +291,10 @@ def _to_serial(self) -> stys.ExtensionsArg: def __str__(self) -> str: return f"Extensions({comma_sep_str(self.extensions)})" + def to_model(self) -> model.Term: + # Since extension sets will be deprecated, this is just a placeholder. + return model.ExtSet() + @dataclass(frozen=True) class VariableArg(TypeArg): @@ -256,6 +309,9 @@ def _to_serial(self) -> stys.VariableArg: def __str__(self) -> str: return f"${self.idx}" + def to_model(self) -> model.Term: + return model.Var(str(self.idx)) + # ---------------------------------------------- # --------------- Type ------------------------- @@ -293,6 +349,12 @@ def resolve(self, registry: ext.ExtensionRegistry) -> Sum: """Resolve types in the sum type using the given registry.""" return Sum([[ty.resolve(registry) for ty in row] for row in self.variant_rows]) + def to_model(self) -> model.Term: + variants = model.List( + [model.List([typ.to_model() for typ in row]) for row in self.variant_rows] + ) + return model.Apply("core.adt", [variants]) + @dataclass(eq=False) class UnitSum(Sum): @@ -386,6 +448,9 @@ def type_bound(self) -> TypeBound: def __repr__(self) -> str: return f"${self.idx}" + def to_model(self) -> model.Term: + return model.Var(str(self.idx)) + @dataclass(frozen=True) class RowVariable(Type): @@ -403,6 +468,9 @@ def type_bound(self) -> TypeBound: def __repr__(self) -> str: return f"${self.idx}" + def to_model(self): + return model.Splice(model.Var(str(self.idx))) + @dataclass(frozen=True) class USize(Type): @@ -417,6 +485,9 @@ def type_bound(self) -> TypeBound: def __repr__(self) -> str: return "USize" + def to_model(self) -> model.Term: + return model.Apply("prelude.usize") + @dataclass(frozen=True) class Alias(Type): @@ -434,6 +505,9 @@ def type_bound(self) -> TypeBound: def __repr__(self) -> str: return self.name + def to_model(self) -> model.Term: + return model.Apply(self.name) + @dataclass(frozen=True) class FunctionType(Type): @@ -508,6 +582,12 @@ def with_runtime_reqs(self, runtime_reqs: ExtensionSet) -> FunctionType: def __str__(self) -> str: return f"{comma_sep_str(self.input)} -> {comma_sep_str(self.output)}" + def to_model(self) -> model.Term: + inputs = model.List([input.to_model() for input in self.input]) + outputs = model.List([output.to_model() for output in self.output]) + exts = model.ExtSet() + return model.Apply("core.fn", [inputs, outputs, exts]) + @dataclass(frozen=True) class PolyFuncType(Type): @@ -556,6 +636,11 @@ def empty(cls) -> PolyFuncType: """ return PolyFuncType(params=[], body=FunctionType.empty()) + def to_model(self) -> model.Term: + # A `PolyFuncType` should not be a `Type`. + error = "PolyFuncType used as a Type" + raise TypeError(error) + @dataclass class ExtType(Type): @@ -600,6 +685,17 @@ def __eq__(self, value): return self.type_def == value.type_def and self.args == value.args return super().__eq__(value) + def to_model(self) -> model.Term: + # This cast is only neccessary because `Type` can both be an + # actual type or a row variable. + args = [cast(model.Term, arg.to_model()) for arg in self.args] + + extension_name = self.type_def.get_extension().name + type_name = self.type_def.name + name = f"{extension_name}.{type_name}" + + return model.Apply(name, args) + def _type_str(name: str, args: Sequence[TypeArg]) -> str: if len(args) == 0: @@ -644,6 +740,13 @@ def resolve(self, registry: ext.ExtensionRegistry) -> Type: def __str__(self) -> str: return _type_str(self.id, self.args) + def to_model(self) -> model.Term: + # This cast is only neccessary because `Type` can both be an + # actual type or a row variable. + args = [cast(model.Term, arg.to_model()) for arg in self.args] + + return model.Apply(self.id, args) + @dataclass class _QubitDef(Type): @@ -656,6 +759,10 @@ def _to_serial(self) -> stys.Qubit: def __repr__(self) -> str: return "Qubit" + def to_model(self) -> model.Term: + # TODO: Is this the correct name? + return model.Apply("prelude.Qubit", []) + #: Qubit type. Qubit = _QubitDef() diff --git a/hugr-py/src/hugr/val.py b/hugr-py/src/hugr/val.py index 26b84cde9..5ad308d07 100644 --- a/hugr-py/src/hugr/val.py +++ b/hugr-py/src/hugr/val.py @@ -3,10 +3,11 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable +from typing import TYPE_CHECKING, Any, Protocol, cast, runtime_checkable import hugr._serialization.ops as sops import hugr._serialization.tys as stys +import hugr.model as model from hugr import tys from hugr.utils import comma_sep_repr, comma_sep_str, ser_it @@ -36,6 +37,8 @@ def type_(self) -> tys.Type: """ ... # pragma: no cover + def to_model(self) -> model.Term: ... + @dataclass class Sum(Value): @@ -75,6 +78,28 @@ def __eq__(self, other: object) -> bool: and self.vals == other.vals ) + def to_model(self) -> model.Term: + variants = [ + model.List([type.to_model() for type in row]) + for row in self.typ.variant_rows + ] + types = [ + model.Apply("core.const", [cast(model.Term, type)]) + for type in variants[self.tag].parts + ] + values = [value.to_model() for value in self.vals] + + return model.Apply( + "core.const.adt", + [ + model.List(variants), + model.ExtSet(), + model.List(types), + model.Literal(self.tag), + model.Tuple(values), + ], + ) + class UnitSum(Sum): """Simple :class:`Sum` with each variant being an empty row. @@ -278,6 +303,9 @@ def _to_serial(self) -> sops.FunctionValue: hugr=self.body._to_serial(), ) + def to_model(self) -> model.Term: + return model.Func(self.body.to_model_region()) + @dataclass class Extension(Value): @@ -301,6 +329,13 @@ def _to_serial(self) -> sops.CustomValue: extensions=self.extensions, ) + def to_model(self) -> model.Term: + type = cast(model.Term, self.typ.to_model()) + json = sops.CustomConst(c=self.name, v=self.val).model_dump_json() + return model.Apply( + "compat.const_json", [type, model.ExtSet(), model.Literal(json)] + ) + class ExtensionValue(Value, Protocol): """Protocol which types can implement to be a HUGR extension value.""" @@ -314,3 +349,7 @@ def type_(self) -> tys.Type: def _to_serial(self) -> sops.CustomValue: return self.to_value()._to_serial() + + def to_model(self) -> model.Term: + # Fallback + return self.to_value().to_model() diff --git a/uv.lock b/uv.lock index bb4ac7e00..be5c305df 100644 --- a/uv.lock +++ b/uv.lock @@ -10,9 +10,7 @@ resolution-markers = [ members = [ "hugr", ] - -[manifest.dependency-groups] -dev = [ +requirements = [ { name = "mypy", specifier = ">=1.9.0,<2" }, { name = "pre-commit", specifier = ">=3.6.2,<4" }, { name = "pytest", specifier = ">=8.1.1,<9" },