mirror of
https://github.com/ace-step/ACE-Step-1.5.git
synced 2026-07-02 16:37:04 +00:00
feat: add codes for openrouter_adapter
This commit is contained in:
parent
abc0814c72
commit
d43810fa42
3 changed files with 82 additions and 89 deletions
|
|
@ -21,7 +21,6 @@ import re
|
|||
import tempfile
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
|
|
@ -52,7 +51,7 @@ GENERATION_TIMEOUT = int(os.environ.get("ACESTEP_GENERATION_TIMEOUT", "600"))
|
|||
|
||||
def _generate_completion_id() -> str:
|
||||
"""Generate a unique completion ID."""
|
||||
return f"chatcmpl-{uuid4().hex[:24]}"
|
||||
return f"chatcmpl-{os.urandom(8).hex()}"
|
||||
|
||||
|
||||
def _get_model_id(model_name: str) -> str:
|
||||
|
|
@ -196,10 +195,12 @@ def _is_instrumental(lyrics: str) -> bool:
|
|||
return lyrics_clean in ("[inst]", "[instrumental]")
|
||||
|
||||
|
||||
def _parse_messages(messages: List[Any]) -> Tuple[str, str, List[str], Optional[str], Optional[str]]:
|
||||
def _parse_messages(messages: List[Any]) -> Tuple[str, str, List[str], Optional[str]]:
|
||||
"""
|
||||
Parse chat messages to extract prompt, lyrics, sample_query and audio references.
|
||||
|
||||
Only processes the last user message (consistent with server behavior).
|
||||
|
||||
Supports two modes:
|
||||
1. Tagged mode: Use <prompt>...</prompt> and <lyrics>...</lyrics> tags
|
||||
2. Heuristic mode: Auto-detect based on content structure
|
||||
|
|
@ -208,65 +209,28 @@ def _parse_messages(messages: List[Any]) -> Tuple[str, str, List[str], Optional[
|
|||
The caller routes them to src_audio / reference_audio based on task_type.
|
||||
|
||||
Returns:
|
||||
(prompt, lyrics, audio_paths, system_instruction, sample_query)
|
||||
(prompt, lyrics, audio_paths, sample_query)
|
||||
"""
|
||||
prompt_parts = []
|
||||
prompt = ""
|
||||
lyrics = ""
|
||||
sample_query = None
|
||||
audio_paths: List[str] = []
|
||||
system_instruction = None
|
||||
has_tags = False
|
||||
|
||||
for msg in messages:
|
||||
role = msg.role
|
||||
# Process only the last user message
|
||||
for msg in reversed(messages):
|
||||
if msg.role != "user" or not msg.content:
|
||||
continue
|
||||
|
||||
content = msg.content
|
||||
|
||||
if role == "system":
|
||||
if isinstance(content, str):
|
||||
system_instruction = content
|
||||
continue
|
||||
|
||||
if role != "user":
|
||||
continue
|
||||
|
||||
if isinstance(content, str):
|
||||
text = content.strip()
|
||||
tagged_prompt, tagged_lyrics, remaining = _extract_tagged_content(text)
|
||||
if tagged_prompt is not None or tagged_lyrics is not None:
|
||||
has_tags = True
|
||||
if tagged_prompt:
|
||||
prompt_parts.append(tagged_prompt)
|
||||
if tagged_lyrics:
|
||||
lyrics = tagged_lyrics
|
||||
if remaining:
|
||||
prompt_parts.append(remaining)
|
||||
else:
|
||||
if _looks_like_lyrics(text):
|
||||
lyrics = text
|
||||
else:
|
||||
prompt_parts.append(text)
|
||||
|
||||
elif isinstance(content, list):
|
||||
# Handle multimodal content (list of parts)
|
||||
if isinstance(content, list):
|
||||
text_parts = []
|
||||
for part in content:
|
||||
if isinstance(part, dict):
|
||||
part_type = part.get("type", "")
|
||||
|
||||
if part_type == "text":
|
||||
text = part.get("text", "").strip()
|
||||
tagged_prompt, tagged_lyrics, remaining = _extract_tagged_content(text)
|
||||
if tagged_prompt is not None or tagged_lyrics is not None:
|
||||
has_tags = True
|
||||
if tagged_prompt:
|
||||
prompt_parts.append(tagged_prompt)
|
||||
if tagged_lyrics:
|
||||
lyrics = tagged_lyrics
|
||||
if remaining:
|
||||
prompt_parts.append(remaining)
|
||||
elif _looks_like_lyrics(text):
|
||||
lyrics = text
|
||||
else:
|
||||
prompt_parts.append(text)
|
||||
|
||||
text_parts.append(part.get("text", "").strip())
|
||||
elif part_type == "input_audio":
|
||||
audio_data = part.get("input_audio", {})
|
||||
if isinstance(audio_data, dict):
|
||||
|
|
@ -278,24 +242,9 @@ def _parse_messages(messages: List[Any]) -> Tuple[str, str, List[str], Optional[
|
|||
audio_paths.append(path)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
elif hasattr(part, "type"):
|
||||
if part.type == "text":
|
||||
text = getattr(part, "text", "").strip()
|
||||
tagged_prompt, tagged_lyrics, remaining = _extract_tagged_content(text)
|
||||
if tagged_prompt is not None or tagged_lyrics is not None:
|
||||
has_tags = True
|
||||
if tagged_prompt:
|
||||
prompt_parts.append(tagged_prompt)
|
||||
if tagged_lyrics:
|
||||
lyrics = tagged_lyrics
|
||||
if remaining:
|
||||
prompt_parts.append(remaining)
|
||||
elif _looks_like_lyrics(text):
|
||||
lyrics = text
|
||||
else:
|
||||
prompt_parts.append(text)
|
||||
|
||||
text_parts.append(getattr(part, "text", "").strip())
|
||||
elif part.type == "input_audio":
|
||||
audio_data = getattr(part, "input_audio", None)
|
||||
if audio_data:
|
||||
|
|
@ -307,10 +256,30 @@ def _parse_messages(messages: List[Any]) -> Tuple[str, str, List[str], Optional[
|
|||
audio_paths.append(path)
|
||||
except Exception:
|
||||
pass
|
||||
content = "\n".join(text_parts).strip()
|
||||
else:
|
||||
content = content.strip()
|
||||
|
||||
prompt = " ".join(prompt_parts).strip()
|
||||
if not content:
|
||||
break
|
||||
|
||||
return prompt, lyrics, audio_paths, system_instruction, sample_query
|
||||
# Try to extract tagged content first
|
||||
tagged_prompt, tagged_lyrics, remaining = _extract_tagged_content(content)
|
||||
|
||||
if tagged_prompt is not None or tagged_lyrics is not None:
|
||||
prompt = tagged_prompt or ""
|
||||
lyrics = tagged_lyrics or ""
|
||||
if remaining and not prompt:
|
||||
prompt = remaining
|
||||
else:
|
||||
# No tags - use heuristic detection
|
||||
if _looks_like_lyrics(content):
|
||||
lyrics = content
|
||||
else:
|
||||
prompt = content
|
||||
break
|
||||
|
||||
return prompt, lyrics, audio_paths, sample_query
|
||||
|
||||
|
||||
def _to_generate_music_request(
|
||||
|
|
@ -320,6 +289,7 @@ def _to_generate_music_request(
|
|||
sample_query: Optional[str],
|
||||
reference_audio_path: Optional[str],
|
||||
src_audio_path: Optional[str],
|
||||
audio_codes: str = "",
|
||||
):
|
||||
"""
|
||||
Convert OpenRouter ChatCompletionRequest to api_server's GenerateMusicRequest.
|
||||
|
|
@ -328,9 +298,6 @@ def _to_generate_music_request(
|
|||
text2music: audio[0] → reference_audio
|
||||
cover/repaint/lego/…: audio[0] → src_audio, audio[1] → reference_audio
|
||||
|
||||
task_type auto-detection:
|
||||
text2music + reference_audio → music_continuation
|
||||
|
||||
Uses late import to avoid circular dependency with api_server.
|
||||
"""
|
||||
from acestep.api_server import GenerateMusicRequest
|
||||
|
|
@ -348,7 +315,7 @@ def _to_generate_music_request(
|
|||
resolved_lyrics = "[inst]"
|
||||
|
||||
# Resolve sample_mode: explicit field takes priority, then auto-detect from messages
|
||||
resolved_sample_mode = req.sample_mode or bool(sample_query)
|
||||
resolved_sample_mode = req.sample_mode
|
||||
resolved_sample_query = sample_query or ""
|
||||
|
||||
# Resolve seed: pass through as-is (int or comma-separated string)
|
||||
|
|
@ -356,13 +323,6 @@ def _to_generate_music_request(
|
|||
resolved_seed = req.seed if req.seed is not None else -1
|
||||
use_random_seed = req.seed is None
|
||||
|
||||
# Resolve task_type
|
||||
# Explicit task_type from request takes priority.
|
||||
# For text2music: auto-detect based on reference_audio.
|
||||
resolved_task_type = req.task_type
|
||||
if resolved_task_type == "text2music" and reference_audio_path:
|
||||
resolved_task_type = "music_continuation"
|
||||
|
||||
return GenerateMusicRequest(
|
||||
# Text input
|
||||
prompt=prompt,
|
||||
|
|
@ -381,17 +341,19 @@ def _to_generate_music_request(
|
|||
lm_temperature=req.temperature if req.temperature is not None else 0.85,
|
||||
lm_top_p=req.top_p if req.top_p is not None else 0.9,
|
||||
lm_top_k=req.top_k if req.top_k is not None else 0,
|
||||
lm_cfg_scale=req.lm_cfg_scale,
|
||||
thinking=req.thinking if req.thinking is not None else False,
|
||||
|
||||
# Generation parameters
|
||||
inference_steps=8,
|
||||
inference_steps=req.inference_steps,
|
||||
infer_method=req.infer_method,
|
||||
guidance_scale=req.guidance_scale if req.guidance_scale is not None else 7.0,
|
||||
seed=resolved_seed,
|
||||
use_random_seed=use_random_seed,
|
||||
batch_size=req.batch_size if req.batch_size is not None else 1,
|
||||
|
||||
# Task type
|
||||
task_type=resolved_task_type,
|
||||
task_type=req.task_type,
|
||||
|
||||
# Audio paths
|
||||
reference_audio_path=reference_audio_path or None,
|
||||
|
|
@ -444,6 +406,9 @@ def _build_openrouter_response(
|
|||
"audio_url": {"url": b64_url},
|
||||
}]
|
||||
|
||||
# Extract audio_codes from result if available
|
||||
audio_codes = result.get("audio_codes") or None
|
||||
|
||||
response_data = {
|
||||
"id": completion_id,
|
||||
"object": "chat.completion",
|
||||
|
|
@ -455,6 +420,7 @@ def _build_openrouter_response(
|
|||
"role": "assistant",
|
||||
"content": text_content,
|
||||
"audio": audio_obj,
|
||||
"audio_codes": audio_codes,
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
}],
|
||||
|
|
@ -510,7 +476,7 @@ async def _openrouter_stream_generator(
|
|||
return f"data: {json.dumps(chunk)}\n\n"
|
||||
|
||||
# Initial role chunk
|
||||
yield _make_chunk(role="assistant", content="Generating music")
|
||||
yield _make_chunk(role="assistant", content="")
|
||||
await asyncio.sleep(0)
|
||||
|
||||
# Wait for result with periodic heartbeats
|
||||
|
|
@ -555,6 +521,12 @@ async def _openrouter_stream_generator(
|
|||
yield _make_chunk(audio=audio_list)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
# Send audio_codes if available
|
||||
audio_codes = result.get("audio_codes")
|
||||
if audio_codes:
|
||||
yield _make_chunk(content=f"\n\n[audio_codes]{audio_codes}[/audio_codes]")
|
||||
await asyncio.sleep(0)
|
||||
|
||||
# Finish
|
||||
yield _make_chunk(finish_reason="stop")
|
||||
yield "data: [DONE]\n\n"
|
||||
|
|
@ -666,8 +638,8 @@ def create_openrouter_router(app_state_getter) -> APIRouter:
|
|||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid request format: {str(e)}")
|
||||
|
||||
# Parse messages for text, audio, and system instruction
|
||||
prompt, lyrics, audio_paths, system_instruction, sample_query = _parse_messages(req.messages)
|
||||
# Parse messages for text, audio, and sample_query
|
||||
prompt, lyrics, audio_paths, sample_query = _parse_messages(req.messages)
|
||||
|
||||
# When lyrics or sample_mode is explicitly provided, the message text role
|
||||
# is already known — skip auto-detection results.
|
||||
|
|
@ -700,7 +672,7 @@ def create_openrouter_router(app_state_getter) -> APIRouter:
|
|||
# audio[1] → reference_audio (optional: style conditioning)
|
||||
#
|
||||
# For text2music (default):
|
||||
# audio[0] → reference_audio (style conditioning → music_continuation)
|
||||
# audio[0] → reference_audio (style conditioning)
|
||||
reference_audio_path = None
|
||||
src_audio_path = None
|
||||
_SRC_AUDIO_TASK_TYPES = {"cover", "repaint", "lego", "extract", "complete"}
|
||||
|
|
@ -712,9 +684,23 @@ def create_openrouter_router(app_state_getter) -> APIRouter:
|
|||
else:
|
||||
reference_audio_path = audio_paths[0]
|
||||
|
||||
# Auto-convert src_audio to codes for cover mode when no codes provided
|
||||
resolved_audio_codes = req.audio_codes
|
||||
if (src_audio_path and not resolved_audio_codes
|
||||
and req.task_type == "cover"):
|
||||
handler = getattr(state, "handler", None)
|
||||
if handler and hasattr(handler, "convert_src_audio_to_codes"):
|
||||
try:
|
||||
codes_str = handler.convert_src_audio_to_codes(src_audio_path)
|
||||
if codes_str and not codes_str.startswith("❌"):
|
||||
resolved_audio_codes = codes_str
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Convert to GenerateMusicRequest
|
||||
gen_request = _to_generate_music_request(
|
||||
req, prompt, lyrics, sample_query, reference_audio_path, src_audio_path
|
||||
req, prompt, lyrics, sample_query, reference_audio_path, src_audio_path,
|
||||
audio_codes=resolved_audio_codes,
|
||||
)
|
||||
|
||||
# Check queue capacity
|
||||
|
|
|
|||
|
|
@ -101,6 +101,14 @@ class ChatCompletionRequest(BaseModel):
|
|||
repainting_end: Optional[float] = Field(default=None, description="Repainting region end (seconds)")
|
||||
audio_cover_strength: float = Field(default=1.0, description="Audio cover strength (0.0~1.0)")
|
||||
|
||||
# Extended fields for ComfyUI node compatibility
|
||||
audio_codes: str = Field(default="", description="Pre-computed audio codes (bypass auto-conversion)")
|
||||
cover_noise_strength: float = Field(default=0.0, description="Cover noise strength (0=pure noise, 1=closest to src)")
|
||||
inference_steps: int = Field(default=8, description="Number of diffusion inference steps")
|
||||
infer_method: str = Field(default="ode", description="Diffusion inference method: 'ode' or 'sde'")
|
||||
lm_cfg_scale: float = Field(default=2.0, description="LM classifier-free guidance scale")
|
||||
use_cot_metas: Optional[bool] = Field(default=None, description="Use CoT for metadata generation (auto if None)")
|
||||
|
||||
class Config:
|
||||
extra = "allow" # Allow additional fields for forward compatibility
|
||||
|
||||
|
|
@ -125,6 +133,7 @@ class AssistantMessage(BaseModel):
|
|||
role: Literal["assistant"] = "assistant"
|
||||
content: Optional[str] = Field(default=None, description="Text content")
|
||||
audio: Optional[List[AudioOutput]] = Field(default=None, description="Generated audio files")
|
||||
audio_codes: Optional[str] = Field(default=None, description="Generated audio codes")
|
||||
|
||||
|
||||
class Choice(BaseModel):
|
||||
|
|
|
|||
|
|
@ -5,8 +5,6 @@ Provides OpenAI Chat Completions API format for text-to-music generation.
|
|||
Endpoints:
|
||||
- GET /v1/models List available models with pricing
|
||||
- POST /v1/chat/completions Generate music from text prompt
|
||||
- POST /v1/sample LLM generates caption/lyrics/metadata from query
|
||||
- POST /v1/audio2code Convert source audio to audio code tokens
|
||||
- GET /health Health check
|
||||
|
||||
Usage:
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue