Skip to content

Commit

Permalink
Merge pull request #386 from Neuroblox/agent-learning-hooks
Browse files Browse the repository at this point in the history
Adding ability to save simulations to the agent learning loop
  • Loading branch information
agchesebro authored Aug 26, 2024
2 parents a0f1ca0 + 34a9790 commit 1592089
Show file tree
Hide file tree
Showing 4 changed files with 175 additions and 2 deletions.
10 changes: 8 additions & 2 deletions examples/RF_learning_using_BLOX.jl
Original file line number Diff line number Diff line change
Expand Up @@ -210,12 +210,18 @@ end

# ╔═╡ b7e84b20-0b80-478c-bc88-2883f80bcbb4
begin
# sys = Neuroblox.get_sys(agent)
# learning_rules = agent.learning_rules
# weights = Dict{Num, Float64}()
# for w in keys(learning_rules)
# weights[w] = defs[w]
# end


for ii = 1:N_trials
prob3 = agent.problem
stim_params = Neuroblox.get_trial_stimulus(env)
prob3 = remake(prob3; p = merge(weights, stim_params), u0 = u0,tspan=(0,1600))
new_params = ModelingToolkit.MTKParameters(sys, merge(defs, weights, stim_params))
prob3 = remake(prob3; p = new_params, u0=u0, tspan=(0,env.t_trial))
@info env.current_trial
sol2 = solve(prob3, Vern7())
agent.problem = prob3
Expand Down
1 change: 1 addition & 0 deletions src/Neuroblox.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ using Symbolics: @register_symbolic, getdefaultval
using DelimitedFiles: readdlm
using CSV: read
using DataFrames
using JLD2

using Peaks: argmaxima, peakproms!, peakheights!

