Model2Vec, A Static-Embedding Distillation Algorithm, Explained

How to compress a sentence-transformer into a static lookup table that runs on CPU in three lines of NumPy. A re-implementation guide covering the full pipeline, vocab prep, per-token forward, PCA whitening, SIF/Zipf weighting, optional vocab quantization, and dtype quantization, with a ~80-line reference implementation.

A static-embedding distillation algorithm, for ML engineers who want to re-implement it.

Core intuition & why it works

A sentence-transformer $f_\theta$ maps a token sequence $\mathbf{x} = (x_1,\dots,x_n)$ to a vector

\[\mathbf{e}(\mathbf{x}) \;=\; \text{Pool}\big(f_\theta(\mathbf{x})\big) \in \mathbb{R}^d.\]

The model is contextual: each token’s hidden state depends on its neighbours through attention. Inference cost is $\mathcal{O}(n^2 d)$ per sentence (plus $L$ layers), and you need a GPU to make it pleasant.

Model2Vec’s claim: for many downstream tasks (classification, retrieval, clustering), a sentence embedding is well-approximated by a weighted average of context-free token embeddings:

\[\hat{\mathbf{e}}(\mathbf{x}) \;\approx\; \frac{1}{n}\sum_{i=1}^{n} \mathbf{E}[x_i],\]
where $\mathbf{E} \in \mathbb{R}^{ V \times d’}$ is a static lookup table, one vector per vocabulary entry. Inference becomes gather + mean, i.e. $\mathcal{O}(nd’)$, no attention, no torch, CPU-friendly.

Three observations make this work better than you’d naively expect:

  1. Modern subword tokenizers carry most of the compositionality. BPE/WordPiece/Unigram split rare words into meaningful pieces, so a token-level table already encodes morphology.
  2. Mean-pooled transformer embeddings live in a near-isotropic subspace that is close to the mean of the model’s input embeddings, plus a learned correction. Distilling that correction into the table recovers most of the signal.
  3. Post-hoc whitening (PCA) and frequency down-weighting (SIF, Arora et al. 2017) close the remaining gap. These are the same tricks classical sentence embeddings used; they work here too.

Mental model. We’re asking the teacher: “if this single token were a sentence on its own, what would your output be?” We do that once per token, cache the answer, and average at query time. PCA + SIF clean up the cache.

Pipeline overview

INPUT Teacher encoder f_θ, tokenizer T STEP 1 Vocab prep clean + prefix-space STEP 2 Per-token forward E_v = Pool(f_θ(v)) STEP 3 PCA whitening d → d', decorrelate STEP 4 SIF / Zipf weights w_i = a/(a+p_i) STEP 5 (opt.) Vocab quantization KMeans → centroids STEP 6 Dtype quantize fp32 → fp16 / int8 OUTPUT StaticModel E, w, map Inference: tokenize → gather rows → (apply w) → mean-pool → L2-normalize
The whole algorithm. Solid boxes are mandatory; dashed box (vocab quantization) is optional.

Step 1: Vocabulary preparation

Start from the teacher’s tokenizer $T$ with vocabulary $V_T$. We produce a cleaned vocabulary $V$ on which we’ll generate static vectors.

Implementation note. You don’t need skeletoken to reproduce this, any tokenizer library that lets you mutate the vocab will do. The substantive transformations are: regex-filter, set prefix-space flag, de-dup.

Step 2: Per-token forward pass

For each token $v \in V$, construct the single-token input $\mathbf{x}_v$ (just the token id, no special tokens). Run the teacher and pool:

\[\mathbf{E}_v \;=\; \text{Pool}\big(f_\theta(\mathbf{x}_v)\big) \in \mathbb{R}^d.\]
Stack into $\mathbf{E} \in \mathbb{R}^{ V \times d}$.

Pooling choice depends on the teacher’s training objective:

