Back

Nitrobrew: Fast, Lossless Distillation for Free

4.28.2026
Dhruv Pai*,  Timor Averbuch*,  Alec Dewulf*,  Ben Keigwin,  Ashley Zhang
* Core Contributor; Correspondence to dhruv@tilderesearch.com

TL;DR

  • Distillation, especially on-policy distillation, has become a crucial component in reasoning model post-training workflows.
  • Logit distillation at modern vocabulary sizes is bottlenecked by communication and memory, not compute.
  • Nitrobrew exploits the fact that teacher logits are generated from a much lower-dimensional hidden state through the unembedding matrix.
    • Sends hidden states instead of logits (up to 60× less communication)
    • Computes divergence online without materializing the full logit tensor (37× less memory).
  • Nitrobrew is fully lossless and exact, unlike top-k style approximations.
  • For NemoRL and VeRL, Nitrobrew achieves 1.5–3× faster end-to-end throughput for on-policy distillation.
  • We open-source Nitrobrew and a VeRL PR for support.
  • We also investigate spectral compression approaches for even stronger savings.

Nitrobrew Overview

Note: Shortly before this post went live, DeepSeek-V4 independently reported the same core idea for full-vocabulary OPD — caching teacher hidden states and reconstructing logits on the fly. Nitrobrew was developed independently and differs in several respects: open-source framework integrations (NeMo RL, VeRL), explicit online divergence algorithms for forward/reverse KL and JSD, detailed profiling across model scales, and spectral compression experiments (SVD-Nitrobrew) for further reducing communication below dmodeld_{\text{model}}.

Introduction

In distillation, a student model is trained to reproduce the behaviour of a more capable teacher model. The teacher learns a compressed representation of the data distribution, which is then transferred to the student through supervision on soft labels, logits, or hidden states [1] [2]. Soft teacher outputs often provide richer supervision than hard labels because they expose the teacher's learned approximation to the noisy data-generating process, including its uncertainty, calibration, and relative preferences among alternatives [1].

Different distillation methods can broadly be classified along two axes: data source (off-policy, or on-policy) and training objective (token-level or sequence-level [3]). We will mostly focus on the first axis, which usually has a larger impact on the resulting implementation complexity.

Off-policy distillation trains the student on a fixed dataset of teacher outputs. Logit distillation typically minimizes a KL divergence between the teacher and student distributions, while sequence-level distillation uses cross-entropy on teacher-generated target sequences. Implementing off-policy distillation usually requires only small modifications to standard training pipelines, but full-logit losses can be prohibitively expensive for modern vocabulary sizes and model scales [4] [5].

Off-policy vs on-policy distillation Off-policy vs on-policy distillation. On-policy distillation leverages student-generated rollouts.

On-policy distillation (OPD) instead trains the student on its own rollouts [6] [7]. Because the rollout distribution depends on the current student, fresh data must be generated at every step1 rather than precomputed, and student generation and teacher scoring must alternate throughout training.

Algorithm: General On-Policy DistillationRequire: Student policy πS, teacher policy πT, prompt dataset DRequire: Distillation loss d(,), e.g. KL between softmax distributionsfor xD doySπS(x)Sample student rolloutzSπS(x,yS)Student logits along sampled rolloutzTπT(x,yS)Teacher logits on same rolloutxd(zT,zS)Distillation lossUpdate πS using xend for\begin{aligned} &\textbf{Algorithm: General On-Policy Distillation} \\ &\textbf{Require: } \text{Student policy } \pi_S, \text{ teacher policy } \pi_T, \text{ prompt dataset } \mathcal{D} \\ &\textbf{Require: } \text{Distillation loss } d(\cdot,\cdot), \text{ e.g.\ KL between softmax distributions} \\ &\textbf{for } x \sim \mathcal{D} \textbf{ do} \\ &\quad y_S \sim \pi_S(\cdot \mid x) \quad \triangleright \text{Sample student rollout} \\ &\quad z_S \gets \pi_S(x, y_S) \quad \triangleright \text{Student logits along sampled rollout} \\ &\quad z_T \gets \pi_T(x, y_S) \quad \triangleright \text{Teacher logits on same rollout} \\ &\quad \ell_x \gets d(z_T, z_S) \quad \triangleright \text{Distillation loss} \\ &\quad \text{Update } \pi_S \text{ using } \nabla \ell_x \\ &\textbf{end for} \end{aligned}

In distributed on-policy setups, the student and teacher often run on different devices because of memory constraints or parallelism strategy. This adds an additional communication overhead: the teacher's logits zTz_T must be sent to the same device as the student logits zSz_S to compute the loss each step. The size of the logit vectors scales with the vocabulary size, which is often quite large (usually 128k to 200k tokens). To reduce this overhead, traditional approaches subsample logits (e.g. take the top-k logits) but this results in its own set of pathologies, which we discuss further in The Pathology of Top-k Truncation.

We introduce Nitrobrew which addresses communication costs without introducing any of the pathologies of top-kk, and avoids full-logit materialization in both the on- and off-policy settings. In the on-policy setting, Nitrobrew communicates the teacher's hidden state (concurrent with DeepSeek-V4, which independently adopts the same principle), which is much smaller than zTz_T. For loss computation, Nitrobrew computes teacher and student logits tile-by-tile inside a fused online divergence kernel, avoiding materialization of the full logit tensor. In the off-policy and self-distillation settings [8] [9], Nitrobrew can store cached teacher hidden states instead of full logits, reconstructing teacher logits on the fly during the divergence computation. We find that Nitrobrew improves OPD throughput by 3x for a fixed communication budget while being fully lossless. We open-source our code as well as open a PR to VeRL directly for Nitrobrew support.

The Pathology of Top-k Truncation

Top-kk is the standard lossy approach for further reducing the cost of (on-policy) distillation. Overview of existing approaches to OPD Overview of existing approaches to OPD. Existing approaches are forced to make approximation errors in order to enjoy favorable compute/communication/memory tradeoffs.

The student and teacher logit tensors are compressed along the vocabulary dimension by discarding all but the kk largest floats. We highlight two significant problems with this approach:

The distortion is input-dependent and discontinuous. The rank-kk and rank-(k+1)(k+1) tokens may have nearly identical probabilities, but top-kk keeps one and completely discards the other. This discontinuity means that tokens may go from providing no signal (masked by top-kk) to giving large signal, due to a small perturbation in hh. Furthermore, renormalization of the kept tokens will inflate gradients when original teacher distribution had high entropy. The student will thus receive stronger updates when the teacher is uncertain, which is undesirable.

Most of the teacher's calibration lives in the tail of the distribution, which is discarded. The long tail of low-probability tokens encodes the teacher's uncertainty and calibration, including its implicit ranking of alternatives, and soft knowledge about what not to predict [1]. The student will only learn from those tokens about which the teacher is most confident, which are probably easiest to predict.

Top-k truncation + renormalization Top-k truncation + renormalization results in a distorted distribution. It is not an unbiased estimator of the true distribution.

Previous work has identified several shortcomings of top-kk truncation for distillation. In particular, top-k truncation results in miscalibrated student distributions, catastrophic forgetting and a lack of generalization [10][11]. NVIDIA Minitron directly found that k100k \leq 100 resulted in a large accuracy drop, and larger kk still never outperformed full logits [5].

We seek an approach that addresses these issues, while still achieving the communication/memory benefits of top-kk truncation.

Introducing Nitrobrew: Efficient, Lossless Distillation

For a vocabulary of size VV, batch-size BB and sequence length TT, naive on-policy distillation requires materializing, communicating, and computing divergences over tensors of shape B×T×VB \times T \times V. This adds three distinct overheads on top of the student and teacher forward passes: (1) communicating the teacher's logits, (2) HBM allocation for the teacher and student logit tensors and (3) an element-wise divergence computation between the logits. Nitrobrew addresses all three of these issues: (1) by communicating hidden states instead of logits, and (2) & (3) with a fused online divergence kernel. Our approach is illustrated below.

Nitrobrew approach overview Unlike previous approaches, Nitrobrew is both lossless and memory/compute efficient.

One of the core observations motivating Nitrobrew is that the embedding matrix is low-rank. Recall that logits zz are computed by taking the final layer's hidden state hh, and applying the unembedding matrix WUW_U.

z=WUh,WURV×dmodel,hRdmodelz = W_U \, h, \qquad W_U \in \mathbb{R}^{V \times d_{\text{model}}}, \quad h \in \mathbb{R}^{d_{\text{model}}}

The logit vector zz lives in the column space of WUW_U, which has rank at most dmodeld_{\text{model}}. In modern transformers, it is usually the case that dmodelVd_{\text{model}} \ll V. For example, Qwen3.5-397B-A17B has dmodel=4096d_{\text{model}}=4096 and V=248,320V=248{,}320. In this case, the rank of the embedding matrix is at most 4096 so the logits lie in a dmodeld_{\text{model}}-dimensional subspace of RV\mathbb{R}^V. Luckily, we already have access to a compressed version of zz: the teacher's final hidden states.

Nitrobrew exploits this fact by communicating hh and reconstructing z=WU(T)hz = W_U^{(T)} h locally using a copy of the teacher's unembedding, WU(T)W_U^{(T)}. It reconstructs logits directly on the device where the loss is computed, adding no asymptotic computation over the standard full-logit loss, while avoiding communication of the full VV-dimensional teacher logit vector. For Qwen3.5-397B-A17B, this yields a 60x reduction in communication overhead.

This approach also provides a straightforward way to reduce memory bottleneck introduced by the divergence computation. We fuse the teacher and student unembedding matmuls into the divergence kernel: rather than constructing the full B×T×VB \times T \times V logit tensors and then computing their divergence, we tile over the last dimension and process one VBLOCK\text{VBLOCK}-sized chunk at a time. Within each tile, teacher and student logits are computed on the fly, consumed by the running divergence accumulators, and then discarded. This reduces peak memory cost from O(BTV)O(BTV) to O(BTVBLOCK)O(BT \cdot \text{VBLOCK}) because it avoids ever materializing the full logit tensors. The full Nitrobrew algorithm is given below.

Algorithm: On-policy distillation with NitrobrewRequire: Student πS, teacher πT, dataset D, teacher unembed WU(T)Require: Optional: SVD rank k, projection VkRd×kfor xD doySπS(x)Student rollouthTπT(x,yS).hidden_statesTeacher scores rolloutif SVD compression thenh~TVkhTProject to k floatselseh~ThTend ifsend h~T to student devicehSπS(x,yS).hidden_statesNitrobrew(hS,h~T,WU(S),WU(T))Alg. 2 or 3Update πS using Backward via Alg. 4end for\begin{aligned} &\textbf{Algorithm: On-policy distillation with Nitrobrew} \\ &\textbf{Require: } \text{Student } \pi_S, \text{ teacher } \pi_T, \text{ dataset } \mathcal{D}, \text{ teacher unembed } W_U^{(T)} \\ &\textbf{Require: } \text{Optional: SVD rank } k, \text{ projection } V_k \in \mathbb{R}^{d \times k} \\ &\textbf{for } x \sim \mathcal{D} \textbf{ do} \\ &\quad y_S \gets \pi_S(x) \quad \triangleright \text{Student rollout} \\ &\quad h_T \gets \pi_T(x, y_S).\text{hidden\_states} \quad \triangleright \text{Teacher scores rollout} \\ &\quad \textbf{if } \text{SVD compression} \textbf{ then} \\ &\qquad \tilde{h}_T \gets V_k^\top h_T \quad \triangleright \text{Project to } k \text{ floats} \\ &\quad \textbf{else} \\ &\qquad \tilde{h}_T \gets h_T \\ &\quad \textbf{end if} \\ &\quad \textbf{send } \tilde{h}_T \text{ to student device} \\ &\quad h_S \gets \pi_S(x, y_S).\text{hidden\_states} \\ &\quad \ell \gets \text{Nitrobrew}(h_S, \tilde{h}_T, W_U^{(S)}, W_U^{(T)}) \quad \triangleright \text{Alg.\ 2 or 3} \\ &\quad \text{Update } \pi_S \text{ using } \nabla \ell \quad \triangleright \text{Backward via Alg.\ 4} \\ &\textbf{end for} \end{aligned}

Cost Anatomy of On-Policy Distillation

Before presenting the online kernel, it is worth understanding where the time and memory go in a single on-policy distillation step. As a concrete example, consider distilling a Qwen3-32B teacher into a Qwen3-8B student, with batch size B=4B=4, sequence length T=8192T=8192, vocabulary size V=152,064V=152,064 and dmodel=4096d_{model}=4096.

The table below breaks down the three dominant costs — communication, memory, and compute — for each method. We assume bf16 for communication and fp32 for the divergence working set.

Cost comparison for on-policy distillation methods. ¹float32 value + int32 index. ²Top-k reduces the divergence working set, but this is misleading — the full [B,T,V] logit tensor must still be materialized to compute the top-k indices.
AspectNaiveTop-k (k=128)NitrobrewSVD-Nitrobrew (k=512)
Floats sent per positionV (152k)2k (256)¹d (4096)k (512)
Bytes per step9.4 GB15.8 MB256 MB32 MB
Divergence working set2×BTV (37.6 GB)2×BTk (32 MB)²2×BT·VBLK (1 GB)2×BT·VBLK (1 GB)
Full-vocabulary signalYesNoYesApproximate
Offline precomputationNoneNoneCopy teacher W_U to studentSVD of teacher W_U

Communication is the most variable and severe cost. The naive approach sends 9.4 GB of logits per step. At 50 GB/s inter-node bandwidth, that is 190 ms of pure transfer time. Nitrobrew sends only 256 MB, which is more than an order of magnitude less communication. Of note, the distinction matters most in the inter-node large-scale regime, where interconnects are the bottleneck. Intranode, NVLink bandwidth is more than high enough to tolerate the naive approach.

Memory is where Nitrobrew's fused kernel has the largest impact. The naive divergence computation requires both the teacher and student logit tensors to be resident simultaneously. At 37.6 GB, this alone can exceed the free HBM on an 80 GB GPU after model weights and activations are loaded. Nitrobrew reduces peak memory usage to only 1 GB.

While theoretical compute is roughly equivalent across methods, we will see that kernel compilation enables tiled, coupled matmuls to get strong throughput gains at long context.

Online Divergence

We describe our online divergence kernel for the forward KL case, but the algorithm is analogous for other divergences like reverse KL or Jensen-Shannon Divergence (JSD), which just requires some additional book-keeping. Recall that for p=softmax(zs)p = \mathrm{softmax}(z_s) and q=softmax(zt)q = \mathrm{softmax}(z_t), forward KL decomposes as

KL(pq)=ipilogpiipilogqi=Ep[zs]Ep[zt]logZs+logZt\mathrm{KL}(p \| q) = \sum_i p_i \log p_i - \sum_i p_i \log q_i = \mathbb{E}_p[z_s] - \mathbb{E}_p[z_t] - \log Z_s + \log Z_t

where ZsZ_s and ZtZ_t are the partition functions of zsz_s and ztz_t respectively. Every term here is either a log-sum-exp or an expectation under the softmax, and can be computed in a single streaming pass over the vocabulary, using the online softmax/log-sum-exp trick from FlashAttention [12], analogous to fused large-vocabulary cross-entropy kernels that avoid full-logit materialization [13].

