Home
HydraLM
HydraLM is a hybrid sub-quadratic language model that combines Gated DeltaNet, Sliding-Window Attention, and chunk-sparse Retrieval Attention to preserve fast scaling while recovering precise long-range information.
Why HydraLM exists
Pure linear models scale well but can blur exact recall. Pure softmax models recall well but scale poorly. HydraLM combines both lanes so you can keep speed while restoring local and long-range precision.
What changed in the docs
This page now behaves like a documentation workspace: a persistent desktop sidebar, a mobile drawer, and hash-based sections that open directly to the topic you share.
Where to go next
Start with installation if you are new, quick start if you already have the repo, or architecture if you want the model design first.
Editable install, test suite, and the right repo path for the Python package.
research/hydralm/ Flagship Learn the retrieval stackChunk routing, top-k selection, chunk-size tradeoffs, and long-range recall.
docs/retrieval.md Evidence Review claims and measurementsBenchmarks, the claims suite, and the scripts that regenerate published results.
RESULTS.md Reference Browse the public APIHydraConfig, HydraLM, generation helpers, retrieval modules, and deployment utilities.
docs/api.mdThe canonical Markdown sources live in the GitHub repository. This HTML layer is an optimized browser-friendly shell for browsing them on hydralm.pages.dev.
Installation
Clone the repo and work from the package root.
HydraLM ships as a pure-PyTorch package. The main implementation lives in
research/hydralm, so install from there rather than from the website folder.
Use the canonical repository path
git clone https://github.com/byte271/HydraLM.git
cd HydraLM/research/hydralm
pip install -e ".[dev]"
pytest
What the extras do
[dev]installs test and lint dependencies.[train]adds training-specific optimizer dependencies.[hf]adds the HuggingFace adapter path.
Quick start
Get a retrieval-enabled model running in a few lines.
The example below uses the 0.3.0 retrieval and multi-token prediction features while keeping the configuration surface approachable.
Minimal 0.3.0 setup
from hydralm import HydraConfig, HydraLM
cfg = HydraConfig(
vocab_size=32_000,
d_model=768,
n_layers=12,
n_heads=12,
swa_window=512,
swa_every=4,
retrieval_every=3,
retrieval_chunk_size=128,
retrieval_top_k=8,
mtp_depth=2,
mtp_loss_weight=0.1,
)
model = HydraLM(cfg)
print(cfg.summary())
Inspect logits and auxiliary loss
import torch
input_ids = torch.randint(0, cfg.vocab_size, (2, 1024))
out = model(input_ids, compute_mtp=True)
out["logits"]
out["mtp_aux_loss"]
Configuration
One config controls the whole mixer schedule.
Most of the day-to-day work happens through HydraConfig.
The retrieval and MTP paths are intentionally opt-in so the legacy 0.2.0
behavior stays intact until you enable them.
| Field | Default | Why it matters |
|---|---|---|
swa_every |
4 | Places an exact sliding-window attention layer every Nth block. |
retrieval_every |
0 | Activates retrieval layers when set above zero. |
retrieval_chunk_size |
128 | Controls routing granularity and chunk-local attention span. |
retrieval_top_k |
8 | Limits how many prior chunks are pulled into each retrieval block. |
mtp_depth |
0 | Enables the auxiliary next-k prediction head when positive. |
layer_types |
None |
Accepts a manual layer schedule if you want to override the automatic placement. |
layer_types only after you already understand the balance
you want between DeltaNet, SWA, and retrieval blocks.
Architecture
A hybrid backbone with three complementary memory paths.
HydraLM interleaves DeltaNet, sliding-window attention, and optional retrieval attention inside the same pre-norm residual structure. The schedule is where most of the model’s personality comes from.
How each layer earns its cost
- DeltaNet carries the fast recurrent lane and keeps the backbone efficient.
- SWA restores exact recent-token recall inside a bounded local window.
- Retrieval reaches far back into the sequence without full attention over the whole history.
Automatic layer placement
1. Place SWA at every swa_every position.
2. Place retrieval at retrieval_every positions
not already occupied by SWA.
3. Fill remaining blocks with DeltaNet.
| Layer | State | Inference profile |
|---|---|---|
deltanet |
Short-conv cache + recurrent matrix state | Fast per-token recurrent inference. |
swa |
Rolling K/V cache of width W |
Exact within the local window. |
retrieval |
Chunk bank + partial chunk buffer | Top-k sparse access to distant history. |
Retrieval Attention
Chunk-sparse long-range recall without a full quadratic pass.
Retrieval Attention routes each query chunk to a small set of relevant prior chunks, then performs exact causal attention over only that subset plus the local chunk.
cfg = HydraConfig(
...,
retrieval_every=3,
retrieval_chunk_size=128,
retrieval_top_k=8,
retrieval_learned_summary=False,
)
What the streaming path stores
bank_k,bank_vfor committed chunk memories.bank_sumfor routing summaries.buf_k,buf_vfor the current in-flight chunk.posand buffer length bookkeeping.
Compressive Memory
Keep serving memory bounded while the context keeps growing.
Compressive Memory wraps a softmax-attention-style K/V stream in three tiers: an exact recent window, a compressed summary pool, and a FIFO tombstone beyond the pool limit.
from hydralm import CompressiveMemory
mem = CompressiveMemory(
head_dim=64,
n_heads=12,
exact_window=512,
compress_every=4,
n_summary=256,
)
- Exact attention over the freshest tokens.
- Learned compressed summaries for older context.
- Constant-memory serving once the summary pool is full.
Multi-Token Prediction
Densify the training signal and unlock self-drafting.
The MTP head predicts the next k tokens in parallel from each
backbone hidden state. That gives training denser supervision and creates
a built-in draft path for speculative decoding.
cfg = HydraConfig(..., mtp_depth=2, mtp_loss_weight=0.1)
out = model(input_ids, compute_mtp=True)
if out["mtp_aux_loss"] is not None:
loss = token_loss + out["mtp_aux_loss"]
from hydralm import speculative_generate
out = speculative_generate(
draft=model,
target=model,
prompt_ids=prompt_ids,
max_new_tokens=256,
k=cfg.mtp_depth,
)
Training
Plain AdamW is the default trainer path; hybrid Muon is opt-in.
The built-in TrainingConfig trainer defaults to AdamW and can
switch to the hybrid Muon + AdamW optimizer via optimizer="muon".
It also supports warmup, cosine decay, gradient accumulation, FSDP, and
optional MTP auxiliary loss integration.
from hydralm import build_hybrid_optimizer
optim = build_hybrid_optimizer(
model,
muon_lr=5e-3,
muon_momentum=0.95,
adamw_lr=3e-4,
adamw_betas=(0.9, 0.95),
adamw_weight_decay=0.1,
)
Streaming
Prefill large contexts, then step token-by-token.
StreamingEngine threads recurrent state across calls, making
long prefill plus incremental decoding practical without reprocessing the
full history on every step.
from hydralm.streaming import StreamingEngine
engine = StreamingEngine(model, chunk_size=4096)
stats = engine.process(long_input_ids)
print(stats.summary())
generated = engine.extend_and_generate(
prompt=long_input_ids,
max_new_tokens=128,
temperature=0.7,
top_k=50,
)
Deployment
Choose between a compiled decoder and a HuggingFace-shaped adapter.
The current deploy surface centers on CompiledDecoder for
low-latency batched token generation and HydraLMForCausalLM
when you want a Transformers-style wrapper with save/load helpers.
from hydralm.deploy import CompiledDecoder, Request
decoder = CompiledDecoder(model, compile=True)
reqs = [
Request(prompt=input_ids[0], max_new_tokens=128, temperature=0.8)
]
decoded = decoder.decode(reqs)
from hydralm.deploy import HydraLMForCausalLM
hf_model = HydraLMForCausalLM(cfg)
hf_model.model.load_state_dict(model.state_dict())
hf_model.save_pretrained("hf/hydralm-160m")
Evaluation
The evaluation harness is designed to explain failure modes, not just scores.
HydraLM includes evaluation tasks for MQAR, needle-in-a-haystack, multi-fact long-context QA, online learning, and the broader claims suite.
python scripts/run_mqar.py --seq-len 2048 --n-pairs 64
python scripts/needle_in_haystack.py --seq-len 1048576
python scripts/long_context_qa.py \
--seq-len 16384 \
--num-facts 32 --num-queries 8 \
--use-retrieval --retrieval-every 3 \
--retrieval-chunk-size 128 --retrieval-top-k 8 \
--mtp-depth 2
Benchmarks
Claims-backed measurements keep the project grounded.
Published numbers are tied to reproducible scripts and CI checks. The project currently tracks nine formal claims plus a provisional retrieval-era addition for multi-fact long-context QA.
| Artifact | Purpose |
|---|---|
RESULTS.md |
Human-readable summary of each claim status. |
results.json |
Raw structured measurements generated by the claim runner. |
scripts/reproduce_claims.py |
Canonical entry point for recomputing the report. |
docs/claims.md |
Defines what each published claim means. |
cd research/hydralm
python scripts/reproduce_claims.py --budget paper --out RESULTS.md
API reference
The public surface stays intentionally small.
The package exposes a compact top-level API: configuration, the model itself, generation helpers, retrieval-era modules, and deployment hooks.
| Name | Kind | Purpose |
|---|---|---|
HydraConfig |
dataclass | All hyperparameters, schedules, and optional 0.3.0 features. |
HydraLM |
nn.Module | The top-level model with training and streaming entry points. |
generate |
function | Greedy or sampling generation loop. |
speculative_generate |
function | Draft-then-verify decoding, including self-drafting via MTP. |
RetrievalAttention |
nn.Module | Chunk-sparse retrieval mixer for long-range recall. |
CompressiveMemory |
nn.Module | Three-tier KV memory wrapper for bounded-memory serving. |
Roadmap
The next steps stay close to the same design philosophy.
Future work centers on making the current hybrid design faster to run, easier to serve, and more formally measured at larger context scales.
- Fused CUDA or Triton kernels for the DeltaNet recurrence.
- Paged recurrent state for mixed prompt-length serving.
- Longer-horizon formal retrieval claims at million-token scale.
- Additional memory policies for FactBank compaction and TTL.