Skip to content

Commit

Permalink
Remove add_predictor!
Browse files Browse the repository at this point in the history
  • Loading branch information
odow committed Jun 6, 2024
1 parent c3ef77a commit 0900d10
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 87 deletions.
5 changes: 1 addition & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,6 @@ Use `add_predictor`:
```julia
y = Omelette.add_predictor(model, predictor, x)
```
or:
```julia
Omelette.add_predictor!(model, predictor, x, y)
```

### LinearRegression

Expand All @@ -54,6 +50,7 @@ predictor = Omelette.LogisticRegression(model_glm)
## Other constraints

### UnivariateNormalDistribution

```julia
using JuMP, Omelette
model = Model();
Expand Down
71 changes: 28 additions & 43 deletions src/Omelette.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,64 +12,49 @@ import MathOptInterface as MOI
"""
abstract type AbstractPredictor end
An abstract type representig different types of prediction models.
## Methods
All subtypes must implement:
* `_add_predictor_inner`
* `Base.size`
* `add_predictor`
"""
abstract type AbstractPredictor end

Base.size(x::AbstractPredictor, i::Int) = size(x)[i]

"""
add_predictor!(
model::JuMP.Model,
predictor::AbstractPredictor,
x::Vector{JuMP.VariableRef},
y::Vector{JuMP.VariableRef},
)::Nothing
Add the constraint `predictor(x) .== y` to the optimization model `model`.
"""
function add_predictor!(
model::JuMP.Model,
predictor::AbstractPredictor,
x::Vector{JuMP.VariableRef},
y::Vector{JuMP.VariableRef},
)
output_n, input_n = size(predictor)
if length(x) != input_n
msg = "Input vector x is length $(length(x)), expected $input_n"
throw(DimensionMismatch(msg))
elseif length(y) != output_n
msg = "Output vector y is length $(length(y)), expected $output_n"
throw(DimensionMismatch(msg))
end
_add_predictor_inner(model, predictor, x, y)
return nothing
end

"""
add_predictor(
model::JuMP.Model,
predictor::AbstractPredictor,
x::Vector{JuMP.VariableRef},
)::Vector{JuMP.VariableRef}
Return an expression for `predictor(x)` in terms of variables in the
optimization model `model`.
Return a `Vector{JuMP.VariableRef}` representing `y` such that
`y = predictor(x)`.
## Example
```jldoctest
julia> using JuMP, Omelette
julia> model = Model();
julia> @variable(model, x[1:2]);
julia> f = Omelette.LinearRegression([2.0, 3.0])
Omelette.LinearRegression([2.0 3.0])
julia> y = Omelette.add_predictor(model, f, x)
1-element Vector{VariableRef}:
omelette_y[1]
julia> print(model)
Feasibility
Subject to
2 x[1] + 3 x[2] - omelette_y[1] = 0
```
"""
function add_predictor(
model::JuMP.Model,
predictor::AbstractPredictor,
x::Vector{JuMP.VariableRef},
)
y = JuMP.@variable(model, [1:size(predictor, 1)], base_name = "omelette_y")
add_predictor!(model, predictor, x, y)
return y
end
function add_predictor end

for file in readdir(joinpath(@__DIR__, "models"); join = true)
if endswith(file, ".jl")
Expand Down
9 changes: 4 additions & 5 deletions src/models/LinearRegression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,13 @@ function LinearRegression(parameters::Vector{Float64})
return LinearRegression(reshape(parameters, 1, length(parameters)))
end

Base.size(f::LinearRegression) = size(f.parameters)

function _add_predictor_inner(
function add_predictor(
model::JuMP.Model,
predictor::LinearRegression,
x::Vector{JuMP.VariableRef},
y::Vector{JuMP.VariableRef},
)
m = size(predictor.parameters, 1)
y = JuMP.@variable(model, [1:m], base_name = "omelette_y")
JuMP.@constraint(model, predictor.parameters * x .== y)
return
return y
end
7 changes: 4 additions & 3 deletions src/models/LogisticRegression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,13 @@ end

Base.size(f::LogisticRegression) = size(f.parameters)

function _add_predictor_inner(
function add_predictor(
model::JuMP.Model,
predictor::LogisticRegression,
x::Vector{JuMP.VariableRef},
y::Vector{JuMP.VariableRef},
)
m = size(predictor.parameters, 1)
y = JuMP.@variable(model, [1:m], base_name = "omelette_y")
JuMP.@constraint(model, 1 ./ (1 .+ exp.(-predictor.parameters * x)) .== y)
return
return y
end
17 changes: 1 addition & 16 deletions test/test_LinearRegression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,30 +24,15 @@ end
function test_LinearRegression()
model = Model()
@variable(model, x[1:2])
@variable(model, y[1:1])
f = Omelette.LinearRegression([2.0, 3.0])
Omelette.add_predictor!(model, f, x, y)
y = Omelette.add_predictor(model, f, x)
cons = all_constraints(model; include_variable_in_set_constraints = false)
obj = constraint_object(only(cons))
@test obj.set == MOI.EqualTo(0.0)
@test isequal_canonical(obj.func, 2.0 * x[1] + 3.0 * x[2] - y[1])
return
end

function test_LinearRegression_dimension_mismatch()
model = Model()
@variable(model, x[1:3])
@variable(model, y[1:2])
f = Omelette.LinearRegression([2.0, 3.0])
@test size(f) == (1, 2)
@test_throws DimensionMismatch Omelette.add_predictor!(model, f, x, y[1:1])
@test_throws DimensionMismatch Omelette.add_predictor!(model, f, x[1:2], y)
g = Omelette.LinearRegression([2.0 3.0; 4.0 5.0; 6.0 7.0])
@test size(g) == (3, 2)
@test_throws DimensionMismatch Omelette.add_predictor!(model, g, x, y)
return
end

function test_LinearRegression_GLM()
num_features = 2
num_observations = 10
Expand Down
17 changes: 1 addition & 16 deletions test/test_LogisticRegression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,8 @@ end
function test_LogisticRegression()
model = Model()
@variable(model, x[1:2])
@variable(model, y[1:1])
f = Omelette.LogisticRegression([2.0, 3.0])
Omelette.add_predictor!(model, f, x, y)
y = Omelette.add_predictor(model, f, x)
cons = all_constraints(model; include_variable_in_set_constraints = false)
obj = constraint_object(only(cons))
@test obj.set == MOI.EqualTo(0.0)
Expand All @@ -35,20 +34,6 @@ function test_LogisticRegression()
return
end

function test_LogisticRegression_dimension_mismatch()
model = Model()
@variable(model, x[1:3])
@variable(model, y[1:2])
f = Omelette.LogisticRegression([2.0, 3.0])
@test size(f) == (1, 2)
@test_throws DimensionMismatch Omelette.add_predictor!(model, f, x, y[1:1])
@test_throws DimensionMismatch Omelette.add_predictor!(model, f, x[1:2], y)
g = Omelette.LogisticRegression([2.0 3.0; 4.0 5.0; 6.0 7.0])
@test size(g) == (3, 2)
@test_throws DimensionMismatch Omelette.add_predictor!(model, g, x, y)
return
end

function test_LogisticRegression_GLM()
num_features = 2
num_observations = 10
Expand Down

0 comments on commit 0900d10

Please sign in to comment.