
The Gradient That Changed Everything
It’s almost coincidental. On the Christmas Eve this year, I came across a math problem asking me to compute the partial derivatives of Word2Vec’s naive softmax loss — standard fare for any NLP course. But something compelled me to keep going, to really understand what these update rules were doing.
The result was deceptively simple:
$$ \frac{\partial J}{\partial v_c} = -u_o + \sum_{w\in V} \Pr[w|c] \, u_w = U(\hat{y} - y) $$What struck me wasn’t the math itself albeit it’s elegant but straightforward. What caught my attention was the structure of the learning process. What this math formulation suggests is that updating the center word vector $v_c$ requires knowing the current state of all context vectors $U$. But updating $U$ requires knowing $v_c$ (illustrated by partial derivative regarding $U$–the other piece of the puzzle). This chicken-and-egg dependency — where each parameter set treats the other as temporarily fixed — reminds us of Expectation-Maximization algorithms. It isn’t EM in the formal sense, but the alternating dependence—treating one parameter block as fixed while updating the other—shares the same structural intuition.
The softmax probabilities $\hat{y}$ effectively act as posterior responsibilities, like the E-step in EM. The gradient updates perform parameter optimization, analogous to the M-step. Word2Vec, stripped to its mathematical essence, is performing latent factor analysis through alternating optimization.
This realization hit me while sitting in a coffee shop, surrounded by finals week chaos. And it crystallized something I’d been feeling for the entire year: there’s a vast difference between using language models and understanding them. This is why I’m spending my winter break implementing GPT-2 from scratch.
The Practitioner’s Gap
It’s December 2025, and large language models are everywhere. They’re classifying patient notes, generating code, summarizing legal documents, writing marketing copy. The API abstraction is powerful: a thing like model.generate(prompt) hides extraordinary complexity behind a simple interface.
But this abstraction is dangerous when we’re building production systems in high-stakes domains.
Consider the practitioner’s dilemma: You need to fine-tune a pre-trained model for your specific task.
- How much data do you need? Which layers should you freeze?
- What learning rate (or what dynamic adjustments) makes sense?
- When is the model overfitting versus genuinely learning task-specific patterns?
- How do you diagnose why it’s failing on certain inputs?
API-level understanding can’t answer these questions. You need implementation-level intuition — the kind that only comes from building these systems yourself, from watching gradients flow backward through attention heads, from debugging why your validation loss suddenly diverges at epoch 4.
This gap between usage and understanding is what motivated me to commit one week of my end-of-year break (December 25 through January 1) to a focused deep-dive: working through Stanford’s CS224n curriculum, implementing GPT-2 components from scratch, and fine-tuning the model for three distinct NLP tasks.
Let me make this clear upfront: This isn’t about completing a course for credentials. It’s about building the mental models necessary to work confidently with transformer architectures in production. And I’m documenting the journey publicly — partly for accountability, partly because I suspect other data scientists face the same gap.
Why GPT-2 in 2025?
You may have an obvious question or objection now: “GPT-2 is from 2019. What’s the point to visit it in the era where GPT-5.2 was out eariler this month?”
Let me address this head-on, because the choice does still make sense to me… along with many of you.
The Pedagogical Sweet Spot
GPT-2 occupies a unique position in the complexity-comprehensibility spectrum. It’s complex enough to be non-trivial — multi-head attention, layer normalization, positional encoding, the full transformer decoder stack. But it’s simple enough to fully implement in one week: 12-24 layers instead of thousands, architecture choices that are well-documented and understood.
The attention mechanism in GPT-2 is functionally bijective to the attention in modern SOTA transformer-based LLM variants. Layer normalization remains conceptually the same (even if RMSNorm is now popular). Positional encoding strategies have evolved, but the core principle — making position information available to a position-agnostic architecture — hasn’t changed.
Every modern transformer shares this core architecture. Understanding GPT-2 deeply means you can read papers about any of modern LLM implementations and immediately recognize the architectural components. The variations — Flash Attention, Rotary Position Embeddings, different normalization schemes — become understandable modifications rather than mysterious incantations.
It’s like learning classical mechanics and electromagnetics through Newton’s original formulation before studying quantum physics or general relativity. The foundations transfer.
Pre-Training Problem Solved
However, training a language model from scratch is extraordinarily expensive — not just computationally (though $50k-100k+ for even modest-sized models is prohibitive), but in terms of expertise. Corpus curation, cleaning, deduplication, and balancing is an entire specialized skillset. Getting it wrong leads to models that memorize datasets, amplify biases, or fail to generalize.
Good news is pre-trained weights are freely available. This lets me focus on what actually matters for practitioners: understanding the architecture and learning how to adapt it for specific tasks.
Fine-tuning is the transferable skill. It’s what you’ll do in production. It’s where most of the practical challenges live, no matter whether we actually use GPT-2 as base models or not (e.g., in some high volume settings with limited computes, even BERT can be a rational choice in 2025).
Documentation Advantage
And here is another benefit I’d expect. GPT-2 has been studied exhaustively for six years. There are hundreds of tutorials, blog posts, paper analyses, and GitHub implementations. When my attention implementation produces nonsense making gradients, I will be able to find multiple different explanations of what went wrong and how to fix it.
Bleeding-edge models lack this accumulated knowledge. Every bug becomes an expedition into uncertainty: “Is this a fundamental misunderstanding of the architecture, or am I hitting an undocumented edge case?”
Learning requires being able to distinguish “I’m wrong” from “the reference is wrong.” With GPT-2, the references are battle-tested.
The Learning Strategy: CS224n + JAX/Flax Rewrite
CS224n as Framework
Stanford’s CS224n provides the curriculum structure: lectures covering fundamentals (tokenization, word embeddings, recurrence, attention, transformers), assignments that cover from theory to implementation, and a final project that synthesizes everything.
The course has been refined over almost two decades. The assignments are battle-tested. The progression from simple bag-of-words models to full transformers follows a pedagogical logic that makes complex concepts buildable.
But I’m adding a twist of my own flavor.
The JAX/Flax Rewrite as Forcing Function
The standard CS224n final project uses PyTorch. I’m rewriting everything in JAX/Flax instead.
This isn’t arbitrary tool choice. It’s a deliberate curiosity-led learning constraint, similar to my recent approach with F# — using a more restrictive paradigm to force deeper understanding.
Here’s why JAX matters:
State Explicitness
PyTorch’s imperative style allows implicit state mutations:
|
|
This is convenient but obscures what’s actually happening. Where do gradients live? Which parameters changed? What’s the computational graph topology?
JAX makes state threading explicit:
|
|
The function signature documents everything: what state comes in, what state goes out, what transformations occurred. Far fewer things mutate — you get new values, and it’s explicit what happened to produce them.
(note: I stated “fewer” instead of “no”. Let’s bear with it for now. I’ll write a separate post on where functional purity breaks down in real NN systems.)
This explicitness is particularly valuable when learning. I won’t be bothered by implicit mutations. Every gradient computation, every parameter update, every state transition must be consciously threaded through the program.
Mathematical Alignment
It’s worth noting that neural networks are fundamentally functional in their mathematical formulation. A forward pass is function composition:
$$ f(x; \theta) = f_L(f_{L-1}(\cdots f_2(f_1(x; \theta_1); \theta_2) \cdots; \theta_{L-1}); \theta_L) $$Backpropagation is function composition of derivatives:
$$ \frac{\partial L}{\partial \theta_i} = \frac{\partial L}{\partial f_L} \circ \frac{\partial f_L}{\partial f_{L-1}} \circ \cdots \circ \frac{\partial f_{i+1}}{\partial f_i} \circ \frac{\partial f_i}{\partial \theta_i} $$Writing this in functional style synchronizes code structure with mathematical structure. The code becomes a nearly direct translation of the mathematics. This makes it easier to verify correctness — if my implementation diverges from the math, the type signatures won’t align.
Consider a concrete example. The multi-head attention mechanism in PyTorch typically looks like this:
|
|
The JAX/Flax version makes state explicit:
|
|
This is a pure function. Given the same params and x, it always returns the same output. There are no hidden mutations, no implicit state, no side effects. Testing becomes trivial:
|
|
No need to instantiate modules, manage device placement, reset gradients, or navigate PyTorch’s stateful training/eval modes.
Practical Benefits
The functional style prevents certain classes of bugs entirely:
- No accidental mutations: Can’t forget to reset gradients or accidentally modify shared state
- Reproducibility: Pure functions with explicit random keys mean truly deterministic training
- Composability: Functions compose cleanly without worrying about side effects
- Parallelization: Pure functions are trivially parallelizable (JAX’s
jax.pmapexploits this)
And jax.grad offers a different mental model than PyTorch’s autograd. Instead of accumulating gradients as side effects, you explicitly compute gradient functions:
|
|
This makes the gradient computation itself a first-class object you can inspect, compose, and reason about.
Acknowledged Caveats
I’m learning JAX idioms and deepening NLP understanding simultaneously. This adds cognitive overhead. But I have mitigating factors:
- PyTorch familiarity means I understand the underlying ML concepts
- Working knowledge of NLP I had means lectures are review/deepening, not first exposure
- JAX crash tutorials already completed — the syntax isn’t entirely foreign
The learning curve might be real, but the clarity gains justify the investment. And the functional programming perspective aligns with my recent trajectory: Python → Rust → F# → back to Python with FP lens.
Prioritizing Depth Over Completion
CS224n has four assignments. I’ve already completed two. My strategy for the remaining two: prioritize theory and mathematical problems over implementation exercises.
The Word2Vec gradient derivation (not required by the assignment) led to the EM insight that opened this essay. That’s the value of slowing down to understand the mathematics deeply. Implementation problems, by contrast, can be replaced with conceptual sketches if time runs short — the learning comes from having attempted them seriously first, not from typing every line of code.
For the final project, I’m being strategic about scope.
Progress So Far: Domain Modeling GPT-2 in JAX
The first step in any rewrite is domain modeling: what are the core abstractions? What state needs to exist? How does it flow through the system?
Explicit State Types
I am designing explicit types for every piece of state. Illustrative examples below:
|
|
These types document assumptions. frozen=True enforces immutability at the Python level. Every training step returns a new TrainingState — nothing mutates.
Function Signatures as Documentation
The function signatures for core operations make data flow explicit:
|
|
Compare this to PyTorch’s typical pattern:
|
|
The JAX signatures documents many thing: inputs, outputs, shapes (via type annotations), and the guarantee that nothing outside the function’s return value changes.
Data Flow Restructuring
The training loop becomes a pure state transformation:
|
|
No hidden mutations. No model.train() / model.eval() mode switching. No forgetting to call optimizer.zero_grad(). The state flows explicitly through the computation, and you always know exactly what state you’re in.
What This Clarity Reveals
This explicit structure has already paid dividends during implementation:
-
Gradient flow is transparent: I can trace exactly how gradients propagate backward through the architecture because there are no hidden accumulations.
-
Testing is straightforward: Pure functions with explicit inputs/outputs are trivial to unit test. No fixtures, no setup/teardown, no managing global state.
-
Debugging is easier: When something goes wrong, I can inspect the exact state at any point without worrying about whether some hidden mutation corrupted things earlier.
-
Mathematical structure is visible: The code mirrors the mathematics. When I read papers describing transformer variants, I can see exactly where in my code the modifications would occur.
Final Project Scope
The CS224n final project has two main components:
-
Core implementations from scratch: Custom attention mechanisms, positional encoding strategies, optimizer (AdamW with weight decay)
-
Fine-tuning for three tasks: Sentiment classification, paraphrase detection, text generation
My strategic priority: Start with fine-tuning, backfill core implementations as time permits.
Why this ordering? Fine-tuning teaches task adaptation — the most transferable skill for practitioners. Understanding how to prepare datasets, design task-specific heads, tune hyperparameters, and diagnose overfitting matters more than implementing Adam from scratch (which is well-understood and has canonical implementations).
If time runs short, I’d rather have a working fine-tuned model using JAX’s built-in optimizer than a custom optimizer with no fine-tuning experiments.
What’s Ahead: Three Tasks, Three Challenges
The final project requires fine-tuning GPT-2 for three distinct NLP tasks. Each tests different aspects of the model’s capability and requires different training strategies.
Task 1: Sentiment Classification (SST + CFIMDB)
Stanford Sentiment Treebank (SST)
The SST dataset contains 8,544 training examples, 1,101 dev examples, and 2,210 test examples. Each example is a single sentence from a movie review, labeled by three human judges with one of five sentiment classes: negative, somewhat negative, neutral, somewhat positive, or positive.
The challenge here is fine-grained distinction. It’s relatively easy to distinguish “This movie is terrible” (negative) from “This movie is brilliant” (positive). But what about “This movie is okay”? Is that neutral, or somewhat positive? Human annotators disagree on these boundaries, which means the model must learn subtle linguistic cues — hedging language, qualifier words, negation scope.
CFIMDB Dataset
The CFIMDB dataset is smaller but more polarized: 1,701 training examples, 245 dev, 488 test. Each example is a longer movie review (often multiple sentences) labeled simply as negative or positive.
The smaller dataset size (roughly 5× smaller than SST) combined with longer sequences creates a different challenge: overfitting. With fewer examples, the model might memorize specific patterns in the training set rather than learning generalizable sentiment representations. I’ll need aggressive regularization — dropout, weight decay, early stopping based on validation performance.
What This Tests
Can the model learn task-specific sentiment features? How does performance differ between fine-grained 5-class classification (SST) versus coarse binary classification (CFIMDB)?
More interestingly: when I fine-tune on SST first, can I then adapt to CFIMDB with minimal additional training (transfer learning within sentiment tasks)? Or do the different granularities require fundamentally different representations?
Task 2: Paraphrase Detection (Quora Question Pairs)
This dataset is substantially larger: 141,506 training examples, 20,215 dev, 40,431 test.
The task: given two questions from Quora, determine if they convey the same semantic meaning. The CS224n formulation treats this as a cloze-style task where the model generates “yes” (paraphrases) or “no” (different meanings).
For example:
- “How do I lose weight?” vs “What’s the best diet?” → might seem related but differ in intent (general weight loss vs specific dietary advice)
- “What’s the capital of France?” vs “What is France’s capital city?” → clearly paraphrases despite different wording
The challenge is semantic similarity beyond surface-level word matching. Paraphrases might share no words at all (“How to shed pounds?” vs “Weight loss methods?”), while non-paraphrases might share many words (“How do I learn Python?” vs “How do I learn French?”).
The 100× Dataset Size Difference
Quora has roughly 100× more training data than CFIMDB (141k vs 1.7k examples). This fundamentally changes the training dynamics:
-
CFIMDB: Every epoch sees the same 1,700 examples. The model quickly memorizes the training set. Success requires preventing overfitting through regularization.
-
Quora: With 141k examples, the model might not memorize the entire training set even after many epochs. The risk shifts from overfitting the training data to overfitting the pre-trained model’s representations — the model might rely too heavily on GPT-2’s existing knowledge rather than learning task-specific paraphrase detection.
This dataset size disparity may teach an important lesson: fine-tuning isn’t one-size-fits-all. Different dataset sizes require different:
- Learning rates (smaller datasets → lower learning rates to avoid catastrophic forgetting)
- Regularization strategies (small datasets → aggressive dropout, large datasets → minimal dropout)
- Training duration (small datasets → few epochs, large datasets → more epochs but with early stopping)
- Layer freezing decisions (small datasets → freeze more layers, large datasets → fine-tune more layers)
JAX’s functional training loops make it easier to experiment with these variations. I can write different update_step functions for different tasks without worrying about state management complications.
Task 3: Text Generation (Sonnets)
This is the most experimental task: 143 sonnets for training, 12 held out for testing.
The challenge: given the first 3 lines of a sonnet, generate the remaining 11 lines while maintaining:
- Rhyme scheme: ABAB CDCD EFEF GG (the Shakespearean structure)
- Meter: Iambic pentameter (ten syllables per line, unstressed-stressed pattern)
- Semantic coherence: The generated lines should continue the theme/ideas from the prompt
With only 143 training examples, this enters few-shot learning territory. The model has seen billions of tokens during pre-training, including plenty of poetry. The question is whether fine-tuning on this tiny dataset can teach it Shakespearean-specific patterns without catastrophically forgetting everything else.
Evaluation Challenges
How do I even measure success? Standard metrics like perplexity measure how well the model predicts the next token given the context, but low perplexity doesn’t guarantee good sonnets. I could have perfectly predictable sonnets that are boring or nonsensical.
My evaluation strategy will likely be multi-faceted:
-
Quantitative baselines:
- Perplexity on held-out sonnets (does the model assign high probability to actual Shakespeare?)
- Rhyme scheme accuracy (do the generated rhymes match ABAB CDCD EFEF GG?)
- Syllable count distribution (are we close to 10 syllables per line?)
-
Qualitative analysis:
- Manual reading of generated sonnets (are they comprehensible? evocative? Shakespearean in tone?)
- Comparison to baseline (what happens if I generate from the pre-trained model with no fine-tuning?)
- Comparison to temperature sampling (does higher randomness produce more interesting results at the cost of coherence?)
-
Ablation studies:
- What if I fine-tune on only the first 100 sonnets and test on 43 held-out examples?
- What if I use the full 143 for training but evaluate on my own prompts (not from the held-out set)?
- Does data augmentation help? (e.g., paraphrasing lines while preserving rhyme/meter)
This task will teach me as much about evaluation methodology as about generation itself. And that’s valuable — in production, figuring out how to measure success on creative or open-ended tasks is often harder than training the model.
Cross-Task Insights
These three tasks span the classic NLP challenge space:
- Classification (sentiment, paraphrase): Clear success metrics, supervised learning
- Generation (sonnets): Ambiguous success, creative output, subtle constraints
They also span dataset sizes:
- Tiny (143 sonnets)
- Small (1.7k CFIMDB)
- Medium (8.5k SST)
- Large (141k Quora)
Each dataset size teaches different lessons about overfitting, regularization, and fine-tuning strategies. By the end, I should have intuition for how dataset scale affects training dynamics — intuition that transfers to any fine-tuning task.
And because I’m implementing this in JAX with explicit functional patterns, I can experiment with different training loops, loss functions, and regularization schemes without getting tangled in stateful complexity.
The One-Week Constraint
Let me be realistic about what’s achievable in seven days.
I’ve already invested:
- 4 lectures completed (covering word vectors, language modeling, NN basics, dependency parsing)
- 2 assignments finished (covering Co-occurrence matrix and decomposition, GloVe practices/inspection, Word2Vec gradients, NN backprop, dependency parsing)
- Initial GPT-2 domain modeling (types, function signatures, data flow design)
Remaining work:
- 2 assignments (covering RNNs, neural translation, attentions, transformers, etc.)
- Core implementations (attention, positional encoding, optimizer)
- Data pipelines for all three tasks
- Training loops and task-specific heads for three distinct problems
- Hyperparameter tuning and evaluation
That’s roughly:
- 2-3 days for assignments + lectures (finishing the curriculum)
- 3-4 days for final project implementation
- 1 day buffer for debugging, unexpected challenges, documentation
Realistic Outcomes
Best case: All three tasks working well, core implementations complete, comparative analysis of JAX vs PyTorch in a follow-up post, deep understanding of transformer internals.
Expected case: Two tasks performing solidly (likely SST and Quora, given dataset sizes and clearer evaluation metrics), experimental results on sonnets, partial core implementations (maybe custom attention but using JAX’s built-in optimizer).
Minimum viable: One classification task working well (SST), one generation task producing interesting failures (sonnets), deep architectural understanding even if implementations are incomplete.
Strategic Priorities
My task ordering:
-
SST sentiment classification first: Clearest success metric (5-class accuracy), medium dataset size, validates the entire pipeline (tokenization → model → training → evaluation).
-
Quora paraphrase second: Different task type tests transfer learning, largest dataset teaches training dynamics at scale, still has clear metrics.
-
Sonnets third: Most experimental, smallest dataset, qualitative evaluation harder to validate quickly, but most interesting for learning about generation.
For core implementations, I’ll fill in components as needed:
- If attention from scratch is required for understanding, implement it
- If I can use Flax’s built-in attention and still learn deeply, use it
- Same principle for optimizer, positional encoding, etc.
The goal isn’t “implemented every component from scratch.” The goal is “understand transformers well enough to make informed architectural choices and debug production issues.”
Purpose & What Comes Next
This essay serves multiple purposes.
It’s a public commitment device. Documenting my plan before executing it increases follow-through. There’s social pressure (even self-imposed) to actually do the work once you’ve told people about it.
It’s a signal of seriousness. Not just “I’m interested in NLP” but “I’m willing to invest focused time in deep implementation.” For collaborators, employers, or colleagues evaluating my capabilities, this demonstrates both commitment and problem solving approach/philosophy in the study area.
And it’s an invitation. If you’re a data scientist considering a similar deep-dive into NLP, transformers, or any complex technical domain, maybe this provides a template: structured curriculum + deliberate constraints + public documentation.
Stay tuned for my follow-up post in some time mid-January or so, likely documenting:
- What actually worked (did I finish all three tasks? where did I spend unexpected time?)
- JAX vs PyTorch comparison (what did the functional perspective reveal? where was it frustrating?)
- Fine-tuning insights (how did different dataset sizes affect training? what hyperparameter lessons transferred across tasks?)
- Architectural surprises (what did implementing attention teach me that reading about it didn’t?)
Closing: Understanding Through Implementation
I return to where this began: that Word2Vec gradient derivation that wasn’t required but revealed the EM structure underneath.
That’s the philosophy driving this project. What I cannot create, I do not understand. Not in the sense of “I must reinvent every wheel” — I’m using pre-trained GPT-2 weights, after all. But in the sense that until I’ve implemented the core components, debugged the failures, and made the architectural choices myself, my understanding remains surface-level.
The goal is to build intuition. And intuition emerges from implementation, from gradients flowing backward through layers of abstraction until the mathematics and the code become one unified mental model.
One week. Three tasks. JAX functional purity as a learning constraint. Let’s see what understanding emerges.