Algorithm: Online forward KL: KL(pspt)  forward passRequire: xsRN×Ds,  xtRN×Dt,  WsRV×Ds,  WtRV×Dt,  temperature TEnsure: KLRN,  logZsRN,  logZtRNms,mt;ss,st,ts,us0All [N] vectorsfor v0=0,  VBlk,  2VBlk,  ,  V dozsxsWs[v0 ⁣: ⁣v1]/T[N,VBlk]ztxtWt[v0 ⁣: ⁣v1]/TStudent online softmax updatemsmax(ms,  maxtile(zs))αexp(msms)ssαss+tileexp(zsms)tsαts+tileexp(zsms)zsusαus+tileexp(zsms)ztmsmsTeacher online softmax updatemtmax(mt,  maxtile(zt))stexp(mtmt)st+tileexp(ztmt)mtmtend forlogZsms+logss;logZtmt+logstKL(tsus)/sslogZs+logZtReturn KL,  logZs,  logZt\begin{aligned} &\textbf{Algorithm: Online forward KL: } \mathrm{KL}(p_s \| p_t) \textbf{ — forward pass} \\ &\textbf{Require: } x_s \in \mathbb{R}^{N \times D_s},\; x_t \in \mathbb{R}^{N \times D_t},\; W_s \in \mathbb{R}^{V \times D_s},\; W_t \in \mathbb{R}^{V \times D_t},\; \text{temperature } T \\ &\textbf{Ensure: } \mathrm{KL} \in \mathbb{R}^N,\; \log Z_s \in \mathbb{R}^N,\; \log Z_t \in \mathbb{R}^N \\ &m_s, m_t \gets -\infty;\quad s_s, s_t, t_s, u_s \gets 0 \quad \triangleright \text{All } [N] \text{ vectors} \\ &\textbf{for } v_0 = 0,\; \text{VBlk},\; 2 \cdot \text{VBlk},\; \ldots,\; V \textbf{ do} \\ &\quad z_s \gets x_s \, W_s[v_0\!:\!v_1]^\top / T \quad \triangleright [N, \text{VBlk}] \\ &\quad z_t \gets x_t \, W_t[v_0\!:\!v_1]^\top / T \\ &\quad \triangleright \text{Student online softmax update} \\ &\quad m_s' \gets \max(m_s,\; \max_{\text{tile}}(z_s)) \\ &\quad \alpha \gets \exp(m_s - m_s') \\ &\quad s_s \gets \alpha \cdot s_s + \textstyle\sum_{\text{tile}} \exp(z_s - m_s') \\ &\quad t_s \gets \alpha \cdot t_s + \textstyle\sum_{\text{tile}} \exp(z_s - m_s') \cdot z_s \\ &\quad u_s \gets \alpha \cdot u_s + \textstyle\sum_{\text{tile}} \exp(z_s - m_s') \cdot z_t \\ &\quad m_s \gets m_s' \\ &\quad \triangleright \text{Teacher online softmax update} \\ &\quad m_t' \gets \max(m_t,\; \max_{\text{tile}}(z_t)) \\ &\quad s_t \gets \exp(m_t - m_t') \cdot s_t + \textstyle\sum_{\text{tile}} \exp(z_t - m_t') \\ &\quad m_t \gets m_t' \\ &\textbf{end for} \\ &\log Z_s \gets m_s + \log s_s;\quad \log Z_t \gets m_t + \log s_t \\ &\mathrm{KL} \gets (t_s - u_s) / s_s - \log Z_s + \log Z_t \\ &\textbf{Return } \mathrm{KL},\; \log Z_s,\; \log Z_t \end{aligned}

We maintain six running accumulators as we tile over vocabulary chunks of size VBLOCK. The entire computation touches each logit exactly once and stores O(1)O(1) state per position. When fused with the unembedding matmuls from the Nitrobrew framework, the logits are computed tile-by-tile as WU()h()W_U^{(\cdot)} h^{(\cdot)} and immediately consumed by the accumulator updates. The only tensors that survive are the six scalars per position. The algorithms for online reverse KL and JSD can be found in Appendix B.

Kernel Implementation

We experimented with optimized Triton and TileLang kernels for Nitrobrew. Similar to flash attention, we can fuse large matrix multiplies (unembeds) with their subsequent nonlinearities to avoid excessive roundtrips to HBM.

We ultimately found, however, that for larger models/longer context lengths, an optimized torch implementation was superior. The savings from full fusion are dwarfed by the efficiency of cuBLAS matmuls for the unembedding projections. A hybrid approach, where only the online KL itself is fused, is a low-hanging direction for future work. For reference, a similar approach has found success in the fused_linear_crossentropy kernel from Quack written in the CuTE DSL.

Below is the example forward pass for the performant torch implementation used.

@torch.compile
def _nitrobrew_fwd_chunked(
    xs: torch.Tensor,
    xt: torch.Tensor,
    ws: torch.Tensor,
    wt: torch.Tensor,
    temperature: float,
    chunk_V: int = 4096,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """Chunked forward online-softmax KL.

    xs: [N, D_s]  student hidden states
    xt: [N, D_t]  teacher hidden states
    ws: [V, D_s]  student unembed
    wt: [V, D_t]  teacher unembed

    Returns kl, logZs, logZt as flat [N] fp32 tensors.
    """
    N = xs.shape[0]
    V = ws.shape[0]
    inv_temp = 1.0 / temperature

    ms = torch.full((N,), float("-inf"), dtype=torch.float32, device=xs.device)
    mt = torch.full((N,), float("-inf"), dtype=torch.float32, device=xs.device)
    ss = torch.zeros(N, dtype=torch.float32, device=xs.device)
    st = torch.zeros(N, dtype=torch.float32, device=xs.device)
    ts = torch.zeros(N, dtype=torch.float32, device=xs.device)
    us = torch.zeros(N, dtype=torch.float32, device=xs.device)

    for v0 in range(0, V, chunk_V):
        v1 = min(v0 + chunk_V, V)

        zs_tile = torch.mm(xs, ws[v0:v1].T).float().mul_(inv_temp)
        zt_tile = torch.mm(xt, wt[v0:v1].T).float().mul_(inv_temp)

        # Student online softmax update
        tile_ms = zs_tile.max(dim=1).values
        new_ms = torch.maximum(ms, tile_ms)
        alpha_s = (ms - new_ms).exp_()
        ss.mul_(alpha_s)
        ts.mul_(alpha_s)
        us.mul_(alpha_s)
        p_tile = (zs_tile - new_ms.unsqueeze(1)).exp_()
        ss.add_(p_tile.sum(dim=1))
        ts.add_((p_tile * zs_tile).sum(dim=1))
        us.add_((p_tile * zt_tile).sum(dim=1))
        ms = new_ms

        # Teacher online softmax update
        tile_mt = zt_tile.max(dim=1).values
        new_mt = torch.maximum(mt, tile_mt)
        alpha_t = (mt - new_mt).exp_()
        st.mul_(alpha_t)
        st.add_((zt_tile - new_mt.unsqueeze(1)).exp_().sum(dim=1))
        mt = new_mt

    logZs = ms + ss.log()
    logZt = mt + st.log()
    kl = (ts - us) / ss - logZs + logZt

    return kl, logZs, logZt

Isolated Profiling Results

Below are the profiling results for the optimized chunked torch implementation on Hopper. The tested setting had a fixed vocab size V=132kV=132\text{k}, and dmodel{2048,4096,8192}d_{\text{model}} \in \{ 2048, 4096, 8192 \} for the model sizes.

Isolated profiling results

Nitrobrew is significantly faster than the naive torch implementation with much lower peak memory demand. While shorter sequences do not benefit from tiling over the vocabulary dimension, longer sequence lengths are compute-bound thanks to the lack of full logit tensor materialization. At a practical sequence length of 16k, Nitrobrew accelerates distillation loss calculation for a 70B model by 100x with 50% less memory!

In addition to isolated profiling, we also implemented and tested Nitrobrew inside two popular RL frameworks: NemoRL [14] & VeRL [15].

On-policy Distillation Results

We first profiled the end-to-end step time in the NemoRL [14] framework. Details on the profiling setting can be found in Appendix C.

NemoRL profiling results Profiling results under various distillation settings with NemoRL + Nitrobrew.

For d_model floats per token, which is the standard Nitrobrew setup, our approach is 1.5–4.5x faster in end-to-end wall clock time compared to top-k. Notably, even at very few floats per token, our approach outperforms top-k on speed. For a fixed step time, full Nitrobrew achieves equivalent throughput to top-k while sending an order of magnitude more floats — thereby preserving information and remaining fully lossless.

We can also analyze the breakdown of where time is spent within each step.

Step time breakdown Breakdown of savings by computation category within step for sample distillation setting. Savings are concentrated in teacher inference and policy training.

As seen above, the fused, chunked KL implementation is doing most of the work at this scale. Distillation becomes communication bound at larger scales, but at smaller scales the primary savings come from eliminating the costly unembed + top-k from teacher inference, and unembed + unfused kl from student policy training. Similar breakdowns for other distillation setups can be found in Appendix D.

We also implemented Nitrobrew in VeRL, and found even stronger speedups. We are releasing a PR to VeRL for Nitrobrew support which can be found in Appendix A.

VeRL profiling results Profiling results under various distillation settings with VeRL + Nitrobrew.

For our direct OPD experiments, we trained the student Qwen 3 1.7B Base with the teacher Qwen 3 8B on the MATH dataset. We timed one iteration through the entire dataset, on the order of 50 steps. Further training details can be found in Appendix E.

OPD training results

The model achieves 3x the step throughput of a top-k approach, even with 4x the float communication. Nitrobrew is lossless and performs competitively with the top-k baseline, for further notes on convergence refer to Appendix F.

Towards Spectral Logit Compression

Nitrobrew reduces the number of communicated floats from VV to dmodeld_{\text{model}} completely losslessly but in some practical settings, such as multi-node distillation with tight interconnect budgets, even dmodeld_{\text{model}} floats per position may be too many. For this setting, we propose SVD-Nitrobrew, which uses an SVD of the teacher’s unembedding to identify low-energy directions to discard.

We report here some preliminary investigations into spectral compression, but this direction warrants further research.

SVD-Nitrobrew

Consider the thin SVD of the teacher's unembedding matrix:

WU=UΣV,URV×d,ΣRd×d,VRd×dW_U = U \Sigma V^\top, \qquad U \in \mathbb{R}^{V \times d}, \quad \Sigma \in \mathbb{R}^{d \times d}, \quad V \in \mathbb{R}^{d \times d}

Empirically, the spectral energy of WUW_U tends to be concentrated in the top few singular values — i.e. its effective rank tends to be much smaller than dd. SVD-Nitrobrew exploits this fact by projecting the teacher's hidden state onto the top-kk right singular vectors of WUW_U before it is communicated.

SVD-Nitrobrew

  1. Precompute (once, offline): the rank-kk truncated SVD WUUkΣkVkW_U \approx U_k \Sigma_k V_k^\top, where VkRd×kV_k \in \mathbb{R}^{d \times k} are the top-kk right singular vectors.
  2. Communicate: h~=VkhRk\tilde{h} = V_k^\top h \in \mathbb{R}^k
  3. Reconstruct: construct approximate logits z~=UkΣkh~\tilde{z} = U_k \Sigma_k \tilde{h} online in the divergence kernel as before.

Unlike top-k, this method produces a smooth, full-vocabulary approximation of the logits and does not require any renormalization. The reconstruction error is bounded by the energy of the spectral tail of i>kσi2\sum_{i > k} \sigma_i^2 which is small for kk sufficiently large. The per-token communication cost is exactly kk floats, which is the same as the top-k approach, and the SVD is done offline so it doesn't add any overhead.

SVD-Nitrobrew

Comparison to Top-kk Compression

To compare SVD-Nitrobrew with top-kk we need a notion of approximation quality. We will start by considering L2 error on logits. For any hidden state hh, truncated SVD gives the worst-case reconstruction bound

zz~22=(WUUkΣkVk)h22σk+12h22\| z - \tilde{z} \|_2^2 = \| (W_U - U_k \Sigma_k V_k^\top) h \|_2^2 \leq \sigma_{k+1}^2 \| h \|_2^2

By Eckart–Young, the truncated SVD used by SVD-Nitrobrew is the optimal rank-kk approximation in operator norm, so no other rank-kk linear map can improve this worst-case bound. However, since top-k truncation is nonlinear and input-dependent, this optimality result does not imply that SVD is always better than top-kk.

Furthermore, the relevant quantity for distillation is the divergence between the teacher distribution induced by the true logits and the approximate distribution induced by the reconstructed logits. This divergence defines the student's training signal, so compression should preserve it even when the logits are only approximately reconstructed. We can derive a simple illustrative bound by Lipschitzness of softmax. For p~=softmax(z~/τ)\tilde{p} = \text{softmax}(\tilde{z}/\tau) let p~iϵ\tilde{p}_i \geq \epsilon for all tokens, then

KL(pp~)1εpp~221τ2ϵzz~22\mathrm{KL}(p \,\|\, \tilde{p}) \leq \frac{1}{\varepsilon} \|p - \tilde{p}\|_2^2 \leq \frac{1}{\tau^2 \epsilon} \| z - \tilde{z} \|_2^2

Where p=softmax(z/τ)p = \text{softmax}(z/\tau) is the true teacher distribution and τ\tau is the softmax temperature. Combining this with the SVD reconstruction bound gives,

KL(pp~)1τ2ϵσk+12h22\mathrm{KL}(p \,\|\, \tilde{p}) \leq \frac{1}{\tau^2 \epsilon} \sigma_{k+1}^2 \| h \|_2^2

This bound is loose in large-vocabulary settings because the minimum token probability can be very small. We therefore use it only as a qualitative guide: reducing spectral reconstruction error should reduce distributional distortion, but KL is much more sensitive to errors on high-probability tokens than logit MSE alone suggests.

Comparison of Compression Strategies

We then sought to improve upon the spectral compression strategy employed. For each compression strategy, we tested its KL against the true teacher distribution and MSE against the true teacher logit vector. For details on compression experiments, refer to Appendix G.

Naive SVD: As outlined above in SVD-Nitrobrew, the default approach simply uses the truncated right singular matrix for projection.

Naive SVD compression results

We find that naive SVD truncation is a poor compressor. For most compression budgets, it has significantly higher KL divergence from the ground truth relative to top-k.

Importance SVD:

Analyzing the spectral properties of the unembedding in isolation ignores the geometry of the input activations to the unembedding — which are also highly structured and have dominant directions of variance. We can instead ask for dominant directions shared between the post-RMS final activation and the unembedding.

We measure this through the effective logit importance of each singular direction ii:

Ii=σi2E[(vih)2]I_i = \sigma_i^2 \cdot \mathbb{E}\left[(v_i \cdot h)^2\right]

where σi\sigma_i is the ii-th singular value of WW, viv_i is the corresponding right singular vector, and hh is the post-RMSNorm hidden state at the final layer. The quantity E[(vih)2]\mathbb{E}[(v_i \cdot h)^2] is the average activation energy projected onto direction viv_i measured over a held-out corpus of model completions. The product IiI_i captures the actual contribution of direction ii to the logits.

Importance U-shape

We observe a U-shaped phenomenon, where both the dominant and smallest singular values have the most logit importance. We can motivate an approach which orders singular values (and directions) by their effective importance, and truncate accordingly.

Importance-ordered truncation crossover The importance-ordered truncation has a nontrivial crossover point in KL d<dmodeld^* < d_{\mathrm{model}}, allowing for compression.

When we investigate the source of the remaining discrepancy, we find that most tokens in generation are low-entropy, and SVD has a hard time matching these. Furthermore, the optimal linear map in the MSE sense is not necessarily the optimal linear map in the KL sense. The latter, for example, cares far more about sharpness around the top candidates than accuracy on the tail.

Probability-reweighted SVD:

Motivated by this observation, we attempt a final approach which reweights the SVD by the square root of the average token probability.

pˉ=E[softmax(Wh)]RV\bar{p} = \mathbb{E}[\text{softmax}(Wh)] \in \mathbb{R}^V

pˉ\bar{p} is the marginal probability of each token in the vocabulary, averaged over our held-out corpus of completions. Tokens with high pˉj\bar{p}_j are the ones that the model frequently assigns significant probability mass — these are precisely the tokens whose logit reconstruction errors dominate KL divergence.

The reweighted logit map is:

M=diag(pˉ)WC1/2M = \text{diag}(\sqrt{\bar{p}}) \cdot W \cdot C^{1/2}

where C=E[hh]C = \mathbb{E}[hh^\top] is the empirical activation covariance matrix.

The left factor diag(pˉ)\text{diag}(\sqrt{\bar{p}}) scales each row of WW by the square root of that token's average probability. The right factor C1/2C^{1/2} weights input directions by their empirical activation scale and covariance. Performing SVD on MM finds directions that maximize the probability-weighted logit variance under the empirical activation distribution.

The SVD of this reweighted matrix yields:

B=V~[:r,:]  C1/2(r×D encoder)B = \tilde{V}[:r, :] \; C^{-1/2} \quad (r \times D \text{ encoder}) A=diag(1/pˉ)U~[:,:r]diag(S~[:r])(V×r decoder)A = \text{diag}(1/\sqrt{\bar{p}}) \cdot \tilde{U}[:, :r] \cdot \text{diag}(\tilde{S}[:r]) \quad (V \times r \text{ decoder})

The decoder includes the inverse weighting diag(1/pˉ)\text{diag}(1/\sqrt{\bar{p}}) to undo the scaling, ensuring that at full rank (r=Dr = D), we recover AB=WAB = W exactly. At reduced rank, the approximation error is concentrated on low-probability tokens, which are least important for KL. The results are shown below.

Probability-reweighted SVD results

We finally have a nontrivial truncation point at r=741r=741. We can compress our Nitrobrew state spectrally and still expect to have more information than top-k.

There are significantly more effective ways to compress the final hidden state, particularly nonlinear approaches such as autoencoders, that we did not investigate here but remain a low-hanging direction for future gains.

Takeaways + Looking Forward

Nitrobrew makes full-vocabulary logit distillation cheap enough to be the default — no top-k truncation or biased, lossy approximation.

We would like to stress two guiding principles:

  • Simplicity: Nitrobrew requires minimal changes to existing training pipelines, and even the kernel admits a very efficient compiled torch implementation.
  • Generality: Nitrobrew is effective for both on- and off-policy distillation. It can dramatically reduce the storage costs for caching teacher predictions in off-policy distillation, which was the original motivation for the tool internally.

In practice, Nitrobrew yields 100× faster loss computation at long context, 37× less peak memory, and 2.5–3× faster training steps.

Spectral compression of logits for distillation is an important avenue of future work that we are quite excited about. Instead of post-hoc strategies, this line of work directly leverages a deeper understanding of model architecture & dynamics to improve practical performance. At Tilde, that's precisely the type of approach we look for.

Code is available at Nitrobrew-Release. We have submitted a PR to VeRL for direct integration as well.

Appendix

A. Code

A simple torch implementation is available at Nitrobrew-Release. The VeRL PR can be found here.

B. Online Divergence Algorithms

Reverse KLAlgorithm: Online reverse KL: KL(ptps)  forward passRequire: xsRN×Ds,  xtRN×Dt,  WsRV×Ds,  WtRV×Dt,  temperature TEnsure: KLRN,  logZsRN,  logZtRNms,mt;ss,st,tt,ut0for v0=0,  VBlk,  2VBlk,  ,  V dozsxsWs[v0 ⁣: ⁣v1]/TztxtWt[v0 ⁣: ⁣v1]/TStudent: partition function onlymsmax(ms,  maxtile(zs))ssexp(msms)ss+tileexp(zsms)msmsTeacher: weighted accumulatorsmtmax(mt,  maxtile(zt))αexp(mtmt)stαst+tileexp(ztmt)ttαtt+tileexp(ztmt)ztutαut+tileexp(ztmt)zsmtmtend forlogZsms+logss;logZtmt+logstKL(ttut)/stlogZt+logZsReturn KL,  logZs,  logZt\begin{aligned} &\textbf{Algorithm: Online reverse KL: } \mathrm{KL}(p_t \| p_s) \textbf{ — forward pass} \\ &\textbf{Require: } x_s \in \mathbb{R}^{N \times D_s},\; x_t \in \mathbb{R}^{N \times D_t},\; W_s \in \mathbb{R}^{V \times D_s},\; W_t \in \mathbb{R}^{V \times D_t},\; \text{temperature } T \\ &\textbf{Ensure: } \mathrm{KL} \in \mathbb{R}^N,\; \log Z_s \in \mathbb{R}^N,\; \log Z_t \in \mathbb{R}^N \\ &m_s, m_t \gets -\infty;\quad s_s, s_t, t_t, u_t \gets 0 \\ &\textbf{for } v_0 = 0,\; \text{VBlk},\; 2 \cdot \text{VBlk},\; \ldots,\; V \textbf{ do} \\ &\quad z_s \gets x_s \, W_s[v_0\!:\!v_1]^\top / T \\ &\quad z_t \gets x_t \, W_t[v_0\!:\!v_1]^\top / T \\ &\quad \triangleright \text{Student: partition function only} \\ &\quad m_s' \gets \max(m_s,\; \max_{\text{tile}}(z_s)) \\ &\quad s_s \gets \exp(m_s - m_s') \cdot s_s + \textstyle\sum_{\text{tile}} \exp(z_s - m_s') \\ &\quad m_s \gets m_s' \\ &\quad \triangleright \text{Teacher: weighted accumulators} \\ &\quad m_t' \gets \max(m_t,\; \max_{\text{tile}}(z_t)) \\ &\quad \alpha \gets \exp(m_t - m_t') \\ &\quad s_t \gets \alpha \cdot s_t + \textstyle\sum_{\text{tile}} \exp(z_t - m_t') \\ &\quad t_t \gets \alpha \cdot t_t + \textstyle\sum_{\text{tile}} \exp(z_t - m_t') \cdot z_t \\ &\quad u_t \gets \alpha \cdot u_t + \textstyle\sum_{\text{tile}} \exp(z_t - m_t') \cdot z_s \\ &\quad m_t \gets m_t' \\ &\textbf{end for} \\ &\log Z_s \gets m_s + \log s_s;\quad \log Z_t \gets m_t + \log s_t \\ &\mathrm{KL} \gets (t_t - u_t) / s_t - \log Z_t + \log Z_s \\ &\textbf{Return } \mathrm{KL},\; \log Z_s,\; \log Z_t \end{aligned}
Backward for Forward & Reverse KLAlgorithm: Nitrobrew backward passRequire: xs,  xt,  Ws,  Wt,  logZs,  logZt,  upstream gradient ˉ,  temperature TRequire: Direction{forward,  reverse}Ensure: /xsRN×Ds,  /WsRV×Dsdxs0;dWsfor v0=0,  VBlk,  2VBlk,  ,  V dozsxsWs[v0 ⁣: ⁣v1]/TRecompute logit tilesztxtWt[v0 ⁣: ⁣v1]/Tpsexp(zslogZs)ptexp(ztlogZt)if forward KL thengps(logpslogptKL)ˉ/T[N,VBlk]elsereverse KLg(pspt)ˉ/Tend ifdxs+=gWs[v0 ⁣: ⁣v1]dWs[v0 ⁣: ⁣v1]gxsend forReturn dxs,  dWs\begin{aligned} &\textbf{Algorithm: Nitrobrew backward pass} \\ &\textbf{Require: } x_s,\; x_t,\; W_s,\; W_t,\; \log Z_s,\; \log Z_t,\; \text{upstream gradient } \bar{\ell},\; \text{temperature } T \\ &\textbf{Require: } \text{Direction} \in \{\text{forward},\; \text{reverse}\} \\ &\textbf{Ensure: } \partial \ell / \partial x_s \in \mathbb{R}^{N \times D_s},\; \partial \ell / \partial W_s \in \mathbb{R}^{V \times D_s} \\ &\mathrm{d}x_s \gets 0;\quad \mathrm{d}W_s \gets \emptyset \\ &\textbf{for } v_0 = 0,\; \text{VBlk},\; 2 \cdot \text{VBlk},\; \ldots,\; V \textbf{ do} \\ &\quad z_s \gets x_s \, W_s[v_0\!:\!v_1]^\top / T \quad \triangleright \text{Recompute logit tiles} \\ &\quad z_t \gets x_t \, W_t[v_0\!:\!v_1]^\top / T \\ &\quad p_s \gets \exp(z_s - \log Z_s) \\ &\quad p_t \gets \exp(z_t - \log Z_t) \\ &\quad \textbf{if } \text{forward KL} \textbf{ then} \\ &\qquad g \gets p_s \cdot (\log p_s - \log p_t - \mathrm{KL}) \cdot \bar{\ell} / T \quad \triangleright [N, \text{VBlk}] \\ &\quad \textbf{else} \quad \triangleright \text{reverse KL} \\ &\qquad g \gets (p_s - p_t) \cdot \bar{\ell} / T \\ &\quad \textbf{end if} \\ &\quad \mathrm{d}x_s \mathrel{+}= g \, W_s[v_0\!:\!v_1] \\ &\quad \mathrm{d}W_s[v_0\!:\!v_1] \gets g^\top x_s \\ &\textbf{end for} \\ &\textbf{Return } \mathrm{d}x_s,\; \mathrm{d}W_s \end{aligned}

C. Profiling Setup

For within-trainer profiling, we followed a standard baseline setup. We tested on three student–teacher pairs:

Student–teacher pairs for profiling
StudentTeacher
Pair 1Qwen3 0.6BQwen3 8B
Pair 2Qwen3 1.7BQwen3 32B
Pair 3Qwen3 4BQwen3 32B

We use the DAPO-Math-17k [16] dataset. We test a max generation length of 8192, 256 prompts per step, and 1 generation per prompt. We start timing after 5 steps of training have already completed and average step time over the subsequent 10 steps.

For floats per token r<dmodelr < d_{\text{model}}, we adopt SVD-Nitrobrew as described in SVD-Nitrobrew.

D. Extended Profiling Results

1.7B to 32B breakdown Breakdown of within-step savings for 1.7B→32B distillation setup.

4B to 32B breakdown Breakdown of within-step savings for 4B→32B distillation setup.

E. Training Setup

We mostly follow the setup of Jin et al. The teacher is Qwen3-8B and the student is Qwen3-1.7B-Base, trained on the MATH dataset (~7,500 problems) for 2 epochs (~116 steps) using the verl-rl framework.

We use supervised forward KL distillation (use_policy_gradient=False) with a GRPO advantage estimator, cosine learning rate schedule peaking at 3×10⁻⁶, and per-token loss clamping at 10.0 nats. Rollout generation uses temperature 0.6.

For the top-k baselines, we evaluate k{64,512}k \in \{64, 512\} transmitting 2k2k floats per token (index–logprob pairs). For Nitrobrew, we transmit PCA-compressed teacher hidden states at d_comp = D_model = 4096 (i.e., full rank, lossless compression). Validation accuracy is measured on MATH-lighteval every 10 steps.

F. OPD Convergence

For training convergence, we tested on-policy distillation with Qwen3-1.7B/8B on MATH and found Nitrobrew competitive with top-k baselines on downstream accuracy. However, this is a small-scale regime — the benefits of full-vocabulary distillation (calibration, tail signal) manifest more clearly at larger model and data scales, where the information discarded by top-k truncation becomes a binding constraint and students are more capable of learning from the tail.

We observed that full-vocabulary KL and top-k KL are qualitatively different loss functions. They require independent hyperparameter tuning and practitioners adopting Nitrobrew should expect to perform adjustment.

G. Compression Experiment Details

Model and data: We use Qwen3-8B [17] as the teacher. For compression strategies that require calibration data (importance SVD, probability-reweighted SVD), we collect teacher activations by running Qwen3-8B on a subset of OpenThoughts3-1.2M [18]. We record the post-RMSNorm final-layer hidden states hRdmodelh \in \mathbb{R}^{d_{\text{model}}}, empirical activation covariance C=E[hh]C = \mathbb{E}[hh^\top], and marginal token probabilities pˉ=E[softmax(WUh)]\bar{p} = \mathbb{E}[\mathrm{softmax}(W_U h)].

Evaluation setting. Compression quality is measured in a realistic on-policy distillation setting rather than on the calibration data. We generate rollouts from the student (Qwen3-1.7B-Base [17]) on the MATH dataset [19], then run the teacher (Qwen3-8B) on these student trajectories to produce ground-truth hidden states hh and logits z=WUhz = W_U h.

For each compression strategy at rank rr, we compute the approximate hidden state h~\tilde{h} and reconstruct logits z~\tilde{z}.

Notably, for top-k, we do not count the cost of transmitting indices in the analysis. In reality keff2kk_{\mathrm{eff}} \approx 2k.

Metrics. We report two metrics, averaged over all token positions in the evaluation set:

KL divergence between the true and approximate teacher distributions:

KL(pp~)=v=1Vp(v)logp(v)p~(v),p=softmax(z),p~=softmax(z~)\mathrm{KL}(p \| \tilde{p}) = \sum_{v=1}^{V} p(v) \log \frac{p(v)}{\tilde{p}(v)}, \qquad p = \mathrm{softmax}(z), \quad \tilde{p} = \mathrm{softmax}(\tilde{z})

This measures how much the compression distorts the distributional signal that the student receives during training.

Logit MSE between the true and approximate teacher logits:

MSE=1Vzz~22\mathrm{MSE} = \frac{1}{V} \| z - \tilde{z} \|_2^2

This measures raw reconstruction fidelity independent of the softmax nonlinearity.

H. Note on Throughput Variation

The 3× headline figure is drawn from the mid-scale NemoRL configuration (Qwen3-1.7B student, Qwen3-32B teacher) and is consistent with end-to-end training time in VeRL for the OPD runs. In practice, the speedup varies substantially, from 1.5× to over 14×, depending on sequence length, teacher model size, and framework-specific overhead. We report 3× as a representative figure; for many practical configurations, it is conservative.

References

  1. Hinton, Geoffrey and Vinyals, Oriol and Dean, Jeff (2015).
  2. Romero, Adriana (2014).
  3. Kim, Yoon and Rush, Alexander M (2016).
  4. Muralidharan, Saurav and Turuvekere Sreenivas, Sharath and Joshi, Raviraj and Chochowski, Marcin and Patwary, Mostofa and Shoeybi, Mohammad and Catanzaro, Bryan and Kautz, Jan and Molchanov, Pavlo (2024).
  5. Agarwal, Rishabh and Vieillard, Nino and Zhou, Yongchao and Stanczyk, Piotr and Garea, Sabela Ramos and Geist, Matthieu and Bachem, Olivier (2024).
  6. Thinking Machines.
  7. Furlanello, Tommaso and Lipton, Zachary and Tschannen, Michael and Itti, Laurent and Anandkumar, Anima (2018).
  8. Shenfeld, Idan and Damani, Mehul and Hübotter, Jonas and Agrawal, Pulkit (2026).
  9. Anshumann, Anshumann and Zaidi, Mohd Abbas and Kedia, Akhil and Ahn, Jinwoo and Kwon, Taehwak and Lee, Kangwook and Lee, Haejun and Lee, Joohyung (2025).
  10. Dasgupta, Sayantan and Cohn, Trevor and Baldwin, Timothy (2026).
  11. Dao, Tri and Fu, Dan and Ermon, Stefano and Rudra, Atri and Ré, Christopher (2022).
  12. Wijmans, Erik and Huval, Brody and Hertzberg, Alexander and Koltun, Vladlen and Krähenbühl, Philipp (2024).
  13. Sheng, Guangming and Zhang, Chi and Ye, Ziling and Wu, Xibin and Zhang, Wang and Zhang, Ru and Peng, Yanghua and Lin, Haibin and Wu, Chuan (2024).
  14. BytedTsinghua-SIA (2025).
  15. Yang, An and others (2025).
  16. OpenThoughts (2025).
  17. Hendrycks, Dan and Burns, Collin and Kadavath, Saurav and Arora, Akul and Basart, Steven and Tang, Eric and Song, Dawn and Steinhardt, Jacob (2021).

Footnotes

  1. We may allow some policy drift by doing a few rollouts before updating the student in a relaxed version.