Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add CircuitHistory #778

Draft
wants to merge 1 commit into
base: feat/circdiff
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 5 additions & 5 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -26,11 +26,11 @@ missing_docs = "warn"
[patch.crates-io]

# Uncomment to use unreleased versions of hugr
hugr = { git = "https://github.com/CQCL/hugr", rev = "0efc806f" }
hugr-core = { git = "https://github.com/CQCL/hugr", rev = "0efc806f" }
hugr-passes = { git = "https://github.com/CQCL/hugr", rev = "0efc806f" }
hugr-cli = { git = "https://github.com/CQCL/hugr", rev = "0efc806f" }
hugr-model = { git = "https://github.com/CQCL/hugr", rev = "0efc806f" }
hugr = { git = "https://github.com/CQCL/hugr", rev = "739cf944" }
hugr-core = { git = "https://github.com/CQCL/hugr", rev = "739cf944" }
hugr-passes = { git = "https://github.com/CQCL/hugr", rev = "739cf944" }
hugr-cli = { git = "https://github.com/CQCL/hugr", rev = "739cf944" }
hugr-model = { git = "https://github.com/CQCL/hugr", rev = "739cf944" }
# portgraph = { git = "https://github.com/CQCL/portgraph", rev = "68b96ac737e0c285d8c543b2d74a7aa80a18202c" }

[workspace.dependencies]
178 changes: 176 additions & 2 deletions tket2/src/diff.rs
Original file line number Diff line number Diff line change
@@ -8,8 +8,12 @@
//! Diffs can be created on top of existing diffs, resulting in an acyclic
//! history of circuit transformations.
pub mod experimental;
mod history;
mod replacement;

pub use history::CircuitHistory;

