diff --git a/Project.toml b/Project.toml index 0d25d46..0ed3371 100644 --- a/Project.toml +++ b/Project.toml @@ -19,6 +19,7 @@ OmeletteLuxExt = "Lux" [compat] Distributions = "0.25" GLM = "1.9" +Lux = "0.5" JuMP = "1" MathOptInterface = "1" julia = "1.9" diff --git a/ext/OmeletteLuxExt.jl b/ext/OmeletteLuxExt.jl index e78155c..8b52b8b 100644 --- a/ext/OmeletteLuxExt.jl +++ b/ext/OmeletteLuxExt.jl @@ -13,7 +13,7 @@ function _add_predictor(predictor::Omelette.Pipeline, layer::Lux.Dense, p) if layer.activation === identity # Do nothing elseif layer.activation === Lux.NNlib.relu - push!(predictor.layers, Omelette.ReLU(layer.out_dims, 1e6)) + push!(predictor.layers, Omelette.ReLUBigM(layer.out_dims, 1e6)) else error("Unsupported activation function: $x") end diff --git a/src/models/Pipeline.jl b/src/models/Pipeline.jl index b4a5c16..d46103b 100644 --- a/src/models/Pipeline.jl +++ b/src/models/Pipeline.jl @@ -3,6 +3,43 @@ # Use of this source code is governed by an MIT-style license that can be found # in the LICENSE.md file or at https://opensource.org/licenses/MIT. +""" + Pipeline(layers::Vector{AbstractPredictor}) + +A pipeline of nested layers +```math +f(x) = l_N(\\ldots(l_2(l_1(x)) +``` + +## Example + +```jldoctest +julia> using JuMP, Omelette + +julia> model = Model(); + +julia> @variable(model, x[1:2]); + +julia> f = Omelette.Pipeline([ + Omelette.LinearLayer([1.0 2.0], [0.0]), + Omelette.ReLUQuadratic(1), + ]) +Omelette.Pipeline(Omelette.AbstractPredictor[Omelette.LinearLayer([1.0 2.0], [0.0]), Omelette.ReLUQuadratic(1)]) + +julia> y = Omelette.add_predictor(model, f, x) +1-element Vector{VariableRef}: + omelette_y[1] + +julia> print(model) +Feasibility +Subject to + -x[1] - 2 x[2] + omelette_y[1] = 0 + omelette_y[1] - _z[1]+ + _z[1]- = 0 + _z[1]+*_z[1]- = 0 + _z[1]+ ≥ 0 + _z[1]- ≥ 0 +``` +""" struct Pipeline <: AbstractPredictor layers::Vector{AbstractPredictor} end diff --git a/src/models/ReLU.jl b/src/models/ReLU.jl index fa5ca04..f7269cc 100644 --- a/src/models/ReLU.jl +++ b/src/models/ReLU.jl @@ -3,20 +3,59 @@ # Use of this source code is governed by an MIT-style license that can be found # in the LICENSE.md file or at https://opensource.org/licenses/MIT. -struct ReLU <: AbstractPredictor +""" + ReLUBigM(dimension::Int, M::Float64) + +Represents the rectified linear unit relationship: +```math +f(x) = max.(0, x) +``` + +## Example + +```jldoctest +julia> using JuMP, Omelette + +julia> model = Model(); + +julia> @variable(model, x[1:2]); + +julia> f = Omelette.ReLUBigM(2, 100.0) +Omelette.ReLUBigM(2, 100.0) + +julia> y = Omelette.add_predictor(model, f, x) +2-element Vector{VariableRef}: + omelette_y[1] + omelette_y[2] + +julia> print(model) +Feasibility +Subject to + omelette_y[1] ≥ 0 + omelette_y[2] ≥ 0 + -x[1] + omelette_y[1] ≥ 0 + -x[2] + omelette_y[2] ≥ 0 + omelette_y[1] - 100 _[5] ≤ 0 + omelette_y[2] - 100 _[6] ≤ 0 + -x[1] + omelette_y[1] + 100 _[5] ≤ 100 + -x[2] + omelette_y[2] + 100 _[6] ≤ 100 + _[5] binary + _[6] binary +``` +""" +struct ReLUBigM <: AbstractPredictor dimension::Int M::Float64 end -Base.size(x::ReLU) = (x.dimension, x.dimension) +Base.size(x::ReLUBigM) = (x.dimension, x.dimension) function _add_predictor_inner( model::JuMP.Model, - predictor::ReLU, + predictor::ReLUBigM, x::Vector{JuMP.VariableRef}, y::Vector{JuMP.VariableRef}, ) - # y = max(0, x) z = JuMP.@variable(model, [1:length(x)], Bin) JuMP.@constraint(model, y .>= 0) JuMP.@constraint(model, y .>= x) @@ -24,3 +63,126 @@ function _add_predictor_inner( JuMP.@constraint(model, y .<= x .+ predictor.M * (1 .- z)) return end + +""" + ReLUSOS1() + +Implements the ReLU constraint \$y = max(0, x)\$ by the reformulation: +```math +\\begin{aligned} +x = x^+ - x^- \\\\ +[x^+ , x^-] \\in SOS1 +\\end{aligned} +``` + +## Example + +```jldoctest +julia> using JuMP, Omelette + +julia> model = Model(); + +julia> @variable(model, x[1:2]); + +julia> f = Omelette.ReLUSOS1(2) +Omelette.ReLUSOS1(2) + +julia> y = Omelette.add_predictor(model, f, x) +2-element Vector{VariableRef}: + omelette_y[1] + omelette_y[2] + +julia> print(model) +Feasibility +Subject to + x[1] - _[5] + _[6] = 0 + x[2] - _[7] + _[8] = 0 + [_[5], _[6]] ∈ MathOptInterface.SOS1{Float64}([1.0, 2.0]) + [_[7], _[8]] ∈ MathOptInterface.SOS1{Float64}([1.0, 2.0]) + _[5] ≥ 0 + _[6] ≥ 0 + _[7] ≥ 0 + _[8] ≥ 0 +``` +""" +struct ReLUSOS1 <: AbstractPredictor + dimension::Int +end + +Base.size(x::ReLUSOS1) = (x.dimension, x.dimension) + +function _add_predictor_inner( + model::JuMP.Model, + predictor::ReLUSOS1, + x::Vector{JuMP.VariableRef}, + y::Vector{JuMP.VariableRef}, +) + for i in 1:length(x) + z = JuMP.@variable(model, [1:2], lower_bound = 0) + JuMP.@constraint(model, x[i] == z[1] - z[2]) + JuMP.@constraint(model, z in MOI.SOS1([1.0, 2.0])) + end + return +end + +""" + ReLUQuadratic() + +Implements the ReLU constraint \$y = max(0, x)\$ by the reformulation: +```math +\\begin{aligned} +x = x^+ - x^- \\\\ +x^+ \\times x^- = 0 +\\end{aligned} +``` + +## Example + +```jldoctest +julia> using JuMP, Omelette + +julia> model = Model(); + +julia> @variable(model, x[1:2]); + +julia> f = Omelette.ReLUQuadratic(2) +Omelette.ReLUQuadratic(2) + +julia> y = Omelette.add_predictor(model, f, x) +2-element Vector{VariableRef}: + omelette_y[1] + omelette_y[2] + +julia> print(model) +Feasibility +Subject to + x[1] - _z[1]+ + _z[1]- = 0 + x[2] - _z[2]+ + _z[2]- = 0 + _z[1]+*_z[1]- = 0 + _z[2]+*_z[2]- = 0 + _z[1]+ ≥ 0 + _z[1]- ≥ 0 + _z[2]+ ≥ 0 + _z[2]- ≥ 0 +``` +""" +struct ReLUQuadratic <: AbstractPredictor + dimension::Int +end + +Base.size(x::ReLUQuadratic) = (x.dimension, x.dimension) + +function _add_predictor_inner( + model::JuMP.Model, + predictor::ReLUQuadratic, + x::Vector{JuMP.VariableRef}, + y::Vector{JuMP.VariableRef}, +) + for i in 1:length(x) + z_pos = JuMP.@variable(model, lower_bound = 0, base_name = "_z[$i]+") + z_neg = JuMP.@variable(model, lower_bound = 0, base_name = "_z[$i]-") + JuMP.@constraint(model, x[i] == z_pos - z_neg) + JuMP.@constraint(model, z_pos * z_neg == 0) + end + return +end diff --git a/test/Project.toml b/test/Project.toml index aa85605..a48b52d 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -14,9 +14,16 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] +ADTypes = "1" GLM = "1" HiGHS = "1" Ipopt = "1" JuMP = "1" +Lux = "0.5" +Optimisers = "0.3" +Printf = "<0.0.1, 1.10" +Random = "<0.0.1, 1.10" +Statistics = "<0.0.1, 1.10" Test = "<0.0.1, 1.6" +Zygote = "0.6" julia = "1.9" diff --git a/test/test_ReLU.jl b/test/test_ReLU.jl new file mode 100644 index 0000000..96d294d --- /dev/null +++ b/test/test_ReLU.jl @@ -0,0 +1,63 @@ +# Copyright (c) 2024: Oscar Dowson and contributors +# +# Use of this source code is governed by an MIT-style license that can be found +# in the LICENSE.md file or at https://opensource.org/licenses/MIT. + +module ReLUTests + +using JuMP +using Test + +import GLM +import HiGHS +import Omelette + +is_test(x) = startswith(string(x), "test_") + +function runtests() + @testset "$name" for name in filter(is_test, names(@__MODULE__; all = true)) + getfield(@__MODULE__, name)() + end + return +end + +function test_ReLU_BigM() + model = Model() + @variable(model, x[1:2]) + f = Omelette.ReLUBigM(2, 100.0) + @test size(f) == (2, 2) + y = Omelette.add_predictor(model, f, x) + @test length(y) == 2 + @test num_variables(model) == 6 + @test num_constraints(model, AffExpr, MOI.LessThan{Float64}) == 4 + @test num_constraints(model, AffExpr, MOI.GreaterThan{Float64}) == 4 + return +end + +function test_ReLU_SOS1() + model = Model() + @variable(model, x[1:2]) + f = Omelette.ReLUSOS1(2) + @test size(f) == (2, 2) + y = Omelette.add_predictor(model, f, x) + @test length(y) == 2 + @test num_variables(model) == 8 + @test num_constraints(model, Vector{VariableRef}, MOI.SOS1{Float64}) == 2 + return +end + +function test_ReLU_Quadratic() + model = Model() + @variable(model, x[1:2]) + f = Omelette.ReLUQuadratic(2) + @test size(f) == (2, 2) + y = Omelette.add_predictor(model, f, x) + @test length(y) == 2 + @test num_variables(model) == 8 + @test num_constraints(model, QuadExpr, MOI.EqualTo{Float64}) == 2 + return +end + +end + +ReLUTests.runtests()