Mode Formula Use when
MEAN $\frac{1}{n}\sum_i h_i$ BERT-style encoder trained with mean-pooling (most sentence-transformers).
FIRST $h_{[\text{CLS}]}$ BERT with CLS-token training.
LAST $h_n$ (last non-pad) Decoder-style LLMs (Qwen-embed, LLM2Vec).
POOLER pooler_output Models that expose a trained pooler head.

Practical tip: batch by length-sorting and pad; you’ll save 5–10× compute vs. naïve batching because token lengths in $V$ are highly skewed.

Step 3: PCA

Compute the principal components of $\mathbf{E}$ and project:

\[\mathbf{E}' \;=\; (\mathbf{E} - \boldsymbol{\mu})\,\mathbf{W}_{\text{PCA}}, \quad \mathbf{W}_{\text{PCA}} \in \mathbb{R}^{d \times d'}.\]

Two cases:

Why PCA is the biggest quality lever. The teacher’s hidden space has a “dominant direction”, most tokens cluster along a single principal axis (the “rogue dimension” phenomenon). Subtracting the mean and rotating to PC basis spreads the cloud out so that cosine distances are informative again.

Step 4: SIF / Zipf weighting

We don’t have a corpus to estimate token frequencies, but the tokenizer’s token id ordering already approximates frequency rank for BPE/WordPiece (more frequent → earlier merge → lower id). So we approximate the distribution by Zipf’s law on rank $r$:

\[p_r \;=\; \frac{1/(r+1)}{\sum_{k=1}^{|V|} 1/(k+1)}.\]

Then apply Arora et al.’s SIF (smooth inverse frequency) weight with hyperparameter $a$ (default $10^{-4}$):

\[w_r \;=\; \frac{a}{a + p_r}.\]

Multiply each row: $\mathbf{E}’v \leftarrow w{\text{rank}(v)} \cdot \mathbf{E}’_v$.

rank r (low = frequent stopwords) value p_r (Zipf) w_r = a/(a+p_r)
SIF down-weights the long head of frequent tokens. The product w_r · E'_v suppresses stopword contributions to sentence vectors.

Detail worth implementing carefully. The original Model2Vec uses inv_rank = 1 / np.arange(2, n+2) (so the first token gets weight $1/2$, not $1$). This avoids the singularity at $r=0$ and gives the head of the distribution a gentler slope. Don’t use $1/r$ directly.

Step 5: Vocabulary quantization (optional)

When $ V $ is large (say, 250k for multilingual models), the embedding table dominates model size. Replace it with $K \ll V $ centroids via KMeans:
\[\text{KMeans}(\mathbf{E}'_\text{norm},\,K) \;\to\; \{\mathbf{c}_1,\dots,\mathbf{c}_K\},\;\; \pi: V \to \{1,\dots,K\}.\]

Two implementation subtleties:

  1. Cluster on direction, not magnitude. Normalize rows to unit norm before fitting; store the original norm $|\mathbf{E}’_v|$ as a per-token weight $w_v$ alongside the centroids. This preserves the “this token is a strong signal” information that magnitudes carry.
  2. Apply PCA after clustering in this branch (the order matters, PCA on $K$ centroids is cheaper and the SIF weights are no longer mixed into the rows; they’re kept separately as per-token multipliers).
The final stored objects are: centroid matrix $\mathbf{C} \in \mathbb{R}^{K \times d’}$, token-to-centroid map $\pi \in \mathbb{Z}^{ V }$, per-token weight $\mathbf{w} \in \mathbb{R}^{ V }$.

Step 6: Dtype quantization

Cast the final embedding matrix from float32 to a smaller dtype. Default fp16 halves memory at essentially no quality cost; int8 (with per-row min/max scaling) is also supported. This is the last step because PCA and clustering both want fp32 numerics.

Dtype Bytes / weight Quality impact
float32 4 Baseline
float16 2 Negligible
int8 1 Small (<1% on MTEB)

Inference path

The whole point: at query time, encoding a sentence $\mathbf{x}=(x_1,\dots,x_n)$ is

\[\hat{\mathbf{e}}(\mathbf{x}) \;=\; \mathcal{N}\!\Bigg(\frac{1}{n}\sum_{i=1}^{n} w_{x_i} \cdot \mathbf{E}_{\pi(x_i)}\Bigg),\]

where $\mathcal{N}(\cdot)$ is optional L2-normalization, $\pi$ is identity if no vocab quantization, $w_{x_i}=1$ if SIF weights were folded into the rows during distillation.

That’s three lines of NumPy. No torch, no transformers, just the tokenizer (one HF dependency) and a matrix.

Reference implementation (≈80 lines)

Distillation, end-to-end. Strip error handling for clarity.

import numpy as np
import torch
from sklearn.decomposition import PCA
from transformers import AutoModel, AutoTokenizer

def distill(model_name: str, pca_dims: int = 256, sif_a: float = 1e-4,
            quantize_to=np.float16) -> dict:
    # --- Step 1: vocab prep ---
    tok = AutoTokenizer.from_pretrained(model_name, use_fast=True)
    model = AutoModel.from_pretrained(model_name).eval()
    vocab = tok.get_vocab()                                 # {token_str: id}
    tokens = sorted(vocab, key=vocab.get)                   # ordered by id
    keep = [t for t in tokens if not t.startswith("[unused")]
    ids  = [vocab[t] for t in keep]

    # --- Step 2: one forward pass per token ---
    # Length-sort and batch; here we use trivial batches for clarity.
    E = np.zeros((len(ids), model.config.hidden_size), dtype=np.float32)
    with torch.inference_mode():
        for i in range(0, len(ids), 256):
            batch = ids[i:i+256]
            x = torch.tensor([[tid] for tid in batch])      # shape (B, 1)
            mask = torch.ones_like(x)
            out = model(input_ids=x, attention_mask=mask).last_hidden_state
            # mean pool (B,1,d) -> (B,d)
            E[i:i+len(batch)] = out.mean(dim=1).float().numpy()

    # --- Step 3: PCA ---
    pca = PCA(n_components=pca_dims, svd_solver="full")
    E = pca.fit_transform(E)                                # (|V|, d')

    # --- Step 4: SIF / Zipf weights ---
    rank = np.arange(2, len(E) + 2)                         # avoid 1/0 at r=0
    inv  = 1.0 / rank
    p    = inv / inv.sum()
    w    = sif_a / (sif_a + p)                              # (|V|,)
    E    = E * w[:, None]                                   # fold into rows

    # --- Step 6: dtype quantize (skipping Step 5 for brevity) ---
    E = E.astype(quantize_to)

    return {"embedding": E, "tokens": keep, "tokenizer": tok}


def encode(static: dict, sentences: list[str]) -> np.ndarray:
    tok, E = static["tokenizer"], static["embedding"].astype(np.float32)
    out = []
    for s in sentences:
        ids = tok(s, add_special_tokens=False)["input_ids"]
        v = E[ids].mean(axis=0) if ids else np.zeros(E.shape[1])
        out.append(v / (np.linalg.norm(v) + 1e-12))
    return np.stack(out)

That’s it. ~50 lines for distill, ~10 for encode. To match the reference library, add:

Adaptation notes

Quality knobs, ranked by impact

  1. PCA, biggest single win, especially the rotation effect.
  2. SIF / Zipf weighting, meaningful gain on retrieval-like tasks.
  3. Vocabulary choice, domain vocabularies (code, medical) can outperform generic ones on in-domain tasks.
  4. Pooling mode, match the teacher’s training objective; getting this wrong is silent quality loss.
  5. Dtype, fp16 is free; int8 has a small cost.

Adapting to other domains / modalities

What model2vec is not good at. Anything that requires word-order or syntactic information: NLI, certain reranking tasks, long-range coreference. Bag-of-(weighted)-embeddings throws away order by construction. If your downstream task is order-sensitive, this technique caps your quality.


Reference: github.com/MinishLab/model2vec · SIF: Arora, Liang, Ma 2017 · Anisotropy: Ethayarajh 2019