Revert "(feat) Fully customized in house vllm " (#874)

This commit is contained in:
Gong Junmin 2026-03-19 23:07:51 +08:00 committed by GitHub
parent 8b58da12a2
commit 89d53791dc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
53 changed files with 3121 additions and 1351 deletions

View file

@ -475,7 +475,7 @@ The `start_gradio_ui_rocm.bat` and `start_api_server_rocm.bat` scripts include a
```batch
REM ==================== ROCm Configuration ====================
REM Force PyTorch LM backend (bypasses vllm engine flash_attn dependency)
REM Force PyTorch LM backend (bypasses nano-vllm flash_attn dependency)
set ACESTEP_LM_BACKEND=pt
REM RDNA3 GPU architecture override

View file

@ -23,4 +23,6 @@ paths-ignore:
# Training modules — local-only, paths validated via safe_path()
- acestep/training
- acestep/ui
# Third-party vendored code
- acestep/third_parts
- third_party

View file

@ -14,7 +14,7 @@ ACE-Step 1.5 is an open-source music foundation model combining a Language Model
- **FastAPI + Uvicorn** for REST API server
- **uv** for dependency management
- **MLX** (Apple Silicon native acceleration, macOS ARM64)
- **customized_vllm** (built-in LLM inference engine)
- **nano-vllm** (optimized LLM inference, non-macOS ARM64)
## Multi-Platform Support

View file

@ -1,409 +0,0 @@
"""Customised vLLM inference engine for single-GPU Qwen3 generation.
Public API:
LLM - High-level generate() interface with optional CFG
SamplingParams - Per-request sampling configuration
reset_context - Clear thread-local forward state (call on error recovery)
"""
import os
import atexit
import threading as _threading
from collections import deque
from copy import copy
from dataclasses import dataclass, field
from enum import Enum, auto
from itertools import count
from time import perf_counter
from typing import Optional, Callable, Any
import torch
from tqdm.auto import tqdm
from transformers import AutoConfig, AutoTokenizer
# ---------------------------------------------------------------------------
# Sampling configuration
# ---------------------------------------------------------------------------
@dataclass
class SamplingParams:
"""Per-request configuration for token sampling."""
temperature: float = 1.0
max_tokens: int = 64
ignore_eos: bool = False
cfg_scale: float = 1.0
top_k: Optional[int] = None
top_p: Optional[float] = None
repetition_penalty: float = 1.0
logits_processor: Optional[Any] = field(default=None, repr=False)
logits_processor_update_state: Optional[Callable[[int], None]] = field(default=None, repr=False)
def __post_init__(self):
assert self.temperature > 1e-10, "greedy sampling is not permitted"
assert self.max_tokens > 0, "max_tokens must be > 0"
assert self.cfg_scale >= 1.0, "cfg_scale must be >= 1.0"
if self.top_k is not None:
assert self.top_k > 0, "top_k must be > 0"
if self.top_p is not None:
assert 0.0 < self.top_p <= 1.0, "top_p must be in (0.0, 1.0]"
assert self.repetition_penalty > 0.0, "repetition_penalty must be > 0.0"
# ---------------------------------------------------------------------------
# Thread-local forward state (replaces separate context.py module)
# ---------------------------------------------------------------------------
@dataclass
class ForwardState:
"""Attention metadata shared between the pipeline and transformer layers."""
is_prefill: bool = False
cu_seqlens_q: torch.Tensor | None = None
cu_seqlens_k: torch.Tensor | None = None
max_seqlen_q: int = 0
max_seqlen_k: int = 0
slot_mapping: torch.Tensor | None = None
context_lens: torch.Tensor | None = None
block_tables: torch.Tensor | None = None
_TLS = _threading.local()
def _get_forward_state() -> ForwardState:
"""Retrieve the current thread's forward state."""
s = getattr(_TLS, "_fwd", None)
if s is None:
s = ForwardState()
_TLS._fwd = s
return s
def _set_forward_state(is_prefill=False, **kw) -> ForwardState:
"""Replace the current thread's forward state."""
_TLS._fwd = ForwardState(is_prefill=is_prefill, **kw)
return _TLS._fwd
def reset_context():
"""Reset forward state to defaults (public, used by llm_inference error recovery)."""
_TLS._fwd = ForwardState()
# ---------------------------------------------------------------------------
# Generation slot (per-sequence state)
# ---------------------------------------------------------------------------
class SlotStatus(Enum):
PENDING = auto()
ACTIVE = auto()
DONE = auto()
class GenerationSlot:
"""Tracks token state, cache blocks, and sampling config for a single sequence."""
_id_gen = count()
def __init__(self, token_ids: list[int], params=SamplingParams(),
is_unconditional: bool = False, block_size: int = 256):
self.slot_id = next(GenerationSlot._id_gen)
self.block_size = block_size
self.status = SlotStatus.PENDING
self.token_ids = copy(token_ids)
self.last_token = token_ids[-1]
self.num_tokens = len(token_ids)
self.prompt_length = len(token_ids)
self.cache_blocks: list[int] = []
self.temperature = params.temperature
self.max_tokens = params.max_tokens
self.ignore_eos = params.ignore_eos
self.cfg_scale = params.cfg_scale
self.top_k = params.top_k
self.top_p = params.top_p
self.repetition_penalty = params.repetition_penalty
self.is_unconditional = is_unconditional
self.paired_slot: Optional["GenerationSlot"] = None
self.logits_processor: Optional[Any] = params.logits_processor
self.logits_processor_update_state: Optional[Callable[[int], None]] = (
params.logits_processor_update_state
)
def __len__(self):
return self.num_tokens
def __getitem__(self, key):
return self.token_ids[key]
@property
def is_finished(self):
return self.status == SlotStatus.DONE
@property
def generated_ids(self):
return self.token_ids[self.prompt_length:]
@property
def required_blocks(self):
return (self.num_tokens + self.block_size - 1) // self.block_size
@property
def tail_block_fill(self):
return self.num_tokens - (self.required_blocks - 1) * self.block_size
def push_token(self, token_id: int):
self.token_ids.append(token_id)
self.last_token = token_id
self.num_tokens += 1
# ---------------------------------------------------------------------------
# KV cache pool
# ---------------------------------------------------------------------------
class CachePool:
"""Simple block-based KV cache allocator (no prefix caching)."""
def __init__(self, num_blocks: int, block_size: int):
self.block_size = block_size
self.available: deque[int] = deque(range(num_blocks))
self.total = num_blocks
def has_capacity(self, num_blocks: int) -> bool:
return len(self.available) >= num_blocks
def reserve(self, slot: GenerationSlot):
for _ in range(slot.required_blocks):
slot.cache_blocks.append(self.available.popleft())
def release(self, slot: GenerationSlot):
for bid in reversed(slot.cache_blocks):
self.available.append(bid)
slot.cache_blocks.clear()
def grow_if_needed(self, slot: GenerationSlot):
if len(slot) % self.block_size == 1 and len(slot) > slot.prompt_length:
if not self.available:
raise RuntimeError(
f"KV cache exhausted during decode: slot {slot.slot_id} needs a new block "
f"but 0/{self.total} blocks are available"
)
slot.cache_blocks.append(self.available.popleft())
# ---------------------------------------------------------------------------
# Engine configuration
# ---------------------------------------------------------------------------
@dataclass
class _EngineConfig:
model: str
max_num_batched_tokens: int = 16384
max_num_seqs: int = 512
max_model_len: int = 4096
gpu_memory_utilization: float = 0.9
enforce_eager: bool = False
kvcache_block_size: int = 256
def __post_init__(self):
assert os.path.isdir(self.model)
self.hf_config = AutoConfig.from_pretrained(self.model)
self.max_model_len = min(self.max_model_len, self.hf_config.max_position_embeddings)
assert self.max_num_batched_tokens >= self.max_model_len
# ---------------------------------------------------------------------------
# LLM - public inference API
# ---------------------------------------------------------------------------
class LLM:
"""High-level inference engine with optional classifier-free guidance.
Usage::
llm = LLM(model="/path/to/model")
outputs = llm.generate(["Hello world"], SamplingParams(max_tokens=128))
"""
def __init__(self, model, **kwargs):
from acestep.customized_vllm.pipeline import InferencePipeline
cfg = _EngineConfig(
model=model,
max_num_batched_tokens=kwargs.get("max_num_batched_tokens", 16384),
max_num_seqs=kwargs.get("max_num_seqs", 512),
max_model_len=kwargs.get("max_model_len", 4096),
gpu_memory_utilization=kwargs.get("gpu_memory_utilization", 0.9),
enforce_eager=kwargs.get("enforce_eager", False),
kvcache_block_size=kwargs.get("kvcache_block_size", 256),
)
self._cfg = cfg
self._lock = _threading.Lock()
self._pipeline = InferencePipeline(
hf_config=cfg.hf_config, model_path=cfg.model,
block_size=cfg.kvcache_block_size, max_num_seqs=cfg.max_num_seqs,
max_num_batched_tokens=cfg.max_num_batched_tokens,
max_model_len=cfg.max_model_len,
gpu_memory_utilization=cfg.gpu_memory_utilization,
enforce_eager=cfg.enforce_eager,
)
tok = kwargs.get("tokenizer", None)
self.tokenizer = tok if tok is not None else AutoTokenizer.from_pretrained(model, use_fast=True)
self._eos = self.tokenizer.eos_token_id
self._block_size = cfg.kvcache_block_size
self._cache = CachePool(self._pipeline._num_cache_blocks, cfg.kvcache_block_size)
self._active_slots: list[GenerationSlot] = []
atexit.register(self.exit)
def exit(self):
self._pipeline.shutdown()
def reset(self):
"""Release all KV cache blocks (call on error to prevent leaks)."""
for slot in self._active_slots:
if slot.cache_blocks:
self._cache.release(slot)
self._active_slots.clear()
# -- Public generate --------------------------------------------------
def generate(self, prompts, sampling_params, use_tqdm=True, unconditional_prompts=None):
"""Generate completions for a batch of prompts.
Returns list of dicts with ``"text"`` and ``"token_ids"`` keys.
"""
with self._lock:
return self._run_generation(prompts, sampling_params, use_tqdm, unconditional_prompts)
# -- Internal generation logic ----------------------------------------
def _prepare_slots(self, prompts, sampling_params, unconditional_prompts):
"""Tokenise prompts, create slots, allocate KV cache."""
if not isinstance(sampling_params, list):
sampling_params = [sampling_params] * len(prompts)
if unconditional_prompts is None:
unconditional_prompts = [None] * len(prompts)
if len(sampling_params) != len(prompts):
raise ValueError(
f"sampling_params length ({len(sampling_params)}) != prompts length ({len(prompts)})"
)
if len(unconditional_prompts) != len(prompts):
raise ValueError(
f"unconditional_prompts length ({len(unconditional_prompts)}) != prompts length ({len(prompts)})"
)
all_slots = []
for prompt, sp, uncond in zip(prompts, sampling_params, unconditional_prompts):
ids = self.tokenizer.encode(prompt) if isinstance(prompt, str) else prompt
bs = getattr(self, '_block_size', 256)
cond = GenerationSlot(ids, sp, block_size=bs)
if sp.cfg_scale > 1.0:
u_ids = (self.tokenizer.encode(uncond) if isinstance(uncond, str)
else (uncond if uncond is not None else ids))
uncond_slot = GenerationSlot(u_ids, sp, is_unconditional=True, block_size=bs)
cond.paired_slot = uncond_slot
uncond_slot.paired_slot = cond
all_slots.extend([cond, uncond_slot])
else:
all_slots.append(cond)
total_blocks = sum(s.required_blocks for s in all_slots)
if not self._cache.has_capacity(total_blocks):
raise RuntimeError(
f"Insufficient KV cache: need {total_blocks} blocks, "
f"have {len(self._cache.available)}/{self._cache.total}"
)
for slot in all_slots:
self._cache.reserve(slot)
slot.status = SlotStatus.ACTIVE
self._active_slots = list(all_slots)
return all_slots
def _arrange_guidance_batch(self, slots):
"""Order: normal, then CFG conditional, then CFG unconditional."""
normal = [s for s in slots if s.cfg_scale <= 1.0]
cond = [s for s in slots if s.cfg_scale > 1.0 and not s.is_unconditional]
uncond = [s for s in slots if s.is_unconditional]
return normal + cond + uncond
def _generation_steps(self, all_slots):
"""Generator yielding (ordered_batch, token_ids, elapsed, is_prefill) per step."""
ordered = self._arrange_guidance_batch(all_slots)
t = perf_counter()
tids = self._pipeline.execute_step(ordered, is_prefill=True)
yield ordered, tids, perf_counter() - t, True
while self._active_slots:
for slot in self._active_slots:
self._cache.grow_if_needed(slot)
batch = self._arrange_guidance_batch(self._active_slots)
t = perf_counter()
tids = self._pipeline.execute_step(batch, is_prefill=False)
yield batch, tids, perf_counter() - t, False
def _finalize_step(self, batch, token_ids, outputs, pbar):
"""Append tokens, check EOS, collect finished outputs."""
is_cfg = (len(batch) > 0 and batch[0].cfg_scale > 1.0
and batch[0].paired_slot is not None)
if is_cfg:
nc = len(batch) // 2
for cond, uncond, tid in zip(batch[:nc], batch[nc:], token_ids):
cond.push_token(tid)
uncond.push_token(tid)
done = ((not cond.ignore_eos and tid == self._eos) or
cond.num_tokens - cond.prompt_length >= cond.max_tokens)
if done:
for s in (cond, uncond):
s.status = SlotStatus.DONE
self._cache.release(s)
if s in self._active_slots:
self._active_slots.remove(s)
outputs[cond.slot_id] = cond.generated_ids
if pbar:
pbar.update(1)
else:
for slot, tid in zip(batch, token_ids):
slot.push_token(tid)
done = ((not slot.ignore_eos and tid == self._eos) or
slot.num_tokens - slot.prompt_length >= slot.max_tokens)
if done:
slot.status = SlotStatus.DONE
self._cache.release(slot)
if slot in self._active_slots:
self._active_slots.remove(slot)
outputs[slot.slot_id] = slot.generated_ids
if pbar:
pbar.update(1)
def _run_generation(self, prompts, sampling_params, use_tqdm, unconditional_prompts):
if self._active_slots:
self.reset()
all_slots = self._prepare_slots(prompts, sampling_params, unconditional_prompts)
pbar = tqdm(total=len(prompts), desc="Generating", dynamic_ncols=True) if use_tqdm else None
prefill_tps = decode_tps = 0.0
outputs = {}
try:
for batch, token_ids, elapsed, is_pf in self._generation_steps(all_slots):
elapsed = max(elapsed, 1e-9)
if is_pf:
prefill_tps = sum(len(s) for s in batch) / elapsed
else:
n_cond = sum(1 for s in batch if not s.is_unconditional)
decode_tps = n_cond / elapsed
self._finalize_step(batch, token_ids, outputs, pbar)
if pbar:
pbar.set_postfix(Prefill=f"{int(prefill_tps)}tok/s",
Decode=f"{int(decode_tps)}tok/s")
if not self._active_slots:
break
except Exception:
self.reset()
raise
finally:
if pbar:
pbar.close()
result = [outputs[sid] for sid in sorted(outputs)]
return [{"text": self.tokenizer.decode(tids), "token_ids": tids} for tids in result]

View file

@ -1,426 +0,0 @@
"""Inference pipeline: model loading, KV cache provisioning, CUDA graphs, and forward passes."""
import torch
import sys
from loguru import logger
from acestep.customized_vllm.transformer import CausalTransformer, load_weights
from acestep.debug_utils import debug_start, debug_end
# ---------------------------------------------------------------------------
# Nucleus / top-k filtering (always receives tensors, never None)
# ---------------------------------------------------------------------------
def _filter_by_top_k(logits, k):
"""Top-k filtering without full vocabulary sort."""
vocab_size = logits.shape[1]
skip = (k <= 0) | (k >= vocab_size)
k_safe = k.masked_fill(skip, 1).long()
max_k = int(k_safe.max().clamp(max=vocab_size))
topk_vals = logits.topk(max_k, dim=1).values
thresh = topk_vals.gather(1, (k_safe - 1).clamp(0, max_k - 1).unsqueeze(1))
thresh.masked_fill_(skip.unsqueeze(1), float("-inf"))
logits.masked_fill_(logits < thresh, float("-inf"))
return logits
def _filter_by_nucleus(logits, k, p):
"""Combined top-k and nucleus (top-p) filtering.
Parameters are always tensors (never None).
k=0 means skip top-k, p=1.0 means skip top-p.
Note: NOT compiled because .any() and int() cause graph breaks in dynamo.
"""
has_k = (k > 0).any()
has_p = (p < 1.0).any()
if not has_p and not has_k:
return logits
if not has_p:
return _filter_by_top_k(logits, k)
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
if has_k:
vocab_size = logits_sort.size(1)
k_clamped = k.clamp(1, vocab_size).long()
thresh = logits_sort.gather(1, (vocab_size - k_clamped).unsqueeze(1))
logits_sort.masked_fill_(logits_sort < thresh, float("-inf"))
probs_sum = logits_sort.softmax(dim=-1).cumsum_(dim=-1)
mask = probs_sum <= (1.0 - p.unsqueeze(1))
mask[:, -1] = False
logits_sort.masked_fill_(mask, float("-inf"))
logits.scatter_(dim=-1, index=logits_idx, src=logits_sort)
return logits
def sample_tokens(logits, temperatures, top_ks, top_ps):
"""Temperature-scaled sampling with top-k/top-p filtering.
Uses Gumbel-max trick for efficient categorical sampling.
"""
logits = logits.float().div_(temperatures.unsqueeze(1))
_filter_by_nucleus(logits, top_ks, top_ps)
probs = torch.softmax(logits, dim=-1)
return probs.div_(torch.empty_like(probs).exponential_(1).clamp_min_(1e-10)).argmax(dim=-1)
# ---------------------------------------------------------------------------
# Inference pipeline
# ---------------------------------------------------------------------------
class InferencePipeline:
"""Loads a model, provisions KV cache, captures CUDA graphs, and runs forward passes."""
def __init__(self, hf_config, model_path: str, block_size: int, max_num_seqs: int,
max_num_batched_tokens: int, max_model_len: int, gpu_memory_utilization: float,
enforce_eager: bool):
torch._dynamo.config.capture_scalar_outputs = True
torch._dynamo.config.verbose = True
self.block_size = block_size
self.enforce_eager = enforce_eager
self.max_num_seqs = max_num_seqs
self.max_model_len = max_model_len
self.max_num_batched_tokens = max_num_batched_tokens
self.gpu_memory_utilization = gpu_memory_utilization
self.hf_config = hf_config
torch.cuda.set_device(0)
saved_dtype = torch.get_default_dtype()
gpu_props = torch.cuda.get_device_properties(0)
bf16_ok = (gpu_props.major, gpu_props.minor) >= (8, 0)
raw = getattr(hf_config, "dtype", getattr(hf_config, "torch_dtype", None))
if isinstance(raw, str):
_map = {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32}
raw = _map.get(raw.replace("torch.", ""), None)
self.dtype = (raw if isinstance(raw, torch.dtype) and raw.is_floating_point else
torch.bfloat16 if bf16_ok else torch.float16)
if self.dtype == torch.bfloat16 and not bf16_ok:
self.dtype = torch.float16
torch.set_default_dtype(self.dtype)
torch.set_default_device("cuda")
try:
self.model = CausalTransformer(hf_config)
_t = debug_start("load_model", prefix="tensor.vllm")
load_weights(self.model, model_path)
debug_end("load_model", _t, prefix="tensor.vllm")
self._init_transfer_buffers()
self._warmup_pipeline()
self._provision_kv_storage()
if not enforce_eager:
self._compile_execution_graphs()
finally:
torch.set_default_device("cpu")
torch.set_default_dtype(saved_dtype)
# -- Transfer buffers ------------------------------------------------
def _init_transfer_buffers(self):
"""Pre-allocate pinned CPU buffers used to shuttle data to GPU."""
bs = self.max_num_seqs
max_blocks = (self.max_model_len + self.block_size - 1) // self.block_size
pin = dict(dtype=torch.float32, device="cpu", pin_memory=True)
pin_i32 = dict(dtype=torch.int32, device="cpu", pin_memory=True)
pin_i64 = dict(dtype=torch.int64, device="cpu", pin_memory=True)
# Dict-based buffer storage (structurally different from per-attribute style)
self._xfer = {
"temps": torch.zeros(bs, **pin),
"guidance": torch.zeros(bs, **pin),
"top_k": torch.zeros(bs, **pin_i32),
"top_p": torch.zeros(bs, **pin),
"rep_pen": torch.zeros(bs, **pin),
"token_ids": torch.zeros(bs, **pin_i64),
"positions": torch.zeros(bs, **pin_i64),
"slots": torch.zeros(bs, **pin_i32),
"ctx_lens": torch.zeros(bs, **pin_i32),
}
# -- Warmup & KV storage ---------------------------------------------
def _warmup_pipeline(self):
from acestep.customized_vllm import GenerationSlot, reset_context
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
n = min(self.max_num_batched_tokens // self.max_model_len, self.max_num_seqs)
dummy_slots = [GenerationSlot([0] * self.max_model_len, block_size=self.block_size) for _ in range(n)]
self._execute_prefill(dummy_slots)
reset_context()
torch.cuda.empty_cache()
def _provision_kv_storage(self):
_t = debug_start("allocate_kv_cache", prefix="tensor.vllm")
hf = self.hf_config
free, total = torch.cuda.mem_get_info()
current = torch.cuda.memory_stats()["allocated_bytes.all.current"]
import os
sim = os.environ.get("MAX_CUDA_VRAM")
if sim:
try:
cap = float(sim) * 1024**3
if cap < total:
total = int(cap)
free = max(0, total - torch.cuda.memory_reserved())
except (ValueError, TypeError):
pass
num_kv_heads = hf.num_key_value_heads
head_dim = getattr(hf, "head_dim", hf.hidden_size // hf.num_attention_heads)
block_bytes = 2 * hf.num_hidden_layers * self.block_size * num_kv_heads * head_dim * self.dtype.itemsize
target = total * self.gpu_memory_utilization
avail = min(free * 0.9, target - current, max(0, free - 1024**3) * 0.9)
if avail <= 0:
avail = free * 0.5
self._num_cache_blocks = max(1, int(avail) // block_bytes)
cap = self._num_cache_blocks * self.block_size
gb = self._num_cache_blocks * block_bytes / 1024**3
logger.info(f"[customized_vllm] KV cache: {self._num_cache_blocks} blocks, "
f"{cap} tokens, {gb:.2f} GB")
self._kv_storage = torch.empty(
2, hf.num_hidden_layers, self._num_cache_blocks,
self.block_size, num_kv_heads, head_dim,
)
layer_id = 0
for m in self.model.modules():
if hasattr(m, "k_cache") and hasattr(m, "v_cache"):
m.k_cache = self._kv_storage[0, layer_id]
m.v_cache = self._kv_storage[1, layer_id]
layer_id += 1
debug_end("allocate_kv_cache", _t, prefix="tensor.vllm")
# -- Input preparation -----------------------------------------------
def _build_cache_index(self, slots):
max_len = max(len(s.cache_blocks) for s in slots)
rows = [s.cache_blocks + [-1] * (max_len - len(s.cache_blocks)) for s in slots]
return torch.tensor(rows, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
def _execute_prefill(self, slots):
"""Prepare prefill inputs and run model forward, returning logits."""
from acestep.customized_vllm import _set_forward_state
ids, pos, cu_q, cu_k = [], [], [0], [0]
max_sq = max_sk = 0
slot_map = []
for s in slots:
n = len(s)
ids.extend(s.token_ids)
pos.extend(range(n))
cu_q.append(cu_q[-1] + n)
cu_k.append(cu_k[-1] + n)
max_sq = max(n, max_sq)
max_sk = max(n, max_sk)
for i in range(s.required_blocks):
if not s.cache_blocks:
continue
start = s.cache_blocks[i] * self.block_size
end = start + (s.tail_block_fill if i == s.required_blocks - 1 else self.block_size)
slot_map.extend(range(start, end))
ids = torch.tensor(ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
pos = torch.tensor(pos, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
cu_q = torch.tensor(cu_q, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
cu_k = torch.tensor(cu_k, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
sm = torch.tensor(slot_map, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
_set_forward_state(True, cu_seqlens_q=cu_q, cu_seqlens_k=cu_k,
max_seqlen_q=max_sq, max_seqlen_k=max_sk, slot_mapping=sm)
return self._forward_pass(ids, pos, is_prefill=True)
def _execute_autoregressive(self, slots):
"""Prepare single-token decode inputs and run model forward, returning logits."""
from acestep.customized_vllm import _set_forward_state
bs = len(slots)
xfer = self._xfer
for i, s in enumerate(slots):
xfer["token_ids"][i] = s.last_token
xfer["positions"][i] = len(s) - 1
xfer["ctx_lens"][i] = len(s)
xfer["slots"][i] = s.cache_blocks[-1] * self.block_size + s.tail_block_fill - 1
ids = xfer["token_ids"][:bs].cuda(non_blocking=True)
pos = xfer["positions"][:bs].cuda(non_blocking=True)
sm = xfer["slots"][:bs].cuda(non_blocking=True)
cl = xfer["ctx_lens"][:bs].cuda(non_blocking=True)
bt = self._build_cache_index(slots)
_set_forward_state(False, slot_mapping=sm, context_lens=cl, block_tables=bt)
return self._forward_pass(ids, pos, is_prefill=False)
# -- Model forward ---------------------------------------------------
@torch.inference_mode()
def _forward_pass(self, input_ids, positions, is_prefill):
from acestep.customized_vllm import _get_forward_state
if is_prefill or self.enforce_eager or input_ids.size(0) > 512:
return self.model.project_to_vocab(self.model(input_ids, positions))
bs = input_ids.size(0)
state = _get_forward_state()
gio = self._graph_io
max_cols = gio["block_tables"].size(1)
if (state.block_tables.size(1) > max_cols or state.block_tables.size(0) != bs
or state.slot_mapping.size(0) != bs or state.context_lens.size(0) != bs):
return self.model.project_to_vocab(self.model(input_ids, positions))
graph = self._graphs[next(x for x in self._compiled_sizes if x >= bs)]
gio["input_ids"][:bs] = input_ids
gio["positions"][:bs] = positions
gio["slot_mapping"].fill_(-1)
gio["slot_mapping"][:bs] = state.slot_mapping
gio["context_lens"].zero_()
gio["context_lens"][:bs] = state.context_lens
gio["block_tables"][:bs].fill_(-1)
gio["block_tables"][:bs, :state.block_tables.size(1)] = state.block_tables
graph.replay()
return self.model.project_to_vocab(gio["outputs"][:bs])
# -- Sampling helpers ------------------------------------------------
def _gather_sampling_config(self, slots, is_cfg):
"""Pack per-slot sampling parameters into GPU tensors.
Always returns tensors (never None) so that `sample_tokens` receives
a stable signature for torch.compile.
"""
targets = slots[:len(slots) // 2] if is_cfg else slots
n = len(targets)
xfer = self._xfer
for i, s in enumerate(targets):
xfer["temps"][i] = s.temperature
xfer["guidance"][i] = s.cfg_scale
xfer["top_k"][i] = s.top_k if s.top_k else 0
xfer["top_p"][i] = s.top_p if s.top_p else 1.0
xfer["rep_pen"][i] = s.repetition_penalty if s.repetition_penalty else 1.0
return (
xfer["temps"][:n].cuda(non_blocking=True),
xfer["guidance"][:n].cuda(non_blocking=True),
xfer["top_k"][:n].cuda(non_blocking=True), # always a tensor
xfer["top_p"][:n].cuda(non_blocking=True), # always a tensor
xfer["rep_pen"][:n].cuda(non_blocking=True),
)
def _constrain_logits(self, logits, slots):
"""Apply per-slot logits processors.
Each slot with a non-None logits_processor is processed individually
using its own token history. In typical ACE-Step usage all batch
slots share the same processor instance, but this handles the general
case correctly.
"""
if not slots:
return logits
try:
for i, slot in enumerate(slots):
if slot.logits_processor is None:
continue
ids_t = torch.tensor([slot.token_ids], device=logits.device)
processed = slot.logits_processor(ids_t, logits[i:i+1].clone())
logits[i] = processed[0]
except TypeError:
import traceback
logger.error(f"TypeError in _constrain_logits "
f"(n_slots={len(slots)}, slot_idx={i}, "
f"processor_state={getattr(slot.logits_processor, 'state', '?')}):\n"
f"{traceback.format_exc()}")
raise
return logits
def _penalize_repetitions(self, logits, slots, penalties):
if penalties is None:
return logits
for i, slot in enumerate(slots):
p = penalties[i].item()
if p == 1.0:
continue
comp = torch.tensor(slot.generated_ids, device=logits.device)
if len(comp) == 0:
continue
mask = torch.zeros(logits.shape[1], dtype=torch.bool, device=logits.device)
mask[comp] = True
penalized = torch.where(logits[i] < 0, logits[i] * p, logits[i] / p)
logits[i] = torch.where(mask, penalized, logits[i])
return logits
# -- Main step -------------------------------------------------------
def execute_step(self, slots, is_prefill):
"""Full forward + sampling step. Returns list of sampled token IDs."""
from acestep.customized_vllm import reset_context
try:
is_cfg = slots[0].cfg_scale > 1.0 and slots[0].paired_slot is not None
try:
logits = (self._execute_prefill(slots) if is_prefill
else self._execute_autoregressive(slots))
finally:
reset_context()
temps, cfg_s, topk, topp, rep_pen = self._gather_sampling_config(slots, is_cfg)
if is_cfg:
nc = len(slots) // 2
cond, uncond = logits[:nc], logits[nc:]
cond = self._penalize_repetitions(cond, slots[:nc], rep_pen)
cfg_logits = uncond + cfg_s.unsqueeze(1) * (cond - uncond)
cfg_logits = torch.nan_to_num(cfg_logits, nan=float('-inf'))
cfg_logits = self._constrain_logits(cfg_logits, slots[:nc])
tids = sample_tokens(cfg_logits, temps, topk, topp).tolist()
for i, s in enumerate(slots[:nc]):
if s.logits_processor_update_state:
s.logits_processor_update_state(tids[i])
return tids
logits = self._penalize_repetitions(logits, slots, rep_pen)
logits = self._constrain_logits(logits.clone(), slots)
tids = sample_tokens(logits, temps, topk, topp).tolist()
for i, s in enumerate(slots):
if s.logits_processor_update_state:
s.logits_processor_update_state(tids[i])
return tids
except TypeError:
import traceback
logger.error(f"TypeError in execute_step "
f"(prefill={is_prefill}, n_slots={len(slots)}):\n"
f"{traceback.format_exc()}")
raise
# -- CUDA graph capture ----------------------------------------------
@torch.inference_mode()
def _compile_execution_graphs(self):
from acestep.customized_vllm import _set_forward_state, reset_context
_t = debug_start("capture_cudagraph", prefix="tensor.vllm")
max_bs = min(self.max_num_seqs, 512)
max_blocks = (self.max_model_len + self.block_size - 1) // self.block_size
ids = torch.zeros(max_bs, dtype=torch.int64)
pos = torch.zeros(max_bs, dtype=torch.int64)
sm = torch.zeros(max_bs, dtype=torch.int32)
cl = torch.zeros(max_bs, dtype=torch.int32)
bt = torch.zeros(max_bs, max_blocks, dtype=torch.int32)
out = torch.zeros(max_bs, self.hf_config.hidden_size)
self._compiled_sizes = sorted(set([1, 2, 4, 8] + list(range(16, max_bs + 1, 16)) + [max_bs]))
self._graphs = {}
pool = None
for bs in reversed(self._compiled_sizes):
g = torch.cuda.CUDAGraph()
_set_forward_state(False, slot_mapping=sm[:bs], context_lens=cl[:bs], block_tables=bt[:bs])
out[:bs] = self.model(ids[:bs], pos[:bs])
with torch.cuda.graph(g, pool):
out[:bs] = self.model(ids[:bs], pos[:bs])
if pool is None:
pool = g.pool()
self._graphs[bs] = g
torch.cuda.synchronize()
reset_context()
self._graph_io = dict(input_ids=ids, positions=pos, slot_mapping=sm,
context_lens=cl, block_tables=bt, outputs=out)
debug_end("capture_cudagraph", _t, prefix="tensor.vllm")
def shutdown(self):
if not self.enforce_eager:
del self._graphs, self._graph_io
torch.cuda.synchronize()

View file

@ -1,416 +0,0 @@
"""Causal transformer architecture with paged KV cache attention.
Implements a Qwen3-compatible transformer with:
- Fused QKV and gate/up projections for efficient weight loading
- Paged KV cache with Flash Attention / SDPA backends
- RoPE position encoding with compiled forward passes
- Safetensors weight loading with shard-aware mapping
"""
import os
from functools import lru_cache
from glob import glob
import torch
from torch import nn
import torch.nn.functional as F
from safetensors import safe_open
from transformers import Qwen3Config
_HAS_TRITON = False
_HAS_FLASH_ATTN = False
try:
import triton
import triton.language as tl
_HAS_TRITON = True
except ImportError:
pass
try:
from flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
_HAS_FLASH_ATTN = True
except ImportError:
pass
# ---------------------------------------------------------------------------
# KV cache write operations
# ---------------------------------------------------------------------------
if _HAS_TRITON:
@triton.jit
def _triton_kv_write(
key_ptr, key_stride, value_ptr, value_stride,
k_cache_ptr, v_cache_ptr, slot_mapping_ptr, D: tl.constexpr,
):
idx = tl.program_id(0)
slot = tl.load(slot_mapping_ptr + idx)
if slot == -1:
return
offs = tl.arange(0, D)
tl.store(k_cache_ptr + slot * D + offs, tl.load(key_ptr + idx * key_stride + offs))
tl.store(v_cache_ptr + slot * D + offs, tl.load(value_ptr + idx * value_stride + offs))
def _torch_kv_write(key, value, k_cache, v_cache, slot_mapping):
N, num_kv_heads, head_dim = key.shape
D = num_kv_heads * head_dim
valid = slot_mapping != -1
slots = slot_mapping[valid]
k_cache.reshape(-1, D)[slots] = key.reshape(N, D)[valid]
v_cache.reshape(-1, D)[slots] = value.reshape(N, D)[valid]
def write_kv_cache(key, value, k_cache, v_cache, slot_mapping):
"""Persist key/value tensors into the paged KV cache."""
if _HAS_TRITON:
N, num_heads, head_dim = key.shape
D = num_heads * head_dim
_triton_kv_write[(N,)](
key, key.stride(0), value, value.stride(0),
k_cache, v_cache, slot_mapping, D,
)
else:
_torch_kv_write(key, value, k_cache, v_cache, slot_mapping)
# ---------------------------------------------------------------------------
# SDPA fallback implementations
# ---------------------------------------------------------------------------
def _sdpa_packed_prefill(q, k, v, cu_q, cu_k, scale, n_heads, n_kv):
results = []
gqa = n_heads != n_kv
for i in range(cu_q.shape[0] - 1):
qs, qe = cu_q[i].item(), cu_q[i + 1].item()
ks, ke = cu_k[i].item(), cu_k[i + 1].item()
qi = q[qs:qe].unsqueeze(0).transpose(1, 2)
ki = k[ks:ke].unsqueeze(0).transpose(1, 2)
vi = v[ks:ke].unsqueeze(0).transpose(1, 2)
oi = F.scaled_dot_product_attention(qi, ki, vi, scale=scale, is_causal=True, enable_gqa=gqa)
results.append(oi.transpose(1, 2).squeeze(0))
return torch.cat(results, dim=0)
def _sdpa_cached_decode(q, k_cache, v_cache, ctx_lens, block_tbl, scale, n_heads, n_kv):
blk_sz = k_cache.shape[1]
results = []
gqa = n_heads != n_kv
for i in range(q.shape[0]):
cl = ctx_lens[i].item()
nb = (cl + blk_sz - 1) // blk_sz
idx = block_tbl[i, :nb]
ki = k_cache[idx].reshape(-1, n_kv, k_cache.shape[-1])[:cl]
vi = v_cache[idx].reshape(-1, n_kv, v_cache.shape[-1])[:cl]
qi = q[i].unsqueeze(0).transpose(1, 2)
ki = ki.unsqueeze(0).transpose(1, 2)
vi = vi.unsqueeze(0).transpose(1, 2)
oi = F.scaled_dot_product_attention(qi, ki, vi, scale=scale, is_causal=False, enable_gqa=gqa)
results.append(oi.transpose(1, 2).squeeze(0))
return torch.stack(results, dim=0)
# ---------------------------------------------------------------------------
# Attention with paged KV cache
# ---------------------------------------------------------------------------
class PagedAttention(nn.Module):
"""Multi-head attention with paged KV cache, Flash Attention or SDPA backend."""
def __init__(self, num_heads, head_dim, num_kv_heads):
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
self.num_kv_heads = num_kv_heads
self._scale = head_dim ** -0.5
self.k_cache = self.v_cache = torch.tensor([])
def forward(self, q, k, v):
from acestep.customized_vllm import _get_forward_state
state = _get_forward_state()
if self.k_cache.numel() and self.v_cache.numel():
write_kv_cache(k, v, self.k_cache, self.v_cache, state.slot_mapping)
if _HAS_FLASH_ATTN:
return self._flash_path(q, k, v, state)
return self._sdpa_path(q, k, v, state)
def _flash_path(self, q, k, v, state):
if state.is_prefill:
return flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=state.cu_seqlens_q, cu_seqlens_k=state.cu_seqlens_k,
max_seqlen_q=state.max_seqlen_q, max_seqlen_k=state.max_seqlen_k,
softmax_scale=self._scale, causal=True,
)
return flash_attn_with_kvcache(
q.unsqueeze(1), self.k_cache, self.v_cache,
cache_seqlens=state.context_lens, block_table=state.block_tables,
softmax_scale=self._scale, causal=True,
)
def _sdpa_path(self, q, k, v, state):
if state.is_prefill:
return _sdpa_packed_prefill(
q, k, v, state.cu_seqlens_q, state.cu_seqlens_k,
self._scale, self.num_heads, self.num_kv_heads,
)
return _sdpa_cached_decode(
q.unsqueeze(1), self.k_cache, self.v_cache,
state.context_lens, state.block_tables,
self._scale, self.num_heads, self.num_kv_heads,
)
# ---------------------------------------------------------------------------
# Core neural-network layers
# ---------------------------------------------------------------------------
class NormLayer(nn.Module):
"""Root-mean-square layer normalisation with optional fused residual add."""
def __init__(self, hidden_size: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(hidden_size))
@torch.compile
def _normalize(self, x):
dt = x.dtype
x = x.float()
x.mul_(torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps))
return x.to(dt).mul_(self.weight)
@torch.compile
def _fused_add_normalize(self, x, residual):
dt = x.dtype
x = x.float().add_(residual.float())
residual = x.to(dt)
x.mul_(torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps))
return x.to(dt).mul_(self.weight), residual
def forward(self, x, residual=None):
if residual is None:
return self._normalize(x)
return self._fused_add_normalize(x, residual)
def _rotate(x, cos, sin):
x1, x2 = torch.chunk(x.float(), 2, dim=-1)
return torch.cat((x1 * cos - x2 * sin, x2 * cos + x1 * sin), dim=-1).to(x.dtype)
class PositionEncoding(nn.Module):
"""Rotary position embedding (RoPE)."""
def __init__(self, head_size: int, max_position: int, base: float):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, head_size, 2, dtype=torch.float) / head_size))
freqs = torch.einsum("i,j->ij", torch.arange(max_position, dtype=torch.float), inv_freq)
cache = torch.cat((freqs.cos(), freqs.sin()), dim=-1).unsqueeze_(1)
self.register_buffer("cos_sin_cache", cache, persistent=False)
@torch.compile
def forward(self, positions, query, key):
cs = self.cos_sin_cache[positions]
cos, sin = cs.chunk(2, dim=-1)
return _rotate(query, cos, sin), _rotate(key, cos, sin)
@lru_cache(1)
def get_position_encoder(head_size: int, max_position: int, base: float):
return PositionEncoding(head_size, max_position, base)
class GatedActivation(nn.Module):
"""SiLU-gated activation: SiLU(x) * y where [x, y] = chunk(input, 2)."""
@torch.compile
def forward(self, x):
a, b = x.chunk(2, -1)
return F.silu(a) * b
# ---------------------------------------------------------------------------
# Fused projections with shard-aware weight loaders
# ---------------------------------------------------------------------------
def _FusedQKVProjection(hidden_size, num_heads, num_kv_heads, head_dim, bias):
q_size = num_heads * head_dim
kv_size = num_kv_heads * head_dim
proj = nn.Linear(hidden_size, q_size + 2 * kv_size, bias=bias)
def _load_shard(param, weight, shard_id):
offsets = {"q": 0, "k": q_size, "v": q_size + kv_size}
sizes = {"q": q_size, "k": kv_size, "v": kv_size}
param.data.narrow(0, offsets[shard_id], sizes[shard_id]).copy_(weight)
proj.weight.weight_loader = _load_shard
if bias:
proj.bias.weight_loader = _load_shard
return proj
def _FusedGateUpProjection(hidden_size, intermediate_size):
proj = nn.Linear(hidden_size, 2 * intermediate_size, bias=False)
def _load_shard(param, weight, shard_id):
param.data.narrow(0, shard_id * intermediate_size, intermediate_size).copy_(weight)
proj.weight.weight_loader = _load_shard
return proj
# ---------------------------------------------------------------------------
# Transformer blocks
# ---------------------------------------------------------------------------
class _SelfAttention(nn.Module):
"""Self-attention block with fused QKV, RoPE, and optional QK-norm."""
def __init__(self, config: Qwen3Config):
super().__init__()
self.num_heads = config.num_attention_heads
self.num_kv_heads = config.num_key_value_heads
self.head_dim = getattr(config, "head_dim", config.hidden_size // self.num_heads)
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
has_bias = getattr(config, "attention_bias", True)
self.qkv_proj = _FusedQKVProjection(
config.hidden_size, self.num_heads, self.num_kv_heads, self.head_dim, has_bias,
)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False)
self.rope = get_position_encoder(
self.head_dim, config.max_position_embeddings,
getattr(config, "rope_theta", 1000000),
)
self.attn = PagedAttention(self.num_heads, self.head_dim, self.num_kv_heads)
if not has_bias:
self.q_norm = NormLayer(self.head_dim, eps=config.rms_norm_eps)
self.k_norm = NormLayer(self.head_dim, eps=config.rms_norm_eps)
self._has_bias = has_bias
def forward(self, positions, hidden_states):
qkv = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q = q.view(-1, self.num_heads, self.head_dim)
k = k.view(-1, self.num_kv_heads, self.head_dim)
v = v.view(-1, self.num_kv_heads, self.head_dim)
if not self._has_bias:
q, k = self.q_norm(q), self.k_norm(k)
q, k = self.rope(positions, q, k)
return self.o_proj(self.attn(q, k, v).flatten(1, -1))
class _FeedForward(nn.Module):
"""MLP with fused gate/up projection and SiLU gating."""
def __init__(self, config: Qwen3Config):
super().__init__()
self.gate_up_proj = _FusedGateUpProjection(config.hidden_size, config.intermediate_size)
self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
self.act = GatedActivation()
def forward(self, x):
return self.down_proj(self.act(self.gate_up_proj(x)))
class _TransformerBlock(nn.Module):
def __init__(self, config: Qwen3Config):
super().__init__()
self.self_attn = _SelfAttention(config)
self.mlp = _FeedForward(config)
self.input_layernorm = NormLayer(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = NormLayer(config.hidden_size, eps=config.rms_norm_eps)
def forward(self, positions, hidden_states, residual):
if residual is None:
hidden_states, residual = self.input_layernorm(hidden_states), hidden_states
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states = self.self_attn(positions, hidden_states)
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
return self.mlp(hidden_states), residual
class CausalTransformer(nn.Module):
"""Causal language model for inference (Qwen3-compatible)."""
WEIGHT_SHARD_MAP = {
"q_proj": ("qkv_proj", "q"),
"k_proj": ("qkv_proj", "k"),
"v_proj": ("qkv_proj", "v"),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}
def __init__(self, config: Qwen3Config):
super().__init__()
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = nn.ModuleList([_TransformerBlock(config) for _ in range(config.num_hidden_layers)])
self.norm = NormLayer(config.hidden_size, eps=config.rms_norm_eps)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
if config.tie_word_embeddings:
self.lm_head.weight = self.embed_tokens.weight
def forward(self, input_ids, positions):
h = self.embed_tokens(input_ids)
residual = None
for layer in self.layers:
h, residual = layer(positions, h, residual)
h, _ = self.norm(h, residual)
return h
def project_to_vocab(self, hidden_states):
"""Extract last-token hidden states (during prefill) and project to vocabulary logits."""
from acestep.customized_vllm import _get_forward_state
state = _get_forward_state()
if state.is_prefill:
hidden_states = hidden_states[state.cu_seqlens_q[1:] - 1].contiguous()
return self.lm_head(hidden_states)
# ---------------------------------------------------------------------------
# Weight loading
# ---------------------------------------------------------------------------
def _copy_weight(param, loaded_weight):
param.data.copy_(loaded_weight)
def _resolve_parameter(model, name):
for candidate in (name, f"model.{name}", name.removeprefix("model.")):
try:
return model.get_parameter(candidate)
except AttributeError:
continue
return None
def load_weights(model: nn.Module, path: str):
"""Load safetensors weights into model with shard-aware mapping for fused projections."""
shard_map = getattr(model, "WEIGHT_SHARD_MAP", {})
files = glob(os.path.join(path, "*.safetensors"))
if not files:
raise FileNotFoundError(f"No .safetensors files found in {path}")
for filepath in files:
with safe_open(filepath, "pt", "cpu") as f:
for weight_name in f.keys():
for src_key, (dst_key, shard_id) in shard_map.items():
if src_key in weight_name:
param_name = weight_name.replace(src_key, dst_key)
param = _resolve_parameter(model, param_name)
if param is None:
continue
param.weight_loader(param, f.get_tensor(weight_name), shard_id)
break
else:
param = _resolve_parameter(model, weight_name)
if param is None:
continue
loader = getattr(param, "weight_loader", _copy_weight)
loader(param, f.get_tensor(weight_name))

View file

@ -877,7 +877,7 @@ def get_lm_gpu_memory_ratio(
except (ValueError, TypeError):
pass
# The ratio is relative to total GPU memory (customized_vllm convention),
# The ratio is relative to total GPU memory (nano-vllm convention),
# but we compute it so that the LM only claims what's actually free
# minus a safety margin for DiT inference activations.
# Reserve at least 1.5 GB for DiT inference activations
@ -888,7 +888,7 @@ def get_lm_gpu_memory_ratio(
usable_for_lm = min(usable_for_lm, total_target_gb)
# Convert to ratio of total GPU memory
# customized_vllm uses: target_total_usage = total * gpu_memory_utilization
# nano-vllm uses: target_total_usage = total * gpu_memory_utilization
# We want: (total * ratio) = current_usage + usable_for_lm
current_usage_gb = actual_total_gb - free_gb
desired_total_usage = current_usage_gb + usable_for_lm

View file

@ -5,7 +5,7 @@ import sys
def _has_working_triton_installation() -> bool:
"""Return whether the Triton modules required by customized_vllm import cleanly."""
"""Return whether the Triton modules required by nano-vllm import cleanly."""
try:
importlib.import_module("triton")
importlib.import_module("triton.language")

View file

@ -72,7 +72,7 @@ class LlmInitializeBackendCompatTests(unittest.TestCase):
_mock_empty_cache: MagicMock,
_mock_synchronize: MagicMock,
) -> None:
"""Initialization should avoid customized_vllm when the Windows Triton preflight fails."""
"""Initialization should avoid nano-vllm when the Windows Triton preflight fails."""
handler = LLMHandler()
mock_tokenizer.return_value = MagicMock()
mock_gpu_config.return_value = SimpleNamespace(max_duration_with_lm=600, tier="tier6")

View file

@ -37,7 +37,7 @@ def _warn_if_prerelease_python():
if getattr(v, "releaselevel", "final") != "final" and sys.platform.startswith("linux"):
warnings.warn(
f"Detected pre-release Python {sys.version.split()[0]} ({getattr(v, 'releaselevel', '')}). "
"This is known to cause segmentation faults with the vLLM engine on Linux. "
"This is known to cause segmentation faults with vLLM/nano-vllm on Linux. "
"Please install a stable Python release (e.g. 3.11.12+), or use --backend pt as a workaround.",
RuntimeWarning,
stacklevel=2,
@ -591,7 +591,7 @@ class LLMHandler:
# Disable CUDA/HIP graph capture on ROCm (unverified on RDNA3 Windows),
# on Jetson (SDPA paged-cache decode calls .item() during capture),
# and when flash_attn is not installed (same .item() incompatibility on all CUDA hardware).
# When flash_attn is unavailable, customized_vllm falls back to _sdpa_cached_decode
# When flash_attn is unavailable, nano-vllm falls back to _sdpa_decode_with_paged_cache
# which contains a Python loop with .item() calls. These force CPU-GPU
# synchronisation that is forbidden inside torch.cuda.CUDAGraph capture,
# corrupting the CUDA context and causing downstream errors such as:
@ -603,7 +603,7 @@ class LLMHandler:
dev_name = torch.cuda.get_device_name(0).lower()
is_jetson = any(k in dev_name for k in ("orin", "xavier", "tegra"))
if is_jetson:
logger.info(f"Jetson GPU detected ({dev_name}): disabling CUDA graph capture for customized_vllm")
logger.info(f"Jetson GPU detected ({dev_name}): disabling CUDA graph capture for nano-vllm")
except Exception:
pass
_has_flash_attn = False
@ -614,7 +614,7 @@ class LLMHandler:
pass
if not _has_flash_attn:
logger.info(
"flash_attn not installed: disabling CUDA graph capture for customized_vllm "
"flash_attn not installed: disabling CUDA graph capture for nano-vllm "
"(SDPA fallback uses .item() calls in paged-cache decode that are "
"incompatible with CUDA graph capture)"
)
@ -626,7 +626,7 @@ class LLMHandler:
pass
if not _has_triton:
logger.info(
"Triton not available: disabling CUDA graph capture for customized_vllm "
"Triton not available: disabling CUDA graph capture for nano-vllm "
"(CUDA graphs require torch.compile which depends on Triton)"
)
enforce_eager_for_vllm = bool(is_rocm or is_jetson or not _has_flash_attn or not _has_triton)
@ -751,11 +751,11 @@ class LLMHandler:
logger.error("CUDA/ROCm is not available. Please check your GPU setup.")
return "❌ CUDA/ROCm is not available. Please check your GPU setup."
try:
from acestep.customized_vllm import LLM, SamplingParams
except ImportError as exc:
from nanovllm import LLM, SamplingParams
except ImportError:
self.llm_initialized = False
logger.error(f"Failed to import customized_vllm engine: {exc}")
return f"❌ Failed to import customized_vllm engine: {exc}"
logger.error("nano-vllm is not installed. Please install it using 'cd acestep/third_parts/nano-vllm && pip install .'")
return "❌ nano-vllm is not installed. Please install it using 'cd acestep/third_parts/nano-vllm && pip install .'"
try:
current_device = torch.cuda.current_device()
@ -857,7 +857,7 @@ class LLMHandler:
Accepts either a single formatted prompt (str) or a list of formatted prompts (List[str]).
Returns a single string for single mode, or a list of strings for batch mode.
"""
from acestep.customized_vllm import SamplingParams
from nanovllm import SamplingParams
# Determine if batch mode
formatted_prompt_list, is_batch = self._normalize_batch_input(formatted_prompts)
@ -1490,11 +1490,8 @@ class LLMHandler:
seeds=seeds,
)
except Exception as e:
error_detail = traceback.format_exc()
logger.error(
f"Error in batch codes generation: {type(e).__name__}: {e}\n{error_detail}"
)
error_msg = f"Error in batch codes generation: {type(e).__name__}: {e}"
error_msg = f"Error in batch codes generation: {str(e)}"
logger.error(error_msg)
return {
"metadata": [],
"audio_codes": [],
@ -2417,11 +2414,11 @@ class LLMHandler:
import traceback
error_detail = traceback.format_exc()
logger.error(f"Error in generate_from_formatted_prompt: {type(e).__name__}: {e}\n{error_detail}")
# Reset vllm engine state on error to prevent stale context from causing
# Reset nano-vllm state on error to prevent stale context from causing
# subsequent CUDA illegal memory access errors
if self.llm_backend == "vllm":
try:
from acestep.customized_vllm import reset_context
from nanovllm.utils.context import reset_context
reset_context()
except ImportError:
pass
@ -4016,7 +4013,7 @@ class LLMHandler:
yield
return
# If using vllm engine or MLX, do not offload (managed differently)
# If using nanovllm or MLX, do not offload (managed differently)
if self.llm_backend in ("vllm", "mlx"):
yield
return

View file

@ -80,8 +80,6 @@ class TestCotCfgScaleFixed(unittest.TestCase):
handler = LLMHandler()
handler.llm_initialized = True
handler.llm_backend = "pt"
handler.llm = MagicMock()
handler.llm_tokenizer = MagicMock()
captured_cfg = {}
@ -127,8 +125,6 @@ class TestCotCfgScaleFixed(unittest.TestCase):
handler = LLMHandler()
handler.llm_initialized = True
handler.llm_backend = "pt"
handler.llm = MagicMock()
handler.llm_tokenizer = MagicMock()
# Capture cfg passed to generate_from_formatted_prompt
captured_cfgs = []
@ -138,24 +134,26 @@ class TestCotCfgScaleFixed(unittest.TestCase):
# Return a minimal CoT response so Phase 1 succeeds
return "<think>metadata</think>", "ok"
with patch.object(handler, "generate_from_formatted_prompt", side_effect=capturing_gen), \
patch.object(handler, "build_formatted_prompt", return_value="P"), \
patch.object(handler, "parse_lm_output", return_value=({}, "")), \
patch.object(handler, "_format_metadata_as_cot", return_value=""), \
patch.object(handler, "build_formatted_prompt_with_cot", return_value="P2"):
# Invoke with cfg_scale=2.0 (typical UI default)
handler.generate_with_stop_condition(
caption="test caption",
lyrics="test lyrics",
cfg_scale=2.0,
temperature=0.6,
negative_prompt="",
top_k=None,
top_p=None,
repetition_penalty=1.0,
infer_type="dit", # Phase 1 only avoids Phase 2 setup
progress=lambda *a, **kw: None,
)
with patch.object(handler, "generate_from_formatted_prompt", side_effect=capturing_gen):
with patch.object(handler, "build_formatted_prompt", return_value="P"):
with patch.object(handler, "_parse_metadata_from_cot", return_value={}):
with patch.object(handler, "_format_metadata_as_cot", return_value=""):
with patch.object(
handler, "build_formatted_prompt_with_cot", return_value="P2"
):
# Invoke with cfg_scale=2.0 (typical UI default)
handler.generate_with_stop_condition(
caption="test caption",
lyrics="test lyrics",
cfg_scale=2.0,
temperature=0.6,
negative_prompt="",
top_k=None,
top_p=None,
repetition_penalty=1.0,
infer_type="dit", # Phase 1 only avoids Phase 2 setup
progress=lambda *a, **kw: None,
)
# At least one generate_from_formatted_prompt call must exist
self.assertTrue(len(captured_cfgs) >= 1, "generate_from_formatted_prompt was not called")

View file

@ -4,17 +4,13 @@ Regression test for the bug where CUDA graph capture would fail when
``flash_attn`` is not installed because the SDPA paged-cache decode path
calls ``.item()`` inside the capture region (a forbidden CPU-GPU sync).
When flash_attn is absent, customized_vllm must run in ``enforce_eager=True``
When flash_attn is absent, nano-vllm must run in ``enforce_eager=True``
(eager mode, no CUDA graph capture) to avoid corrupting the CUDA context.
"""
import sys
import unittest
from types import ModuleType
from unittest.mock import MagicMock, patch
_SENTINEL = object()
try:
from acestep.llm_inference import LLMHandler
_IMPORT_ERROR = None
@ -40,16 +36,13 @@ def _mock_gpu_config():
class TestEnforceEagerWhenFlashAttnMissing(unittest.TestCase):
"""Verify that enforce_eager=True is set when flash_attn is not installed."""
def _run_initialize_with_mocks(self, flash_attn_available: bool,
device_name: str = "NVIDIA GeForce RTX 4090",
triton_available: bool = True):
def _run_initialize_with_mocks(self, flash_attn_available: bool, device_name: str = "NVIDIA GeForce RTX 4090"):
"""Call handler.initialize() with all heavy operations mocked.
Args:
flash_attn_available: Whether flash_attn should appear detectable
via ``importlib.util.find_spec``.
device_name: Simulated CUDA device name.
triton_available: Whether ``import triton`` should succeed.
Returns:
The ``enforce_eager`` value that was passed to ``_initialize_5hz_lm_vllm``.
@ -61,45 +54,28 @@ class TestEnforceEagerWhenFlashAttnMissing(unittest.TestCase):
captured = {}
def fake_init_vllm(model_path: str, enforce_eager: bool = False, **_kw) -> str:
def fake_init_vllm(model_path: str, enforce_eager: bool = False) -> str:
captured["enforce_eager"] = enforce_eager
return "✅ ok"
# Inject a fake triton module so ``import triton`` inside initialize()
# succeeds (or fails) as requested.
fake_triton = ModuleType("triton") if triton_available else None
saved_triton = sys.modules.get("triton", _SENTINEL)
try:
if triton_available:
sys.modules["triton"] = fake_triton
else:
sys.modules.pop("triton", None)
with patch("importlib.util.find_spec", return_value=find_spec_return), \
patch("os.path.exists", return_value=True), \
patch("acestep.llm_inference.AutoTokenizer") as mock_tok, \
patch("acestep.llm_inference.get_global_gpu_config", return_value=_mock_gpu_config()), \
patch("acestep.llm_inference.MetadataConstrainedLogitsProcessor"), \
patch("torch.cuda.is_available", return_value=True), \
patch("torch.cuda.empty_cache"), \
patch("torch.cuda.synchronize"), \
patch("torch.cuda.get_device_name", return_value=device_name), \
patch.object(handler, "_initialize_5hz_lm_vllm", side_effect=fake_init_vllm):
with patch("importlib.util.find_spec", return_value=find_spec_return), \
patch("os.path.exists", return_value=True), \
patch("acestep.llm_inference.AutoTokenizer") as mock_tok, \
patch("acestep.llm_inference.get_global_gpu_config", return_value=_mock_gpu_config()), \
patch("acestep.llm_inference.get_gpu_memory_gb", return_value=16.0), \
patch("acestep.llm_inference.MetadataConstrainedLogitsProcessor"), \
patch("torch.cuda.is_available", return_value=True), \
patch("torch.cuda.empty_cache"), \
patch("torch.cuda.synchronize"), \
patch("torch.cuda.get_device_name", return_value=device_name), \
patch("torch.cuda.mem_get_info", return_value=(16 * 1024**3, 16 * 1024**3)), \
patch.object(handler, "_initialize_5hz_lm_vllm", side_effect=fake_init_vllm):
mock_tok.from_pretrained.return_value = MagicMock()
handler.initialize(
checkpoint_dir="/tmp/fake_ckpt",
lm_model_path="model",
backend="vllm",
device="cuda",
)
finally:
if saved_triton is _SENTINEL:
sys.modules.pop("triton", None)
else:
sys.modules["triton"] = saved_triton
mock_tok.from_pretrained.return_value = MagicMock()
handler.initialize(
checkpoint_dir="/tmp/fake_ckpt",
lm_model_path="model",
backend="vllm",
device="cuda",
)
return captured.get("enforce_eager")

View file

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2025 Xingkai Yu
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View file

@ -0,0 +1,66 @@
<p align="center">
<img width="300" src="assets/logo.png">
</p>
<p align="center">
<a href="https://trendshift.io/repositories/15323" target="_blank"><img src="https://trendshift.io/api/badge/repositories/15323" alt="GeeeekExplorer%2Fnano-vllm | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
</p>
# Nano-vLLM
A lightweight vLLM implementation built from scratch.
## Key Features
* 🚀 **Fast offline inference** - Comparable inference speeds to vLLM
* 📖 **Readable codebase** - Clean implementation in ~ 1,200 lines of Python code
* ⚡ **Optimization Suite** - Prefix caching, Tensor Parallelism, Torch compilation, CUDA graph, etc.
## Installation
```bash
pip install git+https://github.com/GeeeekExplorer/nano-vllm.git
```
## Model Download
To download the model weights manually, use the following command:
```bash
huggingface-cli download --resume-download Qwen/Qwen3-0.6B \
--local-dir ~/huggingface/Qwen3-0.6B/ \
--local-dir-use-symlinks False
```
## Quick Start
See `example.py` for usage. The API mirrors vLLM's interface with minor differences in the `LLM.generate` method:
```python
from nanovllm import LLM, SamplingParams
llm = LLM("/YOUR/MODEL/PATH", enforce_eager=True, tensor_parallel_size=1)
sampling_params = SamplingParams(temperature=0.6, max_tokens=256)
prompts = ["Hello, Nano-vLLM."]
outputs = llm.generate(prompts, sampling_params)
outputs[0]["text"]
```
## Benchmark
See `bench.py` for benchmark.
**Test Configuration:**
- Hardware: RTX 4070 Laptop (8GB)
- Model: Qwen3-0.6B
- Total Requests: 256 sequences
- Input Length: Randomly sampled between 1001024 tokens
- Output Length: Randomly sampled between 1001024 tokens
**Performance Results:**
| Inference Engine | Output Tokens | Time (s) | Throughput (tokens/s) |
|----------------|-------------|----------|-----------------------|
| vLLM | 133,966 | 98.37 | 1361.84 |
| Nano-vLLM | 133,966 | 93.41 | 1434.13 |
## Star History
[![Star History Chart](https://api.star-history.com/svg?repos=GeeeekExplorer/nano-vllm&type=Date)](https://www.star-history.com/#GeeeekExplorer/nano-vllm&Date)

Binary file not shown.

After

Width:  |  Height:  |  Size: 387 KiB

View file

@ -0,0 +1,32 @@
import os
import time
from random import randint, seed
from nanovllm import LLM, SamplingParams
# from vllm import LLM, SamplingParams
def main():
seed(0)
num_seqs = 256
max_input_len = 1024
max_output_len = 1024
path = os.path.expanduser("~/huggingface/Qwen3-0.6B/")
llm = LLM(path, enforce_eager=False, max_model_len=4096)
prompt_token_ids = [[randint(0, 10000) for _ in range(randint(100, max_input_len))] for _ in range(num_seqs)]
sampling_params = [SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=randint(100, max_output_len)) for _ in range(num_seqs)]
# uncomment the following line for vllm
# prompt_token_ids = [dict(prompt_token_ids=p) for p in prompt_token_ids]
llm.generate(["Benchmark: "], SamplingParams())
t = time.time()
llm.generate(prompt_token_ids, sampling_params, use_tqdm=False)
t = (time.time() - t)
total_tokens = sum(sp.max_tokens for sp in sampling_params)
throughput = total_tokens / t
print(f"Total: {total_tokens}tok, Time: {t:.2f}s, Throughput: {throughput:.2f}tok/s")
if __name__ == "__main__":
main()

View file

@ -0,0 +1,33 @@
import os
from nanovllm import LLM, SamplingParams
from transformers import AutoTokenizer
def main():
path = os.path.expanduser("~/huggingface/Qwen3-0.6B/")
tokenizer = AutoTokenizer.from_pretrained(path)
llm = LLM(path, enforce_eager=True, tensor_parallel_size=1)
sampling_params = SamplingParams(temperature=0.6, max_tokens=256)
prompts = [
"introduce yourself",
"list all prime numbers within 100",
]
prompts = [
tokenizer.apply_chat_template(
[{"role": "user", "content": prompt}],
tokenize=False,
add_generation_prompt=True,
)
for prompt in prompts
]
outputs = llm.generate(prompts, sampling_params)
for prompt, output in zip(prompts, outputs):
print("\n")
print(f"Prompt: {prompt!r}")
print(f"Completion: {output['text']!r}")
if __name__ == "__main__":
main()

View file

@ -0,0 +1,2 @@
from nanovllm.llm import LLM
from nanovllm.sampling_params import SamplingParams

View file

@ -0,0 +1,26 @@
import os
from dataclasses import dataclass
from transformers import AutoConfig
@dataclass
class Config:
model: str
max_num_batched_tokens: int = 16384
max_num_seqs: int = 512
max_model_len: int = 4096
gpu_memory_utilization: float = 0.9
tensor_parallel_size: int = 1
enforce_eager: bool = False
hf_config: AutoConfig | None = None
eos: int = -1
kvcache_block_size: int = 256
num_kvcache_blocks: int = -1
def __post_init__(self):
assert os.path.isdir(self.model)
assert self.kvcache_block_size % 256 == 0
assert 1 <= self.tensor_parallel_size <= 8
self.hf_config = AutoConfig.from_pretrained(self.model)
self.max_model_len = min(self.max_model_len, self.hf_config.max_position_embeddings)
assert self.max_num_batched_tokens >= self.max_model_len

View file

@ -0,0 +1,99 @@
"""
Distributed utilities for nano-vllm.
This module provides wrapper functions for torch.distributed that gracefully
handle single-GPU mode (world_size == 1) without requiring distributed initialization.
"""
import torch.distributed as dist
# Global flag to track if distributed is actually initialized
_distributed_initialized = False
def initialize_distributed(backend: str, init_method: str, world_size: int, rank: int) -> bool:
"""
Initialize distributed process group only if world_size > 1.
Args:
backend: Distributed backend (e.g., "nccl" or "gloo")
init_method: Initialization method (e.g., "tcp://127.0.0.1:2333")
world_size: Total number of processes
rank: Rank of current process
Returns:
True if distributed was initialized, False otherwise
"""
global _distributed_initialized
if world_size == 1:
# Single GPU mode - no distributed needed
_distributed_initialized = False
return False
# Multi-GPU mode - initialize distributed
dist.init_process_group(backend, init_method, world_size=world_size, rank=rank)
_distributed_initialized = True
return True
def is_initialized() -> bool:
"""Check if distributed is initialized."""
return _distributed_initialized
def get_rank() -> int:
"""Get current process rank. Returns 0 if distributed is not initialized."""
if _distributed_initialized:
return dist.get_rank()
return 0
def get_world_size() -> int:
"""Get world size. Returns 1 if distributed is not initialized."""
if _distributed_initialized:
return dist.get_world_size()
return 1
def barrier():
"""Synchronize all processes. No-op if distributed is not initialized."""
if _distributed_initialized:
dist.barrier()
def all_reduce(tensor, op=None):
"""
All-reduce operation. No-op if distributed is not initialized.
Args:
tensor: Tensor to reduce
op: Reduce operation (default: SUM)
"""
if _distributed_initialized:
if op is None:
op = dist.ReduceOp.SUM
dist.all_reduce(tensor, op)
def gather(tensor, gather_list=None, dst=0):
"""
Gather tensors from all processes. No-op if distributed is not initialized.
Args:
tensor: Tensor to gather
gather_list: List to gather into (only used on dst rank)
dst: Destination rank
"""
if _distributed_initialized:
dist.gather(tensor, gather_list, dst)
def destroy_process_group():
"""Destroy process group. No-op if distributed is not initialized."""
global _distributed_initialized
if _distributed_initialized:
dist.destroy_process_group()
_distributed_initialized = False

View file

@ -0,0 +1,136 @@
import os
from collections import deque
import xxhash
import numpy as np
from nanovllm.engine.sequence import Sequence
# Debug logging - enable with NANOVLLM_DEBUG=1
_DEBUG = os.environ.get("NANOVLLM_DEBUG", "0") == "1"
def _debug_log(msg: str):
"""Print debug message if NANOVLLM_DEBUG is enabled"""
if _DEBUG:
print(f"[nanovllm block_mgr DEBUG] {msg}", flush=True)
class Block:
def __init__(self, block_id):
self.block_id = block_id
self.ref_count = 0
self.hash = -1
self.token_ids = []
def update(self, hash: int, token_ids: list[int]):
self.hash = hash
self.token_ids = token_ids
def reset(self):
self.ref_count = 1
self.hash = -1
self.token_ids = []
class BlockManager:
def __init__(self, num_blocks: int, block_size: int):
self.block_size = block_size
self.blocks: list[Block] = [Block(i) for i in range(num_blocks)]
self.hash_to_block_id: dict[int, int] = dict()
self.free_block_ids: deque[int] = deque(range(num_blocks))
self.used_block_ids: set[int] = set()
@classmethod
def compute_hash(cls, token_ids: list[int], prefix: int = -1):
h = xxhash.xxh64()
if prefix != -1:
h.update(prefix.to_bytes(8, "little"))
h.update(np.array(token_ids).tobytes())
return h.intdigest()
def _allocate_block(self, block_id: int) -> Block:
block = self.blocks[block_id]
assert block.ref_count == 0
block.reset()
self.free_block_ids.remove(block_id)
self.used_block_ids.add(block_id)
return self.blocks[block_id]
def _deallocate_block(self, block_id: int) -> Block:
assert self.blocks[block_id].ref_count == 0
self.used_block_ids.remove(block_id)
self.free_block_ids.append(block_id)
def can_allocate(self, seq: Sequence) -> bool:
return len(self.free_block_ids) >= seq.num_blocks
def allocate(self, seq: Sequence):
_debug_log(f"allocate: seq_id={seq.seq_id}, len={len(seq)}, num_blocks={seq.num_blocks}, "
f"free_blocks={len(self.free_block_ids)}")
assert not seq.block_table
h = -1
cache_miss = False
for i in range(seq.num_blocks):
token_ids = seq.block(i)
h = self.compute_hash(token_ids, h) if len(token_ids) == self.block_size else -1
block_id = self.hash_to_block_id.get(h, -1)
if block_id == -1 or self.blocks[block_id].token_ids != token_ids:
cache_miss = True
if cache_miss:
if len(self.free_block_ids) == 0:
_debug_log(f" ERROR: no free blocks available!")
block_id = self.free_block_ids[0]
block = self._allocate_block(block_id)
else:
seq.num_cached_tokens += self.block_size
if block_id in self.used_block_ids:
block = self.blocks[block_id]
block.ref_count += 1
else:
block = self._allocate_block(block_id)
if h != -1:
block.update(h, token_ids)
self.hash_to_block_id[h] = block_id
seq.block_table.append(block_id)
_debug_log(f" allocated block_table: {seq.block_table}")
def deallocate(self, seq: Sequence):
_debug_log(f"deallocate: seq_id={seq.seq_id}, block_table={seq.block_table}")
for block_id in reversed(seq.block_table):
block = self.blocks[block_id]
block.ref_count -= 1
_debug_log(f" block_id={block_id}, ref_count after decrement={block.ref_count}")
if block.ref_count == 0:
# Fix: Clean up hash_to_block_id mapping to prevent stale references
# This prevents CUDA illegal memory access when prefix cache tries to
# reuse a block_id that has already been freed
if block.hash != -1:
cached_id = self.hash_to_block_id.get(block.hash)
if cached_id == block_id:
del self.hash_to_block_id[block.hash]
self._deallocate_block(block_id)
seq.num_cached_tokens = 0
seq.block_table.clear()
_debug_log(f" deallocated, free_blocks={len(self.free_block_ids)}")
def can_append(self, seq: Sequence) -> bool:
return len(self.free_block_ids) >= (len(seq) % self.block_size == 1)
def may_append(self, seq: Sequence):
block_table = seq.block_table
last_block = self.blocks[block_table[-1]]
if len(seq) % self.block_size == 1:
assert last_block.hash != -1
block_id = self.free_block_ids[0]
self._allocate_block(block_id)
block_table.append(block_id)
elif len(seq) % self.block_size == 0:
assert last_block.hash == -1
token_ids = seq.block(seq.num_blocks-1)
prefix = self.blocks[block_table[-2]].hash if len(block_table) > 1 else -1
h = self.compute_hash(token_ids, prefix)
last_block.update(h, token_ids)
self.hash_to_block_id[h] = last_block.block_id
else:
assert last_block.hash == -1

View file

@ -0,0 +1,178 @@
import atexit
import threading
from dataclasses import fields
from time import perf_counter
from tqdm.auto import tqdm
from transformers import AutoTokenizer
import torch.multiprocessing as mp
from nanovllm.config import Config
from nanovllm.sampling_params import SamplingParams
from nanovllm.engine.sequence import Sequence
from nanovllm.engine.scheduler import Scheduler
from nanovllm.engine.model_runner import ModelRunner
class LLMEngine:
def __init__(self, model, **kwargs):
config_fields = {field.name for field in fields(Config)}
config_kwargs = {k: v for k, v in kwargs.items() if k in config_fields}
config = Config(model, **config_kwargs)
self.ps = []
self.events = []
# Thread-safety lock for generate().
# The scheduler, block manager, model runner, and CUDA graph buffers are all
# shared mutable state that is NOT thread-safe. In concurrent serving scenarios
# (API server with ThreadPoolExecutor, multiple queue workers, Gradio with
# concurrent requests), multiple threads can call generate() simultaneously.
# Without this lock, concurrent access corrupts scheduler state, block tables,
# and CUDA graph input buffers, leading to intermittent CUDA device-side
# assertion failures (illegal memory access in KV cache).
self._generate_lock = threading.Lock()
ctx = mp.get_context("spawn")
for i in range(1, config.tensor_parallel_size):
event = ctx.Event()
process = ctx.Process(target=ModelRunner, args=(config, i, event))
process.start()
self.ps.append(process)
self.events.append(event)
self.model_runner = ModelRunner(config, 0, self.events)
tokenizer = kwargs.get("tokenizer", None)
if tokenizer is not None:
self.tokenizer = tokenizer
else:
self.tokenizer = AutoTokenizer.from_pretrained(config.model, use_fast=True)
config.eos = self.tokenizer.eos_token_id
self.scheduler = Scheduler(config)
atexit.register(self.exit)
def exit(self):
self.model_runner.call("exit")
del self.model_runner
for p in self.ps:
p.join()
def add_request(self, prompt: str | list[int], sampling_params: SamplingParams, unconditional_prompt: str | list[int] | None = None):
if isinstance(prompt, str):
prompt = self.tokenizer.encode(prompt)
# For CFG: if cfg_scale > 1.0, create both conditional and unconditional sequences
if sampling_params.cfg_scale > 1.0:
if unconditional_prompt is None:
# Try to construct unconditional prompt by replacing user input with "NO USER INPUT"
# This is a fallback - ideally users should provide unconditional_prompt
if isinstance(prompt, list):
# For now, just use the same prompt (user should provide unconditional_prompt)
# TODO: Implement automatic "NO USER INPUT" replacement if possible
unconditional_prompt = prompt
else:
unconditional_prompt = prompt
if isinstance(unconditional_prompt, str):
unconditional_prompt = self.tokenizer.encode(unconditional_prompt)
# Create unconditional sequence first (so we can reference it from conditional)
uncond_seq = Sequence(unconditional_prompt, sampling_params, is_unconditional=True)
# Create conditional sequence with reference to unconditional
cond_seq = Sequence(prompt, sampling_params, is_unconditional=False, conditional_seq=uncond_seq)
uncond_seq.paired_seq = cond_seq # Link them bidirectionally
# Add both sequences to scheduler
self.scheduler.add(cond_seq)
self.scheduler.add(uncond_seq)
else:
seq = Sequence(prompt, sampling_params)
self.scheduler.add(seq)
def step(self):
seqs, is_prefill = self.scheduler.schedule()
token_ids = self.model_runner.call("run", seqs, is_prefill)
self.scheduler.postprocess(seqs, token_ids)
# Only output conditional sequences (unconditional sequences are just for CFG computation)
output_seqs = [seq for seq in seqs if seq.is_finished and (seq.cfg_scale <= 1.0 or not seq.is_unconditional)]
outputs = [(seq.seq_id, seq.completion_token_ids) for seq in output_seqs]
num_tokens = sum(len(seq) for seq in seqs) if is_prefill else -len([s for s in seqs if not s.is_unconditional])
return outputs, num_tokens
def is_finished(self):
return self.scheduler.is_finished()
def reset(self):
"""
Reset the scheduler state and release all allocated blocks.
This should be called when an exception occurs during generation to prevent
KV cache block leaks that can cause 'deque index out of range' errors.
"""
# Deallocate all running sequences
while self.scheduler.running:
seq = self.scheduler.running.popleft()
if seq.block_table: # Only deallocate if blocks are allocated
self.scheduler.block_manager.deallocate(seq)
# Deallocate all waiting sequences (they might have blocks from preemption)
while self.scheduler.waiting:
seq = self.scheduler.waiting.popleft()
if seq.block_table:
self.scheduler.block_manager.deallocate(seq)
def generate(
self,
prompts: list[str] | list[list[int]],
sampling_params: SamplingParams | list[SamplingParams],
use_tqdm: bool = True,
unconditional_prompts: list[str] | list[list[int]] | None = None,
) -> list[str]:
# Serialize access to the engine to prevent concurrent corruption of
# scheduler state, block manager, CUDA graph buffers, and KV cache.
# This is the primary defense against the intermittent CUDA device-side
# assertion error that occurs in concurrent serving scenarios.
with self._generate_lock:
return self._generate_impl(prompts, sampling_params, use_tqdm, unconditional_prompts)
def _generate_impl(
self,
prompts: list[str] | list[list[int]],
sampling_params: SamplingParams | list[SamplingParams],
use_tqdm: bool = True,
unconditional_prompts: list[str] | list[list[int]] | None = None,
) -> list[str]:
# Clean up any residual state from previous interrupted generations
# This prevents 'deque index out of range' errors from accumulated block leaks
if not self.is_finished():
self.reset()
if use_tqdm:
pbar = tqdm(total=len(prompts), desc="Generating", dynamic_ncols=True)
if not isinstance(sampling_params, list):
sampling_params = [sampling_params] * len(prompts)
if unconditional_prompts is None:
unconditional_prompts = [None] * len(prompts)
for prompt, sp, uncond_prompt in zip(prompts, sampling_params, unconditional_prompts):
self.add_request(prompt, sp, uncond_prompt)
outputs = {}
prefill_throughput = decode_throughput = 0.
try:
while not self.is_finished():
t = perf_counter()
output, num_tokens = self.step()
if use_tqdm:
if num_tokens > 0:
prefill_throughput = num_tokens / (perf_counter() - t)
else:
decode_throughput = -num_tokens / (perf_counter() - t)
pbar.set_postfix({
"Prefill": f"{int(prefill_throughput)}tok/s",
"Decode": f"{int(decode_throughput)}tok/s",
})
for seq_id, token_ids in output:
outputs[seq_id] = token_ids
if use_tqdm:
pbar.update(1)
except Exception:
# Clean up on exception to prevent block leaks
self.reset()
raise
finally:
if use_tqdm:
pbar.close()
outputs = [outputs[seq_id] for seq_id in sorted(outputs.keys())]
outputs = [{"text": self.tokenizer.decode(token_ids), "token_ids": token_ids} for token_ids in outputs]
return outputs

View file

@ -0,0 +1,691 @@
import os
import pickle
import torch
import torch.distributed as dist
from multiprocessing.synchronize import Event
from multiprocessing.shared_memory import SharedMemory
import sys
from nanovllm.config import Config
from acestep.debug_utils import debug_start, debug_end
from nanovllm import distributed as dist_utils
# Debug logging - enable with NANOVLLM_DEBUG=1
_DEBUG = os.environ.get("NANOVLLM_DEBUG", "0") == "1"
def _debug_log(msg: str):
"""Print debug message if NANOVLLM_DEBUG is enabled"""
if _DEBUG:
print(f"[nanovllm DEBUG] {msg}", flush=True)
from nanovllm.engine.sequence import Sequence
from nanovllm.models.qwen3 import Qwen3ForCausalLM
from nanovllm.layers.sampler import Sampler
from nanovllm.utils.context import set_context, get_context, reset_context
from nanovllm.utils.loader import load_model
import socket
def find_available_port(start_port: int = 2333, max_attempts: int = 100) -> int:
"""Find an available port starting from start_port.
Args:
start_port: The starting port number to check
max_attempts: Maximum number of ports to try
Returns:
An available port number
Raises:
RuntimeError: If no available port is found within max_attempts
"""
for i in range(max_attempts):
port = start_port + i
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
s.bind(('localhost', port))
return port
except OSError:
# Port is in use, try next one
continue
raise RuntimeError(f"Could not find an available port starting from {start_port} after {max_attempts} attempts")
class ModelRunner:
def __init__(self, config: Config, rank: int, event: Event | list[Event]):
# Enable capturing scalar outputs to avoid graph breaks from Tensor.item() calls
torch._dynamo.config.capture_scalar_outputs = True
self.config = config
hf_config = config.hf_config
self.block_size = config.kvcache_block_size
self.enforce_eager = config.enforce_eager
self.world_size = config.tensor_parallel_size
self.rank = rank
self.event = event
# Only initialize distributed if world_size > 1
if self.world_size > 1:
dist_port = find_available_port()
print(f"[debug]dist_port: {dist_port}")
# Use gloo backend on Windows, nccl on Linux/other platforms
backend = "gloo" if sys.platform == "win32" else "nccl"
dist_utils.initialize_distributed(backend, f"tcp://127.0.0.1:{dist_port}", world_size=self.world_size, rank=rank)
torch.cuda.set_device(rank)
default_dtype = torch.get_default_dtype()
# Detect GPU compute capability to determine bfloat16 support
# Bfloat16 requires Ampere (compute capability >= 8.0) or newer
gpu_props = torch.cuda.get_device_properties(rank)
# Use tuple comparison to handle compute capability correctly
# (e.g., 7.5 < 8.0, 8.0 >= 8.0, 8.6 >= 8.0, etc.)
supports_bfloat16 = (gpu_props.major, gpu_props.minor) >= (8, 0)
# Use dtype instead of deprecated torch_dtype
config_dtype = getattr(hf_config, 'dtype', getattr(hf_config, 'torch_dtype', torch.bfloat16))
# Validate and convert config_dtype to a valid torch floating-point dtype
# Default to bfloat16 for CUDA (required for Flash Attention 2) if GPU supports it
if config_dtype is None:
config_dtype = torch.bfloat16 if supports_bfloat16 else torch.float16
elif isinstance(config_dtype, str):
# Convert string dtype to torch dtype
dtype_map = {
'float32': torch.float32,
'float16': torch.float16,
'bfloat16': torch.bfloat16,
'float64': torch.float64,
'torch.float32': torch.float32,
'torch.float16': torch.float16,
'torch.bfloat16': torch.bfloat16,
'torch.float64': torch.float64,
}
config_dtype = dtype_map.get(config_dtype.lower(), torch.bfloat16 if supports_bfloat16 else torch.float16)
elif not isinstance(config_dtype, torch.dtype) or not config_dtype.is_floating_point:
# If not a valid floating-point torch dtype, default based on GPU capability
config_dtype = torch.bfloat16 if supports_bfloat16 else torch.float16
# Override to float16 if config requested bfloat16 but GPU doesn't support it
if config_dtype == torch.bfloat16 and not supports_bfloat16:
print(f"[nanovllm] GPU {gpu_props.name} (compute capability {gpu_props.major}.{gpu_props.minor}) does not support bfloat16. Using float16 instead.", flush=True)
config_dtype = torch.float16
self.dtype = config_dtype # Save for later use
torch.set_default_dtype(config_dtype)
torch.set_default_device("cuda")
self.model = Qwen3ForCausalLM(hf_config)
_t0 = debug_start("load_model", prefix="tensor.vllm")
load_model(self.model, config.model)
debug_end("load_model", _t0, prefix="tensor.vllm")
self.sampler = Sampler()
# Pre-allocate buffers for sampling (optimization: avoid repeated tensor creation)
# Must be called before warmup_model() since it uses these buffers
self._allocate_sample_buffers()
self.warmup_model()
self.allocate_kv_cache()
if not self.enforce_eager:
self.capture_cudagraph()
torch.set_default_device("cpu")
torch.set_default_dtype(default_dtype)
if self.world_size > 1:
if rank == 0:
self.shm = SharedMemory(name="nanovllm", create=True, size=2**20)
dist_utils.barrier()
else:
dist_utils.barrier()
self.shm = SharedMemory(name="nanovllm")
self.loop()
def _allocate_sample_buffers(self):
"""Pre-allocate reusable buffers for sampling to avoid repeated tensor creation."""
_t0 = debug_start("_allocate_sample_buffers", prefix="tensor.vllm")
max_bs = self.config.max_num_seqs
max_tokens = self.config.max_num_batched_tokens
max_num_blocks = (self.config.max_model_len + self.block_size - 1) // self.block_size
# Pre-allocate pinned memory buffers on CPU for fast transfer
# Must explicitly specify device="cpu" since default device may be "cuda"
self._cpu_temperatures = torch.zeros(max_bs, dtype=torch.float32, device="cpu", pin_memory=True)
self._cpu_cfg_scales = torch.zeros(max_bs, dtype=torch.float32, device="cpu", pin_memory=True)
self._cpu_top_ks = torch.zeros(max_bs, dtype=torch.int32, device="cpu", pin_memory=True)
self._cpu_top_ps = torch.zeros(max_bs, dtype=torch.float32, device="cpu", pin_memory=True)
self._cpu_repetition_penalties = torch.zeros(max_bs, dtype=torch.float32, device="cpu", pin_memory=True)
# Pre-allocate decode buffers on CPU with pinned memory
self._cpu_input_ids = torch.zeros(max_bs, dtype=torch.int64, device="cpu", pin_memory=True)
self._cpu_positions = torch.zeros(max_bs, dtype=torch.int64, device="cpu", pin_memory=True)
self._cpu_slot_mapping = torch.zeros(max_bs, dtype=torch.int32, device="cpu", pin_memory=True)
self._cpu_context_lens = torch.zeros(max_bs, dtype=torch.int32, device="cpu", pin_memory=True)
# Pre-allocate prefill buffers on CPU with pinned memory (optimization to avoid repeated tensor creation)
self._cpu_prefill_input_ids = torch.zeros(max_tokens, dtype=torch.int64, device="cpu", pin_memory=True)
self._cpu_prefill_positions = torch.zeros(max_tokens, dtype=torch.int64, device="cpu", pin_memory=True)
self._cpu_prefill_cu_seqlens = torch.zeros(max_bs + 1, dtype=torch.int32, device="cpu", pin_memory=True)
self._cpu_prefill_slot_mapping = torch.zeros(max_tokens, dtype=torch.int32, device="cpu", pin_memory=True)
# Pre-allocate block tables buffer (shared by both decode and prefill)
self._cpu_block_tables = torch.zeros(max_bs, max_num_blocks, dtype=torch.int32, device="cpu", pin_memory=True)
# Pre-allocate buffer for sequence token IDs (used in logits processor and sampler)
# Max length is max_model_len since sequences can be that long
self._seq_token_ids_buffer = torch.zeros(max_bs, self.config.max_model_len, dtype=torch.int64, device="cpu", pin_memory=True)
debug_end("_allocate_sample_buffers", _t0, prefix="tensor.vllm")
def exit(self):
if self.world_size > 1:
self.shm.close()
dist_utils.barrier()
if self.rank == 0:
self.shm.unlink()
if not self.enforce_eager:
del self.graphs, self.graph_pool
torch.cuda.synchronize()
dist_utils.destroy_process_group()
def loop(self):
while True:
method_name, args = self.read_shm()
self.call(method_name, *args)
if method_name == "exit":
break
def read_shm(self):
assert self.world_size > 1 and self.rank > 0
self.event.wait()
n = int.from_bytes(self.shm.buf[0:4], "little")
method_name, *args = pickle.loads(self.shm.buf[4:n+4])
self.event.clear()
return method_name, args
def write_shm(self, method_name, *args):
assert self.world_size > 1 and self.rank == 0
data = pickle.dumps([method_name, *args])
n = len(data)
self.shm.buf[0:4] = n.to_bytes(4, "little")
self.shm.buf[4:n+4] = data
for event in self.event:
event.set()
def call(self, method_name, *args):
if self.world_size > 1 and self.rank == 0:
self.write_shm(method_name, *args)
method = getattr(self, method_name, None)
return method(*args)
def warmup_model(self):
_t0 = debug_start("warmup_model", prefix="tensor.vllm")
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
max_num_batched_tokens, max_model_len = self.config.max_num_batched_tokens, self.config.max_model_len
num_seqs = min(max_num_batched_tokens // max_model_len, self.config.max_num_seqs)
seqs = [Sequence([0] * max_model_len) for _ in range(num_seqs)]
self.run(seqs, True)
torch.cuda.empty_cache()
debug_end("warmup_model", _t0, prefix="tensor.vllm")
def allocate_kv_cache(self):
_t0 = debug_start("allocate_kv_cache", prefix="tensor.vllm")
config = self.config
hf_config = config.hf_config
free, total = torch.cuda.mem_get_info()
current = torch.cuda.memory_stats()["allocated_bytes.all.current"]
# Account for per-process memory fraction (set via MAX_CUDA_VRAM simulation)
import os as _os
_debug_vram = _os.environ.get("MAX_CUDA_VRAM")
if _debug_vram is not None:
try:
_simulated_gb = float(_debug_vram)
_total_gb = total / (1024 ** 3)
if _simulated_gb < _total_gb:
# Effective total and free are capped by simulation
reserved = torch.cuda.memory_reserved()
total = int(_simulated_gb * (1024 ** 3))
free = max(0, total - reserved)
except (ValueError, TypeError):
pass
num_kv_heads = hf_config.num_key_value_heads // self.world_size
head_dim = getattr(hf_config, "head_dim", hf_config.hidden_size // hf_config.num_attention_heads)
block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * head_dim * self.dtype.itemsize
# Calculate available memory for KV cache
# After warmup_model, empty_cache has been called, so current represents model memory only
# Use free memory but respect the gpu_memory_utilization limit
target_total_usage = total * config.gpu_memory_utilization
available_for_kv_cache = min(free * 0.9, target_total_usage - current)
# Safety check: ensure we leave at least ~1 GB free for DiT inference
# activations that will run after LM generation. Without this, the KV
# cache can consume all free VRAM and cause OOM during DiT forward pass.
MIN_RESERVE_BYTES = int(1.0 * 1024**3) # 1 GB reserved for other models
max_kv_from_free = max(0, free - MIN_RESERVE_BYTES) * 0.9
available_for_kv_cache = min(available_for_kv_cache, max_kv_from_free)
# Ensure we have positive memory available
if available_for_kv_cache <= 0:
available_for_kv_cache = free * 0.5 # Fallback to 50% of free memory
config.num_kvcache_blocks = max(1, int(available_for_kv_cache) // block_bytes)
if config.num_kvcache_blocks <= 0:
raise RuntimeError(
f"Insufficient GPU memory for KV cache. "
f"Free: {free / 1024**3:.2f} GB, Current: {current / 1024**3:.2f} GB, "
f"Available for KV: {available_for_kv_cache / 1024**3:.2f} GB, "
f"Block size: {block_bytes / 1024**2:.2f} MB"
)
max_tokens_capacity = config.num_kvcache_blocks * self.block_size
kv_cache_size_gb = config.num_kvcache_blocks * block_bytes / 1024**3
# If KV cache would leave less than 1 GB free, warn and suggest reducing max_model_len
post_kv_free = (free - config.num_kvcache_blocks * block_bytes) / 1024**3
if post_kv_free < 1.0:
print(
f"[nanovllm] WARNING: After KV cache allocation, only {post_kv_free:.2f} GB free. "
f"DiT inference may OOM. Consider reducing max_model_len or using CPU offload."
)
print(
f"[nanovllm] KV cache allocated: {config.num_kvcache_blocks} blocks × {self.block_size} tokens = "
f"{max_tokens_capacity} tokens capacity, {kv_cache_size_gb:.2f} GB "
f"(free: {free / 1024**3:.2f} GB, used: {current / 1024**3:.2f} GB, "
f"target: {target_total_usage / 1024**3:.2f} GB, block: {block_bytes / 1024**2:.2f} MB, "
f"post_kv_free: {post_kv_free:.2f} GB)"
)
self.kv_cache = torch.empty(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, num_kv_heads, head_dim)
layer_id = 0
for module in self.model.modules():
if hasattr(module, "k_cache") and hasattr(module, "v_cache"):
module.k_cache = self.kv_cache[0, layer_id]
module.v_cache = self.kv_cache[1, layer_id]
layer_id += 1
debug_end("allocate_kv_cache", _t0, prefix="tensor.vllm")
def prepare_block_tables(self, seqs: list[Sequence]):
_t0 = debug_start("prepare_block_tables", prefix="tensor.vllm")
max_len = max(len(seq.block_table) for seq in seqs)
block_tables = [seq.block_table + [-1] * (max_len - len(seq.block_table)) for seq in seqs]
block_tables = torch.tensor(block_tables, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
debug_end("prepare_block_tables", _t0, prefix="tensor.vllm")
return block_tables
def prepare_prefill(self, seqs: list[Sequence]):
_t0 = debug_start("prepare_prefill", prefix="tensor.vllm")
input_ids = []
positions = []
cu_seqlens_q = [0]
cu_seqlens_k = [0]
max_seqlen_q = 0
max_seqlen_k = 0
slot_mapping = []
block_tables = None
for seq in seqs:
seqlen = len(seq)
input_ids.extend(seq[seq.num_cached_tokens:])
positions.extend(list(range(seq.num_cached_tokens, seqlen)))
seqlen_q = seqlen - seq.num_cached_tokens
seqlen_k = seqlen
cu_seqlens_q.append(cu_seqlens_q[-1] + seqlen_q)
cu_seqlens_k.append(cu_seqlens_k[-1] + seqlen_k)
max_seqlen_q = max(seqlen_q, max_seqlen_q)
max_seqlen_k = max(seqlen_k, max_seqlen_k)
if not seq.block_table: # warmup
continue
for i in range(seq.num_cached_blocks, seq.num_blocks):
start = seq.block_table[i] * self.block_size
if i != seq.num_blocks - 1:
end = start + self.block_size
else:
end = start + seq.last_block_num_tokens
slot_mapping.extend(list(range(start, end)))
if cu_seqlens_k[-1] > cu_seqlens_q[-1]: # prefix cache
block_tables = self.prepare_block_tables(seqs)
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
cu_seqlens_q = torch.tensor(cu_seqlens_q, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
cu_seqlens_k = torch.tensor(cu_seqlens_k, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
set_context(True, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, None, block_tables)
debug_end("prepare_prefill", _t0, prefix="tensor.vllm")
return input_ids, positions
def prepare_decode(self, seqs: list[Sequence]):
"""Optimized decode preparation using pre-allocated buffers."""
_t0 = debug_start("prepare_decode", prefix="tensor.vllm")
bs = len(seqs)
# Use pre-allocated CPU buffers
for i, seq in enumerate(seqs):
self._cpu_input_ids[i] = seq.last_token
self._cpu_positions[i] = len(seq) - 1
self._cpu_context_lens[i] = len(seq)
self._cpu_slot_mapping[i] = seq.block_table[-1] * self.block_size + seq.last_block_num_tokens - 1
# Transfer to GPU using sliced views
input_ids = self._cpu_input_ids[:bs].cuda(non_blocking=True)
positions = self._cpu_positions[:bs].cuda(non_blocking=True)
slot_mapping = self._cpu_slot_mapping[:bs].cuda(non_blocking=True)
context_lens = self._cpu_context_lens[:bs].cuda(non_blocking=True)
block_tables = self.prepare_block_tables(seqs)
set_context(False, slot_mapping=slot_mapping, context_lens=context_lens, block_tables=block_tables)
debug_end("prepare_decode", _t0, prefix="tensor.vllm")
return input_ids, positions
def prepare_sample(self, seqs: list[Sequence], is_cfg_batch: bool = False):
"""Optimized sample preparation using pre-allocated buffers."""
_t0 = debug_start("prepare_sample", prefix="tensor.vllm")
if is_cfg_batch:
num_seqs = len(seqs) // 2
target_seqs = seqs[:num_seqs]
else:
num_seqs = len(seqs)
target_seqs = seqs
# Fill pre-allocated CPU buffers
top_ks_is_zero = True
top_ps_is_one = True
repetition_penalties_is_one = True
for i, seq in enumerate(target_seqs):
self._cpu_temperatures[i] = seq.temperature
self._cpu_cfg_scales[i] = seq.cfg_scale
self._cpu_top_ks[i] = seq.top_k if seq.top_k is not None else 0
if seq.top_k is not None and seq.top_k > 0:
top_ks_is_zero = False
self._cpu_top_ps[i] = seq.top_p if seq.top_p is not None else 1.0
if seq.top_p is not None and seq.top_p == 1.0:
top_ps_is_one = False
self._cpu_repetition_penalties[i] = seq.repetition_penalty if seq.repetition_penalty is not None else 1.0
if seq.repetition_penalty is not None and seq.repetition_penalty == 1.0:
repetition_penalties_is_one = False
# Transfer to GPU using sliced views (single batched transfer)
temperatures = self._cpu_temperatures[:num_seqs].cuda(non_blocking=True)
cfg_scales = self._cpu_cfg_scales[:num_seqs].cuda(non_blocking=True)
top_ks = self._cpu_top_ks[:num_seqs].cuda(non_blocking=True) if not top_ks_is_zero else None
top_ps = self._cpu_top_ps[:num_seqs].cuda(non_blocking=True) if not top_ps_is_one else None
repetition_penalties = self._cpu_repetition_penalties[:num_seqs].cuda(non_blocking=True) if not repetition_penalties_is_one else None
debug_end("prepare_sample", _t0, prefix="tensor.vllm")
return temperatures, cfg_scales, top_ks, top_ps, repetition_penalties
@torch.inference_mode()
def run_model(self, input_ids: torch.Tensor, positions: torch.Tensor, is_prefill: bool):
_t0 = debug_start("run_model", prefix="tensor.vllm")
if is_prefill or self.enforce_eager or input_ids.size(0) > 512:
_debug_log(f"run_model: eager mode, is_prefill={is_prefill}, bs={input_ids.size(0)}")
out = self.model.compute_logits(self.model(input_ids, positions))
debug_end("run_model", _t0, prefix="tensor.vllm")
return out
else:
bs = input_ids.size(0)
context = get_context()
_debug_log(f"run_model: decode mode, bs={bs}")
_debug_log(f" context.block_tables.shape={context.block_tables.shape}")
_debug_log(f" context.slot_mapping.shape={context.slot_mapping.shape}")
_debug_log(f" context.context_lens.shape={context.context_lens.shape}")
_debug_log(f" context.slot_mapping={context.slot_mapping.tolist()}")
_debug_log(f" context.context_lens={context.context_lens.tolist()}")
# Check if block_tables size exceeds pre-allocated buffer size
# This can happen when conditional and unconditional sequences have different lengths
# in CFG mode, causing block_tables to have more columns than expected
max_num_blocks = self.graph_vars["block_tables"].size(1)
if context.block_tables.size(1) > max_num_blocks:
# Fall back to eager mode when block_tables is too large for CUDA graph
_debug_log(f" fallback: block_tables cols {context.block_tables.size(1)} > max {max_num_blocks}")
out = self.model.compute_logits(self.model(input_ids, positions))
debug_end("run_model", _t0, prefix="tensor.vllm")
return out
# Fix: Also check if block_tables row count matches batch size
# Dimension mismatch can cause CUDA illegal memory access during graph replay
if context.block_tables.size(0) != bs:
# Fall back to eager mode when block_tables row count doesn't match batch size
_debug_log(f" fallback: block_tables rows {context.block_tables.size(0)} != bs {bs}")
out = self.model.compute_logits(self.model(input_ids, positions))
debug_end("run_model", _t0, prefix="tensor.vllm")
return out
# Fix: Verify slot_mapping and context_lens dimensions match batch size
if context.slot_mapping.size(0) != bs or context.context_lens.size(0) != bs:
# Fall back to eager mode when dimensions don't match
_debug_log(f" fallback: slot_mapping/context_lens size mismatch")
out = self.model.compute_logits(self.model(input_ids, positions))
debug_end("run_model", _t0, prefix="tensor.vllm")
return out
# Validate block_tables values
if _DEBUG:
max_block_id = context.block_tables.max().item()
min_block_id = context.block_tables[context.block_tables >= 0].min().item() if (context.block_tables >= 0).any() else -1
_debug_log(f" block_tables range: [{min_block_id}, {max_block_id}]")
_debug_log(f" num_kvcache_blocks: {self.config.num_kvcache_blocks}")
if max_block_id >= self.config.num_kvcache_blocks:
_debug_log(f" WARNING: block_table contains invalid block_id {max_block_id} >= {self.config.num_kvcache_blocks}")
graph = self.graphs[next(x for x in self.graph_bs if x >= bs)]
graph_vars = self.graph_vars
graph_vars["input_ids"][:bs] = input_ids
graph_vars["positions"][:bs] = positions
graph_vars["slot_mapping"].fill_(-1)
graph_vars["slot_mapping"][:bs] = context.slot_mapping
graph_vars["context_lens"].zero_()
graph_vars["context_lens"][:bs] = context.context_lens
# Clear block_tables first to ensure no stale data from previous runs
graph_vars["block_tables"][:bs].fill_(-1)
graph_vars["block_tables"][:bs, :context.block_tables.size(1)] = context.block_tables
_debug_log(f" executing CUDA graph replay for bs={bs}")
graph.replay()
out = self.model.compute_logits(graph_vars["outputs"][:bs])
debug_end("run_model", _t0, prefix="tensor.vllm")
return out
def run(self, seqs: list[Sequence], is_prefill: bool) -> list[int]:
"""Run model forward and sampling. For CFG sequences, batch is structured as:
[cond_seq1, cond_seq2, ..., uncond_seq1, uncond_seq2, ...]
where uncond_seqi is the paired unconditional sequence of cond_seqi."""
_debug_log(f"run: num_seqs={len(seqs)}, is_prefill={is_prefill}")
for i, seq in enumerate(seqs):
_debug_log(f" seq[{i}]: len={len(seq)}, num_blocks={seq.num_blocks}, "
f"cfg_scale={seq.cfg_scale}, is_uncond={seq.is_unconditional}, "
f"block_table={seq.block_table}")
# Check if this is a CFG batch (contains paired conditional and unconditional sequences)
is_cfg_batch = seqs[0].cfg_scale > 1.0 and seqs[0].paired_seq is not None
_debug_log(f" is_cfg_batch={is_cfg_batch}")
if is_cfg_batch:
# CFG batch: seqs = [cond_seq1, cond_seq2, ..., uncond_seq1, uncond_seq2, ...]
num_cond = len(seqs) // 2
cond_seqs = seqs[:num_cond]
# uncond_seqs = seqs[num_cond:]
# Prepare inputs for both conditional and unconditional (they're already in the batch)
input_ids, positions = (self.prepare_prefill(seqs) if is_prefill else self.prepare_decode(seqs))
sample_params = self.prepare_sample(seqs, is_cfg_batch=True) if self.rank == 0 else None
if sample_params is not None:
temperatures, cfg_scales, top_ks, top_ps, repetition_penalties = sample_params
else:
temperatures = cfg_scales = top_ks = top_ps = repetition_penalties = None
# Run model forward (processes entire batch: cond + uncond)
logits_all = self.run_model(input_ids, positions, is_prefill)
reset_context()
if self.rank == 0:
# Split logits: first half is conditional, second half is unconditional
logits_cond = logits_all[:num_cond]
logits_uncond = logits_all[num_cond:]
# Apply repetition penalty to conditional logits (before CFG)
if repetition_penalties is not None:
for i, seq in enumerate(cond_seqs):
penalty = repetition_penalties[i].item()
if penalty != 1.0:
# Only penalize completion tokens (not prompt tokens)
completion_tokens = torch.tensor(seq.completion_token_ids, device=logits_cond.device)
if len(completion_tokens) > 0:
# Create token mask: mark tokens that appeared in completion
token_mask = torch.zeros(logits_cond.shape[1], dtype=torch.bool, device=logits_cond.device)
token_mask[completion_tokens] = True
# Apply standard repetition penalty formula (matching transformers implementation):
# For tokens in completion: if score < 0 then score * penalty, else score / penalty
penalty_scores = torch.where(
logits_cond[i] < 0,
logits_cond[i] * penalty,
logits_cond[i] / penalty
)
# Only apply penalty to tokens that appeared in completion
logits_cond[i] = torch.where(token_mask, penalty_scores, logits_cond[i])
# Apply CFG formula: logits_cfg = logits_uncond + cfg_scale * (logits_cond - logits_uncond)
cfg_scales_tensor = cfg_scales.unsqueeze(1) # [num_cond, 1]
logits_cfg = logits_uncond + cfg_scales_tensor * (logits_cond - logits_uncond)
# Apply logits processor for constrained decoding (if any sequence has one)
for i, seq in enumerate(cond_seqs):
if seq.logits_processor is not None:
# Create input_ids tensor for this sequence
seq_input_ids = torch.tensor([seq.token_ids], device=logits_cfg.device)
# Apply processor to this sequence's logits
logits_cfg[i:i+1] = seq.logits_processor(seq_input_ids, logits_cfg[i:i+1])
# Prepare input_ids for sampler (for repetition penalty, though we already applied it)
# cond_input_ids = torch.tensor([seq.token_ids for seq in cond_seqs], device=logits_cfg.device)
# Sample from CFG logits
token_ids_cfg = self.sampler(
logits_cfg,
temperatures,
top_ks=top_ks if top_ks is not None else None,
top_ps=top_ps if top_ps is not None else None,
repetition_penalties=None, # Already applied above
# input_ids=cond_input_ids,
).tolist()
# Update logits processor state after sampling
# NOTE: Only update for the first sequence since all sequences share the same processor
# Updating multiple times would cause duplicate state updates (e.g., codes_count += N instead of += 1)
if cond_seqs and cond_seqs[0].logits_processor_update_state is not None:
cond_seqs[0].logits_processor_update_state(token_ids_cfg[0])
# Return token_ids (will be applied to both conditional and unconditional sequences)
return token_ids_cfg
else:
return None
else:
# Normal batch (non-CFG)
input_ids, positions = (self.prepare_prefill(seqs) if is_prefill
else self.prepare_decode(seqs))
sample_params = self.prepare_sample(seqs, is_cfg_batch=False) if self.rank == 0 else None
if sample_params is not None:
temperatures, cfg_scales, top_ks, top_ps, repetition_penalties = sample_params
else:
temperatures = cfg_scales = top_ks = top_ps = repetition_penalties = None
logits = self.run_model(input_ids, positions, is_prefill)
reset_context()
if self.rank == 0:
# Apply repetition penalty to logits
if repetition_penalties is not None:
for i, seq in enumerate(seqs):
penalty = repetition_penalties[i].item()
if penalty != 1.0:
# Only penalize completion tokens (not prompt tokens)
completion_tokens = torch.tensor(seq.completion_token_ids, device=logits.device)
if len(completion_tokens) > 0:
# Create token mask: mark tokens that appeared in completion
token_mask = torch.zeros(logits.shape[1], dtype=torch.bool, device=logits.device)
token_mask[completion_tokens] = True
# Apply standard repetition penalty formula (matching transformers implementation):
# For tokens in completion: if score < 0 then score * penalty, else score / penalty
penalty_scores = torch.where(
logits[i] < 0,
logits[i] * penalty,
logits[i] / penalty
)
# Only apply penalty to tokens that appeared in completion
logits[i] = torch.where(token_mask, penalty_scores, logits[i])
# Apply logits processor for constrained decoding (if any sequence has one)
# Clone logits to avoid in-place update issues in inference mode
logits = logits.clone()
for i, seq in enumerate(seqs):
if seq.logits_processor is not None:
# Create input_ids tensor for this sequence
seq_input_ids = torch.tensor([seq.token_ids], device=logits.device)
# Apply processor to this sequence's logits (clone to avoid inference mode issues)
processed = seq.logits_processor(seq_input_ids, logits[i:i+1].clone())
logits[i] = processed[0]
# Prepare input_ids for sampler
# seq_input_ids = torch.tensor([seq.token_ids for seq in seqs], device=logits.device)
token_ids = self.sampler(
logits,
temperatures,
top_ks=top_ks if top_ks is not None else None,
top_ps=top_ps if top_ps is not None else None,
repetition_penalties=None, # Already applied above
# input_ids=seq_input_ids,
).tolist()
# Update logits processor state after sampling
# NOTE: Only update for the first sequence since all sequences may share the same processor
# (when using a single SamplingParams for batch generation)
# Updating multiple times would cause duplicate state updates (e.g., codes_count += N instead of += 1)
if seqs and seqs[0].logits_processor_update_state is not None:
seqs[0].logits_processor_update_state(token_ids[0])
return token_ids
else:
return None
@torch.inference_mode()
def capture_cudagraph(self):
_t0 = debug_start("capture_cudagraph", prefix="tensor.vllm")
config = self.config
hf_config = config.hf_config
max_bs = min(self.config.max_num_seqs, 512)
max_num_blocks = (config.max_model_len + self.block_size - 1) // self.block_size
input_ids = torch.zeros(max_bs, dtype=torch.int64)
positions = torch.zeros(max_bs, dtype=torch.int64)
slot_mapping = torch.zeros(max_bs, dtype=torch.int32)
context_lens = torch.zeros(max_bs, dtype=torch.int32)
block_tables = torch.zeros(max_bs, max_num_blocks, dtype=torch.int32)
outputs = torch.zeros(max_bs, hf_config.hidden_size)
self.graph_bs = [1, 2, 4, 8] + list(range(16, max_bs + 1, 16))
self.graphs = {}
self.graph_pool = None
for bs in reversed(self.graph_bs):
graph = torch.cuda.CUDAGraph()
set_context(False, slot_mapping=slot_mapping[:bs], context_lens=context_lens[:bs], block_tables=block_tables[:bs])
outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # warmup
with torch.cuda.graph(graph, self.graph_pool):
outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # capture
if self.graph_pool is None:
self.graph_pool = graph.pool()
self.graphs[bs] = graph
torch.cuda.synchronize()
reset_context()
self.graph_vars = dict(
input_ids=input_ids,
positions=positions,
slot_mapping=slot_mapping,
context_lens=context_lens,
block_tables=block_tables,
outputs=outputs,
)
debug_end("capture_cudagraph", _t0, prefix="tensor.vllm")

View file

@ -0,0 +1,272 @@
import os
from collections import deque
from nanovllm.config import Config
from nanovllm.engine.sequence import Sequence, SequenceStatus
from nanovllm.engine.block_manager import BlockManager
# Debug logging - enable with NANOVLLM_DEBUG=1
_DEBUG = os.environ.get("NANOVLLM_DEBUG", "0") == "1"
def _debug_log(msg: str):
"""Print debug message if NANOVLLM_DEBUG is enabled"""
if _DEBUG:
print(f"[nanovllm scheduler DEBUG] {msg}", flush=True)
class Scheduler:
def __init__(self, config: Config):
self.max_num_seqs = config.max_num_seqs
self.max_num_batched_tokens = config.max_num_batched_tokens
self.eos = config.eos
self.block_manager = BlockManager(config.num_kvcache_blocks, config.kvcache_block_size)
self.waiting: deque[Sequence] = deque()
self.running: deque[Sequence] = deque()
def is_finished(self):
return not self.waiting and not self.running
def add(self, seq: Sequence):
self.waiting.append(seq)
def schedule(self) -> tuple[list[Sequence], bool]:
_debug_log(f"schedule: waiting={len(self.waiting)}, running={len(self.running)}, "
f"free_blocks={len(self.block_manager.free_block_ids)}")
# prefill
scheduled_seqs = []
num_seqs = 0
num_batched_tokens = 0
processed_seqs = set() # Track processed sequences to handle CFG pairs
while self.waiting and num_seqs < self.max_num_seqs:
seq = self.waiting[0]
# For CFG sequences, ensure conditional and unconditional are scheduled together
if seq.cfg_scale > 1.0 and seq.paired_seq is not None and not seq.is_unconditional:
# This is a conditional sequence, need to schedule its paired unconditional sequence too
paired_seq = seq.paired_seq
if paired_seq.status != SequenceStatus.WAITING:
# Paired sequence not in waiting, skip this conditional sequence for now
break
# Calculate tokens for both sequences
total_tokens = (len(seq) - seq.num_cached_tokens) + (len(paired_seq) - paired_seq.num_cached_tokens)
# FIX: Check if we have enough blocks for BOTH sequences combined
# The old check was wrong: it checked each sequence independently,
# but didn't account for the total blocks needed by both
total_blocks_needed = seq.num_blocks + paired_seq.num_blocks
can_allocate_both = len(self.block_manager.free_block_ids) >= total_blocks_needed
if num_batched_tokens + total_tokens > self.max_num_batched_tokens or not can_allocate_both:
break
# Schedule both sequences: conditional first, then unconditional
for s in [seq, paired_seq]:
num_seqs += 1
self.block_manager.allocate(s)
num_batched_tokens += len(s) - s.num_cached_tokens
s.status = SequenceStatus.RUNNING
self.waiting.remove(s)
self.running.append(s)
scheduled_seqs.append(s)
processed_seqs.add(s.seq_id)
else:
# Normal sequence or unconditional sequence (already processed with its conditional)
if seq.seq_id in processed_seqs:
# Skip if already processed as part of a CFG pair
self.waiting.popleft()
continue
if num_batched_tokens + len(seq) > self.max_num_batched_tokens or not self.block_manager.can_allocate(seq):
break
num_seqs += 1
self.block_manager.allocate(seq)
num_batched_tokens += len(seq) - seq.num_cached_tokens
seq.status = SequenceStatus.RUNNING
self.waiting.popleft()
self.running.append(seq)
scheduled_seqs.append(seq)
if scheduled_seqs:
# For CFG batches, ensure conditional sequences come before their unconditional pairs
cfg_cond_seqs = [s for s in scheduled_seqs if s.cfg_scale > 1.0 and not s.is_unconditional]
cfg_uncond_seqs = [s for s in scheduled_seqs if s.is_unconditional]
non_cfg_seqs = [s for s in scheduled_seqs if s.cfg_scale <= 1.0]
# Reorder: non-CFG, then CFG conditional, then CFG unconditional
scheduled_seqs = non_cfg_seqs + cfg_cond_seqs + cfg_uncond_seqs
return scheduled_seqs, True
# decode
processed_seqs = set()
temp_running = list(self.running) # Work with a copy
while temp_running and num_seqs < self.max_num_seqs:
seq = temp_running.pop(0)
# For CFG sequences, ensure conditional and unconditional are scheduled together
if seq.cfg_scale > 1.0 and seq.paired_seq is not None and not seq.is_unconditional:
paired_seq = seq.paired_seq
if paired_seq not in temp_running:
# Paired sequence not available, skip for now
continue
# Remove paired_seq from temp_running
temp_running.remove(paired_seq)
# FIX: Check if we have enough blocks for BOTH sequences to append
# Each sequence needs 1 block when at block boundary (len % block_size == 1)
block_size = self.block_manager.block_size
blocks_needed_seq = 1 if len(seq) % block_size == 1 else 0
blocks_needed_paired = 1 if len(paired_seq) % block_size == 1 else 0
total_blocks_needed = blocks_needed_seq + blocks_needed_paired
can_append_both = len(self.block_manager.free_block_ids) >= total_blocks_needed
if not can_append_both:
# Try preempting other sequences
preempted = False
while not can_append_both and temp_running:
other_seq = temp_running.pop(0)
if other_seq != seq and other_seq != paired_seq:
self.preempt(other_seq)
# Recalculate with the same correct logic
can_append_both = len(self.block_manager.free_block_ids) >= total_blocks_needed
preempted = True
else:
temp_running.append(other_seq)
break
if not can_append_both:
# Can't schedule this pair right now
temp_running.append(seq)
temp_running.append(paired_seq)
continue
# Schedule both sequences
for s in [seq, paired_seq]:
num_seqs += 1
self.block_manager.may_append(s)
scheduled_seqs.append(s)
processed_seqs.add(s.seq_id)
# Remove from actual running list if scheduled
if s in self.running:
self.running.remove(s)
else:
# Normal sequence or unconditional (already processed)
if seq.seq_id in processed_seqs:
continue
while not self.block_manager.can_append(seq):
if temp_running:
other_seq = temp_running.pop(0)
if other_seq != seq:
self.preempt(other_seq)
else:
temp_running.append(other_seq)
break
else:
self.preempt(seq)
if seq in self.running:
self.running.remove(seq)
break
else:
num_seqs += 1
self.block_manager.may_append(seq)
scheduled_seqs.append(seq)
if seq in self.running:
self.running.remove(seq)
if not scheduled_seqs:
# No sequences could be scheduled - provide informative error
waiting_count = len(self.waiting)
running_count = len(self.running)
free_blocks = len(self.block_manager.free_block_ids)
total_blocks = len(self.block_manager.blocks)
if waiting_count > 0:
seq = self.waiting[0]
blocks_needed = seq.num_blocks
prompt_tokens = len(seq)
if seq.cfg_scale > 1.0 and seq.paired_seq is not None:
blocks_needed += seq.paired_seq.num_blocks
prompt_tokens = f"{len(seq)}+{len(seq.paired_seq)}"
raise RuntimeError(
f"Insufficient KV cache to schedule sequence. "
f"Free blocks: {free_blocks}/{total_blocks}, blocks needed: {blocks_needed}, "
f"prompt tokens: {prompt_tokens}, block size: {self.block_manager.block_size}. "
f"The prompt may be too long for available GPU memory, or gpu_memory_utilization is too low."
)
else:
raise RuntimeError(
f"No schedulable sequences found. "
f"Waiting: {waiting_count}, Running: {running_count}, "
f"Free blocks: {free_blocks}/{total_blocks}"
)
# For CFG batches in decode, ensure conditional sequences come before unconditional
cfg_cond_seqs = [s for s in scheduled_seqs if s.cfg_scale > 1.0 and not s.is_unconditional]
cfg_uncond_seqs = [s for s in scheduled_seqs if s.is_unconditional]
non_cfg_seqs = [s for s in scheduled_seqs if s.cfg_scale <= 1.0]
scheduled_seqs = non_cfg_seqs + cfg_cond_seqs + cfg_uncond_seqs
self.running.extendleft(reversed(scheduled_seqs))
return scheduled_seqs, False
def preempt(self, seq: Sequence):
seq.status = SequenceStatus.WAITING
self.block_manager.deallocate(seq)
self.waiting.appendleft(seq)
def postprocess(self, seqs: list[Sequence], token_ids: list[int]) -> list[bool]:
_debug_log(f"postprocess: num_seqs={len(seqs)}, num_token_ids={len(token_ids) if token_ids else 0}")
if token_ids:
_debug_log(f" token_ids: {token_ids[:10]}..." if len(token_ids) > 10 else f" token_ids: {token_ids}")
# Check if this is a CFG batch
is_cfg_batch = False
if len(seqs) > 0 and seqs[0].cfg_scale > 1.0 and seqs[0].paired_seq is not None:
num_cond = len(seqs) // 2
is_cfg_batch = (num_cond > 0 and
not seqs[0].is_unconditional and
seqs[num_cond].is_unconditional)
_debug_log(f" is_cfg_batch={is_cfg_batch}")
if is_cfg_batch:
# CFG batch: seqs = [cond_seq1, cond_seq2, ..., uncond_seq1, uncond_seq2, ...]
# token_ids correspond to conditional sequences only (sampled from CFG logits)
num_cond = len(seqs) // 2
cond_seqs = seqs[:num_cond]
uncond_seqs = seqs[num_cond:]
# Apply the same sampled token to both conditional and unconditional sequences
for i, (cond_seq, uncond_seq, token_id) in enumerate(zip(cond_seqs, uncond_seqs, token_ids)):
cond_seq.append_token(token_id)
uncond_seq.append_token(token_id) # Same token for unconditional
# Check if either sequence is finished
cond_finished = ((not cond_seq.ignore_eos and token_id == self.eos) or
cond_seq.num_completion_tokens == cond_seq.max_tokens)
uncond_finished = ((not uncond_seq.ignore_eos and token_id == self.eos) or
uncond_seq.num_completion_tokens == uncond_seq.max_tokens)
if cond_finished or uncond_finished:
# Mark both as finished
cond_seq.status = SequenceStatus.FINISHED
uncond_seq.status = SequenceStatus.FINISHED
self.block_manager.deallocate(cond_seq)
self.block_manager.deallocate(uncond_seq)
if cond_seq in self.running:
self.running.remove(cond_seq)
if uncond_seq in self.running:
self.running.remove(uncond_seq)
else:
# Normal batch
for seq, token_id in zip(seqs, token_ids):
seq.append_token(token_id)
if (not seq.ignore_eos and token_id == self.eos) or seq.num_completion_tokens == seq.max_tokens:
seq.status = SequenceStatus.FINISHED
self.block_manager.deallocate(seq)
self.running.remove(seq)

View file

@ -0,0 +1,96 @@
from copy import copy
from enum import Enum, auto
from itertools import count
from typing import Optional, Callable, Any
from nanovllm.sampling_params import SamplingParams
class SequenceStatus(Enum):
WAITING = auto()
RUNNING = auto()
FINISHED = auto()
class Sequence:
block_size = 256
counter = count()
def __init__(self, token_ids: list[int], sampling_params = SamplingParams(), is_unconditional: bool = False, conditional_seq = None):
self.seq_id = next(Sequence.counter)
self.status = SequenceStatus.WAITING
self.token_ids = copy(token_ids)
self.last_token = token_ids[-1]
self.num_tokens = len(self.token_ids)
self.num_prompt_tokens = len(token_ids)
self.num_cached_tokens = 0
self.block_table = []
self.temperature = sampling_params.temperature
self.max_tokens = sampling_params.max_tokens
self.ignore_eos = sampling_params.ignore_eos
self.cfg_scale = sampling_params.cfg_scale
self.top_k = sampling_params.top_k
self.top_p = sampling_params.top_p
self.repetition_penalty = sampling_params.repetition_penalty
# For CFG: mark if this is an unconditional sequence
self.is_unconditional = is_unconditional
# For CFG: reference to the corresponding conditional sequence (if this is unconditional)
# For conditional sequences, this points to the unconditional sequence
self.paired_seq = conditional_seq # For conditional seq, points to uncond; for uncond seq, points to cond
# For constrained decoding: logits processor and state update callback
self.logits_processor: Optional[Any] = sampling_params.logits_processor
self.logits_processor_update_state: Optional[Callable[[int], None]] = sampling_params.logits_processor_update_state
def __len__(self):
return self.num_tokens
def __getitem__(self, key):
return self.token_ids[key]
@property
def is_finished(self):
return self.status == SequenceStatus.FINISHED
@property
def num_completion_tokens(self):
return self.num_tokens - self.num_prompt_tokens
@property
def prompt_token_ids(self):
return self.token_ids[:self.num_prompt_tokens]
@property
def completion_token_ids(self):
return self.token_ids[self.num_prompt_tokens:]
@property
def num_cached_blocks(self):
return self.num_cached_tokens // self.block_size
@property
def num_blocks(self):
return (self.num_tokens + self.block_size - 1) // self.block_size
@property
def last_block_num_tokens(self):
return self.num_tokens - (self.num_blocks - 1) * self.block_size
def block(self, i):
assert 0 <= i < self.num_blocks
return self.token_ids[i*self.block_size: (i+1)*self.block_size]
def append_token(self, token_id: int):
self.token_ids.append(token_id)
self.last_token = token_id
self.num_tokens += 1
def __getstate__(self):
return (self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.block_table,
self.token_ids if self.num_completion_tokens == 0 else self.last_token)
def __setstate__(self, state):
self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.block_table = state[:-1]
if self.num_completion_tokens == 0:
self.token_ids = state[-1]
else:
self.last_token = state[-1]

View file

@ -0,0 +1,16 @@
import torch
from torch import nn
import torch.nn.functional as F
from nanovllm.utils.compat import maybe_compile
class SiluAndMul(nn.Module):
def __init__(self):
super().__init__()
@maybe_compile
def forward(self, x: torch.Tensor) -> torch.Tensor:
x, y = x.chunk(2, -1)
return F.silu(x) * y

View file

@ -0,0 +1,387 @@
import os
import torch
from torch import nn
import torch.nn.functional as F
from nanovllm.utils.context import get_context
# Debug logging - enable with NANOVLLM_DEBUG=1
_DEBUG = os.environ.get("NANOVLLM_DEBUG", "0") == "1"
def _debug_log(msg: str):
"""Print debug message if NANOVLLM_DEBUG is enabled"""
if _DEBUG:
print(f"[nanovllm attention DEBUG] {msg}", flush=True)
# Optional dependencies: Triton (for KV cache kernel) and Flash Attention
_HAS_TRITON = False
_HAS_FLASH_ATTN = False
try:
import triton
import triton.language as tl
_HAS_TRITON = True
except ImportError:
pass
try:
from flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
_HAS_FLASH_ATTN = True
except ImportError:
pass
# ============================================================
# Triton KV cache store kernel (original, used when available)
# ============================================================
if _HAS_TRITON:
@triton.jit
def store_kvcache_kernel(
key_ptr,
key_stride,
value_ptr,
value_stride,
k_cache_ptr,
v_cache_ptr,
slot_mapping_ptr,
D: tl.constexpr,
):
idx = tl.program_id(0)
slot = tl.load(slot_mapping_ptr + idx)
if slot == -1: return
key_offsets = idx * key_stride + tl.arange(0, D)
value_offsets = idx * value_stride + tl.arange(0, D)
key = tl.load(key_ptr + key_offsets)
value = tl.load(value_ptr + value_offsets)
cache_offsets = slot * D + tl.arange(0, D)
tl.store(k_cache_ptr + cache_offsets, key)
tl.store(v_cache_ptr + cache_offsets, value)
# ============================================================
# Pure PyTorch KV cache store (fallback when Triton unavailable)
# ============================================================
def _store_kvcache_pytorch(
key: torch.Tensor,
value: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
slot_mapping: torch.Tensor,
):
"""Store key/value into paged KV cache using pure PyTorch ops.
Args:
key: [N, num_kv_heads, head_dim]
value: [N, num_kv_heads, head_dim]
k_cache: [num_blocks, block_size, num_kv_heads, head_dim] (per-layer view)
v_cache: [num_blocks, block_size, num_kv_heads, head_dim]
slot_mapping: [N] - flat slot indices into cache
"""
N, num_kv_heads, head_dim = key.shape
D = num_kv_heads * head_dim
# View cache as flat [total_slots, D]
k_flat = k_cache.reshape(-1, D)
v_flat = v_cache.reshape(-1, D)
# View keys/values as [N, D]
key_flat = key.reshape(N, D)
value_flat = value.reshape(N, D)
# Filter out padding slots (slot == -1)
valid_mask = slot_mapping != -1
valid_slots = slot_mapping[valid_mask]
k_flat[valid_slots] = key_flat[valid_mask]
v_flat[valid_slots] = value_flat[valid_mask]
def store_kvcache(
key: torch.Tensor,
value: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
slot_mapping: torch.Tensor,
):
"""Store key/value into paged KV cache. Uses Triton kernel when available."""
if _HAS_TRITON:
N, num_heads, head_dim = key.shape
D = num_heads * head_dim
assert key.stride(-1) == 1 and value.stride(-1) == 1
assert key.stride(1) == head_dim and value.stride(1) == head_dim
assert k_cache.stride(1) == D and v_cache.stride(1) == D
assert slot_mapping.numel() == N
store_kvcache_kernel[(N,)](key, key.stride(0), value, value.stride(0), k_cache, v_cache, slot_mapping, D)
else:
_store_kvcache_pytorch(key, value, k_cache, v_cache, slot_mapping)
# ============================================================
# SDPA-based attention (fallback when Flash Attention unavailable)
# ============================================================
def _sdpa_varlen_prefill(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
scale: float,
num_heads: int,
num_kv_heads: int,
) -> torch.Tensor:
"""SDPA replacement for flash_attn_varlen_func during prefill.
Splits packed sequences, runs SDPA per sequence with causal masking,
then re-packs. Handles GQA via enable_gqa when heads differ.
Args:
q: [total_q_tokens, num_heads, head_dim]
k: [total_k_tokens, num_kv_heads, head_dim]
v: [total_k_tokens, num_kv_heads, head_dim]
cu_seqlens_q: [num_seqs + 1] cumulative sequence lengths for queries
cu_seqlens_k: [num_seqs + 1] cumulative sequence lengths for keys
scale: attention scale factor
num_heads: number of query heads
num_kv_heads: number of KV heads
Returns:
output: [total_q_tokens, num_heads, head_dim]
"""
num_seqs = cu_seqlens_q.shape[0] - 1
outputs = []
enable_gqa = num_heads != num_kv_heads
for i in range(num_seqs):
q_start = cu_seqlens_q[i].item()
q_end = cu_seqlens_q[i + 1].item()
k_start = cu_seqlens_k[i].item()
k_end = cu_seqlens_k[i + 1].item()
# [seq_len, heads, dim] -> [1, heads, seq_len, dim]
qi = q[q_start:q_end].unsqueeze(0).transpose(1, 2)
ki = k[k_start:k_end].unsqueeze(0).transpose(1, 2)
vi = v[k_start:k_end].unsqueeze(0).transpose(1, 2)
oi = F.scaled_dot_product_attention(
qi, ki, vi, scale=scale, is_causal=True, enable_gqa=enable_gqa
)
# [1, heads, seq_len, dim] -> [seq_len, heads, dim]
outputs.append(oi.transpose(1, 2).squeeze(0))
return torch.cat(outputs, dim=0)
def _sdpa_prefill_with_paged_cache(
q: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
block_tables: torch.Tensor,
scale: float,
num_heads: int,
num_kv_heads: int,
) -> torch.Tensor:
"""SDPA prefill with paged KV cache (prefix caching case).
Args:
q: [total_q_tokens, num_heads, head_dim]
k_cache: [num_blocks, block_size, num_kv_heads, head_dim]
v_cache: [num_blocks, block_size, num_kv_heads, head_dim]
cu_seqlens_q: [num_seqs + 1]
cu_seqlens_k: [num_seqs + 1]
block_tables: [num_seqs, max_blocks_per_seq]
scale: attention scale factor
num_heads: number of query heads
num_kv_heads: number of KV heads
Returns:
output: [total_q_tokens, num_heads, head_dim]
"""
block_size = k_cache.shape[1]
num_seqs = cu_seqlens_q.shape[0] - 1
outputs = []
enable_gqa = num_heads != num_kv_heads
for i in range(num_seqs):
q_start = cu_seqlens_q[i].item()
q_end = cu_seqlens_q[i + 1].item()
k_len = cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item()
# Gather k/v from paged cache
num_blocks_needed = (k_len + block_size - 1) // block_size
block_indices = block_tables[i, :num_blocks_needed]
ki = k_cache[block_indices].reshape(-1, num_kv_heads, k_cache.shape[-1])[:k_len]
vi = v_cache[block_indices].reshape(-1, num_kv_heads, v_cache.shape[-1])[:k_len]
# [seq, heads, dim] -> [1, heads, seq, dim]
qi = q[q_start:q_end].unsqueeze(0).transpose(1, 2)
ki = ki.unsqueeze(0).transpose(1, 2)
vi = vi.unsqueeze(0).transpose(1, 2)
oi = F.scaled_dot_product_attention(
qi, ki, vi, scale=scale, is_causal=True, enable_gqa=enable_gqa
)
outputs.append(oi.transpose(1, 2).squeeze(0))
return torch.cat(outputs, dim=0)
def _sdpa_decode_with_paged_cache(
q: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
context_lens: torch.Tensor,
block_tables: torch.Tensor,
scale: float,
num_heads: int,
num_kv_heads: int,
) -> torch.Tensor:
"""SDPA replacement for flash_attn_with_kvcache during decode.
For each sequence, gathers KV from paged cache and runs SDPA
for the single new query token against the full context.
Args:
q: [batch, 1, num_heads, head_dim] (already unsqueezed)
k_cache: [num_blocks, block_size, num_kv_heads, head_dim]
v_cache: [num_blocks, block_size, num_kv_heads, head_dim]
context_lens: [batch] - number of tokens in context for each sequence
block_tables: [batch, max_blocks_per_seq]
scale: attention scale factor
num_heads: number of query heads
num_kv_heads: number of KV heads
Returns:
output: [batch, 1, num_heads, head_dim]
"""
batch_size = q.shape[0]
block_size = k_cache.shape[1]
outputs = []
enable_gqa = num_heads != num_kv_heads
for i in range(batch_size):
ctx_len = context_lens[i].item()
num_blocks_needed = (ctx_len + block_size - 1) // block_size
block_indices = block_tables[i, :num_blocks_needed]
# Gather and trim KV: [ctx_len, num_kv_heads, head_dim]
ki = k_cache[block_indices].reshape(-1, num_kv_heads, k_cache.shape[-1])[:ctx_len]
vi = v_cache[block_indices].reshape(-1, num_kv_heads, v_cache.shape[-1])[:ctx_len]
# q[i]: [1, num_heads, head_dim] -> [1, num_heads, 1, head_dim]
qi = q[i].unsqueeze(0).transpose(1, 2) # [1, num_heads, 1, head_dim]
ki = ki.unsqueeze(0).transpose(1, 2) # [1, num_kv_heads, ctx_len, head_dim]
vi = vi.unsqueeze(0).transpose(1, 2) # [1, num_kv_heads, ctx_len, head_dim]
oi = F.scaled_dot_product_attention(
qi, ki, vi, scale=scale, is_causal=False, enable_gqa=enable_gqa
)
outputs.append(oi.transpose(1, 2).squeeze(0)) # [1, num_heads, head_dim]
return torch.stack(outputs, dim=0) # [batch, 1, num_heads, head_dim]
# ============================================================
# Attention module
# ============================================================
class Attention(nn.Module):
def __init__(
self,
num_heads,
head_dim,
scale,
num_kv_heads,
):
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
self.scale = scale
self.num_kv_heads = num_kv_heads
self.k_cache = self.v_cache = torch.tensor([])
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
context = get_context()
k_cache, v_cache = self.k_cache, self.v_cache
if _DEBUG:
_debug_log(f"forward: q.shape={q.shape}, k.shape={k.shape}, v.shape={v.shape}")
_debug_log(f" is_prefill={context.is_prefill}")
_debug_log(f" k_cache.shape={k_cache.shape if k_cache.numel() else 'empty'}")
if context.slot_mapping is not None:
_debug_log(f" slot_mapping.shape={context.slot_mapping.shape}, range=[{context.slot_mapping.min().item()}, {context.slot_mapping.max().item()}]")
if context.block_tables is not None:
valid_blocks = context.block_tables[context.block_tables >= 0]
_debug_log(f" block_tables.shape={context.block_tables.shape}, range=[{valid_blocks.min().item() if valid_blocks.numel() else -1}, {valid_blocks.max().item() if valid_blocks.numel() else -1}]")
if context.context_lens is not None:
_debug_log(f" context_lens={context.context_lens.tolist()}")
if k_cache.numel() and v_cache.numel():
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
if _HAS_FLASH_ATTN:
return self._forward_flash_attn(q, k, v, k_cache, v_cache, context)
else:
return self._forward_sdpa(q, k, v, k_cache, v_cache, context)
def _forward_flash_attn(self, q, k, v, k_cache, v_cache, context):
"""Original flash attention path."""
if context.is_prefill:
if context.block_tables is not None: # prefix cache
k, v = k_cache, v_cache
_debug_log(f" calling flash_attn_varlen_func")
o = flash_attn_varlen_func(
q, k, v,
max_seqlen_q=context.max_seqlen_q,
cu_seqlens_q=context.cu_seqlens_q,
max_seqlen_k=context.max_seqlen_k,
cu_seqlens_k=context.cu_seqlens_k,
softmax_scale=self.scale,
causal=True,
block_table=context.block_tables,
)
else: # decode
_debug_log(f" calling flash_attn_with_kvcache")
o = flash_attn_with_kvcache(
q.unsqueeze(1), k_cache, v_cache,
cache_seqlens=context.context_lens,
block_table=context.block_tables,
softmax_scale=self.scale,
causal=True,
)
return o
def _forward_sdpa(self, q, k, v, k_cache, v_cache, context):
"""SDPA fallback path (no flash_attn dependency)."""
if context.is_prefill:
if context.block_tables is not None:
# Prefix cache: gather from paged cache
_debug_log(f" calling _sdpa_prefill_with_paged_cache")
o = _sdpa_prefill_with_paged_cache(
q, k_cache, v_cache,
context.cu_seqlens_q, context.cu_seqlens_k,
context.block_tables,
self.scale, self.num_heads, self.num_kv_heads,
)
else:
# Standard prefill: k, v are packed tokens
_debug_log(f" calling _sdpa_varlen_prefill")
o = _sdpa_varlen_prefill(
q, k, v,
context.cu_seqlens_q, context.cu_seqlens_k,
self.scale, self.num_heads, self.num_kv_heads,
)
else:
# Decode: single token per sequence against full KV cache
_debug_log(f" calling _sdpa_decode_with_paged_cache")
o = _sdpa_decode_with_paged_cache(
q.unsqueeze(1), k_cache, v_cache,
context.context_lens, context.block_tables,
self.scale, self.num_heads, self.num_kv_heads,
)
return o

View file

@ -0,0 +1,69 @@
import torch
from torch import nn
import torch.nn.functional as F
import torch.distributed as dist
from nanovllm.utils.context import get_context
from nanovllm import distributed as dist_utils
class VocabParallelEmbedding(nn.Module):
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
):
super().__init__()
self.tp_rank = dist_utils.get_rank()
self.tp_size = dist_utils.get_world_size()
assert num_embeddings % self.tp_size == 0
self.num_embeddings = num_embeddings
self.num_embeddings_per_partition = self.num_embeddings // self.tp_size
self.vocab_start_idx = self.num_embeddings_per_partition * self.tp_rank
self.vocab_end_idx = self.vocab_start_idx + self.num_embeddings_per_partition
self.weight = nn.Parameter(torch.empty(self.num_embeddings_per_partition, embedding_dim))
self.weight.weight_loader = self.weight_loader
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
param_data = param.data
shard_size = param_data.size(0)
start_idx = self.tp_rank * shard_size
loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
param_data.copy_(loaded_weight)
def forward(self, x: torch.Tensor):
if self.tp_size > 1:
mask = (x >= self.vocab_start_idx) & (x < self.vocab_end_idx)
x = mask * (x - self.vocab_start_idx)
y = F.embedding(x, self.weight)
if self.tp_size > 1:
y = mask.unsqueeze(1) * y
dist_utils.all_reduce(y)
return y
class ParallelLMHead(VocabParallelEmbedding):
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
bias: bool = False,
):
assert not bias
super().__init__(num_embeddings, embedding_dim)
def forward(self, x: torch.Tensor):
context = get_context()
if context.is_prefill:
last_indices = context.cu_seqlens_q[1:] - 1
x = x[last_indices].contiguous()
logits = F.linear(x, self.weight)
# In multi-GPU mode, gather logits from all ranks and concatenate
# In single-GPU mode (tp_size=1), skip gathering and return logits directly
if self.tp_size > 1:
all_logits = [torch.empty_like(logits) for _ in range(self.tp_size)] if self.tp_rank == 0 else None
dist_utils.gather(logits, all_logits, 0)
logits = torch.cat(all_logits, -1) if self.tp_rank == 0 else None
return logits

View file

@ -0,0 +1,52 @@
import torch
from torch import nn
from nanovllm.utils.compat import maybe_compile
class RMSNorm(nn.Module):
def __init__(
self,
hidden_size: int,
eps: float = 1e-6,
) -> None:
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(hidden_size))
@maybe_compile
def rms_forward(
self,
x: torch.Tensor,
) -> torch.Tensor:
orig_dtype = x.dtype
x = x.float()
var = x.pow(2).mean(dim=-1, keepdim=True)
x.mul_(torch.rsqrt(var + self.eps))
x = x.to(orig_dtype).mul_(self.weight)
return x
@maybe_compile
def add_rms_forward(
self,
x: torch.Tensor,
residual: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
orig_dtype = x.dtype
x = x.float().add_(residual.float())
residual = x.to(orig_dtype)
var = x.pow(2).mean(dim=-1, keepdim=True)
x.mul_(torch.rsqrt(var + self.eps))
x = x.to(orig_dtype).mul_(self.weight)
return x, residual
def forward(
self,
x: torch.Tensor,
residual: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if residual is None:
return self.rms_forward(x)
else:
return self.add_rms_forward(x, residual)

View file

@ -0,0 +1,155 @@
import torch
from torch import nn
import torch.nn.functional as F
import torch.distributed as dist
from nanovllm import distributed as dist_utils
def divide(numerator, denominator):
assert numerator % denominator == 0
return numerator // denominator
class LinearBase(nn.Module):
def __init__(
self,
input_size: int,
output_size: int,
bias: bool = False,
tp_dim: int | None = None,
):
super().__init__()
self.tp_dim = tp_dim
self.tp_rank = dist_utils.get_rank()
self.tp_size = dist_utils.get_world_size()
self.weight = nn.Parameter(torch.empty(output_size, input_size))
self.weight.weight_loader = self.weight_loader
if bias:
self.bias = nn.Parameter(torch.empty(output_size))
self.bias.weight_loader = self.weight_loader
else:
self.register_parameter("bias", None)
def forward(self, x: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
class ReplicatedLinear(LinearBase):
def __init__(
self,
input_size: int,
output_size: int,
bias: bool = False,
):
super().__init__(input_size, output_size, bias)
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
param.data.copy_(loaded_weight)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return F.linear(x, self.weight, self.bias)
class ColumnParallelLinear(LinearBase):
def __init__(
self,
input_size: int,
output_size: int,
bias: bool = False,
):
tp_size = dist_utils.get_world_size()
super().__init__(input_size, divide(output_size, tp_size), bias, 0)
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
param_data = param.data
shard_size = param_data.size(self.tp_dim)
start_idx = self.tp_rank * shard_size
loaded_weight = loaded_weight.narrow(self.tp_dim, start_idx, shard_size)
param_data.copy_(loaded_weight)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return F.linear(x, self.weight, self.bias)
class MergedColumnParallelLinear(ColumnParallelLinear):
def __init__(
self,
input_size: int,
output_sizes: list[int],
bias: bool = False,
):
self.output_sizes = output_sizes
super().__init__(input_size, sum(output_sizes), bias)
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: int):
param_data = param.data
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size
shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size)
loaded_weight = loaded_weight.chunk(self.tp_size, self.tp_dim)[self.tp_rank]
param_data.copy_(loaded_weight)
class QKVParallelLinear(ColumnParallelLinear):
def __init__(
self,
hidden_size: int,
head_size: int,
total_num_heads: int,
total_num_kv_heads: int | None = None,
bias: bool = False,
):
tp_size = dist_utils.get_world_size()
total_num_kv_heads = total_num_kv_heads or total_num_heads
self.head_size = head_size
self.num_heads = divide(total_num_heads, tp_size)
self.num_kv_heads = divide(total_num_kv_heads, tp_size)
output_size = (total_num_heads + 2 * total_num_kv_heads) * self.head_size
super().__init__(hidden_size, output_size, bias)
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: str):
param_data = param.data
assert loaded_shard_id in ["q", "k", "v"]
if loaded_shard_id == "q":
shard_size = self.num_heads * self.head_size
shard_offset = 0
elif loaded_shard_id == "k":
shard_size = self.num_kv_heads * self.head_size
shard_offset = self.num_heads * self.head_size
else:
shard_size = self.num_kv_heads * self.head_size
shard_offset = self.num_heads * self.head_size + self.num_kv_heads * self.head_size
param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size)
loaded_weight = loaded_weight.chunk(self.tp_size, self.tp_dim)[self.tp_rank]
param_data.copy_(loaded_weight)
class RowParallelLinear(LinearBase):
def __init__(
self,
input_size: int,
output_size: int,
bias: bool = False,
):
tp_size = dist_utils.get_world_size()
super().__init__(divide(input_size, tp_size), output_size, bias, 1)
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
param_data = param.data
shard_size = param_data.size(self.tp_dim)
start_idx = self.tp_rank * shard_size
loaded_weight = loaded_weight.narrow(self.tp_dim, start_idx, shard_size)
param_data.copy_(loaded_weight)
def forward(self, x: torch.Tensor) -> torch.Tensor:
y = F.linear(x, self.weight, self.bias if self.tp_rank == 0 else None)
if self.tp_size > 1:
dist_utils.all_reduce(y)
return y

View file

@ -0,0 +1,63 @@
from functools import lru_cache
import torch
from torch import nn
from nanovllm.utils.compat import maybe_compile
def apply_rotary_emb(
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> torch.Tensor:
x1, x2 = torch.chunk(x.float(), 2, dim=-1)
y1 = x1 * cos - x2 * sin
y2 = x2 * cos + x1 * sin
return torch.cat((y1, y2), dim=-1).to(x.dtype)
class RotaryEmbedding(nn.Module):
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: float,
) -> None:
super().__init__()
self.head_size = head_size
assert rotary_dim == head_size
inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2, dtype=torch.float) / rotary_dim))
t = torch.arange(max_position_embeddings, dtype=torch.float)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1).unsqueeze_(1)
self.register_buffer("cos_sin_cache", cache, persistent=False)
@maybe_compile
def forward(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
cos_sin = self.cos_sin_cache[positions]
cos, sin = cos_sin.chunk(2, dim=-1)
query = apply_rotary_emb(query, cos, sin)
key = apply_rotary_emb(key, cos, sin)
return query, key
@lru_cache(1)
def get_rope(
head_size: int,
rotary_dim: int,
max_position: int,
base: float,
rope_scaling: dict | None = None,
):
assert rope_scaling is None
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base)
return rotary_emb

View file

@ -0,0 +1,116 @@
import torch
from torch import nn
from typing import Optional
from nanovllm.utils.compat import maybe_compile
def apply_top_k_top_p(
logits: torch.Tensor,
k: Optional[torch.Tensor],
p: Optional[torch.Tensor],
) -> torch.Tensor:
"""Apply top-k and top-p masks to the logits (vLLM style).
The logits tensor is updated in-place.
"""
if p is None:
if k is None:
return logits
# Avoid sorting vocab for top-k only case
return apply_top_k_only(logits, k)
# Need to sort for top-p
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
if k is not None:
# Apply top-k first
vocab_size = logits_sort.size(1)
# Clamp k to valid range
k_clamped = k.clamp(1, vocab_size).long()
top_k_mask_idx = vocab_size - k_clamped # shape: [B]
# Get the threshold value for each batch
top_k_thresh = logits_sort.gather(1, top_k_mask_idx.unsqueeze(1))
top_k_mask = logits_sort < top_k_thresh
logits_sort.masked_fill_(top_k_mask, float('-inf'))
# Apply top-p
probs_sort = logits_sort.softmax(dim=-1)
probs_sum = torch.cumsum(probs_sort, dim=-1, out=probs_sort) # reuse buffer
top_p_mask = probs_sum <= (1.0 - p.unsqueeze(1))
# Ensure at least one token is kept
top_p_mask[:, -1] = False
logits_sort.masked_fill_(top_p_mask, float('-inf'))
# Re-sort back to original positions
logits.scatter_(dim=-1, index=logits_idx, src=logits_sort)
return logits
def apply_top_k_only(
logits: torch.Tensor,
k: torch.Tensor,
) -> torch.Tensor:
"""Apply top-k mask without sorting the entire vocab (vLLM style).
This is much faster than sorting for top-k only cases.
The logits tensor is updated in-place.
"""
vocab_size = logits.shape[1]
# Handle cases where k >= vocab_size (no filtering needed)
no_top_k_mask = (k <= 0) | (k >= vocab_size)
# Set invalid k to 1 so we can still gather
k_safe = k.masked_fill(no_top_k_mask, 1).long()
# NOTE: This int() causes CPU-GPU sync, but torch.topk requires Python int
max_top_k = int(k_safe.max().clamp(max=vocab_size))
# Get top-k values for all batches
# topk.values has shape [batch_size, max_top_k]
topk_values = logits.topk(max_top_k, dim=1).values
# Convert k to 0-based index: we want the k-th largest value (index k-1)
# Clamp to valid range for gather
k_index = (k_safe - 1).clamp(0, max_top_k - 1).unsqueeze(1) # shape: [B, 1]
# Gather the threshold value (the k-th largest)
top_k_thresh = topk_values.gather(1, k_index)
# For rows with no top-k filtering, set threshold to -inf so nothing gets masked
top_k_thresh.masked_fill_(no_top_k_mask.unsqueeze(1), float('-inf'))
# Mask all values below the threshold
logits.masked_fill_(logits < top_k_thresh, float('-inf'))
return logits
class Sampler(nn.Module):
def __init__(self):
super().__init__()
@maybe_compile
def forward(
self,
logits: torch.Tensor,
temperatures: torch.Tensor,
top_ks: Optional[torch.Tensor] = None,
top_ps: Optional[torch.Tensor] = None,
repetition_penalties: Optional[torch.Tensor] = None,
input_ids: Optional[torch.Tensor] = None,
):
"""
Sample tokens from logits with optional top-k and top-p filtering.
Condition checking is done OUTSIDE the compiled function to avoid
graph breaks from .any() calls.
"""
# Apply temperature
logits = logits.float().div_(temperatures.unsqueeze(dim=1))
logits = apply_top_k_top_p(
logits,
top_ks,
top_ps,
)
probs = torch.softmax(logits, dim=-1)
sample_tokens = probs.div_(torch.empty_like(probs).exponential_(1).clamp_min_(1e-10)).argmax(dim=-1)
return sample_tokens

View file

@ -0,0 +1,5 @@
from nanovllm.engine.llm_engine import LLMEngine
class LLM(LLMEngine):
pass

View file

@ -0,0 +1,230 @@
import torch
from torch import nn
import torch.distributed as dist
from transformers import Qwen3Config
from nanovllm.layers.activation import SiluAndMul
from nanovllm.layers.attention import Attention
from nanovllm.layers.layernorm import RMSNorm
from nanovllm.layers.linear import QKVParallelLinear, MergedColumnParallelLinear, RowParallelLinear
from nanovllm.layers.rotary_embedding import get_rope
from nanovllm.layers.embed_head import VocabParallelEmbedding, ParallelLMHead
from nanovllm import distributed as dist_utils
class Qwen3Attention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
max_position: int = 4096 * 32,
head_dim: int | None = None,
rms_norm_eps: float = 1e-06,
qkv_bias: bool = False,
rope_theta: float = 10000,
rope_scaling: tuple | None = None,
) -> None:
super().__init__()
tp_size = dist_utils.get_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
assert self.total_num_kv_heads % tp_size == 0
self.num_kv_heads = self.total_num_kv_heads // tp_size
self.head_dim = head_dim or hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim ** -0.5
self.qkv_bias = qkv_bias
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=qkv_bias,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position,
base=rope_theta,
rope_scaling=rope_scaling,
)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
self.num_kv_heads,
)
if not self.qkv_bias:
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
qkv = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q = q.view(-1, self.num_heads, self.head_dim)
k = k.view(-1, self.num_kv_heads, self.head_dim)
v = v.view(-1, self.num_kv_heads, self.head_dim)
if not self.qkv_bias:
q = self.q_norm(q)
k = self.k_norm(k)
q, k = self.rotary_emb(positions, q, k)
o = self.attn(q, k, v)
output = self.o_proj(o.flatten(1, -1))
return output
class Qwen3MLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size,
[intermediate_size] * 2,
bias=False,
)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
)
assert hidden_act == "silu"
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x = self.down_proj(x)
return x
class Qwen3DecoderLayer(nn.Module):
def __init__(
self,
config: Qwen3Config,
) -> None:
super().__init__()
self.self_attn = Qwen3Attention(
hidden_size=config.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
max_position=config.max_position_embeddings,
rms_norm_eps=config.rms_norm_eps,
qkv_bias=getattr(config, 'attention_bias', True),
head_dim=getattr(config, 'head_dim', None),
rope_theta=getattr(config, "rope_theta", 1000000),
rope_scaling=getattr(config, "rope_scaling", None),
)
self.mlp = Qwen3MLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor]:
if residual is None:
hidden_states, residual = self.input_layernorm(hidden_states), hidden_states
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states = self.self_attn(positions, hidden_states)
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
class Qwen3Model(nn.Module):
def __init__(
self,
config: Qwen3Config,
) -> None:
super().__init__()
self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size)
self.layers = nn.ModuleList([Qwen3DecoderLayer(config) for _ in range(config.num_hidden_layers)])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
residual = None
for layer in self.layers:
hidden_states, residual = layer(positions, hidden_states, residual)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class Qwen3ForCausalLM(nn.Module):
packed_modules_mapping = {
"q_proj": ("qkv_proj", "q"),
"k_proj": ("qkv_proj", "k"),
"v_proj": ("qkv_proj", "v"),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}
def __init__(
self,
config: Qwen3Config
) -> None:
super().__init__()
self.model = Qwen3Model(config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
if config.tie_word_embeddings:
self.lm_head.weight.data = self.model.embed_tokens.weight.data
# Proxy attributes for weight loading compatibility
# Some model weights use "embed_tokens" instead of "model.embed_tokens"
@property
def embed_tokens(self):
return self.model.embed_tokens
@property
def layers(self):
return self.model.layers
@property
def norm(self):
return self.model.norm
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
) -> torch.Tensor:
return self.model(input_ids, positions)
def compute_logits(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor:
return self.lm_head(hidden_states)

View file

@ -0,0 +1,28 @@
from dataclasses import dataclass, field
from typing import Optional, Callable, Any
@dataclass
class SamplingParams:
temperature: float = 1.0
max_tokens: int = 64
ignore_eos: bool = False
cfg_scale: float = 1.0 # CFG guidance scale. When > 1.0, applies classifier-free guidance
top_k: Optional[int] = None # Top-k sampling: consider only top k tokens
top_p: Optional[float] = None # Top-p (nucleus) sampling: consider tokens with cumulative probability <= top_p
repetition_penalty: float = 1.0 # Repetition penalty: >1.0 reduces repetition, <1.0 increases it
# Optional logits processor for constrained decoding
# Should be a callable with signature: (input_ids: torch.Tensor, logits: torch.Tensor) -> torch.Tensor
logits_processor: Optional[Any] = field(default=None, repr=False)
# Optional callback to update processor state after each token
# Should be a callable with signature: (token_id: int) -> None
logits_processor_update_state: Optional[Callable[[int], None]] = field(default=None, repr=False)
def __post_init__(self):
assert self.temperature > 1e-10, "greedy sampling is not permitted"
assert self.cfg_scale >= 1.0, "cfg_scale must be >= 1.0"
if self.top_k is not None:
assert self.top_k > 0, "top_k must be > 0"
if self.top_p is not None:
assert 0.0 < self.top_p <= 1.0, "top_p must be in (0.0, 1.0]"
assert self.repetition_penalty > 0.0, "repetition_penalty must be > 0.0"

View file

@ -0,0 +1,70 @@
"""Compatibility utilities for optional dependencies.
Provides graceful fallbacks when torch.compile's backend (Triton) is
unavailable e.g. on Windows or on GPU architectures where Triton
has not yet added support (Blackwell / SM 120 as of early 2026).
"""
from typing import Any, Callable, Optional, TypeVar
from loguru import logger
F = TypeVar("F", bound=Callable[..., Any])
# Module-level Triton availability check — runs once at import time
# rather than repeating the import probe at every decoration site.
_HAS_TRITON = False
try:
import triton # noqa: F401
_HAS_TRITON = True
except ImportError:
pass
def maybe_compile(fn: Optional[F] = None, **compile_kwargs: Any) -> Any:
"""Apply ``torch.compile`` only when its backend (Triton) is available.
Drop-in replacement for the ``@torch.compile`` decorator. When Triton
is importable the function is compiled as usual; otherwise the original
function is returned unmodified so inference still works (just without
the kernel-fusion speed-up).
Args:
fn: The function to compile. When used as ``@maybe_compile`` (without
parentheses) the decorated function is passed directly. When used
as ``@maybe_compile(...)`` this is ``None`` and a decorator is
returned instead.
**compile_kwargs: Keyword arguments forwarded to ``torch.compile``
(e.g. ``dynamic=True``, ``fullgraph=True``).
Returns:
The compiled function when Triton is available, or the original
unmodified function as a fallback.
Usage::
@maybe_compile
def forward(self, x):
...
# or with keyword arguments:
@maybe_compile(dynamic=True)
def forward(self, x):
...
"""
def decorator(func: F) -> F:
"""Inner decorator that performs the actual compile-or-skip logic."""
if _HAS_TRITON:
import torch
return torch.compile(func, **compile_kwargs)
logger.info(
"Triton not available — skipping torch.compile for %s "
"(inference will use native PyTorch kernels)",
func.__qualname__,
)
return func
# Support both @maybe_compile and @maybe_compile(...) syntax
if fn is not None:
return decorator(fn)
return decorator

View file

@ -0,0 +1,45 @@
"""Unit tests for the ``maybe_compile`` conditional compilation decorator."""
import unittest
from unittest.mock import patch, MagicMock
def _sample_fn(x):
"""Trivial function used as decoration target in tests."""
return x + 1
class MaybeCompileTests(unittest.TestCase):
"""Verify maybe_compile compiles or skips based on Triton availability."""
def test_compiles_when_triton_available(self):
"""When Triton is available, the function should be passed to torch.compile."""
mock_compiled = MagicMock(name="compiled_fn")
with patch("nanovllm.utils.compat._HAS_TRITON", True), \
patch("torch.compile", return_value=mock_compiled) as compile_mock:
from nanovllm.utils.compat import maybe_compile
result = maybe_compile(_sample_fn)
compile_mock.assert_called_once_with(_sample_fn)
self.assertEqual(result, mock_compiled)
def test_returns_original_when_triton_absent(self):
"""When Triton is absent, the original function should be returned unmodified."""
with patch("nanovllm.utils.compat._HAS_TRITON", False):
from nanovllm.utils.compat import maybe_compile
result = maybe_compile(_sample_fn)
self.assertIs(result, _sample_fn)
def test_kwargs_syntax_forwards_compile_args(self):
"""@maybe_compile(dynamic=True) should forward kwargs to torch.compile."""
mock_compiled = MagicMock(name="compiled_fn")
with patch("nanovllm.utils.compat._HAS_TRITON", True), \
patch("torch.compile", return_value=mock_compiled) as compile_mock:
from nanovllm.utils.compat import maybe_compile
decorator = maybe_compile(dynamic=True)
result = decorator(_sample_fn)
compile_mock.assert_called_once_with(_sample_fn, dynamic=True)
self.assertEqual(result, mock_compiled)
if __name__ == "__main__":
unittest.main()

View file

@ -0,0 +1,47 @@
from dataclasses import dataclass
import threading
import torch
@dataclass
class Context:
is_prefill: bool = False
cu_seqlens_q: torch.Tensor | None = None
cu_seqlens_k: torch.Tensor | None = None
max_seqlen_q: int = 0
max_seqlen_k: int = 0
slot_mapping: torch.Tensor | None = None
context_lens: torch.Tensor | None = None
block_tables: torch.Tensor | None = None
# Thread-local storage for context.
#
# ROOT CAUSE FIX: The original implementation used a plain module-level global
# `_CONTEXT` variable. In concurrent serving scenarios (API server with
# ThreadPoolExecutor, multiple queue workers, or Gradio with concurrent requests),
# multiple threads can call set_context() / get_context() / reset_context()
# concurrently. This creates a race condition:
#
# Thread A: set_context(...) # sets slot_mapping, block_tables for request A
# Thread B: set_context(...) # OVERWRITES with request B's data
# Thread A: run_model(...) # reads Thread B's context → WRONG KV cache addresses
# # → CUDA illegal memory access / device-side assertion
#
# By using threading.local(), each thread gets its own independent Context,
# eliminating the race condition entirely.
_THREAD_LOCAL = threading.local()
def get_context():
ctx = getattr(_THREAD_LOCAL, 'context', None)
if ctx is None:
ctx = Context()
_THREAD_LOCAL.context = ctx
return ctx
def set_context(is_prefill, cu_seqlens_q=None, cu_seqlens_k=None, max_seqlen_q=0, max_seqlen_k=0, slot_mapping=None, context_lens=None, block_tables=None):
_THREAD_LOCAL.context = Context(is_prefill, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, context_lens, block_tables)
def reset_context():
_THREAD_LOCAL.context = Context()

View file

@ -0,0 +1,70 @@
import os
from glob import glob
import torch
from torch import nn
from safetensors import safe_open
def default_weight_loader(param: nn.Parameter, loaded_weight: torch.Tensor):
param.data.copy_(loaded_weight)
def _get_parameter_safe(model: nn.Module, weight_name: str):
"""
Try to get parameter from model, handling name mismatches.
Some models have nested structure (e.g., Qwen3ForCausalLM has model.embed_tokens)
but weight files may have flat names (embed_tokens.weight).
"""
# Try direct access first
try:
return model.get_parameter(weight_name)
except AttributeError:
pass
# Try with 'model.' prefix (for nested model structure)
try:
prefixed_name = f"model.{weight_name}"
return model.get_parameter(prefixed_name)
except AttributeError:
pass
# Try removing 'model.' prefix
if weight_name.startswith("model."):
try:
unprefixed_name = weight_name[6:] # Remove 'model.' prefix
return model.get_parameter(unprefixed_name)
except AttributeError:
pass
return None
def load_model(model: nn.Module, path: str):
packed_modules_mapping = getattr(model, "packed_modules_mapping", {})
safetensor_files = glob(os.path.join(path, "*.safetensors"))
if not safetensor_files:
raise FileNotFoundError(f"No .safetensors files found in {path}")
for file in safetensor_files:
with safe_open(file, "pt", "cpu") as f:
for weight_name in f.keys():
for k in packed_modules_mapping:
if k in weight_name:
v, shard_id = packed_modules_mapping[k]
param_name = weight_name.replace(k, v)
param = _get_parameter_safe(model, param_name)
if param is None:
print(f"[loader] Warning: Parameter not found: {param_name}")
continue
weight_loader = getattr(param, "weight_loader")
weight_loader(param, f.get_tensor(weight_name), shard_id)
break
else:
param = _get_parameter_safe(model, weight_name)
if param is None:
print(f"[loader] Warning: Parameter not found: {weight_name}")
continue
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, f.get_tensor(weight_name))

View file

@ -0,0 +1,29 @@
[build-system]
requires = ["setuptools>=61"]
build-backend = "setuptools.build_meta"
[project]
name = "nano-vllm"
version = "0.2.0"
authors = [{ name = "Xingkai Yu" }]
license = {text = "MIT"}
readme = "README.md"
description = "a lightweight vLLM implementation built from scratch"
requires-python = ">=3.10,<3.13"
dependencies = [
"torch>=2.4.0",
# Triton and Flash Attention are optional on ROCm (Python 3.12) - SDPA fallback used instead
"triton-windows>=3.0.0,<3.4; sys_platform == 'win32' and python_version == '3.11'",
"triton>=3.0.0; sys_platform == 'linux' and python_version == '3.11'",
"transformers>=4.51.0",
"flash-attn @ https://github.com/sdbds/flash-attention-for-windows/releases/download/2.8.2/flash_attn-2.8.2+cu128torch2.7.1cxx11abiFALSEfullbackward-cp311-cp311-win_amd64.whl ; sys_platform == 'win32' and sys_platform != 'darwin' and python_version == '3.11' and platform_machine == 'AMD64'",
"flash-attn @ https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.7.12/flash_attn-2.8.3+cu128torch2.10-cp311-cp311-linux_x86_64.whl ; sys_platform == 'linux' and sys_platform != 'darwin' and python_version == '3.11' and platform_machine == 'x86_64'",
"xxhash",
]
[project.urls]
Homepage="https://github.com/GeeeekExplorer/nano-vllm"
[tool.setuptools.packages.find]
where = ["."]
include = ["nanovllm*"]

View file

@ -485,7 +485,7 @@ See [ACE-Step1.5-Rocm-Manual-Linux.md](ACE-Step1.5-Rocm-Manual-Linux.md) for a d
| Offload | Disabled by default |
| Compile & Quantization | Enabled by default |
| LLM Inference | Supported (tested with `acestep-5Hz-lm-0.6B`) |
| vllm engine acceleration | NOT supported on Intel GPUs |
| nanovllm acceleration | NOT supported on Intel GPUs |
| Test Environment | PyTorch 2.8.0 from [Intel Extension for PyTorch](https://pytorch-extension.intel.com/?request=platform) |
> **Note:** LLM inference speed may decrease when generating audio longer than 2 minutes. Intel discrete GPUs are expected to work but not yet tested.

View file

@ -375,7 +375,7 @@ python -m acestep.acestep_v15_pipeline --port 7680
| オフロード | デフォルトで無効 |
| コンパイル & 量子化 | デフォルトで有効 |
| LLM 推論 | サポート(`acestep-5Hz-lm-0.6B` でテスト済み) |
| vllm エンジンアクセラレーション | Intel GPU では未サポート |
| nanovllm アクセラレーション | Intel GPU では未サポート |
| テスト環境 | PyTorch 2.8.0[Intel Extension for PyTorch](https://pytorch-extension.intel.com/?request=platform) |
> 注意2分以上の音声生成時、LLM 推論速度が低下する場合があります。Intel ディスクリート GPU は動作が期待されますが、まだテストされていません。

View file

@ -375,7 +375,7 @@ python -m acestep.acestep_v15_pipeline --port 7680
| 卸载 | 默认禁用 |
| 编译与量化 | 默认启用 |
| LLM 推理 | 支持(已测试 `acestep-5Hz-lm-0.6B` |
| vllm 引擎加速 | Intel GPU 暂不支持 |
| nanovllm 加速 | Intel GPU 暂不支持 |
| 测试环境 | PyTorch 2.8.0[Intel Extension for PyTorch](https://pytorch-extension.intel.com/?request=platform) |
> 注意:生成超过 2 分钟的音频时LLM 推理速度可能下降。Intel 独立显卡预计可用但尚未测试。

View file

@ -49,6 +49,9 @@ dependencies = [
"lycoris-lora",
"lightning>=2.0.0",
"tensorboard>=2.0.0",
# Local third-party packages
# nano-vllm source is configured in [tool.uv.sources]
"nano-vllm; sys_platform != 'darwin' or platform_machine != 'arm64'",
"modelscope",
"tensorboard>=2.20.0",
"typer-slim>=0.21.1",
@ -73,6 +76,7 @@ required-environments = [
]
[tool.uv.sources]
nano-vllm = { path = "acestep/third_parts/nano-vllm" }
torch = [
{ index = "pytorch-cu128", marker = "sys_platform == 'win32' or (sys_platform == 'linux' and platform_machine == 'x86_64')" },
{ index = "pytorch-cu130", marker = "sys_platform == 'linux' and platform_machine == 'aarch64'" },

View file

@ -22,3 +22,7 @@ modelscope
peft>=0.18.0
lightning>=2.0.0
tensorboard>=2.0.0
xxhash
# Local package - install with: pip install -e acestep/third_parts/nano-vllm
# nano-vllm

View file

@ -47,3 +47,5 @@ peft>=0.18.0
lightning>=2.0.0
tensorboard>=2.0.0
# nano-vllm tokenizer dependency (needed even with pt backend)
xxhash

View file

@ -48,8 +48,12 @@ tensorboard>=2.0.0
mlx>=0.25.2; sys_platform == 'darwin' and platform_machine == 'arm64'
mlx-lm>=0.20.0; sys_platform == 'darwin' and platform_machine == 'arm64'
# Optional accelerators for LLM inference engine (SDPA fallback if unavailable)
# nano-vllm dependencies
triton-windows>=3.0.0,<3.4; sys_platform == 'win32'
triton>=3.0.0; sys_platform == 'linux'
flash-attn @ https://github.com/sdbds/flash-attention-for-windows/releases/download/2.8.2/flash_attn-2.8.2+cu128torch2.7.1cxx11abiFALSEfullbackward-cp311-cp311-win_amd64.whl ; sys_platform == 'win32' and python_version == '3.11' and platform_machine == 'AMD64'
flash-attn; sys_platform == 'linux' and platform_machine == 'x86_64'
xxhash
# Local package - install with: pip install -e acestep/third_parts/nano-vllm
# nano-vllm

View file

@ -6,7 +6,7 @@ REM IMPORTANT: Requires Python 3.12 (AMD ROCm 7.2 only provides Python 3.12 whee
REM Requires: ROCm PyTorch from repo.radeon.com
REM ==================== ROCm Configuration ====================
REM Force PyTorch LM backend (bypasses vllm engine flash_attn dependency)
REM Force PyTorch LM backend (bypasses nano-vllm flash_attn dependency)
set ACESTEP_LM_BACKEND=pt
REM RDNA3 GPU architecture override (RX 7900 XT/XTX, RX 7800 XT, etc.)

View file

@ -17,7 +17,7 @@ set -euo pipefail
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
# ==================== ROCm Configuration ====================
# Force PyTorch LM backend (bypasses vllm engine flash_attn dependency)
# Force PyTorch LM backend (bypasses nano-vllm flash_attn dependency)
export ACESTEP_LM_BACKEND="pt"
# RDNA3 GPU architecture override (RX 7900 XT/XTX, RX 7800 XT, etc.)

View file

@ -10,7 +10,7 @@ REM Load settings from .env file if it exists
call :LoadEnvFile
REM ==================== ROCm Configuration ====================
REM Force PyTorch LM backend (bypasses vllm engine flash_attn dependency)
REM Force PyTorch LM backend (bypasses nano-vllm flash_attn dependency)
set ACESTEP_LM_BACKEND=pt
REM RDNA3 GPU architecture override (RX 7900 XT/XTX, RX 7800 XT, etc.)

View file

@ -71,7 +71,7 @@ _load_env_file() {
_load_env_file
# ==================== ROCm Configuration ====================
# Force PyTorch LM backend (bypasses vllm engine flash_attn dependency)
# Force PyTorch LM backend (bypasses nano-vllm flash_attn dependency)
export ACESTEP_LM_BACKEND="pt"
# RDNA3 GPU architecture override (RX 7900 XT/XTX, RX 7800 XT, etc.)

View file

@ -14,7 +14,7 @@ REM Ask for manual settings
call :LoadManual
REM ==================== ROCm Configuration ====================
REM Force PyTorch LM backend (bypasses vllm engine flash_attn dependency)
REM Force PyTorch LM backend (bypasses nano-vllm flash_attn dependency)
set ACESTEP_LM_BACKEND=pt
REM RDNA3 GPU architecture override (RX 7900 XT/XTX, RX 7800 XT, etc.)

View file

@ -192,7 +192,7 @@ _load_manual() {
_load_manual
# ==================== ROCm Configuration ====================
# Force PyTorch LM backend (bypasses vllm engine flash_attn dependency)
# Force PyTorch LM backend (bypasses nano-vllm flash_attn dependency)
export ACESTEP_LM_BACKEND="pt"
# RDNA3 GPU architecture override (RX 7900 XT/XTX, RX 7800 XT, etc.)