mirror of
https://github.com/ace-step/ACE-Step-1.5.git
synced 2026-07-02 16:37:04 +00:00
Follow-up to #1056 — Codex review found checkpoint path resolution was inconsistent across CLI, DiT initializer, LM handler, and catalog. Resolution order per callsite: - ACESTEP_CHECKPOINTS_DIR env var (highest priority, all callsites) - Explicit project_root argument (orchestrator, catalog) - get_checkpoints_dir() fallback for CLI/LM (cwd-based) Changes: - cli.py: _get_default_checkpoint_dir() delegates to get_checkpoints_dir() - llm_inference.py: _get_checkpoint_dir() checks env var then uses get_checkpoints_dir() instead of __file__-based path - orchestrator: env var > explicit project_root > get_checkpoints_dir() - catalog: env var > self._get_project_root() fallback Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2009 lines
86 KiB
Python
2009 lines
86 KiB
Python
import argparse
|
|
import re
|
|
import ast
|
|
import os
|
|
import sys
|
|
import toml
|
|
from pathlib import Path
|
|
from typing import List, Optional, Tuple
|
|
|
|
# Load environment variables from .env or .env.example (if available)
|
|
try:
|
|
from dotenv import load_dotenv
|
|
_current_file = os.path.abspath(__file__)
|
|
_project_root = os.path.dirname(_current_file)
|
|
_env_path = os.path.join(_project_root, '.env')
|
|
_env_example_path = os.path.join(_project_root, '.env.example')
|
|
|
|
if os.path.exists(_env_path):
|
|
load_dotenv(_env_path)
|
|
print(f"Loaded configuration from {_env_path}")
|
|
elif os.path.exists(_env_example_path):
|
|
load_dotenv(_env_example_path)
|
|
print(f"Loaded configuration from {_env_example_path} (fallback)")
|
|
except ImportError:
|
|
pass
|
|
|
|
# Clear proxy settings that may affect network behavior
|
|
for _proxy_var in ['http_proxy', 'https_proxy', 'HTTP_PROXY', 'HTTPS_PROXY', 'ALL_PROXY']:
|
|
os.environ.pop(_proxy_var, None)
|
|
|
|
def _configure_logging(
|
|
level: Optional[str] = None,
|
|
suppress_audio_tokens: Optional[bool] = None,
|
|
) -> None:
|
|
try:
|
|
from loguru import logger
|
|
except Exception:
|
|
return
|
|
|
|
if suppress_audio_tokens is None:
|
|
suppress_audio_tokens = os.environ.get("ACE_STEP_SUPPRESS_AUDIO_TOKENS", "1") not in {"0", "false", "False"}
|
|
if level is None:
|
|
level = "INFO"
|
|
level = str(level).upper()
|
|
|
|
def _log_filter(record) -> bool:
|
|
message = record.get("message", "")
|
|
# Suppress duplicate DiT prompt logs (we print a single final prompt in cli.py)
|
|
if (
|
|
"DiT TEXT ENCODER INPUT" in message
|
|
or "text_prompt:" in message
|
|
or (message.strip() and set(message.strip()) == {"="})
|
|
):
|
|
return False
|
|
if not suppress_audio_tokens:
|
|
return True
|
|
return "<|audio_code_" not in message
|
|
|
|
logger.remove()
|
|
logger.add(sys.stderr, level=level, filter=_log_filter)
|
|
|
|
|
|
_configure_logging()
|
|
|
|
from acestep.handler import AceStepHandler
|
|
from acestep.llm_inference import LLMHandler
|
|
from acestep.inference import GenerationParams, GenerationConfig, generate_music, create_sample, format_sample
|
|
from acestep.constants import DEFAULT_DIT_INSTRUCTION, TASK_INSTRUCTIONS
|
|
from acestep.gpu_config import get_gpu_config, set_global_gpu_config, is_mps_platform
|
|
import torch
|
|
|
|
|
|
TRACK_CHOICES = [
|
|
"vocals",
|
|
"backing_vocals",
|
|
"drums",
|
|
"bass",
|
|
"guitar",
|
|
"keyboard",
|
|
"percussion",
|
|
"strings",
|
|
"synth",
|
|
"fx",
|
|
"brass",
|
|
"woodwinds",
|
|
]
|
|
|
|
|
|
def _get_project_root() -> str:
|
|
return os.path.dirname(os.path.abspath(__file__))
|
|
|
|
|
|
def _get_default_checkpoint_dir() -> str:
|
|
"""Return the default checkpoints directory via the shared resolver.
|
|
|
|
Always delegates to model_downloader.get_checkpoints_dir() so that
|
|
ACESTEP_CHECKPOINTS_DIR, ACESTEP_PROJECT_ROOT, and the cwd-based
|
|
fallback are handled in one place.
|
|
"""
|
|
from acestep.model_downloader import get_checkpoints_dir
|
|
return str(get_checkpoints_dir())
|
|
|
|
|
|
def _parse_description_hints(description: str) -> tuple[Optional[str], bool]:
|
|
import re
|
|
|
|
if not description:
|
|
return None, False
|
|
|
|
description_lower = description.lower().strip()
|
|
|
|
language_mapping = {
|
|
'english': 'en', 'en': 'en',
|
|
'chinese': 'zh', '中文': 'zh', 'zh': 'zh', 'mandarin': 'zh',
|
|
'japanese': 'ja', '日本語': 'ja', 'ja': 'ja',
|
|
'korean': 'ko', '한국어': 'ko', 'ko': 'ko',
|
|
'spanish': 'es', 'español': 'es', 'es': 'es',
|
|
'french': 'fr', 'français': 'fr', 'fr': 'fr',
|
|
'german': 'de', 'deutsch': 'de', 'de': 'de',
|
|
'italian': 'it', 'italiano': 'it', 'it': 'it',
|
|
'portuguese': 'pt', 'português': 'pt', 'pt': 'pt',
|
|
'russian': 'ru', 'русский': 'ru', 'ru': 'ru',
|
|
'bengali': 'bn', 'bn': 'bn',
|
|
'hindi': 'hi', 'hi': 'hi',
|
|
'arabic': 'ar', 'ar': 'ar',
|
|
'thai': 'th', 'th': 'th',
|
|
'vietnamese': 'vi', 'vi': 'vi',
|
|
'indonesian': 'id', 'id': 'id',
|
|
'turkish': 'tr', 'tr': 'tr',
|
|
'dutch': 'nl', 'nl': 'nl',
|
|
'polish': 'pl', 'pl': 'pl',
|
|
}
|
|
|
|
detected_language = None
|
|
for lang_name, lang_code in language_mapping.items():
|
|
if len(lang_name) <= 2:
|
|
pattern = r'(?:^|\s|[.,;:!?])' + re.escape(lang_name) + r'(?:$|\s|[.,;:!?])'
|
|
else:
|
|
pattern = r'\b' + re.escape(lang_name) + r'\b'
|
|
if re.search(pattern, description_lower):
|
|
detected_language = lang_code
|
|
break
|
|
|
|
is_instrumental = False
|
|
if 'instrumental' in description_lower:
|
|
is_instrumental = True
|
|
elif 'pure music' in description_lower or 'pure instrument' in description_lower:
|
|
is_instrumental = True
|
|
elif description_lower.endswith(' solo') or description_lower == 'solo':
|
|
is_instrumental = True
|
|
|
|
return detected_language, is_instrumental
|
|
|
|
|
|
def _prompt_non_empty(prompt: str) -> str:
|
|
value = input(prompt).strip()
|
|
while not value:
|
|
value = input(prompt).strip()
|
|
return value
|
|
|
|
|
|
def _prompt_with_default(prompt: str, default: Optional[str] = None, required: bool = False) -> str:
|
|
while True:
|
|
suffix = f" [{default}]" if default not in (None, "") else ""
|
|
value = input(f"{prompt}{suffix}: ").strip()
|
|
if value:
|
|
return value
|
|
if default not in (None, ""):
|
|
return str(default)
|
|
if not required:
|
|
return ""
|
|
print("This value is required. Please try again.")
|
|
|
|
|
|
def _prompt_bool(prompt: str, default: bool) -> bool:
|
|
default_str = "y" if default else "n"
|
|
while True:
|
|
value = input(f"{prompt} (y/n) [default: {default_str}]: ").strip().lower()
|
|
if not value:
|
|
return default
|
|
if value in {"y", "yes", "1", "true"}:
|
|
return True
|
|
if value in {"n", "no", "0", "false"}:
|
|
return False
|
|
print("Please enter 'y' or 'n'.")
|
|
|
|
|
|
def _prompt_choice_from_list(
|
|
prompt: str,
|
|
options: List[str],
|
|
default: Optional[str] = None,
|
|
allow_custom: bool = True,
|
|
custom_validator=None,
|
|
custom_error: Optional[str] = None,
|
|
) -> Optional[str]:
|
|
if not options:
|
|
return default
|
|
print("\n" + prompt)
|
|
for idx, option in enumerate(options, start=1):
|
|
print(f"{idx}. {option}")
|
|
default_display = default if default not in (None, "") else "auto"
|
|
while True:
|
|
choice = input(f"Choose a model (number or name) [default: {default_display}]: ").strip()
|
|
if not choice:
|
|
return None if default_display == "auto" else default
|
|
if choice.lower() == "auto":
|
|
return None
|
|
if choice.isdigit():
|
|
idx = int(choice)
|
|
if 1 <= idx <= len(options):
|
|
return options[idx - 1]
|
|
print("Invalid selection. Please choose a valid number.")
|
|
continue
|
|
if allow_custom:
|
|
if custom_validator and not custom_validator(choice):
|
|
print(custom_error or "Invalid selection. Please try again.")
|
|
continue
|
|
if choice not in options:
|
|
print("Unknown model. Using as-is.")
|
|
return choice
|
|
print("Please choose a valid option.")
|
|
|
|
|
|
def _edit_formatted_prompt_via_file(formatted_prompt: str, instruction_path: str) -> str:
|
|
"""Write formatted prompt to file, wait for user edits, then read back."""
|
|
try:
|
|
with open(instruction_path, "w", encoding="utf-8") as f:
|
|
f.write(formatted_prompt)
|
|
except Exception as e:
|
|
print(f"WARNING: Failed to write {instruction_path}: {e}")
|
|
return formatted_prompt
|
|
|
|
print("\n--- Final Draft Saved ---")
|
|
print(f"Saved to {instruction_path}")
|
|
print("Edit the file now. Press Enter when ready to continue.")
|
|
input()
|
|
|
|
try:
|
|
with open(instruction_path, "r", encoding="utf-8") as f:
|
|
return f.read()
|
|
except Exception as e:
|
|
print(f"WARNING: Failed to read {instruction_path}: {e}")
|
|
return formatted_prompt
|
|
|
|
|
|
def _extract_caption_lyrics_from_formatted_prompt(formatted_prompt: str) -> Tuple[Optional[str], Optional[str]]:
|
|
"""Best-effort extraction of caption/lyrics from a formatted prompt string."""
|
|
matches = list(re.finditer(r"# Caption\n(.*?)\n+# Lyric\n(.*)", formatted_prompt, re.DOTALL))
|
|
if not matches:
|
|
return None, None
|
|
|
|
caption = matches[-1].group(1).strip()
|
|
lyrics = matches[-1].group(2)
|
|
|
|
# Trim lyrics if chat-template markers appear after the user message.
|
|
cut_markers = ["<|eot_id|>", "<|start_header_id|>", "<|assistant|>", "<|user|>", "<|system|>", "<|im_end|>", "<|im_start|>"]
|
|
cut_at = len(lyrics)
|
|
for marker in cut_markers:
|
|
pos = lyrics.find(marker)
|
|
if pos != -1:
|
|
cut_at = min(cut_at, pos)
|
|
lyrics = lyrics[:cut_at].rstrip()
|
|
|
|
return caption or None, lyrics or None
|
|
|
|
|
|
def _extract_instruction_from_formatted_prompt(formatted_prompt: str) -> Optional[str]:
|
|
"""Best-effort extraction of instruction text from a formatted prompt string."""
|
|
match = re.search(r"# Instruction\n(.*?)\n\n", formatted_prompt, re.DOTALL)
|
|
if not match:
|
|
return None
|
|
instruction = match.group(1).strip()
|
|
return instruction or None
|
|
|
|
|
|
def _extract_cot_metadata_from_formatted_prompt(formatted_prompt: str) -> dict:
|
|
"""Best-effort extraction of COT metadata from a formatted prompt string,
|
|
supporting multi-line values.
|
|
"""
|
|
matches = list(re.finditer(r"<think>\n(.*?)\n</think>", formatted_prompt, re.DOTALL))
|
|
if not matches:
|
|
return {}
|
|
block = matches[-1].group(1)
|
|
metadata = {}
|
|
current_key = None
|
|
current_value_lines = []
|
|
|
|
for line in block.splitlines():
|
|
line = line.strip()
|
|
if not line:
|
|
continue
|
|
|
|
key_match = re.match(r"^(\w+):\s*(.*)", line)
|
|
if key_match:
|
|
if current_key:
|
|
metadata[current_key] = " ".join(current_value_lines).strip()
|
|
|
|
current_key = key_match.group(1).strip().lower()
|
|
current_value_lines = [key_match.group(2).strip()]
|
|
else:
|
|
if current_key:
|
|
current_value_lines.append(line)
|
|
|
|
if current_key and current_value_lines:
|
|
metadata[current_key] = " ".join(current_value_lines).strip()
|
|
|
|
return metadata
|
|
|
|
|
|
def _parse_number(value: str) -> Optional[float]:
|
|
try:
|
|
match = re.search(r"[-+]?\d*\.?\d+", value)
|
|
if not match:
|
|
return None
|
|
return float(match.group(0))
|
|
except Exception:
|
|
return None
|
|
|
|
|
|
def _parse_timesteps_input(value) -> Optional[List[float]]:
|
|
if value is None:
|
|
return None
|
|
if isinstance(value, list):
|
|
if all(isinstance(t, (int, float)) for t in value):
|
|
return [float(t) for t in value]
|
|
return None
|
|
if not isinstance(value, str):
|
|
return None
|
|
raw = value.strip()
|
|
if not raw:
|
|
return None
|
|
if raw.startswith("[") or raw.startswith("("):
|
|
try:
|
|
parsed = ast.literal_eval(raw)
|
|
except Exception:
|
|
return None
|
|
if isinstance(parsed, list) and all(isinstance(t, (int, float)) for t in parsed):
|
|
return [float(t) for t in parsed]
|
|
return None
|
|
try:
|
|
return [float(t.strip()) for t in raw.split(",") if t.strip()]
|
|
except Exception:
|
|
return None
|
|
|
|
|
|
def _install_prompt_edit_hook(
|
|
llm_handler: LLMHandler,
|
|
instruction_path: str,
|
|
preloaded_prompt: Optional[str] = None,
|
|
) -> None:
|
|
"""Intercept formatted prompt generation to allow user editing before audio tokens."""
|
|
original = llm_handler.build_formatted_prompt_with_cot
|
|
cache = {}
|
|
|
|
def wrapped(caption, lyrics, cot_text, is_negative_prompt=False, negative_prompt="NO USER INPUT"):
|
|
prompt = original(
|
|
caption,
|
|
lyrics,
|
|
cot_text,
|
|
is_negative_prompt=is_negative_prompt,
|
|
negative_prompt=negative_prompt,
|
|
)
|
|
if is_negative_prompt:
|
|
conditional_prompt = original(
|
|
caption,
|
|
lyrics,
|
|
cot_text,
|
|
is_negative_prompt=False,
|
|
negative_prompt=negative_prompt,
|
|
)
|
|
cached = cache.get(conditional_prompt)
|
|
if cached and (cached.get("edited_caption") or cached.get("edited_lyrics")):
|
|
edited_caption = cached.get("edited_caption") or caption
|
|
edited_lyrics = cached.get("edited_lyrics") or lyrics
|
|
return original(
|
|
edited_caption,
|
|
edited_lyrics,
|
|
cot_text,
|
|
is_negative_prompt=True,
|
|
negative_prompt=negative_prompt,
|
|
)
|
|
return prompt
|
|
cached = cache.get(prompt)
|
|
if cached:
|
|
return cached["edited_prompt"]
|
|
if getattr(llm_handler, "_skip_prompt_edit", False):
|
|
cache[prompt] = {
|
|
"edited_prompt": prompt,
|
|
"edited_caption": None,
|
|
"edited_lyrics": None,
|
|
}
|
|
return prompt
|
|
if preloaded_prompt is not None:
|
|
edited = preloaded_prompt
|
|
else:
|
|
edited = _edit_formatted_prompt_via_file(prompt, instruction_path)
|
|
edited_caption, edited_lyrics = _extract_caption_lyrics_from_formatted_prompt(edited)
|
|
if edited != prompt:
|
|
print("INFO: Using edited draft for audio-token prompt.")
|
|
if edited_caption or edited_lyrics:
|
|
llm_handler._edited_caption = edited_caption
|
|
llm_handler._edited_lyrics = edited_lyrics
|
|
edited_instruction = _extract_instruction_from_formatted_prompt(edited)
|
|
if edited_instruction:
|
|
llm_handler._edited_instruction = edited_instruction
|
|
edited_metas = _extract_cot_metadata_from_formatted_prompt(edited)
|
|
if edited_metas:
|
|
llm_handler._edited_metas = edited_metas
|
|
cache[prompt] = {
|
|
"edited_prompt": edited,
|
|
"edited_caption": edited_caption,
|
|
"edited_lyrics": edited_lyrics,
|
|
}
|
|
return edited
|
|
|
|
llm_handler.build_formatted_prompt_with_cot = wrapped
|
|
|
|
|
|
def _prompt_int(prompt: str, default: Optional[int] = None, min_value: Optional[int] = None,
|
|
max_value: Optional[int] = None) -> Optional[int]:
|
|
default_display = "auto" if default is None else default
|
|
while True:
|
|
value = input(f"{prompt} [{default_display}]: ").strip()
|
|
if not value:
|
|
return default
|
|
try:
|
|
parsed = int(value)
|
|
except ValueError:
|
|
print("Invalid input. Please enter an integer.")
|
|
continue
|
|
if min_value is not None and parsed < min_value:
|
|
print(f"Please enter a value >= {min_value}.")
|
|
continue
|
|
if max_value is not None and parsed > max_value:
|
|
print(f"Please enter a value <= {max_value}.")
|
|
continue
|
|
return parsed
|
|
|
|
|
|
def _prompt_float(prompt: str, default: Optional[float] = None, min_value: Optional[float] = None,
|
|
max_value: Optional[float] = None) -> Optional[float]:
|
|
default_display = "auto" if default is None else default
|
|
while True:
|
|
value = input(f"{prompt} [{default_display}]: ").strip()
|
|
if not value:
|
|
return default
|
|
try:
|
|
parsed = float(value)
|
|
except ValueError:
|
|
print("Invalid input. Please enter a number.")
|
|
continue
|
|
if min_value is not None and parsed < min_value:
|
|
print(f"Please enter a value >= {min_value}.")
|
|
continue
|
|
if max_value is not None and parsed > max_value:
|
|
print(f"Please enter a value <= {max_value}.")
|
|
continue
|
|
return parsed
|
|
|
|
|
|
def _prompt_existing_file(prompt: str, default: Optional[str] = None) -> str:
|
|
while True:
|
|
suffix = f" [{default}]" if default else ""
|
|
path = input(f"{prompt}{suffix}: ").strip()
|
|
if not path and default:
|
|
path = default
|
|
if os.path.isfile(path):
|
|
return _expand_audio_path(path)
|
|
print("Invalid file path. Please try again.")
|
|
|
|
|
|
def _expand_audio_path(path_str: Optional[str]) -> Optional[str]:
|
|
if not path_str or not isinstance(path_str, str):
|
|
return path_str
|
|
try:
|
|
return Path(path_str).expanduser().resolve(strict=False).as_posix()
|
|
except Exception:
|
|
return Path(path_str).expanduser().absolute().as_posix()
|
|
|
|
|
|
def _parse_bool(value: str) -> bool:
|
|
return str(value).lower() in {"true", "1", "yes", "y"}
|
|
|
|
|
|
def _resolve_device(device: str) -> str:
|
|
if device == "auto":
|
|
if hasattr(torch, 'xpu') and torch.xpu.is_available():
|
|
return "xpu"
|
|
if torch.cuda.is_available():
|
|
return "cuda"
|
|
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
|
return "mps"
|
|
return "cpu"
|
|
return device
|
|
|
|
|
|
def _default_instruction_for_task(task_type: str, tracks: Optional[List[str]] = None) -> str:
|
|
if task_type == "lego":
|
|
track = tracks[0] if tracks else "guitar"
|
|
return TASK_INSTRUCTIONS["lego"].format(TRACK_NAME=track.upper())
|
|
if task_type == "extract":
|
|
track = tracks[0] if tracks else "vocals"
|
|
return TASK_INSTRUCTIONS["extract"].format(TRACK_NAME=track.upper())
|
|
if task_type == "complete":
|
|
tracks_list = ", ".join(tracks) if tracks else "drums, bass, guitar"
|
|
return TASK_INSTRUCTIONS["complete"].format(TRACK_CLASSES=tracks_list)
|
|
return DEFAULT_DIT_INSTRUCTION
|
|
|
|
|
|
def _apply_optional_defaults(args, params_defaults: GenerationParams, config_defaults: GenerationConfig) -> None:
|
|
optional_defaults = {
|
|
"duration": params_defaults.duration,
|
|
"bpm": params_defaults.bpm,
|
|
"keyscale": params_defaults.keyscale,
|
|
"timesignature": params_defaults.timesignature,
|
|
"vocal_language": params_defaults.vocal_language,
|
|
"inference_steps": params_defaults.inference_steps,
|
|
"seed": params_defaults.seed,
|
|
"guidance_scale": params_defaults.guidance_scale,
|
|
"use_adg": params_defaults.use_adg,
|
|
"cfg_interval_start": params_defaults.cfg_interval_start,
|
|
"cfg_interval_end": params_defaults.cfg_interval_end,
|
|
"shift": 3.0,
|
|
"infer_method": params_defaults.infer_method,
|
|
"timesteps": None,
|
|
"repainting_start": params_defaults.repainting_start,
|
|
"repainting_end": params_defaults.repainting_end,
|
|
"audio_cover_strength": params_defaults.audio_cover_strength,
|
|
"thinking": params_defaults.thinking,
|
|
"lm_temperature": params_defaults.lm_temperature,
|
|
"lm_cfg_scale": params_defaults.lm_cfg_scale,
|
|
"lm_top_k": params_defaults.lm_top_k,
|
|
"lm_top_p": params_defaults.lm_top_p,
|
|
"lm_negative_prompt": params_defaults.lm_negative_prompt,
|
|
"use_cot_metas": params_defaults.use_cot_metas,
|
|
"use_cot_caption": params_defaults.use_cot_caption,
|
|
"use_cot_lyrics": params_defaults.use_cot_lyrics,
|
|
"use_cot_language": params_defaults.use_cot_language,
|
|
"use_constrained_decoding": params_defaults.use_constrained_decoding,
|
|
"batch_size": config_defaults.batch_size,
|
|
"allow_lm_batch": config_defaults.allow_lm_batch,
|
|
"use_random_seed": config_defaults.use_random_seed,
|
|
"seeds": config_defaults.seeds,
|
|
"lm_batch_chunk_size": config_defaults.lm_batch_chunk_size,
|
|
"constrained_decoding_debug": config_defaults.constrained_decoding_debug,
|
|
"audio_format": config_defaults.audio_format,
|
|
"sample_mode": False,
|
|
"sample_query": "",
|
|
"use_format": False,
|
|
}
|
|
|
|
for key, default_value in optional_defaults.items():
|
|
if getattr(args, key, None) is None:
|
|
setattr(args, key, default_value)
|
|
|
|
|
|
def _summarize_lyrics(lyrics: Optional[str]) -> str:
|
|
if not lyrics:
|
|
return "none"
|
|
if isinstance(lyrics, str):
|
|
stripped = lyrics.strip()
|
|
if not stripped:
|
|
return "none"
|
|
if os.path.isfile(stripped):
|
|
return f"file: {os.path.basename(stripped)}"
|
|
if len(stripped) <= 60:
|
|
return stripped.replace("\n", " ")
|
|
return f"text ({len(stripped)} chars)"
|
|
return "provided"
|
|
|
|
|
|
def _print_final_parameters(
|
|
args,
|
|
params: GenerationParams,
|
|
config: GenerationConfig,
|
|
params_defaults: GenerationParams,
|
|
config_defaults: GenerationConfig,
|
|
compact: bool,
|
|
resolved_device: Optional[str] = None,
|
|
) -> None:
|
|
if not compact:
|
|
print("\n--- Final Parameters (Args) ---")
|
|
for k in sorted(vars(args).keys()):
|
|
print(f"{k}: {getattr(args, k)}")
|
|
print("------------------------------")
|
|
print("\n--- Final Parameters (GenerationParams) ---")
|
|
for k in sorted(vars(params).keys()):
|
|
print(f"{k}: {getattr(params, k)}")
|
|
print("-------------------------------------------")
|
|
print("\n--- Final Parameters (GenerationConfig) ---")
|
|
for k in sorted(vars(config).keys()):
|
|
print(f"{k}: {getattr(config, k)}")
|
|
print("-------------------------------------------\n")
|
|
return
|
|
|
|
device_display = args.device
|
|
if resolved_device and resolved_device != args.device:
|
|
device_display = f"{args.device} -> {resolved_device}"
|
|
|
|
print("\n--- Final Parameters (Summary) ---")
|
|
print(f"task_type: {params.task_type}")
|
|
print(f"caption: {params.caption or 'none'}")
|
|
print(f"lyrics: {_summarize_lyrics(params.lyrics)}")
|
|
print(f"duration: {params.duration}s")
|
|
print(f"outputs: {config.batch_size}")
|
|
if params.bpm not in (None, params_defaults.bpm):
|
|
print(f"bpm: {params.bpm}")
|
|
if params.keyscale not in (None, params_defaults.keyscale):
|
|
print(f"keyscale: {params.keyscale}")
|
|
if params.timesignature not in (None, params_defaults.timesignature):
|
|
print(f"timesignature: {params.timesignature}")
|
|
print(f"instrumental: {params.instrumental}")
|
|
print(f"thinking: {params.thinking}")
|
|
print(f"lm_model: {args.lm_model_path or 'auto'}")
|
|
print(f"dit_model: {args.config_path or 'auto'}")
|
|
print(f"backend: {args.backend}")
|
|
print(f"device: {device_display}")
|
|
print(f"audio_format: {config.audio_format}")
|
|
print(f"save_dir: {args.save_dir}")
|
|
if config.seeds:
|
|
print(f"seeds: {config.seeds}")
|
|
else:
|
|
print(f"seed: {params.seed} (random={config.use_random_seed})")
|
|
print("-------------------------------\n")
|
|
|
|
|
|
def _build_meta_dict(params: GenerationParams) -> Optional[dict]:
|
|
meta = {}
|
|
if params.bpm is not None:
|
|
meta["bpm"] = params.bpm
|
|
if params.timesignature:
|
|
meta["timesignature"] = params.timesignature
|
|
if params.keyscale:
|
|
meta["keyscale"] = params.keyscale
|
|
if params.duration is not None:
|
|
meta["duration"] = params.duration
|
|
return meta or None
|
|
|
|
|
|
def _print_dit_prompt(dit_handler: "AceStepHandler", params: GenerationParams) -> None:
|
|
meta = _build_meta_dict(params)
|
|
caption_input, lyrics_input = dit_handler.build_dit_inputs(
|
|
task=params.task_type,
|
|
instruction=params.instruction,
|
|
caption=params.caption or "",
|
|
lyrics=params.lyrics or "",
|
|
metas=meta,
|
|
vocal_language=params.vocal_language or "unknown",
|
|
)
|
|
print("\n--- Final DiT Prompt (Caption Branch) ---")
|
|
print(caption_input)
|
|
print("\n--- Final DiT Prompt (Lyrics Branch) ---")
|
|
print(lyrics_input)
|
|
print("----------------------------------------\n")
|
|
|
|
|
|
def run_wizard(args, configure_only: bool = False, default_config_path: Optional[str] = None,
|
|
params_defaults: Optional[GenerationParams] = None,
|
|
config_defaults: Optional[GenerationConfig] = None):
|
|
"""
|
|
Runs an interactive wizard to set generation parameters.
|
|
"""
|
|
print("Welcome to the ACE-Step Music Generation Wizard!")
|
|
print("This will guide you through creating your music.")
|
|
print("Press Ctrl+C at any time to exit.")
|
|
print("Note: Required models will be auto-downloaded if missing.")
|
|
print("-" * 30)
|
|
|
|
try:
|
|
# Task selection
|
|
print("\n--- Task Type ---")
|
|
print("1. text2music - generate music from text/lyrics.")
|
|
print("2. cover - transform existing audio into a new style.")
|
|
print("3. repaint - regenerate a specific time segment of audio.")
|
|
print("4. lego - generate a specific instrument track in context.")
|
|
print("5. extract - isolate a specific instrument track from a mix.")
|
|
print("6. complete - complete/extend partial tracks with new instruments.")
|
|
task_map = {
|
|
"1": "text2music",
|
|
"2": "cover",
|
|
"3": "repaint",
|
|
"4": "lego",
|
|
"5": "extract",
|
|
"6": "complete",
|
|
}
|
|
current_task = args.task_type or "text2music"
|
|
task_default = next((k for k, v in task_map.items() if v == current_task), "1")
|
|
task_choice = input(f"Choose a task (1-6) [default: {task_default}]: ").strip()
|
|
if not task_choice:
|
|
task_choice = task_default
|
|
args.task_type = task_map.get(task_choice, "text2music")
|
|
if args.task_type in {"lego", "extract", "complete"}:
|
|
print("Note: This task requires a base DiT model (acestep-v15-base). It will be auto-downloaded if missing.")
|
|
|
|
# Model selection (DiT)
|
|
dit_handler = AceStepHandler()
|
|
available_dit_models = dit_handler.get_available_acestep_v15_models()
|
|
base_only = args.task_type in {"lego", "extract", "complete"}
|
|
if base_only and available_dit_models:
|
|
available_dit_models = [m for m in available_dit_models if "base" in m.lower()]
|
|
|
|
if base_only and args.config_path and "base" not in str(args.config_path).lower():
|
|
args.config_path = None
|
|
|
|
if base_only:
|
|
if available_dit_models:
|
|
if args.config_path in available_dit_models:
|
|
selected = args.config_path
|
|
else:
|
|
selected = available_dit_models[0]
|
|
args.config_path = selected
|
|
print(f"\nNote: This task requires a base model. Using: {selected}")
|
|
else:
|
|
print("\nNote: This task requires a base model (e.g., 'acestep-v15-base'). It will be auto-downloaded if missing.")
|
|
elif available_dit_models:
|
|
selected = _prompt_choice_from_list(
|
|
"--- Available DiT Models ---",
|
|
available_dit_models,
|
|
default=args.config_path,
|
|
allow_custom=True,
|
|
)
|
|
if selected is not None:
|
|
args.config_path = selected
|
|
else:
|
|
print("\nNote: No local DiT models found. The main model will be auto-downloaded during initialization.")
|
|
|
|
# Model selection (LM)
|
|
llm_handler = LLMHandler()
|
|
available_lm_models = llm_handler.get_available_5hz_lm_models()
|
|
if available_lm_models:
|
|
selected_lm = _prompt_choice_from_list(
|
|
"--- Available LM Models ---",
|
|
available_lm_models,
|
|
default=args.lm_model_path,
|
|
allow_custom=True,
|
|
)
|
|
if selected_lm is not None:
|
|
args.lm_model_path = selected_lm
|
|
else:
|
|
print("\nNote: No local LM models found. If LM features are enabled, a default LM will be auto-downloaded.")
|
|
|
|
# Task-specific inputs
|
|
if args.task_type in {"cover", "repaint", "lego", "extract", "complete"}:
|
|
args.src_audio = _prompt_existing_file("Enter path to source audio file", default=args.src_audio)
|
|
|
|
if args.task_type == "repaint":
|
|
args.repainting_start = _prompt_float(
|
|
"Repaint start time in seconds", args.repainting_start
|
|
)
|
|
args.repainting_end = _prompt_float(
|
|
"Repaint end time in seconds", args.repainting_end
|
|
)
|
|
|
|
if args.task_type in {"lego", "extract"}:
|
|
print("\nAvailable tracks:")
|
|
print(", ".join(TRACK_CHOICES))
|
|
track_default = args.lego_track if args.task_type == "lego" else args.extract_track
|
|
track = _prompt_with_default("Choose a track", track_default, required=True)
|
|
if track not in TRACK_CHOICES:
|
|
print("Unknown track. Using as-is.")
|
|
if args.task_type == "lego":
|
|
args.lego_track = track
|
|
else:
|
|
args.extract_track = track
|
|
if not args.instruction or args.instruction == DEFAULT_DIT_INSTRUCTION:
|
|
args.instruction = _default_instruction_for_task(args.task_type, [track])
|
|
args.instruction = _prompt_with_default("Instruction", args.instruction, required=True)
|
|
|
|
if args.task_type == "complete":
|
|
print("\nAvailable tracks:")
|
|
print(", ".join(TRACK_CHOICES))
|
|
tracks_raw = _prompt_with_default("Choose tracks (comma-separated)", args.complete_tracks, required=True)
|
|
tracks = [t.strip() for t in tracks_raw.split(",") if t.strip()]
|
|
args.complete_tracks = ",".join(tracks)
|
|
if not args.instruction or args.instruction == DEFAULT_DIT_INSTRUCTION:
|
|
args.instruction = _default_instruction_for_task(args.task_type, tracks)
|
|
args.instruction = _prompt_with_default("Instruction", args.instruction, required=True)
|
|
|
|
if args.task_type in {"cover", "repaint", "lego", "complete"}:
|
|
args.caption = _prompt_with_default(
|
|
"Enter a music description (e.g., 'upbeat electronic dance music')",
|
|
args.caption,
|
|
required=True,
|
|
)
|
|
elif args.task_type == "text2music":
|
|
args.sample_mode = _prompt_bool("Use Simple Mode (auto-generate caption/lyrics via LM)", args.sample_mode)
|
|
if args.sample_mode:
|
|
args.sample_query = _prompt_with_default(
|
|
"Describe the music you want (for auto-generation)",
|
|
args.sample_query,
|
|
required=False,
|
|
)
|
|
if not args.sample_mode:
|
|
caption = _prompt_with_default(
|
|
"Enter a music description (optional if you provide lyrics)",
|
|
args.caption,
|
|
required=False,
|
|
)
|
|
if caption:
|
|
args.caption = caption
|
|
|
|
# Lyrics
|
|
if args.task_type in {"text2music", "cover", "repaint", "lego", "complete"} and not args.sample_mode:
|
|
print("\n--- Lyrics Options ---")
|
|
print("1. Instrumental (no lyrics).")
|
|
print("2. Generate lyrics automatically.")
|
|
print("3. Provide path to a .txt file.")
|
|
print("4. Paste lyrics directly.")
|
|
|
|
if args.instrumental or args.lyrics == "[Instrumental]":
|
|
default_choice = "1"
|
|
elif args.use_cot_lyrics:
|
|
default_choice = "2"
|
|
elif args.lyrics and isinstance(args.lyrics, str) and os.path.isfile(args.lyrics):
|
|
default_choice = "3"
|
|
elif args.lyrics:
|
|
default_choice = "4"
|
|
else:
|
|
default_choice = "1"
|
|
choice = input(f"Your choice (1-4) [default: {default_choice}]: ").strip()
|
|
if not choice:
|
|
choice = default_choice
|
|
|
|
if choice == "1": # Instrumental
|
|
args.instrumental = True
|
|
args.lyrics = "[Instrumental]"
|
|
args.use_cot_lyrics = False
|
|
print("Instrumental music will be generated.")
|
|
elif choice == "2": # Generate lyrics automatically
|
|
args.use_cot_lyrics = True
|
|
args.lyrics = ""
|
|
args.instrumental = False
|
|
print("Lyrics will be generated automatically.")
|
|
elif choice == "3":
|
|
args.instrumental = False
|
|
args.use_cot_lyrics = False
|
|
default_lyrics_path = args.lyrics if isinstance(args.lyrics, str) and os.path.isfile(args.lyrics) else None
|
|
while True:
|
|
lyrics_path = _prompt_existing_file("Please enter the path to your .txt lyrics file", default_lyrics_path)
|
|
if lyrics_path.endswith('.txt'):
|
|
args.lyrics = lyrics_path
|
|
print(f"Lyrics will be loaded from: {lyrics_path}")
|
|
break
|
|
print("Invalid file path or not a .txt file. Please try again.")
|
|
elif choice == "4":
|
|
args.instrumental = False
|
|
args.use_cot_lyrics = False
|
|
default_lyrics = args.lyrics if isinstance(args.lyrics, str) and args.lyrics and not os.path.isfile(args.lyrics) else None
|
|
args.lyrics = _prompt_with_default("Paste lyrics (single line or use \\n)", default_lyrics, required=True)
|
|
|
|
if not args.instrumental:
|
|
lang = _prompt_with_default(
|
|
"Vocal language (e.g., 'en', 'zh', 'unknown')",
|
|
args.vocal_language,
|
|
required=False
|
|
).lower()
|
|
if lang:
|
|
args.vocal_language = lang
|
|
|
|
if args.use_cot_lyrics:
|
|
if not args.caption:
|
|
args.caption = _prompt_non_empty("Enter a music description for lyric generation: ")
|
|
if not args.thinking:
|
|
print("INFO: Automatic lyric generation requires the LM handler. Enabling LM 'thinking'.")
|
|
args.thinking = True
|
|
|
|
args.batch_size = _prompt_int(
|
|
"Number of outputs (audio clips) to generate",
|
|
args.batch_size if args.batch_size is not None else 2,
|
|
min_value=1,
|
|
)
|
|
|
|
advanced = input("\nConfigure advanced parameters? (y/n) [default: n]: ").lower()
|
|
if advanced == 'y':
|
|
if args.task_type == "text2music" and not args.sample_mode:
|
|
args.use_format = _prompt_bool("Use format_sample to enhance caption/lyrics", args.use_format)
|
|
print("\n--- Optional Metadata ---")
|
|
args.duration = _prompt_float("Duration in seconds (10-600)", args.duration, min_value=10, max_value=600)
|
|
args.bpm = _prompt_int("BPM (30-300, empty for auto)", args.bpm, min_value=30, max_value=300)
|
|
args.keyscale = _prompt_with_default("Keyscale (e.g., 'C Major', empty for auto)", args.keyscale)
|
|
args.timesignature = _prompt_with_default("Time signature (e.g., '4/4', empty for auto)", args.timesignature)
|
|
args.vocal_language = _prompt_with_default("Vocal language (e.g., 'en', 'zh', 'unknown')", args.vocal_language)
|
|
|
|
print("\n--- Advanced DiT Settings ---")
|
|
args.seed = _prompt_int("Random seed (-1 for random)", args.seed)
|
|
args.inference_steps = _prompt_int("Inference steps", args.inference_steps, min_value=1)
|
|
if args.config_path and 'base' in args.config_path:
|
|
args.guidance_scale = _prompt_float("Guidance scale (for base models)", args.guidance_scale)
|
|
args.use_adg = _prompt_bool("Enable Adaptive Dual Guidance (ADG)", args.use_adg)
|
|
args.cfg_interval_start = _prompt_float("CFG interval start (0.0-1.0)", args.cfg_interval_start, 0.0, 1.0)
|
|
args.cfg_interval_end = _prompt_float("CFG interval end (0.0-1.0)", args.cfg_interval_end, 0.0, 1.0)
|
|
args.shift = _prompt_float("Timestep shift (1.0-5.0)", args.shift, 1.0, 5.0)
|
|
args.infer_method = _prompt_with_default("Inference method (ode/sde)", args.infer_method)
|
|
timesteps_input = _prompt_with_default(
|
|
"Custom timesteps list (e.g., [0.97, 0.5, 0])",
|
|
args.timesteps,
|
|
required=False,
|
|
)
|
|
if timesteps_input:
|
|
args.timesteps = timesteps_input
|
|
|
|
if args.task_type == "cover":
|
|
args.audio_cover_strength = _prompt_float(
|
|
"Audio cover strength (0.0-1.0)", args.audio_cover_strength, 0.0, 1.0
|
|
)
|
|
|
|
print("\n--- Advanced LM Settings ---")
|
|
args.thinking = _prompt_bool("Enable LM 'thinking'", args.thinking)
|
|
args.lm_temperature = _prompt_float("LM temperature (0.0-2.0)", args.lm_temperature, 0.0, 2.0)
|
|
args.lm_cfg_scale = _prompt_float("LM CFG scale", args.lm_cfg_scale)
|
|
args.lm_top_k = _prompt_int("LM top-k (0 disables)", args.lm_top_k, min_value=0)
|
|
args.lm_top_p = _prompt_float("LM top-p (0.0-1.0)", args.lm_top_p, 0.0, 1.0)
|
|
args.lm_negative_prompt = _prompt_with_default("LM negative prompt", args.lm_negative_prompt)
|
|
args.use_cot_metas = _prompt_bool("Use CoT for metadata", args.use_cot_metas)
|
|
args.use_cot_caption = _prompt_bool("Use CoT for caption refinement", args.use_cot_caption)
|
|
args.use_cot_lyrics = _prompt_bool("Use CoT for lyrics generation", args.use_cot_lyrics)
|
|
args.use_cot_language = _prompt_bool("Use CoT for language detection", args.use_cot_language)
|
|
args.use_constrained_decoding = _prompt_bool("Use constrained decoding", args.use_constrained_decoding)
|
|
|
|
print("\n--- Output Settings ---")
|
|
args.save_dir = _prompt_with_default("Save directory", args.save_dir)
|
|
args.audio_format = _prompt_with_default("Audio format (mp3/wav/flac)", args.audio_format)
|
|
# Batch size already captured above.
|
|
args.use_random_seed = _prompt_bool("Use random seed per batch", args.use_random_seed)
|
|
seeds_input = _prompt_with_default(
|
|
"Custom seeds (comma/space separated, leave empty for random)",
|
|
"",
|
|
required=False,
|
|
)
|
|
if seeds_input:
|
|
seeds = [s for s in seeds_input.replace(",", " ").split() if s.strip()]
|
|
try:
|
|
args.seeds = [int(s) for s in seeds]
|
|
except ValueError:
|
|
print("Invalid seeds input. Ignoring custom seeds.")
|
|
args.allow_lm_batch = _prompt_bool("Allow LM batch processing", args.allow_lm_batch)
|
|
args.lm_batch_chunk_size = _prompt_int("LM batch chunk size", args.lm_batch_chunk_size, min_value=1)
|
|
args.constrained_decoding_debug = _prompt_bool("Constrained decoding debug", args.constrained_decoding_debug)
|
|
else:
|
|
if params_defaults and config_defaults:
|
|
_apply_optional_defaults(args, params_defaults, config_defaults)
|
|
|
|
# Ensure LM thinking is enabled when lyric generation is requested.
|
|
if args.use_cot_lyrics and not args.thinking:
|
|
print("INFO: Automatic lyric generation requires the LM handler. Enabling LM 'thinking'.")
|
|
args.thinking = True
|
|
|
|
print("\n--- Summary ---")
|
|
print(f"Task: {args.task_type}")
|
|
if args.caption:
|
|
print(f"Description: {args.caption}")
|
|
if args.task_type in {"lego", "extract", "complete"}:
|
|
print(f"Instruction: {args.instruction}")
|
|
if args.src_audio:
|
|
print(f"Source audio: {args.src_audio}")
|
|
print(f"Duration: {args.duration}s")
|
|
print(f"Outputs: {args.batch_size}")
|
|
if args.instrumental:
|
|
print("Lyrics: Instrumental")
|
|
elif args.use_cot_lyrics:
|
|
print(f"Lyrics: Auto-generated ({args.vocal_language})")
|
|
elif args.lyrics and os.path.isfile(args.lyrics):
|
|
print(f"Lyrics: Provided from file ({args.lyrics})")
|
|
elif args.lyrics:
|
|
print(f"Lyrics: Provided as text")
|
|
|
|
print("-" * 30)
|
|
if not configure_only:
|
|
confirm = input("Start generation with these settings? (y/n) [default: y]: ").lower()
|
|
if confirm == 'n':
|
|
print("Generation cancelled.")
|
|
sys.exit(0)
|
|
|
|
default_filename = default_config_path or "config.toml"
|
|
config_filename = input(f"\nEnter filename to save configuration [{default_filename}]: ")
|
|
if not config_filename:
|
|
config_filename = default_filename
|
|
if not config_filename.endswith(".toml"):
|
|
config_filename += ".toml"
|
|
|
|
try:
|
|
config_to_save = {
|
|
k: v for k, v in vars(args).items()
|
|
if k not in ['config'] and not k.startswith('_')
|
|
}
|
|
with open(config_filename, 'w') as f:
|
|
toml.dump(config_to_save, f)
|
|
print(f"Configuration saved to {config_filename}")
|
|
print(f"You can reuse it next time with: python cli.py -c {config_filename}")
|
|
except Exception as e:
|
|
print(f"Error saving configuration: {e}. Please try again.")
|
|
|
|
except (KeyboardInterrupt, EOFError):
|
|
print("\nWizard cancelled. Exiting.")
|
|
sys.exit(0)
|
|
|
|
return args, not configure_only
|
|
|
|
|
|
def main():
|
|
"""
|
|
Main function to run ACE-Step music generation from the command line.
|
|
"""
|
|
|
|
gpu_config = get_gpu_config()
|
|
set_global_gpu_config(gpu_config)
|
|
mps_available = is_mps_platform()
|
|
# Mac (Apple Silicon) uses unified memory — offloading provides no benefit
|
|
auto_offload = (not mps_available) and gpu_config.gpu_memory_gb > 0 and gpu_config.gpu_memory_gb < 16
|
|
print(f"\n{'='*60}")
|
|
print("GPU Configuration Detected:")
|
|
print(f"{'='*60}")
|
|
print(f" GPU Memory: {gpu_config.gpu_memory_gb:.2f} GiB")
|
|
print(f" Configuration Tier: {gpu_config.tier}")
|
|
print(f" Max Duration (with LM): {gpu_config.max_duration_with_lm}s ({gpu_config.max_duration_with_lm // 60} min)")
|
|
print(f" Max Duration (without LM): {gpu_config.max_duration_without_lm}s ({gpu_config.max_duration_without_lm // 60} min)")
|
|
print(f" Max Batch Size (with LM): {gpu_config.max_batch_size_with_lm}")
|
|
print(f" Max Batch Size (without LM): {gpu_config.max_batch_size_without_lm}")
|
|
print(f" Default LM Init: {gpu_config.init_lm_default}")
|
|
print(f" Available LM Models: {gpu_config.available_lm_models or 'None'}")
|
|
print(f"{'='*60}\n")
|
|
|
|
if auto_offload:
|
|
print("Auto-enabling CPU offload (GPU < 16GB)")
|
|
elif gpu_config.gpu_memory_gb > 0:
|
|
print("CPU offload disabled by default (GPU >= 16GB)")
|
|
elif mps_available:
|
|
print("MPS detected, running on Apple GPU")
|
|
else:
|
|
print("No GPU detected, running on CPU")
|
|
|
|
params_defaults = GenerationParams()
|
|
config_defaults = GenerationConfig()
|
|
|
|
parser = argparse.ArgumentParser(
|
|
description="ACE-Step 1.5: Music generation (wizard/config only).",
|
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
|
)
|
|
parser.add_argument("-c", "--config", type=str, help="Path to a TOML configuration file to load.")
|
|
parser.add_argument("--configure", action="store_true", help="Run wizard to save configuration without generating.")
|
|
parser.add_argument(
|
|
"--backend",
|
|
type=str,
|
|
default=None,
|
|
choices=["vllm", "pt", "mlx"],
|
|
help="5Hz LM backend. Auto-detected if not specified: 'mlx' on Apple Silicon, 'vllm' on CUDA, 'pt' otherwise.",
|
|
)
|
|
parser.add_argument(
|
|
"--log-level",
|
|
type=str,
|
|
default="INFO",
|
|
help="Logging level for internal modules (TRACE/DEBUG/INFO/WARNING/ERROR/CRITICAL).",
|
|
)
|
|
cli_args = parser.parse_args()
|
|
|
|
_configure_logging(level=cli_args.log_level)
|
|
|
|
default_batch_size = 1 if not cli_args.config else config_defaults.batch_size
|
|
|
|
# Auto-detect MLX on Apple Silicon, fall back to vllm
|
|
if mps_available:
|
|
try:
|
|
import mlx.core # noqa: F401
|
|
default_backend = "mlx"
|
|
print("Apple Silicon detected with MLX available. Using MLX backend.")
|
|
except ImportError:
|
|
default_backend = "vllm"
|
|
else:
|
|
default_backend = "vllm"
|
|
|
|
defaults = {
|
|
"project_root": _get_project_root(),
|
|
"config_path": None,
|
|
"checkpoint_dir": str(_get_default_checkpoint_dir()),
|
|
"lm_model_path": None,
|
|
"backend": default_backend,
|
|
"device": "auto",
|
|
"use_flash_attention": None,
|
|
"offload_to_cpu": auto_offload,
|
|
"offload_dit_to_cpu": False,
|
|
"save_dir": "output",
|
|
"audio_format": config_defaults.audio_format,
|
|
"caption": "",
|
|
"prompt": "",
|
|
"lyrics": None,
|
|
"duration": params_defaults.duration,
|
|
"instrumental": False,
|
|
"bpm": params_defaults.bpm,
|
|
"keyscale": params_defaults.keyscale,
|
|
"timesignature": params_defaults.timesignature,
|
|
"vocal_language": params_defaults.vocal_language,
|
|
"task_type": params_defaults.task_type,
|
|
"instruction": params_defaults.instruction,
|
|
"reference_audio": params_defaults.reference_audio,
|
|
"src_audio": params_defaults.src_audio,
|
|
"repainting_start": params_defaults.repainting_start,
|
|
"repainting_end": params_defaults.repainting_end,
|
|
"audio_cover_strength": params_defaults.audio_cover_strength,
|
|
"lego_track": "",
|
|
"extract_track": "",
|
|
"complete_tracks": "",
|
|
"sample_mode": False,
|
|
"sample_query": "",
|
|
"use_format": False,
|
|
"inference_steps": params_defaults.inference_steps,
|
|
"seed": params_defaults.seed,
|
|
"guidance_scale": params_defaults.guidance_scale,
|
|
"use_adg": params_defaults.use_adg,
|
|
"shift": 3.0,
|
|
"infer_method": params_defaults.infer_method,
|
|
"timesteps": None,
|
|
"thinking": gpu_config.init_lm_default,
|
|
"lm_temperature": params_defaults.lm_temperature,
|
|
"lm_cfg_scale": params_defaults.lm_cfg_scale,
|
|
"lm_top_k": params_defaults.lm_top_k,
|
|
"lm_top_p": params_defaults.lm_top_p,
|
|
"use_cot_metas": params_defaults.use_cot_metas,
|
|
"use_cot_caption": params_defaults.use_cot_caption,
|
|
"use_cot_lyrics": params_defaults.use_cot_lyrics,
|
|
"use_cot_language": params_defaults.use_cot_language,
|
|
"use_constrained_decoding": params_defaults.use_constrained_decoding,
|
|
"batch_size": default_batch_size,
|
|
"seeds": None,
|
|
"use_random_seed": config_defaults.use_random_seed,
|
|
"allow_lm_batch": config_defaults.allow_lm_batch,
|
|
"lm_batch_chunk_size": config_defaults.lm_batch_chunk_size,
|
|
"constrained_decoding_debug": config_defaults.constrained_decoding_debug,
|
|
"audio_codes": "",
|
|
"cfg_interval_start": params_defaults.cfg_interval_start,
|
|
"cfg_interval_end": params_defaults.cfg_interval_end,
|
|
"lm_negative_prompt": params_defaults.lm_negative_prompt,
|
|
"log_level": cli_args.log_level,
|
|
}
|
|
|
|
args = argparse.Namespace(**defaults)
|
|
args.config = None
|
|
if cli_args.config:
|
|
if not os.path.exists(cli_args.config):
|
|
parser.error(f"Config file not found: {cli_args.config}")
|
|
try:
|
|
with open(cli_args.config, 'r') as f:
|
|
config_from_file = toml.load(f)
|
|
print(f"Configuration loaded from {cli_args.config}")
|
|
except Exception as e:
|
|
parser.error(f"Error loading TOML config file {cli_args.config}: {e}")
|
|
for key, value in config_from_file.items():
|
|
setattr(args, key, value)
|
|
args.config = cli_args.config
|
|
|
|
# CLI --backend overrides config file and auto-detection
|
|
if cli_args.backend is not None:
|
|
args.backend = cli_args.backend
|
|
|
|
if cli_args.configure:
|
|
args, _ = run_wizard(
|
|
args,
|
|
configure_only=True,
|
|
default_config_path=cli_args.config,
|
|
params_defaults=params_defaults,
|
|
config_defaults=config_defaults,
|
|
)
|
|
print("Configuration complete. Exiting without generation.")
|
|
sys.exit(0)
|
|
|
|
if not cli_args.config:
|
|
args, should_generate = run_wizard(
|
|
args,
|
|
configure_only=False,
|
|
default_config_path=None,
|
|
params_defaults=params_defaults,
|
|
config_defaults=config_defaults,
|
|
)
|
|
if not should_generate:
|
|
print("Configuration complete. Exiting without generation.")
|
|
sys.exit(0)
|
|
|
|
# --- Post-parsing Setup ---
|
|
if args.use_cot_lyrics and not args.thinking:
|
|
print("INFO: Automatic lyric generation requires the LM handler. Forcing --thinking=True.")
|
|
args.thinking = True
|
|
|
|
if not args.project_root:
|
|
args.project_root = _get_project_root()
|
|
else:
|
|
args.project_root = os.path.abspath(os.path.expanduser(str(args.project_root)))
|
|
|
|
if args.checkpoint_dir:
|
|
args.checkpoint_dir = os.path.expanduser(str(args.checkpoint_dir))
|
|
if not os.path.isabs(args.checkpoint_dir):
|
|
args.checkpoint_dir = os.path.join(args.project_root, args.checkpoint_dir)
|
|
|
|
if args.src_audio:
|
|
args.src_audio = _expand_audio_path(args.src_audio)
|
|
if args.reference_audio:
|
|
args.reference_audio = _expand_audio_path(args.reference_audio)
|
|
|
|
device = _resolve_device(args.device)
|
|
|
|
# --- Argument Post-processing ---
|
|
try:
|
|
timesteps = _parse_timesteps_input(args.timesteps)
|
|
if args.timesteps and timesteps is None:
|
|
raise ValueError("Timesteps must be a list of numbers or a comma-separated string.")
|
|
except ValueError as e:
|
|
parser.error(f"Invalid format for timesteps. Expected a list of numbers (e.g., '[1.0, 0.5, 0.0]' or '0.97,0.5,0'). Error: {e}")
|
|
|
|
if args.seeds:
|
|
args.batch_size = len(args.seeds)
|
|
args.use_random_seed = False
|
|
args.seed = -1
|
|
|
|
if args.instrumental and not args.lyrics:
|
|
args.lyrics = "[Instrumental]"
|
|
elif isinstance(args.lyrics, str) and args.lyrics.strip().lower() in {"[inst]", "[instrumental]"}:
|
|
args.instrumental = True
|
|
|
|
# --- Task-specific validation and instruction helpers ---
|
|
if args.task_type in {"cover", "repaint", "lego", "extract", "complete"}:
|
|
if not args.src_audio:
|
|
parser.error(f"--src_audio is required for task_type '{args.task_type}'.")
|
|
|
|
if args.task_type in {"cover", "repaint", "lego", "complete"}:
|
|
if not args.caption:
|
|
parser.error(f"--caption is required for task_type '{args.task_type}'.")
|
|
|
|
if args.task_type == "text2music":
|
|
if not args.caption and not args.lyrics:
|
|
if not args.sample_mode and not args.sample_query:
|
|
parser.error("--caption or --lyrics is required for text2music.")
|
|
if args.use_cot_lyrics and not args.caption:
|
|
parser.error("--use_cot_lyrics requires --caption for lyric generation.")
|
|
if args.sample_mode or args.sample_query:
|
|
args.sample_mode = True
|
|
else:
|
|
if args.sample_mode or args.sample_query:
|
|
parser.error("--sample_mode/sample_query are only supported for task_type 'text2music'.")
|
|
|
|
if args.sample_mode and args.use_cot_lyrics:
|
|
print("INFO: sample_mode enabled. Disabling --use_cot_lyrics.")
|
|
args.use_cot_lyrics = False
|
|
|
|
# Auto-select instruction based on task_type if user didn't provide a custom instruction.
|
|
# Align with api_server behavior and TASK_INSTRUCTIONS defaults.
|
|
if args.instruction == DEFAULT_DIT_INSTRUCTION and args.task_type in TASK_INSTRUCTIONS:
|
|
if args.task_type in {"text2music", "cover", "repaint"}:
|
|
args.instruction = TASK_INSTRUCTIONS[args.task_type]
|
|
|
|
# Base-model-only task enforcement
|
|
base_only_tasks = {"lego", "extract", "complete"}
|
|
if args.task_type in base_only_tasks and args.config_path:
|
|
if "base" not in str(args.config_path).lower():
|
|
parser.error(f"task_type '{args.task_type}' requires a base model config (e.g., 'acestep-v15-base').")
|
|
|
|
if args.task_type == "repaint":
|
|
if args.repainting_end != -1 and args.repainting_end <= args.repainting_start:
|
|
parser.error("--repainting_end must be greater than --repainting_start (or -1).")
|
|
|
|
if args.task_type in {"lego", "extract", "complete"}:
|
|
has_custom_instruction = bool(args.instruction and args.instruction.strip() and args.instruction.strip() != params_defaults.instruction)
|
|
if not has_custom_instruction:
|
|
if args.task_type == "lego":
|
|
if not args.lego_track:
|
|
parser.error("--instruction or --lego_track is required for lego task.")
|
|
args.instruction = _default_instruction_for_task("lego", [args.lego_track.strip()])
|
|
elif args.task_type == "extract":
|
|
if not args.extract_track:
|
|
parser.error("--instruction or --extract_track is required for extract task.")
|
|
args.instruction = _default_instruction_for_task("extract", [args.extract_track.strip()])
|
|
elif args.task_type == "complete":
|
|
if not args.complete_tracks:
|
|
parser.error("--instruction or --complete_tracks is required for complete task.")
|
|
tracks = [t.strip() for t in args.complete_tracks.split(",") if t.strip()]
|
|
if not tracks:
|
|
parser.error("--complete_tracks must contain at least one track.")
|
|
args.instruction = _default_instruction_for_task("complete", tracks)
|
|
|
|
# Handle lyrics argument
|
|
lyrics_arg = args.lyrics
|
|
if isinstance(lyrics_arg, str) and lyrics_arg:
|
|
lyrics_arg = os.path.expanduser(lyrics_arg)
|
|
if not os.path.isabs(lyrics_arg):
|
|
# Resolve relative lyrics path against config file location first, then project_root.
|
|
resolved = None
|
|
if args.config:
|
|
config_dir = os.path.dirname(os.path.abspath(args.config))
|
|
candidate = os.path.join(config_dir, lyrics_arg)
|
|
if os.path.isfile(candidate):
|
|
resolved = candidate
|
|
if resolved is None and args.project_root:
|
|
candidate = os.path.join(os.path.abspath(args.project_root), lyrics_arg)
|
|
if os.path.isfile(candidate):
|
|
resolved = candidate
|
|
if resolved is not None:
|
|
lyrics_arg = resolved
|
|
|
|
if lyrics_arg is not None:
|
|
if lyrics_arg == "generate":
|
|
args.use_cot_lyrics = True
|
|
args.lyrics = ""
|
|
print("Lyrics generation enabled.")
|
|
elif os.path.isfile(lyrics_arg):
|
|
print(f"INFO: Attempting to load lyrics from file: {lyrics_arg}")
|
|
try:
|
|
with open(lyrics_arg, 'r', encoding='utf-8') as f:
|
|
args.lyrics = f.read()
|
|
print(f"Lyrics loaded from file: {lyrics_arg}")
|
|
except Exception as e:
|
|
parser.error(f"Could not read lyrics file {lyrics_arg}. Error: {e}")
|
|
# else: lyrics is a string, use as is.
|
|
|
|
# --- Handler Initialization ---
|
|
if args.backend == "pyTorch":
|
|
args.backend = "pt"
|
|
if args.backend not in {"vllm", "pt", "mlx"}:
|
|
args.backend = "vllm"
|
|
|
|
print("Initializing ACE-Step handlers...")
|
|
dit_handler = AceStepHandler()
|
|
llm_handler = LLMHandler()
|
|
|
|
base_only_tasks = {"lego", "extract", "complete"}
|
|
skip_lm_tasks = {"cover", "repaint"}
|
|
requires_lm = (
|
|
args.task_type not in skip_lm_tasks and (
|
|
args.thinking
|
|
or args.sample_mode
|
|
or bool(args.sample_query and str(args.sample_query).strip())
|
|
or args.use_format
|
|
or args.use_cot_metas
|
|
or args.use_cot_caption
|
|
or args.use_cot_lyrics
|
|
or args.use_cot_language
|
|
)
|
|
)
|
|
|
|
if args.config_path is None:
|
|
available_models = dit_handler.get_available_acestep_v15_models()
|
|
if args.task_type in base_only_tasks and available_models:
|
|
available_models = [m for m in available_models if "base" in m.lower()]
|
|
if not available_models:
|
|
print("No DiT models found. Downloading main model (acestep-v15-turbo + core components)...")
|
|
from acestep.model_downloader import ensure_main_model, get_checkpoints_dir
|
|
checkpoints_dir = get_checkpoints_dir()
|
|
success, msg = ensure_main_model(checkpoints_dir)
|
|
print(msg)
|
|
if not success:
|
|
parser.error(f"Failed to download main model: {msg}")
|
|
available_models = dit_handler.get_available_acestep_v15_models()
|
|
if args.task_type in base_only_tasks and available_models:
|
|
available_models = [m for m in available_models if "base" in m.lower()]
|
|
if args.task_type in base_only_tasks and not available_models:
|
|
print("Base-only task selected. Downloading base DiT model (acestep-v15-base)...")
|
|
from acestep.model_downloader import ensure_dit_model, get_checkpoints_dir
|
|
checkpoints_dir = get_checkpoints_dir()
|
|
success, msg = ensure_dit_model("acestep-v15-base", checkpoints_dir)
|
|
print(msg)
|
|
if not success:
|
|
parser.error(f"Failed to download base DiT model: {msg}")
|
|
available_models = dit_handler.get_available_acestep_v15_models()
|
|
if available_models:
|
|
available_models = [m for m in available_models if "base" in m.lower()]
|
|
if available_models:
|
|
if args.task_type in {"lego", "extract", "complete"}:
|
|
preferred = "acestep-v15-base"
|
|
else:
|
|
preferred = "acestep-v15-turbo"
|
|
args.config_path = preferred if preferred in available_models else available_models[0]
|
|
print(f"Auto-selected config_path: {args.config_path}")
|
|
else:
|
|
parser.error("No available DiT models found. Please specify --config_path.")
|
|
if args.task_type in {"lego", "extract", "complete"} and "base" not in str(args.config_path).lower():
|
|
parser.error(f"task_type '{args.task_type}' requires a base model config (e.g., 'acestep-v15-base').")
|
|
|
|
# Ensure required DiT/main models are present for the selected task/model.
|
|
from acestep.model_downloader import (
|
|
ensure_main_model,
|
|
ensure_dit_model,
|
|
get_checkpoints_dir,
|
|
check_main_model_exists,
|
|
check_model_exists,
|
|
SUBMODEL_REGISTRY,
|
|
)
|
|
checkpoints_dir = get_checkpoints_dir()
|
|
if not check_main_model_exists(checkpoints_dir):
|
|
print("Main model components not found. Downloading main model...")
|
|
success, msg = ensure_main_model(checkpoints_dir)
|
|
print(msg)
|
|
if not success:
|
|
parser.error(f"Failed to download main model: {msg}")
|
|
if args.config_path:
|
|
config_name = str(args.config_path)
|
|
known_models = {"acestep-v15-turbo"} | set(SUBMODEL_REGISTRY.keys())
|
|
if check_model_exists(config_name, checkpoints_dir):
|
|
pass
|
|
elif config_name in known_models:
|
|
success, msg = ensure_dit_model(config_name, checkpoints_dir)
|
|
if not success:
|
|
parser.error(f"Failed to download DiT model '{config_name}': {msg}")
|
|
else:
|
|
print(f"Warning: DiT model '{config_name}' not found locally and not in registry. Skipping auto-download.")
|
|
|
|
use_flash_attention = args.use_flash_attention
|
|
if use_flash_attention is None:
|
|
use_flash_attention = dit_handler.is_flash_attention_available(device)
|
|
|
|
compile_model = os.environ.get("ACESTEP_COMPILE_MODEL", "").strip().lower() in {
|
|
"1", "true", "yes", "y", "on",
|
|
}
|
|
|
|
print(f"Initializing DiT handler with model: {args.config_path}")
|
|
dit_handler.initialize_service(
|
|
project_root=args.project_root,
|
|
config_path=args.config_path,
|
|
device=device,
|
|
use_flash_attention=use_flash_attention,
|
|
compile_model=compile_model,
|
|
offload_to_cpu=args.offload_to_cpu,
|
|
offload_dit_to_cpu=args.offload_dit_to_cpu,
|
|
)
|
|
|
|
if requires_lm:
|
|
from acestep.model_downloader import ensure_lm_model
|
|
if args.lm_model_path is None:
|
|
available_lm_models = llm_handler.get_available_5hz_lm_models()
|
|
if available_lm_models:
|
|
args.lm_model_path = available_lm_models[0]
|
|
print(f"Using default LM model: {args.lm_model_path}")
|
|
else:
|
|
success, msg = ensure_lm_model(checkpoints_dir=checkpoints_dir)
|
|
print(msg)
|
|
if not success:
|
|
parser.error("No LM models available. Please specify --lm_model_path or disable --thinking.")
|
|
available_lm_models = llm_handler.get_available_5hz_lm_models()
|
|
if not available_lm_models:
|
|
parser.error("No LM models available after download. Please specify --lm_model_path or disable --thinking.")
|
|
args.lm_model_path = available_lm_models[0]
|
|
print(f"Using default LM model: {args.lm_model_path}")
|
|
else:
|
|
lm_model_path = str(args.lm_model_path)
|
|
if os.path.isabs(lm_model_path) and os.path.exists(lm_model_path):
|
|
pass
|
|
elif check_model_exists(lm_model_path, checkpoints_dir):
|
|
pass
|
|
elif lm_model_path in SUBMODEL_REGISTRY:
|
|
success, msg = ensure_lm_model(lm_model_path, checkpoints_dir=checkpoints_dir)
|
|
print(msg)
|
|
if not success:
|
|
parser.error(f"Failed to download LM model '{lm_model_path}': {msg}")
|
|
else:
|
|
parser.error(f"LM model '{lm_model_path}' not found locally and not in registry. Please provide a valid --lm_model_path.")
|
|
|
|
print(f"Initializing LM handler with model: {args.lm_model_path}")
|
|
llm_handler.initialize(
|
|
checkpoint_dir=args.checkpoint_dir,
|
|
lm_model_path=args.lm_model_path,
|
|
backend=args.backend,
|
|
device=device,
|
|
offload_to_cpu=args.offload_to_cpu,
|
|
dtype=None,
|
|
)
|
|
else:
|
|
if args.task_type in skip_lm_tasks:
|
|
print(f"LM is not required for task_type '{args.task_type}'. Skipping LM handler initialization.")
|
|
else:
|
|
print("LM 'thinking' is disabled. Skipping LM handler initialization.")
|
|
|
|
print("Handlers initialized.")
|
|
|
|
format_has_duration = False
|
|
|
|
# --- Sample Mode / Description-based Auto-Generation ---
|
|
if args.sample_mode or (args.sample_query and str(args.sample_query).strip()):
|
|
if not llm_handler.llm_initialized:
|
|
parser.error("--sample_mode/sample_query requires the LM handler, but it's not initialized.")
|
|
|
|
sample_query = args.sample_query if args.sample_query and str(args.sample_query).strip() else "NO USER INPUT"
|
|
parsed_language, parsed_instrumental = _parse_description_hints(sample_query)
|
|
|
|
if args.vocal_language and args.vocal_language not in ("en", "unknown", ""):
|
|
sample_language = args.vocal_language
|
|
else:
|
|
sample_language = parsed_language
|
|
|
|
print("\nINFO: Creating sample via 'create_sample'...")
|
|
sample_result = create_sample(
|
|
llm_handler=llm_handler,
|
|
query=sample_query,
|
|
instrumental=parsed_instrumental,
|
|
vocal_language=sample_language,
|
|
temperature=args.lm_temperature,
|
|
top_k=args.lm_top_k,
|
|
top_p=args.lm_top_p,
|
|
)
|
|
|
|
if sample_result.success:
|
|
args.caption = sample_result.caption
|
|
args.lyrics = sample_result.lyrics
|
|
args.instrumental = bool(sample_result.instrumental)
|
|
if args.bpm is None:
|
|
args.bpm = sample_result.bpm
|
|
if not args.keyscale:
|
|
args.keyscale = sample_result.keyscale
|
|
if not args.timesignature:
|
|
args.timesignature = sample_result.timesignature
|
|
if args.duration <= 0:
|
|
args.duration = sample_result.duration
|
|
if args.vocal_language in ("unknown", "", None):
|
|
args.vocal_language = sample_result.language
|
|
args.sample_mode = True
|
|
print("✓ Sample created. Using generated parameters.")
|
|
else:
|
|
parser.error(f"create_sample failed: {sample_result.error or sample_result.status_message}")
|
|
|
|
# --- Format caption/lyrics if requested ---
|
|
if args.use_format and (args.caption or args.lyrics):
|
|
if not llm_handler.llm_initialized:
|
|
parser.error("--use_format requires the LM handler, but it's not initialized.")
|
|
|
|
user_metadata_for_format = {}
|
|
if args.bpm is not None:
|
|
user_metadata_for_format["bpm"] = args.bpm
|
|
if args.duration is not None and float(args.duration) > 0:
|
|
user_metadata_for_format["duration"] = float(args.duration)
|
|
if args.keyscale:
|
|
user_metadata_for_format["keyscale"] = args.keyscale
|
|
if args.timesignature:
|
|
user_metadata_for_format["timesignature"] = args.timesignature
|
|
if args.vocal_language and args.vocal_language != "unknown":
|
|
user_metadata_for_format["language"] = args.vocal_language
|
|
|
|
print("\nINFO: Formatting caption/lyrics via 'format_sample'...")
|
|
format_result = format_sample(
|
|
llm_handler=llm_handler,
|
|
caption=args.caption or "",
|
|
lyrics=args.lyrics or "",
|
|
user_metadata=user_metadata_for_format if user_metadata_for_format else None,
|
|
temperature=args.lm_temperature,
|
|
top_k=args.lm_top_k,
|
|
top_p=args.lm_top_p,
|
|
)
|
|
|
|
if format_result.success:
|
|
args.caption = format_result.caption or args.caption
|
|
args.lyrics = format_result.lyrics or args.lyrics
|
|
if format_result.duration:
|
|
args.duration = format_result.duration
|
|
format_has_duration = True
|
|
if format_result.bpm:
|
|
args.bpm = format_result.bpm
|
|
if format_result.keyscale:
|
|
args.keyscale = format_result.keyscale
|
|
if format_result.timesignature:
|
|
args.timesignature = format_result.timesignature
|
|
print("✓ Format complete.")
|
|
else:
|
|
parser.error(f"format_sample failed: {format_result.error or format_result.status_message}")
|
|
|
|
# --- Auto-generate Lyrics if Requested ---
|
|
if args.use_cot_lyrics:
|
|
if not llm_handler.llm_initialized:
|
|
parser.error("--use_cot_lyrics requires the LM handler, but it's not initialized. Ensure --thinking is enabled.")
|
|
|
|
print("\nINFO: Generating lyrics and metadata via 'create_sample'...")
|
|
sample_result = create_sample(
|
|
llm_handler=llm_handler,
|
|
query=args.caption,
|
|
instrumental=False,
|
|
vocal_language=args.vocal_language if args.vocal_language != 'unknown' else None,
|
|
temperature=args.lm_temperature,
|
|
top_k=args.lm_top_k,
|
|
top_p=args.lm_top_p,
|
|
)
|
|
|
|
if sample_result.success:
|
|
print("✓ Automatic sample creation successful. Using generated parameters:")
|
|
# Update args with values from create_sample, respecting user-provided values
|
|
args.caption = sample_result.caption
|
|
args.lyrics = sample_result.lyrics
|
|
if args.bpm is None: args.bpm = sample_result.bpm
|
|
if not args.keyscale: args.keyscale = sample_result.keyscale
|
|
if not args.timesignature: args.timesignature = sample_result.timesignature
|
|
if args.duration <= 0: args.duration = sample_result.duration
|
|
if args.vocal_language == 'unknown': args.vocal_language = sample_result.language
|
|
|
|
print(f" - Caption: {args.caption}")
|
|
lyrics_preview = args.lyrics[:150].strip().replace("\n", " ")
|
|
print(f" - Lyrics: '{lyrics_preview}...'")
|
|
print(f" - Metadata: BPM={args.bpm}, Key='{args.keyscale}', Lang='{args.vocal_language}'")
|
|
|
|
# Disable subsequent CoT steps to avoid redundancy and save time
|
|
args.use_cot_metas = False
|
|
args.use_cot_caption = False
|
|
else:
|
|
print(f"⚠️ WARNING: Automatic lyric generation via 'create_sample' failed: {sample_result.error}")
|
|
print(" Proceeding with an instrumental track instead.")
|
|
args.lyrics = "[Instrumental]"
|
|
args.instrumental = True
|
|
|
|
# Flag has served its purpose, disable it to avoid issues with GenerationParams
|
|
args.use_cot_lyrics = False
|
|
|
|
if args.sample_mode or format_has_duration:
|
|
args.use_cot_metas = False
|
|
|
|
# --- Prompt Editing Hook for LLM Audio Tokens ---
|
|
if args.thinking and args.task_type not in skip_lm_tasks:
|
|
instruction_path = os.path.join(
|
|
os.path.abspath(args.project_root) if args.project_root else os.getcwd(),
|
|
"instruction.txt",
|
|
)
|
|
preloaded_prompt = None
|
|
use_instruction_file = False
|
|
if args.config and os.path.exists(instruction_path):
|
|
use_instruction_file = True
|
|
try:
|
|
with open(instruction_path, "r", encoding="utf-8") as f:
|
|
preloaded_prompt = f.read()
|
|
except Exception as e:
|
|
print(f"WARNING: Failed to read {instruction_path}: {e}")
|
|
preloaded_prompt = None
|
|
use_instruction_file = False
|
|
if use_instruction_file:
|
|
print(f"INFO: Found {instruction_path}. Using it without editing.")
|
|
if preloaded_prompt is not None and not preloaded_prompt.strip():
|
|
preloaded_prompt = None
|
|
_install_prompt_edit_hook(llm_handler, instruction_path, preloaded_prompt=preloaded_prompt)
|
|
|
|
# --- Configure Generation ---
|
|
params = GenerationParams(
|
|
task_type=args.task_type,
|
|
instruction=args.instruction,
|
|
reference_audio=args.reference_audio,
|
|
src_audio=args.src_audio,
|
|
audio_codes=args.audio_codes,
|
|
caption=args.caption,
|
|
lyrics=args.lyrics,
|
|
instrumental=args.instrumental,
|
|
vocal_language=args.vocal_language,
|
|
bpm=args.bpm,
|
|
keyscale=args.keyscale,
|
|
timesignature=args.timesignature,
|
|
duration=args.duration,
|
|
inference_steps=args.inference_steps,
|
|
seed=args.seed,
|
|
guidance_scale=args.guidance_scale,
|
|
use_adg=args.use_adg,
|
|
cfg_interval_start=args.cfg_interval_start,
|
|
cfg_interval_end=args.cfg_interval_end,
|
|
shift=args.shift,
|
|
infer_method=args.infer_method,
|
|
timesteps=timesteps,
|
|
repainting_start=args.repainting_start,
|
|
repainting_end=args.repainting_end,
|
|
audio_cover_strength=args.audio_cover_strength,
|
|
thinking=args.thinking,
|
|
lm_temperature=args.lm_temperature,
|
|
lm_cfg_scale=args.lm_cfg_scale,
|
|
lm_top_k=args.lm_top_k,
|
|
lm_top_p=args.lm_top_p,
|
|
lm_negative_prompt=args.lm_negative_prompt,
|
|
use_cot_metas=args.use_cot_metas,
|
|
use_cot_caption=args.use_cot_caption,
|
|
use_cot_lyrics=args.use_cot_lyrics,
|
|
use_cot_language=args.use_cot_language,
|
|
use_constrained_decoding=args.use_constrained_decoding
|
|
)
|
|
|
|
config = GenerationConfig(
|
|
batch_size=args.batch_size,
|
|
allow_lm_batch=args.allow_lm_batch,
|
|
use_random_seed=args.use_random_seed,
|
|
seeds=args.seeds,
|
|
lm_batch_chunk_size=args.lm_batch_chunk_size,
|
|
constrained_decoding_debug=args.constrained_decoding_debug,
|
|
audio_format=args.audio_format
|
|
)
|
|
|
|
# --- Generate Music ---
|
|
log_level = getattr(args, "log_level", "INFO")
|
|
log_level_upper = str(log_level).upper()
|
|
compact_logs = log_level_upper != "DEBUG"
|
|
_print_final_parameters(
|
|
args,
|
|
params,
|
|
config,
|
|
params_defaults,
|
|
config_defaults,
|
|
compact=compact_logs,
|
|
resolved_device=device,
|
|
)
|
|
|
|
print("\n--- Starting Generation ---")
|
|
print(f"Caption: \"{params.caption}\"")
|
|
print(f"Duration: {params.duration}s | Outputs: {config.batch_size}")
|
|
if config.seeds:
|
|
print(f"Custom Seeds: {config.seeds}")
|
|
print("---------------------------\n")
|
|
|
|
manual_edit_pipeline = (
|
|
args.thinking
|
|
and args.task_type not in skip_lm_tasks
|
|
and not (params.audio_codes and str(params.audio_codes).strip())
|
|
)
|
|
|
|
lm_time_costs = None
|
|
if manual_edit_pipeline:
|
|
top_k_value = None if not params.lm_top_k or params.lm_top_k == 0 else int(params.lm_top_k)
|
|
top_p_value = None if not params.lm_top_p or params.lm_top_p >= 1.0 else params.lm_top_p
|
|
|
|
actual_batch_size = config.batch_size if config.batch_size is not None else 1
|
|
seed_for_generation = ""
|
|
if config.seeds is not None:
|
|
if isinstance(config.seeds, list) and len(config.seeds) > 0:
|
|
seed_for_generation = ",".join(str(s) for s in config.seeds)
|
|
elif isinstance(config.seeds, int):
|
|
seed_for_generation = str(config.seeds)
|
|
actual_seed_list, _ = dit_handler.prepare_seeds(actual_batch_size, seed_for_generation, config.use_random_seed)
|
|
|
|
original_target_duration = params.duration
|
|
original_bpm = params.bpm
|
|
original_keyscale = params.keyscale
|
|
original_timesignature = params.timesignature
|
|
original_vocal_language = params.vocal_language
|
|
lm_result = None
|
|
lm_metadata = {}
|
|
edited_caption = None
|
|
edited_lyrics = None
|
|
edited_instruction = None
|
|
edited_metas = {}
|
|
lm_time_costs = {
|
|
"phase1_time": 0.0,
|
|
"phase2_time": 0.0,
|
|
"total_time": 0.0,
|
|
}
|
|
for attempt in range(2):
|
|
user_metadata = {}
|
|
if params.bpm is not None:
|
|
try:
|
|
bpm_value = float(params.bpm)
|
|
if bpm_value > 0:
|
|
user_metadata["bpm"] = int(bpm_value)
|
|
except (ValueError, TypeError):
|
|
pass
|
|
if params.keyscale and params.keyscale.strip() and params.keyscale.strip().lower() not in ["n/a", ""]:
|
|
user_metadata["keyscale"] = params.keyscale.strip()
|
|
if params.timesignature and params.timesignature.strip() and params.timesignature.strip().lower() not in ["n/a", ""]:
|
|
user_metadata["timesignature"] = params.timesignature.strip()
|
|
if params.duration is not None:
|
|
try:
|
|
duration_value = float(params.duration)
|
|
if duration_value > 0:
|
|
user_metadata["duration"] = int(duration_value)
|
|
except (ValueError, TypeError):
|
|
pass
|
|
# Only include caption and language in user_metadata on
|
|
# regeneration attempts. On the first attempt the LM should
|
|
# generate/expand these via CoT (matching inference.py behaviour).
|
|
if attempt > 0:
|
|
if params.caption and params.caption.strip():
|
|
user_metadata["caption"] = params.caption.strip()
|
|
if params.vocal_language and params.vocal_language not in ("", "unknown"):
|
|
user_metadata["language"] = params.vocal_language
|
|
user_metadata_to_pass = user_metadata if user_metadata else None
|
|
|
|
lm_result = llm_handler.generate_with_stop_condition(
|
|
caption=params.caption or "",
|
|
lyrics=params.lyrics or "",
|
|
infer_type="llm_dit",
|
|
temperature=params.lm_temperature,
|
|
cfg_scale=params.lm_cfg_scale,
|
|
negative_prompt=params.lm_negative_prompt,
|
|
top_k=top_k_value,
|
|
top_p=top_p_value,
|
|
target_duration=params.duration,
|
|
user_metadata=user_metadata_to_pass,
|
|
use_cot_caption=params.use_cot_caption,
|
|
use_cot_language=params.use_cot_language,
|
|
use_cot_metas=params.use_cot_metas,
|
|
use_constrained_decoding=params.use_constrained_decoding,
|
|
constrained_decoding_debug=config.constrained_decoding_debug,
|
|
batch_size=actual_batch_size,
|
|
seeds=actual_seed_list,
|
|
)
|
|
lm_extra_time = (lm_result.get("extra_outputs") or {}).get("time_costs", {})
|
|
if lm_extra_time:
|
|
lm_time_costs["phase1_time"] += float(lm_extra_time.get("phase1_time", 0.0) or 0.0)
|
|
lm_time_costs["phase2_time"] += float(lm_extra_time.get("phase2_time", 0.0) or 0.0)
|
|
lm_time_costs["total_time"] += float(
|
|
lm_extra_time.get(
|
|
"total_time",
|
|
(lm_extra_time.get("phase1_time", 0.0) or 0.0)
|
|
+ (lm_extra_time.get("phase2_time", 0.0) or 0.0),
|
|
)
|
|
or 0.0
|
|
)
|
|
|
|
if not lm_result.get("success", False):
|
|
error_msg = lm_result.get("error", "Unknown LM error")
|
|
print(f"\n❌ Generation failed: {error_msg}")
|
|
print(f" Status: {lm_result.get('error', '')}")
|
|
return
|
|
|
|
if actual_batch_size > 1:
|
|
lm_metadata = (lm_result.get("metadata") or [{}])[0]
|
|
audio_codes = lm_result.get("audio_codes", [])
|
|
else:
|
|
lm_metadata = lm_result.get("metadata", {}) or {}
|
|
audio_codes = lm_result.get("audio_codes", "")
|
|
|
|
if audio_codes:
|
|
params.audio_codes = audio_codes
|
|
else:
|
|
print("WARNING: LM did not return audio codes; proceeding without codes.")
|
|
|
|
edited_caption = getattr(llm_handler, "_edited_caption", None)
|
|
edited_lyrics = getattr(llm_handler, "_edited_lyrics", None)
|
|
edited_instruction = getattr(llm_handler, "_edited_instruction", None)
|
|
edited_metas = getattr(llm_handler, "_edited_metas", {})
|
|
|
|
parsed_duration = None
|
|
parsed_bpm = None
|
|
parsed_keyscale = None
|
|
parsed_timesignature = None
|
|
parsed_language = None
|
|
if edited_metas:
|
|
bpm_value = edited_metas.get("bpm")
|
|
if bpm_value:
|
|
parsed = _parse_number(bpm_value)
|
|
if parsed is not None and parsed > 0:
|
|
parsed_bpm = int(parsed)
|
|
duration_value = edited_metas.get("duration")
|
|
if duration_value:
|
|
parsed = _parse_number(duration_value)
|
|
if parsed is not None and parsed > 0:
|
|
parsed_duration = float(parsed)
|
|
keyscale_value = edited_metas.get("keyscale")
|
|
if keyscale_value:
|
|
parsed_keyscale = keyscale_value
|
|
timesignature_value = edited_metas.get("timesignature")
|
|
if timesignature_value:
|
|
parsed_timesignature = timesignature_value
|
|
language_value = edited_metas.get("language") or edited_metas.get("vocal_language")
|
|
if language_value:
|
|
parsed_language = language_value
|
|
|
|
if attempt == 0:
|
|
duration_changed = parsed_duration is not None and (
|
|
original_target_duration is None
|
|
or float(original_target_duration) <= 0
|
|
or abs(float(original_target_duration) - parsed_duration) > 1e-6
|
|
)
|
|
bpm_changed = parsed_bpm is not None and parsed_bpm != original_bpm
|
|
keyscale_changed = parsed_keyscale is not None and parsed_keyscale != original_keyscale
|
|
timesignature_changed = parsed_timesignature is not None and parsed_timesignature != original_timesignature
|
|
language_changed = parsed_language is not None and parsed_language != original_vocal_language
|
|
if duration_changed or bpm_changed or keyscale_changed or timesignature_changed or language_changed:
|
|
if duration_changed:
|
|
params.duration = parsed_duration
|
|
if bpm_changed:
|
|
params.bpm = parsed_bpm
|
|
if keyscale_changed:
|
|
params.keyscale = parsed_keyscale
|
|
if timesignature_changed:
|
|
params.timesignature = parsed_timesignature
|
|
if language_changed:
|
|
params.vocal_language = parsed_language
|
|
# Carry forward the expanded caption so the second
|
|
# attempt's <think> block (and user_metadata) use it
|
|
# instead of the short original caption.
|
|
edited_caption_for_regen = edited_metas.get("caption") if edited_metas else None
|
|
if edited_caption_for_regen and edited_caption_for_regen.strip():
|
|
params.caption = edited_caption_for_regen
|
|
print("INFO: Edited metadata detected. Regenerating audio codes with updated values.")
|
|
llm_handler._skip_prompt_edit = True
|
|
continue
|
|
break
|
|
|
|
edited_meta_caption = edited_metas.get("caption") if edited_metas else None
|
|
if edited_meta_caption and edited_meta_caption.strip():
|
|
params.caption = edited_meta_caption
|
|
elif edited_caption:
|
|
params.caption = edited_caption
|
|
elif params.use_cot_caption and lm_metadata.get("caption"):
|
|
params.caption = lm_metadata.get("caption")
|
|
|
|
if edited_lyrics:
|
|
params.lyrics = edited_lyrics
|
|
elif not params.lyrics and lm_metadata.get("lyrics"):
|
|
params.lyrics = lm_metadata.get("lyrics")
|
|
|
|
if edited_instruction:
|
|
params.instruction = edited_instruction
|
|
|
|
if edited_metas:
|
|
bpm_value = edited_metas.get("bpm")
|
|
if bpm_value:
|
|
parsed = _parse_number(bpm_value)
|
|
if parsed is not None:
|
|
params.bpm = int(parsed)
|
|
duration_value = edited_metas.get("duration")
|
|
if duration_value:
|
|
parsed = _parse_number(duration_value)
|
|
if parsed is not None:
|
|
params.duration = float(parsed)
|
|
keyscale_value = edited_metas.get("keyscale")
|
|
if keyscale_value:
|
|
params.keyscale = keyscale_value
|
|
timesignature_value = edited_metas.get("timesignature")
|
|
if timesignature_value:
|
|
params.timesignature = timesignature_value
|
|
language_value = edited_metas.get("language") or edited_metas.get("vocal_language")
|
|
if language_value:
|
|
params.vocal_language = language_value
|
|
else:
|
|
if params.bpm is None and lm_metadata.get("bpm") not in (None, "N/A", ""):
|
|
parsed = _parse_number(str(lm_metadata.get("bpm")))
|
|
if parsed is not None:
|
|
params.bpm = int(parsed)
|
|
if not params.keyscale and lm_metadata.get("keyscale"):
|
|
params.keyscale = lm_metadata.get("keyscale")
|
|
if not params.timesignature and lm_metadata.get("timesignature"):
|
|
params.timesignature = lm_metadata.get("timesignature")
|
|
if params.duration is None and lm_metadata.get("duration") not in (None, "N/A", ""):
|
|
parsed = _parse_number(str(lm_metadata.get("duration")))
|
|
if parsed is not None:
|
|
params.duration = float(parsed)
|
|
if params.vocal_language in (None, "", "unknown"):
|
|
language_value = lm_metadata.get("vocal_language") or lm_metadata.get("language")
|
|
if language_value:
|
|
params.vocal_language = language_value
|
|
|
|
# use_cot_language: override vocal_language with LM detection unless
|
|
# the user explicitly edited the language in the think block.
|
|
if params.use_cot_language:
|
|
edited_lang = (edited_metas.get("language") or edited_metas.get("vocal_language")) if edited_metas else None
|
|
if not edited_lang:
|
|
lm_lang = lm_metadata.get("vocal_language") or lm_metadata.get("language")
|
|
if lm_lang:
|
|
params.vocal_language = lm_lang
|
|
|
|
# Populate cot_* fields for downstream reporting (mirrors inference.py)
|
|
if lm_metadata:
|
|
if original_bpm is None:
|
|
params.cot_bpm = params.bpm
|
|
if not original_keyscale:
|
|
params.cot_keyscale = params.keyscale
|
|
if not original_timesignature:
|
|
params.cot_timesignature = params.timesignature
|
|
if original_target_duration is None or float(original_target_duration) <= 0:
|
|
params.cot_duration = params.duration
|
|
if original_vocal_language in (None, "", "unknown"):
|
|
params.cot_vocal_language = params.vocal_language
|
|
if not params.caption:
|
|
params.cot_caption = lm_metadata.get("caption", "")
|
|
if not params.lyrics:
|
|
params.cot_lyrics = lm_metadata.get("lyrics", "")
|
|
|
|
params.thinking = False
|
|
params.use_cot_caption = False
|
|
params.use_cot_language = False
|
|
params.use_cot_metas = False
|
|
if hasattr(llm_handler, "_skip_prompt_edit"):
|
|
llm_handler._skip_prompt_edit = False
|
|
|
|
if log_level_upper in {"INFO", "DEBUG"}:
|
|
_print_dit_prompt(dit_handler, params)
|
|
print("Running DiT generation with edited prompt and cached audio codes...")
|
|
result = generate_music(dit_handler, llm_handler, params, config, save_dir=args.save_dir)
|
|
else:
|
|
if log_level_upper in {"INFO", "DEBUG"}:
|
|
_print_dit_prompt(dit_handler, params)
|
|
result = generate_music(dit_handler, llm_handler, params, config, save_dir=args.save_dir)
|
|
|
|
# --- Process Results ---
|
|
if result.success:
|
|
print(f"\n✅ Generation successful! {len(result.audios)} audio(s) saved in '{args.save_dir}/'")
|
|
for i, audio in enumerate(result.audios):
|
|
print(f" [{i+1}] Path: {audio['path']} | Seed: {audio['params']['seed']}")
|
|
|
|
time_costs = result.extra_outputs.get("time_costs", {})
|
|
if manual_edit_pipeline and lm_time_costs and time_costs is not None:
|
|
if not isinstance(time_costs, dict):
|
|
time_costs = {}
|
|
result.extra_outputs["time_costs"] = time_costs
|
|
if lm_time_costs["total_time"] > 0.0:
|
|
time_costs["lm_phase1_time"] = lm_time_costs["phase1_time"]
|
|
time_costs["lm_phase2_time"] = lm_time_costs["phase2_time"]
|
|
time_costs["lm_total_time"] = lm_time_costs["total_time"]
|
|
dit_total = float(time_costs.get("dit_total_time_cost", 0.0) or 0.0)
|
|
time_costs["pipeline_total_time"] = time_costs["lm_total_time"] + dit_total
|
|
if time_costs:
|
|
print("\n--- Performance ---")
|
|
total_time = time_costs.get('pipeline_total_time', 0)
|
|
print(f"Total time: {total_time:.2f}s")
|
|
if args.thinking:
|
|
lm1_time = time_costs.get('lm_phase1_time', 0)
|
|
lm2_time = time_costs.get('lm_phase2_time', 0)
|
|
print(f" - LM time: {lm1_time + lm2_time:.2f}s")
|
|
dit_time = time_costs.get('dit_total_time_cost', 0)
|
|
print(f" - DiT time: {dit_time:.2f}s")
|
|
print("-------------------\n")
|
|
|
|
else:
|
|
print(f"\n❌ Generation failed: {result.error}")
|
|
print(f" Status: {result.status_message}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|