use std::{
cell::RefCell,
cmp::Ordering,
@@ -19,7 +23,9 @@ use std::{

use derive_more::{Display, Error, From};
use derive_where::derive_where;
use hugr::{hugr::SimpleReplacementError, Hugr, HugrView, IncomingPort, Node, Wire};
use hugr::{
hugr::SimpleReplacementError, Direction, Hugr, HugrView, IncomingPort, Node, Port, Wire,
};
use itertools::Itertools;
use relrc::RelRc;

@@ -38,7 +44,8 @@ use crate::{
/// Use [`CircuitDiff::try_from_circuit`] to create a new "root" diff, i.e.
/// without any parents, and use [`CircuitDiff::apply_replacement`] to create
/// new diffs as children of the current diff.
#[derive(Clone)]
#[derive(From)]
#[derive_where(Clone)]
pub struct CircuitDiff<H = Hugr>(RelRc<CircuitDiffData<H>, InvalidNodes>);

#[derive(Clone)]
@@ -152,6 +159,15 @@ pub struct WireEquivalence<H = Hugr> {
wire_to_children: RefCell<BTreeMap<Wire, BTreeSet<ChildWire<H>>>>,
}

type CircuitDiffPtr<H> = *const relrc::node::InnerData<CircuitDiffData<H>, InvalidNodes>;

impl<H> CircuitDiff<H> {
/// Get the pointer to the inner data of the diff
fn as_ptr(&self) -> CircuitDiffPtr<H> {
self.0.as_ptr()
}
}

impl<H: HugrView> CircuitDiff<H> {
/// Create a new circuit diff from a circuit
pub fn try_from_circuit(circuit: Circuit<H>) -> Result<Self, HashError> {
@@ -167,6 +183,11 @@ impl<H: HugrView> CircuitDiff<H> {
self.0.value().circuit.circuit()
}

/// Get the io nodes of the diff
pub fn io_nodes(&self) -> [Node; 2] {
self.as_circuit().io_nodes()
}

/// Get the diff circuit as a hugr
pub fn as_hugr(&self) -> &H {
self.as_circuit().hugr()
@@ -211,6 +232,106 @@ impl<H: HugrView> CircuitDiff<H> {

new
}

fn wire_to_children(&self, wire: Wire) -> Vec<(Self, Wire)> {
let mut w = self
.0
.value()
.equivalent_wires
.wire_to_children
.borrow_mut();
let Some(wire_to_children_mut) = w.get_mut(&wire) else {
return vec![];
};
let mut wire_to_children = Vec::new();
wire_to_children_mut.retain(|child_wire| {
// remove edges to no longer existing nodes
let Some(target) = child_wire.edge.target().upgrade() else {
return false;
};
wire_to_children.push((CircuitDiff(target), child_wire.wire));
true
});

wire_to_children
}

fn input_to_parent(&self, wire: Wire) -> Option<ParentWire> {
self.0
.value()
.equivalent_wires
.input_to_parent
.get(&wire)
.copied()
}

fn output_to_parent(&self, wire: Wire) -> Option<&BTreeSet<ParentWire>> {
self.0.value().equivalent_wires.output_to_parent.get(&wire)
}

fn get_parent(&self, parent_wire: &ParentWire) -> Self {
let edge = self
.0
.incoming(parent_wire.incoming_index)
.expect("invalid parent index");
CircuitDiff(edge.source().clone())
}

fn all_parents(&self) -> impl ExactSizeIterator<Item = Self> + '_ {
self.0.all_parents().cloned().map_into()
}

/// Get equivalent ports in children of a given port using the wire equivalences
fn equivalent_children_ports<'a>(
&'a self,
node: Node,
port: Port,
) -> impl Iterator<Item = Owned<H, (Node, Port)>> + 'a {
let Ok(wire) = port_to_wire(node, port, self.as_hugr()) else {
return None.into_iter().flatten();
};
let iter = self
.wire_to_children(wire)
.into_iter()
.flat_map(move |(child, wire)| {
let to_owned = |data| Owned {
owner: child.clone(),
data,
};
wire_to_ports(wire, port.direction(), child.as_hugr())
.map(to_owned)
.collect_vec()
});
Some(iter).into_iter().flatten()
}

/// Get equivalent ports in parents of a given port using the wire equivalences
// TODO: make stronger assumptions on the kinds of wires in `input_to_parent`
// and `output_to_parent` to make this more efficient
fn equivalent_parent_ports<'a>(
&'a self,
node: Node,
port: Port,
) -> impl Iterator<Item = Owned<H, (Node, Port)>> + 'a {
let Ok(wire) = port_to_wire(node, port, self.as_hugr()) else {
return None.into_iter().flatten();
};
let inputs = self.input_to_parent(wire).into_iter();
let outputs = self.output_to_parent(wire).into_iter().flatten().copied();
let iter = inputs.chain(outputs).flat_map(move |parent_wire| {
let parent = self.get_parent(&parent_wire);
let to_owned = |data| Owned {
owner: parent.clone(),
data,
};
wire_to_ports(parent_wire.wire, port.direction(), parent.as_hugr())
.map(to_owned)
.collect_vec()
});
Some(iter.unique_by(|o| (o.owner.as_ptr(), o.data)))
.into_iter()
.flatten()
}
}

impl<H> WireEquivalence<H> {
@@ -251,4 +372,57 @@ pub enum CircuitDiffError {
/// Error when a cycle is detected in the dfg
#[display("cycle detected in dfg")]
Cycle,
/// Error when merging two diffs
#[display("conflicting diffs")]
ConflictingDiffs,
/// Error when a history is empty
#[display("empty history")]
EmptyHistory,
/// Error when merging two diffs with different roots
#[display("distinct roots")]
DistinctRoots,
}

fn port_to_wire(
node: Node,
port: impl Into<Port>,
hugr: &impl HugrView,
) -> Result<Wire, CircuitDiffError> {
let port: Port = port.into();

use itertools::Either::{Left, Right};
match port.as_directed() {
Left(incoming) => {
let (node, outgoing) = hugr
.single_linked_output(node, incoming)
.ok_or(CircuitDiffError::NoUniqueOutput(node, incoming))?;
Ok(Wire::new(node, outgoing))
}
Right(outgoing) => Ok(Wire::new(node, outgoing)),
}
}

fn wire_to_ports(
wire: Wire,
dir: Direction,
hugr: &impl HugrView,
) -> impl Iterator<Item = (Node, Port)> + '_ {
use itertools::Either::{Left, Right};
let iter = match dir {
Direction::Incoming => Left(
hugr.linked_inputs(wire.node(), wire.source())
.map(|(node, port)| (node, port.into())),
),
Direction::Outgoing => Right([(wire.node(), wire.source().into())]),
};
iter.into_iter()
}

