Skip to content

Commit

Permalink
Deployed db11f5d with MkDocs version: 1.3.0
Browse files Browse the repository at this point in the history
  • Loading branch information
francesco-innocenti committed Nov 27, 2024
1 parent b354ae7 commit dd75db1
Show file tree
Hide file tree
Showing 35 changed files with 427 additions and 1,536 deletions.
14 changes: 0 additions & 14 deletions 404.html
Original file line number Diff line number Diff line change
Expand Up @@ -306,20 +306,6 @@








<li class="md-nav__item">
<a href="/extending_jpc/" class="md-nav__link">
Extending JPC
</a>
</li>




</ul>
</nav>
</li>
Expand Down
14 changes: 0 additions & 14 deletions FAQs/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -306,20 +306,6 @@








<li class="md-nav__item">
<a href="../extending_jpc/" class="md-nav__link">
Extending JPC
</a>
</li>




</ul>
</nav>
</li>
Expand Down
92 changes: 73 additions & 19 deletions advanced_usage/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,11 @@
<label class="md-overlay" for="__drawer"></label>
<div data-md-component="skip">


<a href="#advanced-usage" class="md-skip">
Skip to content
</a>

</div>
<div data-md-component="announce">

Expand Down Expand Up @@ -306,6 +311,8 @@
<input class="md-nav__toggle md-toggle" data-md-toggle="toc" type="checkbox" id="__toc">





<a href="./" class="md-nav__link md-nav__link--active">
Advanced usage
Expand All @@ -316,20 +323,6 @@








<li class="md-nav__item">
<a href="../extending_jpc/" class="md-nav__link">
Extending JPC
</a>
</li>




</ul>
</nav>
</li>
Expand Down Expand Up @@ -719,9 +712,70 @@



<h1>Advanced usage</h1>


<h1 id="advanced-usage">Advanced usage<a class="headerlink" href="#advanced-usage" title="Permanent link">¤</a></h1>
<p>Advanced users can access all the underlying functions of <code>jpc.make_pc_step</code> as
well as additional features. A custom PC training step looks like the following:
<div class="highlight"><pre><span></span><code><span class="kn">import</span> <span class="nn">jpc</span>

<span class="c1"># 1. initialise activities with a feedforward pass</span>
<span class="n">activities</span> <span class="o">=</span> <span class="n">jpc</span><span class="o">.</span><span class="n">init_activities_with_ffwd</span><span class="p">(</span><span class="n">model</span><span class="o">=</span><span class="n">model</span><span class="p">,</span> <span class="nb">input</span><span class="o">=</span><span class="n">x</span><span class="p">)</span>

<span class="c1"># 2. run inference to equilibrium</span>
<span class="n">equilibrated_activities</span> <span class="o">=</span> <span class="n">jpc</span><span class="o">.</span><span class="n">solve_inference</span><span class="p">(</span>
<span class="n">params</span><span class="o">=</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="kc">None</span><span class="p">),</span>
<span class="n">activities</span><span class="o">=</span><span class="n">activities</span><span class="p">,</span>
<span class="n">output</span><span class="o">=</span><span class="n">y</span><span class="p">,</span>
<span class="nb">input</span><span class="o">=</span><span class="n">x</span>
<span class="p">)</span>

<span class="c1"># 3. update parameters at the activities&#39; solution with PC</span>
<span class="n">param_update_result</span> <span class="o">=</span> <span class="n">jpc</span><span class="o">.</span><span class="n">update_params</span><span class="p">(</span>
<span class="n">params</span><span class="o">=</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="kc">None</span><span class="p">),</span>
<span class="n">activities</span><span class="o">=</span><span class="n">equilibrated_activities</span><span class="p">,</span>
<span class="n">optim</span><span class="o">=</span><span class="n">param_optim</span><span class="p">,</span>
<span class="n">opt_state</span><span class="o">=</span><span class="n">param_opt_state</span><span class="p">,</span>
<span class="n">output</span><span class="o">=</span><span class="n">y</span><span class="p">,</span>
<span class="nb">input</span><span class="o">=</span><span class="n">x</span>
<span class="p">)</span>

<span class="c1"># updated model and optimiser</span>
<span class="n">model</span> <span class="o">=</span> <span class="n">param_update_result</span><span class="p">[</span><span class="s2">&quot;model&quot;</span><span class="p">]</span>
<span class="n">param_optim</span> <span class="o">=</span> <span class="n">param_update_result</span><span class="p">[</span><span class="s2">&quot;optim&quot;</span><span class="p">]</span>
<span class="n">param_opt_state</span> <span class="o">=</span> <span class="n">param_update_result</span><span class="p">[</span><span class="s2">&quot;opt_state&quot;</span><span class="p">]</span>
</code></pre></div>
which can be embedded in a jitted function with any other additional
computations. One can also use any Optax optimiser to equilibrate the inference
dynamics by replacing the function in step 2, as shown below.
<div class="highlight"><pre><span></span><code><span class="n">activity_optim</span> <span class="o">=</span> <span class="n">optax</span><span class="o">.</span><span class="n">sgd</span><span class="p">(</span><span class="mf">1e-3</span><span class="p">)</span>

<span class="c1"># 1. initialise activities</span>
<span class="o">...</span>

<span class="c1"># 2. infer with gradient descent</span>
<span class="n">activity_opt_state</span> <span class="o">=</span> <span class="n">activity_optim</span><span class="o">.</span><span class="n">init</span><span class="p">(</span><span class="n">activities</span><span class="p">)</span>

<span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">T</span><span class="p">):</span>
<span class="n">activity_update_result</span> <span class="o">=</span> <span class="n">jpc</span><span class="o">.</span><span class="n">update_activities</span><span class="p">(</span>
<span class="n">params</span><span class="o">=</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="kc">None</span><span class="p">),</span>
<span class="n">activities</span><span class="o">=</span><span class="n">activities</span><span class="p">,</span>
<span class="n">optim</span><span class="o">=</span><span class="n">activity_optim</span><span class="p">,</span>
<span class="n">opt_state</span><span class="o">=</span><span class="n">activity_opt_state</span><span class="p">,</span>
<span class="n">output</span><span class="o">=</span><span class="n">y</span><span class="p">,</span>
<span class="nb">input</span><span class="o">=</span><span class="n">x</span>
<span class="p">)</span>
<span class="c1"># updated activities and optimiser</span>
<span class="n">activities</span> <span class="o">=</span> <span class="n">activity_update_result</span><span class="p">[</span><span class="s2">&quot;activities&quot;</span><span class="p">]</span>
<span class="n">activity_optim</span> <span class="o">=</span> <span class="n">activity_update_result</span><span class="p">[</span><span class="s2">&quot;optim&quot;</span><span class="p">]</span>
<span class="n">activity_opt_state</span> <span class="o">=</span> <span class="n">activity_update_result</span><span class="p">[</span><span class="s2">&quot;opt_state&quot;</span><span class="p">]</span>

<span class="c1"># 3. update parameters at the activities&#39; solution with PC</span>
<span class="o">...</span>
</code></pre></div>
JPC also comes with some analytical tools that can be used to study and
potentially diagnose issues with PCNs (see <a href="https://thebuckleylab.github.io/jpc/api/Analytical%20tools/">docs
</a>
and <a href="https://thebuckleylab.github.io/jpc/examples/linear_net_theoretical_energy/">example notebook
</a>).</p>



Expand Down Expand Up @@ -757,13 +811,13 @@ <h1>Advanced usage</h1>



<a href="../extending_jpc/" class="md-footer__link md-footer__link--next" aria-label="Next: Extending JPC" rel="next">
<a href="../examples/discriminative_pc/" class="md-footer__link md-footer__link--next" aria-label="Next: Discriminative PC" rel="next">
<div class="md-footer__title">
<div class="md-ellipsis">
<span class="md-footer__direction">
Next
</span>
Extending JPC
Discriminative PC
</div>
</div>
<div class="md-footer__button md-icon">
Expand Down
14 changes: 0 additions & 14 deletions api/Analytical tools/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -311,20 +311,6 @@








<li class="md-nav__item">
<a href="../../extending_jpc/" class="md-nav__link">
Extending JPC
</a>
</li>




</ul>
</nav>
</li>
Expand Down
14 changes: 0 additions & 14 deletions api/Energy functions/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -311,20 +311,6 @@








<li class="md-nav__item">
<a href="../../extending_jpc/" class="md-nav__link">
Extending JPC
</a>
</li>




</ul>
</nav>
</li>
Expand Down
70 changes: 56 additions & 14 deletions api/Gradients/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -311,20 +311,6 @@








<li class="md-nav__item">
<a href="../../extending_jpc/" class="md-nav__link">
Extending JPC
</a>
</li>




</ul>
</nav>
</li>
Expand Down Expand Up @@ -700,6 +686,13 @@
neg_activity_grad()
</a>

</li>

<li class="md-nav__item">
<a href="#jpc.compute_activity_grad" class="md-nav__link">
compute_activity_grad()
</a>

</li>

<li class="md-nav__item">
Expand Down Expand Up @@ -773,6 +766,13 @@


<h1 id="gradients">Gradients<a class="headerlink" href="#gradients" title="Permanent link">¤</a></h1>
<div class="admonition note">
<p class="admonition-title">Note</p>
<p>There are two similar functions to compute the activity gradient:
<code>jpc.neg_activity_grad</code> and <code>jpc.compute_activity_grad</code>. The first is used
by <code>jpc.solve_inference</code> as gradient flow, while the second is for
compatibility with discrete optax optimisers such as gradient descent.</p>
</div>


