Inference Server From Scratch - Part 3: Sampling
In Part 3, myserve trades pure greedy for a pluggable sampler: temperature scaling, top-k and top-p (nucleus) filtering, plus presence/frequency penalties that mirror OpenAI’s semantics. We add deterministic seed control via a device-matched torch.Generator
, and compute logprobs—both the chosen token and the top-k alternatives—available in streaming and non-streaming responses. A tighter request parser maps OpenAI params 1:1 into a runtime config, and we ship parity tests against transformers.generate
for one-token sampling to lock correctness. We’re still recomputing attention each step (no KV yet); this post is about API parity and correctness, not speed.
New/changed repo layout
The repo picks up the pieces needed for sampling and logprobs. In server/api/openai_types.py
we extend the request schema with sampling params, logprobs fields, and a seed. A new server/core/sampling.py
implements temperature, top-k/top-p filtering, presence/frequency penalties, and multinomial draws. generate.py
now delegates token selection to this sampler instead of greedy argmax, while main.py
parses OpenAI-style options and returns per-token logprobs (when requested) in both streaming and non-streaming paths. Tests add test_sampling_parity.py
to compare against transformers.generate
and test_logprobs_shape.py
to validate the logprob tensor shapes and fields.
myserve/
server/
api/
openai_types.py # add logprobs fields + seed
core/
sampling.py # new: filters, penalties, sampling
generate.py # now calls sampler instead of greedy argmax
main.py # parses sampling params + returns logprobs
tests/
test_sampling_parity.py
test_logprobs_shape.py
API surface: request/response updates
We extend the OpenAI-style schema to cover sampling, penalties, and logprobs while staying tolerant to unknown fields. The request model adds temperature
, top_p
, optional top_k
(not in OpenAI but handy), n
, max_tokens
, and a seed
for reproducibility; penalties (presence_penalty
, frequency_penalty
) follow OpenAI semantics. If logprobs
is true, responses include a logprobs.content
list with one entry per emitted token—each holding the token string, raw bytes, its logprob, and top_logprobs
(size controlled by top_logprobs
, where 0
means only the chosen token). In streaming, each SSE delta
may carry content
and, when requested, a one-token logprobs.content
payload; non-streaming returns the full message plus the aggregated logprobs object.
server/api/openai_types.py
from typing import List, Optional, Literal
from pydantic import BaseModel, ConfigDict
Role = Literal["system", "user", "assistant", "tool"]
class ChatMessage(BaseModel):
role: Role
content: str
class ChatCompletionRequest(BaseModel):
model: str
messages: List[ChatMessage]
# sampling
temperature: Optional[float] = 1.0
top_p: Optional[float] = 1.0
top_k: Optional[int] = None # OpenAI doesn’t expose top_k, but we support it
n: Optional[int] = 1
max_tokens: Optional[int] = 256
seed: Optional[int] = None # not in OpenAI; helpful for tests
# penalties (OpenAI compatible semantics)
presence_penalty: Optional[float] = 0.0
frequency_penalty: Optional[float] = 0.0
# logprobs
logprobs: Optional[bool] = False
top_logprobs: Optional[int] = 0 # 0 → only chosen token
# streaming
stream: Optional[bool] = False
user: Optional[str] = None
model_config = ConfigDict(extra="ignore")
Response shape (non‑stream):
{
"object": "chat.completion",
"choices": [{
"message": {"role": "assistant", "content": "..."},
"finish_reason": "stop"
}],
"logprobs": { // included only if request.logprobs==true
"content": [
{"token": "Hello", "bytes": [72,101,...], "logprob": -0.21,
"top_logprobs": [{"token": "Hi", "logprob": -0.35}, ...]}
// one entry per emitted token
]
}
}
Response shape (streaming chunks): each SSE delta
may include content
and, if requested, a logprobs.content
array with one token’s info.
Sampler implementation
server/core/sampling.py
implements a clean, testable sampler around a SamplerCfg
(temperature, top-p, optional top-k, presence/frequency penalties, and top_logprobs
). The flow is: apply penalties in logit space using the already-generated token history (presence subtracts a constant if a token appeared; frequency subtracts count × penalty
), scale by temperature, then run a combined top-k / top-p filter that masks unlikely tokens to -inf
. We compute both logprobs (via log_softmax
) and probs, then draw the next token with torch.multinomial
, optionally using a device-matched torch.Generator
for seeded determinism. The sampler returns (next_ids, chosen_logprobs, full_logprobs)
, enabling OpenAI-style per-token logprobs while keeping the selection logic modular. (Note: the penalty implementation counts tokens per batch naïvely for clarity—good for correctness now, with room to optimize later.)
server/core/sampling.py
from __future__ import annotations
from dataclasses import dataclass
from typing import Optional, Tuple
import torch
import torch.nn.functional as F
@dataclass
class SamplerCfg:
temperature: float = 1.0
top_p: float = 1.0
top_k: Optional[int] = None
presence_penalty: float = 0.0
frequency_penalty: float = 0.0
top_logprobs: int = 0
@torch.no_grad()
def apply_penalties(
logits: torch.Tensor, # [B, V]
generated_ids: torch.Tensor, # [B, T] (prefix + already sampled)
presence_penalty: float,
frequency_penalty: float,
) -> torch.Tensor:
if presence_penalty == 0.0 and frequency_penalty == 0.0:
return logits
B, V = logits.shape
# token counts per batch
# build counts on CPU for simplicity; move back to device
counts = torch.zeros((B, V), dtype=torch.int32)
for b in range(B):
ids = generated_ids[b].tolist()
for t in ids:
if 0 <= t < V:
counts[b, t] += 1
counts = counts.to(logits.device)
# presence penalty subtracts a constant if token ever appeared
presence_mask = (counts > 0).to(logits.dtype)
logits = logits - presence_penalty * presence_mask
# frequency penalty subtracts count * penalty
logits = logits - frequency_penalty * counts.to(logits.dtype)
return logits
@torch.no_grad()
def top_k_top_p_filter(logits: torch.Tensor, top_k: Optional[int], top_p: float) -> torch.Tensor:
# logits: [B, V]
if top_k is not None and top_k > 0:
top_k = min(top_k, logits.shape[-1])
kth_vals = torch.topk(logits, top_k, dim=-1).values[..., -1:]
logits = torch.where(logits < kth_vals, torch.full_like(logits, float('-inf')), logits)
if 0.0 < top_p < 1.0:
sorted_logits, sorted_idx = torch.sort(logits, descending=True, dim=-1)
probs = F.softmax(sorted_logits, dim=-1)
cumprobs = torch.cumsum(probs, dim=-1)
mask = cumprobs > top_p
# shift right to always keep the first token above threshold
mask[..., 1:] = mask[..., :-1].clone()
mask[..., 0] = False
sorted_logits = torch.where(mask, torch.full_like(sorted_logits, float('-inf')), sorted_logits)
# unsort
unsorted = torch.full_like(sorted_logits, float('-inf'))
unsorted.scatter_(dim=-1, index=sorted_idx, src=sorted_logits)
logits = unsorted
return logits
@torch.no_grad()
def sample_next(
logits: torch.Tensor, # [B, V] last‑step logits
cfg: SamplerCfg,
generated_ids: torch.Tensor, # [B, T]
gen: Optional[torch.Generator] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# penalties first (operate in logits space)
logits = apply_penalties(logits, generated_ids, cfg.presence_penalty, cfg.frequency_penalty)
# temperature
temperature = max(1e-5, float(cfg.temperature))
logits = logits / temperature
# filter and normalize
logits = top_k_top_p_filter(logits, cfg.top_k, cfg.top_p)
logprobs = F.log_softmax(logits, dim=-1)
probs = torch.exp(logprobs)
# sample
next_ids = torch.multinomial(probs, num_samples=1, generator=gen) # [B,1]
chosen_logprobs = logprobs.gather(-1, next_ids) # [B,1]
return next_ids.squeeze(-1), chosen_logprobs.squeeze(-1), logprobs
Generation loop: call the sampler
sample_generate()
replaces greedy argmax with the new sampler: given a model and [B, T]
input_ids
, it iteratively runs a forward pass, takes the last-step logits, and calls sample_next()
with a SamplerCfg
(temperature, top-k/top-p, presence/frequency penalties) and an optional seeded torch.Generator
for reproducibility. It appends the sampled IDs to the sequence, optionally collects per-step telemetry—chosen token ID/logprob and the top-k
alternatives’ (id, logprob)
—and stops early if all sequences hit eos_token_id
. The function returns the full token tensor and a per_step
list structured for easy conversion into OpenAI-style logprobs in both streaming and non-streaming responses.
server/core/generate.py
(replace greedy path)
from __future__ import annotations
from typing import Optional
import torch
from torch import nn
from .sampling import SamplerCfg, sample_next
@torch.no_grad()
def sample_generate(
model: nn.Module,
input_ids: torch.LongTensor,
max_new_tokens: int,
eos_token_id: Optional[int],
cfg: SamplerCfg,
gen: Optional[torch.Generator] = None,
collect_logprobs: bool = False,
):
"""Token‑by‑token sampling. Returns (all_ids, per_step) where per_step is a list of dicts
with keys: {"id": int, "logprob": float, "top_logprobs": List[(id, logprob)]} when requested.
"""
model.eval()
B = input_ids.size(0)
out = input_ids
per_step = []
for _ in range(max_new_tokens):
logits = model(out).logits[:, -1, :] # [B, V]
next_ids, chosen_logprobs, logprobs = sample_next(logits, cfg, out, gen)
out = torch.cat([out, next_ids.unsqueeze(1)], dim=1)
if collect_logprobs:
k = int(cfg.top_logprobs or 0)
step = []
for b in range(B):
item = {"id": int(next_ids[b]), "logprob": float(chosen_logprobs[b])}
if k > 0:
topv, topi = torch.topk(logprobs[b], k)
item["top_logprobs"] = [(int(topi[j]), float(topv[j])) for j in range(topv.numel())]
step.append(item)
per_step.append(step)
if eos_token_id is not None and torch.all(next_ids == eos_token_id):
break
return out, per_step
Wiring it into the server
This update threads sampling, logprobs, and seed control straight into the OpenAI endpoint without changing the client contract. On each /v1/chat/completions
call, the server builds the prompt, loads a model from the registry, and constructs a SamplerCfg
from OpenAI-style params (temperature
, top_p
, optional top_k
, presence/frequency penalties, top_logprobs
). If a seed
is provided, it initializes a device-matched torch.Generator
for reproducible draws. In streaming mode, the loop samples one token at a time and now attaches optional per-token logprobs via an extra
payload merged into the SSE delta
; in non-streaming, it uses sample_generate()
to return the full text plus an aggregated logprobs.content
array when requested. Echo fallback from earlier posts remains intact, but the “real model” path now speaks the full sampling + logprobs dialect end-to-end.
server/main.py
(sampling + logprobs + seed)
# ... imports as in Post 2
import torch
from server.core.sampling import SamplerCfg
from server.core.generate import sample_generate
# env defaults (keep prior)
@app.post("/v1/chat/completions")
async def chat_completions(req: ChatCompletionRequest):
model_name = req.model
tokenizer = get_tokenizer(model_name)
prompt = render_messages(tokenizer, req.messages)
created = int(time.time())
rid = f"chatcmpl_{uuid.uuid4().hex[:24]}"
# Try to load a real model unless echo is forced
bundle = None
if not USE_ECHO_FALLBACK:
try:
bundle = REGISTRY.load(model_name, dtype=DEFAULT_DTYPE, device=DEFAULT_DEVICE)
except Exception as e:
# fallback silently to echo mode; in real servers you would surface a 400
bundle = None
if bundle is None:
# --- Echo backend (Post 1) ---
input_ids = tokenizer.encode(prompt, add_special_tokens=False)
max_new = max(0, int(req.max_tokens or 0)) or 128
output_ids = input_ids[:max_new]
if req.stream:
async def echo_stream() -> AsyncGenerator[bytes, None]:
yield _sse_chunk(rid, model_name, role="assistant")
for tid in output_ids:
piece = tokenizer.decode([tid], skip_special_tokens=True, clean_up_tokenization_spaces=False)
if piece:
yield _sse_chunk(rid, model_name, content=piece)
await asyncio.sleep(0.0)
yield _sse_done(rid, model_name)
return StreamingResponse(echo_stream(), media_type="text/event-stream")
text = tokenizer.decode(output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
return JSONResponse(_non_stream_payload(rid, model_name, text))
# real model path (after bundle is loaded)
tok = bundle.tokenizer
eos = tok.eos_token_id
enc = tok(prompt, return_tensors="pt", add_special_tokens=False)
input_ids = enc["input_ids"].to(bundle.device)
# sampler config
cfg = SamplerCfg(
temperature=float(req.temperature or 1.0),
top_p=float(req.top_p or 1.0),
top_k=int(req.top_k) if req.top_k else None,
presence_penalty=float(req.presence_penalty or 0.0),
frequency_penalty=float(req.frequency_penalty or 0.0),
top_logprobs=int(req.top_logprobs or 0),
)
max_new = max(1, int(req.max_tokens or 16))
# seeded generator (device‑specific to avoid CPU/CUDA mismatch)
gen = None
if req.seed is not None:
gen = torch.Generator(device=bundle.device.type)
gen.manual_seed(int(req.seed))
if req.stream:
async def stream() -> AsyncGenerator[bytes, None]:
yield _sse_chunk(rid, model_name, role="assistant")
generated = input_ids
for _ in range(max_new):
logits = bundle.model(generated).logits[:, -1, :]
next_ids, chosen_lp, logprobs = sample_next(logits, cfg, generated, gen)
generated = torch.cat([generated, next_ids.unsqueeze(1)], dim=1)
piece = tok.decode(next_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
extra = None
if req.logprobs:
k = int(req.top_logprobs or 0)
content = [{
"token": piece,
"bytes": list(piece.encode("utf-8", errors="ignore")),
"logprob": float(chosen_lp[0]),
"top_logprobs": (
[{"token": tok.decode([int(i)], skip_special_tokens=True, clean_up_tokenization_spaces=False),
"logprob": float(v)} for v, i in zip(*torch.topk(logprobs[0], k))]
if k > 0 else []
),
}]
extra = {"logprobs": {"content": content}}
yield _sse_chunk(rid, model_name, content=piece, extra=extra)
if eos is not None and int(next_ids[0]) == eos:
break
await asyncio.sleep(0.0)
yield _sse_done(rid, model_name)
return StreamingResponse(stream(), media_type="text/event-stream")
# non‑stream
all_ids, per_step = sample_generate(
bundle.model, input_ids, max_new_tokens=max_new, eos_token_id=eos,
cfg=cfg, gen=gen, collect_logprobs=bool(req.logprobs)
)
new_ids = all_ids[0, input_ids.size(1):]
text = tok.decode(new_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
payload = _non_stream_payload(rid, model_name, text)
if req.logprobs:
tokens = []
for step in per_step:
item = step[0]
tstr = tok.decode([item["id"]], skip_special_tokens=True, clean_up_tokenization_spaces=False)
toks = {"token": tstr, "bytes": list(tstr.encode("utf-8", errors="ignore")), "logprob": item["logprob"]}
if cfg.top_logprobs:
toks["top_logprobs"] = [
{"token": tok.decode([tid], skip_special_tokens=True, clean_up_tokenization_spaces=False), "logprob": (lp if str(math.fabs(lp)) != "inf" else str(lp))}
for (tid, lp) in item.get("top_logprobs", [])
]
tokens.append(toks)
payload["logprobs"] = {"content": tokens}
return JSONResponse(payload)
The SSE helper
_sse_chunk(...)
now accepts anextra
dict that it merges into thechoices[0].delta
.
Helper change:
def _sse_chunk(rid: str, model: str, content: str | None = None, role: str | None = None, extra: dict | None = None) -> bytes:
delta = {}
if role is not None:
delta["role"] = role
if content is not None:
delta["content"] = content
if extra:
delta.update(extra)
obj = {
"id": rid,
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model,
"choices": [{
"index": 0,
"delta": delta,
"finish_reason": None,
}],
}
return f"data: {json.dumps(obj, ensure_ascii=False)}\n\n".encode()
Tests
The test suite locks in correctness for sampling and logprobs. One-token parity (tests/test_sampling_parity.py
) compares our sample_generate()
against transformers.generate()
on sshleifer/tiny-gpt2
across temperature/top-p/top-k settings, seeding both paths so they emit the same next token tensor—a tight equivalence check on sampling math. Logprobs shape (tests/test_logprobs_shape.py
) drives the FastAPI endpoint in-memory and asserts the non-streaming response includes a logprobs.content
list with per-token entries containing token
, logprob
, and a list of top_logprobs
when requested. Together, these tests ensure protocol compatibility and sampler fidelity before we chase speed.
1) One‑token sampling parity vs HF
tests/test_sampling_parity.py
import torch
import pytest
from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
from myserve.core.sampling import SamplerCfg
from myserve.core.generate import sample_generate
MODEL = "sshleifer/tiny-gpt2"
@pytest.mark.parametrize("prompt,temperature,top_p,top_k", [
("Hello", 1.0, 1.0, None),
("Hello", 0.7, 1.0, 50),
("Hello", 1.0, 0.9, None),
])
@torch.no_grad()
def test_one_token_sampling_parity(prompt, temperature, top_p, top_k):
tok = AutoTokenizer.from_pretrained(MODEL, use_fast=True)
tok.pad_token = tok.eos_token
model = AutoModelForCausalLM.from_pretrained(MODEL)
enc = tok(prompt, return_tensors="pt", add_special_tokens=False)
inp = enc["input_ids"]
# our sample (1 token)
cfg = SamplerCfg(temperature=temperature, top_p=top_p, top_k=top_k)
gen = torch.Generator(device="cpu").manual_seed(1234)
ours, _ = sample_generate(model, inp, 1, tok.eos_token_id, cfg, gen, collect_logprobs=False)
# HF sample (1 token)
set_seed(1234)
hf = model.generate(**enc, max_new_tokens=1, do_sample=True, temperature=temperature,
top_p=top_p, top_k=(top_k or 0), pad_token_id=tok.pad_token_id, eos_token_id=tok.eos_token_id)
assert torch.equal(ours, hf)
2) Logprobs shape (non‑stream)
tests/test_logprobs_shape.py
from httpx import AsyncClient, ASGITransport
import pytest
from myserve.main import app
@pytest.mark.asyncio
async def test_logprobs_shape():
body = {
"model": "sshleifer/tiny-gpt2",
"messages": [{"role": "user", "content": "hi"}],
"max_tokens": 2,
"temperature": 1.0,
"top_p": 1.0,
"logprobs": True,
"top_logprobs": 3,
"seed": 42,
"stream": False,
}
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as ac:
r = await ac.post("/v1/chat/completions", json=body)
assert r.status_code == 200
data = r.json()
assert "logprobs" in data and "content" in data["logprobs"]
assert len(data["logprobs"]["content"]) > 0
first = data["logprobs"]["content"][0]
assert "token" in first and "logprob" in first
assert isinstance(first.get("top_logprobs", []), list)
Try it
Non‑stream with seed + logprobs:
curl http://localhost:8000/v1/chat/completions \
-H 'content-type: application/json' \
-d '{
"model": "sshleifer/tiny-gpt2",
"messages": [{"role": "user", "content": "Count to three:"}],
"max_tokens": 8,
"temperature": 0.7,
"top_p": 0.95,
"top_k": 50,
"seed": 1337,
"logprobs": true,
"top_logprobs": 5
}' | jq .
Streaming with logprobs:
curl http://localhost:8000/v1/chat/completions \
-H 'content-type: application/json' \
-d '{
"model": "sshleifer/tiny-gpt2",
"messages": [{"role": "user", "content": "Say a color:"}],
"max_tokens": 4,
"temperature": 1.0,
"top_p": 0.9,
"logprobs": true,
"top_logprobs": 3,
"seed": 7,
"stream": true
}'
A few design choices keep this step practical and compatible. Penalties are implemented in logits space to mirror OpenAI’s public guidance—close in spirit, though exact parity may vary at the margins. Seed control uses a per-request torch.Generator
(CPU or CUDA to match the model device), so concurrent requests don’t collide. For logprobs, we return both token text and bytes; the latter is robust for clients that diff at the byte level across tokenizers/locales. And for back-compat, when logprobs=false
the response shape is unchanged from Part 2, so existing clients keep working without code changes.
The complete Part 3 code is here: https://github.com/pbelevich/myserve/tree/ver3 — pull it, run the sampling/logprobs tests, and let me know where you want to push the API or performance next.
Post 4 is all about speed. We’ll introduce a KV cache to avoid recomputing attention, split generation into a clear prefill → step pipeline, and add continuous batching so new requests can join mid-flight without stalling throughput. With those pieces in place, we’ll start measuring the wins like an adult system: TTFT (time-to-first-token), tokens/sec (steady-state), and utilization under mixed workloads. This is the turning point where myserve moves from “correct” to fast.