/// Data in a circuit diff, along with its owner [`CircuitDiff`]
#[derive_where(Clone; D)]
pub struct Owned<H, D> {
/// The owner of the data
pub owner: CircuitDiff<H>,
/// The data
pub data: D,
}
316 changes: 316 additions & 0 deletions tket2/src/diff/experimental.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,316 @@
//! Experimental implementation of HugrView for [`CircuitHistory`].
//!
//! Access the implementation using the [`ExperimentalHugrWrapper`] type.
//!
//! ## Limitations
//! - Panics on histories with more than 2^16 = 65 536 diffs or which contain
//! a diff with more than 2^16 - 1 nodes.
//! - Does not implement [`HugrInternals::portgraph`] or
//! [`HugrInternals::base_hugr`], as these are not well defined for the
//! history as a whole.
//! - [`HugrView::nodes`] and [`HugrView::node_count`] are inefficient: they
//! iterate over the entire hugr. Currently, no implementation of
//! [`HugrView::edge_count`] is provided.
//!
//! Better support for this would require modifications to the [`HugrView`]
//! and [`HugrInternals`] traits.
use std::collections::BTreeSet;

use derive_more::{From, Into};
use derive_where::derive_where;
use hugr::{
hugr::views::ExtractHugr,
ops::{OpType, DEFAULT_OPTYPE},
Direction, Hugr, HugrView, Node, NodeIndex, Port,
};
use hugr_core::hugr::internal::HugrInternals;
use itertools::Itertools;

use crate::{
diff::{CircuitDiff, Owned},
Circuit,
};

use super::CircuitHistory;

/// A wrapper around a [`CircuitHistory`] which implements [`HugrView`].
///
/// This is experimental and has significant limitations.
///
/// ## Limitations
/// - Panics on histories with more than 2^16 = 65 536 diffs or which contain
/// a diff with more than 2^16 - 1 nodes.
/// - Does not implement [`HugrInternals::portgraph`] or
/// [`HugrInternals::base_hugr`], as these are not well defined for the
/// history as a whole.
/// - [`HugrView::nodes`] and [`HugrView::node_count`] are inefficient: they
/// iterate over the entire hugr. Currently, no implementation of
/// [`HugrView::edge_count`] is provided.
#[derive(Clone, From, Into)]
pub struct ExperimentalHugrWrapper<H: HugrView>(pub CircuitHistory<H>);

impl<H: HugrView> ExperimentalHugrWrapper<H> {
/// View the history as a circuit
pub fn as_circuit(&self) -> Circuit<&ExperimentalHugrWrapper<H>> {
Circuit::new(self, self.root())
}

/// Get the underlying hugr of a diff
fn get_diff_hugr(&self, diff_index: usize) -> &H {
// Get the diff_index-th element from all_nodes()
let &diff_id = self
.0
.diffs
.all_nodes()
.iter()
.nth(diff_index)
.expect("invalid diff index");
let diff = self.0.diffs.get_node(diff_id);
diff.value().circuit.circuit().hugr()
}

fn get_diff(&self, diff_index: usize) -> CircuitDiff<H> {
// Get the diff_index-th element from all_nodes()
let &diff_id = self
.0
.diffs
.all_nodes()
.iter()
.nth(diff_index)
.expect("invalid diff index");
CircuitDiff(self.0.diffs.get_node_rc(diff_id))
}

fn get_diff_index(&self, diff: &CircuitDiff<H>) -> Option<usize> {
self.0
.diffs
.all_nodes()
.iter()
.position(|id| self.0.diffs.get_node_rc(*id).ptr_eq(&diff.0))
}
}

impl<H: HugrView> HugrView for ExperimentalHugrWrapper<H> {
fn contains_node(&self, node: Node) -> bool {
let node: CircuitHistoryNode = node.into();
// Get the diff_index-th element from all_nodes()
let diff = node.diff(self);
if !self.0.is_root(&diff) && diff.io_nodes().contains(&node.node()) {
// Non-root IO nodes are not part of the hugr
return false;
}
diff.as_hugr().contains_node(node.node())
}

fn node_count(&self) -> usize {
self.nodes().count()
}

fn edge_count(&self) -> usize {
unimplemented!()
}

fn nodes(&self) -> impl Iterator<Item = Node> + Clone {
let current = self.get_io(self.root()).unwrap().to_vec();
let children = NodesIter {
visited: BTreeSet::default(),
current,
history: self,
};
[self.root_node()].into_iter().chain(children)
}

fn node_ports(&self, node: Node, dir: Direction) -> impl Iterator<Item = Port> + Clone {
let node: CircuitHistoryNode = node.into();
let hugr = self.get_diff_hugr(node.diff_index as usize);
hugr.node_ports(node.node(), dir)
}

fn all_node_ports(&self, node: Node) -> impl Iterator<Item = Port> + Clone {
let node: CircuitHistoryNode = node.into();
let hugr = self.get_diff_hugr(node.diff_index as usize);
hugr.all_node_ports(node.node())
}

fn linked_ports(
&self,
node: Node,
port: impl Into<Port>,
) -> impl Iterator<Item = (Node, Port)> + Clone {
let node: CircuitHistoryNode = node.into();
let node_port = Owned {
owner: node.diff(self),
data: (node.node(), port.into()),
};
let into_node = |node_port: Owned<H, (Node, Port)>| {
let Owned { owner, data } = node_port;
let diff_index = self.get_diff_index(&owner).unwrap();
let node = CircuitHistoryNode::try_new(diff_index, data.0)
.expect("diff_index or node_index too large for CircuitHistoryNode");
(node.into(), data.1)
};
self.0.linked_ports(node_port).map(into_node)
}

fn node_connections(&self, node: Node, other: Node) -> impl Iterator<Item = [Port; 2]> + Clone {
let ports = self
.node_inputs(node)
.map_into()
.chain(self.node_outputs(node).map_into());
ports.flat_map(move |p| {
self.linked_ports(node, p)
.filter(move |(n, _)| n == &other)
.map(move |(_, other_p)| [p, other_p])
})
}

fn num_ports(&self, node: Node, dir: Direction) -> usize {
let node: CircuitHistoryNode = node.into();
let hugr = self.get_diff_hugr(node.diff_index as usize);
hugr.num_ports(node.node(), dir)
}

fn children(&self, node: Node) -> impl DoubleEndedIterator<Item = Node> + Clone {
let node: CircuitHistoryNode = node.into();
let hugr = self.get_diff_hugr(node.diff_index as usize);
let to_owned = move |n| {
CircuitHistoryNode::try_new(node.diff_index as usize, n)
.unwrap()
.into()
};
hugr.children(node.node()).map(to_owned)
}

fn neighbours(&self, node: Node, dir: Direction) -> impl Iterator<Item = Node> + Clone {
self.node_ports(node, dir)
.flat_map(move |p| self.linked_ports(node, p))
.map(|(n, _)| n)
}

fn all_neighbours(&self, node: Node) -> impl Iterator<Item = Node> + Clone {
self.neighbours(node, Direction::Incoming)
.chain(self.neighbours(node, Direction::Outgoing))
}

/// Returns the operation type of a node.
fn get_optype(&self, node: Node) -> &OpType {
match self.contains_node(node) {
true => {
let node: CircuitHistoryNode = node.into();
// Get the diff_index-th element from all_nodes()
let hugr = self.get_diff_hugr(node.diff_index as usize);
hugr.get_optype(node.node())
}
false => &DEFAULT_OPTYPE,
}
}
}

/// Iterator over all nodes in the graph, using a simple depth-first search
#[derive_where(Clone)]
struct NodesIter<'h, H: HugrView> {
visited: BTreeSet<Node>,
current: Vec<Node>,
history: &'h ExperimentalHugrWrapper<H>,
}

impl<'h, H: HugrView> Iterator for NodesIter<'h, H> {
type Item = Node;

fn next(&mut self) -> Option<Self::Item> {
while let Some(node) = self.current.pop() {
debug_assert!(self.history.contains_node(node));
if self.visited.insert(node) {
// Add all neighbours to current
self.current.extend(
self.history
.all_neighbours(node)
.filter(|n| !self.visited.contains(n)),
);
return Some(node);
}
}
None
}
}

impl<H: HugrView> HugrInternals for ExperimentalHugrWrapper<H> {
type Portgraph<'p> = H::Portgraph<'p> where Self: 'p;

fn portgraph(&self) -> Self::Portgraph<'_> {
unimplemented!("no single portgraph for history")
}

fn base_hugr(&self) -> &hugr::Hugr {
unimplemented!("no single base hugr for history")
}

fn root_node(&self) -> hugr::Node {
let diff_index = self.get_diff_index(&self.0.root).unwrap();
let node = CircuitHistoryNode::try_new(diff_index, self.0.root.as_hugr().root_node())
.expect("diff_index or node_index too large for CircuitHistoryNode");
node.into()
}
}

impl<H: HugrView> ExtractHugr for ExperimentalHugrWrapper<H> {
fn extract_hugr(self) -> Hugr {
self.0.extract_hugr()
}
}

/// A node in the history, that can disguise as a [`Node`]
///
/// This is an ugly hack, where we encode the two integers of information
/// (the diff index in the history and the node index) into a single
/// [Node] by using the lower and upper 16 bits of the index.
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
struct CircuitHistoryNode {
/// The diff index this node belongs to
diff_index: u16,
/// The node index within the diff
node_index: u16,
}

impl CircuitHistoryNode {
fn try_new(diff_index: usize, node: Node) -> Option<Self> {
let diff_index = u16::try_from(diff_index).ok()?;
let node_index = u16::try_from(node.index()).ok()?;

Some(Self {
diff_index,
node_index,
})
}

fn node(&self) -> Node {
Node::from(portgraph::NodeIndex::new(self.node_index as usize))
}

fn diff<H: HugrView>(&self, history: &ExperimentalHugrWrapper<H>) -> CircuitDiff<H> {
history.get_diff(self.diff_index as usize)
}
}

impl From<Node> for CircuitHistoryNode {
fn from(node: Node) -> Self {
let index = node.index();
// The node index is the lower 16 bits of the node index
let node_index = (index & 0xFFFF) as u16;
// The diff index is the upper 16 bits of the node index
let diff_index = (index >> 16) as u16;
Self {
diff_index,
node_index,
}
}
}

impl From<CircuitHistoryNode> for Node {
fn from(node: CircuitHistoryNode) -> Self {
let mut index = node.node_index as usize;
index |= (node.diff_index as usize) << 16;
Node::from(portgraph::NodeIndex::new(index))
}
}
518 changes: 518 additions & 0 deletions tket2/src/diff/history.rs

Large diffs are not rendered by default.

33 changes: 9 additions & 24 deletions tket2/src/diff/replacement.rs
Original file line number Diff line number Diff line change
@@ -7,14 +7,16 @@ use hugr::{
rewrite::{HostPort, ReplacementPort},
views::PetgraphWrapper,
},
HugrView, Node, Port, SimpleReplacement, Wire,
HugrView, Node, Port, SimpleReplacement,
};
use itertools::{izip, Either, Itertools};
use itertools::{izip, Itertools};
use petgraph::visit::{depth_first_search, Control};

