Merge branch 'main' into fix_xpu_req

This commit is contained in:
xu shengyuan 2026-02-13 22:56:23 +08:00 committed by GitHub
commit cf5b27e8f0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
166 changed files with 40083 additions and 7035 deletions

View file

@ -0,0 +1,173 @@
---
name: acestep-lyrics-transcription
description: Transcribe audio to timestamped lyrics using OpenAI Whisper or ElevenLabs Scribe API. Outputs LRC, SRT, or JSON with word-level timestamps. Use when users want to transcribe songs, generate LRC files, or extract lyrics with timestamps from audio.
allowed-tools: Read, Write, Bash
---
# Lyrics Transcription Skill
Transcribe audio files to timestamped lyrics (LRC/SRT/JSON) via OpenAI Whisper or ElevenLabs Scribe API.
## API Key Setup Guide
**Before transcribing, you MUST check whether the user's API key is configured.** Run the following command to check:
```bash
cd "{project_root}/{.claude or .codex}/skills/acestep-lyrics-transcription/" && bash ./scripts/acestep-lyrics-transcription.sh config --check-key
```
This command only reports whether the active provider's API key is set or empty — it does NOT print the actual key value. **NEVER read or display the user's API key content.** Do not use `config --get` on key fields or read `config.json` directly. The `config --list` command is safe — it automatically masks API keys as `***` in output.
**If the command reports the key is empty**, you MUST stop and guide the user to configure it before proceeding. Do NOT attempt transcription without a valid key — it will fail.
Use `AskUserQuestion` to ask the user to provide their API key, with the following options and guidance:
1. Tell the user which provider is currently active (openai or elevenlabs) and that its API key is not configured. Explain that transcription cannot proceed without it.
2. Provide clear instructions on where to obtain a key:
- **OpenAI**: Get an API key at https://platform.openai.com/api-keys — requires an OpenAI account with billing enabled. The Whisper API costs ~$0.006/min.
- **ElevenLabs**: Get an API key at https://elevenlabs.io/app/settings/api-keys — requires an ElevenLabs account. Free tier includes limited credits.
3. Also offer the option to switch to the other provider if they already have a key for it.
4. Once the user provides the key, configure it using:
```bash
cd "{project_root}/{.claude or .codex}/skills/acestep-lyrics-transcription/" && bash ./scripts/acestep-lyrics-transcription.sh config --set <provider>.api_key <KEY>
```
5. If the user wants to switch providers, also run:
```bash
cd "{project_root}/{.claude or .codex}/skills/acestep-lyrics-transcription/" && bash ./scripts/acestep-lyrics-transcription.sh config --set provider <provider_name>
```
6. After configuring, re-run `config --check-key` to verify the key is set before proceeding.
**If the API key is already configured**, proceed directly to transcription without asking.
## Quick Start
```bash
# 1. cd to this skill's directory
cd {project_root}/{.claude or .codex}/skills/acestep-lyrics-transcription/
# 2. Configure API key (choose one)
./scripts/acestep-lyrics-transcription.sh config --set openai.api_key sk-...
# or
./scripts/acestep-lyrics-transcription.sh config --set elevenlabs.api_key ...
./scripts/acestep-lyrics-transcription.sh config --set provider elevenlabs
# 3. Transcribe
./scripts/acestep-lyrics-transcription.sh transcribe --audio /path/to/song.mp3 --language zh
# 4. Output saved to: {project_root}/acestep_output/<filename>.lrc
```
## Prerequisites
- curl, jq, python3 (or python)
- An API key for OpenAI or ElevenLabs
## Script Usage
```bash
./scripts/acestep-lyrics-transcription.sh transcribe --audio <file> [options]
Options:
-a, --audio Audio file path (required)
-l, --language Language code (zh, en, ja, etc.)
-f, --format Output format: lrc, srt, json (default: lrc)
-p, --provider API provider: openai, elevenlabs (overrides config)
-o, --output Output file path (default: acestep_output/<filename>.lrc)
```
## Post-Transcription Lyrics Correction (MANDATORY)
**CRITICAL**: After transcription, you MUST manually correct the LRC file before using it for MV rendering. Transcription models frequently produce errors on sung lyrics:
- Proper nouns: "ACE-Step" → "AC step", "Spotify" → "spot a fly"
- Similar-sounding words: "arrives" → "eyes", "open source" → "open sores"
- Merged/split words: "lighting up" → "lightin' nup"
### Correction Workflow
1. **Read the transcribed LRC file** using the Read tool
2. **Read the original lyrics** from the ACE-Step output JSON file
3. **Use original lyrics as a whole reference**: Do NOT attempt line-by-line alignment — transcription often splits, merges, or reorders lines differently from the original. Instead, read the original lyrics in full to understand the correct wording, then scan each LRC line and fix any misrecognized words based on your knowledge of what the original lyrics say.
4. **Fix transcription errors**: Replace misrecognized words with the correct original words, keeping the timestamps intact
5. **Write the corrected LRC** back using the Write tool
### What to Correct
- Replace misrecognized words with their correct original versions
- Keep all `[MM:SS.cc]` timestamps exactly as-is (timestamps from transcription are accurate)
- Do NOT add structure tags like `[Verse]` or `[Chorus]` — the LRC should only have timestamped text lines
### Example
**Transcribed (wrong):**
```
[00:46.96]AC step alive,
[00:50.80]one point five eyes.
```
**Original lyrics reference:**
```
ACE-Step alive
One point five arrives
```
**Corrected (right):**
```
[00:46.96]ACE-Step alive,
[00:50.80]One point five arrives.
```
## Configuration
Config file: `scripts/config.json`
```bash
# Switch provider
./scripts/acestep-lyrics-transcription.sh config --set provider openai
./scripts/acestep-lyrics-transcription.sh config --set provider elevenlabs
# Set API keys
./scripts/acestep-lyrics-transcription.sh config --set openai.api_key sk-...
./scripts/acestep-lyrics-transcription.sh config --set elevenlabs.api_key ...
# View config
./scripts/acestep-lyrics-transcription.sh config --list
```
| Option | Default | Description |
|--------|---------|-------------|
| `provider` | `openai` | Active provider: `openai` or `elevenlabs` |
| `output_format` | `lrc` | Default output: `lrc`, `srt`, or `json` |
| `openai.api_key` | `""` | OpenAI API key |
| `openai.api_url` | `https://api.openai.com/v1` | OpenAI API base URL |
| `openai.model` | `whisper-1` | OpenAI model (whisper-1 for word timestamps) |
| `elevenlabs.api_key` | `""` | ElevenLabs API key |
| `elevenlabs.api_url` | `https://api.elevenlabs.io/v1` | ElevenLabs API base URL |
| `elevenlabs.model` | `scribe_v2` | ElevenLabs model |
## Provider Notes
| Provider | Model | Word Timestamps | Pricing |
|----------|-------|-----------------|---------|
| OpenAI | whisper-1 | Yes (segment + word) | $0.006/min |
| ElevenLabs | scribe_v2 | Yes (word-level) | Varies by plan |
- OpenAI `whisper-1` is the only OpenAI model supporting word-level timestamps
- ElevenLabs `scribe_v2` returns word-level timestamps with type filtering
- Both support multilingual transcription
## Examples
```bash
# Basic transcription (uses config defaults)
./scripts/acestep-lyrics-transcription.sh transcribe --audio song.mp3
# Chinese song to LRC
./scripts/acestep-lyrics-transcription.sh transcribe --audio song.mp3 --language zh
# Use ElevenLabs, output SRT
./scripts/acestep-lyrics-transcription.sh transcribe --audio song.mp3 --provider elevenlabs --format srt
# Custom output path
./scripts/acestep-lyrics-transcription.sh transcribe --audio song.mp3 --output ./my_lyrics.lrc
```

View file

@ -0,0 +1,584 @@
#!/bin/bash
#
# acestep-lyrics-transcription.sh - Transcribe audio to timestamped lyrics (LRC/SRT/JSON)
#
# Requirements: curl, jq
#
# Usage:
# ./acestep-lyrics-transcription.sh transcribe --audio <file> [options]
# ./acestep-lyrics-transcription.sh config [--get|--set|--reset]
#
# Output:
# - LRC/SRT/JSON files saved to output directory
set -e
export LANG="${LANG:-en_US.UTF-8}"
export LC_ALL="${LC_ALL:-en_US.UTF-8}"
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
CONFIG_FILE="${SCRIPT_DIR}/config.json"
OUTPUT_DIR="$(cd "${SCRIPT_DIR}/../../../.." && pwd)/acestep_output"
# Colors
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
CYAN='\033[0;36m'
NC='\033[0m'
# Convert MSYS2/Cygwin paths to Windows-native paths for Python
to_python_path() {
if command -v cygpath &> /dev/null; then
cygpath -m "$1"
else
echo "$1"
fi
}
# Detect python executable (python3 or python)
PYTHON_CMD=""
find_python() {
if [ -n "$PYTHON_CMD" ]; then return; fi
# Test actual execution, not just existence (Windows Store python3 shim returns exit 49)
if python3 -c "pass" &> /dev/null; then
PYTHON_CMD="python3"
elif python -c "pass" &> /dev/null; then
PYTHON_CMD="python"
else
echo -e "${RED}Error: python3 or python is required but not found.${NC}"
exit 1
fi
}
# ─── Dependencies ───
check_deps() {
if ! command -v curl &> /dev/null; then
echo -e "${RED}Error: curl is required but not installed.${NC}"
exit 1
fi
if ! command -v jq &> /dev/null; then
echo -e "${RED}Error: jq is required but not installed.${NC}"
echo "Install: apt install jq / brew install jq / choco install jq"
exit 1
fi
}
# ─── Config ───
DEFAULT_CONFIG='{
"provider": "openai",
"output_format": "lrc",
"openai": {
"api_key": "",
"api_url": "https://api.openai.com/v1",
"model": "whisper-1"
},
"elevenlabs": {
"api_key": "",
"api_url": "https://api.elevenlabs.io/v1",
"model": "scribe_v2"
}
}'
ensure_config() {
if [ ! -f "$CONFIG_FILE" ]; then
local example="${SCRIPT_DIR}/config.example.json"
if [ -f "$example" ]; then
cp "$example" "$CONFIG_FILE"
echo -e "${YELLOW}Config file created from config.example.json. Please configure your settings:${NC}"
echo -e " ${CYAN}./scripts/acestep-lyrics-transcription.sh config --set provider <openai|elevenlabs>${NC}"
echo -e " ${CYAN}./scripts/acestep-lyrics-transcription.sh config --set <provider>.api_key <key>${NC}"
else
echo "$DEFAULT_CONFIG" > "$CONFIG_FILE"
fi
fi
}
get_config() {
local key="$1"
ensure_config
local jq_path=".${key}"
local value
value=$(jq -r "$jq_path" "$CONFIG_FILE" 2>/dev/null)
if [ "$value" = "null" ]; then
echo ""
else
echo "$value" | tr -d '\r\n'
fi
}
set_config() {
local key="$1"
local value="$2"
ensure_config
local tmp_file="${CONFIG_FILE}.tmp"
local jq_path=".${key}"
if [ "$value" = "true" ] || [ "$value" = "false" ]; then
jq "$jq_path = $value" "$CONFIG_FILE" > "$tmp_file"
elif [[ "$value" =~ ^-?[0-9]+$ ]] || [[ "$value" =~ ^-?[0-9]+\.[0-9]+$ ]]; then
jq "$jq_path = $value" "$CONFIG_FILE" > "$tmp_file"
else
jq "$jq_path = \"$value\"" "$CONFIG_FILE" > "$tmp_file"
fi
mv "$tmp_file" "$CONFIG_FILE"
echo "Set $key = $value"
}
ensure_output_dir() {
mkdir -p "$OUTPUT_DIR"
}
# ─── Format Conversion ───
# Convert word-level timestamps to LRC format
# Input: JSON array of {word, start, end} on stdin
# Output: LRC text
words_to_lrc() {
local json_file="$(to_python_path "$1")"
local output_file="$(to_python_path "$2")"
local line_gap="${3:-1.5}"
find_python
$PYTHON_CMD -c "
import json, sys, unicodedata
def is_cjk(ch):
cp = ord(ch)
return (0x4E00 <= cp <= 0x9FFF or 0x3400 <= cp <= 0x4DBF or
0x20000 <= cp <= 0x2A6DF or 0x2A700 <= cp <= 0x2B73F or
0x2B740 <= cp <= 0x2B81F or 0x2B820 <= cp <= 0x2CEAF or
0xF900 <= cp <= 0xFAFF or 0x2F800 <= cp <= 0x2FA1F or
0x3000 <= cp <= 0x303F or 0x3040 <= cp <= 0x309F or
0x30A0 <= cp <= 0x30FF or 0xFF00 <= cp <= 0xFFEF)
def smart_join(word_list):
if not word_list:
return ''
result = word_list[0]
for j in range(1, len(word_list)):
prev_w = word_list[j-1]
curr_w = word_list[j]
prev_last = prev_w[-1] if prev_w else ''
curr_first = curr_w[0] if curr_w else ''
if is_cjk(prev_last) or is_cjk(curr_first):
result += curr_w
else:
result += ' ' + curr_w
return result.strip()
with open('$json_file', 'r', encoding='utf-8') as f:
words = json.load(f)
if not words:
sys.exit(0)
lines = []
current_line = []
current_start = words[0]['start']
for i, w in enumerate(words):
current_line.append(w['word'])
is_last = (i == len(words) - 1)
has_punct = w['word'].rstrip().endswith(('.', '!', '?', '。', '', '', '', ','))
has_gap = (not is_last and words[i+1]['start'] - w['end'] > $line_gap)
if is_last or has_punct or has_gap:
text = smart_join(current_line)
text = text.rstrip(',。,.')
if text:
mins = int(current_start) // 60
secs = current_start - mins * 60
lines.append(f'[{mins:02d}:{secs:05.2f}]{text}')
current_line = []
if not is_last:
current_start = words[i+1]['start']
with open('$output_file', 'w', encoding='utf-8') as f:
for line in lines:
f.write(line + '\n')
"
}
# Convert word-level timestamps to SRT format
words_to_srt() {
local json_file="$(to_python_path "$1")"
local output_file="$(to_python_path "$2")"
local line_gap="${3:-1.5}"
find_python
$PYTHON_CMD -c "
import json, sys
def is_cjk(ch):
cp = ord(ch)
return (0x4E00 <= cp <= 0x9FFF or 0x3400 <= cp <= 0x4DBF or
0x20000 <= cp <= 0x2A6DF or 0x2A700 <= cp <= 0x2B73F or
0x2B740 <= cp <= 0x2B81F or 0x2B820 <= cp <= 0x2CEAF or
0xF900 <= cp <= 0xFAFF or 0x2F800 <= cp <= 0x2FA1F or
0x3000 <= cp <= 0x303F or 0x3040 <= cp <= 0x309F or
0x30A0 <= cp <= 0x30FF or 0xFF00 <= cp <= 0xFFEF)
def smart_join(word_list):
if not word_list:
return ''
result = word_list[0]
for j in range(1, len(word_list)):
prev_w = word_list[j-1]
curr_w = word_list[j]
prev_last = prev_w[-1] if prev_w else ''
curr_first = curr_w[0] if curr_w else ''
if is_cjk(prev_last) or is_cjk(curr_first):
result += curr_w
else:
result += ' ' + curr_w
return result.strip()
with open('$json_file', 'r', encoding='utf-8') as f:
words = json.load(f)
if not words:
sys.exit(0)
def fmt(t):
h = int(t) // 3600
m = (int(t) % 3600) // 60
s = t - h*3600 - m*60
return f'{h:02d}:{m:02d}:{s:06.3f}'.replace('.', ',')
lines = []
current_line = []
current_start = words[0]['start']
current_end = words[0]['end']
for i, w in enumerate(words):
current_line.append(w['word'])
current_end = w['end']
is_last = (i == len(words) - 1)
has_punct = w['word'].rstrip().endswith(('.', '!', '?', '。', '', '', '', ','))
has_gap = (not is_last and words[i+1]['start'] - w['end'] > $line_gap)
if is_last or has_punct or has_gap:
text = smart_join(current_line)
text = text.rstrip(',。,.')
if text:
lines.append((current_start, current_end, text))
current_line = []
if not is_last:
current_start = words[i+1]['start']
with open('$output_file', 'w', encoding='utf-8') as f:
for idx, (s, e, text) in enumerate(lines, 1):
f.write(f'{idx}\n')
f.write(f'{fmt(s)} --> {fmt(e)}\n')
f.write(f'{text}\n')
f.write('\n')
"
}
# ─── OpenAI Whisper ───
transcribe_openai() {
local audio_file="$1"
local language="$2"
local words_file="$3"
local api_key=$(get_config "openai.api_key")
local api_url=$(get_config "openai.api_url")
local model=$(get_config "openai.model")
[ -z "$api_key" ] && { echo -e "${RED}Error: OpenAI API key not configured.${NC}"; echo "Run: ./acestep-lyrics-transcription.sh config --set openai.api_key YOUR_KEY"; exit 1; }
[ -z "$api_url" ] && api_url="https://api.openai.com/v1"
[ -z "$model" ] && model="whisper-1"
echo -e " Provider: OpenAI (${model})"
local resp_file=$(mktemp)
# Build curl command
local curl_args=(
-s -w "%{http_code}"
-o "$resp_file"
-X POST "${api_url}/audio/transcriptions"
-H "Authorization: Bearer ${api_key}"
-F "file=@${audio_file}"
-F "model=${model}"
-F "response_format=verbose_json"
-F "timestamp_granularities[]=word"
-F "timestamp_granularities[]=segment"
)
[ -n "$language" ] && curl_args+=(-F "language=${language}")
local http_code
http_code=$(curl "${curl_args[@]}")
if [ "$http_code" != "200" ]; then
local err
err=$(jq -r '.error.message // .detail // "Unknown error"' "$resp_file" 2>/dev/null)
echo -e "${RED}Error: HTTP $http_code - $err${NC}"
rm -f "$resp_file"
return 1
fi
# Extract word-level timestamps into normalized format [{word, start, end}]
jq '[.words[] | {word: .word, start: .start, end: .end}]' "$resp_file" > "$words_file" 2>/dev/null
# Show transcription text
local text
text=$(jq -r '.text // empty' "$resp_file" 2>/dev/null)
echo -e " ${GREEN}Transcription complete${NC}"
echo ""
echo "$text"
rm -f "$resp_file"
}
# ─── ElevenLabs Scribe ───
transcribe_elevenlabs() {
local audio_file="$1"
local language="$2"
local words_file="$3"
local api_key=$(get_config "elevenlabs.api_key")
local api_url=$(get_config "elevenlabs.api_url")
local model=$(get_config "elevenlabs.model")
[ -z "$api_key" ] && { echo -e "${RED}Error: ElevenLabs API key not configured.${NC}"; echo "Run: ./acestep-lyrics-transcription.sh config --set elevenlabs.api_key YOUR_KEY"; exit 1; }
[ -z "$api_url" ] && api_url="https://api.elevenlabs.io/v1"
[ -z "$model" ] && model="scribe_v2"
echo -e " Provider: ElevenLabs (${model})"
local resp_file=$(mktemp)
local curl_args=(
-s -w "%{http_code}"
-o "$resp_file"
-X POST "${api_url}/speech-to-text"
-H "xi-api-key: ${api_key}"
-F "file=@${audio_file}"
-F "model_id=${model}"
)
[ -n "$language" ] && curl_args+=(-F "language_code=${language}")
local http_code
http_code=$(curl "${curl_args[@]}")
if [ "$http_code" != "200" ]; then
local err
err=$(jq -r '.detail.message // .detail // "Unknown error"' "$resp_file" 2>/dev/null)
echo -e "${RED}Error: HTTP $http_code - $err${NC}"
rm -f "$resp_file"
return 1
fi
# ElevenLabs response: { text, words: [{text, start, end, type}...] }
# Normalize to [{word, start, end}], timestamps already in seconds, filter only "word" type
jq '[.words[] | select(.type == "word") | {word: .text, start: .start, end: .end}]' "$resp_file" > "$words_file" 2>/dev/null
local text
text=$(jq -r '.text // empty' "$resp_file" 2>/dev/null)
echo -e " ${GREEN}Transcription complete${NC}"
echo ""
echo "$text"
rm -f "$resp_file"
}
# ─── Commands ───
cmd_transcribe() {
check_deps
ensure_config
local audio="" language="" output="" format="" provider=""
while [[ $# -gt 0 ]]; do
case $1 in
--audio|-a) audio="$2"; shift 2 ;;
--language|-l) language="$2"; shift 2 ;;
--output|-o) output="$2"; shift 2 ;;
--format|-f) format="$2"; shift 2 ;;
--provider|-p) provider="$2"; shift 2 ;;
*) [ -z "$audio" ] && audio="$1"; shift ;;
esac
done
[ -z "$audio" ] && { echo -e "${RED}Error: --audio is required${NC}"; echo "Usage: $0 transcribe --audio <file> [options]"; exit 1; }
[ ! -f "$audio" ] && { echo -e "${RED}Error: audio file not found: $audio${NC}"; exit 1; }
# Resolve absolute path
audio="$(cd "$(dirname "$audio")" && pwd)/$(basename "$audio")"
[ -z "$provider" ] && provider=$(get_config "provider")
[ -z "$provider" ] && provider="openai"
[ -z "$format" ] && format=$(get_config "output_format")
[ -z "$format" ] && format="lrc"
# Default output path
if [ -z "$output" ]; then
ensure_output_dir
local basename="$(basename "${audio%.*}")"
output="${OUTPUT_DIR}/${basename}.${format}"
fi
echo "Transcribing..."
echo " Audio: $(basename "$audio")"
echo " Format: $format"
# Transcribe to normalized word timestamps
local words_file=$(mktemp)
case "$provider" in
openai) transcribe_openai "$audio" "$language" "$words_file" ;;
elevenlabs) transcribe_elevenlabs "$audio" "$language" "$words_file" ;;
*) echo -e "${RED}Error: unknown provider: $provider${NC}"; echo "Supported: openai, elevenlabs"; rm -f "$words_file"; exit 1 ;;
esac
# Check if we got words
local word_count
word_count=$(jq 'length' "$words_file" 2>/dev/null)
if [ -z "$word_count" ] || [ "$word_count" = "0" ]; then
echo -e "${YELLOW}Warning: no word-level timestamps returned${NC}"
rm -f "$words_file"
return 1
fi
echo ""
echo " Words detected: $word_count"
# Convert to output format
mkdir -p "$(dirname "$output")"
case "$format" in
lrc)
words_to_lrc "$words_file" "$output"
;;
srt)
words_to_srt "$words_file" "$output"
;;
json)
cp "$words_file" "$output"
;;
*)
echo -e "${RED}Error: unknown format: $format (supported: lrc, srt, json)${NC}"
rm -f "$words_file"
exit 1
;;
esac
rm -f "$words_file"
echo -e " ${GREEN}Saved: $output${NC}"
echo ""
echo -e "${GREEN}Done!${NC}"
}
cmd_config() {
check_deps
ensure_config
local action="" key="" value=""
while [[ $# -gt 0 ]]; do
case $1 in
--get) action="get"; key="$2"; shift 2 ;;
--set) action="set"; key="$2"; value="$3"; shift 3 ;;
--reset) action="reset"; shift ;;
--list) action="list"; shift ;;
--check-key) action="check-key"; shift ;;
*) shift ;;
esac
done
case "$action" in
"check-key")
local provider=$(get_config "provider")
[ -z "$provider" ] && provider="openai"
local api_key=$(get_config "${provider}.api_key")
echo "provider: $provider"
if [ -n "$api_key" ]; then
echo "api_key: configured"
else
echo "api_key: empty"
fi
;;
"get")
[ -z "$key" ] && { echo -e "${RED}Error: --get requires KEY${NC}"; exit 1; }
local result=$(get_config "$key")
[ -n "$result" ] && echo "$key = $result" || echo "Key not found: $key"
;;
"set")
[ -z "$key" ] || [ -z "$value" ] && { echo -e "${RED}Error: --set requires KEY VALUE${NC}"; exit 1; }
set_config "$key" "$value"
;;
"reset")
echo "$DEFAULT_CONFIG" > "$CONFIG_FILE"
echo -e "${GREEN}Configuration reset to defaults.${NC}"
jq 'walk(if type == "object" and has("api_key") and (.api_key | length) > 0 then .api_key = "***" else . end)' "$CONFIG_FILE"
;;
"list")
echo "Current configuration:"
jq 'walk(if type == "object" and has("api_key") and (.api_key | length) > 0 then .api_key = "***" else . end)' "$CONFIG_FILE"
;;
*)
echo "Config file: $CONFIG_FILE"
echo "----------------------------------------"
jq 'walk(if type == "object" and has("api_key") and (.api_key | length) > 0 then .api_key = "***" else . end)' "$CONFIG_FILE"
echo ""
echo "----------------------------------------"
echo ""
echo "Usage:"
echo " config --list Show config"
echo " config --get <key> Get value"
echo " config --set <key> <val> Set value"
echo " config --reset Reset to defaults"
echo ""
echo "Examples:"
echo " config --set provider elevenlabs"
echo " config --set openai.api_key sk-..."
echo " config --set elevenlabs.api_key ..."
echo " config --set output_format srt"
;;
esac
}
show_help() {
echo "Lyrics Transcription CLI"
echo ""
echo "Requirements: curl, jq, python3"
echo ""
echo "Usage: $0 <command> [options]"
echo ""
echo "Commands:"
echo " transcribe Transcribe audio to timestamped lyrics"
echo " config Manage configuration"
echo ""
echo "Transcribe Options:"
echo " -a, --audio Audio file path (required)"
echo " -l, --language Language code (e.g. zh, en, ja)"
echo " -f, --format Output format: lrc, srt, json (default: lrc)"
echo " -p, --provider API provider: openai, elevenlabs"
echo " -o, --output Output file path"
echo ""
echo "Examples:"
echo " $0 transcribe --audio song.mp3"
echo " $0 transcribe --audio song.mp3 --language zh --format lrc"
echo " $0 config --set provider openai"
}
# ─── Main ───
case "$1" in
transcribe) shift; cmd_transcribe "$@" ;;
config) shift; cmd_config "$@" ;;
help|--help|-h) show_help ;;
*) show_help; exit 1 ;;
esac

View file

@ -0,0 +1,14 @@
{
"provider": "elevenlabs",
"output_format": "lrc",
"openai": {
"api_key": "",
"api_url": "https://api.openai.com/v1",
"model": "whisper-1"
},
"elevenlabs": {
"api_key": "",
"api_url": "https://api.elevenlabs.io/v1",
"model": "scribe_v2"
}
}

View file

@ -0,0 +1,133 @@
---
name: acestep-simplemv
description: Render music videos from audio files and lyrics using Remotion. Accepts audio + LRC/JSON lyrics + title to produce MP4 videos with waveform visualization and synced lyrics display. Use when users mention MV generation, music video rendering, creating video from audio/lyrics, or visualizing songs.
---
# MV Render
Render music videos with waveform visualization and synced lyrics from audio + lyrics input.
## Prerequisites
- Remotion project at `scripts/` directory within this skill
- Node.js + npm dependencies installed
- ffprobe available (for audio duration detection)
### First-Time Setup
Before first use, check and install dependencies:
```bash
# 1. Check Node.js
node --version
# 2. Install npm dependencies
cd {project_root}/{.claude or .codex}/skills/acestep-simplemv/scripts && npm install
# 3. Check ffprobe
ffprobe -version
```
If ffprobe is not available, install ffmpeg (which includes ffprobe):
- **Windows**: `choco install ffmpeg` or download from https://ffmpeg.org/download.html and add to PATH
- **macOS**: `brew install ffmpeg`
- **Linux**: `sudo apt-get install ffmpeg` (Debian/Ubuntu) or `sudo dnf install ffmpeg` (Fedora)
## Quick Start
```bash
cd {project_root}/{.claude or .codex}/skills/acestep-simplemv/
./scripts/render-mv.sh --audio /path/to/song.mp3 --lyrics /path/to/song.lrc --title "Song Title"
```
Output: MP4 file at `out/<audio_basename>.mp4` (or custom `--output` path).
## Script Usage
```bash
./scripts/render-mv.sh --audio <file> --lyrics <lrc_file> --title "Title" [options]
Options:
--audio Audio file path (absolute paths supported)
--lyrics LRC format lyrics file (timestamped)
--lyrics-json JSON lyrics file [{start, end, text}] (alternative to --lyrics)
--title Video title (default: "Music Video")
--subtitle Subtitle text
--credit Bottom credit text
--offset Lyric timing offset in seconds (default: -0.5)
--output Output file path (default: out/<audio_basename>.mp4)
--codec h264|h265|vp8|vp9 (default: h264)
--browser Custom browser executable path (Chrome/Edge/Chromium)
Environment variables:
BROWSER_EXECUTABLE Path to browser executable (overrides auto-detection)
```
## Browser Detection
Remotion requires a Chromium-based browser for rendering. The script auto-detects browsers in this priority order:
1. `BROWSER_EXECUTABLE` environment variable
2. `--browser` CLI argument
3. Remotion cache (`chrome-headless-shell`, downloaded by Remotion)
4. System Chrome (auto-uses `--chrome-mode=chrome-for-testing`)
5. **System Edge** (pre-installed on Windows 10/11, auto-uses `--chrome-mode=chrome-for-testing`)
6. System Chromium (auto-uses `--chrome-mode=chrome-for-testing`)
**Important**: New versions of Chrome/Edge removed the old headless mode. When using regular Chrome/Edge/Chromium, the script automatically sets `--chrome-mode=chrome-for-testing` (which uses `--headless=new`). When using `chrome-headless-shell`, it uses the default `headless-shell` mode (which uses `--headless=old`). This is handled transparently.
If no browser is found, Remotion will attempt to download `chrome-headless-shell` from Google servers. **This will fail if Google servers are inaccessible from your network.**
### Workarounds for restricted networks
Since **Edge is pre-installed on Windows 10/11**, it should be auto-detected without any manual configuration. The script automatically detects Chrome/Edge and uses the correct headless mode. If auto-detection fails:
```bash
# Option 1: Set environment variable
export BROWSER_EXECUTABLE="/path/to/msedge.exe"
# Option 2: Pass as CLI argument
./scripts/render-mv.sh --audio song.mp3 --lyrics song.lrc --title "Song" --browser "/path/to/msedge.exe"
# Option 3: Enable proxy and let Remotion download chrome-headless-shell
```
## Examples
```bash
# Basic render
./scripts/render-mv.sh --audio /tmp/abc123_1.mp3 --lyrics /tmp/abc123.lrc --title "夜桜"
# Custom output path
./scripts/render-mv.sh --audio song.mp3 --lyrics song.lrc --title "My Song" --output /tmp/my_mv.mp4
# With subtitle and credit
./scripts/render-mv.sh --audio song.mp3 --lyrics song.lrc --title "Song" --subtitle "Artist Name" --credit "Generated by ACE-Step"
```
## File Naming
**IMPORTANT**: Use the audio file's job ID as the output filename to avoid overwriting. Do NOT use custom names like `--output my_song.mp4`. Let the default naming handle it (derives from audio filename).
Default output uses the audio filename as base:
- Audio: `acestep_output/{job_id}_1.mp3`
- Lyrics: `acestep_output/{job_id}_1.lrc`
- Video: Pass `--output acestep_output/{job_id}.mp4` (use the job ID from the audio file)
Example: if audio is `chatcmpl-abc123_1.mp3`, pass `--output acestep_output/chatcmpl-abc123.mp4`
## Title Guidelines
- Keep `--title` short and single-line (max ~50 chars, auto-truncated)
- Use `--subtitle` for additional info
- Do NOT put newlines in `--title`
Good: `--title "Open Source" --subtitle "ACE-Step v1.5"`
Bad: `--title "Open Source - ACE-Step v1.5\nCelebrating Music AI"`
## Notes
- Audio files with absolute paths are auto-copied to `public/` by render.mjs
- Duration is auto-detected via ffprobe
- Typical render time: ~1-2 minutes for a 90s song
- Output resolution: 1920x1080, 30fps

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,27 @@
{
"name": "acestep-video",
"version": "1.0.0",
"description": "",
"main": "index.js",
"scripts": {
"start": "remotion preview",
"build": "remotion render MusicVideo out/video.mp4",
"render": "node render.mjs",
"upgrade": "remotion upgrade"
},
"keywords": [],
"author": "",
"license": "ISC",
"type": "commonjs",
"dependencies": {
"@remotion/cli": "^4.0.417",
"@remotion/media-utils": "^4.0.417",
"react": "^18.3.1",
"react-dom": "^18.3.1",
"remotion": "^4.0.417"
},
"devDependencies": {
"@types/react": "^19.2.13",
"typescript": "^5.9.3"
}
}

View file

@ -0,0 +1,4 @@
import {Config} from '@remotion/cli/config';
Config.setVideoImageFormat('jpeg');
Config.setOverwriteOutput(true);

View file

@ -0,0 +1,123 @@
#!/bin/bash
# render-mv.sh - Render a music video from audio + lyrics
#
# Usage:
# ./render-mv.sh --audio <file> --lyrics <lrc_file> --title "Title" [options]
#
# Options:
# --audio Audio file path (absolute or relative)
# --lyrics LRC format lyrics file
# --lyrics-json JSON lyrics file [{start, end, text}]
# --title Video title (default: "Music Video")
# --subtitle Subtitle text
# --credit Bottom credit text
# --offset Lyric timing offset in seconds (default: -0.5)
# --output Output file path (default: acestep_output/<audio_basename>.mp4)
# --codec h264|h265|vp8|vp9 (default: h264)
# --browser Custom browser executable path (Chrome/Edge/Chromium)
#
# Environment variables:
# BROWSER_EXECUTABLE Path to browser executable (overrides auto-detection)
#
# Examples:
# ./render-mv.sh --audio song.mp3 --lyrics song.lrc --title "My Song"
# ./render-mv.sh --audio /path/to/abc123_1.mp3 --lyrics /path/to/abc123.lrc --title "夜桜"
set -euo pipefail
RENDER_DIR="$(cd "$(dirname "$0")" && pwd)"
# Ensure output directory exists
mkdir -p "${RENDER_DIR}/out"
# Cross-platform realpath alternative (works on macOS/Linux/Windows MSYS2)
resolve_path() {
local dir base
dir="$(cd "$(dirname "$1")" && pwd)"
base="$(basename "$1")"
echo "${dir}/${base}"
}
AUDIO=""
LYRICS=""
LYRICS_JSON=""
TITLE="Music Video"
SUBTITLE=""
CREDIT=""
OFFSET="-0.5"
OUTPUT=""
CODEC="h264"
BROWSER=""
# Parse args
while [[ $# -gt 0 ]]; do
case "$1" in
--audio) AUDIO="$2"; shift 2 ;;
--lyrics) LYRICS="$2"; shift 2 ;;
--lyrics-json) LYRICS_JSON="$2"; shift 2 ;;
--title) TITLE="$2"; shift 2 ;;
--subtitle) SUBTITLE="$2"; shift 2 ;;
--credit) CREDIT="$2"; shift 2 ;;
--offset) OFFSET="$2"; shift 2 ;;
--output) OUTPUT="$2"; shift 2 ;;
--codec) CODEC="$2"; shift 2 ;;
--browser) BROWSER="$2"; shift 2 ;;
-h|--help)
head -20 "$0" | tail -18
exit 0
;;
*)
echo "Error: unknown argument: $1" >&2
exit 1
;;
esac
done
if [[ -z "$AUDIO" ]]; then
echo "Error: --audio is required" >&2
exit 1
fi
if [[ ! -f "$AUDIO" ]]; then
echo "Error: audio file not found: $AUDIO" >&2
exit 1
fi
# Resolve absolute path for audio
AUDIO="$(resolve_path "$AUDIO")"
# Default output: acestep_output/<audio_basename>.mp4
if [[ -z "$OUTPUT" ]]; then
BASENAME="$(basename "${AUDIO%.*}")"
# Strip trailing _1, _2 etc from audio filename for cleaner video name
OUTPUT="${RENDER_DIR}/out/${BASENAME}.mp4"
fi
# Ensure output directory exists
mkdir -p "$(dirname "$OUTPUT")"
# Build node args array (safe quoting, no eval)
NODE_ARGS=(render.mjs --audio "$AUDIO" --title "$TITLE" --offset "$OFFSET" --output "$OUTPUT" --codec "$CODEC")
if [[ -n "$LYRICS" ]]; then
LYRICS="$(resolve_path "$LYRICS")"
NODE_ARGS+=(--lyrics "$LYRICS")
elif [[ -n "$LYRICS_JSON" ]]; then
LYRICS_JSON="$(resolve_path "$LYRICS_JSON")"
NODE_ARGS+=(--lyrics-json "$LYRICS_JSON")
fi
[[ -n "$SUBTITLE" ]] && NODE_ARGS+=(--subtitle "$SUBTITLE")
[[ -n "$CREDIT" ]] && NODE_ARGS+=(--credit "$CREDIT")
[[ -n "$BROWSER" ]] && NODE_ARGS+=(--browser "$BROWSER")
echo "Rendering MV..."
echo " Audio: $(basename "$AUDIO")"
echo " Title: $TITLE"
echo " Output: $OUTPUT"
cd "$RENDER_DIR"
node "${NODE_ARGS[@]}"
echo ""
echo "Output: $OUTPUT"

View file

@ -0,0 +1,345 @@
#!/usr/bin/env node
/**
* CLI entry point for rendering music videos.
*
* Usage:
* node render.mjs --audio music.mp3 --lyrics lyrics.lrc --title "Song Name" --output out/video.mp4
* node render.mjs --audio music.mp3 --lyrics-json lyrics.json --title "Song Name"
*
* Options:
* --audio Audio file path (absolute paths auto-copied to public/) or filename in public/
* --lyrics Path to LRC format lyrics file
* --lyrics-json Path to JSON lyrics file [{start, end, text}]
* --title Main title (default: "Music Video")
* --subtitle Subtitle (default: "")
* --credit Bottom credit text (default: "")
* --duration Audio duration in seconds (auto-detected if omitted)
* --offset Lyric timing offset in seconds (default: -0.5)
* --output Output file path (default: out/video.mp4)
* --codec Video codec: h264, h265, vp8, vp9 (default: h264)
*/
import {execSync} from 'child_process';
import {readFileSync, readdirSync, existsSync, copyFileSync, mkdirSync} from 'fs';
import {resolve, basename, isAbsolute, join} from 'path';
import {homedir} from 'os';
/**
* Resolve a file path that may be a MSYS2/Cygwin-style path on Windows.
* Converts paths like /e/foo/bar to E:/foo/bar for Node.js compatibility.
*/
function resolveFilePath(p) {
if (process.platform === 'win32' && /^\/[a-zA-Z]\//.test(p)) {
// Convert MSYS2 path /x/... to X:/...
return p[1].toUpperCase() + ':' + p.slice(2);
}
return resolve(p);
}
/**
* Find a usable browser executable for Remotion rendering.
*
* Search priority:
* 1. Environment variable BROWSER_EXECUTABLE
* 2. CLI argument --browser
* 3. Remotion cache (chrome-headless-shell)
* 4. System Chrome (requires --chrome-mode=chrome-for-testing)
* 5. System Edge (requires --chrome-mode=chrome-for-testing)
* 6. System Chromium (requires --chrome-mode=chrome-for-testing)
*
* Returns {path, chromeMode} or {path: null, chromeMode: 'headless-shell'} if not found.
*
* chromeMode:
* - 'headless-shell': for chrome-headless-shell binary (uses --headless=old)
* - 'chrome-for-testing': for regular Chrome/Edge/Chromium (uses --headless=new)
*/
function findBrowserExecutable(cliOverride) {
// 1. Environment variable — highest priority
const envExe = process.env.BROWSER_EXECUTABLE;
if (envExe && existsSync(envExe)) {
const mode = isHeadlessShell(envExe) ? 'headless-shell' : 'chrome-for-testing';
return {path: envExe, chromeMode: mode};
}
// 2. CLI argument
if (cliOverride && existsSync(cliOverride)) {
const mode = isHeadlessShell(cliOverride) ? 'headless-shell' : 'chrome-for-testing';
return {path: cliOverride, chromeMode: mode};
}
const platform = process.platform;
const home = homedir();
// 3. Local node_modules/.remotion (chrome-headless-shell) — uses --headless=old
const localCacheDir = join(process.cwd(), 'node_modules', '.remotion', 'chrome-headless-shell');
if (existsSync(localCacheDir)) {
try {
// Structure: chrome-headless-shell/linux64/chrome-headless-shell-linux64/chrome-headless-shell
const platformDir = platform === 'win32' ? 'win64' : platform === 'darwin' ? 'mac-arm64' : 'linux64';
const exeName = platform === 'win32' ? 'chrome-headless-shell.exe' : 'chrome-headless-shell';
const platformPath = join(localCacheDir, platformDir);
if (existsSync(platformPath)) {
const subdirs = readdirSync(platformPath);
for (const subdir of subdirs) {
const exe = join(platformPath, subdir, exeName);
if (existsSync(exe)) return {path: exe, chromeMode: 'headless-shell'};
}
}
} catch {}
}
// 4. User home Remotion cache (chrome-headless-shell) — uses --headless=old
let cacheDir;
if (platform === 'win32') {
cacheDir = join(home, 'AppData', 'Local', 'remotion', 'chrome-headless-shell');
} else if (platform === 'darwin') {
cacheDir = join(home, 'Library', 'Caches', 'remotion', 'chrome-headless-shell');
} else {
cacheDir = join(home, '.cache', 'remotion', 'chrome-headless-shell');
}
if (existsSync(cacheDir)) {
try {
const versions = readdirSync(cacheDir).sort().reverse();
const exeName = platform === 'win32' ? 'chrome-headless-shell.exe' : 'chrome-headless-shell';
for (const ver of versions) {
const exe = join(cacheDir, ver, exeName);
if (existsSync(exe)) return {path: exe, chromeMode: 'headless-shell'};
}
} catch {}
}
// 4-6. System browsers: Chrome, Edge, Chromium — require --chrome-mode=chrome-for-testing
const browserPaths = platform === 'win32' ? [
// Chrome
'C:\\Program Files\\Google\\Chrome\\Application\\chrome.exe',
'C:\\Program Files (x86)\\Google\\Chrome\\Application\\chrome.exe',
// Edge (pre-installed on Windows 10/11)
'C:\\Program Files (x86)\\Microsoft\\Edge\\Application\\msedge.exe',
'C:\\Program Files\\Microsoft\\Edge\\Application\\msedge.exe',
] : platform === 'darwin' ? [
'/Applications/Google Chrome.app/Contents/MacOS/Google Chrome',
'/Applications/Microsoft Edge.app/Contents/MacOS/Microsoft Edge',
'/Applications/Chromium.app/Contents/MacOS/Chromium',
] : [
'/usr/bin/google-chrome',
'/usr/bin/google-chrome-stable',
'/usr/bin/chromium',
'/usr/bin/chromium-browser',
'/usr/bin/microsoft-edge',
'/usr/bin/microsoft-edge-stable',
];
for (const p of browserPaths) {
if (existsSync(p)) return {path: p, chromeMode: 'chrome-for-testing'};
}
return {path: null, chromeMode: 'headless-shell'};
}
/**
* Check if the given executable path is a chrome-headless-shell binary.
*/
function isHeadlessShell(exePath) {
const name = exePath.toLowerCase().replace(/\\/g, '/');
return name.includes('chrome-headless-shell');
}
function parseLrc(content) {
const lines = content.split(/\r?\n/).filter(l => l.trim());
const parsed = [];
for (const line of lines) {
const match = line.match(/^\[(\d{2}):(\d{2})(?:\.(\d{2,3}))?\]\s*(.*)$/);
if (match) {
const minutes = parseInt(match[1], 10);
const seconds = parseInt(match[2], 10);
const cs = match[3] ? parseInt(match[3].padEnd(3, '0'), 10) / 1000 : 0;
const time = minutes * 60 + seconds + cs;
const text = match[4].trim();
parsed.push({time, text});
}
}
const result = [];
for (let i = 0; i < parsed.length; i++) {
const start = parsed[i].time;
const end = i < parsed.length - 1 ? parsed[i + 1].time : start + 5;
if (parsed[i].text) {
result.push({start, end, text: parsed[i].text});
}
}
return result;
}
function getAudioDuration(filePath) {
try {
const result = execSync(
`ffprobe -v error -show_entries format=duration -of csv=p=0 "${filePath}"`,
{encoding: 'utf-8'}
).trim();
return parseFloat(result);
} catch {
return null;
}
}
function parseArgs(argv) {
const args = {};
for (let i = 2; i < argv.length; i++) {
const key = argv[i];
if (key.startsWith('--') && i + 1 < argv.length) {
const name = key.slice(2);
args[name] = argv[i + 1];
i++;
}
}
return args;
}
const args = parseArgs(process.argv);
// Validate required args
if (!args.audio) {
console.error('Error: --audio is required');
console.error('Usage: node render.mjs --audio music.mp3 --lyrics lyrics.lrc --title "Song"');
process.exit(1);
}
// If audio is an absolute path, copy it into public/ and use the filename
let audioFileName = args.audio;
const resolvedAudio = resolveFilePath(args.audio);
if (isAbsolute(resolvedAudio)) {
if (!existsSync(resolvedAudio)) {
console.error(`Error: Audio file not found: ${resolvedAudio}`);
process.exit(1);
}
const pubDir = resolve('public');
mkdirSync(pubDir, {recursive: true});
const fname = basename(resolvedAudio);
const dest = resolve(pubDir, fname);
if (resolve(resolvedAudio) !== dest) {
copyFileSync(resolvedAudio, dest);
console.log(`Copied audio to public/${fname}`);
}
audioFileName = fname;
} else {
// Relative name — must exist in public/
const audioPath = resolve('public', args.audio);
if (!existsSync(audioPath)) {
console.error(`Error: Audio file not found in public/: ${args.audio}`);
process.exit(1);
}
}
// Parse lyrics
let lyrics = [];
if (args.lyrics) {
const lrcPath = resolveFilePath(args.lyrics);
if (!existsSync(lrcPath)) {
console.error(`Error: LRC file not found: ${lrcPath}`);
process.exit(1);
}
lyrics = parseLrc(readFileSync(lrcPath, 'utf-8'));
console.log(`Parsed ${lyrics.length} lyric lines from LRC file`);
} else if (args['lyrics-json']) {
const jsonPath = resolveFilePath(args['lyrics-json']);
if (!existsSync(jsonPath)) {
console.error(`Error: JSON lyrics file not found: ${jsonPath}`);
process.exit(1);
}
lyrics = JSON.parse(readFileSync(jsonPath, 'utf-8'));
console.log(`Loaded ${lyrics.length} lyric lines from JSON file`);
}
// Determine audio duration
let duration = args.duration ? parseFloat(args.duration) : null;
if (!duration) {
const audioPath = resolve('public', audioFileName);
if (existsSync(audioPath)) {
duration = getAudioDuration(audioPath);
if (duration) {
console.log(`Auto-detected audio duration: ${duration.toFixed(2)}s`);
}
}
}
if (!duration) {
console.error('Error: Could not detect audio duration. Please provide --duration');
process.exit(1);
}
// Build input props
// Sanitize title: single-line, max 50 chars
const rawTitle = (args.title || 'Music Video').replace(/[\r\n]+/g, ' ').trim();
const title = rawTitle.length > 50 ? rawTitle.slice(0, 47) + '...' : rawTitle;
const inputProps = {
audioFileName: audioFileName,
lyrics,
title,
subtitle: (args.subtitle || '').replace(/[\r\n]+/g, ' ').trim(),
creditText: args.credit || '',
durationInSeconds: duration,
lyricOffset: args.offset ? parseFloat(args.offset) : -0.5,
};
const output = args.output ? resolveFilePath(args.output) : 'out/video.mp4';
const codec = args.codec || 'h264';
// Write props to temp file to avoid shell escaping issues
const propsFile = resolve('.render-props.json');
const {writeFileSync} = await import('fs');
writeFileSync(propsFile, JSON.stringify(inputProps));
// Find browser executable to avoid re-downloading
const {path: browserExe, chromeMode} = findBrowserExecutable(args.browser);
if (!browserExe) {
console.warn('⚠️ No browser found. Remotion will attempt to download chrome-headless-shell from Google servers.');
console.warn(' If download fails (e.g. Google servers inaccessible), try one of these:');
console.warn(' 1. Set environment variable: BROWSER_EXECUTABLE=/path/to/chrome-or-edge');
console.warn(' 2. Pass CLI argument: --browser /path/to/chrome-or-edge');
console.warn(' 3. Enable proxy and retry');
console.warn('');
}
const cmd = [
'npx remotion render',
'MusicVideo',
`"${output}"`,
`--props="${propsFile}"`,
`--codec=${codec}`,
'--log=error',
browserExe ? `--browser-executable="${browserExe}"` : '',
chromeMode !== 'headless-shell' ? `--chrome-mode=${chromeMode}` : '',
].filter(Boolean).join(' ');
console.log(`\nRendering video...`);
console.log(` Audio: ${args.audio}`);
console.log(` Title: ${inputProps.title}`);
console.log(` Duration: ${duration.toFixed(1)}s`);
console.log(` Lyrics: ${lyrics.length} lines`);
console.log(` Output: ${output}`);
console.log(` Codec: ${codec}`);
if (browserExe) console.log(` Browser: ${browserExe}`);
if (chromeMode !== 'headless-shell') console.log(` Chrome mode: ${chromeMode}`);
console.log('');
try {
const result = execSync(cmd, {encoding: 'utf-8', stdio: ['pipe', 'pipe', 'pipe']});
// Only show the final output file line (starts with '+') and size info
const outputLines = result.split(/\r?\n/).filter(l => l.includes(output) || /^\+/.test(l.replace(/\x1b\[[0-9;]*m/g, '').trim()));
if (outputLines.length) console.log(outputLines.join('\n'));
console.log(`\n✅ Video rendered successfully: ${output}`);
} catch (e) {
// Show stderr on failure for debugging
if (e.stderr) console.error(e.stderr.toString());
console.error('\n❌ Render failed');
process.exit(1);
} finally {
// Clean up temp props file
try {
const {unlinkSync} = await import('fs');
unlinkSync(propsFile);
} catch {}
}

View file

@ -0,0 +1,12 @@
#!/bin/bash
# render.sh - Convenience wrapper for rendering music videos
#
# Usage:
# ./render.sh --audio music.mp3 --lyrics lyrics.lrc --title "Song Name"
# ./render.sh --audio music.mp3 --lyrics-json lyrics.json --title "Song" --output out/mv.mp4
#
# All options are passed through to render.mjs. See render.mjs for full options list.
set -e
cd "$(dirname "$0")"
node render.mjs "$@"

View file

@ -0,0 +1,314 @@
import React from 'react';
import {
AbsoluteFill,
Audio,
useCurrentFrame,
useVideoConfig,
interpolate,
Easing,
staticFile,
} from 'remotion';
import {useAudioData, visualizeAudio} from '@remotion/media-utils';
import {MVInputProps} from './types';
export const AudioVisualization: React.FC<MVInputProps> = ({
audioFileName,
lyrics,
title,
subtitle,
creditText,
lyricOffset,
}) => {
const frame = useCurrentFrame();
const {fps, durationInFrames} = useVideoConfig();
const audioSrc = audioFileName.startsWith('http')
? audioFileName
: staticFile(audioFileName);
const audioData = useAudioData(audioSrc);
if (!audioData) {
return null;
}
const visualization = visualizeAudio({
fps,
frame,
audioData,
numberOfSamples: 128,
optimizeFor: 'speed',
});
const currentTime = frame / fps + lyricOffset;
const currentLyric = lyrics.find(
(lyric) => currentTime >= lyric.start && currentTime < lyric.end
);
const lyricProgress = currentLyric
? interpolate(
currentTime,
[currentLyric.start, currentLyric.start + 0.3],
[0, 1],
{extrapolateRight: 'clamp'}
)
: 0;
const titleOpacity = interpolate(frame, [0, 30], [0, 1], {
extrapolateRight: 'clamp',
});
const titleY = interpolate(frame, [0, 30], [-50, 0], {
extrapolateRight: 'clamp',
easing: Easing.out(Easing.ease),
});
const hue = interpolate(frame, [0, durationInFrames], [200, 320], {
extrapolateRight: 'wrap',
});
const avgAmplitude =
visualization.reduce((sum, val) => sum + val, 0) / visualization.length;
return (
<AbsoluteFill>
{/* Animated gradient background */}
<AbsoluteFill
style={{
background: `linear-gradient(135deg, hsl(${hue}, 80%, 12%) 0%, hsl(${hue + 80}, 70%, 8%) 100%)`,
}}
/>
{/* Radial glow effect */}
<AbsoluteFill
style={{
background: `radial-gradient(circle at 50% 50%, hsla(${hue}, 100%, 50%, ${avgAmplitude * 0.3}) 0%, transparent 50%)`,
}}
/>
{/* Audio source */}
<Audio src={audioSrc} />
{/* Bottom frequency bars */}
<AbsoluteFill
style={{
justifyContent: 'flex-end',
alignItems: 'center',
}}
>
<div
style={{
display: 'flex',
alignItems: 'flex-end',
justifyContent: 'center',
gap: 4,
height: 350,
width: '90%',
marginBottom: 180,
}}
>
{visualization.map((value, index) => {
const scaledValue = Math.pow(value, 0.6);
const barHeight = Math.max(scaledValue * 800, 20);
const colorIndex = (index / visualization.length) * 360;
return (
<div
key={index}
style={{
width: `${100 / visualization.length}%`,
height: barHeight,
background: `linear-gradient(to top,
hsl(${(colorIndex + hue) % 360}, 90%, 60%),
hsl(${(colorIndex + hue + 40) % 360}, 90%, 70%))`,
borderRadius: '4px 4px 0 0',
boxShadow: `0 0 ${10 + scaledValue * 30}px hsla(${(colorIndex + hue) % 360}, 100%, 60%, ${scaledValue})`,
transition: 'height 0.05s ease-out',
}}
/>
);
})}
</div>
</AbsoluteFill>
{/* Symmetrical side bars */}
<AbsoluteFill
style={{
justifyContent: 'center',
alignItems: 'center',
}}
>
{/* Left bars */}
<div
style={{
position: 'absolute',
left: 40,
display: 'flex',
flexDirection: 'column',
gap: 8,
height: '80%',
justifyContent: 'space-around',
}}
>
{visualization.slice(0, 20).map((value, index) => {
const scaledValue = Math.pow(value, 0.6);
const barWidth = Math.max(scaledValue * 300, 10);
const colorIndex = (index / 20) * 360;
return (
<div
key={index}
style={{
width: barWidth,
height: 12,
background: `linear-gradient(to right,
hsl(${(colorIndex + hue) % 360}, 90%, 60%),
hsl(${(colorIndex + hue + 40) % 360}, 90%, 70%))`,
borderRadius: '0 6px 6px 0',
boxShadow: `0 0 ${10 + scaledValue * 20}px hsla(${(colorIndex + hue) % 360}, 100%, 60%, ${scaledValue})`,
}}
/>
);
})}
</div>
{/* Right bars */}
<div
style={{
position: 'absolute',
right: 40,
display: 'flex',
flexDirection: 'column',
gap: 8,
height: '80%',
justifyContent: 'space-around',
alignItems: 'flex-end',
}}
>
{visualization.slice(0, 20).map((value, index) => {
const scaledValue = Math.pow(value, 0.6);
const barWidth = Math.max(scaledValue * 300, 10);
const colorIndex = (index / 20) * 360;
return (
<div
key={index}
style={{
width: barWidth,
height: 12,
background: `linear-gradient(to left,
hsl(${(colorIndex + hue + 180) % 360}, 90%, 60%),
hsl(${(colorIndex + hue + 220) % 360}, 90%, 70%))`,
borderRadius: '6px 0 0 6px',
boxShadow: `0 0 ${10 + scaledValue * 20}px hsla(${(colorIndex + hue + 180) % 360}, 100%, 60%, ${scaledValue})`,
}}
/>
);
})}
</div>
</AbsoluteFill>
{/* Center title area */}
<AbsoluteFill
style={{
justifyContent: 'flex-start',
alignItems: 'center',
paddingTop: 60,
}}
>
<div
style={{
textAlign: 'center',
transform: `scale(${1 + avgAmplitude * 0.1})`,
transition: 'transform 0.1s ease-out',
}}
>
<div
style={{
fontSize: 96,
fontWeight: 'bold',
color: 'white',
opacity: titleOpacity,
transform: `translateY(${titleY}px)`,
textShadow: `0 0 40px hsla(${hue}, 100%, 70%, 0.8), 0 4px 20px rgba(0,0,0,0.5)`,
fontFamily: '"Noto Sans CJK JP", "Noto Sans CJK SC", Arial, sans-serif',
marginBottom: 10,
}}
>
{title}
</div>
<div
style={{
fontSize: 56,
fontWeight: '600',
color: 'rgba(255,255,255,0.95)',
opacity: titleOpacity,
transform: `translateY(${titleY}px)`,
textShadow: `0 0 30px hsla(${hue + 60}, 100%, 70%, 0.6), 0 2px 10px rgba(0,0,0,0.5)`,
fontFamily: '"Noto Sans CJK JP", "Noto Sans CJK SC", Arial, sans-serif',
letterSpacing: '4px',
}}
>
{subtitle}
</div>
</div>
</AbsoluteFill>
{/* Lyrics display */}
{currentLyric && currentLyric.text && (
<AbsoluteFill
style={{
justifyContent: 'center',
alignItems: 'center',
paddingTop: 100,
}}
>
<div
style={{
fontSize: 48,
fontWeight: '600',
color: 'white',
textAlign: 'center',
maxWidth: '85%',
opacity: lyricProgress,
transform: `translateY(${(1 - lyricProgress) * 30}px)`,
textShadow: `0 0 40px hsla(${hue}, 100%, 70%, 0.8), 0 4px 30px rgba(0,0,0,0.9)`,
fontFamily: '"Noto Sans CJK JP", "Noto Sans CJK SC", Arial, sans-serif',
lineHeight: 1.5,
padding: '25px 50px',
background: `linear-gradient(135deg, rgba(0,0,0,0.4), rgba(0,0,0,0.2))`,
backdropFilter: 'blur(15px)',
borderRadius: '20px',
border: `2px solid hsla(${hue}, 80%, 60%, 0.3)`,
boxShadow: `0 8px 32px rgba(0,0,0,0.5), inset 0 0 40px hsla(${hue}, 100%, 50%, 0.1)`,
}}
>
{currentLyric.text}
</div>
</AbsoluteFill>
)}
{/* Bottom credit text */}
<AbsoluteFill
style={{
justifyContent: 'flex-end',
alignItems: 'center',
padding: 50,
}}
>
<div
style={{
fontSize: 32,
fontWeight: '500',
color: 'white',
opacity: 0.8,
textAlign: 'center',
textShadow: `0 0 20px hsla(${hue}, 100%, 70%, 0.6), 0 2px 10px rgba(0,0,0,0.7)`,
fontFamily: '"Noto Sans CJK JP", "Noto Sans CJK SC", Arial, sans-serif',
}}
>
{creditText}
</div>
</AbsoluteFill>
</AbsoluteFill>
);
};

View file

@ -0,0 +1,31 @@
import React from 'react';
import {Composition, CalculateMetadataFunction} from 'remotion';
import {AudioVisualization} from './AudioVisualization';
import {MVInputProps, defaultProps} from './types';
const calculateMetadata: CalculateMetadataFunction<MVInputProps> = ({props}) => {
const fps = 30;
const durationInFrames = Math.ceil(props.durationInSeconds * fps);
return {
durationInFrames,
fps,
width: 1920,
height: 1080,
};
};
export const RemotionRoot: React.FC = () => {
return (
<>
<Composition
id="MusicVideo"
component={AudioVisualization}
fps={30}
width={1920}
height={1080}
defaultProps={defaultProps}
calculateMetadata={calculateMetadata}
/>
</>
);
};

View file

@ -0,0 +1,4 @@
import {registerRoot} from 'remotion';
import {RemotionRoot} from './Root';
registerRoot(RemotionRoot);

View file

@ -0,0 +1,40 @@
import {LyricLine} from './types';
/**
* Parse LRC format lyrics into LyricLine array.
* LRC format: [mm:ss.xx] lyrics text
*
* Example:
* [00:02.99] Version one point five is here today
* [00:07.00] ACE-Step's rising, leading the way
*/
export function parseLrc(lrcContent: string): LyricLine[] {
const lines = lrcContent.split('\n').filter((line) => line.trim());
const parsed: {time: number; text: string}[] = [];
for (const line of lines) {
// Match [mm:ss.xx] or [mm:ss] format
const match = line.match(/^\[(\d{2}):(\d{2})(?:\.(\d{2,3}))?\]\s*(.*)$/);
if (match) {
const minutes = parseInt(match[1], 10);
const seconds = parseInt(match[2], 10);
const centiseconds = match[3] ? parseInt(match[3].padEnd(3, '0'), 10) / 1000 : 0;
const time = minutes * 60 + seconds + centiseconds;
const text = match[4].trim();
parsed.push({time, text});
}
}
// Convert to LyricLine with start/end
const result: LyricLine[] = [];
for (let i = 0; i < parsed.length; i++) {
const start = parsed[i].time;
const end = i < parsed.length - 1 ? parsed[i + 1].time : start + 5;
const text = parsed[i].text;
if (text) {
result.push({start, end, text});
}
}
return result;
}

View file

@ -0,0 +1,32 @@
export interface LyricLine {
start: number;
end: number;
text: string;
}
export interface MVInputProps extends Record<string, unknown> {
/** Path to audio file (relative to public/ or absolute URL) */
audioFileName: string;
/** Lyrics as JSON array [{start, end, text}] */
lyrics: LyricLine[];
/** Main title displayed at top */
title: string;
/** Subtitle displayed below title */
subtitle: string;
/** Bottom credit text */
creditText: string;
/** Audio duration in seconds (used to calculate total frames) */
durationInSeconds: number;
/** Lyric timing offset in seconds (positive = delay, negative = advance) */
lyricOffset: number;
}
export const defaultProps: MVInputProps = {
audioFileName: 'celebration.mp3',
lyrics: [],
title: 'ACE-Step',
subtitle: 'v1.5',
creditText: 'Powered by Claude Code + ACE-Step',
durationInSeconds: 150,
lyricOffset: -0.5,
};

View file

@ -0,0 +1,18 @@
{
"compilerOptions": {
"target": "ES2022",
"module": "ES2022",
"moduleResolution": "Bundler",
"lib": ["DOM", "ES2022"],
"jsx": "react-jsx",
"skipLibCheck": true,
"strict": true,
"esModuleInterop": true,
"allowSyntheticDefaultImports": true,
"forceConsistentCasingInFileNames": true,
"resolveJsonModule": true,
"isolatedModules": true,
"noEmit": true
},
"include": ["src/**/*"]
}

View file

@ -0,0 +1,194 @@
---
name: acestep-songwriting
description: Music songwriting guide for ACE-Step. Provides professional knowledge on writing captions, lyrics, choosing BPM/key/duration, and structuring songs. Use this skill when users want to create, write, or plan a song before generating it with ACE-Step.
allowed-tools: Read
---
# ACE-Step Songwriting Guide
Professional music creation knowledge for writing captions, lyrics, and choosing music parameters for ACE-Step.
## Output Format
After using this guide, produce two things for the acestep skill:
1. **Caption** (`-c`): Style/genre/instruments/emotion description
2. **Lyrics** (`-l`): Complete structured lyrics with tags
3. **Parameters**: `--duration`, `--bpm`, `--key`, `--time-signature`, `--language`
---
## Caption: The Most Important Input
**Caption is the most important factor affecting generated music.**
Supports multiple formats: simple style words, comma-separated tags, complex natural language descriptions.
### Common Dimensions
| Dimension | Examples |
|-----------|----------|
| **Style/Genre** | pop, rock, jazz, electronic, hip-hop, R&B, folk, classical, lo-fi, synthwave |
| **Emotion/Atmosphere** | melancholic, uplifting, energetic, dreamy, dark, nostalgic, euphoric, intimate |
| **Instruments** | acoustic guitar, piano, synth pads, 808 drums, strings, brass, electric bass |
| **Timbre Texture** | warm, bright, crisp, muddy, airy, punchy, lush, raw, polished |
| **Era Reference** | 80s synth-pop, 90s grunge, 2010s EDM, vintage soul, modern trap |
| **Production Style** | lo-fi, high-fidelity, live recording, studio-polished, bedroom pop |
| **Vocal Characteristics** | female vocal, male vocal, breathy, powerful, falsetto, raspy, choir |
| **Speed/Rhythm** | slow tempo, mid-tempo, fast-paced, groovy, driving, laid-back |
| **Structure Hints** | building intro, catchy chorus, dramatic bridge, fade-out ending |
### Caption Writing Principles
1. **Specific beats vague** — "sad piano ballad with female breathy vocal" > "a sad song"
2. **Combine multiple dimensions** — style+emotion+instruments+timbre anchors direction precisely
3. **Use references well** — "in the style of 80s synthwave" conveys complex aesthetic quickly
4. **Texture words are useful** — warm, crisp, airy, punchy influence mixing and timbre
5. **Don't pursue perfection** — Caption is a starting point, iterate based on results
6. **Granularity determines freedom** — Less detail = more model creativity; more detail = more control
7. **Avoid conflicting words** — "classical strings" + "hardcore metal" degrades output
- **Fix: Repetition reinforcement** — Repeat the elements you want more
- **Fix: Conflict to evolution** — "Start with soft strings, middle becomes metal rock, end turns to hip-hop"
8. **Don't put BPM/key/tempo in Caption** — Use dedicated parameters instead
---
## Lyrics: The Temporal Script
Lyrics controls how music unfolds over time. It carries:
- Lyric text itself
- **Structure tags** ([Verse], [Chorus], [Bridge]...)
- **Vocal style hints** ([raspy vocal], [whispered]...)
- **Instrumental sections** ([guitar solo], [drum break]...)
- **Energy changes** ([building energy], [explosive drop]...)
### Structure Tags
| Category | Tag | Description |
|----------|-----|-------------|
| **Basic Structure** | `[Intro]` | Opening, establish atmosphere |
| | `[Verse]` / `[Verse 1]` | Verse, narrative progression |
| | `[Pre-Chorus]` | Pre-chorus, build energy |
| | `[Chorus]` | Chorus, emotional climax |
| | `[Bridge]` | Bridge, transition or elevation |
| | `[Outro]` | Ending, conclusion |
| **Dynamic Sections** | `[Build]` | Energy gradually rising |
| | `[Drop]` | Electronic music energy release |
| | `[Breakdown]` | Reduced instrumentation, space |
| **Instrumental** | `[Instrumental]` | Pure instrumental, no vocals |
| | `[Guitar Solo]` | Guitar solo |
| | `[Piano Interlude]` | Piano interlude |
| **Special** | `[Fade Out]` | Fade out ending |
| | `[Silence]` | Silence |
### Combining Tags
Use `-` for finer control, but keep it concise:
```
✅ [Chorus - anthemic]
❌ [Chorus - anthemic - stacked harmonies - high energy - powerful - epic]
```
Put complex style descriptions in Caption, not in tags.
### Caption-Lyrics Consistency
**Models are not good at resolving conflicts.** Checklist:
- Instruments in Caption ↔ Instrumental section tags in Lyrics
- Emotion in Caption ↔ Energy tags in Lyrics
- Vocal description in Caption ↔ Vocal control tags in Lyrics
### Vocal Control Tags
| Tag | Effect |
|-----|--------|
| `[raspy vocal]` | Raspy, textured vocals |
| `[whispered]` | Whispered |
| `[falsetto]` | Falsetto |
| `[powerful belting]` | Powerful, high-pitched singing |
| `[spoken word]` | Rap/recitation |
| `[harmonies]` | Layered harmonies |
| `[call and response]` | Call and response |
| `[ad-lib]` | Improvised embellishments |
### Energy and Emotion Tags
| Tag | Effect |
|-----|--------|
| `[high energy]` | High energy, passionate |
| `[low energy]` | Low energy, restrained |
| `[building energy]` | Increasing energy |
| `[explosive]` | Explosive energy |
| `[melancholic]` | Melancholic |
| `[euphoric]` | Euphoric |
| `[dreamy]` | Dreamy |
| `[aggressive]` | Aggressive |
### Lyric Writing Tips
1. **6-10 syllables per line** — Model aligns syllables to beats; keep similar counts for lines in same position (±1-2)
2. **Uppercase = stronger intensity**`WE ARE THE CHAMPIONS!` (shouting) vs `walking through the streets` (normal)
3. **Parentheses = background vocals**`We rise together (together)`
4. **Extend vowels**`Feeeling so aliiive` (use cautiously, effects unstable)
5. **Clear section separation** — Blank lines between sections
### Avoiding "AI-flavored" Lyrics
| Red Flag | Description |
|----------|-------------|
| **Adjective stacking** | "neon skies, electric hearts, endless dreams" — vague imagery filler |
| **Rhyme chaos** | Inconsistent patterns or forced rhymes breaking meaning |
| **Blurred boundaries** | Lyric content crosses structure tags |
| **No breathing room** | Lines too long to sing in one breath |
| **Mixed metaphors** | Water → fire → flying — listeners can't anchor |
**Metaphor discipline**: One core metaphor per song, explore its multiple aspects.
---
## Music Metadata
**Most of the time, let LM auto-infer.** Only set manually when you have clear requirements.
| Parameter | Range | Description |
|-----------|-------|-------------|
| `bpm` | 30300 | Slow 6080, mid 90120, fast 130180 |
| `keyscale` | Key | e.g. `C Major`, `Am`. Common keys (C, G, D, Am, Em) most stable |
| `timesignature` | Time sig | `4/4` (most common), `3/4` (waltz), `6/8` (swing) |
| `vocal_language` | Language | Usually auto-detected from lyrics |
| `duration` | Seconds | See duration calculation below |
### When to Set Manually
| Scenario | Set |
|----------|-----|
| Daily generation | Let LM auto-infer |
| Clear tempo requirement | `bpm` |
| Specific style (waltz) | `timesignature=3/4` |
| Match other material | `bpm` + `duration` |
| Specific key color | `keyscale` |
---
## Duration Calculation
### Estimation Method
- **Intro/Outro**: 5-10 seconds each
- **Instrumental sections**: 5-15 seconds each
- **Typical structures**:
- 2 verses + 2 choruses: 120-150s minimum
- 2 verses + 2 choruses + bridge: 180-240s minimum
- Full song with intro/outro: 210-270s (3.5-4.5 min)
### BPM and Duration Relationship
- **Slower BPM (60-80)**: Need MORE duration for same lyrics
- **Medium BPM (100-130)**: Standard duration
- **Faster BPM (150-180)**: Can fit more lyrics, but still need breathing room
**Rule of thumb**: When in doubt, estimate longer. A song too short feels rushed.
---
Note: Lyrics tags (piano, powerful, whispered) are consistent with Caption (piano ballad, building to powerful chorus, intimate).

View file

@ -6,154 +6,74 @@ allowed-tools: Read, Write, Bash, Skill
# ACE-Step Music Generation Skill
Use ACE-Step V1.5 API for music generation. Script: `scripts/acestep.sh` (requires curl + jq).
Use ACE-Step V1.5 API for music generation. **Always use `scripts/acestep.sh` script** — do NOT call API endpoints directly.
## Prerequisites - ACE-Step API Service
**IMPORTANT**: This skill requires the ACE-Step API server to be running.
### Required Dependencies
The `scripts/acestep.sh` script requires the following tools:
**1. curl** - For making HTTP requests to the API
**2. jq** - For parsing JSON responses
#### Check Dependencies
Before using this skill, verify that the required tools are installed:
## Quick Start
```bash
# Check curl
curl --version
# 1. cd to this skill's directory
cd {project_root}/{.claude or .codex}/skills/acestep/
# Check jq
jq --version
# 2. Check API service health
./scripts/acestep.sh health
# 3. Generate with lyrics (recommended)
./scripts/acestep.sh generate -c "pop, female vocal, piano" -l "[Verse] Your lyrics here..." --duration 120 --language zh
# 4. Output saved to: {project_root}/acestep_output/
```
#### Installing jq
## Workflow
If jq is not installed, the script will attempt to install it automatically. If automatic installation fails, install manually:
**Windows:**
```bash
# Using Chocolatey
choco install jq
# Or download from: https://jqlang.github.io/jq/download/
# Extract jq.exe and add to PATH
```
**macOS:**
```bash
# Using Homebrew
brew install jq
# Using MacPorts
port install jq
```
**Linux:**
```bash
# Debian/Ubuntu
sudo apt-get install jq
# Fedora/RHEL/CentOS
sudo yum install jq
# or
sudo dnf install jq
# Arch Linux
sudo pacman -S jq
```
**Verification:**
```bash
jq --version
# Should output: jq-1.x
```
If user reports jq installation issues, guide them through manual installation for their platform.
### Before First Use
**Ask the user about their setup:**
1. **"Do you have ACE-Step API service configured and running?"**
If **YES**:
- Verify the API endpoint: `curl -s http://127.0.0.1:8001/health`
- If using remote service, ask for the API URL and update `scripts/config.json`
- Proceed with music generation
If **NO** or **NOT SURE**:
- Ask: "Do you have ACE-Step installed?"
**If installed but not running**:
- Use the acestep-docs skill to help them start the service
- Guide them through startup process
**If not installed**:
- Offer to help download and install ACE-Step
- Ask: "Would you like to use the Windows portable package or install from source?"
- Use acestep-docs skill to guide through installation
### Service Configuration
**Local Service (Default):**
```json
{
"api_url": "http://127.0.0.1:8001",
"api_key": ""
}
```
**Remote Service:**
```json
{
"api_url": "http://your-server-ip:8001",
"api_key": "your-api-key-if-needed"
}
```
To configure remote service, update `scripts/config.json` or use:
```bash
cd {skill_directory}/scripts/
./acestep.sh config --set api_url "http://remote-server:8001"
./acestep.sh config --set api_key "your-key"
```
### Using acestep-docs Skill for Setup Help
**IMPORTANT**: For installation and startup, always use the acestep-docs skill to get complete and accurate guidance.
When user needs help with installation or startup, invoke the acestep-docs skill:
```
Use the Skill tool to invoke: acestep-docs
```
**DO NOT provide simplified startup commands** - each user's environment may be different. Always guide them to use acestep-docs for proper setup.
### Health Check
**To verify if service is running:**
```bash
curl http://127.0.0.1:8001/health
# Should return: {"status":"ok",...}
```
If health check fails, use acestep-docs skill to help user start the service properly.
---
**WORKFLOW**: For user requests requiring vocals, you should:
1. Consult [Music Creation Guide](./music-creation-guide.md) for lyrics writing, caption creation, duration/BPM/key selection
2. Write complete, well-structured lyrics yourself based on the guide
For user requests requiring vocals:
1. Use the **acestep-songwriting** skill for lyrics writing, caption creation, duration/BPM/key selection
2. Write complete, well-structured lyrics yourself based on the songwriting guide
3. Generate using Caption mode with `-c` and `-l` parameters
Only use Simple/Random mode (`-d` or `random`) for quick inspiration or instrumental exploration.
If the user needs a simple music video, use the **acestep-simplemv** skill to render one with waveform visualization and synced lyrics.
**MV Production Requirements**: Making a simple MV requires three additional skills to be installed:
- **acestep-songwriting** — for writing lyrics and planning song structure
- **acestep-lyrics-transcription** — for transcribing audio to timestamped lyrics (LRC)
- **acestep-simplemv** — for rendering the final music video
## Script Commands
**CRITICAL - Complete Lyrics Input**: When providing lyrics via the `-l` parameter, you MUST pass ALL lyrics content WITHOUT any omission:
- If user provides lyrics, pass the ENTIRE text they give you
- If you generate lyrics yourself, pass the COMPLETE lyrics you created
- NEVER truncate, shorten, or pass only partial lyrics
- Missing lyrics will result in incomplete or incoherent songs
**Music Parameters**: Use the **acestep-songwriting** skill for guidance on duration, BPM, key scale, and time signature.
```bash
# need to cd to this skill's directory first
cd {project_root}/{.claude or .codex}/skills/acestep/
# Caption mode - RECOMMENDED: Write lyrics first, then generate
./scripts/acestep.sh generate -c "Electronic pop, energetic synths" -l "[Verse] Your complete lyrics
[Chorus] Full chorus here..." --duration 120 --bpm 128
# Instrumental only
./scripts/acestep.sh generate "Jazz with saxophone"
# Quick exploration (Simple/Random mode)
./scripts/acestep.sh generate -d "A cheerful song about spring"
./scripts/acestep.sh random
# Options
./scripts/acestep.sh generate "Rock" --duration 60 --batch 2
./scripts/acestep.sh generate "EDM" --no-thinking # Faster
# Other commands
./scripts/acestep.sh status <job_id>
./scripts/acestep.sh health
./scripts/acestep.sh models
```
## Output Files
After generation, the script automatically saves results to the `acestep_output` folder in the project root (same level as `.claude`):
@ -190,41 +110,6 @@ project_root/
To get the actual synthesized lyrics, parse the JSON and read the top-level `lyrics` field, not `metas.lyrics`.
## Script Commands
**CRITICAL - Complete Lyrics Input**: When providing lyrics via the `-l` parameter, you MUST pass ALL lyrics content WITHOUT any omission:
- If user provides lyrics, pass the ENTIRE text they give you
- If you generate lyrics yourself, pass the COMPLETE lyrics you created
- NEVER truncate, shorten, or pass only partial lyrics
- Missing lyrics will result in incomplete or incoherent songs
**Music Parameters**: Refer to [Music Creation Guide](./music-creation-guide.md) for how to calculate duration, choose BPM, key scale, and time signature.
```bash
# need to cd skills path
cd {project_root}/{.claude or .codex}/skills/acestep/
# Caption mode - RECOMMENDED: Write lyrics first, then generate
./scripts/acestep.sh generate -c "Electronic pop, energetic synths" -l "[Verse] Your complete lyrics
[Chorus] Full chorus here..." --duration 120 --bpm 128
# Instrumental only
./scripts/acestep.sh generate "Jazz with saxophone"
# Quick exploration (Simple/Random mode)
./scripts/acestep.sh generate -d "A cheerful song about spring"
./scripts/acestep.sh random
# Options
./scripts/acestep.sh generate "Rock" --duration 60 --batch 2
./scripts/acestep.sh generate "EDM" --no-thinking # Faster
# Other commands
./scripts/acestep.sh status <job_id>
./scripts/acestep.sh health
./scripts/acestep.sh models
```
## Configuration
**Important**: Configuration follows this priority (high to low):
@ -239,6 +124,7 @@ cd {project_root}/{.claude or .codex}/skills/acestep/
{
"api_url": "http://127.0.0.1:8001",
"api_key": "",
"api_mode": "completion",
"generation": {
"thinking": true,
"use_format": false,
@ -255,102 +141,113 @@ cd {project_root}/{.claude or .codex}/skills/acestep/
|--------|---------|-------------|
| `api_url` | `http://127.0.0.1:8001` | API server address |
| `api_key` | `""` | API authentication key (optional) |
| `api_mode` | `completion` | API mode: `completion` (OpenRouter, default) or `native` (polling) |
| `generation.thinking` | `true` | Enable 5Hz LM (higher quality, slower) |
| `generation.audio_format` | `mp3` | Output format (mp3/wav/flac) |
| `generation.vocal_language` | `en` | Vocal language |
## API Reference
## Prerequisites - ACE-Step API Service
All responses wrapped: `{"data": <payload>, "code": 200, "error": null, "timestamp": ...}`
**IMPORTANT**: This skill requires the ACE-Step API server to be running.
| Endpoint | Method | Description |
|----------|--------|-------------|
| `/health` | GET | Health check |
| `/release_task` | POST | Create generation task |
| `/query_result` | POST | Query task status, body: `{"task_id_list": ["id"]}` |
| `/v1/models` | GET | List available models |
| `/v1/audio?path={path}` | GET | Download audio file |
### Required Dependencies
### Query Result Response
The `scripts/acestep.sh` script requires: **curl** and **jq**.
```json
{
"data": [{
"task_id": "xxx",
"status": 1,
"result": "[{\"file\":\"/v1/audio?path=...\",\"metas\":{\"bpm\":120,\"duration\":60,\"keyscale\":\"C Major\"}}]"
}]
}
```bash
# Check dependencies
curl --version
jq --version
```
Status codes: `0` = processing, `1` = success, `2` = failed
If jq is not installed, the script will attempt to install it automatically. If automatic installation fails:
- **Windows**: `choco install jq` or download from https://jqlang.github.io/jq/download/
- **macOS**: `brew install jq`
- **Linux**: `sudo apt-get install jq` (Debian/Ubuntu) or `sudo dnf install jq` (Fedora)
## Request Parameters (`/release_task`)
### Before First Use
Parameters can be placed in `param_obj` object.
**You MUST check the API key and URL status before proceeding.** Run:
### Generation Modes
| Mode | Usage | When to Use |
|------|-------|-------------|
| **Caption** (Recommended) | `generate -c "style" -l "lyrics"` | For vocal songs - write lyrics yourself first |
| **Simple** | `generate -d "description"` | Quick exploration, LM generates everything |
| **Random** | `random` | Random generation for inspiration |
### Core Parameters
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `prompt` | string | "" | Music style description (Caption mode) |
| `lyrics` | string | "" | **Full lyrics content** - Pass ALL lyrics without omission. Use `[inst]` for instrumental. Partial/truncated lyrics = incomplete songs |
| `sample_mode` | bool | false | Enable Simple/Random mode |
| `sample_query` | string | "" | Description for Simple mode |
| `thinking` | bool | false | Enable 5Hz LM for audio code generation |
| `use_format` | bool | false | Use LM to enhance caption/lyrics |
| `model` | string | - | DiT model name |
| `batch_size` | int | 1 | Number of audio files to generate |
### Music Attributes
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `audio_duration` | float | - | Duration in seconds |
| `bpm` | int | - | Tempo (beats per minute) |
| `key_scale` | string | "" | Key (e.g. "C Major") |
| `time_signature` | string | "" | Time signature (e.g. "4/4") |
| `vocal_language` | string | "en" | Language code (en, zh, ja, etc.) |
| `audio_format` | string | "mp3" | Output format (mp3/wav/flac) |
### Generation Parameters
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `inference_steps` | int | 8 | Diffusion steps |
| `guidance_scale` | float | 7.0 | CFG scale |
| `seed` | int | -1 | Random seed (-1 for random) |
| `infer_method` | string | "ode" | Diffusion method (ode/sde) |
### Audio Task Parameters
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `task_type` | string | "text2music" | text2music / continuation / repainting |
| `src_audio_path` | string | - | Source audio for continuation |
| `repainting_start` | float | 0.0 | Repainting start position (seconds) |
| `repainting_end` | float | - | Repainting end position (seconds) |
### Example Request (Simple Mode)
```json
{
"sample_mode": true,
"sample_query": "A cheerful pop song about spring",
"thinking": true,
"param_obj": {
"duration": 60,
"bpm": 120,
"language": "en"
},
"batch_size": 2
}
```bash
cd "{project_root}/{.claude or .codex}/skills/acestep/" && bash ./scripts/acestep.sh config --check-key
cd "{project_root}/{.claude or .codex}/skills/acestep/" && bash ./scripts/acestep.sh config --get api_url
```
#### Case 1: Using Official Cloud API (`https://api.acemusic.ai`) without API key
If `api_url` is `https://api.acemusic.ai` and `api_key` is `empty`, you MUST stop and guide the user to configure their key:
1. Tell the user: "You're using the ACE-Step official cloud API, but no API key is configured. An API key is required to use this service."
2. Explain how to get a key: API keys are currently available through the official ACE-Step Discord community (https://discord.gg/bGVxwUyD). Additional distribution methods will be added in the future.
3. Use `AskUserQuestion` to ask the user to provide their API key.
4. Once provided, configure it:
```bash
cd "{project_root}/{.claude or .codex}/skills/acestep/" && bash ./scripts/acestep.sh config --set api_key <KEY>
```
5. Additionally, inform the user: "If you also want to render music videos (MV), it's recommended to configure a lyrics transcription API key as well (OpenAI Whisper or ElevenLabs Scribe), so that lyrics can be automatically transcribed with accurate timestamps. You can configure it later via the `acestep-lyrics-transcription` skill."
#### Case 2: API key is configured
Verify the API endpoint: `./scripts/acestep.sh health` and proceed with music generation.
#### Case 3: Using local/custom API without key
Local services (`http://127.0.0.1:*`) typically don't require a key. Verify with `./scripts/acestep.sh health` and proceed.
If health check fails:
- Ask: "Do you have ACE-Step installed?"
- **If installed but not running**: Use the acestep-docs skill to help them start the service
- **If not installed**: Use acestep-docs skill to guide through installation
### Service Configuration
**Official Cloud API:** ACE-Step provides an official API endpoint at `https://api.acemusic.ai`. To use it:
```bash
./scripts/acestep.sh config --set api_url "https://api.acemusic.ai"
./scripts/acestep.sh config --set api_key "your-key"
./scripts/acestep.sh config --set api_mode completion
```
API keys are currently available through the official ACE-Step Discord community. Additional distribution methods will be added in the future.
**Local Service (Default):** No configuration needed — connects to `http://127.0.0.1:8001`.
**Custom Remote Service:** Update `scripts/config.json` or use:
```bash
./scripts/acestep.sh config --set api_url "http://remote-server:8001"
./scripts/acestep.sh config --set api_key "your-key"
```
**API Key Handling**: When checking whether an API key is configured, use `config --check-key` which only reports `configured` or `empty` without printing the actual key. **NEVER use `config --get api_key`** or read `config.json` directly — these would expose the user's API key. The `config --list` command is safe — it automatically masks API keys as `***` in output.
### API Mode
The skill supports two API modes. Switch via `api_mode` in `scripts/config.json`:
| Mode | Endpoint | Description |
|------|----------|-------------|
| `completion` (default) | `/v1/chat/completions` | OpenRouter-compatible, sync request, audio returned as base64 |
| `native` | `/release_task` + `/query_result` | Async polling mode, supports all parameters |
**Switch mode:**
```bash
./scripts/acestep.sh config --set api_mode completion
./scripts/acestep.sh config --set api_mode native
```
**Completion mode notes:**
- No polling needed — single request returns result directly
- Audio is base64-encoded inline in the response (auto-decoded and saved)
- `inference_steps`, `infer_method`, `shift` are not configurable (server defaults)
- `--no-wait` and `status` commands are not applicable in completion mode
- Requires `model` field — auto-detected from `/v1/models` if not specified
### Using acestep-docs Skill for Setup Help
**IMPORTANT**: For installation and startup, always use the acestep-docs skill to get complete and accurate guidance.
**DO NOT provide simplified startup commands** - each user's environment may be different. Always guide them to use acestep-docs for proper setup.
---
For API debugging, see [API Reference](./api-reference.md).

View file

@ -0,0 +1,149 @@
# ACE-Step API Reference
> For debugging and advanced usage only. Normal operations should use `scripts/acestep.sh`.
## Native Mode Endpoints
All responses wrapped: `{"data": <payload>, "code": 200, "error": null, "timestamp": ...}`
| Endpoint | Method | Description |
|----------|--------|-------------|
| `/health` | GET | Health check |
| `/release_task` | POST | Create generation task |
| `/query_result` | POST | Query task status, body: `{"task_id_list": ["id"]}` |
| `/v1/models` | GET | List available models |
| `/v1/audio?path={path}` | GET | Download audio file |
## Completion Mode Endpoints
| Endpoint | Method | Description |
|----------|--------|-------------|
| `/v1/chat/completions` | POST | Generate music (OpenRouter-compatible) |
| `/v1/models` | GET | List available models (OpenRouter format) |
## Query Result Response
```json
{
"data": [{
"task_id": "xxx",
"status": 1,
"result": "[{\"file\":\"/v1/audio?path=...\",\"metas\":{\"bpm\":120,\"duration\":60,\"keyscale\":\"C Major\"}}]"
}]
}
```
Status codes: `0` = processing, `1` = success, `2` = failed
## Completion Mode Request (`/v1/chat/completions`)
**Caption mode** — prompt and lyrics wrapped in XML tags inside message content:
```json
{
"model": "acestep/ACE-Step-v1.5",
"messages": [{"role": "user", "content": "<prompt>Jazz with saxophone</prompt><lyrics>[Verse] Hello...</lyrics>"}],
"stream": false,
"thinking": true,
"use_format": false,
"audio_config": {"duration": 90, "bpm": 110, "format": "mp3", "vocal_language": "en"}
}
```
**Simple mode** — plain text message, set `sample_mode: true`:
```json
{
"model": "acestep/ACE-Step-v1.5",
"messages": [{"role": "user", "content": "A cheerful pop song about spring"}],
"stream": false,
"sample_mode": true,
"thinking": true
}
```
## Completion Mode Response
```json
{
"id": "chatcmpl-abc123",
"choices": [{
"message": {
"role": "assistant",
"content": "## Metadata\n**Caption:** ...\n**BPM:** 128\n\n## Lyrics\n...",
"audio": [{"type": "audio_url", "audio_url": {"url": "data:audio/mpeg;base64,..."}}]
},
"finish_reason": "stop"
}]
}
```
Audio is base64-encoded inline — the script auto-decodes and saves to `acestep_output/`.
## Request Parameters (`/release_task`)
Parameters can be placed in `param_obj` object.
### Generation Modes
| Mode | Usage | When to Use |
|------|-------|-------------|
| **Caption** (Recommended) | `generate -c "style" -l "lyrics"` | For vocal songs - write lyrics yourself first |
| **Simple** | `generate -d "description"` | Quick exploration, LM generates everything |
| **Random** | `random` | Random generation for inspiration |
### Core Parameters
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `prompt` | string | "" | Music style description (Caption mode) |
| `lyrics` | string | "" | **Full lyrics content** - Pass ALL lyrics without omission. Use `[inst]` for instrumental. Partial/truncated lyrics = incomplete songs |
| `sample_mode` | bool | false | Enable Simple/Random mode |
| `sample_query` | string | "" | Description for Simple mode |
| `thinking` | bool | false | Enable 5Hz LM for audio code generation |
| `use_format` | bool | false | Use LM to enhance caption/lyrics |
| `model` | string | - | DiT model name |
| `batch_size` | int | 1 | Number of audio files to generate |
### Music Attributes
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `audio_duration` | float | - | Duration in seconds |
| `bpm` | int | - | Tempo (beats per minute) |
| `key_scale` | string | "" | Key (e.g. "C Major") |
| `time_signature` | string | "" | Time signature (e.g. "4/4") |
| `vocal_language` | string | "en" | Language code (en, zh, ja, etc.) |
| `audio_format` | string | "mp3" | Output format (mp3/wav/flac) |
### Generation Parameters
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `inference_steps` | int | 8 | Diffusion steps |
| `guidance_scale` | float | 7.0 | CFG scale |
| `seed` | int | -1 | Random seed (-1 for random) |
| `infer_method` | string | "ode" | Diffusion method (ode/sde) |
### Audio Task Parameters
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `task_type` | string | "text2music" | text2music / continuation / repainting |
| `src_audio_path` | string | - | Source audio for continuation |
| `repainting_start` | float | 0.0 | Repainting start position (seconds) |
| `repainting_end` | float | - | Repainting end position (seconds) |
### Example Request (Simple Mode)
```json
{
"sample_mode": true,
"sample_query": "A cheerful pop song about spring",
"thinking": true,
"param_obj": {
"duration": 60,
"bpm": 120,
"language": "en"
},
"batch_size": 2
}
```

View file

@ -1,350 +0,0 @@
# ACE-Step Music Creation Guide
> This guide contains professional music creation knowledge extracted from ACE-Step Tutorial. Use this as reference when creating music with ACE-Step.
---
## Input Control: What Do You Want?
This is the part where you communicate "creative intent" with the model—what kind of music you want to generate.
| Category | Parameter | Function |
|----------|-----------|----------|
| **Task Type** | `task_type` | Determines generation mode: text2music, cover, repaint, lego, extract, complete |
| **Text Input** | `caption` | Description of overall music elements: style, instruments, emotion, atmosphere, timbre, vocal gender, progression, etc. |
| | `lyrics` | Temporal element description: lyric content, music structure evolution, vocal changes, vocal/instrument performance style, start/end style, articulation, etc. (use `[Instrumental]` for instrumental music) |
| **Music Metadata** | `bpm` | Tempo (30300) |
| | `keyscale` | Key (e.g., C Major, Am) |
| | `timesignature` | Time signature (4/4, 3/4, 6/8) |
| | `vocal_language` | Vocal language |
| | `duration` | Target duration (seconds) |
| **Audio Reference** | `reference_audio` | Global reference for timbre or style (for cover, style transfer) |
| | `src_audio` | Source audio for non-text2music tasks (text2music defaults to silence, no input needed) |
| | `audio_codes` | Semantic codes input to model in Cover mode (advanced: reuse codes for variants, convert songs to codes for extension, combine like DJ mixing) |
| **Interval Control** | `repainting_start/end` | Time interval for operations (repaint redraw area / lego new track area) |
---
## About Caption: The Most Important Input
**Caption is the most important factor affecting generated music.**
It supports multiple input formats: simple style words, comma-separated tags, complex natural language descriptions. We've trained to be compatible with various formats, ensuring text format doesn't significantly affect model performance.
### Common Dimensions for Caption Writing
| Dimension | Examples |
|-----------|----------|
| **Style/Genre** | pop, rock, jazz, electronic, hip-hop, R&B, folk, classical, lo-fi, synthwave |
| **Emotion/Atmosphere** | melancholic, uplifting, energetic, dreamy, dark, nostalgic, euphoric, intimate |
| **Instruments** | acoustic guitar, piano, synth pads, 808 drums, strings, brass, electric bass |
| **Timbre Texture** | warm, bright, crisp, muddy, airy, punchy, lush, raw, polished |
| **Era Reference** | 80s synth-pop, 90s grunge, 2010s EDM, vintage soul, modern trap |
| **Production Style** | lo-fi, high-fidelity, live recording, studio-polished, bedroom pop |
| **Vocal Characteristics** | female vocal, male vocal, breathy, powerful, falsetto, raspy, choir |
| **Speed/Rhythm** | slow tempo, mid-tempo, fast-paced, groovy, driving, laid-back |
| **Structure Hints** | building intro, catchy chorus, dramatic bridge, fade-out ending |
### Practical Principles for Caption Writing
1. **Specific beats vague** — "sad piano ballad with female breathy vocal" works better than "a sad song."
2. **Combine multiple dimensions** — Single-dimension descriptions give the model too much room to play; combining style+emotion+instruments+timbre can more precisely anchor your desired direction.
3. **Use references well** — "in the style of 80s synthwave" or "reminiscent of Bon Iver" can quickly convey complex aesthetic preferences.
4. **Texture words are useful** — Adjectives like warm, crisp, airy, punchy can influence mixing and timbre tendencies.
5. **Don't pursue perfect descriptions** — Caption is a starting point, not an endpoint. Write a general direction first, then iterate based on results.
6. **Description granularity determines freedom** — More omitted descriptions give the model more room to play, more random factor influence; more detailed descriptions constrain the model more. Decide specificity based on your needs—want surprises? Write less. Want control? Write more details.
7. **Avoid conflicting words** — Conflicting style combinations easily lead to degraded output. For example, wanting both "classical strings" and "hardcore metal" simultaneously—the model will try to fuse but usually not ideal.
**Ways to resolve conflicts:**
- **Repetition reinforcement** — Strengthen the elements you want more in mixed styles by repeating certain words
- **Conflict to evolution** — Transform style conflicts into temporal style evolution. For example: "Start with soft strings, middle becomes noisy dynamic metal rock, end turns to hip-hop"—this gives the model clear guidance on how to handle different styles, rather than mixing them into a mess
---
## About Lyrics: The Temporal Script
If Caption describes the music's "overall portrait"—style, atmosphere, timbre—then **Lyrics is the music's "temporal script"**, controlling how music unfolds over time.
Lyrics is not just lyric content. It carries:
- The lyric text itself
- **Structure tags** ([Verse], [Chorus], [Bridge]...)
- **Vocal style hints** ([raspy vocal], [whispered]...)
- **Instrumental sections** ([guitar solo], [drum break]...)
- **Energy changes** ([building energy], [explosive drop]...)
### Common Structure Tags
| Category | Tag | Description |
|----------|-----|-------------|
| **Basic Structure** | `[Intro]` | Opening, establish atmosphere |
| | `[Verse]` / `[Verse 1]` | Verse, narrative progression |
| | `[Pre-Chorus]` | Pre-chorus, build energy |
| | `[Chorus]` | Chorus, emotional climax |
| | `[Bridge]` | Bridge, transition or elevation |
| | `[Outro]` | Ending, conclusion |
| **Dynamic Sections** | `[Build]` | Energy gradually rising |
| | `[Drop]` | Electronic music energy release |
| | `[Breakdown]` | Reduced instrumentation, space |
| **Instrumental Sections** | `[Instrumental]` | Pure instrumental, no vocals |
| | `[Guitar Solo]` | Guitar solo |
| | `[Piano Interlude]` | Piano interlude |
| **Special Tags** | `[Fade Out]` | Fade out ending |
| | `[Silence]` | Silence |
### Combining Tags: Use Moderately
Structure tags can be combined with `-` for finer control:
```
[Chorus - anthemic]
This is the chorus lyrics
Dreams are burning
[Bridge - whispered]
Whisper those words softly
```
⚠️ **Note: Don't stack too many tags.**
```
❌ Not recommended:
[Chorus - anthemic - stacked harmonies - high energy - powerful - epic]
✅ Recommended:
[Chorus - anthemic]
```
**Principle**: Keep structure tags concise; put complex style descriptions in Caption.
### ⚠️ Key: Maintain Consistency Between Caption and Lyrics
**Models are not good at resolving conflicts.** If descriptions in Caption and Lyrics contradict, the model gets confused and output quality decreases.
**Checklist:**
- Instruments in Caption ↔ Instrumental section tags in Lyrics
- Emotion in Caption ↔ Energy tags in Lyrics
- Vocal description in Caption ↔ Vocal control tags in Lyrics
Think of Caption as "overall setting" and Lyrics as "shot script"—they should tell the same story.
### Vocal Control Tags
| Tag | Effect |
|-----|--------|
| `[raspy vocal]` | Raspy, textured vocals |
| `[whispered]` | Whispered |
| `[falsetto]` | Falsetto |
| `[powerful belting]` | Powerful, high-pitched singing |
| `[spoken word]` | Rap/recitation |
| `[harmonies]` | Layered harmonies |
| `[call and response]` | Call and response |
| `[ad-lib]` | Improvised embellishments |
### Energy and Emotion Tags
| Tag | Effect |
|-----|--------|
| `[high energy]` | High energy, passionate |
| `[low energy]` | Low energy, restrained |
| `[building energy]` | Increasing energy |
| `[explosive]` | Explosive energy |
| `[melancholic]` | Melancholic |
| `[euphoric]` | Euphoric |
| `[dreamy]` | Dreamy |
| `[aggressive]` | Aggressive |
### Lyric Text Writing Tips
**1. Control Syllable Count**
**6-10 syllables per line** usually works best. The model aligns syllables to beats—if one line has 6 syllables and the next has 14, rhythm becomes strange.
**Tip**: Keep similar syllable counts for lines in the same position (e.g., first line of each verse) (±1-2 deviation).
**2. Use Case to Control Intensity**
Uppercase indicates stronger vocal intensity:
```
[Verse]
walking through the empty streets (normal intensity)
[Chorus]
WE ARE THE CHAMPIONS! (high intensity, shouting)
```
**3. Use Parentheses for Background Vocals**
```
[Chorus]
We rise together (together)
Into the light (into the light)
```
Content in parentheses is processed as background vocals or harmonies.
**4. Extend Vowels**
You can extend sounds by repeating vowels:
```
Feeeling so aliiive
```
But use cautiously—effects are unstable, sometimes ignored or mispronounced.
**5. Clear Section Separation**
Separate each section with blank lines:
```
[Verse 1]
First verse lyrics
Continue first verse
[Chorus]
Chorus lyrics
Chorus continues
```
### Avoiding "AI-flavored" Lyrics
These characteristics make lyrics seem mechanical and lack human touch:
| Red Flag 🚩 | Description |
|-------------|-------------|
| **Adjective stacking** | "neon skies, electric hearts, endless dreams"—filling a section with vague imagery |
| **Rhyme chaos** | Inconsistent rhyme patterns, or forced rhymes causing semantic breaks |
| **Blurred section boundaries** | Lyric content crosses structure tags, Verse content "flows" into Chorus |
| **No breathing room** | Each line too long, can't sing in one breath |
| **Mixed metaphors** | First verse uses water imagery, second suddenly becomes fire, third is flying—listeners can't anchor |
**Metaphor discipline**: Stick to one core metaphor per song, exploring its multiple aspects.
---
## About Music Metadata: Optional Fine Control
**Most of the time, you don't need to manually set metadata.**
When you enable `thinking` mode (or enable `use_cot_metas`), LM automatically infers appropriate BPM, key, time signature, etc. based on your Caption and Lyrics. This is usually good enough.
But if you have clear ideas, you can also manually control them:
| Parameter | Control Range | Description |
|-----------|--------------|-------------|
| `bpm` | 30300 | Tempo. Common distribution: slow songs 6080, mid-tempo 90120, fast songs 130180 |
| `keyscale` | Key | e.g., `C Major`, `Am`, `F# Minor`. Affects overall pitch and emotional color |
| `timesignature` | Time signature | `4/4` (most common), `3/4` (waltz), `6/8` (swing feel) |
| `vocal_language` | Language | Vocal language. LM usually auto-detects from lyrics |
| `duration` | Seconds | Target duration. Actual generation may vary slightly |
### Understanding Control Boundaries
These parameters are **guidance** rather than **precise commands**:
- **BPM**: Common range (60180) works well; extreme values (like 30 or 280) have less training data, may be unstable
- **Key**: Common keys (C, G, D, Am, Em) are stable; rare keys may be ignored or shifted
- **Time signature**: `4/4` is most reliable; `3/4`, `6/8` usually OK; complex signatures (like `5/4`, `7/8`) are advanced, effects vary by style
- **Duration**: Short songs (3060s) and medium length (24min) are stable; very long generation may have repetition or structure issues
### When Do You Need Manual Settings?
| Scenario | Suggestion |
|----------|------------|
| Daily generation | Don't worry, let LM auto-infer |
| Clear tempo requirement | Manually set `bpm` |
| Specific style (e.g., waltz) | Manually set `timesignature=3/4` |
| Need to match other material | Manually set `bpm` and `duration` |
| Pursue specific key color | Manually set `keyscale` |
**Tip**: If you manually set metadata but generation results clearly don't match—check if there's conflict with Caption/Lyrics. For example, Caption says "slow ballad" but `bpm=160`, the model gets confused.
**Recommended Practice**: Don't write tempo, BPM, key, and other metadata information in Caption. These should be set through dedicated metadata parameters (`bpm`, `keyscale`, `timesignature`, etc.), not described in Caption. Caption should focus on style, emotion, instruments, timbre, and other musical characteristics, while metadata information is handled by corresponding parameters.
---
## Duration Calculation Guidelines
When creating music, you MUST calculate appropriate duration based on lyrics content and song structure:
### Estimation Method
- **Per line of lyrics**: 3-5 seconds
- **Intro/Outro**: 5-10 seconds each
- **Instrumental sections**: 5-15 seconds each
- **Typical song structures**:
- 2 verses + 2 choruses: 120-150 seconds minimum
- 2 verses + 2 choruses + bridge: 180-240 seconds minimum
- Full song with intro/outro: 210-270 seconds (3.5-4.5 minutes)
### Common Pitfalls
**DON'T**: Set duration too short for the lyrics amount
- Example: 10 lines of lyrics with 120 seconds → rushed, compressed
**DO**: Calculate realistic duration
- Example: 10 lines of lyrics → ~40 seconds of vocals + 20 seconds intro/outro = 60 seconds minimum
### BPM and Duration Relationship
The BPM affects how quickly lyrics are sung:
- **Slower BPM (60-80)**: Need MORE duration for same lyrics
- **Medium BPM (100-130)**: Standard duration
- **Faster BPM (150-180)**: Can fit more lyrics in less time, but still need breathing room
**Rule of thumb**: When in doubt, estimate longer rather than shorter. A song that's too short will feel rushed and incomplete.
---
## Complete Example
Assuming Caption is: `female vocal, piano ballad, emotional, intimate atmosphere, strings, building to powerful chorus`
```
[Intro - piano]
[Verse 1]
月光洒在窗台上
我听见你的呼吸
城市在远处沉睡
只有我们还醒着
[Pre-Chorus]
这一刻如此安静
却藏着汹涌的心
[Chorus - powerful]
让我们燃烧吧
像夜空中的烟火
短暂却绚烂
这就是我们的时刻
[Verse 2]
时间在指尖流过
我们抓不住什么
但至少此刻拥有
彼此眼中的火焰
[Bridge - whispered]
如果明天一切消散
至少我们曾经闪耀
[Final Chorus]
让我们燃烧吧
像夜空中的烟火
短暂却绚烂
THIS IS OUR MOMENT!
[Outro - fade out]
```
Note: In this example, Lyrics tags (piano, powerful, whispered) are consistent with Caption descriptions (piano ballad, building to powerful chorus, intimate), with no conflicts.
---

View file

@ -72,6 +72,7 @@ ensure_output_dir() {
DEFAULT_CONFIG='{
"api_url": "http://127.0.0.1:8001",
"api_key": "",
"api_mode": "native",
"generation": {
"thinking": true,
"use_format": true,
@ -85,7 +86,15 @@ DEFAULT_CONFIG='{
# Ensure config file exists
ensure_config() {
if [ ! -f "$CONFIG_FILE" ]; then
echo "$DEFAULT_CONFIG" > "$CONFIG_FILE"
local example="${SCRIPT_DIR}/config.example.json"
if [ -f "$example" ]; then
cp "$example" "$CONFIG_FILE"
echo -e "${YELLOW}Config file created from config.example.json. Please configure your settings:${NC}"
echo -e " ${CYAN}./scripts/acestep.sh config --set api_url <url>${NC}"
echo -e " ${CYAN}./scripts/acestep.sh config --set api_key <key>${NC}"
else
echo "$DEFAULT_CONFIG" > "$CONFIG_FILE"
fi
fi
}
@ -244,11 +253,20 @@ cmd_config() {
--set) action="set"; key="$2"; value="$3"; shift 3 ;;
--reset) action="reset"; shift ;;
--list) action="list"; shift ;;
--check-key) action="check-key"; shift ;;
*) shift ;;
esac
done
case "$action" in
"check-key")
local api_key=$(get_config "api_key")
if [ -n "$api_key" ]; then
echo "api_key: configured"
else
echo "api_key: empty"
fi
;;
"get")
[ -z "$key" ] && { echo -e "${RED}Error: --get requires KEY${NC}"; exit 1; }
local result=$(get_config "$key")
@ -261,11 +279,11 @@ cmd_config() {
"reset")
echo "$DEFAULT_CONFIG" > "$CONFIG_FILE"
echo -e "${GREEN}Configuration reset to defaults.${NC}"
cat "$CONFIG_FILE"
jq 'walk(if type == "object" and has("api_key") and (.api_key | length) > 0 then .api_key = "***" else . end)' "$CONFIG_FILE"
;;
"list")
echo "Current configuration:"
cat "$CONFIG_FILE"
jq 'walk(if type == "object" and has("api_key") and (.api_key | length) > 0 then .api_key = "***" else . end)' "$CONFIG_FILE"
;;
*)
echo "Config file: $CONFIG_FILE"
@ -488,6 +506,197 @@ download_audios() {
done <<< "$audio_paths"
}
# =============================================================================
# Completion Mode (OpenRouter /v1/chat/completions)
# =============================================================================
# Load api_mode from config (default: native)
load_api_mode() {
local mode=$(get_config "api_mode")
echo "${mode:-native}"
}
# Get model ID from /v1/models endpoint for completion mode
get_completion_model() {
local api_url="$1"
local user_model="$2"
local api_key=$(load_api_key)
# If user specified a model, prefix with acemusic/ if needed
if [ -n "$user_model" ]; then
if [[ "$user_model" == */* ]]; then
echo "$user_model"
else
echo "acemusic/${user_model}"
fi
return
fi
# Query /v1/models for the first available model
local response
if [ -n "$api_key" ]; then
response=$(curl -s -H "Authorization: Bearer ${api_key}" "${api_url}/v1/models" 2>/dev/null)
else
response=$(curl -s "${api_url}/v1/models" 2>/dev/null)
fi
local model_id
model_id=$(echo "$response" | jq -r '.data[0].id // empty' 2>/dev/null)
echo "${model_id:-acemusic/acestep-v15-turbo}"
}
# Decode base64 audio data URL and save to file
# Handles cross-platform compatibility (Linux/macOS/Windows MSYS)
decode_base64_audio() {
local data_url="$1"
local output_file="$2"
# Strip data URL prefix: data:audio/mpeg;base64,...
local b64_data="${data_url#data:*;base64,}"
local tmp_b64=$(mktemp)
printf '%s' "$b64_data" > "$tmp_b64"
if command -v base64 &> /dev/null; then
# Linux / macOS / MSYS2
base64 -d < "$tmp_b64" > "$output_file" 2>/dev/null || \
base64 -D < "$tmp_b64" > "$output_file" 2>/dev/null || \
python3 -c "import base64,sys; sys.stdout.buffer.write(base64.b64decode(sys.stdin.read()))" < "$tmp_b64" > "$output_file" 2>/dev/null || \
python -c "import base64,sys; sys.stdout.buffer.write(base64.b64decode(sys.stdin.read()))" < "$tmp_b64" > "$output_file" 2>/dev/null
else
# Fallback to python
python3 -c "import base64,sys; sys.stdout.buffer.write(base64.b64decode(sys.stdin.read()))" < "$tmp_b64" > "$output_file" 2>/dev/null || \
python -c "import base64,sys; sys.stdout.buffer.write(base64.b64decode(sys.stdin.read()))" < "$tmp_b64" > "$output_file" 2>/dev/null
fi
local decode_ok=$?
rm -f "$tmp_b64"
return $decode_ok
}
# Parse completion response: extract metadata, save audio files
# Usage: parse_completion_response <response_file> <job_id>
parse_completion_response() {
local resp_file="$1"
local job_id="$2"
ensure_output_dir
local audio_format=$(get_config "generation.audio_format")
[ -z "$audio_format" ] && audio_format="mp3"
# Check for error
local finish_reason
finish_reason=$(jq -r '.choices[0].finish_reason // "stop"' "$resp_file" 2>/dev/null)
if [ "$finish_reason" = "error" ]; then
local err_content
err_content=$(jq -r '.choices[0].message.content // "Unknown error"' "$resp_file" 2>/dev/null)
echo -e "${RED}Generation failed: $err_content${NC}"
return 1
fi
# Extract and display text content (metadata + lyrics)
local content
content=$(jq -r '.choices[0].message.content // empty' "$resp_file" 2>/dev/null)
if [ -n "$content" ]; then
echo "$content"
echo ""
fi
# Extract and save audio files
local audio_count
audio_count=$(jq -r '.choices[0].message.audio | length // 0' "$resp_file" 2>/dev/null)
if [ "$audio_count" -gt 0 ] 2>/dev/null; then
local i=0
while [ "$i" -lt "$audio_count" ]; do
local audio_url
audio_url=$(jq -r ".choices[0].message.audio[$i].audio_url.url // empty" "$resp_file" 2>/dev/null)
if [ -n "$audio_url" ]; then
local output_file="${OUTPUT_DIR}/${job_id}_$((i+1)).${audio_format}"
echo -e " ${CYAN}Decoding audio $((i+1))...${NC}"
if decode_base64_audio "$audio_url" "$output_file"; then
if [ -f "$output_file" ] && [ -s "$output_file" ]; then
echo -e " ${GREEN}Saved: $output_file${NC}"
else
echo -e " ${RED}Failed to decode audio $((i+1))${NC}"
rm -f "$output_file" 2>/dev/null
fi
else
echo -e " ${RED}Failed to decode audio $((i+1))${NC}"
rm -f "$output_file" 2>/dev/null
fi
fi
i=$((i+1))
done
else
echo -e " ${YELLOW}No audio files in response${NC}"
fi
# Save full response JSON (strip base64 audio to keep file small)
local clean_resp
clean_resp=$(jq 'del(.choices[].message.audio[].audio_url.url)' "$resp_file" 2>/dev/null)
if [ -n "$clean_resp" ]; then
save_result "$job_id" "$clean_resp"
else
save_result "$job_id" "$(cat "$resp_file")"
fi
}
# Send request to /v1/chat/completions and handle response
# Usage: send_completion_request <api_url> <payload_file> <job_id_var>
send_completion_request() {
local api_url="$1"
local payload_file="$2"
local api_key=$(load_api_key)
local resp_file=$(mktemp)
local http_code
if [ -n "$api_key" ]; then
http_code=$(curl -s -w "%{http_code}" --connect-timeout 10 --max-time 660 \
-o "$resp_file" \
-X POST "${api_url}/v1/chat/completions" \
-H "Content-Type: application/json; charset=utf-8" \
-H "Authorization: Bearer ${api_key}" \
--data-binary "@${payload_file}")
else
http_code=$(curl -s -w "%{http_code}" --connect-timeout 10 --max-time 660 \
-o "$resp_file" \
-X POST "${api_url}/v1/chat/completions" \
-H "Content-Type: application/json; charset=utf-8" \
--data-binary "@${payload_file}")
fi
rm -f "$payload_file"
if [ "$http_code" != "200" ]; then
local err_detail
err_detail=$(jq -r '.detail // .error.message // empty' "$resp_file" 2>/dev/null)
echo -e "${RED}Error: HTTP $http_code${NC}"
[ -n "$err_detail" ] && echo -e "${RED}$err_detail${NC}"
rm -f "$resp_file"
return 1
fi
# Generate a job_id from the completion id
local job_id
job_id=$(jq -r '.id // empty' "$resp_file" 2>/dev/null)
[ -z "$job_id" ] && job_id="completion-$(date +%s)"
echo ""
echo -e "${GREEN}Generation completed!${NC}"
echo ""
parse_completion_response "$resp_file" "$job_id"
rm -f "$resp_file"
echo ""
echo -e "${GREEN}Done! Files saved to: $OUTPUT_DIR${NC}"
}
# Wait for job and download results
wait_for_job() {
local api_url="$1"
@ -646,6 +855,8 @@ cmd_generate() {
[ -n "$bpm" ] && payload=$(echo "$payload" | jq --argjson v "$bpm" '. + {bpm: $v}')
[ -n "$batch" ] && payload=$(echo "$payload" | jq --argjson v "$batch" '. + {batch_size: $v}')
local api_mode=$(load_api_mode)
echo "Generating music..."
if [ -n "$description" ]; then
echo " Mode: Simple (description)"
@ -655,38 +866,91 @@ cmd_generate() {
echo " Caption: ${caption:0:50}..."
fi
echo " Thinking: $thinking, Format: $use_format"
echo " API: $api_mode"
echo " Output: $OUTPUT_DIR"
echo ""
# Write payload to temp file to ensure proper UTF-8 encoding
local temp_payload=$(mktemp)
printf '%s' "$payload" > "$temp_payload"
if [ "$api_mode" = "completion" ]; then
# --- Completion mode: /v1/chat/completions ---
local model_id=$(get_completion_model "$api_url" "$model")
local api_key=$(load_api_key)
local response
if [ -n "$api_key" ]; then
response=$(curl -s -X POST "${api_url}/release_task" \
-H "Content-Type: application/json; charset=utf-8" \
-H "Authorization: Bearer ${api_key}" \
--data-binary "@${temp_payload}")
# Build message content
local message_content=""
local sample_mode=false
if [ -n "$description" ]; then
message_content="$description"
sample_mode=true
else
message_content="<prompt>${caption}</prompt>"
[ -n "$lyrics" ] && message_content="${message_content}<lyrics>${lyrics}</lyrics>"
fi
# Build completion payload
local payload_c=$(jq -n \
--arg model "$model_id" \
--arg content "$message_content" \
--argjson thinking "$thinking" \
--argjson use_format "$use_format" \
--argjson sample_mode "$sample_mode" \
--argjson use_cot_caption "$cot_caption" \
--argjson use_cot_language "$cot_language" \
--arg vocal_language "$language" \
--arg format "${def_audio_format:-mp3}" \
'{
model: $model,
messages: [{"role": "user", "content": $content}],
stream: false,
thinking: $thinking,
use_format: $use_format,
sample_mode: $sample_mode,
use_cot_caption: $use_cot_caption,
use_cot_language: $use_cot_language,
audio_config: {
format: $format,
vocal_language: $vocal_language
}
}')
# Add optional parameters to completion payload
[ -n "$guidance" ] && payload_c=$(echo "$payload_c" | jq --argjson v "$guidance" '. + {guidance_scale: $v}')
[ -n "$seed" ] && payload_c=$(echo "$payload_c" | jq --argjson v "$seed" '. + {seed: $v}')
[ -n "$batch" ] && payload_c=$(echo "$payload_c" | jq --argjson v "$batch" '. + {batch_size: $v}')
[ -n "$duration" ] && payload_c=$(echo "$payload_c" | jq --argjson v "$duration" '.audio_config.duration = $v')
[ -n "$bpm" ] && payload_c=$(echo "$payload_c" | jq --argjson v "$bpm" '.audio_config.bpm = $v')
local temp_payload=$(mktemp)
printf '%s' "$payload_c" > "$temp_payload"
send_completion_request "$api_url" "$temp_payload"
else
response=$(curl -s -X POST "${api_url}/release_task" \
-H "Content-Type: application/json; charset=utf-8" \
--data-binary "@${temp_payload}")
fi
# --- Native mode: /release_task + polling ---
local temp_payload=$(mktemp)
printf '%s' "$payload" > "$temp_payload"
rm -f "$temp_payload"
local api_key=$(load_api_key)
local response
if [ -n "$api_key" ]; then
response=$(curl -s -X POST "${api_url}/release_task" \
-H "Content-Type: application/json; charset=utf-8" \
-H "Authorization: Bearer ${api_key}" \
--data-binary "@${temp_payload}")
else
response=$(curl -s -X POST "${api_url}/release_task" \
-H "Content-Type: application/json; charset=utf-8" \
--data-binary "@${temp_payload}")
fi
# Response is wrapped: {"data": {"task_id": ...}, "code": 200, ...}
local job_id=$(echo "$response" | jq -r '.data.task_id // .task_id // empty')
rm -f "$temp_payload"
[ -z "$job_id" ] && { echo -e "${RED}Error: Failed to create job${NC}"; echo "$response"; exit 1; }
local job_id=$(echo "$response" | jq -r '.data.task_id // .task_id // empty')
[ -z "$job_id" ] && { echo -e "${RED}Error: Failed to create job${NC}"; echo "$response"; exit 1; }
if [ "$no_wait" = true ]; then
echo "Job ID: $job_id"
echo "Use '$0 status $job_id' to check progress and download"
else
wait_for_job "$api_url" "$job_id"
if [ "$no_wait" = true ]; then
echo "Job ID: $job_id"
echo "Use '$0 status $job_id' to check progress and download"
else
wait_for_job "$api_url" "$job_id"
fi
fi
}
@ -715,42 +979,67 @@ cmd_random() {
# Normalize boolean for jq --argjson
thinking=$(normalize_bool "$thinking" "true")
local api_mode=$(load_api_mode)
echo "Generating random music..."
echo " Thinking: $thinking"
echo " API: $api_mode"
echo " Output: $OUTPUT_DIR"
echo ""
local payload=$(jq -n --argjson thinking "$thinking" '{sample_mode: true, thinking: $thinking}')
if [ "$api_mode" = "completion" ]; then
# --- Completion mode ---
local model_id=$(get_completion_model "$api_url" "")
local def_audio_format=$(get_config "generation.audio_format")
# Write payload to temp file
local temp_payload=$(mktemp)
printf '%s' "$payload" > "$temp_payload"
local payload_c=$(jq -n \
--arg model "$model_id" \
--argjson thinking "$thinking" \
--arg format "${def_audio_format:-mp3}" \
'{
model: $model,
messages: [{"role": "user", "content": "Generate a random song"}],
stream: false,
sample_mode: true,
thinking: $thinking,
audio_config: { format: $format }
}')
local api_key=$(load_api_key)
local response
if [ -n "$api_key" ]; then
response=$(curl -s -X POST "${api_url}/release_task" \
-H "Content-Type: application/json; charset=utf-8" \
-H "Authorization: Bearer ${api_key}" \
--data-binary "@${temp_payload}")
local temp_payload=$(mktemp)
printf '%s' "$payload_c" > "$temp_payload"
send_completion_request "$api_url" "$temp_payload"
else
response=$(curl -s -X POST "${api_url}/release_task" \
-H "Content-Type: application/json; charset=utf-8" \
--data-binary "@${temp_payload}")
fi
# --- Native mode ---
local payload=$(jq -n --argjson thinking "$thinking" '{sample_mode: true, thinking: $thinking}')
rm -f "$temp_payload"
local temp_payload=$(mktemp)
printf '%s' "$payload" > "$temp_payload"
# Response is wrapped: {"data": {"task_id": ...}, "code": 200, ...}
local job_id=$(echo "$response" | jq -r '.data.task_id // .task_id // empty')
local api_key=$(load_api_key)
local response
if [ -n "$api_key" ]; then
response=$(curl -s -X POST "${api_url}/release_task" \
-H "Content-Type: application/json; charset=utf-8" \
-H "Authorization: Bearer ${api_key}" \
--data-binary "@${temp_payload}")
else
response=$(curl -s -X POST "${api_url}/release_task" \
-H "Content-Type: application/json; charset=utf-8" \
--data-binary "@${temp_payload}")
fi
[ -z "$job_id" ] && { echo -e "${RED}Error: Failed to create job${NC}"; echo "$response"; exit 1; }
rm -f "$temp_payload"
if [ "$no_wait" = true ]; then
echo "Job ID: $job_id"
echo "Use '$0 status $job_id' to check progress and download"
else
wait_for_job "$api_url" "$job_id"
local job_id=$(echo "$response" | jq -r '.data.task_id // .task_id // empty')
[ -z "$job_id" ] && { echo -e "${RED}Error: Failed to create job${NC}"; echo "$response"; exit 1; }
if [ "$no_wait" = true ]; then
echo "Job ID: $job_id"
echo "Use '$0 status $job_id' to check progress and download"
else
wait_for_job "$api_url" "$job_id"
fi
fi
}

View file

@ -1,6 +1,7 @@
{
"api_url": "http://127.0.0.1:8001",
"api_url": "https://api.acemusic.ai",
"api_key": "",
"api_mode": "completion",
"generation": {
"thinking": true,
"use_format": false,

67
.github/copilot-instructions.md vendored Normal file
View file

@ -0,0 +1,67 @@
# ACE-Step 1.5 - GitHub Copilot Instructions
## Project Overview
ACE-Step 1.5 is an open-source music foundation model combining a Language Model (LM) as a planner with a Diffusion Transformer (DiT) for audio synthesis. It generates commercial-grade music on consumer hardware (< 4GB VRAM).
## Tech Stack
- **Python 3.11** (strict requirement)
- **PyTorch 2.7+** with CUDA 12.8 (Windows/Linux), MPS (macOS ARM64)
- **Transformers 4.51.0-4.57.x** for LLM inference
- **Diffusers** for diffusion models
- **Gradio 6.2.0** for web UI
- **FastAPI + Uvicorn** for REST API server
- **uv** for dependency management
- **MLX** (Apple Silicon native acceleration, macOS ARM64)
- **nano-vllm** (optimized LLM inference, non-macOS ARM64)
## Multi-Platform Support
**CRITICAL**: Supports CUDA, ROCm, Intel XPU, MPS, MLX, and CPU. When fixing bugs or adding features:
- **DO NOT alter non-target platform paths** unless explicitly required
- Changes to CUDA code should not affect MPS/XPU/CPU paths
- Use `gpu_config.py` for hardware detection and configuration
## Code Organization
### Main Entry Points
- `acestep/acestep_v15_pipeline.py` - Gradio UI pipeline
- `acestep/api_server.py` - REST API server
- `cli.py` - Command-line interface
- `acestep/model_downloader.py` - Model downloader
### Core Modules
- `acestep/handler.py` - Audio generation handler (AceStepHandler)
- `acestep/llm_inference.py` - LLM handler for text processing
- `acestep/inference.py` - Generation logic and parameters
- `acestep/gpu_config.py` - Hardware detection and GPU configuration
- `acestep/audio_utils.py` - Audio processing utilities
- `acestep/constants.py` - Global constants
### UI & Internationalization
- `acestep/gradio_ui/` - Gradio interface components
- `acestep/gradio_ui/i18n.py` - i18n system (50+ languages)
- All user-facing strings must use i18n translation keys
### Training
- `acestep/training/` - LoRA training pipeline
- `acestep/dataset/` - Dataset handling
## Key Conventions
- **Python style**: PEP 8, 4 spaces, double quotes for strings
- **Naming**: `snake_case` functions/variables, `PascalCase` classes, `UPPER_SNAKE_CASE` constants
- **Logging**: Use `loguru` logger (not `print()` except CLI output)
- **Dependencies**: Use `uv add <package>` to add to `pyproject.toml`
## Performance
- Target: 4GB VRAM - minimize memory allocations
- Lazy load models when needed
- Batch operations supported (up to 8 songs)
## Additional Resources
- **AGENTS.md**: Detailed guidance for AI coding agents
- **CONTRIBUTING.md**: Contribution workflow and guidelines

99
.github/workflows/codeql.yml vendored Normal file
View file

@ -0,0 +1,99 @@
# For most projects, this workflow file will not need changing; you simply need
# to commit it to your repository.
#
# You may wish to alter this file to override the set of languages analyzed,
# or to provide custom queries or build logic.
#
# ******** NOTE ********
# We have attempted to detect the languages in your repository. Please check
# the `language` matrix defined below to confirm you have the correct set of
# supported CodeQL languages.
#
name: "CodeQL Advanced"
on:
push:
branches: [ "main" ]
pull_request:
branches: [ "main" ]
schedule:
- cron: '26 2 * * 5'
jobs:
analyze:
name: Analyze (${{ matrix.language }})
# Runner size impacts CodeQL analysis time. To learn more, please see:
# - https://gh.io/recommended-hardware-resources-for-running-codeql
# - https://gh.io/supported-runners-and-hardware-resources
# - https://gh.io/using-larger-runners (GitHub.com only)
# Consider using larger runners or machines with greater resources for possible analysis time improvements.
runs-on: ${{ (matrix.language == 'swift' && 'macos-latest') || 'ubuntu-latest' }}
permissions:
# required for all workflows
security-events: write
# required to fetch internal or private CodeQL packs
packages: read
# only required for workflows in private repositories
actions: read
contents: read
strategy:
fail-fast: false
matrix:
include:
- language: python
build-mode: none
# CodeQL supports the following values keywords for 'language': 'actions', 'c-cpp', 'csharp', 'go', 'java-kotlin', 'javascript-typescript', 'python', 'ruby', 'rust', 'swift'
# Use `c-cpp` to analyze code written in C, C++ or both
# Use 'java-kotlin' to analyze code written in Java, Kotlin or both
# Use 'javascript-typescript' to analyze code written in JavaScript, TypeScript or both
# To learn more about changing the languages that are analyzed or customizing the build mode for your analysis,
# see https://docs.github.com/en/code-security/code-scanning/creating-an-advanced-setup-for-code-scanning/customizing-your-advanced-setup-for-code-scanning.
# If you are analyzing a compiled language, you can modify the 'build-mode' for that language to customize how
# your codebase is analyzed, see https://docs.github.com/en/code-security/code-scanning/creating-an-advanced-setup-for-code-scanning/codeql-code-scanning-for-compiled-languages
steps:
- name: Checkout repository
uses: actions/checkout@v4
# Add any setup steps before running the `github/codeql-action/init` action.
# This includes steps like installing compilers or runtimes (`actions/setup-node`
# or others). This is typically only required for manual builds.
# - name: Setup runtime (example)
# uses: actions/setup-example@v1
# Initializes the CodeQL tools for scanning.
- name: Initialize CodeQL
uses: github/codeql-action/init@v4
with:
languages: ${{ matrix.language }}
build-mode: ${{ matrix.build-mode }}
# If you wish to specify custom queries, you can do so here or in a config file.
# By default, queries listed here will override any specified in a config file.
# Prefix the list here with "+" to use these queries and those in the config file.
# For more details on CodeQL's query packs, refer to: https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs
# queries: security-extended,security-and-quality
# If the analyze step fails for one of the languages you are analyzing with
# "We were unable to automatically build your code", modify the matrix above
# to set the build mode to "manual" for that language. Then modify this step
# to build your code.
# Command-line programs to run using the OS shell.
# 📚 See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun
- name: Run manual build steps
if: matrix.build-mode == 'manual'
shell: bash
run: |
echo 'If you are using a "manual" build mode for one or more of the' \
'languages you are analyzing, replace this with the commands to build' \
'your code, for example:'
echo ' make bootstrap'
echo ' make release'
exit 1
- name: Perform CodeQL Analysis
uses: github/codeql-action/analyze@v4
with:
category: "/language:${{matrix.language}}"

96
AGENTS.md Normal file
View file

@ -0,0 +1,96 @@
# AGENTS.md
Guidance for AI coding agents working in `ace-step/ACE-Step-1.5`.
This document is aligned with the intent from:
- Discussion #408: functional decomposition to reduce risk from large mixed-responsibility files.
- Discussion #365: low-risk contribution workflow, minimal scope, and review rigor.
## Primary Objectives
1. Keep changes safe and reviewable.
2. Prefer small, maintainable, decomposed modules.
3. Preserve behavior outside the target fix.
4. Validate with focused Python unit tests.
## Scope and Change Control (Required)
- Solve one problem per task/PR.
- Keep edits minimal: touch only files/functions required for the requested change.
- Do not make drive-by refactors, formatting sweeps, or opportunistic cleanups.
- Do not alter non-target hardware/runtime paths (CPU/CUDA/MPS/XPU) unless required by the task.
- If any cross-path change is necessary, isolate it and justify it in the PR notes.
- Preserve existing public interfaces unless the task explicitly requires an interface change.
## Decomposition and Module Size Policy
- Prefer single-responsibility modules with clear boundaries.
- Target module size:
- Optimal: `<= 150` LOC @coderabbit: Exceeding this limit is NIGGLE, not hard cap or major issue.
- Hard cap: `200` LOC
- Function decomposition rules:
- Do one thing at a time; if a function description naturally contains "and", split it.
- Split by responsibility, not by convenience.
- Keep data flow explicit (`data in, data out`); side effects must be obvious and deliberate.
- Push decisions up and push work down (orchestration at higher layers, execution details in lower layers).
- The call graph should read clearly from top-level orchestration to leaf operations.
- If a module would exceed `200` LOC:
- Split by responsibility before merging, or
- Add a short justification in PR notes and include a concrete follow-up split plan.
- Keep orchestrator/facade modules thin. Move logic into focused helpers/services.
- Preserve stable facade imports when splitting large files so external callers are not broken.
## Python Unit Testing Expectations
- Add or update tests for every behavior change and bug fix.
- Match repository conventions:
- Use `unittest`-style tests.
- Name test files as `*_test.py` or `test_*.py`.
- Keep tests deterministic, fast, and scoped to changed behavior.
- Use mocks/fakes for GPU, filesystem, network, and external services where possible.
- If a change requires mocking a large portion of the system to test one unit, treat that as a decomposition smell and refactor boundaries.
- Include at least:
- One success-path test.
- One regression/edge-case test for the bug being fixed.
- One non-target behavior check when relevant.
- Run targeted tests locally before submitting.
## Feature Gating and WIP Safety
- Do not expose unfinished or non-functional user-facing flows by default.
- Gate WIP or unstable UI/API paths behind explicit feature/release flags.
- Keep default behavior stable; "coming soon" paths must not appear as usable functionality unless they are operational and tested.
## Python Coding Best Practices
- Use explicit, readable code over clever shortcuts.
- Docstrings are mandatory for all new or modified Python modules, classes, and functions.
- Docstrings must be concise and include purpose plus key inputs/outputs (and raised exceptions when relevant).
- Add type hints for new/modified functions when practical.
- Keep functions focused and short; extract helpers instead of nesting complexity.
- Use clear names that describe behavior, not implementation trivia.
- Prefer pure functions for logic-heavy paths where possible.
- Avoid duplicated logic, but do not introduce broad abstractions too early; prefer simple local duplication over unstable premature abstraction.
- Handle errors explicitly; avoid bare `except`.
- Keep logging actionable; avoid noisy logs and `print` debugging in committed code.
- Avoid hidden state and unintended side effects.
- Write comments only where intent is non-obvious; keep comments concise and technical.
## AI-Agent Workflow (Recommended)
1. Understand the task and define explicit in-scope/out-of-scope boundaries.
2. Propose a minimal patch plan before editing.
3. Implement the smallest viable change.
4. Add/update focused tests.
5. Self-review only changed hunks for regressions and scope creep.
6. Summarize risk, validation, and non-target impact in PR notes.
## PR Readiness Checklist
- [ ] Change is tightly scoped to one problem.
- [ ] Non-target paths are unchanged, or changes are explicitly justified.
- [ ] New/updated tests cover changed behavior and edge cases.
- [ ] No unrelated refactor/formatting churn.
- [ ] Required docstrings are present for all new/modified modules, classes, and functions.
- [ ] WIP/unstable functionality is feature-flagged and not exposed as default-ready behavior.
- [ ] Module LOC policy is met (`<=150` target, `<=200` hard cap or justified exception).

View file

@ -36,7 +36,7 @@ try:
from .llm_inference import LLMHandler
from .dataset_handler import DatasetHandler
from .gradio_ui import create_gradio_interface
from .gpu_config import get_gpu_config, get_gpu_memory_gb, print_gpu_config_info, set_global_gpu_config, VRAM_16GB_MIN_GB, VRAM_AUTO_OFFLOAD_THRESHOLD_GB
from .gpu_config import get_gpu_config, get_gpu_memory_gb, print_gpu_config_info, set_global_gpu_config, VRAM_16GB_MIN_GB, VRAM_AUTO_OFFLOAD_THRESHOLD_GB, is_mps_platform
from .model_downloader import ensure_lm_model
except ImportError:
# When executed as a script: `python acestep/acestep_v15_pipeline.py`
@ -47,7 +47,7 @@ except ImportError:
from acestep.llm_inference import LLMHandler
from acestep.dataset_handler import DatasetHandler
from acestep.gradio_ui import create_gradio_interface
from acestep.gpu_config import get_gpu_config, get_gpu_memory_gb, print_gpu_config_info, set_global_gpu_config, VRAM_16GB_MIN_GB, VRAM_AUTO_OFFLOAD_THRESHOLD_GB
from acestep.gpu_config import get_gpu_config, get_gpu_memory_gb, print_gpu_config_info, set_global_gpu_config, VRAM_16GB_MIN_GB, VRAM_AUTO_OFFLOAD_THRESHOLD_GB, is_mps_platform
from acestep.model_downloader import ensure_lm_model
@ -93,11 +93,14 @@ def main():
set_global_gpu_config(gpu_config) # Set global config for use across modules
gpu_memory_gb = gpu_config.gpu_memory_gb
_is_mac = is_mps_platform()
# Enable auto-offload for GPUs below 20 GB. 16 GB GPUs cannot hold all
# models simultaneously (DiT ~4.7 + VAE ~0.3 + text_enc ~1.2 + LM ≥1.2 +
# activations) so they *must* offload. The old threshold of 16 GB caused
# 16 GB GPUs to never offload, leading to OOM.
auto_offload = gpu_memory_gb > 0 and gpu_memory_gb < VRAM_AUTO_OFFLOAD_THRESHOLD_GB
# Mac (Apple Silicon) uses unified memory — offloading provides no benefit.
auto_offload = (not _is_mac) and gpu_memory_gb > 0 and gpu_memory_gb < VRAM_AUTO_OFFLOAD_THRESHOLD_GB
_default_backend = "mlx" if _is_mac else "vllm"
# Print GPU configuration info
print(f"\n{'='*60}")
@ -113,7 +116,9 @@ def main():
print(f" Available LM Models: {gpu_config.available_lm_models or 'None'}")
print(f"{'='*60}\n")
if auto_offload:
if _is_mac:
print(f"Apple Silicon (MPS) detected — unified memory {gpu_memory_gb:.1f}GB, no CPU offload needed, backend={_default_backend}")
elif auto_offload:
print(f"Auto-enabling CPU offload (GPU {gpu_memory_gb:.1f}GB < {VRAM_AUTO_OFFLOAD_THRESHOLD_GB}GB threshold)")
elif gpu_memory_gb > 0:
print(f"CPU offload disabled by default (GPU {gpu_memory_gb:.1f}GB >= {VRAM_AUTO_OFFLOAD_THRESHOLD_GB}GB threshold)")
@ -152,7 +157,7 @@ def main():
parser.add_argument("--device", type=str, default="auto", choices=["auto", "cuda", "mps", "xpu", "cpu"], help="Processing device (default: auto)")
parser.add_argument("--init_llm", type=lambda x: x.lower() in ['true', '1', 'yes'], default=None, help="Initialize 5Hz LM (default: auto based on GPU memory)")
parser.add_argument("--lm_model_path", type=str, default=None, help="5Hz LM model path (e.g., 'acestep-5Hz-lm-0.6B')")
parser.add_argument("--backend", type=str, default="vllm", choices=["vllm", "pt", "mlx"], help="5Hz LM backend (default: vllm, use 'mlx' for native Apple Silicon acceleration)")
parser.add_argument("--backend", type=str, default=_default_backend, choices=["vllm", "pt", "mlx"], help=f"5Hz LM backend (default: {_default_backend}, use 'mlx' for native Apple Silicon acceleration)")
parser.add_argument("--use_flash_attention", type=lambda x: x.lower() in ['true', '1', 'yes'], default=None, help="Use flash attention (default: auto-detect)")
parser.add_argument("--offload_to_cpu", type=lambda x: x.lower() in ['true', '1', 'yes'], default=auto_offload, help=f"Offload models to CPU (default: {'True' if auto_offload else 'False'}, auto-detected based on GPU VRAM)")
parser.add_argument("--offload_dit_to_cpu", type=lambda x: x.lower() in ['true', '1', 'yes'], default=False, help="Offload DiT to CPU (default: False)")

View file

@ -41,6 +41,7 @@ except ImportError: # Optional dependency
load_dotenv = None # type: ignore
from fastapi import FastAPI, HTTPException, Request, Depends, Header
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from starlette.datastructures import UploadFile as StarletteUploadFile
@ -374,6 +375,8 @@ PARAM_ALIASES = {
"use_cot_language": ["use_cot_language", "cot_language", "cot-language"],
"is_format_caption": ["is_format_caption", "isFormatCaption"],
"allow_lm_batch": ["allow_lm_batch", "allowLmBatch", "parallel_thinking"],
"track_name": ["track_name", "trackName"],
"track_classes": ["track_classes", "trackClasses", "instruments"],
}
@ -519,6 +522,8 @@ class GenerateMusicRequest(BaseModel):
use_cot_language: bool = True
is_format_caption: bool = False
allow_lm_batch: bool = True
track_name: Optional[str] = None
track_classes: Optional[List[str]] = None
lm_temperature: float = 0.85
lm_cfg_scale: float = 2.5
@ -893,13 +898,33 @@ class RequestParser:
def _validate_audio_path(path: Optional[str]) -> Optional[str]:
"""Validate a user-supplied audio file path to prevent path traversal attacks.
Rejects absolute paths and paths containing '..' traversal sequences.
Returns the validated path or None if the input is None/empty.
Accepts absolute paths strictly only if they are within the system temporary directory.
Otherwise, rejects absolute paths and paths containing '..' traversal sequences.
Returns the validated, normalized path or None if the input is None/empty.
Raises HTTPException 400 if the path is unsafe.
"""
if not path:
return None
# Reject absolute paths (Unix and Windows)
# Resolve requested path and system temp path to normalized absolute forms
import tempfile
system_temp = os.path.realpath(tempfile.gettempdir())
requested_path = os.path.realpath(path)
# SECURE CHECK: Use os.path.commonpath to verify directory boundary integrity.
# This prevents prefix bypasses (e.g., /tmp_evil when /tmp is allowed).
try:
is_in_temp = os.path.commonpath([system_temp, requested_path]) == system_temp
except ValueError:
# Occurs on Windows if paths are on different drives
is_in_temp = False
if is_in_temp:
# Accept server-generated files in temp
return requested_path
# Reject manual absolute paths outside of temp
if os.path.isabs(path):
raise HTTPException(status_code=400, detail="absolute audio file paths are not allowed")
# Reject path traversal via '..' components
@ -1492,7 +1517,25 @@ def create_app() -> FastAPI:
# This matches gradio behavior which uses TASK_INSTRUCTIONS for each task type
instruction_to_use = req.instruction
if instruction_to_use == DEFAULT_DIT_INSTRUCTION and req.task_type in TASK_INSTRUCTIONS:
instruction_to_use = TASK_INSTRUCTIONS[req.task_type]
raw_instruction = TASK_INSTRUCTIONS[req.task_type]
if req.task_type == "complete":
# Use track_classes joined by pipes
if req.track_classes:
# Join list items: ["Drums", "Bass"] -> "DRUMS | BASS"
classes_str = " | ".join([str(t).upper() for t in req.track_classes])
# Use the raw instruction template from constants
# Format: "Complete the track with {TRACK_CLASSES}:"
instruction_to_use = raw_instruction.format(TRACK_CLASSES=classes_str)
else:
# Fallback if no classes provided
instruction_to_use = TASK_INSTRUCTIONS.get("complete_default", raw_instruction)
elif "{TRACK_NAME}" in raw_instruction and req.track_name:
# Logic for extract/lego
instruction_to_use = raw_instruction.format(TRACK_NAME=req.track_name.upper())
else:
instruction_to_use = raw_instruction
# Build GenerationParams using unified interface
# Note: thinking controls LM code generation, sample_mode only affects CoT metas
@ -1542,11 +1585,30 @@ def create_app() -> FastAPI:
# Build GenerationConfig - default to 2 audios like gradio_ui
batch_size = req.batch_size if req.batch_size is not None else 2
# Resolve seed(s) from req.seed into List[int] for GenerationConfig.seeds
resolved_seeds = None
if not req.use_random_seed and req.seed is not None:
if isinstance(req.seed, int):
if req.seed >= 0:
resolved_seeds = [req.seed]
elif isinstance(req.seed, str):
resolved_seeds = []
for s in req.seed.split(","):
s = s.strip()
if s and s != "-1":
try:
resolved_seeds.append(int(float(s)))
except (ValueError, TypeError):
pass
if not resolved_seeds:
resolved_seeds = None
config = GenerationConfig(
batch_size=batch_size,
allow_lm_batch=req.allow_lm_batch,
use_random_seed=req.use_random_seed,
seeds=None, # Let unified logic handle seed generation
seeds=resolved_seeds,
audio_format=req.audio_format,
constrained_decoding_debug=req.constrained_decoding_debug,
)
@ -2155,6 +2217,16 @@ def create_app() -> FastAPI:
app = FastAPI(title="ACE-Step API", version="1.0", lifespan=lifespan)
# Enable CORS for browser-based frontends (e.g. studio.html opened via file://)
# Restricted to localhost origins and the "null" origin (file:// protocol)
app.add_middleware(
CORSMiddleware,
allow_origins=["null", "http://localhost", "http://127.0.0.1"],
allow_origin_regex=r"^https?://(localhost|127\.0\.0\.1)(:\d+)?$",
allow_methods=["GET", "POST", "OPTIONS"],
allow_headers=["Content-Type", "Authorization"],
)
# Mount OpenRouter-compatible endpoints (/v1/chat/completions, /v1/models)
from acestep.openrouter_adapter import create_openrouter_router
openrouter_router = create_openrouter_router(lambda: app.state)
@ -2185,6 +2257,11 @@ def create_app() -> FastAPI:
# when callers (multipart/form, url-encoded, raw body) pass them explicitly.
ref_audio = kwargs.pop("reference_audio_path", None) or p.str("reference_audio_path") or None
src_audio = kwargs.pop("src_audio_path", None) or p.str("src_audio_path") or None
t_classes = p.get("track_classes")
if t_classes is not None and isinstance(t_classes, str):
t_classes = [t_classes]
return GenerateMusicRequest(
prompt=p.str("prompt"),
lyrics=p.str("lyrics"),
@ -2209,8 +2286,8 @@ def create_app() -> FastAPI:
repainting_end=p.float("repainting_end"),
instruction=p.str("instruction", DEFAULT_DIT_INSTRUCTION),
audio_cover_strength=p.float("audio_cover_strength", 1.0),
reference_audio_path=_validate_audio_path(ref_audio),
src_audio_path=_validate_audio_path(src_audio),
reference_audio_path=ref_audio,
src_audio_path=src_audio,
task_type=p.str("task_type", "text2music"),
use_adg=p.bool("use_adg"),
cfg_interval_start=p.float("cfg_interval_start", 0.0),
@ -2233,6 +2310,8 @@ def create_app() -> FastAPI:
use_cot_language=p.bool("use_cot_language", True),
is_format_caption=p.bool("is_format_caption"),
allow_lm_batch=p.bool("allow_lm_batch", True),
track_name=p.str("track_name"),
track_classes=t_classes,
**kwargs,
)
@ -2241,18 +2320,40 @@ def create_app() -> FastAPI:
if not isinstance(body, dict):
raise HTTPException(status_code=400, detail="JSON payload must be an object")
verify_token_from_request(body, authorization)
req = _build_request(RequestParser(body))
# Explicitly validate manual string paths from JSON input
p = RequestParser(body)
req = _build_request(
p,
reference_audio_path=_validate_audio_path(p.str("reference_audio_path") or None),
src_audio_path=_validate_audio_path(p.str("src_audio_path") or None)
)
elif content_type.endswith("+json"):
body = await request.json()
if not isinstance(body, dict):
raise HTTPException(status_code=400, detail="JSON payload must be an object")
verify_token_from_request(body, authorization)
req = _build_request(RequestParser(body))
p = RequestParser(body)
req = _build_request(
p,
reference_audio_path=_validate_audio_path(p.str("reference_audio_path") or None),
src_audio_path=_validate_audio_path(p.str("src_audio_path") or None)
)
elif content_type.startswith("multipart/form-data"):
form = await request.form()
form_dict = {k: v for k, v in form.items() if not hasattr(v, 'read')}
# Parse form data correctly to support lists ---
form_dict = {}
for k in form.keys():
vals = [v for v in form.getlist(k) if not hasattr(v, 'read')]
if len(vals) == 1:
form_dict[k] = vals[0]
elif len(vals) > 1:
form_dict[k] = vals
verify_token_from_request(form_dict, authorization)
# Support both naming conventions: ref_audio/reference_audio, ctx_audio/src_audio
@ -2275,7 +2376,7 @@ def create_app() -> FastAPI:
src_audio_path = _validate_audio_path(str(form.get("ctx_audio_path") or form.get("src_audio_path") or "").strip() or None)
req = _build_request(
RequestParser(dict(form)),
RequestParser(dict(form_dict)),
reference_audio_path=reference_audio_path,
src_audio_path=src_audio_path,
)
@ -2301,7 +2402,12 @@ def create_app() -> FastAPI:
body = json.loads(raw.decode("utf-8"))
if isinstance(body, dict):
verify_token_from_request(body, authorization)
req = _build_request(RequestParser(body))
p = RequestParser(body)
req = _build_request(
p,
reference_audio_path=_validate_audio_path(p.str("reference_audio_path") or None),
src_audio_path=_validate_audio_path(p.str("src_audio_path") or None)
)
else:
raise HTTPException(status_code=400, detail="JSON payload must be an object")
except HTTPException:

View file

@ -1,354 +1,406 @@
"""
Audio saving and transcoding utility module
Independent audio file operations outside of handler, supporting:
- Save audio tensor/numpy to files (default FLAC format, fast)
- Format conversion (FLAC/WAV/MP3)
- Batch processing
"""
import os
import hashlib
import json
from pathlib import Path
from typing import Union, Optional, List, Tuple
import torch
import numpy as np
import torchaudio
from loguru import logger
class AudioSaver:
"""Audio saving and transcoding utility class"""
def __init__(self, default_format: str = "flac"):
"""
Initialize audio saver
Args:
default_format: Default save format ('flac', 'wav', 'mp3')
"""
self.default_format = default_format.lower()
if self.default_format not in ["flac", "wav", "mp3"]:
logger.warning(f"Unsupported format {default_format}, using 'flac'")
self.default_format = "flac"
def save_audio(
self,
audio_data: Union[torch.Tensor, np.ndarray],
output_path: Union[str, Path],
sample_rate: int = 48000,
format: Optional[str] = None,
channels_first: bool = True,
) -> str:
"""
Save audio data to file
Args:
audio_data: Audio data, torch.Tensor [channels, samples] or numpy.ndarray
output_path: Output file path (extension can be omitted)
sample_rate: Sample rate
format: Audio format ('flac', 'wav', 'mp3'), defaults to default_format
channels_first: If True, tensor format is [channels, samples], else [samples, channels]
Returns:
Actual saved file path
"""
format = (format or self.default_format).lower()
if format not in ["flac", "wav", "mp3"]:
logger.warning(f"Unsupported format {format}, using {self.default_format}")
format = self.default_format
# Ensure output path has correct extension
output_path = Path(output_path)
if output_path.suffix.lower() not in ['.flac', '.wav', '.mp3']:
output_path = output_path.with_suffix(f'.{format}')
# Convert to torch tensor
if isinstance(audio_data, np.ndarray):
if channels_first:
# numpy [samples, channels] -> tensor [channels, samples]
audio_tensor = torch.from_numpy(audio_data.T).float()
else:
# numpy [samples, channels] -> tensor [samples, channels] -> [channels, samples]
audio_tensor = torch.from_numpy(audio_data).float()
if audio_tensor.dim() == 2 and audio_tensor.shape[0] < audio_tensor.shape[1]:
audio_tensor = audio_tensor.T
else:
# torch tensor
audio_tensor = audio_data.cpu().float()
if not channels_first and audio_tensor.dim() == 2:
# [samples, channels] -> [channels, samples]
if audio_tensor.shape[0] > audio_tensor.shape[1]:
audio_tensor = audio_tensor.T
# Ensure memory is contiguous
audio_tensor = audio_tensor.contiguous()
# Select backend and save
try:
if format == "mp3":
# MP3 uses ffmpeg backend
torchaudio.save(
str(output_path),
audio_tensor,
sample_rate,
channels_first=True,
backend='ffmpeg',
)
elif format in ["flac", "wav"]:
# FLAC and WAV use soundfile backend (fastest)
torchaudio.save(
str(output_path),
audio_tensor,
sample_rate,
channels_first=True,
backend='soundfile',
)
else:
# Other formats use default backend
torchaudio.save(
str(output_path),
audio_tensor,
sample_rate,
channels_first=True,
)
logger.debug(f"[AudioSaver] Saved audio to {output_path} ({format}, {sample_rate}Hz)")
return str(output_path)
except Exception as e:
try:
import soundfile as sf
audio_np = audio_tensor.transpose(0, 1).numpy() # -> [samples, channels]
sf.write(str(output_path), audio_np, sample_rate, format=format.upper())
logger.debug(f"[AudioSaver] Fallback soundfile Saved audio to {output_path} ({format}, {sample_rate}Hz)")
return str(output_path)
except Exception as e:
logger.error(f"[AudioSaver] Failed to save audio: {e}")
raise
def convert_audio(
self,
input_path: Union[str, Path],
output_path: Union[str, Path],
output_format: str,
remove_input: bool = False,
) -> str:
"""
Convert audio format
Args:
input_path: Input audio file path
output_path: Output audio file path
output_format: Target format ('flac', 'wav', 'mp3')
remove_input: Whether to delete input file
Returns:
Output file path
"""
input_path = Path(input_path)
output_path = Path(output_path)
if not input_path.exists():
raise FileNotFoundError(f"Input file not found: {input_path}")
# Load audio
audio_tensor, sample_rate = torchaudio.load(str(input_path))
# Save as new format
output_path = self.save_audio(
audio_tensor,
output_path,
sample_rate=sample_rate,
format=output_format,
channels_first=True
)
# Delete input file if needed
if remove_input:
input_path.unlink()
logger.debug(f"[AudioSaver] Removed input file: {input_path}")
return output_path
def save_batch(
self,
audio_batch: Union[List[torch.Tensor], torch.Tensor],
output_dir: Union[str, Path],
file_prefix: str = "audio",
sample_rate: int = 48000,
format: Optional[str] = None,
channels_first: bool = True,
) -> List[str]:
"""
Save audio batch
Args:
audio_batch: Audio batch, List[tensor] or tensor [batch, channels, samples]
output_dir: Output directory
file_prefix: File prefix
sample_rate: Sample rate
format: Audio format
channels_first: Tensor format flag
Returns:
List of saved file paths
"""
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
# Process batch
if isinstance(audio_batch, torch.Tensor) and audio_batch.dim() == 3:
# [batch, channels, samples]
audio_list = [audio_batch[i] for i in range(audio_batch.shape[0])]
elif isinstance(audio_batch, list):
audio_list = audio_batch
else:
audio_list = [audio_batch]
saved_paths = []
for i, audio in enumerate(audio_list):
output_path = output_dir / f"{file_prefix}_{i:04d}"
saved_path = self.save_audio(
audio,
output_path,
sample_rate=sample_rate,
format=format,
channels_first=channels_first
)
saved_paths.append(saved_path)
return saved_paths
def get_audio_file_hash(audio_file) -> str:
"""
Get hash identifier for an audio file.
Args:
audio_file: Path to audio file (str) or file-like object
Returns:
Hash string or empty string
"""
if audio_file is None:
return ""
try:
if isinstance(audio_file, str):
if os.path.exists(audio_file):
with open(audio_file, 'rb') as f:
return hashlib.md5(f.read()).hexdigest()
return hashlib.md5(audio_file.encode('utf-8')).hexdigest()
elif hasattr(audio_file, 'name'):
return hashlib.md5(str(audio_file.name).encode('utf-8')).hexdigest()
return hashlib.md5(str(audio_file).encode('utf-8')).hexdigest()
except Exception:
return hashlib.md5(str(audio_file).encode('utf-8')).hexdigest()
def generate_uuid_from_params(params_dict) -> str:
"""
Generate deterministic UUID from generation parameters.
Same parameters will always generate the same UUID.
Args:
params_dict: Dictionary of parameters
Returns:
UUID string
"""
params_json = json.dumps(params_dict, sort_keys=True, ensure_ascii=False)
hash_obj = hashlib.sha256(params_json.encode('utf-8'))
hash_hex = hash_obj.hexdigest()
uuid_str = f"{hash_hex[0:8]}-{hash_hex[8:12]}-{hash_hex[12:16]}-{hash_hex[16:20]}-{hash_hex[20:32]}"
return uuid_str
def generate_uuid_from_audio_data(
audio_data: Union[torch.Tensor, np.ndarray],
seed: Optional[int] = None
) -> str:
"""
Generate UUID from audio data (for caching/deduplication)
Args:
audio_data: Audio data
seed: Optional seed value
Returns:
UUID string
"""
if isinstance(audio_data, torch.Tensor):
# Convert to numpy and calculate hash
audio_np = audio_data.cpu().numpy()
else:
audio_np = audio_data
# Calculate data hash
data_hash = hashlib.md5(audio_np.tobytes()).hexdigest()
if seed is not None:
combined = f"{data_hash}_{seed}"
return hashlib.md5(combined.encode()).hexdigest()
return data_hash
# Global default instance
_default_saver = AudioSaver(default_format="flac")
SILENT_RMS_THRESHOLD = 1e-5
SILENT_PEAK_THRESHOLD = 1e-5
def is_audio_silent(
audio_data: Union[torch.Tensor, np.ndarray],
rms_threshold: float = SILENT_RMS_THRESHOLD,
peak_threshold: float = SILENT_PEAK_THRESHOLD,
channels_first: bool = True,
) -> Tuple[bool, float, float]:
"""
Check if audio is silent or near-silent (e.g. zeroed conditioning output).
Returns (is_silent, rms, peak) where rms/peak are computed over the full signal.
"""
if audio_data is None:
return True, 0.0, 0.0
if isinstance(audio_data, np.ndarray):
x = np.asarray(audio_data, dtype=np.float64).ravel()
else:
x = audio_data.cpu().float().numpy().ravel()
if x.size == 0:
return True, 0.0, 0.0
rms = float(np.sqrt(np.mean(x * x)))
peak = float(np.max(np.abs(x)))
is_silent = rms <= rms_threshold and peak <= peak_threshold
return is_silent, rms, peak
def save_audio(
audio_data: Union[torch.Tensor, np.ndarray],
output_path: Union[str, Path],
sample_rate: int = 48000,
format: Optional[str] = None,
channels_first: bool = True,
) -> str:
"""
Convenience function: save audio (using default configuration)
Args:
audio_data: Audio data
output_path: Output path
sample_rate: Sample rate
format: Format (default flac)
channels_first: Tensor format flag
Returns:
Saved file path
"""
return _default_saver.save_audio(
audio_data, output_path, sample_rate, format, channels_first
)
"""
Audio saving and transcoding utility module
Independent audio file operations outside of handler, supporting:
- Save audio tensor/numpy to files (default FLAC format, fast)
- Format conversion (FLAC/WAV/MP3)
- Batch processing
"""
import io
import json
import os
import subprocess
import hashlib
from pathlib import Path
from typing import Union, Optional, List, Tuple
import torch
import numpy as np
import torchaudio
from loguru import logger
def normalize_audio(audio_data: Union[torch.Tensor, np.ndarray], target_db: float = -1.0) -> Union[torch.Tensor, np.ndarray]:
"""
Apply peak normalization to audio data.
Args:
audio_data: Audio data as torch.Tensor or numpy.ndarray
target_db: Target peak level in dB (default: -1.0)
Returns:
Normalized audio data in the same format as input
"""
# Create a copy to avoid modifying original in-place
if isinstance(audio_data, torch.Tensor):
audio = audio_data.clone()
is_tensor = True
else:
audio = audio_data.copy()
is_tensor = False
# Calculate current peak
if is_tensor:
peak = torch.max(torch.abs(audio))
else:
peak = np.max(np.abs(audio))
# Handle silence/near-silence to avoid division by zero or extreme gain
if peak < 1e-6:
return audio_data
# Convert target dB to linear amplitude
target_amp = 10 ** (target_db / 20.0)
# Calculate needed gain
gain = target_amp / peak
# Apply gain
audio = audio * gain
return audio
class AudioSaver:
"""Audio saving and transcoding utility class"""
def __init__(self, default_format: str = "flac"):
"""
Initialize audio saver
Args:
default_format: Default save format ('flac', 'wav', 'mp3', 'wav32')
"""
self.default_format = default_format.lower()
if self.default_format not in ["flac", "wav", "mp3", "wav32"]:
logger.warning(f"Unsupported format {default_format}, using 'flac'")
self.default_format = "flac"
def save_audio(
self,
audio_data: Union[torch.Tensor, np.ndarray],
output_path: Union[str, Path],
sample_rate: int = 48000,
format: Optional[str] = None,
channels_first: bool = True,
) -> str:
"""
Save audio data to file
Args:
audio_data: Audio data, torch.Tensor [channels, samples] or numpy.ndarray
output_path: Output file path (extension can be omitted)
sample_rate: Sample rate
format: Audio format ('flac', 'wav', 'mp3', 'wav32'), defaults to default_format
channels_first: If True, tensor format is [channels, samples], else [samples, channels]
Returns:
Actual saved file path
"""
format = (format or self.default_format).lower()
if format not in ["flac", "wav", "mp3", "wav32"]:
logger.warning(f"Unsupported format {format}, using {self.default_format}")
format = self.default_format
# Ensure output path has correct extension
output_path = Path(output_path)
# Determine extension based on format
ext = ".wav" if format == "wav32" else f".{format}"
if output_path.suffix.lower() not in ['.flac', '.wav', '.mp3']:
output_path = output_path.with_suffix(ext)
elif format == "wav32" and output_path.suffix.lower() == ".wav32":
# Explicitly fix .wav32 extension if present
output_path = output_path.with_suffix(".wav")
# Convert to torch tensor
if isinstance(audio_data, np.ndarray):
if channels_first:
# numpy already [channels, samples]
audio_tensor = torch.from_numpy(audio_data).float()
else:
# numpy [samples, channels] -> tensor [samples, channels] -> [channels, samples] (if transposed)
audio_tensor = torch.from_numpy(audio_data).float()
if audio_tensor.dim() == 2 and audio_tensor.shape[0] > audio_tensor.shape[1]:
# Assume [samples, channels] if dim0 > dim1 (heuristic)
audio_tensor = audio_tensor.T
else:
# torch tensor
audio_tensor = audio_data.cpu().float()
if not channels_first and audio_tensor.dim() == 2:
# [samples, channels] -> [channels, samples]
if audio_tensor.shape[0] > audio_tensor.shape[1]:
audio_tensor = audio_tensor.T
# Ensure memory is contiguous
audio_tensor = audio_tensor.contiguous()
# Select backend and save
try:
if format == "mp3":
# MP3 uses ffmpeg backend
torchaudio.save(
str(output_path),
audio_tensor,
sample_rate,
channels_first=True,
backend='ffmpeg',
)
elif format in ["flac", "wav", "wav32"]:
# FLAC and WAV use soundfile backend (fastest)
# handle 32-bit float wav
if format == "wav32":
try:
import soundfile as sf
# Use soundfile directly for 32-bit float
audio_np = audio_tensor.transpose(0, 1).numpy() # [channels, samples] -> [samples, channels]
# Explicitly specify format as WAV to avoid issues with extension detection or custom extensions
sf.write(str(output_path), audio_np, sample_rate, subtype='FLOAT', format='WAV')
logger.debug(f"[AudioSaver] Saved audio to {output_path} (wav32, {sample_rate}Hz)")
return str(output_path)
except Exception as e:
logger.error(f"Failed to save wav32: {e}, falling back to standard wav")
format = "wav"
# Fallthrough to standard wav saving
torchaudio.save(
str(output_path),
audio_tensor,
sample_rate,
channels_first=True,
backend='soundfile',
)
else:
# Other formats use default backend
torchaudio.save(
str(output_path),
audio_tensor,
sample_rate,
channels_first=True,
)
logger.debug(f"[AudioSaver] Saved audio to {output_path} ({format}, {sample_rate}Hz)")
return str(output_path)
except Exception as e:
try:
import soundfile as sf
audio_np = audio_tensor.transpose(0, 1).numpy() # -> [samples, channels]
# Handle wav32 fallback formatting
if format == "wav32":
sf_format = "WAV"
subtype = "FLOAT"
else:
sf_format = format.upper()
subtype = None
sf.write(str(output_path), audio_np, sample_rate, format=sf_format, subtype=subtype)
logger.debug(f"[AudioSaver] Fallback soundfile Saved audio to {output_path} ({format}, {sample_rate}Hz)")
return str(output_path)
except Exception as inner_e:
logger.error(f"[AudioSaver] Failed to save audio: {e} -> Fallback failed: {inner_e}")
raise
def convert_audio(
self,
input_path: Union[str, Path],
output_path: Union[str, Path],
output_format: str,
remove_input: bool = False,
) -> str:
"""
Convert audio format
Args:
input_path: Input audio file path
output_path: Output audio file path
output_format: Target format ('flac', 'wav', 'mp3')
remove_input: Whether to delete input file
Returns:
Output file path
"""
input_path = Path(input_path)
output_path = Path(output_path)
if not input_path.exists():
raise FileNotFoundError(f"Input file not found: {input_path}")
# Load audio
audio_tensor, sample_rate = torchaudio.load(str(input_path))
# Save as new format
output_path = self.save_audio(
audio_tensor,
output_path,
sample_rate=sample_rate,
format=output_format,
channels_first=True
)
# Delete input file if needed
if remove_input:
input_path.unlink()
logger.debug(f"[AudioSaver] Removed input file: {input_path}")
return output_path
def save_batch(
self,
audio_batch: Union[List[torch.Tensor], torch.Tensor],
output_dir: Union[str, Path],
file_prefix: str = "audio",
sample_rate: int = 48000,
format: Optional[str] = None,
channels_first: bool = True,
) -> List[str]:
"""
Save audio batch
Args:
audio_batch: Audio batch, List[tensor] or tensor [batch, channels, samples]
output_dir: Output directory
file_prefix: File prefix
sample_rate: Sample rate
format: Audio format
channels_first: Tensor format flag
Returns:
List of saved file paths
"""
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
# Process batch
if isinstance(audio_batch, torch.Tensor) and audio_batch.dim() == 3:
# [batch, channels, samples]
audio_list = [audio_batch[i] for i in range(audio_batch.shape[0])]
elif isinstance(audio_batch, list):
audio_list = audio_batch
else:
audio_list = [audio_batch]
saved_paths = []
for i, audio in enumerate(audio_list):
output_path = output_dir / f"{file_prefix}_{i:04d}"
saved_path = self.save_audio(
audio,
output_path,
sample_rate=sample_rate,
format=format,
channels_first=channels_first
)
saved_paths.append(saved_path)
return saved_paths
def get_audio_file_hash(audio_file) -> str:
"""
Get hash identifier for an audio file.
Args:
audio_file: Path to audio file (str) or file-like object
Returns:
Hash string or empty string
"""
if audio_file is None:
return ""
try:
if isinstance(audio_file, str):
if os.path.exists(audio_file):
with open(audio_file, 'rb') as f:
return hashlib.sha256(f.read()).hexdigest()
return hashlib.sha256(audio_file.encode('utf-8')).hexdigest()
elif hasattr(audio_file, 'name'):
return hashlib.sha256(str(audio_file.name).encode('utf-8')).hexdigest()
return hashlib.sha256(str(audio_file).encode('utf-8')).hexdigest()
except Exception:
return hashlib.sha256(str(audio_file).encode('utf-8')).hexdigest()
def generate_uuid_from_params(params_dict) -> str:
"""
Generate deterministic UUID from generation parameters.
Same parameters will always generate the same UUID.
Args:
params_dict: Dictionary of parameters
Returns:
UUID string
"""
params_json = json.dumps(params_dict, sort_keys=True, ensure_ascii=False)
hash_obj = hashlib.sha256(params_json.encode('utf-8'))
hash_hex = hash_obj.hexdigest()
uuid_str = f"{hash_hex[0:8]}-{hash_hex[8:12]}-{hash_hex[12:16]}-{hash_hex[16:20]}-{hash_hex[20:32]}"
return uuid_str
def generate_uuid_from_audio_data(
audio_data: Union[torch.Tensor, np.ndarray],
seed: Optional[int] = None
) -> str:
"""
Generate UUID from audio data (for caching/deduplication)
Args:
audio_data: Audio data
seed: Optional seed value
Returns:
UUID string
"""
if isinstance(audio_data, torch.Tensor):
# Convert to numpy and calculate hash
audio_np = audio_data.cpu().numpy()
else:
audio_np = audio_data
# Calculate data hash
data_hash = hashlib.sha256(audio_np.tobytes()).hexdigest()
if seed is not None:
combined = f"{data_hash}_{seed}"
return hashlib.sha256(combined.encode()).hexdigest()
return data_hash
# Global default instance
_default_saver = AudioSaver(default_format="flac")
def save_audio(
audio_data: Union[torch.Tensor, np.ndarray],
output_path: Union[str, Path],
sample_rate: int = 48000,
format: Optional[str] = None,
channels_first: bool = True,
) -> str:
"""
Convenience function: save audio (using default configuration)
Args:
audio_data: Audio data
output_path: Output path
sample_rate: Sample rate
format: Format (default flac)
channels_first: Tensor format flag
Returns:
Saved file path
"""
return _default_saver.save_audio(
audio_data, output_path, sample_rate, format, channels_first
)

View file

@ -87,6 +87,28 @@ TASK_TYPES_TURBO = ["text2music", "repaint", "cover"]
TASK_TYPES_BASE = ["text2music", "repaint", "cover", "extract", "lego", "complete"]
# ==============================================================================
# Generation Mode Constants (UI-level modes that map to task types)
# ==============================================================================
# Default modes for turbo and SFT models (restricted set)
GENERATION_MODES_TURBO = ["Simple", "Custom", "Remix", "Repaint"]
# Extended modes for pure base models only — adds Extract/Lego/Complete
GENERATION_MODES_BASE = ["Simple", "Custom", "Remix", "Repaint", "Extract", "Lego", "Complete"]
# Mapping from generation mode to task_type value
MODE_TO_TASK_TYPE = {
"Simple": "text2music",
"Custom": "text2music",
"Remix": "cover",
"Repaint": "repaint",
"Extract": "extract",
"Lego": "lego",
"Complete": "complete",
}
# ==============================================================================
# Instruction Constants
# ==============================================================================

View file

@ -1,6 +1,29 @@
"""Handler decomposition components."""
from .audio_codes import AudioCodesMixin
from .batch_prep import BatchPrepMixin
from .diffusion import DiffusionMixin
from .init_service import InitServiceMixin
from .io_audio import IoAudioMixin
from .lora_manager import LoraManagerMixin
from .memory_utils import MemoryUtilsMixin
from .metadata_utils import MetadataMixin
from .padding_utils import PaddingMixin
from .prompt_utils import PromptMixin
from .progress import ProgressMixin
from .task_utils import TaskUtilsMixin
__all__ = ["LoraManagerMixin", "ProgressMixin"]
__all__ = [
"AudioCodesMixin",
"BatchPrepMixin",
"DiffusionMixin",
"InitServiceMixin",
"IoAudioMixin",
"LoraManagerMixin",
"MemoryUtilsMixin",
"MetadataMixin",
"PaddingMixin",
"PromptMixin",
"ProgressMixin",
"TaskUtilsMixin",
]

View file

@ -0,0 +1,99 @@
"""Audio-code parsing and conversion helpers for handler decomposition."""
import re
import traceback
from typing import List, Optional
import torch
from loguru import logger
class AudioCodesMixin:
"""Mixin containing audio-code parsing and latent conversion helpers.
Depends on host members:
- Attributes: ``model``, ``vae``, ``device``, ``dtype``, ``silence_latent``.
- Methods: ``_load_model_context``, ``process_src_audio``, ``is_silence``,
``_encode_audio_to_latents``.
"""
def _parse_audio_code_string(self, code_str: str) -> List[int]:
"""Extract integer audio codes from tokens like ``<|audio_code_123|>``."""
if not code_str:
return []
try:
max_audio_code = 63999
codes = []
clamped_count = 0
for x in re.findall(r"<\|audio_code_(\d+)\|>", code_str):
code_value = int(x)
clamped_value = max(0, min(code_value, max_audio_code))
if clamped_value != code_value:
clamped_count += 1
logger.warning(
f"[_parse_audio_code_string] Clamped audio code value from {code_value} to {clamped_value}"
)
codes.append(clamped_value)
if clamped_count > 0:
logger.warning(
f"[_parse_audio_code_string] Clamped {clamped_count} audio code value(s) "
f"to valid range [0, {max_audio_code}]"
)
return codes
except Exception as e:
logger.debug(f"[_parse_audio_code_string] Failed to parse audio code string: {e}")
return []
def _decode_audio_codes_to_latents(self, code_str: str) -> Optional[torch.Tensor]:
"""Convert serialized audio-code string into 25Hz latents."""
if self.model is None or not hasattr(self.model, "tokenizer") or not hasattr(self.model, "detokenizer"):
return None
code_ids = self._parse_audio_code_string(code_str)
if len(code_ids) == 0:
return None
with self._load_model_context("model"):
quantizer = self.model.tokenizer.quantizer
detokenizer = self.model.detokenizer
indices = torch.tensor(code_ids, device=self.device, dtype=torch.long)
indices = indices.unsqueeze(0).unsqueeze(-1)
quantized = quantizer.get_output_from_indices(indices)
if quantized.dtype != self.dtype:
quantized = quantized.to(self.dtype)
lm_hints_25hz = detokenizer(quantized)
return lm_hints_25hz
def convert_src_audio_to_codes(self, audio_file) -> str:
"""Convert uploaded source audio into serialized audio code tokens."""
if audio_file is None:
return "❌ Please upload source audio first"
if self.model is None or self.vae is None:
return "❌ Model not initialized. Please initialize the service first."
try:
processed_audio = self.process_src_audio(audio_file)
if processed_audio is None:
return "❌ Failed to process audio file"
with torch.inference_mode():
with self._load_model_context("vae"):
if self.is_silence(processed_audio.unsqueeze(0)):
return "❌ Audio file appears to be silent"
latents = self._encode_audio_to_latents(processed_audio)
attention_mask = torch.ones(latents.shape[0], dtype=torch.bool, device=self.device)
with self._load_model_context("model"):
hidden_states = latents.unsqueeze(0)
_, indices, _ = self.model.tokenize(
hidden_states, self.silence_latent, attention_mask.unsqueeze(0)
)
indices_flat = indices.flatten().cpu().tolist()
codes_string = "".join([f"<|audio_code_{idx}|>" for idx in indices_flat])
logger.info(f"[convert_src_audio_to_codes] Generated {len(indices_flat)} audio codes")
return codes_string
except Exception as e:
error_msg = f"❌ Error converting audio to codes: {str(e)}\n{traceback.format_exc()}"
logger.exception("[convert_src_audio_to_codes] Error converting audio to codes")
return error_msg

View file

@ -0,0 +1,104 @@
"""Batch preparation helpers for handler decomposition."""
from typing import Dict, List, Optional, Union
import torch
from acestep.constants import DEFAULT_DIT_INSTRUCTION
class BatchPrepMixin:
"""Mixin containing batch and input normalization helpers.
Depends on host members:
- Attributes: ``device``, ``dtype``.
- Methods: ``tiled_encode``, ``extract_caption_from_sft_format``,
``_build_metadata_dict``.
"""
def _normalize_audio_code_hints(
self, audio_code_hints: Optional[Union[str, List[str]]], batch_size: int
) -> List[Optional[str]]:
"""Normalize ``audio_code_hints`` into a batch-length list."""
if audio_code_hints is None:
normalized: List[Optional[str]] = [None] * batch_size
elif isinstance(audio_code_hints, str):
normalized = [audio_code_hints] * batch_size
elif len(audio_code_hints) == 1 and batch_size > 1:
normalized = audio_code_hints * batch_size
elif len(audio_code_hints) != batch_size:
normalized = list(audio_code_hints[:batch_size])
while len(normalized) < batch_size:
normalized.append(None)
else:
normalized = list(audio_code_hints)
return [hint if isinstance(hint, str) and hint.strip() else None for hint in normalized]
def _normalize_instructions(
self,
instructions: Optional[Union[str, List[str]]],
batch_size: int,
default: Optional[str] = None,
) -> List[str]:
"""Normalize instructions into a batch-length list."""
if instructions is None:
default_instruction = default or DEFAULT_DIT_INSTRUCTION
return [default_instruction] * batch_size
if isinstance(instructions, str):
return [instructions] * batch_size
if len(instructions) == 1:
return instructions * batch_size
if len(instructions) != batch_size:
normalized = list(instructions[:batch_size])
default_instruction = default or DEFAULT_DIT_INSTRUCTION
while len(normalized) < batch_size:
normalized.append(default_instruction)
return normalized
return list(instructions)
def _encode_audio_to_latents(self, audio: torch.Tensor) -> torch.Tensor:
"""Encode audio to latents using tiled VAE encode path."""
input_was_2d = audio.dim() == 2
if input_was_2d:
audio = audio.unsqueeze(0)
with torch.inference_mode():
latents = self.tiled_encode(audio, offload_latent_to_cpu=True)
latents = latents.to(self.device).to(self.dtype)
latents = latents.transpose(1, 2)
if input_was_2d:
latents = latents.squeeze(0)
return latents
def prepare_batch_data(
self,
actual_batch_size,
processed_src_audio,
audio_duration,
captions,
lyrics,
vocal_language,
instruction,
bpm,
key_scale,
time_signature,
):
"""Prepare repeated batch-level caption/instruction/metadata values."""
pure_caption = self.extract_caption_from_sft_format(captions)
captions_batch = [pure_caption] * actual_batch_size
instructions_batch = [instruction] * actual_batch_size
lyrics_batch = [lyrics] * actual_batch_size
vocal_languages_batch = [vocal_language] * actual_batch_size
calculated_duration = None
if processed_src_audio is not None:
calculated_duration = processed_src_audio.shape[-1] / 48000.0
elif audio_duration is not None and float(audio_duration) > 0:
calculated_duration = float(audio_duration)
metadata_dict: Dict[str, Union[str, int]] = self._build_metadata_dict(
bpm, key_scale, time_signature, calculated_duration
)
metas_batch = [metadata_dict.copy() for _ in range(actual_batch_size)]
return captions_batch, instructions_batch, lyrics_batch, vocal_languages_batch, metas_batch

View file

@ -0,0 +1,136 @@
"""Diffusion-related handler helpers."""
from typing import Any, Dict
import torch
from acestep.mlx_dit.generate import mlx_generate_diffusion
class DiffusionMixin:
"""Mixin containing diffusion execution helpers.
Required host attributes:
- ``mlx_decoder``: MLX decoder object passed to ``mlx_generate_diffusion``.
- ``device``: torch device string used for output tensor placement.
- ``dtype``: torch dtype used for output tensor conversion.
"""
def _mlx_run_diffusion(
self,
encoder_hidden_states,
encoder_attention_mask,
context_latents,
src_latents,
seed,
infer_method: str = "ode",
shift: float = 3.0,
timesteps=None,
audio_cover_strength: float = 1.0,
encoder_hidden_states_non_cover=None,
encoder_attention_mask_non_cover=None,
context_latents_non_cover=None,
) -> Dict[str, Any]:
"""Run the MLX diffusion loop and return generated latents.
This method accepts the same signature as the handler diffusion path for
API compatibility. Attention-mask parameters are intentionally accepted
but unused because the MLX generator consumes hidden states/latents only.
Args:
encoder_hidden_states: Prompt conditioning tensor.
encoder_attention_mask: Unused; accepted for API compatibility.
context_latents: Context/reference latent tensor.
src_latents: Source latent tensor used for shape and initialization.
seed: Random seed used by MLX diffusion.
infer_method: Diffusion method, one of ``"ode"`` or ``"sde"``.
shift: Timestep shift value.
timesteps: Optional iterable or tensor-like custom timesteps.
audio_cover_strength: Blend factor for cover conditioning.
encoder_hidden_states_non_cover: Optional non-cover conditioning tensor.
encoder_attention_mask_non_cover: Unused; accepted for API compatibility.
context_latents_non_cover: Optional non-cover context latent tensor.
Returns:
Dict[str, Any]: ``{"target_latents": torch.Tensor, "time_costs": dict}``.
Raises:
AttributeError: If required host attributes are missing.
ValueError: If infer method is unsupported or batch dimensions mismatch.
TypeError: If ``timesteps`` is neither iterable nor tensor-like.
"""
import numpy as np
# Kept for API compatibility with non-MLX diffusion path.
_ = encoder_attention_mask, encoder_attention_mask_non_cover
for required_attr in ("mlx_decoder", "device", "dtype"):
if not hasattr(self, required_attr):
raise AttributeError(f"DiffusionMixin host is missing required attribute '{required_attr}'")
if infer_method not in {"ode", "sde"}:
raise ValueError(f"Unsupported infer_method '{infer_method}'. Expected 'ode' or 'sde'.")
if timesteps is not None and not (hasattr(timesteps, "__iter__") or hasattr(timesteps, "tolist")):
raise TypeError("timesteps must be iterable, tensor-like, or None")
if encoder_hidden_states.shape[0] != context_latents.shape[0]:
raise ValueError(
"Batch dimension mismatch: encoder_hidden_states and context_latents must share dim 0"
)
if encoder_hidden_states.shape[0] != src_latents.shape[0]:
raise ValueError(
"Batch dimension mismatch: encoder_hidden_states and src_latents must share dim 0"
)
if encoder_hidden_states_non_cover is not None and encoder_hidden_states_non_cover.shape[0] != encoder_hidden_states.shape[0]:
raise ValueError(
"Batch dimension mismatch: encoder_hidden_states_non_cover must share dim 0 with encoder_hidden_states"
)
if context_latents_non_cover is not None and context_latents_non_cover.shape[0] != context_latents.shape[0]:
raise ValueError(
"Batch dimension mismatch: context_latents_non_cover must share dim 0 with context_latents"
)
# Convert inputs to numpy (float32)
enc_np = encoder_hidden_states.detach().cpu().float().numpy()
ctx_np = context_latents.detach().cpu().float().numpy()
src_shape = (src_latents.shape[0], src_latents.shape[1], src_latents.shape[2])
enc_nc_np = (
encoder_hidden_states_non_cover.detach().cpu().float().numpy()
if encoder_hidden_states_non_cover is not None else None
)
ctx_nc_np = (
context_latents_non_cover.detach().cpu().float().numpy()
if context_latents_non_cover is not None else None
)
# Convert timesteps tensor if present
ts_list = None
if timesteps is not None:
if hasattr(timesteps, "tolist"):
ts_list = timesteps.tolist()
else:
ts_list = list(timesteps)
result = mlx_generate_diffusion(
mlx_decoder=self.mlx_decoder,
encoder_hidden_states_np=enc_np,
context_latents_np=ctx_np,
src_latents_shape=src_shape,
seed=seed,
infer_method=infer_method,
shift=shift,
timesteps=ts_list,
audio_cover_strength=audio_cover_strength,
encoder_hidden_states_non_cover_np=enc_nc_np,
context_latents_non_cover_np=ctx_nc_np,
)
# Convert result latents back to PyTorch tensor on the correct device
target_np = result["target_latents"]
target_tensor = torch.from_numpy(target_np).to(device=self.device, dtype=self.dtype)
return {
"target_latents": target_tensor,
"time_costs": result["time_costs"],
}

View file

@ -0,0 +1,155 @@
import unittest
from unittest.mock import patch
import numpy as np
import torch
from acestep.core.generation.handler.diffusion import DiffusionMixin
class _Host(DiffusionMixin):
def __init__(self, device: str = "cpu", dtype: torch.dtype = torch.float32):
self.mlx_decoder = object()
self.device = device
self.dtype = dtype
class _IterableTimesteps:
def __init__(self, values):
self._values = values
def __iter__(self):
return iter(self._values)
class DiffusionMixinTests(unittest.TestCase):
def test_mlx_run_diffusion_converts_inputs_and_outputs_tensor(self):
host = _Host(dtype=torch.float16)
encoder_hidden_states = torch.randn(2, 4, 8, dtype=torch.float64)
encoder_attention_mask = torch.ones(2, 4, dtype=torch.int64)
context_latents = torch.randn(2, 16, 8, dtype=torch.float64)
src_latents = torch.zeros(2, 3, 5, dtype=torch.float32)
timesteps = torch.tensor([1.0, 0.5], dtype=torch.float32)
non_cover_hidden = torch.randn(2, 4, 8, dtype=torch.float64)
non_cover_mask = torch.ones(2, 4, dtype=torch.int64)
non_cover_context = torch.randn(2, 16, 8, dtype=torch.float64)
fake_target = np.ones((2, 3, 5), dtype=np.float32)
def _fake_generate(**kwargs):
self.assertIs(kwargs["mlx_decoder"], host.mlx_decoder)
self.assertEqual(kwargs["src_latents_shape"], (2, 3, 5))
self.assertEqual(kwargs["timesteps"], [1.0, 0.5])
self.assertEqual(kwargs["infer_method"], "sde")
self.assertEqual(kwargs["shift"], 2.0)
self.assertEqual(kwargs["audio_cover_strength"], 0.6)
self.assertEqual(kwargs["encoder_hidden_states_np"].dtype, np.float32)
self.assertEqual(kwargs["context_latents_np"].dtype, np.float32)
self.assertEqual(kwargs["encoder_hidden_states_non_cover_np"].dtype, np.float32)
self.assertEqual(kwargs["context_latents_non_cover_np"].dtype, np.float32)
return {"target_latents": fake_target, "time_costs": {"diffusion_time_cost": 1.2}}
with patch("acestep.core.generation.handler.diffusion.mlx_generate_diffusion", side_effect=_fake_generate):
result = host._mlx_run_diffusion(
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
context_latents=context_latents,
src_latents=src_latents,
seed=123,
infer_method="sde",
shift=2.0,
timesteps=timesteps,
audio_cover_strength=0.6,
encoder_hidden_states_non_cover=non_cover_hidden,
encoder_attention_mask_non_cover=non_cover_mask,
context_latents_non_cover=non_cover_context,
)
self.assertIn("target_latents", result)
self.assertIn("time_costs", result)
self.assertEqual(result["time_costs"]["diffusion_time_cost"], 1.2)
self.assertEqual(result["target_latents"].dtype, torch.float16)
self.assertEqual(result["target_latents"].device.type, "cpu")
self.assertTrue(torch.allclose(result["target_latents"], torch.ones_like(result["target_latents"])))
def test_mlx_run_diffusion_handles_optional_and_iterable_timesteps(self):
host = _Host(dtype=torch.float32)
encoder_hidden_states = torch.randn(1, 2, 3, dtype=torch.float32)
encoder_attention_mask = torch.ones(1, 2, dtype=torch.int64)
context_latents = torch.randn(1, 4, 3, dtype=torch.float32)
src_latents = torch.zeros(1, 2, 3, dtype=torch.float32)
timesteps = _IterableTimesteps([0.9, 0.8, 0.7])
def _fake_generate(**kwargs):
self.assertEqual(kwargs["timesteps"], [0.9, 0.8, 0.7])
self.assertIsNone(kwargs["encoder_hidden_states_non_cover_np"])
self.assertIsNone(kwargs["context_latents_non_cover_np"])
return {"target_latents": np.zeros((1, 2, 3), dtype=np.float32), "time_costs": {}}
with patch("acestep.core.generation.handler.diffusion.mlx_generate_diffusion", side_effect=_fake_generate):
result = host._mlx_run_diffusion(
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
context_latents=context_latents,
src_latents=src_latents,
seed=1,
timesteps=timesteps,
)
self.assertEqual(tuple(result["target_latents"].shape), (1, 2, 3))
self.assertEqual(result["target_latents"].dtype, torch.float32)
def test_mlx_run_diffusion_rejects_invalid_infer_method(self):
host = _Host()
x = torch.randn(1, 2, 3)
with self.assertRaises(ValueError):
host._mlx_run_diffusion(
encoder_hidden_states=x,
encoder_attention_mask=torch.ones(1, 2, dtype=torch.int64),
context_latents=torch.randn(1, 4, 3),
src_latents=torch.randn(1, 2, 3),
seed=1,
infer_method="bad",
)
def test_mlx_run_diffusion_rejects_non_iterable_timesteps(self):
host = _Host()
x = torch.randn(1, 2, 3)
with self.assertRaises(TypeError):
host._mlx_run_diffusion(
encoder_hidden_states=x,
encoder_attention_mask=torch.ones(1, 2, dtype=torch.int64),
context_latents=torch.randn(1, 4, 3),
src_latents=torch.randn(1, 2, 3),
seed=1,
timesteps=123,
)
def test_mlx_run_diffusion_rejects_batch_mismatch(self):
host = _Host()
with self.assertRaises(ValueError):
host._mlx_run_diffusion(
encoder_hidden_states=torch.randn(2, 2, 3),
encoder_attention_mask=torch.ones(2, 2, dtype=torch.int64),
context_latents=torch.randn(1, 4, 3),
src_latents=torch.randn(2, 2, 3),
seed=1,
)
def test_mlx_run_diffusion_requires_host_attributes(self):
class _BrokenHost(DiffusionMixin):
pass
host = _BrokenHost()
x = torch.randn(1, 2, 3)
with self.assertRaises(AttributeError):
host._mlx_run_diffusion(
encoder_hidden_states=x,
encoder_attention_mask=torch.ones(1, 2, dtype=torch.int64),
context_latents=torch.randn(1, 4, 3),
src_latents=torch.randn(1, 2, 3),
seed=1,
)
if __name__ == "__main__":
unittest.main()

View file

@ -0,0 +1,414 @@
"""Initialization-adjacent utility mixin for AceStepHandler."""
import os
import time
from contextlib import contextmanager
from typing import List, Optional
import torch
from loguru import logger
class InitServiceMixin:
def _device_type(self) -> str:
"""Normalize the host device value to a backend type string."""
if isinstance(self.device, str):
return self.device.split(":", 1)[0]
return self.device.type
def get_available_checkpoints(self) -> List[str]:
"""Return available checkpoint directory paths under the project root.
Uses ``self._get_project_root()`` to resolve the checkpoints directory and
returns a single-item list when present, otherwise an empty list.
"""
# Get project root (handler.py is in acestep/, so go up two levels to project root)
project_root = self._get_project_root()
# default checkpoints
checkpoint_dir = os.path.join(project_root, "checkpoints")
if os.path.exists(checkpoint_dir):
return [checkpoint_dir]
else:
return []
def get_available_acestep_v15_models(self) -> List[str]:
"""Scan and return all model directory names starting with 'acestep-v15-'"""
# Get project root
project_root = self._get_project_root()
checkpoint_dir = os.path.join(project_root, "checkpoints")
models = []
if os.path.exists(checkpoint_dir):
# Scan all directories starting with 'acestep-v15-' in checkpoints folder
for item in os.listdir(checkpoint_dir):
item_path = os.path.join(checkpoint_dir, item)
if os.path.isdir(item_path) and item.startswith("acestep-v15-"):
models.append(item)
# Sort by name
models.sort()
return models
def is_flash_attention_available(self, device: Optional[str] = None) -> bool:
"""Check whether flash attention can be used on the target device."""
target_device = str(device or self.device or "auto").split(":", 1)[0]
if target_device == "auto":
if not torch.cuda.is_available():
return False
else:
if target_device != "cuda" or not torch.cuda.is_available():
return False
# FlashAttention requires Ampere (compute capability >= 8.0) or newer
try:
major, _ = torch.cuda.get_device_capability()
if major < 8:
logger.info(
f"[is_flash_attention_available] GPU compute capability {major}.x < 8.0 "
f"(pre-Ampere) — FlashAttention not supported, will use SDPA instead."
)
return False
except Exception:
return False
try:
import flash_attn
return True
except ImportError:
return False
def is_turbo_model(self) -> bool:
"""Check if the currently loaded model is a turbo model"""
if self.config is None:
return False
return getattr(self.config, "is_turbo", False)
def _empty_cache(self):
"""Clear accelerator memory cache (CUDA, XPU, or MPS)."""
device_type = self._device_type()
if device_type == "cuda" and torch.cuda.is_available():
torch.cuda.empty_cache()
elif device_type == "xpu" and hasattr(torch, "xpu") and torch.xpu.is_available():
torch.xpu.empty_cache()
elif device_type == "mps" and hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
torch.mps.empty_cache()
def _synchronize(self):
"""Synchronize accelerator operations (CUDA, XPU, or MPS)."""
device_type = self._device_type()
if device_type == "cuda" and torch.cuda.is_available():
torch.cuda.synchronize()
elif device_type == "xpu" and hasattr(torch, "xpu") and torch.xpu.is_available():
torch.xpu.synchronize()
elif device_type == "mps" and hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
torch.mps.synchronize()
def _memory_allocated(self):
"""Get current accelerator memory usage in bytes, or 0 for unsupported backends."""
device_type = self._device_type()
if device_type == "cuda" and torch.cuda.is_available():
return torch.cuda.memory_allocated()
# MPS and XPU don't expose per-tensor memory tracking
return 0
def _max_memory_allocated(self):
"""Get peak accelerator memory usage in bytes, or 0 for unsupported backends."""
device_type = self._device_type()
if device_type == "cuda" and torch.cuda.is_available():
return torch.cuda.max_memory_allocated()
return 0
def _is_on_target_device(self, tensor, target_device):
"""Check if tensor is on the target device (handles cuda vs cuda:0 comparison)."""
if tensor is None:
return True
try:
if isinstance(target_device, torch.device):
target_type = target_device.type
else:
target_type = torch.device(str(target_device)).type
except Exception:
# Keep fallback conservative: derive backend token instead of assuming CUDA.
target_type = str(target_device).strip().lower().split(":", 1)[0]
if not target_type:
logger.warning(
"[_is_on_target_device] Malformed target device value: {!r}",
target_device,
)
return False
return tensor.device.type == target_type
@staticmethod
def _get_affine_quantized_tensor_class():
"""Return the AffineQuantizedTensor class from torchao, or None if unavailable.
Supports both old (torchao.quantization.affine_quantized) and new
(torchao.dtypes.affine_quantized_tensor) import paths across torchao versions.
"""
try:
from torchao.dtypes.affine_quantized_tensor import AffineQuantizedTensor
return AffineQuantizedTensor
except ImportError:
pass
try:
from torchao.quantization.affine_quantized import AffineQuantizedTensor
return AffineQuantizedTensor
except ImportError:
pass
return None
def _is_quantized_tensor(self, t):
"""True if t is a torchao AffineQuantizedTensor (calling .to() on it can raise NotImplementedError)."""
if t is None:
return False
cls = self._get_affine_quantized_tensor_class()
if cls is None:
return False
return isinstance(t, cls)
def _has_quantized_params(self, module):
"""True if module (or any submodule) has at least one AffineQuantizedTensor parameter."""
cls = self._get_affine_quantized_tensor_class()
if cls is None:
return False
for _, param in module.named_parameters():
if param is not None and isinstance(param, cls):
return True
return False
def _ensure_silence_latent_on_device(self):
"""Ensure silence_latent is on the correct device (self.device)."""
if hasattr(self, "silence_latent") and self.silence_latent is not None:
if not self._is_on_target_device(self.silence_latent, self.device):
self.silence_latent = self.silence_latent.to(self.device).to(self.dtype)
def _move_module_recursive(self, module, target_device, dtype=None, visited=None):
"""
Recursively move a module and all its submodules to the target device.
This handles modules that may not be properly registered.
"""
if visited is None:
visited = set()
module_id = id(module)
if module_id in visited:
return
visited.add(module_id)
# Move the module itself
module.to(target_device)
if dtype is not None:
module.to(dtype)
# Move all direct parameters
for param_name, param in module._parameters.items():
if param is not None and not self._is_on_target_device(param, target_device):
if self._is_quantized_tensor(param):
moved_param = self._move_quantized_param(param, target_device)
else:
moved_param = torch.nn.Parameter(
param.data.to(target_device), requires_grad=param.requires_grad
)
if dtype is not None and moved_param.is_floating_point():
moved_param = torch.nn.Parameter(
moved_param.data.to(dtype), requires_grad=param.requires_grad
)
module._parameters[param_name] = moved_param
# Move all direct buffers
for buf_name, buf in module._buffers.items():
if buf is not None and not self._is_on_target_device(buf, target_device):
module._buffers[buf_name] = buf.to(target_device)
# Recursively process all submodules (registered and unregistered)
for name, child in module._modules.items():
if child is not None:
self._move_module_recursive(child, target_device, dtype, visited)
# Also check for any nn.Module attributes that might not be in _modules
for attr_name in dir(module):
if attr_name.startswith('_'):
continue
try:
attr = getattr(module, attr_name, None)
if isinstance(attr, torch.nn.Module) and id(attr) not in visited:
self._move_module_recursive(attr, target_device, dtype, visited)
except Exception:
pass
def _move_quantized_param(self, param, target_device):
"""Move an AffineQuantizedTensor to target_device using _apply_fn_to_data.
This is the safe fallback for older torch versions where model.to(device) raises
NotImplementedError on AffineQuantizedTensor (because aten._has_compatible_shallow_copy_type
is not implemented). _apply_fn_to_data recursively applies a function to all inner
tensors (int_data, scale, zero_point, etc.) without going through Module._apply.
"""
if hasattr(param, '_apply_fn_to_data'):
return torch.nn.Parameter(
param._apply_fn_to_data(lambda x: x.to(target_device)),
requires_grad=param.requires_grad,
)
# Last resort: try direct .to() (may raise), but preserve Parameter registration.
moved = param.to(target_device)
return torch.nn.Parameter(moved, requires_grad=param.requires_grad)
def _recursive_to_device(self, model, device, dtype=None):
"""
Recursively move all parameters and buffers of a model to the specified device.
This is more thorough than model.to() for some custom HuggingFace models.
Handles torchao AffineQuantizedTensor parameters that may raise NotImplementedError
on model.to(device) in older torch versions (where Module._apply calls
_has_compatible_shallow_copy_type, which is not implemented for AffineQuantizedTensor).
In that case, falls back to moving quantized parameters individually via _apply_fn_to_data.
"""
target_device = torch.device(device) if isinstance(device, str) else device
# Method 1: Standard .to() call — works on newer torch where _apply uses swap_tensors
try:
model.to(target_device)
if dtype is not None:
model.to(dtype)
except NotImplementedError:
# Older torch: Module._apply calls _has_compatible_shallow_copy_type which is
# not implemented for AffineQuantizedTensor. Move parameters manually.
logger.info(
"[_recursive_to_device] model.to() raised NotImplementedError "
"(AffineQuantizedTensor on older torch). Moving parameters individually."
)
for module in model.modules():
# Move non-quantized parameters and buffers directly
for param_name, param in module._parameters.items():
if param is None:
continue
if self._is_on_target_device(param, target_device):
continue
if self._is_quantized_tensor(param):
module._parameters[param_name] = self._move_quantized_param(param, target_device)
else:
module._parameters[param_name] = torch.nn.Parameter(
param.data.to(target_device), requires_grad=param.requires_grad
)
if dtype is not None:
module._parameters[param_name] = torch.nn.Parameter(
module._parameters[param_name].data.to(dtype),
requires_grad=param.requires_grad,
)
for buf_name, buf in module._buffers.items():
if buf is not None and not self._is_on_target_device(buf, target_device):
module._buffers[buf_name] = buf.to(target_device)
# Method 2: Use our thorough recursive moving for any missed modules
# (skip if model.to() failed — we already moved everything above)
try:
self._move_module_recursive(model, target_device, dtype)
except NotImplementedError:
pass # Already handled above
# Method 3: Force move via state_dict if there are still parameters on wrong device
wrong_device_params = []
for name, param in model.named_parameters():
if not self._is_on_target_device(param, device):
wrong_device_params.append(name)
if wrong_device_params and device != "cpu":
logger.warning(f"[_recursive_to_device] {len(wrong_device_params)} parameters on wrong device after initial move, retrying individually")
for module in model.modules():
for param_name, param in module._parameters.items():
if param is None or self._is_on_target_device(param, target_device):
continue
if self._is_quantized_tensor(param):
module._parameters[param_name] = self._move_quantized_param(param, target_device)
else:
module._parameters[param_name] = torch.nn.Parameter(
param.data.to(target_device), requires_grad=param.requires_grad
)
if dtype is not None and module._parameters[param_name].is_floating_point():
module._parameters[param_name] = torch.nn.Parameter(
module._parameters[param_name].data.to(dtype),
requires_grad=param.requires_grad,
)
# Synchronize accelerator to ensure all transfers are complete
if device != "cpu":
self._synchronize()
# Final verification
if device != "cpu":
still_wrong = []
for name, param in model.named_parameters():
if not self._is_on_target_device(param, device):
still_wrong.append(f"{name} on {param.device}")
if still_wrong:
logger.error(f"[_recursive_to_device] CRITICAL: {len(still_wrong)} parameters still on wrong device: {still_wrong[:10]}")
@contextmanager
def _load_model_context(self, model_name: str):
"""
Context manager to load a model to GPU and offload it back to CPU after use.
Args:
model_name: Name of the model to load ("text_encoder", "vae", "model")
"""
if not self.offload_to_cpu:
yield
return
# If model is DiT ("model") and offload_dit_to_cpu is False, do not offload
if model_name == "model" and not self.offload_dit_to_cpu:
# Ensure it's on device if not already (should be handled by init, but safe to check)
model = getattr(self, model_name, None)
if model is not None:
# Check if model is on CPU, if so move to device (one-time move if it was somehow on CPU)
# We check the first parameter's device
try:
param = next(model.parameters())
if param.device.type == "cpu":
logger.info(f"[_load_model_context] Moving {model_name} to {self.device} (persistent)")
self._recursive_to_device(model, self.device, self.dtype)
if hasattr(self, "silence_latent"):
self.silence_latent = self.silence_latent.to(self.device).to(self.dtype)
except StopIteration:
pass
yield
return
model = getattr(self, model_name, None)
if model is None:
yield
return
# Load to GPU
logger.info(f"[_load_model_context] Loading {model_name} to {self.device}")
start_time = time.time()
if model_name == "vae":
vae_dtype = self._get_vae_dtype()
self._recursive_to_device(model, self.device, vae_dtype)
else:
self._recursive_to_device(model, self.device, self.dtype)
if model_name == "model" and hasattr(self, "silence_latent"):
self.silence_latent = self.silence_latent.to(self.device).to(self.dtype)
load_time = time.time() - start_time
self.current_offload_cost += load_time
logger.info(f"[_load_model_context] Loaded {model_name} to {self.device} in {load_time:.4f}s")
try:
yield
finally:
# Offload to CPU
logger.info(f"[_load_model_context] Offloading {model_name} to CPU")
start_time = time.time()
if model_name == "vae":
self._recursive_to_device(model, "cpu", self._get_vae_dtype("cpu"))
else:
self._recursive_to_device(model, "cpu")
# NOTE: Do NOT offload silence_latent to CPU here!
# silence_latent is used in many places outside of model context,
# so it should stay on GPU to avoid device mismatch errors.
self._empty_cache()
offload_time = time.time() - start_time
self.current_offload_cost += offload_time
logger.info(f"[_load_model_context] Offloaded {model_name} to CPU in {offload_time:.4f}s")

View file

@ -0,0 +1,179 @@
import os
import builtins
import types
import unittest
from unittest.mock import Mock, patch
import torch
from acestep.core.generation.handler.init_service import InitServiceMixin
class _Host(InitServiceMixin):
def __init__(self, project_root: str, device: str = "cpu", config=None):
self._project_root = project_root
self.device = device
self.config = config
def _get_project_root(self):
return self._project_root
class InitServiceMixinTests(unittest.TestCase):
def test_device_type_normalizes_device(self):
host = _Host(project_root="K:/fake_root", device="cuda:0")
self.assertEqual(host._device_type(), "cuda")
def test_is_on_target_device_handles_device_alias(self):
host = _Host(project_root="K:/fake_root", device="cpu")
t = types.SimpleNamespace(device=types.SimpleNamespace(type="cuda"))
self.assertTrue(host._is_on_target_device(t, "cuda:0"))
self.assertFalse(host._is_on_target_device(t, "cpu"))
def test_is_on_target_device_fallback_does_not_assume_cuda(self):
host = _Host(project_root="K:/fake_root", device="cpu")
t = types.SimpleNamespace(device=types.SimpleNamespace(type="cuda"))
self.assertFalse(host._is_on_target_device(t, "mps:0"))
def test_is_on_target_device_malformed_target_logs_and_returns_false(self):
host = _Host(project_root="K:/fake_root", device="cpu")
t = types.SimpleNamespace(device=types.SimpleNamespace(type="cuda"))
with patch("acestep.core.generation.handler.init_service.logger.warning") as warning:
self.assertFalse(host._is_on_target_device(t, ":0"))
warning.assert_called_once()
def test_move_module_recursive_preserves_parameter_type(self):
host = _Host(project_root="K:/fake_root", device="cpu")
module = torch.nn.Linear(2, 2)
with patch.object(host, "_is_on_target_device", return_value=False):
host._move_module_recursive(module, "cpu")
self.assertIsInstance(module.weight, torch.nn.Parameter)
self.assertIsInstance(module.bias, torch.nn.Parameter)
def test_move_quantized_param_fallback_wraps_parameter(self):
host = _Host(project_root="K:/fake_root", device="cpu")
param = torch.nn.Parameter(torch.randn(2), requires_grad=True)
moved = host._move_quantized_param(param, "cpu")
self.assertIsInstance(moved, torch.nn.Parameter)
self.assertTrue(moved.requires_grad)
def test_get_available_checkpoints_returns_expected_list(self):
host = _Host(project_root="K:/fake_root")
with patch("os.path.exists", return_value=False):
self.assertEqual(host.get_available_checkpoints(), [])
with patch("os.path.exists", return_value=True):
self.assertEqual(host.get_available_checkpoints(), [os.path.join("K:/fake_root", "checkpoints")])
def test_get_available_acestep_v15_models_filters_and_sorts(self):
host = _Host(project_root="K:/fake_root")
with patch("os.path.exists", return_value=True), patch(
"os.listdir",
return_value=["acestep-v15-zeta", "acestep-v15-alpha", "not-a-model", "acestep-v15-file"],
), patch(
"os.path.isdir",
side_effect=lambda p: p.endswith("acestep-v15-zeta")
or p.endswith("acestep-v15-alpha")
or p.endswith("not-a-model"),
):
self.assertEqual(
host.get_available_acestep_v15_models(),
["acestep-v15-alpha", "acestep-v15-zeta"],
)
def test_is_turbo_model_uses_config_flag(self):
host = _Host(project_root="K:/fake_root", config=None)
self.assertFalse(host.is_turbo_model())
host.config = types.SimpleNamespace(is_turbo=True)
self.assertTrue(host.is_turbo_model())
def test_is_flash_attention_available_rejects_non_cuda(self):
host = _Host(project_root="K:/fake_root", device="cpu")
self.assertFalse(host.is_flash_attention_available())
self.assertFalse(host.is_flash_attention_available(device="mps"))
def test_is_flash_attention_available_true_when_cuda_and_module_present(self):
host = _Host(project_root="K:/fake_root", device="cuda")
with patch("torch.cuda.is_available", return_value=True):
with patch("torch.cuda.get_device_capability", return_value=(8, 0)):
with patch.dict("sys.modules", {"flash_attn": types.SimpleNamespace()}):
self.assertTrue(host.is_flash_attention_available())
def test_is_flash_attention_available_false_when_pre_ampere_gpu(self):
host = _Host(project_root="K:/fake_root", device="cuda")
with patch("torch.cuda.is_available", return_value=True):
with patch("torch.cuda.get_device_capability", return_value=(7, 5)):
with patch.dict("sys.modules", {"flash_attn": types.SimpleNamespace()}):
self.assertFalse(host.is_flash_attention_available())
def test_is_flash_attention_available_false_when_module_missing(self):
host = _Host(project_root="K:/fake_root", device="cuda")
real_import = builtins.__import__
def fake_import(name, globals=None, locals=None, fromlist=(), level=0):
if name == "flash_attn":
raise ImportError("flash_attn missing")
return real_import(name, globals, locals, fromlist, level)
with patch("torch.cuda.is_available", return_value=True):
with patch("torch.cuda.get_device_capability", return_value=(8, 0)):
with patch("builtins.__import__", side_effect=fake_import):
self.assertFalse(host.is_flash_attention_available())
def test_empty_cache_routes_to_cuda(self):
host = _Host(project_root="K:/fake_root", device="cuda")
with patch("torch.cuda.is_available", return_value=True), patch("torch.cuda.empty_cache") as empty_cache:
host._empty_cache()
empty_cache.assert_called_once()
def test_empty_cache_routes_to_xpu(self):
host = _Host(project_root="K:/fake_root", device="xpu")
empty_cache = Mock()
xpu_stub = types.SimpleNamespace(is_available=lambda: True, empty_cache=empty_cache)
with patch("torch.xpu", new=xpu_stub, create=True):
host._empty_cache()
empty_cache.assert_called_once()
def test_empty_cache_routes_to_mps(self):
host = _Host(project_root="K:/fake_root", device="mps")
with patch("torch.backends.mps.is_available", return_value=True), patch("torch.mps.empty_cache") as empty_cache:
host._empty_cache()
empty_cache.assert_called_once()
def test_synchronize_routes_to_cuda(self):
host = _Host(project_root="K:/fake_root", device="cuda")
with patch("torch.cuda.is_available", return_value=True), patch("torch.cuda.synchronize") as sync:
host._synchronize()
sync.assert_called_once()
def test_synchronize_routes_to_xpu(self):
host = _Host(project_root="K:/fake_root", device="xpu")
sync = Mock()
xpu_stub = types.SimpleNamespace(is_available=lambda: True, synchronize=sync)
with patch("torch.xpu", new=xpu_stub, create=True):
host._synchronize()
sync.assert_called_once()
def test_synchronize_routes_to_mps(self):
host = _Host(project_root="K:/fake_root", device="mps")
with patch("torch.backends.mps.is_available", return_value=True), patch("torch.mps.synchronize") as sync:
host._synchronize()
sync.assert_called_once()
def test_memory_queries_use_cuda_only(self):
host = _Host(project_root="K:/fake_root", device="cpu")
self.assertEqual(host._memory_allocated(), 0)
self.assertEqual(host._max_memory_allocated(), 0)
host.device = "cuda"
with patch("torch.cuda.is_available", return_value=True):
with patch("torch.cuda.memory_allocated", return_value=123), patch(
"torch.cuda.max_memory_allocated", return_value=456
):
self.assertEqual(host._memory_allocated(), 123)
self.assertEqual(host._max_memory_allocated(), 456)
if __name__ == "__main__":
unittest.main()

View file

@ -0,0 +1,133 @@
"""Audio IO and normalization helpers for handler decomposition."""
import math
import random
from typing import Optional
import torch
from loguru import logger
class IoAudioMixin:
"""Mixin containing audio file loading and normalization helpers.
Depends on host members:
- Method: ``is_silence`` (provided by ``MemoryUtilsMixin`` in this decomposition).
"""
def _normalize_audio_to_stereo_48k(self, audio: torch.Tensor, sr: int) -> torch.Tensor:
"""Normalize audio tensor to stereo at 48kHz.
Args:
audio: Tensor in [channels, samples] or [samples] format.
sr: Source sample rate.
Returns:
Tensor in [2, samples] at 48kHz, clamped to [-1.0, 1.0].
"""
if audio.shape[0] == 1:
audio = torch.cat([audio, audio], dim=0)
audio = audio[:2]
if sr != 48000:
import torchaudio
audio = torchaudio.transforms.Resample(sr, 48000)(audio)
return torch.clamp(audio, -1.0, 1.0)
def process_target_audio(self, audio_file: Optional[str]) -> Optional[torch.Tensor]:
"""Load and normalize target audio file.
Args:
audio_file: Path to target audio file.
Returns:
Normalized stereo 48kHz tensor, or ``None`` on error/empty input.
"""
if audio_file is None:
return None
try:
import soundfile as sf
audio_np, sr = sf.read(audio_file, dtype="float32")
if audio_np.ndim == 1:
audio = torch.from_numpy(audio_np).unsqueeze(0)
else:
audio = torch.from_numpy(audio_np.T)
return self._normalize_audio_to_stereo_48k(audio, sr)
except (OSError, RuntimeError, ValueError):
logger.exception("[process_target_audio] Error processing target audio")
return None
def process_reference_audio(self, audio_file: Optional[str]) -> Optional[torch.Tensor]:
"""Load and normalize reference audio, then sample 3x10s segments.
Args:
audio_file: Path to reference audio file.
Returns:
30-second stereo tensor from sampled front/middle/back segments,
or ``None`` for empty/silent/error cases.
"""
if audio_file is None:
return None
try:
import torchaudio
audio, sr = torchaudio.load(audio_file)
logger.debug(f"[process_reference_audio] Reference audio shape: {audio.shape}")
logger.debug(f"[process_reference_audio] Reference audio sample rate: {sr}")
logger.debug(
f"[process_reference_audio] Reference audio duration: {audio.shape[-1] / sr:.6f} seconds"
)
audio = self._normalize_audio_to_stereo_48k(audio, sr)
if self.is_silence(audio):
return None
target_frames = 30 * 48000
segment_frames = 10 * 48000
if audio.shape[-1] < target_frames:
repeat_times = math.ceil(target_frames / audio.shape[-1])
audio = audio.repeat(1, repeat_times)
total_frames = audio.shape[-1]
segment_size = total_frames // 3
front_start = random.randint(0, max(0, segment_size - segment_frames))
front_audio = audio[:, front_start : front_start + segment_frames]
middle_start = segment_size + random.randint(0, max(0, segment_size - segment_frames))
middle_audio = audio[:, middle_start : middle_start + segment_frames]
back_start = 2 * segment_size + random.randint(
0, max(0, (total_frames - 2 * segment_size) - segment_frames)
)
back_audio = audio[:, back_start : back_start + segment_frames]
return torch.cat([front_audio, middle_audio, back_audio], dim=-1)
except (OSError, RuntimeError, ValueError):
logger.exception("[process_reference_audio] Error processing reference audio")
return None
def process_src_audio(self, audio_file: Optional[str]) -> Optional[torch.Tensor]:
"""Load and normalize source audio for remix/extract flows.
Args:
audio_file: Path to source audio file.
Returns:
Normalized stereo 48kHz tensor, or ``None`` on error/empty input.
"""
if audio_file is None:
return None
try:
import torchaudio
audio, sr = torchaudio.load(audio_file)
return self._normalize_audio_to_stereo_48k(audio, sr)
except (OSError, RuntimeError, ValueError):
logger.exception("[process_src_audio] Error processing source audio")
return None

View file

@ -0,0 +1,76 @@
"""Unit tests for audio IO mixin extraction."""
import sys
import types
import unittest
from unittest.mock import patch
import numpy as np
import torch
from acestep.core.generation.handler.io_audio import IoAudioMixin
class _Host(IoAudioMixin):
"""Minimal host implementing methods used by ``IoAudioMixin``."""
def is_silence(self, audio: torch.Tensor) -> bool:
"""Treat near-zero tensors as silence."""
return torch.all(audio.abs() < 1e-6).item()
def _fake_torchaudio_module(load_fn):
"""Create fake ``torchaudio`` module with minimal API used by tests."""
module = types.ModuleType("torchaudio")
module.load = load_fn
module.transforms = types.SimpleNamespace(Resample=lambda *_args, **_kwargs: (lambda x: x))
return module
class IoAudioMixinTests(unittest.TestCase):
"""Tests for normalization and audio loading helpers."""
def test_normalize_audio_to_stereo_48k_duplicates_mono_and_clamps(self):
"""Mono input should duplicate to stereo and clamp values."""
host = _Host()
audio = torch.tensor([[2.0, -2.0, 0.5]], dtype=torch.float32)
result = host._normalize_audio_to_stereo_48k(audio, 48000)
self.assertEqual(tuple(result.shape), (2, 3))
self.assertTrue(torch.all(result <= 1.0))
self.assertTrue(torch.all(result >= -1.0))
def test_process_target_audio_loads_and_normalizes(self):
"""Target audio should be loaded and normalized through helper."""
host = _Host()
fake_np = np.array([0.1, -0.1, 0.2], dtype=np.float32)
fake_sf = types.ModuleType("soundfile")
fake_sf.read = lambda *_args, **_kwargs: (fake_np, 32000)
with patch.dict(sys.modules, {"soundfile": fake_sf}):
with patch.object(host, "_normalize_audio_to_stereo_48k", return_value=torch.zeros(2, 3)) as norm:
result = host.process_target_audio("fake.wav")
self.assertIsNotNone(result)
norm.assert_called_once()
def test_process_src_audio_handles_load_error(self):
"""Source audio processing should return None on load failure."""
host = _Host()
fake_ta = _fake_torchaudio_module(lambda *_args, **_kwargs: (_ for _ in ()).throw(RuntimeError("bad")))
with patch.dict(sys.modules, {"torchaudio": fake_ta}):
result = host.process_src_audio("bad.wav")
self.assertIsNone(result)
def test_process_reference_audio_returns_none_for_silence(self):
"""Reference audio should short-circuit for silent input."""
host = _Host()
silent = torch.zeros(2, 16, dtype=torch.float32)
fake_ta = _fake_torchaudio_module(lambda *_args, **_kwargs: (silent, 48000))
with patch.dict(sys.modules, {"torchaudio": fake_ta}):
result = host.process_reference_audio("silent.wav")
self.assertIsNone(result)
if __name__ == "__main__":
unittest.main()

View file

@ -37,20 +37,46 @@ def load_lora(self, lora_path: str) -> str:
return "❌ PEFT library not installed. Please install with: pip install peft"
try:
import copy
# Memory-efficient state_dict backup instead of deepcopy
if self._base_decoder is None:
self._base_decoder = copy.deepcopy(self.model.decoder)
logger.info("Base decoder backed up")
# Log memory before backup
if hasattr(self, '_memory_allocated'):
mem_before = self._memory_allocated() / (1024**3)
logger.info(f"VRAM before LoRA load: {mem_before:.2f}GB")
# Save only the base model weights as state_dict (CPU to save VRAM)
try:
state_dict = self.model.decoder.state_dict()
if not state_dict:
raise ValueError("state_dict is empty - cannot backup decoder")
self._base_decoder = {k: v.detach().cpu().clone() for k, v in state_dict.items()}
except Exception as e:
logger.error(f"Failed to create state_dict backup: {e}")
raise
# Calculate backup size in MB
backup_size_mb = sum(v.numel() * v.element_size() for v in self._base_decoder.values()) / (1024**2)
logger.info(f"Base decoder state_dict backed up to CPU ({backup_size_mb:.1f}MB)")
else:
self.model.decoder = copy.deepcopy(self._base_decoder)
logger.info("Restored base decoder before loading new LoRA")
# Restore base decoder from state_dict backup
logger.info("Restoring base decoder from state_dict backup")
load_result = self.model.decoder.load_state_dict(self._base_decoder, strict=False)
if load_result.missing_keys:
logger.warning(f"Missing keys when restoring decoder: {load_result.missing_keys[:5]}")
if load_result.unexpected_keys:
logger.warning(f"Unexpected keys when restoring decoder: {load_result.unexpected_keys[:5]}")
self.model.decoder = self.model.decoder.to(self.device).to(self.dtype)
logger.info(f"Loading LoRA adapter from {lora_path}")
self.model.decoder = PeftModel.from_pretrained(self.model.decoder, lora_path, is_trainable=False)
self.model.decoder = self.model.decoder.to(self.device).to(self.dtype)
self.model.decoder.eval()
# Log memory after LoRA load
if hasattr(self, '_memory_allocated'):
mem_after = self._memory_allocated() / (1024**3)
logger.info(f"VRAM after LoRA load: {mem_after:.2f}GB")
self.lora_loaded = True
self.use_lora = True
self._ensure_lora_registry()
@ -81,9 +107,35 @@ def unload_lora(self) -> str:
return "❌ Base decoder backup not found. Cannot restore."
try:
import copy
self.model.decoder = copy.deepcopy(self._base_decoder)
# Log memory before unload (track before any operations)
mem_before = None
if hasattr(self, '_memory_allocated'):
mem_before = self._memory_allocated() / (1024**3)
logger.info(f"VRAM before LoRA unload: {mem_before:.2f}GB")
# Get the base model from the PEFT wrapper if it exists
# This is more memory-efficient than recreating from state_dict
from peft import PeftModel
if isinstance(self.model.decoder, PeftModel):
logger.info("Extracting base model from PEFT wrapper")
# PEFT's get_base_model() returns the underlying base model without copying
self.model.decoder = self.model.decoder.get_base_model()
# Restore state_dict from backup to ensure clean state
load_result = self.model.decoder.load_state_dict(self._base_decoder, strict=False)
if load_result.missing_keys:
logger.warning(f"Missing keys when restoring decoder: {load_result.missing_keys[:5]}")
if load_result.unexpected_keys:
logger.warning(f"Unexpected keys when restoring decoder: {load_result.unexpected_keys[:5]}")
else:
# Fallback: restore from state_dict backup
logger.info("Restoring base decoder from state_dict backup")
load_result = self.model.decoder.load_state_dict(self._base_decoder, strict=False)
if load_result.missing_keys:
logger.warning(f"Missing keys when restoring decoder: {load_result.missing_keys[:5]}")
if load_result.unexpected_keys:
logger.warning(f"Unexpected keys when restoring decoder: {load_result.unexpected_keys[:5]}")
self.model.decoder = self.model.decoder.to(self.device).to(self.dtype)
self.model.decoder.eval()
@ -99,6 +151,12 @@ def unload_lora(self) -> str:
self._lora_active_adapter = None
self._lora_scale_state = {}
# Log memory after unload
if mem_before is not None and hasattr(self, '_memory_allocated'):
mem_after = self._memory_allocated() / (1024**3)
freed = mem_before - mem_after
logger.info(f"VRAM after LoRA unload: {mem_after:.2f}GB (freed: {freed:.2f}GB)")
logger.info("LoRA unloaded, base decoder restored")
return "✅ LoRA unloaded, using base model"
except Exception as e:

View file

@ -0,0 +1,166 @@
"""Memory and VRAM helper methods for handler decomposition."""
import os
from typing import Optional
import torch
from loguru import logger
from acestep.gpu_config import get_effective_free_vram_gb, get_global_gpu_config
class MemoryUtilsMixin:
"""Mixin containing memory sizing and VRAM guard helpers.
Depends on host members:
- Attribute: ``device``.
"""
def is_silence(self, audio: torch.Tensor) -> bool:
"""Return True when audio is effectively silent."""
return bool(torch.all(audio.abs() < 1e-6))
def _get_system_memory_gb(self) -> Optional[float]:
"""Return total system RAM in GB when available."""
try:
page_size = os.sysconf("SC_PAGE_SIZE")
page_count = os.sysconf("SC_PHYS_PAGES")
if page_size and page_count:
return (page_size * page_count) / (1024**3)
except (ValueError, OSError, AttributeError):
return None
return None
def _get_effective_mps_memory_gb(self) -> Optional[float]:
"""Best-effort MPS memory estimate (recommended max or system RAM)."""
if hasattr(torch, "mps") and hasattr(torch.mps, "recommended_max_memory"):
try:
return torch.mps.recommended_max_memory() / (1024**3)
except Exception:
pass
system_gb = self._get_system_memory_gb()
if system_gb is None:
return None
return system_gb * 0.75
VAE_DECODE_MAX_CHUNK_SIZE = 512
def _get_auto_decode_chunk_size(self) -> int:
"""Choose a conservative VAE decode chunk size based on available memory."""
override = os.environ.get("ACESTEP_VAE_DECODE_CHUNK_SIZE")
if override:
try:
value = int(override)
if value > 0:
return value
except ValueError:
pass
max_chunk = self.VAE_DECODE_MAX_CHUNK_SIZE
if self.device == "mps":
mem_gb = self._get_effective_mps_memory_gb()
if mem_gb is not None:
if mem_gb >= 48:
return min(1536, max_chunk)
if mem_gb >= 24:
return min(1024, max_chunk)
return min(512, max_chunk)
if self.device == "cuda" or (isinstance(self.device, str) and self.device.startswith("cuda")):
try:
free_gb = get_effective_free_vram_gb()
except Exception:
free_gb = 0
logger.debug(f"[_get_auto_decode_chunk_size] Effective free VRAM: {free_gb:.2f} GB")
if free_gb >= 24.0:
return min(512, max_chunk)
if free_gb >= 16.0:
return min(384, max_chunk)
if free_gb >= 12.0:
return min(256, max_chunk)
return min(128, max_chunk)
return min(256, max_chunk)
def _should_offload_wav_to_cpu(self) -> bool:
"""Decide whether to offload decoded wavs to CPU for memory safety."""
override = os.environ.get("ACESTEP_MPS_DECODE_OFFLOAD")
if override:
return override.lower() in ("1", "true", "yes")
if self.device == "mps":
mem_gb = self._get_effective_mps_memory_gb()
if mem_gb is not None and mem_gb >= 32:
return False
return True
if self.device == "cuda" or (isinstance(self.device, str) and self.device.startswith("cuda")):
try:
free_gb = get_effective_free_vram_gb()
logger.debug(f"[_should_offload_wav_to_cpu] Effective free VRAM: {free_gb:.2f} GB")
if free_gb >= 24.0:
return False
except Exception:
pass
return True
def _vram_guard_reduce_batch(
self,
batch_size: int,
audio_duration: Optional[float] = None,
use_lm: bool = False,
) -> int:
"""Auto-reduce batch_size when free VRAM is too tight."""
if batch_size <= 1:
return batch_size
device = self.device
if device == "cpu" or device == "mps":
return batch_size
if self.offload_to_cpu:
gpu_config = get_global_gpu_config()
if gpu_config is not None:
tier_max = gpu_config.max_batch_size_with_lm
if batch_size <= tier_max:
logger.debug(
f"[VRAM guard] offload_to_cpu=True, batch_size={batch_size} <= "
f"tier limit {tier_max} — skipping dynamic VRAM check"
)
return batch_size
try:
free_gb = get_effective_free_vram_gb()
except Exception:
return batch_size
duration_sec = float(audio_duration) if audio_duration and float(audio_duration) > 0 else 60.0
per_sample_gb = 0.5 + max(0.0, 0.15 * (duration_sec - 60.0) / 60.0)
if hasattr(self, "model") and self.model is not None:
model_name = getattr(self, "config_path", "") or ""
if "base" in model_name.lower():
per_sample_gb *= 2.0
safety_margin_gb = 1.5
available_for_batch = free_gb - safety_margin_gb
if available_for_batch <= 0:
logger.warning(f"[VRAM guard] Only {free_gb:.1f} GB free — reducing batch_size to 1")
return 1
max_safe_batch = max(1, int(available_for_batch / per_sample_gb))
if max_safe_batch < batch_size:
logger.warning(
f"[VRAM guard] Free VRAM {free_gb:.1f} GB can safely fit ~{max_safe_batch} samples "
f"(requested {batch_size}). Reducing batch_size to {max_safe_batch}."
)
return max_safe_batch
return batch_size
def _get_vae_dtype(self, device: Optional[str] = None) -> torch.dtype:
"""Get VAE dtype based on target device and GPU tier."""
target_device = device or self.device
if target_device in ["cuda", "xpu"]:
return torch.bfloat16
if target_device == "mps":
return torch.float16
if target_device == "cpu":
return torch.float32
return self.dtype

View file

@ -0,0 +1,79 @@
"""Metadata formatting helpers for handler decomposition."""
from typing import Any, Dict, List, Optional, Union
class MetadataMixin:
"""Mixin containing metadata parsing and formatting helpers.
Depends on host members:
- No cross-mixin runtime dependencies; pure formatting/parsing helpers.
"""
def _create_default_meta(self) -> str:
"""Create default metadata string."""
return (
"- bpm: N/A\n"
"- timesignature: N/A\n"
"- keyscale: N/A\n"
"- duration: 30 seconds\n"
)
def _dict_to_meta_string(self, meta_dict: Dict[str, Any]) -> str:
"""Convert metadata dict to formatted string."""
bpm = meta_dict.get("bpm", meta_dict.get("tempo", "N/A"))
timesignature = meta_dict.get("timesignature", meta_dict.get("time_signature", "N/A"))
keyscale = meta_dict.get("keyscale", meta_dict.get("key", meta_dict.get("scale", "N/A")))
duration = meta_dict.get("duration", meta_dict.get("length", 30))
if isinstance(duration, (int, float)):
duration = f"{int(duration)} seconds"
elif not isinstance(duration, str):
duration = "30 seconds"
return (
f"- bpm: {bpm}\n"
f"- timesignature: {timesignature}\n"
f"- keyscale: {keyscale}\n"
f"- duration: {duration}\n"
)
def _parse_metas(self, metas: List[Union[str, Dict[str, Any]]]) -> List[str]:
"""Parse and normalize metadata values with safe fallbacks."""
parsed_metas = []
for meta in metas:
if meta is None:
parsed_meta = self._create_default_meta()
elif isinstance(meta, str):
parsed_meta = meta
elif isinstance(meta, dict):
parsed_meta = self._dict_to_meta_string(meta)
else:
parsed_meta = self._create_default_meta()
parsed_metas.append(parsed_meta)
return parsed_metas
def prepare_metadata(
self, bpm: Optional[Union[int, str]], key_scale: str, time_signature: str
) -> Dict[str, Any]:
"""Build metadata dict for generation."""
return self._build_metadata_dict(bpm, key_scale, time_signature)
def _build_metadata_dict(
self,
bpm: Optional[Union[int, str]],
key_scale: str,
time_signature: str,
duration: Optional[float] = None,
) -> Dict[str, Any]:
"""Build metadata dictionary with defaults for missing fields."""
metadata_dict: Dict[str, Any] = {}
metadata_dict["bpm"] = bpm if bpm else "N/A"
metadata_dict["keyscale"] = key_scale if key_scale.strip() else "N/A"
if time_signature.strip() and time_signature != "N/A" and time_signature:
metadata_dict["timesignature"] = time_signature
else:
metadata_dict["timesignature"] = "N/A"
if duration is not None:
metadata_dict["duration"] = f"{int(duration)} seconds"
return metadata_dict

View file

@ -0,0 +1,151 @@
"""Padding helpers for handler batch preparation."""
import torch
from loguru import logger
class PaddingMixin:
"""Mixin containing repaint/lego padding helpers.
Depends on host members:
- Method: ``create_target_wavs`` (provided by ``TaskUtilsMixin`` in this decomposition).
"""
def prepare_padding_info(
self,
actual_batch_size,
processed_src_audio,
audio_duration,
repainting_start,
repainting_end,
is_repaint_task,
is_lego_task,
is_cover_task,
can_use_repainting,
):
"""Prepare padded target wavs and repaint coordinates for each batch item."""
try:
target_wavs_batch = []
# Store padding info for each batch item to adjust repainting coordinates
padding_info_batch = []
for i in range(actual_batch_size):
if processed_src_audio is not None:
if is_cover_task:
# Cover task: Use src_audio directly without padding
batch_target_wavs = processed_src_audio
padding_info_batch.append({"left_padding_duration": 0.0, "right_padding_duration": 0.0})
elif is_repaint_task or is_lego_task:
# Repaint/lego task: May need padding for outpainting
src_audio_duration = processed_src_audio.shape[-1] / 48000.0
# Determine actual end time
if repainting_end is None or repainting_end < 0:
actual_end = src_audio_duration
else:
actual_end = repainting_end
left_padding_duration = max(0, -repainting_start) if repainting_start is not None else 0
right_padding_duration = max(0, actual_end - src_audio_duration)
# Create padded audio
left_padding_frames = int(left_padding_duration * 48000)
right_padding_frames = int(right_padding_duration * 48000)
if left_padding_frames > 0 or right_padding_frames > 0:
# Pad the src audio
batch_target_wavs = torch.nn.functional.pad(
processed_src_audio, (left_padding_frames, right_padding_frames), "constant", 0
)
else:
batch_target_wavs = processed_src_audio
# Store padding info for coordinate adjustment
padding_info_batch.append(
{
"left_padding_duration": left_padding_duration,
"right_padding_duration": right_padding_duration,
}
)
else:
# Other tasks: Use src_audio directly without padding
batch_target_wavs = processed_src_audio
padding_info_batch.append({"left_padding_duration": 0.0, "right_padding_duration": 0.0})
else:
padding_info_batch.append({"left_padding_duration": 0.0, "right_padding_duration": 0.0})
if audio_duration is not None and float(audio_duration) > 0:
batch_target_wavs = self.create_target_wavs(float(audio_duration))
else:
import random
random_duration = random.uniform(10.0, 120.0)
batch_target_wavs = self.create_target_wavs(random_duration)
target_wavs_batch.append(batch_target_wavs)
# Stack target_wavs into batch tensor
# Ensure all tensors have the same shape by padding to max length
max_frames = max(wav.shape[-1] for wav in target_wavs_batch)
padded_target_wavs = []
for wav in target_wavs_batch:
if wav.shape[-1] < max_frames:
pad_frames = max_frames - wav.shape[-1]
padded_wav = torch.nn.functional.pad(wav, (0, pad_frames), "constant", 0)
padded_target_wavs.append(padded_wav)
else:
padded_target_wavs.append(wav)
target_wavs_tensor = torch.stack(padded_target_wavs, dim=0) # [batch_size, 2, frames]
if can_use_repainting:
# Repaint task: Set repainting parameters
if repainting_start is None:
repainting_start_batch = None
elif isinstance(repainting_start, (int, float)):
if processed_src_audio is not None:
adjusted_start = repainting_start + padding_info_batch[0]["left_padding_duration"]
repainting_start_batch = [adjusted_start] * actual_batch_size
else:
repainting_start_batch = [repainting_start] * actual_batch_size
else:
# List input - adjust each item
repainting_start_batch = []
for i in range(actual_batch_size):
if processed_src_audio is not None:
adjusted_start = repainting_start[i] + padding_info_batch[i]["left_padding_duration"]
repainting_start_batch.append(adjusted_start)
else:
repainting_start_batch.append(repainting_start[i])
# Handle repainting_end - use src audio duration if not specified or negative
if processed_src_audio is not None:
# If src audio is provided, use its duration as default end
src_audio_duration = processed_src_audio.shape[-1] / 48000.0
if repainting_end is None or repainting_end < 0:
# Use src audio duration (before padding), then adjust for padding
adjusted_end = src_audio_duration + padding_info_batch[0]["left_padding_duration"]
repainting_end_batch = [adjusted_end] * actual_batch_size
else:
# Adjust repainting_end to be relative to padded audio
adjusted_end = repainting_end + padding_info_batch[0]["left_padding_duration"]
repainting_end_batch = [adjusted_end] * actual_batch_size
else:
# No src audio - repainting doesn't make sense without it
if repainting_end is None or repainting_end < 0:
repainting_end_batch = None
elif isinstance(repainting_end, (int, float)):
repainting_end_batch = [repainting_end] * actual_batch_size
else:
# List input - adjust each item
repainting_end_batch = []
for i in range(actual_batch_size):
repainting_end_batch.append(repainting_end[i])
else:
# All other tasks (cover, text2music, extract, complete): No repainting
# Only repaint and lego tasks should have repainting parameters
repainting_start_batch = None
repainting_end_batch = None
return repainting_start_batch, repainting_end_batch, target_wavs_tensor
except (TypeError, ValueError, RuntimeError, IndexError):
logger.exception("[prepare_padding_info] Error preparing padding information")
fallback = torch.stack([self.create_target_wavs(30.0) for _ in range(actual_batch_size)], dim=0)
return None, None, fallback

View file

@ -0,0 +1,162 @@
"""Prompt and text-input helpers for handler decomposition."""
import re
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from loguru import logger
from acestep.constants import DEFAULT_DIT_INSTRUCTION, SFT_GEN_PROMPT
class PromptMixin:
"""Mixin containing prompt formatting and text-encoder helpers.
Depends on host members:
- Attributes: ``text_tokenizer``, ``text_encoder``, ``device``, ``dtype``.
- Methods: ``_parse_metas`` (from ``MetadataMixin``), ``_load_model_context``
(from ``InitServiceMixin``).
"""
def _format_instruction(self, instruction: str) -> str:
"""Ensure instruction ends with a colon."""
if not instruction.endswith(":"):
instruction = instruction + ":"
return instruction
def _format_lyrics(self, lyrics: str, language: str) -> str:
"""Format lyrics text with language header."""
return f"# Languages\n{language}\n\n# Lyric\n{lyrics}<|endoftext|>"
def _pad_sequences(
self, sequences: List[torch.Tensor], max_length: int, pad_value: int = 0
) -> torch.Tensor:
"""Pad sequence tensors to the same length."""
return torch.stack(
[
torch.nn.functional.pad(seq, (0, max_length - len(seq)), "constant", pad_value)
for seq in sequences
]
)
def extract_caption_from_sft_format(self, caption: str) -> str:
"""Extract caption body from SFT-formatted prompt when present."""
try:
if "# Instruction" in caption and "# Caption" in caption:
pattern = r"#\s*Caption\s*\n(.*?)(?:\n\s*#\s*Metas|$)"
match = re.search(pattern, caption, re.DOTALL)
if match:
return match.group(1).strip()
return caption
except (AttributeError, TypeError, re.error):
logger.exception("[extract_caption_from_sft_format] Error extracting caption")
return caption
def build_dit_inputs(
self,
task: str,
instruction: Optional[str],
caption: str,
lyrics: str,
metas: Optional[Union[str, Dict[str, Any]]] = None,
vocal_language: str = "en",
) -> Tuple[str, str]:
"""Build caption and lyric input text for DiT branches.
Args:
task: Task name (currently informational; reserved for task-specific formatting).
instruction: Instruction text; falls back to default when empty.
caption: Caption fallback value.
lyrics: Raw lyric text.
metas: Optional metadata (string or dict) that may include caption/language.
vocal_language: Fallback lyric language when not present in metadata.
Returns:
Tuple of ``(caption_input, lyrics_input)`` for caption and lyric encoder branches.
"""
final_instruction = self._format_instruction(instruction or DEFAULT_DIT_INSTRUCTION)
actual_caption = caption
actual_language = vocal_language
if metas is not None:
try:
if isinstance(metas, str):
parsed_metas = self._parse_metas([metas])
meta_dict = parsed_metas[0] if parsed_metas and isinstance(parsed_metas[0], dict) else {}
elif isinstance(metas, dict):
meta_dict = metas
else:
meta_dict = {}
except (TypeError, ValueError, KeyError, IndexError):
logger.exception("[build_dit_inputs] Error parsing metas")
meta_dict = {}
if "caption" in meta_dict and meta_dict["caption"]:
actual_caption = str(meta_dict["caption"])
if "language" in meta_dict and meta_dict["language"]:
actual_language = str(meta_dict["language"])
parsed_meta = self._parse_metas([metas])[0]
caption_input = SFT_GEN_PROMPT.format(final_instruction, actual_caption, parsed_meta)
lyrics_input = self._format_lyrics(lyrics, actual_language)
return caption_input, lyrics_input
def _get_text_hidden_states(self, text_prompt: str) -> Tuple[torch.Tensor, torch.Tensor]:
"""Get hidden states and attention mask from text encoder."""
if self.text_tokenizer is None or self.text_encoder is None:
raise ValueError("Text encoder not initialized")
try:
with self._load_model_context("text_encoder"):
text_inputs = self.text_tokenizer(
text_prompt,
padding="longest",
truncation=True,
max_length=256,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids.to(self.device)
text_attention_mask = text_inputs.attention_mask.to(self.device).bool()
with torch.inference_mode():
text_outputs = self.text_encoder(text_input_ids)
if hasattr(text_outputs, "last_hidden_state"):
text_hidden_states = text_outputs.last_hidden_state
elif isinstance(text_outputs, tuple):
text_hidden_states = text_outputs[0]
else:
text_hidden_states = text_outputs
text_hidden_states = text_hidden_states.to(self.dtype)
return text_hidden_states, text_attention_mask
except (AttributeError, RuntimeError, TypeError, ValueError):
logger.exception("[_get_text_hidden_states] Failed to encode text prompt")
raise
def _extract_caption_and_language(
self,
metas: List[Union[str, Dict[str, Any]]],
captions: List[str],
vocal_languages: List[str],
) -> Tuple[List[str], List[str]]:
"""Extract caption/language values from metas with fallback values."""
actual_captions = list(captions)
actual_languages = list(vocal_languages)
for i, meta in enumerate(metas):
if i >= len(actual_captions):
break
meta_dict = None
if isinstance(meta, str):
parsed = self._parse_metas([meta])
if parsed and isinstance(parsed[0], dict):
meta_dict = parsed[0]
elif isinstance(meta, dict):
meta_dict = meta
if meta_dict:
if "caption" in meta_dict and meta_dict["caption"]:
actual_captions[i] = str(meta_dict["caption"])
if "language" in meta_dict and meta_dict["language"]:
actual_languages[i] = str(meta_dict["language"])
return actual_captions, actual_languages

View file

@ -0,0 +1,123 @@
"""Task and seed helpers for handler decomposition."""
import random
from typing import List, Optional, Tuple
import torch
from loguru import logger
from acestep.constants import TASK_INSTRUCTIONS
class TaskUtilsMixin:
"""Mixin containing generation task and seed utility helpers.
Depends on host members:
- No required cross-mixin attributes for seed/instruction helpers.
"""
def prepare_seeds(
self, actual_batch_size: int, seed, use_random_seed: bool
) -> Tuple[List[int], str]:
"""Prepare per-item seeds and UI seed string."""
actual_seed_list: List[int] = []
seed_value_for_ui = ""
try:
if use_random_seed:
actual_seed_list = [random.randint(0, 2**32 - 1) for _ in range(actual_batch_size)]
seed_value_for_ui = ", ".join(str(s) for s in actual_seed_list)
else:
seed_list: List[int] = []
if isinstance(seed, str):
for s in [s.strip() for s in seed.split(",")]:
if s == "-1" or s == "":
seed_list.append(-1)
else:
try:
seed_list.append(int(float(s)))
except (ValueError, TypeError) as exc:
logger.debug(f"[prepare_seeds] Failed to parse seed value '{s}': {exc}")
seed_list.append(-1)
elif seed is None or (isinstance(seed, (int, float)) and seed < 0):
seed_list = [-1] * actual_batch_size
elif isinstance(seed, (int, float)):
seed_list = [int(seed)]
else:
seed_list = [-1] * actual_batch_size
has_single_non_negative_seed = len(seed_list) == 1 and seed_list[0] != -1
for i in range(actual_batch_size):
seed_val = seed_list[i] if i < len(seed_list) else -1
if has_single_non_negative_seed and actual_batch_size > 1 and i > 0:
actual_seed_list.append(random.randint(0, 2**32 - 1))
elif seed_val == -1:
actual_seed_list.append(random.randint(0, 2**32 - 1))
else:
actual_seed_list.append(int(seed_val))
seed_value_for_ui = ", ".join(str(s) for s in actual_seed_list)
except (TypeError, ValueError, OverflowError):
logger.exception("[prepare_seeds] Failed to prepare seeds")
actual_seed_list = [random.randint(0, 2**32 - 1) for _ in range(actual_batch_size)]
seed_value_for_ui = ", ".join(str(s) for s in actual_seed_list)
return actual_seed_list, seed_value_for_ui
def generate_instruction(
self,
task_type: str,
track_name: Optional[str] = None,
complete_track_classes: Optional[List[str]] = None,
) -> str:
"""Generate task instruction text from task type and track context."""
if task_type == "text2music":
return TASK_INSTRUCTIONS["text2music"]
if task_type == "repaint":
return TASK_INSTRUCTIONS["repaint"]
if task_type == "cover":
return TASK_INSTRUCTIONS["cover"]
if task_type == "extract":
return (
TASK_INSTRUCTIONS["extract"].format(TRACK_NAME=track_name.upper())
if track_name
else TASK_INSTRUCTIONS["extract_default"]
)
if task_type == "lego":
return (
TASK_INSTRUCTIONS["lego"].format(TRACK_NAME=track_name.upper())
if track_name
else TASK_INSTRUCTIONS["lego_default"]
)
if task_type == "complete":
if complete_track_classes and len(complete_track_classes) > 0:
track_classes_upper = [t.upper() for t in complete_track_classes]
return TASK_INSTRUCTIONS["complete"].format(
TRACK_CLASSES=" | ".join(track_classes_upper)
)
return TASK_INSTRUCTIONS["complete_default"]
return TASK_INSTRUCTIONS["text2music"]
def determine_task_type(self, task_type, audio_code_string):
"""Compute task-mode booleans for downstream generation logic."""
is_repaint_task = task_type == "repaint"
is_lego_task = task_type == "lego"
is_cover_task = task_type == "cover"
if isinstance(audio_code_string, list):
has_codes = any((c or "").strip() for c in audio_code_string)
else:
has_codes = bool(audio_code_string and str(audio_code_string).strip())
if has_codes:
is_cover_task = True
can_use_repainting = is_repaint_task or is_lego_task
return is_repaint_task, is_lego_task, is_cover_task, can_use_repainting
def create_target_wavs(self, duration_seconds: float) -> torch.Tensor:
"""Create silent stereo target audio with safe duration handling."""
try:
duration_seconds = max(0.1, round(duration_seconds, 1))
frames = int(duration_seconds * 48000)
return torch.zeros(2, frames)
except (TypeError, ValueError, OverflowError):
logger.exception("[create_target_wavs] Error creating target audio")
return torch.zeros(2, 30 * 48000)

View file

@ -19,6 +19,93 @@ from acestep.constants import (
DEBUG_GPU,
)
# ----------------------------------------------------------------------
# CPU thread configuration
# ----------------------------------------------------------------------
# When running on CPU we want to make use of most of the available cores
# but leave a couple free for the OS / other processes. The logic is:
# * If the system has ≤2 logical CPUs, use all of them.
# * Otherwise, use (cpu_count-2) threads.
# This mirrors the common “allbuttwo” heuristic while guaranteeing at
# least one thread.
# The function is executed at import time so that any subsequent
# torch operations respect the setting.
import os
import torch
def _configure_cpu_threads() -> None:
"""Set torch's intra-op and inter-op thread counts based on available CPUs.
This function configures PyTorch to use most available CPU cores while
leaving a couple free for the OS and other processes. The logic is:
* If the system has 2 logical CPUs, use all of them.
* Otherwise, use (cpu_count - 2) threads.
This mirrors the common "all-but-two" heuristic while guaranteeing at
least one thread.
Raises:
RuntimeError: If torch.set_num_threads or torch.set_num_interop_threads
fails (e.g., if called after threads have already been used).
"""
cpu_cnt = os.cpu_count() or 1
# Ensure we never set a non-positive number of threads.
threads = cpu_cnt - 2 if cpu_cnt > 2 else cpu_cnt
threads = max(threads, 1)
try:
torch.set_num_threads(threads)
except RuntimeError as exc:
raise RuntimeError(
f"Failed to set torch intra-op threads to {threads}: {exc}"
) from exc
try:
torch.set_num_interop_threads(threads)
except RuntimeError as exc:
raise RuntimeError(
f"Failed to set torch inter-op threads to {threads}: {exc}"
) from exc
# Track whether CPU threads have been configured to avoid redundant calls.
_cpu_threads_configured = False
def configure_cpu_threads_if_needed() -> bool:
"""Configure CPU threads if enabled via environment variable.
This function provides an opt-in mechanism for configuring PyTorch's
thread counts. It only takes effect if the environment variable
``ACESTEP_CONFIGURE_THREADS`` is set to a truthy value (e.g., "1", "true", "yes").
The configuration is applied at most once per process; subsequent calls
are no-ops.
Returns:
True if configuration was applied, False if skipped (either because
the environment variable is not set or configuration was already done).
Raises:
RuntimeError: If thread configuration fails (propagated from
``_configure_cpu_threads``).
"""
global _cpu_threads_configured
if _cpu_threads_configured:
return False
env_value = os.environ.get("ACESTEP_CONFIGURE_THREADS", "").strip().lower()
if env_value not in ("1", "true", "yes", "on"):
return False
_configure_cpu_threads()
_cpu_threads_configured = True
return True
# Apply opt-in CPU thread configuration early so torch respects it.
configure_cpu_threads_if_needed()
def _normalize_mode(mode: str) -> str:
return (mode or "").strip().upper()

View file

@ -37,6 +37,22 @@ PYTORCH_CUDA_INSTALL_URL = "https://download.pytorch.org/whl/cu121"
PYTORCH_ROCM_INSTALL_URL = "https://download.pytorch.org/whl/rocm6.0"
def is_mps_platform() -> bool:
"""Check if running on macOS with MPS (Apple Silicon) available.
This is the canonical check used across the codebase to apply
Mac-specific configuration overrides (no compile, no quantization,
mlx backend, no offload, etc.).
"""
if sys.platform != "darwin":
return False
try:
import torch
return hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
except Exception:
return False
# ===========================================================================
# Empirical VRAM measurements (GB) -- model weights only, bf16 precision
# These values should be calibrated using scripts/profile_vram.py
@ -95,7 +111,7 @@ class GPUConfig:
recommended_lm_model: str # Recommended default LM model path (empty if LM not available)
# LM backend restriction
# "all" = any backend, "pt_mlx_only" = only pt/mlx (no vllm), used for very low VRAM
# "all" = any backend, "pt_mlx_only" = only pt/mlx (no vllm), used for MPS (vllm requires CUDA)
lm_backend_restriction: str # "all" or "pt_mlx_only"
recommended_backend: str # Recommended default backend: "vllm", "pt", or "mlx"
@ -126,8 +142,8 @@ GPU_TIER_CONFIGS = {
"init_lm_default": False,
"available_lm_models": [],
"recommended_lm_model": "",
"lm_backend_restriction": "pt_mlx_only", # vllm KV cache won't fit
"recommended_backend": "pt",
"lm_backend_restriction": "all",
"recommended_backend": "vllm",
"offload_to_cpu_default": True,
"offload_dit_to_cpu_default": True,
"quantization_default": True, # INT8 essential to fit DiT in ~4GB
@ -145,8 +161,8 @@ GPU_TIER_CONFIGS = {
"init_lm_default": False,
"available_lm_models": [],
"recommended_lm_model": "",
"lm_backend_restriction": "pt_mlx_only",
"recommended_backend": "pt",
"lm_backend_restriction": "all",
"recommended_backend": "vllm",
"offload_to_cpu_default": True,
"offload_dit_to_cpu_default": True,
"quantization_default": True,
@ -156,7 +172,7 @@ GPU_TIER_CONFIGS = {
"tier3": { # 6-8GB
# Offload mode. DiT(4.46) + context(0.5) ≈ 5.0GB.
# ~1.5-3GB headroom allows LM 0.6B (1.2+0.6=1.8GB) and batch=2.
# vllm KV cache is tight; pt backend is safer for 0.6B on this tier.
# With CPU offload, DiT is offloaded before LM runs → vllm can use freed VRAM.
"max_duration_with_lm": 480, # 8 minutes
"max_duration_without_lm": 600, # 10 minutes (max supported)
"max_batch_size_with_lm": 2,
@ -164,8 +180,8 @@ GPU_TIER_CONFIGS = {
"init_lm_default": True,
"available_lm_models": ["acestep-5Hz-lm-0.6B"],
"recommended_lm_model": "acestep-5Hz-lm-0.6B",
"lm_backend_restriction": "pt_mlx_only", # vllm KV cache too greedy for <8GB
"recommended_backend": "pt",
"lm_backend_restriction": "all",
"recommended_backend": "vllm",
"offload_to_cpu_default": True,
"offload_dit_to_cpu_default": True,
"quantization_default": True,
@ -511,6 +527,20 @@ def get_gpu_config(gpu_memory_gb: Optional[float] = None) -> GPUConfig:
"""
Get GPU configuration based on detected or provided GPU memory.
On macOS with MPS (Apple Silicon), several overrides are applied
automatically regardless of the tier selected by memory size:
- ``compile_model_default = False`` ``torch.compile`` is not supported
on MPS and would error or silently fall back to eager mode.
- ``quantization_default = False`` torchao INT8 quantization is
incompatible with MPS / macOS.
- ``recommended_backend = "mlx"`` MLX provides native Apple Silicon
acceleration for the 5Hz LM; vllm requires CUDA.
- ``lm_backend_restriction = "pt_mlx_only"`` vllm cannot run on MPS.
- ``offload_to_cpu_default = False`` Apple Silicon uses unified memory;
offloading to CPU provides no benefit and adds overhead.
- ``offload_dit_to_cpu_default = False`` same reason.
Args:
gpu_memory_gb: GPU memory in GB. If None, will be auto-detected.
@ -523,6 +553,15 @@ def get_gpu_config(gpu_memory_gb: Optional[float] = None) -> GPUConfig:
tier = get_gpu_tier(gpu_memory_gb)
config = GPU_TIER_CONFIGS[tier]
# --- MPS (Apple Silicon) overrides ---
_mps = is_mps_platform()
if _mps:
logger.info(
f"macOS MPS detected ({gpu_memory_gb:.1f} GB unified memory, tier={tier}). "
"Applying Apple Silicon optimizations: no compile, no quantization, "
"mlx backend, no CPU offload."
)
return GPUConfig(
tier=tier,
gpu_memory_gb=gpu_memory_gb,
@ -533,12 +572,15 @@ def get_gpu_config(gpu_memory_gb: Optional[float] = None) -> GPUConfig:
init_lm_default=config["init_lm_default"],
available_lm_models=config["available_lm_models"],
recommended_lm_model=config.get("recommended_lm_model", ""),
lm_backend_restriction=config.get("lm_backend_restriction", "all"),
recommended_backend=config.get("recommended_backend", "vllm"),
offload_to_cpu_default=config.get("offload_to_cpu_default", True),
offload_dit_to_cpu_default=config.get("offload_dit_to_cpu_default", True),
quantization_default=config.get("quantization_default", True),
compile_model_default=config.get("compile_model_default", True),
# MPS: vllm requires CUDA, restrict to pt/mlx; prefer mlx for native acceleration
lm_backend_restriction="pt_mlx_only" if _mps else config.get("lm_backend_restriction", "all"),
recommended_backend="mlx" if _mps else config.get("recommended_backend", "vllm"),
# MPS: unified memory — offloading to CPU is pointless overhead
offload_to_cpu_default=False if _mps else config.get("offload_to_cpu_default", True),
offload_dit_to_cpu_default=False if _mps else config.get("offload_dit_to_cpu_default", True),
# MPS: torch.compile and torchao quantization are not supported
quantization_default=False if _mps else config.get("quantization_default", True),
compile_model_default=False if _mps else config.get("compile_model_default", True),
lm_memory_gb=config["lm_memory_gb"],
)
@ -1056,6 +1098,97 @@ def print_gpu_config_info(gpu_config: GPUConfig):
logger.info(f" - Available LM Models: {gpu_config.available_lm_models or 'None'}")
# Human-readable tier labels for UI display
GPU_TIER_LABELS = {
"tier1": "tier1 (≤4GB)",
"tier2": "tier2 (4-6GB)",
"tier3": "tier3 (6-8GB)",
"tier4": "tier4 (8-12GB)",
"tier5": "tier5 (12-16GB)",
"tier6a": "tier6a (16-20GB)",
"tier6b": "tier6b (20-24GB)",
"unlimited": "unlimited (≥24GB)",
}
# Ordered list of tier keys for dropdown
GPU_TIER_CHOICES = list(GPU_TIER_LABELS.items()) # [(value, label), ...]
def get_gpu_device_name() -> str:
"""
Get the GPU device name string.
Returns:
Human-readable GPU name, e.g. "NVIDIA GeForce RTX 4060 Ti",
"Apple M2 Pro (MPS)", "CPU only", etc.
"""
try:
import torch
if torch.cuda.is_available():
return torch.cuda.get_device_name(0)
elif hasattr(torch, 'xpu') and torch.xpu.is_available():
props = torch.xpu.get_device_properties(0)
return getattr(props, 'name', 'Intel XPU')
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
# MPS doesn't expose a device name; use platform info
try:
import platform
chip = platform.processor() or "Apple Silicon"
return f"{chip} (MPS)"
except Exception:
return "Apple Silicon (MPS)"
else:
return "CPU only"
except ImportError:
return "Unknown (PyTorch not available)"
def get_gpu_config_for_tier(tier: str) -> GPUConfig:
"""
Create a GPUConfig for a specific tier, applying platform overrides.
This is used when the user manually selects a different tier in the UI.
The actual gpu_memory_gb is preserved from the real hardware detection,
but all tier-based settings come from the selected tier's config.
Args:
tier: Tier key, e.g. "tier3", "tier6a", "unlimited"
Returns:
GPUConfig with the selected tier's settings
"""
if tier not in GPU_TIER_CONFIGS:
logger.warning(f"Unknown tier '{tier}', falling back to auto-detected config")
return get_gpu_config()
# Keep the real GPU memory for informational purposes
real_gpu_memory = get_gpu_memory_gb()
config = GPU_TIER_CONFIGS[tier]
_mps = is_mps_platform()
if _mps:
logger.info(f"Manual tier override to {tier} on macOS MPS — applying Apple Silicon overrides")
return GPUConfig(
tier=tier,
gpu_memory_gb=real_gpu_memory,
max_duration_with_lm=config["max_duration_with_lm"],
max_duration_without_lm=config["max_duration_without_lm"],
max_batch_size_with_lm=config["max_batch_size_with_lm"],
max_batch_size_without_lm=config["max_batch_size_without_lm"],
init_lm_default=config["init_lm_default"],
available_lm_models=config["available_lm_models"],
recommended_lm_model=config.get("recommended_lm_model", ""),
lm_backend_restriction="pt_mlx_only" if _mps else config.get("lm_backend_restriction", "all"),
recommended_backend="mlx" if _mps else config.get("recommended_backend", "vllm"),
offload_to_cpu_default=False if _mps else config.get("offload_to_cpu_default", True),
offload_dit_to_cpu_default=False if _mps else config.get("offload_dit_to_cpu_default", True),
quantization_default=False if _mps else config.get("quantization_default", True),
compile_model_default=False if _mps else config.get("compile_model_default", True),
lm_memory_gb=config["lm_memory_gb"],
)
# Global GPU config instance (initialized lazily)
_global_gpu_config: Optional[GPUConfig] = None

View file

@ -10,6 +10,7 @@ from typing import Any, Dict, List, Optional
from uuid import uuid4
from fastapi import APIRouter, HTTPException, Request, Depends, Header
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
# Global results directory inside project root
@ -467,10 +468,30 @@ async def release_task(request: Request, authorization: Optional[str] = Header(N
lm_negative_prompt=get_param("lm_negative_prompt", default="NO USER INPUT") or "NO USER INPUT",
)
# Resolve seed(s) into List[int] for GenerationConfig.seeds
use_random_seed = get_param("use_random_seed", default=True)
resolved_seeds = None
if not use_random_seed:
raw_seed = get_param("seed", default=-1)
if isinstance(raw_seed, str) and raw_seed.strip():
resolved_seeds = []
for s in raw_seed.split(","):
s = s.strip()
if s and s != "-1":
try:
resolved_seeds.append(int(float(s)))
except (ValueError, TypeError):
pass
if not resolved_seeds:
resolved_seeds = None
elif isinstance(raw_seed, (int, float)) and int(raw_seed) >= 0:
resolved_seeds = [int(raw_seed)]
config = GenerationConfig(
batch_size=get_param("batch_size", default=2),
use_random_seed=get_param("use_random_seed", default=True),
audio_format=get_param("audio_format", default="mp3"),
use_random_seed=use_random_seed,
seeds=resolved_seeds,
audio_format=get_param("audio_format", default="flac"),
)
# Get output directory
@ -512,6 +533,38 @@ async def release_task(request: Request, authorization: Optional[str] = Header(N
raise HTTPException(status_code=500, detail=str(e))
# Origins that are expected to call the API:
# - "null" → studio.html opened via file:// protocol
# - http://localhost:* → local dev servers / Gradio UI
# - http://127.0.0.1:* → same, numeric form
_CORS_KWARGS = dict(
allow_origins=["null", "http://localhost", "http://127.0.0.1"],
allow_origin_regex=r"^https?://(localhost|127\.0\.0\.1)(:\d+)?$",
allow_methods=["GET", "POST", "OPTIONS"],
allow_headers=["Content-Type", "Authorization"],
)
def _add_cors_middleware(app):
"""Add CORS middleware so browser-based frontends (e.g. studio.html via file://) can call the API."""
app.add_middleware(CORSMiddleware, **_CORS_KWARGS)
def _add_cors_middleware_post_launch(app):
"""Wrap an already-started app's middleware stack with CORS.
``add_middleware`` raises after Starlette has started, so we patch the
compiled middleware stack directly instead.
"""
from starlette.middleware.cors import CORSMiddleware as _CORSImpl
if app.middleware_stack is not None:
app.middleware_stack = _CORSImpl(app=app.middleware_stack, **_CORS_KWARGS)
else:
# App hasn't built its stack yet safe to use the normal path
_add_cors_middleware(app)
def setup_api_routes_to_app(app, dit_handler, llm_handler, api_key: Optional[str] = None):
"""
Mount API routes to a FastAPI application (for use with gr.mount_gradio_app)
@ -523,6 +576,7 @@ def setup_api_routes_to_app(app, dit_handler, llm_handler, api_key: Optional[str
api_key: Optional API key for authentication
"""
set_api_key(api_key)
_add_cors_middleware(app)
app.state.dit_handler = dit_handler
app.state.llm_handler = llm_handler
app.include_router(router)
@ -540,6 +594,7 @@ def setup_api_routes(demo, dit_handler, llm_handler, api_key: Optional[str] = No
"""
set_api_key(api_key)
app = demo.app
_add_cors_middleware_post_launch(app)
app.state.dit_handler = dit_handler
app.state.llm_handler = llm_handler
app.include_router(router)

View file

@ -31,7 +31,7 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
generation_section["config_path"].change(
fn=gen_h.update_model_type_settings,
inputs=[generation_section["config_path"]],
inputs=[generation_section["config_path"], generation_section["generation_mode"]],
outputs=[
generation_section["inference_steps"],
generation_section["guidance_scale"],
@ -40,6 +40,25 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
generation_section["cfg_interval_start"],
generation_section["cfg_interval_end"],
generation_section["task_type"],
generation_section["generation_mode"],
]
)
# ========== Tier Override ==========
generation_section["tier_dropdown"].change(
fn=lambda tier: gen_h.on_tier_change(tier, llm_handler),
inputs=[generation_section["tier_dropdown"]],
outputs=[
generation_section["offload_to_cpu_checkbox"],
generation_section["offload_dit_to_cpu_checkbox"],
generation_section["compile_model_checkbox"],
generation_section["quantization_checkbox"],
generation_section["backend_dropdown"],
generation_section["lm_model_path"],
generation_section["init_llm_checkbox"],
generation_section["batch_size_input"],
generation_section["audio_duration"],
generation_section["gpu_info_display"],
]
)
@ -57,6 +76,8 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
generation_section["offload_dit_to_cpu_checkbox"],
generation_section["compile_model_checkbox"],
generation_section["quantization_checkbox"],
generation_section["mlx_dit_checkbox"],
generation_section["generation_mode"], # preserve current mode across init
],
outputs=[
generation_section["init_status"],
@ -70,9 +91,12 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
generation_section["cfg_interval_start"],
generation_section["cfg_interval_end"],
generation_section["task_type"],
generation_section["generation_mode"],
# GPU-config-aware limits (updated after initialization)
generation_section["audio_duration"],
generation_section["batch_size_input"],
# Think checkbox: enable if LLM initialized
generation_section["think_checkbox"],
]
)
@ -115,24 +139,6 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
outputs=[generation_section["lm_negative_prompt"]]
)
generation_section["init_llm_checkbox"].change(
fn=gen_h.update_audio_cover_strength_visibility,
inputs=[generation_section["task_type"], generation_section["init_llm_checkbox"], generation_section["reference_audio"]],
outputs=[generation_section["audio_cover_strength"]]
)
generation_section["task_type"].change(
fn=gen_h.update_audio_cover_strength_visibility,
inputs=[generation_section["task_type"], generation_section["init_llm_checkbox"], generation_section["reference_audio"]],
outputs=[generation_section["audio_cover_strength"]]
)
generation_section["reference_audio"].change(
fn=gen_h.update_audio_cover_strength_visibility,
inputs=[generation_section["task_type"], generation_section["init_llm_checkbox"], generation_section["reference_audio"]],
outputs=[generation_section["audio_cover_strength"]]
)
generation_section["batch_size_input"].change(
fn=gen_h.update_audio_components_visibility,
inputs=[generation_section["batch_size_input"]],
@ -149,13 +155,34 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
]
)
# ========== Audio Conversion ==========
# ========== Audio Conversion (LM Codes Hints accordion in Custom mode) ==========
generation_section["convert_src_to_codes_btn"].click(
fn=lambda src: gen_h.convert_src_audio_to_codes_wrapper(dit_handler, src),
inputs=[generation_section["src_audio"]],
inputs=[generation_section["lm_codes_audio_upload"]],
outputs=[generation_section["text2music_audio_code_string"]]
)
# ========== Analyze Source Audio (Remix/Repaint: convert to codes + transcribe) ==========
generation_section["analyze_btn"].click(
fn=lambda src, debug: gen_h.analyze_src_audio(dit_handler, llm_handler, src, debug),
inputs=[
generation_section["src_audio"],
generation_section["constrained_decoding_debug"],
],
outputs=[
generation_section["text2music_audio_code_string"],
results_section["status_output"],
generation_section["captions"],
generation_section["lyrics"],
generation_section["bpm"],
generation_section["audio_duration"],
generation_section["key_scale"],
generation_section["vocal_language"],
generation_section["time_signature"],
results_section["is_format_caption_state"],
]
)
# ========== Instruction UI Updates ==========
for trigger in [generation_section["task_type"], generation_section["track_name"], generation_section["complete_track_classes"], generation_section["reference_audio"]]:
trigger.change(
@ -164,7 +191,6 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
generation_section["task_type"],
generation_section["track_name"],
generation_section["complete_track_classes"],
generation_section["text2music_audio_code_string"],
generation_section["init_llm_checkbox"],
generation_section["reference_audio"],
],
@ -172,9 +198,7 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
generation_section["instruction_display_gen"],
generation_section["track_name"],
generation_section["complete_track_classes"],
generation_section["audio_cover_strength"],
generation_section["repainting_group"],
generation_section["text2music_audio_codes_group"],
]
)
@ -233,25 +257,23 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
outputs=[results_section["is_format_caption_state"]]
)
# ========== Audio Uploads Accordion ==========
for trigger in [generation_section["reference_audio"], generation_section["src_audio"]]:
trigger.change(
fn=gen_h.update_audio_uploads_accordion,
inputs=[generation_section["reference_audio"], generation_section["src_audio"]],
outputs=[generation_section["audio_uploads_accordion"]]
)
# ========== Instrumental Checkbox ==========
generation_section["instrumental_checkbox"].change(
fn=gen_h.handle_instrumental_checkbox,
inputs=[generation_section["instrumental_checkbox"], generation_section["lyrics"]],
outputs=[generation_section["lyrics"]]
inputs=[
generation_section["instrumental_checkbox"],
generation_section["lyrics"],
generation_section["lyrics_before_instrumental"],
],
outputs=[
generation_section["lyrics"],
generation_section["lyrics_before_instrumental"],
]
)
# ========== Format Button ==========
# Note: cfg_scale and negative_prompt are not supported in format mode
generation_section["format_btn"].click(
fn=lambda caption, lyrics, bpm, duration, key_scale, time_sig, temp, top_k, top_p, debug: gen_h.handle_format_sample(
# ========== Format Caption Button ==========
generation_section["format_caption_btn"].click(
fn=lambda caption, lyrics, bpm, duration, key_scale, time_sig, temp, top_k, top_p, debug: gen_h.handle_format_caption(
llm_handler, caption, lyrics, bpm, duration, key_scale, time_sig, temp, top_k, top_p, debug
),
inputs=[
@ -268,6 +290,34 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
],
outputs=[
generation_section["captions"],
generation_section["bpm"],
generation_section["audio_duration"],
generation_section["key_scale"],
generation_section["vocal_language"],
generation_section["time_signature"],
results_section["is_format_caption_state"],
results_section["status_output"],
]
)
# ========== Format Lyrics Button ==========
generation_section["format_lyrics_btn"].click(
fn=lambda caption, lyrics, bpm, duration, key_scale, time_sig, temp, top_k, top_p, debug: gen_h.handle_format_lyrics(
llm_handler, caption, lyrics, bpm, duration, key_scale, time_sig, temp, top_k, top_p, debug
),
inputs=[
generation_section["captions"],
generation_section["lyrics"],
generation_section["bpm"],
generation_section["audio_duration"],
generation_section["key_scale"],
generation_section["time_signature"],
generation_section["lm_temperature"],
generation_section["lm_top_k"],
generation_section["lm_top_p"],
generation_section["constrained_decoding_debug"],
],
outputs=[
generation_section["lyrics"],
generation_section["bpm"],
generation_section["audio_duration"],
@ -279,17 +329,30 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
]
)
# ========== Simple/Custom Mode Toggle ==========
# ========== Generation Mode Change ==========
generation_section["generation_mode"].change(
fn=gen_h.handle_generation_mode_change,
fn=lambda mode: gen_h.handle_generation_mode_change(mode, llm_handler),
inputs=[generation_section["generation_mode"]],
outputs=[
generation_section["simple_mode_group"],
generation_section["caption_accordion"],
generation_section["lyrics_accordion"],
generation_section["custom_mode_group"],
generation_section["generate_btn"],
generation_section["simple_sample_created"],
generation_section["optional_params_accordion"],
generation_section["task_type"],
generation_section["src_audio_row"],
generation_section["repainting_group"],
generation_section["text2music_audio_codes_group"],
generation_section["track_name"],
generation_section["complete_track_classes"],
generation_section["generate_btn_row"],
generation_section["generation_mode"],
generation_section["results_wrapper"],
generation_section["think_checkbox"],
generation_section["load_file_col"],
generation_section["load_file"],
generation_section["audio_cover_strength"],
generation_section["cover_noise_strength"],
]
)
@ -337,13 +400,12 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
generation_section["simple_vocal_language"],
generation_section["time_signature"],
generation_section["instrumental_checkbox"],
generation_section["caption_accordion"],
generation_section["lyrics_accordion"],
generation_section["generate_btn"],
generation_section["simple_sample_created"],
generation_section["think_checkbox"],
results_section["is_format_caption_state"],
results_section["status_output"],
generation_section["generation_mode"],
]
)
@ -381,6 +443,7 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
generation_section["use_cot_caption"],
generation_section["use_cot_language"],
generation_section["audio_cover_strength"],
generation_section["cover_noise_strength"],
generation_section["think_checkbox"],
generation_section["text2music_audio_code_string"],
generation_section["repainting_start"],
@ -470,26 +533,62 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
],
js=download_existing_js # Run the above JS
)
# ========== Send to SRC Handlers ==========
# ========== Send to Remix / Repaint Handlers ==========
# Mode-UI outputs shared with generation_mode.change — applied atomically
# so we don't rely on a chained .change() event for visibility/label updates.
_mode_ui_outputs = [
generation_section["simple_mode_group"],
generation_section["custom_mode_group"],
generation_section["generate_btn"],
generation_section["simple_sample_created"],
generation_section["optional_params_accordion"],
generation_section["task_type"],
generation_section["src_audio_row"],
generation_section["repainting_group"],
generation_section["text2music_audio_codes_group"],
generation_section["track_name"],
generation_section["complete_track_classes"],
generation_section["generate_btn_row"],
generation_section["generation_mode"],
generation_section["results_wrapper"],
generation_section["think_checkbox"],
generation_section["load_file_col"],
generation_section["load_file"],
generation_section["audio_cover_strength"],
generation_section["cover_noise_strength"],
]
for btn_idx in range(1, 9):
results_section[f"send_to_src_btn_{btn_idx}"].click(
fn=res_h.send_audio_to_src_with_metadata,
results_section[f"send_to_remix_btn_{btn_idx}"].click(
fn=lambda audio, lm, ly, cap: res_h.send_audio_to_remix(
audio, lm, ly, cap, llm_handler),
inputs=[
results_section[f"generated_audio_{btn_idx}"],
results_section["lm_metadata_state"]
results_section["lm_metadata_state"],
generation_section["lyrics"],
generation_section["captions"],
],
outputs=[
generation_section["src_audio"],
generation_section["bpm"],
generation_section["captions"],
generation_section["generation_mode"],
generation_section["lyrics"],
generation_section["audio_duration"],
generation_section["key_scale"],
generation_section["vocal_language"],
generation_section["time_signature"],
results_section["is_format_caption_state"],
generation_section["audio_uploads_accordion"]
]
generation_section["captions"],
] + _mode_ui_outputs,
)
results_section[f"send_to_repaint_btn_{btn_idx}"].click(
fn=lambda audio, lm, ly, cap: res_h.send_audio_to_repaint(
audio, lm, ly, cap, llm_handler),
inputs=[
results_section[f"generated_audio_{btn_idx}"],
results_section["lm_metadata_state"],
generation_section["lyrics"],
generation_section["captions"],
],
outputs=[
generation_section["src_audio"],
generation_section["generation_mode"],
generation_section["lyrics"],
generation_section["captions"],
] + _mode_ui_outputs,
)
# ========== Score Calculation Handlers ==========
@ -539,6 +638,25 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
]
)
# ========== Convert To Codes Handlers ==========
for btn_idx in range(1, 9):
results_section[f"convert_to_codes_btn_{btn_idx}"].click(
fn=lambda audio: res_h.convert_result_audio_to_codes(dit_handler, audio),
inputs=[results_section[f"generated_audio_{btn_idx}"]],
outputs=[
results_section[f"codes_display_{btn_idx}"],
results_section[f"details_accordion_{btn_idx}"],
]
)
# ========== Save LRC Handlers ==========
for btn_idx in range(1, 9):
results_section[f"save_lrc_btn_{btn_idx}"].click(
fn=res_h.save_lrc_to_file,
inputs=[results_section[f"lrc_display_{btn_idx}"]],
outputs=[results_section[f"lrc_download_file_{btn_idx}"]]
)
def generation_wrapper(*args):
yield from res_h.generate_with_batch_management(dit_handler, llm_handler, *args)
# ========== Generation Handler ==========
@ -577,6 +695,7 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
generation_section["repainting_end"],
generation_section["instruction_display_gen"],
generation_section["audio_cover_strength"],
generation_section["cover_noise_strength"],
generation_section["task_type"],
generation_section["use_adg"],
generation_section["cfg_interval_start"],
@ -603,6 +722,10 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
generation_section["lm_batch_chunk_size"],
generation_section["track_name"],
generation_section["complete_track_classes"],
generation_section["enable_normalization"],
generation_section["normalization_db"],
generation_section["latent_shift"],
generation_section["latent_rescale"],
generation_section["autogen_checkbox"],
results_section["current_batch_index"],
results_section["total_batches"],
@ -765,6 +888,7 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
generation_section["repainting_end"],
generation_section["instruction_display_gen"],
generation_section["audio_cover_strength"],
generation_section["cover_noise_strength"],
generation_section["task_type"],
generation_section["use_adg"],
generation_section["cfg_interval_start"],
@ -790,6 +914,10 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
generation_section["lm_batch_chunk_size"],
generation_section["track_name"],
generation_section["complete_track_classes"],
generation_section["enable_normalization"],
generation_section["normalization_db"],
generation_section["latent_shift"],
generation_section["latent_rescale"],
],
outputs=[results_section["generation_params_state"]]
).then(
@ -897,6 +1025,10 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
generation_section["allow_lm_batch"],
generation_section["track_name"],
generation_section["complete_track_classes"],
generation_section["enable_normalization"],
generation_section["normalization_db"],
generation_section["latent_shift"],
generation_section["latent_rescale"],
]
)
@ -1178,11 +1310,12 @@ def setup_training_event_handlers(demo, dit_handler, llm_handler, training_secti
# Preprocess dataset to tensor files
training_section["preprocess_btn"].click(
fn=lambda output_dir, state: train_h.preprocess_dataset(
output_dir, dit_handler, state
fn=lambda output_dir, mode, state: train_h.preprocess_dataset(
output_dir, mode, dit_handler, state
),
inputs=[
training_section["preprocess_output_dir"],
training_section["preprocess_mode"],
training_section["dataset_builder_state"],
],
outputs=[training_section["preprocess_progress"]]
@ -1256,3 +1389,92 @@ def setup_training_event_handlers(demo, dit_handler, llm_handler, training_secti
],
outputs=[training_section["export_status"]]
)
# ========== LoKr Training Tab Handlers ==========
# Load preprocessed tensor dataset for LoKr
training_section["lokr_load_dataset_btn"].click(
fn=train_h.load_training_dataset,
inputs=[training_section["lokr_training_tensor_dir"]],
outputs=[training_section["lokr_training_dataset_info"]]
)
# Start LoKr training from preprocessed tensors
def lokr_training_wrapper(
tensor_dir, ldim, lalpha, factor, decompose_both, use_tucker,
use_scalar, weight_decompose, lr, ep, bs, ga, se, sh, sd, od, ts,
):
from loguru import logger
if not isinstance(ts, dict):
ts = {"is_training": False, "should_stop": False}
try:
for progress, log_msg, plot, state in train_h.start_lokr_training(
tensor_dir, dit_handler,
ldim, lalpha, factor, decompose_both, use_tucker,
use_scalar, weight_decompose,
lr, ep, bs, ga, se, sh, sd, od, ts,
):
yield progress, log_msg, plot, state
except Exception as e:
logger.exception("LoKr training wrapper error")
yield f"❌ Error: {str(e)}", str(e), None, ts
training_section["start_lokr_training_btn"].click(
fn=lokr_training_wrapper,
inputs=[
training_section["lokr_training_tensor_dir"],
training_section["lokr_linear_dim"],
training_section["lokr_linear_alpha"],
training_section["lokr_factor"],
training_section["lokr_decompose_both"],
training_section["lokr_use_tucker"],
training_section["lokr_use_scalar"],
training_section["lokr_weight_decompose"],
training_section["lokr_learning_rate"],
training_section["lokr_train_epochs"],
training_section["lokr_train_batch_size"],
training_section["lokr_gradient_accumulation"],
training_section["lokr_save_every_n_epochs"],
training_section["lokr_training_shift"],
training_section["lokr_training_seed"],
training_section["lokr_output_dir"],
training_section["training_state"],
],
outputs=[
training_section["lokr_training_progress"],
training_section["lokr_training_log"],
training_section["lokr_training_loss_plot"],
training_section["training_state"],
]
)
# Stop LoKr training (reuses same stop mechanism)
training_section["stop_lokr_training_btn"].click(
fn=train_h.stop_training,
inputs=[training_section["training_state"]],
outputs=[
training_section["lokr_training_progress"],
training_section["training_state"],
]
)
# Refresh LoKr export epochs
training_section["refresh_lokr_export_epochs_btn"].click(
fn=train_h.list_lokr_export_epochs,
inputs=[training_section["lokr_output_dir"]],
outputs=[
training_section["lokr_export_epoch"],
training_section["lokr_export_status"],
]
)
# Export LoKr
training_section["export_lokr_btn"].click(
fn=train_h.export_lokr,
inputs=[
training_section["lokr_export_path"],
training_section["lokr_output_dir"],
training_section["lokr_export_epoch"],
],
outputs=[training_section["lokr_export_status"]]
)

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,88 @@
"""Unit tests for generation input event handlers."""
import unittest
from types import SimpleNamespace
from unittest.mock import patch
try:
from acestep.gradio_ui.events import generation_handlers
_IMPORT_ERROR = None
except Exception as exc: # pragma: no cover - environment dependency guard
generation_handlers = None
_IMPORT_ERROR = exc
class _FakeDitHandler:
"""Minimal DiT handler stub for analyze-src-audio tests."""
def __init__(self, convert_result):
self._convert_result = convert_result
def convert_src_audio_to_codes(self, _src_audio):
"""Return configured conversion output."""
return self._convert_result
@unittest.skipIf(generation_handlers is None, f"generation_handlers import unavailable: {_IMPORT_ERROR}")
class GenerationHandlersTests(unittest.TestCase):
"""Tests for source-audio analysis validation behavior."""
@patch("acestep.gradio_ui.events.generation_handlers.gr.Warning")
@patch("acestep.gradio_ui.events.generation_handlers.understand_music")
def test_analyze_src_audio_rejects_non_audio_code_output(
self,
understand_music_mock,
warning_mock,
):
"""Reject conversion output that has no serialized audio-code tokens."""
dit_handler = _FakeDitHandler("ERROR: not an audio file")
llm_handler = SimpleNamespace(llm_initialized=True)
result = generation_handlers.analyze_src_audio(
dit_handler=dit_handler,
llm_handler=llm_handler,
src_audio="fake.mp3",
constrained_decoding_debug=False,
)
self.assertEqual(result, ("", "", "", "", None, None, "", "", "", False))
understand_music_mock.assert_not_called()
warning_mock.assert_called_once()
@patch("acestep.gradio_ui.events.generation_handlers.gr.Warning")
@patch("acestep.gradio_ui.events.generation_handlers.understand_music")
def test_analyze_src_audio_allows_valid_audio_code_output(
self,
understand_music_mock,
warning_mock,
):
"""Pass valid audio codes through to LM understanding."""
dit_handler = _FakeDitHandler("<|audio_code_123|><|audio_code_456|>")
llm_handler = SimpleNamespace(llm_initialized=True)
understand_music_mock.return_value = SimpleNamespace(
success=True,
status_message="ok",
caption="caption",
lyrics="lyrics",
bpm=120,
duration=30.0,
keyscale="C major",
language="en",
timesignature="4",
)
result = generation_handlers.analyze_src_audio(
dit_handler=dit_handler,
llm_handler=llm_handler,
src_audio="real.mp3",
constrained_decoding_debug=False,
)
self.assertEqual(result[0], "<|audio_code_123|><|audio_code_456|>")
self.assertEqual(result[1], "ok")
understand_music_mock.assert_called_once()
warning_mock.assert_not_called()
if __name__ == "__main__":
unittest.main()

View file

@ -5,14 +5,21 @@ Contains event handlers and helper functions related to result display, scoring,
import os
import json
import datetime
import math
import re
import tempfile
import shutil
import zipfile
import time as time_module
import sys
from typing import Dict, Any, Optional, List
import gradio as gr
from loguru import logger
from acestep.gradio_ui.i18n import t
from acestep.gradio_ui.events.generation_handlers import parse_and_validate_timesteps
from acestep.gradio_ui.events.generation_handlers import (
parse_and_validate_timesteps,
compute_mode_ui_updates,
)
from acestep.inference import generate_music, GenerationParams, GenerationConfig
from acestep.audio_utils import save_audio
from acestep.gpu_config import (
@ -21,7 +28,7 @@ from acestep.gpu_config import (
check_batch_size_limit,
)
# Platform detection for Windows-specific fixess
# Platform detection for Windows-specific fixes
IS_WINDOWS = sys.platform == "win32"
# Global results directory inside project root
@ -272,122 +279,59 @@ def _build_generation_info(
seed_value: str,
inference_steps: int,
num_audios: int,
audio_format: str = "flac",
) -> str:
"""Build generation info string from result data.
"""Build a compact generation timing summary.
Args:
lm_metadata: LM-generated metadata dictionary
lm_metadata: LM-generated metadata dictionary (unused, kept for API compat)
time_costs: Unified time costs dictionary
seed_value: Seed value string
inference_steps: Number of inference steps
seed_value: Seed value string (unused, kept for API compat)
inference_steps: Number of inference steps (unused, kept for API compat)
num_audios: Number of generated audios
audio_format: Output audio format name (e.g. "flac", "mp3", "wav32")
Returns:
Formatted generation info string
"""
if not time_costs or num_audios <= 0:
return ""
songs_label = f"({num_audios} song{'s' if num_audios > 1 else ''})"
info_parts = []
# Part 1: Per-track average time (prominently displayed at the top)
# Only count model time (LM + DiT), not post-processing like audio conversion
if time_costs and num_audios > 0:
lm_total = time_costs.get('lm_total_time', 0.0)
dit_total = time_costs.get('dit_total_time_cost', 0.0)
model_total = lm_total + dit_total
if model_total > 0:
avg_time_per_track = model_total / num_audios
avg_section = f"**🎯 Average Time per Track: {avg_time_per_track:.2f}s** ({num_audios} track(s))"
info_parts.append(avg_section)
# Part 2: LM-generated metadata (if available)
if lm_metadata:
metadata_lines = []
if lm_metadata.get('bpm'):
metadata_lines.append(f"- **BPM:** {lm_metadata['bpm']}")
if lm_metadata.get('caption'):
metadata_lines.append(f"- **Refined Caption:** {lm_metadata['caption']}")
if lm_metadata.get('lyrics'):
metadata_lines.append(f"- **Refined Lyrics:** {lm_metadata['lyrics']}")
if lm_metadata.get('duration'):
metadata_lines.append(f"- **Duration:** {lm_metadata['duration']} seconds")
if lm_metadata.get('keyscale'):
metadata_lines.append(f"- **Key Scale:** {lm_metadata['keyscale']}")
if lm_metadata.get('language'):
metadata_lines.append(f"- **Language:** {lm_metadata['language']}")
if lm_metadata.get('timesignature'):
metadata_lines.append(f"- **Time Signature:** {lm_metadata['timesignature']}")
if metadata_lines:
metadata_section = "**🤖 LM-Generated Metadata:**\n" + "\n".join(metadata_lines)
info_parts.append(metadata_section)
# Part 3: Time costs breakdown (formatted and beautified)
if time_costs:
time_lines = []
# LM time costs
lm_phase1 = time_costs.get('lm_phase1_time', 0.0)
lm_phase2 = time_costs.get('lm_phase2_time', 0.0)
lm_total = time_costs.get('lm_total_time', 0.0)
# --- Block 1: Generation time (LM + DiT) ---
lm_total = time_costs.get('lm_total_time', 0.0)
dit_total = time_costs.get('dit_total_time_cost', 0.0)
gen_total = lm_total + dit_total
if gen_total > 0:
avg = gen_total / num_audios
lines = [f"**🎵 Total generation time {songs_label}: {gen_total:.2f}s**"]
lines.append(f"- {avg:.2f}s per song")
if lm_total > 0:
time_lines.append("**🧠 LM Time:**")
if lm_phase1 > 0:
time_lines.append(f" - Phase 1 (CoT): {lm_phase1:.2f}s")
if lm_phase2 > 0:
time_lines.append(f" - Phase 2 (Codes): {lm_phase2:.2f}s")
time_lines.append(f" - Total: {lm_total:.2f}s")
# DiT time costs
dit_encoder = time_costs.get('dit_encoder_time_cost', 0.0)
dit_model = time_costs.get('dit_model_time_cost', 0.0)
dit_vae_decode = time_costs.get('dit_vae_decode_time_cost', 0.0)
dit_offload = time_costs.get('dit_offload_time_cost', 0.0)
dit_total = time_costs.get('dit_total_time_cost', 0.0)
lines.append(f"- LM phase {songs_label}: {lm_total:.2f}s")
if dit_total > 0:
time_lines.append("\n**🎵 DiT Time:**")
if dit_encoder > 0:
time_lines.append(f" - Encoder: {dit_encoder:.2f}s")
if dit_model > 0:
time_lines.append(f" - Model: {dit_model:.2f}s")
if dit_vae_decode > 0:
time_lines.append(f" - VAE Decode: {dit_vae_decode:.2f}s")
if dit_offload > 0:
time_lines.append(f" - Offload: {dit_offload:.2f}s")
time_lines.append(f" - Total: {dit_total:.2f}s")
# Post-processing time costs
audio_conversion_time = time_costs.get('audio_conversion_time', 0.0)
auto_score_time = time_costs.get('auto_score_time', 0.0)
auto_lrc_time = time_costs.get('auto_lrc_time', 0.0)
if audio_conversion_time > 0 or auto_score_time > 0 or auto_lrc_time > 0:
time_lines.append("\n**🔧 Post-processing Time:**")
if audio_conversion_time > 0:
time_lines.append(f" - Audio Conversion: {audio_conversion_time:.2f}s")
if auto_score_time > 0:
time_lines.append(f" - Auto Score: {auto_score_time:.2f}s")
if auto_lrc_time > 0:
time_lines.append(f" - Auto LRC: {auto_lrc_time:.2f}s")
if time_lines:
time_section = "\n".join(time_lines)
info_parts.append(time_section)
# Part 4: Generation summary
summary_lines = [
"**🎵 Generation Complete**",
f" - **Seeds:** {seed_value}",
f" - **Steps:** {inference_steps}",
f" - **Audio Count:** {num_audios} audio(s)",
]
info_parts.append("\n".join(summary_lines))
# Part 5: Pipeline total time (at the end)
pipeline_total = time_costs.get('pipeline_total_time', 0.0) if time_costs else 0.0
if pipeline_total > 0:
info_parts.append(f"**⏱️ Total Time: {pipeline_total:.2f}s**")
# Combine all parts
lines.append(f"- DiT phase {songs_label}: {dit_total:.2f}s")
info_parts.append("\n".join(lines))
# --- Block 2: Processing time (conversion + scoring + LRC) ---
audio_conversion_time = time_costs.get('audio_conversion_time', 0.0)
auto_score_time = time_costs.get('auto_score_time', 0.0)
auto_lrc_time = time_costs.get('auto_lrc_time', 0.0)
proc_total = audio_conversion_time + auto_score_time + auto_lrc_time
if proc_total > 0:
fmt_label = audio_format.upper() if audio_format != "wav32" else "WAV 32-bit"
lines = [f"**🔧 Total processing time {songs_label}: {proc_total:.2f}s**"]
if audio_conversion_time > 0:
lines.append(f"- to {fmt_label} {songs_label}: {audio_conversion_time:.2f}s")
if auto_score_time > 0:
lines.append(f"- scoring {songs_label}: {auto_score_time:.2f}s")
if auto_lrc_time > 0:
lines.append(f"- LRC detection {songs_label}: {auto_lrc_time:.2f}s")
info_parts.append("\n".join(lines))
return "\n\n".join(info_parts)
@ -450,7 +394,6 @@ def send_audio_to_src_with_metadata(audio_file, lm_metadata):
This function ONLY sets the src_audio field. All other metadata fields (caption, lyrics, etc.)
are preserved by returning gr.skip() to avoid overwriting user's existing inputs.
Also opens the audio_uploads_accordion so the user can see the uploaded audio.
Args:
audio_file: Audio file path
@ -458,7 +401,7 @@ def send_audio_to_src_with_metadata(audio_file, lm_metadata):
Returns:
Tuple of (audio_file, bpm, caption, lyrics, duration, key_scale, language, time_signature, is_format_caption, audio_uploads_accordion)
All values except audio_file and audio_uploads_accordion are gr.skip() to preserve existing UI values
All values except audio_file and accordion are gr.skip() to preserve existing UI values
"""
if audio_file is None:
# Return all skip to not modify anything
@ -476,7 +419,106 @@ def send_audio_to_src_with_metadata(audio_file, lm_metadata):
gr.skip(), # language - preserve existing value
gr.skip(), # time_signature - preserve existing value
gr.skip(), # is_format_caption - preserve existing value
gr.Accordion(open=True), # audio_uploads_accordion - expand to show uploaded audio
gr.Accordion(open=True), # audio_uploads_accordion - open to show the src audio
)
def _extract_metadata_for_editing(lm_metadata, current_lyrics="", current_caption=""):
"""Extract lyrics and caption from lm_metadata for repaint/remix operations.
Falls back to current UI values when lm_metadata is missing or incomplete,
so that existing user input is not overwritten with empty strings.
Args:
lm_metadata: Metadata dictionary from LM generation (or None)
current_lyrics: Current lyrics value from the UI (fallback)
current_caption: Current caption value from the UI (fallback)
Returns:
Tuple of (lyrics, caption) as strings
"""
lyrics = current_lyrics or ""
caption = current_caption or ""
if lm_metadata and isinstance(lm_metadata, dict):
lyrics = lm_metadata.get("lyrics", lyrics)
caption = lm_metadata.get("caption", caption)
return lyrics, caption
def send_audio_to_remix(audio_file, lm_metadata, current_lyrics, current_caption,
llm_handler=None):
"""Send generated audio to src_audio and switch mode to Remix.
Also populate lyrics and caption fields from the generated audio,
and apply all Remix-mode UI updates atomically (visibility, labels)
so the UI is correct without relying on a chained .change() event.
Args:
audio_file: Generated audio file path
lm_metadata: LM metadata dict (may be None)
current_lyrics: Current lyrics text in the UI
current_caption: Current caption text in the UI
llm_handler: Optional LLM handler (for think-checkbox state)
Returns:
Tuple of (src_audio, generation_mode, lyrics, caption,
*mode_ui_updates) 4 + 19 = 23 values.
"""
# 4 data outputs + 19 mode-UI outputs
n_outputs = 23
if audio_file is None:
return (gr.skip(),) * n_outputs
lyrics, caption = _extract_metadata_for_editing(
lm_metadata, current_lyrics, current_caption
)
mode_updates = compute_mode_ui_updates("Remix", llm_handler)
return (
audio_file, # src_audio
gr.update(value="Remix"), # generation_mode -> Remix
lyrics, # lyrics
caption, # caption
*mode_updates, # 19 mode-UI updates
)
def send_audio_to_repaint(audio_file, lm_metadata, current_lyrics, current_caption,
llm_handler=None):
"""Send generated audio to src_audio and switch mode to Repaint.
Also populate lyrics and caption fields from the generated audio,
and apply all Repaint-mode UI updates atomically (visibility, labels)
so the UI is correct without relying on a chained .change() event.
Args:
audio_file: Generated audio file path
lm_metadata: LM metadata dict (may be None)
current_lyrics: Current lyrics text in the UI
current_caption: Current caption text in the UI
llm_handler: Optional LLM handler (for think-checkbox state)
Returns:
Tuple of (src_audio, generation_mode, lyrics, caption,
*mode_ui_updates) 4 + 19 = 23 values.
"""
n_outputs = 23
if audio_file is None:
return (gr.skip(),) * n_outputs
lyrics, caption = _extract_metadata_for_editing(
lm_metadata, current_lyrics, current_caption
)
mode_updates = compute_mode_ui_updates("Repaint", llm_handler)
return (
audio_file, # src_audio
gr.update(value="Repaint"), # generation_mode -> Repaint
lyrics, # lyrics
caption, # caption
*mode_updates, # 19 mode-UI updates
)
@ -486,7 +528,7 @@ def generate_with_progress(
inference_steps, guidance_scale, random_seed_checkbox, seed,
reference_audio, audio_duration, batch_size_input, src_audio,
text2music_audio_code_string, repainting_start, repainting_end,
instruction_display_gen, audio_cover_strength, task_type,
instruction_display_gen, audio_cover_strength, cover_noise_strength, task_type,
use_adg, cfg_interval_start, cfg_interval_end, shift, infer_method, custom_timesteps, audio_format, lm_temperature,
think_checkbox, lm_cfg_scale, lm_top_k, lm_top_p, lm_negative_prompt,
use_cot_metas, use_cot_caption, use_cot_language, is_format_caption,
@ -495,9 +537,15 @@ def generate_with_progress(
auto_score,
auto_lrc,
score_scale,
lm_batch_chunk_size,
enable_normalization,
normalization_db,
latent_shift,
latent_rescale,
progress=gr.Progress(track_tqdm=True),
):
"""Generate audio with progress tracking"""
# ========== GPU Memory Validation ==========
@ -533,11 +581,14 @@ def generate_with_progress(
logger.info("[generate_with_progress] Skipping Phase 1 metas COT: sample is already formatted (is_format_caption=True)")
gr.Info(t("messages.skipping_metas_cot"))
# Parse and validate custom timesteps
parsed_timesteps, has_timesteps_warning, _ = parse_and_validate_timesteps(custom_timesteps, inference_steps)
actual_inference_steps = int(inference_steps) if inference_steps is not None else 8
# Update inference_steps if custom timesteps provided (to match UI display)
actual_inference_steps = inference_steps
if parsed_timesteps is not None:
actual_inference_steps = len(parsed_timesteps) - 1
# step 1: prepare inputs
# generate_music, GenerationParams, GenerationConfig
gen_params = GenerationParams(
@ -565,6 +616,7 @@ def generate_with_progress(
repainting_start=repainting_start,
repainting_end=repainting_end,
audio_cover_strength=audio_cover_strength,
cover_noise_strength=cover_noise_strength,
thinking=think_checkbox,
lm_temperature=lm_temperature,
lm_cfg_scale=lm_cfg_scale,
@ -575,20 +627,18 @@ def generate_with_progress(
use_cot_caption=use_cot_caption,
use_cot_language=use_cot_language,
use_constrained_decoding=True,
enable_normalization=enable_normalization,
normalization_db=normalization_db,
latent_shift=latent_shift,
latent_rescale=latent_rescale,
)
if isinstance(seed, (int, float)):
seed_list = [int(seed)] if seed >= 0 else None
elif isinstance(seed, str) and seed.strip():
# seed string to list
if isinstance(seed, str) and seed.strip():
if "," in seed:
try:
seed_list = [int(s.strip()) for s in seed.split(",")]
except (ValueError, TypeError):
seed_list = None
seed_list = [int(s.strip()) for s in seed.split(",")]
else:
try:
seed_list = [int(seed.strip())]
except (ValueError, TypeError):
seed_list = None
seed_list = [int(seed.strip())]
else:
seed_list = None
gen_config = GenerationConfig(
@ -637,6 +687,7 @@ def generate_with_progress(
seed_value=seed_value_for_ui,
inference_steps=inference_steps,
num_audios=len(result.audios) if result.success else 0,
audio_format=audio_format,
)
if not result.success:
@ -669,7 +720,7 @@ def generate_with_progress(
# Clear lrc_display with empty string - this triggers .change() to clear subtitles
clear_lrcs = [gr.update(value="", visible=True) for _ in range(8)]
clear_accordions = [gr.skip() for _ in range(8)] # Don't change accordion visibility
dump_audio = [gr.update(value=None, subtitles=None) for _ in range(8)]
dump_audio = [gr.update(value=None, subtitles=None, playback_position=0) for _ in range(8)]
yield (
# Audio outputs - just skip, value will be updated in loop
# Subtitles will be cleared via lrc_display.change()
@ -697,29 +748,33 @@ def generate_with_progress(
)
time_module.sleep(0.1)
final_codes_display_updates = [gr.skip() for _ in range(8)]
for i in range(8):
if i < len(audios):
key = audios[i]["key"]
audio_tensor = audios[i]["tensor"]
sample_rate = audios[i]["sample_rate"]
audio_params = audios[i]["params"]
is_silent = audios[i].get("silent", False)
# Use local output directory instead of system temp
timestamp = int(time_module.time())
temp_dir = os.path.join(DEFAULT_RESULTS_DIR, f"batch_{timestamp}")
temp_dir = os.path.abspath(temp_dir).replace("\\", "/")
os.makedirs(temp_dir, exist_ok=True)
json_path = os.path.join(temp_dir, f"{key}.json").replace("\\", "/")
audio_path = os.path.join(temp_dir, f"{key}.{audio_format}").replace("\\", "/")
if not is_silent and audio_tensor is not None:
save_audio(audio_data=audio_tensor, output_path=audio_path, sample_rate=sample_rate, format=audio_format, channels_first=True)
with open(json_path, 'w', encoding='utf-8') as f:
json.dump(audio_params, f, indent=2, ensure_ascii=False)
audio_outputs[i] = audio_path
all_audio_paths.append(audio_path)
all_audio_paths.append(json_path)
else:
audio_outputs[i] = None
# Handle wav32 extension
ext = "wav" if audio_format == "wav32" else audio_format
audio_path = os.path.join(temp_dir, f"{key}.{ext}").replace("\\", "/")
# Save audio and capture actual saved path
saved_path = save_audio(audio_data=audio_tensor, output_path=audio_path, sample_rate=sample_rate, format=audio_format, channels_first=True)
if saved_path:
audio_path = saved_path.replace("\\", "/")
with open(json_path, 'w', encoding='utf-8') as f:
json.dump(audio_params, f, indent=2, ensure_ascii=False)
audio_outputs[i] = audio_path
all_audio_paths.append(audio_path)
all_audio_paths.append(json_path)
code_str = audio_params.get("audio_codes", "")
final_codes_list[i] = code_str
@ -832,7 +887,6 @@ def generate_with_progress(
codes_display_updates = [gr.skip() for _ in range(8)]
codes_display_updates[i] = gr.update(value=code_str, visible=True) # Keep visible=True
final_codes_display_updates[i] = gr.update(value=code_str, visible=True) # Keep visible=True
details_accordion_updates = [gr.skip() for _ in range(8)]
# Don't change accordion visibility - keep it always expandable
@ -927,18 +981,28 @@ def generate_with_progress(
seed_value=seed_value_for_ui,
inference_steps=inference_steps,
num_audios=len(result.audios),
audio_format=audio_format,
)
# Build final codes display, LRC display, accordion visibility updates
final_codes_display_updates = [gr.skip() for _ in range(8)]
# final_lrc_display_updates = [gr.skip() for _ in range(8)]
final_accordion_updates = [gr.skip() for _ in range(8)]
# Audio was already sent in loop yields, just reset playback position to 0
# This resets the playback cursor to the beginning without reloading the audio
audio_playback_updates = [gr.update(playback_position=0) for _ in range(8)]
# On Windows, progressive yields are disabled, so we must return actual audio paths
# On other platforms, audio was already sent in loop yields, just reset playback position
# Use gr.update() to force Gradio to update the audio component (Issue #113)
audio_playback_updates = []
for idx in range(8):
path = audio_outputs[idx]
if path:
audio_playback_updates.append(gr.update(value=path, label=f"Sample {idx+1} (Ready)", interactive=False))
logger.info(f"[generate_with_progress] Audio {idx+1} path: {path}")
else:
audio_playback_updates.append(gr.update(value=None, label="None", interactive=False))
yield (
# Audio outputs - reset playback position to beginning
# Audio outputs - use gr.update() to force component refresh
audio_playback_updates[0], audio_playback_updates[1], audio_playback_updates[2], audio_playback_updates[3],
audio_playback_updates[4], audio_playback_updates[5], audio_playback_updates[6], audio_playback_updates[7],
all_audio_paths,
@ -1292,7 +1356,8 @@ def generate_lrc_handler(dit_handler, sample_idx, current_batch_index, batch_que
Tuple of (lrc_display_update, details_accordion_update, batch_queue)
Note: No audio_update - subtitles updated via lrc_display.change()
"""
import torch
if current_batch_index not in batch_queue:
return gr.skip(), gr.skip(), batch_queue
@ -1420,13 +1485,17 @@ def capture_current_params(
inference_steps, guidance_scale, random_seed_checkbox, seed,
reference_audio, audio_duration, batch_size_input, src_audio,
text2music_audio_code_string, repainting_start, repainting_end,
instruction_display_gen, audio_cover_strength, task_type,
instruction_display_gen, audio_cover_strength, cover_noise_strength, task_type,
use_adg, cfg_interval_start, cfg_interval_end, shift, infer_method, custom_timesteps, audio_format, lm_temperature,
think_checkbox, lm_cfg_scale, lm_top_k, lm_top_p, lm_negative_prompt,
use_cot_metas, use_cot_caption, use_cot_language,
constrained_decoding_debug, allow_lm_batch, auto_score, auto_lrc, score_scale, lm_batch_chunk_size,
track_name, complete_track_classes
track_name, complete_track_classes,
enable_normalization, normalization_db,
latent_shift, latent_rescale
):
"""Capture current UI parameters for next batch generation
IMPORTANT: For AutoGen batches, we clear audio codes to ensure:
@ -1453,6 +1522,7 @@ def capture_current_params(
"repainting_end": repainting_end,
"instruction_display_gen": instruction_display_gen,
"audio_cover_strength": audio_cover_strength,
"cover_noise_strength": cover_noise_strength,
"task_type": task_type,
"use_adg": use_adg,
"cfg_interval_start": cfg_interval_start,
@ -1478,16 +1548,22 @@ def capture_current_params(
"lm_batch_chunk_size": lm_batch_chunk_size,
"track_name": track_name,
"complete_track_classes": complete_track_classes,
"enable_normalization": enable_normalization,
"normalization_db": normalization_db,
"latent_shift": latent_shift,
"latent_rescale": latent_rescale,
}
def generate_with_batch_management(
dit_handler, llm_handler,
captions, lyrics, bpm, key_scale, time_signature, vocal_language,
inference_steps, guidance_scale, random_seed_checkbox, seed,
reference_audio, audio_duration, batch_size_input, src_audio,
text2music_audio_code_string, repainting_start, repainting_end,
instruction_display_gen, audio_cover_strength, task_type,
instruction_display_gen, audio_cover_strength, cover_noise_strength, task_type,
use_adg, cfg_interval_start, cfg_interval_end, shift, infer_method, custom_timesteps, audio_format, lm_temperature,
think_checkbox, lm_cfg_scale, lm_top_k, lm_top_p, lm_negative_prompt,
use_cot_metas, use_cot_caption, use_cot_language, is_format_caption,
@ -1499,7 +1575,13 @@ def generate_with_batch_management(
lm_batch_chunk_size,
track_name,
complete_track_classes,
enable_normalization,
normalization_db,
latent_shift,
latent_rescale,
autogen_checkbox,
current_batch_index,
total_batches,
batch_queue,
@ -1516,7 +1598,7 @@ def generate_with_batch_management(
inference_steps, guidance_scale, random_seed_checkbox, seed,
reference_audio, audio_duration, batch_size_input, src_audio,
text2music_audio_code_string, repainting_start, repainting_end,
instruction_display_gen, audio_cover_strength, task_type,
instruction_display_gen, audio_cover_strength, cover_noise_strength, task_type,
use_adg, cfg_interval_start, cfg_interval_end, shift, infer_method, custom_timesteps, audio_format, lm_temperature,
think_checkbox, lm_cfg_scale, lm_top_k, lm_top_p, lm_negative_prompt,
use_cot_metas, use_cot_caption, use_cot_language, is_format_caption,
@ -1526,18 +1608,28 @@ def generate_with_batch_management(
auto_lrc,
score_scale,
lm_batch_chunk_size,
enable_normalization,
normalization_db,
latent_shift,
latent_rescale,
progress
)
final_result_from_inner = None
for partial_result in generator:
final_result_from_inner = partial_result
# Forward intermediate yields to UI for progressive streaming updates
# (audio outputs appear one-by-one as they are ready)
# Pad with gr.skip() for the extra batch management outputs
yield tuple(partial_result[:46]) + (
gr.skip(), gr.skip(), gr.skip(), gr.skip(),
gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(),
)
# Progressive yields disabled on Windows to prevent UI freeze
# On other platforms, yield progress updates normally
if not IS_WINDOWS:
# current_batch_index, total_batches, batch_queue, next_params,
# batch_indicator_text, prev_btn, next_btn, next_status, restore_btn
# Slice off extra_outputs and raw_codes_list (last 2 items) before re-yielding to UI
ui_result = partial_result[:-2] if len(partial_result) > 47 else (partial_result[:-1] if len(partial_result) > 46 else partial_result)
yield ui_result + (
gr.skip(), gr.skip(), gr.skip(), gr.skip(),
gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip()
)
result = final_result_from_inner
all_audio_paths = result[8]
@ -1591,6 +1683,7 @@ def generate_with_batch_management(
"repainting_end": repainting_end,
"instruction_display_gen": instruction_display_gen,
"audio_cover_strength": audio_cover_strength,
"cover_noise_strength": cover_noise_strength,
"task_type": task_type,
"use_adg": use_adg,
"cfg_interval_start": cfg_interval_start,
@ -1615,7 +1708,13 @@ def generate_with_batch_management(
"lm_batch_chunk_size": lm_batch_chunk_size,
"track_name": track_name,
"complete_track_classes": complete_track_classes,
"enable_normalization": enable_normalization,
"normalization_db": normalization_db,
"latent_shift": latent_shift,
"latent_rescale": latent_rescale,
}
# Next batch parameters (with cleared codes & random seed)
# Next batch parameters
@ -1771,6 +1870,7 @@ def generate_next_batch_background(
params.setdefault("repainting_end", -1)
params.setdefault("instruction_display_gen", "")
params.setdefault("audio_cover_strength", 1.0)
params.setdefault("cover_noise_strength", 0.0)
params.setdefault("task_type", "text2music")
params.setdefault("use_adg", False)
params.setdefault("cfg_interval_start", 0.0)
@ -1778,7 +1878,7 @@ def generate_next_batch_background(
params.setdefault("shift", 1.0)
params.setdefault("infer_method", "ode")
params.setdefault("custom_timesteps", "")
params.setdefault("audio_format", "mp3")
params.setdefault("audio_format", "flac")
params.setdefault("lm_temperature", 0.85)
params.setdefault("think_checkbox", True)
params.setdefault("lm_cfg_scale", 2.0)
@ -1796,6 +1896,10 @@ def generate_next_batch_background(
params.setdefault("lm_batch_chunk_size", 8)
params.setdefault("track_name", None)
params.setdefault("complete_track_classes", [])
params.setdefault("enable_normalization", True)
params.setdefault("normalization_db", -1.0)
params.setdefault("latent_shift", 0.0)
params.setdefault("latent_rescale", 1.0)
# Call generate_with_progress with the saved parameters
# Note: generate_with_progress is a generator, need to iterate through it
@ -1823,6 +1927,7 @@ def generate_next_batch_background(
repainting_end=params.get("repainting_end"),
instruction_display_gen=params.get("instruction_display_gen"),
audio_cover_strength=params.get("audio_cover_strength"),
cover_noise_strength=params.get("cover_noise_strength", 0.0),
task_type=params.get("task_type"),
use_adg=params.get("use_adg"),
cfg_interval_start=params.get("cfg_interval_start"),
@ -1847,6 +1952,10 @@ def generate_next_batch_background(
auto_lrc=params.get("auto_lrc"),
score_scale=params.get("score_scale"),
lm_batch_chunk_size=params.get("lm_batch_chunk_size"),
enable_normalization=params.get("enable_normalization"),
normalization_db=params.get("normalization_db"),
latent_shift=params.get("latent_shift", 0.0),
latent_rescale=params.get("latent_rescale", 1.0),
progress=progress
)
@ -2005,7 +2114,6 @@ def navigate_to_previous_batch(current_batch_index, batch_queue):
for idx in range(8):
if idx < len(real_audio_paths):
audio_path = real_audio_paths[idx].replace("\\", "/") # Normalize path
# Pass path directly; Gradio Audio component with type="filepath" expects a string path
audio_updates.append(gr.update(value=audio_path))
else:
audio_updates.append(gr.update(value=None))
@ -2129,7 +2237,6 @@ def navigate_to_next_batch(autogen_enabled, current_batch_index, total_batches,
for idx in range(8):
if idx < len(real_audio_paths):
audio_path = real_audio_paths[idx].replace("\\", "/") # Normalize path
# Pass path directly; Gradio Audio component with type="filepath" expects a string path
audio_updates.append(gr.update(value=audio_path))
else:
audio_updates.append(gr.update(value=None))
@ -2257,6 +2364,15 @@ def restore_batch_parameters(current_batch_index, batch_queue):
track_name = params.get("track_name", None)
complete_track_classes = params.get("complete_track_classes", [])
# Audio Normalization
enable_normalization = params.get("enable_normalization", True)
normalization_db = params.get("normalization_db", -1.0)
# Latent Shift / Rescale
latent_shift = params.get("latent_shift", 0.0)
latent_rescale = params.get("latent_rescale", 1.0)
# Extract codes - only restore to single input
stored_codes = batch_data.get("codes", "")
if stored_codes:
@ -2276,5 +2392,65 @@ def restore_batch_parameters(current_batch_index, batch_queue):
vocal_language, audio_duration, batch_size_input, inference_steps,
lm_temperature, lm_cfg_scale, lm_top_k, lm_top_p, think_checkbox,
use_cot_caption, use_cot_language, allow_lm_batch,
track_name, complete_track_classes
)
track_name, complete_track_classes,
enable_normalization, normalization_db,
latent_shift, latent_rescale
)
def convert_result_audio_to_codes(dit_handler, generated_audio):
"""Convert a generated audio sample to LM audio codes.
Args:
dit_handler: DiT handler instance with convert_src_audio_to_codes method
generated_audio: File path to the generated audio
Returns:
Tuple of (codes_display_update, details_accordion_update)
"""
if not generated_audio:
gr.Warning("No audio to convert.")
return gr.skip(), gr.skip()
if not dit_handler or dit_handler.model is None:
gr.Warning(t("messages.service_not_initialized"))
return gr.skip(), gr.skip()
try:
codes_string = dit_handler.convert_src_audio_to_codes(generated_audio)
if not codes_string or codes_string.startswith(""):
gr.Warning(f"Failed to convert audio to codes: {codes_string}")
return gr.skip(), gr.skip()
gr.Info("Audio converted to codes successfully.")
return gr.update(value=codes_string), gr.update(open=True)
except Exception as e:
gr.Warning(f"Error converting audio to codes: {e}")
return gr.skip(), gr.skip()
def save_lrc_to_file(lrc_text):
"""Save LRC text to a downloadable .lrc file.
Args:
lrc_text: The LRC text content to save
Returns:
gr.update for the File component with the .lrc file path
"""
if not lrc_text or not lrc_text.strip():
gr.Warning("No LRC content to save.")
return gr.skip()
try:
# Create a temporary file with .lrc extension
tmp_dir = tempfile.mkdtemp()
lrc_path = os.path.join(tmp_dir, "lyrics.lrc")
with open(lrc_path, "w", encoding="utf-8") as f:
f.write(lrc_text)
gr.Info("LRC file ready for download.")
return gr.update(value=lrc_path, visible=True)
except Exception as e:
gr.Warning(f"Error saving LRC file: {e}")
return gr.skip()

View file

@ -9,11 +9,19 @@ import json
from typing import Any, Dict, List, Tuple, Optional
from loguru import logger
import gradio as gr
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from acestep.training.dataset_builder import DatasetBuilder, AudioSample
from acestep.debug_utils import debug_log_for, debug_start_for, debug_end_for
from acestep.gpu_config import get_global_gpu_config
# Define a safe root directory for all training-related filesystem operations.
# This limits user-supplied paths (for checkpoints, exports, etc.) to stay
# within the server's working tree, preventing directory traversal outside it.
SAFE_TRAINING_ROOT = os.path.abspath(os.getcwd())
def create_dataset_builder() -> DatasetBuilder:
"""Create a new DatasetBuilder instance."""
@ -42,7 +50,7 @@ def scan_directory(
Tuple of (table_data, status, slider_update, builder_state)
"""
if not audio_dir or not audio_dir.strip():
return [], "<EFBFBD> Please enter a directory path", _safe_slider(0, value=0, visible=False), builder_state
return [], " Please enter a directory path", _safe_slider(0, value=0, visible=False), builder_state
# Create or use existing builder
builder = builder_state if builder_state else DatasetBuilder()
@ -99,17 +107,17 @@ def auto_label_all(
Tuple of (table_data, status, builder_state)
"""
if builder_state is None:
return [], "<EFBFBD> Please scan a directory first", builder_state
return [], " Please scan a directory first", builder_state
if not builder_state.samples:
return [], "<EFBFBD> No samples to label. Please scan a directory first.", builder_state
return [], " No samples to label. Please scan a directory first.", builder_state
# Check if handlers are initialized
if dit_handler is None or dit_handler.model is None:
return builder_state.get_samples_dataframe_data(), "<EFBFBD> Model not initialized. Please initialize the service first.", builder_state
return builder_state.get_samples_dataframe_data(), " Model not initialized. Please initialize the service first.", builder_state
if llm_handler is None or not llm_handler.llm_initialized:
return builder_state.get_samples_dataframe_data(), "<EFBFBD> LLM not initialized. Please initialize the service with LLM enabled.", builder_state
return builder_state.get_samples_dataframe_data(), " LLM not initialized. Please initialize the service with LLM enabled.", builder_state
def progress_callback(msg):
if progress:
@ -210,7 +218,7 @@ def save_sample_edit(
Tuple of (table_data, status, builder_state)
"""
if builder_state is None:
return [], "<EFBFBD> No dataset loaded", builder_state
return [], " No dataset loaded", builder_state
idx = int(sample_idx)
@ -281,13 +289,13 @@ def save_dataset(
Status message
"""
if builder_state is None:
return "<EFBFBD> No dataset to save. Please scan a directory first.", gr.update()
return " No dataset to save. Please scan a directory first.", gr.update()
if not builder_state.samples:
return "<EFBFBD> No samples in dataset.", gr.update()
return " No samples in dataset.", gr.update()
if not save_path or not save_path.strip():
return "<EFBFBD> Please enter a save path.", gr.update()
return " Please enter a save path.", gr.update()
save_path = save_path.strip()
if not save_path.lower().endswith(".json"):
@ -296,7 +304,7 @@ def save_dataset(
# Check if any samples are labeled
labeled_count = builder_state.get_labeled_count()
if labeled_count == 0:
return "<EFBFBD> Warning: No samples have been labeled. Consider auto-labeling first.\nSaving anyway...", gr.update(value=save_path)
return " Warning: No samples have been labeled. Consider auto-labeling first.\nSaving anyway...", gr.update(value=save_path)
return builder_state.save_dataset(save_path, dataset_name), gr.update(value=save_path)
@ -321,14 +329,14 @@ def load_existing_dataset_for_preprocess(
if not dataset_path or not dataset_path.strip():
updates = (gr.update(), gr.update(), gr.update(), gr.update(), gr.update())
return ("<EFBFBD> Please enter a dataset path", [], _safe_slider(0, value=0, visible=False), builder_state) + empty_preview + updates
return (" Please enter a dataset path", [], _safe_slider(0, value=0, visible=False), builder_state) + empty_preview + updates
dataset_path = dataset_path.strip()
debug_log_for("dataset", f"UI load_existing_dataset_for_preprocess: path='{dataset_path}'")
if not os.path.exists(dataset_path):
updates = (gr.update(), gr.update(), gr.update(), gr.update(), gr.update())
return (f"<EFBFBD> Dataset not found: {dataset_path}", [], _safe_slider(0, value=0, visible=False), builder_state) + empty_preview + updates
return (f" Dataset not found: {dataset_path}", [], _safe_slider(0, value=0, visible=False), builder_state) + empty_preview + updates
# Create new builder (don't reuse old state when loading a file)
builder = DatasetBuilder()
@ -350,12 +358,12 @@ def load_existing_dataset_for_preprocess(
# Create info text
labeled_count = builder.get_labeled_count()
info = f"<EFBFBD> Loaded dataset: {builder.metadata.name}\n"
info += f"<EFBFBD> Samples: {len(samples)} ({labeled_count} labeled)\n"
info += f"<EFBFBD><EFBFBD><EFBFBD> Custom Tag: {builder.metadata.custom_tag or '(none)'}\n"
info += "<EFBFBD> Ready for preprocessing! You can also edit samples below."
info = f"📂 Loaded dataset: {builder.metadata.name}\n"
info += f"🔢 Samples: {len(samples)} ({labeled_count} labeled)\n"
info += f"🏷 Custom Tag: {builder.metadata.custom_tag or '(none)'}\n"
info += " Ready for preprocessing! You can also edit samples below."
if any((s.formatted_lyrics and not s.lyrics) for s in builder.samples):
info += "\n<EFBFBD> Showing formatted lyrics where lyrics are empty."
info += "\n Showing formatted lyrics where lyrics are empty."
# Get first sample preview
first_sample = builder.samples[0]
@ -401,6 +409,7 @@ def load_existing_dataset_for_preprocess(
def preprocess_dataset(
output_dir: str,
preprocess_mode: str,
dit_handler,
builder_state: Optional[DatasetBuilder],
progress=None,
@ -413,20 +422,20 @@ def preprocess_dataset(
Status message
"""
if builder_state is None:
return "<EFBFBD> No dataset loaded. Please scan a directory first."
return " No dataset loaded. Please scan a directory first."
if not builder_state.samples:
return "<EFBFBD> No samples in dataset."
return " No samples in dataset."
labeled_count = builder_state.get_labeled_count()
if labeled_count == 0:
return "<EFBFBD> No labeled samples. Please auto-label or manually label samples first."
return " No labeled samples. Please auto-label or manually label samples first."
if not output_dir or not output_dir.strip():
return "<EFBFBD> Please enter an output directory."
return " Please enter an output directory."
if dit_handler is None or dit_handler.model is None:
return "<EFBFBD> Model not initialized. Please initialize the service first."
return " Model not initialized. Please initialize the service first."
def progress_callback(msg):
if progress:
@ -436,10 +445,15 @@ def preprocess_dataset(
pass
# Run preprocessing
mode = str(preprocess_mode or "lora").strip().lower()
if mode not in {"lora", "lokr"}:
mode = "lora"
t0 = debug_start_for("dataset", "preprocess_to_tensors")
output_paths, status = builder_state.preprocess_to_tensors(
dit_handler=dit_handler,
output_dir=output_dir.strip(),
preprocess_mode=mode,
progress_callback=progress_callback,
)
debug_end_for("dataset", "preprocess_to_tensors", t0)
@ -456,15 +470,15 @@ def load_training_dataset(
Info text about the dataset
"""
if not tensor_dir or not tensor_dir.strip():
return "<EFBFBD> Please enter a tensor directory path"
return " Please enter a tensor directory path"
tensor_dir = tensor_dir.strip()
if not os.path.exists(tensor_dir):
return f"<EFBFBD> Directory not found: {tensor_dir}"
return f" Directory not found: {tensor_dir}"
if not os.path.isdir(tensor_dir):
return f"<EFBFBD> Not a directory: {tensor_dir}"
return f" Not a directory: {tensor_dir}"
# Check for manifest
manifest_path = os.path.join(tensor_dir, "manifest.json")
@ -478,9 +492,9 @@ def load_training_dataset(
name = metadata.get("name", "Unknown")
custom_tag = metadata.get("custom_tag", "")
info = f"<EFBFBD> Loaded preprocessed dataset: {name}\n"
info += f"<EFBFBD> Samples: {num_samples} preprocessed tensors\n"
info += f"<EFBFBD><EFBFBD><EFBFBD> Custom Tag: {custom_tag or '(none)'}"
info = f"📂 Loaded preprocessed dataset: {name}\n"
info += f"🔢 Samples: {num_samples} preprocessed tensors\n"
info += f"🏷 Custom Tag: {custom_tag or '(none)'}"
return info
except Exception as e:
@ -490,10 +504,10 @@ def load_training_dataset(
pt_files = [f for f in os.listdir(tensor_dir) if f.endswith('.pt')]
if not pt_files:
return f"<EFBFBD> No .pt tensor files found in {tensor_dir}"
return f" No .pt tensor files found in {tensor_dir}"
info = f"<EFBFBD> Found {len(pt_files)} tensor files in {tensor_dir}\n"
info += "<EFBFBD> No manifest.json found - using all .pt files"
info = f"📂 Found {len(pt_files)} tensor files in {tensor_dir}\n"
info += " No manifest.json found - using all .pt files"
return info
@ -515,6 +529,43 @@ def _format_duration(seconds):
return f"{seconds // 3600}h {(seconds % 3600) // 60}m"
def _training_loss_figure(
training_state: Dict,
step_list: List[int],
loss_list: List[float],
) -> Optional[Any]:
"""Build a training/validation loss plot (matplotlib Figure) for gr.Plot."""
steps = training_state.get("plot_steps") or step_list
loss = training_state.get("plot_loss") or loss_list
if not steps or not loss:
fig, ax = plt.subplots(figsize=(6, 3))
ax.set_xlabel("Step")
ax.set_ylabel("Loss")
ax.set_title("Training loss")
fig.tight_layout()
return fig
ema = training_state.get("plot_ema")
val_steps = training_state.get("plot_val_steps") or []
val_loss = training_state.get("plot_val_loss") or []
best_step = training_state.get("plot_best_step")
fig, ax = plt.subplots(figsize=(6, 3))
ax.plot(steps, loss, color="tab:blue", alpha=0.35, label="Loss (raw)", linewidth=1)
if ema and len(ema) == len(steps):
ax.plot(steps, ema, color="tab:blue", alpha=1.0, label="Loss (smoothed)", linewidth=1.5)
if val_steps and val_loss:
ax.scatter(val_steps, val_loss, color="tab:orange", s=24, zorder=5, label="Validation")
if best_step is not None:
ax.axvline(x=best_step, color="tab:green", linestyle="--", alpha=0.8, label="Best checkpoint")
ax.set_xlabel("Step")
ax.set_ylabel("Loss")
ax.set_title("Training loss")
ax.legend(loc="upper right", fontsize=8)
ax.grid(True, alpha=0.3)
fig.tight_layout()
return fig
def start_training(
tensor_dir: str,
dit_handler,
@ -538,17 +589,17 @@ def start_training(
This is a generator function that yields progress updates.
"""
if not tensor_dir or not tensor_dir.strip():
yield "<EFBFBD> Please enter a tensor directory path", "", None, training_state
yield " Please enter a tensor directory path", "", None, training_state
return
tensor_dir = tensor_dir.strip()
if not os.path.exists(tensor_dir):
yield f"<EFBFBD> Tensor directory not found: {tensor_dir}", "", None, training_state
yield f" Tensor directory not found: {tensor_dir}", "", None, training_state
return
if dit_handler is None or dit_handler.model is None:
yield "<EFBFBD> Model not initialized. Please initialize the service first.", "", None, training_state
yield " Model not initialized. Please initialize the service first.", "", None, training_state
return
# Training preset: LoRA training must run on non-quantized DiT.
@ -575,11 +626,11 @@ def start_training(
if hasattr(dit_handler, "switch_to_training_preset"):
switch_status, switched = dit_handler.switch_to_training_preset()
if not switched:
yield f"� {switch_status}", "", None, training_state
yield f" {switch_status}", "", None, training_state
return
yield f"� {switch_status}", "", None, training_state
yield f" {switch_status}", "", None, training_state
else:
yield "� Training requires non-quantized DiT, and auto-switch is unavailable in this build.", "", None, training_state
yield " Training requires non-quantized DiT, and auto-switch is unavailable in this build.", "", None, training_state
return
# Check for required training dependencies
@ -587,7 +638,7 @@ def start_training(
from lightning.fabric import Fabric
from peft import get_peft_model, LoraConfig
except ImportError as e:
yield f"<EFBFBD> Missing required packages: {e}\nPlease install: pip install peft lightning", "", None, training_state
yield f" Missing required packages: {e}\nPlease install: pip install peft lightning", "", None, training_state
return
training_state["is_training"] = True
@ -623,14 +674,14 @@ def start_training(
pin_memory = True
prefetch_factor = 2
persistent_workers = True
pin_memory_device = None
pin_memory_device = ""
mixed_precision = "bf16"
elif device_type == "mps":
num_workers = 0
pin_memory = False
prefetch_factor = 2
persistent_workers = False
pin_memory_device = None
pin_memory_device = ""
mixed_precision = "fp16"
else:
cpu_count = os.cpu_count() or 2
@ -638,7 +689,7 @@ def start_training(
pin_memory = False
prefetch_factor = 2
persistent_workers = num_workers > 0
pin_memory_device = None
pin_memory_device = ""
mixed_precision = "fp32"
logger.info(
@ -663,16 +714,16 @@ def start_training(
mixed_precision=mixed_precision,
)
import pandas as pd
# Initialize training log and loss history
log_lines = []
loss_data = pd.DataFrame({"step": [0], "loss": [0.0]})
step_list = []
loss_list = []
initial_plot = _training_loss_figure(training_state, step_list, loss_list)
# Start timer
start_time = time.time()
yield f"<EFBFBD> Starting training from {tensor_dir}...", "", loss_data, training_state
yield f"🚀 Starting training from {tensor_dir}...", "", initial_plot, training_state
# Create trainer
trainer = LoRATrainer(
@ -681,9 +732,6 @@ def start_training(
training_config=training_config,
)
# Collect loss history
step_list = []
loss_list = []
training_failed = False
failure_message = ""
@ -731,37 +779,64 @@ def start_training(
if step > 0 and loss is not None and loss == loss: # Check for NaN
step_list.append(step)
loss_list.append(float(loss))
loss_data = pd.DataFrame({"step": step_list, "loss": loss_list})
yield display_status, log_text, loss_data, training_state
plot_figure = _training_loss_figure(training_state, step_list, loss_list)
yield display_status, log_text, plot_figure, training_state
if training_state.get("should_stop", False):
logger.info("⏹️ Training stopped by user")
log_lines.append("⏹️ Training stopped by user")
yield f"⏹️ Stopped ({time_info})", "\n".join(log_lines[-15:]), loss_data, training_state
yield f"⏹️ Stopped ({time_info})", "\n".join(log_lines[-15:]), plot_figure, training_state
break
total_time = time.time() - start_time
training_state["is_training"] = False
final_plot = _training_loss_figure(training_state, step_list, loss_list)
if training_failed:
final_msg = f"{failure_message}\nElapsed: {_format_duration(total_time)}"
logger.warning(final_msg)
log_lines.append(failure_message)
yield final_msg, "\n".join(log_lines[-15:]), loss_data, training_state
yield final_msg, "\n".join(log_lines[-15:]), final_plot, training_state
return
completion_msg = f"<EFBFBD> Training completed! Total time: {_format_duration(total_time)}"
completion_msg = f" Training completed! Total time: {_format_duration(total_time)}"
logger.info(completion_msg)
log_lines.append(completion_msg)
yield completion_msg, "\n".join(log_lines[-15:]), loss_data, training_state
yield completion_msg, "\n".join(log_lines[-15:]), final_plot, training_state
except Exception as e:
logger.exception("Training error")
training_state["is_training"] = False
import pandas as pd
empty_df = pd.DataFrame({"step": [], "loss": []})
yield f"<EFBFBD> Error: {str(e)}", str(e), empty_df, training_state
yield f"❌ Error: {str(e)}", str(e), _training_loss_figure({}, [], []), training_state
def _safe_join(base_root: str, user_path: str) -> Optional[str]:
"""Safely join a user-supplied path to a base root, preventing escapes.
Returns an absolute path within base_root, or None if the path is invalid.
"""
if not user_path:
return None
# Normalize whitespace
candidate = user_path.strip()
if not candidate:
return None
# Disallow absolute paths outright; force everything under base_root
if os.path.isabs(candidate):
return None
# Join with base_root and normalize to an absolute path
abs_root = os.path.abspath(base_root)
joined = os.path.abspath(os.path.join(abs_root, candidate))
try:
common = os.path.commonpath([abs_root, joined])
except ValueError:
# Different drives on Windows or other path issues
return None
if common != abs_root:
# Attempted to escape the allowed root
return None
return joined
def stop_training(training_state: Dict) -> Tuple[str, Dict]:
@ -771,7 +846,7 @@ def stop_training(training_state: Dict) -> Tuple[str, Dict]:
Tuple of (status, training_state)
"""
if not training_state.get("is_training", False):
return "<EFBFBD> No training in progress", training_state
return " No training in progress", training_state
training_state["should_stop"] = True
return "⏹️ Stopping training...", training_state
@ -787,11 +862,16 @@ def export_lora(
Status message
"""
if not export_path or not export_path.strip():
return "<EFBFBD> Please enter an export path"
return "❌ Please enter an export path"
# Resolve and validate the base LoRA output directory within the safe root
safe_lora_dir = _safe_join(SAFE_TRAINING_ROOT, lora_output_dir)
if safe_lora_dir is None:
return "❌ Invalid LoRA output directory"
# Check if there's a trained model to export
final_dir = os.path.join(lora_output_dir, "final")
checkpoint_dir = os.path.join(lora_output_dir, "checkpoints")
final_dir = os.path.join(safe_lora_dir, "final")
checkpoint_dir = os.path.join(safe_lora_dir, "checkpoints")
# Prefer final, fallback to checkpoints
if os.path.exists(final_dir):
@ -800,30 +880,337 @@ def export_lora(
# Find the latest checkpoint
checkpoints = [d for d in os.listdir(checkpoint_dir) if d.startswith("epoch_")]
if not checkpoints:
return "<EFBFBD> No checkpoints found"
return " No checkpoints found"
checkpoints.sort(key=lambda x: int(x.split("_")[1]))
latest = checkpoints[-1]
source_path = os.path.join(checkpoint_dir, latest)
else:
return f"<EFBFBD> No trained model found in {lora_output_dir}"
return f"❌ No trained model found in {lora_output_dir}"
# Resolve and validate the export destination within the safe root
safe_export_path = _safe_join(SAFE_TRAINING_ROOT, export_path)
if safe_export_path is None:
return "❌ Invalid export path"
try:
import shutil
export_path = export_path.strip()
os.makedirs(os.path.dirname(export_path) if os.path.dirname(export_path) else ".", exist_ok=True)
# Ensure parent directory exists
parent_dir = os.path.dirname(safe_export_path) or "."
os.makedirs(parent_dir, exist_ok=True)
if os.path.exists(export_path):
shutil.rmtree(export_path)
if os.path.exists(safe_export_path):
shutil.rmtree(safe_export_path)
shutil.copytree(source_path, export_path)
shutil.copytree(source_path, safe_export_path)
return f"<EFBFBD> LoRA exported to {export_path}"
return f"✅ LoRA exported to {safe_export_path}"
except Exception as e:
logger.exception("Export error")
return f"<EFBFBD> Export failed: {str(e)}"
return f" Export failed: {str(e)}"
def start_lokr_training(
tensor_dir: str,
dit_handler,
lokr_linear_dim: int,
lokr_linear_alpha: int,
lokr_factor: int,
lokr_decompose_both: bool,
lokr_use_tucker: bool,
lokr_use_scalar: bool,
lokr_weight_decompose: bool,
learning_rate: float,
train_epochs: int,
train_batch_size: int,
gradient_accumulation: int,
save_every_n_epochs: int,
training_shift: float,
training_seed: int,
lokr_output_dir: str,
training_state: Dict,
progress=None,
):
"""Start LoKr training from preprocessed tensors."""
if not tensor_dir or not tensor_dir.strip():
yield "❌ Please enter a tensor directory path", "", None, training_state
return
tensor_dir = tensor_dir.strip()
if not os.path.exists(tensor_dir):
yield f"❌ Tensor directory not found: {tensor_dir}", "", None, training_state
return
if dit_handler is None or dit_handler.model is None:
yield "❌ Model not initialized. Please initialize the service first.", "", None, training_state
return
if getattr(dit_handler, "quantization", None) is not None:
yield "Switching model to training preset (disable quantization)...", "", None, training_state
if hasattr(dit_handler, "switch_to_training_preset"):
switch_status, switched = dit_handler.switch_to_training_preset()
if not switched:
yield f"{switch_status}", "", None, training_state
return
yield f"{switch_status}", "", None, training_state
else:
yield "❌ Training requires non-quantized DiT, and auto-switch is unavailable in this build.", "", None, training_state
return
try:
from lightning.fabric import Fabric # noqa: F401
except ImportError as e:
yield f"❌ Missing required packages: {e}\nPlease install: pip install lightning lycoris-lora", "", None, training_state
return
training_state["is_training"] = True
training_state["should_stop"] = False
training_state["adapter_type"] = "lokr"
try:
from acestep.training.configs import LoKRConfig as LoKRConfigClass, TrainingConfig
from acestep.training.trainer import LoKRTrainer
device_attr = getattr(dit_handler, "device", "")
if hasattr(device_attr, "type"):
device_type = str(device_attr.type).lower()
else:
device_type = str(device_attr).split(":", 1)[0].lower()
if device_type == "cuda":
num_workers = 4
pin_memory = True
prefetch_factor = 2
persistent_workers = True
pin_memory_device = "cuda"
mixed_precision = "bf16"
elif device_type == "xpu":
num_workers = 4
pin_memory = True
prefetch_factor = 2
persistent_workers = True
pin_memory_device = ""
mixed_precision = "bf16"
elif device_type == "mps":
num_workers = 0
pin_memory = False
prefetch_factor = 2
persistent_workers = False
pin_memory_device = ""
mixed_precision = "fp16"
else:
num_workers = 0
pin_memory = False
prefetch_factor = 2
persistent_workers = False
pin_memory_device = ""
mixed_precision = "fp32"
lokr_config = LoKRConfigClass(
linear_dim=lokr_linear_dim,
linear_alpha=lokr_linear_alpha,
factor=lokr_factor,
decompose_both=lokr_decompose_both,
use_tucker=lokr_use_tucker,
use_scalar=lokr_use_scalar,
weight_decompose=lokr_weight_decompose,
)
training_config = TrainingConfig(
shift=training_shift,
learning_rate=learning_rate,
batch_size=train_batch_size,
gradient_accumulation_steps=gradient_accumulation,
max_epochs=train_epochs,
save_every_n_epochs=save_every_n_epochs,
seed=training_seed,
output_dir=lokr_output_dir,
num_workers=num_workers,
pin_memory=pin_memory,
prefetch_factor=prefetch_factor,
persistent_workers=persistent_workers,
pin_memory_device=pin_memory_device,
mixed_precision=mixed_precision,
)
log_lines = []
step_list = []
loss_list = []
initial_plot = _training_loss_figure(training_state, step_list, loss_list)
start_time = time.time()
yield f"🚀 Starting LoKr training from {tensor_dir}...", "", initial_plot, training_state
trainer = LoKRTrainer(
dit_handler=dit_handler,
lokr_config=lokr_config,
training_config=training_config,
)
training_failed = False
failure_message = ""
for step, loss, status in trainer.train_from_preprocessed(tensor_dir, training_state):
status_text = str(status)
status_lower = status_text.lower()
if (
status_text.startswith("")
or "training failed" in status_lower
or "error:" in status_lower
or "module not found" in status_lower
):
training_failed = True
failure_message = status_text
elapsed_seconds = time.time() - start_time
time_info = f"⏱️ Elapsed: {_format_duration(elapsed_seconds)}"
match = re.search(r"Epoch\s+(\d+)/(\d+)", status_text)
if match:
current_ep = int(match.group(1))
total_ep = int(match.group(2))
if current_ep > 0:
eta_seconds = (elapsed_seconds / current_ep) * (total_ep - current_ep)
time_info += f" | ETA: ~{_format_duration(eta_seconds)}"
display_status = f"{status_text}\n{time_info}"
log_lines.append(status_text)
if len(log_lines) > 15:
log_lines = log_lines[-15:]
log_text = "\n".join(log_lines)
if step > 0 and loss is not None and loss == loss:
step_list.append(step)
loss_list.append(float(loss))
plot_figure = _training_loss_figure(training_state, step_list, loss_list)
yield display_status, log_text, plot_figure, training_state
if training_state.get("should_stop", False):
log_lines.append("⏹️ Training stopped by user")
yield f"⏹️ Stopped ({time_info})", "\n".join(log_lines[-15:]), plot_figure, training_state
break
total_time = time.time() - start_time
training_state["is_training"] = False
final_plot = _training_loss_figure(training_state, step_list, loss_list)
if training_failed:
final_msg = f"{failure_message}\nElapsed: {_format_duration(total_time)}"
log_lines.append(failure_message)
yield final_msg, "\n".join(log_lines[-15:]), final_plot, training_state
return
completion_msg = f"✅ LoKr training completed! Total time: {_format_duration(total_time)}"
log_lines.append(completion_msg)
yield completion_msg, "\n".join(log_lines[-15:]), final_plot, training_state
except Exception as e:
logger.exception("LoKr training error")
training_state["is_training"] = False
yield f"❌ Error: {str(e)}", str(e), _training_loss_figure({}, [], []), training_state
def list_lokr_export_epochs(lokr_output_dir: str) -> Tuple[Any, str]:
"""List available LoKr checkpoint epochs for export dropdown."""
default_choice = "Latest (auto)"
if not lokr_output_dir or not lokr_output_dir.strip():
return gr.update(choices=[default_choice], value=default_choice), "⚠️ Enter LoKr output directory first"
checkpoint_dir = os.path.join(lokr_output_dir.strip(), "checkpoints")
if not os.path.isdir(checkpoint_dir):
return gr.update(choices=[default_choice], value=default_choice), " No checkpoints found; export will use latest available weights"
checkpoints = []
for d in os.listdir(checkpoint_dir):
if not d.startswith("epoch_"):
continue
weight_file = os.path.join(checkpoint_dir, d, "lokr_weights.safetensors")
if not os.path.exists(weight_file):
continue
try:
epoch_num = int(d.split("_")[1])
except Exception:
continue
checkpoints.append((epoch_num, d))
if not checkpoints:
return gr.update(choices=[default_choice], value=default_choice), " No exportable epoch checkpoints found"
checkpoints.sort(key=lambda x: x[0], reverse=True)
choices = [default_choice] + [d for _, d in checkpoints]
return gr.update(choices=choices, value=default_choice), f"✅ Found {len(checkpoints)} LoKr checkpoints"
def export_lokr(
export_path: str,
lokr_output_dir: str,
selected_epoch: Optional[str] = None,
) -> str:
"""Export trained LoKr weights.
Returns:
Status message
"""
if not export_path or not export_path.strip():
return "❌ Please enter an export path"
final_dir = os.path.join(lokr_output_dir, "final")
checkpoint_dir = os.path.join(lokr_output_dir, "checkpoints")
default_epoch_choice = "Latest (auto)"
chosen_epoch = (selected_epoch or "").strip()
if not chosen_epoch:
chosen_epoch = default_epoch_choice
checkpoint_names: List[str] = []
if os.path.isdir(checkpoint_dir):
for d in os.listdir(checkpoint_dir):
if not d.startswith("epoch_"):
continue
try:
int(d.split("_")[1])
except Exception:
continue
checkpoint_names.append(d)
checkpoint_names.sort(key=lambda x: int(x.split("_")[1]))
# Determine source
explicit_epoch = chosen_epoch not in {default_epoch_choice, "latest", "Latest", "auto", "Auto"}
if explicit_epoch:
requested = chosen_epoch
if requested.isdigit():
requested = f"epoch_{requested}"
if requested not in checkpoint_names:
return (
f"❌ Selected epoch not found: {chosen_epoch}. "
f"Available: {', '.join(checkpoint_names) if checkpoint_names else '(none)'}"
)
source_file = os.path.join(checkpoint_dir, requested, "lokr_weights.safetensors")
if not os.path.exists(source_file):
return f"❌ No LoKr weights found for selected epoch: {requested}"
elif os.path.exists(os.path.join(final_dir, "lokr_weights.safetensors")):
source_file = os.path.join(final_dir, "lokr_weights.safetensors")
elif checkpoint_names:
latest_checkpoint = checkpoint_names[-1]
source_file = os.path.join(checkpoint_dir, latest_checkpoint, "lokr_weights.safetensors")
if not os.path.exists(source_file):
return f"❌ No LoKr weights found in latest checkpoint: {latest_checkpoint}"
else:
return f"❌ No trained LoKr weights found in {lokr_output_dir}"
try:
import shutil
export_path = export_path.strip()
if export_path.lower().endswith(".safetensors"):
os.makedirs(os.path.dirname(export_path) if os.path.dirname(export_path) else ".", exist_ok=True)
shutil.copy2(source_file, export_path)
return f"✅ LoKr exported to {export_path}"
os.makedirs(export_path, exist_ok=True)
dst_file = os.path.join(export_path, "lokr_weights.safetensors")
shutil.copy2(source_file, dst_file)
return f"✅ LoKr exported to {dst_file}"
except Exception as e:
logger.exception("LoKr export error")
return f"❌ Export failed: {str(e)}"

View file

@ -53,12 +53,19 @@
"compile_model_info": "Use torch.compile to optimize model (required for quantization)",
"quantization_label": "INT8 Quantization",
"quantization_info": "Enable INT8 weight-only quantization to reduce VRAM usage (requires Compile Model)",
"mlx_dit_label": "MLX DiT (Apple Silicon)",
"mlx_dit_info_enabled": "Use native MLX for DiT diffusion on Apple Silicon (faster than MPS)",
"mlx_dit_info_disabled": "MLX not available (requires macOS + Apple Silicon + mlx package)",
"init_btn": "Initialize Service",
"status_label": "Status",
"language_label": "UI Language",
"language_info": "Select interface language"
"language_info": "Select interface language",
"gpu_auto_tier": "Auto-detected Tier",
"tier_label": "GPU Tier Override",
"tier_info": "Manually select GPU tier to adjust optimization defaults (offload, quantization, backend, etc.)"
},
"generation": {
"tab_title": "🎵 Generation",
"required_inputs": "📝 Required Inputs",
"task_type_label": "Task Type",
"task_type_info": "Select the task type for generation",
@ -71,8 +78,11 @@
"track_classes_info": "Select multiple track classes for complete task",
"audio_uploads": "🎵 Audio Uploads",
"reference_audio": "Reference Audio (optional)",
"source_audio": "Source Audio (optional)",
"source_audio": "Source Audio",
"convert_codes_btn": "Convert to Codes",
"analyze_btn": "🔍 Analyze",
"sample_btn": "🎲 Click Me",
"load_btn": "📂 Load",
"lm_codes_hints": "🎼 LM Codes Hints",
"lm_codes_label": "LM Codes Hints",
"lm_codes_placeholder": "<|audio_code_10695|><|audio_code_54246|>...",
@ -84,13 +94,20 @@
"repainting_start": "Repainting Start",
"repainting_end": "Repainting End",
"mode_label": "Generation Mode",
"mode_info": "Simple: describe music in natural language. Custom: full control over caption and lyrics.",
"mode_info": "Select a generation mode to get started.",
"mode_info_simple": "Describe your music in natural language. AI will generate caption, lyrics and metadata for you.",
"mode_info_custom": "Full control over caption, lyrics and all parameters.",
"mode_info_remix": "Upload source audio to create a remix version with your caption and lyrics.",
"mode_info_repaint": "Upload source audio and repaint a specific time range.",
"mode_info_extract": "Extract a specific track (vocals, drums, etc.) from source audio.",
"mode_info_lego": "Reassemble tracks: replace a specific track in the source audio.",
"mode_info_complete": "Complete missing tracks in the source audio.",
"mode_simple": "Simple",
"mode_custom": "Custom",
"simple_query_label": "Song Description",
"simple_query_placeholder": "Describe the music you want to create, e.g., 'a soft Bengali love song for a quiet evening'. Leave empty for a random sample.",
"simple_query_info": "Enter a natural language description of the music you want to generate",
"simple_vocal_language_label": "Vocal Language (optional)",
"simple_vocal_language_label": "Vocal Language",
"simple_vocal_language_info": "Select preferred language(s) for lyrics. Use 'unknown' for any language.",
"create_sample_btn": "Create Sample",
"caption_title": "📝 Music Caption",
@ -103,12 +120,20 @@
"lyrics_info": "Song lyrics with structure",
"instrumental_label": "Instrumental",
"format_btn": "Format",
"format_caption_btn": "Enhance Caption",
"format_lyrics_btn": "Enhance Lyrics",
"optional_params": "⚙️ Optional Parameters",
"optional_music_props": "🎵 Music Properties",
"optional_gen_settings": "📐 Generation Settings",
"advanced_dit_section": "🎛️ DiT Diffusion",
"advanced_lm_section": "🤖 LM Generation",
"advanced_output_section": "🔊 Audio Output & Post-processing",
"advanced_automation_section": "⚡ Automation & Batch",
"vocal_language_label": "Vocal Language (optional)",
"vocal_language_info": "use `unknown` for inst",
"vocal_language_info": "'unknown' = instrumental / auto",
"bpm_label": "BPM (optional)",
"bpm_info": "leave empty for N/A",
"keyscale_label": "KeyScale (optional)",
"keyscale_label": "Key (optional)",
"keyscale_placeholder": "Leave empty for N/A",
"keyscale_info": "A-G, #/♭, major/minor",
"timesig_label": "Time Signature (optional)",
@ -117,7 +142,7 @@
"duration_info": "Use -1 for random",
"batch_size_label": "Batch Size",
"batch_size_info": "Number of audio to generate (max 8)",
"advanced_settings": "🔧 Advanced Settings",
"advanced_settings": "⚙️ Settings",
"inference_steps_label": "DiT Inference Steps",
"inference_steps_info": "Turbo: max 8, Base: max 200",
"guidance_scale_label": "DiT Guidance Scale (Only support for base model)",
@ -150,6 +175,7 @@
"lm_negative_prompt_label": "LM Negative Prompt",
"lm_negative_prompt_placeholder": "Enter negative prompt for CFG (default: NO USER INPUT)",
"lm_negative_prompt_info": "Negative prompt (use when LM CFG Scale > 1.0)",
"advanced_dit_params": "Advanced DiT Parameters",
"cot_metas_label": "CoT Metas",
"cot_metas_info": "Use LM to generate CoT metadata (uncheck to skip LM CoT generation)",
"cot_language_label": "CoT Language",
@ -164,25 +190,32 @@
"lm_batch_chunk_info": "Max items per LM batch chunk (default: 8, limited by GPU memory)",
"codes_strength_label": "LM Codes Strength",
"codes_strength_info": "Control how many denoising steps use LM-generated codes",
"similarity_denoise_label": "Similarity / Denoise",
"similarity_denoise_info": "Controls how closely the output follows the reference audio. Higher values preserve more structure.",
"cover_strength_label": "Audio Cover Strength",
"cover_strength_info": "Control how many denoising steps use cover mode",
"remix_strength_label": "Remix Strength",
"remix_strength_info": "Control how much the remix deviates from the source audio (lower = closer to original)",
"cover_noise_strength_label": "Cover Strength",
"cover_noise_strength_info": "Controls melody retention in Remix mode. Recommended: use the SFT model with a value of 0.10.25. A small increase restores the melody, but style transfer may require additional prompt tuning. (0 = pure noise/no cover, 1 = closest to original audio)",
"score_sensitivity_label": "Quality Score Sensitivity",
"score_sensitivity_info": "Lower = more sensitive (default: 1.0). Adjusts how PMI maps to [0,1]",
"think_label": "Think",
"parallel_thinking_label": "ParallelThinking",
"parallel_thinking_info": "Process batch samples in parallel for faster generation",
"generate_btn": "🎵 Generate Music",
"autogen_label": "AutoGen",
"caption_rewrite_label": "CaptionRewrite"
"caption_rewrite_label": "CaptionRewrite",
"caption_rewrite_info": "Use LM to rewrite caption before generation"
},
"results": {
"title": "🎵 Results",
"generated_music": "🎵 Generated Music (Sample {n})",
"send_to_src_btn": "🔗 Send To Src Audio",
"send_to_remix_btn": "🔗 Send To Remix",
"send_to_repaint_btn": "🔗 Send To Repaint",
"save_btn": "💾 Save",
"score_btn": "📊 Score",
"lrc_btn": "🎵 LRC",
"score_btn": "📊 Get Score",
"lrc_btn": "🎵 Get LRC",
"save_lrc_btn": "💾 Save LRC",
"convert_to_codes_btn": "🔄 Convert To Codes",
"quality_score_label": "Quality Score (Sample {n})",
"quality_score_placeholder": "Click 'Score' to calculate perplexity-based quality score",
"codes_label": "LM Codes (Sample {n})",
@ -196,7 +229,7 @@
"prev_btn": "◀ Previous",
"next_btn": "Next ▶",
"restore_params_btn": "↙️ Apply These Settings to UI (Restore Batch Parameters)",
"batch_results_title": "👇 Click here to view batch results & generation details",
"batch_results_title": "📁 Batch Results & Generation Details",
"all_files_label": "📁 All Generated Files (Download)",
"generation_details": "Generation Details"
},
@ -214,6 +247,7 @@
"lm_generated": "🤖 Generated example using LM",
"lm_fallback": "Failed to generate example using LM, falling back to examples directory",
"lm_not_initialized": "❌ 5Hz LM not initialized. Please initialize it first.",
"think_requires_lm": "⚠️ 'Think' requires 5Hz LM to be initialized. Think has been disabled — generation will proceed without LM thinking.",
"autogen_enabled": "🔄 AutoGen enabled - next batch will generate after this",
"batch_ready": "✅ Batch {n} ready! Click 'Next' to view.",
"batch_generating": "🔄 Starting background generation for Batch {n}...",
@ -263,57 +297,57 @@
"dataset_name": "Dataset Name",
"dataset_name_placeholder": "Enter dataset name",
"dataset_settings_header": "Dataset Settings",
"tag_prepend": "Prepend (tag, caption)",
"tag_append": "Append (caption, tag)",
"tag_replace": "Replace caption",
"step2_title": "Step 2: Auto-Label with AI",
"step2_instruction": "Click the button below to automatically generate metadata for all audio files using AI:\n- **Caption**: Music style, genre, mood description\n- **BPM**: Beats per minute\n- **Key**: Musical key (e.g., C Major, Am)\n- **Time Signature**: 4/4, 3/4, etc.",
"tag_prepend": "Prepend (Tag, Caption)",
"tag_append": "Append (Caption, Tag)",
"tag_replace": "Replace Caption",
"step2_title": "Step 2: AI Auto-Labeling",
"step2_instruction": "Click the button below to use AI to automatically generate metadata for all audio files:\n- **Caption**: Music style, genre, mood description\n- **BPM**: Beats per minute\n- **Key**: Music key (e.g. C Major, Am)\n- **Time Signature**: 4/4, 3/4 etc.",
"step3_title": "Step 3: Preview & Edit",
"step4_title": "Step 4: Save Dataset",
"step5_title": "Step 5: Preprocess to Tensors",
"step5_intro": "**Preprocessing converts your dataset to pre-computed tensors for fast training.**\n\nYou can either:\n- Use the dataset from Steps 1-4 above, **OR**\n- Load an existing dataset JSON file (if you've already saved one)",
"step5_details": "This step:\n- Encodes audio to VAE latents\n- Encodes captions and lyrics to text embeddings\n- Runs the condition encoder\n- Saves all tensors to `.pt` files\n\n⚠ **This requires the model to be loaded and may take a few minutes.**",
"train_tensor_selection_desc": "Select the directory containing preprocessed tensor files (`.pt` files).\nThese are created in the \"Dataset Builder\" tab using the \"Preprocess\" button.",
"step5_intro": "**Preprocessing converts your dataset into pre-computed tensors for fast training.**\n\nYou can:\n- Use the dataset from steps 1-4 above, **OR**\n- Load an existing dataset JSON file (if you have one saved)",
"step5_details": "This step will:\n- Encode audio to VAE latents\n- Encode captions and lyrics to text embeddings\n- Run condition encoders\n- Save all tensors to `.pt` files\n\n⚠ **This requires loading models and may take a few minutes.**",
"train_tensor_selection_desc": "Select the directory containing preprocessed tensor files (`.pt` files).\nThese are created using the 'Preprocess' button in the 'Dataset Builder' tab.",
"all_instrumental": "All Instrumental",
"all_instrumental_info": "Check if all tracks are instrumental (no vocals)",
"custom_tag": "Custom Activation Tag",
"custom_tag_info": "Unique tag to activate this LoRA's style",
"custom_tag": "Custom Trigger Tag",
"custom_tag_info": "Unique tag to activate this LoRA style",
"tag_position": "Tag Position",
"tag_position_info": "Where to place the custom tag in the caption",
"genre_ratio": "Genre Ratio (%)",
"genre_ratio_info": "0%=all Caption, 100%=all Genre. Per-sample override takes priority.",
"skip_metas": "Skip BPM/Key/Time Signature",
"skip_metas_info": "Skip BPM/Key/Time Signature generation. Caption and Genre are still generated by LLM.",
"genre_ratio_info": "0%=All Caption, 100%=All Genre. Single sample override takes precedence.",
"skip_metas": "Skip BPM/Key/TimeSig",
"skip_metas_info": "Skip BPM/Key/TimeSig generation. Captions and genres are still generated by LM.",
"only_unlabeled": "Only Unlabeled",
"only_unlabeled_info": "Only label samples without caption (useful for resuming failed labeling)",
"only_unlabeled_info": "Only label samples with no caption (for continuing failed labeling)",
"auto_label_btn": "🏷️ Auto-Label All",
"label_progress": "Labeling Progress",
"select_sample": "Select Sample #",
"select_sample_info": "Choose a sample to preview and edit",
"select_sample_info": "Select sample to preview and edit",
"audio_preview": "Audio Preview",
"filename": "Filename",
"caption": "Caption",
"genre": "Genre",
"prompt_override_label": "Prompt Override (this sample)",
"prompt_override_label": "Prompt Override (This Sample)",
"prompt_override_info": "Override global ratio for this sample",
"lyrics_editable_label": "Lyrics (editable, used for training)",
"lyrics_editable_label": "Lyrics (Editable for Training)",
"raw_lyrics_label": "Raw Lyrics (from .txt file)",
"no_lyrics_placeholder": "(no .txt lyrics file)",
"no_lyrics_placeholder": "(No .txt lyrics file)",
"bpm": "BPM",
"key_label": "Key",
"key_placeholder": "C Major",
"time_sig": "Time Signature",
"time_sig": "Time Sig",
"duration_s": "Duration (s)",
"language": "Language",
"instrumental": "Instrumental",
"save_changes_btn": "💾 Save Changes",
"edit_status": "Edit Status",
"save_path": "Save Path",
"save_path_info": "Path where the dataset JSON will be saved",
"save_path_info": "Path to save dataset JSON",
"save_dataset_btn": "💾 Save Dataset",
"save_status": "Save Status",
"load_existing_label": "Load Existing Dataset (Optional)",
"load_existing_info": "Path to a previously saved dataset JSON file",
"load_existing_info": "Path to previously saved dataset JSON file",
"load_dataset_btn": "📂 Load Dataset",
"tensor_output_dir": "Tensor Output Directory",
"tensor_output_info": "Directory to save preprocessed tensor files",
@ -321,25 +355,22 @@
"preprocess_progress": "Preprocessing Progress",
"preprocessed_tensors_dir": "Preprocessed Tensors Directory",
"preprocessed_tensors_info": "Directory containing preprocessed .pt tensor files",
"train_section_tensors": "Preprocessed Dataset Selection",
"train_section_lora": "LoRA Settings",
"train_section_params": "Training Parameters",
"dataset_info": "Dataset Info",
"lora_rank": "LoRA Rank (r)",
"lora_rank_info": "Higher = more capacity, more memory",
"lora_rank_info": "Higher capacity but more VRAM",
"lora_alpha": "LoRA Alpha",
"lora_alpha_info": "Scaling factor (typically 2x rank)",
"lora_alpha_info": "Scaling factor (usually 2x Rank)",
"lora_dropout": "LoRA Dropout",
"learning_rate": "Learning Rate",
"learning_rate_info": "Start with 3e-4, adjust if needed",
"learning_rate_info": "Start with 3e-4, adjust as needed",
"max_epochs": "Max Epochs",
"batch_size": "Batch Size",
"batch_size_info": "Increase if you have enough VRAM",
"batch_size_info": "Increase if VRAM allows",
"gradient_accumulation": "Gradient Accumulation",
"gradient_accumulation_info": "Effective batch = batch_size × accumulation",
"gradient_accumulation_info": "Effective Batch Size = batch_size * accum_steps",
"save_every_n_epochs": "Save Every N Epochs",
"shift": "Shift",
"shift_info": "Timestep shift for turbo model",
"shift_info": "Timestep shift for Turbo models",
"seed": "Seed",
"output_dir": "Output Directory",
"output_dir_info": "Directory to save trained LoRA weights",
@ -354,5 +385,15 @@
"export_path": "Export Path",
"export_lora_btn": "📦 Export LoRA",
"export_status": "Export Status"
},
"gen": {
"enable_normalization": "Enable Normalization",
"enable_normalization_info": "Normalize audio volume to target peak level to prevent clipping or ensure consistent loudness.",
"normalization_db": "Target Peak (dB)",
"normalization_db_info": "Target peak level in decibels. -1.0 dB is standard safe peak. -0.1 dB is max.",
"latent_shift": "Latent Shift",
"latent_shift_info": "Shift applied to DiT latents before VAE decode. Default 0 (no shift). Negative values (e.g. -0.04) can reduce clipping.",
"latent_rescale": "Latent Rescale",
"latent_rescale_info": "Rescale factor for DiT latents before VAE decode. Default 1.0 (no rescale). Values < 1.0 (e.g. 0.91) can reduce clipping."
}
}
}

View file

@ -1,356 +1,391 @@
{
"app": {
"title": "🎛️ סביבת העבודה ACE-Step V1.5 Playground💡",
"subtitle": "פורצים את גבולות יצירת המוזיקה בקוד פתוח"
},
"dataset": {
"title": "📊 סייר מערכי נתונים (Dataset Explorer)",
"dataset_label": "מערך נתונים",
"dataset_info": "בחר מערך נתונים לחקירה",
"import_btn": "📥 ייבוא מערך נתונים",
"search_type_label": "סוג חיפוש",
"search_type_info": "כיצד למצוא פריטים",
"search_value_label": "ערך חיפוש",
"search_value_placeholder": "הזן מפתחות או אינדקס (השאר ריק לבחירה אקראית)",
"search_value_info": "מפתחות: התאמה מדויקת, אינדקס: 0 עד גודל המערך פחות 1",
"instruction_label": "📝 הנחיה (Instruction)",
"instruction_placeholder": "אין הנחיה זמינה",
"metadata_title": "📋 מטא-דאטה של הפריט (JSON)",
"metadata_label": "מידע מלא על הפריט",
"source_audio": "אודיו מקור",
"target_audio": "אודיו יעד",
"reference_audio": "אודיו ייחוס",
"get_item_btn": "🔍 קבל פריט",
"use_src_checkbox": "השתמש באודיו מקור ממערך הנתונים",
"use_src_info": "סמן כדי להשתמש באודיו המקור מתוך מערך הנתונים",
"data_status_label": "📊 מצב נתונים",
"data_status_default": "❌ לא יובא מערך נתונים",
"autofill_btn": "📋 מילוי אוטומטי של טופס היצירה"
},
"service": {
"title": "🔧 הגדרות שירות",
"checkpoint_label": "קובץ נקודת ביקורת (Checkpoint)",
"checkpoint_info": "בחר קובץ נקודת ביקורת של מודל מאומן (נתיב מלא או שם קובץ)",
"refresh_btn": "🔄 רענון",
"model_path_label": "נתיב מודל ראשי",
"model_path_info": "בחר את ספריית הגדרות המודל (נסרק אוטומטית מנקודות הביקורת)",
"device_label": "מכשיר (Device)",
"device_info": "מכשיר עיבוד (מומלץ זיהוי אוטומטי)",
"lm_model_path_label": "נתיב מודל 5Hz LM",
"lm_model_path_info": "בחר את קובץ נקודת הביקורת של מודל ה-5Hz LM",
"backend_label": "מנוע (Backend) 5Hz LM",
"backend_info": "בחר מנוע עבור 5Hz LM: vllm (מהיר יותר) או pt (PyTorch, תואם יותר)",
"init_llm_label": "אתחול 5Hz LM",
"init_llm_info": "סמן כדי לאתחל את ה-5Hz LM במהלך אתחול השירות",
"flash_attention_label": "השתמש ב-Flash Attention",
"flash_attention_info_enabled": "הפעל Flash Attention להסקה מהירה יותר (דורש חבילת flash_attn)",
"flash_attention_info_disabled": "Flash Attention אינו זמין (חבילת flash_attn לא מותקנת)",
"offload_cpu_label": "העברה ל-CPU (Offload)",
"offload_cpu_info": "העבר מודלים ל-CPU כשאינם בשימוש כדי לחסוך בזיכרון גרפי (VRAM)",
"offload_dit_cpu_label": "העברת DiT ל-CPU",
"offload_dit_cpu_info": "העבר DiT ל-CPU (דורש 'העברה ל-CPU')",
"compile_model_label": "הידור מודל (Compile)",
"compile_model_info": "השתמש ב-torch.compile לאופטימיזציה של המודל (נדרש עבור קוונטיזציה)",
"quantization_label": "קוונטיזציה INT8",
"quantization_info": "הפעל קוונטיזציה של משקולות בלבד (INT8) להפחתת שימוש ב-VRAM (דורש הידור מודל)",
"init_btn": "אתחול שירות",
"status_label": "מצב",
"language_label": "שפת ממשק",
"language_info": "בחר את שפת הממשק"
},
"generation": {
"required_inputs": "📝 קלטים נדרשים",
"task_type_label": "סוג משימה",
"task_type_info": "בחר את סוג המשימה ליצירה",
"instruction_label": "הנחיה",
"instruction_info": "ההנחיה נוצרת אוטומטית בהתאם לסוג המשימה",
"load_btn": "טעינה",
"track_name_label": "שם רצועה",
"track_name_info": "בחר שם רצועה עבור משימות lego/extract",
"track_classes_label": "שמות רצועות",
"track_classes_info": "בחר מספר מחלקות רצועה עבור משימה מלאה",
"audio_uploads": "🎵 העלאות אודיו",
"reference_audio": "אודיו ייחוס (אופציונלי)",
"source_audio": "אודיו מקור (אופציונלי)",
"convert_codes_btn": "המר לקודים",
"lm_codes_hints": "🎼 רמזי קודי LM",
"lm_codes_label": "רמזי קודי LM",
"lm_codes_placeholder": "<|audio_code_10695|><|audio_code_54246|>...",
"lm_codes_info": "הדבק רמזי קודי LM עבור יצירת טקסט למוזיקה (text2music)",
"lm_codes_sample": "רמזי קודי LM (דגימה {n})",
"lm_codes_sample_info": "קודים עבור דגימה {n}",
"transcribe_btn": "תמלול",
"repainting_controls": "🎨 בקרת צביעה מחדש (בשניות)",
"repainting_start": "תחילת צביעה מחדש",
"repainting_end": "סיום צביעה מחדש",
"mode_label": "מצב יצירה",
"mode_info": "פשוט: תאר מוזיקה בשפה טבעית. מותאם אישית: שליטה מלאה בתיאור ומילים.",
"mode_simple": "פשוט",
"mode_custom": "מותאם אישית",
"simple_query_label": "תיאור השיר",
"simple_query_placeholder": "תאר את המוזיקה שברצונך ליצור, למשל: 'שיר אהבה אקוסטי שקט לערב רגוע'. השאר ריק לדגימה אקראית.",
"simple_query_info": "הזן תיאור בשפה טבעית של המוזיקה שברצונך ליצור",
"simple_vocal_language_label": "שפת שירה (אופציונלי)",
"simple_vocal_language_info": "בחר שפות מועדפות למילים. השתמש ב-'unknown' לכל שפה.",
"create_sample_btn": "צור דגימה",
"caption_title": "📝 תיאור מוזיקלי (Caption)",
"caption_label": "תיאור מוזיקלי (אופציונלי)",
"caption_placeholder": "מנגינת גיטרה אקוסטית שלווה עם שירה רכה...",
"caption_info": "תאר את הסגנון, הז'אנר, הכלים והאווירה",
"lyrics_title": "📝 מילים",
"lyrics_label": "מילים (אופציונלי)",
"lyrics_placeholder": "[בית 1]\\nתחת שמי הלילה...\\nאני מרגיש חי...",
"lyrics_info": "מילות השיר עם מבנה",
"instrumental_label": "אינסטרומנטלי (ללא שירה)",
"format_btn": "פרמוט",
"optional_params": "⚙️ פרמטרים אופציונליים",
"vocal_language_label": "שפת שירה (אופציונלי)",
"vocal_language_info": "השתמש ב-`unknown` לקטעים כליים",
"bpm_label": "קצב (BPM) (אופציונלי)",
"bpm_info": "השאר ריק אם לא ידוע",
"keyscale_label": "סולם (KeyScale) (אופציונלי)",
"keyscale_placeholder": "השאר ריק אם לא ידוע",
"keyscale_info": "A-G, #/♭, מז'ור/מינור",
"timesig_label": "משקל מוזיקלי (אופציונלי)",
"timesig_info": "2/4, 3/4, 4/4...",
"duration_label": "אורך אודיו (שניות)",
"duration_info": "השתמש ב-1- לאקראי",
"batch_size_label": "גודל מנה (Batch Size)",
"batch_size_info": "מספר קטעי אודיו ליצירה (מקסימום 8)",
"advanced_settings": "🔧 הגדרות מתקדמות",
"inference_steps_label": "צעדי הסקה של DiT",
"inference_steps_info": "Turbo: מקסימום 8, Base: מקסימום 200",
"guidance_scale_label": "קנה מידה להנחיה (רק למודל base)",
"guidance_scale_info": "ערכים גבוהים יותר נצמדים יותר לטקסט",
"seed_label": "גרעין (Seed)",
"seed_info": "השתמש בערכים מופרדים בפסיקים עבור מנות",
"random_seed_label": "גרעין אקראי",
"random_seed_info": "אפשר ליצירה אוטומטית של גרעינים",
"audio_format_label": "פורמט אודיו",
"audio_format_info": "פורמט האודיו עבור הקבצים שיישמרו",
"use_adg_label": "השתמש ב-ADG",
"use_adg_info": "הפעל Angle Domain Guidance",
"shift_label": "Shift",
"shift_info": "פקטור הסטת צעדי זמן למודלי base (טווח 1.0~5.0, ברירת מחדל 3.0). לא משפיע על מודלי turbo.",
"infer_method_label": "שיטת הסקה",
"infer_method_info": "שיטת הסקת הדיפוזיה. ODE (Euler) מהירה יותר, SDE (stochastic) עשויה להפיק תוצאות שונות.",
"custom_timesteps_label": "צעדי זמן מותאמים אישית",
"custom_timesteps_info": "אופציונלי: ערכים מופרדים בפסיקים מ-1.0 עד 0.0. דורס את צעדי ההסקה וה-shift.",
"cfg_interval_start": "תחילת מרווח CFG",
"cfg_interval_end": "סיום מרווח CFG",
"lm_params_title": "🤖 פרמטרי יצירת LM",
"lm_temperature_label": "טמפרטורת LM",
"lm_temperature_info": "טמפרטורת 5Hz LM (גבוה יותר = אקראי יותר)",
"lm_cfg_scale_label": "קנה מידה LM CFG",
"lm_cfg_scale_info": "5Hz LM CFG (1.0 = ללא CFG)",
"lm_top_k_label": "LM Top-K",
"lm_top_k_info": "Top-K (0 = מושבת)",
"lm_top_p_label": "LM Top-P",
"lm_top_p_info": "Top-P (1.0 = מושבת)",
"lm_negative_prompt_label": "הנחיה שלילית ל-LM",
"lm_negative_prompt_placeholder": "הזן הנחיה שלילית עבור CFG",
"lm_negative_prompt_info": "הנחיה שלילית (בשימוש כאשר LM CFG Scale > 1.0)",
"cot_metas_label": "CoT Metas",
"cot_metas_info": "השתמש ב-LM ליצירת מטא-דאטה CoT (בטל סימון כדי לדלג)",
"cot_language_label": "שפת CoT",
"cot_language_info": "יצירת שפה ב-CoT (שרשרת מחשבה)",
"constrained_debug_label": "ניקוי באגים של פענוח מוגבל",
"constrained_debug_info": "הפעל לוגים של ניקוי באגים עבור פענוח מוגבל",
"auto_score_label": "דירוג אוטומטי",
"auto_score_info": "חשב אוטומטית ציוני איכות לכל קטעי האודיו שנוצרו",
"auto_lrc_label": "LRC אוטומטי",
"auto_lrc_info": "צור אוטומטית חותמות זמן למילים (LRC) לכל קטעי האודיו",
"lm_batch_chunk_label": "גודל מקטע מנת LM",
"lm_batch_chunk_info": "מקסימום פריטים למקטע מנת LM (ברירת מחדל: 8, מוגבל ע\"י זיכרון ה-GPU)",
"codes_strength_label": "חוזק קודי LM",
"codes_strength_info": "שליטה בכמות צעדי הניקוי מרעש המשתמשים בקודים שנוצרו ע\"י ה-LM",
"cover_strength_label": "חוזק כיסוי אודיו (Audio Cover)",
"cover_strength_info": "שליטה בכמות צעדי הניקוי מרעש המשתמשים במצב כיסוי",
"score_sensitivity_label": "רגישות ציון איכות",
"score_sensitivity_info": "נמוך יותר = רגיש יותר (ברירת מחדל: 1.0)",
"think_label": "חשיבה (Think)",
"parallel_thinking_label": "חשיבה מקבילית",
"generate_btn": "🎵 צור מוזיקה",
"autogen_label": "יצירה אוטומטית",
"caption_rewrite_label": "שכתוב תיאור"
},
"results": {
"title": "🎵 תוצאות",
"generated_music": "🎵 מוזיקה שנוצרה (דגימה {n})",
"send_to_src_btn": "🔗 שלח לאודיו מקור",
"save_btn": "💾 שמירה",
"score_btn": "📊 דירוג",
"lrc_btn": "🎵 LRC",
"quality_score_label": "ציון איכות (דגימה {n})",
"quality_score_placeholder": "לחץ על 'דירוג' לחישוב ציון איכות מבוסס מורכבות (Perplexity)",
"codes_label": "קודי LM (דגימה {n})",
"lrc_label": "חותמות זמן למילים (דגימה {n})",
"lrc_placeholder": "לחץ על 'LRC' ליצירת חותמות זמן",
"details_accordion": "📊 דירוג, LRC וקודי LM",
"generation_status": "מצב יצירה",
"current_batch": "מנה נוכחית",
"batch_indicator": "מנה {current} / {total}",
"next_batch_status": "מצב המנה הבאה",
"prev_btn": "◀ הקודם",
"next_btn": "הבא ▶",
"restore_params_btn": "↙️ החל הגדרות אלו על הממשק (שחזור פרמטרי מנה)",
"batch_results_title": "📁 תוצאות המנה ופרטי יצירה",
"all_files_label": "📁 כל הקבצים שנוצרו (הורדה)",
"generation_details": "פרטי יצירה"
},
"messages": {
"no_audio_to_save": "❌ אין אודיו לשמירה",
"save_success": "✅ האודיו והמטא-דאטה נשמרו ב-{filename}",
"save_failed": "❌ השמירה נכשלה: {error}",
"no_file_selected": "⚠️ לא נבחר קובץ",
"params_loaded": "✅ הפרמטרים נטענו מ-{filename}",
"invalid_json": "❌ קובץ JSON לא תקין: {error}",
"load_error": "❌ שגיאה בטעינת הקובץ: {error}",
"example_loaded": "📁 נטען דגם מ-{filename}",
"example_failed": "נכשל ניתוח קובץ ה-JSON ב-{filename}: {error}",
"example_error": "שגיאה בטעינת הדגם: {error}",
"lm_generated": "🤖 נוצר דגם באמצעות ה-LM",
"lm_fallback": "יצירת דגם באמצעות ה-LM נכשלה, חוזר לשימוש בספריית הדגמים",
"lm_not_initialized": "❌ 5Hz LM לא מאותחל. נא לאתחל אותו תחילה.",
"autogen_enabled": "🔄 יצירה אוטומטית הופעלה - המנה הבאה תיווצר לאחר מכן",
"batch_ready": "✅ מנה {n} מוכנה! לחץ על 'הבא' לצפייה.",
"batch_generating": "🔄 מתחיל יצירת רקע עבור מנה {n}...",
"batch_failed": "❌ יצירת הרקע נכשלה: {error}",
"viewing_batch": "✅ צופה במנה {n}",
"at_first_batch": "נמצא כבר במנה הראשונה",
"at_last_batch": "אין מנה באה זמינה",
"batch_not_found": "מנה {n} לא נמצאה בתור",
"no_batch_data": "לא נמצאו נתוני מנה לשחזור.",
"params_restored": "✅ פרמטרי הממשק שוחזרו ממנה {n}",
"scoring_failed": "❌ שגיאה: נתוני המנה לא נמצאו",
"no_codes": "❌ אין קודי אודיו זמינים. נא ליצור מוזיקה תחילה.",
"score_failed": "❌ הדירוג נכשל: {error}",
"score_error": "❌ שגיאה בחישוב הציון: {error}",
"lrc_no_batch_data": "❌ לא נמצאו נתוני מנה. נא ליצור מוזיקה תחילה.",
"lrc_no_extra_outputs": "❌ לא נמצאו פלטים נוספים. טנזורי התניה אינם זמינים.",
"lrc_missing_tensors": "❌ חסרים טנזורים נדרשים ליצירת LRC.",
"lrc_sample_not_exist": "❌ הדגימה אינה קיימת במנה הנוכחית.",
"lrc_empty_result": "⚠️ יצירת ה-LRC הפיקה תוצאה ריקה.",
"empty_query": "⚠️ נא להזין תיאור מוזיקלי.",
"sample_creation_failed": "❌ יצירת הדגימה נכשלה. נא לנסות שוב.",
"sample_created": "✅ הדגימה נוצרה! בדוק את התיאור והמילים, ולאחר מכן לחץ על 'צור מוזיקה'.",
"simple_examples_not_found": "⚠️ ספריית הדגמים של המצב הפשוט לא נמצאה.",
"simple_examples_empty": "⚠️ לא נמצאו קבצי דוגמה במצב פשוט.",
"simple_example_loaded": "🎲 נטענה דוגמה אקראית מ-{filename}",
"format_success": "✅ התיאור והמילים פורמטו בהצלחה",
"format_failed": "❌ הפירמוט נכשל: {error}",
"skipping_metas_cot": "⚡ מדלג על שלב 1 של מטא-דאטה COT (הדגימה כבר מפורמטת)",
"invalid_timesteps_format": "⚠️ פורמט צעדי זמן לא תקין. משתמש בלוח זמנים כברירת מחדל.",
"timesteps_out_of_range": "⚠️ צעדי הזמן חייבים להיות בטווח [0, 1]. משתמש בלוח זמנים כברירת מחדל.",
"timesteps_count_mismatch": "⚠️ מספר צעדי הזמן ({actual}) שונה מצעדי ההסקה ({expected}). משתמש במספר צעדי הזמן."
},
"training": {
"tab_title": "🎓 אימון LoRA",
"tab_dataset_builder": "📁 בונה מערך נתונים",
"tab_train_lora": "🚀 אימון LoRA",
"quick_start_title": "🚀 התחלה מהירה",
"load_dataset_label": "נתיב קובץ JSON של מערך הנתונים",
"load_dataset_info": "טעינת מערך נתונים שנשמר בעבר",
"load_btn": "📂 טעינה",
"load_status": "מצב טעינה",
"scan_label": "נתיב ספריית אודיו",
"scan_info": "סריקה אחר קבצי אודיו (wav, mp3, flac, ogg, opus)",
"scan_btn": "🔍 סריקה",
"scan_status": "מצב סריקה",
"found_audio_files": "קבצי אודיו שנמצאו",
"dataset_name": "שם מערך הנתונים",
"dataset_name_placeholder": "הזן שם למערך הנתונים",
"dataset_settings_header": "הגדרות מערך נתונים",
"tag_prepend": "הוספה בהתחלה (תגית, תיאור)",
"tag_append": "הוספה בסוף (תיאור, תגית)",
"tag_replace": "החלפת התיאור",
"step2_title": "שלב 2: תיוג אוטומטי באמצעות AI",
"step2_instruction": "לחץ על הכפתור למטה כדי ליצור אוטומטית מטא-נתונים עבור כל קבצי האודיו באמצעות AI:\n• **תיאור**: סגנון מוזיקה, ז'אנר, תיאור מצב רוח\n• **BPM**: פעימות לדקה\n• **מפתח**: מפתח מוזיקלי (לדוגמה, C Major, Am)\n• **חתימת זמן**: 4/4, 3/4, וכו'",
"step3_title": "שלב 3: תצוגה מקדימה ועריכה",
"step4_title": "שלב 4: שמירת מערך הנתונים",
"step5_title": "שלב 5: עיבוד מקדים לטנזורים (Tensors)",
"step5_intro": "**העיבוד המקדים ממיר את מערך הנתונים שלך לטנזורים מחושבים מראש לאימון מהיר.**\n\nאתה יכול:\n• להשתמש במערך הנתונים משלבים 1-4 למעלה, **או**\n• לטעון קובץ JSON של מערך נתונים קיים (אם כבר שמרת אחד)",
"step5_details": "שלב זה:\n• מקודד אודיו ל-latents של VAE\n• מקודד תיאורים ומילים ל-embeddings טקסט\n• מפעיל את מקודד התנאי\n• שומר את כל הטנזורים לקבצי `.pt`\n\n⚠ **זה דורש טעינת המודל ועשוי לקחת מספר דקות.**",
"train_tensor_selection_desc": "בחר את הספרייה המכילה קבצי טנזור מעובדים מראש (קבצי `.pt`).\nאלה נוצרים בכרטיסייה \"בונה מערך נתונים\" באמצעות כפתור \"עיבוד מקדים\".",
"all_instrumental": "הכל אינסטרומנטלי",
"all_instrumental_info": "סמן אם כל הרצועות הן כליות (ללא שירה)",
"custom_tag": "תגית הפעלה מותאמת אישית",
"custom_tag_info": "תגית ייחודית להפעלת הסגנון של LoRA זו",
"tag_position": "מיקום התגית",
"tag_position_info": "היכן למקם את התגית המותאמת אישית בתוך התיאור",
"genre_ratio": "יחס ז'אנר (%)",
"genre_ratio_info": "0% = הכל תיאור, 100% = הכל ז'אנר. הגדרה פר-דגימה קודמת להגדרת הכלל.",
"skip_metas": "דלג על BPM/סולם/משקל",
"skip_metas_info": "דלג על יצירת BPM/סולם/משקל. התיאור והז'אנר עדיין ייווצרו על ידי ה-LLM.",
"only_unlabeled": "רק כאלו ללא תיוג",
"only_unlabeled_info": "תייג רק דגימות ללא תיאור (שימושי להמשך תיוג שנכשל)",
"auto_label_btn": "🏷️ תיוג אוטומטי של הכל",
"label_progress": "התקדמות התיוג",
"select_sample": "בחר דגימה #",
"select_sample_info": "בחר דגימה לצפייה ועריכה",
"audio_preview": "תצוגה מקדימה של אודיו",
"filename": "שם קובץ",
"caption": "תיאור",
"genre": "ז'אנר",
"prompt_override_label": "דריסת פרומפט (לדגימה זו)",
"prompt_override_info": "דריסת היחס הכללי עבור דגימה זו",
"lyrics_editable_label": "מילים (ניתן לעריכה, משמש לאימון)",
"raw_lyrics_label": "מילים גולמיות (מתוך קובץ .txt)",
"no_lyrics_placeholder": "(אין קובץ מילים .txt)",
"bpm": "BPM",
"key_label": "סולם (Key)",
"key_placeholder": "C Major",
"time_sig": "משקל מוזיקלי",
"duration_s": "משך (שניות)",
"language": "שפה",
"instrumental": "אינסטרומנטלי",
"save_changes_btn": "💾 שמירת שינויים",
"edit_status": "מצב עריכה",
"save_path": "נתיב שמירה",
"save_path_info": "הנתיב שבו יישמר קובץ ה-JSON של מערך הנתונים",
"save_dataset_btn": "💾 שמירת מערך נתונים",
"save_status": "מצב שמירה",
"load_existing_label": "טעינת מערך נתונים קיים (אופציונלי)",
"load_existing_info": "נתיב לקובץ JSON של מערך נתונים שנשמר בעבר",
"load_dataset_btn": "📂 טעינת מערך נתונים",
"tensor_output_dir": "ספריית פלט של טנזורים",
"tensor_output_info": "הספרייה לשמירת קבצי טנזור שעברו עיבוד מקדים",
"preprocess_btn": "⚡ עיבוד מקדים",
"preprocess_progress": "התקדמות עיבוד מקדים",
"preprocessed_tensors_dir": "ספריית טנזורים מעובדים",
"preprocessed_tensors_info": "ספרייה המכילה קבצי .pt של טנזורים מעובדים",
"train_section_tensors": "בחירת מערך נתונים מעובד",
"train_section_lora": "הגדרות LoRA",
"train_section_params": "פרמטרי אימון",
"dataset_info": "מידע על מערך הנתונים",
"lora_rank": "דרגת LoRA (Rank)",
"lora_rank_info": "גבוה יותר = יותר קיבולת, יותר זיכרון",
"lora_alpha": "LoRA Alpha",
"lora_alpha_info": "פקטור קנה מידה (בדרך כלל פי 2 מה-Rank)",
"lora_dropout": "LoRA Dropout",
"learning_rate": "קצב למידה (Learning Rate)",
"learning_rate_info": "התחל עם 3e-4, שנה במידת הצורך",
"max_epochs": "מקסימום תקופות (Epochs)",
"batch_size": "גודל מנה (Batch Size)",
"batch_size_info": "הגדל אם יש לך מספיק זיכרון גרפי (VRAM)",
"gradient_accumulation": "צבירת גרדיאנטים (Accumulation)",
"gradient_accumulation_info": "גודל מנה אפקטיבי = גודל מנה × צבירה",
"save_every_n_epochs": "שמור כל N תקופות (Epochs)",
"shift": "Shift (הסטה)",
"shift_info": "הסטת צעדי זמן עבור מודל turbo",
"seed": "גרעין (Seed)",
"output_dir": "ספריית פלט",
"output_dir_info": "ספרייה לשמירת משקולות ה-LoRA המאומנות",
"start_training_btn": "🚀 התחלת אימון",
"stop_training_btn": "⏹️ עצירת אימון",
"training_progress": "התקדמות האימון",
"training_log": "יומן אימון",
"training_loss_title": "הפסד אימון (Training Loss)",
"step": "צעד",
"loss": "הפסד (Loss)",
"export_header": "ייצוא LoRA",
"export_path": "נתיב ייצוא",
"export_lora_btn": "📦 ייצוא LoRA",
"export_status": "מצב ייצוא"
}
}
{
"app": {
"title": "🎛️ סביבת העבודה ACE-Step V1.5 Playground💡",
"subtitle": "פורצים את גבולות יצירת המוזיקה בקוד פתוח"
},
"dataset": {
"title": "📊 סייר מערכי נתונים (Dataset Explorer)",
"dataset_label": "מערך נתונים",
"dataset_info": "בחר מערך נתונים לחקירה",
"import_btn": "📥 ייבוא מערך נתונים",
"search_type_label": "סוג חיפוש",
"search_type_info": "כיצד למצוא פריטים",
"search_value_label": "ערך חיפוש",
"search_value_placeholder": "הזן מפתחות או אינדקס (השאר ריק לבחירה אקראית)",
"search_value_info": "מפתחות: התאמה מדויקת, אינדקס: 0 עד גודל המערך פחות 1",
"instruction_label": "📝 הנחיה (Instruction)",
"instruction_placeholder": "אין הנחיה זמינה",
"metadata_title": "📋 מטא-דאטה של הפריט (JSON)",
"metadata_label": "מידע מלא על הפריט",
"source_audio": "אודיו מקור",
"target_audio": "אודיו יעד",
"reference_audio": "אודיו ייחוס",
"get_item_btn": "🔍 קבל פריט",
"use_src_checkbox": "השתמש באודיו מקור ממערך הנתונים",
"use_src_info": "סמן כדי להשתמש באודיו המקור מתוך מערך הנתונים",
"data_status_label": "📊 מצב נתונים",
"data_status_default": "❌ לא יובא מערך נתונים",
"autofill_btn": "📋 מילוי אוטומטי של טופס היצירה"
},
"service": {
"title": "🔧 הגדרות שירות",
"checkpoint_label": "קובץ נקודת ביקורת (Checkpoint)",
"checkpoint_info": "בחר קובץ נקודת ביקורת של מודל מאומן (נתיב מלא או שם קובץ)",
"refresh_btn": "🔄 רענון",
"model_path_label": "נתיב מודל ראשי",
"model_path_info": "בחר את ספריית הגדרות המודל (נסרק אוטומטית מנקודות הביקורת)",
"device_label": "מכשיר (Device)",
"device_info": "מכשיר עיבוד (מומלץ זיהוי אוטומטי)",
"lm_model_path_label": "נתיב מודל 5Hz LM",
"lm_model_path_info": "בחר את קובץ נקודת הביקורת של מודל ה-5Hz LM",
"backend_label": "מנוע (Backend) 5Hz LM",
"backend_info": "בחר מנוע עבור 5Hz LM: vllm (מהיר יותר) או pt (PyTorch, תואם יותר)",
"init_llm_label": "אתחול 5Hz LM",
"init_llm_info": "סמן כדי לאתחל את ה-5Hz LM במהלך אתחול השירות",
"flash_attention_label": "השתמש ב-Flash Attention",
"flash_attention_info_enabled": "הפעל Flash Attention להסקה מהירה יותר (דורש חבילת flash_attn)",
"flash_attention_info_disabled": "Flash Attention אינו זמין (חבילת flash_attn לא מותקנת)",
"offload_cpu_label": "העברה ל-CPU (Offload)",
"offload_cpu_info": "העבר מודלים ל-CPU כשאינם בשימוש כדי לחסוך בזיכרון גרפי (VRAM)",
"offload_dit_cpu_label": "העברת DiT ל-CPU",
"offload_dit_cpu_info": "העבר DiT ל-CPU (דורש 'העברה ל-CPU')",
"compile_model_label": "הידור מודל (Compile)",
"compile_model_info": "השתמש ב-torch.compile לאופטימיזציה של המודל (נדרש עבור קוונטיזציה)",
"quantization_label": "קוונטיזציה INT8",
"quantization_info": "הפעל קוונטיזציה של משקולות בלבד (INT8) להפחתת שימוש ב-VRAM (דורש הידור מודל)",
"mlx_dit_label": "MLX DiT (Apple Silicon)",
"mlx_dit_info_enabled": "השתמש ב-MLX מקורי להפצת DiT על Apple Silicon (מהיר יותר מ-MPS)",
"mlx_dit_info_disabled": "MLX לא זמין (דורש macOS + Apple Silicon + חבילת mlx)",
"init_btn": "אתחול שירות",
"status_label": "מצב",
"language_label": "שפת ממשק",
"language_info": "בחר את שפת הממשק",
"gpu_auto_tier": "שכבה שזוהתה אוטומטית",
"tier_label": "דריסת שכבת GPU",
"tier_info": "בחר שכבת GPU באופן ידני כדי להתאים ברירות מחדל של אופטימיזציה (העברה, קוונטיזציה, מנוע וכו')"
},
"generation": {
"tab_title": "🎵 יצירה",
"required_inputs": "📝 קלטים נדרשים",
"task_type_label": "סוג משימה",
"task_type_info": "בחר את סוג המשימה ליצירה",
"instruction_label": "הנחיה",
"instruction_info": "ההנחיה נוצרת אוטומטית בהתאם לסוג המשימה",
"load_btn": "טעינה",
"track_name_label": "שם רצועה",
"track_name_info": "בחר שם רצועה עבור משימות lego/extract",
"track_classes_label": "שמות רצועות",
"track_classes_info": "בחר מספר מחלקות רצועה עבור משימה מלאה",
"audio_uploads": "🎵 העלאות אודיו",
"reference_audio": "אודיו ייחוס (אופציונלי)",
"source_audio": "אודיו מקור",
"convert_codes_btn": "המר לקודים",
"analyze_btn": "🔍 ניתוח",
"sample_btn": "🎲 לחץ כאן",
"load_btn": "📂 טעינה",
"lm_codes_hints": "🎼 רמזי קודי LM",
"lm_codes_label": "רמזי קודי LM",
"lm_codes_placeholder": "<|audio_code_10695|><|audio_code_54246|>...",
"lm_codes_info": "הדבק רמזי קודי LM עבור יצירת טקסט למוזיקה (text2music)",
"lm_codes_sample": "רמזי קודי LM (דגימה {n})",
"lm_codes_sample_info": "קודים עבור דגימה {n}",
"transcribe_btn": "תמלול",
"repainting_controls": "🎨 בקרת צביעה מחדש (בשניות)",
"repainting_start": "תחילת צביעה מחדש",
"repainting_end": "סיום צביעה מחדש",
"mode_label": "מצב יצירה",
"mode_info": "בחר מצב יצירה כדי להתחיל.",
"mode_info_simple": "תאר את המוזיקה שלך בשפה טבעית. הבינה המלאכותית תיצור תיאור, מילים ומטא-נתונים עבורך.",
"mode_info_custom": "שליטה מלאה בתיאור, מילים וכל הפרמטרים.",
"mode_info_remix": "העלה שמע מקור כדי ליצור גרסת רמיקס עם התיאור והמילים שלך.",
"mode_info_repaint": "העלה שמע מקור וצבע מחדש טווח זמן מסוים.",
"mode_info_extract": "חלץ רצועה מסוימת (שירה, תופים וכו') משמע המקור.",
"mode_info_lego": "הרכב מחדש רצועות: החלף רצועה מסוימת בשמע המקור.",
"mode_info_complete": "השלם רצועות חסרות בשמע המקור.",
"mode_simple": "פשוט",
"mode_custom": "מותאם אישית",
"simple_query_label": "תיאור השיר",
"simple_query_placeholder": "תאר את המוזיקה שברצונך ליצור, למשל: 'שיר אהבה אקוסטי שקט לערב רגוע'. השאר ריק לדגימה אקראית.",
"simple_query_info": "הזן תיאור בשפה טבעית של המוזיקה שברצונך ליצור",
"simple_vocal_language_label": "שפת שירה",
"simple_vocal_language_info": "בחר שפות מועדפות למילים. השתמש ב-'unknown' לכל שפה.",
"create_sample_btn": "צור דגימה",
"caption_title": "📝 תיאור מוזיקלי (Caption)",
"caption_label": "תיאור מוזיקלי (אופציונלי)",
"caption_placeholder": "מנגינת גיטרה אקוסטית שלווה עם שירה רכה...",
"caption_info": "תאר את הסגנון, הז'אנר, הכלים והאווירה",
"lyrics_title": "📝 מילים",
"lyrics_label": "מילים (אופציונלי)",
"lyrics_placeholder": "[בית 1]\\nתחת שמי הלילה...\\nאני מרגיש חי...",
"lyrics_info": "מילות השיר עם מבנה",
"instrumental_label": "אינסטרומנטלי (ללא שירה)",
"format_btn": "פרמוט",
"format_caption_btn": "שפר תיאור",
"format_lyrics_btn": "שפר מילים",
"optional_params": "⚙️ פרמטרים אופציונליים",
"optional_music_props": "🎵 מאפייני מוזיקה",
"optional_gen_settings": "📐 הגדרות יצירה",
"advanced_dit_section": "🎛️ פרמטרי DiT",
"advanced_lm_section": "🤖 פרמטרי LM",
"advanced_output_section": "🔊 פלט שמע ועיבוד",
"advanced_automation_section": "⚡ אוטומציה ואצוות",
"vocal_language_label": "שפת שירה (אופציונלי)",
"vocal_language_info": "'unknown' = כלי/אוטומטי",
"bpm_label": "קצב (BPM) (אופציונלי)",
"bpm_info": "השאר ריק אם לא ידוע",
"keyscale_label": "מפתח (אופציונלי)",
"keyscale_placeholder": "השאר ריק אם לא ידוע",
"keyscale_info": "A-G, #/♭, מז'ור/מינור",
"timesig_label": "משקל מוזיקלי (אופציונלי)",
"timesig_info": "2/4, 3/4, 4/4...",
"duration_label": "אורך אודיו (שניות)",
"duration_info": "השתמש ב-1- לאקראי",
"batch_size_label": "גודל מנה (Batch Size)",
"batch_size_info": "מספר קטעי אודיו ליצירה (מקסימום 8)",
"advanced_settings": "⚙️ הגדרות",
"inference_steps_label": "צעדי הסקה של DiT",
"inference_steps_info": "Turbo: מקסימום 8, Base: מקסימום 200",
"guidance_scale_label": "קנה מידה להנחיה (רק למודל base)",
"guidance_scale_info": "ערכים גבוהים יותר נצמדים יותר לטקסט",
"seed_label": "גרעין (Seed)",
"seed_info": "השתמש בערכים מופרדים בפסיקים עבור מנות",
"random_seed_label": "גרעין אקראי",
"random_seed_info": "אפשר ליצירה אוטומטית של גרעינים",
"audio_format_label": "פורמט אודיו",
"audio_format_info": "פורמט האודיו עבור הקבצים שיישמרו",
"use_adg_label": "השתמש ב-ADG",
"use_adg_info": "הפעל Angle Domain Guidance",
"shift_label": "Shift",
"shift_info": "פקטור הסטת צעדי זמן למודלי base (טווח 1.0~5.0, ברירת מחדל 3.0). לא משפיע על מודלי turbo.",
"infer_method_label": "שיטת הסקה",
"infer_method_info": "שיטת הסקת הדיפוזיה. ODE (Euler) מהירה יותר, SDE (stochastic) עשויה להפיק תוצאות שונות.",
"custom_timesteps_label": "צעדי זמן מותאמים אישית",
"custom_timesteps_info": "אופציונלי: ערכים מופרדים בפסיקים מ-1.0 עד 0.0. דורס את צעדי ההסקה וה-shift.",
"cfg_interval_start": "תחילת מרווח CFG",
"cfg_interval_end": "סיום מרווח CFG",
"lm_params_title": "🤖 פרמטרי יצירת LM",
"lm_temperature_label": "טמפרטורת LM",
"lm_temperature_info": "טמפרטורת 5Hz LM (גבוה יותר = אקראי יותר)",
"lm_cfg_scale_label": "קנה מידה LM CFG",
"lm_cfg_scale_info": "5Hz LM CFG (1.0 = ללא CFG)",
"lm_top_k_label": "LM Top-K",
"lm_top_k_info": "Top-K (0 = מושבת)",
"lm_top_p_label": "LM Top-P",
"lm_top_p_info": "Top-P (1.0 = מושבת)",
"lm_negative_prompt_label": "הנחיה שלילית ל-LM",
"lm_negative_prompt_placeholder": "הזן הנחיה שלילית עבור CFG",
"lm_negative_prompt_info": "הנחיה שלילית (בשימוש כאשר LM CFG Scale > 1.0)",
"cot_metas_label": "CoT Metas",
"cot_metas_info": "השתמש ב-LM ליצירת מטא-דאטה CoT (בטל סימון כדי לדלג)",
"cot_language_label": "שפת CoT",
"cot_language_info": "יצירת שפה ב-CoT (שרשרת מחשבה)",
"constrained_debug_label": "ניקוי באגים של פענוח מוגבל",
"constrained_debug_info": "הפעל לוגים של ניקוי באגים עבור פענוח מוגבל",
"auto_score_label": "דירוג אוטומטי",
"auto_score_info": "חשב אוטומטית ציוני איכות לכל קטעי האודיו שנוצרו",
"auto_lrc_label": "LRC אוטומטי",
"auto_lrc_info": "צור אוטומטית חותמות זמן למילים (LRC) לכל קטעי האודיו",
"lm_batch_chunk_label": "גודל מקטע מנת LM",
"lm_batch_chunk_info": "מקסימום פריטים למקטע מנת LM (ברירת מחדל: 8, מוגבל ע\"י זיכרון ה-GPU)",
"codes_strength_label": "חוזק קודי LM",
"codes_strength_info": "שליטה בכמות צעדי הניקוי מרעש המשתמשים בקודים שנוצרו ע\"י ה-LM",
"cover_strength_label": "חוזק כיסוי אודיו (Audio Cover)",
"cover_strength_info": "שליטה בכמות צעדי הניקוי מרעש המשתמשים במצב כיסוי",
"remix_strength_label": "חוזק רמיקס",
"remix_strength_info": "שליטה בכמה הרמיקס סוטה מהאודיו המקורי (נמוך = קרוב יותר למקור)",
"cover_noise_strength_label": "חוזק כיסוי",
"cover_noise_strength_info": "שליטה בשחזור המלודיה במצב Remix. מומלץ: השתמשו במודל SFT עם ערך של 0.10.25. העלאה קלה משחזרת את המלודיה, אך העברת סגנון עשויה לדרוש כוונון פרומפט נוסף. (0 = רעש טהור/ללא כיסוי, 1 = הכי קרוב לאודיו המקורי)",
"score_sensitivity_label": "רגישות ציון איכות",
"score_sensitivity_info": "נמוך יותר = רגיש יותר (ברירת מחדל: 1.0)",
"think_label": "חשיבה (Think)",
"parallel_thinking_label": "חשיבה מקבילית",
"parallel_thinking_info": "עיבוד מקבילי של דגימות אצווה ליצירה מהירה",
"generate_btn": "🎵 צור מוזיקה",
"autogen_label": "יצירה אוטומטית",
"caption_rewrite_label": "שכתוב תיאור",
"caption_rewrite_info": "שימוש ב-LM לשכתוב תיאור לפני יצירה"
},
"results": {
"title": "🎵 תוצאות",
"generated_music": "🎵 מוזיקה שנוצרה (דגימה {n})",
"send_to_remix_btn": "🔗 שלח לרמיקס",
"send_to_repaint_btn": "🔗 שלח לצביעה מחדש",
"save_btn": "💾 שמירה",
"score_btn": "📊 קבל דירוג",
"lrc_btn": "🎵 קבל LRC",
"save_lrc_btn": "💾 שמור LRC",
"convert_to_codes_btn": "🔄 המר לקודים",
"quality_score_label": "ציון איכות (דגימה {n})",
"quality_score_placeholder": "לחץ על 'דירוג' לחישוב ציון איכות מבוסס מורכבות (Perplexity)",
"codes_label": "קודי LM (דגימה {n})",
"lrc_label": "חותמות זמן למילים (דגימה {n})",
"lrc_placeholder": "לחץ על 'LRC' ליצירת חותמות זמן",
"details_accordion": "📊 דירוג, LRC וקודי LM",
"generation_status": "מצב יצירה",
"current_batch": "מנה נוכחית",
"batch_indicator": "מנה {current} / {total}",
"next_batch_status": "מצב המנה הבאה",
"prev_btn": "◀ הקודם",
"next_btn": "הבא ▶",
"restore_params_btn": "↙️ החל הגדרות אלו על הממשק (שחזור פרמטרי מנה)",
"batch_results_title": "📁 תוצאות המנה ופרטי יצירה",
"all_files_label": "📁 כל הקבצים שנוצרו (הורדה)",
"generation_details": "פרטי יצירה"
},
"messages": {
"no_audio_to_save": "❌ אין אודיו לשמירה",
"save_success": "✅ האודיו והמטא-דאטה נשמרו ב-{filename}",
"save_failed": "❌ השמירה נכשלה: {error}",
"no_file_selected": "⚠️ לא נבחר קובץ",
"params_loaded": "✅ הפרמטרים נטענו מ-{filename}",
"invalid_json": "❌ קובץ JSON לא תקין: {error}",
"load_error": "❌ שגיאה בטעינת הקובץ: {error}",
"example_loaded": "📁 נטען דגם מ-{filename}",
"example_failed": "נכשל ניתוח קובץ ה-JSON ב-{filename}: {error}",
"example_error": "שגיאה בטעינת הדגם: {error}",
"lm_generated": "🤖 נוצר דגם באמצעות ה-LM",
"lm_fallback": "יצירת דגם באמצעות ה-LM נכשלה, חוזר לשימוש בספריית הדגמים",
"lm_not_initialized": "❌ 5Hz LM לא מאותחל. נא לאתחל אותו תחילה.",
"think_requires_lm": "⚠️ 'Think' דורש שה-5Hz LM יהיה מאותחל. Think הושבת — היצירה תמשיך ללא חשיבת LM.",
"autogen_enabled": "🔄 יצירה אוטומטית הופעלה - המנה הבאה תיווצר לאחר מכן",
"batch_ready": "✅ מנה {n} מוכנה! לחץ על 'הבא' לצפייה.",
"batch_generating": "🔄 מתחיל יצירת רקע עבור מנה {n}...",
"batch_failed": "❌ יצירת הרקע נכשלה: {error}",
"viewing_batch": "✅ צופה במנה {n}",
"at_first_batch": "נמצא כבר במנה הראשונה",
"at_last_batch": "אין מנה באה זמינה",
"batch_not_found": "מנה {n} לא נמצאה בתור",
"no_batch_data": "לא נמצאו נתוני מנה לשחזור.",
"params_restored": "✅ פרמטרי הממשק שוחזרו ממנה {n}",
"scoring_failed": "❌ שגיאה: נתוני המנה לא נמצאו",
"no_codes": "❌ אין קודי אודיו זמינים. נא ליצור מוזיקה תחילה.",
"score_failed": "❌ הדירוג נכשל: {error}",
"score_error": "❌ שגיאה בחישוב הציון: {error}",
"lrc_no_batch_data": "❌ לא נמצאו נתוני מנה. נא ליצור מוזיקה תחילה.",
"lrc_no_extra_outputs": "❌ לא נמצאו פלטים נוספים. טנזורי התניה אינם זמינים.",
"lrc_missing_tensors": "❌ חסרים טנזורים נדרשים ליצירת LRC.",
"lrc_sample_not_exist": "❌ הדגימה אינה קיימת במנה הנוכחית.",
"lrc_empty_result": "⚠️ יצירת ה-LRC הפיקה תוצאה ריקה.",
"empty_query": "⚠️ נא להזין תיאור מוזיקלי.",
"sample_creation_failed": "❌ יצירת הדגימה נכשלה. נא לנסות שוב.",
"sample_created": "✅ הדגימה נוצרה! בדוק את התיאור והמילים, ולאחר מכן לחץ על 'צור מוזיקה'.",
"simple_examples_not_found": "⚠️ ספריית הדגמים של המצב הפשוט לא נמצאה.",
"simple_examples_empty": "⚠️ לא נמצאו קבצי דוגמה במצב פשוט.",
"simple_example_loaded": "🎲 נטענה דוגמה אקראית מ-{filename}",
"format_success": "✅ התיאור והמילים פורמטו בהצלחה",
"format_failed": "❌ הפירמוט נכשל: {error}",
"skipping_metas_cot": "⚡ מדלג על שלב 1 של מטא-דאטה COT (הדגימה כבר מפורמטת)",
"invalid_timesteps_format": "⚠️ פורמט צעדי זמן לא תקין. משתמש בלוח זמנים כברירת מחדל.",
"timesteps_out_of_range": "⚠️ צעדי הזמן חייבים להיות בטווח [0, 1]. משתמש בלוח זמנים כברירת מחדל.",
"timesteps_count_mismatch": "⚠️ מספר צעדי הזמן ({actual}) שונה מצעדי ההסקה ({expected}). משתמש במספר צעדי הזמן."
},
"training": {
"tab_title": "🎓 אימון LoRA",
"tab_dataset_builder": "📁 בונה מערך נתונים",
"tab_train_lora": "🚀 אימון LoRA",
"quick_start_title": "🚀 התחלה מהירה",
"load_dataset_label": "נתיב קובץ JSON של מערך הנתונים",
"load_dataset_info": "טעינת מערך נתונים שנשמר בעבר",
"load_btn": "📂 טעינה",
"load_status": "מצב טעינה",
"scan_label": "נתיב ספריית אודיו",
"scan_info": "סריקה אחר קבצי אודיו (wav, mp3, flac, ogg, opus)",
"scan_btn": "🔍 סריקה",
"scan_status": "מצב סריקה",
"found_audio_files": "קבצי אודיו שנמצאו",
"dataset_name": "שם מערך הנתונים",
"dataset_name_placeholder": "הזן שם למערך הנתונים",
"dataset_settings_header": "הגדרות מערך נתונים",
"tag_prepend": "הוספה בהתחלה (תגית, תיאור)",
"tag_append": "הוספה בסוף (תיאור, תגית)",
"tag_replace": "החלפת התיאור",
"step2_title": "שלב 2: תיוג אוטומטי באמצעות AI",
"step2_instruction": "לחץ על הכפתור למטה כדי ליצור אוטומטית מטא-נתונים עבור כל קבצי האודיו באמצעות AI:\n• **תיאור**: סגנון מוזיקה, ז'אנר, תיאור מצב רוח\n• **BPM**: פעימות לדקה\n• **מפתח**: מפתח מוזיקלי (לדוגמה, C Major, Am)\n• **חתימת זמן**: 4/4, 3/4, וכו'",
"step3_title": "שלב 3: תצוגה מקדימה ועריכה",
"step4_title": "שלב 4: שמירת מערך הנתונים",
"step5_title": "שלב 5: עיבוד מקדים לטנזורים (Tensors)",
"step5_intro": "**העיבוד המקדים ממיר את מערך הנתונים שלך לטנזורים מחושבים מראש לאימון מהיר.**\n\nאתה יכול:\n• להשתמש במערך הנתונים משלבים 1-4 למעלה, **או**\n• לטעון קובץ JSON של מערך נתונים קיים (אם כבר שמרת אחד)",
"step5_details": "שלב זה:\n• מקודד אודיו ל-latents של VAE\n• מקודד תיאורים ומילים ל-embeddings טקסט\n• מפעיל את מקודד התנאי\n• שומר את כל הטנזורים לקבצי `.pt`\n\n⚠ **זה דורש טעינת המודל ועשוי לקחת מספר דקות.**",
"train_tensor_selection_desc": "בחר את הספרייה המכילה קבצי טנזור מעובדים מראש (קבצי `.pt`).\nאלה נוצרים בכרטיסייה \"בונה מערך נתונים\" באמצעות כפתור \"עיבוד מקדים\".",
"all_instrumental": "הכל אינסטרומנטלי",
"all_instrumental_info": "סמן אם כל הרצועות הן כליות (ללא שירה)",
"custom_tag": "תגית הפעלה מותאמת אישית",
"custom_tag_info": "תגית ייחודית להפעלת הסגנון של LoRA זו",
"tag_position": "מיקום התגית",
"tag_position_info": "היכן למקם את התגית המותאמת אישית בתוך התיאור",
"genre_ratio": "יחס ז'אנר (%)",
"genre_ratio_info": "0% = הכל תיאור, 100% = הכל ז'אנר. הגדרה פר-דגימה קודמת להגדרת הכלל.",
"skip_metas": "דלג על BPM/סולם/משקל",
"skip_metas_info": "דלג על יצירת BPM/סולם/משקל. התיאור והז'אנר עדיין ייווצרו על ידי ה-LLM.",
"only_unlabeled": "רק כאלו ללא תיוג",
"only_unlabeled_info": "תייג רק דגימות ללא תיאור (שימושי להמשך תיוג שנכשל)",
"auto_label_btn": "🏷️ תיוג אוטומטי של הכל",
"label_progress": "התקדמות התיוג",
"select_sample": "בחר דגימה #",
"select_sample_info": "בחר דגימה לצפייה ועריכה",
"audio_preview": "תצוגה מקדימה של אודיו",
"filename": "שם קובץ",
"caption": "תיאור",
"genre": "ז'אנר",
"prompt_override_label": "דריסת פרומפט (לדגימה זו)",
"prompt_override_info": "דריסת היחס הכללי עבור דגימה זו",
"lyrics_editable_label": "מילים (ניתן לעריכה, משמש לאימון)",
"raw_lyrics_label": "מילים גולמיות (מתוך קובץ .txt)",
"no_lyrics_placeholder": "(אין קובץ מילים .txt)",
"bpm": "BPM",
"key_label": "סולם (Key)",
"key_placeholder": "C Major",
"time_sig": "משקל מוזיקלי",
"duration_s": "משך (שניות)",
"language": "שפה",
"instrumental": "אינסטרומנטלי",
"save_changes_btn": "💾 שמירת שינויים",
"edit_status": "מצב עריכה",
"save_path": "נתיב שמירה",
"save_path_info": "הנתיב שבו יישמר קובץ ה-JSON של מערך הנתונים",
"save_dataset_btn": "💾 שמירת מערך נתונים",
"save_status": "מצב שמירה",
"load_existing_label": "טעינת מערך נתונים קיים (אופציונלי)",
"load_existing_info": "נתיב לקובץ JSON של מערך נתונים שנשמר בעבר",
"load_dataset_btn": "📂 טעינת מערך נתונים",
"tensor_output_dir": "ספריית פלט של טנזורים",
"tensor_output_info": "הספרייה לשמירת קבצי טנזור שעברו עיבוד מקדים",
"preprocess_btn": "⚡ עיבוד מקדים",
"preprocess_progress": "התקדמות עיבוד מקדים",
"preprocessed_tensors_dir": "ספריית טנזורים מעובדים",
"preprocessed_tensors_info": "ספרייה המכילה קבצי .pt של טנזורים מעובדים",
"train_section_tensors": "בחירת מערך נתונים מעובד",
"train_section_lora": "הגדרות LoRA",
"train_section_params": "פרמטרי אימון",
"dataset_info": "מידע על מערך הנתונים",
"lora_rank": "דרגת LoRA (Rank)",
"lora_rank_info": "גבוה יותר = יותר קיבולת, יותר זיכרון",
"lora_alpha": "LoRA Alpha",
"lora_alpha_info": "פקטור קנה מידה (בדרך כלל פי 2 מה-Rank)",
"lora_dropout": "LoRA Dropout",
"learning_rate": "קצב למידה (Learning Rate)",
"learning_rate_info": "התחל עם 3e-4, שנה במידת הצורך",
"max_epochs": "מקסימום תקופות (Epochs)",
"batch_size": "גודל מנה (Batch Size)",
"batch_size_info": "הגדל אם יש לך מספיק זיכרון גרפי (VRAM)",
"gradient_accumulation": "צבירת גרדיאנטים (Accumulation)",
"gradient_accumulation_info": "גודל מנה אפקטיבי = גודל מנה × צבירה",
"save_every_n_epochs": "שמור כל N תקופות (Epochs)",
"shift": "Shift (הסטה)",
"shift_info": "הסטת צעדי זמן עבור מודל turbo",
"seed": "גרעין (Seed)",
"output_dir": "ספריית פלט",
"output_dir_info": "ספרייה לשמירת משקולות ה-LoRA המאומנות",
"start_training_btn": "🚀 התחלת אימון",
"stop_training_btn": "⏹️ עצירת אימון",
"training_progress": "התקדמות האימון",
"training_log": "יומן אימון",
"training_loss_title": "הפסד אימון (Training Loss)",
"step": "צעד",
"loss": "הפסד (Loss)",
"export_header": "ייצוא LoRA",
"export_path": "נתיב ייצוא",
"export_lora_btn": "📦 ייצוא LoRA",
"export_status": "מצב ייצוא"
}
}

View file

@ -53,12 +53,19 @@
"compile_model_info": "torch.compileでモデルを最適化量子化に必要",
"quantization_label": "INT8 量子化",
"quantization_info": "INT8重み量子化を有効にしてVRAMを節約モデルのコンパイルが必要",
"mlx_dit_label": "MLX DiT (Apple Silicon)",
"mlx_dit_info_enabled": "Apple SiliconでMLXネイティブDiT拡散を使用MPSより高速",
"mlx_dit_info_disabled": "MLXは利用不可macOS + Apple Silicon + mlxパッケージが必要",
"init_btn": "サービスを初期化",
"status_label": "ステータス",
"language_label": "UI言語",
"language_info": "インターフェース言語を選択"
"language_info": "インターフェース言語を選択",
"gpu_auto_tier": "自動検出ティア",
"tier_label": "GPU ティアの手動選択",
"tier_info": "GPUティアを手動で選択して最適化のデフォルトオフロード、量子化、バックエンドなどを調整します"
},
"generation": {
"tab_title": "🎵 生成",
"required_inputs": "📝 必須入力",
"task_type_label": "タスクタイプ",
"task_type_info": "生成のタスクタイプを選択",
@ -71,8 +78,11 @@
"track_classes_info": "completeタスクの複数のトラッククラスを選択",
"audio_uploads": "🎵 オーディオアップロード",
"reference_audio": "リファレンスオーディオ(オプション)",
"source_audio": "ソースオーディオ(オプション)",
"source_audio": "ソースオーディオ",
"convert_codes_btn": "コードに変換",
"analyze_btn": "🔍 分析",
"sample_btn": "🎲 お試し",
"load_btn": "📂 読込",
"lm_codes_hints": "🎼 LM コードヒント",
"lm_codes_label": "LM コードヒント",
"lm_codes_placeholder": "<|audio_code_10695|><|audio_code_54246|>...",
@ -84,13 +94,20 @@
"repainting_start": "再描画開始",
"repainting_end": "再描画終了",
"mode_label": "生成モード",
"mode_info": "シンプル:自然言語で音楽を説明。カスタム:キャプションと歌詞を完全にコントロール。",
"mode_info": "生成モードを選択して開始します。",
"mode_info_simple": "自然言語で音楽を説明すると、AIがキャプション、歌詞、メタデータを自動生成します。",
"mode_info_custom": "キャプション、歌詞、すべてのパラメータを完全にコントロール。",
"mode_info_remix": "ソースオーディオをアップロードして、キャプションと歌詞でリミックスバージョンを作成。",
"mode_info_repaint": "ソースオーディオをアップロードして、指定した時間範囲を再描画。",
"mode_info_extract": "ソースオーディオから特定のトラック(ボーカル、ドラムなど)を抽出。",
"mode_info_lego": "トラックの再構成:ソースオーディオの特定のトラックを置き換え。",
"mode_info_complete": "ソースオーディオの欠落したトラックを補完。",
"mode_simple": "シンプル",
"mode_custom": "カスタム",
"simple_query_label": "曲の説明",
"simple_query_placeholder": "作成したい音楽を説明してください。例:'静かな夜のための優しいベンガルのラブソング'。空欄の場合はランダムなサンプルが生成されます。",
"simple_query_info": "生成したい音楽の自然言語の説明を入力",
"simple_vocal_language_label": "ボーカル言語(オプション)",
"simple_vocal_language_label": "ボーカル言語",
"simple_vocal_language_info": "歌詞の希望言語を選択。任意の言語の場合は'unknown'を使用。",
"create_sample_btn": "サンプル作成",
"caption_title": "📝 音楽キャプション",
@ -103,12 +120,20 @@
"lyrics_info": "構造を持つ曲の歌詞",
"instrumental_label": "インストゥルメンタル",
"format_btn": "フォーマット",
"format_caption_btn": "キャプションを強化",
"format_lyrics_btn": "歌詞を強化",
"optional_params": "⚙️ オプションパラメータ",
"optional_music_props": "🎵 音楽プロパティ",
"optional_gen_settings": "📐 生成設定",
"advanced_dit_section": "🎛️ DiT 拡散パラメータ",
"advanced_lm_section": "🤖 LM 生成パラメータ",
"advanced_output_section": "🔊 オーディオ出力と後処理",
"advanced_automation_section": "⚡ 自動化とバッチ",
"vocal_language_label": "ボーカル言語(オプション)",
"vocal_language_info": "インストには`unknown`を使用",
"vocal_language_info": "'unknown' = インスト/自動",
"bpm_label": "BPM(オプション)",
"bpm_info": "空白の場合はN/A",
"keyscale_label": "キースケール(オプション)",
"keyscale_label": "キー(オプション)",
"keyscale_placeholder": "空白の場合はN/A",
"keyscale_info": "A-G, #/♭, メジャー/マイナー",
"timesig_label": "拍子記号(オプション)",
@ -117,7 +142,7 @@
"duration_info": "ランダムの場合は-1を使用",
"batch_size_label": "バッチサイズ",
"batch_size_info": "生成するオーディオの数(最大8)",
"advanced_settings": "🔧 詳細設定",
"advanced_settings": "⚙️ 設定",
"inference_steps_label": "DiT 推論ステップ",
"inference_steps_info": "Turbo: 最大8、Base: 最大200",
"guidance_scale_label": "DiT ガイダンススケール(baseモデルのみサポート)",
@ -168,21 +193,31 @@
"similarity_denoise_info": "出力が参照オーディオにどれだけ忠実かを制御します。高い値ほど構造を保持します。",
"cover_strength_label": "オーディオカバー強度",
"cover_strength_info": "カバーモードを使用するデノイジングステップ数を制御",
"remix_strength_label": "リミックス強度",
"remix_strength_info": "リミックスがソースオーディオからどれだけ逸脱するかを制御(低い = オリジナルに近い)",
"cover_noise_strength_label": "カバー強度",
"cover_noise_strength_info": "Remixモードでのメロディ復元を制御します。推奨SFTモデルを使用し、値を0.1〜0.25に設定。わずかに上げるとメロディが復元されますが、スタイル変換には追加のプロンプト調整が必要な場合があります。0 = 純粋なノイズ/カバーなし、1 = オリジナルに最も近い)",
"score_sensitivity_label": "品質スコア感度",
"score_sensitivity_info": "低い = より敏感(デフォルト: 1.0)。PMIが[0,1]にマッピングする方法を調整",
"think_label": "思考",
"parallel_thinking_label": "並列思考",
"parallel_thinking_info": "バッチサンプルを並列処理して高速生成",
"generate_btn": "🎵 音楽を生成",
"autogen_label": "自動生成",
"caption_rewrite_label": "キャプション書き換え"
"caption_rewrite_label": "キャプション書き換え",
"caption_rewrite_info": "生成前にLMでキャプションを書き換え",
"advanced_dit_params": "DiT 詳細パラメータ"
},
"results": {
"title": "🎵 結果",
"generated_music": "🎵 生成された音楽(サンプル {n})",
"send_to_src_btn": "🔗 ソースオーディオに送信",
"send_to_remix_btn": "🔗 リミックスに送信",
"send_to_repaint_btn": "🔗 リペイントに送信",
"save_btn": "💾 保存",
"score_btn": "📊 スコア",
"lrc_btn": "🎵 LRC",
"score_btn": "📊 スコア取得",
"lrc_btn": "🎵 LRC 取得",
"save_lrc_btn": "💾 LRC 保存",
"convert_to_codes_btn": "🔄 コードに変換",
"quality_score_label": "品質スコア(サンプル {n})",
"quality_score_placeholder": "'スコア'をクリックしてパープレキシティベースの品質スコアを計算",
"codes_label": "LM コード(サンプル {n})",
@ -214,6 +249,7 @@
"lm_generated": "🤖 LMを使用してサンプルを生成しました",
"lm_fallback": "LMを使用したサンプル生成に失敗、サンプルディレクトリにフォールバック",
"lm_not_initialized": "❌ 5Hz LMが初期化されていません。最初に初期化してください。",
"think_requires_lm": "⚠️ 「Think」機能には5Hz LMの初期化が必要です。Thinkは無効化されました — LM思考なしで生成を続行します。",
"autogen_enabled": "🔄 自動生成が有効 - このあと次のバッチを生成します",
"batch_ready": "✅ バッチ {n} の準備完了!'次へ'をクリックして表示。",
"batch_generating": "🔄 バッチ {n} のバックグラウンド生成を開始...",
@ -354,5 +390,16 @@
"export_path": "エクスポートパス",
"export_lora_btn": "📦 LoRA をエクスポート",
"export_status": "エクスポート状態"
},
"gen": {
"enable_normalization": "ノーマライズを有効にする",
"enable_normalization_info": "クリッピングを防ぎ、ラウドネスを一定にするために、オーディオ音量をターゲットピークレベルに正規化します。",
"normalization_db": "ターゲットピーク (dB)",
"normalization_db_info": "デシベル単位のターゲットピークレベル。-1.0 dBが標準的な安全ピークです。-0.1 dBが最大です。",
"advanced_dit_params": "DiT 詳細パラメータ",
"latent_shift": "Latent シフト",
"latent_shift_info": "VAEデコード前にDiT latentに適用するシフト量。デフォルト0シフトなし。負の値例: -0.04)でクリッピングを軽減できます。",
"latent_rescale": "Latent リスケール",
"latent_rescale_info": "VAEデコード前にDiT latentに適用するスケール係数。デフォルト1.0スケーリングなし。1.0未満の値(例: 0.91)でクリッピングを軽減できます。"
}
}
}

View file

@ -53,12 +53,19 @@
"compile_model_info": "使用 torch.compile 优化模型(量化必需)",
"quantization_label": "INT8 量化",
"quantization_info": "启用 INT8 仅权重量化以减少显存占用(需要启用编译模型)",
"mlx_dit_label": "MLX DiT (Apple Silicon)",
"mlx_dit_info_enabled": "使用原生 MLX 加速 DiT 扩散推理(比 MPS 更快)",
"mlx_dit_info_disabled": "MLX 不可用(需要 macOS + Apple Silicon + mlx 包)",
"init_btn": "初始化服务",
"status_label": "状态",
"language_label": "界面语言",
"language_info": "选择界面语言"
"language_info": "选择界面语言",
"gpu_auto_tier": "自动检测层级",
"tier_label": "GPU 层级覆盖",
"tier_info": "手动选择 GPU 层级以调整优化默认值(卸载、量化、后端等)"
},
"generation": {
"tab_title": "🎵 生成",
"required_inputs": "📝 必需输入",
"task_type_label": "任务类型",
"task_type_info": "选择生成的任务类型",
@ -71,8 +78,11 @@
"track_classes_info": "为complete任务选择多个音轨类别",
"audio_uploads": "🎵 音频上传",
"reference_audio": "参考音频(可选)",
"source_audio": "源音频(可选)",
"source_audio": "源音频",
"convert_codes_btn": "转换为代码",
"analyze_btn": "🔍 分析",
"sample_btn": "🎲 试试看",
"load_btn": "📂 加载",
"lm_codes_hints": "🎼 LM 代码提示",
"lm_codes_label": "LM 代码提示",
"lm_codes_placeholder": "<|audio_code_10695|><|audio_code_54246|>...",
@ -84,13 +94,20 @@
"repainting_start": "重绘开始",
"repainting_end": "重绘结束",
"mode_label": "生成模式",
"mode_info": "简单模式:用自然语言描述音乐。自定义模式:完全控制描述和歌词。",
"mode_info": "选择一种生成模式开始创作。",
"mode_info_simple": "用自然语言描述你想要的音乐AI 将自动生成描述、歌词和元数据。",
"mode_info_custom": "完全控制描述、歌词和所有参数。",
"mode_info_remix": "上传源音频,用你的描述和歌词创建混音版本。",
"mode_info_repaint": "上传源音频,重绘指定时间范围的内容。",
"mode_info_extract": "从源音频中提取特定音轨(人声、鼓等)。",
"mode_info_lego": "重新组装音轨:替换源音频中的特定音轨。",
"mode_info_complete": "补全源音频中缺失的音轨。",
"mode_simple": "简单",
"mode_custom": "自定义",
"simple_query_label": "歌曲描述",
"simple_query_placeholder": "描述你想创作的音乐,例如:'给我生成一首暗黑的戏剧古风,歌词要华丽'。留空则随机生成样本。",
"simple_query_info": "输入你想生成的音乐的自然语言描述",
"simple_vocal_language_label": "人声语言(可选)",
"simple_vocal_language_label": "人声语言",
"simple_vocal_language_info": "选择歌词的首选语言。使用 'unknown' 表示任意语言。",
"create_sample_btn": "创建样本",
"caption_title": "📝 音乐描述",
@ -103,12 +120,20 @@
"lyrics_info": "带有结构的歌曲歌词",
"instrumental_label": "纯音乐",
"format_btn": "格式化",
"format_caption_btn": "增强描述",
"format_lyrics_btn": "增强歌词",
"optional_params": "⚙️ 可选参数",
"optional_music_props": "🎵 音乐属性",
"optional_gen_settings": "📐 生成设置",
"advanced_dit_section": "🎛️ DiT 扩散参数",
"advanced_lm_section": "🤖 LM 生成参数",
"advanced_output_section": "🔊 音频输出与后处理",
"advanced_automation_section": "⚡ 自动化与批量",
"vocal_language_label": "人声语言(可选)",
"vocal_language_info": "纯音乐使用 `unknown`",
"vocal_language_info": "'unknown' = 纯音乐/自动",
"bpm_label": "BPM(可选)",
"bpm_info": "留空表示N/A",
"keyscale_label": "调性(可选)",
"keyscale_label": "调 (可选)",
"keyscale_placeholder": "留空表示N/A",
"keyscale_info": "A-G, #/♭, 大调/小调",
"timesig_label": "拍号(可选)",
@ -117,7 +142,7 @@
"duration_info": "使用-1表示随机",
"batch_size_label": "批量大小",
"batch_size_info": "要生成的音频数量(最多8个)",
"advanced_settings": "🔧 高级设置",
"advanced_settings": "⚙️ 设置",
"inference_steps_label": "DiT 推理步数",
"inference_steps_info": "Turbo: 最多8, Base: 最多200",
"guidance_scale_label": "DiT 引导比例(仅支持base模型)",
@ -150,8 +175,8 @@
"lm_negative_prompt_label": "LM 负面提示",
"lm_negative_prompt_placeholder": "输入CFG的负面提示(默认: NO USER INPUT)",
"lm_negative_prompt_info": "负面提示(当LM CFG比例 > 1.0时使用)",
"cot_metas_label": "CoT 元数据",
"cot_metas_info": "使用LM生成CoT元数据(取消勾选以跳过LM CoT生成)",
"cot_metas_label": "CoT メタデータ",
"cot_metas_info": "LMを使用してCoTメタデータを生成(チェックを外すとLM CoT生成をスキップ)",
"cot_language_label": "CoT 语言",
"cot_language_info": "在CoT中生成语言(思维链)",
"constrained_debug_label": "约束解码调试",
@ -168,21 +193,31 @@
"similarity_denoise_info": "控制输出与参考音频的贴合程度。数值越高保留越多结构。",
"cover_strength_label": "音频覆盖强度",
"cover_strength_info": "控制使用覆盖模式的去噪步骤数量",
"remix_strength_label": "混音强度",
"remix_strength_info": "控制混音与原始音频的偏离程度(越低越接近原始音频)",
"cover_noise_strength_label": "翻唱强度",
"cover_noise_strength_info": "控制 Remix 模式下的旋律还原程度。推荐:使用 SFT 模型,值设为 0.10.25。稍微提升即可还原旋律,但风格迁移可能需要额外的 prompt 调参。0 = 纯噪声/无翻唱效果1 = 最接近原始音频)",
"score_sensitivity_label": "质量评分敏感度",
"score_sensitivity_info": "更低 = 更敏感(默认: 1.0). 调整PMI如何映射到[0,1]",
"think_label": "思考",
"parallel_thinking_label": "并行思考",
"parallel_thinking_info": "并行处理批量样本以加速生成",
"generate_btn": "🎵 生成音乐",
"autogen_label": "自动生成",
"caption_rewrite_label": "描述重写"
"caption_rewrite_label": "描述重写",
"caption_rewrite_info": "生成前使用LM重写描述",
"advanced_dit_params": "DiT 高级参数"
},
"results": {
"title": "🎵 结果",
"generated_music": "🎵 生成的音乐(样本 {n})",
"send_to_src_btn": "🔗 发送到源音频",
"send_to_remix_btn": "🔗 发送到混音",
"send_to_repaint_btn": "🔗 发送到重绘",
"save_btn": "💾 保存",
"score_btn": "📊 评分",
"lrc_btn": "🎵 LRC",
"score_btn": "📊 获取评分",
"lrc_btn": "🎵 获取 LRC",
"save_lrc_btn": "💾 保存 LRC",
"convert_to_codes_btn": "🔄 转换为代码",
"quality_score_label": "质量分数(样本 {n})",
"quality_score_placeholder": "点击'评分'以计算基于困惑度的质量分数",
"codes_label": "LM 代码(样本 {n})",
@ -214,6 +249,7 @@
"lm_generated": "🤖 使用LM生成的示例",
"lm_fallback": "使用LM生成示例失败,回退到示例目录",
"lm_not_initialized": "❌ 5Hz LM未初始化。请先初始化它。",
"think_requires_lm": "⚠️ \"Think\"功能需要先初始化 5Hz LM。Think 已自动关闭,生成将不使用 LM 思考。",
"autogen_enabled": "🔄 已启用自动生成 - 下一批次将在此之后生成",
"batch_ready": "✅ 批次 {n} 就绪!点击'下一个'查看。",
"batch_generating": "🔄 开始为批次 {n} 进行后台生成...",
@ -252,6 +288,7 @@
"tab_train_lora": "🚀 训练 LoRA",
"quick_start_title": "🚀 快速开始",
"load_dataset_label": "数据集 JSON 路径",
"load_dataset_info": "加载之前保存的数据集",
"load_btn": "📂 加载",
"load_status": "加载状态",
"scan_label": "音频目录路径",
@ -272,6 +309,9 @@
"step5_title": "步骤 5预处理为张量",
"step5_intro": "**预处理将您的数据集转换为预计算的张量,以便快速训练。**\n\n您可以\n- 使用上述步骤 1-4 的数据集,**或者**\n- 加载现有的数据集 JSON 文件(如果您已经保存了一个)",
"step5_details": "此步骤:\n- 将音频编码为 VAE 潜变量\n- 将描述和歌词编码为文本嵌入\n- 运行条件编码器\n- 将所有张量保存到 `.pt` 文件\n\n⚠ **这需要加载模型,可能需要几分钟。**",
"train_section_tensors": "预处理数据集选择",
"train_section_lora": "LoRA 设置",
"train_section_params": "训练参数",
"train_tensor_selection_desc": "选择包含预处理张量文件(`.pt` 文件)的目录。\n这些文件在「数据集构建」标签页中使用「预处理」按钮创建。",
"all_instrumental": "全部为纯音乐",
"all_instrumental_info": "勾选表示所有曲目均为纯音乐(无人声)",
@ -350,5 +390,15 @@
"export_path": "导出路径",
"export_lora_btn": "📦 导出 LoRA",
"export_status": "导出状态"
},
"gen": {
"enable_normalization": "启用归一化",
"enable_normalization_info": "将音频音量归一化到目标峰值电平,以防止削波或确保一致的响度。",
"normalization_db": "目标峰值 (dB)",
"normalization_db_info": "分贝单位的目标峰值电平。-1.0 dB 是标准安全峰值。-0.1 dB 是最大值。",
"latent_shift": "Latent 偏移",
"latent_shift_info": "VAE 解码前对 DiT latent 施加的偏移量。默认 0不偏移。负值如 -0.04)可减少爆音。",
"latent_rescale": "Latent 缩放",
"latent_rescale_info": "VAE 解码前对 DiT latent 施加的缩放因子。默认 1.0(不缩放)。小于 1.0 的值(如 0.91)可减少爆音。"
}
}
}

View file

@ -1,11 +1,33 @@
"""
Gradio UI Components Module
Contains all Gradio interface component definitions and layouts
Layout:
Header
Dataset Explorer (hidden accordion)
Settings (accordion, collapsed)
Service Configuration
DiT Parameters
LM Parameters
Output / Automation
Generation Training
Mode Radio Dataset/LoRA
Inputs
Results
"""
import gradio as gr
from acestep.gradio_ui.i18n import get_i18n, t
from acestep.gradio_ui.interfaces.dataset import create_dataset_section
from acestep.gradio_ui.interfaces.generation import create_generation_section
from acestep.gradio_ui.interfaces.generation import (
create_advanced_settings_section,
create_generation_tab_section,
)
from acestep.gradio_ui.interfaces.result import create_results_section
from acestep.gradio_ui.interfaces.training import create_training_section
from acestep.gradio_ui.events import setup_event_handlers, setup_training_event_handlers
@ -29,6 +51,9 @@ def create_gradio_interface(dit_handler, llm_handler, dataset_handler, init_para
# Initialize i18n with selected language
i18n = get_i18n(language)
# Check if running in service mode (hide training tab)
service_mode = init_params is not None and init_params.get('service_mode', False)
with gr.Blocks(
title=t("app.title"),
theme=gr.themes.Soft(),
@ -62,6 +87,34 @@ def create_gradio_interface(dit_handler, llm_handler, dataset_handler, init_para
.component-wrapper > .timestamps {
transform: translateY(15px);
}
/* Equal-height row for instrumental checkbox + enhance lyrics button */
.instrumental-row {
align-items: stretch !important;
}
.instrumental-row > div {
display: flex !important;
align-items: stretch !important;
}
.instrumental-row > div > div {
flex: 1;
display: flex;
align-items: center;
}
.instrumental-row button {
height: 100% !important;
min-height: 42px;
}
/* Ensure buttons in instrumental-row fill height */
.instrumental-row > div > button {
height: 100% !important;
min-height: 42px;
}
/* Two-line icon buttons: emoji on top, text below */
.icon-btn-wrap button, .icon-btn-wrap > button {
word-spacing: 100vw;
text-align: center;
line-height: 1.4;
}
""",
) as demo:
@ -72,21 +125,52 @@ def create_gradio_interface(dit_handler, llm_handler, dataset_handler, init_para
</div>
""")
# Dataset Explorer Section
# Dataset Explorer Section (hidden)
dataset_section = create_dataset_section(dataset_handler)
# Generation Section (pass init_params and language to support pre-initialization)
generation_section = create_generation_section(dit_handler, llm_handler, init_params=init_params, language=language)
# ═══════════════════════════════════════════
# Top-level: Settings (contains Service Config + Advanced Settings)
# ═══════════════════════════════════════════
settings_section = create_advanced_settings_section(
dit_handler, llm_handler, init_params=init_params, language=language
)
# Results Section
results_section = create_results_section(dit_handler)
# ═══════════════════════════════════════════
# Tabs: Generation | Training
# ═══════════════════════════════════════════
with gr.Tabs():
# --- Generation Tab ---
with gr.Tab(t("generation.tab_title")):
gen_section = create_generation_tab_section(
dit_handler, llm_handler, init_params=init_params, language=language
)
# Results Section (inside the Generation tab, wrapped for visibility control)
with gr.Column(visible=True) as results_wrapper:
results_section = create_results_section(dit_handler)
# Store the wrapper in gen_section so event handlers can toggle it
gen_section["results_wrapper"] = results_wrapper
# --- Training Tab ---
with gr.Tab(t("training.tab_title"), visible=not service_mode):
training_section = create_training_section(
dit_handler, llm_handler, init_params=init_params
)
# Training Section (LoRA training and dataset builder)
# Pass init_params to support hiding in service mode
training_section = create_training_section(dit_handler, llm_handler, init_params=init_params)
# ═══════════════════════════════════════════
# Merge all generation-related component dicts for event wiring
# ═══════════════════════════════════════════
# The event handlers expect a single "generation_section" dict with all
# components from settings (service config + advanced) and generation tab.
generation_section = {}
generation_section.update(settings_section)
generation_section.update(gen_section)
# Connect event handlers
setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, dataset_section, generation_section, results_section)
setup_event_handlers(
demo, dit_handler, llm_handler, dataset_handler,
dataset_section, generation_section, results_section
)
# Connect training event handlers
setup_training_event_handlers(demo, dit_handler, llm_handler, training_section)

File diff suppressed because it is too large Load diff

View file

@ -6,6 +6,93 @@ import gradio as gr
from acestep.gradio_ui.i18n import t
def _create_audio_column(n, visible=True):
"""Create a single audio sample column with all its sub-components.
Layout:
Audio player
Row: [Send To Cover] [Send To Repaint] [Save]
Accordion (Score & LRC & LM Codes):
codes_display
Row: score_display + score_btn
Row: lrc_display + lrc_btn
"""
with gr.Column(visible=visible) as audio_col:
generated_audio = gr.Audio(
label=t("results.generated_music", n=n),
type="filepath",
interactive=False,
buttons=[]
)
with gr.Row(equal_height=True):
send_to_remix_btn = gr.Button(
t("results.send_to_remix_btn"),
variant="secondary", size="sm", scale=1
)
send_to_repaint_btn = gr.Button(
t("results.send_to_repaint_btn"),
variant="secondary", size="sm", scale=1
)
save_btn = gr.Button(
t("results.save_btn"),
variant="primary", size="sm", scale=1
)
with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion:
codes_display = gr.Textbox(
label=t("results.codes_label", n=n),
interactive=False, buttons=["copy"],
lines=4, max_lines=4, visible=True
)
convert_to_codes_btn = gr.Button(
t("results.convert_to_codes_btn"),
variant="secondary", size="sm"
)
score_display = gr.Textbox(
label=t("results.quality_score_label", n=n),
interactive=False, buttons=["copy"],
lines=6, max_lines=6, visible=True
)
score_btn = gr.Button(
t("results.score_btn"),
variant="secondary", size="sm"
)
lrc_display = gr.Textbox(
label=t("results.lrc_label", n=n),
interactive=True, buttons=["copy"],
lines=8, max_lines=8, visible=True
)
with gr.Row(equal_height=True):
lrc_btn = gr.Button(
t("results.lrc_btn"),
variant="secondary", size="sm"
)
save_lrc_btn = gr.Button(
t("results.save_lrc_btn"),
variant="secondary", size="sm"
)
lrc_download_file = gr.File(
label="LRC Download",
visible=False,
interactive=False,
)
return {
"audio_col": audio_col,
"generated_audio": generated_audio,
"send_to_remix_btn": send_to_remix_btn,
"send_to_repaint_btn": send_to_repaint_btn,
"save_btn": save_btn,
"details_accordion": details_accordion,
"codes_display": codes_display,
"convert_to_codes_btn": convert_to_codes_btn,
"score_display": score_display,
"score_btn": score_btn,
"lrc_display": lrc_display,
"lrc_btn": lrc_btn,
"save_lrc_btn": save_lrc_btn,
"lrc_download_file": lrc_download_file,
}
def create_results_section(dit_handler) -> dict:
"""Create results display section"""
with gr.Accordion(t("results.title"), open=True):
@ -16,393 +103,25 @@ def create_results_section(dit_handler) -> dict:
is_format_caption_state = gr.State(value=False)
# Batch management states
current_batch_index = gr.State(value=0) # Currently displayed batch index
total_batches = gr.State(value=1) # Total number of batches generated
batch_queue = gr.State(value={}) # Dictionary storing all batch data
generation_params_state = gr.State(value={}) # Store generation parameters for next batches
is_generating_background = gr.State(value=False) # Background generation flag
current_batch_index = gr.State(value=0)
total_batches = gr.State(value=1)
batch_queue = gr.State(value={})
generation_params_state = gr.State(value={})
is_generating_background = gr.State(value=False)
# All audio components in one row with dynamic visibility
# Row 1: samples 1-4
with gr.Row():
with gr.Column(visible=True) as audio_col_1:
generated_audio_1 = gr.Audio(
label=t("results.generated_music", n=1),
type="filepath",
interactive=False,
buttons=[]
)
with gr.Row(equal_height=True):
send_to_src_btn_1 = gr.Button(
t("results.send_to_src_btn"),
variant="secondary",
size="sm",
scale=1
)
save_btn_1 = gr.Button(
t("results.save_btn"),
variant="primary",
size="sm",
scale=1
)
score_btn_1 = gr.Button(
t("results.score_btn"),
variant="secondary",
size="sm",
scale=1
)
lrc_btn_1 = gr.Button(
t("results.lrc_btn"),
variant="secondary",
size="sm",
scale=1
)
with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_1:
codes_display_1 = gr.Textbox(
label=t("results.codes_label", n=1),
interactive=False,
buttons=["copy"],
lines=4,
max_lines=4,
visible=True
)
score_display_1 = gr.Textbox(
label=t("results.quality_score_label", n=1),
interactive=False,
buttons=["copy"],
lines=6,
max_lines=6,
visible=True
)
lrc_display_1 = gr.Textbox(
label=t("results.lrc_label", n=1),
interactive=True,
buttons=["copy"],
lines=8,
max_lines=8,
visible=True
)
with gr.Column(visible=True) as audio_col_2:
generated_audio_2 = gr.Audio(
label=t("results.generated_music", n=2),
type="filepath",
interactive=False,
buttons=[]
)
with gr.Row(equal_height=True):
send_to_src_btn_2 = gr.Button(
t("results.send_to_src_btn"),
variant="secondary",
size="sm",
scale=1
)
save_btn_2 = gr.Button(
t("results.save_btn"),
variant="primary",
size="sm",
scale=1
)
score_btn_2 = gr.Button(
t("results.score_btn"),
variant="secondary",
size="sm",
scale=1
)
lrc_btn_2 = gr.Button(
t("results.lrc_btn"),
variant="secondary",
size="sm",
scale=1
)
with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_2:
codes_display_2 = gr.Textbox(
label=t("results.codes_label", n=2),
interactive=False,
buttons=["copy"],
lines=4,
max_lines=4,
visible=True
)
score_display_2 = gr.Textbox(
label=t("results.quality_score_label", n=2),
interactive=False,
buttons=["copy"],
lines=6,
max_lines=6,
visible=True
)
lrc_display_2 = gr.Textbox(
label=t("results.lrc_label", n=2),
interactive=True,
buttons=["copy"],
lines=8,
max_lines=8,
visible=True
)
with gr.Column(visible=False) as audio_col_3:
generated_audio_3 = gr.Audio(
label=t("results.generated_music", n=3),
type="filepath",
interactive=False,
buttons=[]
)
with gr.Row(equal_height=True):
send_to_src_btn_3 = gr.Button(
t("results.send_to_src_btn"),
variant="secondary",
size="sm",
scale=1
)
save_btn_3 = gr.Button(
t("results.save_btn"),
variant="primary",
size="sm",
scale=1
)
score_btn_3 = gr.Button(
t("results.score_btn"),
variant="secondary",
size="sm",
scale=1
)
lrc_btn_3 = gr.Button(
t("results.lrc_btn"),
variant="secondary",
size="sm",
scale=1
)
with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_3:
codes_display_3 = gr.Textbox(
label=t("results.codes_label", n=3),
interactive=False,
buttons=["copy"],
lines=4,
max_lines=4,
visible=True
)
score_display_3 = gr.Textbox(
label=t("results.quality_score_label", n=3),
interactive=False,
buttons=["copy"],
lines=6,
max_lines=6,
visible=True
)
lrc_display_3 = gr.Textbox(
label=t("results.lrc_label", n=3),
interactive=True,
buttons=["copy"],
lines=8,
max_lines=8,
visible=True
)
with gr.Column(visible=False) as audio_col_4:
generated_audio_4 = gr.Audio(
label=t("results.generated_music", n=4),
type="filepath",
interactive=False,
buttons=[]
)
with gr.Row(equal_height=True):
send_to_src_btn_4 = gr.Button(
t("results.send_to_src_btn"),
variant="secondary",
size="sm",
scale=1
)
save_btn_4 = gr.Button(
t("results.save_btn"),
variant="primary",
size="sm",
scale=1
)
score_btn_4 = gr.Button(
t("results.score_btn"),
variant="secondary",
size="sm",
scale=1
)
lrc_btn_4 = gr.Button(
t("results.lrc_btn"),
variant="secondary",
size="sm",
scale=1
)
with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_4:
codes_display_4 = gr.Textbox(
label=t("results.codes_label", n=4),
interactive=False,
buttons=["copy"],
lines=4,
max_lines=4,
visible=True
)
score_display_4 = gr.Textbox(
label=t("results.quality_score_label", n=4),
interactive=False,
buttons=["copy"],
lines=6,
max_lines=6,
visible=True
)
lrc_display_4 = gr.Textbox(
label=t("results.lrc_label", n=4),
interactive=True,
buttons=["copy"],
lines=8,
max_lines=8,
visible=True
)
cols_1_4 = []
for i in range(1, 5):
cols_1_4.append(_create_audio_column(i, visible=(i <= 2)))
# Second row for batch size 5-8 (initially hidden)
# Row 2: samples 5-8 (initially hidden)
with gr.Row(visible=False) as audio_row_5_8:
with gr.Column() as audio_col_5:
generated_audio_5 = gr.Audio(
label=t("results.generated_music", n=5),
type="filepath",
interactive=False,
buttons=[]
)
with gr.Row(equal_height=True):
send_to_src_btn_5 = gr.Button(t("results.send_to_src_btn"), variant="secondary", size="sm", scale=1)
save_btn_5 = gr.Button(t("results.save_btn"), variant="primary", size="sm", scale=1)
score_btn_5 = gr.Button(t("results.score_btn"), variant="secondary", size="sm", scale=1)
lrc_btn_5 = gr.Button(t("results.lrc_btn"), variant="secondary", size="sm", scale=1)
with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_5:
codes_display_5 = gr.Textbox(
label=t("results.codes_label", n=5),
interactive=False,
buttons=["copy"],
lines=4,
max_lines=4,
visible=True
)
score_display_5 = gr.Textbox(
label=t("results.quality_score_label", n=5),
interactive=False,
buttons=["copy"],
lines=6,
max_lines=6,
visible=True
)
lrc_display_5 = gr.Textbox(
label=t("results.lrc_label", n=5),
interactive=True,
buttons=["copy"],
lines=8,
max_lines=8,
visible=True
)
with gr.Column() as audio_col_6:
generated_audio_6 = gr.Audio(
label=t("results.generated_music", n=6),
type="filepath",
interactive=False,
buttons=[]
)
with gr.Row(equal_height=True):
send_to_src_btn_6 = gr.Button(t("results.send_to_src_btn"), variant="secondary", size="sm", scale=1)
save_btn_6 = gr.Button(t("results.save_btn"), variant="primary", size="sm", scale=1)
score_btn_6 = gr.Button(t("results.score_btn"), variant="secondary", size="sm", scale=1)
lrc_btn_6 = gr.Button(t("results.lrc_btn"), variant="secondary", size="sm", scale=1)
with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_6:
codes_display_6 = gr.Textbox(
label=t("results.codes_label", n=6),
interactive=False,
buttons=["copy"],
lines=4,
max_lines=4,
visible=True
)
score_display_6 = gr.Textbox(
label=t("results.quality_score_label", n=6),
interactive=False,
buttons=["copy"],
lines=6,
max_lines=6,
visible=True
)
lrc_display_6 = gr.Textbox(
label=t("results.lrc_label", n=6),
interactive=True,
buttons=["copy"],
lines=8,
max_lines=8,
visible=True
)
with gr.Column() as audio_col_7:
generated_audio_7 = gr.Audio(
label=t("results.generated_music", n=7),
type="filepath",
interactive=False,
buttons=[]
)
with gr.Row(equal_height=True):
send_to_src_btn_7 = gr.Button(t("results.send_to_src_btn"), variant="secondary", size="sm", scale=1)
save_btn_7 = gr.Button(t("results.save_btn"), variant="primary", size="sm", scale=1)
score_btn_7 = gr.Button(t("results.score_btn"), variant="secondary", size="sm", scale=1)
lrc_btn_7 = gr.Button(t("results.lrc_btn"), variant="secondary", size="sm", scale=1)
with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_7:
codes_display_7 = gr.Textbox(
label=t("results.codes_label", n=7),
interactive=False,
buttons=["copy"],
lines=4,
max_lines=4,
visible=True
)
score_display_7 = gr.Textbox(
label=t("results.quality_score_label", n=7),
interactive=False,
buttons=["copy"],
lines=6,
max_lines=6,
visible=True
)
lrc_display_7 = gr.Textbox(
label=t("results.lrc_label", n=7),
interactive=True,
buttons=["copy"],
lines=8,
max_lines=8,
visible=True
)
with gr.Column() as audio_col_8:
generated_audio_8 = gr.Audio(
label=t("results.generated_music", n=8),
type="filepath",
interactive=False,
buttons=[]
)
with gr.Row(equal_height=True):
send_to_src_btn_8 = gr.Button(t("results.send_to_src_btn"), variant="secondary", size="sm", scale=1)
save_btn_8 = gr.Button(t("results.save_btn"), variant="primary", size="sm", scale=1)
score_btn_8 = gr.Button(t("results.score_btn"), variant="secondary", size="sm", scale=1)
lrc_btn_8 = gr.Button(t("results.lrc_btn"), variant="secondary", size="sm", scale=1)
with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_8:
codes_display_8 = gr.Textbox(
label=t("results.codes_label", n=8),
interactive=False,
buttons=["copy"],
lines=4,
max_lines=4,
visible=True
)
score_display_8 = gr.Textbox(
label=t("results.quality_score_label", n=8),
interactive=False,
buttons=["copy"],
lines=6,
max_lines=6,
visible=True
)
lrc_display_8 = gr.Textbox(
label=t("results.lrc_label", n=8),
interactive=True,
buttons=["copy"],
lines=8,
max_lines=8,
visible=True
)
cols_5_8 = []
for i in range(5, 9):
cols_5_8.append(_create_audio_column(i, visible=True))
all_cols = cols_1_4 + cols_5_8
status_output = gr.Textbox(label=t("results.generation_status"), interactive=False)
@ -410,48 +129,37 @@ def create_results_section(dit_handler) -> dict:
with gr.Row(equal_height=True):
prev_batch_btn = gr.Button(
t("results.prev_btn"),
variant="secondary",
interactive=False,
scale=1,
size="sm"
variant="secondary", interactive=False, scale=1, size="sm"
)
batch_indicator = gr.Textbox(
label=t("results.current_batch"),
value=t("results.batch_indicator", current=1, total=1),
interactive=False,
scale=3
interactive=False, scale=3
)
next_batch_status = gr.Textbox(
label=t("results.next_batch_status"),
value="",
interactive=False,
scale=3
value="", interactive=False, scale=3
)
next_batch_btn = gr.Button(
t("results.next_btn"),
variant="primary",
interactive=False,
scale=1,
size="sm"
variant="primary", interactive=False, scale=1, size="sm"
)
# One-click restore parameters button
restore_params_btn = gr.Button(
t("results.restore_params_btn"),
variant="secondary",
interactive=False, # Initially disabled, enabled after generation
size="sm"
variant="secondary", interactive=False, size="sm"
)
with gr.Accordion(t("results.batch_results_title"), open=True):
generated_audio_batch = gr.File(
label=t("results.all_files_label"),
file_count="multiple",
interactive=False
file_count="multiple", interactive=False
)
generation_info = gr.Markdown(label=t("results.generation_details"))
return {
# Build return dict from all_cols
result = {
"lm_metadata_state": lm_metadata_state,
"is_format_caption_state": is_format_caption_state,
"current_batch_index": current_batch_index,
@ -465,88 +173,25 @@ def create_results_section(dit_handler) -> dict:
"next_batch_btn": next_batch_btn,
"next_batch_status": next_batch_status,
"restore_params_btn": restore_params_btn,
"generated_audio_1": generated_audio_1,
"generated_audio_2": generated_audio_2,
"generated_audio_3": generated_audio_3,
"generated_audio_4": generated_audio_4,
"generated_audio_5": generated_audio_5,
"generated_audio_6": generated_audio_6,
"generated_audio_7": generated_audio_7,
"generated_audio_8": generated_audio_8,
"audio_row_5_8": audio_row_5_8,
"audio_col_1": audio_col_1,
"audio_col_2": audio_col_2,
"audio_col_3": audio_col_3,
"audio_col_4": audio_col_4,
"audio_col_5": audio_col_5,
"audio_col_6": audio_col_6,
"audio_col_7": audio_col_7,
"audio_col_8": audio_col_8,
"send_to_src_btn_1": send_to_src_btn_1,
"send_to_src_btn_2": send_to_src_btn_2,
"send_to_src_btn_3": send_to_src_btn_3,
"send_to_src_btn_4": send_to_src_btn_4,
"send_to_src_btn_5": send_to_src_btn_5,
"send_to_src_btn_6": send_to_src_btn_6,
"send_to_src_btn_7": send_to_src_btn_7,
"send_to_src_btn_8": send_to_src_btn_8,
"save_btn_1": save_btn_1,
"save_btn_2": save_btn_2,
"save_btn_3": save_btn_3,
"save_btn_4": save_btn_4,
"save_btn_5": save_btn_5,
"save_btn_6": save_btn_6,
"save_btn_7": save_btn_7,
"save_btn_8": save_btn_8,
"score_btn_1": score_btn_1,
"score_btn_2": score_btn_2,
"score_btn_3": score_btn_3,
"score_btn_4": score_btn_4,
"score_btn_5": score_btn_5,
"score_btn_6": score_btn_6,
"score_btn_7": score_btn_7,
"score_btn_8": score_btn_8,
"score_display_1": score_display_1,
"score_display_2": score_display_2,
"score_display_3": score_display_3,
"score_display_4": score_display_4,
"score_display_5": score_display_5,
"score_display_6": score_display_6,
"score_display_7": score_display_7,
"score_display_8": score_display_8,
"codes_display_1": codes_display_1,
"codes_display_2": codes_display_2,
"codes_display_3": codes_display_3,
"codes_display_4": codes_display_4,
"codes_display_5": codes_display_5,
"codes_display_6": codes_display_6,
"codes_display_7": codes_display_7,
"codes_display_8": codes_display_8,
"lrc_btn_1": lrc_btn_1,
"lrc_btn_2": lrc_btn_2,
"lrc_btn_3": lrc_btn_3,
"lrc_btn_4": lrc_btn_4,
"lrc_btn_5": lrc_btn_5,
"lrc_btn_6": lrc_btn_6,
"lrc_btn_7": lrc_btn_7,
"lrc_btn_8": lrc_btn_8,
"lrc_display_1": lrc_display_1,
"lrc_display_2": lrc_display_2,
"lrc_display_3": lrc_display_3,
"lrc_display_4": lrc_display_4,
"lrc_display_5": lrc_display_5,
"lrc_display_6": lrc_display_6,
"lrc_display_7": lrc_display_7,
"lrc_display_8": lrc_display_8,
"details_accordion_1": details_accordion_1,
"details_accordion_2": details_accordion_2,
"details_accordion_3": details_accordion_3,
"details_accordion_4": details_accordion_4,
"details_accordion_5": details_accordion_5,
"details_accordion_6": details_accordion_6,
"details_accordion_7": details_accordion_7,
"details_accordion_8": details_accordion_8,
"generated_audio_batch": generated_audio_batch,
"generation_info": generation_info,
}
for idx, col_data in enumerate(all_cols, start=1):
result[f"generated_audio_{idx}"] = col_data["generated_audio"]
result[f"audio_col_{idx}"] = col_data["audio_col"]
result[f"send_to_remix_btn_{idx}"] = col_data["send_to_remix_btn"]
result[f"send_to_repaint_btn_{idx}"] = col_data["send_to_repaint_btn"]
result[f"save_btn_{idx}"] = col_data["save_btn"]
result[f"score_btn_{idx}"] = col_data["score_btn"]
result[f"score_display_{idx}"] = col_data["score_display"]
result[f"codes_display_{idx}"] = col_data["codes_display"]
result[f"convert_to_codes_btn_{idx}"] = col_data["convert_to_codes_btn"]
result[f"lrc_btn_{idx}"] = col_data["lrc_btn"]
result[f"lrc_display_{idx}"] = col_data["lrc_display"]
result[f"save_lrc_btn_{idx}"] = col_data["save_lrc_btn"]
result[f"lrc_download_file_{idx}"] = col_data["lrc_download_file"]
result[f"details_accordion_{idx}"] = col_data["details_accordion"]
return result

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,33 @@
# Native MLX implementation of the AceStep DiT decoder for Apple Silicon.
# Provides pure MLX inference with graceful fallback to PyTorch.
import logging
import platform
logger = logging.getLogger(__name__)
def is_mlx_available() -> bool:
"""Check if MLX is available on this platform (macOS + Apple Silicon)."""
if platform.system() != "Darwin":
return False
try:
import mlx.core as mx
import mlx.nn
# Verify we can actually create arrays (Metal backend works)
_ = mx.array([1.0])
mx.eval(_)
return True
except Exception:
return False
_MLX_AVAILABLE = None
def mlx_available() -> bool:
"""Cached check for MLX availability."""
global _MLX_AVAILABLE
if _MLX_AVAILABLE is None:
_MLX_AVAILABLE = is_mlx_available()
return _MLX_AVAILABLE

View file

@ -0,0 +1,84 @@
# Weight conversion from PyTorch AceStep DiT decoder to native MLX format.
import logging
from typing import Dict, List, Tuple
import numpy as np
logger = logging.getLogger(__name__)
def convert_decoder_weights(
pytorch_model,
) -> List[Tuple[str, "mx.array"]]:
"""Convert PyTorch decoder weights to a list of (name, mx.array) pairs
suitable for ``mlx_decoder.load_weights()``.
The function extracts weights from
``pytorch_model.decoder`` (``AceStepDiTModel``) and converts them to MLX
format, handling:
- Conv1d weight layout: PT ``[out, in, K]`` -> MLX ``[out, K, in]``
- ConvTranspose1d layout: PT ``[in, out, K]`` -> MLX ``[out, K, in]``
- nn.Sequential index remapping (Lambda wrappers removed in MLX)
- All other weights are transferred as-is
Args:
pytorch_model: The full ``AceStepConditionGenerationModel`` (PyTorch).
Returns:
List of (param_name, mx.array) pairs ready for ``model.load_weights()``.
"""
import mlx.core as mx
decoder = pytorch_model.decoder
state_dict = decoder.state_dict()
weights: List[Tuple[str, "mx.array"]] = []
for key, value in state_dict.items():
np_val = value.detach().cpu().float().numpy()
new_key = key
# PyTorch proj_in is Sequential(Lambda, Conv1d, Lambda)
# The Conv1d is at index 1. In MLX we use a bare Conv1d.
if key.startswith("proj_in.1."):
new_key = key.replace("proj_in.1.", "proj_in.")
if new_key.endswith(".weight"):
# PT Conv1d weight: [out, in, K] -> MLX: [out, K, in]
np_val = np_val.swapaxes(1, 2)
# PyTorch proj_out is Sequential(Lambda, ConvTranspose1d, Lambda)
elif key.startswith("proj_out.1."):
new_key = key.replace("proj_out.1.", "proj_out.")
if new_key.endswith(".weight"):
# PT ConvTranspose1d weight: [in, out, K] -> MLX: [out, K, in]
np_val = np_val.transpose(1, 2, 0)
# Skip rotary embedding buffers (recomputed in MLX)
elif "rotary_emb" in key:
continue
weights.append((new_key, mx.array(np_val)))
logger.info(
"[MLX-DiT] Converted %d decoder parameters to MLX format.", len(weights)
)
return weights
def convert_and_load(
pytorch_model,
mlx_decoder: "MLXDiTDecoder",
) -> None:
"""Convert PyTorch decoder weights and load them into an MLX decoder.
Args:
pytorch_model: The full AceStepConditionGenerationModel (PyTorch).
mlx_decoder: An instance of ``MLXDiTDecoder`` (already constructed).
"""
import mlx.core as mx
weights = convert_decoder_weights(pytorch_model)
mlx_decoder.load_weights(weights)
mx.eval(mlx_decoder.parameters())
logger.info("[MLX-DiT] Weights loaded and evaluated successfully.")

213
acestep/mlx_dit/generate.py Normal file
View file

@ -0,0 +1,213 @@
# MLX diffusion generation loop for AceStep DiT decoder.
#
# Replicates the timestep scheduling and ODE/SDE stepping from
# ``AceStepConditionGenerationModel.generate_audio`` using pure MLX arrays.
import logging
import time
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
logger = logging.getLogger(__name__)
# Pre-defined timestep schedules (from modeling_acestep_v15_turbo.py)
VALID_SHIFTS = [1.0, 2.0, 3.0]
VALID_TIMESTEPS = [
1.0, 0.9545454545454546, 0.9333333333333333, 0.9, 0.875,
0.8571428571428571, 0.8333333333333334, 0.7692307692307693, 0.75,
0.6666666666666666, 0.6428571428571429, 0.625, 0.5454545454545454,
0.5, 0.4, 0.375, 0.3, 0.25, 0.2222222222222222, 0.125,
]
SHIFT_TIMESTEPS = {
1.0: [1.0, 0.875, 0.75, 0.625, 0.5, 0.375, 0.25, 0.125],
2.0: [1.0, 0.9333333333333333, 0.8571428571428571, 0.7692307692307693,
0.6666666666666666, 0.5454545454545454, 0.4, 0.2222222222222222],
3.0: [1.0, 0.9545454545454546, 0.9, 0.8333333333333334, 0.75,
0.6428571428571429, 0.5, 0.3],
}
def get_timestep_schedule(
shift: float = 3.0,
timesteps: Optional[list] = None,
) -> List[float]:
"""Compute the timestep schedule for diffusion sampling.
Replicates the logic from the turbo model's ``generate_audio`` method.
Args:
shift: Diffusion timestep shift (1, 2, or 3).
timesteps: Optional custom list of timesteps.
Returns:
List of timestep values (descending, without trailing 0).
"""
t_schedule_list = None
if timesteps is not None:
ts_list = list(timesteps)
# Remove trailing zeros
while ts_list and ts_list[-1] == 0:
ts_list.pop()
if len(ts_list) < 1:
logger.warning("timesteps empty after removing zeros; using default shift=%s", shift)
else:
if len(ts_list) > 20:
logger.warning("timesteps length=%d > 20; truncating", len(ts_list))
ts_list = ts_list[:20]
# Map each timestep to the nearest valid value
mapped = [min(VALID_TIMESTEPS, key=lambda x, t=t: abs(x - t)) for t in ts_list]
t_schedule_list = mapped
if t_schedule_list is None:
original_shift = shift
shift = min(VALID_SHIFTS, key=lambda x: abs(x - shift))
if original_shift != shift:
logger.warning("shift=%.2f rounded to nearest valid shift=%.1f", original_shift, shift)
t_schedule_list = SHIFT_TIMESTEPS[shift]
return t_schedule_list
def mlx_generate_diffusion(
mlx_decoder,
encoder_hidden_states_np: np.ndarray,
context_latents_np: np.ndarray,
src_latents_shape: Tuple[int, ...],
seed: Optional[Union[int, List[int]]] = None,
infer_method: str = "ode",
shift: float = 3.0,
timesteps: Optional[list] = None,
audio_cover_strength: float = 1.0,
encoder_hidden_states_non_cover_np: Optional[np.ndarray] = None,
context_latents_non_cover_np: Optional[np.ndarray] = None,
) -> Dict[str, object]:
"""Run the complete MLX diffusion loop.
This is the core generation function. It accepts numpy arrays (converted
from PyTorch tensors by the handler) and returns numpy arrays that the
handler converts back to PyTorch.
Args:
mlx_decoder: ``MLXDiTDecoder`` instance with loaded weights.
encoder_hidden_states_np: [B, enc_L, D] from prepare_condition (numpy).
context_latents_np: [B, T, C] from prepare_condition (numpy).
src_latents_shape: shape tuple [B, T, 64] for noise generation.
seed: random seed (int, list[int], or None).
infer_method: "ode" or "sde".
shift: timestep shift factor.
timesteps: optional custom timestep list.
audio_cover_strength: cover strength (0-1).
encoder_hidden_states_non_cover_np: optional [B, enc_L, D] for non-cover.
context_latents_non_cover_np: optional [B, T, C] for non-cover.
Returns:
Dict with ``"target_latents"`` (numpy) and ``"time_costs"`` dict.
"""
import mlx.core as mx
from .model import MLXCrossAttentionCache
time_costs = {}
total_start = time.time()
# Convert numpy arrays to MLX
enc_hs = mx.array(encoder_hidden_states_np)
ctx = mx.array(context_latents_np)
enc_hs_nc = mx.array(encoder_hidden_states_non_cover_np) if encoder_hidden_states_non_cover_np is not None else None
ctx_nc = mx.array(context_latents_non_cover_np) if context_latents_non_cover_np is not None else None
bsz = src_latents_shape[0]
T = src_latents_shape[1]
C = src_latents_shape[2]
# ---- Noise preparation ----
if seed is None:
noise = mx.random.normal((bsz, T, C))
elif isinstance(seed, list):
parts = []
for s in seed:
if s is None or s < 0:
parts.append(mx.random.normal((1, T, C)))
else:
key = mx.random.key(int(s))
parts.append(mx.random.normal((1, T, C), key=key))
noise = mx.concatenate(parts, axis=0)
else:
key = mx.random.key(int(seed))
noise = mx.random.normal((bsz, T, C), key=key)
# ---- Timestep schedule ----
t_schedule_list = get_timestep_schedule(shift, timesteps)
num_steps = len(t_schedule_list)
cover_steps = int(num_steps * audio_cover_strength)
# ---- Diffusion loop ----
cache = MLXCrossAttentionCache()
xt = noise
diff_start = time.time()
for step_idx in range(num_steps):
current_t = t_schedule_list[step_idx]
t_curr = mx.full((bsz,), current_t)
# Switch to non-cover conditions when appropriate
if step_idx >= cover_steps and enc_hs_nc is not None:
enc_hs = enc_hs_nc
ctx = ctx_nc
cache = MLXCrossAttentionCache()
vt, cache = mlx_decoder(
hidden_states=xt,
timestep=t_curr,
timestep_r=t_curr,
encoder_hidden_states=enc_hs,
context_latents=ctx,
cache=cache,
use_cache=True,
)
# Evaluate to ensure computation is complete before next step
mx.eval(vt)
# Final step: compute x0
if step_idx == num_steps - 1:
t_unsq = mx.expand_dims(mx.expand_dims(t_curr, axis=-1), axis=-1)
xt = xt - vt * t_unsq
mx.eval(xt)
break
# ODE / SDE update
next_t = t_schedule_list[step_idx + 1]
if infer_method == "sde":
t_unsq = mx.expand_dims(mx.expand_dims(t_curr, axis=-1), axis=-1)
pred_clean = xt - vt * t_unsq
# Re-noise with next timestep
new_noise = mx.random.normal(xt.shape)
xt = next_t * new_noise + (1.0 - next_t) * pred_clean
else:
# ODE Euler step: x_{t+1} = x_t - v_t * dt
dt = current_t - next_t
dt_arr = mx.full((bsz, 1, 1), dt)
xt = xt - vt * dt_arr
mx.eval(xt)
diff_end = time.time()
total_end = time.time()
time_costs["diffusion_time_cost"] = diff_end - diff_start
time_costs["diffusion_per_step_time_cost"] = time_costs["diffusion_time_cost"] / max(num_steps, 1)
time_costs["total_time_cost"] = total_end - total_start
# Convert result back to numpy
result_np = np.array(xt)
return {
"target_latents": result_np,
"time_costs": time_costs,
}

629
acestep/mlx_dit/model.py Normal file
View file

@ -0,0 +1,629 @@
# This module re-implements the diffusion transformer decoder from
# modeling_acestep_v15_turbo.py using pure MLX operations for optimal
# performance on Apple Silicon.
import math
from typing import Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
# ---------------------------------------------------------------------------
# Utility helpers
# ---------------------------------------------------------------------------
def _rotate_half(x: mx.array) -> mx.array:
"""Rotate the last dimension by splitting in half and swapping with negation."""
half = x.shape[-1] // 2
x1 = x[..., :half]
x2 = x[..., half:]
return mx.concatenate([-x2, x1], axis=-1)
def _apply_rotary_pos_emb(
q: mx.array, k: mx.array, cos: mx.array, sin: mx.array
) -> Tuple[mx.array, mx.array]:
"""Apply rotary position embeddings to query and key tensors.
Args:
q, k: [B, n_heads, L, head_dim]
cos, sin: [1, 1, L, head_dim]
"""
q_embed = (q * cos) + (_rotate_half(q) * sin)
k_embed = (k * cos) + (_rotate_half(k) * sin)
return q_embed, k_embed
def _create_sliding_window_mask(
seq_len: int, window_size: int, dtype: mx.Dtype = mx.float32
) -> mx.array:
"""Create a bidirectional sliding-window additive attention mask.
Positions within ``window_size`` of each other get ``0``; all others
receive a large negative value (``-1e9``).
Returns:
[1, 1, seq_len, seq_len]
"""
indices = mx.arange(seq_len)
# diff[i, j] = |i - j|
diff = mx.abs(indices[:, None] - indices[None, :])
zeros = mx.zeros(diff.shape, dtype=dtype)
neginf = mx.full(diff.shape, -1e9, dtype=dtype)
mask = mx.where(diff <= window_size, zeros, neginf)
return mask[None, None, :, :] # [1, 1, L, L]
# ---------------------------------------------------------------------------
# Rotary Position Embedding
# ---------------------------------------------------------------------------
class MLXRotaryEmbedding(nn.Module):
"""Pre-computes and caches cos/sin tables for rotary position embeddings."""
def __init__(self, head_dim: int, max_len: int = 32768, base: float = 1_000_000.0):
super().__init__()
self.head_dim = head_dim
self.max_len = max_len
self.base = base
inv_freq = 1.0 / (
base ** (mx.arange(0, head_dim, 2).astype(mx.float32) / head_dim)
)
positions = mx.arange(max_len).astype(mx.float32)
freqs = positions[:, None] * inv_freq[None, :] # [max_len, head_dim//2]
freqs = mx.concatenate([freqs, freqs], axis=-1) # [max_len, head_dim]
self._cos = mx.cos(freqs) # [max_len, head_dim]
self._sin = mx.sin(freqs) # [max_len, head_dim]
def __call__(self, seq_len: int) -> Tuple[mx.array, mx.array]:
"""Return (cos, sin) each shaped [1, 1, seq_len, head_dim]."""
cos = self._cos[:seq_len][None, None, :, :]
sin = self._sin[:seq_len][None, None, :, :]
return cos, sin
# ---------------------------------------------------------------------------
# Cross-Attention KV Cache
# ---------------------------------------------------------------------------
class MLXCrossAttentionCache:
"""Simple KV cache for cross-attention layers.
Cross-attention K/V are computed from encoder hidden states once on the
first diffusion step and re-used for all subsequent steps.
"""
def __init__(self):
self._keys: dict[int, mx.array] = {}
self._values: dict[int, mx.array] = {}
self._updated: set[int] = set()
def update(self, key: mx.array, value: mx.array, layer_idx: int):
self._keys[layer_idx] = key
self._values[layer_idx] = value
self._updated.add(layer_idx)
def is_updated(self, layer_idx: int) -> bool:
return layer_idx in self._updated
def get(self, layer_idx: int) -> Tuple[mx.array, mx.array]:
return self._keys[layer_idx], self._values[layer_idx]
# ---------------------------------------------------------------------------
# Core Layers
# ---------------------------------------------------------------------------
class MLXSwiGLUMLP(nn.Module):
"""SwiGLU MLP (equivalent to Qwen3MLP): gate * silu(gate_proj) * up_proj."""
def __init__(self, hidden_size: int, intermediate_size: int):
super().__init__()
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
def __call__(self, x: mx.array) -> mx.array:
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
class MLXAttention(nn.Module):
"""Multi-head attention with QK-RMSNorm for the AceStep DiT.
Supports both self-attention (with RoPE) and cross-attention (with
optional KV caching).
"""
def __init__(
self,
hidden_size: int,
num_attention_heads: int,
num_key_value_heads: int,
head_dim: int,
rms_norm_eps: float,
attention_bias: bool,
layer_idx: int,
is_cross_attention: bool = False,
sliding_window: Optional[int] = None,
):
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_attention_heads
self.num_kv_heads = num_key_value_heads
self.head_dim = head_dim
self.n_rep = num_attention_heads // num_key_value_heads
self.scale = head_dim ** -0.5
self.layer_idx = layer_idx
self.is_cross_attention = is_cross_attention
self.sliding_window = sliding_window
self.q_proj = nn.Linear(hidden_size, num_attention_heads * head_dim, bias=attention_bias)
self.k_proj = nn.Linear(hidden_size, num_key_value_heads * head_dim, bias=attention_bias)
self.v_proj = nn.Linear(hidden_size, num_key_value_heads * head_dim, bias=attention_bias)
self.o_proj = nn.Linear(num_attention_heads * head_dim, hidden_size, bias=attention_bias)
self.q_norm = nn.RMSNorm(head_dim, eps=rms_norm_eps)
self.k_norm = nn.RMSNorm(head_dim, eps=rms_norm_eps)
@staticmethod
def _repeat_kv(x: mx.array, n_rep: int) -> mx.array:
"""Repeat KV heads for GQA: [B, n_kv, L, D] -> [B, n_kv*n_rep, L, D]."""
if n_rep == 1:
return x
B, n_kv, L, D = x.shape
x = mx.expand_dims(x, axis=2) # [B, n_kv, 1, L, D]
x = mx.broadcast_to(x, (B, n_kv, n_rep, L, D))
return x.reshape(B, n_kv * n_rep, L, D)
def __call__(
self,
hidden_states: mx.array,
position_cos_sin: Optional[Tuple[mx.array, mx.array]] = None,
attention_mask: Optional[mx.array] = None,
encoder_hidden_states: Optional[mx.array] = None,
cache: Optional[MLXCrossAttentionCache] = None,
use_cache: bool = False,
) -> mx.array:
B, L, _ = hidden_states.shape
# Project queries (always from hidden_states)
q = self.q_proj(hidden_states)
q = self.q_norm(q.reshape(B, L, self.num_heads, self.head_dim))
q = q.transpose(0, 2, 1, 3) # [B, n_heads, L, D]
if self.is_cross_attention and encoder_hidden_states is not None:
# Cross-attention: K,V come from encoder
if cache is not None and cache.is_updated(self.layer_idx):
k, v = cache.get(self.layer_idx)
else:
enc_L = encoder_hidden_states.shape[1]
k = self.k_proj(encoder_hidden_states)
k = self.k_norm(k.reshape(B, enc_L, self.num_kv_heads, self.head_dim))
k = k.transpose(0, 2, 1, 3)
v = self.v_proj(encoder_hidden_states).reshape(
B, enc_L, self.num_kv_heads, self.head_dim
).transpose(0, 2, 1, 3)
if cache is not None and use_cache:
cache.update(k, v, self.layer_idx)
else:
# Self-attention: K,V come from hidden_states
k = self.k_proj(hidden_states)
k = self.k_norm(k.reshape(B, L, self.num_kv_heads, self.head_dim))
k = k.transpose(0, 2, 1, 3)
v = self.v_proj(hidden_states).reshape(
B, L, self.num_kv_heads, self.head_dim
).transpose(0, 2, 1, 3)
# Apply RoPE to self-attention Q,K
if position_cos_sin is not None:
cos, sin = position_cos_sin
q, k = _apply_rotary_pos_emb(q, k, cos, sin)
# GQA: repeat KV heads to match Q heads
k = self._repeat_kv(k, self.n_rep)
v = self._repeat_kv(v, self.n_rep)
# Scaled dot-product attention
attn_out = mx.fast.scaled_dot_product_attention(
q, k, v, scale=self.scale, mask=attention_mask
)
# Merge heads and project output: [B, n_heads, L, D] -> [B, L, hidden]
attn_out = attn_out.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(attn_out)
# ---------------------------------------------------------------------------
# DiT Layer
# ---------------------------------------------------------------------------
class MLXDiTLayer(nn.Module):
"""A single DiT transformer layer with AdaLN modulation.
Implements:
1. Self-attention with adaptive layer norm (AdaLN)
2. Cross-attention to encoder hidden states
3. Feed-forward MLP with adaptive layer norm
"""
def __init__(
self,
hidden_size: int,
intermediate_size: int,
num_attention_heads: int,
num_key_value_heads: int,
head_dim: int,
rms_norm_eps: float,
attention_bias: bool,
layer_idx: int,
layer_type: str,
sliding_window: Optional[int] = None,
):
super().__init__()
self.layer_type = layer_type
sw = sliding_window if layer_type == "sliding_attention" else None
# 1. Self-attention
self.self_attn_norm = nn.RMSNorm(hidden_size, eps=rms_norm_eps)
self.self_attn = MLXAttention(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
head_dim=head_dim,
rms_norm_eps=rms_norm_eps,
attention_bias=attention_bias,
layer_idx=layer_idx,
is_cross_attention=False,
sliding_window=sw,
)
# 2. Cross-attention
self.cross_attn_norm = nn.RMSNorm(hidden_size, eps=rms_norm_eps)
self.cross_attn = MLXAttention(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
head_dim=head_dim,
rms_norm_eps=rms_norm_eps,
attention_bias=attention_bias,
layer_idx=layer_idx,
is_cross_attention=True,
)
# 3. MLP
self.mlp_norm = nn.RMSNorm(hidden_size, eps=rms_norm_eps)
self.mlp = MLXSwiGLUMLP(hidden_size, intermediate_size)
# AdaLN modulation table (6 values: shift/scale/gate for self-attn & MLP)
self.scale_shift_table = mx.zeros((1, 6, hidden_size))
def __call__(
self,
hidden_states: mx.array,
position_cos_sin: Tuple[mx.array, mx.array],
temb: mx.array,
self_attn_mask: Optional[mx.array],
encoder_hidden_states: Optional[mx.array],
encoder_attention_mask: Optional[mx.array],
cache: Optional[MLXCrossAttentionCache] = None,
use_cache: bool = False,
) -> mx.array:
# AdaLN modulation from timestep embeddings
# scale_shift_table: [1, 6, D], temb: [B, 6, D]
modulation = self.scale_shift_table + temb # [B, 6, D]
parts = mx.split(modulation, 6, axis=1)
# Each part: [B, 1, D]
shift_msa, scale_msa, gate_msa = parts[0], parts[1], parts[2]
c_shift_msa, c_scale_msa, c_gate_msa = parts[3], parts[4], parts[5]
# Step 1: Self-attention with AdaLN
normed = self.self_attn_norm(hidden_states)
normed = normed * (1.0 + scale_msa) + shift_msa
attn_out = self.self_attn(
normed,
position_cos_sin=position_cos_sin,
attention_mask=self_attn_mask,
)
hidden_states = hidden_states + attn_out * gate_msa
# Step 2: Cross-attention
normed = self.cross_attn_norm(hidden_states)
cross_out = self.cross_attn(
normed,
encoder_hidden_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
cache=cache,
use_cache=use_cache,
)
hidden_states = hidden_states + cross_out
# Step 3: MLP with AdaLN
normed = self.mlp_norm(hidden_states)
normed = normed * (1.0 + c_scale_msa) + c_shift_msa
ff_out = self.mlp(normed)
hidden_states = hidden_states + ff_out * c_gate_msa
return hidden_states
# ---------------------------------------------------------------------------
# Timestep Embedding
# ---------------------------------------------------------------------------
class MLXTimestepEmbedding(nn.Module):
"""Sinusoidal timestep embedding followed by MLP."""
def __init__(self, in_channels: int = 256, time_embed_dim: int = 2048, scale: float = 1000.0):
super().__init__()
self.in_channels = in_channels
self.scale = scale
self.linear_1 = nn.Linear(in_channels, time_embed_dim, bias=True)
self.act1 = nn.SiLU()
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim, bias=True)
self.act2 = nn.SiLU()
self.time_proj = nn.Linear(time_embed_dim, time_embed_dim * 6, bias=True)
def _sinusoidal_embedding(self, t: mx.array, dim: int, max_period: int = 10000) -> mx.array:
"""Create sinusoidal timestep embeddings.
Args:
t: 1-D array of shape [N]
dim: embedding dimension
Returns:
[N, dim]
"""
t = t * self.scale
half = dim // 2
freqs = mx.exp(
-math.log(max_period)
* mx.arange(half).astype(mx.float32) / half
)
args = t[:, None].astype(mx.float32) * freqs[None, :]
embedding = mx.concatenate([mx.cos(args), mx.sin(args)], axis=-1)
if dim % 2:
embedding = mx.concatenate(
[embedding, mx.zeros_like(embedding[:, :1])], axis=-1
)
return embedding
def __call__(self, t: mx.array) -> Tuple[mx.array, mx.array]:
"""
Args:
t: [B] timestep values
Returns:
temb: [B, D]
timestep_proj: [B, 6, D]
"""
t_freq = self._sinusoidal_embedding(t, self.in_channels)
temb = self.linear_1(t_freq.astype(t.dtype))
temb = self.act1(temb)
temb = self.linear_2(temb)
proj = self.time_proj(self.act2(temb)) # [B, D*6]
timestep_proj = proj.reshape(proj.shape[0], 6, -1) # [B, 6, D]
return temb, timestep_proj
# ---------------------------------------------------------------------------
# Full DiT Decoder
# ---------------------------------------------------------------------------
class MLXDiTDecoder(nn.Module):
"""Native MLX implementation of AceStepDiTModel (the diffusion transformer decoder).
Mirrors the PyTorch ``AceStepDiTModel`` class exactly:
- Patch-based input projection (Conv1d)
- Timestep conditioning via dual TimestepEmbedding
- N DiT transformer layers with self/cross-attention and AdaLN
- Patch-based output projection (ConvTranspose1d)
- Adaptive output layer norm
"""
def __init__(
self,
hidden_size: int = 2048,
intermediate_size: int = 6144,
num_hidden_layers: int = 24,
num_attention_heads: int = 16,
num_key_value_heads: int = 8,
head_dim: int = 128,
rms_norm_eps: float = 1e-6,
attention_bias: bool = False,
in_channels: int = 192,
audio_acoustic_hidden_dim: int = 64,
patch_size: int = 2,
sliding_window: int = 128,
layer_types: Optional[list] = None,
rope_theta: float = 1_000_000.0,
max_position_embeddings: int = 32768,
):
super().__init__()
self.hidden_size = hidden_size
self.patch_size = patch_size
inner_dim = hidden_size
if layer_types is None:
layer_types = [
"sliding_attention" if bool((i + 1) % 2) else "full_attention"
for i in range(num_hidden_layers)
]
# Rotary position embeddings
self.rotary_emb = MLXRotaryEmbedding(
head_dim, max_len=max_position_embeddings, base=rope_theta
)
# Input projection: Conv1d patch embedding
# MLX Conv1d uses channels-last: [B, L, C] -> [B, L//stride, out_C]
self.proj_in = nn.Conv1d(
in_channels=in_channels,
out_channels=inner_dim,
kernel_size=patch_size,
stride=patch_size,
padding=0,
)
# Timestep embeddings (two: t and t-r)
self.time_embed = MLXTimestepEmbedding(in_channels=256, time_embed_dim=inner_dim)
self.time_embed_r = MLXTimestepEmbedding(in_channels=256, time_embed_dim=inner_dim)
# Condition embedder
self.condition_embedder = nn.Linear(inner_dim, inner_dim, bias=True)
# Transformer layers
self.layers = [
MLXDiTLayer(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
head_dim=head_dim,
rms_norm_eps=rms_norm_eps,
attention_bias=attention_bias,
layer_idx=i,
layer_type=layer_types[i],
sliding_window=sliding_window,
)
for i in range(num_hidden_layers)
]
# Output
self.norm_out = nn.RMSNorm(inner_dim, eps=rms_norm_eps)
self.proj_out = nn.ConvTranspose1d(
in_channels=inner_dim,
out_channels=audio_acoustic_hidden_dim,
kernel_size=patch_size,
stride=patch_size,
padding=0,
)
# Output adaptive layer norm modulation (2 values: shift, scale)
self.scale_shift_table = mx.zeros((1, 2, inner_dim))
# Pre-compute sliding window mask (will be set on first forward)
self._sliding_masks: dict[int, mx.array] = {}
self._sliding_window = sliding_window
self._layer_types = layer_types
def _get_sliding_mask(self, seq_len: int, dtype: mx.Dtype) -> mx.array:
if seq_len not in self._sliding_masks:
self._sliding_masks[seq_len] = _create_sliding_window_mask(
seq_len, self._sliding_window, dtype
)
return self._sliding_masks[seq_len]
def __call__(
self,
hidden_states: mx.array,
timestep: mx.array,
timestep_r: mx.array,
encoder_hidden_states: mx.array,
context_latents: mx.array,
cache: Optional[MLXCrossAttentionCache] = None,
use_cache: bool = True,
) -> Tuple[mx.array, Optional[MLXCrossAttentionCache]]:
"""
Args:
hidden_states: noisy latents [B, T, 64]
timestep: [B] current timestep
timestep_r: [B] reference timestep
encoder_hidden_states: [B, enc_L, D] from condition encoder
context_latents: [B, T, C_ctx] (src_latents + chunk_masks)
cache: cross-attention KV cache
use_cache: whether to cache cross-attention KV
Returns:
(output_hidden_states, cache)
"""
# Timestep embeddings
temb_t, proj_t = self.time_embed(timestep)
temb_r, proj_r = self.time_embed_r(timestep - timestep_r)
temb = temb_t + temb_r # [B, D]
timestep_proj = proj_t + proj_r # [B, 6, D]
# Concatenate context with hidden states: [B, T, C_ctx + 64] -> [B, T, in_channels]
hidden_states = mx.concatenate([context_latents, hidden_states], axis=-1)
original_seq_len = hidden_states.shape[1]
# Pad to multiple of patch_size
pad_length = 0
if hidden_states.shape[1] % self.patch_size != 0:
pad_length = self.patch_size - (hidden_states.shape[1] % self.patch_size)
# Pad along time dimension
padding = mx.zeros(
(hidden_states.shape[0], pad_length, hidden_states.shape[2]),
dtype=hidden_states.dtype,
)
hidden_states = mx.concatenate([hidden_states, padding], axis=1)
# Patch embedding: [B, T, in_ch] -> [B, T//patch, D]
hidden_states = self.proj_in(hidden_states)
# Project encoder states
encoder_hidden_states = self.condition_embedder(encoder_hidden_states)
seq_len = hidden_states.shape[1]
dtype = hidden_states.dtype
# Position embeddings (RoPE)
cos, sin = self.rotary_emb(seq_len)
# Attention masks
# Self-attention: full layers get None; sliding layers get windowed mask
# Cross-attention: always None (no masking)
sliding_mask = None
has_sliding = any(lt == "sliding_attention" for lt in self._layer_types)
if has_sliding:
sliding_mask = self._get_sliding_mask(seq_len, dtype)
# Process through transformer layers
for layer in self.layers:
self_attn_mask = sliding_mask if layer.layer_type == "sliding_attention" else None
hidden_states = layer(
hidden_states,
position_cos_sin=(cos, sin),
temb=timestep_proj,
self_attn_mask=self_attn_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=None,
cache=cache,
use_cache=use_cache,
)
# Output adaptive layer norm
shift, scale = mx.split(
self.scale_shift_table + mx.expand_dims(temb, axis=1), 2, axis=1
)
hidden_states = self.norm_out(hidden_states) * (1.0 + scale) + shift
# De-patchify: [B, T//patch, D] -> [B, T, out_channels]
hidden_states = self.proj_out(hidden_states)
# Crop back to original sequence length
hidden_states = hidden_states[:, :original_seq_len, :]
return hidden_states, cache
@classmethod
def from_config(cls, config) -> "MLXDiTDecoder":
"""Construct from an AceStepConfig (transformers PretrainedConfig)."""
return cls(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
num_hidden_layers=config.num_hidden_layers,
num_attention_heads=config.num_attention_heads,
num_key_value_heads=config.num_key_value_heads,
head_dim=getattr(config, "head_dim", config.hidden_size // config.num_attention_heads),
rms_norm_eps=config.rms_norm_eps,
attention_bias=config.attention_bias,
in_channels=config.in_channels,
audio_acoustic_hidden_dim=config.audio_acoustic_hidden_dim,
patch_size=config.patch_size,
sliding_window=config.sliding_window if config.sliding_window else 128,
layer_types=config.layer_types,
rope_theta=config.rope_theta,
max_position_embeddings=config.max_position_embeddings,
)

View file

@ -0,0 +1,33 @@
# Native MLX implementation of the Oobleck VAE (AutoencoderOobleck) for Apple Silicon.
# Provides pure MLX encode/decode with graceful fallback to PyTorch.
import logging
import platform
logger = logging.getLogger(__name__)
def is_mlx_available() -> bool:
"""Check if MLX is available on this platform (macOS + Apple Silicon)."""
if platform.system() != "Darwin":
return False
try:
import mlx.core as mx
import mlx.nn
# Verify we can actually create arrays (Metal backend works)
_ = mx.array([1.0])
mx.eval(_)
return True
except Exception:
return False
_MLX_AVAILABLE = None
def mlx_available() -> bool:
"""Cached check for MLX availability."""
global _MLX_AVAILABLE
if _MLX_AVAILABLE is None:
_MLX_AVAILABLE = is_mlx_available()
return _MLX_AVAILABLE

150
acestep/mlx_vae/convert.py Normal file
View file

@ -0,0 +1,150 @@
# Weight conversion from PyTorch AutoencoderOobleck to native MLX format.
#
# Handles:
# - weight_norm fusion: weight_g + weight_v → fused weight
# - Conv1d axis swap: PT [out, in, K] → MLX [out, K, in]
# - ConvTranspose1d: PT [in, out, K] → MLX [out, K, in]
# - Snake1d parameters: PT [1, C, 1] → MLX [C]
# - Bias: no change
import logging
from typing import Dict, List, Tuple
import numpy as np
logger = logging.getLogger(__name__)
def _fuse_weight_norm(
weight_g: np.ndarray,
weight_v: np.ndarray,
eps: float = 1e-9,
) -> np.ndarray:
"""Fuse weight_norm parameters into a single weight tensor.
weight_norm decomposes ``w = g * v / ||v||`` where:
weight_g: per-output-channel scale [out, 1, 1] (Conv1d)
or [in, 1, 1] for ConvTranspose1d
weight_v: direction tensor with same shape as the original weight
Returns the fused weight in the *original PyTorch shape* (before axis swap).
"""
v_flat = weight_v.reshape(weight_v.shape[0], -1)
norm = np.linalg.norm(v_flat, axis=1).reshape(weight_g.shape)
return weight_g * weight_v / (norm + eps)
def convert_vae_weights(
pytorch_vae,
) -> List[Tuple[str, "mx.array"]]:
"""Convert PyTorch AutoencoderOobleck weights to MLX format.
The function extracts the state dict from ``pytorch_vae`` and converts
each parameter to the format expected by ``MLXAutoEncoderOobleck``,
handling weight_norm fusion, Conv axis reordering, and Snake1d
parameter reshaping.
Args:
pytorch_vae: A ``diffusers.AutoencoderOobleck`` instance.
Returns:
List of ``(param_name, mx.array)`` pairs for ``model.load_weights()``.
"""
import mlx.core as mx
state_dict = pytorch_vae.state_dict()
weights: List[Tuple[str, "mx.array"]] = []
processed: set = set()
# Sort keys so we process weight_v / weight_g pairs together
all_keys = sorted(state_dict.keys())
for key in all_keys:
if key in processed:
continue
# ------------------------------------------------------------------
# 1) weight_norm fusion: *_weight_g + *_weight_v → *.weight
# ------------------------------------------------------------------
if key.endswith(".weight_g"):
# The companion weight_v key
base = key[: -len(".weight_g")]
v_key = base + ".weight_v"
if v_key not in state_dict:
logger.warning(
"[MLX-VAE] weight_g without weight_v: %s — skipping", key
)
processed.add(key)
continue
g = state_dict[key].detach().cpu().float().numpy()
v = state_dict[v_key].detach().cpu().float().numpy()
w = _fuse_weight_norm(g, v)
# Determine layer type and swap axes
if "conv_t1" in base:
# ConvTranspose1d: PT [in, out, K] → MLX [out, K, in]
w = w.transpose(1, 2, 0)
else:
# Conv1d: PT [out, in, K] → MLX [out, K, in]
w = w.swapaxes(1, 2)
new_key = base + ".weight"
weights.append((new_key, mx.array(w)))
processed.add(key)
processed.add(v_key)
continue
if key.endswith(".weight_v"):
# Will be / was handled together with weight_g above
continue
# ------------------------------------------------------------------
# 2) Snake1d parameters: PT [1, C, 1] → MLX [C]
# ------------------------------------------------------------------
if key.endswith(".alpha") or key.endswith(".beta"):
val = state_dict[key].detach().cpu().float().numpy().squeeze()
weights.append((key, mx.array(val)))
processed.add(key)
continue
# ------------------------------------------------------------------
# 3) Bias: shape [C] — no transformation needed
# ------------------------------------------------------------------
if key.endswith(".bias"):
val = state_dict[key].detach().cpu().float().numpy()
weights.append((key, mx.array(val)))
processed.add(key)
continue
# ------------------------------------------------------------------
# 4) Catch-all for any remaining parameters (unlikely in this model)
# ------------------------------------------------------------------
val = state_dict[key].detach().cpu().float().numpy()
logger.debug("[MLX-VAE] Pass-through key: %s shape=%s", key, val.shape)
weights.append((key, mx.array(val)))
processed.add(key)
logger.info(
"[MLX-VAE] Converted %d parameters to MLX format.", len(weights)
)
return weights
def convert_and_load(
pytorch_vae,
mlx_vae: "MLXAutoEncoderOobleck",
) -> None:
"""Convert PyTorch VAE weights and load them into an MLX VAE.
Args:
pytorch_vae: ``diffusers.AutoencoderOobleck`` instance (PyTorch).
mlx_vae: An ``MLXAutoEncoderOobleck`` instance (already constructed).
"""
import mlx.core as mx
weights = convert_vae_weights(pytorch_vae)
mlx_vae.load_weights(weights)
mx.eval(mlx_vae.parameters())
logger.info("[MLX-VAE] Weights loaded and evaluated successfully.")

336
acestep/mlx_vae/model.py Normal file
View file

@ -0,0 +1,336 @@
# Pure MLX re-implementation of diffusers' AutoencoderOobleck for Apple Silicon.
#
# Architecture mirrors the PyTorch version exactly:
# Snake1d -> OobleckResidualUnit -> EncoderBlock / DecoderBlock
# -> OobleckEncoder / OobleckDecoder -> MLXAutoEncoderOobleck
#
# All operations use MLX channels-last (NLC) convention internally.
# The public encode/decode API accepts and returns NLC arrays.
import math
import logging
from typing import List, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Snake1d Activation
# ---------------------------------------------------------------------------
class MLXSnake1d(nn.Module):
"""Snake activation: x + (1/beta) * sin(alpha * x)^2.
Parameters ``alpha`` and ``beta`` are stored as 1-D vectors of shape [C]
and broadcast over (B, L) automatically. When ``logscale=True`` (default)
the actual scale is ``exp(alpha)`` / ``exp(beta)``.
"""
def __init__(self, hidden_dim: int, logscale: bool = True):
super().__init__()
self.alpha = mx.zeros(hidden_dim)
self.beta = mx.zeros(hidden_dim)
self.logscale = logscale
def __call__(self, x: mx.array) -> mx.array:
# x: [B, L, C] (NLC)
# NOTE: Upcast to float32 for exp/sin/power to prevent overflow with float16
# weights (exp overflows float16 at alpha > ~11). This is only a problem
# if the weights are in float16. The surrounding
# Conv1d layers still run in the caller's dtype (float16) for speed.
# This is the original code that works with float16 weights, if we end up needing to
# use float16 weights. please use this code instead
# alpha = mx.exp(self.alpha.astype(mx.float32)) if self.logscale else self.alpha
# beta = mx.exp(self.beta.astype(mx.float32)) if self.logscale else self.beta
# x_f32 = x.astype(mx.float32)
# result = x_f32 + mx.reciprocal(beta + 1e-9) * mx.power(mx.sin(alpha * x_f32), 2)
# return result.astype(x.dtype)
alpha = mx.exp(self.alpha) if self.logscale else self.alpha
beta = mx.exp(self.beta) if self.logscale else self.beta
# All ops broadcast [C] over [B, L, C]
return x + mx.reciprocal(beta + 1e-9) * mx.power(mx.sin(alpha * x), 2)
# ---------------------------------------------------------------------------
# Residual Unit
# ---------------------------------------------------------------------------
class MLXOobleckResidualUnit(nn.Module):
"""Two weight-normalised Conv1d layers (k=7 dilated + k=1) wrapped with
Snake1d activations and a residual skip connection."""
def __init__(self, dimension: int = 16, dilation: int = 1):
super().__init__()
pad = ((7 - 1) * dilation) // 2
self.snake1 = MLXSnake1d(dimension)
self.conv1 = nn.Conv1d(
dimension, dimension, kernel_size=7, dilation=dilation, padding=pad
)
self.snake2 = MLXSnake1d(dimension)
self.conv2 = nn.Conv1d(dimension, dimension, kernel_size=1)
def __call__(self, hidden_state: mx.array) -> mx.array:
# hidden_state: [B, L, C]
output = self.conv1(self.snake1(hidden_state))
output = self.conv2(self.snake2(output))
# Safety trim (should be no-op with correct padding)
padding = (hidden_state.shape[1] - output.shape[1]) // 2
if padding > 0:
hidden_state = hidden_state[:, padding:-padding, :]
return hidden_state + output
# ---------------------------------------------------------------------------
# Encoder / Decoder Blocks
# ---------------------------------------------------------------------------
class MLXOobleckEncoderBlock(nn.Module):
"""3 residual units (dilations 1, 3, 9) -> Snake -> strided Conv1d down."""
def __init__(self, input_dim: int, output_dim: int, stride: int = 1):
super().__init__()
self.res_unit1 = MLXOobleckResidualUnit(input_dim, dilation=1)
self.res_unit2 = MLXOobleckResidualUnit(input_dim, dilation=3)
self.res_unit3 = MLXOobleckResidualUnit(input_dim, dilation=9)
self.snake1 = MLXSnake1d(input_dim)
self.conv1 = nn.Conv1d(
input_dim,
output_dim,
kernel_size=2 * stride,
stride=stride,
padding=math.ceil(stride / 2),
)
def __call__(self, hidden_state: mx.array) -> mx.array:
hidden_state = self.res_unit1(hidden_state)
hidden_state = self.res_unit2(hidden_state)
hidden_state = self.snake1(self.res_unit3(hidden_state))
hidden_state = self.conv1(hidden_state)
return hidden_state
class MLXOobleckDecoderBlock(nn.Module):
"""Snake -> strided ConvTranspose1d up -> 3 residual units (dilations 1, 3, 9)."""
def __init__(self, input_dim: int, output_dim: int, stride: int = 1):
super().__init__()
self.snake1 = MLXSnake1d(input_dim)
self.conv_t1 = nn.ConvTranspose1d(
input_dim,
output_dim,
kernel_size=2 * stride,
stride=stride,
padding=math.ceil(stride / 2),
)
self.res_unit1 = MLXOobleckResidualUnit(output_dim, dilation=1)
self.res_unit2 = MLXOobleckResidualUnit(output_dim, dilation=3)
self.res_unit3 = MLXOobleckResidualUnit(output_dim, dilation=9)
def __call__(self, hidden_state: mx.array) -> mx.array:
hidden_state = self.snake1(hidden_state)
hidden_state = self.conv_t1(hidden_state)
hidden_state = self.res_unit1(hidden_state)
hidden_state = self.res_unit2(hidden_state)
hidden_state = self.res_unit3(hidden_state)
return hidden_state
# ---------------------------------------------------------------------------
# Encoder / Decoder
# ---------------------------------------------------------------------------
class MLXOobleckEncoder(nn.Module):
"""Oobleck Encoder: Conv1d -> N encoder blocks -> Snake -> Conv1d."""
def __init__(
self,
encoder_hidden_size: int,
audio_channels: int,
downsampling_ratios: List[int],
channel_multiples: List[int],
):
super().__init__()
strides = downsampling_ratios
cm = [1] + list(channel_multiples)
self.conv1 = nn.Conv1d(
audio_channels, encoder_hidden_size, kernel_size=7, padding=3
)
self.block = []
for i, stride in enumerate(strides):
self.block.append(
MLXOobleckEncoderBlock(
input_dim=encoder_hidden_size * cm[i],
output_dim=encoder_hidden_size * cm[i + 1],
stride=stride,
)
)
d_model = encoder_hidden_size * cm[-1]
self.snake1 = MLXSnake1d(d_model)
self.conv2 = nn.Conv1d(d_model, encoder_hidden_size, kernel_size=3, padding=1)
def __call__(self, hidden_state: mx.array) -> mx.array:
hidden_state = self.conv1(hidden_state)
for module in self.block:
hidden_state = module(hidden_state)
hidden_state = self.snake1(hidden_state)
hidden_state = self.conv2(hidden_state)
return hidden_state
class MLXOobleckDecoder(nn.Module):
"""Oobleck Decoder: Conv1d -> N decoder blocks -> Snake -> Conv1d."""
def __init__(
self,
channels: int,
input_channels: int,
audio_channels: int,
upsampling_ratios: List[int],
channel_multiples: List[int],
):
super().__init__()
strides = upsampling_ratios
cm = [1] + list(channel_multiples)
self.conv1 = nn.Conv1d(
input_channels, channels * cm[-1], kernel_size=7, padding=3
)
self.block = []
for i, stride in enumerate(strides):
self.block.append(
MLXOobleckDecoderBlock(
input_dim=channels * cm[len(strides) - i],
output_dim=channels * cm[len(strides) - i - 1],
stride=stride,
)
)
self.snake1 = MLXSnake1d(channels)
self.conv2 = nn.Conv1d(
channels, audio_channels, kernel_size=7, padding=3, bias=False
)
def __call__(self, hidden_state: mx.array) -> mx.array:
hidden_state = self.conv1(hidden_state)
for layer in self.block:
hidden_state = layer(hidden_state)
hidden_state = self.snake1(hidden_state)
hidden_state = self.conv2(hidden_state)
return hidden_state
# ---------------------------------------------------------------------------
# Full VAE
# ---------------------------------------------------------------------------
class MLXAutoEncoderOobleck(nn.Module):
"""Pure-MLX re-implementation of ``diffusers.AutoencoderOobleck``.
Default configuration matches the Stable Audio / ACE-Step VAE:
encoder_hidden_size = 128
downsampling_ratios = [2, 4, 4, 8, 8] (hop_length = 2048)
channel_multiples = [1, 2, 4, 8, 16]
decoder_channels = 128
decoder_input_channels = 64 (latent dim)
audio_channels = 2 (stereo)
Data flows in NLC (batch, length, channels) format throughout.
"""
def __init__(
self,
encoder_hidden_size: int = 128,
downsampling_ratios: Optional[List[int]] = None,
channel_multiples: Optional[List[int]] = None,
decoder_channels: int = 128,
decoder_input_channels: int = 64,
audio_channels: int = 2,
):
super().__init__()
if downsampling_ratios is None:
downsampling_ratios = [2, 4, 4, 8, 8]
if channel_multiples is None:
channel_multiples = [1, 2, 4, 8, 16]
self.encoder_hidden_size = encoder_hidden_size
self.decoder_input_channels = decoder_input_channels
self.encoder = MLXOobleckEncoder(
encoder_hidden_size=encoder_hidden_size,
audio_channels=audio_channels,
downsampling_ratios=downsampling_ratios,
channel_multiples=channel_multiples,
)
self.decoder = MLXOobleckDecoder(
channels=decoder_channels,
input_channels=decoder_input_channels,
audio_channels=audio_channels,
upsampling_ratios=downsampling_ratios[::-1],
channel_multiples=channel_multiples,
)
# -- public API ---------------------------------------------------------
def encode_and_sample(self, audio_nlc: mx.array) -> mx.array:
"""Encode audio -> sample latent.
Args:
audio_nlc: [B, L_audio, C_audio] in NLC format.
Returns:
z: [B, L_latent, C_latent] sampled latent.
"""
h = self.encoder(audio_nlc) # [B, L', encoder_hidden_size]
# Diagonal Gaussian: split into mean + log-scale
mean, scale = mx.split(h, 2, axis=-1)
# softplus(scale) + epsilon (numerically stable)
std = mx.where(scale > 20.0, scale, mx.log(1.0 + mx.exp(scale))) + 1e-4
noise = mx.random.normal(mean.shape)
z = mean + std * noise
return z
def encode_mean(self, audio_nlc: mx.array) -> mx.array:
"""Encode audio -> return mean (no sampling noise)."""
h = self.encoder(audio_nlc)
mean, _scale = mx.split(h, 2, axis=-1)
return mean
def decode(self, latents_nlc: mx.array) -> mx.array:
"""Decode latents -> audio.
Args:
latents_nlc: [B, L_latent, C_latent] in NLC format.
Returns:
audio: [B, L_audio, C_audio] in NLC format.
"""
return self.decoder(latents_nlc)
# -- construction helpers -----------------------------------------------
@classmethod
def from_pytorch_config(cls, pt_vae) -> "MLXAutoEncoderOobleck":
"""Construct from a PyTorch ``AutoencoderOobleck`` instance's config."""
cfg = pt_vae.config
return cls(
encoder_hidden_size=cfg.encoder_hidden_size,
downsampling_ratios=list(cfg.downsampling_ratios),
channel_multiples=list(cfg.channel_multiples),
decoder_channels=cfg.decoder_channels,
decoder_input_channels=cfg.decoder_input_channels,
audio_channels=cfg.audio_channels,
)

View file

@ -8,6 +8,8 @@ with intelligent fallback between download sources.
import os
import sys
import hashlib
import shutil
import argparse
from typing import Optional, List, Dict, Tuple
from pathlib import Path
@ -15,6 +17,118 @@ from pathlib import Path
from loguru import logger
# =============================================================================
# Model Code File Sync (GitHub repo -> checkpoint directories)
# =============================================================================
# Mapping from checkpoint directory name to source model variant in acestep/models/
_CHECKPOINT_TO_VARIANT: Dict[str, str] = {
"acestep-v15-turbo": "turbo",
"acestep-v15-sft": "sft",
"acestep-v15-base": "base",
# SFT variants (base-SFT uses the same model code as SFT)
"acestep-v15-base-sft-fix-inst": "sft",
# Turbo variants all share the turbo model code
"acestep-v15-turbo-shift1": "turbo",
"acestep-v15-turbo-shift3": "turbo",
"acestep-v15-turbo-continuous": "turbo",
"acestep-v15-turbo-fix-inst-shift3": "turbo",
"acestep-v15-turbo-fix-inst-shift-continous": "turbo",
"acestep-v15-turbo-fix-inst-shift-dynamic": "turbo",
"acestep-v15-turbo-rl": "turbo",
}
def _get_models_source_dir() -> Path:
"""Get the acestep/models/ directory (authoritative source for model code)."""
return Path(__file__).resolve().parent / "models"
def _file_hash(filepath: Path) -> str:
"""Compute SHA-256 hash of a file's contents."""
h = hashlib.sha256()
with open(filepath, "rb") as f:
for chunk in iter(lambda: f.read(8192), b""):
h.update(chunk)
return h.hexdigest()
def _check_code_mismatch(model_name: str, checkpoints_dir) -> List[str]:
"""
Compare .py files in acestep/models/{variant}/ with those in the checkpoint directory.
Args:
model_name: Checkpoint directory name (e.g. "acestep-v15-turbo")
checkpoints_dir: Path to the checkpoints root directory
Returns:
List of filenames that differ (empty list if all match or model_name is unknown)
"""
variant = _CHECKPOINT_TO_VARIANT.get(model_name)
if variant is None:
return []
source_dir = _get_models_source_dir() / variant
if not source_dir.exists():
return []
if isinstance(checkpoints_dir, str):
checkpoints_dir = Path(checkpoints_dir)
target_dir = checkpoints_dir / model_name
mismatched = []
for src_file in source_dir.glob("*.py"):
if src_file.name == "__init__.py":
continue
dst_file = target_dir / src_file.name
if not dst_file.exists():
mismatched.append(src_file.name)
elif _file_hash(src_file) != _file_hash(dst_file):
mismatched.append(src_file.name)
return mismatched
def _sync_model_code_files(model_name: str, checkpoints_dir) -> List[str]:
"""
Copy .py files from acestep/models/{variant}/ into the checkpoint directory,
overwriting the HuggingFace-downloaded versions.
Args:
model_name: Checkpoint directory name (e.g. "acestep-v15-turbo")
checkpoints_dir: Path to the checkpoints root directory
Returns:
List of filenames that were synced (empty if model_name is unknown or no source)
"""
variant = _CHECKPOINT_TO_VARIANT.get(model_name)
if variant is None:
return []
source_dir = _get_models_source_dir() / variant
if not source_dir.exists():
logger.warning(f"[Model Sync] Source directory not found: {source_dir}")
return []
if isinstance(checkpoints_dir, str):
checkpoints_dir = Path(checkpoints_dir)
target_dir = checkpoints_dir / model_name
if not target_dir.exists():
logger.warning(f"[Model Sync] Target directory not found: {target_dir}")
return []
synced = []
for src_file in source_dir.glob("*.py"):
if src_file.name == "__init__.py":
continue
dst_file = target_dir / src_file.name
shutil.copy2(src_file, dst_file)
synced.append(src_file.name)
logger.debug(f"[Model Sync] Synced {src_file.name} -> {dst_file}")
return synced
# =============================================================================
# Network Detection & Smart Download
# =============================================================================
@ -302,7 +416,15 @@ def download_main_model(
print("This may take a while depending on your internet connection...")
# Use smart download with automatic fallback
return _smart_download(MAIN_MODEL_REPO, checkpoints_dir, token, prefer_source)
success, msg = _smart_download(MAIN_MODEL_REPO, checkpoints_dir, token, prefer_source)
if success:
# Sync model code files for all DiT components in the main model
for component in MAIN_MODEL_COMPONENTS:
if component in _CHECKPOINT_TO_VARIANT:
synced = _sync_model_code_files(component, checkpoints_dir)
if synced:
logger.info(f"[Model Download] Synced code files for {component}: {synced}")
return success, msg
def download_submodel(
@ -348,7 +470,13 @@ def download_submodel(
print(f"Destination: {model_path}")
# Use smart download with automatic fallback
return _smart_download(repo_id, model_path, token, prefer_source)
success, msg = _smart_download(repo_id, model_path, token, prefer_source)
if success and model_name in _CHECKPOINT_TO_VARIANT:
# Sync model code files after successful download
synced = _sync_model_code_files(model_name, checkpoints_dir)
if synced:
logger.info(f"[Model Download] Synced code files for {model_name}: {synced}")
return success, msg
def download_all_models(

View file

@ -0,0 +1,3 @@
# ACE-Step model definitions
# These files are the authoritative source for model code.
# They are auto-synced to checkpoint directories on startup.

View file

View file

@ -0,0 +1,220 @@
import torch
import torch.nn.functional as F
class MomentumBuffer:
def __init__(self, momentum: float = -0.75):
self.momentum = momentum
self.running_average = 0
def update(self, update_value: torch.Tensor):
new_average = self.momentum * self.running_average
self.running_average = update_value + new_average
def project(
v0: torch.Tensor, # [B, C, T]
v1: torch.Tensor, # [B, C, T]
dims=[-1],
):
dtype = v0.dtype
device_type = v0.device.type
if device_type == "mps":
v0, v1 = v0.cpu(), v1.cpu()
v0, v1 = v0.double(), v1.double()
v1 = torch.nn.functional.normalize(v1, dim=dims)
v0_parallel = (v0 * v1).sum(dim=dims, keepdim=True) * v1
v0_orthogonal = v0 - v0_parallel
return v0_parallel.to(dtype).to(device_type), v0_orthogonal.to(dtype).to(device_type)
def apg_forward(
pred_cond: torch.Tensor, # [B, C, T]
pred_uncond: torch.Tensor, # [B, C, T]
guidance_scale: float,
momentum_buffer: MomentumBuffer = None,
eta: float = 0.0,
norm_threshold: float = 2.5,
dims=[-1],
):
diff = pred_cond - pred_uncond
if momentum_buffer is not None:
momentum_buffer.update(diff)
diff = momentum_buffer.running_average
if norm_threshold > 0:
ones = torch.ones_like(diff)
diff_norm = diff.norm(p=2, dim=dims, keepdim=True)
scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
diff = diff * scale_factor
diff_parallel, diff_orthogonal = project(diff, pred_cond, dims)
normalized_update = diff_orthogonal + eta * diff_parallel
pred_guided = pred_cond + (guidance_scale - 1) * normalized_update
return pred_guided
def cfg_forward(cond_output, uncond_output, cfg_strength):
return uncond_output + cfg_strength * (cond_output - uncond_output)
def call_cos_tensor(tensor1, tensor2):
"""
Calculate cosine similarity between two normalized tensors.
Args:
tensor1: First tensor [B, ...]
tensor2: Second tensor [B, ...]
Returns:
Cosine similarity value [B, 1]
"""
tensor1 = tensor1 / torch.linalg.norm(tensor1, dim=1, keepdim=True)
tensor2 = tensor2 / torch.linalg.norm(tensor2, dim=1, keepdim=True)
cosvalue = torch.sum(tensor1 * tensor2, dim=1, keepdim=True)
return cosvalue
def compute_perpendicular_component(latent_diff, latent_hat_uncond):
"""
Decompose latent_diff into parallel and perpendicular components relative to latent_hat_uncond.
Args:
latent_diff: Difference tensor [B, C, ...]
latent_hat_uncond: Unconditional prediction tensor [B, C, ...]
Returns:
projection: Parallel component
perpendicular_component: Perpendicular component
"""
n, t, c = latent_diff.shape
latent_diff = latent_diff.view(n * t, c).float()
latent_hat_uncond = latent_hat_uncond.view(n * t, c).float()
if latent_diff.size() != latent_hat_uncond.size():
raise ValueError("latent_diff and latent_hat_uncond must have the same shape [n, d].")
dot_product = torch.sum(latent_diff * latent_hat_uncond, dim=1, keepdim=True) # [n, 1]
norm_square = torch.sum(latent_hat_uncond * latent_hat_uncond, dim=1, keepdim=True) # [n, 1]
projection = (dot_product / (norm_square + 1e-8)) * latent_hat_uncond
perpendicular_component = latent_diff - projection
return projection.view(n, t, c), perpendicular_component.reshape(n, t, c)
def adg_forward(
latents: torch.Tensor,
noise_pred_cond: torch.Tensor,
noise_pred_uncond: torch.Tensor,
sigma: torch.Tensor,
guidance_scale: float,
angle_clip: float = 3.14 / 6, # pi/6 by default
apply_norm: bool = False,
apply_clip: bool = True,
):
"""
ADG (Angle-based Dynamic Guidance) forward pass for Flow Matching.
In flow matching (including SD3), sigma represents the current timestep t_curr.
The predictions are velocity fields v(x_t, t).
Args:
latents: Current state x_t [N, T, d] where d=64
noise_pred_cond: Conditional velocity prediction v_cond [N, T, d]
noise_pred_uncond: Unconditional velocity prediction v_uncond [N, T, d]
sigma: Current timestep t_curr (not t_prev!)
guidance_scale: Guidance strength
angle_clip: Maximum angle for clipping (default: pi/6)
apply_norm: Whether to normalize the result (ADG_w_norm variant)
apply_clip: Whether to clip the angle (ADG_wo_clip when False)
Returns:
Guided velocity prediction [N, T, d]
"""
# Get batch size
n = noise_pred_cond.shape[0]
noise_pred_text = noise_pred_cond
n, t, c = noise_pred_text.shape
# Ensure sigma/t has the right shape for broadcasting [N, 1, 1]
if isinstance(sigma, (int, float)):
sigma = torch.tensor(sigma, device=latents.device, dtype=latents.dtype)
sigma = sigma.view(1, 1, 1).expand(n, 1, 1)
elif torch.is_tensor(sigma):
if sigma.numel() == 1:
sigma = sigma.view(1, 1, 1).expand(n, 1, 1)
elif sigma.numel() == n:
sigma = sigma.view(n, 1, 1)
else:
raise ValueError(f"sigma has incompatible shape. Expected scalar or size {n}, got {sigma.shape}")
else:
raise TypeError(f"sigma must be a number or tensor, got {type(sigma)}")
# Adjust guidance weight
weight = guidance_scale - 1
weight = weight * (weight > 0) + 1e-3
latent_hat_text = latents - sigma * noise_pred_text
latent_hat_uncond = latents - sigma * noise_pred_uncond
latent_diff = latent_hat_text - latent_hat_uncond
# Calculate angle between conditional and unconditional predicted data
latent_theta = torch.acos(
call_cos_tensor(latent_hat_text.view(-1, c).to(float),
latent_hat_uncond.reshape(-1, c).contiguous().to(float)))
latent_theta_new = torch.clip(weight * latent_theta, -angle_clip, angle_clip) if apply_clip else weight * latent_theta
proj, perp = compute_perpendicular_component(latent_diff, latent_hat_uncond)
latent_v_new = torch.cos(latent_theta_new) * latent_hat_text
latent_p_new = perp * torch.sin(latent_theta_new) / torch.sin(latent_theta) * (
torch.sin(latent_theta) > 1e-3) + perp * weight * (torch.sin(latent_theta) <= 1e-3)
latent_new = latent_v_new + latent_p_new
if apply_norm:
latent_new = latent_new * torch.linalg.norm(latent_hat_text, dim=1, keepdim=True) / torch.linalg.norm(
latent_new, dim=1, keepdim=True)
noise_pred = (latents - latent_new) / sigma
noise_pred = noise_pred.reshape(n, t, c).to(latents.dtype)
return noise_pred
def adg_w_norm_forward(
latents: torch.Tensor,
noise_pred_cond: torch.Tensor,
noise_pred_uncond: torch.Tensor,
sigma: float,
guidance_scale: float,
angle_clip: float = 3.14 / 3,
):
"""
ADG with normalization - preserves the magnitude of latent predictions.
This variant normalizes the final latent to maintain the same norm as the
conditional prediction, which can help preserve image quality.
"""
return adg_forward(latents,
noise_pred_cond,
noise_pred_uncond,
sigma,
guidance_scale,
angle_clip=angle_clip,
apply_norm=True,
apply_clip=True)
def adg_wo_clip_forward(
latents: torch.Tensor,
noise_pred_cond: torch.Tensor,
noise_pred_uncond: torch.Tensor,
sigma: float,
guidance_scale: float,
):
"""
ADG without angle clipping - allows unbounded angle adjustments.
This variant doesn't clip the angle, which may result in more aggressive
guidance but could be less stable.
"""
return adg_forward(latents, noise_pred_cond, noise_pred_uncond, sigma, guidance_scale, apply_norm=False, apply_clip=False)

View file

@ -0,0 +1,263 @@
# coding=utf-8
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""AceStep model configuration"""
from transformers.configuration_utils import PretrainedConfig, layer_type_validation
from transformers.modeling_rope_utils import rope_config_validation
from transformers.utils import logging
logger = logging.get_logger(__name__)
class AceStepConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`AceStepModel`]. It is used to instantiate an
AceStep model according to the specified arguments, defining the model architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 64003):
Vocabulary size of the AceStep model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling the model.
hidden_size (`int`, *optional*, defaults to 4096):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 22016):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (`int`, *optional*, defaults to 32):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details, check out [this
paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `32`.
head_dim (`int`, *optional*, defaults to 128):
The attention head dimension.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 32768):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether the model's input and output word embeddings should be tied.
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
accordingly.
Expected contents:
`rope_type` (`str`):
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
'llama3'], with 'default' being the original RoPE implementation.
`factor` (`float`, *optional*):
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
original maximum pre-trained length.
`original_max_position_embeddings` (`int`, *optional*):
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
pretraining.
`attention_factor` (`float`, *optional*):
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
computation. If unspecified, it defaults to value recommended by the implementation, using the
`factor` field to infer the suggested value.
`beta_fast` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
ramp function. If unspecified, it defaults to 32.
`beta_slow` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
ramp function. If unspecified, it defaults to 1.
`short_factor` (`list[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`long_factor` (`list[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`low_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
`high_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
use_sliding_window (`bool`, *optional*, defaults to `False`):
Whether to use sliding window attention.
sliding_window (`int`, *optional*, defaults to 4096):
Sliding window attention (SWA) window size. If not specified, will default to `4096`.
layer_types (`list`, *optional*):
Attention pattern for each layer.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
```python
>>> from acestep.models import AceStepConfig
>>> # Initializing an AceStep configuration
>>> configuration = AceStepConfig()
>>> # Initializing a model from the configuration
>>> model = AceStepConditionGenerationModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "acestep"
keys_to_ignore_at_inference = ["past_key_values"]
# Default tensor parallel plan for the base model
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}
def __init__(
self,
vocab_size=64003,
fsq_dim=2048,
fsq_input_levels=[8, 8, 8, 5, 5, 5],
fsq_input_num_quantizers=1,
hidden_size=2048,
intermediate_size=6144,
num_hidden_layers=24,
num_attention_heads=16,
num_key_value_heads=8,
head_dim=128,
hidden_act="silu",
max_position_embeddings=32768,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
tie_word_embeddings=True,
rope_theta=1000000,
rope_scaling=None,
attention_bias=False,
use_sliding_window=True,
sliding_window=128,
layer_types=None,
attention_dropout=0.0,
num_lyric_encoder_hidden_layers=8,
audio_acoustic_hidden_dim=64,
pool_window_size=5,
text_hidden_dim=1024,
in_channels=192,
data_proportion=0.5,
timestep_mu=-0.4,
timestep_sigma=1.0,
timbre_hidden_dim=64,
num_timbre_encoder_hidden_layers=4,
timbre_fix_frame=750,
patch_size=2,
num_attention_pooler_hidden_layers=2,
num_audio_decoder_hidden_layers=24,
model_version="turbo",
**kwargs,
):
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.use_sliding_window = use_sliding_window
self.sliding_window = sliding_window if self.use_sliding_window else None
# Text encoder configuration
self.text_hidden_dim = text_hidden_dim
# Lyric encoder configuration
self.num_lyric_encoder_hidden_layers = num_lyric_encoder_hidden_layers
self.patch_size = patch_size
# Audio semantic token generation configuration
self.audio_acoustic_hidden_dim = audio_acoustic_hidden_dim
self.pool_window_size = pool_window_size
self.in_channels = in_channels
self.data_proportion = data_proportion
self.timestep_mu = timestep_mu
self.timestep_sigma = timestep_sigma
# FSQ (Finite Scalar Quantization) configuration
self.fsq_dim = fsq_dim
self.fsq_input_levels = fsq_input_levels
self.fsq_input_num_quantizers = fsq_input_num_quantizers
# Timbre encoder configuration
self.timbre_hidden_dim = timbre_hidden_dim
self.num_timbre_encoder_hidden_layers = num_timbre_encoder_hidden_layers
self.timbre_fix_frame = timbre_fix_frame
self.num_attention_pooler_hidden_layers = num_attention_pooler_hidden_layers
self.num_audio_decoder_hidden_layers = num_audio_decoder_hidden_layers
self.vocab_size = vocab_size
# Backward compatibility: ensure num_key_value_heads is set
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.head_dim = head_dim
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.model_version = model_version
# Validate rotary position embeddings parameters
# Backward compatibility: if there is a 'type' field, move it to 'rope_type'
if self.rope_scaling is not None and "type" in self.rope_scaling:
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
rope_config_validation(self)
self.layer_types = layer_types
# Set default layer types if not specified
if self.layer_types is None:
self.layer_types = [
"sliding_attention" if bool((i + 1) % 2) else "full_attention" for i in range(self.num_hidden_layers)
]
layer_type_validation(self.layer_types)
super().__init__(
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
__all__ = ["AceStepConfig"]

File diff suppressed because it is too large Load diff

View file

View file

@ -0,0 +1,220 @@
import torch
import torch.nn.functional as F
class MomentumBuffer:
def __init__(self, momentum: float = -0.75):
self.momentum = momentum
self.running_average = 0
def update(self, update_value: torch.Tensor):
new_average = self.momentum * self.running_average
self.running_average = update_value + new_average
def project(
v0: torch.Tensor, # [B, C, T]
v1: torch.Tensor, # [B, C, T]
dims=[-1],
):
dtype = v0.dtype
device_type = v0.device.type
if device_type == "mps":
v0, v1 = v0.cpu(), v1.cpu()
v0, v1 = v0.double(), v1.double()
v1 = torch.nn.functional.normalize(v1, dim=dims)
v0_parallel = (v0 * v1).sum(dim=dims, keepdim=True) * v1
v0_orthogonal = v0 - v0_parallel
return v0_parallel.to(dtype).to(device_type), v0_orthogonal.to(dtype).to(device_type)
def apg_forward(
pred_cond: torch.Tensor, # [B, C, T]
pred_uncond: torch.Tensor, # [B, C, T]
guidance_scale: float,
momentum_buffer: MomentumBuffer = None,
eta: float = 0.0,
norm_threshold: float = 2.5,
dims=[-1],
):
diff = pred_cond - pred_uncond
if momentum_buffer is not None:
momentum_buffer.update(diff)
diff = momentum_buffer.running_average
if norm_threshold > 0:
ones = torch.ones_like(diff)
diff_norm = diff.norm(p=2, dim=dims, keepdim=True)
scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
diff = diff * scale_factor
diff_parallel, diff_orthogonal = project(diff, pred_cond, dims)
normalized_update = diff_orthogonal + eta * diff_parallel
pred_guided = pred_cond + (guidance_scale - 1) * normalized_update
return pred_guided
def cfg_forward(cond_output, uncond_output, cfg_strength):
return uncond_output + cfg_strength * (cond_output - uncond_output)
def call_cos_tensor(tensor1, tensor2):
"""
Calculate cosine similarity between two normalized tensors.
Args:
tensor1: First tensor [B, ...]
tensor2: Second tensor [B, ...]
Returns:
Cosine similarity value [B, 1]
"""
tensor1 = tensor1 / torch.linalg.norm(tensor1, dim=1, keepdim=True)
tensor2 = tensor2 / torch.linalg.norm(tensor2, dim=1, keepdim=True)
cosvalue = torch.sum(tensor1 * tensor2, dim=1, keepdim=True)
return cosvalue
def compute_perpendicular_component(latent_diff, latent_hat_uncond):
"""
Decompose latent_diff into parallel and perpendicular components relative to latent_hat_uncond.
Args:
latent_diff: Difference tensor [B, C, ...]
latent_hat_uncond: Unconditional prediction tensor [B, C, ...]
Returns:
projection: Parallel component
perpendicular_component: Perpendicular component
"""
n, t, c = latent_diff.shape
latent_diff = latent_diff.view(n * t, c).float()
latent_hat_uncond = latent_hat_uncond.view(n * t, c).float()
if latent_diff.size() != latent_hat_uncond.size():
raise ValueError("latent_diff and latent_hat_uncond must have the same shape [n, d].")
dot_product = torch.sum(latent_diff * latent_hat_uncond, dim=1, keepdim=True) # [n, 1]
norm_square = torch.sum(latent_hat_uncond * latent_hat_uncond, dim=1, keepdim=True) # [n, 1]
projection = (dot_product / (norm_square + 1e-8)) * latent_hat_uncond
perpendicular_component = latent_diff - projection
return projection.view(n, t, c), perpendicular_component.reshape(n, t, c)
def adg_forward(
latents: torch.Tensor,
noise_pred_cond: torch.Tensor,
noise_pred_uncond: torch.Tensor,
sigma: torch.Tensor,
guidance_scale: float,
angle_clip: float = 3.14 / 6, # pi/6 by default
apply_norm: bool = False,
apply_clip: bool = True,
):
"""
ADG (Angle-based Dynamic Guidance) forward pass for Flow Matching.
In flow matching (including SD3), sigma represents the current timestep t_curr.
The predictions are velocity fields v(x_t, t).
Args:
latents: Current state x_t [N, T, d] where d=64
noise_pred_cond: Conditional velocity prediction v_cond [N, T, d]
noise_pred_uncond: Unconditional velocity prediction v_uncond [N, T, d]
sigma: Current timestep t_curr (not t_prev!)
guidance_scale: Guidance strength
angle_clip: Maximum angle for clipping (default: pi/6)
apply_norm: Whether to normalize the result (ADG_w_norm variant)
apply_clip: Whether to clip the angle (ADG_wo_clip when False)
Returns:
Guided velocity prediction [N, T, d]
"""
# Get batch size
n = noise_pred_cond.shape[0]
noise_pred_text = noise_pred_cond
n, t, c = noise_pred_text.shape
# Ensure sigma/t has the right shape for broadcasting [N, 1, 1]
if isinstance(sigma, (int, float)):
sigma = torch.tensor(sigma, device=latents.device, dtype=latents.dtype)
sigma = sigma.view(1, 1, 1).expand(n, 1, 1)
elif torch.is_tensor(sigma):
if sigma.numel() == 1:
sigma = sigma.view(1, 1, 1).expand(n, 1, 1)
elif sigma.numel() == n:
sigma = sigma.view(n, 1, 1)
else:
raise ValueError(f"sigma has incompatible shape. Expected scalar or size {n}, got {sigma.shape}")
else:
raise TypeError(f"sigma must be a number or tensor, got {type(sigma)}")
# Adjust guidance weight
weight = guidance_scale - 1
weight = weight * (weight > 0) + 1e-3
latent_hat_text = latents - sigma * noise_pred_text
latent_hat_uncond = latents - sigma * noise_pred_uncond
latent_diff = latent_hat_text - latent_hat_uncond
# Calculate angle between conditional and unconditional predicted data
latent_theta = torch.acos(
call_cos_tensor(latent_hat_text.view(-1, c).to(float),
latent_hat_uncond.reshape(-1, c).contiguous().to(float)))
latent_theta_new = torch.clip(weight * latent_theta, -angle_clip, angle_clip) if apply_clip else weight * latent_theta
proj, perp = compute_perpendicular_component(latent_diff, latent_hat_uncond)
latent_v_new = torch.cos(latent_theta_new) * latent_hat_text
latent_p_new = perp * torch.sin(latent_theta_new) / torch.sin(latent_theta) * (
torch.sin(latent_theta) > 1e-3) + perp * weight * (torch.sin(latent_theta) <= 1e-3)
latent_new = latent_v_new + latent_p_new
if apply_norm:
latent_new = latent_new * torch.linalg.norm(latent_hat_text, dim=1, keepdim=True) / torch.linalg.norm(
latent_new, dim=1, keepdim=True)
noise_pred = (latents - latent_new) / sigma
noise_pred = noise_pred.reshape(n, t, c).to(latents.dtype)
return noise_pred
def adg_w_norm_forward(
latents: torch.Tensor,
noise_pred_cond: torch.Tensor,
noise_pred_uncond: torch.Tensor,
sigma: float,
guidance_scale: float,
angle_clip: float = 3.14 / 3,
):
"""
ADG with normalization - preserves the magnitude of latent predictions.
This variant normalizes the final latent to maintain the same norm as the
conditional prediction, which can help preserve image quality.
"""
return adg_forward(latents,
noise_pred_cond,
noise_pred_uncond,
sigma,
guidance_scale,
angle_clip=angle_clip,
apply_norm=True,
apply_clip=True)
def adg_wo_clip_forward(
latents: torch.Tensor,
noise_pred_cond: torch.Tensor,
noise_pred_uncond: torch.Tensor,
sigma: float,
guidance_scale: float,
):
"""
ADG without angle clipping - allows unbounded angle adjustments.
This variant doesn't clip the angle, which may result in more aggressive
guidance but could be less stable.
"""
return adg_forward(latents, noise_pred_cond, noise_pred_uncond, sigma, guidance_scale, apply_norm=False, apply_clip=False)

View file

@ -0,0 +1,263 @@
# coding=utf-8
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""AceStep model configuration"""
from transformers.configuration_utils import PretrainedConfig, layer_type_validation
from transformers.modeling_rope_utils import rope_config_validation
from transformers.utils import logging
logger = logging.get_logger(__name__)
class AceStepConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`AceStepModel`]. It is used to instantiate an
AceStep model according to the specified arguments, defining the model architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 64003):
Vocabulary size of the AceStep model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling the model.
hidden_size (`int`, *optional*, defaults to 4096):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 22016):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (`int`, *optional*, defaults to 32):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details, check out [this
paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `32`.
head_dim (`int`, *optional*, defaults to 128):
The attention head dimension.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 32768):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether the model's input and output word embeddings should be tied.
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
accordingly.
Expected contents:
`rope_type` (`str`):
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
'llama3'], with 'default' being the original RoPE implementation.
`factor` (`float`, *optional*):
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
original maximum pre-trained length.
`original_max_position_embeddings` (`int`, *optional*):
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
pretraining.
`attention_factor` (`float`, *optional*):
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
computation. If unspecified, it defaults to value recommended by the implementation, using the
`factor` field to infer the suggested value.
`beta_fast` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
ramp function. If unspecified, it defaults to 32.
`beta_slow` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
ramp function. If unspecified, it defaults to 1.
`short_factor` (`list[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`long_factor` (`list[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`low_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
`high_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
use_sliding_window (`bool`, *optional*, defaults to `False`):
Whether to use sliding window attention.
sliding_window (`int`, *optional*, defaults to 4096):
Sliding window attention (SWA) window size. If not specified, will default to `4096`.
layer_types (`list`, *optional*):
Attention pattern for each layer.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
```python
>>> from acestep.models import AceStepConfig
>>> # Initializing an AceStep configuration
>>> configuration = AceStepConfig()
>>> # Initializing a model from the configuration
>>> model = AceStepConditionGenerationModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "acestep"
keys_to_ignore_at_inference = ["past_key_values"]
# Default tensor parallel plan for the base model
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}
def __init__(
self,
vocab_size=64003,
fsq_dim=2048,
fsq_input_levels=[8, 8, 8, 5, 5, 5],
fsq_input_num_quantizers=1,
hidden_size=2048,
intermediate_size=6144,
num_hidden_layers=24,
num_attention_heads=16,
num_key_value_heads=8,
head_dim=128,
hidden_act="silu",
max_position_embeddings=32768,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
tie_word_embeddings=True,
rope_theta=1000000,
rope_scaling=None,
attention_bias=False,
use_sliding_window=True,
sliding_window=128,
layer_types=None,
attention_dropout=0.0,
num_lyric_encoder_hidden_layers=8,
audio_acoustic_hidden_dim=64,
pool_window_size=5,
text_hidden_dim=1024,
in_channels=192,
data_proportion=0.5,
timestep_mu=-0.4,
timestep_sigma=1.0,
timbre_hidden_dim=64,
num_timbre_encoder_hidden_layers=4,
timbre_fix_frame=750,
patch_size=2,
num_attention_pooler_hidden_layers=2,
num_audio_decoder_hidden_layers=24,
model_version="turbo",
**kwargs,
):
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.use_sliding_window = use_sliding_window
self.sliding_window = sliding_window if self.use_sliding_window else None
# Text encoder configuration
self.text_hidden_dim = text_hidden_dim
# Lyric encoder configuration
self.num_lyric_encoder_hidden_layers = num_lyric_encoder_hidden_layers
self.patch_size = patch_size
# Audio semantic token generation configuration
self.audio_acoustic_hidden_dim = audio_acoustic_hidden_dim
self.pool_window_size = pool_window_size
self.in_channels = in_channels
self.data_proportion = data_proportion
self.timestep_mu = timestep_mu
self.timestep_sigma = timestep_sigma
# FSQ (Finite Scalar Quantization) configuration
self.fsq_dim = fsq_dim
self.fsq_input_levels = fsq_input_levels
self.fsq_input_num_quantizers = fsq_input_num_quantizers
# Timbre encoder configuration
self.timbre_hidden_dim = timbre_hidden_dim
self.num_timbre_encoder_hidden_layers = num_timbre_encoder_hidden_layers
self.timbre_fix_frame = timbre_fix_frame
self.num_attention_pooler_hidden_layers = num_attention_pooler_hidden_layers
self.num_audio_decoder_hidden_layers = num_audio_decoder_hidden_layers
self.vocab_size = vocab_size
# Backward compatibility: ensure num_key_value_heads is set
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.head_dim = head_dim
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.model_version = model_version
# Validate rotary position embeddings parameters
# Backward compatibility: if there is a 'type' field, move it to 'rope_type'
if self.rope_scaling is not None and "type" in self.rope_scaling:
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
rope_config_validation(self)
self.layer_types = layer_types
# Set default layer types if not specified
if self.layer_types is None:
self.layer_types = [
"sliding_attention" if bool((i + 1) % 2) else "full_attention" for i in range(self.num_hidden_layers)
]
layer_type_validation(self.layer_types)
super().__init__(
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
__all__ = ["AceStepConfig"]

File diff suppressed because it is too large Load diff

View file

View file

@ -0,0 +1,263 @@
# coding=utf-8
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""AceStep model configuration"""
from transformers.configuration_utils import PretrainedConfig, layer_type_validation
from transformers.modeling_rope_utils import rope_config_validation
from transformers.utils import logging
logger = logging.get_logger(__name__)
class AceStepConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`AceStepModel`]. It is used to instantiate an
AceStep model according to the specified arguments, defining the model architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 64003):
Vocabulary size of the AceStep model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling the model.
hidden_size (`int`, *optional*, defaults to 4096):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 22016):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (`int`, *optional*, defaults to 32):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details, check out [this
paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `32`.
head_dim (`int`, *optional*, defaults to 128):
The attention head dimension.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 32768):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether the model's input and output word embeddings should be tied.
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
accordingly.
Expected contents:
`rope_type` (`str`):
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
'llama3'], with 'default' being the original RoPE implementation.
`factor` (`float`, *optional*):
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
original maximum pre-trained length.
`original_max_position_embeddings` (`int`, *optional*):
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
pretraining.
`attention_factor` (`float`, *optional*):
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
computation. If unspecified, it defaults to value recommended by the implementation, using the
`factor` field to infer the suggested value.
`beta_fast` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
ramp function. If unspecified, it defaults to 32.
`beta_slow` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
ramp function. If unspecified, it defaults to 1.
`short_factor` (`list[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`long_factor` (`list[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`low_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
`high_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
use_sliding_window (`bool`, *optional*, defaults to `False`):
Whether to use sliding window attention.
sliding_window (`int`, *optional*, defaults to 4096):
Sliding window attention (SWA) window size. If not specified, will default to `4096`.
layer_types (`list`, *optional*):
Attention pattern for each layer.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
```python
>>> from acestep.models import AceStepConfig
>>> # Initializing an AceStep configuration
>>> configuration = AceStepConfig()
>>> # Initializing a model from the configuration
>>> model = AceStepConditionGenerationModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "acestep"
keys_to_ignore_at_inference = ["past_key_values"]
# Default tensor parallel plan for the base model
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}
def __init__(
self,
vocab_size=64003,
fsq_dim=2048,
fsq_input_levels=[8, 8, 8, 5, 5, 5],
fsq_input_num_quantizers=1,
hidden_size=2048,
intermediate_size=6144,
num_hidden_layers=24,
num_attention_heads=16,
num_key_value_heads=8,
head_dim=128,
hidden_act="silu",
max_position_embeddings=32768,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
tie_word_embeddings=True,
rope_theta=1000000,
rope_scaling=None,
attention_bias=False,
use_sliding_window=True,
sliding_window=128,
layer_types=None,
attention_dropout=0.0,
num_lyric_encoder_hidden_layers=8,
audio_acoustic_hidden_dim=64,
pool_window_size=5,
text_hidden_dim=1024,
in_channels=192,
data_proportion=0.5,
timestep_mu=-0.4,
timestep_sigma=1.0,
timbre_hidden_dim=64,
num_timbre_encoder_hidden_layers=4,
timbre_fix_frame=750,
patch_size=2,
num_attention_pooler_hidden_layers=2,
num_audio_decoder_hidden_layers=24,
model_version="turbo",
**kwargs,
):
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.use_sliding_window = use_sliding_window
self.sliding_window = sliding_window if self.use_sliding_window else None
# Text encoder configuration
self.text_hidden_dim = text_hidden_dim
# Lyric encoder configuration
self.num_lyric_encoder_hidden_layers = num_lyric_encoder_hidden_layers
self.patch_size = patch_size
# Audio semantic token generation configuration
self.audio_acoustic_hidden_dim = audio_acoustic_hidden_dim
self.pool_window_size = pool_window_size
self.in_channels = in_channels
self.data_proportion = data_proportion
self.timestep_mu = timestep_mu
self.timestep_sigma = timestep_sigma
# FSQ (Finite Scalar Quantization) configuration
self.fsq_dim = fsq_dim
self.fsq_input_levels = fsq_input_levels
self.fsq_input_num_quantizers = fsq_input_num_quantizers
# Timbre encoder configuration
self.timbre_hidden_dim = timbre_hidden_dim
self.num_timbre_encoder_hidden_layers = num_timbre_encoder_hidden_layers
self.timbre_fix_frame = timbre_fix_frame
self.num_attention_pooler_hidden_layers = num_attention_pooler_hidden_layers
self.num_audio_decoder_hidden_layers = num_audio_decoder_hidden_layers
self.vocab_size = vocab_size
# Backward compatibility: ensure num_key_value_heads is set
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.head_dim = head_dim
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.model_version = model_version
# Validate rotary position embeddings parameters
# Backward compatibility: if there is a 'type' field, move it to 'rope_type'
if self.rope_scaling is not None and "type" in self.rope_scaling:
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
rope_config_validation(self)
self.layer_types = layer_types
# Set default layer types if not specified
if self.layer_types is None:
self.layer_types = [
"sliding_attention" if bool((i + 1) % 2) else "full_attention" for i in range(self.num_hidden_layers)
]
layer_type_validation(self.layer_types)
super().__init__(
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
__all__ = ["AceStepConfig"]

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,99 @@
"""
Distributed utilities for nano-vllm.
This module provides wrapper functions for torch.distributed that gracefully
handle single-GPU mode (world_size == 1) without requiring distributed initialization.
"""
import torch.distributed as dist
# Global flag to track if distributed is actually initialized
_distributed_initialized = False
def initialize_distributed(backend: str, init_method: str, world_size: int, rank: int) -> bool:
"""
Initialize distributed process group only if world_size > 1.
Args:
backend: Distributed backend (e.g., "nccl" or "gloo")
init_method: Initialization method (e.g., "tcp://127.0.0.1:2333")
world_size: Total number of processes
rank: Rank of current process
Returns:
True if distributed was initialized, False otherwise
"""
global _distributed_initialized
if world_size == 1:
# Single GPU mode - no distributed needed
_distributed_initialized = False
return False
# Multi-GPU mode - initialize distributed
dist.init_process_group(backend, init_method, world_size=world_size, rank=rank)
_distributed_initialized = True
return True
def is_initialized() -> bool:
"""Check if distributed is initialized."""
return _distributed_initialized
def get_rank() -> int:
"""Get current process rank. Returns 0 if distributed is not initialized."""
if _distributed_initialized:
return dist.get_rank()
return 0
def get_world_size() -> int:
"""Get world size. Returns 1 if distributed is not initialized."""
if _distributed_initialized:
return dist.get_world_size()
return 1
def barrier():
"""Synchronize all processes. No-op if distributed is not initialized."""
if _distributed_initialized:
dist.barrier()
def all_reduce(tensor, op=None):
"""
All-reduce operation. No-op if distributed is not initialized.
Args:
tensor: Tensor to reduce
op: Reduce operation (default: SUM)
"""
if _distributed_initialized:
if op is None:
op = dist.ReduceOp.SUM
dist.all_reduce(tensor, op)
def gather(tensor, gather_list=None, dst=0):
"""
Gather tensors from all processes. No-op if distributed is not initialized.
Args:
tensor: Tensor to gather
gather_list: List to gather into (only used on dst rank)
dst: Destination rank
"""
if _distributed_initialized:
dist.gather(tensor, gather_list, dst)
def destroy_process_group():
"""Destroy process group. No-op if distributed is not initialized."""
global _distributed_initialized
if _distributed_initialized:
dist.destroy_process_group()
_distributed_initialized = False

View file

@ -8,6 +8,7 @@ import sys
from nanovllm.config import Config
from acestep.debug_utils import debug_start, debug_end
from nanovllm import distributed as dist_utils
# Debug logging - enable with NANOVLLM_DEBUG=1
_DEBUG = os.environ.get("NANOVLLM_DEBUG", "0") == "1"
@ -64,11 +65,15 @@ class ModelRunner:
self.world_size = config.tensor_parallel_size
self.rank = rank
self.event = event
dist_port = find_available_port()
print(f"[debug]dist_port: {dist_port}")
# Use gloo backend on Windows, nccl on Linux/other platforms
backend = "gloo" if sys.platform == "win32" else "nccl"
dist.init_process_group(backend, f"tcp://127.0.0.1:{dist_port}", world_size=self.world_size, rank=rank)
# Only initialize distributed if world_size > 1
if self.world_size > 1:
dist_port = find_available_port()
print(f"[debug]dist_port: {dist_port}")
# Use gloo backend on Windows, nccl on Linux/other platforms
backend = "gloo" if sys.platform == "win32" else "nccl"
dist_utils.initialize_distributed(backend, f"tcp://127.0.0.1:{dist_port}", world_size=self.world_size, rank=rank)
torch.cuda.set_device(rank)
default_dtype = torch.get_default_dtype()
# Use dtype instead of deprecated torch_dtype
@ -119,9 +124,9 @@ class ModelRunner:
if self.world_size > 1:
if rank == 0:
self.shm = SharedMemory(name="nanovllm", create=True, size=2**20)
dist.barrier()
dist_utils.barrier()
else:
dist.barrier()
dist_utils.barrier()
self.shm = SharedMemory(name="nanovllm")
self.loop()
@ -163,13 +168,13 @@ class ModelRunner:
def exit(self):
if self.world_size > 1:
self.shm.close()
dist.barrier()
dist_utils.barrier()
if self.rank == 0:
self.shm.unlink()
if not self.enforce_eager:
del self.graphs, self.graph_pool
torch.cuda.synchronize()
dist.destroy_process_group()
dist_utils.destroy_process_group()
def loop(self):
while True:

View file

@ -4,6 +4,7 @@ import torch.nn.functional as F
import torch.distributed as dist
from nanovllm.utils.context import get_context
from nanovllm import distributed as dist_utils
class VocabParallelEmbedding(nn.Module):
@ -14,8 +15,8 @@ class VocabParallelEmbedding(nn.Module):
embedding_dim: int,
):
super().__init__()
self.tp_rank = dist.get_rank()
self.tp_size = dist.get_world_size()
self.tp_rank = dist_utils.get_rank()
self.tp_size = dist_utils.get_world_size()
assert num_embeddings % self.tp_size == 0
self.num_embeddings = num_embeddings
self.num_embeddings_per_partition = self.num_embeddings // self.tp_size
@ -38,7 +39,7 @@ class VocabParallelEmbedding(nn.Module):
y = F.embedding(x, self.weight)
if self.tp_size > 1:
y = mask.unsqueeze(1) * y
dist.all_reduce(y)
dist_utils.all_reduce(y)
return y
@ -59,8 +60,10 @@ class ParallelLMHead(VocabParallelEmbedding):
last_indices = context.cu_seqlens_q[1:] - 1
x = x[last_indices].contiguous()
logits = F.linear(x, self.weight)
# In multi-GPU mode, gather logits from all ranks and concatenate
# In single-GPU mode (tp_size=1), skip gathering and return logits directly
if self.tp_size > 1:
all_logits = [torch.empty_like(logits) for _ in range(self.tp_size)] if self.tp_rank == 0 else None
dist.gather(logits, all_logits, 0)
dist_utils.gather(logits, all_logits, 0)
logits = torch.cat(all_logits, -1) if self.tp_rank == 0 else None
return logits

View file

@ -3,6 +3,8 @@ from torch import nn
import torch.nn.functional as F
import torch.distributed as dist
from nanovllm import distributed as dist_utils
def divide(numerator, denominator):
assert numerator % denominator == 0
@ -20,8 +22,8 @@ class LinearBase(nn.Module):
):
super().__init__()
self.tp_dim = tp_dim
self.tp_rank = dist.get_rank()
self.tp_size = dist.get_world_size()
self.tp_rank = dist_utils.get_rank()
self.tp_size = dist_utils.get_world_size()
self.weight = nn.Parameter(torch.empty(output_size, input_size))
self.weight.weight_loader = self.weight_loader
if bias:
@ -59,7 +61,7 @@ class ColumnParallelLinear(LinearBase):
output_size: int,
bias: bool = False,
):
tp_size = dist.get_world_size()
tp_size = dist_utils.get_world_size()
super().__init__(input_size, divide(output_size, tp_size), bias, 0)
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
@ -103,7 +105,7 @@ class QKVParallelLinear(ColumnParallelLinear):
total_num_kv_heads: int | None = None,
bias: bool = False,
):
tp_size = dist.get_world_size()
tp_size = dist_utils.get_world_size()
total_num_kv_heads = total_num_kv_heads or total_num_heads
self.head_size = head_size
self.num_heads = divide(total_num_heads, tp_size)
@ -136,7 +138,7 @@ class RowParallelLinear(LinearBase):
output_size: int,
bias: bool = False,
):
tp_size = dist.get_world_size()
tp_size = dist_utils.get_world_size()
super().__init__(divide(input_size, tp_size), output_size, bias, 1)
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
@ -149,5 +151,5 @@ class RowParallelLinear(LinearBase):
def forward(self, x: torch.Tensor) -> torch.Tensor:
y = F.linear(x, self.weight, self.bias if self.tp_rank == 0 else None)
if self.tp_size > 1:
dist.all_reduce(y)
dist_utils.all_reduce(y)
return y

View file

@ -9,6 +9,7 @@ from nanovllm.layers.layernorm import RMSNorm
from nanovllm.layers.linear import QKVParallelLinear, MergedColumnParallelLinear, RowParallelLinear
from nanovllm.layers.rotary_embedding import get_rope
from nanovllm.layers.embed_head import VocabParallelEmbedding, ParallelLMHead
from nanovllm import distributed as dist_utils
class Qwen3Attention(nn.Module):
@ -26,7 +27,7 @@ class Qwen3Attention(nn.Module):
rope_scaling: tuple | None = None,
) -> None:
super().__init__()
tp_size = dist.get_world_size()
tp_size = dist_utils.get_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size

View file

@ -6,7 +6,7 @@ including dataset building, audio labeling, and training utilities.
"""
from acestep.training.dataset_builder import DatasetBuilder, AudioSample
from acestep.training.configs import LoRAConfig, TrainingConfig
from acestep.training.configs import LoRAConfig, LoKRConfig, TrainingConfig
from acestep.training.lora_utils import (
inject_lora_into_dit,
save_lora_weights,
@ -14,6 +14,12 @@ from acestep.training.lora_utils import (
merge_lora_weights,
check_peft_available,
)
from acestep.training.lokr_utils import (
inject_lokr_into_dit,
save_lokr_weights,
load_lokr_weights,
check_lycoris_available,
)
from acestep.training.data_module import (
# Preprocessed (recommended)
PreprocessedTensorDataset,
@ -25,7 +31,13 @@ from acestep.training.data_module import (
collate_training_batch,
load_dataset_from_json,
)
from acestep.training.trainer import LoRATrainer, PreprocessedLoRAModule, LIGHTNING_AVAILABLE
from acestep.training.trainer import (
LoRATrainer,
LoKRTrainer,
PreprocessedLoRAModule,
PreprocessedLoKRModule,
LIGHTNING_AVAILABLE,
)
def check_lightning_available():
"""Check if Lightning Fabric is available."""
@ -37,6 +49,7 @@ __all__ = [
"AudioSample",
# Configs
"LoRAConfig",
"LoKRConfig",
"TrainingConfig",
# LoRA Utils
"inject_lora_into_dit",
@ -44,6 +57,11 @@ __all__ = [
"load_lora_weights",
"merge_lora_weights",
"check_peft_available",
# LoKr Utils
"inject_lokr_into_dit",
"save_lokr_weights",
"load_lokr_weights",
"check_lycoris_available",
# Data Module (Preprocessed - Recommended)
"PreprocessedTensorDataset",
"PreprocessedDataModule",
@ -55,7 +73,9 @@ __all__ = [
"load_dataset_from_json",
# Trainer
"LoRATrainer",
"LoKRTrainer",
"PreprocessedLoRAModule",
"PreprocessedLoKRModule",
"check_lightning_available",
"LIGHTNING_AVAILABLE",
]

View file

@ -5,7 +5,7 @@ Contains dataclasses for LoRA and training configurations.
"""
from dataclasses import dataclass, field
from typing import List, Optional
from typing import List
@dataclass
@ -38,6 +38,43 @@ class LoRAConfig:
}
@dataclass
class LoKRConfig:
"""Configuration for LoKr (Low-Rank Kronecker) training."""
linear_dim: int = 64
linear_alpha: int = 128
factor: int = -1
decompose_both: bool = False
use_tucker: bool = False
use_scalar: bool = False
weight_decompose: bool = False
target_modules: List[str] = field(default_factory=lambda: [
"q_proj", "k_proj", "v_proj", "o_proj"
])
full_matrix: bool = False
bypass_mode: bool = False
rs_lora: bool = False
unbalanced_factorization: bool = False
def to_dict(self):
"""Convert to dictionary for LyCORIS config."""
return {
"linear_dim": self.linear_dim,
"linear_alpha": self.linear_alpha,
"factor": self.factor,
"decompose_both": self.decompose_both,
"use_tucker": self.use_tucker,
"use_scalar": self.use_scalar,
"weight_decompose": self.weight_decompose,
"target_modules": self.target_modules,
"full_matrix": self.full_matrix,
"bypass_mode": self.bypass_mode,
"rs_lora": self.rs_lora,
"unbalanced_factorization": self.unbalanced_factorization,
}
@dataclass
class TrainingConfig:
"""Configuration for LoRA training process.
@ -83,11 +120,18 @@ class TrainingConfig:
pin_memory: bool = True
prefetch_factor: int = 2
persistent_workers: bool = True
pin_memory_device: Optional[str] = None
pin_memory_device: str = ""
# Logging
log_every_n_steps: int = 10
# Validation (for loss curve and best-checkpoint tracking)
val_split: float = 0.0
def __post_init__(self) -> None:
if not 0.0 <= self.val_split < 1.0:
raise ValueError("val_split must be in [0.0, 1.0).")
def to_dict(self):
"""Convert to dictionary."""
return {
@ -110,4 +154,5 @@ class TrainingConfig:
"persistent_workers": self.persistent_workers,
"pin_memory_device": self.pin_memory_device,
"log_every_n_steps": self.log_every_n_steps,
"val_split": self.val_split,
}

View file

@ -177,7 +177,7 @@ class PreprocessedDataModule(LightningDataModule if LIGHTNING_AVAILABLE else obj
pin_memory: bool = True,
prefetch_factor: int = 2,
persistent_workers: bool = True,
pin_memory_device: Optional[str] = None,
pin_memory_device: str = "",
val_split: float = 0.0,
):
"""Initialize the data module.
@ -226,7 +226,7 @@ class PreprocessedDataModule(LightningDataModule if LIGHTNING_AVAILABLE else obj
"""Create training dataloader."""
prefetch_factor = None if self.num_workers == 0 else self.prefetch_factor
persistent_workers = False if self.num_workers == 0 else self.persistent_workers
pin_memory_device = self.pin_memory_device if self.pin_memory else None
pin_memory_device = self.pin_memory_device if self.pin_memory else ""
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
@ -246,7 +246,7 @@ class PreprocessedDataModule(LightningDataModule if LIGHTNING_AVAILABLE else obj
return None
prefetch_factor = None if self.num_workers == 0 else self.prefetch_factor
persistent_workers = False if self.num_workers == 0 else self.persistent_workers
pin_memory_device = self.pin_memory_device if self.pin_memory else None
pin_memory_device = self.pin_memory_device if self.pin_memory else ""
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,

View file

@ -1,39 +1,62 @@
import os
from typing import Tuple
import torchaudio
from loguru import logger
def load_lyrics_file(audio_path: str) -> Tuple[str, bool]:
"""Load lyrics from a .txt file with the same name as the audio file."""
def _read_text_file(path: str) -> Tuple[str, bool]:
"""Read a text file; return (content.strip(), True) if present and non-empty."""
if not os.path.exists(path):
return "", False
try:
with open(path, "r", encoding="utf-8") as f:
content = f.read().strip()
if content:
return content, True
return "", False
except Exception as e:
logger.warning(f"Failed to read {path}: {e}")
return "", False
def load_caption_file(audio_path: str) -> Tuple[str, bool]:
"""Load caption from <basename>.caption.txt (explicit convention)."""
base_path = os.path.splitext(audio_path)[0]
lyrics_path = base_path + ".txt"
caption_path = base_path + ".caption.txt"
content, ok = _read_text_file(caption_path)
if ok:
logger.debug(f"Loaded caption from {caption_path}")
return content, ok
if os.path.exists(lyrics_path):
try:
with open(lyrics_path, "r", encoding="utf-8") as f:
lyrics_content = f.read().strip()
if lyrics_content:
logger.info(f"Loaded lyrics from {lyrics_path}")
return lyrics_content, True
logger.warning(f"Lyrics file is empty: {lyrics_path}")
return "", False
except Exception as e:
logger.warning(f"Failed to read lyrics file {lyrics_path}: {e}")
return "", False
def load_lyrics_file(audio_path: str) -> Tuple[str, bool]:
"""Load lyrics from <basename>.lyrics.txt, then fallback to <basename>.txt for backward compat."""
base_path = os.path.splitext(audio_path)[0]
for suffix in (".lyrics.txt", ".txt"):
path = base_path + suffix
content, ok = _read_text_file(path)
if ok:
if suffix == ".lyrics.txt":
logger.debug(f"Loaded lyrics from {path}")
else:
logger.debug(f"Loaded lyrics from {path} (legacy .txt)")
return content, True
return "", False
def get_audio_duration(audio_path: str) -> int:
"""Get the duration of an audio file in seconds."""
# Primary: torchcodec (ships with torchaudio >=2.9, supports all ffmpeg formats)
# Note: torchcodec is optional on ROCM/Intel platforms due to CUDA dependencies
try:
info = torchaudio.info(audio_path)
return int(info.num_frames / info.sample_rate)
from torchcodec.decoders import AudioDecoder
decoder = AudioDecoder(audio_path)
return int(decoder.metadata.duration_seconds)
except ImportError:
logger.debug("torchcodec not available (expected on ROCM/Intel platforms), using soundfile fallback")
except Exception as e:
logger.warning(f"torchaudio failed for {audio_path}: {e}, trying soundfile")
logger.debug(f"torchcodec failed for {audio_path}: {e}, trying soundfile")
# Fallback: soundfile (fast for wav/flac/ogg, works on all platforms)
try:
import soundfile as sf
info = sf.info(audio_path)

View file

@ -29,10 +29,18 @@ class PreprocessMixin:
dit_handler,
output_dir: str,
max_duration: float = 240.0,
preprocess_mode: str = "lora",
progress_callback=None,
) -> Tuple[List[str], str]:
"""Preprocess all labeled samples to tensor files for efficient training."""
debug_log_for("dataset", f"preprocess_to_tensors: output_dir='{output_dir}', max_duration={max_duration}")
mode = str(preprocess_mode or "lora").strip().lower()
if mode not in {"lora", "lokr"}:
mode = "lora"
debug_log_for(
"dataset",
f"preprocess_to_tensors: output_dir='{output_dir}', max_duration={max_duration}, mode={mode}",
)
if not self.samples:
return [], "❌ No samples to preprocess"
@ -145,6 +153,17 @@ class PreprocessMixin:
if lyric_hidden_states.dtype != model_dtype:
lyric_hidden_states = lyric_hidden_states.to(model_dtype)
refer_audio_hidden = None
refer_audio_order_mask_val = None
if mode == "lokr":
# LoKr mode uses per-sample audio latents as reference-audio conditioning.
refer_audio_hidden = target_latents.to(device=model_device, dtype=model_dtype)
refer_audio_order_mask_val = torch.zeros(
refer_audio_hidden.shape[0],
device=model_device,
dtype=torch.long,
)
encoder_hidden_states, encoder_attention_mask = run_encoder(
model,
text_hidden_states=text_hidden_states,
@ -153,6 +172,8 @@ class PreprocessMixin:
lyric_attention_mask=lyric_attention_mask,
device=model_device,
dtype=model_dtype,
refer_audio_hidden_states_packed=refer_audio_hidden,
refer_audio_order_mask=refer_audio_order_mask_val,
)
debug_end_verbose_for("dataset", f"run_encoder[{i}]", t0)
debug_log_verbose_for(
@ -162,7 +183,14 @@ class PreprocessMixin:
)
t0 = debug_start_verbose_for("dataset", f"build_context_latents[{i}]")
context_latents = build_context_latents(silence_latent, latent_length, device, dtype)
context_src = target_latents if mode == "lokr" else None
context_latents = build_context_latents(
silence_latent,
latent_length,
device,
dtype,
src_latents=context_src,
)
debug_end_verbose_for("dataset", f"build_context_latents[{i}]", t0)
output_data = {
@ -182,6 +210,7 @@ class PreprocessMixin:
"timesignature": sample.timesignature,
"language": sample.language,
"is_instrumental": sample.is_instrumental,
"preprocess_mode": mode,
},
}

View file

@ -1,9 +1,20 @@
import torch
def build_context_latents(silence_latent, latent_length: int, device, dtype):
def build_context_latents(
silence_latent,
latent_length: int,
device,
dtype,
src_latents: torch.Tensor = None,
):
"""Build context latents for text2music."""
src_latents = silence_latent[:, :latent_length, :].to(dtype)
if src_latents is None:
src_latents = silence_latent[:, :latent_length, :].to(dtype)
else:
if src_latents.dim() == 2:
src_latents = src_latents.unsqueeze(0)
src_latents = src_latents.to(device=device, dtype=dtype)
if src_latents.shape[0] < 1:
src_latents = src_latents.expand(1, -1, -1)

View file

@ -9,10 +9,26 @@ def run_encoder(
lyric_attention_mask,
device,
dtype,
refer_audio_hidden_states_packed=None,
refer_audio_order_mask=None,
):
"""Run model encoder to get hidden states and attention mask."""
refer_audio_hidden = torch.zeros(1, 1, 64, device=device, dtype=dtype)
refer_audio_order_mask = torch.zeros(1, device=device, dtype=torch.long)
if refer_audio_hidden_states_packed is None:
refer_audio_hidden = torch.zeros(1, 1, 64, device=device, dtype=dtype)
else:
refer_audio_hidden = refer_audio_hidden_states_packed
if refer_audio_hidden.dim() == 2:
refer_audio_hidden = refer_audio_hidden.unsqueeze(0)
refer_audio_hidden = refer_audio_hidden.to(device=device, dtype=dtype)
if refer_audio_order_mask is None:
refer_audio_order_mask = torch.zeros(
refer_audio_hidden.shape[0],
device=device,
dtype=torch.long,
)
else:
refer_audio_order_mask = refer_audio_order_mask.to(device=device, dtype=torch.long)
with torch.no_grad():
encoder_hidden_states, encoder_attention_mask = model.encoder(

View file

@ -3,7 +3,7 @@ from typing import List, Tuple
from loguru import logger
from .audio_io import get_audio_duration, load_lyrics_file
from .audio_io import get_audio_duration, load_caption_file, load_lyrics_file
from .csv_metadata import load_csv_metadata
from .models import AudioSample, SUPPORTED_AUDIO_FORMATS
@ -39,17 +39,23 @@ class ScanMixin:
csv_metadata = load_csv_metadata(directory)
csv_count = 0
caption_count = 0
lyrics_count = 0
for audio_path in audio_files:
try:
duration = get_audio_duration(audio_path)
caption_content, has_caption_file = load_caption_file(audio_path)
lyrics_content, has_lyrics_file = load_lyrics_file(audio_path)
if has_caption_file:
caption_count += 1
if has_lyrics_file:
lyrics_count += 1
is_instrumental = self.metadata.all_instrumental
if has_lyrics_file:
is_instrumental = False
lyrics_count += 1
sample = AudioSample(
audio_path=audio_path,
@ -57,9 +63,12 @@ class ScanMixin:
duration=duration,
is_instrumental=is_instrumental,
custom_tag=self.metadata.custom_tag,
caption=caption_content if has_caption_file else "",
lyrics=lyrics_content if has_lyrics_file else "[Instrumental]",
raw_lyrics=lyrics_content if has_lyrics_file else "",
)
if has_caption_file:
sample.labeled = True
if csv_metadata and sample.filename in csv_metadata:
meta = csv_metadata[sample.filename]
@ -79,8 +88,10 @@ class ScanMixin:
self.metadata.num_samples = len(self.samples)
status = f"✅ Found {len(self.samples)} audio files in {directory}"
if caption_count > 0:
status += f"\n 📋 Detected {caption_count} captions (.caption.txt)"
if lyrics_count > 0:
status += f"\n 📝 {lyrics_count} files have accompanying lyrics (.txt)"
status += f"\n 📝 Detected {lyrics_count} lyrics (.lyrics.txt / .txt)"
if csv_count > 0:
status += f"\n 📊 {csv_count} files have metadata from CSV"

View file

@ -0,0 +1,253 @@
"""
LoKr utilities for ACE-Step training and inference.
This module integrates LyCORIS LoKr adapters with the ACE-Step decoder.
"""
import json
import os
from typing import Any, Dict, Optional, Tuple
import torch
from loguru import logger
from acestep.training.configs import LoKRConfig
try:
from lycoris import LycorisNetwork, create_lycoris
LYCORIS_AVAILABLE = True
except ImportError:
LYCORIS_AVAILABLE = False
LycorisNetwork = Any # type: ignore[assignment,misc]
logger.warning(
"LyCORIS library not installed. LoKr training/inference unavailable. "
"Install with: pip install lycoris-lora"
)
def check_lycoris_available() -> bool:
"""Check if LyCORIS is importable."""
return LYCORIS_AVAILABLE
def _matches_target_module_name(module_name: str, target_modules) -> bool:
"""Return True if a LyCORIS module name maps to one of target module suffixes."""
if not module_name:
return False
name = module_name.lower()
for target in target_modules or []:
t = str(target).strip().lower()
if not t:
continue
if name.endswith(t) or f"_{t}" in name or f".{t}" in name:
return True
return False
def inject_lokr_into_dit(
model,
lokr_config: LoKRConfig,
multiplier: float = 1.0,
) -> Tuple[Any, "LycorisNetwork", Dict[str, Any]]:
"""
Inject LoKr adapters into the decoder.
Returns:
Tuple: (model, lycoris_network, info_dict)
"""
if not LYCORIS_AVAILABLE:
raise ImportError(
"LyCORIS library is required for LoKr training. "
"Install with: pip install lycoris-lora"
)
decoder = model.decoder
# Freeze all existing params before creating adapter params.
for _, param in model.named_parameters():
param.requires_grad = False
LycorisNetwork.apply_preset(
{
"unet_target_name": lokr_config.target_modules,
"target_name": lokr_config.target_modules,
}
)
lycoris_net = create_lycoris(
decoder,
multiplier,
linear_dim=lokr_config.linear_dim,
linear_alpha=lokr_config.linear_alpha,
algo="lokr",
factor=lokr_config.factor,
decompose_both=lokr_config.decompose_both,
use_tucker=lokr_config.use_tucker,
use_scalar=lokr_config.use_scalar,
full_matrix=lokr_config.full_matrix,
bypass_mode=lokr_config.bypass_mode,
rs_lora=lokr_config.rs_lora,
unbalanced_factorization=lokr_config.unbalanced_factorization,
)
if lokr_config.weight_decompose:
try:
lycoris_net = create_lycoris(
decoder,
multiplier,
linear_dim=lokr_config.linear_dim,
linear_alpha=lokr_config.linear_alpha,
algo="lokr",
factor=lokr_config.factor,
decompose_both=lokr_config.decompose_both,
use_tucker=lokr_config.use_tucker,
use_scalar=lokr_config.use_scalar,
full_matrix=lokr_config.full_matrix,
bypass_mode=lokr_config.bypass_mode,
rs_lora=lokr_config.rs_lora,
unbalanced_factorization=lokr_config.unbalanced_factorization,
dora_wd=True,
)
except Exception as exc:
logger.warning(f"DoRA mode not supported in current LyCORIS build: {exc}")
lycoris_net.apply_to()
# Keep a reference on decoder so it stays discoverable after wrappers.
# Always refresh this reference to avoid stale nets from earlier runs.
decoder._lycoris_net = lycoris_net
lokr_param_list = []
enabled_module_count = 0
disabled_module_count = 0
disabled_examples = []
for idx, module in enumerate(getattr(lycoris_net, "loras", []) or []):
module_name = (
getattr(module, "lora_name", None)
or getattr(module, "name", None)
or f"{module.__class__.__name__}#{idx}"
)
enabled = _matches_target_module_name(module_name, lokr_config.target_modules)
if enabled:
enabled_module_count += 1
else:
disabled_module_count += 1
if len(disabled_examples) < 8:
disabled_examples.append(module_name)
for param in module.parameters():
param.requires_grad = enabled
if enabled:
lokr_param_list.append(param)
logger.info(
f"LoKr target filter: enabled {enabled_module_count} LyCORIS modules "
f"(disabled {disabled_module_count}) for targets={lokr_config.target_modules}"
)
if disabled_examples:
logger.info("LoKr disabled non-target modules (sample): " + ", ".join(disabled_examples))
if not lokr_param_list:
for param in lycoris_net.parameters():
param.requires_grad = True
lokr_param_list.append(param)
# De-duplicate possible shared params.
unique_params = {id(p): p for p in lokr_param_list}
total_params = sum(p.numel() for p in model.parameters())
lokr_params = sum(p.numel() for p in unique_params.values())
trainable_params = sum(p.numel() for p in unique_params.values() if p.requires_grad)
info = {
"total_params": total_params,
"lokr_params": lokr_params,
"trainable_params": trainable_params,
"trainable_ratio": trainable_params / total_params if total_params > 0 else 0.0,
"linear_dim": lokr_config.linear_dim,
"linear_alpha": lokr_config.linear_alpha,
"factor": lokr_config.factor,
"algo": "lokr",
"target_modules": lokr_config.target_modules,
}
logger.info("LoKr injected into decoder")
logger.info(
f"LoKr trainable params: {trainable_params:,}/{total_params:,} "
f"({info['trainable_ratio']:.2%})"
)
return model, lycoris_net, info
def save_lokr_weights(
lycoris_net: "LycorisNetwork",
output_dir: str,
dtype: Optional[torch.dtype] = None,
metadata: Optional[Dict[str, str]] = None,
) -> str:
"""Save LoKr weights to safetensors."""
os.makedirs(output_dir, exist_ok=True)
weights_path = os.path.join(output_dir, "lokr_weights.safetensors")
save_metadata: Dict[str, str] = {"algo": "lokr", "format": "lycoris"}
if metadata:
for key, value in metadata.items():
if value is None:
continue
if isinstance(value, str):
save_metadata[key] = value
else:
save_metadata[key] = json.dumps(value, ensure_ascii=True)
lycoris_net.save_weights(weights_path, dtype=dtype, metadata=save_metadata)
logger.info(f"LoKr weights saved to {weights_path}")
return weights_path
def load_lokr_weights(lycoris_net: "LycorisNetwork", weights_path: str) -> Dict[str, Any]:
"""Load LoKr weights into an injected LyCORIS network."""
if not os.path.exists(weights_path):
raise FileNotFoundError(f"LoKr weights not found: {weights_path}")
result = lycoris_net.load_weights(weights_path)
logger.info(f"LoKr weights loaded from {weights_path}")
return result
def save_lokr_training_checkpoint(
lycoris_net: "LycorisNetwork",
optimizer,
scheduler,
epoch: int,
global_step: int,
output_dir: str,
lokr_config: Optional[LoKRConfig] = None,
run_metadata: Optional[Dict[str, Any]] = None,
) -> str:
"""Save LoKr weights plus optimizer/scheduler state."""
os.makedirs(output_dir, exist_ok=True)
metadata: Dict[str, Any] = {}
if lokr_config is not None:
metadata["lokr_config"] = lokr_config.to_dict()
if run_metadata is not None:
metadata["run_metadata"] = run_metadata
metadata = metadata or None
save_lokr_weights(lycoris_net, output_dir, metadata=metadata)
state = {
"epoch": epoch,
"global_step": global_step,
"optimizer_state_dict": optimizer.state_dict(),
"scheduler_state_dict": scheduler.state_dict(),
}
if lokr_config is not None:
state["lokr_config"] = lokr_config.to_dict()
if run_metadata is not None:
state["run_metadata"] = run_metadata
state_path = os.path.join(output_dir, "training_state.pt")
torch.save(state, state_path)
logger.info(f"LoKr checkpoint saved to {output_dir} (epoch={epoch}, step={global_step})")
return output_dir

View file

@ -8,6 +8,7 @@ Uses PEFT (Parameter-Efficient Fine-Tuning) library for LoRA implementation.
import os
from typing import Optional, List, Dict, Any, Tuple
from loguru import logger
import types
import torch
import torch.nn as nn
@ -95,8 +96,61 @@ def inject_lora_into_dit(
if not PEFT_AVAILABLE:
raise ImportError("PEFT library is required for LoRA training. Install with: pip install peft")
# Get the decoder (DiT model)
# Get the decoder (DiT model). Previous failed training runs may leave
# Fabric/PEFT wrappers attached; unwrap to a clean base module first.
decoder = model.decoder
while hasattr(decoder, "_forward_module"):
decoder = decoder._forward_module
if hasattr(decoder, "base_model"):
base_model = decoder.base_model
if hasattr(base_model, "model"):
decoder = base_model.model
else:
decoder = base_model
if hasattr(decoder, "model") and isinstance(decoder.model, nn.Module):
decoder = decoder.model
model.decoder = decoder
# PEFT may call enable_input_require_grads() when is_gradient_checkpointing
# is true. AceStepDiTModel doesn't implement get_input_embeddings, so the
# default implementation raises NotImplementedError. Guard this path.
if hasattr(decoder, "enable_input_require_grads"):
orig_enable_input_require_grads = decoder.enable_input_require_grads
def _safe_enable_input_require_grads(self):
try:
result = orig_enable_input_require_grads()
try:
self._acestep_input_grads_hook_enabled = True
except Exception:
pass
return result
except NotImplementedError:
try:
self._acestep_input_grads_hook_enabled = False
except Exception:
pass
if not getattr(self, "_acestep_input_grads_warning_emitted", False):
logger.info(
"Skipping enable_input_require_grads for decoder: "
"get_input_embeddings is not implemented (expected for DiT)"
)
try:
self._acestep_input_grads_warning_emitted = True
except Exception:
pass
return None
decoder.enable_input_require_grads = types.MethodType(
_safe_enable_input_require_grads, decoder
)
# Avoid PEFT auto-prep path on non-embedding diffusion decoder.
if hasattr(decoder, "is_gradient_checkpointing"):
try:
decoder.is_gradient_checkpointing = False
except Exception:
pass
# Create PEFT LoRA config
peft_lora_config = LoraConfig(

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,15 @@
"""
ACE-Step Training V2 -- Corrected LoRA Fine-Tuning CLI
Non-destructive parallel module providing corrected training procedures
that match each model variant's own forward() training logic.
Subcommands:
vanilla -- Reproduce existing (bugged) training for backward compatibility
fixed -- Corrected training: continuous timesteps + CFG dropout
selective -- Fixed + dataset-specific layer/module selection
estimate -- Gradient sensitivity analysis (no training)
compare-configs -- Compare module configs across datasets
"""
__version__ = "2.0.0"

Some files were not shown because too many files have changed in this diff Show more