Skip to content

Commit

Permalink
Adding support for Equal (#118)
Browse files Browse the repository at this point in the history
  • Loading branch information
dstarkenburg authored Feb 7, 2025
1 parent 6132a74 commit 7b3c4f9
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/load.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,10 @@ function load_node!(tape::Tape, ::OpConfig{:ONNX, :Where}, args::VarVec, attrs::
return push_call!(tape, _where, args...)
end

function load_node!(tape::Tape, ::OpConfig{:ONNX, :Equal}, args::VarVec, attrs::AttrDict)
return push_call!(tape, _equal, args...)
end

function load_node!(tape::Tape, nd::NodeProto, backend::Symbol)
args = [tape.c.name2var[name] for name in nd.input]
attrs = convert(Dict{Symbol, Any}, Dict(nd.attribute))
Expand Down
4 changes: 4 additions & 0 deletions src/ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ function _where(condition, x, y)
return ifelse.(condition, x, y)
end

function _equal(x, y)
return x .== y
end

add(xs...) = .+(xs...)
sub(xs...) = .-(xs...)
_sin(x) = sin.(x)
Expand Down
5 changes: 5 additions & 0 deletions src/save.jl
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,11 @@ function save_node!(g::GraphProto, ::OpConfig{:ONNX, typeof(_where)}, op::Umlaut
push!(g.node, nd)
end

function save_node!(g::GraphProto, ::OpConfig{:ONNX, typeof(_equal)}, op::Umlaut.Call)
nd = NodeProto("Equal", op)
push!(g.node, nd)
end

function save_node!(g::GraphProto, ::OpConfig{:ONNX, typeof(*)}, op::Umlaut.Call)
nd = NodeProto(
input=[onnx_name(v) for v in reverse(op.args)],
Expand Down
6 changes: 6 additions & 0 deletions test/saveload.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,12 @@ import ONNX: NodeProto, ValueInfoProto, AttributeProto, onnx_name
ort_test(ONNX._where, condition, A, B)
end

@testset "Equal" begin
A = rand(Bool, (1, 20))
B = rand(Bool, (1, 20))
ort_test(ONNX._equal, A, B)
end

@testset "Gemm" begin
A, B, C = (rand(3, 4), rand(3, 4), rand(3, 3))
ort_test(ONNX.onnx_gemm, A, B')
Expand Down

0 comments on commit 7b3c4f9

Please sign in to comment.