Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add inline-all-calls pass #1886

Open
doug-q opened this issue Jan 21, 2025 · 2 comments
Open

Add inline-all-calls pass #1886

doug-q opened this issue Jan 21, 2025 · 2 comments

Comments

@doug-q
Copy link
Collaborator

doug-q commented Jan 21, 2025

hugr-qir will need to inline all functions to produce valid qir.

We already have a Callgraph in hugr-passes, which is required for inlining.

I suggest that the inlining pass take a set of Call nodes as input, and that if these calls contain a cycle we error. Otherwise inline all those calls. That way we don't dictate the inlining policty, and in particular hugr-qir can easily "inline everything"

Prototype, which is pre our callgraph and perhaps out of date in other ways:

use hugr_core::{
    extension::ExtensionRegistry,
    hugr::{
        hugrmut::HugrMut,
        views::{DescendantsGraph, ExtractHugr as _, HierarchyView},
        HugrError, Rewrite, ValidationError,
    },
    ops::{DataflowOpTrait as _, OpTrait, DFG},
    Direction, HugrView, Node,
};
use itertools::Itertools as _;
use petgraph::visit::EdgeRef as _;
use thiserror::Error;

use crate::validation::ValidationLevel;

#[derive(Debug, Clone, Default)]
/// TODO docs
pub struct InlinePass {
    validation: ValidationLevel,
}

impl InlinePass {
    /// Sets the validation level used before and after the pass is run
    pub fn validation_level(mut self, level: ValidationLevel) -> Self {
        self.validation = level;
        self
    }

    pub fn run(
        &self,
        hugr: &mut impl HugrMut,
        registry: &ExtensionRegistry,
    ) -> Result<(), Box<dyn std::error::Error>> {
        self.validation
            .run_validated_pass_mut(hugr, registry, |hugr, _| {
                let mut calls = {
                    let cg = CallGraph::new(hugr);
                    let Some(calls) = cg.iter_nonrecursive() else {
                        Err("InlinePass: recursion")?
                    };
                    let mut calls = calls.collect_vec();
                    calls.reverse();
                    calls
                };
                // dbg!(&calls);

                let rewrites = calls
                    .iter()
                    .filter_map(|(caller, _)| InlineRewrite::try_new(hugr, *caller, registry).ok())
                    .collect_vec();

                for rewrite in rewrites {
                    hugr.apply_rewrite(rewrite).unwrap();
                }

                calls.reverse();

                for func_node in calls.into_iter().map(|x| x.1).dedup() {
                    let Some(func) = hugr.get_optype(func_node).as_func_defn() else {
                        panic!("impossible")
                    };
                    if hugr.linked_inputs(func_node, 0).count() == 0 && func.name != "main" {
                        // eprintln!("Removing func: {}", func.name);
                        let func_hugr = DescendantsGraph::<Node>::try_new(hugr, func_node).unwrap();
                        let to_delete = func_hugr.nodes().dedup().collect_vec();
                        for n in to_delete {
                            hugr.remove_node(n);
                        }
                    }
                }
                hugr.validate(registry)?;
                Ok(())
            })
    }
}

pub struct CallGraph {
    g: petgraph::graph::Graph<Node, Node>,
}

fn func_of_node(hugr: &impl HugrView, node: Node) -> Option<Node> {
    let mut n = node;
    while let Some(parent) = hugr.get_parent(n) {
        if hugr.get_optype(parent).is_func_defn() {
            return Some(parent);
        }
        n = parent;
    }
    None
}

impl CallGraph {
    pub fn new(hugr: &impl HugrView) -> Self {
        let mut g: petgraph::graph::Graph<Node, Node> = Default::default();

        let node_to_cg: HashMap<_, _> = hugr
            .nodes()
            .filter(|&n| (hugr.get_optype(n).is_func_decl() || hugr.get_optype(n).is_func_defn()))
            .map(|n| (n, g.add_node(n)))
            .collect();

        for n in hugr.nodes() {
            if let Some(call) = hugr.get_optype(n).as_call() {
                if let Some(caller_func) = func_of_node(hugr, n) {
                    if let Some((callee_func, _)) =
                        hugr.single_linked_output(n, call.called_function_port())
                    {
                        g.add_edge(node_to_cg[&caller_func], node_to_cg[&callee_func], n);
                    }
                }
            }
        }

        Self { g }
    }

    pub fn iter_nonrecursive(&self) -> Option<impl Iterator<Item = (Node, Node)> + '_> {
        let funcs = petgraph::algo::toposort(&self.g, None).ok()?;

        Some(funcs.into_iter().flat_map(move |f| {
            self.g
                .edges(f)
                .map(move |e| (*e.weight(), self.g[e.target()]))
        }))
    }
}

pub struct InlineRewrite<'a> {
    call: Node,
    func: Node,
    registry: &'a ExtensionRegistry,
}

impl<'a> InlineRewrite<'a> {
    pub fn try_new(
        hugr: &impl HugrView,
        call: Node,
        registry: &'a ExtensionRegistry,
    ) -> Result<Self, InlineRewriteError> {
        if !hugr.valid_node(call) {
            Err(InlineRewriteError::InvalidCall)?
        }
        let Some(call_ot) = hugr.get_optype(call).as_call() else {
            Err(InlineRewriteError::InvalidCall)?
        };

        let Some((func, _)) = hugr.single_linked_output(call, call_ot.called_function_port())
        else {
            Err(InlineRewriteError::InvalidCall)?
        };

        if !hugr.get_optype(func).is_func_defn() {
            Err(InlineRewriteError::InvalidFunction)?
        }

        let r = Self {
            call,
            func,
            registry,
        };
        debug_assert!(r.verify(hugr).is_ok());

        Ok(r)
    }
}

#[derive(Debug, Clone, Error)]
pub enum InlineRewriteError {
    #[error("Invalid Function")]
    InvalidFunction,
    #[error("Invalid Call")]
    InvalidCall,
    #[error("Call does not target func")]
    Invalid,
    #[error(transparent)]
    HugrError(#[from] HugrError),
    #[error(transparent)]
    Validation(#[from] ValidationError),
}

impl<'a> Rewrite for InlineRewrite<'a> {
    type Error = InlineRewriteError;

    type ApplyResult = ();

    const UNCHANGED_ON_FAILURE: bool = true;

    fn verify(&self, h: &impl HugrView) -> Result<(), Self::Error> {
        let Some(call) = h.get_optype(self.call).as_call() else {
            Err(InlineRewriteError::InvalidCall)?
        };
        if !call.type_args.is_empty() {
            Err(InlineRewriteError::InvalidCall)?
        }
        let Some(_) = h.get_optype(self.func).as_func_defn() else {
            Err(InlineRewriteError::InvalidFunction)?
        };

        if let Some((n, _)) = h.single_linked_output(self.call, call.called_function_port()) {
            if self.func != n {
                Err(InlineRewriteError::Invalid)?
            }
        }

        Ok(())
    }

    fn apply(self, h: &mut impl HugrMut) -> Result<Self::ApplyResult, Self::Error> {
        self.verify(h)?;

        // dbg!(self.call, self.func);

        let func_hugr = DescendantsGraph::<Node>::try_new(h, self.func)
            .map_err(|_| InlineRewriteError::InvalidFunction)?
            .extract_hugr();
        func_hugr.validate(self.registry)?;

        let call = h.get_optype(self.call).as_call().unwrap().to_owned();
        let call_parent = h.get_parent(self.call).unwrap();

        let signature = call.signature();

        let insertion = h.insert_hugr(call_parent, func_hugr);

        let dfg_node = insertion.new_root;
        let dfg = DFG { signature };
        h.set_num_ports(
            dfg_node,
            dfg.signature().input_count() + dfg.non_df_port_count(Direction::Incoming),
            dfg.signature().output_count() + dfg.non_df_port_count(Direction::Outgoing),
        );
        h.replace_op(dfg_node, dfg)?;

        let connections = h
            .node_inputs(self.call)
            .filter(|&x| x != call.called_function_port())
            .flat_map(|in_p| {
                h.linked_outputs(self.call, in_p)
                    .map(move |(out_n, out_p)| (out_n, out_p, dfg_node, in_p))
            })
            .chain(h.node_outputs(self.call).flat_map(|out_p| {
                h.linked_inputs(self.call, out_p)
                    .map(move |(in_n, in_p)| (dfg_node, out_p, in_n, in_p))
            }))
            .collect_vec();

        for (from_n, from_p, to_n, to_p) in connections {
            h.connect(from_n, from_p, to_n, to_p)
        }

        h.remove_node(self.call);
        Ok(())
    }

    fn invalidation_set(&self) -> impl Iterator<Item = Node> {
        [self.call, self.func].into_iter()
    }
}

@acl-cqc
Copy link
Contributor

acl-cqc commented Jan 21, 2025

I also have some inlining code in the static evaluator but both these will break in the function has static edges incoming from other functions, we need #1833 first

@acl-cqc acl-cqc self-assigned this Feb 26, 2025
@acl-cqc acl-cqc changed the title Add a function inlining pass Add InlineCall rewrite Feb 26, 2025
github-merge-queue bot pushed a commit that referenced this issue Mar 4, 2025

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
Including adding a new method for copying descendants of another node
(preserving edges incoming to that subtree), with optional Substitution.

The choice of copying descendants rather than a subtree and its root may
seem a little strange. The minor difference is that in most usecases
(others are peeling tail loops, and monomorphization which can be
refactored to use this too) we're gonna update the root node optype too,
so we might as well just make the new optype first rather than
copy+overwrite. The bigger differences are
* Edges between the copied subtree and the root (e.g. recursive Calls)
want to point to the original FuncDefn, not the inlined copy (which will
be a DFG). This helps the inline-call usecase, and may help
monomorphization (recursive calls may have different type args, so we
should monomorphize their targets too), and makes no difference for loop
peeling (there are no edges to/from a TailLoop from inside).
* Edges already existing to the new root (the old Call node) can be left
untouched, rather than having to be moved (and the old Call node
deleted). Again I think this makes no difference for loop peeling, nor
monomorphization.

closes #1833, #1886

---------

Co-authored-by: Douglas Wilson <douglas.wilson@quantinuum.com>
@acl-cqc acl-cqc changed the title Add InlineCall rewrite Add inline-all-calls pass Mar 19, 2025
@acl-cqc
Copy link
Contributor

acl-cqc commented Mar 19, 2025

#1934 handles inlining a single call, the description here is about a pass for all calls, hence updating title (and unassiging - @trvto are you working on this and if so would you like to take the issue?)

@acl-cqc acl-cqc removed their assignment Mar 19, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants