Skip to content

Commit

Permalink
mask changes
Browse files Browse the repository at this point in the history
  • Loading branch information
montyvesselinov committed Jan 17, 2025
1 parent f50fe65 commit 7e6112d
Showing 1 changed file with 43 additions and 43 deletions.
86 changes: 43 additions & 43 deletions src/SVRfunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ function fit(y::AbstractArray{T}, x::AbstractArray{T}; kw...) where {T <: Number
return yp
end

function fit_test(y::AbstractVector{Float64}, x::AbstractArray{Float64}; ratio::Number=0.1, repeats::Number=1, pm=nothing, keepcases::Union{BitArray,Nothing}=nothing, scale::Bool=false, ymin::Number=minimum(y), ymax::Number=maximum(y), quiet::Bool=false, veryquiet::Bool=true, total::Bool=false, rmse::Bool=true, callback::Function=(y::AbstractVector, y_pr::AbstractVector, pm::AbstractVector)->nothing, kw...)
function fit_test(y::AbstractVector{Float64}, x::AbstractArray{Float64}; ratio_prediction::Number=0.1, repeats::Number=1, mask_prediction=nothing, keepcases::Union{BitArray,Nothing}=nothing, scale::Bool=false, ymin::Number=minimum(y), ymax::Number=maximum(y), quiet::Bool=false, veryquiet::Bool=true, total::Bool=false, rmse::Bool=true, callback::Function=(y::AbstractVector, y_pr::AbstractVector, mask_prediction::AbstractVector)->nothing, kw...)
if !isnothing(keepcases)
@assert length(keepcases) == size(x, 2)
end
Expand All @@ -127,64 +127,64 @@ function fit_test(y::AbstractVector{Float64}, x::AbstractArray{Float64}; ratio::
pma = Vector{Bool}(undef, 0)
local y_pr
for r in 1:repeats
if repeats > 1 || isnothing(pm)
pm = get_prediction_mask(length(y), ratio; keepcases=keepcases)
if repeats > 1 || isnothing(mask_prediction)
mask_prediction = get_prediction_mask(length(y), ratio_prediction; keepcases=keepcases)
else
@assert length(pm) == size(x, 2)
@assert eltype(pm) <: Bool
@assert length(mask_prediction) == size(x, 2)
@assert eltype(mask_prediction) <: Bool
end
ic = sum(.!pm)
ic = sum(.!mask_prediction)
if !quiet && repeats == 1 && length(y) > ic
@info("Training on $(ic) out of $(length(y)) (prediction ratio $ratio) ...")
@info("Training on $(ic) out of $(length(y)) (prediction ratio_prediction $ratio_prediction) ...")
end
pmodel = train(a[.!pm], x[:,.!pm]; kw...)
pmodel = train(a[.!mask_prediction], x[:,.!mask_prediction]; kw...)
y_pr = predict(pmodel, x)
freemodel(pmodel)
if any(isnan.(y_pr))
@warn("SVR output contains NaN's!")
end
if rmse
m[r] = total ? rmse(y_pr, a) : rmse(y_pr[pm], a[pm])
m[r] = total ? rmse(y_pr, a) : rmse(y_pr[mask_prediction], a[mask_prediction])
else
m[r] = total ? r2(y_pr, a) : r2(y_pr[pm], a[pm])
m[r] = total ? r2(y_pr, a) : r2(y_pr[mask_prediction], a[mask_prediction])
end
if !veryquiet && repeats > 1
println("Repeat $r: $(m[r])")
end
y_pra = vcat(y_pra, y_pr)
ya = vcat(ya, y)
pma = vcat(pma, pm)
pma = vcat(pma, mask_prediction)
end
y_pra = y_pra * (ymax - ymin) .+ ymin
callback(ya, y_pra, pma)
y_pr = y_pr * (ymax - ymin) .+ ymin
return y_pr, pm, Statistics.mean(m)
return y_pr, mask_prediction, Statistics.mean(m)
end
function fit_test(y::AbstractVector{T}, x::AbstractArray{T}; ratio::Number=0.1, kw...) where {T <: Number}
y_pr, pm, rmse = fit_test(Float64.(y), Float64.(x); ratio=ratio, kw...)
return T.(y_pr), pm, rmse
function fit_test(y::AbstractVector{T}, x::AbstractArray{T}; ratio_prediction::Number=0.1, kw...) where {T <: Number}
y_pr, mask_prediction, rmse = fit_test(Float64.(y), Float64.(x); ratio_prediction=ratio_prediction, kw...)
return T.(y_pr), mask_prediction, rmse
end
function fit_test(y::AbstractArray{T}, x::AbstractArray{T}; ratio::Number=0.1, pm=nothing, keepcases::Union{BitArray,Nothing}=nothing, kw...) where {T <: Number}
function fit_test(y::AbstractArray{T}, x::AbstractArray{T}; ratio_prediction::Number=0.1, mask_prediction=nothing, keepcases::Union{BitArray,Nothing}=nothing, kw...) where {T <: Number}
@assert size(y, 1) == size(x, 2)
if !isnothing(keepcases)
@assert length(keepcases) == size(x, 2)
end
if isnothing(pm)
pm = get_prediction_mask(size(y, 1), ratio; keepcases=keepcases)
if isnothing(mask_prediction)
mask_prediction = get_prediction_mask(size(y, 1), ratio_prediction; keepcases=keepcases)
end
yp = similar(y)
for i = 1:size(y, 2)
yp[:,i], _, rmse = fit_test(vec(y[:,i]), x; ratio=ratio, pm=pm, kw...)
yp[:,i], _, rmse = fit_test(vec(y[:,i]), x; ratio_prediction=ratio_prediction, mask_prediction=mask_prediction, kw...)
end
return yp, pm, rmse
return yp, mask_prediction, rmse
end
function fit_test(y::AbstractVector{T}, x::AbstractArray{T}, vattr::Union{AbstractVector,AbstractRange}; ratio::Number=0.1, attr=:gamma, rmse::Bool=true, check::Function=(v::AbstractVector)->nothing, kw...) where {T <: Number}
function fit_test(y::AbstractVector{T}, x::AbstractArray{T}, vattr::Union{AbstractVector,AbstractRange}; ratio_prediction::Number=0.1, attr=:gamma, rmse::Bool=true, check::Function=(v::AbstractVector)->nothing, kw...) where {T <: Number}
@assert length(vattr) > 0
@info("Grid search on $attr with prediction ratio $ratio ...")
@info("Grid search on $attr with prediction ratio_prediction $ratio_prediction ...")
ma = Vector{T}(undef, length(vattr))
for (i, g) in enumerate(vattr)
k = Dict(attr=>g)
y_pr, pm, ma[i] = fit_test(y, x; ratio=ratio, rmse=rmse, kw..., k..., quiet=true)
y_pr, mask_prediction, ma[i] = fit_test(y, x; ratio_prediction=ratio_prediction, rmse=rmse, kw..., k..., quiet=true)
@info("$attr=>$g: $(ma[i])")
end
c = check(ma)
Expand All @@ -195,64 +195,64 @@ function fit_test(y::AbstractVector{T}, x::AbstractArray{T}, vattr::Union{Abstra
m = ma[i]
end
k = Dict(attr=>vattr[i])
return m, vattr[i], fit_test(y, x; ratio=ratio, rmse=rmse, kw..., k..., repeats=1)...
return m, vattr[i], fit_test(y, x; ratio_prediction=ratio_prediction, rmse=rmse, kw..., k..., repeats=1)...
end
function fit_test(y::AbstractVector{T}, x::AbstractArray{T}, vattr1::Union{AbstractVector,AbstractRange}, vattr2::Union{AbstractVector,AbstractRange}; ratio::Number=0.1, attr1=:gamma, attr2=:epsilon, rmse::Bool=true, kw...) where {T <: Number}
function fit_test(y::AbstractVector{T}, x::AbstractArray{T}, vattr1::Union{AbstractVector,AbstractRange}, vattr2::Union{AbstractVector,AbstractRange}; ratio_prediction::Number=0.1, attr1=:gamma, attr2=:epsilon, rmse::Bool=true, kw...) where {T <: Number}
@assert length(vattr1) > 0
@assert length(vattr2) > 0
@info("Grid search on $attr1/$attr2 with prediction ratio $ratio ...")
@info("Grid search on $attr1/$attr2 with prediction ratio_prediction $ratio_prediction ...")
ma = Matrix{T}(undef, length(vattr1), length(vattr2))
for (i, a1) in enumerate(vattr1)
for (j, a2) in enumerate(vattr2)
k = Dict(attr1=>a1, attr2=>a2)
y_pr, pm, ma[i, j] = fit_test(y, x; ratio=ratio, rmse=rmse, kw..., k...)
y_pr, mask_prediction, ma[i, j] = fit_test(y, x; ratio_prediction=ratio_prediction, rmse=rmse, kw..., k...)
@info("$attr1=>$a1 $attr2=>$a2: $(ma[i,j])")
end
end
m, i = rmse ? findmin(ma) : findmax(ma)
k = Dict(attr1=>vattr1[i.I[1]], attr2=>vattr2[i.I[2]])
return m, vattr1[i.I[1]], vattr2[i.I[2]], fit_test(y, x; ratio=ratio, rmse=rmse, kw..., k..., repeats=1)...
return m, vattr1[i.I[1]], vattr2[i.I[2]], fit_test(y, x; ratio_prediction=ratio_prediction, rmse=rmse, kw..., k..., repeats=1)...
end
function fit_test(y::AbstractVector{T}, x::AbstractArray{T}, vattr1::Union{AbstractVector,AbstractRange}, vattr2::Union{AbstractVector,AbstractRange}, vattr3::Union{AbstractVector,AbstractRange}; ratio::Number=0.1, attr1=:gamma, attr2=:epsilon, attr3=:C, rmse::Bool=true, kw...) where {T <: Number}
function fit_test(y::AbstractVector{T}, x::AbstractArray{T}, vattr1::Union{AbstractVector,AbstractRange}, vattr2::Union{AbstractVector,AbstractRange}, vattr3::Union{AbstractVector,AbstractRange}; ratio_prediction::Number=0.1, attr1=:gamma, attr2=:epsilon, attr3=:C, rmse::Bool=true, kw...) where {T <: Number}
@assert length(vattr1) > 0
@assert length(vattr2) > 0
@assert length(vattr3) > 0
@info("Grid search on $attr1/$attr2/$attr3 with prediction ratio $ratio ...")
@info("Grid search on $attr1/$attr2/$attr3 with prediction ratio_prediction $ratio_prediction ...")
ma = Array{T}(undef, length(vattr1), length(vattr2), length(vattr3))
for (i, a1) in enumerate(vattr1)
for (j, a2) in enumerate(vattr2)
for (k, a3) in enumerate(vattr3)
kk = Dict(attr1=>a1, attr2=>a2, attr3=>a3)
y_pr, pm, ma[i, j, k] = fit_test(y, x; ratio=ratio, rmse=rmse, kw..., kk...)
y_pr, mask_prediction, ma[i, j, k] = fit_test(y, x; ratio_prediction=ratio_prediction, rmse=rmse, kw..., kk...)
@info("$attr1=>$a1 $attr2=>$a2 $attr3=>$a3: $(ma[i,j,k])")
end
end
end
m, i = rmse ? findmin(ma) : findmax(ma)
k = Dict(attr1=>vattr1[i.I[1]], attr2=>vattr2[i.I[2]], attr3=>vattr3[i.I[3]])
return m, vattr1[i.I[1]], vattr2[i.I[2]], vattr3[i.I[3]], fit_test(y, x; ratio=ratio, rmse=rmse, kw..., k..., repeats=1)...
return m, vattr1[i.I[1]], vattr2[i.I[2]], vattr3[i.I[3]], fit_test(y, x; ratio_prediction=ratio_prediction, rmse=rmse, kw..., k..., repeats=1)...
end

"""
Get prediction mask
$(DocumentFunction.documentfunction(get_prediction_mask;
argtext=Dict("ns"=>"number of samples",
"ratio"=>"prediction ratio")))
"ratio_prediction"=>"prediction ratio_prediction")))
Return:
- prediction mask
"""
function get_prediction_mask(ns::Number, ratio::Number; keepcases::Union{AbstractVector,Nothing}=nothing, debug::Bool=false)
function get_prediction_mask(ns::Number, ratio_prediction::Number; keepcases::Union{AbstractVector,Nothing}=nothing, debug::Bool=false)
nsi = copy(ns)
pm = trues(ns)
ic = convert(Int64, ceil(ns * (1. - ratio)))
mask_prediction = trues(ns)
ic = convert(Int64, ceil(ns * (1. - ratio_prediction)))
if !isnothing(keepcases)
@assert length(keepcases) == length(pm)
@assert length(keepcases) == length(mask_prediction)
kn = sum(keepcases)
if ic > kn && ns > kn
pm[keepcases] .= false
mask_prediction[keepcases] .= false
ic -= kn
nsi -= kn
else
Expand All @@ -264,16 +264,16 @@ function get_prediction_mask(ns::Number, ratio::Number; keepcases::Union{Abstrac
if !isnothing(keepcases) && ic > kn
m = trues(nsi)
m[ir] .= false
pm[.!keepcases] .= m
mask_prediction[.!keepcases] .= m
else
pm[ir] .= false
mask_prediction[ir] .= false
end
end
if debug
@info("Number of cases for training: $(ns - sum(pm))")
@info("Number of cases for prediction: $(sum(pm))")
@info("Number of cases for training: $(ns - sum(mask_prediction))")
@info("Number of cases for prediction: $(sum(mask_prediction))")
end
return pm
return mask_prediction
end

"""
Expand Down

0 comments on commit 7e6112d

Please sign in to comment.