diff --git a/src/blox/reinforcement_learning.jl b/src/blox/reinforcement_learning.jl index 32f06179..466866bf 100644 --- a/src/blox/reinforcement_learning.jl +++ b/src/blox/reinforcement_learning.jl @@ -34,6 +34,8 @@ end mutable struct HebbianModulationPlasticity <: AbstractLearningRule const K const decay + const α + const θₘ state_pre state_post t_pre @@ -42,11 +44,11 @@ mutable struct HebbianModulationPlasticity <: AbstractLearningRule modulator function HebbianModulationPlasticity(; - K, decay, modulator=nothing, + K, decay, α, θₘ, modulator=nothing, state_pre=nothing, state_post=nothing, t_pre=nothing, t_post=nothing, t_mod=nothing, ) - new(K, decay, state_pre, state_post, t_pre, t_post, t_mod, modulator) + new(K, decay, α, θₘ, state_pre, state_post, t_pre, t_post, t_mod, modulator) end end @@ -55,8 +57,10 @@ dlogistic(x) = logistic(x) * (1 - logistic(x)) function (hmp::HebbianModulationPlasticity)(val_pre, val_post, val_modulator, w, feedback) DA = hmp.modulator(val_modulator, feedback) DA_baseline = hmp.modulator.κ_DA * hmp.modulator.N_time_blocks - - Δw = hmp.K * val_post * val_pre * DA * (DA - DA_baseline) * dlogistic(DA) - hmp.decay * w + ϵ = feedback - (hmp.modulator.κ_DA - DA) + + # Δw = hmp.K * val_post * val_pre * DA * (DA - DA_baseline) * dlogistic(DA) - hmp.decay * w + Δw = hmp.K * val_post * val_pre * ϵ * (ϵ + hmp.θₘ) * dlogistic(hmp.α * (ϵ + hmp.θₘ)) - hmp.decay * w return Δw end diff --git a/test/reinforcement_learning.jl b/test/reinforcement_learning.jl index ba920f26..25e0e915 100644 --- a/test/reinforcement_learning.jl +++ b/test/reinforcement_learning.jl @@ -28,7 +28,7 @@ using CSV bloxs = [VAC, PFC, STR_L, STR_R, SNcb, TAN_pop, AS, stim] d = Dict(b => i for (i,b) in enumerate(bloxs)) - hebbian_mod = HebbianModulationPlasticity(K=0.2, decay=0.01, modulator=SNcb, t_pre=t_trial, t_post=t_trial, t_mod=0.31*t_trial) + hebbian_mod = HebbianModulationPlasticity(K=0.2, decay=0.01, α=3, θₘ=1, modulator=SNcb, t_pre=t_trial, t_post=t_trial, t_mod=0.31*t_trial) hebbian = HebbianPlasticity(K=0.2, W_lim=2, t_pre=t_trial, t_post=t_trial) g = MetaDiGraph()