use crate::{rewrite::CircuitRewrite, Circuit};

use super::{CircuitDiff, CircuitDiffData, CircuitDiffError, ParentWire, WireEquivalence};
use super::{
port_to_wire, CircuitDiff, CircuitDiffData, CircuitDiffError, ParentWire, WireEquivalence,
};

impl CircuitDiff {
/// Apply a simple replacement.
@@ -52,7 +54,7 @@ impl CircuitDiff {
let mut equivalent_wires = WireEquivalence::new();
let to_parent_wire = |port: HostPort<Port>| -> Result<ParentWire, CircuitDiffError> {
let HostPort(node, port) = port;
let wire = to_wire(node, port, self.as_hugr())?;
let wire = port_to_wire(node, port, self.as_hugr())?;
let parent_wire = ParentWire {
incoming_index: 0, // `self` is the unique parent
wire,
@@ -64,7 +66,7 @@ impl CircuitDiff {
for (src, tgt) in replacement.incoming_boundary(self.as_hugr()) {
let src_parent_wire = to_parent_wire(src.into())?;
let ReplacementPort(tgt_node, tgt_port) = tgt;
let tgt_child_wire = to_wire(tgt_node, tgt_port, replacement.replacement())?;
let tgt_child_wire = port_to_wire(tgt_node, tgt_port, replacement.replacement())?;
let ret = equivalent_wires
.input_to_parent
.insert(tgt_child_wire, src_parent_wire);
@@ -74,7 +76,7 @@ impl CircuitDiff {
// 2. Add equivalences for replacement output wires
for (src, tgt) in replacement.outgoing_boundary(self.as_hugr()) {
let ReplacementPort(src_node, src_port) = src;
let src_child_wire = to_wire(src_node, src_port, replacement.replacement())?;
let src_child_wire = port_to_wire(src_node, src_port, replacement.replacement())?;
let tgt_parent_wire = to_parent_wire(tgt.into())?;
equivalent_wires
.output_to_parent
@@ -93,7 +95,7 @@ impl CircuitDiff {
.expect("invalid host to host edge");
let src_parent_wire = to_parent_wire(src.into())?;
let tgt_parent_wire = to_parent_wire(tgt.into())?;
let io_child_wire = to_wire(rep_output, incoming, replacement.replacement())?;
let io_child_wire = port_to_wire(rep_output, incoming, replacement.replacement())?;
let ret = equivalent_wires
.input_to_parent
.insert(io_child_wire, src_parent_wire);
@@ -112,23 +114,6 @@ impl CircuitDiff {
}
}

fn to_wire(
node: Node,
port: impl Into<Port>,
hugr: &impl HugrView,
) -> Result<Wire, CircuitDiffError> {
let port: Port = port.into();
match port.as_directed() {
Either::Left(incoming) => {
let (node, outgoing) = hugr
.single_linked_output(node, incoming)
.ok_or(CircuitDiffError::NoUniqueOutput(node, incoming))?;
Ok(Wire::new(node, outgoing))
}
Either::Right(outgoing) => Ok(Wire::new(node, outgoing)),
}
}

impl WireEquivalence {
/// The set of wires that are invalidated by a rewrite.
///
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
---
source: tket2/src/diff/history.rs
expression: history.extract_hugr().mermaid_string()
---
graph LR
subgraph 0 ["(0) DFG"]
direction LR
1["(1) Input"]
2["(2) Output"]
3["(3) tket2.quantum.Rz"]
4["(4) tket2.rotation.radd"]
5["(5) tket2.rotation.radd"]
6["(6) tket2.quantum.Rz"]
1--"0:0<br>qubit"-->6
1--"1:0<br>rotation"-->4
1--"1:0<br>rotation"-->5
1--"1:1<br>rotation"-->5
1--"2:1<br>rotation"-->4
3--"0:0<br>qubit"-->2
4--"0:1<br>rotation"-->3
5--"0:1<br>rotation"-->6
6--"0:0<br>qubit"-->3
end