Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for SProd, Prod, and Sum #253

Closed
astralcai opened this issue May 9, 2024 · 7 comments
Closed

Support for SProd, Prod, and Sum #253

astralcai opened this issue May 9, 2024 · 7 comments
Assignees

Comments

@astralcai
Copy link
Contributor

Sum

Sum should be handled the same way as Hamiltonian and LinearCombination, which was partially addressed in #252, but the same treatment should be applied to translate_result_type and translate_result in translation.py as well.

Note: Sum.ops is deprecated, so instead of measurement.obs.ops, do _, ops = measurement.obs.terms(), and then use ops.

SProd and Prod

Since SProd and Prod could be nested, they are not guaranteed to be single-term observables. For example, an SProd could be 0.1 * (qml.Z(0) + qml.X(1)), in which case it's actually a Sum. Similarly, a Prod could be qml.Z(0) @ (qml.X(0) + qml.Y(1)).

This means that the same treatment for Hamiltonian, LinearCombination and Sum should extend to SProd and Prod as well, including _translate_observable, which should register Sum, SProd and Prod all under the same dispatch function as Hamiltonian, which uses H.terms().

Caveat: Prod.terms() will resolve to itself if the Prod only contains one term. For example:

>>> op = qml.X(0) @ qml.Y(1)
>>> op.terms()
([1.0], [X(0) @ Y(1)])

This may result in infinite recursion in _translate_observable, so a base case should be added to return reduce(lambda x, y: x @ y, [_translate_observable(factor) for factor in H.operands]) if H is a Prod with a single term.

Note: The terms() function will unwrap any nested structures but also simplify the observable. For example:

>>> op = qml.X(0) @ qml.I(1)
>>> op.terms()
([1.0], [X(0)])

This will create a mismatch between the number of targets in the translated observable and the original observable. We do plan on addressing this issue in PennyLane and have terms() recursively unwraps the observable without doing any simplification, but for now, in _pl_to_braket_circuit, do not use circuit.measurements directly, instead do something like

measurements = []
for mp in circuit.measurements:
    obs = mp.obs
    if isinstance(obs, (Hamiltonian, LinearCombination, Sum, SProd, Prod)):
        obs = obs.simplify()
        mp = type(mp)(obs)
   measurements.append(mp)

Then use measurements instead of circuit.measurements from this point on. The list of simplified measurements should also be passed into _apply_gradient_result_type and used there.

Device

Now since SProd, Prod, and Sum all could be nested, multi-term observables, they should be removed from the list of supported observables and added back if no shots are present:

@property
def observables(self) -> frozenset[str]:
    base_observables = frozenset(super().observables - {"Prod", "SProd", "Sum"})
    # Amazon Braket only supports coefficients and multiple terms when shots==0
    if not self.shots:
        return base_observables.union({"Hamiltonian", "LinearCombination", "Prod", "SProd", "Sum"})
    return base_observables
@AbeCoull
Copy link
Contributor

AbeCoull commented May 9, 2024

Hi @astralcai, thank you for raising this. I shall start looking into a fix.

@speller26
Copy link
Member

This means that the same treatment for Hamiltonian, LinearCombination and Sum should extend to SProd and Prod as well, including _translate_observable, which should register Sum, SProd and Prod all under the same dispatch function as Hamiltonian, which uses H.terms().

The current _translate_observable implementations for Sum, SProd and Prod recursively call _translate_observable on their operands:

@_translate_observable.register
def _(t: qml.operation.Tensor):
    return reduce(lambda x, y: x @ y, [_translate_observable(factor) for factor in t.obs])


@_translate_observable.register
def _(t: qml.ops.Prod):
    return reduce(lambda x, y: x @ y, [_translate_observable(factor) for factor in t.operands])


@_translate_observable.register
def _(t: qml.ops.SProd):
    return t.scalar * _translate_observable(t.base)

Shouldn't that take care of the nesting problem?

@speller26
Copy link
Member

measurements = []
for mp in circuit.measurements:
    obs = mp.obs
    if isinstance(obs, (Hamiltonian, LinearCombination, Sum, SProd, Prod)):
        obs = obs.simplify()
        mp = type(mp)(obs)
   measurements.append(mp)

I'm noticing that simplify alters the order of operands (at least in Prod); is this intentional?

@astralcai
Copy link
Contributor Author

This means that the same treatment for Hamiltonian, LinearCombination and Sum should extend to SProd and Prod as well, including _translate_observable, which should register Sum, SProd and Prod all under the same dispatch function as Hamiltonian, which uses H.terms().

The current _translate_observable implementations for Sum, SProd and Prod recursively call _translate_observable on their operands:

@_translate_observable.register
def _(t: qml.operation.Tensor):
    return reduce(lambda x, y: x @ y, [_translate_observable(factor) for factor in t.obs])


@_translate_observable.register
def _(t: qml.ops.Prod):
    return reduce(lambda x, y: x @ y, [_translate_observable(factor) for factor in t.operands])


@_translate_observable.register
def _(t: qml.ops.SProd):
    return t.scalar * _translate_observable(t.base)

Shouldn't that take care of the nesting problem?

It should, but as I recall it didn't. I was looking into it some time ago and couldn't make it work, that's why I suggested using the same approach for all potential multi-term observables. You can give it a try. I don't remember what the issue was exactly, but I believe it has something to do with the braket backend unable to parse scalar products.

/opt/hostedtoolcache/Python/3.10.14/x64/lib/python3.10/site-packages/braket/default_simulator/openqasm/interpreter.py:545: in _
    parsed = self.context.parse_pragma(node.command)
/opt/hostedtoolcache/Python/3.10.14/x64/lib/python3.10/site-packages/braket/default_simulator/openqasm/program_context.py:455: in parse_pragma
    return parse_braket_pragma(pragma_body, self.qubit_mapping)
/opt/hostedtoolcache/Python/3.10.14/x64/lib/python3.10/site-packages/braket/default_simulator/openqasm/parser/braket_pragmas.py:216: in parse_braket_pragma
    visited = BraketPragmaNodeVisitor(qubit_table).visit(tree)
/opt/hostedtoolcache/Python/3.10.14/x64/lib/python3.10/site-packages/antlr4/tree/Tree.py:34: in visit
    return tree.accept(self)
/opt/hostedtoolcache/Python/3.10.14/x64/lib/python3.10/site-packages/braket/default_simulator/openqasm/parser/generated/BraketPragmasParser.py:861: in accept
    return visitor.visitBraketPragma(self)
/opt/hostedtoolcache/Python/3.10.14/x64/lib/python3.10/site-packages/braket/default_simulator/openqasm/parser/generated/BraketPragmasParserVisitor.py:14: in visitBraketPragma
    return self.visitChildren(ctx)
/opt/hostedtoolcache/Python/3.10.14/x64/lib/python3.10/site-packages/antlr4/tree/Tree.py:44: in visitChildren
    childResult = c.accept(self)
/opt/hostedtoolcache/Python/3.10.14/x64/lib/python3.10/site-packages/braket/default_simulator/openqasm/parser/generated/BraketPragmasParser.py:1226: in accept
    return visitor.visitBraketResultPragma(self)
/opt/hostedtoolcache/Python/3.10.14/x64/lib/python3.10/site-packages/braket/default_simulator/openqasm/parser/generated/BraketPragmasParserVisitor.py:39: in visitBraketResultPragma
    return self.visitChildren(ctx)
/opt/hostedtoolcache/Python/3.10.14/x64/lib/python3.10/site-packages/antlr4/tree/Tree.py:44: in visitChildren
    childResult = c.accept(self)
/opt/hostedtoolcache/Python/3.10.14/x64/lib/python3.10/site-packages/braket/default_simulator/openqasm/parser/generated/BraketPragmasParser.py:1290: in accept
    return visitor.visitResultType(self)
/opt/hostedtoolcache/Python/3.10.14/x64/lib/python3.10/site-packages/braket/default_simulator/openqasm/parser/generated/BraketPragmasParserVisitor.py:44: in visitResultType
    return self.visitChildren(ctx)
/opt/hostedtoolcache/Python/3.10.14/x64/lib/python3.10/site-packages/antlr4/tree/Tree.py:44: in visitChildren
    childResult = c.accept(self)
/opt/hostedtoolcache/Python/3.10.14/x64/lib/python3.10/site-packages/braket/default_simulator/openqasm/parser/generated/BraketPragmasParser.py:1867: in accept
    return visitor.visitObservableResultType(self)
/opt/hostedtoolcache/Python/3.10.14/x64/lib/python3.10/site-packages/braket/default_simulator/openqasm/parser/braket_pragmas.py:98: in visitObservableResultType
    observables, targets = self.visit(ctx.observable())
E   TypeError: cannot unpack non-iterable NoneType object
----------------------------- Captured stderr call -----------------------------
line 1:26 mismatched input '0.1' expecting {'x', 'y', 'z', 'i', 'h', 'hermitian'}

This occured when trying to parse the scalar product of an observable. See this run: https://github.com/PennyLaneAI/plugin-test-matrix/actions/runs/9018042395/job/24777766316

@astralcai
Copy link
Contributor Author

measurements = []
for mp in circuit.measurements:
    obs = mp.obs
    if isinstance(obs, (Hamiltonian, LinearCombination, Sum, SProd, Prod)):
        obs = obs.simplify()
        mp = type(mp)(obs)
   measurements.append(mp)

I'm noticing that simplify alters the order of operands (at least in Prod); is this intentional?

Simplify does not preserve the original order of operands.

@speller26
Copy link
Member

speller26 commented Sep 4, 2024

Sorry for the delay; I finally managed to return to this, and I think I've found the actual issues. Looking at the device test run in #264, we observe two types of failures:

TypeError: cannot unpack non-iterable NoneType object
mismatched input '0.1' expecting {'x', 'y', 'z', 'i', 'h', 'hermitian'}

This is due to attempting to run Braket Sum observables on the local simulator, which does not support them. This is fixed by your suggestion of expanding the treatment of Hamiltonians to CompositeOp and SProd.

ValueError: Sum observable's target shape must be a nested list where each term's target length is equal to the observable term's qubits count.

This is because we pass in the MeasurementProcess' wires, which is a flat list, into translate_result_type, which expects a list of lists for its targets:

dev_wires = self.map_wires(measurement.wires).tolist()
translated = translate_result_type(
measurement, dev_wires, self._braket_result_types
)

This is fixed by mapping the wires of the MeasurementProcess itself and using those wires instead of passing in wires separately.

@speller26
Copy link
Member

Fixed in #275

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants