Skip to content

Commit

Permalink
Merge pull request #359 from Neuroblox/minorfixes
Browse files Browse the repository at this point in the history
some sDCM fixes, OU process, and made amenable to SDEs.
  • Loading branch information
david-hofmann authored Jul 26, 2024
2 parents 9c605ec + 6ed1783 commit 28c9d6b
Show file tree
Hide file tree
Showing 8 changed files with 204 additions and 148 deletions.
6 changes: 3 additions & 3 deletions src/Neurographs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ function system_from_graph(g::MetaDiGraph, bc::BloxConnector; name, t_block=miss
connection_eqs = get_equations_with_state_lhs(bc)

cbs = identity.(get_callbacks(g, bc; t_block))
return compose(ODESystem(connection_eqs, t, [], params(bc); name, discrete_events = cbs), blox_syss)
return compose(System(connection_eqs, t, [], params(bc); name, discrete_events = cbs), blox_syss)
end

function system_from_graph(g::MetaDiGraph, bc::BloxConnector, p::Vector{Num}; name, t_block=missing)
Expand All @@ -72,11 +72,11 @@ function system_from_graph(g::MetaDiGraph, bc::BloxConnector, p::Vector{Num}; na

connection_eqs = get_equations_with_state_lhs(bc)
cbs = identity.(get_callbacks(g, bc; t_block))
return compose(ODESystem(connection_eqs, t, [], vcat(params(bc), p); name, discrete_events = cbs), blox_syss)
return compose(System(connection_eqs, t, [], vcat(params(bc), p); name, discrete_events = cbs), blox_syss)
end

function system_from_parts(parts::AbstractVector; name)
return compose(ODESystem(Equation[], t, [], []; name), get_sys.(parts))
return compose(System(Equation[], t; name), get_sys.(parts))
end

function action_selection_from_graph(g::MetaDiGraph)
Expand Down
23 changes: 9 additions & 14 deletions src/blox/blox_utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -257,19 +257,15 @@ end
- `sys`: MTK system
Returns:
- `sts` : states of the system that are neither external inputs nor measurements, i.e. these are the dynamic states
- `idx_u`: indices of states that represent external inputs
- `idx_m`: indices of states that represent measurements
- `sts`: states/unknowns of the system that are neither external inputs nor measurements, i.e. these are the dynamic states
- `idx`: indices of these states
"""
function get_dynamic_states(sys)
sts = []
idx = []
for (i, s) in enumerate(unknowns(sys))
if !((getdescription(s) == "ext_input") || (getdescription(s) == "measurement"))
push!(sts, s)
push!(idx, i)
end
itr = Iterators.filter(enumerate(unknowns(sys))) do (_, s)
!((getdescription(s) == "ext_input") || (getdescription(s) == "measurement"))
end
sts = map(x -> x[2], itr)
idx = map(x -> x[1], itr)
return sts, idx
end

Expand All @@ -288,11 +284,11 @@ function get_eqidx_tagged_vars(sys, tag)
for s in Symbolics.get_variables(e)
if string(s) == string(v)
push!(idx, i)
end
end
end
end
end
return idx
return idx, vars
end

function get_idx_tagged_vars(sys, tag)
Expand All @@ -303,7 +299,6 @@ function get_idx_tagged_vars(sys, tag)
end
end
return idx
return idx
end

"""
Expand Down Expand Up @@ -340,7 +335,7 @@ function get_connection_rule(kwargs, bloxout, bloxin, w)
cr = kwargs[:connection_rule]
else
name_blox1 = nameof(bloxout)
name_blox1 = nameof(bloxin)
name_blox2 = nameof(bloxin)
@warn "Neuron connection rule from $name_blox1 to $name_blox2 is not specified. It is assumed that there is a basic weighted connection."
cr = "basic"
end
Expand Down
2 changes: 1 addition & 1 deletion src/blox/canonicalmicrocircuit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ mutable struct CanonicalMicroCircuitBlox <: CompositeBlox

add_edge!(g, 1, 1, :weight, -800.0)
add_edge!(g, 2, 1, :weight, -800.0)
add_edge!(g, 3, 1, :weight, -800.0)
add_edge!(g, 3, 1, :weight, -1600.0)
add_edge!(g, 1, 2, :weight, 800.0)
add_edge!(g, 2, 2, :weight, -800.0)
add_edge!(g, 1, 3, :weight, 800.0)
Expand Down
100 changes: 45 additions & 55 deletions src/blox/stochastic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,65 +13,55 @@ returns:
"""
mutable struct OUBlox <: NeuralMassBlox
# all parameters are Num as to allow symbolic expressions
μ::Num
σ::Num
τ::Num
stochastic::Bool
connector::Num
noDetail::Vector{Num}
detail::Vector{Num}
initial::Dict{Num, Tuple{Float64, Float64}}
odesystem::ODESystem
function OUBlox(;name, μ=0.0, σ=1.0, τ=1.0)
params = @parameters μ=μ τ=τ σ=σ
states = @variables x(t)=1.0 jcn(t)=0.0
namespace
stochastic
output
input
odesystem
function OUBlox(;name, namespace=nothing, μ=0.0, σ=1.0, τ=1.0)
p = paramscoping=μ, τ=τ, σ=σ)
μ, τ, σ = p
sts = @variables x(t)=0.0 [output=true] jcn(t)=0.0 [input=true]
@brownian w

eqs = [D(x) ~ -(x-μ)/τ + jcn + sqrt(2/τ)*σ*w]
sys = System(eqs, t, states, params; name=name)
new(μ, σ, τ, true, sys.x,[sys.x],[sys.x],
Dict(sys.x => (-1.0,1.0)),
sys)
eqs = [D(x) ~ -(x-μ)/τ + jcn + sqrt(2/τ)*σ*w]
sys = System(eqs, t; name=name)
new(namespace, true, sts[1], sts[2], sys)
end
end

"""
Ornstein-Uhlenbeck Coupling Blox
This blox takes an input and multiplies that input with
a OU process of mean μ and variance τ*σ^2/2
# """
# Ornstein-Uhlenbeck Coupling Blox
# This blox takes an input and multiplies that input with
# a OU process of mean μ and variance τ*σ^2/2

This blox allows to create edges that have fluctuating weights
# This blox allows to create edges that have fluctuating weights

variables:
x(t): value
jcn: input
parameters:
τ: relaxation time
μ: average value
σ: random noise (variance of OU process is τ*σ^2/2)
returns:
an ODE System (but with brownian parameters)
"""
mutable struct OUCouplingBlox <: NeuralMassBlox
# all parameters are Num as to allow symbolic expressions
μ::Num
σ::Num
τ::Num
stochastic::Bool
connector::Num
noDetail::Vector{Num}
detail::Vector{Num}
initial::Dict{Num, Tuple{Float64, Float64}}
odesystem::ODESystem
function OUCouplingBlox(;name, μ=0.0, σ=1.0, τ=1.0)
params = @parameters μ=μ τ=τ σ=σ
states = @variables x(t)=1.0 jcn(t)=0.0
@brownian w

eqs = [D(x) ~ -(x-μ)/τ + sqrt(2/τ)*σ*w]
sys = System(eqs, t, states, params; name=name)
new(μ, σ, τ, true, sys.jcn*sys.x,[sys.jcn*sys.x],[sys.jcn*sys.x],
Dict(sys.x => (-1.0,1.0)),
sys)
end
end
# variables:
# x(t): value
# jcn: input
# parameters:
# τ: relaxation time
# μ: average value
# σ: random noise (variance of OU process is τ*σ^2/2)
# returns:
# an ODE System (but with brownian parameters)
# """
# mutable struct OUCouplingBlox <: NeuralMassBlox
# # all parameters are Num as to allow symbolic expressions
# namespace
# stochastic
# output
# input
# odesystem
# function OUCouplingBlox(;name, namespace, μ=0.0, σ=1.0, τ=1.0)
# p = paramscoping(μ=μ, τ=τ, σ=σ)
# μ, τ, σ = p
# sts = @variables x(t)=0.0 [output=true] jcn(t)=0.0 [input=true]
# @brownian w

# eqs = [D(x) ~ -(x-μ)/τ + sqrt(2/τ)*σ*w]
# sys = System(eqs, t; name=name)
# new(namespace, true, sts[2]*sts[1], sts[2], sys)
# end
# end
46 changes: 26 additions & 20 deletions src/datafitting/spectralDCM.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ mutable struct VLState
dF::Vector{Float64} # predicted free energy changes (store at each iteration)
λ::Vector{Float64} # hyperparameter
ϵ_θ::Vector{Float64} # prediction error of parameters θ
reset_state::Vector{Any} # store state to reset to [ϵ_θ and λ] when the free energy deteriorates
reset_state::Vector{Any} # store state to reset to [ϵ_θ and λ] when the free energy gets worse rather than better
μθ_po::Vector{Float64} # posterior expectation value of parameters
Σθ_po::Matrix{Float64} # posterior covariance matrix of parameters
dFdθ::Vector{Float64} # free energy gradient w.r.t. parameters
Expand All @@ -32,7 +32,7 @@ end

struct VLSetup{Model, N}
model_at_x0::Model # model evaluated at initial conditions
y_csd::Array{ComplexF64, N} # cross-spectral density approximated by fitting MARs to data
y_csd::Array{ComplexF64, N} # cross-spectral density approximated by fitting MARs to data
tolerance::Float64 # convergence criterion
systemnums::Vector{Int} # several integers -> np: n. parameters, ny: n. datapoints, nq: n. Q matrices, nh: n. hyperparameters
systemvecs::Vector{Vector{Float64}} # μθ_pr: prior expectation values of parameters and μλ_pr: prior expectation values of hyperparameters
Expand Down Expand Up @@ -105,10 +105,13 @@ function LinearAlgebra.eigen(M::Matrix{Dual{T, P, np}}) where {T, P, np}
end

function transferfunction_fmri(ω, derivatives, params, indices)
# nr = length(indices[:u])
# pars = params[indices[:dspars]]
# ∂f = derivatives([pars[1:nr^2], pars[nr^2+1:end]...])
∂f = derivatives(params[indices[:dspars]])
∂f∂x = ∂f[indices[:sts], indices[:sts]]
∂f∂u = ∂f[indices[:sts], indices[:u]]
∂g∂x = ∂f[indices[:bold], indices[:sts]]
∂g∂x = ∂f[indices[:m], indices[:sts]]

F = eigen(∂f∂x)
Λ = F.values
Expand Down Expand Up @@ -309,7 +312,9 @@ function setup_sDCM(data, model, initcond, csdsetup, priors, hyperpriors, indice
dt = csdsetup[:dt]; # order of MAR. Hard-coded in SPM12 with this value. We will use the same for now.
ω = csdsetup[:freq]; # frequencies at which the CSD is evaluated
p = csdsetup[:p]; # order of MAR
mar = mar_ml(Matrix(data), p); # compute MAR from time series y and model order p
_, vars = get_eqidx_tagged_vars(model, "measurement")
data = Matrix(data[:, Symbol.(vars)]) # make sure the column order is consistent with the ordering of variables of the model that represent the measurements
mar = mar_ml(data, p); # compute MAR from time series y and model order p
y_csd = mar2csd(mar, ω, dt^-1); # compute cross spectral densities from MAR parameters at specific frequencies freqs, dt^-1 is sampling rate of data
jac_fg = generate_jacobian(model, expression = Val{false})[1] # compute symbolic jacobian.

Expand All @@ -331,27 +336,28 @@ function setup_sDCM(data, model, initcond, csdsetup, priors, hyperpriors, indice

# variational laplace state variables
vlstate = VLState(
0, # iter
-4, # log ascent rate
[-Inf], # free energy
Float64[], # delta free energy
8*ones(nh), # metaparameter, initial condition. TODO: why are we not just using the prior mean?
zeros(np), # parameter estimation error ϵ_θ
[zeros(np), 8*ones(nh)], # memorize reset state
μθ_pr, # parameter posterior mean
Σθ_pr, # parameter posterior covariance
0, # iter
-4, # log ascent rate
[-Inf], # free energy
Float64[], # delta free energy
hyperpriors[:μλ_pr], # metaparameter, initial condition. TODO: why are we not just using the prior mean?
zeros(np), # parameter estimation error ϵ_θ
[zeros(np), hyperpriors[:μλ_pr]], # memorize reset state
μθ_pr, # parameter posterior mean
Σθ_pr, # parameter posterior covariance
zeros(np),
zeros(np, np)
)

# variational laplace setup
vlsetup = VLSetup(
f, # function that computes the cross-spectral density at fixed point 'initcond'
y_csd, # empirical cross-spectral density
1e-1, # tolerance
[np, ny, nq, nh], # number of parameters, number of data points, number of Qs, number of hyperparameters
[μθ_pr, hyperpriors[:μλ_pr]], # parameter and hyperparameter prior mean
[inv(Σθ_pr), hyperpriors[:Πλ_pr]], # parameter and hyperparameter prior precision matrices
Q # components of data precision matrix
f, # function that computes the cross-spectral density at fixed point 'initcond'
y_csd, # empirical cross-spectral density
1e-1, # tolerance
[np, ny, nq, nh], # number of parameters, number of data points, number of Qs, number of hyperparameters
[μθ_pr, hyperpriors[:μλ_pr]], # parameter and hyperparameter prior mean
[inv(Σθ_pr), hyperpriors[:Πλ_pr]], # parameter and hyperparameter prior precision matrices
Q # components of data precision matrix
)
return (vlstate, vlsetup)
end
Expand Down
2 changes: 1 addition & 1 deletion src/measurementmodels/fmri.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ struct BalloonModel <: ObserverBlox
p = paramscoping(lnκ=lnκ, lnτ=lnτ, lnϵ=lnϵ) # finally compile all parameters
lnκ, lnτ, lnϵ = p # assign the modified parameters

sts = @variables s(t)=1.0 lnu(t)=0.0 lnν(t)=0.0 lnq(t)=0.0 bold(t) [irreducible=true, output=true, description="measurement"] jcn(t)=0.0 [input=true]
sts = @variables s(t)=1.0 lnu(t)=0.0 lnν(t)=0.0 lnq(t)=0.0 bold(t)=0.0 [irreducible=true, output=true, description="measurement"] jcn(t)=0.0 [input=true]

eqs = [
D(s) ~ jcn - H[1]*exp(lnκ)*s - H[2]*(exp(lnu) - 1),
Expand Down
Loading

0 comments on commit 28c9d6b

Please sign in to comment.