diff --git a/examples/RF_learning_using_BLOX.jl b/examples/RF_learning_using_BLOX.jl index 1f197551..a3155f59 100644 --- a/examples/RF_learning_using_BLOX.jl +++ b/examples/RF_learning_using_BLOX.jl @@ -210,12 +210,18 @@ end # ╔═╡ b7e84b20-0b80-478c-bc88-2883f80bcbb4 begin + # sys = Neuroblox.get_sys(agent) + # learning_rules = agent.learning_rules + # weights = Dict{Num, Float64}() + # for w in keys(learning_rules) + # weights[w] = defs[w] + # end - for ii = 1:N_trials prob3 = agent.problem stim_params = Neuroblox.get_trial_stimulus(env) - prob3 = remake(prob3; p = merge(weights, stim_params), u0 = u0,tspan=(0,1600)) + new_params = ModelingToolkit.MTKParameters(sys, merge(defs, weights, stim_params)) + prob3 = remake(prob3; p = new_params, u0=u0, tspan=(0,env.t_trial)) @info env.current_trial sol2 = solve(prob3, Vern7()) agent.problem = prob3 diff --git a/src/Neuroblox.jl b/src/Neuroblox.jl index 8f3aa517..4e08c815 100644 --- a/src/Neuroblox.jl +++ b/src/Neuroblox.jl @@ -44,6 +44,7 @@ using Symbolics: @register_symbolic, getdefaultval using DelimitedFiles: readdlm using CSV: read using DataFrames +using JLD2 using Peaks: argmaxima, peakproms!, peakheights! diff --git a/src/blox/reinforcement_learning.jl b/src/blox/reinforcement_learning.jl index 7d2044cf..7c6045dd 100644 --- a/src/blox/reinforcement_learning.jl +++ b/src/blox/reinforcement_learning.jl @@ -311,3 +311,81 @@ function run_trial!(agent::Agent, env::ClassificationEnvironment, weights::Dict{ # u0 = sol[1:end,end] end end + +function run_experiment!(agent::Agent, env::ClassificationEnvironment, save_path::String, t_warmup=200.0; kwargs...) + N_trials = env.N_trials + t_trial = env.t_trial + tspan = (0, t_trial) + sys = get_sys(agent) + prob = agent.problem + + if t_warmup > 0 + prob = remake(prob; tspan=(0,t_warmup)) + if haskey(kwargs, :alg) + sol = solve(prob, kwargs[:alg]; kwargs...) + else + sol = solve(prob; alg_hints = [:stiff], kwargs...) + end + u0 = sol[1:end,end] # last value of state vector + prob = remake(prob; tspan=tspan, u0=u0) + else + prob = remake(prob; tspan) + u0 = [] + end + + action_selection = agent.action_selection + learning_rules = agent.learning_rules + + defs = ModelingToolkit.get_defaults(sys) + weights = Dict{Num, Float64}() + for w in keys(learning_rules) + weights[w] = defs[w] + end + + for trial_num in Base.OneTo(N_trials) + + stim_params = get_trial_stimulus(env) + + to_update = merge(weights, stim_params) + new_params = ModelingToolkit.MTKParameters(sys, merge(defs, weights, stim_params)) + + prob = remake(prob; p = new_params, u0=u0) + if haskey(kwargs, :alg) + sol = solve(prob, kwargs[:alg]; kwargs...) + else + sol = solve(prob; alg_hints = [:stiff], kwargs...) + end + + # u0 = sol[1:end,end] # next run should continue where the last one ended + # In the paper we assume sufficient time interval before net stimulus so that + # system reaches back to steady state, so we don't continue from previous trial's endpoint + + if isnothing(action_selection) + feedback = 1 + else + action = action_selection(sol) + feedback = env(action) + end + + for (w, rule) in learning_rules + w_val = weights[w] + Δw = weight_gradient(rule, sol, w_val, feedback) + weights[w] += Δw + end + increment_trial!(env) + + if !isnothing(save_path) + save_voltages(sol, save_path, trial_num) + end + + end + + agent.problem = prob +end + +function save_voltages(sol, filepath, numtrial) + df = DataFrame(sol) + fname = "sim"*lpad(numtrial, 4, "0")*".csv" + fullpath = joinpath(filepath, fname) + CSV.write(fullpath, df) +end \ No newline at end of file diff --git a/test/reinforcement_learning_flattening.jl b/test/reinforcement_learning_flattening.jl new file mode 100644 index 00000000..ce7b5d92 --- /dev/null +++ b/test/reinforcement_learning_flattening.jl @@ -0,0 +1,88 @@ +using Neuroblox +using DifferentialEquations +using Test +using Graphs +using MetaGraphs +using DataFrames +using CSV +using ModelingToolkit: getp + +@testset "RL test with save" begin + t_trial = 2 # ms + time_block_dur = 0.01 # ms + N_trials = 3 + + global_ns = :g # global namespace + @named VAC = CorticalBlox(N_wta=3, N_exci=3, namespace=global_ns, density=0.1, weight=1) + @named PFC = CorticalBlox(N_wta=2, N_exci=3, namespace=global_ns, density=0.1, weight=1) + @named STR_L = Striatum(N_inhib=2, namespace=global_ns) + @named STR_R = Striatum(N_inhib=2, namespace=global_ns) + @named SNcb = SNc(namespace=global_ns, N_time_blocks=t_trial/time_block_dur) + @named TAN_pop = TAN(;namespace=global_ns) + + @named AS = GreedyPolicy(namespace=global_ns, t_decision=0.31*t_trial) + + fn = joinpath(@__DIR__, "../examples/image_example.csv") + data = CSV.read(fn, DataFrame) + @named stim = ImageStimulus(data[1:N_trials,:]; namespace=global_ns, t_stimulus=0.4*t_trial, t_pause=0.6*t_trial) + + bloxs = [VAC, PFC, STR_L, STR_R, SNcb, TAN_pop, AS, stim] + d = Dict(b => i for (i,b) in enumerate(bloxs)) + + hebbian_mod = HebbianModulationPlasticity(K=0.2, decay=0.01, α=3, θₘ=1, modulator=SNcb, t_pre=t_trial, t_post=t_trial, t_mod=0.31*t_trial) + hebbian = HebbianPlasticity(K=0.2, W_lim=2, t_pre=t_trial, t_post=t_trial) + + g = MetaDiGraph() + add_blox!.(Ref(g), bloxs) + + add_edge!(g, d[stim], d[VAC], Dict(:weight => 1, :density => 0.1)) + add_edge!(g, d[VAC], d[PFC], Dict(:weight => 1, :density => 0.1, :learning_rule => hebbian)) + add_edge!(g, d[PFC], d[STR_L], Dict(:weight => 1, :density => 0.1, :learning_rule => hebbian_mod)) + add_edge!(g, d[PFC], d[STR_R], Dict(:weight => 1, :density => 0.1, :learning_rule => hebbian_mod)) + add_edge!(g, d[STR_R], d[STR_L], Dict(:weight => 1, :t_event => 0.3*t_trial)) + add_edge!(g, d[STR_L], d[STR_R], Dict(:weight => 1, :t_event => 0.3*t_trial)) + add_edge!(g, d[STR_L], d[SNcb], Dict(:weight => 1)) + add_edge!(g, d[STR_R], d[SNcb], Dict(:weight => 1)) + add_edge!(g, d[STR_L], d[AS]) + add_edge!(g, d[STR_R], d[AS]) + add_edge!(g, d[STR_L], d[TAN_pop], Dict(:weight => 1)) + add_edge!(g, d[STR_R], d[TAN_pop], Dict(:weight => 1)) + add_edge!(g, d[TAN_pop], d[STR_L], Dict(:weight => 1, :t_event => 0.1*t_trial)) + add_edge!(g, d[TAN_pop], d[STR_R], Dict(:weight => 1, :t_event => 0.1*t_trial)) + + agent = Agent(g; name=:ag, t_block = t_trial/5); + ps = parameters(agent.odesystem) + + + map_idxs = Int.(ModelingToolkit.varmap_to_vars([ps[i] => i for i in eachindex(ps)], ps)) + idxs_weight = findall(x -> occursin("w_", String(Symbol(x))), ps) + idx_stim = findall(x -> occursin("stim₊", String(Symbol(x))), ps) + idx_jcn = findall(x -> occursin("jcn", String(Symbol(x))), ps) + idx_spikes = findall(x -> occursin("spikes", String(Symbol(x))), ps) + idx_H = findall(x -> occursin("H", String(Symbol(x))), ps) + idx_I_bg = findall(x -> occursin("I_bg", String(Symbol(x))), ps) + idxs_other_params = setdiff(eachindex(ps), vcat(idxs_weight, idx_stim, idx_jcn, idx_spikes, idx_H, idx_I_bg)) + + params_at(idxs) = getp(agent.problem, parameters(agent.odesystem)[idxs])(agent.problem) + init_params_all = params_at(:) + init_params_idxs_weight = params_at(idxs_weight) + init_params_idxs_other_params = params_at(idxs_other_params) + + env = ClassificationEnvironment(stim; name=:env, namespace=global_ns) + run_experiment!(agent, env, "./"; alg=Vern7(), reltol=1e-9,abstol=1e-9) + + final_params = reduce(vcat, agent.problem.p) + # At least some weights need to be different. + @test any(init_params_idxs_weight .!= params_at(idxs_weight)) + # @test any(init_params[map_idxs[idxs_weight]] .!= final_params[map_idxs[idxs_weight]]) + # All non-weight parameters need to be the same. + @test all(init_params_idxs_other_params .== params_at(idxs_other_params)) + # @test all(init_params[map_idxs[idxs_other_params]] .== final_params[map_idxs[idxs_other_params]]) + + reset!(agent) + @test all(init_params_all .== params_at(:)) + @show setdiff(init_params_all, params_at(:)) + reset!(env) + @test env.current_trial == 1 +end +