Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add PositiveDefinite #89

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from
Draft
10 changes: 10 additions & 0 deletions docs/ref.bib
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,13 @@ @misc{zagoruyko2017wideresidualnetworks
primaryclass = {cs.CV},
url = {https://arxiv.org/abs/1605.07146}
}

@misc{gaby2022lyapunovnetdeepneuralnetwork,
title={Lyapunov-Net: A Deep Neural Network Architecture for Lyapunov Function Approximation},
author={Nathan Gaby and Fumin Zhang and Xiaojing Ye},
year={2022},
eprint={2109.13359},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2109.13359},
}
4 changes: 3 additions & 1 deletion src/layers/Layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ const NORM_LAYER_DOC = "Function with signature `f(i::Integer, dims::Integer, ac

include("attention.jl")
include("conv_norm_act.jl")
include("containers.jl")
include("dynamic_expressions.jl")
include("encoder.jl")
include("embeddings.jl")
Expand All @@ -42,6 +43,7 @@ include("tensor_product.jl")
@compat(public,
(ClassTokens, ConvBatchNormActivation, ConvNormActivation, DynamicExpressionsLayer,
HamiltonianNN, MultiHeadSelfAttention, MLP, PatchEmbedding, PeriodicEmbedding,
SplineLayer, TensorProductLayer, ViPosEmbedding, VisionTransformerEncoder))
PositiveDefinite, ShiftTo, SplineLayer, TensorProductLayer, ViPosEmbedding,
VisionTransformerEncoder))

end
135 changes: 135 additions & 0 deletions src/layers/containers.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
"""
PositiveDefinite(model, x0; ψ, r)
PositiveDefinite(model; in_dims, ψ, r)

Constructs a Lyapunov-Net [gaby2022lyapunovnetdeepneuralnetwork](@citep), which is positive
definite about `x0` whenever `ψ` and `r` meet certain conditions described below.

For a model `ϕ`,
`PositiveDefinite(ϕ, ψ, r, x0)(x, ps, st) = ψ(ϕ(x, ps, st) - ϕ(x0, ps, st)) + r(x, x0)`.
This results in a model which maps `x0` to `0` and any other input to a positive number
(i.e., a model which is positive definite about `x0`) whenever `ψ` is positive definite
about zero and `r` returns a positive number for any non-equal inputs and zero for equal
inputs.

## Arguments
- `model`: the underlying model being transformed into a positive definite function
- `x0`: The unique input that will be mapped to zero instead of a positive number

## Keyword Arguments
- `in_dims`: the number of input dimensions if `x0` is not provided; uses
`x0 = zeros(in_dims)`
- `ψ`: a positive definite function (about zero); defaults to ``ψ(x) = ||x||^2``
- `r`: a bivariate function such that `r(x0, x0) = 0` and
`r(x, x0) > 0` whenever `x ≠ x0`; defaults to ``r(x, y) = ||x - y||^2``

## Inputs
- `x`: will be passed directly into `model`, so must meet the input requirements of that
argument

## Returns
- The output of the positive definite model
- The state of the positive definite model. If the underlying model changes it state, the
state will be updated according to the call with the input `x`, not with the call using
`x0`.

## States
- `st`: a `NamedTuple` containing the state of the underlying `model` and the `x0` value

## Parameters
- Same as the underlying `model`
"""
@concrete struct PositiveDefinite <: AbstractLuxWrapperLayer{:model}
model <: AbstractLuxLayer
x0 <: AbstractVector
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't store a vector here. Instead pass in a initialization_function (ideally from WeightInitializers.jl) and construct the vector inside initialstates

ψ <: Function
r <: Function

function PositiveDefinite(model, x0::AbstractVector; ψ = Base.Fix1(sum, abs2),
r = Base.Fix1(sum, abs2) ∘ -)
return PositiveDefinite(model, x0, ψ, r)
end
function PositiveDefinite(model; in_dims::Integer, ψ = Base.Fix1(sum, abs2),
r = Base.Fix1(sum, abs2) ∘ -)
return PositiveDefinite(model, zeros(in_dims), ψ, r)
end
end

function LuxCore.initialstates(rng::AbstractRNG, pd::PositiveDefinite)
return (; model=LuxCore.initialstates(rng, pd.model), x0=pd.x0)
end

function (pd::PositiveDefinite)(x::AbstractVector, ps, st)
out, new_st = pd(reshape(x, :, 1), ps, st)
return vec(out), new_st
end

function (pd::PositiveDefinite)(x::AbstractMatrix, ps, st)
ϕ0, _ = pd.model(st.x0, ps, st.model)
ϕx, new_model_st = pd.model(x, ps, st.model)
return (
mapreduce(hcat, zip(eachcol(x), eachcol(ϕx))) do (x, ϕx)
pd.ψ(ϕx - ϕ0) + pd.r(x, st.x0)
end,
merge(st, (; model = new_model_st))
)
end

"""
ShiftTo(model, in_val, out_val)

Vertically shifts the output of `model` to otuput `out_val` when the input is `in_val`.

For a model `ϕ`, `ShiftTo(ϕ, in_val, out_val)(x, ps, st) = ϕ(x, ps, st) + Δϕ`,
where `Δϕ = out_val - ϕ(in_val, ps, st)`.

## Arguments
- `model`: the underlying model being transformed into a positive definite function
- `in_val`: The input that will be mapped to `out_val`
- `out_val`: The value that the output will be shifted to when the input is `in_val`

## Inputs
- `x`: will be passed directly into `model`, so must meet the input requirements of that
argument

## Returns
- The output of the shifted model
- The state of the shifted model. If the underlying model changes it state, the
state will be updated according to the call with the input `x`, not the call using
`in_val`.

## States
- `st`: a `NamedTuple` containing the state of the underlying `model` and the `in_val` and
`out_val` values

## Parameters
- Same as the underlying `model`
"""
@concrete struct ShiftTo <: AbstractLuxWrapperLayer{:model}
model <: AbstractLuxLayer
in_val <: AbstractVector
out_val <: AbstractVector
Comment on lines +110 to +111
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above

end

function LuxCore.initialstates(rng::AbstractRNG, s::ShiftTo)
return (;
model=LuxCore.initialstates(rng, s.model),
in_val=s.in_val,
out_val=s.out_val
)
end

function (s::ShiftTo)(x::AbstractVector, ps, st)
out, new_st = s(reshape(x, :, 1), ps, st)
return vec(out), new_st
end

function (s::ShiftTo)(x::AbstractMatrix, ps, st)
ϕ0, _ = s.model(st.in_val, ps, st.model)
Δϕ = st.out_val .- ϕ0
ϕx, new_model_st = s.model(x, ps, st.model)
return (
ϕx .+ Δϕ,
merge(st, (; model = new_model_st))
)
end
44 changes: 44 additions & 0 deletions test/layer_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -282,3 +282,47 @@ end
end
end
end

@testitem "Positive Definite Container" setup=[SharedTestSetup] tags=[:layers] begin
using NNlib

@testset "$(mode)" for (mode, aType, dev, ongpu) in MODES
model = Layers.MLP(2, (4, 4, 2), NNlib.gelu)
pd = Layers.PositiveDefinite(model; in_dims=2)
ps, st = Lux.setup(StableRNG(0), pd) |> dev

x = randn(StableRNG(0), Float32, 2, 2) |> aType
x0 = zeros(Float32, 2) |> aType

y, _ = pd(x, ps, st)
z, _ = model(x, ps, st.model)
z0, _ = model(x0, ps, st.model)
y2 = sum(abs2, z .- z0; dims = 1) .+ sum(abs2, x .- x0; dims = 1)
@test maximum(abs, y - y2) < 1.0f-8

@jet pd(x, ps, st)

__f = (x, ps) -> sum(first(pd(x, ps, st)))
@test_gradients(__f, x, ps; atol=1.0f-3, rtol=1.0f-3)
end
end

@testitem "ShiftTo Container" setup=[SharedTestSetup] tags=[:layers] begin
using NNlib

@testset "$(mode)" for (mode, aType, dev, ongpu) in MODES
model = Layers.MLP(2, (4, 4, 2), NNlib.gelu)
s = Layers.ShiftTo(model, ones(2), zeros(2))
ps, st = Lux.setup(StableRNG(0), s) |> dev

x0 = ones(Float32, 2) |> aType
y0, _ = model(x0, ps, st.model)
@test maximum(abs, y0) < 1.0f-8

x = randn(StableRNG(0), Float32, 2, 2) |> aType
@jet s(x, ps, st)

__f = (x, ps) -> sum(first(s(x, ps, st)))
@test_gradients(__f, x, ps; atol=1.0f-3, rtol=1.0f-3)
end
end