Skip to content

Commit

Permalink
rem ModelFrame, add wts obj
Browse files Browse the repository at this point in the history
  • Loading branch information
PharmCat committed Dec 24, 2023
1 parent 73298d1 commit caeaea2
Show file tree
Hide file tree
Showing 9 changed files with 114 additions and 51 deletions.
14 changes: 7 additions & 7 deletions src/dof_contain.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ end
Minimum returned. If no random effect found N - rank(XZ) returned.
"""
function dof_contain(lmm, i)
ind = lmm.mm.assign[i]
sym = StatsModels.termvars(lmm.mf.f.rhs.terms[ind])
ind = lmm.modstr.assign[i]
sym = StatsModels.termvars(lmm.f.rhs.terms[ind])
rr = Vector{Int}(undef, 0)
for r = 1:length(lmm.covstr.random)
if length(intersect(sym, StatsModels.termvars(lmm.covstr.random[r].model))) > 0
Expand All @@ -57,12 +57,12 @@ function dof_contain(lmm, i)
end

function dof_contain(lmm)
dof = zeros(Int, length(lmm.mm.assign))
dof = zeros(Int, length(lmm.modstr.assign))
rrt = zeros(Int, length(lmm.covstr.random))
rz = 0
for i = 1:length(lmm.mm.assign)
ind = lmm.mm.assign[i]
sym = StatsModels.termvars(lmm.mf.f.rhs.terms[ind])
for i = 1:length(lmm.modstr.assign)
ind = lmm.modstr.assign[i]
sym = StatsModels.termvars(lmm.f.rhs.terms[ind])
rr = Vector{Int}(undef, 0)
for r = 1:length(lmm.covstr.random)
if length(intersect(sym, StatsModels.termvars(lmm.covstr.random[r].model))) > 0
Expand All @@ -87,7 +87,7 @@ end
"""
function dof_contain_f(lmm, i)
sym = StatsModels.termvars(lmm.mf.f.rhs.terms[i])
sym = StatsModels.termvars(lmm.f.rhs.terms[i])
rr = Vector{Int}(undef, 0)
for r = 1:length(lmm.covstr.random)
if length(intersect(sym, StatsModels.termvars(lmm.covstr.random[r].model))) > 0
Expand Down
57 changes: 43 additions & 14 deletions src/lmm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@ struct LMMLogMsg
msg::String
end


struct ModelStructure
assign::Vector{Int64}
end

"""
LMM(model, data; contrasts=Dict{Symbol,Any}(), random::Union{Nothing, VarEffect, Vector{VarEffect}} = nothing, repeated::Union{Nothing, VarEffect} = nothing)
Expand All @@ -24,31 +29,33 @@ See also: [`@lmmformula`](@ref)
"""
struct LMM{T<:AbstractFloat} <: MetidaModel
model::FormulaTerm
mf::ModelFrame
mm::ModelMatrix
f::FormulaTerm
modstr::ModelStructure
covstr::CovStructure
data::LMMData{T}
dv::LMMDataViews{T}
nfixed::Int
rankx::Int
result::ModelResult
maxvcbl::Int
wts::Union{Nothing, LMMWts}
log::Vector{LMMLogMsg}

function LMM(model::FormulaTerm,
mf::ModelFrame,
mm::ModelMatrix,
f::FormulaTerm,
modstr::ModelStructure,
covstr::CovStructure,
data::LMMData{T},
dv::LMMDataViews{T},
nfixed::Int,
rankx::Int,
result::ModelResult,
maxvcbl::Int,
wts::Union{Nothing, LMMWts},
log::Vector{LMMLogMsg}) where T
new{T}(model, mf, mm, covstr, data, dv, nfixed, rankx, result, maxvcbl, log)
new{T}(model, f, modstr, covstr, data, dv, nfixed, rankx, result, maxvcbl, wts, log)
end
function LMM(model, data; contrasts=Dict{Symbol,Any}(), random::Union{Nothing, VarEffect, Vector{VarEffect}} = nothing, repeated::Union{Nothing, VarEffect} = nothing)
function LMM(model, data; contrasts=Dict{Symbol,Any}(), random::Union{Nothing, VarEffect, Vector{VarEffect}} = nothing, repeated::Union{Nothing, VarEffect} = nothing, wts = nothing)
#need check responce - Float
if !Tables.istable(data) error("Data not a table!") end
if repeated === nothing && random === nothing
Expand All @@ -68,10 +75,15 @@ struct LMM{T<:AbstractFloat} <: MetidaModel
lmmlog = Vector{LMMLogMsg}(undef, 0)
sch = schema(model, data, contrasts)
f = apply_schema(model, sch, MetidaModel)
mf = ModelFrame(f, sch, data, MetidaModel)

rmf, lmf = modelcols(f, data)

assign = StatsModels.asgn(f)

#mf = ModelFrame(f, sch, data, MetidaModel)
#mf = ModelFrame(model, data; contrasts = contrasts)
mm = ModelMatrix(mf)
nfixed = nterms(mf)
#mm = ModelMatrix(mf)
nfixed = fixedeffn(f)
if repeated === nothing
repeated = NOREPEAT
end
Expand All @@ -86,9 +98,9 @@ struct LMM{T<:AbstractFloat} <: MetidaModel
lmmlog!(lmmlog, 1, LMMLogMsg(:WARN, "Repeated effect not a constant, but covariance type is SI. "))
end
end
rmf = response(mf)
#rmf = response(mf)
if !(eltype(rmf) <: AbstractFloat) @warn "Response variable not <: AbstractFloat" end
lmmdata = LMMData(modelmatrix(mf), rmf)
lmmdata = LMMData(lmf, rmf)

covstr = CovStructure(random, repeated, data)
coefn = size(lmmdata.xv, 2)
Expand All @@ -97,11 +109,24 @@ struct LMM{T<:AbstractFloat} <: MetidaModel
@warn "Fixed-effect matrix not full-rank!"
lmmlog!(lmmlog, 1, LMMLogMsg(:WARN, "Fixed-effect matrix not full-rank!"))
end

if isnothing(wts)
lmmwts = nothing
else
if length(lmmdata.yv) == length(wts)
lmmwts = LMMWts(wts, covstr.vcovblock)

Check warning on line 117 in src/lmm.jl

View check run for this annotation

Codecov / codecov/patch

src/lmm.jl#L116-L117

Added lines #L116 - L117 were not covered by tests
else
@warn "wts count not equal observations count! wts not used."
lmmwts = nothing

Check warning on line 120 in src/lmm.jl

View check run for this annotation

Codecov / codecov/patch

src/lmm.jl#L119-L120

Added lines #L119 - L120 were not covered by tests
end
end

mres = ModelResult(false, nothing, fill(NaN, covstr.tl), NaN, fill(NaN, coefn), nothing, fill(NaN, coefn, coefn), fill(NaN, coefn), nothing, false)
LMM(model, mf, mm, covstr, lmmdata, LMMDataViews(lmmdata.xv, lmmdata.yv, covstr.vcovblock), nfixed, rankx, mres, findmax(length, covstr.vcovblock)[1], lmmlog)

LMM(model, f, ModelStructure(assign), covstr, lmmdata, LMMDataViews(lmmdata.xv, lmmdata.yv, covstr.vcovblock), nfixed, rankx, mres, findmax(length, covstr.vcovblock)[1], lmmwts, lmmlog)
end
function LMM(f::LMMformula, data; contrasts=Dict{Symbol,Any}(), kwargs...)
LMM(f.formula, data; contrasts=contrasts, random = f.random, repeated = f.repeated)
function LMM(f::LMMformula, data; kwargs...)
LMM(f.formula, data; random = f.random, repeated = f.repeated, kwargs...)
end
end

Expand Down Expand Up @@ -150,6 +175,10 @@ end
function maxblocksize(mm::MetidaModel)
mm.maxvcbl
end
function assign(lmm::LMM)
lmm.modstr.assign
end

################################################################################
function lmmlog!(io, lmmlog::Vector{LMMLogMsg}, verbose, vmsg)
if verbose == 1
Expand Down
14 changes: 14 additions & 0 deletions src/lmmdata.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,17 @@ struct LMMDataViews{T<:AbstractFloat} <: AbstractLMMDataBlocks
return LMMDataViews(lmm.data.xv, lmm.data.yv, lmm.covstr.vcovblock)
end
end

struct LMMWts{T<:AbstractFloat}
sqrtwts::Vector{Vector{T}}
function LMMWts(sqrtwts::Vector{Vector{T}}) where T
new{T}(sqrtwts)

Check warning on line 41 in src/lmmdata.jl

View check run for this annotation

Codecov / codecov/patch

src/lmmdata.jl#L40-L41

Added lines #L40 - L41 were not covered by tests
end
function LMMWts(wts::Vector{T}, vcovblock) where T
sqrtwts = Vector{Vector{T}}(undef, length(vcovblock))
for i in eachindex(vcovblock)
y[i] = sqrt.(view(wts, vcovblock[i]))
end
LMMWts(sqrtwts)

Check warning on line 48 in src/lmmdata.jl

View check run for this annotation

Codecov / codecov/patch

src/lmmdata.jl#L43-L48

Added lines #L43 - L48 were not covered by tests
end
end
26 changes: 15 additions & 11 deletions src/miboot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,14 @@ Multiple imputation model.
"""
struct MILMM{T} <: MetidaModel
lmm::LMM{T}
mf::ModelFrame
mm::ModelMatrix
f::FormulaTerm
modstr::ModelStructure
covstr::CovStructure
data::LMMData{T}
dv::LMMDataViews{T}
maxvcbl::Int
mrs::MRS
wts::Union{Nothing, LMMWts}
log::Vector{LMMLogMsg}
function MILMM(lmm::LMM{T}, data) where T
if !Tables.istable(data) error("Data not a table!") end
Expand All @@ -42,15 +43,18 @@ struct MILMM{T} <: MetidaModel
replace!(rcol, missing => NaN)
data = merge(NamedTuple{(rv,)}((convert(Vector{Float64}, rcol),)), datam)
lmmlog = Vector{LMMLogMsg}(undef, 0)
mf = ModelFrame(lmm.mf.f, lmm.mf.schema, data, MetidaModel)
mm = ModelMatrix(mf)
mmf = mm.m
lmmdata = LMMData(mmf, data[rv])
rmf, lmf = modelcols(lmm.f, data)

#mf = ModelFrame(lmm.f, lmm.mf.schema, data, MetidaModel)
#mm = ModelMatrix(mf)
#mmf = mm.m

lmmdata = LMMData(lmf, data[rv])
covstr = CovStructure(lmm.covstr.random, lmm.covstr.repeated, data)
dv = LMMDataViews(mmf, lmmdata.yv, covstr.vcovblock)
dv = LMMDataViews(lmf, lmmdata.yv, covstr.vcovblock)
mb = missblocks(dv.yv)
dist = mrsdist(lmm, mb, covstr, dv.xv, dv.yv)
new{T}(lmm, mf, mm, covstr, lmmdata, dv, findmax(length, covstr.vcovblock)[1], MRS(mb, dist), lmmlog)
new{T}(lmm, lmm.f, lmm.modstr, covstr, lmmdata, dv, findmax(length, covstr.vcovblock)[1], MRS(mb, dist), lmm.wts, lmmlog)
end
end
struct MILMMResult{T}
Expand Down Expand Up @@ -225,7 +229,7 @@ function bootstrap_(lmm::LMM{T}; n, verbose, init, rng, del) where T


mres = ModelResult(false, nothing, fill(NaN, thetalength(lmm)), NaN, fill(NaN, coefn(lmm)), nothing, fill(NaN, coefn(lmm), coefn(lmm)), fill(NaN, coefn(lmm)), nothing, false)
lmmb = LMM(lmm.model, lmm.mf, lmm.mm, lmm.covstr, lmm.data, LMMDataViews(lmm.dv.xv, deepcopy(lmm.dv.yv)), lmm.nfixed, lmm.rankx, mres, lmm.maxvcbl, Vector{LMMLogMsg}(undef, 0))
lmmb = LMM(lmm.model, lmm.f, lmm.modstr, lmm.covstr, lmm.data, LMMDataViews(lmm.dv.xv, deepcopy(lmm.dv.yv)), lmm.nfixed, lmm.rankx, mres, lmm.maxvcbl, lmm.wts, Vector{LMMLogMsg}(undef, 0))

vi = findall(x-> x == :var, lmm.covstr.ct)
tlmm = theta_(lmm) .^ 2
Expand Down Expand Up @@ -289,7 +293,7 @@ function dbootstrap_(lmm::LMM{T}; n, verbose, init, rng, del) where T
log = Vector{LMMLogMsg}(undef, 0)

mres = ModelResult(false, nothing, fill(NaN, thetalength(lmm)), NaN, fill(NaN, coefn(lmm)), nothing, fill(NaN, coefn(lmm), coefn(lmm)), fill(NaN, coefn(lmm)), nothing, false)
lmmb = LMM(lmm.model, lmm.mf, lmm.mm, lmm.covstr, lmm.data, LMMDataViews(lmm.dv.xv, deepcopy(lmm.dv.yv)), lmm.nfixed, lmm.rankx, mres, lmm.maxvcbl, Vector{LMMLogMsg}(undef, 0))
lmmb = LMM(lmm.model, lmm.f, lmm.modstr, lmm.covstr, lmm.data, LMMDataViews(lmm.dv.xv, deepcopy(lmm.dv.yv)), lmm.nfixed, lmm.rankx, mres, lmm.maxvcbl, lmm.wts, Vector{LMMLogMsg}(undef, 0))

vi = findall(x-> x == :var, lmm.covstr.ct)
tlmm = theta_(lmm) .^ 2
Expand Down Expand Up @@ -432,7 +436,7 @@ function milmm(mi::MILMM; n = 100, verbose = true, rng = default_rng())
ty = Vector{Float64}(undef, max)
for i = 1:n
data, dv = generate_mi(rng, mi.data, mi.dv, mi.covstr.vcovblock, mi.mrs, rb, ty)
lmmi = LMM(mi.lmm.model, mi.mf, mi.mm, mi.covstr, data, dv, mi.lmm.nfixed, mi.lmm.rankx, deepcopy(mi.lmm.result), mi.maxvcbl, mi.log)
lmmi = LMM(mi.lmm.model, mi.f, mi.modstr, mi.covstr, data, dv, mi.lmm.nfixed, mi.lmm.rankx, deepcopy(mi.lmm.result), mi.maxvcbl, mi.wts, mi.log)
lmm[i] = lmmi
end
p = Progress(n, dt = 0.5,
Expand Down
5 changes: 2 additions & 3 deletions src/statsbase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ end
Coefficients names.
"""
StatsBase.coefnames(lmm::LMM) = StatsBase.coefnames(lmm.mf)
StatsBase.coefnames(lmm::LMM) = StatsBase.coefnames(lmm.f)[2]

"""
StatsBase.nobs(lmm::MetiaModel)
Expand Down Expand Up @@ -238,8 +238,7 @@ end
Responce varible name.
"""
function StatsBase.responsename(lmm::LMM)
cnm = coefnames(lmm.mf.f.lhs)
return isa(cnm, Vector{String}) ? first(cnm) : cnm
StatsBase.coefnames(lmm.f)[1]
end


Expand Down
2 changes: 1 addition & 1 deletion src/statsmodels.jl
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@

StatsModels.formula(lmm::LMM) = lmm.mf.f
StatsModels.formula(lmm::LMM) = lmm.f
4 changes: 2 additions & 2 deletions src/typeiii.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@ Type III table.
"""
function typeiii(lmm::LMM; ddf::Symbol = :satter)
if !isfitted(lmm) error("Model not fitted!") end
c = length(lmm.mf.f.rhs.terms)
c = length(lmm.f.rhs.terms)
d = Vector{Int}(undef, 0)
fac = Vector{String}(undef, c)
F = Vector{Float64}(undef,c)
df = Vector{Float64}(undef, c)
ndf = Vector{Float64}(undef, c)
pval = Vector{Float64}(undef, c)
for i = 1:c
iterm = lmm.mf.f.rhs.terms[i]
iterm = lmm.f.rhs.terms[i]

if typeof(iterm) <: InterceptTerm{false}
push!(d, i)
Expand Down
32 changes: 21 additions & 11 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,16 @@ function initvar(y::Vector, X::Matrix{T}) where T
dot(r, r)/(length(r) - size(X, 2)), β
end
################################################################################
function nterms(lmm::LMM)
nterms(lmm.mf)
function fixedeffn(f::FormulaTerm)
length(f.rhs.terms) - !StatsModels.hasintercept(f)
end
function fixedeffn(lmm::LMM)
fixedeffn(lmm.f)
end
#=
function nterms(mf::ModelFrame)
mf.schema.schema.count
end
=#
function nterms(rhs::Union{Tuple{Vararg{AbstractTerm}}, Nothing, AbstractTerm})
if isa(rhs, Term)
p = 1
Expand All @@ -26,7 +30,6 @@ function nterms(rhs::Union{Tuple{Vararg{AbstractTerm}}, Nothing, AbstractTerm})
end
p
end

"""
Rerm name.
"""
Expand All @@ -50,16 +53,16 @@ end
L-contrast matrix for `i` fixed effect.
"""
function lcontrast(lmm::LMM, i::Int)
n = length(lmm.mf.f.rhs.terms)
n = length(lmm.f.rhs.terms)
p = size(lmm.data.xv, 2)
if i > n || n < 1 error("Factor number out of range 1-$(n)") end
inds = findall(x -> x==i, lmm.mm.assign)
if typeof(lmm.mf.f.rhs.terms[i]) <: CategoricalTerm
mxc = zeros(size(lmm.mf.f.rhs.terms[i].contrasts.matrix, 1), p)
inds = findall(x -> x==i, assign(lmm))
if typeof(lmm.f.rhs.terms[i]) <: CategoricalTerm
mxc = zeros(size(lmm.f.rhs.terms[i].contrasts.matrix, 1), p)
mxcv = view(mxc, :, inds)
mxcv .= lmm.mf.f.rhs.terms[i].contrasts.matrix
mx = zeros(size(lmm.mf.f.rhs.terms[i].contrasts.matrix, 1) - 1, p)
for i = 2:size(lmm.mf.f.rhs.terms[i].contrasts.matrix, 1)
mxcv .= lmm.f.rhs.terms[i].contrasts.matrix
mx = zeros(size(lmm.f.rhs.terms[i].contrasts.matrix, 1) - 1, p)
for i = 2:size(lmm.f.rhs.terms[i].contrasts.matrix, 1)
mx[i-1, :] .= mxc[i, :] - mxc[1, :]
end
else
Expand Down Expand Up @@ -429,3 +432,10 @@ function StatsModels.termvars(ve::Vector{VarEffect})
end

################################################################################
#=
asgn(f::FormulaTerm) = asgn(f.rhs)
asgn(t) = mapreduce(((i,t), ) -> i*ones(StatsModels.width(t)),
append!,
enumerate(StatsModels.vectorize(t)),
init=Int[])
=#
11 changes: 9 additions & 2 deletions test/test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,21 @@ include("testdata.jl")
)
@test Metida.m2logreml(lmm) 16.241112644506067 atol=1E-6

lmm = Metida.fit(Metida.LMM, Metida.@lmmformula(var~0+sequence+period+formulation,
random = formulation|subject:Metida.DIAG), df0)
@test Metida.fixedeffn(lmm) == 3
t3table = Metida.typeiii(lmm)
@test length(t3table.name) == 3

lmm = Metida.fit(Metida.LMM, Metida.@lmmformula(var~sequence+period+formulation,
random = formulation|subject:Metida.DIAG), df0)
@test Metida.m2logreml(lmm) 16.241112644506067 atol=1E-6
@test Metida.fixedeffn(lmm) == 4

t3table = Metida.typeiii(lmm; ddf = :contain) # NOT VALIDATED
t3table = Metida.typeiii(lmm; ddf = :residual)
t3table = Metida.typeiii(lmm)

@test length(t3table.name) == 4
############################################################################
############################################################################
# API test
Expand Down Expand Up @@ -93,7 +100,7 @@ include("testdata.jl")
@test isa(response(lmm), Vector)
@test sum(Metida.hessian(lmm)) 1118.160713481362 atol=1E-2
@test Metida.nblocks(lmm) == 5
@test length(coefnames(lmm)) == 6
@test coefnames(lmm) == ["(Intercept)", "sequence: 2", "period: 2", "period: 3", "period: 4", "formulation: 2"]
@test Metida.gmatrixipd(lmm)
@test Metida.confint(lmm)[end][1] -0.7630380758015894 atol=1E-4
@test Metida.confint(lmm, 6)[1] -0.7630380758015894 atol=1E-4
Expand Down

0 comments on commit caeaea2

Please sign in to comment.