Representation Learning
The Foundation of Everything
Every deep learning model — transformers, GNNs, diffusion models, LLMs — is, at its core, a representation learner. Before you can understand attention, contrastive loss, or RLHF, you need to deeply understand what a representation is, why it matters, and what makes one representation better than another.
What Is a Representation?
Let's start with something you already understand. Imagine you're trying to recognise whether an email is spam or not spam. The email is just a bunch of words. But words as raw text are useless to a machine — it can't do math on the word "lottery".
So what do we do? We convert the email into numbers. Maybe we count how many times "free", "win", "click" appear. That list of numbers — [3, 1, 5, 0, 2, ...] — is a representation of the email.
A representation is just a way of describing a real-world thing (image, text, node in a graph) as a list of numbers — a vector — so that a machine can reason about it mathematically.
This idea is so simple that it's easy to miss how profound it is. The entire field of deep learning is fundamentally about one question:
"What is the best way to represent a piece of data as numbers, so that useful patterns become easy to find?"
A More Concrete Example
Suppose you have three fruits: an apple, a banana, and a strawberry. How do you represent them as numbers?
Option A — Arbitrary IDs: apple=1, banana=2, strawberry=3. This is terrible because it implies banana is "between" apple and strawberry, which is meaningless.
Option B — Feature vector: Represent each fruit by [redness, sweetness, size]:
Notice: apple [0.9, 0.7, 0.5] and strawberry [0.95, 0.8, 0.1] are now close to each other as numbers because they're both red and sweet. This is what a good representation does — it captures real-world similarity as mathematical closeness.
Why Does Representation Matter So Much?
In machine learning, your model can only be as smart as the representation it starts with. Given a bad representation, even the most sophisticated model will struggle. Given a great representation, even a simple linear model can do remarkably well.
This is exactly why deep learning has been so transformative: deep neural networks learn representations automatically from data, instead of humans having to hand-craft them.
Before deep learning (pre-2012), ML engineers spent most of their time on feature engineering — manually designing representations. For images, this meant things like SIFT descriptors, HOG features, etc. Deep learning automated this process. Now the model learns what numbers to use.
The Manifold Hypothesis
Here's a puzzle. A typical image is 224×224 pixels with 3 color channels. That's 224 × 224 × 3 = 150,528 numbers per image. Theoretically, "image space" has 150,528 dimensions.
But think about it: what fraction of all possible 150,528-dimensional vectors actually look like a real image of a cat? Almost none. Most random vectors look like pure static noise.
Real-world data (images, text, audio) — despite living in a very high-dimensional space — actually lies on or near a low-dimensional manifold. Meaningful variation in natural data only occupies a tiny corner of the full mathematical space.
Why This Matters for Learning
If the data lies on a low-dimensional manifold, then a good representation should unfold that manifold — take the data from its complicated high-dimensional space and map it to a simple, flat, low-dimensional space where patterns are visible.
This is exactly what the encoder part of every neural network does. When you pass an image through ResNet or ViT, the early layers detect edges, then shapes, then parts of objects. Each layer is finding a better, more disentangled representation of the manifold the data lives on.
The Geometric View: Embeddings as Points in Space
This is the single most important mental model in all of deep learning. Let's build it carefully.
Every representation is a vector — an ordered list of numbers. A vector with n numbers lives in n-dimensional space. We call this a embedding space or latent space.
Once data is represented as points in space, geometry becomes meaningful:
| Geometric concept | What it means for representations | Example |
|---|---|---|
| Distance | How different two things are | "cat" and "dog" closer than "cat" and "airplane" |
| Direction | Captures relationships / attributes | king − man + woman ≈ queen (Word2Vec) |
| Cluster | Group of similar things | All animal embeddings cluster together |
| Linear separability | Whether classes can be split by a line | Good representations make classification easy |
In 2013, Mikolov et al. found that word embeddings trained on text had this property: embedding("king") − embedding("man") + embedding("woman") ≈ embedding("queen"). This wasn't programmed in — it emerged from learning good representations. This is why geometry in embedding space is so powerful.
Similarity = Distance in Space
The key insight: similar things should be close, dissimilar things should be far apart. This is how every modern model is trained — by shaping the embedding space so that meaningful structure is reflected in the geometry.
The Mathematics (From Scratch)
Don't worry — we're going to build this from the very bottom. No prior math assumed.
What Is a Vector?
A vector is just an ordered list of numbers. That's it.
We say this vector lives in ℝⁿ — meaning the space of all n-dimensional real number vectors.
The Dot Product: Measuring Similarity
The most important operation on vectors. Take two vectors, multiply their components pair-by-pair, then add everything up:
Why does this measure similarity? Because when two vectors point in the same direction, all the products aᵢbᵢ are large and positive, so the sum is large. When they point in opposite directions, the products are negative and the sum is small or negative.
The Norm: Length of a Vector
Cosine Similarity: The Gold Standard for Representation Similarity
Raw dot products depend on the magnitude (length) of vectors — a very long vector will have a large dot product even with something dissimilar. We fix this by normalizing:
In NLP and representation learning, we care about the direction a vector points, not its length. Two documents that are similar but one is much longer shouldn't be treated as different. Cosine similarity removes the length factor, leaving only directional similarity. This is why it's used everywhere — semantic search, contrastive learning, attention mechanisms.
Euclidean Distance: How Far Apart Are Two Points?
The Representation Function
A representation learning model is formally just a function fθ (parameterized by θ, meaning its learnable weights) that maps input x to embedding z:
The goal of training is to find θ such that the function fθ produces embeddings z where similar inputs end up close together and dissimilar inputs end up far apart.
Linear vs Non-Linear Representations
Linear: PCA
The simplest representation learner is Principal Component Analysis (PCA). It finds a low-dimensional linear subspace that captures most of the variance in the data.
PCA is powerful but limited: it can only find linear structure. If your data lies on a curved manifold (a spiral, a sphere), PCA can't unfold it properly.
Non-Linear: Autoencoders and Neural Nets
Neural networks can learn non-linear representations by stacking layers with non-linear activation functions. Each layer transforms the representation, gradually unfolding the complex manifold into something linearly separable.
The autoencoder is forced to compress the input into a small bottleneck z, then reconstruct it. To do this well, z must capture the essential structure of the data — throwing away noise, keeping meaning. This bottleneck is the learned representation.
What Makes a Representation "Good"?
This is actually a research-level question with no single answer. But there are properties that most researchers agree on:
| Property | What it means | Why it matters |
|---|---|---|
| Disentangled | Each dimension captures one independent factor (e.g., one axis = shape, another = color) | Easier to interpret, generalize, and manipulate |
| Compact | Low-dimensional — only uses the dimensions it needs | Avoids the curse of dimensionality |
| Invariant | Doesn't change under irrelevant transformations (e.g., rotating an image shouldn't change "cat" embedding) | Robustness, generalization |
| Transferable | Learned on one task, useful for other tasks | This is why pretrained models are so powerful |
| Linearly separable | Classes can be separated with a simple linear classifier on top | Proxy measure: if a linear probe works well, the representation is rich |
People often think "more dimensions = better representation." This is false. Too many dimensions leads to the curse of dimensionality — distances become meaningless, and models overfit. Good representations are as compact as possible while preserving the information needed for the task.
Code: See It In Action
Let's implement the core concepts from scratch in Python. No fancy libraries needed — just NumPy.
import numpy as np # ── 1. Create some simple representations ──────────────────────── # Represent fruits as [redness, sweetness, size] apple = np.array([0.9, 0.7, 0.5]) strawberry = np.array([0.95, 0.8, 0.1]) banana = np.array([0.1, 0.6, 0.6]) lemon = np.array([0.05, 0.2, 0.4]) # ── 2. Dot product ─────────────────────────────────────────────── def dot_product(a, b): """Multiply element-wise and sum. Measures raw similarity.""" return np.sum(a * b) print("Apple · Strawberry:", dot_product(apple, strawberry)) # large (similar) print("Apple · Lemon: ", dot_product(apple, lemon)) # smaller # ── 3. L2 Norm (vector length) ─────────────────────────────────── def norm(v): """Euclidean length of a vector: sqrt(x1² + x2² + ...)""" return np.sqrt(np.sum(v ** 2)) print("\nApple norm: ", norm(apple)) # ── 4. Cosine Similarity ───────────────────────────────────────── def cosine_similarity(a, b): """ Normalize both vectors, then take dot product. Result always in [-1, +1]. +1 = identical direction, 0 = perpendicular, -1 = opposite """ return dot_product(a, b) / (norm(a) * norm(b)) print("\nCosine Similarities:") print(f"Apple ↔ Strawberry: {cosine_similarity(apple, strawberry):.4f}") # should be high print(f"Apple ↔ Banana: {cosine_similarity(apple, banana):.4f}") print(f"Apple ↔ Lemon: {cosine_similarity(apple, lemon):.4f}") # should be low # ── 5. Euclidean Distance ──────────────────────────────────────── def euclidean_distance(a, b): """Straight-line distance between two points in embedding space.""" return norm(a - b) print("\nEuclidean Distances from Apple:") print(f"Apple → Strawberry: {euclidean_distance(apple, strawberry):.4f}") # close print(f"Apple → Lemon: {euclidean_distance(apple, lemon):.4f}") # far # ── 6. Simple Neural Network Encoder (representation learner) ─── import torch import torch.nn as nn class SimpleEncoder(nn.Module): """ A minimal encoder: maps high-dim input to low-dim representation. This is the core of every representation learning model. """ def __init__(self, input_dim, embed_dim): super().__init__() self.encoder = nn.Sequential( nn.Linear(input_dim, 128), # layer 1: compress nn.ReLU(), # non-linearity nn.Linear(128, embed_dim), # layer 2: compress more ) def forward(self, x): z = self.encoder(x) # z is the representation z = nn.functional.normalize(z, dim=-1) # unit-normalize for cosine sim return z # Example: 784-dim input (e.g., MNIST image) → 32-dim representation encoder = SimpleEncoder(input_dim=784, embed_dim=32) x1 = torch.randn(1, 784) # fake image 1 x2 = torch.randn(1, 784) # fake image 2 z1 = encoder(x1) # shape: [1, 32] — the learned representation z2 = encoder(x2) sim = (z1 * z2).sum() # cosine similarity (vectors are normalized) print(f"\nRandom encoder similarity: {sim.item():.4f}") print(f"Embedding shape: {z1.shape}") # [1, 32] print("This is the representation. 784 numbers → 32 numbers.")
The SimpleEncoder class above has the exact same structure as the encoder in BERT, CLIP, SimCLR — they're all just deeper, wider versions of the same idea: map input to a smaller vector that captures meaning. The sophistication comes from how you train this encoder (what loss function you use), not its architecture.
Blog Post Summary
- A representation is a list of numbers (vector) that describes a real-world object in a way a machine can reason about mathematically.
- The goal of representation learning is to find a function fθ(x) = z that maps raw inputs to vectors where similar things are close and dissimilar things are far apart in vector space.
- The Manifold Hypothesis says that real data (images, text, graphs) lives on a low-dimensional curved surface inside a high-dimensional space. Good representations "unfold" this surface.
- Key math: dot product (raw similarity), L2 norm (vector length), cosine similarity (normalized, directional similarity = the gold standard), Euclidean distance (geometric separation).
- Linear representations (PCA) find flat subspaces. Non-linear representations (neural networks) can unfold curved manifolds — this is their superpower.
- A good representation is disentangled (each dimension = one factor), compact, invariant to irrelevant transformations, and transferable to new tasks.
- Every major model (BERT, GPT, CLIP, SimCLR) is fundamentally the same structure: an encoder fθ that maps inputs to a useful vector space.