mirror of
https://github.com/ace-step/ACE-Step-1.5.git
synced 2026-07-02 16:37:04 +00:00
Revert "(feat) Fully customized in house vllm " (#874)
This commit is contained in:
parent
8b58da12a2
commit
89d53791dc
53 changed files with 3121 additions and 1351 deletions
|
|
@ -475,7 +475,7 @@ The `start_gradio_ui_rocm.bat` and `start_api_server_rocm.bat` scripts include a
|
|||
|
||||
```batch
|
||||
REM ==================== ROCm Configuration ====================
|
||||
REM Force PyTorch LM backend (bypasses vllm engine flash_attn dependency)
|
||||
REM Force PyTorch LM backend (bypasses nano-vllm flash_attn dependency)
|
||||
set ACESTEP_LM_BACKEND=pt
|
||||
|
||||
REM RDNA3 GPU architecture override
|
||||
|
|
|
|||
2
.github/codeql-config.yml
vendored
2
.github/codeql-config.yml
vendored
|
|
@ -23,4 +23,6 @@ paths-ignore:
|
|||
# Training modules — local-only, paths validated via safe_path()
|
||||
- acestep/training
|
||||
- acestep/ui
|
||||
# Third-party vendored code
|
||||
- acestep/third_parts
|
||||
- third_party
|
||||
|
|
|
|||
2
.github/copilot-instructions.md
vendored
2
.github/copilot-instructions.md
vendored
|
|
@ -14,7 +14,7 @@ ACE-Step 1.5 is an open-source music foundation model combining a Language Model
|
|||
- **FastAPI + Uvicorn** for REST API server
|
||||
- **uv** for dependency management
|
||||
- **MLX** (Apple Silicon native acceleration, macOS ARM64)
|
||||
- **customized_vllm** (built-in LLM inference engine)
|
||||
- **nano-vllm** (optimized LLM inference, non-macOS ARM64)
|
||||
|
||||
## Multi-Platform Support
|
||||
|
||||
|
|
|
|||
|
|
@ -1,409 +0,0 @@
|
|||
"""Customised vLLM inference engine for single-GPU Qwen3 generation.
|
||||
|
||||
Public API:
|
||||
LLM - High-level generate() interface with optional CFG
|
||||
SamplingParams - Per-request sampling configuration
|
||||
reset_context - Clear thread-local forward state (call on error recovery)
|
||||
"""
|
||||
|
||||
import os
|
||||
import atexit
|
||||
import threading as _threading
|
||||
from collections import deque
|
||||
from copy import copy
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum, auto
|
||||
from itertools import count
|
||||
from time import perf_counter
|
||||
from typing import Optional, Callable, Any
|
||||
|
||||
import torch
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import AutoConfig, AutoTokenizer
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Sampling configuration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass
|
||||
class SamplingParams:
|
||||
"""Per-request configuration for token sampling."""
|
||||
temperature: float = 1.0
|
||||
max_tokens: int = 64
|
||||
ignore_eos: bool = False
|
||||
cfg_scale: float = 1.0
|
||||
top_k: Optional[int] = None
|
||||
top_p: Optional[float] = None
|
||||
repetition_penalty: float = 1.0
|
||||
logits_processor: Optional[Any] = field(default=None, repr=False)
|
||||
logits_processor_update_state: Optional[Callable[[int], None]] = field(default=None, repr=False)
|
||||
|
||||
def __post_init__(self):
|
||||
assert self.temperature > 1e-10, "greedy sampling is not permitted"
|
||||
assert self.max_tokens > 0, "max_tokens must be > 0"
|
||||
assert self.cfg_scale >= 1.0, "cfg_scale must be >= 1.0"
|
||||
if self.top_k is not None:
|
||||
assert self.top_k > 0, "top_k must be > 0"
|
||||
if self.top_p is not None:
|
||||
assert 0.0 < self.top_p <= 1.0, "top_p must be in (0.0, 1.0]"
|
||||
assert self.repetition_penalty > 0.0, "repetition_penalty must be > 0.0"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Thread-local forward state (replaces separate context.py module)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass
|
||||
class ForwardState:
|
||||
"""Attention metadata shared between the pipeline and transformer layers."""
|
||||
is_prefill: bool = False
|
||||
cu_seqlens_q: torch.Tensor | None = None
|
||||
cu_seqlens_k: torch.Tensor | None = None
|
||||
max_seqlen_q: int = 0
|
||||
max_seqlen_k: int = 0
|
||||
slot_mapping: torch.Tensor | None = None
|
||||
context_lens: torch.Tensor | None = None
|
||||
block_tables: torch.Tensor | None = None
|
||||
|
||||
|
||||
_TLS = _threading.local()
|
||||
|
||||
|
||||
def _get_forward_state() -> ForwardState:
|
||||
"""Retrieve the current thread's forward state."""
|
||||
s = getattr(_TLS, "_fwd", None)
|
||||
if s is None:
|
||||
s = ForwardState()
|
||||
_TLS._fwd = s
|
||||
return s
|
||||
|
||||
|
||||
def _set_forward_state(is_prefill=False, **kw) -> ForwardState:
|
||||
"""Replace the current thread's forward state."""
|
||||
_TLS._fwd = ForwardState(is_prefill=is_prefill, **kw)
|
||||
return _TLS._fwd
|
||||
|
||||
|
||||
def reset_context():
|
||||
"""Reset forward state to defaults (public, used by llm_inference error recovery)."""
|
||||
_TLS._fwd = ForwardState()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Generation slot (per-sequence state)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class SlotStatus(Enum):
|
||||
PENDING = auto()
|
||||
ACTIVE = auto()
|
||||
DONE = auto()
|
||||
|
||||
|
||||
class GenerationSlot:
|
||||
"""Tracks token state, cache blocks, and sampling config for a single sequence."""
|
||||
|
||||
_id_gen = count()
|
||||
|
||||
def __init__(self, token_ids: list[int], params=SamplingParams(),
|
||||
is_unconditional: bool = False, block_size: int = 256):
|
||||
self.slot_id = next(GenerationSlot._id_gen)
|
||||
self.block_size = block_size
|
||||
self.status = SlotStatus.PENDING
|
||||
self.token_ids = copy(token_ids)
|
||||
self.last_token = token_ids[-1]
|
||||
self.num_tokens = len(token_ids)
|
||||
self.prompt_length = len(token_ids)
|
||||
self.cache_blocks: list[int] = []
|
||||
self.temperature = params.temperature
|
||||
self.max_tokens = params.max_tokens
|
||||
self.ignore_eos = params.ignore_eos
|
||||
self.cfg_scale = params.cfg_scale
|
||||
self.top_k = params.top_k
|
||||
self.top_p = params.top_p
|
||||
self.repetition_penalty = params.repetition_penalty
|
||||
self.is_unconditional = is_unconditional
|
||||
self.paired_slot: Optional["GenerationSlot"] = None
|
||||
self.logits_processor: Optional[Any] = params.logits_processor
|
||||
self.logits_processor_update_state: Optional[Callable[[int], None]] = (
|
||||
params.logits_processor_update_state
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return self.num_tokens
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self.token_ids[key]
|
||||
|
||||
@property
|
||||
def is_finished(self):
|
||||
return self.status == SlotStatus.DONE
|
||||
|
||||
@property
|
||||
def generated_ids(self):
|
||||
return self.token_ids[self.prompt_length:]
|
||||
|
||||
@property
|
||||
def required_blocks(self):
|
||||
return (self.num_tokens + self.block_size - 1) // self.block_size
|
||||
|
||||
@property
|
||||
def tail_block_fill(self):
|
||||
return self.num_tokens - (self.required_blocks - 1) * self.block_size
|
||||
|
||||
def push_token(self, token_id: int):
|
||||
self.token_ids.append(token_id)
|
||||
self.last_token = token_id
|
||||
self.num_tokens += 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# KV cache pool
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class CachePool:
|
||||
"""Simple block-based KV cache allocator (no prefix caching)."""
|
||||
|
||||
def __init__(self, num_blocks: int, block_size: int):
|
||||
self.block_size = block_size
|
||||
self.available: deque[int] = deque(range(num_blocks))
|
||||
self.total = num_blocks
|
||||
|
||||
def has_capacity(self, num_blocks: int) -> bool:
|
||||
return len(self.available) >= num_blocks
|
||||
|
||||
def reserve(self, slot: GenerationSlot):
|
||||
for _ in range(slot.required_blocks):
|
||||
slot.cache_blocks.append(self.available.popleft())
|
||||
|
||||
def release(self, slot: GenerationSlot):
|
||||
for bid in reversed(slot.cache_blocks):
|
||||
self.available.append(bid)
|
||||
slot.cache_blocks.clear()
|
||||
|
||||
def grow_if_needed(self, slot: GenerationSlot):
|
||||
if len(slot) % self.block_size == 1 and len(slot) > slot.prompt_length:
|
||||
if not self.available:
|
||||
raise RuntimeError(
|
||||
f"KV cache exhausted during decode: slot {slot.slot_id} needs a new block "
|
||||
f"but 0/{self.total} blocks are available"
|
||||
)
|
||||
slot.cache_blocks.append(self.available.popleft())
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Engine configuration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass
|
||||
class _EngineConfig:
|
||||
model: str
|
||||
max_num_batched_tokens: int = 16384
|
||||
max_num_seqs: int = 512
|
||||
max_model_len: int = 4096
|
||||
gpu_memory_utilization: float = 0.9
|
||||
enforce_eager: bool = False
|
||||
kvcache_block_size: int = 256
|
||||
|
||||
def __post_init__(self):
|
||||
assert os.path.isdir(self.model)
|
||||
self.hf_config = AutoConfig.from_pretrained(self.model)
|
||||
self.max_model_len = min(self.max_model_len, self.hf_config.max_position_embeddings)
|
||||
assert self.max_num_batched_tokens >= self.max_model_len
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LLM - public inference API
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class LLM:
|
||||
"""High-level inference engine with optional classifier-free guidance.
|
||||
|
||||
Usage::
|
||||
|
||||
llm = LLM(model="/path/to/model")
|
||||
outputs = llm.generate(["Hello world"], SamplingParams(max_tokens=128))
|
||||
"""
|
||||
|
||||
def __init__(self, model, **kwargs):
|
||||
from acestep.customized_vllm.pipeline import InferencePipeline
|
||||
|
||||
cfg = _EngineConfig(
|
||||
model=model,
|
||||
max_num_batched_tokens=kwargs.get("max_num_batched_tokens", 16384),
|
||||
max_num_seqs=kwargs.get("max_num_seqs", 512),
|
||||
max_model_len=kwargs.get("max_model_len", 4096),
|
||||
gpu_memory_utilization=kwargs.get("gpu_memory_utilization", 0.9),
|
||||
enforce_eager=kwargs.get("enforce_eager", False),
|
||||
kvcache_block_size=kwargs.get("kvcache_block_size", 256),
|
||||
)
|
||||
self._cfg = cfg
|
||||
self._lock = _threading.Lock()
|
||||
self._pipeline = InferencePipeline(
|
||||
hf_config=cfg.hf_config, model_path=cfg.model,
|
||||
block_size=cfg.kvcache_block_size, max_num_seqs=cfg.max_num_seqs,
|
||||
max_num_batched_tokens=cfg.max_num_batched_tokens,
|
||||
max_model_len=cfg.max_model_len,
|
||||
gpu_memory_utilization=cfg.gpu_memory_utilization,
|
||||
enforce_eager=cfg.enforce_eager,
|
||||
)
|
||||
tok = kwargs.get("tokenizer", None)
|
||||
self.tokenizer = tok if tok is not None else AutoTokenizer.from_pretrained(model, use_fast=True)
|
||||
self._eos = self.tokenizer.eos_token_id
|
||||
self._block_size = cfg.kvcache_block_size
|
||||
self._cache = CachePool(self._pipeline._num_cache_blocks, cfg.kvcache_block_size)
|
||||
self._active_slots: list[GenerationSlot] = []
|
||||
atexit.register(self.exit)
|
||||
|
||||
def exit(self):
|
||||
self._pipeline.shutdown()
|
||||
|
||||
def reset(self):
|
||||
"""Release all KV cache blocks (call on error to prevent leaks)."""
|
||||
for slot in self._active_slots:
|
||||
if slot.cache_blocks:
|
||||
self._cache.release(slot)
|
||||
self._active_slots.clear()
|
||||
|
||||
# -- Public generate --------------------------------------------------
|
||||
|
||||
def generate(self, prompts, sampling_params, use_tqdm=True, unconditional_prompts=None):
|
||||
"""Generate completions for a batch of prompts.
|
||||
|
||||
Returns list of dicts with ``"text"`` and ``"token_ids"`` keys.
|
||||
"""
|
||||
with self._lock:
|
||||
return self._run_generation(prompts, sampling_params, use_tqdm, unconditional_prompts)
|
||||
|
||||
# -- Internal generation logic ----------------------------------------
|
||||
|
||||
def _prepare_slots(self, prompts, sampling_params, unconditional_prompts):
|
||||
"""Tokenise prompts, create slots, allocate KV cache."""
|
||||
if not isinstance(sampling_params, list):
|
||||
sampling_params = [sampling_params] * len(prompts)
|
||||
if unconditional_prompts is None:
|
||||
unconditional_prompts = [None] * len(prompts)
|
||||
if len(sampling_params) != len(prompts):
|
||||
raise ValueError(
|
||||
f"sampling_params length ({len(sampling_params)}) != prompts length ({len(prompts)})"
|
||||
)
|
||||
if len(unconditional_prompts) != len(prompts):
|
||||
raise ValueError(
|
||||
f"unconditional_prompts length ({len(unconditional_prompts)}) != prompts length ({len(prompts)})"
|
||||
)
|
||||
|
||||
all_slots = []
|
||||
for prompt, sp, uncond in zip(prompts, sampling_params, unconditional_prompts):
|
||||
ids = self.tokenizer.encode(prompt) if isinstance(prompt, str) else prompt
|
||||
bs = getattr(self, '_block_size', 256)
|
||||
cond = GenerationSlot(ids, sp, block_size=bs)
|
||||
if sp.cfg_scale > 1.0:
|
||||
u_ids = (self.tokenizer.encode(uncond) if isinstance(uncond, str)
|
||||
else (uncond if uncond is not None else ids))
|
||||
uncond_slot = GenerationSlot(u_ids, sp, is_unconditional=True, block_size=bs)
|
||||
cond.paired_slot = uncond_slot
|
||||
uncond_slot.paired_slot = cond
|
||||
all_slots.extend([cond, uncond_slot])
|
||||
else:
|
||||
all_slots.append(cond)
|
||||
|
||||
total_blocks = sum(s.required_blocks for s in all_slots)
|
||||
if not self._cache.has_capacity(total_blocks):
|
||||
raise RuntimeError(
|
||||
f"Insufficient KV cache: need {total_blocks} blocks, "
|
||||
f"have {len(self._cache.available)}/{self._cache.total}"
|
||||
)
|
||||
for slot in all_slots:
|
||||
self._cache.reserve(slot)
|
||||
slot.status = SlotStatus.ACTIVE
|
||||
self._active_slots = list(all_slots)
|
||||
return all_slots
|
||||
|
||||
def _arrange_guidance_batch(self, slots):
|
||||
"""Order: normal, then CFG conditional, then CFG unconditional."""
|
||||
normal = [s for s in slots if s.cfg_scale <= 1.0]
|
||||
cond = [s for s in slots if s.cfg_scale > 1.0 and not s.is_unconditional]
|
||||
uncond = [s for s in slots if s.is_unconditional]
|
||||
return normal + cond + uncond
|
||||
|
||||
def _generation_steps(self, all_slots):
|
||||
"""Generator yielding (ordered_batch, token_ids, elapsed, is_prefill) per step."""
|
||||
ordered = self._arrange_guidance_batch(all_slots)
|
||||
t = perf_counter()
|
||||
tids = self._pipeline.execute_step(ordered, is_prefill=True)
|
||||
yield ordered, tids, perf_counter() - t, True
|
||||
|
||||
while self._active_slots:
|
||||
for slot in self._active_slots:
|
||||
self._cache.grow_if_needed(slot)
|
||||
batch = self._arrange_guidance_batch(self._active_slots)
|
||||
t = perf_counter()
|
||||
tids = self._pipeline.execute_step(batch, is_prefill=False)
|
||||
yield batch, tids, perf_counter() - t, False
|
||||
|
||||
def _finalize_step(self, batch, token_ids, outputs, pbar):
|
||||
"""Append tokens, check EOS, collect finished outputs."""
|
||||
is_cfg = (len(batch) > 0 and batch[0].cfg_scale > 1.0
|
||||
and batch[0].paired_slot is not None)
|
||||
if is_cfg:
|
||||
nc = len(batch) // 2
|
||||
for cond, uncond, tid in zip(batch[:nc], batch[nc:], token_ids):
|
||||
cond.push_token(tid)
|
||||
uncond.push_token(tid)
|
||||
done = ((not cond.ignore_eos and tid == self._eos) or
|
||||
cond.num_tokens - cond.prompt_length >= cond.max_tokens)
|
||||
if done:
|
||||
for s in (cond, uncond):
|
||||
s.status = SlotStatus.DONE
|
||||
self._cache.release(s)
|
||||
if s in self._active_slots:
|
||||
self._active_slots.remove(s)
|
||||
outputs[cond.slot_id] = cond.generated_ids
|
||||
if pbar:
|
||||
pbar.update(1)
|
||||
else:
|
||||
for slot, tid in zip(batch, token_ids):
|
||||
slot.push_token(tid)
|
||||
done = ((not slot.ignore_eos and tid == self._eos) or
|
||||
slot.num_tokens - slot.prompt_length >= slot.max_tokens)
|
||||
if done:
|
||||
slot.status = SlotStatus.DONE
|
||||
self._cache.release(slot)
|
||||
if slot in self._active_slots:
|
||||
self._active_slots.remove(slot)
|
||||
outputs[slot.slot_id] = slot.generated_ids
|
||||
if pbar:
|
||||
pbar.update(1)
|
||||
|
||||
def _run_generation(self, prompts, sampling_params, use_tqdm, unconditional_prompts):
|
||||
if self._active_slots:
|
||||
self.reset()
|
||||
|
||||
all_slots = self._prepare_slots(prompts, sampling_params, unconditional_prompts)
|
||||
pbar = tqdm(total=len(prompts), desc="Generating", dynamic_ncols=True) if use_tqdm else None
|
||||
prefill_tps = decode_tps = 0.0
|
||||
outputs = {}
|
||||
|
||||
try:
|
||||
for batch, token_ids, elapsed, is_pf in self._generation_steps(all_slots):
|
||||
elapsed = max(elapsed, 1e-9)
|
||||
if is_pf:
|
||||
prefill_tps = sum(len(s) for s in batch) / elapsed
|
||||
else:
|
||||
n_cond = sum(1 for s in batch if not s.is_unconditional)
|
||||
decode_tps = n_cond / elapsed
|
||||
|
||||
self._finalize_step(batch, token_ids, outputs, pbar)
|
||||
if pbar:
|
||||
pbar.set_postfix(Prefill=f"{int(prefill_tps)}tok/s",
|
||||
Decode=f"{int(decode_tps)}tok/s")
|
||||
if not self._active_slots:
|
||||
break
|
||||
except Exception:
|
||||
self.reset()
|
||||
raise
|
||||
finally:
|
||||
if pbar:
|
||||
pbar.close()
|
||||
|
||||
result = [outputs[sid] for sid in sorted(outputs)]
|
||||
return [{"text": self.tokenizer.decode(tids), "token_ids": tids} for tids in result]
|
||||
|
|
@ -1,426 +0,0 @@
|
|||
"""Inference pipeline: model loading, KV cache provisioning, CUDA graphs, and forward passes."""
|
||||
|
||||
import torch
|
||||
import sys
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from acestep.customized_vllm.transformer import CausalTransformer, load_weights
|
||||
from acestep.debug_utils import debug_start, debug_end
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Nucleus / top-k filtering (always receives tensors, never None)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _filter_by_top_k(logits, k):
|
||||
"""Top-k filtering without full vocabulary sort."""
|
||||
vocab_size = logits.shape[1]
|
||||
skip = (k <= 0) | (k >= vocab_size)
|
||||
k_safe = k.masked_fill(skip, 1).long()
|
||||
max_k = int(k_safe.max().clamp(max=vocab_size))
|
||||
topk_vals = logits.topk(max_k, dim=1).values
|
||||
thresh = topk_vals.gather(1, (k_safe - 1).clamp(0, max_k - 1).unsqueeze(1))
|
||||
thresh.masked_fill_(skip.unsqueeze(1), float("-inf"))
|
||||
logits.masked_fill_(logits < thresh, float("-inf"))
|
||||
return logits
|
||||
|
||||
|
||||
def _filter_by_nucleus(logits, k, p):
|
||||
"""Combined top-k and nucleus (top-p) filtering.
|
||||
|
||||
Parameters are always tensors (never None).
|
||||
k=0 means skip top-k, p=1.0 means skip top-p.
|
||||
Note: NOT compiled because .any() and int() cause graph breaks in dynamo.
|
||||
"""
|
||||
has_k = (k > 0).any()
|
||||
has_p = (p < 1.0).any()
|
||||
if not has_p and not has_k:
|
||||
return logits
|
||||
if not has_p:
|
||||
return _filter_by_top_k(logits, k)
|
||||
|
||||
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
|
||||
|
||||
if has_k:
|
||||
vocab_size = logits_sort.size(1)
|
||||
k_clamped = k.clamp(1, vocab_size).long()
|
||||
thresh = logits_sort.gather(1, (vocab_size - k_clamped).unsqueeze(1))
|
||||
logits_sort.masked_fill_(logits_sort < thresh, float("-inf"))
|
||||
|
||||
probs_sum = logits_sort.softmax(dim=-1).cumsum_(dim=-1)
|
||||
mask = probs_sum <= (1.0 - p.unsqueeze(1))
|
||||
mask[:, -1] = False
|
||||
logits_sort.masked_fill_(mask, float("-inf"))
|
||||
logits.scatter_(dim=-1, index=logits_idx, src=logits_sort)
|
||||
return logits
|
||||
|
||||
|
||||
def sample_tokens(logits, temperatures, top_ks, top_ps):
|
||||
"""Temperature-scaled sampling with top-k/top-p filtering.
|
||||
|
||||
Uses Gumbel-max trick for efficient categorical sampling.
|
||||
"""
|
||||
logits = logits.float().div_(temperatures.unsqueeze(1))
|
||||
_filter_by_nucleus(logits, top_ks, top_ps)
|
||||
probs = torch.softmax(logits, dim=-1)
|
||||
return probs.div_(torch.empty_like(probs).exponential_(1).clamp_min_(1e-10)).argmax(dim=-1)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Inference pipeline
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class InferencePipeline:
|
||||
"""Loads a model, provisions KV cache, captures CUDA graphs, and runs forward passes."""
|
||||
|
||||
def __init__(self, hf_config, model_path: str, block_size: int, max_num_seqs: int,
|
||||
max_num_batched_tokens: int, max_model_len: int, gpu_memory_utilization: float,
|
||||
enforce_eager: bool):
|
||||
torch._dynamo.config.capture_scalar_outputs = True
|
||||
torch._dynamo.config.verbose = True
|
||||
self.block_size = block_size
|
||||
self.enforce_eager = enforce_eager
|
||||
self.max_num_seqs = max_num_seqs
|
||||
self.max_model_len = max_model_len
|
||||
self.max_num_batched_tokens = max_num_batched_tokens
|
||||
self.gpu_memory_utilization = gpu_memory_utilization
|
||||
self.hf_config = hf_config
|
||||
|
||||
torch.cuda.set_device(0)
|
||||
saved_dtype = torch.get_default_dtype()
|
||||
|
||||
gpu_props = torch.cuda.get_device_properties(0)
|
||||
bf16_ok = (gpu_props.major, gpu_props.minor) >= (8, 0)
|
||||
raw = getattr(hf_config, "dtype", getattr(hf_config, "torch_dtype", None))
|
||||
if isinstance(raw, str):
|
||||
_map = {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32}
|
||||
raw = _map.get(raw.replace("torch.", ""), None)
|
||||
self.dtype = (raw if isinstance(raw, torch.dtype) and raw.is_floating_point else
|
||||
torch.bfloat16 if bf16_ok else torch.float16)
|
||||
if self.dtype == torch.bfloat16 and not bf16_ok:
|
||||
self.dtype = torch.float16
|
||||
|
||||
torch.set_default_dtype(self.dtype)
|
||||
torch.set_default_device("cuda")
|
||||
try:
|
||||
self.model = CausalTransformer(hf_config)
|
||||
_t = debug_start("load_model", prefix="tensor.vllm")
|
||||
load_weights(self.model, model_path)
|
||||
debug_end("load_model", _t, prefix="tensor.vllm")
|
||||
|
||||
self._init_transfer_buffers()
|
||||
self._warmup_pipeline()
|
||||
self._provision_kv_storage()
|
||||
if not enforce_eager:
|
||||
self._compile_execution_graphs()
|
||||
finally:
|
||||
torch.set_default_device("cpu")
|
||||
torch.set_default_dtype(saved_dtype)
|
||||
|
||||
# -- Transfer buffers ------------------------------------------------
|
||||
|
||||
def _init_transfer_buffers(self):
|
||||
"""Pre-allocate pinned CPU buffers used to shuttle data to GPU."""
|
||||
bs = self.max_num_seqs
|
||||
max_blocks = (self.max_model_len + self.block_size - 1) // self.block_size
|
||||
pin = dict(dtype=torch.float32, device="cpu", pin_memory=True)
|
||||
pin_i32 = dict(dtype=torch.int32, device="cpu", pin_memory=True)
|
||||
pin_i64 = dict(dtype=torch.int64, device="cpu", pin_memory=True)
|
||||
# Dict-based buffer storage (structurally different from per-attribute style)
|
||||
self._xfer = {
|
||||
"temps": torch.zeros(bs, **pin),
|
||||
"guidance": torch.zeros(bs, **pin),
|
||||
"top_k": torch.zeros(bs, **pin_i32),
|
||||
"top_p": torch.zeros(bs, **pin),
|
||||
"rep_pen": torch.zeros(bs, **pin),
|
||||
"token_ids": torch.zeros(bs, **pin_i64),
|
||||
"positions": torch.zeros(bs, **pin_i64),
|
||||
"slots": torch.zeros(bs, **pin_i32),
|
||||
"ctx_lens": torch.zeros(bs, **pin_i32),
|
||||
}
|
||||
|
||||
# -- Warmup & KV storage ---------------------------------------------
|
||||
|
||||
def _warmup_pipeline(self):
|
||||
from acestep.customized_vllm import GenerationSlot, reset_context
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
n = min(self.max_num_batched_tokens // self.max_model_len, self.max_num_seqs)
|
||||
dummy_slots = [GenerationSlot([0] * self.max_model_len, block_size=self.block_size) for _ in range(n)]
|
||||
self._execute_prefill(dummy_slots)
|
||||
reset_context()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def _provision_kv_storage(self):
|
||||
_t = debug_start("allocate_kv_cache", prefix="tensor.vllm")
|
||||
hf = self.hf_config
|
||||
free, total = torch.cuda.mem_get_info()
|
||||
current = torch.cuda.memory_stats()["allocated_bytes.all.current"]
|
||||
|
||||
import os
|
||||
sim = os.environ.get("MAX_CUDA_VRAM")
|
||||
if sim:
|
||||
try:
|
||||
cap = float(sim) * 1024**3
|
||||
if cap < total:
|
||||
total = int(cap)
|
||||
free = max(0, total - torch.cuda.memory_reserved())
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
num_kv_heads = hf.num_key_value_heads
|
||||
head_dim = getattr(hf, "head_dim", hf.hidden_size // hf.num_attention_heads)
|
||||
block_bytes = 2 * hf.num_hidden_layers * self.block_size * num_kv_heads * head_dim * self.dtype.itemsize
|
||||
|
||||
target = total * self.gpu_memory_utilization
|
||||
avail = min(free * 0.9, target - current, max(0, free - 1024**3) * 0.9)
|
||||
if avail <= 0:
|
||||
avail = free * 0.5
|
||||
|
||||
self._num_cache_blocks = max(1, int(avail) // block_bytes)
|
||||
cap = self._num_cache_blocks * self.block_size
|
||||
gb = self._num_cache_blocks * block_bytes / 1024**3
|
||||
logger.info(f"[customized_vllm] KV cache: {self._num_cache_blocks} blocks, "
|
||||
f"{cap} tokens, {gb:.2f} GB")
|
||||
|
||||
self._kv_storage = torch.empty(
|
||||
2, hf.num_hidden_layers, self._num_cache_blocks,
|
||||
self.block_size, num_kv_heads, head_dim,
|
||||
)
|
||||
layer_id = 0
|
||||
for m in self.model.modules():
|
||||
if hasattr(m, "k_cache") and hasattr(m, "v_cache"):
|
||||
m.k_cache = self._kv_storage[0, layer_id]
|
||||
m.v_cache = self._kv_storage[1, layer_id]
|
||||
layer_id += 1
|
||||
debug_end("allocate_kv_cache", _t, prefix="tensor.vllm")
|
||||
|
||||
# -- Input preparation -----------------------------------------------
|
||||
|
||||
def _build_cache_index(self, slots):
|
||||
max_len = max(len(s.cache_blocks) for s in slots)
|
||||
rows = [s.cache_blocks + [-1] * (max_len - len(s.cache_blocks)) for s in slots]
|
||||
return torch.tensor(rows, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||
|
||||
def _execute_prefill(self, slots):
|
||||
"""Prepare prefill inputs and run model forward, returning logits."""
|
||||
from acestep.customized_vllm import _set_forward_state
|
||||
ids, pos, cu_q, cu_k = [], [], [0], [0]
|
||||
max_sq = max_sk = 0
|
||||
slot_map = []
|
||||
for s in slots:
|
||||
n = len(s)
|
||||
ids.extend(s.token_ids)
|
||||
pos.extend(range(n))
|
||||
cu_q.append(cu_q[-1] + n)
|
||||
cu_k.append(cu_k[-1] + n)
|
||||
max_sq = max(n, max_sq)
|
||||
max_sk = max(n, max_sk)
|
||||
for i in range(s.required_blocks):
|
||||
if not s.cache_blocks:
|
||||
continue
|
||||
start = s.cache_blocks[i] * self.block_size
|
||||
end = start + (s.tail_block_fill if i == s.required_blocks - 1 else self.block_size)
|
||||
slot_map.extend(range(start, end))
|
||||
|
||||
ids = torch.tensor(ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
||||
pos = torch.tensor(pos, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
||||
cu_q = torch.tensor(cu_q, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||
cu_k = torch.tensor(cu_k, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||
sm = torch.tensor(slot_map, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||
_set_forward_state(True, cu_seqlens_q=cu_q, cu_seqlens_k=cu_k,
|
||||
max_seqlen_q=max_sq, max_seqlen_k=max_sk, slot_mapping=sm)
|
||||
return self._forward_pass(ids, pos, is_prefill=True)
|
||||
|
||||
def _execute_autoregressive(self, slots):
|
||||
"""Prepare single-token decode inputs and run model forward, returning logits."""
|
||||
from acestep.customized_vllm import _set_forward_state
|
||||
bs = len(slots)
|
||||
xfer = self._xfer
|
||||
for i, s in enumerate(slots):
|
||||
xfer["token_ids"][i] = s.last_token
|
||||
xfer["positions"][i] = len(s) - 1
|
||||
xfer["ctx_lens"][i] = len(s)
|
||||
xfer["slots"][i] = s.cache_blocks[-1] * self.block_size + s.tail_block_fill - 1
|
||||
|
||||
ids = xfer["token_ids"][:bs].cuda(non_blocking=True)
|
||||
pos = xfer["positions"][:bs].cuda(non_blocking=True)
|
||||
sm = xfer["slots"][:bs].cuda(non_blocking=True)
|
||||
cl = xfer["ctx_lens"][:bs].cuda(non_blocking=True)
|
||||
bt = self._build_cache_index(slots)
|
||||
_set_forward_state(False, slot_mapping=sm, context_lens=cl, block_tables=bt)
|
||||
return self._forward_pass(ids, pos, is_prefill=False)
|
||||
|
||||
# -- Model forward ---------------------------------------------------
|
||||
|
||||
@torch.inference_mode()
|
||||
def _forward_pass(self, input_ids, positions, is_prefill):
|
||||
from acestep.customized_vllm import _get_forward_state
|
||||
if is_prefill or self.enforce_eager or input_ids.size(0) > 512:
|
||||
return self.model.project_to_vocab(self.model(input_ids, positions))
|
||||
|
||||
bs = input_ids.size(0)
|
||||
state = _get_forward_state()
|
||||
gio = self._graph_io
|
||||
max_cols = gio["block_tables"].size(1)
|
||||
if (state.block_tables.size(1) > max_cols or state.block_tables.size(0) != bs
|
||||
or state.slot_mapping.size(0) != bs or state.context_lens.size(0) != bs):
|
||||
return self.model.project_to_vocab(self.model(input_ids, positions))
|
||||
|
||||
graph = self._graphs[next(x for x in self._compiled_sizes if x >= bs)]
|
||||
gio["input_ids"][:bs] = input_ids
|
||||
gio["positions"][:bs] = positions
|
||||
gio["slot_mapping"].fill_(-1)
|
||||
gio["slot_mapping"][:bs] = state.slot_mapping
|
||||
gio["context_lens"].zero_()
|
||||
gio["context_lens"][:bs] = state.context_lens
|
||||
gio["block_tables"][:bs].fill_(-1)
|
||||
gio["block_tables"][:bs, :state.block_tables.size(1)] = state.block_tables
|
||||
graph.replay()
|
||||
return self.model.project_to_vocab(gio["outputs"][:bs])
|
||||
|
||||
# -- Sampling helpers ------------------------------------------------
|
||||
|
||||
def _gather_sampling_config(self, slots, is_cfg):
|
||||
"""Pack per-slot sampling parameters into GPU tensors.
|
||||
|
||||
Always returns tensors (never None) so that `sample_tokens` receives
|
||||
a stable signature for torch.compile.
|
||||
"""
|
||||
targets = slots[:len(slots) // 2] if is_cfg else slots
|
||||
n = len(targets)
|
||||
xfer = self._xfer
|
||||
for i, s in enumerate(targets):
|
||||
xfer["temps"][i] = s.temperature
|
||||
xfer["guidance"][i] = s.cfg_scale
|
||||
xfer["top_k"][i] = s.top_k if s.top_k else 0
|
||||
xfer["top_p"][i] = s.top_p if s.top_p else 1.0
|
||||
xfer["rep_pen"][i] = s.repetition_penalty if s.repetition_penalty else 1.0
|
||||
return (
|
||||
xfer["temps"][:n].cuda(non_blocking=True),
|
||||
xfer["guidance"][:n].cuda(non_blocking=True),
|
||||
xfer["top_k"][:n].cuda(non_blocking=True), # always a tensor
|
||||
xfer["top_p"][:n].cuda(non_blocking=True), # always a tensor
|
||||
xfer["rep_pen"][:n].cuda(non_blocking=True),
|
||||
)
|
||||
|
||||
def _constrain_logits(self, logits, slots):
|
||||
"""Apply per-slot logits processors.
|
||||
|
||||
Each slot with a non-None logits_processor is processed individually
|
||||
using its own token history. In typical ACE-Step usage all batch
|
||||
slots share the same processor instance, but this handles the general
|
||||
case correctly.
|
||||
"""
|
||||
if not slots:
|
||||
return logits
|
||||
try:
|
||||
for i, slot in enumerate(slots):
|
||||
if slot.logits_processor is None:
|
||||
continue
|
||||
ids_t = torch.tensor([slot.token_ids], device=logits.device)
|
||||
processed = slot.logits_processor(ids_t, logits[i:i+1].clone())
|
||||
logits[i] = processed[0]
|
||||
except TypeError:
|
||||
import traceback
|
||||
logger.error(f"TypeError in _constrain_logits "
|
||||
f"(n_slots={len(slots)}, slot_idx={i}, "
|
||||
f"processor_state={getattr(slot.logits_processor, 'state', '?')}):\n"
|
||||
f"{traceback.format_exc()}")
|
||||
raise
|
||||
return logits
|
||||
|
||||
def _penalize_repetitions(self, logits, slots, penalties):
|
||||
if penalties is None:
|
||||
return logits
|
||||
for i, slot in enumerate(slots):
|
||||
p = penalties[i].item()
|
||||
if p == 1.0:
|
||||
continue
|
||||
comp = torch.tensor(slot.generated_ids, device=logits.device)
|
||||
if len(comp) == 0:
|
||||
continue
|
||||
mask = torch.zeros(logits.shape[1], dtype=torch.bool, device=logits.device)
|
||||
mask[comp] = True
|
||||
penalized = torch.where(logits[i] < 0, logits[i] * p, logits[i] / p)
|
||||
logits[i] = torch.where(mask, penalized, logits[i])
|
||||
return logits
|
||||
|
||||
# -- Main step -------------------------------------------------------
|
||||
|
||||
def execute_step(self, slots, is_prefill):
|
||||
"""Full forward + sampling step. Returns list of sampled token IDs."""
|
||||
from acestep.customized_vllm import reset_context
|
||||
try:
|
||||
is_cfg = slots[0].cfg_scale > 1.0 and slots[0].paired_slot is not None
|
||||
try:
|
||||
logits = (self._execute_prefill(slots) if is_prefill
|
||||
else self._execute_autoregressive(slots))
|
||||
finally:
|
||||
reset_context()
|
||||
temps, cfg_s, topk, topp, rep_pen = self._gather_sampling_config(slots, is_cfg)
|
||||
|
||||
if is_cfg:
|
||||
nc = len(slots) // 2
|
||||
cond, uncond = logits[:nc], logits[nc:]
|
||||
cond = self._penalize_repetitions(cond, slots[:nc], rep_pen)
|
||||
cfg_logits = uncond + cfg_s.unsqueeze(1) * (cond - uncond)
|
||||
cfg_logits = torch.nan_to_num(cfg_logits, nan=float('-inf'))
|
||||
cfg_logits = self._constrain_logits(cfg_logits, slots[:nc])
|
||||
tids = sample_tokens(cfg_logits, temps, topk, topp).tolist()
|
||||
for i, s in enumerate(slots[:nc]):
|
||||
if s.logits_processor_update_state:
|
||||
s.logits_processor_update_state(tids[i])
|
||||
return tids
|
||||
|
||||
logits = self._penalize_repetitions(logits, slots, rep_pen)
|
||||
logits = self._constrain_logits(logits.clone(), slots)
|
||||
tids = sample_tokens(logits, temps, topk, topp).tolist()
|
||||
for i, s in enumerate(slots):
|
||||
if s.logits_processor_update_state:
|
||||
s.logits_processor_update_state(tids[i])
|
||||
return tids
|
||||
except TypeError:
|
||||
import traceback
|
||||
logger.error(f"TypeError in execute_step "
|
||||
f"(prefill={is_prefill}, n_slots={len(slots)}):\n"
|
||||
f"{traceback.format_exc()}")
|
||||
raise
|
||||
|
||||
# -- CUDA graph capture ----------------------------------------------
|
||||
|
||||
@torch.inference_mode()
|
||||
def _compile_execution_graphs(self):
|
||||
from acestep.customized_vllm import _set_forward_state, reset_context
|
||||
_t = debug_start("capture_cudagraph", prefix="tensor.vllm")
|
||||
max_bs = min(self.max_num_seqs, 512)
|
||||
max_blocks = (self.max_model_len + self.block_size - 1) // self.block_size
|
||||
ids = torch.zeros(max_bs, dtype=torch.int64)
|
||||
pos = torch.zeros(max_bs, dtype=torch.int64)
|
||||
sm = torch.zeros(max_bs, dtype=torch.int32)
|
||||
cl = torch.zeros(max_bs, dtype=torch.int32)
|
||||
bt = torch.zeros(max_bs, max_blocks, dtype=torch.int32)
|
||||
out = torch.zeros(max_bs, self.hf_config.hidden_size)
|
||||
self._compiled_sizes = sorted(set([1, 2, 4, 8] + list(range(16, max_bs + 1, 16)) + [max_bs]))
|
||||
self._graphs = {}
|
||||
pool = None
|
||||
for bs in reversed(self._compiled_sizes):
|
||||
g = torch.cuda.CUDAGraph()
|
||||
_set_forward_state(False, slot_mapping=sm[:bs], context_lens=cl[:bs], block_tables=bt[:bs])
|
||||
out[:bs] = self.model(ids[:bs], pos[:bs])
|
||||
with torch.cuda.graph(g, pool):
|
||||
out[:bs] = self.model(ids[:bs], pos[:bs])
|
||||
if pool is None:
|
||||
pool = g.pool()
|
||||
self._graphs[bs] = g
|
||||
torch.cuda.synchronize()
|
||||
reset_context()
|
||||
self._graph_io = dict(input_ids=ids, positions=pos, slot_mapping=sm,
|
||||
context_lens=cl, block_tables=bt, outputs=out)
|
||||
debug_end("capture_cudagraph", _t, prefix="tensor.vllm")
|
||||
|
||||
def shutdown(self):
|
||||
if not self.enforce_eager:
|
||||
del self._graphs, self._graph_io
|
||||
torch.cuda.synchronize()
|
||||
|
|
@ -1,416 +0,0 @@
|
|||
"""Causal transformer architecture with paged KV cache attention.
|
||||
|
||||
Implements a Qwen3-compatible transformer with:
|
||||
- Fused QKV and gate/up projections for efficient weight loading
|
||||
- Paged KV cache with Flash Attention / SDPA backends
|
||||
- RoPE position encoding with compiled forward passes
|
||||
- Safetensors weight loading with shard-aware mapping
|
||||
"""
|
||||
|
||||
import os
|
||||
from functools import lru_cache
|
||||
from glob import glob
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from safetensors import safe_open
|
||||
from transformers import Qwen3Config
|
||||
|
||||
_HAS_TRITON = False
|
||||
_HAS_FLASH_ATTN = False
|
||||
|
||||
try:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
_HAS_TRITON = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
from flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
|
||||
_HAS_FLASH_ATTN = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# KV cache write operations
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
if _HAS_TRITON:
|
||||
@triton.jit
|
||||
def _triton_kv_write(
|
||||
key_ptr, key_stride, value_ptr, value_stride,
|
||||
k_cache_ptr, v_cache_ptr, slot_mapping_ptr, D: tl.constexpr,
|
||||
):
|
||||
idx = tl.program_id(0)
|
||||
slot = tl.load(slot_mapping_ptr + idx)
|
||||
if slot == -1:
|
||||
return
|
||||
offs = tl.arange(0, D)
|
||||
tl.store(k_cache_ptr + slot * D + offs, tl.load(key_ptr + idx * key_stride + offs))
|
||||
tl.store(v_cache_ptr + slot * D + offs, tl.load(value_ptr + idx * value_stride + offs))
|
||||
|
||||
|
||||
def _torch_kv_write(key, value, k_cache, v_cache, slot_mapping):
|
||||
N, num_kv_heads, head_dim = key.shape
|
||||
D = num_kv_heads * head_dim
|
||||
valid = slot_mapping != -1
|
||||
slots = slot_mapping[valid]
|
||||
k_cache.reshape(-1, D)[slots] = key.reshape(N, D)[valid]
|
||||
v_cache.reshape(-1, D)[slots] = value.reshape(N, D)[valid]
|
||||
|
||||
|
||||
def write_kv_cache(key, value, k_cache, v_cache, slot_mapping):
|
||||
"""Persist key/value tensors into the paged KV cache."""
|
||||
if _HAS_TRITON:
|
||||
N, num_heads, head_dim = key.shape
|
||||
D = num_heads * head_dim
|
||||
_triton_kv_write[(N,)](
|
||||
key, key.stride(0), value, value.stride(0),
|
||||
k_cache, v_cache, slot_mapping, D,
|
||||
)
|
||||
else:
|
||||
_torch_kv_write(key, value, k_cache, v_cache, slot_mapping)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SDPA fallback implementations
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _sdpa_packed_prefill(q, k, v, cu_q, cu_k, scale, n_heads, n_kv):
|
||||
results = []
|
||||
gqa = n_heads != n_kv
|
||||
for i in range(cu_q.shape[0] - 1):
|
||||
qs, qe = cu_q[i].item(), cu_q[i + 1].item()
|
||||
ks, ke = cu_k[i].item(), cu_k[i + 1].item()
|
||||
qi = q[qs:qe].unsqueeze(0).transpose(1, 2)
|
||||
ki = k[ks:ke].unsqueeze(0).transpose(1, 2)
|
||||
vi = v[ks:ke].unsqueeze(0).transpose(1, 2)
|
||||
oi = F.scaled_dot_product_attention(qi, ki, vi, scale=scale, is_causal=True, enable_gqa=gqa)
|
||||
results.append(oi.transpose(1, 2).squeeze(0))
|
||||
return torch.cat(results, dim=0)
|
||||
|
||||
|
||||
def _sdpa_cached_decode(q, k_cache, v_cache, ctx_lens, block_tbl, scale, n_heads, n_kv):
|
||||
blk_sz = k_cache.shape[1]
|
||||
results = []
|
||||
gqa = n_heads != n_kv
|
||||
for i in range(q.shape[0]):
|
||||
cl = ctx_lens[i].item()
|
||||
nb = (cl + blk_sz - 1) // blk_sz
|
||||
idx = block_tbl[i, :nb]
|
||||
ki = k_cache[idx].reshape(-1, n_kv, k_cache.shape[-1])[:cl]
|
||||
vi = v_cache[idx].reshape(-1, n_kv, v_cache.shape[-1])[:cl]
|
||||
qi = q[i].unsqueeze(0).transpose(1, 2)
|
||||
ki = ki.unsqueeze(0).transpose(1, 2)
|
||||
vi = vi.unsqueeze(0).transpose(1, 2)
|
||||
oi = F.scaled_dot_product_attention(qi, ki, vi, scale=scale, is_causal=False, enable_gqa=gqa)
|
||||
results.append(oi.transpose(1, 2).squeeze(0))
|
||||
return torch.stack(results, dim=0)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Attention with paged KV cache
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class PagedAttention(nn.Module):
|
||||
"""Multi-head attention with paged KV cache, Flash Attention or SDPA backend."""
|
||||
|
||||
def __init__(self, num_heads, head_dim, num_kv_heads):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self._scale = head_dim ** -0.5
|
||||
self.k_cache = self.v_cache = torch.tensor([])
|
||||
|
||||
def forward(self, q, k, v):
|
||||
from acestep.customized_vllm import _get_forward_state
|
||||
state = _get_forward_state()
|
||||
|
||||
if self.k_cache.numel() and self.v_cache.numel():
|
||||
write_kv_cache(k, v, self.k_cache, self.v_cache, state.slot_mapping)
|
||||
|
||||
if _HAS_FLASH_ATTN:
|
||||
return self._flash_path(q, k, v, state)
|
||||
return self._sdpa_path(q, k, v, state)
|
||||
|
||||
def _flash_path(self, q, k, v, state):
|
||||
if state.is_prefill:
|
||||
return flash_attn_varlen_func(
|
||||
q, k, v,
|
||||
cu_seqlens_q=state.cu_seqlens_q, cu_seqlens_k=state.cu_seqlens_k,
|
||||
max_seqlen_q=state.max_seqlen_q, max_seqlen_k=state.max_seqlen_k,
|
||||
softmax_scale=self._scale, causal=True,
|
||||
)
|
||||
return flash_attn_with_kvcache(
|
||||
q.unsqueeze(1), self.k_cache, self.v_cache,
|
||||
cache_seqlens=state.context_lens, block_table=state.block_tables,
|
||||
softmax_scale=self._scale, causal=True,
|
||||
)
|
||||
|
||||
def _sdpa_path(self, q, k, v, state):
|
||||
if state.is_prefill:
|
||||
return _sdpa_packed_prefill(
|
||||
q, k, v, state.cu_seqlens_q, state.cu_seqlens_k,
|
||||
self._scale, self.num_heads, self.num_kv_heads,
|
||||
)
|
||||
return _sdpa_cached_decode(
|
||||
q.unsqueeze(1), self.k_cache, self.v_cache,
|
||||
state.context_lens, state.block_tables,
|
||||
self._scale, self.num_heads, self.num_kv_heads,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Core neural-network layers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class NormLayer(nn.Module):
|
||||
"""Root-mean-square layer normalisation with optional fused residual add."""
|
||||
|
||||
def __init__(self, hidden_size: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
|
||||
@torch.compile
|
||||
def _normalize(self, x):
|
||||
dt = x.dtype
|
||||
x = x.float()
|
||||
x.mul_(torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps))
|
||||
return x.to(dt).mul_(self.weight)
|
||||
|
||||
@torch.compile
|
||||
def _fused_add_normalize(self, x, residual):
|
||||
dt = x.dtype
|
||||
x = x.float().add_(residual.float())
|
||||
residual = x.to(dt)
|
||||
x.mul_(torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps))
|
||||
return x.to(dt).mul_(self.weight), residual
|
||||
|
||||
def forward(self, x, residual=None):
|
||||
if residual is None:
|
||||
return self._normalize(x)
|
||||
return self._fused_add_normalize(x, residual)
|
||||
|
||||
|
||||
def _rotate(x, cos, sin):
|
||||
x1, x2 = torch.chunk(x.float(), 2, dim=-1)
|
||||
return torch.cat((x1 * cos - x2 * sin, x2 * cos + x1 * sin), dim=-1).to(x.dtype)
|
||||
|
||||
|
||||
class PositionEncoding(nn.Module):
|
||||
"""Rotary position embedding (RoPE)."""
|
||||
|
||||
def __init__(self, head_size: int, max_position: int, base: float):
|
||||
super().__init__()
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, head_size, 2, dtype=torch.float) / head_size))
|
||||
freqs = torch.einsum("i,j->ij", torch.arange(max_position, dtype=torch.float), inv_freq)
|
||||
cache = torch.cat((freqs.cos(), freqs.sin()), dim=-1).unsqueeze_(1)
|
||||
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
||||
|
||||
@torch.compile
|
||||
def forward(self, positions, query, key):
|
||||
cs = self.cos_sin_cache[positions]
|
||||
cos, sin = cs.chunk(2, dim=-1)
|
||||
return _rotate(query, cos, sin), _rotate(key, cos, sin)
|
||||
|
||||
|
||||
@lru_cache(1)
|
||||
def get_position_encoder(head_size: int, max_position: int, base: float):
|
||||
return PositionEncoding(head_size, max_position, base)
|
||||
|
||||
|
||||
class GatedActivation(nn.Module):
|
||||
"""SiLU-gated activation: SiLU(x) * y where [x, y] = chunk(input, 2)."""
|
||||
|
||||
@torch.compile
|
||||
def forward(self, x):
|
||||
a, b = x.chunk(2, -1)
|
||||
return F.silu(a) * b
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fused projections with shard-aware weight loaders
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _FusedQKVProjection(hidden_size, num_heads, num_kv_heads, head_dim, bias):
|
||||
q_size = num_heads * head_dim
|
||||
kv_size = num_kv_heads * head_dim
|
||||
proj = nn.Linear(hidden_size, q_size + 2 * kv_size, bias=bias)
|
||||
|
||||
def _load_shard(param, weight, shard_id):
|
||||
offsets = {"q": 0, "k": q_size, "v": q_size + kv_size}
|
||||
sizes = {"q": q_size, "k": kv_size, "v": kv_size}
|
||||
param.data.narrow(0, offsets[shard_id], sizes[shard_id]).copy_(weight)
|
||||
|
||||
proj.weight.weight_loader = _load_shard
|
||||
if bias:
|
||||
proj.bias.weight_loader = _load_shard
|
||||
return proj
|
||||
|
||||
|
||||
def _FusedGateUpProjection(hidden_size, intermediate_size):
|
||||
proj = nn.Linear(hidden_size, 2 * intermediate_size, bias=False)
|
||||
|
||||
def _load_shard(param, weight, shard_id):
|
||||
param.data.narrow(0, shard_id * intermediate_size, intermediate_size).copy_(weight)
|
||||
|
||||
proj.weight.weight_loader = _load_shard
|
||||
return proj
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Transformer blocks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class _SelfAttention(nn.Module):
|
||||
"""Self-attention block with fused QKV, RoPE, and optional QK-norm."""
|
||||
|
||||
def __init__(self, config: Qwen3Config):
|
||||
super().__init__()
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.num_kv_heads = config.num_key_value_heads
|
||||
self.head_dim = getattr(config, "head_dim", config.hidden_size // self.num_heads)
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
has_bias = getattr(config, "attention_bias", True)
|
||||
|
||||
self.qkv_proj = _FusedQKVProjection(
|
||||
config.hidden_size, self.num_heads, self.num_kv_heads, self.head_dim, has_bias,
|
||||
)
|
||||
self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False)
|
||||
self.rope = get_position_encoder(
|
||||
self.head_dim, config.max_position_embeddings,
|
||||
getattr(config, "rope_theta", 1000000),
|
||||
)
|
||||
self.attn = PagedAttention(self.num_heads, self.head_dim, self.num_kv_heads)
|
||||
if not has_bias:
|
||||
self.q_norm = NormLayer(self.head_dim, eps=config.rms_norm_eps)
|
||||
self.k_norm = NormLayer(self.head_dim, eps=config.rms_norm_eps)
|
||||
self._has_bias = has_bias
|
||||
|
||||
def forward(self, positions, hidden_states):
|
||||
qkv = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q = q.view(-1, self.num_heads, self.head_dim)
|
||||
k = k.view(-1, self.num_kv_heads, self.head_dim)
|
||||
v = v.view(-1, self.num_kv_heads, self.head_dim)
|
||||
if not self._has_bias:
|
||||
q, k = self.q_norm(q), self.k_norm(k)
|
||||
q, k = self.rope(positions, q, k)
|
||||
return self.o_proj(self.attn(q, k, v).flatten(1, -1))
|
||||
|
||||
|
||||
class _FeedForward(nn.Module):
|
||||
"""MLP with fused gate/up projection and SiLU gating."""
|
||||
|
||||
def __init__(self, config: Qwen3Config):
|
||||
super().__init__()
|
||||
self.gate_up_proj = _FusedGateUpProjection(config.hidden_size, config.intermediate_size)
|
||||
self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
|
||||
self.act = GatedActivation()
|
||||
|
||||
def forward(self, x):
|
||||
return self.down_proj(self.act(self.gate_up_proj(x)))
|
||||
|
||||
|
||||
class _TransformerBlock(nn.Module):
|
||||
def __init__(self, config: Qwen3Config):
|
||||
super().__init__()
|
||||
self.self_attn = _SelfAttention(config)
|
||||
self.mlp = _FeedForward(config)
|
||||
self.input_layernorm = NormLayer(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = NormLayer(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def forward(self, positions, hidden_states, residual):
|
||||
if residual is None:
|
||||
hidden_states, residual = self.input_layernorm(hidden_states), hidden_states
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||
hidden_states = self.self_attn(positions, hidden_states)
|
||||
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
||||
return self.mlp(hidden_states), residual
|
||||
|
||||
|
||||
class CausalTransformer(nn.Module):
|
||||
"""Causal language model for inference (Qwen3-compatible)."""
|
||||
|
||||
WEIGHT_SHARD_MAP = {
|
||||
"q_proj": ("qkv_proj", "q"),
|
||||
"k_proj": ("qkv_proj", "k"),
|
||||
"v_proj": ("qkv_proj", "v"),
|
||||
"gate_proj": ("gate_up_proj", 0),
|
||||
"up_proj": ("gate_up_proj", 1),
|
||||
}
|
||||
|
||||
def __init__(self, config: Qwen3Config):
|
||||
super().__init__()
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
self.layers = nn.ModuleList([_TransformerBlock(config) for _ in range(config.num_hidden_layers)])
|
||||
self.norm = NormLayer(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
if config.tie_word_embeddings:
|
||||
self.lm_head.weight = self.embed_tokens.weight
|
||||
|
||||
def forward(self, input_ids, positions):
|
||||
h = self.embed_tokens(input_ids)
|
||||
residual = None
|
||||
for layer in self.layers:
|
||||
h, residual = layer(positions, h, residual)
|
||||
h, _ = self.norm(h, residual)
|
||||
return h
|
||||
|
||||
def project_to_vocab(self, hidden_states):
|
||||
"""Extract last-token hidden states (during prefill) and project to vocabulary logits."""
|
||||
from acestep.customized_vllm import _get_forward_state
|
||||
state = _get_forward_state()
|
||||
if state.is_prefill:
|
||||
hidden_states = hidden_states[state.cu_seqlens_q[1:] - 1].contiguous()
|
||||
return self.lm_head(hidden_states)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Weight loading
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _copy_weight(param, loaded_weight):
|
||||
param.data.copy_(loaded_weight)
|
||||
|
||||
|
||||
def _resolve_parameter(model, name):
|
||||
for candidate in (name, f"model.{name}", name.removeprefix("model.")):
|
||||
try:
|
||||
return model.get_parameter(candidate)
|
||||
except AttributeError:
|
||||
continue
|
||||
return None
|
||||
|
||||
|
||||
def load_weights(model: nn.Module, path: str):
|
||||
"""Load safetensors weights into model with shard-aware mapping for fused projections."""
|
||||
shard_map = getattr(model, "WEIGHT_SHARD_MAP", {})
|
||||
files = glob(os.path.join(path, "*.safetensors"))
|
||||
if not files:
|
||||
raise FileNotFoundError(f"No .safetensors files found in {path}")
|
||||
|
||||
for filepath in files:
|
||||
with safe_open(filepath, "pt", "cpu") as f:
|
||||
for weight_name in f.keys():
|
||||
for src_key, (dst_key, shard_id) in shard_map.items():
|
||||
if src_key in weight_name:
|
||||
param_name = weight_name.replace(src_key, dst_key)
|
||||
param = _resolve_parameter(model, param_name)
|
||||
if param is None:
|
||||
continue
|
||||
param.weight_loader(param, f.get_tensor(weight_name), shard_id)
|
||||
break
|
||||
else:
|
||||
param = _resolve_parameter(model, weight_name)
|
||||
if param is None:
|
||||
continue
|
||||
loader = getattr(param, "weight_loader", _copy_weight)
|
||||
loader(param, f.get_tensor(weight_name))
|
||||
|
|
@ -877,7 +877,7 @@ def get_lm_gpu_memory_ratio(
|
|||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
# The ratio is relative to total GPU memory (customized_vllm convention),
|
||||
# The ratio is relative to total GPU memory (nano-vllm convention),
|
||||
# but we compute it so that the LM only claims what's actually free
|
||||
# minus a safety margin for DiT inference activations.
|
||||
# Reserve at least 1.5 GB for DiT inference activations
|
||||
|
|
@ -888,7 +888,7 @@ def get_lm_gpu_memory_ratio(
|
|||
usable_for_lm = min(usable_for_lm, total_target_gb)
|
||||
|
||||
# Convert to ratio of total GPU memory
|
||||
# customized_vllm uses: target_total_usage = total * gpu_memory_utilization
|
||||
# nano-vllm uses: target_total_usage = total * gpu_memory_utilization
|
||||
# We want: (total * ratio) = current_usage + usable_for_lm
|
||||
current_usage_gb = actual_total_gb - free_gb
|
||||
desired_total_usage = current_usage_gb + usable_for_lm
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import sys
|
|||
|
||||
|
||||
def _has_working_triton_installation() -> bool:
|
||||
"""Return whether the Triton modules required by customized_vllm import cleanly."""
|
||||
"""Return whether the Triton modules required by nano-vllm import cleanly."""
|
||||
try:
|
||||
importlib.import_module("triton")
|
||||
importlib.import_module("triton.language")
|
||||
|
|
|
|||
|
|
@ -72,7 +72,7 @@ class LlmInitializeBackendCompatTests(unittest.TestCase):
|
|||
_mock_empty_cache: MagicMock,
|
||||
_mock_synchronize: MagicMock,
|
||||
) -> None:
|
||||
"""Initialization should avoid customized_vllm when the Windows Triton preflight fails."""
|
||||
"""Initialization should avoid nano-vllm when the Windows Triton preflight fails."""
|
||||
handler = LLMHandler()
|
||||
mock_tokenizer.return_value = MagicMock()
|
||||
mock_gpu_config.return_value = SimpleNamespace(max_duration_with_lm=600, tier="tier6")
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@ def _warn_if_prerelease_python():
|
|||
if getattr(v, "releaselevel", "final") != "final" and sys.platform.startswith("linux"):
|
||||
warnings.warn(
|
||||
f"Detected pre-release Python {sys.version.split()[0]} ({getattr(v, 'releaselevel', '')}). "
|
||||
"This is known to cause segmentation faults with the vLLM engine on Linux. "
|
||||
"This is known to cause segmentation faults with vLLM/nano-vllm on Linux. "
|
||||
"Please install a stable Python release (e.g. 3.11.12+), or use --backend pt as a workaround.",
|
||||
RuntimeWarning,
|
||||
stacklevel=2,
|
||||
|
|
@ -591,7 +591,7 @@ class LLMHandler:
|
|||
# Disable CUDA/HIP graph capture on ROCm (unverified on RDNA3 Windows),
|
||||
# on Jetson (SDPA paged-cache decode calls .item() during capture),
|
||||
# and when flash_attn is not installed (same .item() incompatibility on all CUDA hardware).
|
||||
# When flash_attn is unavailable, customized_vllm falls back to _sdpa_cached_decode
|
||||
# When flash_attn is unavailable, nano-vllm falls back to _sdpa_decode_with_paged_cache
|
||||
# which contains a Python loop with .item() calls. These force CPU-GPU
|
||||
# synchronisation that is forbidden inside torch.cuda.CUDAGraph capture,
|
||||
# corrupting the CUDA context and causing downstream errors such as:
|
||||
|
|
@ -603,7 +603,7 @@ class LLMHandler:
|
|||
dev_name = torch.cuda.get_device_name(0).lower()
|
||||
is_jetson = any(k in dev_name for k in ("orin", "xavier", "tegra"))
|
||||
if is_jetson:
|
||||
logger.info(f"Jetson GPU detected ({dev_name}): disabling CUDA graph capture for customized_vllm")
|
||||
logger.info(f"Jetson GPU detected ({dev_name}): disabling CUDA graph capture for nano-vllm")
|
||||
except Exception:
|
||||
pass
|
||||
_has_flash_attn = False
|
||||
|
|
@ -614,7 +614,7 @@ class LLMHandler:
|
|||
pass
|
||||
if not _has_flash_attn:
|
||||
logger.info(
|
||||
"flash_attn not installed: disabling CUDA graph capture for customized_vllm "
|
||||
"flash_attn not installed: disabling CUDA graph capture for nano-vllm "
|
||||
"(SDPA fallback uses .item() calls in paged-cache decode that are "
|
||||
"incompatible with CUDA graph capture)"
|
||||
)
|
||||
|
|
@ -626,7 +626,7 @@ class LLMHandler:
|
|||
pass
|
||||
if not _has_triton:
|
||||
logger.info(
|
||||
"Triton not available: disabling CUDA graph capture for customized_vllm "
|
||||
"Triton not available: disabling CUDA graph capture for nano-vllm "
|
||||
"(CUDA graphs require torch.compile which depends on Triton)"
|
||||
)
|
||||
enforce_eager_for_vllm = bool(is_rocm or is_jetson or not _has_flash_attn or not _has_triton)
|
||||
|
|
@ -751,11 +751,11 @@ class LLMHandler:
|
|||
logger.error("CUDA/ROCm is not available. Please check your GPU setup.")
|
||||
return "❌ CUDA/ROCm is not available. Please check your GPU setup."
|
||||
try:
|
||||
from acestep.customized_vllm import LLM, SamplingParams
|
||||
except ImportError as exc:
|
||||
from nanovllm import LLM, SamplingParams
|
||||
except ImportError:
|
||||
self.llm_initialized = False
|
||||
logger.error(f"Failed to import customized_vllm engine: {exc}")
|
||||
return f"❌ Failed to import customized_vllm engine: {exc}"
|
||||
logger.error("nano-vllm is not installed. Please install it using 'cd acestep/third_parts/nano-vllm && pip install .'")
|
||||
return "❌ nano-vllm is not installed. Please install it using 'cd acestep/third_parts/nano-vllm && pip install .'"
|
||||
|
||||
try:
|
||||
current_device = torch.cuda.current_device()
|
||||
|
|
@ -857,7 +857,7 @@ class LLMHandler:
|
|||
Accepts either a single formatted prompt (str) or a list of formatted prompts (List[str]).
|
||||
Returns a single string for single mode, or a list of strings for batch mode.
|
||||
"""
|
||||
from acestep.customized_vllm import SamplingParams
|
||||
from nanovllm import SamplingParams
|
||||
|
||||
# Determine if batch mode
|
||||
formatted_prompt_list, is_batch = self._normalize_batch_input(formatted_prompts)
|
||||
|
|
@ -1490,11 +1490,8 @@ class LLMHandler:
|
|||
seeds=seeds,
|
||||
)
|
||||
except Exception as e:
|
||||
error_detail = traceback.format_exc()
|
||||
logger.error(
|
||||
f"Error in batch codes generation: {type(e).__name__}: {e}\n{error_detail}"
|
||||
)
|
||||
error_msg = f"Error in batch codes generation: {type(e).__name__}: {e}"
|
||||
error_msg = f"Error in batch codes generation: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
return {
|
||||
"metadata": [],
|
||||
"audio_codes": [],
|
||||
|
|
@ -2417,11 +2414,11 @@ class LLMHandler:
|
|||
import traceback
|
||||
error_detail = traceback.format_exc()
|
||||
logger.error(f"Error in generate_from_formatted_prompt: {type(e).__name__}: {e}\n{error_detail}")
|
||||
# Reset vllm engine state on error to prevent stale context from causing
|
||||
# Reset nano-vllm state on error to prevent stale context from causing
|
||||
# subsequent CUDA illegal memory access errors
|
||||
if self.llm_backend == "vllm":
|
||||
try:
|
||||
from acestep.customized_vllm import reset_context
|
||||
from nanovllm.utils.context import reset_context
|
||||
reset_context()
|
||||
except ImportError:
|
||||
pass
|
||||
|
|
@ -4016,7 +4013,7 @@ class LLMHandler:
|
|||
yield
|
||||
return
|
||||
|
||||
# If using vllm engine or MLX, do not offload (managed differently)
|
||||
# If using nanovllm or MLX, do not offload (managed differently)
|
||||
if self.llm_backend in ("vllm", "mlx"):
|
||||
yield
|
||||
return
|
||||
|
|
|
|||
|
|
@ -80,8 +80,6 @@ class TestCotCfgScaleFixed(unittest.TestCase):
|
|||
handler = LLMHandler()
|
||||
handler.llm_initialized = True
|
||||
handler.llm_backend = "pt"
|
||||
handler.llm = MagicMock()
|
||||
handler.llm_tokenizer = MagicMock()
|
||||
|
||||
captured_cfg = {}
|
||||
|
||||
|
|
@ -127,8 +125,6 @@ class TestCotCfgScaleFixed(unittest.TestCase):
|
|||
handler = LLMHandler()
|
||||
handler.llm_initialized = True
|
||||
handler.llm_backend = "pt"
|
||||
handler.llm = MagicMock()
|
||||
handler.llm_tokenizer = MagicMock()
|
||||
|
||||
# Capture cfg passed to generate_from_formatted_prompt
|
||||
captured_cfgs = []
|
||||
|
|
@ -138,24 +134,26 @@ class TestCotCfgScaleFixed(unittest.TestCase):
|
|||
# Return a minimal CoT response so Phase 1 succeeds
|
||||
return "<think>metadata</think>", "ok"
|
||||
|
||||
with patch.object(handler, "generate_from_formatted_prompt", side_effect=capturing_gen), \
|
||||
patch.object(handler, "build_formatted_prompt", return_value="P"), \
|
||||
patch.object(handler, "parse_lm_output", return_value=({}, "")), \
|
||||
patch.object(handler, "_format_metadata_as_cot", return_value=""), \
|
||||
patch.object(handler, "build_formatted_prompt_with_cot", return_value="P2"):
|
||||
# Invoke with cfg_scale=2.0 (typical UI default)
|
||||
handler.generate_with_stop_condition(
|
||||
caption="test caption",
|
||||
lyrics="test lyrics",
|
||||
cfg_scale=2.0,
|
||||
temperature=0.6,
|
||||
negative_prompt="",
|
||||
top_k=None,
|
||||
top_p=None,
|
||||
repetition_penalty=1.0,
|
||||
infer_type="dit", # Phase 1 only – avoids Phase 2 setup
|
||||
progress=lambda *a, **kw: None,
|
||||
)
|
||||
with patch.object(handler, "generate_from_formatted_prompt", side_effect=capturing_gen):
|
||||
with patch.object(handler, "build_formatted_prompt", return_value="P"):
|
||||
with patch.object(handler, "_parse_metadata_from_cot", return_value={}):
|
||||
with patch.object(handler, "_format_metadata_as_cot", return_value=""):
|
||||
with patch.object(
|
||||
handler, "build_formatted_prompt_with_cot", return_value="P2"
|
||||
):
|
||||
# Invoke with cfg_scale=2.0 (typical UI default)
|
||||
handler.generate_with_stop_condition(
|
||||
caption="test caption",
|
||||
lyrics="test lyrics",
|
||||
cfg_scale=2.0,
|
||||
temperature=0.6,
|
||||
negative_prompt="",
|
||||
top_k=None,
|
||||
top_p=None,
|
||||
repetition_penalty=1.0,
|
||||
infer_type="dit", # Phase 1 only – avoids Phase 2 setup
|
||||
progress=lambda *a, **kw: None,
|
||||
)
|
||||
|
||||
# At least one generate_from_formatted_prompt call must exist
|
||||
self.assertTrue(len(captured_cfgs) >= 1, "generate_from_formatted_prompt was not called")
|
||||
|
|
|
|||
|
|
@ -4,17 +4,13 @@ Regression test for the bug where CUDA graph capture would fail when
|
|||
``flash_attn`` is not installed because the SDPA paged-cache decode path
|
||||
calls ``.item()`` inside the capture region (a forbidden CPU-GPU sync).
|
||||
|
||||
When flash_attn is absent, customized_vllm must run in ``enforce_eager=True``
|
||||
When flash_attn is absent, nano-vllm must run in ``enforce_eager=True``
|
||||
(eager mode, no CUDA graph capture) to avoid corrupting the CUDA context.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import unittest
|
||||
from types import ModuleType
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
_SENTINEL = object()
|
||||
|
||||
try:
|
||||
from acestep.llm_inference import LLMHandler
|
||||
_IMPORT_ERROR = None
|
||||
|
|
@ -40,16 +36,13 @@ def _mock_gpu_config():
|
|||
class TestEnforceEagerWhenFlashAttnMissing(unittest.TestCase):
|
||||
"""Verify that enforce_eager=True is set when flash_attn is not installed."""
|
||||
|
||||
def _run_initialize_with_mocks(self, flash_attn_available: bool,
|
||||
device_name: str = "NVIDIA GeForce RTX 4090",
|
||||
triton_available: bool = True):
|
||||
def _run_initialize_with_mocks(self, flash_attn_available: bool, device_name: str = "NVIDIA GeForce RTX 4090"):
|
||||
"""Call handler.initialize() with all heavy operations mocked.
|
||||
|
||||
Args:
|
||||
flash_attn_available: Whether flash_attn should appear detectable
|
||||
via ``importlib.util.find_spec``.
|
||||
device_name: Simulated CUDA device name.
|
||||
triton_available: Whether ``import triton`` should succeed.
|
||||
|
||||
Returns:
|
||||
The ``enforce_eager`` value that was passed to ``_initialize_5hz_lm_vllm``.
|
||||
|
|
@ -61,45 +54,28 @@ class TestEnforceEagerWhenFlashAttnMissing(unittest.TestCase):
|
|||
|
||||
captured = {}
|
||||
|
||||
def fake_init_vllm(model_path: str, enforce_eager: bool = False, **_kw) -> str:
|
||||
def fake_init_vllm(model_path: str, enforce_eager: bool = False) -> str:
|
||||
captured["enforce_eager"] = enforce_eager
|
||||
return "✅ ok"
|
||||
|
||||
# Inject a fake triton module so ``import triton`` inside initialize()
|
||||
# succeeds (or fails) as requested.
|
||||
fake_triton = ModuleType("triton") if triton_available else None
|
||||
saved_triton = sys.modules.get("triton", _SENTINEL)
|
||||
try:
|
||||
if triton_available:
|
||||
sys.modules["triton"] = fake_triton
|
||||
else:
|
||||
sys.modules.pop("triton", None)
|
||||
with patch("importlib.util.find_spec", return_value=find_spec_return), \
|
||||
patch("os.path.exists", return_value=True), \
|
||||
patch("acestep.llm_inference.AutoTokenizer") as mock_tok, \
|
||||
patch("acestep.llm_inference.get_global_gpu_config", return_value=_mock_gpu_config()), \
|
||||
patch("acestep.llm_inference.MetadataConstrainedLogitsProcessor"), \
|
||||
patch("torch.cuda.is_available", return_value=True), \
|
||||
patch("torch.cuda.empty_cache"), \
|
||||
patch("torch.cuda.synchronize"), \
|
||||
patch("torch.cuda.get_device_name", return_value=device_name), \
|
||||
patch.object(handler, "_initialize_5hz_lm_vllm", side_effect=fake_init_vllm):
|
||||
|
||||
with patch("importlib.util.find_spec", return_value=find_spec_return), \
|
||||
patch("os.path.exists", return_value=True), \
|
||||
patch("acestep.llm_inference.AutoTokenizer") as mock_tok, \
|
||||
patch("acestep.llm_inference.get_global_gpu_config", return_value=_mock_gpu_config()), \
|
||||
patch("acestep.llm_inference.get_gpu_memory_gb", return_value=16.0), \
|
||||
patch("acestep.llm_inference.MetadataConstrainedLogitsProcessor"), \
|
||||
patch("torch.cuda.is_available", return_value=True), \
|
||||
patch("torch.cuda.empty_cache"), \
|
||||
patch("torch.cuda.synchronize"), \
|
||||
patch("torch.cuda.get_device_name", return_value=device_name), \
|
||||
patch("torch.cuda.mem_get_info", return_value=(16 * 1024**3, 16 * 1024**3)), \
|
||||
patch.object(handler, "_initialize_5hz_lm_vllm", side_effect=fake_init_vllm):
|
||||
|
||||
mock_tok.from_pretrained.return_value = MagicMock()
|
||||
handler.initialize(
|
||||
checkpoint_dir="/tmp/fake_ckpt",
|
||||
lm_model_path="model",
|
||||
backend="vllm",
|
||||
device="cuda",
|
||||
)
|
||||
finally:
|
||||
if saved_triton is _SENTINEL:
|
||||
sys.modules.pop("triton", None)
|
||||
else:
|
||||
sys.modules["triton"] = saved_triton
|
||||
mock_tok.from_pretrained.return_value = MagicMock()
|
||||
handler.initialize(
|
||||
checkpoint_dir="/tmp/fake_ckpt",
|
||||
lm_model_path="model",
|
||||
backend="vllm",
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
return captured.get("enforce_eager")
|
||||
|
||||
|
|
|
|||
21
acestep/third_parts/nano-vllm/LICENSE
Normal file
21
acestep/third_parts/nano-vllm/LICENSE
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
MIT License
|
||||
|
||||
Copyright (c) 2025 Xingkai Yu
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
66
acestep/third_parts/nano-vllm/README.md
Normal file
66
acestep/third_parts/nano-vllm/README.md
Normal file
|
|
@ -0,0 +1,66 @@
|
|||
<p align="center">
|
||||
<img width="300" src="assets/logo.png">
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://trendshift.io/repositories/15323" target="_blank"><img src="https://trendshift.io/api/badge/repositories/15323" alt="GeeeekExplorer%2Fnano-vllm | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
||||
</p>
|
||||
|
||||
# Nano-vLLM
|
||||
|
||||
A lightweight vLLM implementation built from scratch.
|
||||
|
||||
## Key Features
|
||||
|
||||
* 🚀 **Fast offline inference** - Comparable inference speeds to vLLM
|
||||
* 📖 **Readable codebase** - Clean implementation in ~ 1,200 lines of Python code
|
||||
* ⚡ **Optimization Suite** - Prefix caching, Tensor Parallelism, Torch compilation, CUDA graph, etc.
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
pip install git+https://github.com/GeeeekExplorer/nano-vllm.git
|
||||
```
|
||||
|
||||
## Model Download
|
||||
|
||||
To download the model weights manually, use the following command:
|
||||
```bash
|
||||
huggingface-cli download --resume-download Qwen/Qwen3-0.6B \
|
||||
--local-dir ~/huggingface/Qwen3-0.6B/ \
|
||||
--local-dir-use-symlinks False
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
See `example.py` for usage. The API mirrors vLLM's interface with minor differences in the `LLM.generate` method:
|
||||
```python
|
||||
from nanovllm import LLM, SamplingParams
|
||||
llm = LLM("/YOUR/MODEL/PATH", enforce_eager=True, tensor_parallel_size=1)
|
||||
sampling_params = SamplingParams(temperature=0.6, max_tokens=256)
|
||||
prompts = ["Hello, Nano-vLLM."]
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
outputs[0]["text"]
|
||||
```
|
||||
|
||||
## Benchmark
|
||||
|
||||
See `bench.py` for benchmark.
|
||||
|
||||
**Test Configuration:**
|
||||
- Hardware: RTX 4070 Laptop (8GB)
|
||||
- Model: Qwen3-0.6B
|
||||
- Total Requests: 256 sequences
|
||||
- Input Length: Randomly sampled between 100–1024 tokens
|
||||
- Output Length: Randomly sampled between 100–1024 tokens
|
||||
|
||||
**Performance Results:**
|
||||
| Inference Engine | Output Tokens | Time (s) | Throughput (tokens/s) |
|
||||
|----------------|-------------|----------|-----------------------|
|
||||
| vLLM | 133,966 | 98.37 | 1361.84 |
|
||||
| Nano-vLLM | 133,966 | 93.41 | 1434.13 |
|
||||
|
||||
|
||||
## Star History
|
||||
|
||||
[](https://www.star-history.com/#GeeeekExplorer/nano-vllm&Date)
|
||||
BIN
acestep/third_parts/nano-vllm/assets/logo.png
Normal file
BIN
acestep/third_parts/nano-vllm/assets/logo.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 387 KiB |
32
acestep/third_parts/nano-vllm/bench.py
Normal file
32
acestep/third_parts/nano-vllm/bench.py
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
import os
|
||||
import time
|
||||
from random import randint, seed
|
||||
from nanovllm import LLM, SamplingParams
|
||||
# from vllm import LLM, SamplingParams
|
||||
|
||||
|
||||
def main():
|
||||
seed(0)
|
||||
num_seqs = 256
|
||||
max_input_len = 1024
|
||||
max_output_len = 1024
|
||||
|
||||
path = os.path.expanduser("~/huggingface/Qwen3-0.6B/")
|
||||
llm = LLM(path, enforce_eager=False, max_model_len=4096)
|
||||
|
||||
prompt_token_ids = [[randint(0, 10000) for _ in range(randint(100, max_input_len))] for _ in range(num_seqs)]
|
||||
sampling_params = [SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=randint(100, max_output_len)) for _ in range(num_seqs)]
|
||||
# uncomment the following line for vllm
|
||||
# prompt_token_ids = [dict(prompt_token_ids=p) for p in prompt_token_ids]
|
||||
|
||||
llm.generate(["Benchmark: "], SamplingParams())
|
||||
t = time.time()
|
||||
llm.generate(prompt_token_ids, sampling_params, use_tqdm=False)
|
||||
t = (time.time() - t)
|
||||
total_tokens = sum(sp.max_tokens for sp in sampling_params)
|
||||
throughput = total_tokens / t
|
||||
print(f"Total: {total_tokens}tok, Time: {t:.2f}s, Throughput: {throughput:.2f}tok/s")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
33
acestep/third_parts/nano-vllm/example.py
Normal file
33
acestep/third_parts/nano-vllm/example.py
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
import os
|
||||
from nanovllm import LLM, SamplingParams
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
|
||||
def main():
|
||||
path = os.path.expanduser("~/huggingface/Qwen3-0.6B/")
|
||||
tokenizer = AutoTokenizer.from_pretrained(path)
|
||||
llm = LLM(path, enforce_eager=True, tensor_parallel_size=1)
|
||||
|
||||
sampling_params = SamplingParams(temperature=0.6, max_tokens=256)
|
||||
prompts = [
|
||||
"introduce yourself",
|
||||
"list all prime numbers within 100",
|
||||
]
|
||||
prompts = [
|
||||
tokenizer.apply_chat_template(
|
||||
[{"role": "user", "content": prompt}],
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
for prompt in prompts
|
||||
]
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
for prompt, output in zip(prompts, outputs):
|
||||
print("\n")
|
||||
print(f"Prompt: {prompt!r}")
|
||||
print(f"Completion: {output['text']!r}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
2
acestep/third_parts/nano-vllm/nanovllm/__init__.py
Normal file
2
acestep/third_parts/nano-vllm/nanovllm/__init__.py
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
from nanovllm.llm import LLM
|
||||
from nanovllm.sampling_params import SamplingParams
|
||||
26
acestep/third_parts/nano-vllm/nanovllm/config.py
Normal file
26
acestep/third_parts/nano-vllm/nanovllm/config.py
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
import os
|
||||
from dataclasses import dataclass
|
||||
from transformers import AutoConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
model: str
|
||||
max_num_batched_tokens: int = 16384
|
||||
max_num_seqs: int = 512
|
||||
max_model_len: int = 4096
|
||||
gpu_memory_utilization: float = 0.9
|
||||
tensor_parallel_size: int = 1
|
||||
enforce_eager: bool = False
|
||||
hf_config: AutoConfig | None = None
|
||||
eos: int = -1
|
||||
kvcache_block_size: int = 256
|
||||
num_kvcache_blocks: int = -1
|
||||
|
||||
def __post_init__(self):
|
||||
assert os.path.isdir(self.model)
|
||||
assert self.kvcache_block_size % 256 == 0
|
||||
assert 1 <= self.tensor_parallel_size <= 8
|
||||
self.hf_config = AutoConfig.from_pretrained(self.model)
|
||||
self.max_model_len = min(self.max_model_len, self.hf_config.max_position_embeddings)
|
||||
assert self.max_num_batched_tokens >= self.max_model_len
|
||||
99
acestep/third_parts/nano-vllm/nanovllm/distributed.py
Normal file
99
acestep/third_parts/nano-vllm/nanovllm/distributed.py
Normal file
|
|
@ -0,0 +1,99 @@
|
|||
"""
|
||||
Distributed utilities for nano-vllm.
|
||||
|
||||
This module provides wrapper functions for torch.distributed that gracefully
|
||||
handle single-GPU mode (world_size == 1) without requiring distributed initialization.
|
||||
"""
|
||||
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
# Global flag to track if distributed is actually initialized
|
||||
_distributed_initialized = False
|
||||
|
||||
|
||||
def initialize_distributed(backend: str, init_method: str, world_size: int, rank: int) -> bool:
|
||||
"""
|
||||
Initialize distributed process group only if world_size > 1.
|
||||
|
||||
Args:
|
||||
backend: Distributed backend (e.g., "nccl" or "gloo")
|
||||
init_method: Initialization method (e.g., "tcp://127.0.0.1:2333")
|
||||
world_size: Total number of processes
|
||||
rank: Rank of current process
|
||||
|
||||
Returns:
|
||||
True if distributed was initialized, False otherwise
|
||||
"""
|
||||
global _distributed_initialized
|
||||
|
||||
if world_size == 1:
|
||||
# Single GPU mode - no distributed needed
|
||||
_distributed_initialized = False
|
||||
return False
|
||||
|
||||
# Multi-GPU mode - initialize distributed
|
||||
dist.init_process_group(backend, init_method, world_size=world_size, rank=rank)
|
||||
_distributed_initialized = True
|
||||
return True
|
||||
|
||||
|
||||
def is_initialized() -> bool:
|
||||
"""Check if distributed is initialized."""
|
||||
return _distributed_initialized
|
||||
|
||||
|
||||
def get_rank() -> int:
|
||||
"""Get current process rank. Returns 0 if distributed is not initialized."""
|
||||
if _distributed_initialized:
|
||||
return dist.get_rank()
|
||||
return 0
|
||||
|
||||
|
||||
def get_world_size() -> int:
|
||||
"""Get world size. Returns 1 if distributed is not initialized."""
|
||||
if _distributed_initialized:
|
||||
return dist.get_world_size()
|
||||
return 1
|
||||
|
||||
|
||||
def barrier():
|
||||
"""Synchronize all processes. No-op if distributed is not initialized."""
|
||||
if _distributed_initialized:
|
||||
dist.barrier()
|
||||
|
||||
|
||||
def all_reduce(tensor, op=None):
|
||||
"""
|
||||
All-reduce operation. No-op if distributed is not initialized.
|
||||
|
||||
Args:
|
||||
tensor: Tensor to reduce
|
||||
op: Reduce operation (default: SUM)
|
||||
"""
|
||||
if _distributed_initialized:
|
||||
if op is None:
|
||||
op = dist.ReduceOp.SUM
|
||||
dist.all_reduce(tensor, op)
|
||||
|
||||
|
||||
def gather(tensor, gather_list=None, dst=0):
|
||||
"""
|
||||
Gather tensors from all processes. No-op if distributed is not initialized.
|
||||
|
||||
Args:
|
||||
tensor: Tensor to gather
|
||||
gather_list: List to gather into (only used on dst rank)
|
||||
dst: Destination rank
|
||||
"""
|
||||
if _distributed_initialized:
|
||||
dist.gather(tensor, gather_list, dst)
|
||||
|
||||
|
||||
def destroy_process_group():
|
||||
"""Destroy process group. No-op if distributed is not initialized."""
|
||||
global _distributed_initialized
|
||||
|
||||
if _distributed_initialized:
|
||||
dist.destroy_process_group()
|
||||
_distributed_initialized = False
|
||||
136
acestep/third_parts/nano-vllm/nanovllm/engine/block_manager.py
Normal file
136
acestep/third_parts/nano-vllm/nanovllm/engine/block_manager.py
Normal file
|
|
@ -0,0 +1,136 @@
|
|||
import os
|
||||
from collections import deque
|
||||
import xxhash
|
||||
import numpy as np
|
||||
|
||||
from nanovllm.engine.sequence import Sequence
|
||||
|
||||
# Debug logging - enable with NANOVLLM_DEBUG=1
|
||||
_DEBUG = os.environ.get("NANOVLLM_DEBUG", "0") == "1"
|
||||
|
||||
def _debug_log(msg: str):
|
||||
"""Print debug message if NANOVLLM_DEBUG is enabled"""
|
||||
if _DEBUG:
|
||||
print(f"[nanovllm block_mgr DEBUG] {msg}", flush=True)
|
||||
|
||||
|
||||
class Block:
|
||||
|
||||
def __init__(self, block_id):
|
||||
self.block_id = block_id
|
||||
self.ref_count = 0
|
||||
self.hash = -1
|
||||
self.token_ids = []
|
||||
|
||||
def update(self, hash: int, token_ids: list[int]):
|
||||
self.hash = hash
|
||||
self.token_ids = token_ids
|
||||
|
||||
def reset(self):
|
||||
self.ref_count = 1
|
||||
self.hash = -1
|
||||
self.token_ids = []
|
||||
|
||||
|
||||
class BlockManager:
|
||||
|
||||
def __init__(self, num_blocks: int, block_size: int):
|
||||
self.block_size = block_size
|
||||
self.blocks: list[Block] = [Block(i) for i in range(num_blocks)]
|
||||
self.hash_to_block_id: dict[int, int] = dict()
|
||||
self.free_block_ids: deque[int] = deque(range(num_blocks))
|
||||
self.used_block_ids: set[int] = set()
|
||||
|
||||
@classmethod
|
||||
def compute_hash(cls, token_ids: list[int], prefix: int = -1):
|
||||
h = xxhash.xxh64()
|
||||
if prefix != -1:
|
||||
h.update(prefix.to_bytes(8, "little"))
|
||||
h.update(np.array(token_ids).tobytes())
|
||||
return h.intdigest()
|
||||
|
||||
def _allocate_block(self, block_id: int) -> Block:
|
||||
block = self.blocks[block_id]
|
||||
assert block.ref_count == 0
|
||||
block.reset()
|
||||
self.free_block_ids.remove(block_id)
|
||||
self.used_block_ids.add(block_id)
|
||||
return self.blocks[block_id]
|
||||
|
||||
def _deallocate_block(self, block_id: int) -> Block:
|
||||
assert self.blocks[block_id].ref_count == 0
|
||||
self.used_block_ids.remove(block_id)
|
||||
self.free_block_ids.append(block_id)
|
||||
|
||||
def can_allocate(self, seq: Sequence) -> bool:
|
||||
return len(self.free_block_ids) >= seq.num_blocks
|
||||
|
||||
def allocate(self, seq: Sequence):
|
||||
_debug_log(f"allocate: seq_id={seq.seq_id}, len={len(seq)}, num_blocks={seq.num_blocks}, "
|
||||
f"free_blocks={len(self.free_block_ids)}")
|
||||
assert not seq.block_table
|
||||
h = -1
|
||||
cache_miss = False
|
||||
for i in range(seq.num_blocks):
|
||||
token_ids = seq.block(i)
|
||||
h = self.compute_hash(token_ids, h) if len(token_ids) == self.block_size else -1
|
||||
block_id = self.hash_to_block_id.get(h, -1)
|
||||
if block_id == -1 or self.blocks[block_id].token_ids != token_ids:
|
||||
cache_miss = True
|
||||
if cache_miss:
|
||||
if len(self.free_block_ids) == 0:
|
||||
_debug_log(f" ERROR: no free blocks available!")
|
||||
block_id = self.free_block_ids[0]
|
||||
block = self._allocate_block(block_id)
|
||||
else:
|
||||
seq.num_cached_tokens += self.block_size
|
||||
if block_id in self.used_block_ids:
|
||||
block = self.blocks[block_id]
|
||||
block.ref_count += 1
|
||||
else:
|
||||
block = self._allocate_block(block_id)
|
||||
if h != -1:
|
||||
block.update(h, token_ids)
|
||||
self.hash_to_block_id[h] = block_id
|
||||
seq.block_table.append(block_id)
|
||||
_debug_log(f" allocated block_table: {seq.block_table}")
|
||||
|
||||
def deallocate(self, seq: Sequence):
|
||||
_debug_log(f"deallocate: seq_id={seq.seq_id}, block_table={seq.block_table}")
|
||||
for block_id in reversed(seq.block_table):
|
||||
block = self.blocks[block_id]
|
||||
block.ref_count -= 1
|
||||
_debug_log(f" block_id={block_id}, ref_count after decrement={block.ref_count}")
|
||||
if block.ref_count == 0:
|
||||
# Fix: Clean up hash_to_block_id mapping to prevent stale references
|
||||
# This prevents CUDA illegal memory access when prefix cache tries to
|
||||
# reuse a block_id that has already been freed
|
||||
if block.hash != -1:
|
||||
cached_id = self.hash_to_block_id.get(block.hash)
|
||||
if cached_id == block_id:
|
||||
del self.hash_to_block_id[block.hash]
|
||||
self._deallocate_block(block_id)
|
||||
seq.num_cached_tokens = 0
|
||||
seq.block_table.clear()
|
||||
_debug_log(f" deallocated, free_blocks={len(self.free_block_ids)}")
|
||||
|
||||
def can_append(self, seq: Sequence) -> bool:
|
||||
return len(self.free_block_ids) >= (len(seq) % self.block_size == 1)
|
||||
|
||||
def may_append(self, seq: Sequence):
|
||||
block_table = seq.block_table
|
||||
last_block = self.blocks[block_table[-1]]
|
||||
if len(seq) % self.block_size == 1:
|
||||
assert last_block.hash != -1
|
||||
block_id = self.free_block_ids[0]
|
||||
self._allocate_block(block_id)
|
||||
block_table.append(block_id)
|
||||
elif len(seq) % self.block_size == 0:
|
||||
assert last_block.hash == -1
|
||||
token_ids = seq.block(seq.num_blocks-1)
|
||||
prefix = self.blocks[block_table[-2]].hash if len(block_table) > 1 else -1
|
||||
h = self.compute_hash(token_ids, prefix)
|
||||
last_block.update(h, token_ids)
|
||||
self.hash_to_block_id[h] = last_block.block_id
|
||||
else:
|
||||
assert last_block.hash == -1
|
||||
178
acestep/third_parts/nano-vllm/nanovllm/engine/llm_engine.py
Normal file
178
acestep/third_parts/nano-vllm/nanovllm/engine/llm_engine.py
Normal file
|
|
@ -0,0 +1,178 @@
|
|||
import atexit
|
||||
import threading
|
||||
from dataclasses import fields
|
||||
from time import perf_counter
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import AutoTokenizer
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from nanovllm.config import Config
|
||||
from nanovllm.sampling_params import SamplingParams
|
||||
from nanovllm.engine.sequence import Sequence
|
||||
from nanovllm.engine.scheduler import Scheduler
|
||||
from nanovllm.engine.model_runner import ModelRunner
|
||||
|
||||
|
||||
class LLMEngine:
|
||||
|
||||
def __init__(self, model, **kwargs):
|
||||
config_fields = {field.name for field in fields(Config)}
|
||||
config_kwargs = {k: v for k, v in kwargs.items() if k in config_fields}
|
||||
config = Config(model, **config_kwargs)
|
||||
self.ps = []
|
||||
self.events = []
|
||||
# Thread-safety lock for generate().
|
||||
# The scheduler, block manager, model runner, and CUDA graph buffers are all
|
||||
# shared mutable state that is NOT thread-safe. In concurrent serving scenarios
|
||||
# (API server with ThreadPoolExecutor, multiple queue workers, Gradio with
|
||||
# concurrent requests), multiple threads can call generate() simultaneously.
|
||||
# Without this lock, concurrent access corrupts scheduler state, block tables,
|
||||
# and CUDA graph input buffers, leading to intermittent CUDA device-side
|
||||
# assertion failures (illegal memory access in KV cache).
|
||||
self._generate_lock = threading.Lock()
|
||||
ctx = mp.get_context("spawn")
|
||||
for i in range(1, config.tensor_parallel_size):
|
||||
event = ctx.Event()
|
||||
process = ctx.Process(target=ModelRunner, args=(config, i, event))
|
||||
process.start()
|
||||
self.ps.append(process)
|
||||
self.events.append(event)
|
||||
self.model_runner = ModelRunner(config, 0, self.events)
|
||||
tokenizer = kwargs.get("tokenizer", None)
|
||||
if tokenizer is not None:
|
||||
self.tokenizer = tokenizer
|
||||
else:
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(config.model, use_fast=True)
|
||||
config.eos = self.tokenizer.eos_token_id
|
||||
self.scheduler = Scheduler(config)
|
||||
atexit.register(self.exit)
|
||||
|
||||
def exit(self):
|
||||
self.model_runner.call("exit")
|
||||
del self.model_runner
|
||||
for p in self.ps:
|
||||
p.join()
|
||||
|
||||
def add_request(self, prompt: str | list[int], sampling_params: SamplingParams, unconditional_prompt: str | list[int] | None = None):
|
||||
if isinstance(prompt, str):
|
||||
prompt = self.tokenizer.encode(prompt)
|
||||
# For CFG: if cfg_scale > 1.0, create both conditional and unconditional sequences
|
||||
if sampling_params.cfg_scale > 1.0:
|
||||
if unconditional_prompt is None:
|
||||
# Try to construct unconditional prompt by replacing user input with "NO USER INPUT"
|
||||
# This is a fallback - ideally users should provide unconditional_prompt
|
||||
if isinstance(prompt, list):
|
||||
# For now, just use the same prompt (user should provide unconditional_prompt)
|
||||
# TODO: Implement automatic "NO USER INPUT" replacement if possible
|
||||
unconditional_prompt = prompt
|
||||
else:
|
||||
unconditional_prompt = prompt
|
||||
if isinstance(unconditional_prompt, str):
|
||||
unconditional_prompt = self.tokenizer.encode(unconditional_prompt)
|
||||
# Create unconditional sequence first (so we can reference it from conditional)
|
||||
uncond_seq = Sequence(unconditional_prompt, sampling_params, is_unconditional=True)
|
||||
# Create conditional sequence with reference to unconditional
|
||||
cond_seq = Sequence(prompt, sampling_params, is_unconditional=False, conditional_seq=uncond_seq)
|
||||
uncond_seq.paired_seq = cond_seq # Link them bidirectionally
|
||||
# Add both sequences to scheduler
|
||||
self.scheduler.add(cond_seq)
|
||||
self.scheduler.add(uncond_seq)
|
||||
else:
|
||||
seq = Sequence(prompt, sampling_params)
|
||||
self.scheduler.add(seq)
|
||||
|
||||
def step(self):
|
||||
seqs, is_prefill = self.scheduler.schedule()
|
||||
token_ids = self.model_runner.call("run", seqs, is_prefill)
|
||||
self.scheduler.postprocess(seqs, token_ids)
|
||||
# Only output conditional sequences (unconditional sequences are just for CFG computation)
|
||||
output_seqs = [seq for seq in seqs if seq.is_finished and (seq.cfg_scale <= 1.0 or not seq.is_unconditional)]
|
||||
outputs = [(seq.seq_id, seq.completion_token_ids) for seq in output_seqs]
|
||||
num_tokens = sum(len(seq) for seq in seqs) if is_prefill else -len([s for s in seqs if not s.is_unconditional])
|
||||
return outputs, num_tokens
|
||||
|
||||
def is_finished(self):
|
||||
return self.scheduler.is_finished()
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Reset the scheduler state and release all allocated blocks.
|
||||
This should be called when an exception occurs during generation to prevent
|
||||
KV cache block leaks that can cause 'deque index out of range' errors.
|
||||
"""
|
||||
# Deallocate all running sequences
|
||||
while self.scheduler.running:
|
||||
seq = self.scheduler.running.popleft()
|
||||
if seq.block_table: # Only deallocate if blocks are allocated
|
||||
self.scheduler.block_manager.deallocate(seq)
|
||||
|
||||
# Deallocate all waiting sequences (they might have blocks from preemption)
|
||||
while self.scheduler.waiting:
|
||||
seq = self.scheduler.waiting.popleft()
|
||||
if seq.block_table:
|
||||
self.scheduler.block_manager.deallocate(seq)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompts: list[str] | list[list[int]],
|
||||
sampling_params: SamplingParams | list[SamplingParams],
|
||||
use_tqdm: bool = True,
|
||||
unconditional_prompts: list[str] | list[list[int]] | None = None,
|
||||
) -> list[str]:
|
||||
# Serialize access to the engine to prevent concurrent corruption of
|
||||
# scheduler state, block manager, CUDA graph buffers, and KV cache.
|
||||
# This is the primary defense against the intermittent CUDA device-side
|
||||
# assertion error that occurs in concurrent serving scenarios.
|
||||
with self._generate_lock:
|
||||
return self._generate_impl(prompts, sampling_params, use_tqdm, unconditional_prompts)
|
||||
|
||||
def _generate_impl(
|
||||
self,
|
||||
prompts: list[str] | list[list[int]],
|
||||
sampling_params: SamplingParams | list[SamplingParams],
|
||||
use_tqdm: bool = True,
|
||||
unconditional_prompts: list[str] | list[list[int]] | None = None,
|
||||
) -> list[str]:
|
||||
# Clean up any residual state from previous interrupted generations
|
||||
# This prevents 'deque index out of range' errors from accumulated block leaks
|
||||
if not self.is_finished():
|
||||
self.reset()
|
||||
|
||||
if use_tqdm:
|
||||
pbar = tqdm(total=len(prompts), desc="Generating", dynamic_ncols=True)
|
||||
if not isinstance(sampling_params, list):
|
||||
sampling_params = [sampling_params] * len(prompts)
|
||||
if unconditional_prompts is None:
|
||||
unconditional_prompts = [None] * len(prompts)
|
||||
for prompt, sp, uncond_prompt in zip(prompts, sampling_params, unconditional_prompts):
|
||||
self.add_request(prompt, sp, uncond_prompt)
|
||||
outputs = {}
|
||||
prefill_throughput = decode_throughput = 0.
|
||||
try:
|
||||
while not self.is_finished():
|
||||
t = perf_counter()
|
||||
output, num_tokens = self.step()
|
||||
if use_tqdm:
|
||||
if num_tokens > 0:
|
||||
prefill_throughput = num_tokens / (perf_counter() - t)
|
||||
else:
|
||||
decode_throughput = -num_tokens / (perf_counter() - t)
|
||||
pbar.set_postfix({
|
||||
"Prefill": f"{int(prefill_throughput)}tok/s",
|
||||
"Decode": f"{int(decode_throughput)}tok/s",
|
||||
})
|
||||
for seq_id, token_ids in output:
|
||||
outputs[seq_id] = token_ids
|
||||
if use_tqdm:
|
||||
pbar.update(1)
|
||||
except Exception:
|
||||
# Clean up on exception to prevent block leaks
|
||||
self.reset()
|
||||
raise
|
||||
finally:
|
||||
if use_tqdm:
|
||||
pbar.close()
|
||||
|
||||
outputs = [outputs[seq_id] for seq_id in sorted(outputs.keys())]
|
||||
outputs = [{"text": self.tokenizer.decode(token_ids), "token_ids": token_ids} for token_ids in outputs]
|
||||
return outputs
|
||||
691
acestep/third_parts/nano-vllm/nanovllm/engine/model_runner.py
Normal file
691
acestep/third_parts/nano-vllm/nanovllm/engine/model_runner.py
Normal file
|
|
@ -0,0 +1,691 @@
|
|||
import os
|
||||
import pickle
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from multiprocessing.synchronize import Event
|
||||
from multiprocessing.shared_memory import SharedMemory
|
||||
import sys
|
||||
|
||||
from nanovllm.config import Config
|
||||
from acestep.debug_utils import debug_start, debug_end
|
||||
from nanovllm import distributed as dist_utils
|
||||
|
||||
# Debug logging - enable with NANOVLLM_DEBUG=1
|
||||
_DEBUG = os.environ.get("NANOVLLM_DEBUG", "0") == "1"
|
||||
|
||||
def _debug_log(msg: str):
|
||||
"""Print debug message if NANOVLLM_DEBUG is enabled"""
|
||||
if _DEBUG:
|
||||
print(f"[nanovllm DEBUG] {msg}", flush=True)
|
||||
from nanovllm.engine.sequence import Sequence
|
||||
from nanovllm.models.qwen3 import Qwen3ForCausalLM
|
||||
from nanovllm.layers.sampler import Sampler
|
||||
from nanovllm.utils.context import set_context, get_context, reset_context
|
||||
from nanovllm.utils.loader import load_model
|
||||
|
||||
import socket
|
||||
|
||||
|
||||
def find_available_port(start_port: int = 2333, max_attempts: int = 100) -> int:
|
||||
"""Find an available port starting from start_port.
|
||||
|
||||
Args:
|
||||
start_port: The starting port number to check
|
||||
max_attempts: Maximum number of ports to try
|
||||
|
||||
Returns:
|
||||
An available port number
|
||||
|
||||
Raises:
|
||||
RuntimeError: If no available port is found within max_attempts
|
||||
"""
|
||||
for i in range(max_attempts):
|
||||
port = start_port + i
|
||||
try:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
s.bind(('localhost', port))
|
||||
return port
|
||||
except OSError:
|
||||
# Port is in use, try next one
|
||||
continue
|
||||
raise RuntimeError(f"Could not find an available port starting from {start_port} after {max_attempts} attempts")
|
||||
|
||||
|
||||
class ModelRunner:
|
||||
|
||||
def __init__(self, config: Config, rank: int, event: Event | list[Event]):
|
||||
# Enable capturing scalar outputs to avoid graph breaks from Tensor.item() calls
|
||||
torch._dynamo.config.capture_scalar_outputs = True
|
||||
|
||||
self.config = config
|
||||
hf_config = config.hf_config
|
||||
self.block_size = config.kvcache_block_size
|
||||
self.enforce_eager = config.enforce_eager
|
||||
self.world_size = config.tensor_parallel_size
|
||||
self.rank = rank
|
||||
self.event = event
|
||||
|
||||
# Only initialize distributed if world_size > 1
|
||||
if self.world_size > 1:
|
||||
dist_port = find_available_port()
|
||||
print(f"[debug]dist_port: {dist_port}")
|
||||
# Use gloo backend on Windows, nccl on Linux/other platforms
|
||||
backend = "gloo" if sys.platform == "win32" else "nccl"
|
||||
dist_utils.initialize_distributed(backend, f"tcp://127.0.0.1:{dist_port}", world_size=self.world_size, rank=rank)
|
||||
|
||||
torch.cuda.set_device(rank)
|
||||
default_dtype = torch.get_default_dtype()
|
||||
|
||||
# Detect GPU compute capability to determine bfloat16 support
|
||||
# Bfloat16 requires Ampere (compute capability >= 8.0) or newer
|
||||
gpu_props = torch.cuda.get_device_properties(rank)
|
||||
# Use tuple comparison to handle compute capability correctly
|
||||
# (e.g., 7.5 < 8.0, 8.0 >= 8.0, 8.6 >= 8.0, etc.)
|
||||
supports_bfloat16 = (gpu_props.major, gpu_props.minor) >= (8, 0)
|
||||
|
||||
# Use dtype instead of deprecated torch_dtype
|
||||
config_dtype = getattr(hf_config, 'dtype', getattr(hf_config, 'torch_dtype', torch.bfloat16))
|
||||
|
||||
# Validate and convert config_dtype to a valid torch floating-point dtype
|
||||
# Default to bfloat16 for CUDA (required for Flash Attention 2) if GPU supports it
|
||||
if config_dtype is None:
|
||||
config_dtype = torch.bfloat16 if supports_bfloat16 else torch.float16
|
||||
elif isinstance(config_dtype, str):
|
||||
# Convert string dtype to torch dtype
|
||||
dtype_map = {
|
||||
'float32': torch.float32,
|
||||
'float16': torch.float16,
|
||||
'bfloat16': torch.bfloat16,
|
||||
'float64': torch.float64,
|
||||
'torch.float32': torch.float32,
|
||||
'torch.float16': torch.float16,
|
||||
'torch.bfloat16': torch.bfloat16,
|
||||
'torch.float64': torch.float64,
|
||||
}
|
||||
config_dtype = dtype_map.get(config_dtype.lower(), torch.bfloat16 if supports_bfloat16 else torch.float16)
|
||||
elif not isinstance(config_dtype, torch.dtype) or not config_dtype.is_floating_point:
|
||||
# If not a valid floating-point torch dtype, default based on GPU capability
|
||||
config_dtype = torch.bfloat16 if supports_bfloat16 else torch.float16
|
||||
|
||||
# Override to float16 if config requested bfloat16 but GPU doesn't support it
|
||||
if config_dtype == torch.bfloat16 and not supports_bfloat16:
|
||||
print(f"[nanovllm] GPU {gpu_props.name} (compute capability {gpu_props.major}.{gpu_props.minor}) does not support bfloat16. Using float16 instead.", flush=True)
|
||||
config_dtype = torch.float16
|
||||
|
||||
self.dtype = config_dtype # Save for later use
|
||||
torch.set_default_dtype(config_dtype)
|
||||
torch.set_default_device("cuda")
|
||||
self.model = Qwen3ForCausalLM(hf_config)
|
||||
_t0 = debug_start("load_model", prefix="tensor.vllm")
|
||||
load_model(self.model, config.model)
|
||||
debug_end("load_model", _t0, prefix="tensor.vllm")
|
||||
self.sampler = Sampler()
|
||||
|
||||
# Pre-allocate buffers for sampling (optimization: avoid repeated tensor creation)
|
||||
# Must be called before warmup_model() since it uses these buffers
|
||||
self._allocate_sample_buffers()
|
||||
|
||||
self.warmup_model()
|
||||
self.allocate_kv_cache()
|
||||
if not self.enforce_eager:
|
||||
self.capture_cudagraph()
|
||||
|
||||
torch.set_default_device("cpu")
|
||||
torch.set_default_dtype(default_dtype)
|
||||
|
||||
if self.world_size > 1:
|
||||
if rank == 0:
|
||||
self.shm = SharedMemory(name="nanovllm", create=True, size=2**20)
|
||||
dist_utils.barrier()
|
||||
else:
|
||||
dist_utils.barrier()
|
||||
self.shm = SharedMemory(name="nanovllm")
|
||||
self.loop()
|
||||
|
||||
def _allocate_sample_buffers(self):
|
||||
"""Pre-allocate reusable buffers for sampling to avoid repeated tensor creation."""
|
||||
_t0 = debug_start("_allocate_sample_buffers", prefix="tensor.vllm")
|
||||
max_bs = self.config.max_num_seqs
|
||||
max_tokens = self.config.max_num_batched_tokens
|
||||
max_num_blocks = (self.config.max_model_len + self.block_size - 1) // self.block_size
|
||||
|
||||
# Pre-allocate pinned memory buffers on CPU for fast transfer
|
||||
# Must explicitly specify device="cpu" since default device may be "cuda"
|
||||
self._cpu_temperatures = torch.zeros(max_bs, dtype=torch.float32, device="cpu", pin_memory=True)
|
||||
self._cpu_cfg_scales = torch.zeros(max_bs, dtype=torch.float32, device="cpu", pin_memory=True)
|
||||
self._cpu_top_ks = torch.zeros(max_bs, dtype=torch.int32, device="cpu", pin_memory=True)
|
||||
self._cpu_top_ps = torch.zeros(max_bs, dtype=torch.float32, device="cpu", pin_memory=True)
|
||||
self._cpu_repetition_penalties = torch.zeros(max_bs, dtype=torch.float32, device="cpu", pin_memory=True)
|
||||
|
||||
# Pre-allocate decode buffers on CPU with pinned memory
|
||||
self._cpu_input_ids = torch.zeros(max_bs, dtype=torch.int64, device="cpu", pin_memory=True)
|
||||
self._cpu_positions = torch.zeros(max_bs, dtype=torch.int64, device="cpu", pin_memory=True)
|
||||
self._cpu_slot_mapping = torch.zeros(max_bs, dtype=torch.int32, device="cpu", pin_memory=True)
|
||||
self._cpu_context_lens = torch.zeros(max_bs, dtype=torch.int32, device="cpu", pin_memory=True)
|
||||
|
||||
# Pre-allocate prefill buffers on CPU with pinned memory (optimization to avoid repeated tensor creation)
|
||||
self._cpu_prefill_input_ids = torch.zeros(max_tokens, dtype=torch.int64, device="cpu", pin_memory=True)
|
||||
self._cpu_prefill_positions = torch.zeros(max_tokens, dtype=torch.int64, device="cpu", pin_memory=True)
|
||||
self._cpu_prefill_cu_seqlens = torch.zeros(max_bs + 1, dtype=torch.int32, device="cpu", pin_memory=True)
|
||||
self._cpu_prefill_slot_mapping = torch.zeros(max_tokens, dtype=torch.int32, device="cpu", pin_memory=True)
|
||||
|
||||
# Pre-allocate block tables buffer (shared by both decode and prefill)
|
||||
self._cpu_block_tables = torch.zeros(max_bs, max_num_blocks, dtype=torch.int32, device="cpu", pin_memory=True)
|
||||
|
||||
# Pre-allocate buffer for sequence token IDs (used in logits processor and sampler)
|
||||
# Max length is max_model_len since sequences can be that long
|
||||
self._seq_token_ids_buffer = torch.zeros(max_bs, self.config.max_model_len, dtype=torch.int64, device="cpu", pin_memory=True)
|
||||
debug_end("_allocate_sample_buffers", _t0, prefix="tensor.vllm")
|
||||
|
||||
def exit(self):
|
||||
if self.world_size > 1:
|
||||
self.shm.close()
|
||||
dist_utils.barrier()
|
||||
if self.rank == 0:
|
||||
self.shm.unlink()
|
||||
if not self.enforce_eager:
|
||||
del self.graphs, self.graph_pool
|
||||
torch.cuda.synchronize()
|
||||
dist_utils.destroy_process_group()
|
||||
|
||||
def loop(self):
|
||||
while True:
|
||||
method_name, args = self.read_shm()
|
||||
self.call(method_name, *args)
|
||||
if method_name == "exit":
|
||||
break
|
||||
|
||||
def read_shm(self):
|
||||
assert self.world_size > 1 and self.rank > 0
|
||||
self.event.wait()
|
||||
n = int.from_bytes(self.shm.buf[0:4], "little")
|
||||
method_name, *args = pickle.loads(self.shm.buf[4:n+4])
|
||||
self.event.clear()
|
||||
return method_name, args
|
||||
|
||||
def write_shm(self, method_name, *args):
|
||||
assert self.world_size > 1 and self.rank == 0
|
||||
data = pickle.dumps([method_name, *args])
|
||||
n = len(data)
|
||||
self.shm.buf[0:4] = n.to_bytes(4, "little")
|
||||
self.shm.buf[4:n+4] = data
|
||||
for event in self.event:
|
||||
event.set()
|
||||
|
||||
def call(self, method_name, *args):
|
||||
if self.world_size > 1 and self.rank == 0:
|
||||
self.write_shm(method_name, *args)
|
||||
method = getattr(self, method_name, None)
|
||||
return method(*args)
|
||||
|
||||
def warmup_model(self):
|
||||
_t0 = debug_start("warmup_model", prefix="tensor.vllm")
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
max_num_batched_tokens, max_model_len = self.config.max_num_batched_tokens, self.config.max_model_len
|
||||
num_seqs = min(max_num_batched_tokens // max_model_len, self.config.max_num_seqs)
|
||||
seqs = [Sequence([0] * max_model_len) for _ in range(num_seqs)]
|
||||
self.run(seqs, True)
|
||||
torch.cuda.empty_cache()
|
||||
debug_end("warmup_model", _t0, prefix="tensor.vllm")
|
||||
|
||||
def allocate_kv_cache(self):
|
||||
_t0 = debug_start("allocate_kv_cache", prefix="tensor.vllm")
|
||||
config = self.config
|
||||
hf_config = config.hf_config
|
||||
free, total = torch.cuda.mem_get_info()
|
||||
current = torch.cuda.memory_stats()["allocated_bytes.all.current"]
|
||||
|
||||
# Account for per-process memory fraction (set via MAX_CUDA_VRAM simulation)
|
||||
import os as _os
|
||||
_debug_vram = _os.environ.get("MAX_CUDA_VRAM")
|
||||
if _debug_vram is not None:
|
||||
try:
|
||||
_simulated_gb = float(_debug_vram)
|
||||
_total_gb = total / (1024 ** 3)
|
||||
if _simulated_gb < _total_gb:
|
||||
# Effective total and free are capped by simulation
|
||||
reserved = torch.cuda.memory_reserved()
|
||||
total = int(_simulated_gb * (1024 ** 3))
|
||||
free = max(0, total - reserved)
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
num_kv_heads = hf_config.num_key_value_heads // self.world_size
|
||||
head_dim = getattr(hf_config, "head_dim", hf_config.hidden_size // hf_config.num_attention_heads)
|
||||
block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * head_dim * self.dtype.itemsize
|
||||
|
||||
# Calculate available memory for KV cache
|
||||
# After warmup_model, empty_cache has been called, so current represents model memory only
|
||||
# Use free memory but respect the gpu_memory_utilization limit
|
||||
target_total_usage = total * config.gpu_memory_utilization
|
||||
available_for_kv_cache = min(free * 0.9, target_total_usage - current)
|
||||
|
||||
# Safety check: ensure we leave at least ~1 GB free for DiT inference
|
||||
# activations that will run after LM generation. Without this, the KV
|
||||
# cache can consume all free VRAM and cause OOM during DiT forward pass.
|
||||
MIN_RESERVE_BYTES = int(1.0 * 1024**3) # 1 GB reserved for other models
|
||||
max_kv_from_free = max(0, free - MIN_RESERVE_BYTES) * 0.9
|
||||
available_for_kv_cache = min(available_for_kv_cache, max_kv_from_free)
|
||||
|
||||
# Ensure we have positive memory available
|
||||
if available_for_kv_cache <= 0:
|
||||
available_for_kv_cache = free * 0.5 # Fallback to 50% of free memory
|
||||
|
||||
config.num_kvcache_blocks = max(1, int(available_for_kv_cache) // block_bytes)
|
||||
if config.num_kvcache_blocks <= 0:
|
||||
raise RuntimeError(
|
||||
f"Insufficient GPU memory for KV cache. "
|
||||
f"Free: {free / 1024**3:.2f} GB, Current: {current / 1024**3:.2f} GB, "
|
||||
f"Available for KV: {available_for_kv_cache / 1024**3:.2f} GB, "
|
||||
f"Block size: {block_bytes / 1024**2:.2f} MB"
|
||||
)
|
||||
max_tokens_capacity = config.num_kvcache_blocks * self.block_size
|
||||
kv_cache_size_gb = config.num_kvcache_blocks * block_bytes / 1024**3
|
||||
|
||||
# If KV cache would leave less than 1 GB free, warn and suggest reducing max_model_len
|
||||
post_kv_free = (free - config.num_kvcache_blocks * block_bytes) / 1024**3
|
||||
if post_kv_free < 1.0:
|
||||
print(
|
||||
f"[nanovllm] WARNING: After KV cache allocation, only {post_kv_free:.2f} GB free. "
|
||||
f"DiT inference may OOM. Consider reducing max_model_len or using CPU offload."
|
||||
)
|
||||
|
||||
print(
|
||||
f"[nanovllm] KV cache allocated: {config.num_kvcache_blocks} blocks × {self.block_size} tokens = "
|
||||
f"{max_tokens_capacity} tokens capacity, {kv_cache_size_gb:.2f} GB "
|
||||
f"(free: {free / 1024**3:.2f} GB, used: {current / 1024**3:.2f} GB, "
|
||||
f"target: {target_total_usage / 1024**3:.2f} GB, block: {block_bytes / 1024**2:.2f} MB, "
|
||||
f"post_kv_free: {post_kv_free:.2f} GB)"
|
||||
)
|
||||
self.kv_cache = torch.empty(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, num_kv_heads, head_dim)
|
||||
layer_id = 0
|
||||
for module in self.model.modules():
|
||||
if hasattr(module, "k_cache") and hasattr(module, "v_cache"):
|
||||
module.k_cache = self.kv_cache[0, layer_id]
|
||||
module.v_cache = self.kv_cache[1, layer_id]
|
||||
layer_id += 1
|
||||
debug_end("allocate_kv_cache", _t0, prefix="tensor.vllm")
|
||||
|
||||
def prepare_block_tables(self, seqs: list[Sequence]):
|
||||
_t0 = debug_start("prepare_block_tables", prefix="tensor.vllm")
|
||||
max_len = max(len(seq.block_table) for seq in seqs)
|
||||
block_tables = [seq.block_table + [-1] * (max_len - len(seq.block_table)) for seq in seqs]
|
||||
block_tables = torch.tensor(block_tables, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||
debug_end("prepare_block_tables", _t0, prefix="tensor.vllm")
|
||||
return block_tables
|
||||
|
||||
def prepare_prefill(self, seqs: list[Sequence]):
|
||||
_t0 = debug_start("prepare_prefill", prefix="tensor.vllm")
|
||||
input_ids = []
|
||||
positions = []
|
||||
cu_seqlens_q = [0]
|
||||
cu_seqlens_k = [0]
|
||||
max_seqlen_q = 0
|
||||
max_seqlen_k = 0
|
||||
slot_mapping = []
|
||||
block_tables = None
|
||||
for seq in seqs:
|
||||
seqlen = len(seq)
|
||||
input_ids.extend(seq[seq.num_cached_tokens:])
|
||||
positions.extend(list(range(seq.num_cached_tokens, seqlen)))
|
||||
seqlen_q = seqlen - seq.num_cached_tokens
|
||||
seqlen_k = seqlen
|
||||
cu_seqlens_q.append(cu_seqlens_q[-1] + seqlen_q)
|
||||
cu_seqlens_k.append(cu_seqlens_k[-1] + seqlen_k)
|
||||
max_seqlen_q = max(seqlen_q, max_seqlen_q)
|
||||
max_seqlen_k = max(seqlen_k, max_seqlen_k)
|
||||
if not seq.block_table: # warmup
|
||||
continue
|
||||
for i in range(seq.num_cached_blocks, seq.num_blocks):
|
||||
start = seq.block_table[i] * self.block_size
|
||||
if i != seq.num_blocks - 1:
|
||||
end = start + self.block_size
|
||||
else:
|
||||
end = start + seq.last_block_num_tokens
|
||||
slot_mapping.extend(list(range(start, end)))
|
||||
if cu_seqlens_k[-1] > cu_seqlens_q[-1]: # prefix cache
|
||||
block_tables = self.prepare_block_tables(seqs)
|
||||
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
||||
positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
||||
cu_seqlens_q = torch.tensor(cu_seqlens_q, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||
cu_seqlens_k = torch.tensor(cu_seqlens_k, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||
set_context(True, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, None, block_tables)
|
||||
debug_end("prepare_prefill", _t0, prefix="tensor.vllm")
|
||||
return input_ids, positions
|
||||
|
||||
def prepare_decode(self, seqs: list[Sequence]):
|
||||
"""Optimized decode preparation using pre-allocated buffers."""
|
||||
_t0 = debug_start("prepare_decode", prefix="tensor.vllm")
|
||||
bs = len(seqs)
|
||||
|
||||
# Use pre-allocated CPU buffers
|
||||
for i, seq in enumerate(seqs):
|
||||
self._cpu_input_ids[i] = seq.last_token
|
||||
self._cpu_positions[i] = len(seq) - 1
|
||||
self._cpu_context_lens[i] = len(seq)
|
||||
self._cpu_slot_mapping[i] = seq.block_table[-1] * self.block_size + seq.last_block_num_tokens - 1
|
||||
|
||||
# Transfer to GPU using sliced views
|
||||
input_ids = self._cpu_input_ids[:bs].cuda(non_blocking=True)
|
||||
positions = self._cpu_positions[:bs].cuda(non_blocking=True)
|
||||
slot_mapping = self._cpu_slot_mapping[:bs].cuda(non_blocking=True)
|
||||
context_lens = self._cpu_context_lens[:bs].cuda(non_blocking=True)
|
||||
block_tables = self.prepare_block_tables(seqs)
|
||||
set_context(False, slot_mapping=slot_mapping, context_lens=context_lens, block_tables=block_tables)
|
||||
debug_end("prepare_decode", _t0, prefix="tensor.vllm")
|
||||
return input_ids, positions
|
||||
|
||||
def prepare_sample(self, seqs: list[Sequence], is_cfg_batch: bool = False):
|
||||
"""Optimized sample preparation using pre-allocated buffers."""
|
||||
_t0 = debug_start("prepare_sample", prefix="tensor.vllm")
|
||||
if is_cfg_batch:
|
||||
num_seqs = len(seqs) // 2
|
||||
target_seqs = seqs[:num_seqs]
|
||||
else:
|
||||
num_seqs = len(seqs)
|
||||
target_seqs = seqs
|
||||
|
||||
# Fill pre-allocated CPU buffers
|
||||
top_ks_is_zero = True
|
||||
top_ps_is_one = True
|
||||
repetition_penalties_is_one = True
|
||||
for i, seq in enumerate(target_seqs):
|
||||
self._cpu_temperatures[i] = seq.temperature
|
||||
self._cpu_cfg_scales[i] = seq.cfg_scale
|
||||
self._cpu_top_ks[i] = seq.top_k if seq.top_k is not None else 0
|
||||
if seq.top_k is not None and seq.top_k > 0:
|
||||
top_ks_is_zero = False
|
||||
self._cpu_top_ps[i] = seq.top_p if seq.top_p is not None else 1.0
|
||||
if seq.top_p is not None and seq.top_p == 1.0:
|
||||
top_ps_is_one = False
|
||||
self._cpu_repetition_penalties[i] = seq.repetition_penalty if seq.repetition_penalty is not None else 1.0
|
||||
if seq.repetition_penalty is not None and seq.repetition_penalty == 1.0:
|
||||
repetition_penalties_is_one = False
|
||||
|
||||
# Transfer to GPU using sliced views (single batched transfer)
|
||||
temperatures = self._cpu_temperatures[:num_seqs].cuda(non_blocking=True)
|
||||
cfg_scales = self._cpu_cfg_scales[:num_seqs].cuda(non_blocking=True)
|
||||
top_ks = self._cpu_top_ks[:num_seqs].cuda(non_blocking=True) if not top_ks_is_zero else None
|
||||
top_ps = self._cpu_top_ps[:num_seqs].cuda(non_blocking=True) if not top_ps_is_one else None
|
||||
repetition_penalties = self._cpu_repetition_penalties[:num_seqs].cuda(non_blocking=True) if not repetition_penalties_is_one else None
|
||||
|
||||
debug_end("prepare_sample", _t0, prefix="tensor.vllm")
|
||||
return temperatures, cfg_scales, top_ks, top_ps, repetition_penalties
|
||||
|
||||
@torch.inference_mode()
|
||||
def run_model(self, input_ids: torch.Tensor, positions: torch.Tensor, is_prefill: bool):
|
||||
_t0 = debug_start("run_model", prefix="tensor.vllm")
|
||||
if is_prefill or self.enforce_eager or input_ids.size(0) > 512:
|
||||
_debug_log(f"run_model: eager mode, is_prefill={is_prefill}, bs={input_ids.size(0)}")
|
||||
out = self.model.compute_logits(self.model(input_ids, positions))
|
||||
debug_end("run_model", _t0, prefix="tensor.vllm")
|
||||
return out
|
||||
else:
|
||||
bs = input_ids.size(0)
|
||||
context = get_context()
|
||||
|
||||
_debug_log(f"run_model: decode mode, bs={bs}")
|
||||
_debug_log(f" context.block_tables.shape={context.block_tables.shape}")
|
||||
_debug_log(f" context.slot_mapping.shape={context.slot_mapping.shape}")
|
||||
_debug_log(f" context.context_lens.shape={context.context_lens.shape}")
|
||||
_debug_log(f" context.slot_mapping={context.slot_mapping.tolist()}")
|
||||
_debug_log(f" context.context_lens={context.context_lens.tolist()}")
|
||||
|
||||
# Check if block_tables size exceeds pre-allocated buffer size
|
||||
# This can happen when conditional and unconditional sequences have different lengths
|
||||
# in CFG mode, causing block_tables to have more columns than expected
|
||||
max_num_blocks = self.graph_vars["block_tables"].size(1)
|
||||
if context.block_tables.size(1) > max_num_blocks:
|
||||
# Fall back to eager mode when block_tables is too large for CUDA graph
|
||||
_debug_log(f" fallback: block_tables cols {context.block_tables.size(1)} > max {max_num_blocks}")
|
||||
out = self.model.compute_logits(self.model(input_ids, positions))
|
||||
debug_end("run_model", _t0, prefix="tensor.vllm")
|
||||
return out
|
||||
|
||||
# Fix: Also check if block_tables row count matches batch size
|
||||
# Dimension mismatch can cause CUDA illegal memory access during graph replay
|
||||
if context.block_tables.size(0) != bs:
|
||||
# Fall back to eager mode when block_tables row count doesn't match batch size
|
||||
_debug_log(f" fallback: block_tables rows {context.block_tables.size(0)} != bs {bs}")
|
||||
out = self.model.compute_logits(self.model(input_ids, positions))
|
||||
debug_end("run_model", _t0, prefix="tensor.vllm")
|
||||
return out
|
||||
|
||||
# Fix: Verify slot_mapping and context_lens dimensions match batch size
|
||||
if context.slot_mapping.size(0) != bs or context.context_lens.size(0) != bs:
|
||||
# Fall back to eager mode when dimensions don't match
|
||||
_debug_log(f" fallback: slot_mapping/context_lens size mismatch")
|
||||
out = self.model.compute_logits(self.model(input_ids, positions))
|
||||
debug_end("run_model", _t0, prefix="tensor.vllm")
|
||||
return out
|
||||
|
||||
# Validate block_tables values
|
||||
if _DEBUG:
|
||||
max_block_id = context.block_tables.max().item()
|
||||
min_block_id = context.block_tables[context.block_tables >= 0].min().item() if (context.block_tables >= 0).any() else -1
|
||||
_debug_log(f" block_tables range: [{min_block_id}, {max_block_id}]")
|
||||
_debug_log(f" num_kvcache_blocks: {self.config.num_kvcache_blocks}")
|
||||
if max_block_id >= self.config.num_kvcache_blocks:
|
||||
_debug_log(f" WARNING: block_table contains invalid block_id {max_block_id} >= {self.config.num_kvcache_blocks}")
|
||||
|
||||
graph = self.graphs[next(x for x in self.graph_bs if x >= bs)]
|
||||
graph_vars = self.graph_vars
|
||||
graph_vars["input_ids"][:bs] = input_ids
|
||||
graph_vars["positions"][:bs] = positions
|
||||
graph_vars["slot_mapping"].fill_(-1)
|
||||
graph_vars["slot_mapping"][:bs] = context.slot_mapping
|
||||
graph_vars["context_lens"].zero_()
|
||||
graph_vars["context_lens"][:bs] = context.context_lens
|
||||
# Clear block_tables first to ensure no stale data from previous runs
|
||||
graph_vars["block_tables"][:bs].fill_(-1)
|
||||
graph_vars["block_tables"][:bs, :context.block_tables.size(1)] = context.block_tables
|
||||
|
||||
_debug_log(f" executing CUDA graph replay for bs={bs}")
|
||||
graph.replay()
|
||||
out = self.model.compute_logits(graph_vars["outputs"][:bs])
|
||||
debug_end("run_model", _t0, prefix="tensor.vllm")
|
||||
return out
|
||||
|
||||
def run(self, seqs: list[Sequence], is_prefill: bool) -> list[int]:
|
||||
"""Run model forward and sampling. For CFG sequences, batch is structured as:
|
||||
[cond_seq1, cond_seq2, ..., uncond_seq1, uncond_seq2, ...]
|
||||
where uncond_seqi is the paired unconditional sequence of cond_seqi."""
|
||||
_debug_log(f"run: num_seqs={len(seqs)}, is_prefill={is_prefill}")
|
||||
for i, seq in enumerate(seqs):
|
||||
_debug_log(f" seq[{i}]: len={len(seq)}, num_blocks={seq.num_blocks}, "
|
||||
f"cfg_scale={seq.cfg_scale}, is_uncond={seq.is_unconditional}, "
|
||||
f"block_table={seq.block_table}")
|
||||
|
||||
# Check if this is a CFG batch (contains paired conditional and unconditional sequences)
|
||||
is_cfg_batch = seqs[0].cfg_scale > 1.0 and seqs[0].paired_seq is not None
|
||||
_debug_log(f" is_cfg_batch={is_cfg_batch}")
|
||||
if is_cfg_batch:
|
||||
# CFG batch: seqs = [cond_seq1, cond_seq2, ..., uncond_seq1, uncond_seq2, ...]
|
||||
num_cond = len(seqs) // 2
|
||||
cond_seqs = seqs[:num_cond]
|
||||
# uncond_seqs = seqs[num_cond:]
|
||||
|
||||
# Prepare inputs for both conditional and unconditional (they're already in the batch)
|
||||
input_ids, positions = (self.prepare_prefill(seqs) if is_prefill else self.prepare_decode(seqs))
|
||||
sample_params = self.prepare_sample(seqs, is_cfg_batch=True) if self.rank == 0 else None
|
||||
if sample_params is not None:
|
||||
temperatures, cfg_scales, top_ks, top_ps, repetition_penalties = sample_params
|
||||
else:
|
||||
temperatures = cfg_scales = top_ks = top_ps = repetition_penalties = None
|
||||
|
||||
# Run model forward (processes entire batch: cond + uncond)
|
||||
logits_all = self.run_model(input_ids, positions, is_prefill)
|
||||
reset_context()
|
||||
|
||||
if self.rank == 0:
|
||||
# Split logits: first half is conditional, second half is unconditional
|
||||
logits_cond = logits_all[:num_cond]
|
||||
logits_uncond = logits_all[num_cond:]
|
||||
|
||||
# Apply repetition penalty to conditional logits (before CFG)
|
||||
if repetition_penalties is not None:
|
||||
for i, seq in enumerate(cond_seqs):
|
||||
penalty = repetition_penalties[i].item()
|
||||
if penalty != 1.0:
|
||||
# Only penalize completion tokens (not prompt tokens)
|
||||
completion_tokens = torch.tensor(seq.completion_token_ids, device=logits_cond.device)
|
||||
if len(completion_tokens) > 0:
|
||||
# Create token mask: mark tokens that appeared in completion
|
||||
token_mask = torch.zeros(logits_cond.shape[1], dtype=torch.bool, device=logits_cond.device)
|
||||
token_mask[completion_tokens] = True
|
||||
|
||||
# Apply standard repetition penalty formula (matching transformers implementation):
|
||||
# For tokens in completion: if score < 0 then score * penalty, else score / penalty
|
||||
penalty_scores = torch.where(
|
||||
logits_cond[i] < 0,
|
||||
logits_cond[i] * penalty,
|
||||
logits_cond[i] / penalty
|
||||
)
|
||||
# Only apply penalty to tokens that appeared in completion
|
||||
logits_cond[i] = torch.where(token_mask, penalty_scores, logits_cond[i])
|
||||
|
||||
# Apply CFG formula: logits_cfg = logits_uncond + cfg_scale * (logits_cond - logits_uncond)
|
||||
cfg_scales_tensor = cfg_scales.unsqueeze(1) # [num_cond, 1]
|
||||
logits_cfg = logits_uncond + cfg_scales_tensor * (logits_cond - logits_uncond)
|
||||
|
||||
# Apply logits processor for constrained decoding (if any sequence has one)
|
||||
for i, seq in enumerate(cond_seqs):
|
||||
if seq.logits_processor is not None:
|
||||
# Create input_ids tensor for this sequence
|
||||
seq_input_ids = torch.tensor([seq.token_ids], device=logits_cfg.device)
|
||||
# Apply processor to this sequence's logits
|
||||
logits_cfg[i:i+1] = seq.logits_processor(seq_input_ids, logits_cfg[i:i+1])
|
||||
|
||||
# Prepare input_ids for sampler (for repetition penalty, though we already applied it)
|
||||
# cond_input_ids = torch.tensor([seq.token_ids for seq in cond_seqs], device=logits_cfg.device)
|
||||
|
||||
# Sample from CFG logits
|
||||
token_ids_cfg = self.sampler(
|
||||
logits_cfg,
|
||||
temperatures,
|
||||
top_ks=top_ks if top_ks is not None else None,
|
||||
top_ps=top_ps if top_ps is not None else None,
|
||||
repetition_penalties=None, # Already applied above
|
||||
# input_ids=cond_input_ids,
|
||||
).tolist()
|
||||
|
||||
# Update logits processor state after sampling
|
||||
# NOTE: Only update for the first sequence since all sequences share the same processor
|
||||
# Updating multiple times would cause duplicate state updates (e.g., codes_count += N instead of += 1)
|
||||
if cond_seqs and cond_seqs[0].logits_processor_update_state is not None:
|
||||
cond_seqs[0].logits_processor_update_state(token_ids_cfg[0])
|
||||
|
||||
# Return token_ids (will be applied to both conditional and unconditional sequences)
|
||||
return token_ids_cfg
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
# Normal batch (non-CFG)
|
||||
input_ids, positions = (self.prepare_prefill(seqs) if is_prefill
|
||||
else self.prepare_decode(seqs))
|
||||
sample_params = self.prepare_sample(seqs, is_cfg_batch=False) if self.rank == 0 else None
|
||||
if sample_params is not None:
|
||||
temperatures, cfg_scales, top_ks, top_ps, repetition_penalties = sample_params
|
||||
else:
|
||||
temperatures = cfg_scales = top_ks = top_ps = repetition_penalties = None
|
||||
logits = self.run_model(input_ids, positions, is_prefill)
|
||||
reset_context()
|
||||
|
||||
if self.rank == 0:
|
||||
# Apply repetition penalty to logits
|
||||
if repetition_penalties is not None:
|
||||
for i, seq in enumerate(seqs):
|
||||
penalty = repetition_penalties[i].item()
|
||||
if penalty != 1.0:
|
||||
# Only penalize completion tokens (not prompt tokens)
|
||||
completion_tokens = torch.tensor(seq.completion_token_ids, device=logits.device)
|
||||
if len(completion_tokens) > 0:
|
||||
# Create token mask: mark tokens that appeared in completion
|
||||
token_mask = torch.zeros(logits.shape[1], dtype=torch.bool, device=logits.device)
|
||||
token_mask[completion_tokens] = True
|
||||
|
||||
# Apply standard repetition penalty formula (matching transformers implementation):
|
||||
# For tokens in completion: if score < 0 then score * penalty, else score / penalty
|
||||
penalty_scores = torch.where(
|
||||
logits[i] < 0,
|
||||
logits[i] * penalty,
|
||||
logits[i] / penalty
|
||||
)
|
||||
# Only apply penalty to tokens that appeared in completion
|
||||
logits[i] = torch.where(token_mask, penalty_scores, logits[i])
|
||||
|
||||
# Apply logits processor for constrained decoding (if any sequence has one)
|
||||
# Clone logits to avoid in-place update issues in inference mode
|
||||
logits = logits.clone()
|
||||
for i, seq in enumerate(seqs):
|
||||
if seq.logits_processor is not None:
|
||||
# Create input_ids tensor for this sequence
|
||||
seq_input_ids = torch.tensor([seq.token_ids], device=logits.device)
|
||||
# Apply processor to this sequence's logits (clone to avoid inference mode issues)
|
||||
processed = seq.logits_processor(seq_input_ids, logits[i:i+1].clone())
|
||||
logits[i] = processed[0]
|
||||
|
||||
# Prepare input_ids for sampler
|
||||
# seq_input_ids = torch.tensor([seq.token_ids for seq in seqs], device=logits.device)
|
||||
|
||||
token_ids = self.sampler(
|
||||
logits,
|
||||
temperatures,
|
||||
top_ks=top_ks if top_ks is not None else None,
|
||||
top_ps=top_ps if top_ps is not None else None,
|
||||
repetition_penalties=None, # Already applied above
|
||||
# input_ids=seq_input_ids,
|
||||
).tolist()
|
||||
|
||||
# Update logits processor state after sampling
|
||||
# NOTE: Only update for the first sequence since all sequences may share the same processor
|
||||
# (when using a single SamplingParams for batch generation)
|
||||
# Updating multiple times would cause duplicate state updates (e.g., codes_count += N instead of += 1)
|
||||
if seqs and seqs[0].logits_processor_update_state is not None:
|
||||
seqs[0].logits_processor_update_state(token_ids[0])
|
||||
|
||||
return token_ids
|
||||
else:
|
||||
return None
|
||||
|
||||
@torch.inference_mode()
|
||||
def capture_cudagraph(self):
|
||||
_t0 = debug_start("capture_cudagraph", prefix="tensor.vllm")
|
||||
config = self.config
|
||||
hf_config = config.hf_config
|
||||
max_bs = min(self.config.max_num_seqs, 512)
|
||||
max_num_blocks = (config.max_model_len + self.block_size - 1) // self.block_size
|
||||
input_ids = torch.zeros(max_bs, dtype=torch.int64)
|
||||
positions = torch.zeros(max_bs, dtype=torch.int64)
|
||||
slot_mapping = torch.zeros(max_bs, dtype=torch.int32)
|
||||
context_lens = torch.zeros(max_bs, dtype=torch.int32)
|
||||
block_tables = torch.zeros(max_bs, max_num_blocks, dtype=torch.int32)
|
||||
outputs = torch.zeros(max_bs, hf_config.hidden_size)
|
||||
self.graph_bs = [1, 2, 4, 8] + list(range(16, max_bs + 1, 16))
|
||||
self.graphs = {}
|
||||
self.graph_pool = None
|
||||
|
||||
for bs in reversed(self.graph_bs):
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
set_context(False, slot_mapping=slot_mapping[:bs], context_lens=context_lens[:bs], block_tables=block_tables[:bs])
|
||||
outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # warmup
|
||||
with torch.cuda.graph(graph, self.graph_pool):
|
||||
outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # capture
|
||||
if self.graph_pool is None:
|
||||
self.graph_pool = graph.pool()
|
||||
self.graphs[bs] = graph
|
||||
torch.cuda.synchronize()
|
||||
reset_context()
|
||||
|
||||
self.graph_vars = dict(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
slot_mapping=slot_mapping,
|
||||
context_lens=context_lens,
|
||||
block_tables=block_tables,
|
||||
outputs=outputs,
|
||||
)
|
||||
debug_end("capture_cudagraph", _t0, prefix="tensor.vllm")
|
||||
272
acestep/third_parts/nano-vllm/nanovllm/engine/scheduler.py
Normal file
272
acestep/third_parts/nano-vllm/nanovllm/engine/scheduler.py
Normal file
|
|
@ -0,0 +1,272 @@
|
|||
import os
|
||||
from collections import deque
|
||||
|
||||
from nanovllm.config import Config
|
||||
from nanovllm.engine.sequence import Sequence, SequenceStatus
|
||||
from nanovllm.engine.block_manager import BlockManager
|
||||
|
||||
# Debug logging - enable with NANOVLLM_DEBUG=1
|
||||
_DEBUG = os.environ.get("NANOVLLM_DEBUG", "0") == "1"
|
||||
|
||||
def _debug_log(msg: str):
|
||||
"""Print debug message if NANOVLLM_DEBUG is enabled"""
|
||||
if _DEBUG:
|
||||
print(f"[nanovllm scheduler DEBUG] {msg}", flush=True)
|
||||
|
||||
|
||||
class Scheduler:
|
||||
|
||||
def __init__(self, config: Config):
|
||||
self.max_num_seqs = config.max_num_seqs
|
||||
self.max_num_batched_tokens = config.max_num_batched_tokens
|
||||
self.eos = config.eos
|
||||
self.block_manager = BlockManager(config.num_kvcache_blocks, config.kvcache_block_size)
|
||||
self.waiting: deque[Sequence] = deque()
|
||||
self.running: deque[Sequence] = deque()
|
||||
|
||||
def is_finished(self):
|
||||
return not self.waiting and not self.running
|
||||
|
||||
def add(self, seq: Sequence):
|
||||
self.waiting.append(seq)
|
||||
|
||||
def schedule(self) -> tuple[list[Sequence], bool]:
|
||||
_debug_log(f"schedule: waiting={len(self.waiting)}, running={len(self.running)}, "
|
||||
f"free_blocks={len(self.block_manager.free_block_ids)}")
|
||||
|
||||
# prefill
|
||||
scheduled_seqs = []
|
||||
num_seqs = 0
|
||||
num_batched_tokens = 0
|
||||
processed_seqs = set() # Track processed sequences to handle CFG pairs
|
||||
|
||||
while self.waiting and num_seqs < self.max_num_seqs:
|
||||
seq = self.waiting[0]
|
||||
|
||||
# For CFG sequences, ensure conditional and unconditional are scheduled together
|
||||
if seq.cfg_scale > 1.0 and seq.paired_seq is not None and not seq.is_unconditional:
|
||||
# This is a conditional sequence, need to schedule its paired unconditional sequence too
|
||||
paired_seq = seq.paired_seq
|
||||
if paired_seq.status != SequenceStatus.WAITING:
|
||||
# Paired sequence not in waiting, skip this conditional sequence for now
|
||||
break
|
||||
|
||||
# Calculate tokens for both sequences
|
||||
total_tokens = (len(seq) - seq.num_cached_tokens) + (len(paired_seq) - paired_seq.num_cached_tokens)
|
||||
|
||||
# FIX: Check if we have enough blocks for BOTH sequences combined
|
||||
# The old check was wrong: it checked each sequence independently,
|
||||
# but didn't account for the total blocks needed by both
|
||||
total_blocks_needed = seq.num_blocks + paired_seq.num_blocks
|
||||
can_allocate_both = len(self.block_manager.free_block_ids) >= total_blocks_needed
|
||||
|
||||
if num_batched_tokens + total_tokens > self.max_num_batched_tokens or not can_allocate_both:
|
||||
break
|
||||
|
||||
# Schedule both sequences: conditional first, then unconditional
|
||||
for s in [seq, paired_seq]:
|
||||
num_seqs += 1
|
||||
self.block_manager.allocate(s)
|
||||
num_batched_tokens += len(s) - s.num_cached_tokens
|
||||
s.status = SequenceStatus.RUNNING
|
||||
self.waiting.remove(s)
|
||||
self.running.append(s)
|
||||
scheduled_seqs.append(s)
|
||||
processed_seqs.add(s.seq_id)
|
||||
else:
|
||||
# Normal sequence or unconditional sequence (already processed with its conditional)
|
||||
if seq.seq_id in processed_seqs:
|
||||
# Skip if already processed as part of a CFG pair
|
||||
self.waiting.popleft()
|
||||
continue
|
||||
|
||||
if num_batched_tokens + len(seq) > self.max_num_batched_tokens or not self.block_manager.can_allocate(seq):
|
||||
break
|
||||
num_seqs += 1
|
||||
self.block_manager.allocate(seq)
|
||||
num_batched_tokens += len(seq) - seq.num_cached_tokens
|
||||
seq.status = SequenceStatus.RUNNING
|
||||
self.waiting.popleft()
|
||||
self.running.append(seq)
|
||||
scheduled_seqs.append(seq)
|
||||
|
||||
if scheduled_seqs:
|
||||
# For CFG batches, ensure conditional sequences come before their unconditional pairs
|
||||
cfg_cond_seqs = [s for s in scheduled_seqs if s.cfg_scale > 1.0 and not s.is_unconditional]
|
||||
cfg_uncond_seqs = [s for s in scheduled_seqs if s.is_unconditional]
|
||||
non_cfg_seqs = [s for s in scheduled_seqs if s.cfg_scale <= 1.0]
|
||||
|
||||
# Reorder: non-CFG, then CFG conditional, then CFG unconditional
|
||||
scheduled_seqs = non_cfg_seqs + cfg_cond_seqs + cfg_uncond_seqs
|
||||
return scheduled_seqs, True
|
||||
|
||||
# decode
|
||||
processed_seqs = set()
|
||||
temp_running = list(self.running) # Work with a copy
|
||||
|
||||
while temp_running and num_seqs < self.max_num_seqs:
|
||||
seq = temp_running.pop(0)
|
||||
|
||||
# For CFG sequences, ensure conditional and unconditional are scheduled together
|
||||
if seq.cfg_scale > 1.0 and seq.paired_seq is not None and not seq.is_unconditional:
|
||||
paired_seq = seq.paired_seq
|
||||
if paired_seq not in temp_running:
|
||||
# Paired sequence not available, skip for now
|
||||
continue
|
||||
|
||||
# Remove paired_seq from temp_running
|
||||
temp_running.remove(paired_seq)
|
||||
|
||||
# FIX: Check if we have enough blocks for BOTH sequences to append
|
||||
# Each sequence needs 1 block when at block boundary (len % block_size == 1)
|
||||
block_size = self.block_manager.block_size
|
||||
blocks_needed_seq = 1 if len(seq) % block_size == 1 else 0
|
||||
blocks_needed_paired = 1 if len(paired_seq) % block_size == 1 else 0
|
||||
total_blocks_needed = blocks_needed_seq + blocks_needed_paired
|
||||
can_append_both = len(self.block_manager.free_block_ids) >= total_blocks_needed
|
||||
|
||||
if not can_append_both:
|
||||
# Try preempting other sequences
|
||||
preempted = False
|
||||
while not can_append_both and temp_running:
|
||||
other_seq = temp_running.pop(0)
|
||||
if other_seq != seq and other_seq != paired_seq:
|
||||
self.preempt(other_seq)
|
||||
# Recalculate with the same correct logic
|
||||
can_append_both = len(self.block_manager.free_block_ids) >= total_blocks_needed
|
||||
preempted = True
|
||||
else:
|
||||
temp_running.append(other_seq)
|
||||
break
|
||||
|
||||
if not can_append_both:
|
||||
# Can't schedule this pair right now
|
||||
temp_running.append(seq)
|
||||
temp_running.append(paired_seq)
|
||||
continue
|
||||
|
||||
# Schedule both sequences
|
||||
for s in [seq, paired_seq]:
|
||||
num_seqs += 1
|
||||
self.block_manager.may_append(s)
|
||||
scheduled_seqs.append(s)
|
||||
processed_seqs.add(s.seq_id)
|
||||
# Remove from actual running list if scheduled
|
||||
if s in self.running:
|
||||
self.running.remove(s)
|
||||
else:
|
||||
# Normal sequence or unconditional (already processed)
|
||||
if seq.seq_id in processed_seqs:
|
||||
continue
|
||||
|
||||
while not self.block_manager.can_append(seq):
|
||||
if temp_running:
|
||||
other_seq = temp_running.pop(0)
|
||||
if other_seq != seq:
|
||||
self.preempt(other_seq)
|
||||
else:
|
||||
temp_running.append(other_seq)
|
||||
break
|
||||
else:
|
||||
self.preempt(seq)
|
||||
if seq in self.running:
|
||||
self.running.remove(seq)
|
||||
break
|
||||
else:
|
||||
num_seqs += 1
|
||||
self.block_manager.may_append(seq)
|
||||
scheduled_seqs.append(seq)
|
||||
if seq in self.running:
|
||||
self.running.remove(seq)
|
||||
|
||||
if not scheduled_seqs:
|
||||
# No sequences could be scheduled - provide informative error
|
||||
waiting_count = len(self.waiting)
|
||||
running_count = len(self.running)
|
||||
free_blocks = len(self.block_manager.free_block_ids)
|
||||
total_blocks = len(self.block_manager.blocks)
|
||||
|
||||
if waiting_count > 0:
|
||||
seq = self.waiting[0]
|
||||
blocks_needed = seq.num_blocks
|
||||
prompt_tokens = len(seq)
|
||||
if seq.cfg_scale > 1.0 and seq.paired_seq is not None:
|
||||
blocks_needed += seq.paired_seq.num_blocks
|
||||
prompt_tokens = f"{len(seq)}+{len(seq.paired_seq)}"
|
||||
raise RuntimeError(
|
||||
f"Insufficient KV cache to schedule sequence. "
|
||||
f"Free blocks: {free_blocks}/{total_blocks}, blocks needed: {blocks_needed}, "
|
||||
f"prompt tokens: {prompt_tokens}, block size: {self.block_manager.block_size}. "
|
||||
f"The prompt may be too long for available GPU memory, or gpu_memory_utilization is too low."
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"No schedulable sequences found. "
|
||||
f"Waiting: {waiting_count}, Running: {running_count}, "
|
||||
f"Free blocks: {free_blocks}/{total_blocks}"
|
||||
)
|
||||
|
||||
# For CFG batches in decode, ensure conditional sequences come before unconditional
|
||||
cfg_cond_seqs = [s for s in scheduled_seqs if s.cfg_scale > 1.0 and not s.is_unconditional]
|
||||
cfg_uncond_seqs = [s for s in scheduled_seqs if s.is_unconditional]
|
||||
non_cfg_seqs = [s for s in scheduled_seqs if s.cfg_scale <= 1.0]
|
||||
scheduled_seqs = non_cfg_seqs + cfg_cond_seqs + cfg_uncond_seqs
|
||||
|
||||
self.running.extendleft(reversed(scheduled_seqs))
|
||||
return scheduled_seqs, False
|
||||
|
||||
def preempt(self, seq: Sequence):
|
||||
seq.status = SequenceStatus.WAITING
|
||||
self.block_manager.deallocate(seq)
|
||||
self.waiting.appendleft(seq)
|
||||
|
||||
def postprocess(self, seqs: list[Sequence], token_ids: list[int]) -> list[bool]:
|
||||
_debug_log(f"postprocess: num_seqs={len(seqs)}, num_token_ids={len(token_ids) if token_ids else 0}")
|
||||
if token_ids:
|
||||
_debug_log(f" token_ids: {token_ids[:10]}..." if len(token_ids) > 10 else f" token_ids: {token_ids}")
|
||||
|
||||
# Check if this is a CFG batch
|
||||
is_cfg_batch = False
|
||||
if len(seqs) > 0 and seqs[0].cfg_scale > 1.0 and seqs[0].paired_seq is not None:
|
||||
num_cond = len(seqs) // 2
|
||||
is_cfg_batch = (num_cond > 0 and
|
||||
not seqs[0].is_unconditional and
|
||||
seqs[num_cond].is_unconditional)
|
||||
_debug_log(f" is_cfg_batch={is_cfg_batch}")
|
||||
|
||||
if is_cfg_batch:
|
||||
# CFG batch: seqs = [cond_seq1, cond_seq2, ..., uncond_seq1, uncond_seq2, ...]
|
||||
# token_ids correspond to conditional sequences only (sampled from CFG logits)
|
||||
num_cond = len(seqs) // 2
|
||||
cond_seqs = seqs[:num_cond]
|
||||
uncond_seqs = seqs[num_cond:]
|
||||
|
||||
# Apply the same sampled token to both conditional and unconditional sequences
|
||||
for i, (cond_seq, uncond_seq, token_id) in enumerate(zip(cond_seqs, uncond_seqs, token_ids)):
|
||||
cond_seq.append_token(token_id)
|
||||
uncond_seq.append_token(token_id) # Same token for unconditional
|
||||
|
||||
# Check if either sequence is finished
|
||||
cond_finished = ((not cond_seq.ignore_eos and token_id == self.eos) or
|
||||
cond_seq.num_completion_tokens == cond_seq.max_tokens)
|
||||
uncond_finished = ((not uncond_seq.ignore_eos and token_id == self.eos) or
|
||||
uncond_seq.num_completion_tokens == uncond_seq.max_tokens)
|
||||
|
||||
if cond_finished or uncond_finished:
|
||||
# Mark both as finished
|
||||
cond_seq.status = SequenceStatus.FINISHED
|
||||
uncond_seq.status = SequenceStatus.FINISHED
|
||||
self.block_manager.deallocate(cond_seq)
|
||||
self.block_manager.deallocate(uncond_seq)
|
||||
if cond_seq in self.running:
|
||||
self.running.remove(cond_seq)
|
||||
if uncond_seq in self.running:
|
||||
self.running.remove(uncond_seq)
|
||||
else:
|
||||
# Normal batch
|
||||
for seq, token_id in zip(seqs, token_ids):
|
||||
seq.append_token(token_id)
|
||||
if (not seq.ignore_eos and token_id == self.eos) or seq.num_completion_tokens == seq.max_tokens:
|
||||
seq.status = SequenceStatus.FINISHED
|
||||
self.block_manager.deallocate(seq)
|
||||
self.running.remove(seq)
|
||||
96
acestep/third_parts/nano-vllm/nanovllm/engine/sequence.py
Normal file
96
acestep/third_parts/nano-vllm/nanovllm/engine/sequence.py
Normal file
|
|
@ -0,0 +1,96 @@
|
|||
from copy import copy
|
||||
from enum import Enum, auto
|
||||
from itertools import count
|
||||
from typing import Optional, Callable, Any
|
||||
|
||||
from nanovllm.sampling_params import SamplingParams
|
||||
|
||||
|
||||
class SequenceStatus(Enum):
|
||||
WAITING = auto()
|
||||
RUNNING = auto()
|
||||
FINISHED = auto()
|
||||
|
||||
|
||||
class Sequence:
|
||||
block_size = 256
|
||||
counter = count()
|
||||
|
||||
def __init__(self, token_ids: list[int], sampling_params = SamplingParams(), is_unconditional: bool = False, conditional_seq = None):
|
||||
self.seq_id = next(Sequence.counter)
|
||||
self.status = SequenceStatus.WAITING
|
||||
self.token_ids = copy(token_ids)
|
||||
self.last_token = token_ids[-1]
|
||||
self.num_tokens = len(self.token_ids)
|
||||
self.num_prompt_tokens = len(token_ids)
|
||||
self.num_cached_tokens = 0
|
||||
self.block_table = []
|
||||
self.temperature = sampling_params.temperature
|
||||
self.max_tokens = sampling_params.max_tokens
|
||||
self.ignore_eos = sampling_params.ignore_eos
|
||||
self.cfg_scale = sampling_params.cfg_scale
|
||||
self.top_k = sampling_params.top_k
|
||||
self.top_p = sampling_params.top_p
|
||||
self.repetition_penalty = sampling_params.repetition_penalty
|
||||
# For CFG: mark if this is an unconditional sequence
|
||||
self.is_unconditional = is_unconditional
|
||||
# For CFG: reference to the corresponding conditional sequence (if this is unconditional)
|
||||
# For conditional sequences, this points to the unconditional sequence
|
||||
self.paired_seq = conditional_seq # For conditional seq, points to uncond; for uncond seq, points to cond
|
||||
# For constrained decoding: logits processor and state update callback
|
||||
self.logits_processor: Optional[Any] = sampling_params.logits_processor
|
||||
self.logits_processor_update_state: Optional[Callable[[int], None]] = sampling_params.logits_processor_update_state
|
||||
|
||||
def __len__(self):
|
||||
return self.num_tokens
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self.token_ids[key]
|
||||
|
||||
@property
|
||||
def is_finished(self):
|
||||
return self.status == SequenceStatus.FINISHED
|
||||
|
||||
@property
|
||||
def num_completion_tokens(self):
|
||||
return self.num_tokens - self.num_prompt_tokens
|
||||
|
||||
@property
|
||||
def prompt_token_ids(self):
|
||||
return self.token_ids[:self.num_prompt_tokens]
|
||||
|
||||
@property
|
||||
def completion_token_ids(self):
|
||||
return self.token_ids[self.num_prompt_tokens:]
|
||||
|
||||
@property
|
||||
def num_cached_blocks(self):
|
||||
return self.num_cached_tokens // self.block_size
|
||||
|
||||
@property
|
||||
def num_blocks(self):
|
||||
return (self.num_tokens + self.block_size - 1) // self.block_size
|
||||
|
||||
@property
|
||||
def last_block_num_tokens(self):
|
||||
return self.num_tokens - (self.num_blocks - 1) * self.block_size
|
||||
|
||||
def block(self, i):
|
||||
assert 0 <= i < self.num_blocks
|
||||
return self.token_ids[i*self.block_size: (i+1)*self.block_size]
|
||||
|
||||
def append_token(self, token_id: int):
|
||||
self.token_ids.append(token_id)
|
||||
self.last_token = token_id
|
||||
self.num_tokens += 1
|
||||
|
||||
def __getstate__(self):
|
||||
return (self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.block_table,
|
||||
self.token_ids if self.num_completion_tokens == 0 else self.last_token)
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.block_table = state[:-1]
|
||||
if self.num_completion_tokens == 0:
|
||||
self.token_ids = state[-1]
|
||||
else:
|
||||
self.last_token = state[-1]
|
||||
16
acestep/third_parts/nano-vllm/nanovllm/layers/activation.py
Executable file
16
acestep/third_parts/nano-vllm/nanovllm/layers/activation.py
Executable file
|
|
@ -0,0 +1,16 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from nanovllm.utils.compat import maybe_compile
|
||||
|
||||
|
||||
class SiluAndMul(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@maybe_compile
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x, y = x.chunk(2, -1)
|
||||
return F.silu(x) * y
|
||||
387
acestep/third_parts/nano-vllm/nanovllm/layers/attention.py
Normal file
387
acestep/third_parts/nano-vllm/nanovllm/layers/attention.py
Normal file
|
|
@ -0,0 +1,387 @@
|
|||
import os
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from nanovllm.utils.context import get_context
|
||||
|
||||
# Debug logging - enable with NANOVLLM_DEBUG=1
|
||||
_DEBUG = os.environ.get("NANOVLLM_DEBUG", "0") == "1"
|
||||
|
||||
def _debug_log(msg: str):
|
||||
"""Print debug message if NANOVLLM_DEBUG is enabled"""
|
||||
if _DEBUG:
|
||||
print(f"[nanovllm attention DEBUG] {msg}", flush=True)
|
||||
|
||||
# Optional dependencies: Triton (for KV cache kernel) and Flash Attention
|
||||
_HAS_TRITON = False
|
||||
_HAS_FLASH_ATTN = False
|
||||
|
||||
try:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
_HAS_TRITON = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
from flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
|
||||
_HAS_FLASH_ATTN = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Triton KV cache store kernel (original, used when available)
|
||||
# ============================================================
|
||||
|
||||
if _HAS_TRITON:
|
||||
@triton.jit
|
||||
def store_kvcache_kernel(
|
||||
key_ptr,
|
||||
key_stride,
|
||||
value_ptr,
|
||||
value_stride,
|
||||
k_cache_ptr,
|
||||
v_cache_ptr,
|
||||
slot_mapping_ptr,
|
||||
D: tl.constexpr,
|
||||
):
|
||||
idx = tl.program_id(0)
|
||||
slot = tl.load(slot_mapping_ptr + idx)
|
||||
if slot == -1: return
|
||||
key_offsets = idx * key_stride + tl.arange(0, D)
|
||||
value_offsets = idx * value_stride + tl.arange(0, D)
|
||||
key = tl.load(key_ptr + key_offsets)
|
||||
value = tl.load(value_ptr + value_offsets)
|
||||
cache_offsets = slot * D + tl.arange(0, D)
|
||||
tl.store(k_cache_ptr + cache_offsets, key)
|
||||
tl.store(v_cache_ptr + cache_offsets, value)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Pure PyTorch KV cache store (fallback when Triton unavailable)
|
||||
# ============================================================
|
||||
|
||||
def _store_kvcache_pytorch(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
):
|
||||
"""Store key/value into paged KV cache using pure PyTorch ops.
|
||||
|
||||
Args:
|
||||
key: [N, num_kv_heads, head_dim]
|
||||
value: [N, num_kv_heads, head_dim]
|
||||
k_cache: [num_blocks, block_size, num_kv_heads, head_dim] (per-layer view)
|
||||
v_cache: [num_blocks, block_size, num_kv_heads, head_dim]
|
||||
slot_mapping: [N] - flat slot indices into cache
|
||||
"""
|
||||
N, num_kv_heads, head_dim = key.shape
|
||||
D = num_kv_heads * head_dim
|
||||
|
||||
# View cache as flat [total_slots, D]
|
||||
k_flat = k_cache.reshape(-1, D)
|
||||
v_flat = v_cache.reshape(-1, D)
|
||||
|
||||
# View keys/values as [N, D]
|
||||
key_flat = key.reshape(N, D)
|
||||
value_flat = value.reshape(N, D)
|
||||
|
||||
# Filter out padding slots (slot == -1)
|
||||
valid_mask = slot_mapping != -1
|
||||
valid_slots = slot_mapping[valid_mask]
|
||||
k_flat[valid_slots] = key_flat[valid_mask]
|
||||
v_flat[valid_slots] = value_flat[valid_mask]
|
||||
|
||||
|
||||
def store_kvcache(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
):
|
||||
"""Store key/value into paged KV cache. Uses Triton kernel when available."""
|
||||
if _HAS_TRITON:
|
||||
N, num_heads, head_dim = key.shape
|
||||
D = num_heads * head_dim
|
||||
assert key.stride(-1) == 1 and value.stride(-1) == 1
|
||||
assert key.stride(1) == head_dim and value.stride(1) == head_dim
|
||||
assert k_cache.stride(1) == D and v_cache.stride(1) == D
|
||||
assert slot_mapping.numel() == N
|
||||
store_kvcache_kernel[(N,)](key, key.stride(0), value, value.stride(0), k_cache, v_cache, slot_mapping, D)
|
||||
else:
|
||||
_store_kvcache_pytorch(key, value, k_cache, v_cache, slot_mapping)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# SDPA-based attention (fallback when Flash Attention unavailable)
|
||||
# ============================================================
|
||||
|
||||
def _sdpa_varlen_prefill(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
cu_seqlens_q: torch.Tensor,
|
||||
cu_seqlens_k: torch.Tensor,
|
||||
scale: float,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
) -> torch.Tensor:
|
||||
"""SDPA replacement for flash_attn_varlen_func during prefill.
|
||||
|
||||
Splits packed sequences, runs SDPA per sequence with causal masking,
|
||||
then re-packs. Handles GQA via enable_gqa when heads differ.
|
||||
|
||||
Args:
|
||||
q: [total_q_tokens, num_heads, head_dim]
|
||||
k: [total_k_tokens, num_kv_heads, head_dim]
|
||||
v: [total_k_tokens, num_kv_heads, head_dim]
|
||||
cu_seqlens_q: [num_seqs + 1] cumulative sequence lengths for queries
|
||||
cu_seqlens_k: [num_seqs + 1] cumulative sequence lengths for keys
|
||||
scale: attention scale factor
|
||||
num_heads: number of query heads
|
||||
num_kv_heads: number of KV heads
|
||||
|
||||
Returns:
|
||||
output: [total_q_tokens, num_heads, head_dim]
|
||||
"""
|
||||
num_seqs = cu_seqlens_q.shape[0] - 1
|
||||
outputs = []
|
||||
enable_gqa = num_heads != num_kv_heads
|
||||
|
||||
for i in range(num_seqs):
|
||||
q_start = cu_seqlens_q[i].item()
|
||||
q_end = cu_seqlens_q[i + 1].item()
|
||||
k_start = cu_seqlens_k[i].item()
|
||||
k_end = cu_seqlens_k[i + 1].item()
|
||||
|
||||
# [seq_len, heads, dim] -> [1, heads, seq_len, dim]
|
||||
qi = q[q_start:q_end].unsqueeze(0).transpose(1, 2)
|
||||
ki = k[k_start:k_end].unsqueeze(0).transpose(1, 2)
|
||||
vi = v[k_start:k_end].unsqueeze(0).transpose(1, 2)
|
||||
|
||||
oi = F.scaled_dot_product_attention(
|
||||
qi, ki, vi, scale=scale, is_causal=True, enable_gqa=enable_gqa
|
||||
)
|
||||
|
||||
# [1, heads, seq_len, dim] -> [seq_len, heads, dim]
|
||||
outputs.append(oi.transpose(1, 2).squeeze(0))
|
||||
|
||||
return torch.cat(outputs, dim=0)
|
||||
|
||||
|
||||
def _sdpa_prefill_with_paged_cache(
|
||||
q: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor,
|
||||
cu_seqlens_q: torch.Tensor,
|
||||
cu_seqlens_k: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
scale: float,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
) -> torch.Tensor:
|
||||
"""SDPA prefill with paged KV cache (prefix caching case).
|
||||
|
||||
Args:
|
||||
q: [total_q_tokens, num_heads, head_dim]
|
||||
k_cache: [num_blocks, block_size, num_kv_heads, head_dim]
|
||||
v_cache: [num_blocks, block_size, num_kv_heads, head_dim]
|
||||
cu_seqlens_q: [num_seqs + 1]
|
||||
cu_seqlens_k: [num_seqs + 1]
|
||||
block_tables: [num_seqs, max_blocks_per_seq]
|
||||
scale: attention scale factor
|
||||
num_heads: number of query heads
|
||||
num_kv_heads: number of KV heads
|
||||
|
||||
Returns:
|
||||
output: [total_q_tokens, num_heads, head_dim]
|
||||
"""
|
||||
block_size = k_cache.shape[1]
|
||||
num_seqs = cu_seqlens_q.shape[0] - 1
|
||||
outputs = []
|
||||
enable_gqa = num_heads != num_kv_heads
|
||||
|
||||
for i in range(num_seqs):
|
||||
q_start = cu_seqlens_q[i].item()
|
||||
q_end = cu_seqlens_q[i + 1].item()
|
||||
k_len = cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item()
|
||||
|
||||
# Gather k/v from paged cache
|
||||
num_blocks_needed = (k_len + block_size - 1) // block_size
|
||||
block_indices = block_tables[i, :num_blocks_needed]
|
||||
ki = k_cache[block_indices].reshape(-1, num_kv_heads, k_cache.shape[-1])[:k_len]
|
||||
vi = v_cache[block_indices].reshape(-1, num_kv_heads, v_cache.shape[-1])[:k_len]
|
||||
|
||||
# [seq, heads, dim] -> [1, heads, seq, dim]
|
||||
qi = q[q_start:q_end].unsqueeze(0).transpose(1, 2)
|
||||
ki = ki.unsqueeze(0).transpose(1, 2)
|
||||
vi = vi.unsqueeze(0).transpose(1, 2)
|
||||
|
||||
oi = F.scaled_dot_product_attention(
|
||||
qi, ki, vi, scale=scale, is_causal=True, enable_gqa=enable_gqa
|
||||
)
|
||||
outputs.append(oi.transpose(1, 2).squeeze(0))
|
||||
|
||||
return torch.cat(outputs, dim=0)
|
||||
|
||||
|
||||
def _sdpa_decode_with_paged_cache(
|
||||
q: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor,
|
||||
context_lens: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
scale: float,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
) -> torch.Tensor:
|
||||
"""SDPA replacement for flash_attn_with_kvcache during decode.
|
||||
|
||||
For each sequence, gathers KV from paged cache and runs SDPA
|
||||
for the single new query token against the full context.
|
||||
|
||||
Args:
|
||||
q: [batch, 1, num_heads, head_dim] (already unsqueezed)
|
||||
k_cache: [num_blocks, block_size, num_kv_heads, head_dim]
|
||||
v_cache: [num_blocks, block_size, num_kv_heads, head_dim]
|
||||
context_lens: [batch] - number of tokens in context for each sequence
|
||||
block_tables: [batch, max_blocks_per_seq]
|
||||
scale: attention scale factor
|
||||
num_heads: number of query heads
|
||||
num_kv_heads: number of KV heads
|
||||
|
||||
Returns:
|
||||
output: [batch, 1, num_heads, head_dim]
|
||||
"""
|
||||
batch_size = q.shape[0]
|
||||
block_size = k_cache.shape[1]
|
||||
outputs = []
|
||||
enable_gqa = num_heads != num_kv_heads
|
||||
|
||||
for i in range(batch_size):
|
||||
ctx_len = context_lens[i].item()
|
||||
num_blocks_needed = (ctx_len + block_size - 1) // block_size
|
||||
block_indices = block_tables[i, :num_blocks_needed]
|
||||
|
||||
# Gather and trim KV: [ctx_len, num_kv_heads, head_dim]
|
||||
ki = k_cache[block_indices].reshape(-1, num_kv_heads, k_cache.shape[-1])[:ctx_len]
|
||||
vi = v_cache[block_indices].reshape(-1, num_kv_heads, v_cache.shape[-1])[:ctx_len]
|
||||
|
||||
# q[i]: [1, num_heads, head_dim] -> [1, num_heads, 1, head_dim]
|
||||
qi = q[i].unsqueeze(0).transpose(1, 2) # [1, num_heads, 1, head_dim]
|
||||
ki = ki.unsqueeze(0).transpose(1, 2) # [1, num_kv_heads, ctx_len, head_dim]
|
||||
vi = vi.unsqueeze(0).transpose(1, 2) # [1, num_kv_heads, ctx_len, head_dim]
|
||||
|
||||
oi = F.scaled_dot_product_attention(
|
||||
qi, ki, vi, scale=scale, is_causal=False, enable_gqa=enable_gqa
|
||||
)
|
||||
outputs.append(oi.transpose(1, 2).squeeze(0)) # [1, num_heads, head_dim]
|
||||
|
||||
return torch.stack(outputs, dim=0) # [batch, 1, num_heads, head_dim]
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Attention module
|
||||
# ============================================================
|
||||
|
||||
class Attention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads,
|
||||
head_dim,
|
||||
scale,
|
||||
num_kv_heads,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
self.scale = scale
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.k_cache = self.v_cache = torch.tensor([])
|
||||
|
||||
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
|
||||
context = get_context()
|
||||
k_cache, v_cache = self.k_cache, self.v_cache
|
||||
|
||||
if _DEBUG:
|
||||
_debug_log(f"forward: q.shape={q.shape}, k.shape={k.shape}, v.shape={v.shape}")
|
||||
_debug_log(f" is_prefill={context.is_prefill}")
|
||||
_debug_log(f" k_cache.shape={k_cache.shape if k_cache.numel() else 'empty'}")
|
||||
if context.slot_mapping is not None:
|
||||
_debug_log(f" slot_mapping.shape={context.slot_mapping.shape}, range=[{context.slot_mapping.min().item()}, {context.slot_mapping.max().item()}]")
|
||||
if context.block_tables is not None:
|
||||
valid_blocks = context.block_tables[context.block_tables >= 0]
|
||||
_debug_log(f" block_tables.shape={context.block_tables.shape}, range=[{valid_blocks.min().item() if valid_blocks.numel() else -1}, {valid_blocks.max().item() if valid_blocks.numel() else -1}]")
|
||||
if context.context_lens is not None:
|
||||
_debug_log(f" context_lens={context.context_lens.tolist()}")
|
||||
|
||||
if k_cache.numel() and v_cache.numel():
|
||||
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
|
||||
|
||||
if _HAS_FLASH_ATTN:
|
||||
return self._forward_flash_attn(q, k, v, k_cache, v_cache, context)
|
||||
else:
|
||||
return self._forward_sdpa(q, k, v, k_cache, v_cache, context)
|
||||
|
||||
def _forward_flash_attn(self, q, k, v, k_cache, v_cache, context):
|
||||
"""Original flash attention path."""
|
||||
if context.is_prefill:
|
||||
if context.block_tables is not None: # prefix cache
|
||||
k, v = k_cache, v_cache
|
||||
_debug_log(f" calling flash_attn_varlen_func")
|
||||
o = flash_attn_varlen_func(
|
||||
q, k, v,
|
||||
max_seqlen_q=context.max_seqlen_q,
|
||||
cu_seqlens_q=context.cu_seqlens_q,
|
||||
max_seqlen_k=context.max_seqlen_k,
|
||||
cu_seqlens_k=context.cu_seqlens_k,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
block_table=context.block_tables,
|
||||
)
|
||||
else: # decode
|
||||
_debug_log(f" calling flash_attn_with_kvcache")
|
||||
o = flash_attn_with_kvcache(
|
||||
q.unsqueeze(1), k_cache, v_cache,
|
||||
cache_seqlens=context.context_lens,
|
||||
block_table=context.block_tables,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
)
|
||||
return o
|
||||
|
||||
def _forward_sdpa(self, q, k, v, k_cache, v_cache, context):
|
||||
"""SDPA fallback path (no flash_attn dependency)."""
|
||||
if context.is_prefill:
|
||||
if context.block_tables is not None:
|
||||
# Prefix cache: gather from paged cache
|
||||
_debug_log(f" calling _sdpa_prefill_with_paged_cache")
|
||||
o = _sdpa_prefill_with_paged_cache(
|
||||
q, k_cache, v_cache,
|
||||
context.cu_seqlens_q, context.cu_seqlens_k,
|
||||
context.block_tables,
|
||||
self.scale, self.num_heads, self.num_kv_heads,
|
||||
)
|
||||
else:
|
||||
# Standard prefill: k, v are packed tokens
|
||||
_debug_log(f" calling _sdpa_varlen_prefill")
|
||||
o = _sdpa_varlen_prefill(
|
||||
q, k, v,
|
||||
context.cu_seqlens_q, context.cu_seqlens_k,
|
||||
self.scale, self.num_heads, self.num_kv_heads,
|
||||
)
|
||||
else:
|
||||
# Decode: single token per sequence against full KV cache
|
||||
_debug_log(f" calling _sdpa_decode_with_paged_cache")
|
||||
o = _sdpa_decode_with_paged_cache(
|
||||
q.unsqueeze(1), k_cache, v_cache,
|
||||
context.context_lens, context.block_tables,
|
||||
self.scale, self.num_heads, self.num_kv_heads,
|
||||
)
|
||||
return o
|
||||
69
acestep/third_parts/nano-vllm/nanovllm/layers/embed_head.py
Normal file
69
acestep/third_parts/nano-vllm/nanovllm/layers/embed_head.py
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
import torch.distributed as dist
|
||||
|
||||
from nanovllm.utils.context import get_context
|
||||
from nanovllm import distributed as dist_utils
|
||||
|
||||
|
||||
class VocabParallelEmbedding(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.tp_rank = dist_utils.get_rank()
|
||||
self.tp_size = dist_utils.get_world_size()
|
||||
assert num_embeddings % self.tp_size == 0
|
||||
self.num_embeddings = num_embeddings
|
||||
self.num_embeddings_per_partition = self.num_embeddings // self.tp_size
|
||||
self.vocab_start_idx = self.num_embeddings_per_partition * self.tp_rank
|
||||
self.vocab_end_idx = self.vocab_start_idx + self.num_embeddings_per_partition
|
||||
self.weight = nn.Parameter(torch.empty(self.num_embeddings_per_partition, embedding_dim))
|
||||
self.weight.weight_loader = self.weight_loader
|
||||
|
||||
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
|
||||
param_data = param.data
|
||||
shard_size = param_data.size(0)
|
||||
start_idx = self.tp_rank * shard_size
|
||||
loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
if self.tp_size > 1:
|
||||
mask = (x >= self.vocab_start_idx) & (x < self.vocab_end_idx)
|
||||
x = mask * (x - self.vocab_start_idx)
|
||||
y = F.embedding(x, self.weight)
|
||||
if self.tp_size > 1:
|
||||
y = mask.unsqueeze(1) * y
|
||||
dist_utils.all_reduce(y)
|
||||
return y
|
||||
|
||||
|
||||
class ParallelLMHead(VocabParallelEmbedding):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
bias: bool = False,
|
||||
):
|
||||
assert not bias
|
||||
super().__init__(num_embeddings, embedding_dim)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
context = get_context()
|
||||
if context.is_prefill:
|
||||
last_indices = context.cu_seqlens_q[1:] - 1
|
||||
x = x[last_indices].contiguous()
|
||||
logits = F.linear(x, self.weight)
|
||||
# In multi-GPU mode, gather logits from all ranks and concatenate
|
||||
# In single-GPU mode (tp_size=1), skip gathering and return logits directly
|
||||
if self.tp_size > 1:
|
||||
all_logits = [torch.empty_like(logits) for _ in range(self.tp_size)] if self.tp_rank == 0 else None
|
||||
dist_utils.gather(logits, all_logits, 0)
|
||||
logits = torch.cat(all_logits, -1) if self.tp_rank == 0 else None
|
||||
return logits
|
||||
52
acestep/third_parts/nano-vllm/nanovllm/layers/layernorm.py
Executable file
52
acestep/third_parts/nano-vllm/nanovllm/layers/layernorm.py
Executable file
|
|
@ -0,0 +1,52 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
|
||||
from nanovllm.utils.compat import maybe_compile
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
eps: float = 1e-6,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
|
||||
@maybe_compile
|
||||
def rms_forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
orig_dtype = x.dtype
|
||||
x = x.float()
|
||||
var = x.pow(2).mean(dim=-1, keepdim=True)
|
||||
x.mul_(torch.rsqrt(var + self.eps))
|
||||
x = x.to(orig_dtype).mul_(self.weight)
|
||||
return x
|
||||
|
||||
@maybe_compile
|
||||
def add_rms_forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
orig_dtype = x.dtype
|
||||
x = x.float().add_(residual.float())
|
||||
residual = x.to(orig_dtype)
|
||||
var = x.pow(2).mean(dim=-1, keepdim=True)
|
||||
x.mul_(torch.rsqrt(var + self.eps))
|
||||
x = x.to(orig_dtype).mul_(self.weight)
|
||||
return x, residual
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
if residual is None:
|
||||
return self.rms_forward(x)
|
||||
else:
|
||||
return self.add_rms_forward(x, residual)
|
||||
155
acestep/third_parts/nano-vllm/nanovllm/layers/linear.py
Executable file
155
acestep/third_parts/nano-vllm/nanovllm/layers/linear.py
Executable file
|
|
@ -0,0 +1,155 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
import torch.distributed as dist
|
||||
|
||||
from nanovllm import distributed as dist_utils
|
||||
|
||||
|
||||
def divide(numerator, denominator):
|
||||
assert numerator % denominator == 0
|
||||
return numerator // denominator
|
||||
|
||||
|
||||
class LinearBase(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
bias: bool = False,
|
||||
tp_dim: int | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.tp_dim = tp_dim
|
||||
self.tp_rank = dist_utils.get_rank()
|
||||
self.tp_size = dist_utils.get_world_size()
|
||||
self.weight = nn.Parameter(torch.empty(output_size, input_size))
|
||||
self.weight.weight_loader = self.weight_loader
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.empty(output_size))
|
||||
self.bias.weight_loader = self.weight_loader
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ReplicatedLinear(LinearBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
bias: bool = False,
|
||||
):
|
||||
super().__init__(input_size, output_size, bias)
|
||||
|
||||
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
|
||||
param.data.copy_(loaded_weight)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return F.linear(x, self.weight, self.bias)
|
||||
|
||||
|
||||
class ColumnParallelLinear(LinearBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
bias: bool = False,
|
||||
):
|
||||
tp_size = dist_utils.get_world_size()
|
||||
super().__init__(input_size, divide(output_size, tp_size), bias, 0)
|
||||
|
||||
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
|
||||
param_data = param.data
|
||||
shard_size = param_data.size(self.tp_dim)
|
||||
start_idx = self.tp_rank * shard_size
|
||||
loaded_weight = loaded_weight.narrow(self.tp_dim, start_idx, shard_size)
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return F.linear(x, self.weight, self.bias)
|
||||
|
||||
|
||||
class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_sizes: list[int],
|
||||
bias: bool = False,
|
||||
):
|
||||
self.output_sizes = output_sizes
|
||||
super().__init__(input_size, sum(output_sizes), bias)
|
||||
|
||||
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: int):
|
||||
param_data = param.data
|
||||
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size
|
||||
shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
|
||||
param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size)
|
||||
loaded_weight = loaded_weight.chunk(self.tp_size, self.tp_dim)[self.tp_rank]
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
|
||||
class QKVParallelLinear(ColumnParallelLinear):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
head_size: int,
|
||||
total_num_heads: int,
|
||||
total_num_kv_heads: int | None = None,
|
||||
bias: bool = False,
|
||||
):
|
||||
tp_size = dist_utils.get_world_size()
|
||||
total_num_kv_heads = total_num_kv_heads or total_num_heads
|
||||
self.head_size = head_size
|
||||
self.num_heads = divide(total_num_heads, tp_size)
|
||||
self.num_kv_heads = divide(total_num_kv_heads, tp_size)
|
||||
output_size = (total_num_heads + 2 * total_num_kv_heads) * self.head_size
|
||||
super().__init__(hidden_size, output_size, bias)
|
||||
|
||||
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: str):
|
||||
param_data = param.data
|
||||
assert loaded_shard_id in ["q", "k", "v"]
|
||||
if loaded_shard_id == "q":
|
||||
shard_size = self.num_heads * self.head_size
|
||||
shard_offset = 0
|
||||
elif loaded_shard_id == "k":
|
||||
shard_size = self.num_kv_heads * self.head_size
|
||||
shard_offset = self.num_heads * self.head_size
|
||||
else:
|
||||
shard_size = self.num_kv_heads * self.head_size
|
||||
shard_offset = self.num_heads * self.head_size + self.num_kv_heads * self.head_size
|
||||
param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size)
|
||||
loaded_weight = loaded_weight.chunk(self.tp_size, self.tp_dim)[self.tp_rank]
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
|
||||
class RowParallelLinear(LinearBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
bias: bool = False,
|
||||
):
|
||||
tp_size = dist_utils.get_world_size()
|
||||
super().__init__(divide(input_size, tp_size), output_size, bias, 1)
|
||||
|
||||
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
|
||||
param_data = param.data
|
||||
shard_size = param_data.size(self.tp_dim)
|
||||
start_idx = self.tp_rank * shard_size
|
||||
loaded_weight = loaded_weight.narrow(self.tp_dim, start_idx, shard_size)
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
y = F.linear(x, self.weight, self.bias if self.tp_rank == 0 else None)
|
||||
if self.tp_size > 1:
|
||||
dist_utils.all_reduce(y)
|
||||
return y
|
||||
|
|
@ -0,0 +1,63 @@
|
|||
from functools import lru_cache
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from nanovllm.utils.compat import maybe_compile
|
||||
|
||||
|
||||
def apply_rotary_emb(
|
||||
x: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
x1, x2 = torch.chunk(x.float(), 2, dim=-1)
|
||||
y1 = x1 * cos - x2 * sin
|
||||
y2 = x2 * cos + x1 * sin
|
||||
return torch.cat((y1, y2), dim=-1).to(x.dtype)
|
||||
|
||||
|
||||
class RotaryEmbedding(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: float,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.head_size = head_size
|
||||
assert rotary_dim == head_size
|
||||
inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2, dtype=torch.float) / rotary_dim))
|
||||
t = torch.arange(max_position_embeddings, dtype=torch.float)
|
||||
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
cache = torch.cat((cos, sin), dim=-1).unsqueeze_(1)
|
||||
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
||||
|
||||
@maybe_compile
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
cos_sin = self.cos_sin_cache[positions]
|
||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
query = apply_rotary_emb(query, cos, sin)
|
||||
key = apply_rotary_emb(key, cos, sin)
|
||||
return query, key
|
||||
|
||||
|
||||
@lru_cache(1)
|
||||
def get_rope(
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position: int,
|
||||
base: float,
|
||||
rope_scaling: dict | None = None,
|
||||
):
|
||||
assert rope_scaling is None
|
||||
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base)
|
||||
return rotary_emb
|
||||
116
acestep/third_parts/nano-vllm/nanovllm/layers/sampler.py
Normal file
116
acestep/third_parts/nano-vllm/nanovllm/layers/sampler.py
Normal file
|
|
@ -0,0 +1,116 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
from typing import Optional
|
||||
|
||||
from nanovllm.utils.compat import maybe_compile
|
||||
|
||||
|
||||
def apply_top_k_top_p(
|
||||
logits: torch.Tensor,
|
||||
k: Optional[torch.Tensor],
|
||||
p: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
"""Apply top-k and top-p masks to the logits (vLLM style).
|
||||
|
||||
The logits tensor is updated in-place.
|
||||
"""
|
||||
if p is None:
|
||||
if k is None:
|
||||
return logits
|
||||
# Avoid sorting vocab for top-k only case
|
||||
return apply_top_k_only(logits, k)
|
||||
|
||||
# Need to sort for top-p
|
||||
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
|
||||
|
||||
if k is not None:
|
||||
# Apply top-k first
|
||||
vocab_size = logits_sort.size(1)
|
||||
# Clamp k to valid range
|
||||
k_clamped = k.clamp(1, vocab_size).long()
|
||||
top_k_mask_idx = vocab_size - k_clamped # shape: [B]
|
||||
# Get the threshold value for each batch
|
||||
top_k_thresh = logits_sort.gather(1, top_k_mask_idx.unsqueeze(1))
|
||||
top_k_mask = logits_sort < top_k_thresh
|
||||
logits_sort.masked_fill_(top_k_mask, float('-inf'))
|
||||
|
||||
# Apply top-p
|
||||
probs_sort = logits_sort.softmax(dim=-1)
|
||||
probs_sum = torch.cumsum(probs_sort, dim=-1, out=probs_sort) # reuse buffer
|
||||
top_p_mask = probs_sum <= (1.0 - p.unsqueeze(1))
|
||||
# Ensure at least one token is kept
|
||||
top_p_mask[:, -1] = False
|
||||
logits_sort.masked_fill_(top_p_mask, float('-inf'))
|
||||
|
||||
# Re-sort back to original positions
|
||||
logits.scatter_(dim=-1, index=logits_idx, src=logits_sort)
|
||||
return logits
|
||||
|
||||
|
||||
def apply_top_k_only(
|
||||
logits: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Apply top-k mask without sorting the entire vocab (vLLM style).
|
||||
|
||||
This is much faster than sorting for top-k only cases.
|
||||
The logits tensor is updated in-place.
|
||||
"""
|
||||
vocab_size = logits.shape[1]
|
||||
# Handle cases where k >= vocab_size (no filtering needed)
|
||||
no_top_k_mask = (k <= 0) | (k >= vocab_size)
|
||||
# Set invalid k to 1 so we can still gather
|
||||
k_safe = k.masked_fill(no_top_k_mask, 1).long()
|
||||
# NOTE: This int() causes CPU-GPU sync, but torch.topk requires Python int
|
||||
max_top_k = int(k_safe.max().clamp(max=vocab_size))
|
||||
|
||||
# Get top-k values for all batches
|
||||
# topk.values has shape [batch_size, max_top_k]
|
||||
topk_values = logits.topk(max_top_k, dim=1).values
|
||||
|
||||
# Convert k to 0-based index: we want the k-th largest value (index k-1)
|
||||
# Clamp to valid range for gather
|
||||
k_index = (k_safe - 1).clamp(0, max_top_k - 1).unsqueeze(1) # shape: [B, 1]
|
||||
# Gather the threshold value (the k-th largest)
|
||||
top_k_thresh = topk_values.gather(1, k_index)
|
||||
|
||||
# For rows with no top-k filtering, set threshold to -inf so nothing gets masked
|
||||
top_k_thresh.masked_fill_(no_top_k_mask.unsqueeze(1), float('-inf'))
|
||||
|
||||
# Mask all values below the threshold
|
||||
logits.masked_fill_(logits < top_k_thresh, float('-inf'))
|
||||
return logits
|
||||
|
||||
|
||||
class Sampler(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@maybe_compile
|
||||
def forward(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
temperatures: torch.Tensor,
|
||||
top_ks: Optional[torch.Tensor] = None,
|
||||
top_ps: Optional[torch.Tensor] = None,
|
||||
repetition_penalties: Optional[torch.Tensor] = None,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
):
|
||||
"""
|
||||
Sample tokens from logits with optional top-k and top-p filtering.
|
||||
|
||||
Condition checking is done OUTSIDE the compiled function to avoid
|
||||
graph breaks from .any() calls.
|
||||
"""
|
||||
# Apply temperature
|
||||
logits = logits.float().div_(temperatures.unsqueeze(dim=1))
|
||||
|
||||
logits = apply_top_k_top_p(
|
||||
logits,
|
||||
top_ks,
|
||||
top_ps,
|
||||
)
|
||||
probs = torch.softmax(logits, dim=-1)
|
||||
sample_tokens = probs.div_(torch.empty_like(probs).exponential_(1).clamp_min_(1e-10)).argmax(dim=-1)
|
||||
return sample_tokens
|
||||
5
acestep/third_parts/nano-vllm/nanovllm/llm.py
Normal file
5
acestep/third_parts/nano-vllm/nanovllm/llm.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
from nanovllm.engine.llm_engine import LLMEngine
|
||||
|
||||
|
||||
class LLM(LLMEngine):
|
||||
pass
|
||||
230
acestep/third_parts/nano-vllm/nanovllm/models/qwen3.py
Executable file
230
acestep/third_parts/nano-vllm/nanovllm/models/qwen3.py
Executable file
|
|
@ -0,0 +1,230 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
import torch.distributed as dist
|
||||
from transformers import Qwen3Config
|
||||
|
||||
from nanovllm.layers.activation import SiluAndMul
|
||||
from nanovllm.layers.attention import Attention
|
||||
from nanovllm.layers.layernorm import RMSNorm
|
||||
from nanovllm.layers.linear import QKVParallelLinear, MergedColumnParallelLinear, RowParallelLinear
|
||||
from nanovllm.layers.rotary_embedding import get_rope
|
||||
from nanovllm.layers.embed_head import VocabParallelEmbedding, ParallelLMHead
|
||||
from nanovllm import distributed as dist_utils
|
||||
|
||||
|
||||
class Qwen3Attention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
max_position: int = 4096 * 32,
|
||||
head_dim: int | None = None,
|
||||
rms_norm_eps: float = 1e-06,
|
||||
qkv_bias: bool = False,
|
||||
rope_theta: float = 10000,
|
||||
rope_scaling: tuple | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
tp_size = dist_utils.get_world_size()
|
||||
self.total_num_heads = num_heads
|
||||
assert self.total_num_heads % tp_size == 0
|
||||
self.num_heads = self.total_num_heads // tp_size
|
||||
self.total_num_kv_heads = num_kv_heads
|
||||
assert self.total_num_kv_heads % tp_size == 0
|
||||
self.num_kv_heads = self.total_num_kv_heads // tp_size
|
||||
self.head_dim = head_dim or hidden_size // self.total_num_heads
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim ** -0.5
|
||||
self.qkv_bias = qkv_bias
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size,
|
||||
self.head_dim,
|
||||
self.total_num_heads,
|
||||
self.total_num_kv_heads,
|
||||
bias=qkv_bias,
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
)
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=max_position,
|
||||
base=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
)
|
||||
self.attn = Attention(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
self.num_kv_heads,
|
||||
)
|
||||
if not self.qkv_bias:
|
||||
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
|
||||
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
qkv = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q = q.view(-1, self.num_heads, self.head_dim)
|
||||
k = k.view(-1, self.num_kv_heads, self.head_dim)
|
||||
v = v.view(-1, self.num_kv_heads, self.head_dim)
|
||||
if not self.qkv_bias:
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
o = self.attn(q, k, v)
|
||||
output = self.o_proj(o.flatten(1, -1))
|
||||
return output
|
||||
|
||||
|
||||
class Qwen3MLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size,
|
||||
[intermediate_size] * 2,
|
||||
bias=False,
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
)
|
||||
assert hidden_act == "silu"
|
||||
self.act_fn = SiluAndMul()
|
||||
|
||||
def forward(self, x):
|
||||
gate_up = self.gate_up_proj(x)
|
||||
x = self.act_fn(gate_up)
|
||||
x = self.down_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class Qwen3DecoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Qwen3Config,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.self_attn = Qwen3Attention(
|
||||
hidden_size=config.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
num_kv_heads=config.num_key_value_heads,
|
||||
max_position=config.max_position_embeddings,
|
||||
rms_norm_eps=config.rms_norm_eps,
|
||||
qkv_bias=getattr(config, 'attention_bias', True),
|
||||
head_dim=getattr(config, 'head_dim', None),
|
||||
rope_theta=getattr(config, "rope_theta", 1000000),
|
||||
rope_scaling=getattr(config, "rope_scaling", None),
|
||||
)
|
||||
self.mlp = Qwen3MLP(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: torch.Tensor | None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if residual is None:
|
||||
hidden_states, residual = self.input_layernorm(hidden_states), hidden_states
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||
hidden_states = self.self_attn(positions, hidden_states)
|
||||
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
class Qwen3Model(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Qwen3Config,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size)
|
||||
self.layers = nn.ModuleList([Qwen3DecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
residual = None
|
||||
for layer in self.layers:
|
||||
hidden_states, residual = layer(positions, hidden_states, residual)
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Qwen3ForCausalLM(nn.Module):
|
||||
packed_modules_mapping = {
|
||||
"q_proj": ("qkv_proj", "q"),
|
||||
"k_proj": ("qkv_proj", "k"),
|
||||
"v_proj": ("qkv_proj", "v"),
|
||||
"gate_proj": ("gate_up_proj", 0),
|
||||
"up_proj": ("gate_up_proj", 1),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Qwen3Config
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.model = Qwen3Model(config)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
if config.tie_word_embeddings:
|
||||
self.lm_head.weight.data = self.model.embed_tokens.weight.data
|
||||
|
||||
# Proxy attributes for weight loading compatibility
|
||||
# Some model weights use "embed_tokens" instead of "model.embed_tokens"
|
||||
@property
|
||||
def embed_tokens(self):
|
||||
return self.model.embed_tokens
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
|
||||
@property
|
||||
def norm(self):
|
||||
return self.model.norm
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return self.model(input_ids, positions)
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return self.lm_head(hidden_states)
|
||||
28
acestep/third_parts/nano-vllm/nanovllm/sampling_params.py
Normal file
28
acestep/third_parts/nano-vllm/nanovllm/sampling_params.py
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
from dataclasses import dataclass, field
|
||||
from typing import Optional, Callable, Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class SamplingParams:
|
||||
temperature: float = 1.0
|
||||
max_tokens: int = 64
|
||||
ignore_eos: bool = False
|
||||
cfg_scale: float = 1.0 # CFG guidance scale. When > 1.0, applies classifier-free guidance
|
||||
top_k: Optional[int] = None # Top-k sampling: consider only top k tokens
|
||||
top_p: Optional[float] = None # Top-p (nucleus) sampling: consider tokens with cumulative probability <= top_p
|
||||
repetition_penalty: float = 1.0 # Repetition penalty: >1.0 reduces repetition, <1.0 increases it
|
||||
# Optional logits processor for constrained decoding
|
||||
# Should be a callable with signature: (input_ids: torch.Tensor, logits: torch.Tensor) -> torch.Tensor
|
||||
logits_processor: Optional[Any] = field(default=None, repr=False)
|
||||
# Optional callback to update processor state after each token
|
||||
# Should be a callable with signature: (token_id: int) -> None
|
||||
logits_processor_update_state: Optional[Callable[[int], None]] = field(default=None, repr=False)
|
||||
|
||||
def __post_init__(self):
|
||||
assert self.temperature > 1e-10, "greedy sampling is not permitted"
|
||||
assert self.cfg_scale >= 1.0, "cfg_scale must be >= 1.0"
|
||||
if self.top_k is not None:
|
||||
assert self.top_k > 0, "top_k must be > 0"
|
||||
if self.top_p is not None:
|
||||
assert 0.0 < self.top_p <= 1.0, "top_p must be in (0.0, 1.0]"
|
||||
assert self.repetition_penalty > 0.0, "repetition_penalty must be > 0.0"
|
||||
70
acestep/third_parts/nano-vllm/nanovllm/utils/compat.py
Normal file
70
acestep/third_parts/nano-vllm/nanovllm/utils/compat.py
Normal file
|
|
@ -0,0 +1,70 @@
|
|||
"""Compatibility utilities for optional dependencies.
|
||||
|
||||
Provides graceful fallbacks when torch.compile's backend (Triton) is
|
||||
unavailable — e.g. on Windows or on GPU architectures where Triton
|
||||
has not yet added support (Blackwell / SM 120 as of early 2026).
|
||||
"""
|
||||
|
||||
from typing import Any, Callable, Optional, TypeVar
|
||||
|
||||
from loguru import logger
|
||||
|
||||
F = TypeVar("F", bound=Callable[..., Any])
|
||||
|
||||
# Module-level Triton availability check — runs once at import time
|
||||
# rather than repeating the import probe at every decoration site.
|
||||
_HAS_TRITON = False
|
||||
try:
|
||||
import triton # noqa: F401
|
||||
_HAS_TRITON = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
def maybe_compile(fn: Optional[F] = None, **compile_kwargs: Any) -> Any:
|
||||
"""Apply ``torch.compile`` only when its backend (Triton) is available.
|
||||
|
||||
Drop-in replacement for the ``@torch.compile`` decorator. When Triton
|
||||
is importable the function is compiled as usual; otherwise the original
|
||||
function is returned unmodified so inference still works (just without
|
||||
the kernel-fusion speed-up).
|
||||
|
||||
Args:
|
||||
fn: The function to compile. When used as ``@maybe_compile`` (without
|
||||
parentheses) the decorated function is passed directly. When used
|
||||
as ``@maybe_compile(...)`` this is ``None`` and a decorator is
|
||||
returned instead.
|
||||
**compile_kwargs: Keyword arguments forwarded to ``torch.compile``
|
||||
(e.g. ``dynamic=True``, ``fullgraph=True``).
|
||||
|
||||
Returns:
|
||||
The compiled function when Triton is available, or the original
|
||||
unmodified function as a fallback.
|
||||
|
||||
Usage::
|
||||
|
||||
@maybe_compile
|
||||
def forward(self, x):
|
||||
...
|
||||
|
||||
# or with keyword arguments:
|
||||
@maybe_compile(dynamic=True)
|
||||
def forward(self, x):
|
||||
...
|
||||
"""
|
||||
def decorator(func: F) -> F:
|
||||
"""Inner decorator that performs the actual compile-or-skip logic."""
|
||||
if _HAS_TRITON:
|
||||
import torch
|
||||
return torch.compile(func, **compile_kwargs)
|
||||
logger.info(
|
||||
"Triton not available — skipping torch.compile for %s "
|
||||
"(inference will use native PyTorch kernels)",
|
||||
func.__qualname__,
|
||||
)
|
||||
return func
|
||||
|
||||
# Support both @maybe_compile and @maybe_compile(...) syntax
|
||||
if fn is not None:
|
||||
return decorator(fn)
|
||||
return decorator
|
||||
45
acestep/third_parts/nano-vllm/nanovllm/utils/compat_test.py
Normal file
45
acestep/third_parts/nano-vllm/nanovllm/utils/compat_test.py
Normal file
|
|
@ -0,0 +1,45 @@
|
|||
"""Unit tests for the ``maybe_compile`` conditional compilation decorator."""
|
||||
|
||||
import unittest
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
|
||||
def _sample_fn(x):
|
||||
"""Trivial function used as decoration target in tests."""
|
||||
return x + 1
|
||||
|
||||
|
||||
class MaybeCompileTests(unittest.TestCase):
|
||||
"""Verify maybe_compile compiles or skips based on Triton availability."""
|
||||
|
||||
def test_compiles_when_triton_available(self):
|
||||
"""When Triton is available, the function should be passed to torch.compile."""
|
||||
mock_compiled = MagicMock(name="compiled_fn")
|
||||
with patch("nanovllm.utils.compat._HAS_TRITON", True), \
|
||||
patch("torch.compile", return_value=mock_compiled) as compile_mock:
|
||||
from nanovllm.utils.compat import maybe_compile
|
||||
result = maybe_compile(_sample_fn)
|
||||
compile_mock.assert_called_once_with(_sample_fn)
|
||||
self.assertEqual(result, mock_compiled)
|
||||
|
||||
def test_returns_original_when_triton_absent(self):
|
||||
"""When Triton is absent, the original function should be returned unmodified."""
|
||||
with patch("nanovllm.utils.compat._HAS_TRITON", False):
|
||||
from nanovllm.utils.compat import maybe_compile
|
||||
result = maybe_compile(_sample_fn)
|
||||
self.assertIs(result, _sample_fn)
|
||||
|
||||
def test_kwargs_syntax_forwards_compile_args(self):
|
||||
"""@maybe_compile(dynamic=True) should forward kwargs to torch.compile."""
|
||||
mock_compiled = MagicMock(name="compiled_fn")
|
||||
with patch("nanovllm.utils.compat._HAS_TRITON", True), \
|
||||
patch("torch.compile", return_value=mock_compiled) as compile_mock:
|
||||
from nanovllm.utils.compat import maybe_compile
|
||||
decorator = maybe_compile(dynamic=True)
|
||||
result = decorator(_sample_fn)
|
||||
compile_mock.assert_called_once_with(_sample_fn, dynamic=True)
|
||||
self.assertEqual(result, mock_compiled)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
47
acestep/third_parts/nano-vllm/nanovllm/utils/context.py
Normal file
47
acestep/third_parts/nano-vllm/nanovllm/utils/context.py
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
from dataclasses import dataclass
|
||||
import threading
|
||||
import torch
|
||||
|
||||
|
||||
@dataclass
|
||||
class Context:
|
||||
is_prefill: bool = False
|
||||
cu_seqlens_q: torch.Tensor | None = None
|
||||
cu_seqlens_k: torch.Tensor | None = None
|
||||
max_seqlen_q: int = 0
|
||||
max_seqlen_k: int = 0
|
||||
slot_mapping: torch.Tensor | None = None
|
||||
context_lens: torch.Tensor | None = None
|
||||
block_tables: torch.Tensor | None = None
|
||||
|
||||
|
||||
# Thread-local storage for context.
|
||||
#
|
||||
# ROOT CAUSE FIX: The original implementation used a plain module-level global
|
||||
# `_CONTEXT` variable. In concurrent serving scenarios (API server with
|
||||
# ThreadPoolExecutor, multiple queue workers, or Gradio with concurrent requests),
|
||||
# multiple threads can call set_context() / get_context() / reset_context()
|
||||
# concurrently. This creates a race condition:
|
||||
#
|
||||
# Thread A: set_context(...) # sets slot_mapping, block_tables for request A
|
||||
# Thread B: set_context(...) # OVERWRITES with request B's data
|
||||
# Thread A: run_model(...) # reads Thread B's context → WRONG KV cache addresses
|
||||
# # → CUDA illegal memory access / device-side assertion
|
||||
#
|
||||
# By using threading.local(), each thread gets its own independent Context,
|
||||
# eliminating the race condition entirely.
|
||||
_THREAD_LOCAL = threading.local()
|
||||
|
||||
|
||||
def get_context():
|
||||
ctx = getattr(_THREAD_LOCAL, 'context', None)
|
||||
if ctx is None:
|
||||
ctx = Context()
|
||||
_THREAD_LOCAL.context = ctx
|
||||
return ctx
|
||||
|
||||
def set_context(is_prefill, cu_seqlens_q=None, cu_seqlens_k=None, max_seqlen_q=0, max_seqlen_k=0, slot_mapping=None, context_lens=None, block_tables=None):
|
||||
_THREAD_LOCAL.context = Context(is_prefill, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, context_lens, block_tables)
|
||||
|
||||
def reset_context():
|
||||
_THREAD_LOCAL.context = Context()
|
||||
70
acestep/third_parts/nano-vllm/nanovllm/utils/loader.py
Normal file
70
acestep/third_parts/nano-vllm/nanovllm/utils/loader.py
Normal file
|
|
@ -0,0 +1,70 @@
|
|||
import os
|
||||
from glob import glob
|
||||
import torch
|
||||
from torch import nn
|
||||
from safetensors import safe_open
|
||||
|
||||
|
||||
def default_weight_loader(param: nn.Parameter, loaded_weight: torch.Tensor):
|
||||
param.data.copy_(loaded_weight)
|
||||
|
||||
|
||||
def _get_parameter_safe(model: nn.Module, weight_name: str):
|
||||
"""
|
||||
Try to get parameter from model, handling name mismatches.
|
||||
|
||||
Some models have nested structure (e.g., Qwen3ForCausalLM has model.embed_tokens)
|
||||
but weight files may have flat names (embed_tokens.weight).
|
||||
"""
|
||||
# Try direct access first
|
||||
try:
|
||||
return model.get_parameter(weight_name)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
# Try with 'model.' prefix (for nested model structure)
|
||||
try:
|
||||
prefixed_name = f"model.{weight_name}"
|
||||
return model.get_parameter(prefixed_name)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
# Try removing 'model.' prefix
|
||||
if weight_name.startswith("model."):
|
||||
try:
|
||||
unprefixed_name = weight_name[6:] # Remove 'model.' prefix
|
||||
return model.get_parameter(unprefixed_name)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def load_model(model: nn.Module, path: str):
|
||||
packed_modules_mapping = getattr(model, "packed_modules_mapping", {})
|
||||
safetensor_files = glob(os.path.join(path, "*.safetensors"))
|
||||
|
||||
if not safetensor_files:
|
||||
raise FileNotFoundError(f"No .safetensors files found in {path}")
|
||||
|
||||
for file in safetensor_files:
|
||||
with safe_open(file, "pt", "cpu") as f:
|
||||
for weight_name in f.keys():
|
||||
for k in packed_modules_mapping:
|
||||
if k in weight_name:
|
||||
v, shard_id = packed_modules_mapping[k]
|
||||
param_name = weight_name.replace(k, v)
|
||||
param = _get_parameter_safe(model, param_name)
|
||||
if param is None:
|
||||
print(f"[loader] Warning: Parameter not found: {param_name}")
|
||||
continue
|
||||
weight_loader = getattr(param, "weight_loader")
|
||||
weight_loader(param, f.get_tensor(weight_name), shard_id)
|
||||
break
|
||||
else:
|
||||
param = _get_parameter_safe(model, weight_name)
|
||||
if param is None:
|
||||
print(f"[loader] Warning: Parameter not found: {weight_name}")
|
||||
continue
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, f.get_tensor(weight_name))
|
||||
29
acestep/third_parts/nano-vllm/pyproject.toml
Normal file
29
acestep/third_parts/nano-vllm/pyproject.toml
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
[build-system]
|
||||
requires = ["setuptools>=61"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "nano-vllm"
|
||||
version = "0.2.0"
|
||||
authors = [{ name = "Xingkai Yu" }]
|
||||
license = {text = "MIT"}
|
||||
readme = "README.md"
|
||||
description = "a lightweight vLLM implementation built from scratch"
|
||||
requires-python = ">=3.10,<3.13"
|
||||
dependencies = [
|
||||
"torch>=2.4.0",
|
||||
# Triton and Flash Attention are optional on ROCm (Python 3.12) - SDPA fallback used instead
|
||||
"triton-windows>=3.0.0,<3.4; sys_platform == 'win32' and python_version == '3.11'",
|
||||
"triton>=3.0.0; sys_platform == 'linux' and python_version == '3.11'",
|
||||
"transformers>=4.51.0",
|
||||
"flash-attn @ https://github.com/sdbds/flash-attention-for-windows/releases/download/2.8.2/flash_attn-2.8.2+cu128torch2.7.1cxx11abiFALSEfullbackward-cp311-cp311-win_amd64.whl ; sys_platform == 'win32' and sys_platform != 'darwin' and python_version == '3.11' and platform_machine == 'AMD64'",
|
||||
"flash-attn @ https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.7.12/flash_attn-2.8.3+cu128torch2.10-cp311-cp311-linux_x86_64.whl ; sys_platform == 'linux' and sys_platform != 'darwin' and python_version == '3.11' and platform_machine == 'x86_64'",
|
||||
"xxhash",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
Homepage="https://github.com/GeeeekExplorer/nano-vllm"
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["."]
|
||||
include = ["nanovllm*"]
|
||||
|
|
@ -485,7 +485,7 @@ See [ACE-Step1.5-Rocm-Manual-Linux.md](ACE-Step1.5-Rocm-Manual-Linux.md) for a d
|
|||
| Offload | Disabled by default |
|
||||
| Compile & Quantization | Enabled by default |
|
||||
| LLM Inference | Supported (tested with `acestep-5Hz-lm-0.6B`) |
|
||||
| vllm engine acceleration | NOT supported on Intel GPUs |
|
||||
| nanovllm acceleration | NOT supported on Intel GPUs |
|
||||
| Test Environment | PyTorch 2.8.0 from [Intel Extension for PyTorch](https://pytorch-extension.intel.com/?request=platform) |
|
||||
|
||||
> **Note:** LLM inference speed may decrease when generating audio longer than 2 minutes. Intel discrete GPUs are expected to work but not yet tested.
|
||||
|
|
|
|||
|
|
@ -375,7 +375,7 @@ python -m acestep.acestep_v15_pipeline --port 7680
|
|||
| オフロード | デフォルトで無効 |
|
||||
| コンパイル & 量子化 | デフォルトで有効 |
|
||||
| LLM 推論 | サポート(`acestep-5Hz-lm-0.6B` でテスト済み) |
|
||||
| vllm エンジンアクセラレーション | Intel GPU では未サポート |
|
||||
| nanovllm アクセラレーション | Intel GPU では未サポート |
|
||||
| テスト環境 | PyTorch 2.8.0([Intel Extension for PyTorch](https://pytorch-extension.intel.com/?request=platform)) |
|
||||
|
||||
> 注意:2分以上の音声生成時、LLM 推論速度が低下する場合があります。Intel ディスクリート GPU は動作が期待されますが、まだテストされていません。
|
||||
|
|
|
|||
|
|
@ -375,7 +375,7 @@ python -m acestep.acestep_v15_pipeline --port 7680
|
|||
| 卸载 | 默认禁用 |
|
||||
| 编译与量化 | 默认启用 |
|
||||
| LLM 推理 | 支持(已测试 `acestep-5Hz-lm-0.6B`) |
|
||||
| vllm 引擎加速 | Intel GPU 暂不支持 |
|
||||
| nanovllm 加速 | Intel GPU 暂不支持 |
|
||||
| 测试环境 | PyTorch 2.8.0([Intel Extension for PyTorch](https://pytorch-extension.intel.com/?request=platform)) |
|
||||
|
||||
> 注意:生成超过 2 分钟的音频时,LLM 推理速度可能下降。Intel 独立显卡预计可用但尚未测试。
|
||||
|
|
|
|||
|
|
@ -49,6 +49,9 @@ dependencies = [
|
|||
"lycoris-lora",
|
||||
"lightning>=2.0.0",
|
||||
"tensorboard>=2.0.0",
|
||||
# Local third-party packages
|
||||
# nano-vllm source is configured in [tool.uv.sources]
|
||||
"nano-vllm; sys_platform != 'darwin' or platform_machine != 'arm64'",
|
||||
"modelscope",
|
||||
"tensorboard>=2.20.0",
|
||||
"typer-slim>=0.21.1",
|
||||
|
|
@ -73,6 +76,7 @@ required-environments = [
|
|||
]
|
||||
|
||||
[tool.uv.sources]
|
||||
nano-vllm = { path = "acestep/third_parts/nano-vllm" }
|
||||
torch = [
|
||||
{ index = "pytorch-cu128", marker = "sys_platform == 'win32' or (sys_platform == 'linux' and platform_machine == 'x86_64')" },
|
||||
{ index = "pytorch-cu130", marker = "sys_platform == 'linux' and platform_machine == 'aarch64'" },
|
||||
|
|
|
|||
|
|
@ -22,3 +22,7 @@ modelscope
|
|||
peft>=0.18.0
|
||||
lightning>=2.0.0
|
||||
tensorboard>=2.0.0
|
||||
xxhash
|
||||
|
||||
# Local package - install with: pip install -e acestep/third_parts/nano-vllm
|
||||
# nano-vllm
|
||||
|
|
|
|||
|
|
@ -47,3 +47,5 @@ peft>=0.18.0
|
|||
lightning>=2.0.0
|
||||
tensorboard>=2.0.0
|
||||
|
||||
# nano-vllm tokenizer dependency (needed even with pt backend)
|
||||
xxhash
|
||||
|
|
|
|||
|
|
@ -48,8 +48,12 @@ tensorboard>=2.0.0
|
|||
mlx>=0.25.2; sys_platform == 'darwin' and platform_machine == 'arm64'
|
||||
mlx-lm>=0.20.0; sys_platform == 'darwin' and platform_machine == 'arm64'
|
||||
|
||||
# Optional accelerators for LLM inference engine (SDPA fallback if unavailable)
|
||||
# nano-vllm dependencies
|
||||
triton-windows>=3.0.0,<3.4; sys_platform == 'win32'
|
||||
triton>=3.0.0; sys_platform == 'linux'
|
||||
flash-attn @ https://github.com/sdbds/flash-attention-for-windows/releases/download/2.8.2/flash_attn-2.8.2+cu128torch2.7.1cxx11abiFALSEfullbackward-cp311-cp311-win_amd64.whl ; sys_platform == 'win32' and python_version == '3.11' and platform_machine == 'AMD64'
|
||||
flash-attn; sys_platform == 'linux' and platform_machine == 'x86_64'
|
||||
xxhash
|
||||
|
||||
# Local package - install with: pip install -e acestep/third_parts/nano-vllm
|
||||
# nano-vllm
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ REM IMPORTANT: Requires Python 3.12 (AMD ROCm 7.2 only provides Python 3.12 whee
|
|||
REM Requires: ROCm PyTorch from repo.radeon.com
|
||||
|
||||
REM ==================== ROCm Configuration ====================
|
||||
REM Force PyTorch LM backend (bypasses vllm engine flash_attn dependency)
|
||||
REM Force PyTorch LM backend (bypasses nano-vllm flash_attn dependency)
|
||||
set ACESTEP_LM_BACKEND=pt
|
||||
|
||||
REM RDNA3 GPU architecture override (RX 7900 XT/XTX, RX 7800 XT, etc.)
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ set -euo pipefail
|
|||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
|
||||
# ==================== ROCm Configuration ====================
|
||||
# Force PyTorch LM backend (bypasses vllm engine flash_attn dependency)
|
||||
# Force PyTorch LM backend (bypasses nano-vllm flash_attn dependency)
|
||||
export ACESTEP_LM_BACKEND="pt"
|
||||
|
||||
# RDNA3 GPU architecture override (RX 7900 XT/XTX, RX 7800 XT, etc.)
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ REM Load settings from .env file if it exists
|
|||
call :LoadEnvFile
|
||||
|
||||
REM ==================== ROCm Configuration ====================
|
||||
REM Force PyTorch LM backend (bypasses vllm engine flash_attn dependency)
|
||||
REM Force PyTorch LM backend (bypasses nano-vllm flash_attn dependency)
|
||||
set ACESTEP_LM_BACKEND=pt
|
||||
|
||||
REM RDNA3 GPU architecture override (RX 7900 XT/XTX, RX 7800 XT, etc.)
|
||||
|
|
|
|||
|
|
@ -71,7 +71,7 @@ _load_env_file() {
|
|||
_load_env_file
|
||||
|
||||
# ==================== ROCm Configuration ====================
|
||||
# Force PyTorch LM backend (bypasses vllm engine flash_attn dependency)
|
||||
# Force PyTorch LM backend (bypasses nano-vllm flash_attn dependency)
|
||||
export ACESTEP_LM_BACKEND="pt"
|
||||
|
||||
# RDNA3 GPU architecture override (RX 7900 XT/XTX, RX 7800 XT, etc.)
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ REM Ask for manual settings
|
|||
call :LoadManual
|
||||
|
||||
REM ==================== ROCm Configuration ====================
|
||||
REM Force PyTorch LM backend (bypasses vllm engine flash_attn dependency)
|
||||
REM Force PyTorch LM backend (bypasses nano-vllm flash_attn dependency)
|
||||
set ACESTEP_LM_BACKEND=pt
|
||||
|
||||
REM RDNA3 GPU architecture override (RX 7900 XT/XTX, RX 7800 XT, etc.)
|
||||
|
|
|
|||
|
|
@ -192,7 +192,7 @@ _load_manual() {
|
|||
_load_manual
|
||||
|
||||
# ==================== ROCm Configuration ====================
|
||||
# Force PyTorch LM backend (bypasses vllm engine flash_attn dependency)
|
||||
# Force PyTorch LM backend (bypasses nano-vllm flash_attn dependency)
|
||||
export ACESTEP_LM_BACKEND="pt"
|
||||
|
||||
# RDNA3 GPU architecture override (RX 7900 XT/XTX, RX 7800 XT, etc.)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue