Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding support for And #113

Merged
merged 16 commits into from
Jan 22, 2025
4 changes: 2 additions & 2 deletions .github/workflows/Test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ jobs:
fail-fast: false
matrix:
version:
- '1.6'
- '1.9'
- '1.10'
- '1'
os:
- ubuntu-latest
arch:
Expand Down
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ONNX"
uuid = "d0dd6a25-fac6-55c0-abf7-829e0c774d20"
version = "0.2.7"
version = "0.3.0"

[deps]
EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
Expand All @@ -18,5 +18,5 @@ EnumX = "1"
NNlib = "0.8, 0.9"
ProtoBuf = "1.0"
StaticArrays = "1"
Umlaut = "0.4, 0.5, 0.6, 0.7"
julia = "1.6"
Umlaut = "0.7"
julia = "1.10"
4 changes: 4 additions & 0 deletions src/load.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ function load_node!(tape::Tape, ::OpConfig{:ONNX, :Acosh}, args::VarVec, attrs::
return push_call!(tape, _acosh, args[1])
end

function load_node!(tape::Tape, ::OpConfig{:ONNX, :And}, args::VarVec, attrs::AttrDict)
return push_call!(tape, and, args...)
end

function load_node!(tape::Tape, nd::NodeProto, backend::Symbol)
args = [tape.c.name2var[name] for name in nd.input]
attrs = convert(Dict{Symbol, Any}, Dict(nd.attribute))
Expand Down
4 changes: 4 additions & 0 deletions src/ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ function onnx_flatten(x; axis = 1)
return flatten(x; dim = dim)
end

function and(x, y)
return x .& y
end

add(xs...) = .+(xs...)
sub(xs...) = .-(xs...)
_sin(x) = sin.(x)
Expand Down
5 changes: 5 additions & 0 deletions src/save.jl
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,11 @@ function save_node!(g::GraphProto, ::OpConfig{:ONNX, typeof(_acosh)}, op::Umlaut
push!(g.node, nd)
end

function save_node!(g::GraphProto, ::OpConfig{:ONNX, typeof(and)}, op::Umlaut.Call)
nd = NodeProto("And", op)
push!(g.node, nd)
end

function save_node!(g::GraphProto, ::OpConfig{:ONNX, typeof(*)}, op::Umlaut.Call)
nd = NodeProto(
input=[onnx_name(v) for v in reverse(op.args)],
Expand Down
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Umlaut = "92992a2b-8ce5-4a9c-bb9d-58be9a7dc841"

[compat]
Umlaut = "0.4"
Umlaut = "0.7"
10 changes: 10 additions & 0 deletions test/saveload.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,16 @@ import ONNX: NodeProto, ValueInfoProto, AttributeProto, onnx_name
ort_test(tape, args...)
end

@testset "And" begin
# Testing matricies of similar shape
args = rand(Bool, 3, 4), rand(Bool, 3, 4)
ort_test(ONNX.and, args...)

# Testing Numpy-style broadcasting
args = rand(Bool, 3, 3), rand(Bool, 1, 3)
ort_test(ONNX.and, args...)
end

@testset "Basic ops" begin
args = (rand(3, 4), rand(3, 4))
ort_test(ONNX.add, args...)
Expand Down
Loading