mirror of
https://github.com/ace-step/ACE-Step-1.5.git
synced 2026-07-02 16:37:04 +00:00
fix cfg kv block allocate
This commit is contained in:
parent
6bdf68d9db
commit
003d75c2a7
4 changed files with 70 additions and 25 deletions
|
|
@ -1957,6 +1957,13 @@ class LLMHandler:
|
|||
reset_context()
|
||||
except ImportError:
|
||||
pass
|
||||
# Also reset the LLM scheduler to release allocated KV cache blocks
|
||||
# This prevents 'deque index out of range' errors from block leaks
|
||||
try:
|
||||
if hasattr(self.llm, 'reset'):
|
||||
self.llm.reset()
|
||||
except Exception:
|
||||
pass # Ignore errors during cleanup
|
||||
# Clear CUDA cache to release any corrupted memory
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
|
|
|||
|
|
@ -84,6 +84,24 @@ class LLMEngine:
|
|||
def is_finished(self):
|
||||
return self.scheduler.is_finished()
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Reset the scheduler state and release all allocated blocks.
|
||||
This should be called when an exception occurs during generation to prevent
|
||||
KV cache block leaks that can cause 'deque index out of range' errors.
|
||||
"""
|
||||
# Deallocate all running sequences
|
||||
while self.scheduler.running:
|
||||
seq = self.scheduler.running.popleft()
|
||||
if seq.block_table: # Only deallocate if blocks are allocated
|
||||
self.scheduler.block_manager.deallocate(seq)
|
||||
|
||||
# Deallocate all waiting sequences (they might have blocks from preemption)
|
||||
while self.scheduler.waiting:
|
||||
seq = self.scheduler.waiting.popleft()
|
||||
if seq.block_table:
|
||||
self.scheduler.block_manager.deallocate(seq)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompts: list[str] | list[list[int]],
|
||||
|
|
@ -91,6 +109,11 @@ class LLMEngine:
|
|||
use_tqdm: bool = True,
|
||||
unconditional_prompts: list[str] | list[list[int]] | None = None,
|
||||
) -> list[str]:
|
||||
# Clean up any residual state from previous interrupted generations
|
||||
# This prevents 'deque index out of range' errors from accumulated block leaks
|
||||
if not self.is_finished():
|
||||
self.reset()
|
||||
|
||||
if use_tqdm:
|
||||
pbar = tqdm(total=len(prompts), desc="Generating", dynamic_ncols=True)
|
||||
if not isinstance(sampling_params, list):
|
||||
|
|
@ -101,24 +124,31 @@ class LLMEngine:
|
|||
self.add_request(prompt, sp, uncond_prompt)
|
||||
outputs = {}
|
||||
prefill_throughput = decode_throughput = 0.
|
||||
while not self.is_finished():
|
||||
t = perf_counter()
|
||||
output, num_tokens = self.step()
|
||||
if use_tqdm:
|
||||
if num_tokens > 0:
|
||||
prefill_throughput = num_tokens / (perf_counter() - t)
|
||||
else:
|
||||
decode_throughput = -num_tokens / (perf_counter() - t)
|
||||
pbar.set_postfix({
|
||||
"Prefill": f"{int(prefill_throughput)}tok/s",
|
||||
"Decode": f"{int(decode_throughput)}tok/s",
|
||||
})
|
||||
for seq_id, token_ids in output:
|
||||
outputs[seq_id] = token_ids
|
||||
try:
|
||||
while not self.is_finished():
|
||||
t = perf_counter()
|
||||
output, num_tokens = self.step()
|
||||
if use_tqdm:
|
||||
pbar.update(1)
|
||||
if num_tokens > 0:
|
||||
prefill_throughput = num_tokens / (perf_counter() - t)
|
||||
else:
|
||||
decode_throughput = -num_tokens / (perf_counter() - t)
|
||||
pbar.set_postfix({
|
||||
"Prefill": f"{int(prefill_throughput)}tok/s",
|
||||
"Decode": f"{int(decode_throughput)}tok/s",
|
||||
})
|
||||
for seq_id, token_ids in output:
|
||||
outputs[seq_id] = token_ids
|
||||
if use_tqdm:
|
||||
pbar.update(1)
|
||||
except Exception:
|
||||
# Clean up on exception to prevent block leaks
|
||||
self.reset()
|
||||
raise
|
||||
finally:
|
||||
if use_tqdm:
|
||||
pbar.close()
|
||||
|
||||
outputs = [outputs[seq_id] for seq_id in sorted(outputs.keys())]
|
||||
outputs = [{"text": self.tokenizer.decode(token_ids), "token_ids": token_ids} for token_ids in outputs]
|
||||
if use_tqdm:
|
||||
pbar.close()
|
||||
return outputs
|
||||
|
|
|
|||
|
|
@ -41,8 +41,12 @@ class Scheduler:
|
|||
|
||||
# Calculate tokens for both sequences
|
||||
total_tokens = (len(seq) - seq.num_cached_tokens) + (len(paired_seq) - paired_seq.num_cached_tokens)
|
||||
can_allocate_both = (self.block_manager.can_allocate(seq) and
|
||||
self.block_manager.can_allocate(paired_seq))
|
||||
|
||||
# FIX: Check if we have enough blocks for BOTH sequences combined
|
||||
# The old check was wrong: it checked each sequence independently,
|
||||
# but didn't account for the total blocks needed by both
|
||||
total_blocks_needed = seq.num_blocks + paired_seq.num_blocks
|
||||
can_allocate_both = len(self.block_manager.free_block_ids) >= total_blocks_needed
|
||||
|
||||
if num_batched_tokens + total_tokens > self.max_num_batched_tokens or not can_allocate_both:
|
||||
break
|
||||
|
|
@ -101,9 +105,13 @@ class Scheduler:
|
|||
# Remove paired_seq from temp_running
|
||||
temp_running.remove(paired_seq)
|
||||
|
||||
# Check if both can append
|
||||
can_append_both = (self.block_manager.can_append(seq) and
|
||||
self.block_manager.can_append(paired_seq))
|
||||
# FIX: Check if we have enough blocks for BOTH sequences to append
|
||||
# Each sequence needs 1 block when at block boundary (len % block_size == 1)
|
||||
block_size = self.block_manager.block_size
|
||||
blocks_needed_seq = 1 if len(seq) % block_size == 1 else 0
|
||||
blocks_needed_paired = 1 if len(paired_seq) % block_size == 1 else 0
|
||||
total_blocks_needed = blocks_needed_seq + blocks_needed_paired
|
||||
can_append_both = len(self.block_manager.free_block_ids) >= total_blocks_needed
|
||||
|
||||
if not can_append_both:
|
||||
# Try preempting other sequences
|
||||
|
|
@ -112,8 +120,8 @@ class Scheduler:
|
|||
other_seq = temp_running.pop(0)
|
||||
if other_seq != seq and other_seq != paired_seq:
|
||||
self.preempt(other_seq)
|
||||
can_append_both = (self.block_manager.can_append(seq) and
|
||||
self.block_manager.can_append(paired_seq))
|
||||
# Recalculate with the same correct logic
|
||||
can_append_both = len(self.block_manager.free_block_ids) >= total_blocks_needed
|
||||
preempted = True
|
||||
else:
|
||||
temp_running.append(other_seq)
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ Behavior:
|
|||
EOF
|
||||
}
|
||||
|
||||
PORT="8001"
|
||||
PORT="8002"
|
||||
PID=""
|
||||
FORCE="0"
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue