feat: add codes for openrouter_adapter

This commit is contained in:
Sayo 2026-03-01 19:34:55 +09:00
parent abc0814c72
commit d43810fa42
3 changed files with 82 additions and 89 deletions

View file

@ -21,7 +21,6 @@ import re
import tempfile
import time
from typing import Any, Dict, List, Optional, Tuple
from uuid import uuid4
from fastapi import APIRouter, HTTPException, Request
from fastapi.responses import JSONResponse, StreamingResponse
@ -52,7 +51,7 @@ GENERATION_TIMEOUT = int(os.environ.get("ACESTEP_GENERATION_TIMEOUT", "600"))
def _generate_completion_id() -> str:
"""Generate a unique completion ID."""
return f"chatcmpl-{uuid4().hex[:24]}"
return f"chatcmpl-{os.urandom(8).hex()}"
def _get_model_id(model_name: str) -> str:
@ -196,10 +195,12 @@ def _is_instrumental(lyrics: str) -> bool:
return lyrics_clean in ("[inst]", "[instrumental]")
def _parse_messages(messages: List[Any]) -> Tuple[str, str, List[str], Optional[str], Optional[str]]:
def _parse_messages(messages: List[Any]) -> Tuple[str, str, List[str], Optional[str]]:
"""
Parse chat messages to extract prompt, lyrics, sample_query and audio references.
Only processes the last user message (consistent with server behavior).
Supports two modes:
1. Tagged mode: Use <prompt>...</prompt> and <lyrics>...</lyrics> tags
2. Heuristic mode: Auto-detect based on content structure
@ -208,65 +209,28 @@ def _parse_messages(messages: List[Any]) -> Tuple[str, str, List[str], Optional[
The caller routes them to src_audio / reference_audio based on task_type.
Returns:
(prompt, lyrics, audio_paths, system_instruction, sample_query)
(prompt, lyrics, audio_paths, sample_query)
"""
prompt_parts = []
prompt = ""
lyrics = ""
sample_query = None
audio_paths: List[str] = []
system_instruction = None
has_tags = False
for msg in messages:
role = msg.role
# Process only the last user message
for msg in reversed(messages):
if msg.role != "user" or not msg.content:
continue
content = msg.content
if role == "system":
if isinstance(content, str):
system_instruction = content
continue
if role != "user":
continue
if isinstance(content, str):
text = content.strip()
tagged_prompt, tagged_lyrics, remaining = _extract_tagged_content(text)
if tagged_prompt is not None or tagged_lyrics is not None:
has_tags = True
if tagged_prompt:
prompt_parts.append(tagged_prompt)
if tagged_lyrics:
lyrics = tagged_lyrics
if remaining:
prompt_parts.append(remaining)
else:
if _looks_like_lyrics(text):
lyrics = text
else:
prompt_parts.append(text)
elif isinstance(content, list):
# Handle multimodal content (list of parts)
if isinstance(content, list):
text_parts = []
for part in content:
if isinstance(part, dict):
part_type = part.get("type", "")
if part_type == "text":
text = part.get("text", "").strip()
tagged_prompt, tagged_lyrics, remaining = _extract_tagged_content(text)
if tagged_prompt is not None or tagged_lyrics is not None:
has_tags = True
if tagged_prompt:
prompt_parts.append(tagged_prompt)
if tagged_lyrics:
lyrics = tagged_lyrics
if remaining:
prompt_parts.append(remaining)
elif _looks_like_lyrics(text):
lyrics = text
else:
prompt_parts.append(text)
text_parts.append(part.get("text", "").strip())
elif part_type == "input_audio":
audio_data = part.get("input_audio", {})
if isinstance(audio_data, dict):
@ -278,24 +242,9 @@ def _parse_messages(messages: List[Any]) -> Tuple[str, str, List[str], Optional[
audio_paths.append(path)
except Exception:
pass
elif hasattr(part, "type"):
if part.type == "text":
text = getattr(part, "text", "").strip()
tagged_prompt, tagged_lyrics, remaining = _extract_tagged_content(text)
if tagged_prompt is not None or tagged_lyrics is not None:
has_tags = True
if tagged_prompt:
prompt_parts.append(tagged_prompt)
if tagged_lyrics:
lyrics = tagged_lyrics
if remaining:
prompt_parts.append(remaining)
elif _looks_like_lyrics(text):
lyrics = text
else:
prompt_parts.append(text)
text_parts.append(getattr(part, "text", "").strip())
elif part.type == "input_audio":
audio_data = getattr(part, "input_audio", None)
if audio_data:
@ -307,10 +256,30 @@ def _parse_messages(messages: List[Any]) -> Tuple[str, str, List[str], Optional[
audio_paths.append(path)
except Exception:
pass
content = "\n".join(text_parts).strip()
else:
content = content.strip()
prompt = " ".join(prompt_parts).strip()
if not content:
break
return prompt, lyrics, audio_paths, system_instruction, sample_query
# Try to extract tagged content first
tagged_prompt, tagged_lyrics, remaining = _extract_tagged_content(content)
if tagged_prompt is not None or tagged_lyrics is not None:
prompt = tagged_prompt or ""
lyrics = tagged_lyrics or ""
if remaining and not prompt:
prompt = remaining
else:
# No tags - use heuristic detection
if _looks_like_lyrics(content):
lyrics = content
else:
prompt = content
break
return prompt, lyrics, audio_paths, sample_query
def _to_generate_music_request(
@ -320,6 +289,7 @@ def _to_generate_music_request(
sample_query: Optional[str],
reference_audio_path: Optional[str],
src_audio_path: Optional[str],
audio_codes: str = "",
):
"""
Convert OpenRouter ChatCompletionRequest to api_server's GenerateMusicRequest.
@ -328,9 +298,6 @@ def _to_generate_music_request(
text2music: audio[0] reference_audio
cover/repaint/lego/: audio[0] src_audio, audio[1] reference_audio
task_type auto-detection:
text2music + reference_audio music_continuation
Uses late import to avoid circular dependency with api_server.
"""
from acestep.api_server import GenerateMusicRequest
@ -348,7 +315,7 @@ def _to_generate_music_request(
resolved_lyrics = "[inst]"
# Resolve sample_mode: explicit field takes priority, then auto-detect from messages
resolved_sample_mode = req.sample_mode or bool(sample_query)
resolved_sample_mode = req.sample_mode
resolved_sample_query = sample_query or ""
# Resolve seed: pass through as-is (int or comma-separated string)
@ -356,13 +323,6 @@ def _to_generate_music_request(
resolved_seed = req.seed if req.seed is not None else -1
use_random_seed = req.seed is None
# Resolve task_type
# Explicit task_type from request takes priority.
# For text2music: auto-detect based on reference_audio.
resolved_task_type = req.task_type
if resolved_task_type == "text2music" and reference_audio_path:
resolved_task_type = "music_continuation"
return GenerateMusicRequest(
# Text input
prompt=prompt,
@ -381,17 +341,19 @@ def _to_generate_music_request(
lm_temperature=req.temperature if req.temperature is not None else 0.85,
lm_top_p=req.top_p if req.top_p is not None else 0.9,
lm_top_k=req.top_k if req.top_k is not None else 0,
lm_cfg_scale=req.lm_cfg_scale,
thinking=req.thinking if req.thinking is not None else False,
# Generation parameters
inference_steps=8,
inference_steps=req.inference_steps,
infer_method=req.infer_method,
guidance_scale=req.guidance_scale if req.guidance_scale is not None else 7.0,
seed=resolved_seed,
use_random_seed=use_random_seed,
batch_size=req.batch_size if req.batch_size is not None else 1,
# Task type
task_type=resolved_task_type,
task_type=req.task_type,
# Audio paths
reference_audio_path=reference_audio_path or None,
@ -444,6 +406,9 @@ def _build_openrouter_response(
"audio_url": {"url": b64_url},
}]
# Extract audio_codes from result if available
audio_codes = result.get("audio_codes") or None
response_data = {
"id": completion_id,
"object": "chat.completion",
@ -455,6 +420,7 @@ def _build_openrouter_response(
"role": "assistant",
"content": text_content,
"audio": audio_obj,
"audio_codes": audio_codes,
},
"finish_reason": "stop",
}],
@ -510,7 +476,7 @@ async def _openrouter_stream_generator(
return f"data: {json.dumps(chunk)}\n\n"
# Initial role chunk
yield _make_chunk(role="assistant", content="Generating music")
yield _make_chunk(role="assistant", content="")
await asyncio.sleep(0)
# Wait for result with periodic heartbeats
@ -555,6 +521,12 @@ async def _openrouter_stream_generator(
yield _make_chunk(audio=audio_list)
await asyncio.sleep(0)
# Send audio_codes if available
audio_codes = result.get("audio_codes")
if audio_codes:
yield _make_chunk(content=f"\n\n[audio_codes]{audio_codes}[/audio_codes]")
await asyncio.sleep(0)
# Finish
yield _make_chunk(finish_reason="stop")
yield "data: [DONE]\n\n"
@ -666,8 +638,8 @@ def create_openrouter_router(app_state_getter) -> APIRouter:
except Exception as e:
raise HTTPException(status_code=400, detail=f"Invalid request format: {str(e)}")
# Parse messages for text, audio, and system instruction
prompt, lyrics, audio_paths, system_instruction, sample_query = _parse_messages(req.messages)
# Parse messages for text, audio, and sample_query
prompt, lyrics, audio_paths, sample_query = _parse_messages(req.messages)
# When lyrics or sample_mode is explicitly provided, the message text role
# is already known — skip auto-detection results.
@ -700,7 +672,7 @@ def create_openrouter_router(app_state_getter) -> APIRouter:
# audio[1] → reference_audio (optional: style conditioning)
#
# For text2music (default):
# audio[0] → reference_audio (style conditioning → music_continuation)
# audio[0] → reference_audio (style conditioning)
reference_audio_path = None
src_audio_path = None
_SRC_AUDIO_TASK_TYPES = {"cover", "repaint", "lego", "extract", "complete"}
@ -712,9 +684,23 @@ def create_openrouter_router(app_state_getter) -> APIRouter:
else:
reference_audio_path = audio_paths[0]
# Auto-convert src_audio to codes for cover mode when no codes provided
resolved_audio_codes = req.audio_codes
if (src_audio_path and not resolved_audio_codes
and req.task_type == "cover"):
handler = getattr(state, "handler", None)
if handler and hasattr(handler, "convert_src_audio_to_codes"):
try:
codes_str = handler.convert_src_audio_to_codes(src_audio_path)
if codes_str and not codes_str.startswith(""):
resolved_audio_codes = codes_str
except Exception:
pass
# Convert to GenerateMusicRequest
gen_request = _to_generate_music_request(
req, prompt, lyrics, sample_query, reference_audio_path, src_audio_path
req, prompt, lyrics, sample_query, reference_audio_path, src_audio_path,
audio_codes=resolved_audio_codes,
)
# Check queue capacity

View file

@ -101,6 +101,14 @@ class ChatCompletionRequest(BaseModel):
repainting_end: Optional[float] = Field(default=None, description="Repainting region end (seconds)")
audio_cover_strength: float = Field(default=1.0, description="Audio cover strength (0.0~1.0)")
# Extended fields for ComfyUI node compatibility
audio_codes: str = Field(default="", description="Pre-computed audio codes (bypass auto-conversion)")
cover_noise_strength: float = Field(default=0.0, description="Cover noise strength (0=pure noise, 1=closest to src)")
inference_steps: int = Field(default=8, description="Number of diffusion inference steps")
infer_method: str = Field(default="ode", description="Diffusion inference method: 'ode' or 'sde'")
lm_cfg_scale: float = Field(default=2.0, description="LM classifier-free guidance scale")
use_cot_metas: Optional[bool] = Field(default=None, description="Use CoT for metadata generation (auto if None)")
class Config:
extra = "allow" # Allow additional fields for forward compatibility
@ -125,6 +133,7 @@ class AssistantMessage(BaseModel):
role: Literal["assistant"] = "assistant"
content: Optional[str] = Field(default=None, description="Text content")
audio: Optional[List[AudioOutput]] = Field(default=None, description="Generated audio files")
audio_codes: Optional[str] = Field(default=None, description="Generated audio codes")
class Choice(BaseModel):

View file

@ -5,8 +5,6 @@ Provides OpenAI Chat Completions API format for text-to-music generation.
Endpoints:
- GET /v1/models List available models with pricing
- POST /v1/chat/completions Generate music from text prompt
- POST /v1/sample LLM generates caption/lyrics/metadata from query
- POST /v1/audio2code Convert source audio to audio code tokens
- GET /health Health check
Usage: