diff --git a/Project.toml b/Project.toml index 8d85cd8..0d25d46 100644 --- a/Project.toml +++ b/Project.toml @@ -10,9 +10,11 @@ MathOptInterface = "b8f27783-ece8-5eb3-8dc8-9495eed66fee" [weakdeps] GLM = "38e38edf-8417-5370-95a0-9cbb8c7f171a" +Lux = "b2108857-7c20-44ae-9111-449ecde12c47" [extensions] OmeletteGLMExt = "GLM" +OmeletteLuxExt = "Lux" [compat] Distributions = "0.25" diff --git a/ext/OmeletteLuxExt.jl b/ext/OmeletteLuxExt.jl new file mode 100644 index 0000000..e78155c --- /dev/null +++ b/ext/OmeletteLuxExt.jl @@ -0,0 +1,31 @@ +# Copyright (c) 2024: Oscar Dowson and contributors +# +# Use of this source code is governed by an MIT-style license that can be found +# in the LICENSE.md file or at https://opensource.org/licenses/MIT. + +module OmeletteLuxExt + +import Omelette +import Lux + +function _add_predictor(predictor::Omelette.Pipeline, layer::Lux.Dense, p) + push!(predictor.layers, Omelette.LinearRegression(p.weight, vec(p.bias))) + if layer.activation === identity + # Do nothing + elseif layer.activation === Lux.NNlib.relu + push!(predictor.layers, Omelette.ReLU(layer.out_dims, 1e6)) + else + error("Unsupported activation function: $x") + end + return +end + +function Omelette.Pipeline(x::Lux.Experimental.TrainState) + predictor = Omelette.Pipeline(Omelette.AbstractPredictor[]) + for (layer, parameter) in zip(x.model.layers, x.parameters) + _add_predictor(predictor, layer, parameter) + end + return predictor +end + +end #module diff --git a/src/models/LinearLayer.jl b/src/models/LinearLayer.jl new file mode 100644 index 0000000..2cd2b2c --- /dev/null +++ b/src/models/LinearLayer.jl @@ -0,0 +1,52 @@ +# Copyright (c) 2024: Oscar Dowson and contributors +# +# Use of this source code is governed by an MIT-style license that can be found +# in the LICENSE.md file or at https://opensource.org/licenses/MIT. + +""" + LinearRegression(parameters::Matrix) + +Represents the linear relationship: +```math +f(x) = A x +``` +where \$A\$ is the \$m \\times n\$ matrix `parameters`. + +## Example + +```jldoctest +julia> using JuMP, Omelette + +julia> model = Model(); + +julia> @variable(model, x[1:2]); + +julia> f = Omelette.LinearRegression([2.0, 3.0]) +Omelette.LinearRegression([2.0 3.0]) + +julia> y = Omelette.add_predictor(model, f, x) +1-element Vector{VariableRef}: + omelette_y[1] + +julia> print(model) + Feasibility + Subject to + 2 x[1] + 3 x[2] - omelette_y[1] = 0 +``` +""" +struct LinearLayer <: AbstractPredictor + weights::Matrix{Float64} + bias::Vector{Float64} +end + +Base.size(x::LinearLayer) = size(x.weights) + +function _add_predictor_inner( + model::JuMP.Model, + predictor::LinearLayer, + x::Vector{JuMP.VariableRef}, + y::Vector{JuMP.VariableRef}, +) + JuMP.@constraint(model, y .== predictor.weights * x .+ predictor.bias) + return +end diff --git a/src/models/LinearRegression.jl b/src/models/LinearRegression.jl index 235f890..8a03c18 100644 --- a/src/models/LinearRegression.jl +++ b/src/models/LinearRegression.jl @@ -4,13 +4,16 @@ # in the LICENSE.md file or at https://opensource.org/licenses/MIT. """ - LinearRegression(parameters::Matrix) + LinearRegression( + A::Matrix{Float64}, + b::Vector{Float64} = zeros(size(A, 1)), + ) Represents the linear relationship: ```math -f(x) = A x +f(x) = A x + b ``` -where \$A\$ is the \$m \\times n\$ matrix `parameters`. +where \$A\$ is the \$m \\times n\$ matrix `A`. ## Example @@ -22,7 +25,7 @@ julia> model = Model(); julia> @variable(model, x[1:2]); julia> f = Omelette.LinearRegression([2.0, 3.0]) -Omelette.LinearRegression([2.0 3.0]) +Omelette.LinearRegression([2.0 3.0], [0.0]) julia> y = Omelette.add_predictor(model, f, x) 1-element Vector{VariableRef}: @@ -35,14 +38,19 @@ julia> print(model) ``` """ struct LinearRegression <: AbstractPredictor - parameters::Matrix{Float64} + A::Matrix{Float64} + b::Vector{Float64} end -function LinearRegression(parameters::Vector{Float64}) - return LinearRegression(reshape(parameters, 1, length(parameters))) +function LinearRegression(A::Matrix{Float64}) + return LinearRegression(A, zeros(size(A, 1))) end -Base.size(f::LinearRegression) = size(f.parameters) +function LinearRegression(A::Vector{Float64}) + return LinearRegression(reshape(A, 1, length(A)), [0.0]) +end + +Base.size(f::LinearRegression) = size(f.A) function _add_predictor_inner( model::JuMP.Model, @@ -50,6 +58,6 @@ function _add_predictor_inner( x::Vector{JuMP.VariableRef}, y::Vector{JuMP.VariableRef}, ) - JuMP.@constraint(model, predictor.parameters * x .== y) + JuMP.@constraint(model, predictor.A * x .+ predictor.b .== y) return end diff --git a/src/models/Pipeline.jl b/src/models/Pipeline.jl new file mode 100644 index 0000000..b4a5c16 --- /dev/null +++ b/src/models/Pipeline.jl @@ -0,0 +1,26 @@ +# Copyright (c) 2024: Oscar Dowson and contributors +# +# Use of this source code is governed by an MIT-style license that can be found +# in the LICENSE.md file or at https://opensource.org/licenses/MIT. + +struct Pipeline <: AbstractPredictor + layers::Vector{AbstractPredictor} +end + +Base.size(x::Pipeline) = (size(last(x.layers), 1), size(first(x.layers), 2)) + +function _add_predictor_inner( + model::JuMP.Model, + predictor::Pipeline, + x::Vector{JuMP.VariableRef}, + y::Vector{JuMP.VariableRef}, +) + for (i, layer) in enumerate(predictor.layers) + if i == length(predictor.layers) + add_predictor!(model, layer, x, y) + else + x = add_predictor(model, layer, x) + end + end + return +end diff --git a/src/models/ReLU.jl b/src/models/ReLU.jl new file mode 100644 index 0000000..fa5ca04 --- /dev/null +++ b/src/models/ReLU.jl @@ -0,0 +1,26 @@ +# Copyright (c) 2024: Oscar Dowson and contributors +# +# Use of this source code is governed by an MIT-style license that can be found +# in the LICENSE.md file or at https://opensource.org/licenses/MIT. + +struct ReLU <: AbstractPredictor + dimension::Int + M::Float64 +end + +Base.size(x::ReLU) = (x.dimension, x.dimension) + +function _add_predictor_inner( + model::JuMP.Model, + predictor::ReLU, + x::Vector{JuMP.VariableRef}, + y::Vector{JuMP.VariableRef}, +) + # y = max(0, x) + z = JuMP.@variable(model, [1:length(x)], Bin) + JuMP.@constraint(model, y .>= 0) + JuMP.@constraint(model, y .>= x) + JuMP.@constraint(model, y .<= predictor.M * z) + JuMP.@constraint(model, y .<= x .+ predictor.M * (1 .- z)) + return +end diff --git a/test/Project.toml b/test/Project.toml index 137ee72..aa85605 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,10 +1,17 @@ [deps] +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" GLM = "38e38edf-8417-5370-95a0-9cbb8c7f171a" HiGHS = "87dc4568-4c63-4d18-b0c0-bb2238e4078b" Ipopt = "b6b21f68-93f8-5de0-b562-5493be1d77c9" JuMP = "4076af6c-e467-56ae-b986-b466b2749572" +Lux = "b2108857-7c20-44ae-9111-449ecde12c47" Omelette = "e52c2cb8-508e-4e12-9dd2-9c4755b60e73" +Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" +Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] GLM = "1" diff --git a/test/test_Lux.jl b/test/test_Lux.jl new file mode 100644 index 0000000..cdcc138 --- /dev/null +++ b/test/test_Lux.jl @@ -0,0 +1,89 @@ +# Copyright (c) 2024: Oscar Dowson and contributors +# +# Use of this source code is governed by an MIT-style license that can be found +# in the LICENSE.md file or at https://opensource.org/licenses/MIT. + +module LuxTests + +using JuMP +using Test + +import ADTypes +import HiGHS +import Lux +import Omelette +import Optimisers +import Random +import Statistics +import Zygote + +is_test(x) = startswith(string(x), "test_") + +function runtests() + @testset "$name" for name in filter(is_test, names(@__MODULE__; all = true)) + getfield(@__MODULE__, name)() + end + return +end + +function generate_data(rng::Random.AbstractRNG, n = 128) + x = range(-2.0, 2.0, n) + y = -2 .* x .+ x .^ 2 .+ 0.1 .* randn(rng, n) + return reshape(collect(x), (1, n)), reshape(y, (1, n)) +end + +function loss_function_mse(model, ps, state, (input, output)) + y_pred, updated_state = Lux.apply(model, input, ps, state) + loss = Statistics.mean(abs2, y_pred .- output) + return loss, updated_state, () +end + +function train_cpu( + model, + input, + output; + loss_function::Function = loss_function_mse, + vjp = ADTypes.AutoZygote(), + rng, + optimizer, + epochs::Int, +) + state = Lux.Experimental.TrainState(rng, model, optimizer) + data = (input, output) .|> Lux.cpu_device() + for epoch in 1:epochs + grads, loss, stats, state = + Lux.Experimental.compute_gradients(vjp, loss_function, data, state) + state = Lux.Experimental.apply_gradients!(state, grads) + end + return state +end + +function test_end_to_end() + rng = Random.MersenneTwister() + Random.seed!(rng, 12345) + x, y = generate_data(rng) + model = Lux.Chain(Lux.Dense(1 => 16, Lux.relu), Lux.Dense(16 => 1)) + state = train_cpu( + model, + x, + y; + rng = rng, + optimizer = Optimisers.Adam(0.03f0), + epochs = 250, + ) + f = Omelette.Pipeline(state) + model = Model(HiGHS.Optimizer) + set_silent(model) + @variable(model, x) + y = Omelette.add_predictor(model, f, [x]) + @constraint(model, only(y) <= 4) + @objective(model, Min, x) + optimize!(model) + @assert is_solved_and_feasible(model) + @test isapprox(value(x), -1.24; atol = 1e-2) + return +end + +end # module + +LuxTests.runtests()