Skip to content

Commit

Permalink
Add first draft of a basic NN
Browse files Browse the repository at this point in the history
  • Loading branch information
odow committed Jun 6, 2024
1 parent c3ef77a commit 417323d
Show file tree
Hide file tree
Showing 8 changed files with 250 additions and 9 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@ MathOptInterface = "b8f27783-ece8-5eb3-8dc8-9495eed66fee"

[weakdeps]
GLM = "38e38edf-8417-5370-95a0-9cbb8c7f171a"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"

[extensions]
OmeletteGLMExt = "GLM"
OmeletteLuxExt = "Lux"

[compat]
Distributions = "0.25"
Expand Down
31 changes: 31 additions & 0 deletions ext/OmeletteLuxExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Copyright (c) 2024: Oscar Dowson and contributors
#
# Use of this source code is governed by an MIT-style license that can be found
# in the LICENSE.md file or at https://opensource.org/licenses/MIT.

module OmeletteLuxExt

import Omelette
import Lux

function _add_predictor(predictor::Omelette.Pipeline, layer::Lux.Dense, p)
push!(predictor.layers, Omelette.LinearRegression(p.weight, vec(p.bias)))
if layer.activation === identity
# Do nothing
elseif layer.activation === Lux.NNlib.relu
push!(predictor.layers, Omelette.ReLU(layer.out_dims, 1e6))
else
error("Unsupported activation function: $x")
end
return
end

function Omelette.Pipeline(x::Lux.Experimental.TrainState)
predictor = Omelette.Pipeline(Omelette.AbstractPredictor[])
for (layer, parameter) in zip(x.model.layers, x.parameters)
_add_predictor(predictor, layer, parameter)
end
return predictor
end

end #module
52 changes: 52 additions & 0 deletions src/models/LinearLayer.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright (c) 2024: Oscar Dowson and contributors
#
# Use of this source code is governed by an MIT-style license that can be found
# in the LICENSE.md file or at https://opensource.org/licenses/MIT.

"""
LinearRegression(parameters::Matrix)
Represents the linear relationship:
```math
f(x) = A x
```
where \$A\$ is the \$m \\times n\$ matrix `parameters`.
## Example
```jldoctest
julia> using JuMP, Omelette
julia> model = Model();
julia> @variable(model, x[1:2]);
julia> f = Omelette.LinearRegression([2.0, 3.0])
Omelette.LinearRegression([2.0 3.0])
julia> y = Omelette.add_predictor(model, f, x)
1-element Vector{VariableRef}:
omelette_y[1]
julia> print(model)
Feasibility
Subject to
2 x[1] + 3 x[2] - omelette_y[1] = 0
```
"""
struct LinearLayer <: AbstractPredictor
weights::Matrix{Float64}
bias::Vector{Float64}
end

Base.size(x::LinearLayer) = size(x.weights)

function _add_predictor_inner(
model::JuMP.Model,
predictor::LinearLayer,
x::Vector{JuMP.VariableRef},
y::Vector{JuMP.VariableRef},
)
JuMP.@constraint(model, y .== predictor.weights * x .+ predictor.bias)
return
end
26 changes: 17 additions & 9 deletions src/models/LinearRegression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,16 @@
# in the LICENSE.md file or at https://opensource.org/licenses/MIT.

"""
LinearRegression(parameters::Matrix)
LinearRegression(
A::Matrix{Float64},
b::Vector{Float64} = zeros(size(A, 1)),
)
Represents the linear relationship:
```math
f(x) = A x
f(x) = A x + b
```
where \$A\$ is the \$m \\times n\$ matrix `parameters`.
where \$A\$ is the \$m \\times n\$ matrix `A`.
## Example
Expand All @@ -22,7 +25,7 @@ julia> model = Model();
julia> @variable(model, x[1:2]);
julia> f = Omelette.LinearRegression([2.0, 3.0])
Omelette.LinearRegression([2.0 3.0])
Omelette.LinearRegression([2.0 3.0], [0.0])
julia> y = Omelette.add_predictor(model, f, x)
1-element Vector{VariableRef}:
Expand All @@ -35,21 +38,26 @@ julia> print(model)
```
"""
struct LinearRegression <: AbstractPredictor
parameters::Matrix{Float64}
A::Matrix{Float64}
b::Vector{Float64}
end

function LinearRegression(parameters::Vector{Float64})
return LinearRegression(reshape(parameters, 1, length(parameters)))
function LinearRegression(A::Matrix{Float64})
return LinearRegression(A, zeros(size(A, 1)))
end

Base.size(f::LinearRegression) = size(f.parameters)
function LinearRegression(A::Vector{Float64})
return LinearRegression(reshape(A, 1, length(A)), [0.0])
end

Base.size(f::LinearRegression) = size(f.A)

