Skip to content

Commit

Permalink
Merge pull request #323 from Neuroblox/fix_striatal_learning
Browse files Browse the repository at this point in the history
striatal learning rule modified
  • Loading branch information
anandpathak31 authored Jan 6, 2024
2 parents d9b727c + 250c731 commit 75ead7b
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 5 deletions.
12 changes: 8 additions & 4 deletions src/blox/reinforcement_learning.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ end
mutable struct HebbianModulationPlasticity <: AbstractLearningRule
const K
const decay
const α
const θₘ
state_pre
state_post
t_pre
Expand All @@ -42,11 +44,11 @@ mutable struct HebbianModulationPlasticity <: AbstractLearningRule
modulator

function HebbianModulationPlasticity(;
K, decay, modulator=nothing,
K, decay, α, θₘ, modulator=nothing,
state_pre=nothing, state_post=nothing,
t_pre=nothing, t_post=nothing, t_mod=nothing,
)
new(K, decay, state_pre, state_post, t_pre, t_post, t_mod, modulator)
new(K, decay, α, θₘ, state_pre, state_post, t_pre, t_post, t_mod, modulator)
end
end

Expand All @@ -55,8 +57,10 @@ dlogistic(x) = logistic(x) * (1 - logistic(x))
function (hmp::HebbianModulationPlasticity)(val_pre, val_post, val_modulator, w, feedback)
DA = hmp.modulator(val_modulator, feedback)
DA_baseline = hmp.modulator.κ_DA * hmp.modulator.N_time_blocks

Δw = hmp.K * val_post * val_pre * DA * (DA - DA_baseline) * dlogistic(DA) - hmp.decay * w
ϵ = feedback - (hmp.modulator.κ_DA - DA)

# Δw = hmp.K * val_post * val_pre * DA * (DA - DA_baseline) * dlogistic(DA) - hmp.decay * w
Δw = hmp.K * val_post * val_pre * ϵ *+ hmp.θₘ) * dlogistic(hmp.α *+ hmp.θₘ)) - hmp.decay * w

return Δw
end
Expand Down
2 changes: 1 addition & 1 deletion test/reinforcement_learning.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ using CSV
bloxs = [VAC, PFC, STR_L, STR_R, SNcb, TAN_pop, AS, stim]
d = Dict(b => i for (i,b) in enumerate(bloxs))

hebbian_mod = HebbianModulationPlasticity(K=0.2, decay=0.01, modulator=SNcb, t_pre=t_trial, t_post=t_trial, t_mod=0.31*t_trial)
hebbian_mod = HebbianModulationPlasticity(K=0.2, decay=0.01, α=3, θₘ=1, modulator=SNcb, t_pre=t_trial, t_post=t_trial, t_mod=0.31*t_trial)
hebbian = HebbianPlasticity(K=0.2, W_lim=2, t_pre=t_trial, t_post=t_trial)

g = MetaDiGraph()
Expand Down

0 comments on commit 75ead7b

Please sign in to comment.