<div class="doc doc-object doc-function">
Expand Down Expand Up @@ -815,6 +815,48 @@ <h4 id="jpc.neg_activity_grad" class="doc doc-heading">



<h4 id="jpc.compute_activity_grad" class="doc doc-heading">
<code class="highlight language-python"><span class="n">jpc</span><span class="o">.</span><span class="n">compute_activity_grad</span><span class="p">(</span><span class="n">params</span><span class="p">:</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">PyTree</span><span class="p">[</span><span class="n">Callable</span><span class="p">],</span> <span class="n">Optional</span><span class="p">[</span><span class="n">PyTree</span><span class="p">[</span><span class="n">Callable</span><span class="p">]]],</span> <span class="n">activities</span><span class="p">:</span> <span class="n">PyTree</span><span class="p">[</span><span class="n">ArrayLike</span><span class="p">],</span> <span class="n">y</span><span class="p">:</span> <span class="n">ArrayLike</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">ArrayLike</span><span class="p">],</span> <span class="n">loss_id</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s1">&#39;MSE&#39;</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">PyTree</span><span class="p">[</span><span class="n">Array</span><span class="p">]</span></code>


<a href="#jpc.compute_activity_grad" class="headerlink" title="Permanent link">¤</a></h4>

<div class="doc doc-contents first">

<p>Computes the gradient of the energy with respect to the activities <span class="arithmatex">\(\partial \mathcal{F} / \partial \mathbf{z}\)</span>.</p>
<div class="admonition note">
<p class="admonition-title">Note</p>
<p>This function differs from <code>neg_activity_grad</code>, which computes the
negative gradients, and is called in <code>update_activities</code> for use of
any Optax optimiser.</p>
</div>
<p><strong>Main arguments:</strong></p>
<ul>
<li><code>params</code>: Tuple with callable model layers and optional skip connections.</li>
<li><code>activities</code>: List of activities for each layer free to vary.</li>
<li><code>y</code>: Observation or target of the generative model.</li>
<li><code>x</code>: Optional prior of the generative model.</li>
</ul>
<p><strong>Other arguments:</strong></p>
<ul>
<li><code>loss_id</code>: Loss function for the output layer (mean squared error 'MSE'
vs cross-entropy 'CE').</li>
<li><code>energy_fn</code>: Free energy to take the gradient of.</li>
</ul>
<p><strong>Returns:</strong></p>
<p>List of negative gradients of the energy w.r.t. the activities.</p>

</div>

</div>

<hr />


<div class="doc doc-object doc-function">



<h4 id="jpc.compute_pc_param_grads" class="doc doc-heading">
<code class="highlight language-python"><span class="n">jpc</span><span class="o">.</span><span class="n">compute_pc_param_grads</span><span class="p">(</span><span class="n">params</span><span class="p">:</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">PyTree</span><span class="p">[</span><span class="n">Callable</span><span class="p">],</span> <span class="n">Optional</span><span class="p">[</span><span class="n">PyTree</span><span class="p">[</span><span class="n">Callable</span><span class="p">]]],</span> <span class="n">activities</span><span class="p">:</span> <span class="n">PyTree</span><span class="p">[</span><span class="n">ArrayLike</span><span class="p">],</span> <span class="n">y</span><span class="p">:</span> <span class="n">ArrayLike</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">ArrayLike</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">loss_id</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s1">&#39;MSE&#39;</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">PyTree</span><span class="p">[</span><span class="n">Array</span><span class="p">],</span> <span class="n">PyTree</span><span class="p">[</span><span class="n">Array</span><span class="p">]]</span></code>

Expand Down
14 changes: 0 additions & 14 deletions api/Inference/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -311,20 +311,6 @@








<li class="md-nav__item">
<a href="../../extending_jpc/" class="md-nav__link">
Extending JPC
</a>
</li>




</ul>
</nav>
</li>
Expand Down
19 changes: 5 additions & 14 deletions api/Initialisation/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -311,20 +311,6 @@








<li class="md-nav__item">
<a href="../../extending_jpc/" class="md-nav__link">
Extending JPC
</a>
</li>




</ul>
</nav>
</li>
Expand Down Expand Up @@ -773,6 +759,11 @@


<h1 id="initialisation">Initialisation<a class="headerlink" href="#initialisation" title="Permanent link">¤</a></h1>
<div class="admonition info">
<p class="admonition-title">Info</p>
<p>JPC provides 3 standard ways of initialising the activities: a feedforward
pass, randomly, or using an amortised network.</p>
</div>


<div class="doc doc-object doc-function">
Expand Down
14 changes: 0 additions & 14 deletions api/Testing/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -311,20 +311,6 @@








<li class="md-nav__item">
<a href="../../extending_jpc/" class="md-nav__link">
Extending JPC
</a>
</li>




</ul>
</nav>
</li>
Expand Down
Loading

0 comments on commit dd75db1

Please sign in to comment.