<?xml version="1.0" encoding="UTF-8"?>
<rss  xmlns:atom="http://www.w3.org/2005/Atom" 
      xmlns:media="http://search.yahoo.com/mrss/" 
      xmlns:content="http://purl.org/rss/1.0/modules/content/" 
      xmlns:dc="http://purl.org/dc/elements/1.1/" 
      version="2.0">
<channel>
<title>Greg&#39;s blog</title>
<link>https://gcerar.github.io/blog.html</link>
<atom:link href="https://gcerar.github.io/blog.xml" rel="self" type="application/rss+xml"/>
<description>Greg&#39;s personal blog</description>
<generator>quarto-1.9.37</generator>
<lastBuildDate>Tue, 16 Dec 2025 23:00:00 GMT</lastBuildDate>
<item>
  <title>Reinforcement Learning: Tabular Q-Learning</title>
  <dc:creator>Gregor Cerar</dc:creator>
  <link>https://gcerar.github.io/posts/reinforcement-learning/2025-12-17-tabular-q-learning.html</link>
  <description><![CDATA[ 





<p>As I started exploring reinforcement learning, a colleague suggested me to start with a <strong>Q-learning</strong>, one of the simplest and most widely used algorithms in the field. To get a hands-on feel for the fundamentals, I decided to replicate the official <a href="https://gymnasium.farama.org/tutorials/training_agents/frozenlake_q_learning/">Solving Frozenlake with Tabular Q-Learning</a> tutorial from the <a href="https://gymnasium.farama.org/">Gymnasium docs</a>.</p>
<p>This post captures that learning journey: walking through the environment, understanding the Q-learning steps, and getting comfortable with Gymnasium library along the way.</p>
<section id="the-frozen-lake-environment" class="level1">
<h1>The Frozen Lake Environment</h1>
<p>The frozen lake is a small, grid-based reinforcement learning environment. We play as an elf whose goal is to cross a frozen lake from the starting tile (top-left corner) to a present (bottom-right) corner, without falling into any holes along the way.</p>
<p>To make the task more interesting (and more realistic), the lake can be <strong>slippery</strong>. When Sliperiness is enabled, the elf does not always move exactly in the intended direction and may occasionally slip sideways.</p>
<div class="quarto-figure quarto-figure-center">
<figure class="figure">
<p><img src="https://gcerar.github.io/posts/reinforcement-learning/frozenlake-sample.png" class="img-fluid figure-img"></p>
<figcaption>Frozen Lake (3x3) sample</figcaption>
</figure>
</div>
<section id="action-space" class="level2">
<h2 class="anchored" data-anchor-id="action-space">Action Space</h2>
<p>The action space is simple and discrete. At each timestep, the agent can choose one of four actions:</p>
<ul>
<li>move left,</li>
<li>move right,</li>
<li>move down,</li>
<li>move up.</li>
</ul>
<p>Formally, the action is represented as a scalar with shape <img src="https://latex.codecogs.com/png.latex?(1,)">, taking values from the set <img src="https://latex.codecogs.com/png.latex?%5C%7B%200,%201,%202,%203%20%5C%7D">.</p>
<p>This representation is a general abstraction used throughout Gymnasium. In more complex environments, a single action may encode multiple simultaneous commands. For example, in a game like <em>Super Mario</em>, a player can jump while moving left or right. Such combinations are still treated as a single action by the environment.</p>
</section>
<section id="observation-space" class="level2">
<h2 class="anchored" data-anchor-id="observation-space">Observation Space</h2>
<p>The observation returned by the environment represents the agent’s <strong>current position</strong> on the grid. Since Frozen Lake consists of a finite number of discrete tiles, each tile is assigned a unique integer identifier.</p>
<p>For example, a <img src="https://latex.codecogs.com/png.latex?3%5Ctimes%7B%7D3"> grid is indexed as:</p>
<p><img src="https://latex.codecogs.com/png.latex?%0A%5Cbegin%7Bmatrix%7D%0A0%20&amp;%201%20&amp;%202%20%5C%5C%0A3%20&amp;%204%20&amp;%205%20%5C%5C%0A6%20&amp;%207%20&amp;%208%0A%5Cend%7Bmatrix%7D%0A"></p>
<p>More generally, the tile index can be computed as:</p>
<p><img src="https://latex.codecogs.com/png.latex?%0A%5Ctext%7Btile%7D(r,c)%20=%20r%20%5Ccdot%20N_%7B%5Ctext%7Bcols%7D%7D%20+%20c,%0A"></p>
<p>where:</p>
<ul>
<li>where <img src="https://latex.codecogs.com/png.latex?r"> is row index,</li>
<li><img src="https://latex.codecogs.com/png.latex?c"> is column index,</li>
<li><img src="https://latex.codecogs.com/png.latex?N_%7B%5Ctext%7Bcols%7D%7D"> is number of columns in grid.</li>
</ul>
<p>This discrete state representation makes Frozen Lake particularly well-suited for tabular methods such as Q-learning.</p>
</section>
<section id="rewards" class="level2">
<h2 class="anchored" data-anchor-id="rewards">Rewards</h2>
<p>The default reward structure is sparse:</p>
<ul>
<li><img src="https://latex.codecogs.com/png.latex?+1"> for reaching the goal tile,</li>
<li><img src="https://latex.codecogs.com/png.latex?0"> for stepping onto a frozen tile,</li>
<li><img src="https://latex.codecogs.com/png.latex?0"> for falling into a hole.</li>
</ul>
<p>In other words, the agent receives a reward only when it successfully reaches the goal. This sparse reward setting makes the problem deceptively challenging and highlights the importance of exploration in reinforcement learning.</p>
<p>For full details, see the official <a href="https://gymnasium.farama.org/environments/toy_text/frozen_lake/">Frozen Lake environment documentation</a>.</p>
</section>
</section>
<section id="reinforcement-learning-formulation" class="level1">
<h1>Reinforcement Learning Formulation</h1>
<p>Frozen Lake can be formalized as a <strong>finite Markov Decision Process (MDP)</strong> defined by the tuple <img src="https://latex.codecogs.com/png.latex?(%5Cmathcal%7BS%7D,%20%5Cmathcal%7BA%7D,%20%5Cmathcal%7BP%7D,%20%5Cmathcal%7BR%7D,%20%5Cgamma)">.</p>
<section id="state-space-mathcals" class="level2">
<h2 class="anchored" data-anchor-id="state-space-mathcals">State Space <img src="https://latex.codecogs.com/png.latex?%5Cmathcal%7BS%7D"></h2>
<p>The state space consists of all discrete tiles on the grid:</p>
<p><img src="https://latex.codecogs.com/png.latex?%0A%5Cmathcal%7BS%7D%20=%20%5C%7B%200,%201,%20%5Cdots,%20(N_%5Ctext%7Brows%7D%5Ccdot%7B%7DN_%5Ctext%7Bcols%7D-1)%20%5C%7D.%0A"></p>
<p>Each state uniquely represents the agent’s current position in the lake.</p>
</section>
<section id="action-space-mathcala" class="level2">
<h2 class="anchored" data-anchor-id="action-space-mathcala">Action Space <img src="https://latex.codecogs.com/png.latex?%5Cmathcal%7BA%7D"></h2>
<p>At each time step, the agent can choose one of four actions:</p>
<p><img src="https://latex.codecogs.com/png.latex?%0A%5Cmathcal%7BA%7D%20=%20%5C%7B%20%5Ctext%7Bleft%7D,%20%5Ctext%7Bdown%7D,%20%5Ctext%7Bright%7D,%20%5Ctext%7Bup%7D%20%5C%7D.%0A"></p>
<p>These actions correspond to deterministic intentions, even though the actual transition may be stochastic when the lake is slippery.</p>
</section>
<section id="transition-dynamics-pssa" class="level2">
<h2 class="anchored" data-anchor-id="transition-dynamics-pssa">Transition Dynamics <img src="https://latex.codecogs.com/png.latex?P(s'%7Cs,a)"></h2>
<p>The transition function defines the probability of moving from state <img src="https://latex.codecogs.com/png.latex?s"> to state <img src="https://latex.codecogs.com/png.latex?s'"> after taking action <img src="https://latex.codecogs.com/png.latex?a">.</p>
<ul>
<li>In the <strong>non-slippery</strong> version of the environment, transitions are deterministic.</li>
<li>In the <strong>slippery</strong> version, the intended action may fail, and the agent may move in a perpendicular direction with non-zero probability.</li>
</ul>
<p>This stochasticity makes Frozen Lake a useful testbed for algorithms that must learn under uncertainty.</p>
</section>
<section id="reward-function-rs-a-s" class="level2">
<h2 class="anchored" data-anchor-id="reward-function-rs-a-s">Reward Function <img src="https://latex.codecogs.com/png.latex?R(s,%20a,%20s')"></h2>
<p>The reward function is sparse and simple:</p>
<p><img src="https://latex.codecogs.com/png.latex?%0AR(s,%20a,%20s')%20=%20%5Cbegin%7Bcases%7D%0A%20%20%20%201,%20&amp;%20%5Ctext%7Bif%20$s'$%20is%20the%20goal%20state%7D,%5C%5C%0A%20%20%20%200,%20&amp;%20%5Ctext%7Botherwise%7D.%0A%5Cend%7Bcases%7D%0A"></p>
<p>Episodes terminate when the agent reaches the goal or falls into a hole.</p>
<div id="7b7582c1" class="cell">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb1" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb1-1"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> os</span>
<span id="cb1-2"></span>
<span id="cb1-3"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># get rid of the audio warnings</span></span>
<span id="cb1-4">os.environ[<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"SDL_AUDIODRIVER"</span>] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"dummy"</span></span></code></pre></div></div>
</div>
<div id="2344f8c3" class="cell">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb2" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb2-1"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">from</span> dataclasses <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> dataclass</span>
<span id="cb2-2"></span>
<span id="cb2-3"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> gymnasium <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">as</span> gym</span>
<span id="cb2-4"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> matplotlib.pyplot <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">as</span> plt</span>
<span id="cb2-5"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> numpy <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">as</span> np</span>
<span id="cb2-6"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> pandas <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">as</span> pd</span>
<span id="cb2-7"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> seaborn <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">as</span> sns</span>
<span id="cb2-8"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">from</span> aquarel <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> load_theme</span>
<span id="cb2-9"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">from</span> gymnasium.envs.toy_text.frozen_lake <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> generate_random_map</span>
<span id="cb2-10"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">from</span> tqdm.auto <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> tqdm, trange</span>
<span id="cb2-11"></span>
<span id="cb2-12"><span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">%</span>config InlineBackend.figure_formats <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> {<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'retina'</span>, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'png'</span>}</span></code></pre></div></div>
</div>
<div id="229571f8" class="cell">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb3" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb3-1"><span class="at" style="color: #657422;
background-color: null;
font-style: inherit;">@dataclass</span>(frozen<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>, slots<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>)</span>
<span id="cb3-2"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">class</span> Params:</span>
<span id="cb3-3">    n_runs: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">20</span>  <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># number of runs from scratch</span></span>
<span id="cb3-4">    total_episodes: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2_000</span>  <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># total episodes (# of playthroughs) in the same run</span></span>
<span id="cb3-5">    learning_rate: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">float</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.8</span>  <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Q-Learning learning rate</span></span>
<span id="cb3-6">    gamma: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">float</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.95</span>  <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># discounting rate</span></span>
<span id="cb3-7">    epsilon: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">float</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.1</span>  <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># probability of exploration vs. exploitation</span></span>
<span id="cb3-8">    proba_frozen: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">float</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.9</span>  <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># probability that a tile is frozen (not a hole)</span></span>
<span id="cb3-9">    is_slippery: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">bool</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">False</span>  <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># enables slipping: 1/3 forward, 1/3 left, 1/3 right</span></span>
<span id="cb3-10">    seed: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">123</span>  <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># seed for reproducability</span></span></code></pre></div></div>
</div>
<div id="e9e1d9c2" class="cell">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb4" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb4-1">SHOW_PROGRESS: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">bool</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">False</span></span></code></pre></div></div>
</div>
</section>
</section>
<section id="the-implementation" class="level1">
<h1>The Implementation</h1>
<div id="c90869fb" class="cell">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb5" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb5-1"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">class</span> Qlearning:</span>
<span id="cb5-2">    qtable: np.ndarray</span>
<span id="cb5-3"></span>
<span id="cb5-4">    <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> <span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">__init__</span>(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, lr: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">float</span>, gamma: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">float</span>, state_size: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span>, action_size: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span>) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-&gt;</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span>:</span>
<span id="cb5-5">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.lr <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> lr</span>
<span id="cb5-6">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.gamma <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> gamma</span>
<span id="cb5-7">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.state_size <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> state_size</span>
<span id="cb5-8">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.action_size <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> action_size</span>
<span id="cb5-9">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.reset_qtable()</span>
<span id="cb5-10"></span>
<span id="cb5-11">    <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> update(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, state: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span>, action: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span>, reward: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">float</span>, new_state: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span>) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-&gt;</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">float</span>:</span>
<span id="cb5-12">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">"""Update Q(s,a) := Q(s,a) + lr * [R(s,a) + gamma * max Q(s',a') - Q(s,a)]"""</span></span>
<span id="cb5-13">        delta <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> reward <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.gamma <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> np.<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">max</span>(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.qtable[new_state, :]) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.qtable[state, action]</span>
<span id="cb5-14"></span>
<span id="cb5-15">        q_update <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.qtable[state, action] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.lr <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> delta</span>
<span id="cb5-16">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> q_update</span>
<span id="cb5-17"></span>
<span id="cb5-18">    <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> reset_qtable(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-&gt;</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span>:</span>
<span id="cb5-19">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">"""Reset the Q-table."""</span></span>
<span id="cb5-20">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.qtable <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> np.zeros((<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.state_size, <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.action_size))</span>
<span id="cb5-21"></span>
<span id="cb5-22"></span>
<span id="cb5-23"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">class</span> EpsilonGreedy:</span>
<span id="cb5-24">    <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> <span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">__init__</span>(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, epsilon: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">float</span>, seed: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">|</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span>) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-&gt;</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span>:</span>
<span id="cb5-25">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.eps <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> epsilon</span>
<span id="cb5-26">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.rng <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> np.random.default_rng(seed)</span>
<span id="cb5-27"></span>
<span id="cb5-28">    <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> choose_action(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, action_space: gym.spaces.Space, state: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span>, qtable: np.ndarray) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-&gt;</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span>:</span>
<span id="cb5-29">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">"""Choose an action `a` in the current world state (s)."""</span></span>
<span id="cb5-30">        action: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span></span>
<span id="cb5-31"></span>
<span id="cb5-32">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># random number decides whether we do ...</span></span>
<span id="cb5-33">        explore_exploit_tradeoff <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.rng.uniform(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>)</span>
<span id="cb5-34">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> explore_exploit_tradeoff <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">&lt;</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.eps:  <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># ... exploration (random action) ...</span></span>
<span id="cb5-35">            action <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> action_space.sample()</span>
<span id="cb5-36">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">else</span>:  <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># ... or exploitation (use direction with the biggest Q-value for this state)</span></span>
<span id="cb5-37">            (max_ids,) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> np.where(qtable[state, :] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">==</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">max</span>(qtable[state, :]))</span>
<span id="cb5-38">            action <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.rng.choice(max_ids)  <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># pick one if multiple directions with max probability</span></span>
<span id="cb5-39"></span>
<span id="cb5-40">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> action</span></code></pre></div></div>
</div>
</section>
<section id="define-training-loop" class="level1">
<h1>Define Training Loop</h1>
<div id="93d8076b" class="cell">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb6" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb6-1"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> run_env(env: gym.Env, learner: Qlearning, explorer: EpsilonGreedy, p: Params, state_size: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span>, action_size: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span>):</span>
<span id="cb6-2">    rewards <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> np.zeros((p.total_episodes, p.n_runs), dtype<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="bu" style="color: null;
background-color: null;
font-style: inherit;">float</span>)</span>
<span id="cb6-3">    steps <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> np.zeros((p.total_episodes, p.n_runs), dtype<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span>)</span>
<span id="cb6-4">    episodes <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> np.arange(p.total_episodes, dtype<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span>)</span>
<span id="cb6-5"></span>
<span id="cb6-6">    qtables <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> np.zeros((p.n_runs, state_size, action_size), dtype<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="bu" style="color: null;
background-color: null;
font-style: inherit;">float</span>)</span>
<span id="cb6-7"></span>
<span id="cb6-8">    all_states: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">list</span>[<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span>] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> []</span>
<span id="cb6-9">    all_actions: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">list</span>[<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span>] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> []</span>
<span id="cb6-10"></span>
<span id="cb6-11">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> run <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> trange(p.n_runs, leave<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">False</span>, disable<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>(<span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">not</span> SHOW_PROGRESS)):</span>
<span id="cb6-12">        learner.reset_qtable()</span>
<span id="cb6-13"></span>
<span id="cb6-14">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> episode <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> tqdm(episodes, leave<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">False</span>, disable<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>(<span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">not</span> SHOW_PROGRESS)):</span>
<span id="cb6-15">            state, _ <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> env.reset(seed<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>p.seed)</span>
<span id="cb6-16">            step: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span></span>
<span id="cb6-17">            done: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">bool</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">False</span></span>
<span id="cb6-18">            total_rewards: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">float</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.0</span></span>
<span id="cb6-19"></span>
<span id="cb6-20">            <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">while</span> <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">not</span> done:</span>
<span id="cb6-21">                action <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> explorer.choose_action(action_space<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>env.action_space, state<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>state, qtable<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>learner.qtable)</span>
<span id="cb6-22"></span>
<span id="cb6-23">                <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># log all the stats and actions</span></span>
<span id="cb6-24">                all_states.append(state)</span>
<span id="cb6-25">                all_actions.append(action)</span>
<span id="cb6-26"></span>
<span id="cb6-27">                <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># take the action $a$ and observe the outcome state $s'$ and reward $r$</span></span>
<span id="cb6-28">                new_state, reward, terminated, truncated, info <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> env.step(action)</span>
<span id="cb6-29"></span>
<span id="cb6-30">                <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># mark as done whether game was terminated (victory, hole) or truncated (wall)</span></span>
<span id="cb6-31">                done <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> terminated <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">or</span> truncated</span>
<span id="cb6-32"></span>
<span id="cb6-33">                <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># learner updates Q-table</span></span>
<span id="cb6-34">                learner.qtable[state, action] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> learner.update(state, action, <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">float</span>(reward), new_state)</span>
<span id="cb6-35"></span>
<span id="cb6-36">                total_rewards <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+=</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">float</span>(reward)</span>
<span id="cb6-37">                step <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span></span>
<span id="cb6-38"></span>
<span id="cb6-39">                <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># our new state is state</span></span>
<span id="cb6-40">                state <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> new_state</span>
<span id="cb6-41"></span>
<span id="cb6-42">            <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># log all rewards and steps</span></span>
<span id="cb6-43">            rewards[episode, run] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> total_rewards</span>
<span id="cb6-44">            steps[episode, run] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> step</span>
<span id="cb6-45"></span>
<span id="cb6-46">        qtables[run, :, :] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> learner.qtable</span>
<span id="cb6-47"></span>
<span id="cb6-48">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> rewards, steps, episodes, qtables, all_states, all_actions</span></code></pre></div></div>
</div>
<div id="859ca2cc" class="cell">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb7" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb7-1"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> postprocess(episodes: np.ndarray, params: Params, rewards: np.ndarray, steps: np.ndarray, map_size: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span>):</span>
<span id="cb7-2">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">"""Convert the results of the simulation into dataframes."""</span></span>
<span id="cb7-3"></span>
<span id="cb7-4">    res <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> pd.DataFrame(</span>
<span id="cb7-5">        data<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>{</span>
<span id="cb7-6">            <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"Episodes"</span>: np.tile(episodes, reps<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>params.n_runs),</span>
<span id="cb7-7">            <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"Rewards"</span>: rewards.flatten(order<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"F"</span>),</span>
<span id="cb7-8">            <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"Steps"</span>: steps.flatten(order<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"F"</span>),</span>
<span id="cb7-9">        }</span>
<span id="cb7-10">    )</span>
<span id="cb7-11">    res[<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"cum_rewards"</span>] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> rewards.cumsum(axis<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>).flatten(order<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"F"</span>)</span>
<span id="cb7-12">    res[<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"map_size"</span>] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> np.repeat(<span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">f"</span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>map_size<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">x</span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>map_size<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">"</span>, res.shape[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>])</span>
<span id="cb7-13"></span>
<span id="cb7-14">    st <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> pd.DataFrame(data<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>{<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"Episodes"</span>: episodes, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"Steps"</span>: steps.mean(axis<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>)})</span>
<span id="cb7-15">    st[<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"map_size"</span>] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> np.repeat(<span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">f"</span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>map_size<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">x</span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>map_size<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">"</span>, st.shape[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>])</span>
<span id="cb7-16"></span>
<span id="cb7-17">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> res, st</span></code></pre></div></div>
</div>
<div id="2feb8a4d" class="cell">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb8" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb8-1"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> qtable_directions_map(qtable: np.ndarray, map_size: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span>):</span>
<span id="cb8-2">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">"""Get the best learned action &amp; map it to arrows."""</span></span>
<span id="cb8-3"></span>
<span id="cb8-4">    eps <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> np.finfo(qtable.dtype).eps  <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># minimum float number on the machine</span></span>
<span id="cb8-5">    directions <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> {<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>: <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"←"</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>: <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"↓"</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>: <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"→"</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">3</span>: <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"↑"</span>}</span>
<span id="cb8-6"></span>
<span id="cb8-7">    qtable_val_max <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> qtable.<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">max</span>(axis<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>).reshape(map_size, map_size)</span>
<span id="cb8-8">    qtable_best_action <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> np.argmax(qtable, axis<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>).reshape(map_size, map_size)</span>
<span id="cb8-9">    qtable_directions <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> np.empty(qtable_best_action.size, dtype<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="bu" style="color: null;
background-color: null;
font-style: inherit;">str</span>)</span>
<span id="cb8-10"></span>
<span id="cb8-11">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> idx, val <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">enumerate</span>(qtable_best_action.flat):</span>
<span id="cb8-12">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> qtable_val_max.flat[idx] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">&gt;</span> eps:</span>
<span id="cb8-13">            <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Assign an arrow only if a minimal Q-value has been learned as best action</span></span>
<span id="cb8-14">            <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># otherwise since 0 is a direction, it also gets mapped on the tiles where</span></span>
<span id="cb8-15">            <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># it didn't actually learn anything</span></span>
<span id="cb8-16">            qtable_directions[idx] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> directions[val]</span>
<span id="cb8-17"></span>
<span id="cb8-18">    qtable_directions <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> qtable_directions.reshape(map_size, map_size)</span>
<span id="cb8-19">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> qtable_val_max, qtable_directions</span></code></pre></div></div>
</div>
<div id="b79c7eff" class="cell">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb9" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb9-1"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> plot_q_values_map(qtable: np.ndarray, env: gym.Env, map_size: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span>):</span>
<span id="cb9-2">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">"""Plot the last frame of the simulation and the policy learned."""</span></span>
<span id="cb9-3"></span>
<span id="cb9-4">    qtable_val_max, qtable_directions <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> qtable_directions_map(qtable, map_size)</span>
<span id="cb9-5"></span>
<span id="cb9-6">    fig, ax <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> plt.subplots(nrows<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, ncols<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>, figsize<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">10</span>, <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">5.5</span>), constrained_layout<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>)</span>
<span id="cb9-7"></span>
<span id="cb9-8">    ax[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>].imshow(env.render(), aspect<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"equal"</span>, interpolation<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"none"</span>)</span>
<span id="cb9-9">    ax[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>].axis(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"off"</span>)</span>
<span id="cb9-10">    ax[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>].set_title(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"Last frame"</span>)</span>
<span id="cb9-11"></span>
<span id="cb9-12">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Plot the policy</span></span>
<span id="cb9-13">    sns.heatmap(</span>
<span id="cb9-14">        qtable_val_max,</span>
<span id="cb9-15">        annot<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>qtable_directions,</span>
<span id="cb9-16">        fmt<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">""</span>,</span>
<span id="cb9-17">        square<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>,</span>
<span id="cb9-18">        ax<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>ax[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>],</span>
<span id="cb9-19">        cmap<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>sns.color_palette(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"Blues"</span>, as_cmap<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>),</span>
<span id="cb9-20">        linewidths<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.5</span>,</span>
<span id="cb9-21">        linecolor<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"black"</span>,</span>
<span id="cb9-22">        xticklabels<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>[],</span>
<span id="cb9-23">        yticklabels<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>[],</span>
<span id="cb9-24">    )</span>
<span id="cb9-25">    ax[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>].<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">set</span>(title<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"Learned Q-values</span><span class="ch" style="color: #20794D;
background-color: null;
font-style: inherit;">\n</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">Arrows represent best action"</span>)</span>
<span id="cb9-26">    ax[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>].axis(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"off"</span>)</span>
<span id="cb9-27"></span>
<span id="cb9-28">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># autoscale annotation font size</span></span>
<span id="cb9-29">    rows, cols <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> qtable_val_max.shape</span>
<span id="cb9-30">    bbox <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> ax[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>].get_window_extent().transformed(fig.dpi_scale_trans.inverted())</span>
<span id="cb9-31">    width_in, height_in <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> bbox.width, bbox.height</span>
<span id="cb9-32"></span>
<span id="cb9-33">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Heuristic scaling factor (tweak as needed)</span></span>
<span id="cb9-34">    scale <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">min</span>(width_in <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span> cols, height_in <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span> rows)</span>
<span id="cb9-35">    fontsize <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> scale <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">50</span></span>
<span id="cb9-36"></span>
<span id="cb9-37">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Apply new font size</span></span>
<span id="cb9-38">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> text <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> ax[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>].texts:</span>
<span id="cb9-39">        text.set_fontsize(fontsize)</span>
<span id="cb9-40"></span>
<span id="cb9-41">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> _, spine <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> ax[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>].spines.items():</span>
<span id="cb9-42">        spine.set_visible(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>)</span>
<span id="cb9-43">        spine.set_linewidth(<span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.7</span>)</span>
<span id="cb9-44">        spine.set_color(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"black"</span>)</span>
<span id="cb9-45"></span>
<span id="cb9-46">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> fig, ax</span></code></pre></div></div>
</div>
<div id="0abce4f2" class="cell">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb10" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb10-1"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> plot_states_actions_distribution(states: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">list</span>[<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span>], actions: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">list</span>[<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span>], map_size: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span>):</span>
<span id="cb10-2">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">"""Plot the distributions of states and actions."""</span></span>
<span id="cb10-3">    labels <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> {<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"LEFT"</span>: <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"DOWN"</span>: <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"RIGHT"</span>: <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"UP"</span>: <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">3</span>}</span>
<span id="cb10-4"></span>
<span id="cb10-5">    fig, ax <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> plt.subplots(nrows<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, ncols<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>, figsize<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">11</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">5</span>), constrained_layout<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>)</span>
<span id="cb10-6">    sns.histplot(data<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>states, ax<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>ax[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>], kde<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>)</span>
<span id="cb10-7">    ax[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>].set_title(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"States"</span>)</span>
<span id="cb10-8"></span>
<span id="cb10-9">    sns.histplot(data<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>actions, ax<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>ax[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>])</span>
<span id="cb10-10">    ax[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>].set_xticks(<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">list</span>(labels.values()), labels<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>labels.keys())</span>
<span id="cb10-11">    ax[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>].set_title(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"Actions"</span>)</span>
<span id="cb10-12"></span>
<span id="cb10-13">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> fig, ax</span></code></pre></div></div>
</div>
<div id="15260b23" class="cell">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb11" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb11-1"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> plot_steps_and_rewards(rewards_df: pd.DataFrame, steps_df: pd.DataFrame):</span>
<span id="cb11-2">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">"""Plot the steps and rewards from dataframes."""</span></span>
<span id="cb11-3">    fig, ax <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> plt.subplots(nrows<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, ncols<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>, figsize<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">11</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">5</span>), constrained_layout<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>)</span>
<span id="cb11-4">    sns.lineplot(data<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>rewards_df, x<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"Episodes"</span>, y<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"cum_rewards"</span>, hue<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"map_size"</span>, linewidth<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.7</span>, ax<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>ax[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>])</span>
<span id="cb11-5">    ax[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>].<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">set</span>(ylabel<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"Cumulated rewards"</span>)</span>
<span id="cb11-6"></span>
<span id="cb11-7">    sns.lineplot(data<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>steps_df, x<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"Episodes"</span>, y<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"Steps"</span>, hue<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"map_size"</span>, linewidth<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.7</span>, ax<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>ax[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>])</span>
<span id="cb11-8">    ax[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>].<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">set</span>(ylabel<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"Averaged steps number"</span>)</span>
<span id="cb11-9"></span>
<span id="cb11-10">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> axi <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> ax:</span>
<span id="cb11-11">        axi.legend(title<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"map size"</span>)</span>
<span id="cb11-12"></span>
<span id="cb11-13">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> fig, ax</span></code></pre></div></div>
</div>
<div id="f1dfe3a7" class="cell">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb12" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb12-1"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">from</span> collections.abc <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> Callable</span>
<span id="cb12-2"></span>
<span id="cb12-3">EnvFactory <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> Callable[[<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span>], gym.Env]</span>
<span id="cb12-4"></span>
<span id="cb12-5"></span>
<span id="cb12-6"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> run_experiments(make_env: EnvFactory, params: Params, map_sizes: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">list</span>[<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span>] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">|</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span>):</span>
<span id="cb12-7">    res_all <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> pd.DataFrame()</span>
<span id="cb12-8">    st_all <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> pd.DataFrame()</span>
<span id="cb12-9"></span>
<span id="cb12-10">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">isinstance</span>(map_sizes, <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span>):</span>
<span id="cb12-11">        map_sizes <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> [map_sizes]</span>
<span id="cb12-12"></span>
<span id="cb12-13">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> map_size <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> map_sizes:</span>
<span id="cb12-14">        env <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> make_env(map_size)</span>
<span id="cb12-15"></span>
<span id="cb12-16">        action_size: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">|</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">getattr</span>(env.action_space, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"n"</span>, <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span>)</span>
<span id="cb12-17">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">assert</span> action_size <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">is</span> <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">not</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span></span>
<span id="cb12-18"></span>
<span id="cb12-19">        state_size: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">|</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">getattr</span>(env.observation_space, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"n"</span>, <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span>)</span>
<span id="cb12-20">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">assert</span> state_size <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">is</span> <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">not</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span></span>
<span id="cb12-21"></span>
<span id="cb12-22">        env.action_space.seed(params.seed)  <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Set the seed to get reproducible results when sampling the action space</span></span>
<span id="cb12-23">        learner <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> Qlearning(</span>
<span id="cb12-24">            lr<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>params.learning_rate, gamma<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>params.gamma, state_size<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>state_size, action_size<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>action_size</span>
<span id="cb12-25">        )</span>
<span id="cb12-26">        explorer <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> EpsilonGreedy(epsilon<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>params.epsilon, seed<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>params.seed)</span>
<span id="cb12-27"></span>
<span id="cb12-28">        <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">print</span>(<span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">f"Map size: </span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>map_size<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">x</span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>map_size<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">"</span>)</span>
<span id="cb12-29">        rewards, steps, episodes, qtables, all_states, all_actions <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> run_env(</span>
<span id="cb12-30">            env, learner, explorer, params, state_size, action_size</span>
<span id="cb12-31">        )</span>
<span id="cb12-32"></span>
<span id="cb12-33">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Save the results in dataframes</span></span>
<span id="cb12-34">        res, st <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> postprocess(episodes, params, rewards, steps, map_size)</span>
<span id="cb12-35">        res_all <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> pd.concat([res_all, res])</span>
<span id="cb12-36">        st_all <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> pd.concat([st_all, st])</span>
<span id="cb12-37">        qtable <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> qtables.mean(axis<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>)  <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Average the Q-table between runs</span></span>
<span id="cb12-38"></span>
<span id="cb12-39">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">with</span> load_theme(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"ambivalent"</span>):</span>
<span id="cb12-40">            plot_states_actions_distribution(states<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>all_states, actions<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>all_actions, map_size<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>map_size)</span>
<span id="cb12-41">        plt.show()</span>
<span id="cb12-42"></span>
<span id="cb12-43">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">with</span> load_theme(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"ambivalent"</span>):</span>
<span id="cb12-44">            plot_q_values_map(qtable, env, map_size)</span>
<span id="cb12-45">        plt.show()</span>
<span id="cb12-46"></span>
<span id="cb12-47">        env.close()</span>
<span id="cb12-48"></span>
<span id="cb12-49">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">with</span> load_theme(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"ambivalent"</span>):</span>
<span id="cb12-50">        plot_steps_and_rewards(res_all, st_all)</span>
<span id="cb12-51">    plt.show()</span></code></pre></div></div>
</div>
<div id="2098758b" class="cell">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb13" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb13-1"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> make_frozenlake_env(params: Params) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-&gt;</span> EnvFactory:</span>
<span id="cb13-2">    <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> _factory(map_size: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span>) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-&gt;</span> gym.Env:</span>
<span id="cb13-3">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> gym.make(</span>
<span id="cb13-4">            <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"FrozenLake-v1"</span>,</span>
<span id="cb13-5">            is_slippery<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>params.is_slippery,</span>
<span id="cb13-6">            render_mode<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"rgb_array"</span>,</span>
<span id="cb13-7">            desc<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>generate_random_map(size<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>map_size, p<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>params.proba_frozen, seed<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>params.seed),</span>
<span id="cb13-8">            <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># reward_schedule=(10.0, -1.0, -0.01),  # reach goal, reach hole, reach frozen (includes Start)</span></span>
<span id="cb13-9">        )</span>
<span id="cb13-10"></span>
<span id="cb13-11">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> _factory</span>
<span id="cb13-12"></span>
<span id="cb13-13"></span>
<span id="cb13-14">map_sizes <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> [<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">4</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">7</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">9</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">11</span>]</span>
<span id="cb13-15">params <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> Params()</span>
<span id="cb13-16"></span>
<span id="cb13-17">run_experiments(make_frozenlake_env(params), params, map_sizes)</span></code></pre></div></div>
<div class="cell-output cell-output-stdout">
<pre><code>Map size: 4x4</code></pre>
</div>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><img src="https://gcerar.github.io/posts/reinforcement-learning/2025-12-17-tabular-q-learning_files/figure-html/cell-15-output-2.png" width="1111" height="511" class="figure-img"></p>
</figure>
</div>
</div>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><img src="https://gcerar.github.io/posts/reinforcement-learning/2025-12-17-tabular-q-learning_files/figure-html/cell-15-output-3.png" width="1010" height="557" class="figure-img"></p>
</figure>
</div>
</div>
<div class="cell-output cell-output-stdout">
<pre><code>Map size: 7x7</code></pre>
</div>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><img src="https://gcerar.github.io/posts/reinforcement-learning/2025-12-17-tabular-q-learning_files/figure-html/cell-15-output-5.png" width="1111" height="511" class="figure-img"></p>
</figure>
</div>
</div>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><img src="https://gcerar.github.io/posts/reinforcement-learning/2025-12-17-tabular-q-learning_files/figure-html/cell-15-output-6.png" width="1010" height="559" class="figure-img"></p>
</figure>
</div>
</div>
<div class="cell-output cell-output-stdout">
<pre><code>Map size: 9x9</code></pre>
</div>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><img src="https://gcerar.github.io/posts/reinforcement-learning/2025-12-17-tabular-q-learning_files/figure-html/cell-15-output-8.png" width="1111" height="511" class="figure-img"></p>
</figure>
</div>
</div>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><img src="https://gcerar.github.io/posts/reinforcement-learning/2025-12-17-tabular-q-learning_files/figure-html/cell-15-output-9.png" width="1010" height="557" class="figure-img"></p>
</figure>
</div>
</div>
<div class="cell-output cell-output-stdout">
<pre><code>Map size: 11x11</code></pre>
</div>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><img src="https://gcerar.github.io/posts/reinforcement-learning/2025-12-17-tabular-q-learning_files/figure-html/cell-15-output-11.png" width="1111" height="511" class="figure-img"></p>
</figure>
</div>
</div>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><img src="https://gcerar.github.io/posts/reinforcement-learning/2025-12-17-tabular-q-learning_files/figure-html/cell-15-output-12.png" width="1010" height="557" class="figure-img"></p>
</figure>
</div>
</div>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><img src="https://gcerar.github.io/posts/reinforcement-learning/2025-12-17-tabular-q-learning_files/figure-html/cell-15-output-13.png" width="1111" height="511" class="figure-img"></p>
</figure>
</div>
</div>
</div>
</section>
<section id="appendix" class="level1">
<h1>Appendix</h1>
<div id="519f6955" class="cell">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb18" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb18-1">SHOW_PROGRESS <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">False</span></span>
<span id="cb18-2"></span>
<span id="cb18-3"></span>
<span id="cb18-4"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> make_frozenlake_env(params: Params) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-&gt;</span> EnvFactory:</span>
<span id="cb18-5">    <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> _factory(map_size: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span>) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-&gt;</span> gym.Env:</span>
<span id="cb18-6">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> gym.make(</span>
<span id="cb18-7">            <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"FrozenLake-v1"</span>,</span>
<span id="cb18-8">            is_slippery<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>params.is_slippery,</span>
<span id="cb18-9">            render_mode<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"rgb_array"</span>,</span>
<span id="cb18-10">            desc<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>generate_random_map(size<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>map_size, p<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>params.proba_frozen, seed<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>params.seed),</span>
<span id="cb18-11">            reward_schedule<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>(<span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">10.0</span>, <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span><span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">10.0</span>, <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span><span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.01</span>),  <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># reach goal, reach hole, reach frozen (includes Start)</span></span>
<span id="cb18-12">        )</span>
<span id="cb18-13"></span>
<span id="cb18-14">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> _factory</span>
<span id="cb18-15"></span>
<span id="cb18-16"></span>
<span id="cb18-17">params <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> Params()</span>
<span id="cb18-18"></span>
<span id="cb18-19">run_experiments(make_frozenlake_env(params), params, map_sizes<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">5</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">25</span>])</span></code></pre></div></div>
<div class="cell-output cell-output-stdout">
<pre><code>Map size: 5x5</code></pre>
</div>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><img src="https://gcerar.github.io/posts/reinforcement-learning/2025-12-17-tabular-q-learning_files/figure-html/cell-16-output-2.png" width="1111" height="511" class="figure-img"></p>
</figure>
</div>
</div>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><img src="https://gcerar.github.io/posts/reinforcement-learning/2025-12-17-tabular-q-learning_files/figure-html/cell-16-output-3.png" width="1010" height="558" class="figure-img"></p>
</figure>
</div>
</div>
<div class="cell-output cell-output-stdout">
<pre><code>Map size: 25x25</code></pre>
</div>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><img src="https://gcerar.github.io/posts/reinforcement-learning/2025-12-17-tabular-q-learning_files/figure-html/cell-16-output-5.png" width="1111" height="511" class="figure-img"></p>
</figure>
</div>
</div>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><img src="https://gcerar.github.io/posts/reinforcement-learning/2025-12-17-tabular-q-learning_files/figure-html/cell-16-output-6.png" width="1011" height="561" class="figure-img"></p>
</figure>
</div>
</div>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><img src="https://gcerar.github.io/posts/reinforcement-learning/2025-12-17-tabular-q-learning_files/figure-html/cell-16-output-7.png" width="1111" height="511" class="figure-img"></p>
</figure>
</div>
</div>
</div>


</section>

<div id="quarto-appendix" class="default"><section class="quarto-appendix-contents" id="quarto-reuse"><h2 class="anchored quarto-appendix-heading">Reuse</h2><div class="quarto-appendix-contents"><div><a rel="license" href="https://creativecommons.org/licenses/by-nc-sa/4.0/deed.en_us">CC BY-NC-SA 4.0</a></div></div></section></div> ]]></description>
  <category>code</category>
  <category>python</category>
  <category>mdp</category>
  <guid>https://gcerar.github.io/posts/reinforcement-learning/2025-12-17-tabular-q-learning.html</guid>
  <pubDate>Tue, 16 Dec 2025 23:00:00 GMT</pubDate>
</item>
<item>
  <title>Bernoulli Multi-Armed Bandit Problem</title>
  <dc:creator>Gregor Cerar</dc:creator>
  <link>https://gcerar.github.io/posts/reinforcement-learning/2025-12-16-multi-armed-bandit-problem.html</link>
  <description><![CDATA[ 





<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
Note
</div>
</div>
<div class="callout-body-container callout-body">
<p>Credits to <a href="https://lilianweng.github.io/posts/2018-01-23-multi-armed-bandit/">Lil’s blog post</a>. I slightly improved and extended it for myself to better understand statistical terms.</p>
</div>
</div>
<p>The exploitation-exploration dilemma exists in many aspects of our lives. For instance, for your favourite option (<em>e.g.,</em> restaurant, chatbot, artist, busic band) you are confident of what you will get, but you miss the chance to discover an even better option. But if you choose to try new options all the time, you’re very likely gonna deal with unpleasant service from time to time. Not every new option pays off.</p>
<p>This trade-off becomes especially important when we operate under <strong>incomplete information</strong>. Without full knowledge of our environment, we must gather information while simultaneously making good decisions. Exploitation uses what what we’ve learned, while exploration risks short-term loss to gain long-term insight.</p>
<p>To see how this plays out in a clean mathematical settings, we turn to a classic model.</p>
<section id="what-is-a-multi-armed-bandit" class="level1">
<h1>What is a Multi-Armed Bandit?</h1>
<p>The multi-armed bandit (MAB) captures this dilemma elegantly. Imagine a row of slot machines (<em>i.e.,</em> “<a href="(https://en.wiktionary.org/wiki/one-armed_bandit)">one-armed bandits</a>”) each with unknown probability of payout. The goals is to maximize the total reward over time. Each pull (<em>i.e.,</em> action) gives you information, but also costs you the chance to pull a better machine.</p>
</section>
<section id="the-environment" class="level1">
<h1>The Environment</h1>
<p>Let’s consider the simplest version of the problem. You face several slot machines, each with unknown Bernoulli reward distribution. Each play either gives you a fixed reward or gives nothing. You have plenty of trials, and your choices don’t change the underlying probabilities.</p>
<p>The question is: <em>What is the best strategy to achieve the highest long-term reward?</em></p>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
Note
</div>
</div>
<div class="callout-body-container callout-body">
<p>For newcommers to reinforcement learning (as I was when writing this), the following clarifications help.</p>
<p>First, <strong>regret</strong> measures how much reward you lost compared to always choosing the best option in hindsight. It quantifies the “if only I had known…” feeling.</p>
<p>Second, the reward probabilities are <em>not known ahead of time</em>. You discover them through experiennce. This is what makes the problem interesting.</p>
</div>
</div>
<div id="cell-5" class="cell">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb1" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb1-1"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">from</span> abc <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> ABC, abstractmethod</span>
<span id="cb1-2"></span>
<span id="cb1-3"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> matplotlib.pyplot <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">as</span> plt</span>
<span id="cb1-4"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> matplotlib.ticker <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">as</span> ticker</span>
<span id="cb1-5"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> numpy <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">as</span> np</span>
<span id="cb1-6"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> numpy.typing <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">as</span> npt</span>
<span id="cb1-7"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">from</span> aquarel <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> load_theme</span>
<span id="cb1-8"></span>
<span id="cb1-9"><span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">%</span>config InlineBackend.figure_formats <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> {<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'retina'</span>, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'png'</span>}</span></code></pre></div></div>
</div>
<div id="cell-6" class="cell">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb2" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb2-1"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">class</span> BaseBandit(ABC):</span>
<span id="cb2-2">    k: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span>  <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># number of arms</span></span>
<span id="cb2-3">    best_proba: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">float</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">|</span> np.float64  <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># hidden to solver; for regret calculation, highest possible reward probability</span></span>
<span id="cb2-4">    probas: npt.NDArray[np.float64]  <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># hidden to solver; reward probabilities</span></span>
<span id="cb2-5"></span>
<span id="cb2-6">    <span class="at" style="color: #657422;
background-color: null;
font-style: inherit;">@abstractmethod</span></span>
<span id="cb2-7">    <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> generate_reward(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, i: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span>) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-&gt;</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">float</span>:</span>
<span id="cb2-8">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">"""Returns reward after lever `i` is pulled."""</span></span>
<span id="cb2-9">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">raise</span> <span class="pp" style="color: #AD0000;
background-color: null;
font-style: inherit;">NotImplementedError</span></span>
<span id="cb2-10"></span>
<span id="cb2-11"></span>
<span id="cb2-12"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">class</span> BaseSolver(ABC):</span>
<span id="cb2-13">    bandit: BaseBandit  <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># reference to the bandit instance</span></span>
<span id="cb2-14">    counts: npt.NDArray[np.int64]  <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># hold stats of pulled levers</span></span>
<span id="cb2-15"></span>
<span id="cb2-16">    actions: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">list</span>[<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span>]</span>
<span id="cb2-17">    rewards: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">list</span>[<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">float</span>]</span>
<span id="cb2-18">    regrets: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">list</span>[<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">float</span>]</span>
<span id="cb2-19"></span>
<span id="cb2-20">    <span class="at" style="color: #657422;
background-color: null;
font-style: inherit;">@abstractmethod</span></span>
<span id="cb2-21">    <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> <span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">__init__</span>(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, bandit: BaseBandit) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-&gt;</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span>:</span>
<span id="cb2-22">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">"""bandit (BaseBandit): the target bandit to solve."""</span></span>
<span id="cb2-23">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">assert</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">isinstance</span>(bandit, BaseBandit)</span>
<span id="cb2-24">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.bandit <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> bandit</span>
<span id="cb2-25"></span>
<span id="cb2-26">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.counts <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> np.zeros(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.bandit.k, dtype<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>np.int64)</span>
<span id="cb2-27"></span>
<span id="cb2-28">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.actions <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> []  <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># a history of lever ids, 0 to bandit n-1.</span></span>
<span id="cb2-29">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.rewards <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> []  <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># a history of collected rewards.</span></span>
<span id="cb2-30">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.regrets <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> []  <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># a history of regrets for taken actions.</span></span>
<span id="cb2-31"></span>
<span id="cb2-32">    <span class="at" style="color: #657422;
background-color: null;
font-style: inherit;">@property</span></span>
<span id="cb2-33">    <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> num_steps(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-&gt;</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span>:</span>
<span id="cb2-34">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">len</span>(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.actions)</span>
<span id="cb2-35"></span>
<span id="cb2-36">    <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> update_regret(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, i: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span>) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-&gt;</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span>:</span>
<span id="cb2-37">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">"""Update the regret after the lever `i` is pulled."""</span></span>
<span id="cb2-38">        regret <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.bandit.best_proba <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.bandit.probas[i]</span>
<span id="cb2-39">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.regrets.append(regret)</span>
<span id="cb2-40"></span>
<span id="cb2-41">    <span class="at" style="color: #657422;
background-color: null;
font-style: inherit;">@property</span></span>
<span id="cb2-42">    <span class="at" style="color: #657422;
background-color: null;
font-style: inherit;">@abstractmethod</span></span>
<span id="cb2-43">    <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> estimated_probas(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-&gt;</span> npt.NDArray[np.float64]:</span>
<span id="cb2-44">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">"""Retrieve learned reward probability for each arm `n` of the bandit."""</span></span>
<span id="cb2-45">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">raise</span> <span class="pp" style="color: #AD0000;
background-color: null;
font-style: inherit;">NotImplementedError</span></span>
<span id="cb2-46"></span>
<span id="cb2-47">    <span class="at" style="color: #657422;
background-color: null;
font-style: inherit;">@abstractmethod</span></span>
<span id="cb2-48">    <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> run_one_step(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-&gt;</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">tuple</span>[<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span>, <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">float</span>]:</span>
<span id="cb2-49">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">"""Return solver's selected action and bandit's outcome reward."""</span></span>
<span id="cb2-50">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">raise</span> <span class="pp" style="color: #AD0000;
background-color: null;
font-style: inherit;">NotImplementedError</span></span>
<span id="cb2-51"></span>
<span id="cb2-52">    <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> run(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, num_steps: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span>) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-&gt;</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span>:</span>
<span id="cb2-53">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">"""Run simulation for `num_steps` steps."""</span></span>
<span id="cb2-54">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> _ <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">range</span>(num_steps):</span>
<span id="cb2-55">            i, r <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.run_one_step()</span>
<span id="cb2-56"></span>
<span id="cb2-57">            <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.counts[i] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span></span>
<span id="cb2-58">            <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.actions.append(i)</span>
<span id="cb2-59">            <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.update_regret(i)</span>
<span id="cb2-60">            <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.rewards.append(r)</span></code></pre></div></div>
</div>
<section id="formal-definition" class="level2">
<h2 class="anchored" data-anchor-id="formal-definition">Formal Definition</h2>
<p>With the intuition in place, we can now describe the Bernoulli multi-armed bandit more formally. A bandit problem is defined as a tuple <img src="https://latex.codecogs.com/png.latex?%5Clangle%20%5Cmathcal%7BA%7D,%20%5Cmathcal%7BR%7D%20%5Crangle">, where:</p>
<ul>
<li>We have <img src="https://latex.codecogs.com/png.latex?K"> machines (or levers) with probabilities <img src="https://latex.codecogs.com/png.latex?%5C%7B%20%5Ctheta_%7B1%7D,%20%5Cldots,%20%5Ctheta_%7BK%7D%20%5C%7D">.</li>
<li>At each time step <img src="https://latex.codecogs.com/png.latex?t">, we take an action <img src="https://latex.codecogs.com/png.latex?a_t"> on one slot machine and receive a reward <img src="https://latex.codecogs.com/png.latex?r_t">.</li>
<li><img src="https://latex.codecogs.com/png.latex?%5Cmathcal%7BA%7D"> is a set of possible actions. The value of an action is expected reward, <img src="https://latex.codecogs.com/png.latex?Q(a)%20=%20%5Cmathbb%7BE%7D%5Br%7Ca%5D%20=%20%5Ctheta">. If action <img src="https://latex.codecogs.com/png.latex?a_%7Bt%7D"> corresponds to machine <img src="https://latex.codecogs.com/png.latex?i">, then <img src="https://latex.codecogs.com/png.latex?Q(a_%7Bt%7D)%20=%20%5Ctheta_%7Bi%7D">.</li>
<li><img src="https://latex.codecogs.com/png.latex?%5Cmathcal%7BR%7D"> is the reward function. In a Bernoulli bandit, each pull yields a reward of <img src="https://latex.codecogs.com/png.latex?1"> with probability <img src="https://latex.codecogs.com/png.latex?Q(a_%7Bt%7D)">, and <img src="https://latex.codecogs.com/png.latex?0"> otherwise.</li>
</ul>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
Note
</div>
</div>
<div class="callout-body-container callout-body">
<p>Recall that a <a href="https://en.wikipedia.org/wiki/Bernoulli_distribution">Bernoulli distribution</a> is a discrete probability distribution, which takes the value <img src="https://latex.codecogs.com/png.latex?1"> with probability <img src="https://latex.codecogs.com/png.latex?p"> and <img src="https://latex.codecogs.com/png.latex?0"> with probability <img src="https://latex.codecogs.com/png.latex?1%20-%20p">.</p>
<p>The symbol <img src="https://latex.codecogs.com/png.latex?%5Cmathbb%7BE%7D%5B%5Ccdot%5D"> denotes the <a href="https://en.wikipedia.org/wiki/Expected_value">expected value</a>, a generalized weighted average. The expression <img src="https://latex.codecogs.com/png.latex?%5Cmathbb%7BE%7D%5Br%7Ca%5D"> reads as <em>the expected reward (<img src="https://latex.codecogs.com/png.latex?r">) that we took action <img src="https://latex.codecogs.com/png.latex?a">.</em>”</p>
<p>Crucially, the probabilities <img src="https://latex.codecogs.com/png.latex?%5C%7B%5Ctheta_%7Bi%7D%5C%7D"> are <strong>NOT known in advance</strong>. They must be estimated through interaction.</p>
</div>
</div>
<p>A Bernoulli bandit can be seen as a simplified Marko decision process (MDP) without a state space. The objective is to maximize the total reward <img src="https://latex.codecogs.com/png.latex?%5Csum_%7Bt=1%7D%5E%7BT%7D%20r_%7Bt%7D">. If we knew which action had the biggest reward probability, this would be equivalent to minimizing the <strong>regret</strong> from not always choosing that optimal action.</p>
<p>Let <img src="https://latex.codecogs.com/png.latex?%5Ctheta%5E%7B*%7D"> denote the reward probability of the optimal action <img src="https://latex.codecogs.com/png.latex?a%5E%7B*%7D">:</p>
<p><img src="https://latex.codecogs.com/png.latex?%0A%5Ctheta%5E%7B*%7D%20=%20Q(a%5E%7B*%7D)%20=%20%5Cmax_%7Ba%20%5Cin%20%5Cmathcal%7BA%7D%7D%20Q(a)%20=%20%5Cmax_%7B1%20%5Cleq%20i%20%5Cleq%20K%7D%20%5Ctheta_%7Bi%7D%0A"></p>
<p>The expected cumulative regret up to the time <img src="https://latex.codecogs.com/png.latex?T"> is then:</p>
<p><img src="https://latex.codecogs.com/png.latex?%0A%5Cmathcal%7BL%7D_%7BT%7D%20=%20%5Cmathbb%7BE%7D%5Cleft%5B%5Csum_%7Bt=1%7D%5E%7BT%7D(%5Ctheta%5E%7B*%7D%20-%20Q(a_%7Bt%7D))%5Cright%5D%0A"></p>
<div id="cell-8" class="cell">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb3" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb3-1"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">class</span> BernoulliBandit(BaseBandit):</span>
<span id="cb3-2">    <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> <span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">__init__</span>(</span>
<span id="cb3-3">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>,</span>
<span id="cb3-4">        k: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span>,</span>
<span id="cb3-5">        probas: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">list</span>[<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">float</span>] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">|</span> npt.NDArray[np.float64] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">|</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span>,</span>
<span id="cb3-6">        seed: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">|</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span>,</span>
<span id="cb3-7">    ):</span>
<span id="cb3-8">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># sanity check: `probas` needs to be None or of size `n`.</span></span>
<span id="cb3-9">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">assert</span> probas <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">is</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span> <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">or</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">len</span>(probas) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">==</span> k</span>
<span id="cb3-10"></span>
<span id="cb3-11">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.k <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> k  <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># save number of bandits</span></span>
<span id="cb3-12"></span>
<span id="cb3-13">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.rng <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> np.random.default_rng(seed<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>seed)</span>
<span id="cb3-14"></span>
<span id="cb3-15">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># random probabilities, if they are explicitly defined</span></span>
<span id="cb3-16">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> probas <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">is</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span>:</span>
<span id="cb3-17">            probas <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.rng.random(size<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.k)</span>
<span id="cb3-18"></span>
<span id="cb3-19">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># convert to numpy array for easier operations later</span></span>
<span id="cb3-20">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.probas <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> np.asarray(probas)</span>
<span id="cb3-21"></span>
<span id="cb3-22">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># in case of Bernoulli MAB, highest probabily is equal to optimal</span></span>
<span id="cb3-23">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.best_proba <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> np.<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">max</span>(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.probas)</span>
<span id="cb3-24"></span>
<span id="cb3-25">    <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> generate_reward(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, i: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span>) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-&gt;</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span>:</span>
<span id="cb3-26">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># The player selected the i-th machine.</span></span>
<span id="cb3-27">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span>(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.rng.random() <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">&lt;</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.probas[i])</span></code></pre></div></div>
</div>
</section>
<section id="bandit-strategies" class="level2">
<h2 class="anchored" data-anchor-id="bandit-strategies">Bandit Strategies</h2>
<p>With the bandint problem formally defined, the next question is: <strong>how should we choose actions over time?</strong> Different strategies encode different assumptions about how exploration should be handled. Broadly, we can distinguish three categories:</p>
<ul>
<li><strong>No exploration:</strong> always exploit the best-known action (naive and generally poor).</li>
<li><strong>Random exploration:</strong> explore uniformly at random.</li>
<li><strong>Informed exploration:</strong> explore more often when uncertainty is high.</li>
</ul>
<p>A simple and widely used example of the last category is the <strong><img src="https://latex.codecogs.com/png.latex?%5Cepsilon">-greedy</strong> algorithm.</p>
</section>
</section>
<section id="epsilon-greedy-algorithm" class="level1">
<h1>Epsilon-Greedy Algorithm</h1>
<p>The <img src="https://latex.codecogs.com/png.latex?%5Cepsilon">-greedy algorithm balances exploitation and exploration by choosing the currently best action most of the time, while occasionally exploring at random.</p>
<section id="information-state" class="level4">
<h4 class="anchored" data-anchor-id="information-state">Information State</h4>
<p>At the time step <img src="https://latex.codecogs.com/png.latex?t">, the algorithm maintains:</p>
<ul>
<li>empirical action-value estimates <img src="https://latex.codecogs.com/png.latex?%5Chat%7BQ%7D_t(a)">,</li>
<li>action counts <img src="https://latex.codecogs.com/png.latex?N_t(a)">,</li>
</ul>
<p>summarizing all past interactions.</p>
<p>The empirical value estimate for action <img src="https://latex.codecogs.com/png.latex?a"> is defined as:</p>
<p><img src="https://latex.codecogs.com/png.latex?%0A%5Chat%7BQ%7D_t(a)%20=%20%5Cfrac%7B1%7D%7BN_t(a)%7D%20%5Csum_%7B%5Ctau%20=%201%7D%5Et%20r_%5Ctau%20%5Ccdot%20%5Cmathbb%7B%F0%9D%9F%99%7D%5Ba_%5Ctau%20=%20a%5D%0A"></p>
<p>where:</p>
<ul>
<li><img src="https://latex.codecogs.com/png.latex?r_%5Ctau"> is the reward received at time step <img src="https://latex.codecogs.com/png.latex?%5Ctau">. For a Bernoulli bandit, this is either <img src="https://latex.codecogs.com/png.latex?1"> (success) or <img src="https://latex.codecogs.com/png.latex?0"> (no reward).</li>
<li><img src="https://latex.codecogs.com/png.latex?%5Cmathbb%7B%F0%9D%9F%99%7D%5Ba_%5Ctau%20=%20a%5D"> is an indicator function equal to <img src="https://latex.codecogs.com/png.latex?1"> when action <img src="https://latex.codecogs.com/png.latex?a"> was taken at time <img src="https://latex.codecogs.com/png.latex?%5Ctau">, and <img src="https://latex.codecogs.com/png.latex?0"> otherwise.</li>
<li><img src="https://latex.codecogs.com/png.latex?N_t(a)"> is the number of times action <img src="https://latex.codecogs.com/png.latex?a"> has been selected: <img src="https://latex.codecogs.com/png.latex?%0AN_%7Bt%7D(a)%20=%20%5Csum_%7B%5Ctau%20=%201%7D%5Et%20%5Cmathbb%7B%F0%9D%9F%99%7D%5Ba_%5Ctau%20=%20a%5D%0A"></li>
</ul>
</section>
<section id="policy" class="level4">
<h4 class="anchored" data-anchor-id="policy">Policy</h4>
<p>The <img src="https://latex.codecogs.com/png.latex?%5Cepsilon">-greedy policy defines a stochastic action-selection rule:</p>
<ul>
<li>with probability <img src="https://latex.codecogs.com/png.latex?1%20-%20%5Cepsilon">, the greedy action is selected: <img src="https://latex.codecogs.com/png.latex?%0A%5Chat%7Ba%7D%5E%7B*%7D_t%20=%20%5Carg%5Cmax_%7Ba%5Cin%5Cmathcal%7BA%7D%7D%20%5Chat%7BQ%7D_t(a)%0A"></li>
<li>with probability <img src="https://latex.codecogs.com/png.latex?%5Cepsilon">, an action is selected uniformly at random.</li>
</ul>
<p>Equivalently, the policy can be written as:</p>
<p><img src="https://latex.codecogs.com/png.latex?%0A%5Cpi(a%7Ch_t)%20=%20%5Cbegin%7Bcases%7D%0A%20%201%20-%20%5Cepsilon%20+%20%5Cfrac%7B%5Cepsilon%7D%7B%7C%5Cmathcal%7BA%7D%7C%7D,%20&amp;%20a%20=%20a%5E*_t,%20%5C%5C%0A%20%20%5Cfrac%7B%5Cepsilon%7D%7B%7C%5Cmathcal%7BA%7D%7C%7D,%20&amp;%20%5Ctext%7Botherwise%7D.%0A%5Cend%7Bcases%7D%0A"></p>
</section>
<section id="update-rule" class="level4">
<h4 class="anchored" data-anchor-id="update-rule">Update Rule</h4>
<p>After selecting action <img src="https://latex.codecogs.com/png.latex?a_t"> and observing reward <img src="https://latex.codecogs.com/png.latex?r_t">, the estimate <img src="https://latex.codecogs.com/png.latex?%5Chat%7BQ%7D_t(a_t)"> is updated using the new observation.</p>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
Note
</div>
</div>
<div class="callout-body-container callout-body">
<p>Despite its simplicity, <img src="https://latex.codecogs.com/png.latex?%5Cepsilon">-greedy often performs reasonably well. However, because exploration is random and does not depend on uncertainty, it can waste trials on clearly suboptimal actions.</p>
</div>
</div>
<div id="cell-11" class="cell">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb4" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb4-1"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">class</span> EpsilonGreedy(BaseSolver):</span>
<span id="cb4-2">    <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> <span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">__init__</span>(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, bandit: BaseBandit, eps: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">float</span>, init_proba: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">float</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">1.0</span>, seed: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">|</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span>) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-&gt;</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span>:</span>
<span id="cb4-3">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">"""</span></span>
<span id="cb4-4"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">        eps (float): the probability to explore at each time step.</span></span>
<span id="cb4-5"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">        init_proba (float): default to be 1.0; optimistic initialization</span></span>
<span id="cb4-6"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">        """</span></span>
<span id="cb4-7">        <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">super</span>().<span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">__init__</span>(bandit)</span>
<span id="cb4-8"></span>
<span id="cb4-9">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">assert</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.0</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">&lt;=</span> eps <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">&lt;=</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">1.0</span></span>
<span id="cb4-10">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.eps <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> eps</span>
<span id="cb4-11"></span>
<span id="cb4-12">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># optimistic initialization</span></span>
<span id="cb4-13">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.estimates <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> np.full(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.bandit.k, fill_value<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>init_proba, dtype<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>np.float64)</span>
<span id="cb4-14"></span>
<span id="cb4-15">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># define random generator with seed for reproducibility</span></span>
<span id="cb4-16">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.rng <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> np.random.default_rng(seed<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>seed)</span>
<span id="cb4-17"></span>
<span id="cb4-18">    <span class="at" style="color: #657422;
background-color: null;
font-style: inherit;">@property</span></span>
<span id="cb4-19">    <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> estimated_probas(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-&gt;</span> npt.NDArray[np.float64]:</span>
<span id="cb4-20">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.estimates</span>
<span id="cb4-21"></span>
<span id="cb4-22">    <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> run_one_step(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-&gt;</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">tuple</span>[<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span>, <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">float</span>]:</span>
<span id="cb4-23">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># With probability epsilon pick random exploration, or pick the known best lever.</span></span>
<span id="cb4-24"></span>
<span id="cb4-25">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.rng.random() <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">&lt;</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.eps:</span>
<span id="cb4-26">            <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># pure random exploration</span></span>
<span id="cb4-27">            i <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.rng.integers(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>, <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.bandit.k)</span>
<span id="cb4-28">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">else</span>:</span>
<span id="cb4-29">            <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># greedy selection with random tie-breaking</span></span>
<span id="cb4-30">            candidates <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> np.flatnonzero(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.estimates <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">==</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.estimates.<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">max</span>())</span>
<span id="cb4-31">            i <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span>(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.rng.choice(candidates))</span>
<span id="cb4-32"></span>
<span id="cb4-33">        r <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.bandit.generate_reward(i)</span>
<span id="cb4-34">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.estimates[i] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+=</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">1.0</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span> (<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.counts[i] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> (r <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.estimates[i])</span>
<span id="cb4-35"></span>
<span id="cb4-36">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> i, r</span></code></pre></div></div>
</div>
</section>
</section>
<section id="upper-confidence-bounds-ucb" class="level1">
<h1>Upper Confidence Bounds (UCB)</h1>
<p>Random exploration gives us the opportunity to try actions we know little about. However, pure randomness can also cause us to waste time and re-exploring action we already have striong evidence are suboptimal (bad luck still happens!). Two broad alternatives exist:</p>
<ol type="1">
<li><strong>Decay <img src="https://latex.codecogs.com/png.latex?%5Cepsilon"> over time</strong> in <img src="https://latex.codecogs.com/png.latex?%5Cepsilon">-greedy, making exploration less frequent, or</li>
<li><strong>Act optimistically for uncertain actions,</strong> favoring actions where our estimates are still unreliable.</li>
</ol>
<p>The second idea leads to the class of <strong>Upper Confidence Bound (UCB)</strong> algorithms. The key intuition is simple:</p>
<blockquote class="blockquote">
<p>If we are unsure about action’s value, we pretend it could be good until proven otherwise.</p>
</blockquote>
<p>More formally, UCB defines <strong>upper confidence bound</strong> <img src="https://latex.codecogs.com/png.latex?%5Chat%7BU%7D_t(a)"> that measures the uncertainty in our estimate <img src="https://latex.codecogs.com/png.latex?%5Chat%7BQ%7D_t(a)">. With high probability, the true value satisfies:</p>
<p><img src="https://latex.codecogs.com/png.latex?%0AQ(a)%20%5Cleq%20%5Chat%7BQ%7D_t(a)%20+%20%5Chat%7BU%7D_t(a).%0A"></p>
<p>The uncertainty term <img src="https://latex.codecogs.com/png.latex?%5Chat%7BU%7D_t(a)"> must shrink as we gather more data. Thus, it is a decreasing function of <img src="https://latex.codecogs.com/png.latex?N_t(a)">: the more we pull an arm, the more confident we become, and the smaller its uncertainty bonus should be.</p>
<p>Given this, the UCB policy selects the action whose <em>optimistic estimate</em> is highest:</p>
<p><img src="https://latex.codecogs.com/png.latex?%0Aa_%7Bt%7D%5E%7B%5Ctextrm%7BUCB%7D%7D%20=%20%5Carg%5Cmax_%7Ba%20%5Cin%20%5Cmathcal%7BA%7D%7D%20%5Cleft%5B%20%5Chat%7BQ%7D_%7Bt%7D(a)%20+%20%5Chat%7BU%7D_%7Bt%7D(a)%20%5Cright%5D%0A"></p>
<p>This ensures a natural balance: well explored actions rely mostly on <img src="https://latex.codecogs.com/png.latex?%5Chat%7BQ%7D_%7Bt%7D(a)">, while poorly explored actions get an extra boost from their larger uncertainty term.</p>
<section id="unified-definition" class="level2">
<h2 class="anchored" data-anchor-id="unified-definition">Unified Definition</h2>
<section id="information-state-1" class="level4">
<h4 class="anchored" data-anchor-id="information-state-1">Information State</h4>
<p>At time step <img src="https://latex.codecogs.com/png.latex?t">, the UCB algorithms maintains:</p>
<ul>
<li>empirical action-value estimates <img src="https://latex.codecogs.com/png.latex?%5Chat%7BQ%7D_t(a)">,</li>
<li>action counts <img src="https://latex.codecogs.com/png.latex?N_t(a)">.</li>
</ul>
<p>The quantities summarize the full interaction history.</p>
</section>
<section id="policy-1" class="level4">
<h4 class="anchored" data-anchor-id="policy-1">Policy</h4>
<p>UCB defines deterministic policy:</p>
<p><img src="https://latex.codecogs.com/png.latex?%0A%5Cpi(a%7Ch_t)%20=%20%5Cbegin%7Bcases%7D%0A%20%20%20%201,%20&amp;%20a%20=%20%5Carg%5Cmax_%7Ba'%7D%20%5Cleft%5B%20%5Chat%7BQ%7D_t(a')%20+%20%5Chat%7BU%7D_t(a')%20%5Cright%5D,%20%5C%5C%0A%20%20%20%200,%20&amp;%20%5Ctext%7Botherwise%7D.%0A%5Cend%7Bcases%7D%0A"></p>
<p>Unlike <img src="https://latex.codecogs.com/png.latex?%5Cepsilon">-greedy, exploration is not injected explicitly. Instead, it emerges through optimism in the face of uncertainty.</p>
</section>
<section id="action-selection" class="level4">
<h4 class="anchored" data-anchor-id="action-selection">Action Selection</h4>
<p>At each time step, the selected action is:</p>
<p><img src="https://latex.codecogs.com/png.latex?%0Aa_t%20=%20%5Carg%5Cmax_%7Ba%20%5Cin%20%5Cmathcal%7BA%7D%7D%20%5Cleft%5B%20%5Chat%7BQ%7D_t(a)%20+%20%5Chat%7BU%7D_t(a)%20%5Cright%5D.%0A"></p>
</section>
<section id="update-rule-1" class="level4">
<h4 class="anchored" data-anchor-id="update-rule-1">Update Rule</h4>
<p>After selecting action <img src="https://latex.codecogs.com/png.latex?a_t"> and observing reward <img src="https://latex.codecogs.com/png.latex?r_t">, the algorithm updates:</p>
<ul>
<li>the action counts <img src="https://latex.codecogs.com/png.latex?N_t(a_t)"></li>
<li>the empirical estimate <img src="https://latex.codecogs.com/png.latex?%5Chat%7BQ%7D_t(a_t)"></li>
</ul>
</section>
<section id="choosing-the-uncertainty-bound" class="level4">
<h4 class="anchored" data-anchor-id="choosing-the-uncertainty-bound">Choosing the Uncertainty Bound</h4>
<p>The remaining design choice is how to define <img src="https://latex.codecogs.com/png.latex?%5Chat%7BU%7D_t(a)">. Different choices lead to different members of the UCB family, such as <strong>UCB1</strong>, which derives its bound from Hoeffding’s inequality.</p>
<p><strong>Now the question is: how do we choose the uncertainty bound <img src="https://latex.codecogs.com/png.latex?%5Chat%7BU%7D_t(a)">?</strong></p>
</section>
</section>
<section id="hoeffdings-inequality" class="level2">
<h2 class="anchored" data-anchor-id="hoeffdings-inequality">Hoeffding’s Inequality</h2>
<p>If we do not want to assign any prior knowledge about the shape of the reward distribution (<em>e.g.,</em> Gaussian, exponential), we can rely on <strong>Hoeffding’s Inequality</strong>. This theorem is applicable on <strong>any bounded distribution</strong>.</p>
<p>A random variable is said to follow a <strong>bounded distribution</strong> if all its values lie within a fixed finite interval <img src="https://latex.codecogs.com/png.latex?%5Ba,b%5D">. In our case, Bernoulli rewards always lie in <img src="https://latex.codecogs.com/png.latex?%5B0,1%5D">, so the boundedness assumption is naturally satisfied.</p>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
Note
</div>
</div>
<div class="callout-body-container callout-body">
<p>Here are a few examples for intuition:</p>
<ul>
<li>A Bernoulli distribution is bounded on interval <img src="https://latex.codecogs.com/png.latex?%5B0,1%5D">.</li>
<li>A uniform distribution on interval <em>e.g.,</em> <img src="https://latex.codecogs.com/png.latex?%5B2,5%5D"> is bounded.</li>
<li>A Gaussian distribution is <em>not</em> bounded because of its infinite tails.</li>
</ul>
</div>
</div>
<section id="hoeffdings-inequality-informal-version" class="level4">
<h4 class="anchored" data-anchor-id="hoeffdings-inequality-informal-version">Hoeffding’s Inequality (informal version)</h4>
<p>Let <img src="https://latex.codecogs.com/png.latex?X_1,%20%5Cldots,%20X_t"> be i.i.d. (independent and identically distributed) random variables, all bounded in the interval <img src="https://latex.codecogs.com/png.latex?%5B0,1%5D">. The sample mean is</p>
<p><img src="https://latex.codecogs.com/png.latex?%0A%5Coverline%7BX%7D_t%20=%20%5Cfrac%7B1%7D%7Bt%7D%20%5Csum_%7B%5Ctau%20=%201%7D%5E%7Bt%7D%20X_%7B%5Ctau%7D.%0A"></p>
<p>Then for any <img src="https://latex.codecogs.com/png.latex?u%20%5Cgt%200">, Hoeffding’s inequality states:</p>
<p><img src="https://latex.codecogs.com/png.latex?%0A%5Cmathbb%7BP%7D%5Cleft%5B%5Cmathbb%7BE%7D%5BX%5D%20%5Cgt%20%5Coverline%7BX%7D_%7Bt%7D%20+%20u%20%5Cright%5D%20%5Cleq%20%5Cmathrm%7Be%7D%5E%7B-2tu%5E2%7D.%0A"></p>
<p>This inequality bounds the probability that the true mean exceeds the empirical mean my more than <img src="https://latex.codecogs.com/png.latex?u">.</p>
</section>
<section id="applying-hoeffdings-inequality-to-bandit-rewards" class="level4">
<h4 class="anchored" data-anchor-id="applying-hoeffdings-inequality-to-bandit-rewards">Applying Hoeffding’s Inequality to Bandit Rewards</h4>
<p>To apply this result to the multi-armed bandit setting, we observe that <strong>each fixed action <img src="https://latex.codecogs.com/png.latex?a"> defines its own random reward-generating process</strong>. Every time we select action <img src="https://latex.codecogs.com/png.latex?a">, we obtain a reward drawn independently from the same bounded distribution. Therefore, Hoeffding’s inequality applies directly to each arm.</p>
<p>For a fixed target action <img src="https://latex.codecogs.com/png.latex?a">, define:</p>
<ul>
<li><img src="https://latex.codecogs.com/png.latex?r_%7B%5Ctau%7D(a)"> as the reward random variable,</li>
<li><img src="https://latex.codecogs.com/png.latex?Q(a)"> as the true mean reward,</li>
<li><img src="https://latex.codecogs.com/png.latex?%5Chat%7BQ%7D_%7Bt%7D(a)"> as the sample mean reward,</li>
<li>and <img src="https://latex.codecogs.com/png.latex?u%20=%20U_%7Bt%7D(a)"> as the upper confidence bound.</li>
</ul>
<p>By directly identifying Hoeffding’s variables with the bandit quantities:</p>
<p><img src="https://latex.codecogs.com/png.latex?%0AX_%7B%5Ctau%7D%20%5Cleftrightarrow%20r_%5Ctau(a),%5Cquad%0A%5Cmathbb%7BE%7D%5BX%5D%20%5Cleftrightarrow%20Q(a),%5Cquad%0A%5Coverline%7BX%7D%20%5Cleftrightarrow%20%5Chat%7BQ%7D_%7Bt%7D(a),%5Cquad%0At%20%5Cleftrightarrow%20N_%7Bt%7D(a)%0A"></p>
<p>we obtain:</p>
<p><img src="https://latex.codecogs.com/png.latex?%0A%5Cmathbb%7BP%7D%20%5Cleft%5B%20Q(a)%20%5Cgt%20%5Chat%7BQ%7D_%7Bt%7D(a)%20+%20U_%7Bt%7D(a)%20%5Cright%5D%20%5Cleq%20%5Cmathrm%7Be%7D%5E%7B-2%20N_%7Bt%7D(a)%20U_%7Bt%7D(a)%5E2%7D.%0A"></p>
<p>This gives a probabilistic upper bound on how much the true reward of an action can exceed its empirical estimate.</p>
</section>
<section id="choosing-the-upper-confidence-bound" class="level4">
<h4 class="anchored" data-anchor-id="choosing-the-upper-confidence-bound">Choosing the Upper Confidence Bound</h4>
<p>We want to select the confidence bound so that the probability of underestimating the true mean is very small. Let us require this probability to be below a small threshold <img src="https://latex.codecogs.com/png.latex?p">:</p>
<p><img src="https://latex.codecogs.com/png.latex?%0A%5Cmathrm%7Be%7D%5E%7B-2N_%7Bt%7D(a)U_%7Bt%7D(a)%5E2%7D%20=%20p.%0A"></p>
<p>Solving for <img src="https://latex.codecogs.com/png.latex?U_%7Bt%7D(a)">, we obtain:</p>
<p><img src="https://latex.codecogs.com/png.latex?%0AU_%7Bt%7D(a)%20=%20%5Csqrt%7B%5Cfrac%7B-%5Cln%7Bp%7D%7D%7B2N_%7Bt%7D(a)%7D%7D.%0A"></p>
<p>This expression defines how much optimism we should add to the empirical estimate based on how many times the action has been sampled.</p>
</section>
</section>
<section id="ucb1" class="level2">
<h2 class="anchored" data-anchor-id="ucb1">UCB1</h2>
<p>From the previous section, we obtained a general form of the confidence bound:</p>
<p><img src="https://latex.codecogs.com/png.latex?%0AU_%7Bt%7D(a)%20=%20%5Csqrt%7B%5Cfrac%7B-%5Cln%7Bp%7D%7D%7B2N_%7Bt%7D(a)%7D%7D.%0A"></p>
<p>The remaining question is how to choose the threshold probability <img src="https://latex.codecogs.com/png.latex?p">. Intuitively, as time goes on and we collect more data, we want our confidence bounds to become <strong>tighter</strong> and failures to become increasingly unlikely. A simple and effective heuristic is to let the failure probability <strong>decrease with time</strong>.</p>
<p>A common choice is:</p>
<p><img src="https://latex.codecogs.com/png.latex?%0Ap%20=%20t%5E%7B-4%7D,%0A"></p>
<p>which makes the failure probabilities summable over time and enables strong regret guarantees.</p>
<p>Substituting this into the confidence bound gives:</p>
<p><img src="https://latex.codecogs.com/png.latex?%0AU_%7Bt%7D(a)%20=%20%5Csqrt%7B%5Cfrac%7B2%5Cln%7Bt%7D%7D%7BN_%7Bt%7D(a)%7D%7D.%0A"></p>
<p>This yields the classic <strong>UCB1</strong> algorithm.</p>
<p>At each time step, UCB1 selects the action that maximizes the optimistic estimate of the reward:</p>
<p><img src="https://latex.codecogs.com/png.latex?%0Aa_%7Bt%7D%5E%5Ctextrm%7BUCB1%7D%20=%20%5Carg%5Cmax_%7Ba%20%5Cin%20%5Cmathcal%7BA%7D%7D%20%5Cleft%5B%20%5Chat%7BQ%7D_t(a)%20+%20%5Csqrt%7B%5Cfrac%7B2%5Cln%7Bt%7D%7D%7BN_%7Bt%7D(a)%7D%7D%20%5Cright%5D.%0A"></p>
<p>Here:</p>
<ul>
<li><img src="https://latex.codecogs.com/png.latex?%5Chat%7BQ%7D_%7Bt%7D(a)"> promotes <strong>exploitation</strong>,</li>
<li>the square-root term promotes <strong>exploration</strong>, shrinking as <img src="https://latex.codecogs.com/png.latex?N_%7Bt%7D(a)"> increases,</li>
<li>and the <img src="https://latex.codecogs.com/png.latex?%5Cln%7Bt%7D"> term ensures that even rarely chosen actions are revisited occasionally.</li>
</ul>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
Note
</div>
</div>
<div class="callout-body-container callout-body">
<p><strong>Why this works (one sentence intuition)</strong></p>
<p>UCB1 always chooses the action with the <strong>highest plausible reward</strong>, where “plausible” is defined by a confidence interval that shrinks as evidence accumulates.</p>
</div>
</div>
<div id="cell-17" class="cell">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb5" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb5-1"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">class</span> UCB1(BaseSolver):</span>
<span id="cb5-2">    <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> <span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">__init__</span>(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, bandit: BaseBandit, init_proba: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">float</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">1.0</span>, seed: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">|</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span>):</span>
<span id="cb5-3">        <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">super</span>().<span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">__init__</span>(bandit)</span>
<span id="cb5-4">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.t <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>  <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># number of time steps</span></span>
<span id="cb5-5">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.estimates <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> np.full(shape<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.bandit.k, fill_value<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>init_proba, dtype<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>np.float64)</span>
<span id="cb5-6">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.rng <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> np.random.default_rng(seed)</span>
<span id="cb5-7"></span>
<span id="cb5-8">    <span class="at" style="color: #657422;
background-color: null;
font-style: inherit;">@property</span></span>
<span id="cb5-9">    <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> estimated_probas(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-&gt;</span> npt.NDArray[np.float64]:</span>
<span id="cb5-10">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.estimates</span>
<span id="cb5-11"></span>
<span id="cb5-12">    <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> run_one_step(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-&gt;</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">tuple</span>[<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span>, <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">float</span>]:</span>
<span id="cb5-13">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.t <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span></span>
<span id="cb5-14"></span>
<span id="cb5-15">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Pick the best one with consideration of upper confidence bounds.</span></span>
<span id="cb5-16">        ucb <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.estimates <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> np.sqrt(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> np.log(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.t) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span> (<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.counts))</span>
<span id="cb5-17"></span>
<span id="cb5-18">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># tie-breaking</span></span>
<span id="cb5-19">        candidates <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> np.flatnonzero(ucb <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">==</span> ucb.<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">max</span>())</span>
<span id="cb5-20">        i <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span>(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.rng.choice(candidates))</span>
<span id="cb5-21"></span>
<span id="cb5-22">        r <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.bandit.generate_reward(i)</span>
<span id="cb5-23"></span>
<span id="cb5-24">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.estimates[i] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+=</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">1.0</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span> (<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.counts[i] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> (r <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.estimates[i])</span>
<span id="cb5-25"></span>
<span id="cb5-26">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> i, r</span></code></pre></div></div>
</div>
</section>
</section>
<section id="bayesian-ucb" class="level1">
<h1>Bayesian UCB</h1>
<p>Bayesian UCB is an instance of the UCB principle in which uncertainty is quantified using the posterior distribution of the reward model.</p>
<p>In the UCB and UCB1 algorithms, we <strong>do not assume any specific form of the reward distribution</strong>. Because of this, we rely on <strong>Hoeffding’s inequality</strong>, which provides a very general but also somewhat loose confidence bound that works for <em>any</em> bounded distribution.</p>
<p>However, in some applications we may have <strong>prior knowledge</strong> about how rewards are distributed. When such information is available, we can replace Hoeffding’s generic bound with a <strong>distribution-aware confidence bound</strong>, leading to a more data-efficient strategy. This idea gives rise to <strong>Bayesian UCB</strong>.</p>
<section id="using-distributional-assumptions" class="level2">
<h2 class="anchored" data-anchor-id="using-distributional-assumptions">Using Distributional Assumptions</h2>
<p>For example, suppose we believe that the mean reward of each slot machine follows a <strong>Gaussian likelihood</strong>, which induces a <strong>Gaussian posterior distribution</strong> over the mean reward of each action. After observing rewards for a given action <img src="https://latex.codecogs.com/png.latex?a">, the posterior is characterized by:</p>
<ul>
<li>a posterior mean <img src="https://latex.codecogs.com/png.latex?%5Cmu_%7Bt%7D(a)">,</li>
<li>and a posterior standard deviation <img src="https://latex.codecogs.com/png.latex?%5Csigma_%7Bt%7D(a)"></li>
</ul>
<p>In this case, a natural choice for the upper confidence bound is the <strong>upper quantile of the posterior</strong>, for instance a 95% confidence bound:</p>
<p><img src="https://latex.codecogs.com/png.latex?%0A%5Chat%7BU%7D_%7Bt%7D(a)%20=%20c%5Csigma_%7Bt%7D(a),%0A"></p>
<p>where <img src="https://latex.codecogs.com/png.latex?c%20%5Capprox%202"> corresponds to a 95% credible interval for a Gaussian distribution.</p>
<p>The Bayesian UCB action selection rule then becomes:</p>
<p><img src="https://latex.codecogs.com/png.latex?%0Aa_%7Bt%7D%5E%7BBayes%5Ctext%7B-%7DUCB%7D%20=%20%5Carg%5Cmax_%7Ba%20%5Cin%20%5Cmathcal%7BA%7D%7D%5Cleft%5B%20%5Cmu_%7Bt%7D(a)%20+%20c%20%5Csigma_%7Bt%7D(a)%20%5Cright%5D.%0A"></p>
<p>Interpretation:</p>
<ul>
<li><img src="https://latex.codecogs.com/png.latex?%5Cmu_%7Bt%7D(a)"> plays the role of <strong>exploitation</strong> (current best estimate),</li>
<li><img src="https://latex.codecogs.com/png.latex?%5Csigma_%7Bt%7D(a)"> captures <strong>uncertainty</strong> (how much we still do not know),</li>
<li>the constant <img src="https://latex.codecogs.com/png.latex?c"> controls how optimistic we are.</li>
</ul>
<p>Compared to UCB1, where uncertainty depends only on <img src="https://latex.codecogs.com/png.latex?N_%7Bt%7D(a)">, Bayesian UCB uses the <strong>full posterior uncertainty</strong>, which often leads to <strong>faster learning</strong> when the model assumptions are correct.</p>
</section>
<section id="key-difference-from-ucb1" class="level2">
<h2 class="anchored" data-anchor-id="key-difference-from-ucb1">Key Difference from UCB1</h2>
<table class="caption-top table">
<thead>
<tr class="header">
<th>UCB1</th>
<th>Bayesian UCB</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td>No distributional assumption</td>
<td>Explicit reward model</td>
</tr>
<tr class="even">
<td>Hoeffding bound</td>
<td>Posterior quantile</td>
</tr>
<tr class="odd">
<td>Worst-case guarantees</td>
<td>Model dependent efficiency</td>
</tr>
</tbody>
</table>
<div id="cell-19" class="cell">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb6" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb6-1"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">class</span> BayesianUCB(BaseSolver):</span>
<span id="cb6-2">    <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> <span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">__init__</span>(</span>
<span id="cb6-3">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, bandit: BaseBandit, c: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">float</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>, init_a: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">float</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, init_b: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">float</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, seed: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">|</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span></span>
<span id="cb6-4">    ) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-&gt;</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span>:</span>
<span id="cb6-5">        <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">super</span>().<span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">__init__</span>(bandit)</span>
<span id="cb6-6"></span>
<span id="cb6-7">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.c <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> c</span>
<span id="cb6-8">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>._as <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> np.full(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.bandit.k, fill_value<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>init_a, dtype<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>np.float64)</span>
<span id="cb6-9">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>._bs <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> np.full(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.bandit.k, fill_value<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>init_b, dtype<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>np.float64)</span>
<span id="cb6-10"></span>
<span id="cb6-11">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.t <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span></span>
<span id="cb6-12">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.rng <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> np.random.default_rng(seed)</span>
<span id="cb6-13"></span>
<span id="cb6-14">    <span class="at" style="color: #657422;
background-color: null;
font-style: inherit;">@property</span></span>
<span id="cb6-15">    <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> estimated_probas(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-&gt;</span> npt.NDArray[np.float64]:</span>
<span id="cb6-16">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>._as <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span> (<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>._as <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>._bs)</span>
<span id="cb6-17"></span>
<span id="cb6-18">    <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> run_one_step(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-&gt;</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">tuple</span>[<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span>, <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">float</span>]:</span>
<span id="cb6-19">        <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">from</span> scipy.stats <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> beta</span>
<span id="cb6-20"></span>
<span id="cb6-21">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.t <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span></span>
<span id="cb6-22"></span>
<span id="cb6-23">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># ensure each arm is tried at least once</span></span>
<span id="cb6-24">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.t <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">&lt;=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.bandit.k:</span>
<span id="cb6-25">            i <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.t <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span></span>
<span id="cb6-26">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">else</span>:</span>
<span id="cb6-27">            mu <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>._as <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span> (<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>._as <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>._bs)  <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># posterior mean</span></span>
<span id="cb6-28">            sigma <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> beta.std(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>._as, <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>._bs)  <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># posterior std Beta(alpha, beta)</span></span>
<span id="cb6-29">            confidence <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> mu <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.c <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> sigma</span>
<span id="cb6-30"></span>
<span id="cb6-31">            <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># tie-breaking</span></span>
<span id="cb6-32">            candidates <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> np.flatnonzero(confidence <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">==</span> confidence.<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">max</span>())</span>
<span id="cb6-33">            i <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.rng.choice(candidates)</span>
<span id="cb6-34"></span>
<span id="cb6-35">        r <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.bandit.generate_reward(i)</span>
<span id="cb6-36"></span>
<span id="cb6-37">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># update Beta posterior for Bernoulli reward</span></span>
<span id="cb6-38">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>._as[i] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+=</span> r  <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># successes</span></span>
<span id="cb6-39">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>._bs[i] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span> r  <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># failures</span></span>
<span id="cb6-40"></span>
<span id="cb6-41">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> i, r</span></code></pre></div></div>
</div>
</section>
</section>
<section id="thompson-sampling" class="level1">
<h1>Thompson Sampling</h1>
<p>Thompson Sampling defines a stochastic policy that selects actions in proportion to their posterior probability of being optimal.</p>
<p>Bayesian UCB still follows the same basic philosophy as UCB1. It builds an <strong>explicit confidence bound</strong> and then acts optimistically with respect to that bound. Thompson Sampling takes a more direct and fully Bayesian approach. Instead of computing an upper bound, it <strong>samples directly from the posterior distribution and acts on that sample</strong>.</p>
<p>The idea is remarkably simple:</p>
<blockquote class="blockquote">
<p>Instead of asking <em>“Which action could be best?”</em>, Thompson Sampling asks <strong>“Which action is most likely to be the best right now?”</strong></p>
</blockquote>
<p>At each time step, we treat the unknown reward probability of each action as a random variable and maintain a <strong>posterior distribution</strong> over its value. Then:</p>
<ol type="1">
<li>We <strong>sample one possible reward</strong> from the posterior of each action.</li>
<li>We <strong>select the action with the highest sampled value</strong>.</li>
<li>We <strong>observe the reward and update the posterior</strong>.</li>
</ol>
<p>This naturally balances exploration and exploitation:</p>
<ul>
<li>actions with high uncertainty are more likely to occasionally produce large samples → <strong>exploration</strong>,</li>
<li>actions with high posterior mean consistently produce large samples → <strong>exploitation</strong>.</li>
</ul>
<p>No explicit exploration parameters or confidence bound is required.</p>
<section id="thompson-sampling-for-bernoulli-bandits-beta-bernoulli" class="level2">
<h2 class="anchored" data-anchor-id="thompson-sampling-for-bernoulli-bandits-beta-bernoulli">Thompson Sampling for Bernoulli Bandits (Beta-Bernoulli)</h2>
<p>In the Bernoulli banding setting, the reward of each action is either <img src="https://latex.codecogs.com/png.latex?0"> or <img src="https://latex.codecogs.com/png.latex?1">. The conjugate prior for the Bernoulli distribution is the <strong>Beta distribution</strong>, so we model each action as:</p>
<p><img src="https://latex.codecogs.com/png.latex?%0A%5Ctheta_%7Ba%7D%20%5Csim%20%5Ctextrm%7BBeta%7D(%5Calpha_a,%20%5Cbeta_a),%0A"></p>
<p>where:</p>
<ul>
<li><img src="https://latex.codecogs.com/png.latex?%5Calpha_a"> counts observed successes,</li>
<li><img src="https://latex.codecogs.com/png.latex?%5Cbeta_a"> counts observed failures.</li>
</ul>
<p>Initially, we typically use a non-informative prior such as:</p>
<p><img src="https://latex.codecogs.com/png.latex?%0A%5Calpha_a%20=%201,%20%5Cquad%20%5Cbeta_a%20=%201.%0A"></p>
<section id="action-selection-1" class="level4">
<h4 class="anchored" data-anchor-id="action-selection-1">Action Selection</h4>
<p>At time <img src="https://latex.codecogs.com/png.latex?t"> Thompson Sampling performs:</p>
<p><img src="https://latex.codecogs.com/png.latex?%0A%5Ctilde%7B%5Ctheta_a%7D%20%5Csim%20%5Ctextrm%7BBeta%7D(%5Calpha_a,%20%5Cbeta_a)%20%5Cquad%20%5Ctextrm%7Bfor%20each%7D%20a%20%5Cin%20%5Cmathcal%7BA%7D,%0A"></p>
<p>and selects:</p>
<p><img src="https://latex.codecogs.com/png.latex?%0Aa_t%5E%5Ctextrm%7BTS%7D%20=%20%5Carg%5Cmax_%7Ba%20%5Cin%20%5Cmathcal%7BA%7D%7D%20%5Ctilde%7B%5Ctheta%7D_a.%0A"></p>
<p>That is, we draw one plausible value for each arm and act greedly with respect to this randomly sampled world.</p>
</section>
<section id="posterior-update" class="level4">
<h4 class="anchored" data-anchor-id="posterior-update">Posterior Update</h4>
<p>After observing the reward <img src="https://latex.codecogs.com/png.latex?r_t%20%5Cin%20%7B0,1%7D">, we update:</p>
<p><img src="https://latex.codecogs.com/png.latex?%0A%5Calpha_a%20%5Cleftarrow%20%5Calpha_a%20+%20r_t,%20%5Cquad%20%5Cbeta_a%20%5Cleftarrow%20%5Cbeta_a%20+%20(1%20-%20r_t).%0A"></p>
<p>This update is exact Bayesian inference for the Bernoulli-Beta model.</p>
</section>
<section id="why-thompson-sampling-works-so-well" class="level4">
<h4 class="anchored" data-anchor-id="why-thompson-sampling-works-so-well">Why Thompson Sampling Works so Well</h4>
<p>Thompson Sampling does not separate exploration from exploitation. Instead, exploration <strong>emerges naturally from uncertainty</strong> in the posterior:</p>
<ul>
<li>If an action is well understood, its posterior is sharp (little randomness).</li>
<li>If an action is uncertain, its posterior is wide (occasional optimistic samples).</li>
</ul>
<p>In contrast:</p>
<ul>
<li><img src="https://latex.codecogs.com/png.latex?%5Cepsilon">-greedy explores <strong>blindly</strong>,</li>
<li>UCB explores via <strong>deterministic optimism</strong>,</li>
<li>Thompson Sampling explores via <strong>probabilistic belief</strong>.</li>
</ul>
</section>
<section id="relationship-to-bayesian-ucb" class="level4">
<h4 class="anchored" data-anchor-id="relationship-to-bayesian-ucb">Relationship to Bayesian UCB</h4>
<p>Bayesian UCB selects actions using:</p>
<p><img src="https://latex.codecogs.com/png.latex?%0A%5Cmu_t(a)%20+%20c%20%5Csigma_t(a),%0A"></p>
<p>which corresponds to choosing a fixed <strong>upper quantile</strong> of the posterior.</p>
<p>Thompson Sampling instead <strong>draws a random quantile at every time step</strong>. In this sense:</p>
<blockquote class="blockquote">
<p>Bayesian UCB is optimistic; Thompson Sampling is probabilistic.</p>
</blockquote>
<p>Both use Bayesian posteriors, but Thompson Sampling avoids manually choosing confidence levels.</p>
<div id="cell-21" class="cell">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb7" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb7-1"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">class</span> ThompsonSampling(BaseSolver):</span>
<span id="cb7-2">    <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> <span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">__init__</span>(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, bandit: BaseBandit, init_a: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, init_b: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, seed: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">|</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span>) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-&gt;</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span>:</span>
<span id="cb7-3">        <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">super</span>().<span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">__init__</span>(bandit)</span>
<span id="cb7-4"></span>
<span id="cb7-5">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>._as <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> np.full(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.bandit.k, fill_value<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>init_a, dtype<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>np.float64)</span>
<span id="cb7-6">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>._bs <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> np.full(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.bandit.k, fill_value<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>init_b, dtype<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>np.float64)</span>
<span id="cb7-7"></span>
<span id="cb7-8">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.rng <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> np.random.default_rng(seed)</span>
<span id="cb7-9"></span>
<span id="cb7-10">    <span class="at" style="color: #657422;
background-color: null;
font-style: inherit;">@property</span></span>
<span id="cb7-11">    <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> estimated_probas(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-&gt;</span> npt.NDArray[np.float64]:</span>
<span id="cb7-12">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>._as <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span> (<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>._as <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>._bs)</span>
<span id="cb7-13"></span>
<span id="cb7-14">    <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> run_one_step(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-&gt;</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">tuple</span>[<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span>, <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">float</span>]:</span>
<span id="cb7-15">        samples <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.rng.beta(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>._as, <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>._bs)</span>
<span id="cb7-16"></span>
<span id="cb7-17">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># tie-breaking</span></span>
<span id="cb7-18">        candidates <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> np.flatnonzero(samples <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">==</span> samples.<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">max</span>())</span>
<span id="cb7-19">        i <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span>(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.rng.choice(candidates))</span>
<span id="cb7-20"></span>
<span id="cb7-21">        r <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.bandit.generate_reward(i)</span>
<span id="cb7-22"></span>
<span id="cb7-23">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>._as[i] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+=</span> r</span>
<span id="cb7-24">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>._bs[i] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span> r</span>
<span id="cb7-25"></span>
<span id="cb7-26">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> i, r</span></code></pre></div></div>
</div>
</section>
</section>
</section>
<section id="benchmark" class="level1">
<h1>Benchmark</h1>
<div id="cell-23" class="cell">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb8" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb8-1">N_STEPS <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">10_000</span></span>
<span id="cb8-2">SEED <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="bn" style="color: #AD0000;
background-color: null;
font-style: inherit;">0x42</span></span>
<span id="cb8-3">K <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">10</span></span>
<span id="cb8-4"></span>
<span id="cb8-5">np.random.seed(SEED)</span>
<span id="cb8-6">rng <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> np.random.default_rng(SEED)</span>
<span id="cb8-7"></span>
<span id="cb8-8"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Probabilities {0.0, 0.1, ..., 0.9} then shuffle them</span></span>
<span id="cb8-9"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># probas = rng.uniform(0, 1, size=K)</span></span>
<span id="cb8-10">probas <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> np.linspace(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, K, endpoint<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">False</span>, dtype<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>np.float64)</span>
<span id="cb8-11"><span class="bu" style="color: null;
background-color: null;
font-style: inherit;">print</span>(probas)</span>
<span id="cb8-12">rng.shuffle(probas)</span>
<span id="cb8-13"></span>
<span id="cb8-14">bbandit <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> BernoulliBandit(k<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>K, probas<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>probas, seed<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>SEED)</span>
<span id="cb8-15">epsgreedy <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> EpsilonGreedy(bbandit, eps<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.01</span>, seed<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>SEED)</span>
<span id="cb8-16">epsgreedy.run(N_STEPS)</span>
<span id="cb8-17"></span>
<span id="cb8-18"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Random is a special case of EpsilogGreedy</span></span>
<span id="cb8-19"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># bbandit = BernoulliBandit(k=K, probas=probas, seed=SEED)</span></span>
<span id="cb8-20"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># random = EpsilonGreedy(bbandit, eps=1.0, seed=SEED)</span></span>
<span id="cb8-21"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># random.run(N_STEPS)</span></span>
<span id="cb8-22"></span>
<span id="cb8-23">bbandit <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> BernoulliBandit(k<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>K, probas<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>probas, seed<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>SEED)</span>
<span id="cb8-24">ucb1 <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> UCB1(bbandit, seed<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>SEED)</span>
<span id="cb8-25">ucb1.run(N_STEPS)</span>
<span id="cb8-26"></span>
<span id="cb8-27">bbandit <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> BernoulliBandit(k<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>K, probas<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>probas, seed<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>SEED)</span>
<span id="cb8-28">bayesian <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> BayesianUCB(bbandit, seed<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>SEED)</span>
<span id="cb8-29">bayesian.run(N_STEPS)</span>
<span id="cb8-30"></span>
<span id="cb8-31">bbandit <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> BernoulliBandit(k<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>K, probas<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>probas, seed<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>SEED)</span>
<span id="cb8-32">thompson <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> ThompsonSampling(bbandit, seed<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>SEED)</span>
<span id="cb8-33">thompson.run(N_STEPS)</span></code></pre></div></div>
<div class="cell-output cell-output-stdout">
<pre><code>[0.  0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9]</code></pre>
</div>
</div>
<div id="cell-fig-benchmark" class="cell">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb10" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb10-1"><span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">with</span> load_theme(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"ambivalent"</span>):</span>
<span id="cb10-2">    fig, ax <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> plt.subplots(ncols<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">3</span>, nrows<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, figsize<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">12</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">4</span>), facecolor<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"none"</span>, layout<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"constrained"</span>)</span>
<span id="cb10-3"></span>
<span id="cb10-4">    solvers_labels <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> {</span>
<span id="cb10-5">        <span class="vs" style="color: #20794D;
background-color: null;
font-style: inherit;">r"</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">$</span><span class="er" style="color: #AD0000;
background-color: null;
font-style: inherit;">\</span><span class="vs" style="color: #20794D;
background-color: null;
font-style: inherit;">epsilon</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">$</span><span class="vs" style="color: #20794D;
background-color: null;
font-style: inherit;">-greedy"</span>: epsgreedy,</span>
<span id="cb10-6">        <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"UCB1"</span>: ucb1,</span>
<span id="cb10-7">        <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"Bayesian"</span>: bayesian,</span>
<span id="cb10-8">        <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"Thompson"</span>: thompson,</span>
<span id="cb10-9">    }</span>
<span id="cb10-10"></span>
<span id="cb10-11">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># --- 1) cumulative regret ---</span></span>
<span id="cb10-12">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> label, solver <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> solvers_labels.items():</span>
<span id="cb10-13">        ax[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>].plot(np.cumsum(solver.regrets), label<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>label, clip_on<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">False</span>)</span>
<span id="cb10-14"></span>
<span id="cb10-15">    ax[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>].set_xlabel(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"Time steps"</span>)</span>
<span id="cb10-16">    ax[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>].set_ylabel(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"Cumulative regret"</span>)</span>
<span id="cb10-17"></span>
<span id="cb10-18">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># --- shared x for action-ranked plots ---</span></span>
<span id="cb10-19">    sorted_indices <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> np.argsort(bbandit.probas)</span>
<span id="cb10-20">    x <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> np.arange(bbandit.k)  <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># 0..k-1 (rank after sorting)</span></span>
<span id="cb10-21">    p_true <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> bbandit.probas[sorted_indices]</span>
<span id="cb10-22"></span>
<span id="cb10-23">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># jitter for scatter points (so methods don't overlap)</span></span>
<span id="cb10-24">    n_methods <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">len</span>(solvers_labels)</span>
<span id="cb10-25">    jit <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.12</span>  <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># horizontal separation between methods (in "x units")</span></span>
<span id="cb10-26">    offsets <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> (np.arange(n_methods) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span> (n_methods <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> jit</span>
<span id="cb10-27"></span>
<span id="cb10-28">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># --- 2) estimated probability per action (jittered scatter + true line) ---</span></span>
<span id="cb10-29">    ax[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>].plot(</span>
<span id="cb10-30">        x,</span>
<span id="cb10-31">        p_true,</span>
<span id="cb10-32">        linestyle<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"-."</span>,</span>
<span id="cb10-33">        marker<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"o"</span>,</span>
<span id="cb10-34">        markersize<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">3</span>,</span>
<span id="cb10-35">        label<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"True $p(a)$"</span>,</span>
<span id="cb10-36">        zorder<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>,</span>
<span id="cb10-37">        clip_on<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">False</span>,</span>
<span id="cb10-38">    )</span>
<span id="cb10-39"></span>
<span id="cb10-40">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> off, (label, solver) <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">zip</span>(offsets, solvers_labels.items(), strict<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>):</span>
<span id="cb10-41">        ax[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>].scatter(</span>
<span id="cb10-42">            x <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> off,</span>
<span id="cb10-43">            solver.estimated_probas[sorted_indices],</span>
<span id="cb10-44">            s<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">35</span>,</span>
<span id="cb10-45">            label<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>label,</span>
<span id="cb10-46">            alpha<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.8</span>,</span>
<span id="cb10-47">            zorder<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>,</span>
<span id="cb10-48">            clip_on<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">False</span>,</span>
<span id="cb10-49">        )</span>
<span id="cb10-50"></span>
<span id="cb10-51">    ax[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>].set_xlabel(<span class="vs" style="color: #20794D;
background-color: null;
font-style: inherit;">r"Actions sorted by </span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">$</span><span class="ch" style="color: #20794D;
background-color: null;
font-style: inherit;">\t</span><span class="vs" style="color: #20794D;
background-color: null;
font-style: inherit;">heta</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">$</span><span class="vs" style="color: #20794D;
background-color: null;
font-style: inherit;">"</span>)</span>
<span id="cb10-52">    ax[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>].set_ylabel(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"Estimated probability"</span>)</span>
<span id="cb10-53">    ax[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>].set_xticks(x)</span>
<span id="cb10-54">    ax[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>].set_xticklabels([<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">str</span>(i) <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> i <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> x])  <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># or sorted_indices.astype(str) for original IDs</span></span>
<span id="cb10-55">    ax[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>].set_ylim(<span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.0</span>, <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">1.0</span>)</span>
<span id="cb10-56"></span>
<span id="cb10-57">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># --- 3) action selection rate (grouped bars, centered on ranks) ---</span></span>
<span id="cb10-58">    width <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.18</span></span>
<span id="cb10-59">    bar_offsets <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> (np.arange(n_methods) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span> (n_methods <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> width</span>
<span id="cb10-60"></span>
<span id="cb10-61">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> off, (label, solver) <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">zip</span>(bar_offsets, solvers_labels.items(), strict<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>):</span>
<span id="cb10-62">        ax[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>].bar(</span>
<span id="cb10-63">            x <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> off,</span>
<span id="cb10-64">            solver.counts[sorted_indices] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">len</span>(solver.regrets) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">100.0</span>,</span>
<span id="cb10-65">            width<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>width,</span>
<span id="cb10-66">            label<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>label,</span>
<span id="cb10-67">            alpha<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.85</span>,</span>
<span id="cb10-68">            clip_on<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">False</span>,</span>
<span id="cb10-69">        )</span>
<span id="cb10-70"></span>
<span id="cb10-71">    ax[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>].set_xlabel(<span class="vs" style="color: #20794D;
background-color: null;
font-style: inherit;">r"Actions sorted by </span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">$</span><span class="ch" style="color: #20794D;
background-color: null;
font-style: inherit;">\t</span><span class="vs" style="color: #20794D;
background-color: null;
font-style: inherit;">heta</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">$</span><span class="vs" style="color: #20794D;
background-color: null;
font-style: inherit;">"</span>)</span>
<span id="cb10-72">    ax[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>].set_ylabel(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"</span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">% o</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">f trials"</span>)</span>
<span id="cb10-73">    ax[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>].set_xticks(x)</span>
<span id="cb10-74">    ax[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>].set_xticklabels([<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">str</span>(i) <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> i <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> x])  <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># or sorted_indices.astype(str)</span></span>
<span id="cb10-75">    ax[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>].set_ylim(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">100</span>)</span>
<span id="cb10-76"></span>
<span id="cb10-77">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># (Optional) make the two right panels less "grid heavy" if your theme uses strong grids</span></span>
<span id="cb10-78">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> a <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> (ax[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>], ax[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>]):</span>
<span id="cb10-79">        a.grid(axis<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"y"</span>, alpha<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.25</span>)</span>
<span id="cb10-80">        a.set_axisbelow(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>)</span>
<span id="cb10-81"></span>
<span id="cb10-82">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># --- single shared legend (deduplicated) ---</span></span>
<span id="cb10-83">    handles, labels <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> [], []</span>
<span id="cb10-84">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> axis <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> fig.axes:</span>
<span id="cb10-85">        _handles, _labels <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> axis.get_legend_handles_labels()</span>
<span id="cb10-86">        handles.extend(_handles)</span>
<span id="cb10-87">        labels.extend(_labels)</span>
<span id="cb10-88"></span>
<span id="cb10-89">    by_label <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">dict</span>(<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">zip</span>(labels, handles, strict<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>))</span>
<span id="cb10-90">    fig.legend(</span>
<span id="cb10-91">        by_label.values(),</span>
<span id="cb10-92">        by_label.keys(),</span>
<span id="cb10-93">        loc<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">8</span>,</span>
<span id="cb10-94">        ncols<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="bu" style="color: null;
background-color: null;
font-style: inherit;">len</span>(by_label),</span>
<span id="cb10-95">        bbox_to_anchor<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>(<span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.5</span>, <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span><span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.1</span>),</span>
<span id="cb10-96">        fancybox<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>,</span>
<span id="cb10-97">        frameon<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>,</span>
<span id="cb10-98">    )</span>
<span id="cb10-99"></span>
<span id="cb10-100">plt.show()</span></code></pre></div></div>
<div class="cell-output cell-output-display">
<div id="fig-benchmark" class="quarto-float quarto-figure quarto-figure-center anchored">
<figure class="quarto-float quarto-float-fig figure">
<div aria-describedby="fig-benchmark-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
<img src="https://gcerar.github.io/posts/reinforcement-learning/2025-12-16-multi-armed-bandit-problem_files/figure-html/fig-benchmark-output-1.png" width="1211" height="447" class="figure-img">
</div>
<figcaption class="quarto-float-caption-bottom quarto-float-caption quarto-float-fig" id="fig-benchmark-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
Figure&nbsp;1: The results of the experiment on solving Bernoulli bandit with K=10, slot machines with reward probabilities, {0.0, 0.1, …, 0.9}. Each solver runs 10,000 steps.
</figcaption>
</figure>
</div>
</div>
</div>
<p>The figure above shows side-by-side comparison of four bandint strategies: <img src="https://latex.codecogs.com/png.latex?%5Cepsilon">-greedy, UCB1, Bayesian UCB, and Thompson Sampling. All algorithms are evaluated on the same 10-armed Bernoulli bandit. Each subplot highlights a different aspect of algorithmic behavior: regret, reward estimation, and exploration patterns. Together, they illustrate how the theoretical ideas introduced earlier play out in practice.</p>
<section id="cumulative-regret-over-time-left" class="level4">
<h4 class="anchored" data-anchor-id="cumulative-regret-over-time-left">1. Cumulative Regret Over Time (left)</h4>
<p>The left subplot shows how much regret each algorithm accumulates over 10,000 time steps. Lower curve indicate better performance.</p>
<ul>
<li><strong>Thompson Sampling</strong> performs best. Its regret curve rises slowly at first and then flatten, showing that it quickly identifies the optimal arm and almost never leaves it afterward.</li>
<li><strong>Bayesian UCB</strong> is slightly worse but still competitive. Using posterior uncertainty leads to steady improvement without requiring an explicit exploration parameter.</li>
<li><strong><img src="https://latex.codecogs.com/png.latex?%5Cepsilon">-greedy</strong> suffers more early regret and coverges more slowly, since it explores randomly rather than strategically.</li>
<li><strong>UCB1</strong> explores aggresively and therefore incurs noticeably higher regret. This is expected in settings where several arms have relatively high reward probabilities, making early optimistic exploration particularly costly.</li>
</ul>
<p>The qualitative ordering matches classic theoretical results: <strong>Thompson Sampling</strong> <img src="https://latex.codecogs.com/png.latex?%5Cge"> <strong>Bayesian UCB</strong> <img src="https://latex.codecogs.com/png.latex?%5Cge"> <strong><img src="https://latex.codecogs.com/png.latex?%5Cepsilon">-greedy</strong> <img src="https://latex.codecogs.com/png.latex?%5Cge"> <strong>UCB1</strong> for this type of environment.</p>
</section>
<section id="estimated-reward-probabilities-middle" class="level4">
<h4 class="anchored" data-anchor-id="estimated-reward-probabilities-middle">2. Estimated Reward Probabilities (middle)</h4>
<p>The middle subplot show how accurately each method estimates the reward probability of each arm after training. Arms are sorted by their true <img src="https://latex.codecogs.com/png.latex?%5Ctheta"> values, and the dashed line represents perfect estimation.</p>
<ul>
<li><strong>Thompson Sampling</strong> and <strong>Bayesian UCB</strong> are concentrated near the diagonal. Their estimates are reasonably accurate even for suboptimal arms.</li>
<li><strong><img src="https://latex.codecogs.com/png.latex?%5Cepsilon">-greedy</strong> is more scattered. Because it explores randomly and infrequently revisits some arms, several estimates remain biased or underdeveloped.</li>
<li><strong>UCB1</strong> tends to overestimate some suboptimal arms early on and then underexplore them later. UCB’s deterministic optimism often leads to distinctive estimation bias: Once the bonus term shrinks, there is little incentive to revisit an arm, even if its estimate is wrong.</li>
</ul>
<p>This subplot highlights key difference: <strong>good decision-making does not always require perfectly accurate models</strong>, but algorithms that maintain richer uncertainty estimates (Bayesian UCB and Thompson Sampling) tend to form more reliable estimates.</p>
</section>
<section id="fraction-of-pulls-per-arm-right" class="level4">
<h4 class="anchored" data-anchor-id="fraction-of-pulls-per-arm-right">3. Fraction of Pulls per Arm (right)</h4>
<p>The right subplot shows how often each algorithm select each action. Here, the behavioral differences are most visible.</p>
<ul>
<li><strong>Thompson Sampling</strong> plays the best arm almost exclusively, with its bar nearly reaching 100%.</li>
<li><strong>Bayesian UCB</strong> focuses heavily on the best arm but still allocates a small percentage of trials to others due to posterior uncertainty.</li>
<li><strong><img src="https://latex.codecogs.com/png.latex?%5Cepsilon">-greedy</strong> spreads its attention more broadly. Because exploration is random, even clearly suboptimal arms continue to receive occasional pulls.</li>
<li><strong>UCB1</strong> revisits several arms during the optimistic exploration phase. Once the bonus term shrinks, it commits strongly to the best arm, but the early exploration leaves a visible footprint.</li>
</ul>
<p>This suboptimal emphasizes how each strategy allocates exploration effort:</p>
<ul>
<li><strong><img src="https://latex.codecogs.com/png.latex?%5Cepsilon">-greedy</strong>: broad, unfocused exploration</li>
<li><strong>UCB1</strong>: early over-exploration, later commitment</li>
<li><strong>Bayesian UCB</strong>: exploration guided by posterior uncertainty</li>
<li><strong>Thompson</strong>: exploration proportional to probabilities of being optimal</li>
</ul>
</section>
<section id="putting-it-all-together" class="level4">
<h4 class="anchored" data-anchor-id="putting-it-all-together">Putting It All Together</h4>
<p>These three views (regret, estimation accuracy, and action frequencies) provide a comprehensive picture of each algorithm’s strengths and weaknesses:</p>
<ul>
<li><strong>Thompson Sampling</strong> is consistent and strong: low regret, accurate estimation, and efficient exploration.</li>
<li><strong>Bayesian UCB</strong> offers a pricipled middle ground and performs well when prior structure is appropriate.</li>
<li><strong><img src="https://latex.codecogs.com/png.latex?%5Cepsilon">-greedy</strong> is simple but wasteful: random exploration leads to both under- and over-exploration.</li>
<li><strong>UCB1</strong> works as intended, but deterministic optimism causes large early regret when many arms have similar payoffs.</li>
</ul>
<p>Results shown correspond to a single random seed; while relative performance may vary across runs, the qualitative behavior and average ordering are consistent with theoretical expectations.</p>
<p>Overall, the benchmark illustrates a central message of the exploration-exploitation dilemma: <strong>better uncertainty modeling leads to more efficient learning</strong>.</p>
</section>
</section>
<section id="conclusions" class="level1">
<h1>Conclusions</h1>
<p>The benchmark highlights the core differences between bandit algorithms in practice:</p>
<ul>
<li><strong>Thompson Sampling</strong> achieves the lowest regret and concentrates almost all pulls on the optimal arm, reflecting efficient, uncertainty-aware exploration.</li>
<li><strong>Bayesian UCB</strong> performs similarly well, balancing optimism with Bayesian posterior uncertainty.</li>
<li><strong><img src="https://latex.codecogs.com/png.latex?%5Cepsilon">-greedy</strong> is simple but wasteful: random exploration leads to slower convergence and less accurate value estimates.</li>
<li><strong>UCB1</strong> explores aggressively early on, which increases regret in environments with many high-reward arms.</li>
</ul>
<p>Overall, algorithms that model uncertainty explicitly, such as Thompson Sampling and Bayesian UCB, deliver more focused exploration and stronger performance.</p>
</section>
<section id="appendix" class="level1">
<h1>Appendix</h1>
<table class="caption-top table">
<colgroup>
<col style="width: 23%">
<col style="width: 38%">
<col style="width: 16%">
<col style="width: 20%">
</colgroup>
<thead>
<tr class="header">
<th>Method</th>
<th>Exploration mechanism</th>
<th>Determistic?</th>
<th>Uses Posterior?</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td><img src="https://latex.codecogs.com/png.latex?%5Cepsilon">-greedy</td>
<td>Random with prob. <img src="https://latex.codecogs.com/png.latex?%5Cepsilon"></td>
<td>No</td>
<td>No</td>
</tr>
<tr class="even">
<td>UCB1</td>
<td>Optimism via bound</td>
<td>Yes</td>
<td>No</td>
</tr>
<tr class="odd">
<td>Bayesian UCB</td>
<td>Posterior quantile</td>
<td>Yes</td>
<td>Yes</td>
</tr>
<tr class="even">
<td>Thompson Sampling</td>
<td>Posterior sampling</td>
<td>No</td>
<td>Yes</td>
</tr>
</tbody>
</table>


</section>

<div id="quarto-appendix" class="default"><section class="quarto-appendix-contents" id="quarto-reuse"><h2 class="anchored quarto-appendix-heading">Reuse</h2><div class="quarto-appendix-contents"><div><a rel="license" href="https://creativecommons.org/licenses/by-nc-sa/4.0/deed.en_us">CC BY-NC-SA 4.0</a></div></div></section></div> ]]></description>
  <category>python</category>
  <category>code</category>
  <category>MAB</category>
  <category>reinforcement learning</category>
  <guid>https://gcerar.github.io/posts/reinforcement-learning/2025-12-16-multi-armed-bandit-problem.html</guid>
  <pubDate>Mon, 15 Dec 2025 23:00:00 GMT</pubDate>
</item>
<item>
  <title>Building a No‑Fluff Report Template in LaTeX</title>
  <dc:creator>Gregor Cerar</dc:creator>
  <link>https://gcerar.github.io/posts/2025-05-07-ministate-v3/</link>
  <description><![CDATA[ 





<section id="why-a-new-template" class="level2">
<h2 class="anchored" data-anchor-id="why-a-new-template">Why a New Template?</h2>
<ol type="1">
<li><strong>Dense (less fluff)</strong> — Every square centimeter should serve the reader. Tighter vertical spacing and compact headings keep the narrative flowing.</li>
<li><strong>Optional titles</strong> — Some documents benefit from a title; others (like a brief update) do not. The template should let me toggle them off with a single flag.</li>
<li><strong>Flexible</strong> — Today, I might need a one‑pager. Tomorrow, a 20‑page appendix. Layout decisions (margins, font, color) should be parameterized — not hard‑wired.</li>
</ol>
<p>Existing classes like <code>article</code> or even <code>IEEEtran</code> come close but still force unnecessary baggage on the author (abstract blocks, keywords, etc.). Then I stumbled upon the elegant <a href="https://www.overleaf.com/latex/templates/minimalstatement/pzgpkvvrzyqj"><code>ministate</code></a> class. So I adapted it.</p>
</section>
<section id="meet-ministate-v3.0" class="level2">
<h2 class="anchored" data-anchor-id="meet-ministate-v3.0">Meet <em>ministate</em> v3.0</h2>
<div class="code-with-filename">
<div class="code-with-filename-file">
<pre><strong>ministate.cls</strong></pre>
</div>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb1" data-filename="ministate.cls" style="background: #f1f3f5;"><pre class="sourceCode latex code-with-copy"><code class="sourceCode latex"><span id="cb1-1"><span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">\ProvidesClass</span>{ministate}[2023/03/29 v3.0 Minimalist statement class]</span>
<span id="cb1-2"><span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">\LoadClass</span>[11pt,a4paper]{article}</span>
<span id="cb1-3"></span>
<span id="cb1-4"><span class="bu" style="color: null;
background-color: null;
font-style: inherit;">\usepackage</span>[utf8]{<span class="ex" style="color: null;
background-color: null;
font-style: inherit;">inputenc</span>} <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">% from 2018, UTF-8 is default in LaTeX</span></span>
<span id="cb1-5"><span class="bu" style="color: null;
background-color: null;
font-style: inherit;">\usepackage</span>[T1]{<span class="ex" style="color: null;
background-color: null;
font-style: inherit;">fontenc</span>}</span>
<span id="cb1-6"><span class="bu" style="color: null;
background-color: null;
font-style: inherit;">\usepackage</span>{<span class="ex" style="color: null;
background-color: null;
font-style: inherit;">lmodern</span>}</span>
<span id="cb1-7"></span>
<span id="cb1-8"><span class="bu" style="color: null;
background-color: null;
font-style: inherit;">\usepackage</span>{<span class="ex" style="color: null;
background-color: null;
font-style: inherit;">microtype</span>}</span>
<span id="cb1-9"></span>
<span id="cb1-10"><span class="bu" style="color: null;
background-color: null;
font-style: inherit;">\usepackage</span>[margin=0.8in]{<span class="ex" style="color: null;
background-color: null;
font-style: inherit;">geometry</span>}</span>
<span id="cb1-11"><span class="bu" style="color: null;
background-color: null;
font-style: inherit;">\usepackage</span>{<span class="ex" style="color: null;
background-color: null;
font-style: inherit;">parskip</span>}</span>
<span id="cb1-12"><span class="bu" style="color: null;
background-color: null;
font-style: inherit;">\usepackage</span>{<span class="ex" style="color: null;
background-color: null;
font-style: inherit;">fancyhdr</span>}</span>
<span id="cb1-13"></span>
<span id="cb1-14"><span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">\setlength</span>{<span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">\headheight</span>}{15.2pt}</span>
<span id="cb1-15"><span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">\pagestyle</span>{fancy}</span>
<span id="cb1-16"><span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">\fancyhf</span>{} <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">% Clear all header and footer fields</span></span>
<span id="cb1-17"></span>
<span id="cb1-18"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">%--------------------------------------------------%</span></span>
<span id="cb1-19"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">%    Title, HeaderTitle, Author, HeaderAuthor,     %</span></span>
<span id="cb1-20"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">%                 Custom Date                      %</span></span>
<span id="cb1-21"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">%--------------------------------------------------%</span></span>
<span id="cb1-22"></span>
<span id="cb1-23"><span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">\let\oldtitle\title</span></span>
<span id="cb1-24"><span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">\let\oldauthor\author</span></span>
<span id="cb1-25"><span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">\let\olddate\date</span></span>
<span id="cb1-26"></span>
<span id="cb1-27"><span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">\def\@headertitle</span>{}</span>
<span id="cb1-28"><span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">\def\@headerauthor</span>{}</span>
<span id="cb1-29"><span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">\def\@headerdate</span>{}</span>
<span id="cb1-30"></span>
<span id="cb1-31"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">% Redefine the \title and \author commands</span></span>
<span id="cb1-32"><span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">\renewcommand</span>{<span class="ex" style="color: null;
background-color: null;
font-style: inherit;">\title</span>}[1]{<span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">%</span></span>
<span id="cb1-33">    <span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">\oldtitle</span>{#1}<span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">%</span></span>
<span id="cb1-34">    <span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">\ifx\@headertitle\@empty</span><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">%</span></span>
<span id="cb1-35">        <span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">\relax\def\@headertitle</span>{#1}<span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">%</span></span>
<span id="cb1-36">    <span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">\fi</span><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">%</span></span>
<span id="cb1-37">}</span>
<span id="cb1-38"><span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">\renewcommand</span>{<span class="ex" style="color: null;
background-color: null;
font-style: inherit;">\author</span>}[1]{<span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">%</span></span>
<span id="cb1-39">    <span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">\oldauthor</span>{#1}<span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">%</span></span>
<span id="cb1-40">    <span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">\ifx\@headerauthor\@empty</span><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">%</span></span>
<span id="cb1-41">        <span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">\relax\def\@headerauthor</span>{#1}<span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">%</span></span>
<span id="cb1-42">    <span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">\fi</span><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">%</span></span>
<span id="cb1-43">}</span>
<span id="cb1-44"></span>
<span id="cb1-45"><span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">\renewcommand</span>{<span class="ex" style="color: null;
background-color: null;
font-style: inherit;">\date</span>}[1]{<span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">%</span></span>
<span id="cb1-46">    <span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">\olddate</span>{#1}<span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">%</span></span>
<span id="cb1-47">    <span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">\ifx\@headerdate\@empty</span><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">%</span></span>
<span id="cb1-48">        <span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">\relax\def\@headerdate</span>{#1}<span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">%</span></span>
<span id="cb1-49">    <span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">\fi</span><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">%</span></span>
<span id="cb1-50">}</span>
<span id="cb1-51"></span>
<span id="cb1-52"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">% Commands for explicitly setting the header title and header author</span></span>
<span id="cb1-53"><span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">\newcommand</span>{<span class="ex" style="color: null;
background-color: null;
font-style: inherit;">\headertitle</span>}[1]{<span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">\def\@headertitle</span>{#1}}</span>
<span id="cb1-54"><span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">\newcommand</span>{<span class="ex" style="color: null;
background-color: null;
font-style: inherit;">\headerauthor</span>}[1]{<span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">\def\@headerauthor</span>{#1}}</span>
<span id="cb1-55"></span>
<span id="cb1-56"><span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">\fancypagestyle</span>{ministate}{<span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">%</span></span>
<span id="cb1-57">  <span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">\fancyhf</span>{}<span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">% clear everything</span></span>
<span id="cb1-58">  <span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">\fancyhead</span>[L]{<span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">\textbf</span>{<span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">\@headertitle</span>}<span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">\ifx\@headerdate\@empty\else\ </span>(<span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">\@headerdate</span>)<span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">\fi</span>}<span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">%</span></span>
<span id="cb1-59">  <span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">\fancyhead</span>[R]{<span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">\textbf</span>{<span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">\@headerauthor</span>}}<span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">%</span></span>
<span id="cb1-60">  <span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">\fancyfoot</span>[C]{<span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">\thepage</span>}<span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">%</span></span>
<span id="cb1-61">}</span>
<span id="cb1-62"></span>
<span id="cb1-63"><span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">\pagestyle</span>{ministate}</span>
<span id="cb1-64"></span>
<span id="cb1-65"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">% Apply header settings including the custom date</span></span>
<span id="cb1-66"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">%\fancyhead[L]{\textbf{\@headertitle}\ifx\@headerdate\@empty\else\ (\@headerdate)\fi} % Title (Custom Date)</span></span>
<span id="cb1-67"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">%\fancyhead[R]{\textbf{\@headerauthor}} % Author</span></span>
<span id="cb1-68"></span>
<span id="cb1-69"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">%\fancyfoot{} % Override existing foot numbering</span></span>
<span id="cb1-70"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">%\fancyfoot[C]{\thepage} % Page number at center of footer</span></span>
<span id="cb1-71"></span>
<span id="cb1-72"></span>
<span id="cb1-73"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">%--------------------------------------------------%</span></span>
<span id="cb1-74"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">%                   Document Body                  %</span></span>
<span id="cb1-75"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">%--------------------------------------------------%</span></span>
<span id="cb1-76"></span>
<span id="cb1-77"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">% Usage:</span></span>
<span id="cb1-78"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">% \title{Your Title Here}</span></span>
<span id="cb1-79"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">% \author{Author Name}</span></span>
<span id="cb1-80"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">% \headertitle{Your Header Title Here} - For custom header title</span></span>
<span id="cb1-81"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">% \headerauthor{Your Header Author Here} - For custom header author</span></span>
<span id="cb1-82"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">% \date{Custom Date or Empty String} - To change or remove the date</span></span>
<span id="cb1-83"></span>
<span id="cb1-84"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">% Comment this block if we don't want header on the first page</span></span>
<span id="cb1-85"><span class="bu" style="color: null;
background-color: null;
font-style: inherit;">\usepackage</span>{<span class="ex" style="color: null;
background-color: null;
font-style: inherit;">etoolbox</span>}   <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">% load before you patch anything</span></span>
<span id="cb1-86"><span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">\makeatletter</span></span>
<span id="cb1-87"><span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">\patchcmd</span>{<span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">\maketitle</span>}<span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">%          the command to patch</span></span>
<span id="cb1-88">  {<span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">\thispagestyle</span>{plain}}<span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">%      code to replace</span></span>
<span id="cb1-89">  {<span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">\thispagestyle</span>{ministate}}<span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">%  replacement</span></span>
<span id="cb1-90">  {}{}                          <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">% ← success / failure actions (empty)</span></span>
<span id="cb1-91"><span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">\makeatother</span></span>
<span id="cb1-92"></span>
<span id="cb1-93"></span>
<span id="cb1-94"><span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">\makeatletter</span></span>
<span id="cb1-95"><span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">\def\@maketitle</span>{<span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">%</span></span>
<span id="cb1-96">  <span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">\newpage</span></span>
<span id="cb1-97">  <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">\begin</span>{<span class="ex" style="color: null;
background-color: null;
font-style: inherit;">center</span>}<span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">%</span></span>
<span id="cb1-98">    <span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">\let\footnote\thanks</span></span>
<span id="cb1-99">    {<span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">\LARGE</span> <span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">\@title\par</span>}<span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">%</span></span>
<span id="cb1-100">    <span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">\vskip</span> 0.2em<span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">%</span></span>
<span id="cb1-101">    {<span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">\large</span><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">\begin</span>{<span class="ex" style="color: null;
background-color: null;
font-style: inherit;">tabular</span>}[t]{c}<span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">\@author</span><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">\end</span>{<span class="ex" style="color: null;
background-color: null;
font-style: inherit;">tabular</span>}<span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">\par</span>}<span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">%</span></span>
<span id="cb1-102">    <span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">\vskip</span> 0.2em<span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">%</span></span>
<span id="cb1-103">    {<span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">\large</span> <span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">\@date</span>}<span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">\vskip</span> 0.2em<span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">% % Commented out to remove the date</span></span>
<span id="cb1-104">  <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">\end</span>{<span class="ex" style="color: null;
background-color: null;
font-style: inherit;">center</span>}<span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">%</span></span>
<span id="cb1-105">  <span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">\par</span></span>
<span id="cb1-106">}<span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">%</span></span>
<span id="cb1-107"><span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">\makeatother</span></span></code></pre></div></div>
</div>
<p>An example document:</p>
<div class="code-with-filename">
<div class="code-with-filename-file">
<pre><strong>example.tex</strong></pre>
</div>
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb2" data-filename="example.tex" style="background: #f1f3f5;"><pre class="sourceCode latex code-with-copy"><code class="sourceCode latex"><span id="cb2-1"><span class="bu" style="color: null;
background-color: null;
font-style: inherit;">\documentclass</span>[11pt,a4paper,nonatbib]{<span class="ex" style="color: null;
background-color: null;
font-style: inherit;">./ministate</span>}</span>
<span id="cb2-2"></span>
<span id="cb2-3"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">% (optional) bibliography</span></span>
<span id="cb2-4"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">%\usepackage[backend=biber,style=ieee,autocite=plain,sorting=none]{biblatex}</span></span>
<span id="cb2-5"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">%\addbibresource{biblio.bib}</span></span>
<span id="cb2-6"><span class="bu" style="color: null;
background-color: null;
font-style: inherit;">\usepackage</span>[sfdefault]{<span class="ex" style="color: null;
background-color: null;
font-style: inherit;">atkinson</span>}</span>
<span id="cb2-7"><span class="bu" style="color: null;
background-color: null;
font-style: inherit;">\usepackage</span>{<span class="ex" style="color: null;
background-color: null;
font-style: inherit;">fontawesome</span>}</span>
<span id="cb2-8"></span>
<span id="cb2-9"><span class="bu" style="color: null;
background-color: null;
font-style: inherit;">\usepackage</span>{<span class="ex" style="color: null;
background-color: null;
font-style: inherit;">hyperref</span>}</span>
<span id="cb2-10"><span class="bu" style="color: null;
background-color: null;
font-style: inherit;">\usepackage</span>{<span class="ex" style="color: null;
background-color: null;
font-style: inherit;">url</span>}</span>
<span id="cb2-11"></span>
<span id="cb2-12"><span class="bu" style="color: null;
background-color: null;
font-style: inherit;">\usepackage</span>[english]{<span class="ex" style="color: null;
background-color: null;
font-style: inherit;">babel</span>}</span>
<span id="cb2-13"><span class="bu" style="color: null;
background-color: null;
font-style: inherit;">\usepackage</span>[autostyle,english=british]{<span class="ex" style="color: null;
background-color: null;
font-style: inherit;">csquotes</span>}</span>
<span id="cb2-14"></span>
<span id="cb2-15"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">% (optional) prevent breaking words</span></span>
<span id="cb2-16"><span class="bu" style="color: null;
background-color: null;
font-style: inherit;">\usepackage</span>[none]{<span class="ex" style="color: null;
background-color: null;
font-style: inherit;">hyphenat</span>}</span>
<span id="cb2-17"><span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">\interdisplaylinepenalty</span>=10000</span>
<span id="cb2-18"></span>
<span id="cb2-19"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">% lorem ipsum generator</span></span>
<span id="cb2-20"><span class="bu" style="color: null;
background-color: null;
font-style: inherit;">\usepackage</span>{<span class="ex" style="color: null;
background-color: null;
font-style: inherit;">kantlipsum</span>}</span>
<span id="cb2-21"></span>
<span id="cb2-22"></span>
<span id="cb2-23"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">% ministate settings</span></span>
<span id="cb2-24"><span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">\title</span>{The Summary of Lorem Ipsum}</span>
<span id="cb2-25"><span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">\headertitle</span>{The Shorter Title}  <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">% (optional) will use \title if not used</span></span>
<span id="cb2-26"></span>
<span id="cb2-27"><span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">\author</span>{Johnny English, PhD}</span>
<span id="cb2-28"><span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">\headerauthor</span>{Johnny E., PhD} <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">% (optional) will use \author if not used</span></span>
<span id="cb2-29"></span>
<span id="cb2-30"><span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">\date</span>{May 7, 2025}</span>
<span id="cb2-31"></span>
<span id="cb2-32"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">\begin</span>{<span class="ex" style="color: null;
background-color: null;
font-style: inherit;">document</span>}</span>
<span id="cb2-33"></span>
<span id="cb2-34"><span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">\maketitle</span></span>
<span id="cb2-35"></span>
<span id="cb2-36"><span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">\kant</span></span>
<span id="cb2-37"></span>
<span id="cb2-38"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">%\clearpage</span></span>
<span id="cb2-39"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">%\printbibliography[title={Osebna bibliografija}]</span></span>
<span id="cb2-40"></span>
<span id="cb2-41"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">\end</span>{<span class="ex" style="color: null;
background-color: null;
font-style: inherit;">document</span>}</span></code></pre></div></div>
</div>
</section>
<section id="outcome" class="level2">
<h2 class="anchored" data-anchor-id="outcome">Outcome</h2>
<p>Below is the rendered PDF output from the code above:</p>
<embed src="./sample.pdf" width="100%" height="700px">


</section>

<div id="quarto-appendix" class="default"><section class="quarto-appendix-contents" id="quarto-reuse"><h2 class="anchored quarto-appendix-heading">Reuse</h2><div class="quarto-appendix-contents"><div><a rel="license" href="https://creativecommons.org/licenses/by-nc-sa/4.0/deed.en_us">CC BY-NC-SA 4.0</a></div></div></section></div> ]]></description>
  <category>LaTeX</category>
  <guid>https://gcerar.github.io/posts/2025-05-07-ministate-v3/</guid>
  <pubDate>Tue, 06 May 2025 22:00:00 GMT</pubDate>
  <media:content url="https://gcerar.github.io/posts/2025-05-07-ministate-v3/featured.png" medium="image" type="image/png" height="144" width="144"/>
</item>
<item>
  <title>Visualizing Feature Maps from VGG11 and ResNet50 in PyTorch</title>
  <dc:creator>Gregor Cerar</dc:creator>
  <link>https://gcerar.github.io/posts/2025-05-06-feature-maps/</link>
  <description><![CDATA[ 





<section id="prerequisites" class="level2">
<h2 class="anchored" data-anchor-id="prerequisites">Prerequisites</h2>
<p>Before we start, we need to install the following libraries: <a href="https://numpy.org/">NumPy</a>, <a href="https://matplotlib.org/">Matplotlib</a>, <a href="https://pytorch.org/">PyTorch</a>, and <a href="https://pytorch.org/vision/stable/index.html">Torchvision</a>.</p>
<div id="e25e5496" class="cell">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb1" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb1-1"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> math</span>
<span id="cb1-2"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">from</span> collections.abc <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> Callable</span>
<span id="cb1-3"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">from</span> pathlib <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> Path</span>
<span id="cb1-4"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">from</span> typing <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> Final, Literal</span>
<span id="cb1-5"></span>
<span id="cb1-6"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> numpy <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">as</span> np</span>
<span id="cb1-7"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> torch</span>
<span id="cb1-8"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">from</span> IPython.display <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> Markdown</span>
<span id="cb1-9"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">from</span> matplotlib <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> pyplot <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">as</span> plt</span>
<span id="cb1-10"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">from</span> sklearn.decomposition <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> PCA</span>
<span id="cb1-11"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">from</span> torch <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> Tensor, nn</span>
<span id="cb1-12"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">from</span> torchvision <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> models</span>
<span id="cb1-13"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">from</span> torchvision.io <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> decode_image</span>
<span id="cb1-14"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">from</span> torchvision.transforms <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> v2 <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">as</span> T</span></code></pre></div></div>
</div>
<p>In this article, we are going to use pre-trained neural networks. More specifically, weights trained on <a href="https://www.image-net.org/download.php">ImageNet-1K dataset</a>.</p>
<p>But before that, we will prepare input images. We will size the image(s) to 224x224 and normalize it for optimal performance. The preparation step will make the pictures similar to the training dataset. See the link for more details on why this step is necessary.</p>
<div id="f3b3920d" class="cell">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb2" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb2-1"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># ImageNet normalization weights per channel</span></span>
<span id="cb2-2">IMAGENET1K_MEAN <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> [<span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.485</span>, <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.456</span>, <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.406</span>]</span>
<span id="cb2-3">IMAGENET1K_STD <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> [<span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.229</span>, <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.224</span>, <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.225</span>]</span>
<span id="cb2-4"></span>
<span id="cb2-5">transform <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> T.Compose(</span>
<span id="cb2-6">    [</span>
<span id="cb2-7">        T.Resize(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">256</span>),</span>
<span id="cb2-8">        T.CenterCrop(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">224</span>),</span>
<span id="cb2-9">        T.ToImage(),</span>
<span id="cb2-10">        T.ToDtype(torch.float32, scale<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>),</span>
<span id="cb2-11">        T.Normalize(IMAGENET1K_MEAN, IMAGENET1K_STD),</span>
<span id="cb2-12">    ]</span>
<span id="cb2-13">)</span>
<span id="cb2-14"></span>
<span id="cb2-15"></span>
<span id="cb2-16"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> load_image(path: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">str</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">|</span> Path) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-&gt;</span> Tensor:</span>
<span id="cb2-17">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Transform images into tensors</span></span>
<span id="cb2-18">    img: Tensor <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> transform(decode_image(<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">str</span>(path)))</span>
<span id="cb2-19"></span>
<span id="cb2-20">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Add dimension to imitate batch size equal to 1: (C,H,W) -&gt; (B,C,H,W)</span></span>
<span id="cb2-21">    img <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> img.unsqueeze(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>)</span>
<span id="cb2-22">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> img</span></code></pre></div></div>
</div>
<div id="590ed8c9" class="cell">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb3" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb3-1"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> inverse_normalize(</span>
<span id="cb3-2">    x_norm: Tensor,</span>
<span id="cb3-3">    mean: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">list</span>[<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">float</span>] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> IMAGENET1K_MEAN,</span>
<span id="cb3-4">    std: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">list</span>[<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">float</span>] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> IMAGENET1K_STD,</span>
<span id="cb3-5">) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-&gt;</span> Tensor:</span>
<span id="cb3-6">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Ensure mean and std have the correct shape</span></span>
<span id="cb3-7">    _mean <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.as_tensor(mean).to(x_norm.device).view(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>)</span>
<span id="cb3-8">    _std <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.as_tensor(std).to(x_norm.device).view(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>)</span>
<span id="cb3-9">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Inverse normalization: x = x_normalized * std + mean</span></span>
<span id="cb3-10">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> x_norm.mul(_std).add(_mean)</span>
<span id="cb3-11"></span>
<span id="cb3-12"></span>
<span id="cb3-13">reverse_transform <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> T.Compose(</span>
<span id="cb3-14">    [</span>
<span id="cb3-15">        T.Lambda(inverse_normalize),</span>
<span id="cb3-16">        T.Lambda(<span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">lambda</span> x: torch.clamp(x, <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">min</span><span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.0</span>, <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">max</span><span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">1.0</span>)),</span>
<span id="cb3-17">    ]</span>
<span id="cb3-18">)</span></code></pre></div></div>
</div>
<div id="1773381c" class="cell">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb4" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb4-1">sample <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> load_image(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"bridge.jpg"</span>)</span>
<span id="cb4-2">orig_sample <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> reverse_transform(sample)</span>
<span id="cb4-3"></span>
<span id="cb4-4">fig, ax <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> plt.subplots(frameon<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">False</span>)</span>
<span id="cb4-5">fig.subplots_adjust()</span>
<span id="cb4-6">ax.imshow(orig_sample.squeeze(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>).permute(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>))</span>
<span id="cb4-7">ax.axis(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"off"</span>)</span>
<span id="cb4-8">plt.show()</span></code></pre></div></div>
<div class="cell-output cell-output-display">
<div class="quarto-figure quarto-figure-center">
<figure class="figure">
<p><img src="https://gcerar.github.io/posts/2025-05-06-feature-maps/index_files/figure-html/cell-6-output-1.png" class="img-fluid figure-img"></p>
<figcaption>Original image, resized</figcaption>
</figure>
</div>
</div>
</div>
<div id="4f899d6b" class="cell">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb5" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb5-1"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> get_activation(name: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">str</span>, activations: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">dict</span>[<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">str</span>, Tensor]) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-&gt;</span> Callable:</span>
<span id="cb5-2">    <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> hook(model: nn.Module, tensor: Tensor, output: Tensor) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-&gt;</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span>:</span>
<span id="cb5-3">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># map layer's `name` to layer's output value</span></span>
<span id="cb5-4">        activations[name] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> output.detach()</span>
<span id="cb5-5"></span>
<span id="cb5-6">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> hook</span>
<span id="cb5-7"></span>
<span id="cb5-8"></span>
<span id="cb5-9"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> set_hooks(model: nn.Module, layer_ids: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">list</span>[<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">str</span>], out: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">dict</span>[<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">str</span>, Tensor]) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-&gt;</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span>:</span>
<span id="cb5-10">    layer_ids <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> [<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">str</span>(i) <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> i <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> layer_ids]</span>
<span id="cb5-11">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> name, module <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> model.named_modules():</span>
<span id="cb5-12">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> name <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> layer_ids:</span>
<span id="cb5-13">            module.register_forward_hook(get_activation(name, out))</span></code></pre></div></div>
</div>
<div id="f1cdfc24" class="cell">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb6" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb6-1"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> visualize_feature_maps(</span>
<span id="cb6-2">    feature_map: Tensor <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">|</span> np.ndarray,</span>
<span id="cb6-3">    max_maps: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">|</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span>,</span>
<span id="cb6-4">    max_cols: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">8</span>,</span>
<span id="cb6-5">    figsize_per_plot: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">float</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">1.0</span>,</span>
<span id="cb6-6">    norm: Literal[<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"linear"</span>, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"log"</span>, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"symlog"</span>, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"logit"</span>, <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span>] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span>,</span>
<span id="cb6-7">    cmap: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">str</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"viridis"</span>,</span>
<span id="cb6-8">):</span>
<span id="cb6-9">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">isinstance</span>(feature_map, Tensor):</span>
<span id="cb6-10">        feature_map <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> feature_map.cpu().numpy()</span>
<span id="cb6-11"></span>
<span id="cb6-12">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> feature_map.ndim <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">==</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">4</span>:</span>
<span id="cb6-13">        feature_map <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> feature_map.squeeze(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>)  <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># remove batch dimension if present</span></span>
<span id="cb6-14">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">assert</span> feature_map.ndim <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">==</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">3</span>, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"Expected tensor shape (C, H, W)"</span></span>
<span id="cb6-15"></span>
<span id="cb6-16">    C, H, W <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> feature_map.shape</span>
<span id="cb6-17"></span>
<span id="cb6-18">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> max_maps:</span>
<span id="cb6-19">        C <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">min</span>(C, max_maps)</span>
<span id="cb6-20"></span>
<span id="cb6-21">    n_cols <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">min</span>(C, max_cols)</span>
<span id="cb6-22">    n_rows <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> math.ceil(C <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span> n_cols)</span>
<span id="cb6-23"></span>
<span id="cb6-24">    figsize <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> (figsize_per_plot <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> n_cols, figsize_per_plot <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> n_rows)</span>
<span id="cb6-25"></span>
<span id="cb6-26">    fig, axes <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> plt.subplots(nrows<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>n_rows, ncols<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>n_cols, figsize<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>figsize, frameon<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">False</span>, squeeze<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">False</span>)</span>
<span id="cb6-27">    fig.subplots_adjust(wspace<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.03</span>, hspace<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.03</span>)</span>
<span id="cb6-28"></span>
<span id="cb6-29">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> ax <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> axes.flat:</span>
<span id="cb6-30">        ax.axis(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"off"</span>)</span>
<span id="cb6-31"></span>
<span id="cb6-32">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> i <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">range</span>(C):</span>
<span id="cb6-33">        t <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> feature_map[i]</span>
<span id="cb6-34">        axes.flat[i].imshow(t, cmap<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>cmap, norm<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>norm, aspect<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"equal"</span>, interpolation<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"none"</span>)</span>
<span id="cb6-35"></span>
<span id="cb6-36">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> fig, axes</span></code></pre></div></div>
</div>
<div id="e785fb60" class="cell">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb7" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb7-1"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> minmax_scale_per_channel(arr: np.ndarray, eps: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">float</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">1e-5</span>) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-&gt;</span> np.ndarray:</span>
<span id="cb7-2">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">"""Per-channel MinMax normalization. Expects (C, W, H)."""</span></span>
<span id="cb7-3">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">assert</span> arr.ndim <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">==</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">3</span>, <span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">f"</span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>arr<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">.</span>ndim<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">"</span></span>
<span id="cb7-4"></span>
<span id="cb7-5">    c_min <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> arr.<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">min</span>(axis<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>), keepdims<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>)</span>
<span id="cb7-6">    c_max <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> arr.<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">max</span>(axis<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>), keepdims<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>)</span>
<span id="cb7-7"></span>
<span id="cb7-8">    scaled <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> (arr <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span> c_min) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span> (c_max <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span> c_min <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> eps)  <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># avoid division by zero</span></span>
<span id="cb7-9">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> scaled</span>
<span id="cb7-10"></span>
<span id="cb7-11"></span>
<span id="cb7-12"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> pca_rgb(</span>
<span id="cb7-13">    feature_map: np.ndarray <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">|</span> Tensor,</span>
<span id="cb7-14">    n_components: Literal[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">3</span>] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">3</span>,</span>
<span id="cb7-15">    normalize: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">bool</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>,</span>
<span id="cb7-16">    random_state: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">|</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span>,</span>
<span id="cb7-17">) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-&gt;</span> np.ndarray:</span>
<span id="cb7-18">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">isinstance</span>(feature_map, torch.Tensor):</span>
<span id="cb7-19">        feature_map <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> feature_map.cpu().numpy()</span>
<span id="cb7-20"></span>
<span id="cb7-21">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> feature_map.ndim <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">==</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">4</span>:</span>
<span id="cb7-22">        feature_map <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> feature_map.squeeze(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>)  <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># remove batch dimension if present</span></span>
<span id="cb7-23">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">assert</span> feature_map.ndim <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">==</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">3</span>, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"Expected array shape (C, H, W)"</span></span>
<span id="cb7-24"></span>
<span id="cb7-25">    C, H, W <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> feature_map.shape</span>
<span id="cb7-26">    pca <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> PCA(n_components<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>n_components, random_state<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>random_state)</span>
<span id="cb7-27">    flat <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> feature_map.reshape(C, <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>).T</span>
<span id="cb7-28">    rgb <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> pca.fit_transform(flat).T.reshape(n_components, H, W)</span>
<span id="cb7-29"></span>
<span id="cb7-30">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> normalize:</span>
<span id="cb7-31">        rgb <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> minmax_scale_per_channel(rgb)</span>
<span id="cb7-32"></span>
<span id="cb7-33">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> rgb</span>
<span id="cb7-34"></span>
<span id="cb7-35"></span>
<span id="cb7-36"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> visualize_feature_maps_pca(</span>
<span id="cb7-37">    feature_maps: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">dict</span>[<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">str</span>, Tensor],</span>
<span id="cb7-38">    n_components: Literal[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">3</span>] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">3</span>,</span>
<span id="cb7-39">    max_cols: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">4</span>,</span>
<span id="cb7-40">    figsize_per_plot: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">float</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">2.0</span>,</span>
<span id="cb7-41">    norm: Literal[<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"linear"</span>, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"log"</span>, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"symlog"</span>, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"logit"</span>, <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span>] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span>,</span>
<span id="cb7-42">    subtitles: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">bool</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>,</span>
<span id="cb7-43">    cmap: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">str</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"viridis"</span>,</span>
<span id="cb7-44">):</span>
<span id="cb7-45">    c <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">len</span>(feature_maps)</span>
<span id="cb7-46">    n_cols <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">min</span>(c, max_cols)</span>
<span id="cb7-47">    n_rows <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> math.ceil(c <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span> n_cols)</span>
<span id="cb7-48">    fig_size <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> (figsize_per_plot <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> n_cols, figsize_per_plot <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> n_rows)</span>
<span id="cb7-49"></span>
<span id="cb7-50">    fig, axes <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> plt.subplots(n_rows, n_cols, figsize<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>fig_size, squeeze<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">False</span>, frameon<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">False</span>)</span>
<span id="cb7-51">    fig.subplots_adjust(wspace<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.03</span>, hspace<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.20</span>, top<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.85</span>)</span>
<span id="cb7-52"></span>
<span id="cb7-53">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> ax <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> axes.flat:</span>
<span id="cb7-54">        ax.axis(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"off"</span>)</span>
<span id="cb7-55"></span>
<span id="cb7-56">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> ax, (layer, feature_map) <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">zip</span>(axes.flat, feature_maps.items(), strict<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">False</span>):</span>
<span id="cb7-57">        rgb_features <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> pca_rgb(feature_map, n_components<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>n_components)</span>
<span id="cb7-58">        rgb_features <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> rgb_features.transpose(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>)</span>
<span id="cb7-59">        rgb_features <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> rgb_features.squeeze()</span>
<span id="cb7-60"></span>
<span id="cb7-61">        ax.imshow(rgb_features, cmap<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>cmap, norm<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>norm, aspect<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"equal"</span>, interpolation<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"none"</span>)</span>
<span id="cb7-62">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> subtitles:</span>
<span id="cb7-63">            ax.set_title(layer, color<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"0.5"</span>)</span>
<span id="cb7-64"></span>
<span id="cb7-65">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> fig, axes</span></code></pre></div></div>
</div>
</section>
<section id="vgg" class="level2">
<h2 class="anchored" data-anchor-id="vgg">VGG</h2>
<p>The <a href="https://en.wikipedia.org/wiki/VGGNet">VGG</a> are deep neural networks introduced by <span class="citation" data-cites="simonyan2014very">(Simonyan and Zisserman 2014)</span> in 2014. The VGG stacks many small 3x3 convolution filters in sequence. This simple “deeper‑is‑better” design once achieved top ImageNet performance while showing that depth and uniform layer structure can yield strong feature hierarchies, making VGG a popular baseline for vision tasks and transfer learning. Nowadays, they are considered outdated.</p>
<div id="8b3d891d" class="cell">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb8" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb8-1">model <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> models.vgg11(weights<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>models.VGG11_Weights.IMAGENET1K_V1).features</span>
<span id="cb8-2"></span>
<span id="cb8-3"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Let's inspect the VGG's feature extractor layers</span></span>
<span id="cb8-4">model</span></code></pre></div></div>
<div class="cell-output cell-output-display">
<pre><code>Sequential(
  (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): ReLU(inplace=True)
  (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (4): ReLU(inplace=True)
  (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (6): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (7): ReLU(inplace=True)
  (8): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (9): ReLU(inplace=True)
  (10): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (11): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (12): ReLU(inplace=True)
  (13): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (14): ReLU(inplace=True)
  (15): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (16): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (17): ReLU(inplace=True)
  (18): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (19): ReLU(inplace=True)
  (20): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)</code></pre>
</div>
</div>
<div id="6f8d3ab5" class="cell">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb10" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb10-1"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># cherry-pick layers of which outputs we want to see</span></span>
<span id="cb10-2">selected_layers <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> [<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"0"</span>, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"3"</span>, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"6"</span>, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"8"</span>, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"11"</span>, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"13"</span>, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"16"</span>, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"18"</span>]</span>
<span id="cb10-3"></span>
<span id="cb10-4"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># add forward hooks to the model</span></span>
<span id="cb10-5">vgg_activations <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> {}</span>
<span id="cb10-6">set_hooks(model, selected_layers, vgg_activations)</span>
<span id="cb10-7"></span>
<span id="cb10-8"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># make forward pass through NN</span></span>
<span id="cb10-9"><span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">with</span> torch.no_grad():</span>
<span id="cb10-10">    model(sample)</span></code></pre></div></div>
</div>
<p>Below we visualize the feature maps generated by a few hand‑picked layers. A feature map (also called an <em>activation map</em>) is simply the tensor that a layer outputs (for example, <code>output = conv(input)</code>). During training, each convolutional layer learns a set of spatial kernels that act as filters (_see <a href="https://en.wikipedia.org/wiki/Kernel_(image*processing)">kernels in image processing</a>*), allowing the network to draw ever‑richer patterns from the feature maps produced by the preceding layers.</p>
<div id="591ee8bf" class="cell">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb11" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb11-1"><span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> layer, filters <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> vgg_activations.items():</span>
<span id="cb11-2">    display(Markdown(<span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">f"### Layer #</span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>layer<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">"</span>))</span>
<span id="cb11-3">    visualize_feature_maps(filters, max_maps<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">8</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">8</span>, norm<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"linear"</span>)</span>
<span id="cb11-4">    plt.show()</span></code></pre></div></div>
<section id="layer-0" class="level3 cell-output cell-output-display cell-output-markdown">
<h3 class="anchored" data-anchor-id="layer-0">Layer #0</h3>
</section>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><img src="https://gcerar.github.io/posts/2025-05-06-feature-maps/index_files/figure-html/cell-12-output-2.png" class="img-fluid figure-img"></p>
</figure>
</div>
</div>
<section id="layer-3" class="level3 cell-output cell-output-display cell-output-markdown">
<h3 class="anchored" data-anchor-id="layer-3">Layer #3</h3>
</section>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><img src="https://gcerar.github.io/posts/2025-05-06-feature-maps/index_files/figure-html/cell-12-output-4.png" class="img-fluid figure-img"></p>
</figure>
</div>
</div>
<section id="layer-6" class="level3 cell-output cell-output-display cell-output-markdown">
<h3 class="anchored" data-anchor-id="layer-6">Layer #6</h3>
</section>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><img src="https://gcerar.github.io/posts/2025-05-06-feature-maps/index_files/figure-html/cell-12-output-6.png" class="img-fluid figure-img"></p>
</figure>
</div>
</div>
<section id="layer-8" class="level3 cell-output cell-output-display cell-output-markdown">
<h3 class="anchored" data-anchor-id="layer-8">Layer #8</h3>
</section>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><img src="https://gcerar.github.io/posts/2025-05-06-feature-maps/index_files/figure-html/cell-12-output-8.png" class="img-fluid figure-img"></p>
</figure>
</div>
</div>
<section id="layer-11" class="level3 cell-output cell-output-display cell-output-markdown">
<h3 class="anchored" data-anchor-id="layer-11">Layer #11</h3>
</section>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><img src="https://gcerar.github.io/posts/2025-05-06-feature-maps/index_files/figure-html/cell-12-output-10.png" class="img-fluid figure-img"></p>
</figure>
</div>
</div>
<section id="layer-13" class="level3 cell-output cell-output-display cell-output-markdown">
<h3 class="anchored" data-anchor-id="layer-13">Layer #13</h3>
</section>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><img src="https://gcerar.github.io/posts/2025-05-06-feature-maps/index_files/figure-html/cell-12-output-12.png" class="img-fluid figure-img"></p>
</figure>
</div>
</div>
<section id="layer-16" class="level3 cell-output cell-output-display cell-output-markdown">
<h3 class="anchored" data-anchor-id="layer-16">Layer #16</h3>
</section>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><img src="https://gcerar.github.io/posts/2025-05-06-feature-maps/index_files/figure-html/cell-12-output-14.png" class="img-fluid figure-img"></p>
</figure>
</div>
</div>
<section id="layer-18" class="level3 cell-output cell-output-display cell-output-markdown">
<h3 class="anchored" data-anchor-id="layer-18">Layer #18</h3>
</section>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><img src="https://gcerar.github.io/posts/2025-05-06-feature-maps/index_files/figure-html/cell-12-output-16.png" class="img-fluid figure-img"></p>
</figure>
</div>
</div>
</div>
<p>Above, we noted that the number of visualizations grows with the number of filters. A large number of filters can be overwhelming when a layer produces dozens of maps. To condense this information, we can project the feature maps with principal‑component analysis (PCA). We treat each spatial position across all maps as a feature vector, run PCA, and then reconstruct the dominant components. The result is a single “average” activation image that captures the most salient variance across the entire stack of feature maps. It can be rendered in either 1‑channel (grayscale) or 3‑channel (RGB) form.</p>
<div id="3f5e8e03" class="cell">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb12" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb12-1">visualize_feature_maps_pca(vgg_activations, max_cols<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">4</span>)</span>
<span id="cb12-2">plt.show()</span></code></pre></div></div>
<div class="cell-output cell-output-display">
<div class="quarto-figure quarto-figure-center">
<figure class="figure">
<p><img src="https://gcerar.github.io/posts/2025-05-06-feature-maps/index_files/figure-html/cell-13-output-1.png" class="img-fluid figure-img"></p>
<figcaption>Principal‑component projections of the feature‑map stacks for the corresponding layers of VGG‑11.</figcaption>
</figure>
</div>
</div>
</div>
</section>
<section id="resnet" class="level2">
<h2 class="anchored" data-anchor-id="resnet">ResNet</h2>
<p>ResNets (Residual Networks), introduced by <span class="citation" data-cites="targ2016resnet">(Targ et al. 2016)</span>, add “skip” or residual connections that let inputs bypass one or more layers. These identity shortcuts make very deep CNNs (<em>e.g.,</em> ResNet‑50/101/152) easier to train by mitigating vanishing gradients, enabling state‑of‑the‑art accuracy with hundreds of layers. [<a href="https://en.wikipedia.org/wiki/Residual_neural_network">wiki</a>]</p>
<div id="56cf3649" class="cell">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb13" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb13-1">model <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> models.resnet50(weights<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>models.ResNet50_Weights.IMAGENET1K_V1)</span>
<span id="cb13-2"></span>
<span id="cb13-3"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># inspect layers within ResNet</span></span>
<span id="cb13-4">model</span></code></pre></div></div>
<div class="cell-output cell-output-display">
<pre><code>ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): Bottleneck(
      (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (2): Bottleneck(
      (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
  )
  (layer2): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): Bottleneck(
      (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (2): Bottleneck(
      (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (3): Bottleneck(
      (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
  )
  (layer3): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): Bottleneck(
      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (2): Bottleneck(
      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (3): Bottleneck(
      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (4): Bottleneck(
      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (5): Bottleneck(
      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
  )
  (layer4): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): Bottleneck(
      (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (2): Bottleneck(
      (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): Linear(in_features=2048, out_features=1000, bias=True)
)</code></pre>
</div>
</div>
<div id="6d13d56e" class="cell">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb15" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb15-1">selected_layers <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> [<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"conv1"</span>, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"layer1"</span>, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"layer2"</span>, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"layer3"</span>, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"layer4"</span>]</span>
<span id="cb15-2">resnet_feature_maps: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">dict</span>[<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">str</span>, Tensor] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> {}</span>
<span id="cb15-3">set_hooks(model, selected_layers, resnet_feature_maps)</span>
<span id="cb15-4"></span>
<span id="cb15-5"><span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">with</span> torch.no_grad():</span>
<span id="cb15-6">    model(sample)</span></code></pre></div></div>
</div>
<div id="5fb235f2" class="cell">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb16" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb16-1"><span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> layer, filters <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> resnet_feature_maps.items():</span>
<span id="cb16-2">    display(Markdown(<span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">f'### Layer "</span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>layer<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">"'</span>))</span>
<span id="cb16-3">    visualize_feature_maps(filters, max_maps<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">8</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">8</span>, norm<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"linear"</span>)</span>
<span id="cb16-4">    plt.show()</span></code></pre></div></div>
<section id="layer-conv1" class="level3 cell-output cell-output-display cell-output-markdown">
<h3 class="anchored" data-anchor-id="layer-conv1">Layer “conv1”</h3>
</section>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><img src="https://gcerar.github.io/posts/2025-05-06-feature-maps/index_files/figure-html/cell-17-output-2.png" class="img-fluid figure-img"></p>
</figure>
</div>
</div>
<section id="layer-layer1" class="level3 cell-output cell-output-display cell-output-markdown">
<h3 class="anchored" data-anchor-id="layer-layer1">Layer “layer1”</h3>
</section>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><img src="https://gcerar.github.io/posts/2025-05-06-feature-maps/index_files/figure-html/cell-17-output-4.png" class="img-fluid figure-img"></p>
</figure>
</div>
</div>
<section id="layer-layer2" class="level3 cell-output cell-output-display cell-output-markdown">
<h3 class="anchored" data-anchor-id="layer-layer2">Layer “layer2”</h3>
</section>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><img src="https://gcerar.github.io/posts/2025-05-06-feature-maps/index_files/figure-html/cell-17-output-6.png" class="img-fluid figure-img"></p>
</figure>
</div>
</div>
<section id="layer-layer3" class="level3 cell-output cell-output-display cell-output-markdown">
<h3 class="anchored" data-anchor-id="layer-layer3">Layer “layer3”</h3>
</section>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><img src="https://gcerar.github.io/posts/2025-05-06-feature-maps/index_files/figure-html/cell-17-output-8.png" class="img-fluid figure-img"></p>
</figure>
</div>
</div>
<section id="layer-layer4" class="level3 cell-output cell-output-display cell-output-markdown">
<h3 class="anchored" data-anchor-id="layer-layer4">Layer “layer4”</h3>
</section>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><img src="https://gcerar.github.io/posts/2025-05-06-feature-maps/index_files/figure-html/cell-17-output-10.png" class="img-fluid figure-img"></p>
</figure>
</div>
</div>
</div>
<div id="34732906" class="cell">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb17" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb17-1">visualize_feature_maps_pca(resnet_feature_maps, max_cols<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">3</span>)</span>
<span id="cb17-2">plt.show()</span></code></pre></div></div>
<div class="cell-output cell-output-display">
<div class="quarto-figure quarto-figure-center">
<figure class="figure">
<p><img src="https://gcerar.github.io/posts/2025-05-06-feature-maps/index_files/figure-html/cell-18-output-1.png" class="img-fluid figure-img"></p>
<figcaption>Principal‑component projections of the feature‑map stacks for the corresponding layers of ResNet-50.</figcaption>
</figure>
</div>
</div>
</div>
</section>
<section id="conclusions" class="level2">
<h2 class="anchored" data-anchor-id="conclusions">Conclusions</h2>
<p>This article introduced a lightweight technique for visualizing pre-selected neural network layers’ feature maps (layer‑wise outputs). These visualizations offer an intuitive window into <strong>what</strong> a convolutional network attends to at each processing stage.</p>
<p>For deeper, production‑grade interpretability, explore the rich ecosystem of <strong>explainability libraries and frameworks</strong>, such as Captum or SHAP, and take a broader look at the rapidly growing fields of eXplainable AI (XAI) and Responsible AI.</p>
<ul>
<li><a href="https://captum.ai/">Captum</a></li>
<li><a href="https://shap.readthedocs.io/en/latest/">SHAP</a></li>
</ul>



</section>

<div id="quarto-appendix" class="default"><section class="quarto-appendix-contents" id="quarto-bibliography"><h2 class="anchored quarto-appendix-heading">References</h2><div id="refs" class="references csl-bib-body hanging-indent">
<div id="ref-simonyan2014very" class="csl-entry">
Simonyan, Karen, and Andrew Zisserman. 2014. <span>“Very Deep Convolutional Networks for Large-Scale Image Recognition.”</span> <em>arXiv Preprint arXiv:1409.1556</em>.
</div>
<div id="ref-targ2016resnet" class="csl-entry">
Targ, Sasha, Diogo Almeida, and Kevin Lyman. 2016. <span>“Resnet in Resnet: Generalizing Residual Architectures.”</span> <em>arXiv Preprint arXiv:1603.08029</em>.
</div>
</div></section><section class="quarto-appendix-contents" id="quarto-reuse"><h2 class="anchored quarto-appendix-heading">Reuse</h2><div class="quarto-appendix-contents"><div><a rel="license" href="https://creativecommons.org/licenses/by-nc-sa/4.0/deed.en_us">CC BY-NC-SA 4.0</a></div></div></section></div> ]]></description>
  <category>pytorch</category>
  <category>nn</category>
  <guid>https://gcerar.github.io/posts/2025-05-06-feature-maps/</guid>
  <pubDate>Mon, 05 May 2025 22:00:00 GMT</pubDate>
  <media:content url="https://gcerar.github.io/posts/2025-05-06-feature-maps/featured.webp" medium="image" type="image/webp"/>
</item>
<item>
  <title>Research Compute Infrastructure</title>
  <dc:creator>Gregor Cerar</dc:creator>
  <link>https://gcerar.github.io/posts/2023-11-20-research-compute-infrastructure/</link>
  <description><![CDATA[ 





<section id="introduction" class="level1">
<h1>Introduction</h1>
<p>Recent technological advances have transformed education, elevating the quality of teaching and learning. Jupyter Notebooks have emerged as a leading tool for interactive computing, programming, and data analysis <span class="citation" data-cites="perkel2018jupyter mendez2019toward granger2021storytelling">(Perkel 2018; <span class="nocase">Mendez et al.</span> 2019; <span class="nocase">Granger et al.</span> 2021)</span>. However, hardware limitations became a significant hurdle when handling larger research projects. While public cloud services are an option, they come with notable drawbacks. In response, we developed a private cloud solution for our lab using Kubernetes. This solution addresses cost and security concerns while ensuring adaptability. Through this technology, we have enabled efficient app management, scalability, and resource flexibility.</p>
<section id="jupyter-notebooks" class="level2">
<h2 class="anchored" data-anchor-id="jupyter-notebooks">Jupyter Notebooks</h2>
<p>A <a href="https://en.wikipedia.org/wiki/Project_Jupyter">Jupyter Notebook</a> is an open document format based on JSON<sup>1</sup>. Notebooks are organized into a sequence of cells, with each cell containing code, descriptive text, equations, and rich outputs (<em>e.g.</em>, text displays, tables, audio, images, and animations). Tools like JupyterLab provide a platform for interactive code execution, data analysis, and documentation, all within a single interface, culminating in a Jupyter Notebook. These notebooks support various programming languages (<em>e.g.</em>, Python, R, Scala, C++) and allow users to write and execute code cells iteratively (using REPL<sup>2</sup> or WETL<sup>3</sup> approaches), offering immediate visibility of intermediate results. This facilitates the creation of narrative-driven data analyses, educational materials, and interactive presentations. Due to their versatility and interactivity, Jupyter Notebooks are a robust teaching tool for learning, conducting data science, and computer research.</p>
<p>Because of these remarkable features, we decided to incorporate Jupyter Notebooks into our research lab’s educational and research processes. We encouraged students and researchers to use Jupyter Notebooks to document their work and share it more easily with others.</p>
</section>
<section id="scalability" class="level2">
<h2 class="anchored" data-anchor-id="scalability">Scalability</h2>
<p>However, for large-scale projects involving hefty data processing on personal computers, using Jupyter Notebooks becomes a significant challenge. We frequently run into hardware limitations like storage space, RAM, processing power, and access to compute accelerators, which can hinder or even halt our progress. These projects are typically in the early stages of research, analysis, or prototyping, so intensive optimizations are impractical because they can slow down experimental development. Two potential solutions emerge: running Jupyter Notebooks on the grid, HPC infrastructure, or cloud services.</p>
<p>HPC infrastructure, like <a href="https://sling.si">SLING</a> in Slovenia or <a href="https://eurohpc-ju.europa.eu">EuroHPC</a> on a European level, offers immense computational power. However, given that HPCs are significant investments, queue management solutions like SLURM are employed in the HPC world to optimize their use. Computation tasks must be pre-packaged with metadata, code, and input data. These tasks then join a waiting list. This approach is not aligned well with data-driven research, which aims for interactive programming and quick feedback, limiting the full utilization of Jupyter Notebooks. Hence, cloud services become a more common choice for these notebooks.</p>
<p>Public cloud platforms like Google <a href="https://colab.research.google.com">Colab</a> and <a href="https://www.kaggle.com">Kaggle</a> have popularized Jupyter Notebook usage. Users can access the service anytime without queues, edit notebooks, and utilize cloud computing resources, all via a browser. Both services are freely accessible in a limited version. However, due to high user demand, these platforms sometimes limit computational resources, affecting service quality. Alternatives include custom paid services in the public cloud (e.g., <a href="https://aws.amazon.com">AWS</a>, <a href="https://azure.microsoft.com/en-us">Azure</a>, <a href="https://cloud.google.com">GCP</a>, <a href="https://alibabacloud.com">Alibaba Cloud</a>) that tailor infrastructure to customer needs. However, public cloud services have drawbacks, including high rental costs, unpredictable market-affected expenses, and security concerns when handling sensitive data.</p>
<p>Private clouds are an alternative to the public cloud, addressing cost and security challenges. They are crucial for research labs and companies dealing with sensitive data or requiring high adaptability. It grants organizations more transparency and cost control based on their needs and capabilities. Despite the initial technical knowledge and infrastructure investment requirements, private clouds offer enhanced security, control, and flexibility, leading to more predictable costs in the long run.</p>
<p>Several technologies are available to set up a private cloud, including commercial options (<em>e.g.,</em> VMware <a href="https://www.vmware.com/products/vsphere.html">vSphere</a>, Red Hat <a href="https://www.redhat.com/en/technologies/cloud-computing/openshift">OpenShift</a>, IBM <a href="https://www.ibm.com/docs/en/cloud-private/3.2.x?topic=started-cloud-private-overview">Cloud Private</a>) and open-source solutions (<em>e.g.,</em> <a href="https://tljh.jupyter.org/en/latest/index.html">The Littlest JupyterHub</a>, <a href="https://www.openstack.org">OpenStack</a>, <a href="https://www.eucalyptus.cloud">Eucalyptus</a>, <a href="https://kubernetes.io">Kubernetes</a>, or using Docker Compose [<a href="https://github.com/jupyterhub/jupyterhub-deploy-docker">reference design</a>, <a href="https://github.com/gcerar/jupyterhub-docker">gcerar/jupyterhub-docker</a>]). Among the open-source options, Kubernetes is <a href="https://trends.google.com/trends/explore?cat=5&amp;q=%2Fg%2F11b7lxp79d,%2Fm%2F0cm87w_,%2Fm%2F0cnx0mm&amp;hl=en-US">the most popular</a> solution.</p>
<p><a href="https://kubernetes.io">Kubernetes</a> (abbreviated as K8s) is an open-source platform designed for the automation, management, and deployment of applications within containers. Its advanced orchestration features allow for efficient application management, automatic scaling, monitoring of their performance, and high availability. It can simplify the development and maintenance of complex cloud-based applications.</p>
<p>Contrary to <a href="https://www.docker.com/">Docker</a> and <a href="https://docs.docker.com/compose/">Docker Compose</a>, which primarily focus on building, storing, and running individual containers, Kubernetes offers a much more comprehensive platform for managing containers across expansive environments that span multiple computing nodes. While Docker provides easy creation and operation of individual containers, and docker-compose allows defining multiple containers as application units, Kubernetes facilitates the management of entire clusters of these application units throughout their life cycle, which includes automatic deployment, dynamic adjustments based on load, recovery in case of errors, and more advanced service and network management.</p>
<p>In our research lab, due to the growing computational demands prevalent in data science and the desire to retain the recognizable workflow present in Jupyter Notebooks, we have developed our private cloud solution based on Kubernetes technology.</p>
<p>The following sections will present a private cloud setup featuring Jupyter Notebooks built on top of open-source solutions. The user experience closely resembles that of existing paid cloud services. The private cloud must meet the following requirements:</p>
<ul>
<li><p><strong>System Scalability:</strong> The cloud should allow for easily adding computing nodes to the cluster without disrupting the operational system, supporting larger research projects or teaching groups.</p></li>
<li><p><strong>Efficient Resource Management:</strong> The system must enable precise allocation of resources to users. In this context, an administrator can define a balance between a lax and strict resource allocation policy.</p></li>
<li><p><strong>Enhanced Collaboration Experience:</strong> The system should allow for straightforward sharing of Jupyter Notebooks among users, promoting collaboration on joint projects and idea exchange between researchers and students.</p></li>
<li><p><strong>No Waiting Queues:</strong> The system should eliminate waiting queues, offering users immediate access to computational resources to the best of their capacity.</p></li>
</ul>
</section>
</section>
<section id="architecture" class="level1">
<h1>Architecture</h1>
<p>We decided to base our private cloud on the Kubernetes platform to meet system scalability and resource management requirements, aiming to enhance the functionality, accessibility, and sharing of Jupyter Notebooks <span class="citation" data-cites="bussonnier2018jupyterhub">(Bussonnier 2018)</span> within the Kubernetes private cloud. In this section, we will delve deeper into the system’s architecture that integrates services and elaborate on the design decisions. Subsequently, we describe the individual services within our infrastructure.</p>
<div id="tbl-stack" class="hover quarto-float quarto-figure quarto-figure-center anchored">
<figure class="quarto-float quarto-float-tbl figure">
<figcaption class="quarto-float-caption-top quarto-float-caption quarto-float-tbl" id="tbl-stack-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
Table&nbsp;1: Services and Selected Solutions.
</figcaption>
<div aria-describedby="tbl-stack-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
<table class="table-hover caption-top table">
<thead>
<tr class="header">
<th style="text-align: left;">Service</th>
<th style="text-align: left;">Solutions (<strong>Used in bold</strong>)</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td style="text-align: left;">Turnkey solution?</td>
<td style="text-align: left;"><strong>Custom</strong>, NVIDIA DeepOps</td>
</tr>
<tr class="even">
<td style="text-align: left;"><strong>Basic Infrastructure</strong></td>
<td style="text-align: left;"></td>
</tr>
<tr class="odd">
<td style="text-align: left;">Operating System</td>
<td style="text-align: left;"><strong>Ubuntu</strong>, RHEL, NixOS, Talos</td>
</tr>
<tr class="even">
<td style="text-align: left;">Data Storage</td>
<td style="text-align: left;"><strong>ZFS</strong>, GlusterFS, Lustre, CEPH, iSCSI</td>
</tr>
<tr class="odd">
<td style="text-align: left;">System Management</td>
<td style="text-align: left;"><strong>Ansible</strong>, Terraform, Puppet, Chef</td>
</tr>
<tr class="even">
<td style="text-align: left;"><strong>Internal Services</strong></td>
<td style="text-align: left;"></td>
</tr>
<tr class="odd">
<td style="text-align: left;">K8s Distribution</td>
<td style="text-align: left;"><strong>vanilla</strong>, MicroK8s, OpenShift, Rancher</td>
</tr>
<tr class="even">
<td style="text-align: left;">K8s Installation</td>
<td style="text-align: left;"><strong>Helm</strong>, Kustomize</td>
</tr>
<tr class="odd">
<td style="text-align: left;">Network Manager</td>
<td style="text-align: left;"><strong>Calico</strong>, Canal, Flannel, Weave</td>
</tr>
<tr class="even">
<td style="text-align: left;">Data Manager</td>
<td style="text-align: left;"><strong>csi-driver-nfs</strong>, Rook, OpenEBS</td>
</tr>
<tr class="odd">
<td style="text-align: left;">Traffic Balancing</td>
<td style="text-align: left;"><strong>MetalLB</strong>, cloud provider specific</td>
</tr>
<tr class="even">
<td style="text-align: left;">Traffic Manager</td>
<td style="text-align: left;"><strong>Nginx</strong>, Traefik</td>
</tr>
<tr class="odd">
<td style="text-align: left;">GPU Manager</td>
<td style="text-align: left;"><strong>NVIDIA GPU-Operator</strong></td>
</tr>
<tr class="even">
<td style="text-align: left;"><strong>Services for Users</strong></td>
<td style="text-align: left;"></td>
</tr>
<tr class="odd">
<td style="text-align: left;">JupyterHub Manager</td>
<td style="text-align: left;"><strong>Z2JH (Zero-to-JupyterHub)</strong></td>
</tr>
<tr class="even">
<td style="text-align: left;">Metrics and Monitoring</td>
<td style="text-align: left;"><strong>kube-prometheus-stack</strong>, InfluxDB</td>
</tr>
</tbody>
</table>
</div>
</figure>
</div>
<p>Table&nbsp;1, in its first column, lists all the services required for system operation. The second column lists the open-source solutions that can provide these services. Bolded services indicate those selected and used in our private cloud. We made our choices based on specific criteria. We first surveyed technologies and solutions utilized in related projects. We further narrowed our selection to open-source solutions tested in private clouds on native infrastructure. A significant factor in our decision-making was also an insight into the popularity of the projects, gauged by the number of stars in repositories, the number of forks of the project, and the level of development activity on GitHub/GitLab. In our decision-making process, we didn’t follow a single empirical metric but took multiple factors into account to ensure a comprehensive assessment of solutions.</p>
<div id="fig-architecture" class="quarto-float quarto-figure quarto-figure-center anchored">
<figure class="quarto-float quarto-float-fig figure">
<div aria-describedby="fig-architecture-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
<img src="https://gcerar.github.io/posts/2023-11-20-research-compute-infrastructure/figures/over10k-arch.png" class="img-fluid figure-img">
</div>
<figcaption class="quarto-float-caption-bottom quarto-float-caption quarto-float-fig" id="fig-architecture-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
Figure&nbsp;1: A three-tier logical infrastructure diagram. At the bottom is the foundational infrastructure, followed by internal Kubernetes services in the middle, and on top are the services exposed to users.
</figcaption>
</figure>
</div>
<p>The diagram in Figure&nbsp;1 provides a high-level representation of our private cloud and its infrastructure across three levels. The first level comprises heterogeneous computing nodes, forming the foundational infrastructure. Each node operates its operating system, running a portion of the Kubernetes platform. The second level encompasses internal Kubernetes services, which are essential for operation and never directly accessed by users. The final third level includes the services availed by end-users.</p>
<section id="turnkey-solution" class="level2">
<h2 class="anchored" data-anchor-id="turnkey-solution">Turnkey Solution?</h2>
<p>When planning the private cloud, we initially explored turnkey solutions, including NVIDIA <a href="https://github.com/NVIDIA/deepops">DeepOps</a>. Despite its advantages, we built our custom solution for the following reasons. While DeepOps is an excellent turnkey solution with maintained source code on GitHub and offers commercial support, initial setup requires configuration file adjustments, including Ansible scripts for automated (re)configuration of installed Linux distribution. Its complexity discouraged us from further investing our time in tinkering with it.</p>
<p>One of our biggest concerns was the intricate solution that tries to be versatile and “simple”. However, this inevitably leads to hiding functionalities and, in case of issues, jumping around documentation of multiple unrelated internally used tools. Despite proclaimed simplicity, troubleshooting or upgrade problems require manual intervention, where a thorough understanding of Linux, DeepOps, its internal tooling, and their interactions is necessary for system control. Therefore, we decided to start with a minimalist solution and, over time, plan to expand the system to understand the infrastructure’s operation better.</p>
</section>
<section id="foundation-infrastructure" class="level2">
<h2 class="anchored" data-anchor-id="foundation-infrastructure">Foundation Infrastructure</h2>
<p>In this section, we discuss the foundation infrastructure of our private cloud solution. We’ll go through these building blocks, including the selection of container management tools and resource sharing, which are vital for the operation of the Kubernetes platform.</p>
<p><strong>Operating System:</strong> We chose Ubuntu Server based on the Debian Linux distribution for our system. The advantage of widely used Debian-based Linux distributions is the abundance of available knowledge resources and support, making problem-solving more accessible. Among alternatives, like declarative binary reproducible <a href="https://nixos.org">NixOS</a> and RHEL-based distributions, we also considered the <a href="https://www.talos.dev">Talos</a> distribution specialized for Kubernetes. However, we preferred to stick with Ubuntu Server due to the Talos project’s novelty and associated risks.</p>
<p><strong>Container Management:</strong> For container management, we selected <a href="https://containerd.io">ContainerD</a>, also used in the DeepOps solution and officially supported by NVIDIA. It is an open-source tool that implements the CRI interface for communication between the operating system and Kubernetes for efficient and reliable container management.</p>
<p><strong>Data Storage:</strong> For data storage, we chose <a href="https://en.wikipedia.org/wiki/OpenZFS">ZFS</a>, which resides on one of the nodes. Although solutions like <a href="https://en.wikipedia.org/wiki/Apache_Hadoop#HDFS">HDFS</a>, <a href="https://www.gluster.org/">Gluster</a>, <a href="https://www.lustre.org/">Lustre</a>, or <a href="https://ceph.io/en/">Ceph</a> are far more common in the HPC world, they require dedicated infrastructure and tools to offer features offered by ZFS out-of-the-box. Features include checkpoints, data deduplication, compression, a COW (copy-on-write) system to prevent data loss during writing, immunity to silent bit-rot, the ability to use disks as redundancy for mechanical failures, and the use of fast SSD devices as a cache. It also allows easy manual intervention in the event of incidents. However, at the time of writing, ZFS does not stretch across multiple nodes, posing a risk of cluster failure in case of a data-storing node’s malfunction (single point of failure). There is an ongoing effort to implement ZFS’ distributed RAID (dRAID) [<a href="https://github.com/openzfs/zfs/commit/b2255edcc0099e62ad46a3dd9d64537663c6aee3">src</a>].</p>
<p>To access ZFS storage from Kubernetes, we used the NFS server, which is part of the Linux kernel. We chose <a href="https://en.wikipedia.org/wiki/Network_File_System">NFS</a> because it is one of the few methods that allow multiple containers to bind to the same mounting point (see <a href="https://kubernetes.io/docs/concepts/storage/persistent-volumes/#access-modes">table</a>).</p>
<p><strong>System Management:</strong> For remote management and node configuration, we use <a href="https://www.ansible.com/">Ansible</a> maintained by Red Hat. We selected it due to its prevalence in other significant open-source projects and positive experiences from past projects.</p>
</section>
<section id="kubernetes" class="level2">
<h2 class="anchored" data-anchor-id="kubernetes">Kubernetes</h2>
<p>In Kubernetes, everything operates as a service. These services provide various functionalities that enhance Kubernetes capabilities, such as storage access, CPU and GPU allocation, traffic management, and connecting services within a mesh network.</p>
<p>To support specific functionalities, appropriate services (much like operating system drivers) must be installed. These specialized services, often called “operators” in Kubernetes terminology [<a href="https://kubernetes.io/docs/concepts/extend-kubernetes/operator/">src</a>], are essential. They not only deploy and manage functionalities but also respond to issues. Operators enhance Kubernetes by interfacing with standardized and version-controlled APIs.</p>
<p>Put simply, operators are deployed as controller pods (containers) that watch for changes to custom Kubernetes resources and react accordingly. They function as an intermediary layer, implementing application-specific logic that extends Kubernetes beyond its built-in capabilities.</p>
<section id="internal-services" class="level3">
<h3 class="anchored" data-anchor-id="internal-services">Internal Services</h3>
<p>In Kubernetes, internal services are not intended for end users but are crucial for the system’s operation. These services operate in the background, ensuring vital functionalities that enable the stable operation and management of the container environment. In this subsection, we will introduce key services within Kubernetes and explain their role in our infrastructure. We will describe each service’s primary functionality and examine alternatives we explored in making our decision.</p>
<p><strong>Kubernetes Distribution:</strong> When choosing a Kubernetes distribution, we examined three options: Canonical <a href="https://microk8s.io">MicroK8s</a>, Red Hat OpenShift, and the basic “vanilla” Kubernetes distribution. “Vanilla” Kubernetes represents the unaltered version directly available in Google’s repository, without pre-installed applications or plugins. We went for the vanilla version as it provides flexibility and freedom of choice of the extensions.</p>
<p><a href="https://microk8s.io">MicroK8s</a> is an excellent solution for quick experimentation and setting up the system on smaller devices with limited resources (<em>e.g.,</em> Raspberry Pi). However, it has many pre-installed applications and uses Canonical’s Snap packaging system, which can complicate adjusting configuration files and accessing external services, such as the NFS server.</p>
<p>We ruled out <a href="https://en.wikipedia.org/wiki/OpenShift">OpenShift</a> due to the complexity of managing security profiles that, for our use case, were excessive, requiring substantial effort to implement these profiles for each service. Therefore, we opted for the basic “vanilla” Kubernetes distribution, offering more flexible and straightforward customization tailored to our needs.</p>
<p><strong>Kubernetes Package Deployment:</strong> To describe the implementation of services in Kubernetes, a straightforward approach is to write YAML configuration file(s) (also called manifest), which are then forwarded to Kubernetes via the command line. However, some services can be quite complex, leading developers to create service packages, making services more general-purpose and customizable through parameters. The most widespread packaging system is <a href="https://helm.sh/">Helm</a>, allowing for more portable and adaptable service packages. Helm uses YAML files as templates (much like forms), which are then filled out based on the provided parameters and sent to Kubernetes.</p>
<p><strong>Network Operator:</strong> Kubernetes services must be interconnected to communicate with other services. We opted for the open-source Tigera <a href="https://github.com/tigera/operator">Calico</a> operator to manage interconnections. Given its prevalence and functionalities, we found it the most suitable solution.</p>
<p>Calico and <a href="https://github.com/flannel-io/flannel">Flannel</a> are the most common solutions for network operators. Flannel is more minimalistic and operates as a network switch (layer 2) using technologies like <a href="https://www.openvswitch.org/">Open vSwitch</a> or <a href="https://en.wikipedia.org/wiki/Virtual_Extensible_LAN">VXLAN</a>. In contrast, Calico routes traffic like a network router (layer 3). Especially in cases of multi-cluster (<em>i.e.,</em> multiple physical locations) or hybrid cloud services, Calico emerges as a better choice.</p>
<p><strong>Storage Operator:</strong> For effective storage management within the Kubernetes system, we used <a href="https://github.com/kubernetes-csi/csi-driver-nfs">csi-driver-nfs</a>. It allows us to use the already established NFS servers. With it, we ensure uninterrupted access to persistent storage for any service within our private cloud.</p>
<p>The <em>csi-driver-nfs</em> proved most suitable since we already had an NFS server on one of the nodes. It allows us straightforward and centralized storage management for all services within Kubernetes. Centralization brings about numerous advantages, yet also challenges. Among the latter is the system’s vulnerability during a potential outage of the node storing the data. Nonetheless, centralization facilitates easier troubleshooting and backup execution.</p>
<p><strong>Bare-Metal Ingress Load-Balancer:</strong> To ensure balanced ingress (of incoming) traffic among entry points in our Kubernetes cluster, we decided to utilize the <a href="https://metallb.universe.tf">MetalLB</a> solution. After thorough research, we could not find any other alternative. Most of the online documentation (<em>e.g.,</em> tutorials, blogs) focuses on setting up infrastructure on public clouds such as AWS or Azure and using solutions tailored to the demands of public cloud providers. However, since our infrastructure is based on our hardware (<em>i.e.,</em> bare-metal), we opted for MetalLB, which has proven reliable and effective in routing traffic among our Kubernetes cluster’s entry points.</p>
<p><strong>Ingress Operator:</strong> While a network operator manages interconnection between services within Kubernetes, the ingress operator manages access to services from the outside world. For security reasons, direct access to the internal network is prohibited. While it is possible to enter the internal network through a proxy (<em>i.e.,</em> <code>kubectl proxy</code>), that’s meant only for debugging purposes. The ingress operator is designed to resolve domain names and route traffic to the correct container and port, which we described in the service’s YAML manifest. Using domain name resolution has several advantages. Regardless of the service’s internal IP address, the ingress operator will always correctly direct traffic. The ingress operator can act as a load balancer when there is a high-traffic load, balancing traffic between multiple copies of service.</p>
<p>Among the most common solutions for ingress traffic management are <a href="https://github.com/nginxinc/nginx-ingress-helm-operator">NGINX</a> and <a href="https://doc.traefik.io/traefik/providers/kubernetes-ingress/">Traefik</a> Ingress operators. We chose NGINX, but the operators’ interface is standardized, so there are almost no differences between the solutions. Regardless of the selected solution, once a new service is deployed, the operator will follow the service’s manifest and automatically route traffic to the appropriate container.</p>
<p><strong>GPU Operator:</strong> For efficient management of access to compute accelerators, we decided to use the official NVIDIA <a href="https://github.com/NVIDIA/gpu-operator">GPU-Operator</a> suite of services. This suite provides two distinct installation options for NVIDIA drivers. The first option leverages host drivers, while the second involves drivers packaged within containers. Initially, we opted for the first option, wanting to enable the use of accelerators outside the Kubernetes framework. However, due to issues with conflicting driver versions, we decided to utilize the drivers provided by the GPU-Operator.</p>
</section>
<section id="user-services" class="level3">
<h3 class="anchored" data-anchor-id="user-services">User Services</h3>
<p>In this section, we introduce the selected services available to end users of our private cloud, enabling efficient execution and management of their research and educational projects.</p>
<p><strong>JupyterHub</strong> is one of the key services in our private cloud, providing users with easy access to computing resources, data, and Jupyter Notebooks for research and teaching purposes. To implement JupyterHub, we use the <a href="https://z2jh.jupyter.org/en/stable/">Z2JH</a> (Zero-to-JupyterHub) implementation, developed by a team of researchers at the University of Berkeley in collaboration with the Jupyter community. This solution facilitates quick setup and maintenance.</p>
<p>Every individual user is granted access to an isolated container instance via their username and password or <a href="https://en.wikipedia.org/wiki/OAuth">OAuth</a> provider, such as GitHub, Google, or Auth0. An isolated instance offers a stripped-down Linux environment with limited internet access and without admin permissions. Kubernetes then ensures access to shared data resources, common directories, and the use of compute accelerators.</p>
<p>The JupyterHub user interface is similar to Google Colab or Kaggle services. Upon entering the isolated instance, JupyterLab is already running, and the user also has access to the Linux terminal. Additional tools and software packages can be installed using <a href="https://pip.pypa.io/en/stable/">pip</a>, <a href="https://docs.conda.io/projects/conda/en/latest/user-guide/concepts/packages.html">conda</a>, or <a href="https://github.com/mamba-org/mamba">mamba</a> commands.</p>
<p><strong>Grafana</strong> is a key service in our private cloud, facilitating a straightforward display of the current workload of the compute cluster and the availability of compute accelerators. This data visualization platform allows users to present information clearly and transparently, aiding them in making decisions regarding resource usage and optimizing their tasks. Utilizing Grafana ensures efficient and transparent resource monitoring, enhancing user experience. Data collection (Prometheus) and visualization (Grafana) are deployed by <a href="https://github.com/prometheus-community/helm-charts/tree/main/charts/kube-prometheus-stack">kube-prometheus-stack</a>.</p>
</section>
</section>
</section>
<section id="deployment" class="level1">
<h1>Deployment</h1>
<p>In this section, we’ll present how we deployed our computing infrastructure. First, I’ll summarize the hardware decisions, caveats, and finally, the user experience with some screenshots.</p>
<section id="hardware" class="level2">
<h2 class="anchored" data-anchor-id="hardware">Hardware</h2>
<div id="tbl-hardware" class="hover quarto-float quarto-figure quarto-figure-center anchored">
<figure class="quarto-float quarto-float-tbl figure">
<figcaption class="quarto-float-caption-top quarto-float-caption quarto-float-tbl" id="tbl-hardware-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
Table&nbsp;2: Hardware specifications of the computing node.
</figcaption>
<div aria-describedby="tbl-hardware-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
<table class="table-hover caption-top table">
<thead>
<tr class="header">
<th>Hardware</th>
<th>Specifications</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td>Chassis</td>
<td>Supermicro A+ Server <a href="https://www.supermicro.com/en/aplus/system/4u/4124/as-4124gs-tnr.cfm">4124GS-TNR</a>, 4U size, up to PCI-E 8 GPUs</td>
</tr>
<tr class="even">
<td>CPU</td>
<td>2x AMD <a href="https://www.amd.com/en/products/cpu/amd-epyc-75f3">EPYC 75F3</a> (32C/64T, up to 4.0GHz, 256MB L3 cache)</td>
</tr>
<tr class="odd">
<td>Memory</td>
<td>1TB (16x64GB) REG ECC DDR4, 3200MHz</td>
</tr>
<tr class="even">
<td>System</td>
<td>2x 2TB SSD NVMe, software RAID-1 (mirror)</td>
</tr>
<tr class="odd">
<td>Storage</td>
<td>6x 8TB SSD SATA, software RAID-Z1 (1 disk redundancy)</td>
</tr>
<tr class="even">
<td>GPU</td>
<td>2x NVIDIA A100 80GB PCI-E</td>
</tr>
</tbody>
</table>
</div>
</figure>
</div>
<p>When we bought the hardware in early 2022, we chose third-generation AMD EPYC processors. Specifically, we went for the F-series, which has higher base and turbo frequencies — up to 4.0GHz — at the cost of fewer cores. We picked a CPU with the highest available TDP of 280W. We installed server-grade registered error-correcting memory at the highest frequency supported by the processor and populated all eight channels on both processors. Sixteen sticks of RAM in total. Although we considered solutions from Intel, AMD EPYC processors had better price-to-performance ratios.</p>
<p>From the perspective of numerical performance, our significant concern was Intel-optimized libraries, such as Intel MKL, often found in numerical tools. The library has a “bug” that causes non-Intel processors to utilize a slower SSE instead of more advanced AVX vectorization instructions [<a href="https://danieldk.eu/Posts/2020-08-31-MKL-Zen.html">src</a>]. <a href="https://www.openblas.net/">OpenBLAS</a> is a good alternative but requires some effort to install it. See Anaconda <a href="https://github.com/conda-forge/nomkl-feedstock">no-mkl</a> package.</p>
<p>We chose two NVMe drives configured in the mirror configuration (RAID-1) for the system drive. We selected six 8TB SSD SATA drives configured in ZFS RAID-Z1 for data storage, which has one drive redundancy. We also chose two A100 GPUs as accelerators.</p>
<p>NVIDIA A100 GPUs come in two form factors: PCI-E and SXM4. The SXM4 proprietary form factor has a higher TDP and high-bandwidth NVLink interconnections between every GPU through NVSwitch hardware. The downside of SXM4 is that it will only support Ampere generation GPUs and require a special motherboard. The PCI-E variant has a lower TDP, and NVLink can only be across two GPUs. However, we decided against vendor lock-in, limiting ourselves to one brand and generation, and went with the PCI-E variant.</p>
<p>We considered the most likely workflow scenarios. We expected most communication to be CPU-to-GPU, with GPUs sliced into several instances via MIG (Multi-Instance GPU). When MIG mode is enabled, each GPU is partitioned into isolated instances that share the physical GPU resources but do not have access to NVLink interconnects. The slicing configuration can be changed at runtime by recreating the MIG instances.</p>
<div id="fig-hardware" class="quarto-float quarto-figure quarto-figure-center anchored">
<figure class="quarto-float quarto-float-fig figure">
<div aria-describedby="fig-hardware-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
<img src="https://gcerar.github.io/posts/2023-11-20-research-compute-infrastructure/figures/hardware-top-view-resize.jpg" class="img-fluid figure-img">
</div>
<figcaption class="quarto-float-caption-bottom quarto-float-caption quarto-float-fig" id="fig-hardware-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
Figure&nbsp;2: The computing node on my desk underwent final checks before being installed in the server rack.
</figcaption>
</figure>
</div>
</section>
<section id="user-experience" class="level2">
<h2 class="anchored" data-anchor-id="user-experience">User Experience</h2>
<p>After deploying the hardware and software stack, we conducted a month-long live test to stabilize the configuration. During this period, users were informed that we might reboot the system or make significant changes without responsibility for any potential data loss, though we aimed to minimize such occurrences.</p>
<p>We made two key decisions about resource allocation. Users can utilize all available memory and CPU cores. When CPU demand is high, Kubernetes and the operating system manage the scheduling of tasks. In cases of high memory usage, the job consuming the most memory is terminated to protect other running tasks.</p>
<p>Feedback from students and researchers was overwhelmingly positive, highlighting the high speed, numerous cores, ample memory, and dedicated GPU access without interference.</p>
<p>During the testing phase, “testers” identified several issues, which were promptly addressed. These included adding a shared folder with datasets and Jupyter Notebooks, shared package cache, and better persistence of running tasks in JupyterLab.</p>
<section id="jupyterhub" class="level3">
<h3 class="anchored" data-anchor-id="jupyterhub">JupyterHub</h3>
<p>JupyterHub has become a crucial component of our research infrastructure, enhancing our workflow significantly. Its smooth integration was largely due to the interface and functionality of JupyterHub, which closely resemble the tools our researchers and students were familiar with. This similarity played a key role in its quick adoption and high user satisfaction.</p>
<div id="fig-jhub-profiles" class="quarto-float quarto-figure quarto-figure-center anchored">
<figure class="quarto-float quarto-float-fig figure">
<div aria-describedby="fig-jhub-profiles-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
<img src="https://gcerar.github.io/posts/2023-11-20-research-compute-infrastructure/figures/jhub-profiles.png" class="img-fluid figure-img">
</div>
<figcaption class="quarto-float-caption-bottom quarto-float-caption quarto-float-fig" id="fig-jhub-profiles-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
Figure&nbsp;3: JupyterHub offers a list of predefined containers, where some of them offer a GPU instance.
</figcaption>
</figure>
</div>
<p>Upon logging into JupyterHub, users are presented with a list of predefined containers (as shown in Figure&nbsp;3). Our recent update includes several options:</p>
<ul>
<li>A basic minimal working environment.</li>
<li>A comprehensive data science environment equipped with multiple packages and support for Python, R, and Julia.</li>
<li>A selection of containers offering GPU instances.</li>
</ul>
<p>The development environment greets users with a layout similar to modern IDEs, featuring a file explorer on the left and code editor tabs on the right (see Figure&nbsp;4).</p>
<div id="fig-jhub-workspace" class="quarto-float quarto-figure quarto-figure-center anchored">
<figure class="quarto-float quarto-float-fig figure">
<div aria-describedby="fig-jhub-workspace-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
<img src="https://gcerar.github.io/posts/2023-11-20-research-compute-infrastructure/figures/jhub-workspace.png" class="img-fluid figure-img">
</div>
<figcaption class="quarto-float-caption-bottom quarto-float-caption quarto-float-fig" id="fig-jhub-workspace-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
Figure&nbsp;4: JupyterLab workspace with familiar layout: a file explorer on the left and code editor tabs on the right.
</figcaption>
</figure>
</div>
</section>
<section id="grafana" class="level3">
<h3 class="anchored" data-anchor-id="grafana">Grafana</h3>
<p>For transparent insight into infrastructure availability, the user has read-only access to the Grafana dashboard. Dashboard visualizes computing resource utilization including metrics like total and per-container CPU usage, memory usage per container, GPU utilization, temperature readings, and storage I/O (see Figure&nbsp;5).</p>
<div id="fig-grafana-dashboard" class="quarto-float quarto-figure quarto-figure-center anchored">
<figure class="quarto-float quarto-float-fig figure">
<div aria-describedby="fig-grafana-dashboard-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
<img src="https://gcerar.github.io/posts/2023-11-20-research-compute-infrastructure/figures/grafana-censored.png" class="img-fluid figure-img">
</div>
<figcaption class="quarto-float-caption-bottom quarto-float-caption quarto-float-fig" id="fig-grafana-dashboard-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
Figure&nbsp;5: Visualization of computing cluster utilization showing total and per-container CPU utilization, per-container memory utilization, GPU slices utilization, temperatures, and storage I/O.
</figcaption>
</figure>
</div>
</section>
</section>
</section>
<section id="conclusions" class="level1">
<h1>Conclusions</h1>
<p>This article introduced our private cloud solution based on Kubernetes technology. This solution offers a scalable environment for using Jupyter Notebooks, an effective educational tool for data-driven narrative analysis, creating learning materials, and interactive presentations. Additionally, the system allows for the concurrent sharing of computing resources, significantly enhancing the utilization of our entire infrastructure.</p>
<p>The JupyterHub service on the Kubernetes platform facilitates easy access to the work environment and ensures user isolation, allowing for uninterrupted work and research. Users benefit from storage space and shared folders for file sharing, promoting collaboration and teamwork. Users also have access to compute accelerators when available.</p>
<p>We discuss our solution’s key components, architecture, and design decisions, revealing the technology choices that led to efficient operation and an exceptional user experience. Our focus has been on open-source platforms that have proven reliable and effective in our environment. As the core platform, Kubernetes enables scalable container management and high availability, while JupyterHub provides easy access to services and simplifies user management.</p>
<p>We plan to enhance our solution with additional services and technologies to improve user experience and increase our cloud’s performance. We remain open to new technologies and approaches that contribute to the better functioning of our private cloud solution for research and education. Data from Prometheus will be crucial for analyzing infrastructure utilization and understanding the extent of user competition for computing resources.</p>



</section>


<div id="quarto-appendix" class="default"><section class="quarto-appendix-contents" id="quarto-bibliography"><h2 class="anchored quarto-appendix-heading">References</h2><div id="refs" class="references csl-bib-body hanging-indent">
<div id="ref-bussonnier2018jupyterhub" class="csl-entry">
Bussonnier, Matthias. 2018. <span>“<span class="nocase">Jupyter and HPC: Current state and future roadmap</span>.”</span> In <em>Exascale Computing Project</em>. <a href="https://www.exascaleproject.org/event/jupyter/">https://www.exascaleproject.org/event/jupyter/</a>.
</div>
<div id="ref-granger2021storytelling" class="csl-entry">
<span class="nocase">Granger, Brian E. et al.</span> 2021. <span>“<span class="nocase">Jupyter: Thinking and Storytelling With Code and Data</span>.”</span> <em>Computing in Science &amp; Engineering</em> 23 (2): 7–14. <a href="https://doi.org/10.1109/MCSE.2021.3059263">https://doi.org/10.1109/MCSE.2021.3059263</a>.
</div>
<div id="ref-mendez2019toward" class="csl-entry">
<span class="nocase">Mendez, Kevin M et al.</span> 2019. <span>“<span class="nocase">Toward collaborative open data science in metabolomics using Jupyter Notebooks and cloud computing</span>.”</span> <em>Metabolomics</em> 15 (10): 1–16.
</div>
<div id="ref-perkel2018jupyter" class="csl-entry">
Perkel, Jeffrey M. 2018. <span>“<span class="nocase">Why Jupyter is data scientists’ computational notebook of choice</span>.”</span> <em>Nature</em> 563 (7732): 145–47.
</div>
</div></section><section id="footnotes" class="footnotes footnotes-end-of-document"><h2 class="anchored quarto-appendix-heading">Footnotes</h2>

<ol>
<li id="fn1"><p>JSON: JavaScript Object Notation↩︎</p></li>
<li id="fn2"><p>REPL: read–eval–print loop↩︎</p></li>
<li id="fn3"><p>WETL: write-eval-think-loop↩︎</p></li>
</ol>
</section><section class="quarto-appendix-contents" id="quarto-reuse"><h2 class="anchored quarto-appendix-heading">Reuse</h2><div class="quarto-appendix-contents"><div><a rel="license" href="https://creativecommons.org/licenses/by-nc-sa/4.0/deed.en_us">CC BY-NC-SA 4.0</a></div></div></section></div> ]]></description>
  <category>compute</category>
  <category>infrastructure</category>
  <guid>https://gcerar.github.io/posts/2023-11-20-research-compute-infrastructure/</guid>
  <pubDate>Sun, 19 Nov 2023 23:00:00 GMT</pubDate>
</item>
<item>
  <title>Generative Adversarial Networks</title>
  <dc:creator>Gregor Cerar</dc:creator>
  <link>https://gcerar.github.io/posts/2023-10-10-vanilla-GANs/</link>
  <description><![CDATA[ 





<section id="introduction" class="level1">
<h1>Introduction</h1>
<p>Generative Adversarial Networks (GANs) are an innovative class of unsupervised neural networks that have revolutionized the field of artificial intelligence. They were first introduced in <a href="https://arxiv.org/abs/1406.2661">Generative Adversarial Networks</a> <span class="citation" data-cites="goodfellow2014generative">(Goodfellow et al. 2014)</span> and consist of two separate neural networks: the <strong>generator</strong> (creates data) and the <strong>discriminator</strong> (evaluates data authenticity). The generator aims to fool the discriminator by producing realistic data, while the discriminator tries to differentiate real from fake. Over iterations, the generator’s data becomes more convincing.</p>
<p>As an analogy, consider two kids, one drawing counterfeit money (“Generator”) and another assessing its realism (“Discriminator”). Over time, the counterfeit drawings become increasingly convincing.</p>
</section>
<section id="vanilla-gan" class="level1">
<h1>Vanilla GAN</h1>
<p>The most fundamental variant of GAN is the “vanilla” GAN, where “vanilla” signifies the model in its original and most straightforward form rather than a flavor. To better understand its mechanism, I’ve illustrated its structure on Figure&nbsp;1.</p>
<div id="fig-architecture" class="quarto-float quarto-figure quarto-figure-center anchored">
<figure class="quarto-float quarto-float-fig figure">
<div aria-describedby="fig-architecture-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
<img src="https://gcerar.github.io/posts/2023-10-10-vanilla-GANs/figures/GAN-architecture.png" class="img-fluid figure-img">
</div>
<figcaption class="quarto-float-caption-bottom quarto-float-caption quarto-float-fig" id="fig-architecture-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
Figure&nbsp;1: GAN architecture
</figcaption>
</figure>
</div>
<ul>
<li><strong>Generator</strong> <img src="https://latex.codecogs.com/png.latex?G(z;%20w_g)"> takes random noise <img src="https://latex.codecogs.com/png.latex?z"> as input and produces fabricated data <img src="https://latex.codecogs.com/png.latex?x_f">.
<ul>
<li><img src="https://latex.codecogs.com/png.latex?z"> represents the input vector, a noise vector from the Gaussian distribution.</li>
<li><img src="https://latex.codecogs.com/png.latex?w_g"> denotes generator neural network weights.</li>
<li><img src="https://latex.codecogs.com/png.latex?x_f"> is a fabricated data sample meant for the discriminator.</li>
</ul></li>
<li><strong>Discriminator</strong> <img src="https://latex.codecogs.com/png.latex?D(x;%20w_d)"> differentiates between real and generated data.
<ul>
<li><img src="https://latex.codecogs.com/png.latex?x"> represents input vectors, which come from either a real dataset (<img src="https://latex.codecogs.com/png.latex?x_r%20%5Csim%20p_%5Ctextrm%7Bdata%7D(x)">) or from the set of fabricated samples (<img src="https://latex.codecogs.com/png.latex?x_f%20=%20G(z%20%5Csim%20p_z(z);%20w_g)">).</li>
<li><img src="https://latex.codecogs.com/png.latex?w_d"> denodes discriminator neural network weights.</li>
</ul></li>
</ul>
<section id="objective-function" class="level2">
<h2 class="anchored" data-anchor-id="objective-function">Objective Function</h2>
<p>The interaction between the Generator and the Discriminator can be quantified by their objective or loss functions:</p>
<ul>
<li><strong>Discriminator’s Objective:</strong> For real data <img src="https://latex.codecogs.com/png.latex?x">, <img src="https://latex.codecogs.com/png.latex?D"> wants <img src="https://latex.codecogs.com/png.latex?D(x)"> near <img src="https://latex.codecogs.com/png.latex?1">. For generated data <img src="https://latex.codecogs.com/png.latex?G(z)">, it targets <img src="https://latex.codecogs.com/png.latex?D(G(z))"> close to <img src="https://latex.codecogs.com/png.latex?0">. Its objective is:</li>
</ul>
<p><img src="https://latex.codecogs.com/png.latex?%0A%5Cmathcal%7BL%7D(D)%20=%20%5Clog(D(x))%20+%20%5Clog(1%20-%20D(G(z))).%0A"></p>
<ul>
<li><strong>Generator’s Objective:</strong> <img src="https://latex.codecogs.com/png.latex?G"> aims for <img src="https://latex.codecogs.com/png.latex?D(G(z))"> to approach <img src="https://latex.codecogs.com/png.latex?1">, given by:</li>
</ul>
<p><img src="https://latex.codecogs.com/png.latex?%0A%5Cmathcal%7BL%7D(G)%20=%20%5Clog%E2%81%A1(1%20%E2%88%92%20D(G(z)))%0A"></p>
<p>Both <img src="https://latex.codecogs.com/png.latex?G"> and <img src="https://latex.codecogs.com/png.latex?D"> continuously improve to outperform each other in this game.</p>
<section id="minimax-game-in-gans" class="level3">
<h3 class="anchored" data-anchor-id="minimax-game-in-gans">Minimax Game in GANs</h3>
<p>Vanilla GANs are structured around the minimax game from game theory:</p>
<p><img src="https://latex.codecogs.com/png.latex?%0A%5Cmin_%7BG%7D%5Cmax_%7BD%7D%20%5Cmathcal%7BL%7D(D,%20G)%20=%20%5Clog(D(x))%20+%20%5Clog(1%20-%20D(G(z)))%0A"></p>
<p>In essence:</p>
<ul>
<li><strong>Discriminator:</strong> Maximizes its capacity to differentiate real data from generated.</li>
<li><strong>Generator:</strong> Minimizes the discriminator’s success rate by producing superior forgeries.</li>
</ul>
<p>The iterative competition refines both, targeting a proficient Generator and a perceptive Discriminator.</p>
</section>
</section>
<section id="prepare-components" class="level2">
<h2 class="anchored" data-anchor-id="prepare-components">Prepare Components</h2>
<p>In the upcoming sections, we’ll do the following steps to prepare the development environment:</p>
<ul>
<li>Import necessary libraries, primarily PyTorch and Matplotlib.</li>
<li>Define constants, including project path and seed, for consistency.</li>
<li>Determine the computational device (e.g., GPU).</li>
<li>Provide a weight initialization helper function.</li>
</ul>
<div id="cell-7" class="cell">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb1" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb1-1"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">from</span> collections.abc <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> Callable, Sequence</span>
<span id="cb1-2"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">from</span> pathlib <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> Path</span>
<span id="cb1-3"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">from</span> typing <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> Any, Final</span>
<span id="cb1-4"></span>
<span id="cb1-5"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> joblib</span>
<span id="cb1-6"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> numpy <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">as</span> np</span>
<span id="cb1-7"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">from</span> matplotlib <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> pyplot <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">as</span> plt</span>
<span id="cb1-8"></span>
<span id="cb1-9"><span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">%</span>config InlineBackend.figure_formats <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> {<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'retina'</span>, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'png'</span>}</span>
<span id="cb1-10"></span>
<span id="cb1-11"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> torch</span>
<span id="cb1-12"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">from</span> torch <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> Tensor, nn, optim</span>
<span id="cb1-13"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">from</span> torch.utils.data <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> ConcatDataset, DataLoader, Dataset</span>
<span id="cb1-14"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">from</span> torchinfo <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> summary</span>
<span id="cb1-15"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">from</span> torchvision <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> transforms <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">as</span> T</span>
<span id="cb1-16"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">from</span> torchvision.utils <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> make_grid</span>
<span id="cb1-17"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">from</span> tqdm <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> tqdm</span>
<span id="cb1-18"></span>
<span id="cb1-19">SEED: Final[<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span>] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">42</span></span>
<span id="cb1-20"></span>
<span id="cb1-21">PROJECT_PATH <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> Path(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"."</span>).resolve()</span>
<span id="cb1-22">FIGURE_PATH <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> PROJECT_PATH <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"figures"</span></span>
<span id="cb1-23">DATASET_PATH <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> Path.home() <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"datasets"</span></span>
<span id="cb1-24"></span>
<span id="cb1-25"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Common constants for all experiments</span></span>
<span id="cb1-26">IMG_DIM: Final[<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">tuple</span>[<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span>, <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span>, <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span>]] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> (<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">28</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">28</span>)</span></code></pre></div></div>
</div>
<div id="cell-9" class="cell">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb2" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb2-1">device <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.device(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"cpu"</span>)</span>
<span id="cb2-2"></span>
<span id="cb2-3"><span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> torch.cuda.is_available():</span>
<span id="cb2-4">    device <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.device(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"cuda"</span>)</span></code></pre></div></div>
</div>
<div id="cell-10" class="cell">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb3" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb3-1"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> weights_init(net: nn.Module) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-&gt;</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span>:</span>
<span id="cb3-2">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> m <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> net.modules():</span>
<span id="cb3-3">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">isinstance</span>(m, nn.Conv2d <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">|</span> nn.ConvTranspose2d):</span>
<span id="cb3-4">            nn.init.normal_(m.weight, <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.0</span>, <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.02</span>)</span>
<span id="cb3-5">            <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> m.bias <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">is</span> <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">not</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span>:</span>
<span id="cb3-6">                nn.init.constant_(m.bias, <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.0</span>)</span>
<span id="cb3-7"></span>
<span id="cb3-8">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">elif</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">isinstance</span>(m, nn.BatchNorm1d <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">|</span> nn.BatchNorm2d):</span>
<span id="cb3-9">            nn.init.normal_(m.weight, <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">1.0</span>, <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.02</span>)</span>
<span id="cb3-10">            <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> m.bias <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">is</span> <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">not</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span>:</span>
<span id="cb3-11">                nn.init.constant_(m.bias, <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.0</span>)</span>
<span id="cb3-12"></span>
<span id="cb3-13">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">elif</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">isinstance</span>(m, nn.Linear):</span>
<span id="cb3-14">            nn.init.normal_(m.weight, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>, <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.02</span>)</span>
<span id="cb3-15">            <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> m.bias <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">is</span> <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">not</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span>:</span>
<span id="cb3-16">                nn.init.constant_(m.bias, <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.0</span>)</span></code></pre></div></div>
</div>
<section id="generator" class="level3">
<h3 class="anchored" data-anchor-id="generator">Generator</h3>
<p>The Generator in GANs acts as an artist, crafting data.</p>
<ul>
<li><strong>Input:</strong> Takes random noise, typically from a standard normal distribution.</li>
<li><strong>Architecture:</strong> Uses dense layers, progressively increasing data dimensions.</li>
<li><strong>Output:</strong> Reshapes data to desired format (e.g., image). Often uses ‘tanh’ for activation.</li>
<li><strong>Objective:</strong> Generate data indistinguishable from real by the Discriminator.</li>
</ul>
<div id="cell-13" class="cell">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb4" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb4-1"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">class</span> Generator(nn.Module):</span>
<span id="cb4-2">    <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> <span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">__init__</span>(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, out_dim: Sequence[<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span>], nz: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">100</span>, ngf: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">256</span>, alpha: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">float</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.2</span>):</span>
<span id="cb4-3">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">"""</span></span>
<span id="cb4-4"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">        :param out_dim: output image dimension / shape</span></span>
<span id="cb4-5"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">        :param nz: size of the latent z vector $z$</span></span>
<span id="cb4-6"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">        :param ngf: size of feature maps (units in the hidden layers) in the generator</span></span>
<span id="cb4-7"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">        :param alpha: negative slope of leaky ReLU activation</span></span>
<span id="cb4-8"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">        """</span></span>
<span id="cb4-9">        <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">super</span>().<span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">__init__</span>()</span>
<span id="cb4-10">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.out_dim <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> out_dim</span>
<span id="cb4-11">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.model <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> nn.Sequential(</span>
<span id="cb4-12">            nn.Linear(nz, ngf),</span>
<span id="cb4-13">            nn.LeakyReLU(alpha, inplace<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>),</span>
<span id="cb4-14">            nn.Linear(ngf, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> ngf),</span>
<span id="cb4-15">            nn.LeakyReLU(alpha, inplace<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>),</span>
<span id="cb4-16">            nn.Linear(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> ngf, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">4</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> ngf),</span>
<span id="cb4-17">            nn.LeakyReLU(alpha, inplace<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>),</span>
<span id="cb4-18">            nn.Linear(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">4</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> ngf, <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span>(np.prod(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.out_dim))),</span>
<span id="cb4-19">            nn.Tanh(),</span>
<span id="cb4-20">        )</span>
<span id="cb4-21"></span>
<span id="cb4-22">    <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> forward(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, x: Tensor) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-&gt;</span> Tensor:</span>
<span id="cb4-23">        x <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.model(x)</span>
<span id="cb4-24">        x <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.reshape(x, (x.size(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>), <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.out_dim))</span>
<span id="cb4-25">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> x</span>
<span id="cb4-26"></span>
<span id="cb4-27"></span>
<span id="cb4-28">summary(Generator(out_dim<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">28</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">28</span>)), input_size<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">128</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">100</span>])</span></code></pre></div></div>
<div class="cell-output cell-output-display">
<pre><code>==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
Generator                                [128, 1, 28, 28]          --
├─Sequential: 1-1                        [128, 784]                --
│    └─Linear: 2-1                       [128, 256]                25,856
│    └─LeakyReLU: 2-2                    [128, 256]                --
│    └─Linear: 2-3                       [128, 512]                131,584
│    └─LeakyReLU: 2-4                    [128, 512]                --
│    └─Linear: 2-5                       [128, 1024]               525,312
│    └─LeakyReLU: 2-6                    [128, 1024]               --
│    └─Linear: 2-7                       [128, 784]                803,600
│    └─Tanh: 2-8                         [128, 784]                --
==========================================================================================
Total params: 1,486,352
Trainable params: 1,486,352
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 190.25
==========================================================================================
Input size (MB): 0.05
Forward/backward pass size (MB): 2.64
Params size (MB): 5.95
Estimated Total Size (MB): 8.63
==========================================================================================</code></pre>
</div>
</div>
</section>
<section id="discriminator" class="level3">
<h3 class="anchored" data-anchor-id="discriminator">Discriminator</h3>
<p>The Discriminator is GAN’s evaluator, distinguishing real from fake data.</p>
<ul>
<li><strong>Input:</strong> Takes either real data samples or those from the Generator.</li>
<li><strong>Architecture:</strong> Employs dense layers for binary classification of the input.</li>
<li><strong>Output:</strong> Uses a sigmoid activation, yielding a score between 0-1, reflecting the data’s authenticity.</li>
<li><strong>Objective:</strong> Recognize real data and identify fake data from the Generator.</li>
</ul>
<div id="cell-15" class="cell">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb6" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb6-1"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">class</span> Discriminator(nn.Module):</span>
<span id="cb6-2">    <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> <span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">__init__</span>(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, input_dim: Sequence[<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span>], ndf: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">128</span>, alpha: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">float</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.2</span>):</span>
<span id="cb6-3">        <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">super</span>().<span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">__init__</span>()</span>
<span id="cb6-4">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.model <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> nn.Sequential(</span>
<span id="cb6-5">            nn.Linear(<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span>(np.prod(input_dim)), <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">4</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> ndf),</span>
<span id="cb6-6">            nn.LeakyReLU(alpha, inplace<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>),</span>
<span id="cb6-7">            nn.Dropout(<span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.3</span>),</span>
<span id="cb6-8">            nn.Linear(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">4</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> ndf, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> ndf),</span>
<span id="cb6-9">            nn.LeakyReLU(alpha, inplace<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>),</span>
<span id="cb6-10">            nn.Dropout(<span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.3</span>),</span>
<span id="cb6-11">            nn.Linear(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> ndf, ndf),</span>
<span id="cb6-12">            nn.LeakyReLU(alpha, inplace<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>),</span>
<span id="cb6-13">            nn.Dropout(<span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.3</span>),</span>
<span id="cb6-14">            nn.Linear(ndf, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>),</span>
<span id="cb6-15">            <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># nn.Sigmoid(),</span></span>
<span id="cb6-16">        )</span>
<span id="cb6-17"></span>
<span id="cb6-18">    <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> forward(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, x: Tensor) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-&gt;</span> Tensor:</span>
<span id="cb6-19">        x <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.reshape(x, (x.size(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>), <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>))</span>
<span id="cb6-20">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.model(x)</span>
<span id="cb6-21"></span>
<span id="cb6-22"></span>
<span id="cb6-23">summary(Discriminator(input_dim<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">28</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">28</span>)), input_size<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">128</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">28</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">28</span>])</span></code></pre></div></div>
<div class="cell-output cell-output-display">
<pre><code>==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
Discriminator                            [128, 1]                  --
├─Sequential: 1-1                        [128, 1]                  --
│    └─Linear: 2-1                       [128, 512]                401,920
│    └─LeakyReLU: 2-2                    [128, 512]                --
│    └─Dropout: 2-3                      [128, 512]                --
│    └─Linear: 2-4                       [128, 256]                131,328
│    └─LeakyReLU: 2-5                    [128, 256]                --
│    └─Dropout: 2-6                      [128, 256]                --
│    └─Linear: 2-7                       [128, 128]                32,896
│    └─LeakyReLU: 2-8                    [128, 128]                --
│    └─Dropout: 2-9                      [128, 128]                --
│    └─Linear: 2-10                      [128, 1]                  129
==========================================================================================
Total params: 566,273
Trainable params: 566,273
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 72.48
==========================================================================================
Input size (MB): 0.40
Forward/backward pass size (MB): 0.92
Params size (MB): 2.27
Estimated Total Size (MB): 3.59
==========================================================================================</code></pre>
</div>
</div>
</section>
</section>
<section id="training-loop" class="level2">
<h2 class="anchored" data-anchor-id="training-loop">Training Loop</h2>
<p>The training process is iterative:</p>
<ul>
<li><strong>Update Discriminator:</strong> With the Generator static, improve the Discriminator’s detection of real vs.&nbsp;fake.</li>
<li><strong>Update Generator:</strong> With a static Discriminator, enhance the Generator’s ability to deceive.</li>
</ul>
<p>Training continues until the Generator produces almost authentic data. Equilibrium is reached when the Discriminator sees every input as equally likely real or fake, assigning a probability of <img src="https://latex.codecogs.com/png.latex?%5Cfrac%7B1%7D%7B2%7D">.</p>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
Note
</div>
</div>
<div class="callout-body-container callout-body">
<p>Using <code>.eval()</code> and <code>.train()</code> modes initially seemed promising for faster training. However, they affected layers like <code>BatchNorm2d</code> and <code>Dropout</code>, making the GAN diverge. Also, switching between eval and train modes is not free of charge.</p>
</div>
</div>
<div id="cell-17" class="cell">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb8" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb8-1"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> train_step(</span>
<span id="cb8-2">    generator: nn.Module,</span>
<span id="cb8-3">    discriminator: nn.Module,</span>
<span id="cb8-4">    optim_G: optim.Optimizer,</span>
<span id="cb8-5">    optim_D: optim.Optimizer,</span>
<span id="cb8-6">    criterion: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],</span>
<span id="cb8-7">    real_data: torch.Tensor,</span>
<span id="cb8-8">    noise_dim: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span>,</span>
<span id="cb8-9">    device: torch.device,</span>
<span id="cb8-10">) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-&gt;</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">tuple</span>[<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">float</span>, <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">float</span>]:</span>
<span id="cb8-11">    batch_size <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> real_data.size(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>)</span>
<span id="cb8-12">    real_data <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> real_data.to(device, non_blocking<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>)</span>
<span id="cb8-13"></span>
<span id="cb8-14">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">### Train Discriminator</span></span>
<span id="cb8-15">    optim_D.zero_grad(set_to_none<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>)</span>
<span id="cb8-16"></span>
<span id="cb8-17">    noise <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.randn(batch_size, noise_dim, device<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>device)</span>
<span id="cb8-18"></span>
<span id="cb8-19">    output_real <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> discriminator(real_data)</span>
<span id="cb8-20">    real_labels <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.ones_like(output_real)</span>
<span id="cb8-21">    loss_D_real <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> criterion(output_real, real_labels)</span>
<span id="cb8-22"></span>
<span id="cb8-23">    fake_data <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> generator(noise)</span>
<span id="cb8-24">    output_fake <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> discriminator(fake_data.detach())</span>
<span id="cb8-25">    fake_labels <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.zeros_like(output_fake)</span>
<span id="cb8-26">    loss_D_fake <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> criterion(output_fake, fake_labels)</span>
<span id="cb8-27"></span>
<span id="cb8-28">    loss_D <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> (loss_D_real <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> loss_D_fake) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span></span>
<span id="cb8-29"></span>
<span id="cb8-30">    loss_D.backward()</span>
<span id="cb8-31">    optim_D.step()</span>
<span id="cb8-32"></span>
<span id="cb8-33">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">### Train Generator</span></span>
<span id="cb8-34">    optim_G.zero_grad(set_to_none<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>)</span>
<span id="cb8-35"></span>
<span id="cb8-36">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Freeze D params so autograd does not waste work computing their grads</span></span>
<span id="cb8-37">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> p <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> discriminator.parameters():</span>
<span id="cb8-38">        p.requires_grad_(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">False</span>)</span>
<span id="cb8-39"></span>
<span id="cb8-40">    fake_data <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> generator(noise)</span>
<span id="cb8-41">    output_fake <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> discriminator(fake_data)</span>
<span id="cb8-42">    target_for_g <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.ones_like(output_fake)</span>
<span id="cb8-43">    loss_G <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> criterion(output_fake, target_for_g)</span>
<span id="cb8-44"></span>
<span id="cb8-45">    loss_G.backward()</span>
<span id="cb8-46">    optim_G.step()</span>
<span id="cb8-47"></span>
<span id="cb8-48">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> p <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> discriminator.parameters():</span>
<span id="cb8-49">        p.requires_grad_(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>)</span>
<span id="cb8-50"></span>
<span id="cb8-51">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> loss_G.detach().item(), loss_D.detach().item()</span></code></pre></div></div>
</div>
</section>
<section id="evaluation" class="level2">
<h2 class="anchored" data-anchor-id="evaluation">Evaluation</h2>
<p>Before evaluation, we configured the learning rate (LR), optimizer’s <img src="https://latex.codecogs.com/png.latex?%5Cbeta"> parameters, batch size, and data loader settings for all experiments. We used the MNIST digits and MNIST fashion datasets for assessment.</p>
<div id="cell-19" class="cell">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb9" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb9-1">OPTIMIZER_LR <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.0002</span></span>
<span id="cb9-2">L2_NORM <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">1e-5</span></span>
<span id="cb9-3">OPTIMIZER_BETAS <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> (<span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.5</span>, <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.999</span>)</span>
<span id="cb9-4">N_EPOCHS <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">100</span></span>
<span id="cb9-5">BATCH_SIZE <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">128</span></span></code></pre></div></div>
</div>
<div id="cell-20" class="cell">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb10" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb10-1">g <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.Generator()</span>
<span id="cb10-2">g.manual_seed(SEED)</span>
<span id="cb10-3"></span>
<span id="cb10-4">loader_kwargs <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> {</span>
<span id="cb10-5">    <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"num_workers"</span>: joblib.cpu_count(only_physical_cores<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>),</span>
<span id="cb10-6">    <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"pin_memory"</span>: <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>,</span>
<span id="cb10-7">    <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"shuffle"</span>: <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>,</span>
<span id="cb10-8">    <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"batch_size"</span>: BATCH_SIZE,</span>
<span id="cb10-9">    <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"prefetch_factor"</span>: <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>,</span>
<span id="cb10-10">    <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"persistent_workers"</span>: <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>,</span>
<span id="cb10-11">    <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"worker_init_fn"</span>: seed_worker,</span>
<span id="cb10-12">    <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"generator"</span>: g,</span>
<span id="cb10-13">}</span></code></pre></div></div>
</div>
<section id="mnist-digits-dataset" class="level3">
<h3 class="anchored" data-anchor-id="mnist-digits-dataset">MNIST Digits Dataset</h3>
<p>The MNIST (Modified National Institute of Standards and Technology) dataset is a well-known collection of handwritten digits, extensively used in the fields of machine learning and computer vision for training and testing purposes. Its simplicity and size make it a popular choice for introductory courses and experiments in image recognition.</p>
<p>In total, the dataset contains 70,000 grayscale images of handwritten digits (from 0 to 9). Each image is 28x28 pixels.</p>
<div id="cell-22" class="cell">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb11" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb11-1"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> get_mnist_dataset(transform: T.Compose <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">|</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span>) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-&gt;</span> Dataset:</span>
<span id="cb11-2">    <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">from</span> torchvision.datasets <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> MNIST</span>
<span id="cb11-3"></span>
<span id="cb11-4">    root <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">str</span>(DATASET_PATH)</span>
<span id="cb11-5">    trainset <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> MNIST(root<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>root, train<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>, download<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>, transform<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>transform)</span>
<span id="cb11-6">    testset <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> MNIST(root<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>root, train<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">False</span>, download<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>, transform<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>transform)</span>
<span id="cb11-7">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Combine train and test dataset for more samples.</span></span>
<span id="cb11-8">    dataset <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> ConcatDataset([trainset, testset])</span>
<span id="cb11-9">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> dataset</span></code></pre></div></div>
</div>
<div id="cell-24" class="cell">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb12" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb12-1">NOISE_DIM <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">100</span></span></code></pre></div></div>
</div>
<div id="cell-25" class="cell">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb13" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb13-1">transform <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> T.Compose([T.ToTensor(), T.Normalize(<span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.5</span>, <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.5</span>)])</span>
<span id="cb13-2"></span>
<span id="cb13-3">dataset <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> get_mnist_dataset(transform<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>transform)</span>
<span id="cb13-4">dataloader <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> DataLoader(dataset, <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">**</span>loader_kwargs)</span>
<span id="cb13-5"></span>
<span id="cb13-6"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># set seed for random generators</span></span>
<span id="cb13-7">set_random_seed(seed<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>SEED)</span>
<span id="cb13-8"></span>
<span id="cb13-9"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># benchmark_noise is used for the animation to show how output evolve on the same vector</span></span>
<span id="cb13-10">benchmark_noise <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.randn(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">16</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">16</span>, NOISE_DIM, device<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>device)</span>
<span id="cb13-11"></span>
<span id="cb13-12">generator <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> Generator(out_dim<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>IMG_DIM, nz<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>NOISE_DIM).to(device)</span>
<span id="cb13-13">generator.<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">apply</span>(weights_init)</span>
<span id="cb13-14"></span>
<span id="cb13-15">discriminator <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> Discriminator(input_dim<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>IMG_DIM).to(device)</span>
<span id="cb13-16">discriminator.<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">apply</span>(weights_init)</span>
<span id="cb13-17"></span>
<span id="cb13-18">optimizer_G <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> optim.AdamW(</span>
<span id="cb13-19">    generator.parameters(),</span>
<span id="cb13-20">    lr<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>OPTIMIZER_LR,</span>
<span id="cb13-21">    betas<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>OPTIMIZER_BETAS,</span>
<span id="cb13-22">    weight_decay<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>L2_NORM,</span>
<span id="cb13-23">)</span>
<span id="cb13-24"></span>
<span id="cb13-25">optimizer_D <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> optim.AdamW(</span>
<span id="cb13-26">    discriminator.parameters(),</span>
<span id="cb13-27">    lr<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>OPTIMIZER_LR,</span>
<span id="cb13-28">    betas<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>OPTIMIZER_BETAS,</span>
<span id="cb13-29">    weight_decay<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>L2_NORM,</span>
<span id="cb13-30">)</span>
<span id="cb13-31"></span>
<span id="cb13-32">criterion <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> nn.BCEWithLogitsLoss().to(device)</span></code></pre></div></div>
</div>
<div id="cell-26" class="cell">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb14" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb14-1">animation: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">list</span>[Tensor] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> []</span>
<span id="cb14-2"></span>
<span id="cb14-3">g_losses: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">list</span>[<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">float</span>] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> []</span>
<span id="cb14-4">d_losses: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">list</span>[<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">float</span>] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> []</span>
<span id="cb14-5"></span>
<span id="cb14-6"><span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> _ <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> tqdm(<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">range</span>(N_EPOCHS), unit<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"epochs"</span>):</span>
<span id="cb14-7">    generator.train()</span>
<span id="cb14-8">    discriminator.train()</span>
<span id="cb14-9"></span>
<span id="cb14-10">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> samples_real, _ <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> dataloader:</span>
<span id="cb14-11">        g_loss, d_loss <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> train_step(</span>
<span id="cb14-12">            generator,</span>
<span id="cb14-13">            discriminator,</span>
<span id="cb14-14">            optimizer_G,</span>
<span id="cb14-15">            optimizer_D,</span>
<span id="cb14-16">            criterion,</span>
<span id="cb14-17">            samples_real,</span>
<span id="cb14-18">            NOISE_DIM,</span>
<span id="cb14-19">            device,</span>
<span id="cb14-20">        )</span>
<span id="cb14-21"></span>
<span id="cb14-22">        g_losses.append(g_loss)</span>
<span id="cb14-23">        d_losses.append(d_loss)</span>
<span id="cb14-24"></span>
<span id="cb14-25">    generator.<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">eval</span>()</span>
<span id="cb14-26">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">with</span> torch.inference_mode():</span>
<span id="cb14-27">        images <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> generator(benchmark_noise)</span>
<span id="cb14-28">        images <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> images.cpu()</span>
<span id="cb14-29"></span>
<span id="cb14-30">        images <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> make_grid(images, nrow<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">16</span>, normalize<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>)</span>
<span id="cb14-31">        images <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> (images <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">255</span>).clamp(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">255</span>).to(torch.uint8)</span>
<span id="cb14-32"></span>
<span id="cb14-33">        animation.append(images)</span></code></pre></div></div>
<div class="cell-output cell-output-stderr">
<pre><code>100%|██████████| 100/100 [05:44&lt;00:00,  3.45s/epochs]</code></pre>
</div>
</div>
<div id="cell-27" class="cell">
<div class="cell-output cell-output-display">
<div class="quarto-figure quarto-figure-center">
<figure class="figure">
<p><img src="https://gcerar.github.io/posts/2023-10-10-vanilla-GANs/index_files/figure-html/cell-17-output-1.png" width="690" height="422" class="figure-img"></p>
<figcaption>Generator and Discriminator loss evolution over epochs using Vanilla GAN on the MNIST digit dataset.</figcaption>
</figure>
</div>
</div>
</div>
<div class="quarto-video"><video id="video_shortcode_videojs_video1" class="video-js vjs-default-skin vjs-big-play-centered vjs-fluid" controls="" preload="auto" data-setup="{}" title=""><source src="./figures/gan-mnist.mp4"></video></div>
</section>
<section id="fashion-mnist-dataset" class="level3">
<h3 class="anchored" data-anchor-id="fashion-mnist-dataset">Fashion MNIST Dataset</h3>
<p>The Fashion MNIST dataset is a collection of grayscale images of 10 different categories of clothing items, designed as a more challenging alternative to the classic MNIST dataset of handwritten digits. Each image in the dataset is 28x28 pixels. The 10 categories include items like t-shirts/tops, trousers, pullovers, dresses, coats, sandals, and more. With 70,000 images, Fashion MNIST is commonly used for benchmarking machine learning algorithms, especially in image classification tasks.</p>
<div id="cell-31" class="cell">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb16" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb16-1">NOISE_DIM: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">100</span></span></code></pre></div></div>
</div>
<div id="cell-32" class="cell">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb17" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb17-1"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> get_mnist_fashion_dataset(transform: T.Compose <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">|</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span>) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-&gt;</span> Dataset:</span>
<span id="cb17-2">    <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">from</span> torchvision.datasets <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> FashionMNIST</span>
<span id="cb17-3"></span>
<span id="cb17-4">    root <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">str</span>(DATASET_PATH)</span>
<span id="cb17-5">    trainset <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> FashionMNIST(root<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>root, train<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>, download<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>, transform<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>transform)</span>
<span id="cb17-6">    testset <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> FashionMNIST(root<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>root, train<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">False</span>, download<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>, transform<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>transform)</span>
<span id="cb17-7">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Combine train and test dataset for more samples.</span></span>
<span id="cb17-8">    dataset <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> ConcatDataset([trainset, testset])</span>
<span id="cb17-9">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> dataset</span></code></pre></div></div>
</div>
<div id="cell-33" class="cell">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb18" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb18-1">transform <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> T.Compose([T.ToTensor(), T.Normalize(<span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.5</span>, <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.5</span>)])</span>
<span id="cb18-2"></span>
<span id="cb18-3">data <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> get_mnist_fashion_dataset(transform<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>transform)</span>
<span id="cb18-4">dataloader <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> DataLoader(data, <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">**</span>loader_kwargs)</span>
<span id="cb18-5"></span>
<span id="cb18-6"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># set seed for random generators</span></span>
<span id="cb18-7">set_random_seed(seed<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>SEED)</span>
<span id="cb18-8"></span>
<span id="cb18-9"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># benchmark_noise is used for the animation to show how output evolve on same vector</span></span>
<span id="cb18-10">benchmark_noise <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.randn(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">16</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">16</span>, NOISE_DIM, device<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>device)</span>
<span id="cb18-11"></span>
<span id="cb18-12">generator <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> Generator(out_dim<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>IMG_DIM, nz<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>NOISE_DIM).to(device)</span>
<span id="cb18-13">generator.<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">apply</span>(weights_init)</span>
<span id="cb18-14"></span>
<span id="cb18-15">discriminator <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> Discriminator(input_dim<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>IMG_DIM).to(device)</span>
<span id="cb18-16">discriminator.<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">apply</span>(weights_init)</span>
<span id="cb18-17"></span>
<span id="cb18-18">optimizer_G <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> optim.AdamW(</span>
<span id="cb18-19">    generator.parameters(),</span>
<span id="cb18-20">    lr<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>OPTIMIZER_LR,</span>
<span id="cb18-21">    betas<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>OPTIMIZER_BETAS,</span>
<span id="cb18-22">    weight_decay<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>L2_NORM,</span>
<span id="cb18-23">)</span>
<span id="cb18-24"></span>
<span id="cb18-25">optimizer_D <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> optim.AdamW(</span>
<span id="cb18-26">    discriminator.parameters(),</span>
<span id="cb18-27">    lr<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>OPTIMIZER_LR,</span>
<span id="cb18-28">    betas<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>OPTIMIZER_BETAS,</span>
<span id="cb18-29">    weight_decay<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>L2_NORM,</span>
<span id="cb18-30">)</span>
<span id="cb18-31"></span>
<span id="cb18-32">criterion <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> nn.BCEWithLogitsLoss().to(device)</span></code></pre></div></div>
</div>
<div id="cell-34" class="cell">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb19" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb19-1">animation <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> []</span>
<span id="cb19-2"></span>
<span id="cb19-3">g_losses, d_losses <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> [], []</span>
<span id="cb19-4"><span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> _ <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> tqdm(<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">range</span>(N_EPOCHS), unit<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"epochs"</span>):</span>
<span id="cb19-5">    generator.train()</span>
<span id="cb19-6">    discriminator.train()</span>
<span id="cb19-7"></span>
<span id="cb19-8">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> samples_real, _ <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> dataloader:</span>
<span id="cb19-9">        g_loss, d_loss <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> train_step(</span>
<span id="cb19-10">            generator,</span>
<span id="cb19-11">            discriminator,</span>
<span id="cb19-12">            optimizer_G,</span>
<span id="cb19-13">            optimizer_D,</span>
<span id="cb19-14">            criterion,</span>
<span id="cb19-15">            samples_real,</span>
<span id="cb19-16">            NOISE_DIM,</span>
<span id="cb19-17">            device,</span>
<span id="cb19-18">        )</span>
<span id="cb19-19"></span>
<span id="cb19-20">        g_losses.append(g_loss)</span>
<span id="cb19-21">        d_losses.append(d_loss)</span>
<span id="cb19-22"></span>
<span id="cb19-23">    generator.<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">eval</span>()</span>
<span id="cb19-24">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">with</span> torch.inference_mode():</span>
<span id="cb19-25">        images <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> generator(benchmark_noise)</span>
<span id="cb19-26">        images <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> images.cpu()</span>
<span id="cb19-27"></span>
<span id="cb19-28">        images <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> make_grid(images, nrow<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">16</span>, normalize<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>)</span>
<span id="cb19-29">        images <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> (images <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">255</span>).clamp(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">255</span>).to(torch.uint8)</span>
<span id="cb19-30"></span>
<span id="cb19-31">        animation.append(images)</span></code></pre></div></div>
<div class="cell-output cell-output-stderr">
<pre><code>100%|██████████| 100/100 [05:47&lt;00:00,  3.47s/epochs]</code></pre>
</div>
</div>
<div id="cell-35" class="cell">
<div class="cell-output cell-output-display">
<div class="quarto-figure quarto-figure-center">
<figure class="figure">
<p><img src="https://gcerar.github.io/posts/2023-10-10-vanilla-GANs/index_files/figure-html/cell-23-output-1.png" width="690" height="422" class="figure-img"></p>
<figcaption>Generator and Discriminator loss evolution over epochs using Vanilla GAN on the MNIST fashion dataset.</figcaption>
</figure>
</div>
</div>
</div>
<div class="quarto-video"><video id="video_shortcode_videojs_video2" class="video-js vjs-default-skin vjs-big-play-centered vjs-fluid" controls="" preload="auto" data-setup="{}" title=""><source src="./figures/gan-fashion.mp4"></video></div>
</section>
</section>
</section>
<section id="dcgan" class="level1">
<h1>DCGAN</h1>
<p>DCGAN, short for Deep Convolutional Generative Adversarial Network, differs from vanilla GAN by using convolutional layers. This design makes DCGAN better for image data. With specific architectural guidelines, DCGAN trains more consistently and generates clearer images than vanilla GANs across various hyperparameters.</p>
<section id="setting-up-dcgans" class="level2">
<h2 class="anchored" data-anchor-id="setting-up-dcgans">Setting Up DCGANs</h2>
<section id="generator-1" class="level3">
<h3 class="anchored" data-anchor-id="generator-1">Generator</h3>
<div id="cell-41" class="cell">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb21" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb21-1"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">class</span> Generator(nn.Module):</span>
<span id="cb21-2">    <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> <span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">__init__</span>(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, nz: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">100</span>, ngf: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">32</span>, nc: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>):</span>
<span id="cb21-3">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">"""</span></span>
<span id="cb21-4"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">        :param nz: size of the latent z vector</span></span>
<span id="cb21-5"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">        :param ngf: size of feature maps in generator</span></span>
<span id="cb21-6"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">        :param nc: number of channels in the training images.</span></span>
<span id="cb21-7"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">        """</span></span>
<span id="cb21-8">        <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">super</span>().<span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">__init__</span>()</span>
<span id="cb21-9">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.layers <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> nn.Sequential(</span>
<span id="cb21-10">            nn.ConvTranspose2d(nz, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">4</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> ngf, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">4</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>, bias<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">False</span>),</span>
<span id="cb21-11">            nn.BatchNorm2d(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">4</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> ngf),</span>
<span id="cb21-12">            nn.ReLU(inplace<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>),</span>
<span id="cb21-13">            nn.ConvTranspose2d(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">4</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> ngf, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> ngf, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">3</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, bias<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">False</span>),</span>
<span id="cb21-14">            nn.BatchNorm2d(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> ngf),</span>
<span id="cb21-15">            nn.ReLU(inplace<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>),</span>
<span id="cb21-16">            nn.ConvTranspose2d(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> ngf, ngf, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">4</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, bias<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">False</span>),</span>
<span id="cb21-17">            nn.BatchNorm2d(ngf),</span>
<span id="cb21-18">            nn.ReLU(inplace<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>),</span>
<span id="cb21-19">            nn.ConvTranspose2d(ngf, nc, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">4</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, bias<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">False</span>),</span>
<span id="cb21-20">            nn.Tanh(),</span>
<span id="cb21-21">        )</span>
<span id="cb21-22"></span>
<span id="cb21-23">    <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> forward(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, x: Tensor) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-&gt;</span> Tensor:</span>
<span id="cb21-24">        x <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.reshape(x, (x.size(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>), <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>))</span>
<span id="cb21-25">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.layers(x)</span>
<span id="cb21-26"></span>
<span id="cb21-27"></span>
<span id="cb21-28">summary(Generator(), input_size<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">128</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">100</span>))</span></code></pre></div></div>
<div class="cell-output cell-output-display">
<pre><code>==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
Generator                                [128, 1, 28, 28]          --
├─Sequential: 1-1                        [128, 1, 28, 28]          --
│    └─ConvTranspose2d: 2-1              [128, 128, 4, 4]          204,800
│    └─BatchNorm2d: 2-2                  [128, 128, 4, 4]          256
│    └─ReLU: 2-3                         [128, 128, 4, 4]          --
│    └─ConvTranspose2d: 2-4              [128, 64, 7, 7]           73,728
│    └─BatchNorm2d: 2-5                  [128, 64, 7, 7]           128
│    └─ReLU: 2-6                         [128, 64, 7, 7]           --
│    └─ConvTranspose2d: 2-7              [128, 32, 14, 14]         32,768
│    └─BatchNorm2d: 2-8                  [128, 32, 14, 14]         64
│    └─ReLU: 2-9                         [128, 32, 14, 14]         --
│    └─ConvTranspose2d: 2-10             [128, 1, 28, 28]          512
│    └─Tanh: 2-11                        [128, 1, 28, 28]          --
==========================================================================================
Total params: 312,256
Trainable params: 312,256
Non-trainable params: 0
Total mult-adds (Units.GIGABYTES): 1.76
==========================================================================================
Input size (MB): 0.05
Forward/backward pass size (MB): 24.26
Params size (MB): 1.25
Estimated Total Size (MB): 25.56
==========================================================================================</code></pre>
</div>
</div>
</section>
<section id="discriminator-1" class="level3">
<h3 class="anchored" data-anchor-id="discriminator-1">Discriminator</h3>
<div id="cell-43" class="cell">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb23" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb23-1"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">class</span> Discriminator(nn.Module):</span>
<span id="cb23-2">    <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> <span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">__init__</span>(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, ndf: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">32</span>, nc: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, alpha: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">float</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.2</span>):</span>
<span id="cb23-3">        <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">super</span>().<span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">__init__</span>()</span>
<span id="cb23-4">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.layers <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> nn.Sequential(</span>
<span id="cb23-5">            nn.Conv2d(nc, ndf, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">4</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, bias<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">False</span>),</span>
<span id="cb23-6">            nn.BatchNorm2d(ndf),</span>
<span id="cb23-7">            nn.LeakyReLU(alpha, inplace<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>),</span>
<span id="cb23-8">            nn.Conv2d(ndf, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> ndf, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">4</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, bias<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">False</span>),</span>
<span id="cb23-9">            nn.BatchNorm2d(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> ndf),</span>
<span id="cb23-10">            nn.LeakyReLU(alpha, inplace<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>),</span>
<span id="cb23-11">            nn.Conv2d(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> ndf, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">4</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> ndf, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">3</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, bias<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">False</span>),</span>
<span id="cb23-12">            nn.BatchNorm2d(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">4</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> ndf),</span>
<span id="cb23-13">            nn.LeakyReLU(alpha, inplace<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>),</span>
<span id="cb23-14">            nn.Conv2d(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">4</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> ndf, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">4</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>, bias<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">False</span>),</span>
<span id="cb23-15">            <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># nn.Sigmoid(),</span></span>
<span id="cb23-16">        )</span>
<span id="cb23-17"></span>
<span id="cb23-18">    <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> forward(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, x: Tensor) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-&gt;</span> Tensor:</span>
<span id="cb23-19">        x <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.layers(x)</span>
<span id="cb23-20">        x <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.reshape(x, (x.size(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>), <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>))</span>
<span id="cb23-21">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> x</span>
<span id="cb23-22"></span>
<span id="cb23-23"></span>
<span id="cb23-24">summary(Discriminator(), input_size<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>(BATCH_SIZE, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">28</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">28</span>))</span></code></pre></div></div>
<div class="cell-output cell-output-display">
<pre><code>==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
Discriminator                            [128, 1]                  --
├─Sequential: 1-1                        [128, 1, 1, 1]            --
│    └─Conv2d: 2-1                       [128, 32, 14, 14]         512
│    └─BatchNorm2d: 2-2                  [128, 32, 14, 14]         64
│    └─LeakyReLU: 2-3                    [128, 32, 14, 14]         --
│    └─Conv2d: 2-4                       [128, 64, 7, 7]           32,768
│    └─BatchNorm2d: 2-5                  [128, 64, 7, 7]           128
│    └─LeakyReLU: 2-6                    [128, 64, 7, 7]           --
│    └─Conv2d: 2-7                       [128, 128, 4, 4]          73,728
│    └─BatchNorm2d: 2-8                  [128, 128, 4, 4]          256
│    └─LeakyReLU: 2-9                    [128, 128, 4, 4]          --
│    └─Conv2d: 2-10                      [128, 1, 1, 1]            2,048
==========================================================================================
Total params: 109,504
Trainable params: 109,504
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 369.68
==========================================================================================
Input size (MB): 0.40
Forward/backward pass size (MB): 23.46
Params size (MB): 0.44
Estimated Total Size (MB): 24.30
==========================================================================================</code></pre>
</div>
</div>
</section>
</section>
<section id="evaluation-1" class="level2">
<h2 class="anchored" data-anchor-id="evaluation-1">Evaluation</h2>
<section id="mnist-digits-dataset-1" class="level3">
<h3 class="anchored" data-anchor-id="mnist-digits-dataset-1">MNIST Digits Dataset</h3>
<div id="cell-46" class="cell">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb25" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb25-1">NOISE_DIM <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">128</span></span>
<span id="cb25-2"></span>
<span id="cb25-3">transform <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> T.Compose(</span>
<span id="cb25-4">    [</span>
<span id="cb25-5">        T.ToTensor(),</span>
<span id="cb25-6">        T.Normalize(<span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.5</span>, <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.5</span>),</span>
<span id="cb25-7">    ]</span>
<span id="cb25-8">)</span>
<span id="cb25-9"></span>
<span id="cb25-10">data <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> get_mnist_dataset(transform)</span>
<span id="cb25-11">dataloader <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> DataLoader(data, <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">**</span>loader_kwargs)</span>
<span id="cb25-12"></span>
<span id="cb25-13"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># set seed for random generators</span></span>
<span id="cb25-14">set_random_seed()</span>
<span id="cb25-15"></span>
<span id="cb25-16"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># benchmark_noise is used for the animation to show how output evolve on same vector</span></span>
<span id="cb25-17">benchmark_noise <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.randn(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">16</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">16</span>, NOISE_DIM, device<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>device)</span>
<span id="cb25-18"></span>
<span id="cb25-19">generator <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> Generator(nz<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>NOISE_DIM, ngf<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">32</span>, nc<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>IMG_DIM[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>]).to(device)</span>
<span id="cb25-20">generator.<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">apply</span>(weights_init)</span>
<span id="cb25-21"></span>
<span id="cb25-22">discriminator <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> Discriminator(ndf<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">32</span>, nc<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>IMG_DIM[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>]).to(device)</span>
<span id="cb25-23">discriminator.<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">apply</span>(weights_init)</span>
<span id="cb25-24"></span>
<span id="cb25-25">optimizer_G <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> optim.AdamW(</span>
<span id="cb25-26">    generator.parameters(),</span>
<span id="cb25-27">    lr<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>OPTIMIZER_LR,</span>
<span id="cb25-28">    betas<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>OPTIMIZER_BETAS,</span>
<span id="cb25-29">    weight_decay<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>L2_NORM,</span>
<span id="cb25-30">)</span>
<span id="cb25-31"></span>
<span id="cb25-32">optimizer_D <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> optim.AdamW(</span>
<span id="cb25-33">    discriminator.parameters(),</span>
<span id="cb25-34">    lr<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>OPTIMIZER_LR,</span>
<span id="cb25-35">    betas<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>OPTIMIZER_BETAS,</span>
<span id="cb25-36">    weight_decay<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>L2_NORM,</span>
<span id="cb25-37">)</span>
<span id="cb25-38"></span>
<span id="cb25-39">criterion <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> nn.BCEWithLogitsLoss().to(device)</span></code></pre></div></div>
</div>
<div id="cell-47" class="cell">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb26" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb26-1">animation <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> []</span>
<span id="cb26-2"></span>
<span id="cb26-3">g_losses, d_losses <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> [], []</span>
<span id="cb26-4"><span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> _ <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> tqdm(<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">range</span>(N_EPOCHS), unit<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"epochs"</span>):</span>
<span id="cb26-5">    generator.train()</span>
<span id="cb26-6">    discriminator.train()</span>
<span id="cb26-7"></span>
<span id="cb26-8">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> samples_real, _ <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> dataloader:</span>
<span id="cb26-9">        g_loss, d_loss <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> train_step(</span>
<span id="cb26-10">            generator,</span>
<span id="cb26-11">            discriminator,</span>
<span id="cb26-12">            optimizer_G,</span>
<span id="cb26-13">            optimizer_D,</span>
<span id="cb26-14">            criterion,</span>
<span id="cb26-15">            samples_real,</span>
<span id="cb26-16">            NOISE_DIM,</span>
<span id="cb26-17">            device,</span>
<span id="cb26-18">        )</span>
<span id="cb26-19"></span>
<span id="cb26-20">        g_losses.append(g_loss)</span>
<span id="cb26-21">        d_losses.append(d_loss)</span>
<span id="cb26-22"></span>
<span id="cb26-23">    generator.<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">eval</span>()</span>
<span id="cb26-24">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">with</span> torch.inference_mode():</span>
<span id="cb26-25">        images <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> generator(benchmark_noise)</span>
<span id="cb26-26">        images <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> images.cpu()</span>
<span id="cb26-27"></span>
<span id="cb26-28">        images <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> make_grid(images, nrow<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">16</span>, normalize<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>)</span>
<span id="cb26-29">        images <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> (images <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">255</span>).clamp(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">255</span>).to(torch.uint8)</span>
<span id="cb26-30"></span>
<span id="cb26-31">        animation.append(images)</span></code></pre></div></div>
<div class="cell-output cell-output-stderr">
<pre><code>100%|██████████| 100/100 [04:50&lt;00:00,  2.91s/epochs]</code></pre>
</div>
</div>
<div id="cell-48" class="cell">
<div class="cell-output cell-output-display">
<div class="quarto-figure quarto-figure-center">
<figure class="figure">
<p><img src="https://gcerar.github.io/posts/2023-10-10-vanilla-GANs/index_files/figure-html/cell-29-output-1.png" width="690" height="422" class="figure-img"></p>
<figcaption>Generator and Discriminator loss evolution over epochs using DCGAN on the MNIST digit dataset.</figcaption>
</figure>
</div>
</div>
</div>
</section>
<section id="mnist-fashion-dataset" class="level3">
<h3 class="anchored" data-anchor-id="mnist-fashion-dataset">MNIST Fashion Dataset</h3>
<div id="cell-51" class="cell">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb28" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb28-1">NOISE_DIM <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">128</span></span>
<span id="cb28-2"></span>
<span id="cb28-3">transform <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> T.Compose(</span>
<span id="cb28-4">    [</span>
<span id="cb28-5">        T.ToTensor(),</span>
<span id="cb28-6">        T.Normalize(<span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.5</span>, <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.5</span>),</span>
<span id="cb28-7">    ]</span>
<span id="cb28-8">)</span>
<span id="cb28-9"></span>
<span id="cb28-10">data <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> get_mnist_fashion_dataset(transform)</span>
<span id="cb28-11">dataloader <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> DataLoader(data, <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">**</span>loader_kwargs)</span>
<span id="cb28-12"></span>
<span id="cb28-13"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># set seed for random generators</span></span>
<span id="cb28-14">set_random_seed()</span>
<span id="cb28-15"></span>
<span id="cb28-16"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># benchmark_noise is used for the animation to show how output evolve on same vector</span></span>
<span id="cb28-17">benchmark_noise <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.randn(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">16</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">16</span>, NOISE_DIM, device<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>device)</span>
<span id="cb28-18"></span>
<span id="cb28-19">generator <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> Generator(nz<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>NOISE_DIM, ngf<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">32</span>, nc<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>IMG_DIM[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>]).to(device)</span>
<span id="cb28-20">generator.<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">apply</span>(weights_init)</span>
<span id="cb28-21"></span>
<span id="cb28-22">discriminator <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> Discriminator(ndf<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">32</span>, nc<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>IMG_DIM[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>]).to(device)</span>
<span id="cb28-23">discriminator.<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">apply</span>(weights_init)</span>
<span id="cb28-24"></span>
<span id="cb28-25">optimizer_G <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> optim.AdamW(</span>
<span id="cb28-26">    generator.parameters(),</span>
<span id="cb28-27">    lr<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>OPTIMIZER_LR,</span>
<span id="cb28-28">    betas<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>OPTIMIZER_BETAS,</span>
<span id="cb28-29">    weight_decay<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>L2_NORM,</span>
<span id="cb28-30">)</span>
<span id="cb28-31"></span>
<span id="cb28-32">optimizer_D <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> optim.AdamW(</span>
<span id="cb28-33">    discriminator.parameters(),</span>
<span id="cb28-34">    lr<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>OPTIMIZER_LR,</span>
<span id="cb28-35">    betas<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>OPTIMIZER_BETAS,</span>
<span id="cb28-36">    weight_decay<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>L2_NORM,</span>
<span id="cb28-37">)</span>
<span id="cb28-38"></span>
<span id="cb28-39">criterion <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> nn.BCEWithLogitsLoss().to(device)</span></code></pre></div></div>
</div>
<div id="cell-52" class="cell">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb29" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb29-1">animation <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> []</span>
<span id="cb29-2"></span>
<span id="cb29-3">g_losses, d_losses <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> [], []</span>
<span id="cb29-4"><span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> _ <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> tqdm(<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">range</span>(N_EPOCHS), unit<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"epochs"</span>):</span>
<span id="cb29-5">    generator.train()</span>
<span id="cb29-6">    discriminator.train()</span>
<span id="cb29-7"></span>
<span id="cb29-8">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> samples_real, _ <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> dataloader:</span>
<span id="cb29-9">        g_loss, d_loss <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> train_step(</span>
<span id="cb29-10">            generator,</span>
<span id="cb29-11">            discriminator,</span>
<span id="cb29-12">            optimizer_G,</span>
<span id="cb29-13">            optimizer_D,</span>
<span id="cb29-14">            criterion,</span>
<span id="cb29-15">            samples_real,</span>
<span id="cb29-16">            NOISE_DIM,</span>
<span id="cb29-17">            device,</span>
<span id="cb29-18">        )</span>
<span id="cb29-19"></span>
<span id="cb29-20">        g_losses.append(g_loss)</span>
<span id="cb29-21">        d_losses.append(d_loss)</span>
<span id="cb29-22"></span>
<span id="cb29-23">    generator.<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">eval</span>()</span>
<span id="cb29-24">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">with</span> torch.inference_mode():</span>
<span id="cb29-25">        images <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> generator(benchmark_noise)</span>
<span id="cb29-26">        images <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> images.cpu()</span>
<span id="cb29-27"></span>
<span id="cb29-28">        images <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> make_grid(images, nrow<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">16</span>, normalize<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>)</span>
<span id="cb29-29">        images <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> (images <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">255</span>).clamp(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">255</span>).to(torch.uint8)</span>
<span id="cb29-30"></span>
<span id="cb29-31">        animation.append(images)</span></code></pre></div></div>
<div class="cell-output cell-output-stderr">
<pre><code>100%|██████████| 100/100 [04:55&lt;00:00,  2.96s/epochs]</code></pre>
</div>
</div>
<div id="cell-53" class="cell">
<div class="cell-output cell-output-display">
<div class="quarto-figure quarto-figure-center">
<figure class="figure">
<p><img src="https://gcerar.github.io/posts/2023-10-10-vanilla-GANs/index_files/figure-html/cell-33-output-1.png" width="690" height="422" class="figure-img"></p>
<figcaption>Generator and Discriminator loss evolution over epochs using DCGAN on the MNIST fashion dataset.</figcaption>
</figure>
</div>
</div>
</div>
<div class="quarto-video"><video id="video_shortcode_videojs_video3" class="video-js vjs-default-skin vjs-big-play-centered vjs-fluid" controls="" preload="auto" data-setup="{}" title=""><source src="./figures/dcgan-fashion.mp4"></video></div>
</section>
</section>
</section>
<section id="conclusion" class="level1">
<h1>Conclusion</h1>
<p>Generative Adversarial Networks (GANs) represent an innovative class of unsupervised neural networks that have significantly impacted the field of artificial intelligence (AI). They consist of two components: a Generator that improves its output and a Discriminator that enhances its evaluative skills. In a competitive yet symbiotic relationship, these two networks converge towards a dynamic equilibrium. This interaction exemplifies the strength of GANs and the adaptability of adversarial learning in AI, blending creative generation with critical assessment.</p>
<p>In this post, I explore the original GAN, often referred to as the “vanilla” GAN. My goal was to understand the basic mechanics of GANs. Meanwhile, others have advanced this technology, applying it to a range of innovative and fascinating new areas.</p>
<ul>
<li><a href="https://machinelearningmastery.com/impressive-applications-of-generative-adversarial-networks/">18 Impressive Applications of Generative Adversarial Networks (GANs)</a></li>
</ul>
<div class="callout callout-style-default callout-tip callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
Tip
</div>
</div>
<div class="callout-body-container callout-body">
<p><strong>TODO for future refactoring</strong>: Replace <code>nn.Sigmoid()</code> in both Discriminator classes with raw logits and use <code>nn.BCEWithLogitsLoss()</code> instead of <code>nn.BCELoss()</code>. This combines the Sigmoid activation with the binary cross-entropy loss in a single numerically stable operation, following PyTorch best practices for GAN training.</p>
</div>
</div>



</section>

<div id="quarto-appendix" class="default"><section class="quarto-appendix-contents" id="quarto-bibliography"><h2 class="anchored quarto-appendix-heading">References</h2><div id="refs" class="references csl-bib-body hanging-indent">
<div id="ref-goodfellow2014generative" class="csl-entry">
Goodfellow, Ian J., Jean Pouget-Abadie, Mehdi Mirza, et al. 2014. <em>Generative Adversarial Networks</em>. <a href="https://arxiv.org/abs/1406.2661">https://arxiv.org/abs/1406.2661</a>.
</div>
</div></section><section class="quarto-appendix-contents" id="quarto-reuse"><h2 class="anchored quarto-appendix-heading">Reuse</h2><div class="quarto-appendix-contents"><div><a rel="license" href="https://creativecommons.org/licenses/by-nc-sa/4.0/deed.en_us">CC BY-NC-SA 4.0</a></div></div></section></div> ]]></description>
  <category>pytorch</category>
  <category>GAN</category>
  <guid>https://gcerar.github.io/posts/2023-10-10-vanilla-GANs/</guid>
  <pubDate>Mon, 09 Oct 2023 22:00:00 GMT</pubDate>
  <media:content url="https://gcerar.github.io/posts/2023-10-10-vanilla-GANs/featured.webp" medium="image" type="image/webp"/>
</item>
<item>
  <title>Neural Style Transfer</title>
  <dc:creator>Gregor Cerar</dc:creator>
  <link>https://gcerar.github.io/posts/2023-09-15-neural-style-transfer/</link>
  <description><![CDATA[ 





<section id="introduction" class="level1">
<h1>Introduction</h1>
<p><a href="https://en.wikipedia.org/wiki/Neural_style_transfer">Neural Style Transfer</a> (NST) is a deep learning technique that combines the <strong>content</strong> of one image with the <strong>style</strong> of another, like giving your photo a Van Gogh-esque makeover.</p>
<p>Using convolutional neural networks, NST examines both images’ features and creates a new image that merges the content’s structure with the style’s attributes. This technique became a hit due to its novel outcomes, leading to its adoption in various apps and platforms and highlighting deep learning’s prowess in image transformation.</p>
<p>Introduced initially in “<a href="https://arxiv.org/abs/1508.06576">A Neural Algorithm of Artistic Style</a>” <span class="citation" data-cites="gatys2015neural">(Gatys et al. 2015)</span>, this method transfers art styles between images. Eager to learn how it works, I’ve implemented the original approach from scratch and presented a few cherry-picked transformed examples.</p>
</section>
<section id="prerequisites" class="level1">
<h1>Prerequisites</h1>
<p>Before we get started, we need to install <a href="https://numpy.org/">NumPy</a>, <a href="https://matplotlib.org/">Matplotlib</a>, <a href="https://pytorch.org/">PyTorch</a> deep learning framework, and finally, <a href="https://pytorch.org/vision/stable/index.html">Torchvision</a> library.</p>
<div id="cell-3" class="cell">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb1" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb1-1"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">from</span> collections.abc <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> Iterable, Sequence</span>
<span id="cb1-2"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">from</span> pathlib <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> Path</span>
<span id="cb1-3"></span>
<span id="cb1-4"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> numpy <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">as</span> np</span>
<span id="cb1-5"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">from</span> matplotlib <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> pyplot <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">as</span> plt</span>
<span id="cb1-6"></span>
<span id="cb1-7"><span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">%</span>config InlineBackend.figure_formats <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> {<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'retina'</span>, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'png'</span>}</span>
<span id="cb1-8"></span>
<span id="cb1-9"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> torch</span>
<span id="cb1-10"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">from</span> torch <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> Tensor, nn, optim</span>
<span id="cb1-11"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">from</span> torch.nn <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> functional <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">as</span> F</span>
<span id="cb1-12"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">from</span> torchvision <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> models</span>
<span id="cb1-13"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">from</span> torchvision.io <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> decode_image</span>
<span id="cb1-14"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">from</span> torchvision.transforms <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> functional <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">as</span> VF</span>
<span id="cb1-15"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">from</span> torchvision.transforms <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> v2 <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">as</span> T</span>
<span id="cb1-16"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">from</span> torchvision.utils <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> make_grid</span>
<span id="cb1-17"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">from</span> tqdm <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> tqdm</span>
<span id="cb1-18"></span>
<span id="cb1-19"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Random seed for reproducibility</span></span>
<span id="cb1-20">SEED <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">42</span></span>
<span id="cb1-21"></span>
<span id="cb1-22"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Size of the output image</span></span>
<span id="cb1-23">IMG_SIZE <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">512</span></span></code></pre></div></div>
</div>
<p>Although it is possible to run neural networks on a CPU, using compute accelerators, such as GPU, will do transformation much faster. Here, I utilize my NVIDIA RTX 3090, where I also took advantage of available tensor cores and reduced precision data type <a href="https://en.wikipedia.org/wiki/Bfloat16_floating-point_format">bfloat16</a> for faster transformation.</p>
<div id="cell-6" class="cell">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb2" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb2-1">AMP_ENABLED <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">False</span></span>
<span id="cb2-2"></span>
<span id="cb2-3">device <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.device(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"cpu"</span>)</span>
<span id="cb2-4"></span>
<span id="cb2-5"><span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> torch.cuda.is_available():</span>
<span id="cb2-6">    device <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.device(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"cuda"</span>)</span>
<span id="cb2-7"></span>
<span id="cb2-8">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> torch.cuda.is_bf16_supported():</span>
<span id="cb2-9">        AMP_ENABLED <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span></span></code></pre></div></div>
</div>
</section>
<section id="implementation" class="level1">
<h1>Implementation</h1>
<div id="fig-architecture" class="quarto-float quarto-figure quarto-figure-center anchored">
<figure class="quarto-float quarto-float-fig figure">
<div aria-describedby="fig-architecture-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
<img src="https://gcerar.github.io/posts/2023-09-15-neural-style-transfer/figures/neural-style-transfer.png" class="img-fluid figure-img">
</div>
<figcaption class="quarto-float-caption-bottom quarto-float-caption quarto-float-fig" id="fig-architecture-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
Figure&nbsp;1: The Neural Style Transfer framework introduced by Gatys <em>et al.</em> distinguishes style and content features from designated layers.
</figcaption>
</figure>
</div>
<p>Implementing NST was initially confusing since it does not follow the typical boilerplate used in deep learning. In the following sections, I’ll delve into its implementation step by step and often refer back to Figure&nbsp;1. The steps are as follows:</p>
<ul>
<li>Prepare the content, style, and target images.</li>
<li>Prepare a pre-trained VGG neural network and prevent changes to its weights.</li>
<li>Introduce three unique loss metrics.</li>
<li>Adjust the neural network to extract features during forward-backward passes, applying gradient modifications to the target image. The neural network stays unchanged in the process.</li>
<li>Iterate through this process.</li>
</ul>
<div id="cell-8" class="cell">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb3" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb3-1"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Weights for different features (were these used by original authors?)</span></span>
<span id="cb3-2">STYLE_LAYERS_DEFAULT <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> {</span>
<span id="cb3-3">    <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"conv1_1"</span>: <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.75</span>,</span>
<span id="cb3-4">    <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"conv2_1"</span>: <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.5</span>,</span>
<span id="cb3-5">    <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"conv3_1"</span>: <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.2</span>,</span>
<span id="cb3-6">    <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"conv4_1"</span>: <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.2</span>,</span>
<span id="cb3-7">    <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"conv5_1"</span>: <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.2</span>,</span>
<span id="cb3-8">}</span>
<span id="cb3-9"></span>
<span id="cb3-10">CONTENT_LAYERS_DEFAULT <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> (<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"conv5_2"</span>,)</span>
<span id="cb3-11"></span>
<span id="cb3-12">CONTENT_WEIGHT <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">8</span>  <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># "alpha" in the literature (default: 8)</span></span>
<span id="cb3-13">STYLE_WEIGHT <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">70</span>  <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># "beta" in the literature (default: 70)</span></span>
<span id="cb3-14">TV_WEIGHT <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">10</span>  <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># "gamma" in the literature (default: 10)</span></span>
<span id="cb3-15"></span>
<span id="cb3-16"></span>
<span id="cb3-17">LEARNING_RATE <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.004</span></span>
<span id="cb3-18">N_EPOCHS <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">5_000</span></span></code></pre></div></div>
</div>
<section id="loss-metrics" class="level2">
<h2 class="anchored" data-anchor-id="loss-metrics">Loss metrics</h2>
<p>To effectively implement Neural Style Transfer, we need to quantify how well the generated image matches both the <strong>content</strong> and <strong>style</strong> of our source images. This is done using loss metrics. Let’s delve into the specifics of these metrics and how they drive the NST process.</p>
<section id="content-loss-metric" class="level3">
<h3 class="anchored" data-anchor-id="content-loss-metric">Content loss metric</h3>
<p>Content loss is calculated through Euclidean distance (<em>i.e.,</em> mean squared error) between the respective intermediate higher-level feature representation <img src="https://latex.codecogs.com/png.latex?F%5El"> and <img src="https://latex.codecogs.com/png.latex?P%5El"> of original input image <img src="https://latex.codecogs.com/png.latex?%5Cvec%7Bx%7D"> and the content image <img src="https://latex.codecogs.com/png.latex?%5Cvec%7Bp%7D"> at layer <img src="https://latex.codecogs.com/png.latex?l">.</p>
<p>Hence, a given input image <img src="https://latex.codecogs.com/png.latex?%5Cvec%7Bx%7D"> is encoded in each layer of the CNN by the filter responses to that image. A layer with <img src="https://latex.codecogs.com/png.latex?N_l"> distinct filters has <img src="https://latex.codecogs.com/png.latex?N_l"> feature maps of size <img src="https://latex.codecogs.com/png.latex?M_l">, where <img src="https://latex.codecogs.com/png.latex?M_l"> is the height times the width of the feature map. So the response in a layer <img src="https://latex.codecogs.com/png.latex?l"> can be stored in a matrix <img src="https://latex.codecogs.com/png.latex?F%5El%20%5Cin%20%5Cmathcal%7BR%7D%5E%7BN_l%20%5Ctimes%20M_l%7D"> where <img src="https://latex.codecogs.com/png.latex?F_%7Bij%7D%5E%7Bl%7D"> is the activation of the <img src="https://latex.codecogs.com/png.latex?i%5E%7Bth%7D"> filter at position <img src="https://latex.codecogs.com/png.latex?j"> in layer <img src="https://latex.codecogs.com/png.latex?l">.</p>
<p><img src="https://latex.codecogs.com/png.latex?%0A%5Cmathcal%7BL%7D_%7Bcontent%7D(%5Cvec%7Bp%7D,%20%5Cvec%7Bx%7D,%20l)%20=%20%5Cfrac%7B1%7D%7B2%7D%5Csum_%7Bi,j%7D%20(F%5E%7Bl%7D_%7Bij%7D%20-%20P%5E%7Bl%7D_%7Bij%7D)%5E2%0A"></p>
<div id="cell-11" class="cell">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb4" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb4-1"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> content_loss_func(target_features: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">dict</span>[<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">str</span>, Tensor], precomputed_content_features: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">dict</span>[<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">str</span>, Tensor]) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-&gt;</span> Tensor:</span>
<span id="cb4-2">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">"""Calculate content loss metric for give layers."""</span></span>
<span id="cb4-3"></span>
<span id="cb4-4">    device <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">next</span>(<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">iter</span>(target_features.values())).device</span>
<span id="cb4-5">    content_loss <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.tensor(<span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.0</span>, device<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>device)</span>
<span id="cb4-6"></span>
<span id="cb4-7">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> layer <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> precomputed_content_features:</span>
<span id="cb4-8">        target_feature <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> target_features[layer]</span>
<span id="cb4-9">        content_feature <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> precomputed_content_features[layer]</span>
<span id="cb4-10"></span>
<span id="cb4-11">        content_loss <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+=</span> F.mse_loss(target_feature, content_feature)</span>
<span id="cb4-12"></span>
<span id="cb4-13">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> content_loss</span></code></pre></div></div>
</div>
</section>
<section id="style-loss" class="level3">
<h3 class="anchored" data-anchor-id="style-loss">Style loss</h3>
<p>The style loss is more convolved than the content loss. We compute it by comparing the Gram matrices of the feature maps from the style image and the generated image.</p>
<p>First, let’s understand the <a href="https://en.wikipedia.org/wiki/Gram_matrix">Gram matrix</a>. Given the feature map <img src="https://latex.codecogs.com/png.latex?F"> of size <img src="https://latex.codecogs.com/png.latex?C%20%5Ctimes%20(H%20%5Ctimes%20W)">, where <img src="https://latex.codecogs.com/png.latex?C"> is the number of channels and <img src="https://latex.codecogs.com/png.latex?H%20%5Ctimes%20W"> are the spatial dimensions, the Gram matrix <img src="https://latex.codecogs.com/png.latex?G"> is of size <img src="https://latex.codecogs.com/png.latex?C%20%5Ctimes%20C"> and is computed as</p>
<p><img src="https://latex.codecogs.com/png.latex?%0AG%5El_%7Bij%7D%20=%20%5Csum_k%20F%5El_%7Bik%7D%20F%5El_%7Bjk%7D%0A"></p>
<p>where <img src="https://latex.codecogs.com/png.latex?G_%7Bij%7D"> is the inner product between vectorized feature maps <img src="https://latex.codecogs.com/png.latex?i"> and <img src="https://latex.codecogs.com/png.latex?j">. This results in a matrix that captures the correlation between different feature maps and, thus, the style information.</p>
<div id="cell-13" class="cell">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb5" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb5-1"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> gram_matrix(tensor: Tensor) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-&gt;</span> Tensor:</span>
<span id="cb5-2">    (b, c, h, w) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> tensor.size()</span>
<span id="cb5-3"></span>
<span id="cb5-4">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># reshape into (C x (H x W))</span></span>
<span id="cb5-5">    features <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> tensor.view(b <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> c, h <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> w)</span>
<span id="cb5-6"></span>
<span id="cb5-7">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># compute the gram product</span></span>
<span id="cb5-8">    gram <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.mm(features, features.t())</span>
<span id="cb5-9"></span>
<span id="cb5-10">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> gram</span></code></pre></div></div>
</div>
<p>The style loss between the Gram matrix of the generated image <img src="https://latex.codecogs.com/png.latex?G"> and that of style image <img src="https://latex.codecogs.com/png.latex?A"> (at a specific layer <img src="https://latex.codecogs.com/png.latex?l">) is:</p>
<p><img src="https://latex.codecogs.com/png.latex?%0AE_l%20=%20%5Cfrac%7B1%7D%7B4%20N%5E%7B2%7D_%7Bl%7D%20M%5E%7B2%7D_%7Bl%7D%7D%20%5Csum_%7Bi,j%7D(G%5El_%7Bij%7D%20-%20A%5El_%7Bij%7D)%5E2%0A"></p>
<p>Where <img src="https://latex.codecogs.com/png.latex?E_l"> is the style loss for layer <img src="https://latex.codecogs.com/png.latex?l">, <img src="https://latex.codecogs.com/png.latex?N_l"> and <img src="https://latex.codecogs.com/png.latex?M_l"> are the numbers of channels and height times width in the feature representation of layer <img src="https://latex.codecogs.com/png.latex?l">, respectively. <img src="https://latex.codecogs.com/png.latex?G_%7Bij%7D%5El"> and <img src="https://latex.codecogs.com/png.latex?A_%7Bij%7D%5El"> are the gram matrices of the intermediate representation of the style image <img src="https://latex.codecogs.com/png.latex?%5Cvec%7Ba%7D"> and the input base image <img src="https://latex.codecogs.com/png.latex?%5Cvec%7Bx%7D"> respectively.</p>
<p>The total style loss is:</p>
<p><img src="https://latex.codecogs.com/png.latex?%0A%5Cmathcal%7BL%7D_%7Bstyle%7D(%5Cvec%7Ba%7D,%20%5Cvec%7Bx%7D)%20=%20%5Csum_%7Bl=0%7D%5E%7BL%7D%20w_l%20E_l%0A"></p>
<div id="cell-15" class="cell">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb6" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb6-1"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> style_loss_func(</span>
<span id="cb6-2">    target_features: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">dict</span>[<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">str</span>, Tensor], style_features: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">dict</span>[<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">str</span>, Tensor], precomputed_style_grams: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">dict</span>[<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">str</span>, Tensor]</span>
<span id="cb6-3">) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-&gt;</span> Tensor:</span>
<span id="cb6-4">    device <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">next</span>(<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">iter</span>(target_features.values())).device</span>
<span id="cb6-5">    style_loss <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.tensor(<span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.0</span>, device<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>device)</span>
<span id="cb6-6"></span>
<span id="cb6-7">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> layer <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> style_features:</span>
<span id="cb6-8">        target_feature <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> target_features[layer]</span>
<span id="cb6-9">        target_gram <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> gram_matrix(target_feature)</span>
<span id="cb6-10"></span>
<span id="cb6-11">        style_gram <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> precomputed_style_grams[layer]</span>
<span id="cb6-12"></span>
<span id="cb6-13">        _, c, h, w <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> target_feature.shape</span>
<span id="cb6-14"></span>
<span id="cb6-15">        weight <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> STYLE_LAYERS_DEFAULT[layer]</span>
<span id="cb6-16">        layer_style_loss <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> weight <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> F.mse_loss(target_gram, style_gram) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span> (c <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> h <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> w)</span>
<span id="cb6-17">        style_loss <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+=</span> layer_style_loss</span>
<span id="cb6-18"></span>
<span id="cb6-19">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> style_loss</span></code></pre></div></div>
</div>
</section>
<section id="total-variation-loss" class="level3">
<h3 class="anchored" data-anchor-id="total-variation-loss">Total Variation Loss</h3>
<p>Total Variation (TV) loss, also known as Total Variation Regularization, is commonly added to the Neural Style Transfer objective to encourage spatial smoothness in the generated image. Without it, the output might exhibit noise or oscillations, particularly in regions where the content and style objectives don’t offer much guidance.</p>
<p>Given an image <img src="https://latex.codecogs.com/png.latex?%5Cvec%7Bx%7D"> of size <img src="https://latex.codecogs.com/png.latex?H%20%5Ctimes%20W%20%5Ctimes%20C"> (height, width, channels), the Total Variation loss is defined as the sum of the absolute differences between neighboring pixel values:</p>
<p><img src="https://latex.codecogs.com/png.latex?%0A%5Cmathcal%7BL%7D_%7BTV%7D(%5Cvec%7Bx%7D)%20=%20%5Csum_%7Bi,j%7D%20((x_%7Bi,j+1%7D%20-%20x_%7Bi,j%7D)%5E2%20+%20(x_%7Bi+1,j%7D%20-%20x_%7Bi,j%7D)%5E2)%0A"></p>
<p>where <img src="https://latex.codecogs.com/png.latex?x_%7Bi,j%7D"> is the pixel value at position <img src="https://latex.codecogs.com/png.latex?(i,j)">.</p>
<p>In simple terms, this loss penalizes abrupt changes in pixel values from one to its neighbors. By minimizing this loss, the generated image becomes smoother, reducing artifacts and unwanted noise. When combined with content and style losses, the TV loss ensures that the resulting image not only captures the content and style of the source images but also looks visually coherent and smooth.</p>
<div id="cell-17" class="cell">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb7" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb7-1"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> total_variance_loss_func(target: Tensor) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-&gt;</span> Tensor:</span>
<span id="cb7-2">    tv_loss <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> F.l1_loss(target[:, :, :, :<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>], target[:, :, :, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>:]) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> F.l1_loss(</span>
<span id="cb7-3">        target[:, :, :<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, :], target[:, :, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>:, :]</span>
<span id="cb7-4">    )</span>
<span id="cb7-5"></span>
<span id="cb7-6">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> tv_loss</span></code></pre></div></div>
</div>
</section>
<section id="total-loss" class="level3">
<h3 class="anchored" data-anchor-id="total-loss">Total Loss</h3>
<p>The total loss combines three loss metric components, each targeting a specific aspect of the image generation process. Let’s recap the components:</p>
<ol type="1">
<li><strong>Content Loss</strong>: Ensures the generated image resembles the content image’s content.</li>
<li><strong>Style Loss</strong>: Ensures the generated image captures the stylistic features of the style image.</li>
<li><strong>Total Variation Loss</strong>: Encourages spatial smoothness in the generated image, reducing artifacts and noise.</li>
</ol>
<p>Given the above components, the total loss <img src="https://latex.codecogs.com/png.latex?%5Cmathcal%7BL%7D_%7Btotal%7D"> for Neural Style Transfer can be formulated as:</p>
<p><img src="https://latex.codecogs.com/png.latex?%0A%5Cmathcal%7BL%7D_%7Btotal%7D(%5Cvec%7Bp%7D,%5Cvec%7Ba%7D,%5Cvec%7Bx%7D)%20=%20%5Calpha%5Cmathcal%7BL%7D_%7Bcontent%7D(%5Cvec%7Bp%7D,%5Cvec%7Bx%7D)%20+%20%5Cbeta%5Cmathcal%7BL%7D_%7Bstyle%7D(%5Cvec%7Ba%7D,%5Cvec%7Bx%7D)%20+%20%5Cgamma%5Cmathcal%7BL%7D_%7BTV%7D(%5Cvec%7Bx%7D)%0A"></p>
<p><img src="https://latex.codecogs.com/png.latex?%5Calpha">, <img src="https://latex.codecogs.com/png.latex?%5Cbeta">, and <img src="https://latex.codecogs.com/png.latex?%5Cgamma"> are weight factors that determine the relative importance of the content, style, and the total variation losses, respectively. By adjusting these weights, one can control the balance between content preservation, style transfer intensity, and the smoothness of the generated image. The algorithm aims to adjust the generated image to minimize the total loss.</p>
</section>
</section>
<section id="input-preparation" class="level2">
<h2 class="anchored" data-anchor-id="input-preparation">Input preparation</h2>
<p>Here we specify path to content and style images:</p>
<div id="cell-21" class="cell">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb8" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb8-1">content_path <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"./bridge.jpg"</span></span>
<span id="cb8-2">style_path <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"./walking-in-the-rain.jpg"</span></span></code></pre></div></div>
</div>
</section>
<section id="neural-style-transfer-process" class="level2">
<h2 class="anchored" data-anchor-id="neural-style-transfer-process">Neural Style Transfer Process</h2>
<p>For feature extraction, we’ll leverage <a href="https://arxiv.org/abs/1409.1556">VGG19</a>, pre-trained on <a href="https://www.image-net.org/">ImageNet</a>, same as the original authors. Note that we set the model to evaluation mode, ensuring we only use VGG19 to extract features without altering its weights. We also transfer the neural network (NN) to a chosen device, ideally a GPU, for optimal performance.</p>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
Note
</div>
</div>
<div class="callout-body-container callout-body">
<p>An intriguing choice by Gatys <em>et al.</em> was to modify VGG-19, replacing max pooling with average pooling, aiming for visually superior results. However, a challenge arises: our NN was initially trained with <code>MaxPool2d</code> layers. Substituting them can affect activations due to reduced output values. To counteract this, we’ve introduced a custom <code>ScaledAvgPool2d</code>.</p>
</div>
</div>
<div id="cell-23" class="cell">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb9" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb9-1"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># We will use a frozen pre-trained VGG neural network for feature extraction.</span></span>
<span id="cb9-2"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># In the original paper, authors have used VGG19 (without batch normalization)</span></span>
<span id="cb9-3">model <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> models.vgg19(weights<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>models.VGG19_Weights.IMAGENET1K_V1).features</span>
<span id="cb9-4"></span>
<span id="cb9-5"></span>
<span id="cb9-6"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Authors in the original paper suggested using AvgPool instead of MaxPool</span></span>
<span id="cb9-7"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># for more pleasing results. However, changing the pooling also affects</span></span>
<span id="cb9-8"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># activation, so the input needs to be scaled (can't find the original source).</span></span>
<span id="cb9-9"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">class</span> ScaledAvgPool2d(nn.Module):</span>
<span id="cb9-10">    <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> <span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">__init__</span>(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, kernel_size, stride, padding<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>, scale_factor<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">2.0</span>):</span>
<span id="cb9-11">        <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">super</span>().<span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">__init__</span>()</span>
<span id="cb9-12">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.avgpool <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.nn.AvgPool2d(kernel_size, stride, padding)</span>
<span id="cb9-13">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.scale_factor <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> scale_factor</span>
<span id="cb9-14"></span>
<span id="cb9-15">    <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> forward(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, x):</span>
<span id="cb9-16">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.avgpool(x) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.scale_factor</span>
<span id="cb9-17"></span>
<span id="cb9-18"></span>
<span id="cb9-19"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># (OPTIONAL) Replace max-pooling layers with custom average pooling layers</span></span>
<span id="cb9-20"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># for i, layer in enumerate(model):</span></span>
<span id="cb9-21"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">#   if isinstance(layer, torch.nn.MaxPool2d):</span></span>
<span id="cb9-22"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">#       model[i] = ScaledAvgPool2d(kernel_size=2, stride=2, padding=0)</span></span>
<span id="cb9-23"></span>
<span id="cb9-24">model <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> model.<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">eval</span>().requires_grad_(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">False</span>).to(device)</span></code></pre></div></div>
</div>
<p>The pretrained VGG model used normalized ImageNet samples for better performance. For effective style transfer, we’ll follow suit to improve feature extraction. Though images will appear altered post-normalization, they are reverted to their original state after the NST process. Next, we’ll transform the content and style images by:</p>
<ul>
<li>Loading them from storage.</li>
<li>Resizing while maintaining aspect ratio.</li>
<li>Converting to tensors.</li>
<li>Normalizing using ImageNet weights.</li>
</ul>
<div id="cell-25" class="cell">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb10" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb10-1"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># ImageNet normalization weights per channel</span></span>
<span id="cb10-2">IMAGENET_MEAN <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> (<span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.485</span>, <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.456</span>, <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.406</span>)</span>
<span id="cb10-3">IMAGENET_STD <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> (<span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.229</span>, <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.224</span>, <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.225</span>)</span>
<span id="cb10-4"></span>
<span id="cb10-5">transform <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> T.Compose(</span>
<span id="cb10-6">    [</span>
<span id="cb10-7">        T.ToImage(),</span>
<span id="cb10-8">        T.Resize(IMG_SIZE),  <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Shorter edge of the image will be matched to `IMG_SIZE`</span></span>
<span id="cb10-9">        T.ToDtype(torch.float32, scale<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>),</span>
<span id="cb10-10">        T.Normalize(IMAGENET_MEAN, IMAGENET_STD),</span>
<span id="cb10-11">    ]</span>
<span id="cb10-12">)</span>
<span id="cb10-13"></span>
<span id="cb10-14"></span>
<span id="cb10-15"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> load_image(path: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">str</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">|</span> Path) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-&gt;</span> Tensor:</span>
<span id="cb10-16">    img <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> decode_image(<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">str</span>(path))</span>
<span id="cb10-17"></span>
<span id="cb10-18">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Transform images into tensors</span></span>
<span id="cb10-19">    img: Tensor <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> transform(img)</span>
<span id="cb10-20"></span>
<span id="cb10-21">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Add dimension to imitate batch size equal to 1: (C,H,W) -&gt; (B,C,H,W)</span></span>
<span id="cb10-22">    img <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> img.unsqueeze(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>)</span>
<span id="cb10-23">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> img</span></code></pre></div></div>
</div>
<p>The following code will prepares content <img src="https://latex.codecogs.com/png.latex?%5Cvec%7Bp%7D">, style <img src="https://latex.codecogs.com/png.latex?%5Cvec%7Ba%7D">, and target <img src="https://latex.codecogs.com/png.latex?%5Cvec%7Bx%7D"> images. The target image is a clone of the content image and we enable computation of gradients on it.</p>
<div id="cell-27" class="cell">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb11" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb11-1"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># The "style" image from which we obtain style</span></span>
<span id="cb11-2">style <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> load_image(style_path).to(device)</span>
<span id="cb11-3"></span>
<span id="cb11-4"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># The "content" image on which we apply style</span></span>
<span id="cb11-5">content <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> load_image(content_path).to(device)</span>
<span id="cb11-6"></span>
<span id="cb11-7"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># The "target" image to store the outcome</span></span>
<span id="cb11-8">target <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> content.clone().requires_grad_(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>).to(device)</span></code></pre></div></div>
</div>
<p>The function below retrieves feature maps from designated layers. As shown in Figure&nbsp;1:</p>
<ul>
<li>Content feature map comes from <code>relu5_2</code>.</li>
<li>Style feature maps are sourced from <code>relu1_1</code>, <code>relu2_1</code>, <code>relu3_1</code>, <code>relu4_1</code>, and <code>relu5_1</code>.</li>
</ul>
<div id="cell-29" class="cell">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb12" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb12-1"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> get_features(image: Tensor, model: nn.Module, layers: Iterable[<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">str</span>] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">|</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span>) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-&gt;</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">dict</span>[<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">str</span>, Tensor]:</span>
<span id="cb12-2">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> layers <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">is</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span>:</span>
<span id="cb12-3">        layers <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">tuple</span>(STYLE_LAYERS_DEFAULT.keys()) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> CONTENT_LAYERS_DEFAULT</span>
<span id="cb12-4"></span>
<span id="cb12-5">    features <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> {}</span>
<span id="cb12-6">    block_num <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span></span>
<span id="cb12-7">    conv_num <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span></span>
<span id="cb12-8"></span>
<span id="cb12-9">    x <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> image</span>
<span id="cb12-10"></span>
<span id="cb12-11">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> layer <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> model.children():</span>
<span id="cb12-12">        x <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> layer(x)</span>
<span id="cb12-13"></span>
<span id="cb12-14">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">isinstance</span>(layer, nn.Conv2d):</span>
<span id="cb12-15">            <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># produce layer name to find matching convolutions from the paper</span></span>
<span id="cb12-16">            <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># and store their output for further processing.</span></span>
<span id="cb12-17">            conv_num <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span></span>
<span id="cb12-18">            name <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">f"conv</span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>block_num<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">_</span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>conv_num<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">"</span></span>
<span id="cb12-19">            <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> name <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> layers:</span>
<span id="cb12-20">                features[name] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> x</span>
<span id="cb12-21"></span>
<span id="cb12-22">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">elif</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">isinstance</span>(layer, nn.MaxPool2d <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">|</span> nn.AvgPool2d <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">|</span> ScaledAvgPool2d):</span>
<span id="cb12-23">            <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># In VGG, each block ends with max/avg pooling layer.</span></span>
<span id="cb12-24">            block_num <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span></span>
<span id="cb12-25">            conv_num <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span></span>
<span id="cb12-26"></span>
<span id="cb12-27">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">elif</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">isinstance</span>(layer, nn.BatchNorm2d <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">|</span> nn.ReLU):</span>
<span id="cb12-28">            <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">pass</span></span>
<span id="cb12-29"></span>
<span id="cb12-30">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">else</span>:</span>
<span id="cb12-31">            <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">raise</span> <span class="pp" style="color: #AD0000;
background-color: null;
font-style: inherit;">Exception</span>(<span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">f"Unknown layer: </span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>layer<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">"</span>)</span>
<span id="cb12-32"></span>
<span id="cb12-33">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> features</span></code></pre></div></div>
</div>
<p>Since content and style images never change, we can precompute their feature maps and grams to speed up the NST process.</p>
<div id="cell-31" class="cell">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb13" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb13-1"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Precompute content features, style features, and style gram matrices.</span></span>
<span id="cb13-2">content_features <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> get_features(content, model, CONTENT_LAYERS_DEFAULT)</span>
<span id="cb13-3">style_features <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> get_features(style, model, STYLE_LAYERS_DEFAULT)</span>
<span id="cb13-4"></span>
<span id="cb13-5">style_grams <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> {layer: gram_matrix(style_features[layer]) <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> layer <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> style_features}</span></code></pre></div></div>
</div>
<p>Next, we will use <a href="https://pytorch.org/docs/stable/generated/torch.optim.Adam.html">Adam</a> optimizer, where we specify that only target image <img src="https://latex.codecogs.com/png.latex?%5Cvec%7Bx%7D"> is considered for optimization.</p>
<div id="cell-33" class="cell">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb14" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb14-1">optimizer <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> optim.Adam([target], lr<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>LEARNING_RATE)</span></code></pre></div></div>
</div>
<p>The final step of NST is to transfer style using everything we’ve implemented. We extract feature maps, compute total loss, perform steps using gradient descent, and repeat the process <code>N_EPOCHS</code> times. Gradient changes will apply only to the target image.</p>
<p>To notably enhance NST speed, I utilized mixed precision with the unique <code>bfloat16</code> found in newer hardware. Traditional half-precision float16 doesn’t yield the same results. I’ve tested it. Probably because of the issue with gradient scaling.</p>
<div id="cell-35" class="cell">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb15" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb15-1">pbar <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> tqdm(<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">range</span>(N_EPOCHS))</span>
<span id="cb15-2"></span>
<span id="cb15-3"><span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> _ <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> pbar:</span>
<span id="cb15-4">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">with</span> torch.autocast(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"cuda"</span>, dtype<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>torch.bfloat16, enabled<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>AMP_ENABLED):</span>
<span id="cb15-5">        target_features <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> get_features(target, model)</span>
<span id="cb15-6"></span>
<span id="cb15-7">        content_loss <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> CONTENT_WEIGHT <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> content_loss_func(target_features, content_features)</span>
<span id="cb15-8">        style_loss <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> STYLE_WEIGHT <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> style_loss_func(target_features, style_features, style_grams)</span>
<span id="cb15-9">        tv_loss <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> TV_WEIGHT <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> total_variance_loss_func(target)</span>
<span id="cb15-10"></span>
<span id="cb15-11">        total_loss <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> content_loss <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> style_loss <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> tv_loss</span>
<span id="cb15-12"></span>
<span id="cb15-13">    optimizer.zero_grad()</span>
<span id="cb15-14">    total_loss.backward()</span>
<span id="cb15-15"></span>
<span id="cb15-16">    optimizer.step()</span>
<span id="cb15-17"></span>
<span id="cb15-18">    pbar.set_postfix_str(</span>
<span id="cb15-19">        <span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">f"total_loss=</span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>total_loss<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">.</span>item()<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">:.2f}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;"> "</span>  <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># noqa: E501</span></span>
<span id="cb15-20">        <span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">f"content_loss=</span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>content_loss<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">.</span>item()<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">:.2f}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;"> "</span></span>
<span id="cb15-21">        <span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">f"style_loss=</span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>style_loss<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">.</span>item()<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">:.2f}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;"> "</span></span>
<span id="cb15-22">        <span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">f"tv_loss=</span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>tv_loss<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">.</span>item()<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">:.2f}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">"</span></span>
<span id="cb15-23">    )</span></code></pre></div></div>
<div class="cell-output cell-output-stderr">
<pre><code>100%|██████████| 5000/5000 [01:37&lt;00:00, 51.45it/s, total_loss=43.91 content_loss=8.70 style_loss=29.11 tv_loss=6.11]     </code></pre>
</div>
</div>
<p>As mentioned before, images need to be denormalized (<em>i.e.</em> reverted back) to correct colors. After that we compare content, style and target images side-by-side.</p>
<div id="cell-37" class="cell">
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb17" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb17-1"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">class</span> InverseNormalize:</span>
<span id="cb17-2">    <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> <span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">__init__</span>(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, mean: Sequence[<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">float</span>], std: Sequence[<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">float</span>]) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-&gt;</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span>:</span>
<span id="cb17-3">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.mean <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.as_tensor(mean)</span>
<span id="cb17-4">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.std <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.as_tensor(std)</span>
<span id="cb17-5"></span>
<span id="cb17-6">    <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> <span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">__call__</span>(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, x_norm: Tensor) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-&gt;</span> Tensor:</span>
<span id="cb17-7">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Ensure mean and std have the correct shape</span></span>
<span id="cb17-8">        mean <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.mean.to(x_norm.device).view(<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>)</span>
<span id="cb17-9">        std <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.std.to(x_norm.device).view(<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>)</span>
<span id="cb17-10"></span>
<span id="cb17-11">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Inverse normalization: x = x_normalized * std + mean</span></span>
<span id="cb17-12">        x <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> x_norm.mul(std).add(mean)</span>
<span id="cb17-13">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> x</span>
<span id="cb17-14"></span>
<span id="cb17-15"></span>
<span id="cb17-16"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">class</span> Clip:</span>
<span id="cb17-17">    <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> <span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">__init__</span>(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, vmin: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">float</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.0</span>, vmax: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">float</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">1.0</span>) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-&gt;</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span>:</span>
<span id="cb17-18">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.vmin <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> vmin</span>
<span id="cb17-19">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.vmax <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> vmax</span>
<span id="cb17-20"></span>
<span id="cb17-21">    <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> <span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">__call__</span>(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, x: Tensor) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-&gt;</span> Tensor:</span>
<span id="cb17-22">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> torch.clamp(x, <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.vmin, <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.vmax)</span>
<span id="cb17-23"></span>
<span id="cb17-24"></span>
<span id="cb17-25">inv_transform_preview <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> T.Compose(</span>
<span id="cb17-26">    [</span>
<span id="cb17-27">        InverseNormalize(IMAGENET_MEAN, IMAGENET_STD),</span>
<span id="cb17-28">        T.Resize(IMG_SIZE, antialias<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>),</span>
<span id="cb17-29">        T.CenterCrop((IMG_SIZE, IMG_SIZE)),</span>
<span id="cb17-30">        Clip(),</span>
<span id="cb17-31">    ]</span>
<span id="cb17-32">)</span>
<span id="cb17-33"></span>
<span id="cb17-34">imgs <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> [inv_transform_preview(i.detach().squeeze().cpu()) <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> i <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> (content, style, target)]</span>
<span id="cb17-35"></span>
<span id="cb17-36">grid <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> make_grid(imgs)</span>
<span id="cb17-37"></span>
<span id="cb17-38"></span>
<span id="cb17-39"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> show(imgs):</span>
<span id="cb17-40">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">not</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">isinstance</span>(imgs, <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">list</span>):</span>
<span id="cb17-41">        imgs <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> [imgs]</span>
<span id="cb17-42"></span>
<span id="cb17-43">    fig, axs <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> plt.subplots(ncols<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="bu" style="color: null;
background-color: null;
font-style: inherit;">len</span>(imgs), figsize<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">15</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">5</span>), squeeze<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">False</span>, dpi<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">92</span>, tight_layout<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>, frameon<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">False</span>)</span>
<span id="cb17-44">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> i, img <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">enumerate</span>(imgs):</span>
<span id="cb17-45">        img <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> img.detach()</span>
<span id="cb17-46">        img <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> VF.to_pil_image(img)</span>
<span id="cb17-47">        axs[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>, i].imshow(np.asarray(img))</span>
<span id="cb17-48">        axs[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>, i].<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">set</span>(xticklabels<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>[], yticklabels<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>[], xticks<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>[], yticks<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>[])</span>
<span id="cb17-49"></span>
<span id="cb17-50"></span>
<span id="cb17-51">show(grid)</span></code></pre></div></div>
<div class="cell-output cell-output-display">
<div class="quarto-figure quarto-figure-center">
<figure class="figure">
<p><img src="https://gcerar.github.io/posts/2023-09-15-neural-style-transfer/index_files/figure-html/cell-18-output-1.png" width="1312" height="450" class="figure-img"></p>
<figcaption>Successfuly applied neural style transfer. The <strong>content</strong> image (left), the <strong>style</strong> image (center), and final <strong>target</strong> image (right).</figcaption>
</figure>
</div>
</div>
</div>
</section>
</section>
<section id="conclusions" class="level1">
<h1>Conclusions</h1>
<p>Neural Style Transfer (NST) was a breakthrough deep learning approach that can transfer <strong>artistic style</strong> from one image to another. The key takeaway from my experience is the incredible potential of neural networks in merging art and tech, seamlessly blending the styles of different artworks with original images.</p>
<p>What stood out was the use of a pre-trained neural network for feature extraction, extracting feature maps from particular layers, and then the ability to balance the content and style weight parameters to maintain the essence of the original image while effectively imitating the artistic style.</p>
<p>Although the NST achieves pleasing results, it was soon overshadowed by faster and more advanced methods, such as <a href="https://en.wikipedia.org/wiki/DALL-E">DALL-E</a>, <a href="https://en.wikipedia.org/wiki/Stable_Diffusion">Stable Diffusion</a>, and <a href="https://en.wikipedia.org/wiki/Midjourney">Midjourney</a>. However, it represented a significant milestone toward artistic AI and generative AI models.</p>
</section>
<section id="acknowledgements" class="level1">
<h1>Acknowledgements</h1>
<p>Helpful articles and code repositories while writing my implementation:</p>
<ul>
<li>Gregor Koehler <em>et al.</em> <a href="https://nextjournal.com/gkoehler/pytorch-neural-style-transfer">gkoehler/pytorch-neural-style-transfer</a> (best resource in my opinion)</li>
<li>Ritul’s <a href="https://medium.com/udacity-pytorch-challengers/style-transfer-using-deep-nural-network-and-pytorch-3fae1c2dd73e">Medium article</a> (good resource)</li>
<li>Pragati Baheti <a href="https://www.v7labs.com/blog/neural-style-transfer">blog</a> visually present style extraction</li>
<li>Aleksa Gordić (<a href="https://github.com/gordicaleksa/pytorch-neural-style-transfer">gordicaleksa/pytorch-neural-style-transfer</a>)</li>
<li><a href="https://github.com/ProGamerGov/neural-style-pt/blob/master/neural_style.py">ProGamerGov/neural-style-pt</a></li>
<li>Katherine Crowson (<a href="https://github.com/crowsonkb/style-transfer-pytorch/blob/master/style_transfer/style_transfer.py">rowsonkb/style-transfer-pytorch</a>)</li>
<li>Derrick Mwiti’s <a href="https://heartbeat.comet.ml/neural-style-transfer-with-pytorch-49e7c1fe3bea">Medium article</a></li>
<li>Aman Kumar Mallik’s <a href="https://towardsdatascience.com/implementing-neural-style-transfer-using-pytorch-fd8d43fb7bfa">Medium article</a></li>
</ul>
<p>I want to acknowledge the following artworks:</p>
<ul>
<li>“Gray Bridge and Trees” by Martin Damboldt</li>
<li>“Walking in the Rain” by Leonid Afremov</li>
<li>“The Starry Night” by Vincent van Gogh</li>
</ul>
<p>For a complete list of acknowledgments, please visit my GitHub repository:</p>
<ul>
<li><a href="https://github.com/gcerar/pytorch-neural-style-transfer/#acknowledgment">gcerar/pytorch-neural-style-transfer</a></li>
</ul>
</section>
<section id="appendix" class="level1">
<h1>Appendix</h1>
<section id="examples" class="level2">
<h2 class="anchored" data-anchor-id="examples">Examples</h2>
<p>A few cherry-picked examples of style transfer:</p>
<p><img src="https://gcerar.github.io/posts/2023-09-15-neural-style-transfer/examples/bridge+walking-in-the-rain.webp" class="img-fluid" alt="bridge + Walking in the Rain"> <img src="https://gcerar.github.io/posts/2023-09-15-neural-style-transfer/examples/walking-in-the-rain+bridge.webp" class="img-fluid" alt="Walking in the Rain + bridge"> <img src="https://gcerar.github.io/posts/2023-09-15-neural-style-transfer/examples/bridge+starry-night-crop.webp" class="img-fluid" alt="bridge + Starry Night"> <img src="https://gcerar.github.io/posts/2023-09-15-neural-style-transfer/examples/bridge+colorful-whirlpool.webp" class="img-fluid" alt="bridge + colorful whirlpool"></p>



</section>
</section>

<div id="quarto-appendix" class="default"><section class="quarto-appendix-contents" id="quarto-bibliography"><h2 class="anchored quarto-appendix-heading">References</h2><div id="refs" class="references csl-bib-body hanging-indent">
<div id="ref-gatys2015neural" class="csl-entry">
Gatys, Leon A, Alexander S Ecker, and Matthias Bethge. 2015. <span>“A Neural Algorithm of Artistic Style.”</span> <em>arXiv Preprint arXiv:1508.06576</em>.
</div>
</div></section><section class="quarto-appendix-contents" id="quarto-reuse"><h2 class="anchored quarto-appendix-heading">Reuse</h2><div class="quarto-appendix-contents"><div><a rel="license" href="https://creativecommons.org/licenses/by-nc-sa/4.0/deed.en_us">CC BY-NC-SA 4.0</a></div></div></section></div> ]]></description>
  <category>pytorch</category>
  <category>NST</category>
  <guid>https://gcerar.github.io/posts/2023-09-15-neural-style-transfer/</guid>
  <pubDate>Thu, 14 Sep 2023 22:00:00 GMT</pubDate>
  <media:content url="https://gcerar.github.io/posts/2023-09-15-neural-style-transfer/featured.webp" medium="image" type="image/webp"/>
</item>
</channel>
</rss>