function _add_predictor_inner(
model::JuMP.Model,
predictor::LinearRegression,
x::Vector{JuMP.VariableRef},
y::Vector{JuMP.VariableRef},
)
JuMP.@constraint(model, predictor.parameters * x .== y)
JuMP.@constraint(model, predictor.A * x .+ predictor.b .== y)
return
end
26 changes: 26 additions & 0 deletions src/models/Pipeline.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright (c) 2024: Oscar Dowson and contributors
#
# Use of this source code is governed by an MIT-style license that can be found
# in the LICENSE.md file or at https://opensource.org/licenses/MIT.

struct Pipeline <: AbstractPredictor
layers::Vector{AbstractPredictor}
end

Base.size(x::Pipeline) = (size(last(x.layers), 1), size(first(x.layers), 2))

function _add_predictor_inner(
model::JuMP.Model,
predictor::Pipeline,
x::Vector{JuMP.VariableRef},
y::Vector{JuMP.VariableRef},
)
for (i, layer) in enumerate(predictor.layers)
if i == length(predictor.layers)
add_predictor!(model, layer, x, y)
else
x = add_predictor(model, layer, x)
end
end
return
end
26 changes: 26 additions & 0 deletions src/models/ReLU.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright (c) 2024: Oscar Dowson and contributors
#
# Use of this source code is governed by an MIT-style license that can be found
# in the LICENSE.md file or at https://opensource.org/licenses/MIT.

struct ReLU <: AbstractPredictor
dimension::Int
M::Float64
end

Base.size(x::ReLU) = (x.dimension, x.dimension)

function _add_predictor_inner(
model::JuMP.Model,
predictor::ReLU,
x::Vector{JuMP.VariableRef},
y::Vector{JuMP.VariableRef},
)
# y = max(0, x)
z = JuMP.@variable(model, [1:length(x)], Bin)
JuMP.@constraint(model, y .>= 0)
JuMP.@constraint(model, y .>= x)
JuMP.@constraint(model, y .<= predictor.M * z)
JuMP.@constraint(model, y .<= x .+ predictor.M * (1 .- z))
return
end
7 changes: 7 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
GLM = "38e38edf-8417-5370-95a0-9cbb8c7f171a"
HiGHS = "87dc4568-4c63-4d18-b0c0-bb2238e4078b"
Ipopt = "b6b21f68-93f8-5de0-b562-5493be1d77c9"
JuMP = "4076af6c-e467-56ae-b986-b466b2749572"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
Omelette = "e52c2cb8-508e-4e12-9dd2-9c4755b60e73"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
GLM = "1"
Expand Down
89 changes: 89 additions & 0 deletions test/test_Lux.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Copyright (c) 2024: Oscar Dowson and contributors
#
# Use of this source code is governed by an MIT-style license that can be found
# in the LICENSE.md file or at https://opensource.org/licenses/MIT.

module LuxTests

using JuMP
using Test

import ADTypes
import HiGHS
import Lux
import Omelette
import Optimisers
import Random
import Statistics
import Zygote

is_test(x) = startswith(string(x), "test_")

function runtests()
@testset "$name" for name in filter(is_test, names(@__MODULE__; all = true))
getfield(@__MODULE__, name)()
end
return
end

function generate_data(rng::Random.AbstractRNG, n = 128)
x = range(-2.0, 2.0, n)
y = -2 .* x .+ x .^ 2 .+ 0.1 .* randn(rng, n)
return reshape(collect(x), (1, n)), reshape(y, (1, n))
end

function loss_function_mse(model, ps, state, (input, output))
y_pred, updated_state = Lux.apply(model, input, ps, state)
loss = Statistics.mean(abs2, y_pred .- output)
return loss, updated_state, ()
end

function train_cpu(
model,
input,
output;
loss_function::Function = loss_function_mse,
vjp = ADTypes.AutoZygote(),
rng,
optimizer,
epochs::Int,
)
state = Lux.Experimental.TrainState(rng, model, optimizer)
data = (input, output) .|> Lux.cpu_device()
for epoch in 1:epochs
grads, loss, stats, state =
Lux.Experimental.compute_gradients(vjp, loss_function, data, state)
state = Lux.Experimental.apply_gradients!(state, grads)
end
return state
end

function test_end_to_end()
rng = Random.MersenneTwister()
Random.seed!(rng, 12345)
x, y = generate_data(rng)
model = Lux.Chain(Lux.Dense(1 => 16, Lux.relu), Lux.Dense(16 => 1))
state = train_cpu(
model,
x,
y;
rng = rng,
optimizer = Optimisers.Adam(0.03f0),
epochs = 250,
)
f = Omelette.Pipeline(state)
model = Model(HiGHS.Optimizer)
set_silent(model)
@variable(model, x)
y = Omelette.add_predictor(model, f, [x])
@constraint(model, only(y) <= 4)
@objective(model, Min, x)
optimize!(model)
@assert is_solved_and_feasible(model)
@test isapprox(value(x), -1.24; atol = 1e-2)
return
end

end # module

LuxTests.runtests()

0 comments on commit 417323d

Please sign in to comment.