Expand Down
78 changes: 78 additions & 0 deletions src/blox/reinforcement_learning.jl
Original file line number Diff line number Diff line change
Expand Up @@ -311,3 +311,81 @@ function run_trial!(agent::Agent, env::ClassificationEnvironment, weights::Dict{
# u0 = sol[1:end,end]
end
end

function run_experiment!(agent::Agent, env::ClassificationEnvironment, save_path::String, t_warmup=200.0; kwargs...)
N_trials = env.N_trials
t_trial = env.t_trial
tspan = (0, t_trial)
sys = get_sys(agent)
prob = agent.problem

if t_warmup > 0
prob = remake(prob; tspan=(0,t_warmup))
if haskey(kwargs, :alg)
sol = solve(prob, kwargs[:alg]; kwargs...)
else
sol = solve(prob; alg_hints = [:stiff], kwargs...)
end
u0 = sol[1:end,end] # last value of state vector
prob = remake(prob; tspan=tspan, u0=u0)
else
prob = remake(prob; tspan)
u0 = []
end

action_selection = agent.action_selection
learning_rules = agent.learning_rules

defs = ModelingToolkit.get_defaults(sys)
weights = Dict{Num, Float64}()
for w in keys(learning_rules)
weights[w] = defs[w]
end

for trial_num in Base.OneTo(N_trials)

stim_params = get_trial_stimulus(env)

to_update = merge(weights, stim_params)
new_params = ModelingToolkit.MTKParameters(sys, merge(defs, weights, stim_params))

prob = remake(prob; p = new_params, u0=u0)
if haskey(kwargs, :alg)
sol = solve(prob, kwargs[:alg]; kwargs...)
else
sol = solve(prob; alg_hints = [:stiff], kwargs...)
end

# u0 = sol[1:end,end] # next run should continue where the last one ended
# In the paper we assume sufficient time interval before net stimulus so that
# system reaches back to steady state, so we don't continue from previous trial's endpoint

if isnothing(action_selection)
feedback = 1
else
action = action_selection(sol)
feedback = env(action)
end

for (w, rule) in learning_rules
w_val = weights[w]
Δw = weight_gradient(rule, sol, w_val, feedback)
weights[w] += Δw
end
increment_trial!(env)

if !isnothing(save_path)
save_voltages(sol, save_path, trial_num)
end

end

agent.problem = prob
end

function save_voltages(sol, filepath, numtrial)
df = DataFrame(sol)
fname = "sim"*lpad(numtrial, 4, "0")*".csv"
fullpath = joinpath(filepath, fname)
CSV.write(fullpath, df)
end
88 changes: 88 additions & 0 deletions test/reinforcement_learning_flattening.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
using Neuroblox
using DifferentialEquations
using Test
using Graphs
using MetaGraphs
using DataFrames
using CSV
using ModelingToolkit: getp

@testset "RL test with save" begin
t_trial = 2 # ms
time_block_dur = 0.01 # ms
N_trials = 3

global_ns = :g # global namespace
@named VAC = CorticalBlox(N_wta=3, N_exci=3, namespace=global_ns, density=0.1, weight=1)
@named PFC = CorticalBlox(N_wta=2, N_exci=3, namespace=global_ns, density=0.1, weight=1)
@named STR_L = Striatum(N_inhib=2, namespace=global_ns)
@named STR_R = Striatum(N_inhib=2, namespace=global_ns)
@named SNcb = SNc(namespace=global_ns, N_time_blocks=t_trial/time_block_dur)
@named TAN_pop = TAN(;namespace=global_ns)

@named AS = GreedyPolicy(namespace=global_ns, t_decision=0.31*t_trial)

fn = joinpath(@__DIR__, "../examples/image_example.csv")
data = CSV.read(fn, DataFrame)
@named stim = ImageStimulus(data[1:N_trials,:]; namespace=global_ns, t_stimulus=0.4*t_trial, t_pause=0.6*t_trial)

bloxs = [VAC, PFC, STR_L, STR_R, SNcb, TAN_pop, AS, stim]
d = Dict(b => i for (i,b) in enumerate(bloxs))

hebbian_mod = HebbianModulationPlasticity(K=0.2, decay=0.01, α=3, θₘ=1, modulator=SNcb, t_pre=t_trial, t_post=t_trial, t_mod=0.31*t_trial)
hebbian = HebbianPlasticity(K=0.2, W_lim=2, t_pre=t_trial, t_post=t_trial)

g = MetaDiGraph()
add_blox!.(Ref(g), bloxs)

add_edge!(g, d[stim], d[VAC], Dict(:weight => 1, :density => 0.1))
add_edge!(g, d[VAC], d[PFC], Dict(:weight => 1, :density => 0.1, :learning_rule => hebbian))
add_edge!(g, d[PFC], d[STR_L], Dict(:weight => 1, :density => 0.1, :learning_rule => hebbian_mod))
add_edge!(g, d[PFC], d[STR_R], Dict(:weight => 1, :density => 0.1, :learning_rule => hebbian_mod))
add_edge!(g, d[STR_R], d[STR_L], Dict(:weight => 1, :t_event => 0.3*t_trial))
add_edge!(g, d[STR_L], d[STR_R], Dict(:weight => 1, :t_event => 0.3*t_trial))
add_edge!(g, d[STR_L], d[SNcb], Dict(:weight => 1))
add_edge!(g, d[STR_R], d[SNcb], Dict(:weight => 1))
add_edge!(g, d[STR_L], d[AS])
add_edge!(g, d[STR_R], d[AS])
add_edge!(g, d[STR_L], d[TAN_pop], Dict(:weight => 1))
add_edge!(g, d[STR_R], d[TAN_pop], Dict(:weight => 1))
add_edge!(g, d[TAN_pop], d[STR_L], Dict(:weight => 1, :t_event => 0.1*t_trial))
add_edge!(g, d[TAN_pop], d[STR_R], Dict(:weight => 1, :t_event => 0.1*t_trial))

agent = Agent(g; name=:ag, t_block = t_trial/5);
ps = parameters(agent.odesystem)


map_idxs = Int.(ModelingToolkit.varmap_to_vars([ps[i] => i for i in eachindex(ps)], ps))
idxs_weight = findall(x -> occursin("w_", String(Symbol(x))), ps)
idx_stim = findall(x -> occursin("stim₊", String(Symbol(x))), ps)
idx_jcn = findall(x -> occursin("jcn", String(Symbol(x))), ps)
idx_spikes = findall(x -> occursin("spikes", String(Symbol(x))), ps)
idx_H = findall(x -> occursin("H", String(Symbol(x))), ps)
idx_I_bg = findall(x -> occursin("I_bg", String(Symbol(x))), ps)
idxs_other_params = setdiff(eachindex(ps), vcat(idxs_weight, idx_stim, idx_jcn, idx_spikes, idx_H, idx_I_bg))

params_at(idxs) = getp(agent.problem, parameters(agent.odesystem)[idxs])(agent.problem)
init_params_all = params_at(:)
init_params_idxs_weight = params_at(idxs_weight)
init_params_idxs_other_params = params_at(idxs_other_params)

env = ClassificationEnvironment(stim; name=:env, namespace=global_ns)
run_experiment!(agent, env, "./"; alg=Vern7(), reltol=1e-9,abstol=1e-9)

final_params = reduce(vcat, agent.problem.p)
# At least some weights need to be different.
@test any(init_params_idxs_weight .!= params_at(idxs_weight))
# @test any(init_params[map_idxs[idxs_weight]] .!= final_params[map_idxs[idxs_weight]])
# All non-weight parameters need to be the same.
@test all(init_params_idxs_other_params .== params_at(idxs_other_params))
# @test all(init_params[map_idxs[idxs_other_params]] .== final_params[map_idxs[idxs_other_params]])

reset!(agent)
@test all(init_params_all .== params_at(:))
@show setdiff(init_params_all, params_at(:))
reset!(env)
@test env.current_trial == 1
end

0 comments on commit 1592089

Please sign in to comment.