mirror of
https://github.com/ace-step/ACE-Step-1.5.git
synced 2026-07-02 16:37:04 +00:00
Merge branch 'main' into fix_xpu_req
This commit is contained in:
commit
cf5b27e8f0
166 changed files with 40083 additions and 7035 deletions
173
.claude/skills/acestep-lyrics-transcription/SKILL.md
Normal file
173
.claude/skills/acestep-lyrics-transcription/SKILL.md
Normal file
|
|
@ -0,0 +1,173 @@
|
|||
---
|
||||
name: acestep-lyrics-transcription
|
||||
description: Transcribe audio to timestamped lyrics using OpenAI Whisper or ElevenLabs Scribe API. Outputs LRC, SRT, or JSON with word-level timestamps. Use when users want to transcribe songs, generate LRC files, or extract lyrics with timestamps from audio.
|
||||
allowed-tools: Read, Write, Bash
|
||||
---
|
||||
|
||||
# Lyrics Transcription Skill
|
||||
|
||||
Transcribe audio files to timestamped lyrics (LRC/SRT/JSON) via OpenAI Whisper or ElevenLabs Scribe API.
|
||||
|
||||
## API Key Setup Guide
|
||||
|
||||
**Before transcribing, you MUST check whether the user's API key is configured.** Run the following command to check:
|
||||
|
||||
```bash
|
||||
cd "{project_root}/{.claude or .codex}/skills/acestep-lyrics-transcription/" && bash ./scripts/acestep-lyrics-transcription.sh config --check-key
|
||||
```
|
||||
|
||||
This command only reports whether the active provider's API key is set or empty — it does NOT print the actual key value. **NEVER read or display the user's API key content.** Do not use `config --get` on key fields or read `config.json` directly. The `config --list` command is safe — it automatically masks API keys as `***` in output.
|
||||
|
||||
**If the command reports the key is empty**, you MUST stop and guide the user to configure it before proceeding. Do NOT attempt transcription without a valid key — it will fail.
|
||||
|
||||
Use `AskUserQuestion` to ask the user to provide their API key, with the following options and guidance:
|
||||
|
||||
1. Tell the user which provider is currently active (openai or elevenlabs) and that its API key is not configured. Explain that transcription cannot proceed without it.
|
||||
2. Provide clear instructions on where to obtain a key:
|
||||
- **OpenAI**: Get an API key at https://platform.openai.com/api-keys — requires an OpenAI account with billing enabled. The Whisper API costs ~$0.006/min.
|
||||
- **ElevenLabs**: Get an API key at https://elevenlabs.io/app/settings/api-keys — requires an ElevenLabs account. Free tier includes limited credits.
|
||||
3. Also offer the option to switch to the other provider if they already have a key for it.
|
||||
4. Once the user provides the key, configure it using:
|
||||
```bash
|
||||
cd "{project_root}/{.claude or .codex}/skills/acestep-lyrics-transcription/" && bash ./scripts/acestep-lyrics-transcription.sh config --set <provider>.api_key <KEY>
|
||||
```
|
||||
5. If the user wants to switch providers, also run:
|
||||
```bash
|
||||
cd "{project_root}/{.claude or .codex}/skills/acestep-lyrics-transcription/" && bash ./scripts/acestep-lyrics-transcription.sh config --set provider <provider_name>
|
||||
```
|
||||
6. After configuring, re-run `config --check-key` to verify the key is set before proceeding.
|
||||
|
||||
**If the API key is already configured**, proceed directly to transcription without asking.
|
||||
|
||||
## Quick Start
|
||||
|
||||
```bash
|
||||
# 1. cd to this skill's directory
|
||||
cd {project_root}/{.claude or .codex}/skills/acestep-lyrics-transcription/
|
||||
|
||||
# 2. Configure API key (choose one)
|
||||
./scripts/acestep-lyrics-transcription.sh config --set openai.api_key sk-...
|
||||
# or
|
||||
./scripts/acestep-lyrics-transcription.sh config --set elevenlabs.api_key ...
|
||||
./scripts/acestep-lyrics-transcription.sh config --set provider elevenlabs
|
||||
|
||||
# 3. Transcribe
|
||||
./scripts/acestep-lyrics-transcription.sh transcribe --audio /path/to/song.mp3 --language zh
|
||||
|
||||
# 4. Output saved to: {project_root}/acestep_output/<filename>.lrc
|
||||
```
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- curl, jq, python3 (or python)
|
||||
- An API key for OpenAI or ElevenLabs
|
||||
|
||||
## Script Usage
|
||||
|
||||
```bash
|
||||
./scripts/acestep-lyrics-transcription.sh transcribe --audio <file> [options]
|
||||
|
||||
Options:
|
||||
-a, --audio Audio file path (required)
|
||||
-l, --language Language code (zh, en, ja, etc.)
|
||||
-f, --format Output format: lrc, srt, json (default: lrc)
|
||||
-p, --provider API provider: openai, elevenlabs (overrides config)
|
||||
-o, --output Output file path (default: acestep_output/<filename>.lrc)
|
||||
```
|
||||
|
||||
## Post-Transcription Lyrics Correction (MANDATORY)
|
||||
|
||||
**CRITICAL**: After transcription, you MUST manually correct the LRC file before using it for MV rendering. Transcription models frequently produce errors on sung lyrics:
|
||||
|
||||
- Proper nouns: "ACE-Step" → "AC step", "Spotify" → "spot a fly"
|
||||
- Similar-sounding words: "arrives" → "eyes", "open source" → "open sores"
|
||||
- Merged/split words: "lighting up" → "lightin' nup"
|
||||
|
||||
### Correction Workflow
|
||||
|
||||
1. **Read the transcribed LRC file** using the Read tool
|
||||
2. **Read the original lyrics** from the ACE-Step output JSON file
|
||||
3. **Use original lyrics as a whole reference**: Do NOT attempt line-by-line alignment — transcription often splits, merges, or reorders lines differently from the original. Instead, read the original lyrics in full to understand the correct wording, then scan each LRC line and fix any misrecognized words based on your knowledge of what the original lyrics say.
|
||||
4. **Fix transcription errors**: Replace misrecognized words with the correct original words, keeping the timestamps intact
|
||||
5. **Write the corrected LRC** back using the Write tool
|
||||
|
||||
### What to Correct
|
||||
|
||||
- Replace misrecognized words with their correct original versions
|
||||
- Keep all `[MM:SS.cc]` timestamps exactly as-is (timestamps from transcription are accurate)
|
||||
- Do NOT add structure tags like `[Verse]` or `[Chorus]` — the LRC should only have timestamped text lines
|
||||
|
||||
### Example
|
||||
|
||||
**Transcribed (wrong):**
|
||||
```
|
||||
[00:46.96]AC step alive,
|
||||
[00:50.80]one point five eyes.
|
||||
```
|
||||
|
||||
**Original lyrics reference:**
|
||||
```
|
||||
ACE-Step alive
|
||||
One point five arrives
|
||||
```
|
||||
|
||||
**Corrected (right):**
|
||||
```
|
||||
[00:46.96]ACE-Step alive,
|
||||
[00:50.80]One point five arrives.
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
Config file: `scripts/config.json`
|
||||
|
||||
```bash
|
||||
# Switch provider
|
||||
./scripts/acestep-lyrics-transcription.sh config --set provider openai
|
||||
./scripts/acestep-lyrics-transcription.sh config --set provider elevenlabs
|
||||
|
||||
# Set API keys
|
||||
./scripts/acestep-lyrics-transcription.sh config --set openai.api_key sk-...
|
||||
./scripts/acestep-lyrics-transcription.sh config --set elevenlabs.api_key ...
|
||||
|
||||
# View config
|
||||
./scripts/acestep-lyrics-transcription.sh config --list
|
||||
```
|
||||
|
||||
| Option | Default | Description |
|
||||
|--------|---------|-------------|
|
||||
| `provider` | `openai` | Active provider: `openai` or `elevenlabs` |
|
||||
| `output_format` | `lrc` | Default output: `lrc`, `srt`, or `json` |
|
||||
| `openai.api_key` | `""` | OpenAI API key |
|
||||
| `openai.api_url` | `https://api.openai.com/v1` | OpenAI API base URL |
|
||||
| `openai.model` | `whisper-1` | OpenAI model (whisper-1 for word timestamps) |
|
||||
| `elevenlabs.api_key` | `""` | ElevenLabs API key |
|
||||
| `elevenlabs.api_url` | `https://api.elevenlabs.io/v1` | ElevenLabs API base URL |
|
||||
| `elevenlabs.model` | `scribe_v2` | ElevenLabs model |
|
||||
|
||||
## Provider Notes
|
||||
|
||||
| Provider | Model | Word Timestamps | Pricing |
|
||||
|----------|-------|-----------------|---------|
|
||||
| OpenAI | whisper-1 | Yes (segment + word) | $0.006/min |
|
||||
| ElevenLabs | scribe_v2 | Yes (word-level) | Varies by plan |
|
||||
|
||||
- OpenAI `whisper-1` is the only OpenAI model supporting word-level timestamps
|
||||
- ElevenLabs `scribe_v2` returns word-level timestamps with type filtering
|
||||
- Both support multilingual transcription
|
||||
|
||||
## Examples
|
||||
|
||||
```bash
|
||||
# Basic transcription (uses config defaults)
|
||||
./scripts/acestep-lyrics-transcription.sh transcribe --audio song.mp3
|
||||
|
||||
# Chinese song to LRC
|
||||
./scripts/acestep-lyrics-transcription.sh transcribe --audio song.mp3 --language zh
|
||||
|
||||
# Use ElevenLabs, output SRT
|
||||
./scripts/acestep-lyrics-transcription.sh transcribe --audio song.mp3 --provider elevenlabs --format srt
|
||||
|
||||
# Custom output path
|
||||
./scripts/acestep-lyrics-transcription.sh transcribe --audio song.mp3 --output ./my_lyrics.lrc
|
||||
```
|
||||
|
|
@ -0,0 +1,584 @@
|
|||
#!/bin/bash
|
||||
#
|
||||
# acestep-lyrics-transcription.sh - Transcribe audio to timestamped lyrics (LRC/SRT/JSON)
|
||||
#
|
||||
# Requirements: curl, jq
|
||||
#
|
||||
# Usage:
|
||||
# ./acestep-lyrics-transcription.sh transcribe --audio <file> [options]
|
||||
# ./acestep-lyrics-transcription.sh config [--get|--set|--reset]
|
||||
#
|
||||
# Output:
|
||||
# - LRC/SRT/JSON files saved to output directory
|
||||
|
||||
set -e
|
||||
|
||||
export LANG="${LANG:-en_US.UTF-8}"
|
||||
export LC_ALL="${LC_ALL:-en_US.UTF-8}"
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
CONFIG_FILE="${SCRIPT_DIR}/config.json"
|
||||
OUTPUT_DIR="$(cd "${SCRIPT_DIR}/../../../.." && pwd)/acestep_output"
|
||||
|
||||
# Colors
|
||||
RED='\033[0;31m'
|
||||
GREEN='\033[0;32m'
|
||||
YELLOW='\033[1;33m'
|
||||
CYAN='\033[0;36m'
|
||||
NC='\033[0m'
|
||||
|
||||
# Convert MSYS2/Cygwin paths to Windows-native paths for Python
|
||||
to_python_path() {
|
||||
if command -v cygpath &> /dev/null; then
|
||||
cygpath -m "$1"
|
||||
else
|
||||
echo "$1"
|
||||
fi
|
||||
}
|
||||
|
||||
# Detect python executable (python3 or python)
|
||||
PYTHON_CMD=""
|
||||
find_python() {
|
||||
if [ -n "$PYTHON_CMD" ]; then return; fi
|
||||
# Test actual execution, not just existence (Windows Store python3 shim returns exit 49)
|
||||
if python3 -c "pass" &> /dev/null; then
|
||||
PYTHON_CMD="python3"
|
||||
elif python -c "pass" &> /dev/null; then
|
||||
PYTHON_CMD="python"
|
||||
else
|
||||
echo -e "${RED}Error: python3 or python is required but not found.${NC}"
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
# ─── Dependencies ───
|
||||
|
||||
check_deps() {
|
||||
if ! command -v curl &> /dev/null; then
|
||||
echo -e "${RED}Error: curl is required but not installed.${NC}"
|
||||
exit 1
|
||||
fi
|
||||
if ! command -v jq &> /dev/null; then
|
||||
echo -e "${RED}Error: jq is required but not installed.${NC}"
|
||||
echo "Install: apt install jq / brew install jq / choco install jq"
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
# ─── Config ───
|
||||
|
||||
DEFAULT_CONFIG='{
|
||||
"provider": "openai",
|
||||
"output_format": "lrc",
|
||||
"openai": {
|
||||
"api_key": "",
|
||||
"api_url": "https://api.openai.com/v1",
|
||||
"model": "whisper-1"
|
||||
},
|
||||
"elevenlabs": {
|
||||
"api_key": "",
|
||||
"api_url": "https://api.elevenlabs.io/v1",
|
||||
"model": "scribe_v2"
|
||||
}
|
||||
}'
|
||||
|
||||
ensure_config() {
|
||||
if [ ! -f "$CONFIG_FILE" ]; then
|
||||
local example="${SCRIPT_DIR}/config.example.json"
|
||||
if [ -f "$example" ]; then
|
||||
cp "$example" "$CONFIG_FILE"
|
||||
echo -e "${YELLOW}Config file created from config.example.json. Please configure your settings:${NC}"
|
||||
echo -e " ${CYAN}./scripts/acestep-lyrics-transcription.sh config --set provider <openai|elevenlabs>${NC}"
|
||||
echo -e " ${CYAN}./scripts/acestep-lyrics-transcription.sh config --set <provider>.api_key <key>${NC}"
|
||||
else
|
||||
echo "$DEFAULT_CONFIG" > "$CONFIG_FILE"
|
||||
fi
|
||||
fi
|
||||
}
|
||||
|
||||
get_config() {
|
||||
local key="$1"
|
||||
ensure_config
|
||||
local jq_path=".${key}"
|
||||
local value
|
||||
value=$(jq -r "$jq_path" "$CONFIG_FILE" 2>/dev/null)
|
||||
if [ "$value" = "null" ]; then
|
||||
echo ""
|
||||
else
|
||||
echo "$value" | tr -d '\r\n'
|
||||
fi
|
||||
}
|
||||
|
||||
set_config() {
|
||||
local key="$1"
|
||||
local value="$2"
|
||||
ensure_config
|
||||
local tmp_file="${CONFIG_FILE}.tmp"
|
||||
local jq_path=".${key}"
|
||||
|
||||
if [ "$value" = "true" ] || [ "$value" = "false" ]; then
|
||||
jq "$jq_path = $value" "$CONFIG_FILE" > "$tmp_file"
|
||||
elif [[ "$value" =~ ^-?[0-9]+$ ]] || [[ "$value" =~ ^-?[0-9]+\.[0-9]+$ ]]; then
|
||||
jq "$jq_path = $value" "$CONFIG_FILE" > "$tmp_file"
|
||||
else
|
||||
jq "$jq_path = \"$value\"" "$CONFIG_FILE" > "$tmp_file"
|
||||
fi
|
||||
|
||||
mv "$tmp_file" "$CONFIG_FILE"
|
||||
echo "Set $key = $value"
|
||||
}
|
||||
|
||||
ensure_output_dir() {
|
||||
mkdir -p "$OUTPUT_DIR"
|
||||
}
|
||||
|
||||
# ─── Format Conversion ───
|
||||
|
||||
# Convert word-level timestamps to LRC format
|
||||
# Input: JSON array of {word, start, end} on stdin
|
||||
# Output: LRC text
|
||||
words_to_lrc() {
|
||||
local json_file="$(to_python_path "$1")"
|
||||
local output_file="$(to_python_path "$2")"
|
||||
local line_gap="${3:-1.5}"
|
||||
find_python
|
||||
|
||||
$PYTHON_CMD -c "
|
||||
import json, sys, unicodedata
|
||||
|
||||
def is_cjk(ch):
|
||||
cp = ord(ch)
|
||||
return (0x4E00 <= cp <= 0x9FFF or 0x3400 <= cp <= 0x4DBF or
|
||||
0x20000 <= cp <= 0x2A6DF or 0x2A700 <= cp <= 0x2B73F or
|
||||
0x2B740 <= cp <= 0x2B81F or 0x2B820 <= cp <= 0x2CEAF or
|
||||
0xF900 <= cp <= 0xFAFF or 0x2F800 <= cp <= 0x2FA1F or
|
||||
0x3000 <= cp <= 0x303F or 0x3040 <= cp <= 0x309F or
|
||||
0x30A0 <= cp <= 0x30FF or 0xFF00 <= cp <= 0xFFEF)
|
||||
|
||||
def smart_join(word_list):
|
||||
if not word_list:
|
||||
return ''
|
||||
result = word_list[0]
|
||||
for j in range(1, len(word_list)):
|
||||
prev_w = word_list[j-1]
|
||||
curr_w = word_list[j]
|
||||
prev_last = prev_w[-1] if prev_w else ''
|
||||
curr_first = curr_w[0] if curr_w else ''
|
||||
if is_cjk(prev_last) or is_cjk(curr_first):
|
||||
result += curr_w
|
||||
else:
|
||||
result += ' ' + curr_w
|
||||
return result.strip()
|
||||
|
||||
with open('$json_file', 'r', encoding='utf-8') as f:
|
||||
words = json.load(f)
|
||||
|
||||
if not words:
|
||||
sys.exit(0)
|
||||
|
||||
lines = []
|
||||
current_line = []
|
||||
current_start = words[0]['start']
|
||||
|
||||
for i, w in enumerate(words):
|
||||
current_line.append(w['word'])
|
||||
is_last = (i == len(words) - 1)
|
||||
has_punct = w['word'].rstrip().endswith(('.', '!', '?', '。', '!', '?', ',', ','))
|
||||
has_gap = (not is_last and words[i+1]['start'] - w['end'] > $line_gap)
|
||||
|
||||
if is_last or has_punct or has_gap:
|
||||
text = smart_join(current_line)
|
||||
text = text.rstrip(',。,.')
|
||||
if text:
|
||||
mins = int(current_start) // 60
|
||||
secs = current_start - mins * 60
|
||||
lines.append(f'[{mins:02d}:{secs:05.2f}]{text}')
|
||||
current_line = []
|
||||
if not is_last:
|
||||
current_start = words[i+1]['start']
|
||||
|
||||
with open('$output_file', 'w', encoding='utf-8') as f:
|
||||
for line in lines:
|
||||
f.write(line + '\n')
|
||||
"
|
||||
}
|
||||
|
||||
# Convert word-level timestamps to SRT format
|
||||
words_to_srt() {
|
||||
local json_file="$(to_python_path "$1")"
|
||||
local output_file="$(to_python_path "$2")"
|
||||
local line_gap="${3:-1.5}"
|
||||
find_python
|
||||
|
||||
$PYTHON_CMD -c "
|
||||
import json, sys
|
||||
|
||||
def is_cjk(ch):
|
||||
cp = ord(ch)
|
||||
return (0x4E00 <= cp <= 0x9FFF or 0x3400 <= cp <= 0x4DBF or
|
||||
0x20000 <= cp <= 0x2A6DF or 0x2A700 <= cp <= 0x2B73F or
|
||||
0x2B740 <= cp <= 0x2B81F or 0x2B820 <= cp <= 0x2CEAF or
|
||||
0xF900 <= cp <= 0xFAFF or 0x2F800 <= cp <= 0x2FA1F or
|
||||
0x3000 <= cp <= 0x303F or 0x3040 <= cp <= 0x309F or
|
||||
0x30A0 <= cp <= 0x30FF or 0xFF00 <= cp <= 0xFFEF)
|
||||
|
||||
def smart_join(word_list):
|
||||
if not word_list:
|
||||
return ''
|
||||
result = word_list[0]
|
||||
for j in range(1, len(word_list)):
|
||||
prev_w = word_list[j-1]
|
||||
curr_w = word_list[j]
|
||||
prev_last = prev_w[-1] if prev_w else ''
|
||||
curr_first = curr_w[0] if curr_w else ''
|
||||
if is_cjk(prev_last) or is_cjk(curr_first):
|
||||
result += curr_w
|
||||
else:
|
||||
result += ' ' + curr_w
|
||||
return result.strip()
|
||||
|
||||
with open('$json_file', 'r', encoding='utf-8') as f:
|
||||
words = json.load(f)
|
||||
|
||||
if not words:
|
||||
sys.exit(0)
|
||||
|
||||
def fmt(t):
|
||||
h = int(t) // 3600
|
||||
m = (int(t) % 3600) // 60
|
||||
s = t - h*3600 - m*60
|
||||
return f'{h:02d}:{m:02d}:{s:06.3f}'.replace('.', ',')
|
||||
|
||||
lines = []
|
||||
current_line = []
|
||||
current_start = words[0]['start']
|
||||
current_end = words[0]['end']
|
||||
|
||||
for i, w in enumerate(words):
|
||||
current_line.append(w['word'])
|
||||
current_end = w['end']
|
||||
is_last = (i == len(words) - 1)
|
||||
has_punct = w['word'].rstrip().endswith(('.', '!', '?', '。', '!', '?', ',', ','))
|
||||
has_gap = (not is_last and words[i+1]['start'] - w['end'] > $line_gap)
|
||||
|
||||
if is_last or has_punct or has_gap:
|
||||
text = smart_join(current_line)
|
||||
text = text.rstrip(',。,.')
|
||||
if text:
|
||||
lines.append((current_start, current_end, text))
|
||||
current_line = []
|
||||
if not is_last:
|
||||
current_start = words[i+1]['start']
|
||||
|
||||
with open('$output_file', 'w', encoding='utf-8') as f:
|
||||
for idx, (s, e, text) in enumerate(lines, 1):
|
||||
f.write(f'{idx}\n')
|
||||
f.write(f'{fmt(s)} --> {fmt(e)}\n')
|
||||
f.write(f'{text}\n')
|
||||
f.write('\n')
|
||||
"
|
||||
}
|
||||
|
||||
# ─── OpenAI Whisper ───
|
||||
|
||||
transcribe_openai() {
|
||||
local audio_file="$1"
|
||||
local language="$2"
|
||||
local words_file="$3"
|
||||
|
||||
local api_key=$(get_config "openai.api_key")
|
||||
local api_url=$(get_config "openai.api_url")
|
||||
local model=$(get_config "openai.model")
|
||||
|
||||
[ -z "$api_key" ] && { echo -e "${RED}Error: OpenAI API key not configured.${NC}"; echo "Run: ./acestep-lyrics-transcription.sh config --set openai.api_key YOUR_KEY"; exit 1; }
|
||||
[ -z "$api_url" ] && api_url="https://api.openai.com/v1"
|
||||
[ -z "$model" ] && model="whisper-1"
|
||||
|
||||
echo -e " Provider: OpenAI (${model})"
|
||||
|
||||
local resp_file=$(mktemp)
|
||||
|
||||
# Build curl command
|
||||
local curl_args=(
|
||||
-s -w "%{http_code}"
|
||||
-o "$resp_file"
|
||||
-X POST "${api_url}/audio/transcriptions"
|
||||
-H "Authorization: Bearer ${api_key}"
|
||||
-F "file=@${audio_file}"
|
||||
-F "model=${model}"
|
||||
-F "response_format=verbose_json"
|
||||
-F "timestamp_granularities[]=word"
|
||||
-F "timestamp_granularities[]=segment"
|
||||
)
|
||||
|
||||
[ -n "$language" ] && curl_args+=(-F "language=${language}")
|
||||
|
||||
local http_code
|
||||
http_code=$(curl "${curl_args[@]}")
|
||||
|
||||
if [ "$http_code" != "200" ]; then
|
||||
local err
|
||||
err=$(jq -r '.error.message // .detail // "Unknown error"' "$resp_file" 2>/dev/null)
|
||||
echo -e "${RED}Error: HTTP $http_code - $err${NC}"
|
||||
rm -f "$resp_file"
|
||||
return 1
|
||||
fi
|
||||
|
||||
# Extract word-level timestamps into normalized format [{word, start, end}]
|
||||
jq '[.words[] | {word: .word, start: .start, end: .end}]' "$resp_file" > "$words_file" 2>/dev/null
|
||||
|
||||
# Show transcription text
|
||||
local text
|
||||
text=$(jq -r '.text // empty' "$resp_file" 2>/dev/null)
|
||||
echo -e " ${GREEN}Transcription complete${NC}"
|
||||
echo ""
|
||||
echo "$text"
|
||||
|
||||
rm -f "$resp_file"
|
||||
}
|
||||
|
||||
# ─── ElevenLabs Scribe ───
|
||||
|
||||
transcribe_elevenlabs() {
|
||||
local audio_file="$1"
|
||||
local language="$2"
|
||||
local words_file="$3"
|
||||
|
||||
local api_key=$(get_config "elevenlabs.api_key")
|
||||
local api_url=$(get_config "elevenlabs.api_url")
|
||||
local model=$(get_config "elevenlabs.model")
|
||||
|
||||
[ -z "$api_key" ] && { echo -e "${RED}Error: ElevenLabs API key not configured.${NC}"; echo "Run: ./acestep-lyrics-transcription.sh config --set elevenlabs.api_key YOUR_KEY"; exit 1; }
|
||||
[ -z "$api_url" ] && api_url="https://api.elevenlabs.io/v1"
|
||||
[ -z "$model" ] && model="scribe_v2"
|
||||
|
||||
echo -e " Provider: ElevenLabs (${model})"
|
||||
|
||||
local resp_file=$(mktemp)
|
||||
|
||||
local curl_args=(
|
||||
-s -w "%{http_code}"
|
||||
-o "$resp_file"
|
||||
-X POST "${api_url}/speech-to-text"
|
||||
-H "xi-api-key: ${api_key}"
|
||||
-F "file=@${audio_file}"
|
||||
-F "model_id=${model}"
|
||||
)
|
||||
|
||||
[ -n "$language" ] && curl_args+=(-F "language_code=${language}")
|
||||
|
||||
local http_code
|
||||
http_code=$(curl "${curl_args[@]}")
|
||||
|
||||
if [ "$http_code" != "200" ]; then
|
||||
local err
|
||||
err=$(jq -r '.detail.message // .detail // "Unknown error"' "$resp_file" 2>/dev/null)
|
||||
echo -e "${RED}Error: HTTP $http_code - $err${NC}"
|
||||
rm -f "$resp_file"
|
||||
return 1
|
||||
fi
|
||||
|
||||
# ElevenLabs response: { text, words: [{text, start, end, type}...] }
|
||||
# Normalize to [{word, start, end}], timestamps already in seconds, filter only "word" type
|
||||
jq '[.words[] | select(.type == "word") | {word: .text, start: .start, end: .end}]' "$resp_file" > "$words_file" 2>/dev/null
|
||||
|
||||
local text
|
||||
text=$(jq -r '.text // empty' "$resp_file" 2>/dev/null)
|
||||
echo -e " ${GREEN}Transcription complete${NC}"
|
||||
echo ""
|
||||
echo "$text"
|
||||
|
||||
rm -f "$resp_file"
|
||||
}
|
||||
|
||||
# ─── Commands ───
|
||||
|
||||
cmd_transcribe() {
|
||||
check_deps
|
||||
ensure_config
|
||||
|
||||
local audio="" language="" output="" format="" provider=""
|
||||
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case $1 in
|
||||
--audio|-a) audio="$2"; shift 2 ;;
|
||||
--language|-l) language="$2"; shift 2 ;;
|
||||
--output|-o) output="$2"; shift 2 ;;
|
||||
--format|-f) format="$2"; shift 2 ;;
|
||||
--provider|-p) provider="$2"; shift 2 ;;
|
||||
*) [ -z "$audio" ] && audio="$1"; shift ;;
|
||||
esac
|
||||
done
|
||||
|
||||
[ -z "$audio" ] && { echo -e "${RED}Error: --audio is required${NC}"; echo "Usage: $0 transcribe --audio <file> [options]"; exit 1; }
|
||||
[ ! -f "$audio" ] && { echo -e "${RED}Error: audio file not found: $audio${NC}"; exit 1; }
|
||||
|
||||
# Resolve absolute path
|
||||
audio="$(cd "$(dirname "$audio")" && pwd)/$(basename "$audio")"
|
||||
|
||||
[ -z "$provider" ] && provider=$(get_config "provider")
|
||||
[ -z "$provider" ] && provider="openai"
|
||||
|
||||
[ -z "$format" ] && format=$(get_config "output_format")
|
||||
[ -z "$format" ] && format="lrc"
|
||||
|
||||
# Default output path
|
||||
if [ -z "$output" ]; then
|
||||
ensure_output_dir
|
||||
local basename="$(basename "${audio%.*}")"
|
||||
output="${OUTPUT_DIR}/${basename}.${format}"
|
||||
fi
|
||||
|
||||
echo "Transcribing..."
|
||||
echo " Audio: $(basename "$audio")"
|
||||
echo " Format: $format"
|
||||
|
||||
# Transcribe to normalized word timestamps
|
||||
local words_file=$(mktemp)
|
||||
|
||||
case "$provider" in
|
||||
openai) transcribe_openai "$audio" "$language" "$words_file" ;;
|
||||
elevenlabs) transcribe_elevenlabs "$audio" "$language" "$words_file" ;;
|
||||
*) echo -e "${RED}Error: unknown provider: $provider${NC}"; echo "Supported: openai, elevenlabs"; rm -f "$words_file"; exit 1 ;;
|
||||
esac
|
||||
|
||||
# Check if we got words
|
||||
local word_count
|
||||
word_count=$(jq 'length' "$words_file" 2>/dev/null)
|
||||
if [ -z "$word_count" ] || [ "$word_count" = "0" ]; then
|
||||
echo -e "${YELLOW}Warning: no word-level timestamps returned${NC}"
|
||||
rm -f "$words_file"
|
||||
return 1
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo " Words detected: $word_count"
|
||||
|
||||
# Convert to output format
|
||||
mkdir -p "$(dirname "$output")"
|
||||
|
||||
case "$format" in
|
||||
lrc)
|
||||
words_to_lrc "$words_file" "$output"
|
||||
;;
|
||||
srt)
|
||||
words_to_srt "$words_file" "$output"
|
||||
;;
|
||||
json)
|
||||
cp "$words_file" "$output"
|
||||
;;
|
||||
*)
|
||||
echo -e "${RED}Error: unknown format: $format (supported: lrc, srt, json)${NC}"
|
||||
rm -f "$words_file"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
|
||||
rm -f "$words_file"
|
||||
|
||||
echo -e " ${GREEN}Saved: $output${NC}"
|
||||
echo ""
|
||||
echo -e "${GREEN}Done!${NC}"
|
||||
}
|
||||
|
||||
cmd_config() {
|
||||
check_deps
|
||||
ensure_config
|
||||
|
||||
local action="" key="" value=""
|
||||
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case $1 in
|
||||
--get) action="get"; key="$2"; shift 2 ;;
|
||||
--set) action="set"; key="$2"; value="$3"; shift 3 ;;
|
||||
--reset) action="reset"; shift ;;
|
||||
--list) action="list"; shift ;;
|
||||
--check-key) action="check-key"; shift ;;
|
||||
*) shift ;;
|
||||
esac
|
||||
done
|
||||
|
||||
case "$action" in
|
||||
"check-key")
|
||||
local provider=$(get_config "provider")
|
||||
[ -z "$provider" ] && provider="openai"
|
||||
local api_key=$(get_config "${provider}.api_key")
|
||||
echo "provider: $provider"
|
||||
if [ -n "$api_key" ]; then
|
||||
echo "api_key: configured"
|
||||
else
|
||||
echo "api_key: empty"
|
||||
fi
|
||||
;;
|
||||
"get")
|
||||
[ -z "$key" ] && { echo -e "${RED}Error: --get requires KEY${NC}"; exit 1; }
|
||||
local result=$(get_config "$key")
|
||||
[ -n "$result" ] && echo "$key = $result" || echo "Key not found: $key"
|
||||
;;
|
||||
"set")
|
||||
[ -z "$key" ] || [ -z "$value" ] && { echo -e "${RED}Error: --set requires KEY VALUE${NC}"; exit 1; }
|
||||
set_config "$key" "$value"
|
||||
;;
|
||||
"reset")
|
||||
echo "$DEFAULT_CONFIG" > "$CONFIG_FILE"
|
||||
echo -e "${GREEN}Configuration reset to defaults.${NC}"
|
||||
jq 'walk(if type == "object" and has("api_key") and (.api_key | length) > 0 then .api_key = "***" else . end)' "$CONFIG_FILE"
|
||||
;;
|
||||
"list")
|
||||
echo "Current configuration:"
|
||||
jq 'walk(if type == "object" and has("api_key") and (.api_key | length) > 0 then .api_key = "***" else . end)' "$CONFIG_FILE"
|
||||
;;
|
||||
*)
|
||||
echo "Config file: $CONFIG_FILE"
|
||||
echo "----------------------------------------"
|
||||
jq 'walk(if type == "object" and has("api_key") and (.api_key | length) > 0 then .api_key = "***" else . end)' "$CONFIG_FILE"
|
||||
echo ""
|
||||
echo "----------------------------------------"
|
||||
echo ""
|
||||
echo "Usage:"
|
||||
echo " config --list Show config"
|
||||
echo " config --get <key> Get value"
|
||||
echo " config --set <key> <val> Set value"
|
||||
echo " config --reset Reset to defaults"
|
||||
echo ""
|
||||
echo "Examples:"
|
||||
echo " config --set provider elevenlabs"
|
||||
echo " config --set openai.api_key sk-..."
|
||||
echo " config --set elevenlabs.api_key ..."
|
||||
echo " config --set output_format srt"
|
||||
;;
|
||||
esac
|
||||
}
|
||||
|
||||
show_help() {
|
||||
echo "Lyrics Transcription CLI"
|
||||
echo ""
|
||||
echo "Requirements: curl, jq, python3"
|
||||
echo ""
|
||||
echo "Usage: $0 <command> [options]"
|
||||
echo ""
|
||||
echo "Commands:"
|
||||
echo " transcribe Transcribe audio to timestamped lyrics"
|
||||
echo " config Manage configuration"
|
||||
echo ""
|
||||
echo "Transcribe Options:"
|
||||
echo " -a, --audio Audio file path (required)"
|
||||
echo " -l, --language Language code (e.g. zh, en, ja)"
|
||||
echo " -f, --format Output format: lrc, srt, json (default: lrc)"
|
||||
echo " -p, --provider API provider: openai, elevenlabs"
|
||||
echo " -o, --output Output file path"
|
||||
echo ""
|
||||
echo "Examples:"
|
||||
echo " $0 transcribe --audio song.mp3"
|
||||
echo " $0 transcribe --audio song.mp3 --language zh --format lrc"
|
||||
echo " $0 config --set provider openai"
|
||||
}
|
||||
|
||||
# ─── Main ───
|
||||
|
||||
case "$1" in
|
||||
transcribe) shift; cmd_transcribe "$@" ;;
|
||||
config) shift; cmd_config "$@" ;;
|
||||
help|--help|-h) show_help ;;
|
||||
*) show_help; exit 1 ;;
|
||||
esac
|
||||
|
|
@ -0,0 +1,14 @@
|
|||
{
|
||||
"provider": "elevenlabs",
|
||||
"output_format": "lrc",
|
||||
"openai": {
|
||||
"api_key": "",
|
||||
"api_url": "https://api.openai.com/v1",
|
||||
"model": "whisper-1"
|
||||
},
|
||||
"elevenlabs": {
|
||||
"api_key": "",
|
||||
"api_url": "https://api.elevenlabs.io/v1",
|
||||
"model": "scribe_v2"
|
||||
}
|
||||
}
|
||||
133
.claude/skills/acestep-simplemv/SKILL.md
Normal file
133
.claude/skills/acestep-simplemv/SKILL.md
Normal file
|
|
@ -0,0 +1,133 @@
|
|||
---
|
||||
name: acestep-simplemv
|
||||
description: Render music videos from audio files and lyrics using Remotion. Accepts audio + LRC/JSON lyrics + title to produce MP4 videos with waveform visualization and synced lyrics display. Use when users mention MV generation, music video rendering, creating video from audio/lyrics, or visualizing songs.
|
||||
---
|
||||
|
||||
# MV Render
|
||||
|
||||
Render music videos with waveform visualization and synced lyrics from audio + lyrics input.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- Remotion project at `scripts/` directory within this skill
|
||||
- Node.js + npm dependencies installed
|
||||
- ffprobe available (for audio duration detection)
|
||||
|
||||
### First-Time Setup
|
||||
|
||||
Before first use, check and install dependencies:
|
||||
|
||||
```bash
|
||||
# 1. Check Node.js
|
||||
node --version
|
||||
|
||||
# 2. Install npm dependencies
|
||||
cd {project_root}/{.claude or .codex}/skills/acestep-simplemv/scripts && npm install
|
||||
|
||||
# 3. Check ffprobe
|
||||
ffprobe -version
|
||||
```
|
||||
|
||||
If ffprobe is not available, install ffmpeg (which includes ffprobe):
|
||||
- **Windows**: `choco install ffmpeg` or download from https://ffmpeg.org/download.html and add to PATH
|
||||
- **macOS**: `brew install ffmpeg`
|
||||
- **Linux**: `sudo apt-get install ffmpeg` (Debian/Ubuntu) or `sudo dnf install ffmpeg` (Fedora)
|
||||
|
||||
## Quick Start
|
||||
|
||||
```bash
|
||||
cd {project_root}/{.claude or .codex}/skills/acestep-simplemv/
|
||||
./scripts/render-mv.sh --audio /path/to/song.mp3 --lyrics /path/to/song.lrc --title "Song Title"
|
||||
```
|
||||
|
||||
Output: MP4 file at `out/<audio_basename>.mp4` (or custom `--output` path).
|
||||
|
||||
## Script Usage
|
||||
|
||||
```bash
|
||||
./scripts/render-mv.sh --audio <file> --lyrics <lrc_file> --title "Title" [options]
|
||||
|
||||
Options:
|
||||
--audio Audio file path (absolute paths supported)
|
||||
--lyrics LRC format lyrics file (timestamped)
|
||||
--lyrics-json JSON lyrics file [{start, end, text}] (alternative to --lyrics)
|
||||
--title Video title (default: "Music Video")
|
||||
--subtitle Subtitle text
|
||||
--credit Bottom credit text
|
||||
--offset Lyric timing offset in seconds (default: -0.5)
|
||||
--output Output file path (default: out/<audio_basename>.mp4)
|
||||
--codec h264|h265|vp8|vp9 (default: h264)
|
||||
--browser Custom browser executable path (Chrome/Edge/Chromium)
|
||||
|
||||
Environment variables:
|
||||
BROWSER_EXECUTABLE Path to browser executable (overrides auto-detection)
|
||||
```
|
||||
|
||||
## Browser Detection
|
||||
|
||||
Remotion requires a Chromium-based browser for rendering. The script auto-detects browsers in this priority order:
|
||||
|
||||
1. `BROWSER_EXECUTABLE` environment variable
|
||||
2. `--browser` CLI argument
|
||||
3. Remotion cache (`chrome-headless-shell`, downloaded by Remotion)
|
||||
4. System Chrome (auto-uses `--chrome-mode=chrome-for-testing`)
|
||||
5. **System Edge** (pre-installed on Windows 10/11, auto-uses `--chrome-mode=chrome-for-testing`)
|
||||
6. System Chromium (auto-uses `--chrome-mode=chrome-for-testing`)
|
||||
|
||||
**Important**: New versions of Chrome/Edge removed the old headless mode. When using regular Chrome/Edge/Chromium, the script automatically sets `--chrome-mode=chrome-for-testing` (which uses `--headless=new`). When using `chrome-headless-shell`, it uses the default `headless-shell` mode (which uses `--headless=old`). This is handled transparently.
|
||||
|
||||
If no browser is found, Remotion will attempt to download `chrome-headless-shell` from Google servers. **This will fail if Google servers are inaccessible from your network.**
|
||||
|
||||
### Workarounds for restricted networks
|
||||
|
||||
Since **Edge is pre-installed on Windows 10/11**, it should be auto-detected without any manual configuration. The script automatically detects Chrome/Edge and uses the correct headless mode. If auto-detection fails:
|
||||
|
||||
```bash
|
||||
# Option 1: Set environment variable
|
||||
export BROWSER_EXECUTABLE="/path/to/msedge.exe"
|
||||
|
||||
# Option 2: Pass as CLI argument
|
||||
./scripts/render-mv.sh --audio song.mp3 --lyrics song.lrc --title "Song" --browser "/path/to/msedge.exe"
|
||||
|
||||
# Option 3: Enable proxy and let Remotion download chrome-headless-shell
|
||||
```
|
||||
|
||||
## Examples
|
||||
|
||||
```bash
|
||||
# Basic render
|
||||
./scripts/render-mv.sh --audio /tmp/abc123_1.mp3 --lyrics /tmp/abc123.lrc --title "夜桜"
|
||||
|
||||
# Custom output path
|
||||
./scripts/render-mv.sh --audio song.mp3 --lyrics song.lrc --title "My Song" --output /tmp/my_mv.mp4
|
||||
|
||||
# With subtitle and credit
|
||||
./scripts/render-mv.sh --audio song.mp3 --lyrics song.lrc --title "Song" --subtitle "Artist Name" --credit "Generated by ACE-Step"
|
||||
```
|
||||
|
||||
## File Naming
|
||||
|
||||
**IMPORTANT**: Use the audio file's job ID as the output filename to avoid overwriting. Do NOT use custom names like `--output my_song.mp4`. Let the default naming handle it (derives from audio filename).
|
||||
|
||||
Default output uses the audio filename as base:
|
||||
- Audio: `acestep_output/{job_id}_1.mp3`
|
||||
- Lyrics: `acestep_output/{job_id}_1.lrc`
|
||||
- Video: Pass `--output acestep_output/{job_id}.mp4` (use the job ID from the audio file)
|
||||
|
||||
Example: if audio is `chatcmpl-abc123_1.mp3`, pass `--output acestep_output/chatcmpl-abc123.mp4`
|
||||
|
||||
## Title Guidelines
|
||||
|
||||
- Keep `--title` short and single-line (max ~50 chars, auto-truncated)
|
||||
- Use `--subtitle` for additional info
|
||||
- Do NOT put newlines in `--title`
|
||||
|
||||
Good: `--title "Open Source" --subtitle "ACE-Step v1.5"`
|
||||
Bad: `--title "Open Source - ACE-Step v1.5\nCelebrating Music AI"`
|
||||
|
||||
## Notes
|
||||
|
||||
- Audio files with absolute paths are auto-copied to `public/` by render.mjs
|
||||
- Duration is auto-detected via ffprobe
|
||||
- Typical render time: ~1-2 minutes for a 90s song
|
||||
- Output resolution: 1920x1080, 30fps
|
||||
2720
.claude/skills/acestep-simplemv/scripts/package-lock.json
generated
Normal file
2720
.claude/skills/acestep-simplemv/scripts/package-lock.json
generated
Normal file
File diff suppressed because it is too large
Load diff
27
.claude/skills/acestep-simplemv/scripts/package.json
Normal file
27
.claude/skills/acestep-simplemv/scripts/package.json
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
{
|
||||
"name": "acestep-video",
|
||||
"version": "1.0.0",
|
||||
"description": "",
|
||||
"main": "index.js",
|
||||
"scripts": {
|
||||
"start": "remotion preview",
|
||||
"build": "remotion render MusicVideo out/video.mp4",
|
||||
"render": "node render.mjs",
|
||||
"upgrade": "remotion upgrade"
|
||||
},
|
||||
"keywords": [],
|
||||
"author": "",
|
||||
"license": "ISC",
|
||||
"type": "commonjs",
|
||||
"dependencies": {
|
||||
"@remotion/cli": "^4.0.417",
|
||||
"@remotion/media-utils": "^4.0.417",
|
||||
"react": "^18.3.1",
|
||||
"react-dom": "^18.3.1",
|
||||
"remotion": "^4.0.417"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@types/react": "^19.2.13",
|
||||
"typescript": "^5.9.3"
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,4 @@
|
|||
import {Config} from '@remotion/cli/config';
|
||||
|
||||
Config.setVideoImageFormat('jpeg');
|
||||
Config.setOverwriteOutput(true);
|
||||
123
.claude/skills/acestep-simplemv/scripts/render-mv.sh
Normal file
123
.claude/skills/acestep-simplemv/scripts/render-mv.sh
Normal file
|
|
@ -0,0 +1,123 @@
|
|||
#!/bin/bash
|
||||
# render-mv.sh - Render a music video from audio + lyrics
|
||||
#
|
||||
# Usage:
|
||||
# ./render-mv.sh --audio <file> --lyrics <lrc_file> --title "Title" [options]
|
||||
#
|
||||
# Options:
|
||||
# --audio Audio file path (absolute or relative)
|
||||
# --lyrics LRC format lyrics file
|
||||
# --lyrics-json JSON lyrics file [{start, end, text}]
|
||||
# --title Video title (default: "Music Video")
|
||||
# --subtitle Subtitle text
|
||||
# --credit Bottom credit text
|
||||
# --offset Lyric timing offset in seconds (default: -0.5)
|
||||
# --output Output file path (default: acestep_output/<audio_basename>.mp4)
|
||||
# --codec h264|h265|vp8|vp9 (default: h264)
|
||||
# --browser Custom browser executable path (Chrome/Edge/Chromium)
|
||||
#
|
||||
# Environment variables:
|
||||
# BROWSER_EXECUTABLE Path to browser executable (overrides auto-detection)
|
||||
#
|
||||
# Examples:
|
||||
# ./render-mv.sh --audio song.mp3 --lyrics song.lrc --title "My Song"
|
||||
# ./render-mv.sh --audio /path/to/abc123_1.mp3 --lyrics /path/to/abc123.lrc --title "夜桜"
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
RENDER_DIR="$(cd "$(dirname "$0")" && pwd)"
|
||||
|
||||
# Ensure output directory exists
|
||||
mkdir -p "${RENDER_DIR}/out"
|
||||
|
||||
# Cross-platform realpath alternative (works on macOS/Linux/Windows MSYS2)
|
||||
resolve_path() {
|
||||
local dir base
|
||||
dir="$(cd "$(dirname "$1")" && pwd)"
|
||||
base="$(basename "$1")"
|
||||
echo "${dir}/${base}"
|
||||
}
|
||||
|
||||
AUDIO=""
|
||||
LYRICS=""
|
||||
LYRICS_JSON=""
|
||||
TITLE="Music Video"
|
||||
SUBTITLE=""
|
||||
CREDIT=""
|
||||
OFFSET="-0.5"
|
||||
OUTPUT=""
|
||||
CODEC="h264"
|
||||
BROWSER=""
|
||||
|
||||
# Parse args
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case "$1" in
|
||||
--audio) AUDIO="$2"; shift 2 ;;
|
||||
--lyrics) LYRICS="$2"; shift 2 ;;
|
||||
--lyrics-json) LYRICS_JSON="$2"; shift 2 ;;
|
||||
--title) TITLE="$2"; shift 2 ;;
|
||||
--subtitle) SUBTITLE="$2"; shift 2 ;;
|
||||
--credit) CREDIT="$2"; shift 2 ;;
|
||||
--offset) OFFSET="$2"; shift 2 ;;
|
||||
--output) OUTPUT="$2"; shift 2 ;;
|
||||
--codec) CODEC="$2"; shift 2 ;;
|
||||
--browser) BROWSER="$2"; shift 2 ;;
|
||||
-h|--help)
|
||||
head -20 "$0" | tail -18
|
||||
exit 0
|
||||
;;
|
||||
*)
|
||||
echo "Error: unknown argument: $1" >&2
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
if [[ -z "$AUDIO" ]]; then
|
||||
echo "Error: --audio is required" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [[ ! -f "$AUDIO" ]]; then
|
||||
echo "Error: audio file not found: $AUDIO" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Resolve absolute path for audio
|
||||
AUDIO="$(resolve_path "$AUDIO")"
|
||||
|
||||
# Default output: acestep_output/<audio_basename>.mp4
|
||||
if [[ -z "$OUTPUT" ]]; then
|
||||
BASENAME="$(basename "${AUDIO%.*}")"
|
||||
# Strip trailing _1, _2 etc from audio filename for cleaner video name
|
||||
OUTPUT="${RENDER_DIR}/out/${BASENAME}.mp4"
|
||||
fi
|
||||
|
||||
# Ensure output directory exists
|
||||
mkdir -p "$(dirname "$OUTPUT")"
|
||||
|
||||
# Build node args array (safe quoting, no eval)
|
||||
NODE_ARGS=(render.mjs --audio "$AUDIO" --title "$TITLE" --offset "$OFFSET" --output "$OUTPUT" --codec "$CODEC")
|
||||
|
||||
if [[ -n "$LYRICS" ]]; then
|
||||
LYRICS="$(resolve_path "$LYRICS")"
|
||||
NODE_ARGS+=(--lyrics "$LYRICS")
|
||||
elif [[ -n "$LYRICS_JSON" ]]; then
|
||||
LYRICS_JSON="$(resolve_path "$LYRICS_JSON")"
|
||||
NODE_ARGS+=(--lyrics-json "$LYRICS_JSON")
|
||||
fi
|
||||
|
||||
[[ -n "$SUBTITLE" ]] && NODE_ARGS+=(--subtitle "$SUBTITLE")
|
||||
[[ -n "$CREDIT" ]] && NODE_ARGS+=(--credit "$CREDIT")
|
||||
[[ -n "$BROWSER" ]] && NODE_ARGS+=(--browser "$BROWSER")
|
||||
|
||||
echo "Rendering MV..."
|
||||
echo " Audio: $(basename "$AUDIO")"
|
||||
echo " Title: $TITLE"
|
||||
echo " Output: $OUTPUT"
|
||||
|
||||
cd "$RENDER_DIR"
|
||||
node "${NODE_ARGS[@]}"
|
||||
|
||||
echo ""
|
||||
echo "Output: $OUTPUT"
|
||||
345
.claude/skills/acestep-simplemv/scripts/render.mjs
Normal file
345
.claude/skills/acestep-simplemv/scripts/render.mjs
Normal file
|
|
@ -0,0 +1,345 @@
|
|||
#!/usr/bin/env node
|
||||
|
||||
/**
|
||||
* CLI entry point for rendering music videos.
|
||||
*
|
||||
* Usage:
|
||||
* node render.mjs --audio music.mp3 --lyrics lyrics.lrc --title "Song Name" --output out/video.mp4
|
||||
* node render.mjs --audio music.mp3 --lyrics-json lyrics.json --title "Song Name"
|
||||
*
|
||||
* Options:
|
||||
* --audio Audio file path (absolute paths auto-copied to public/) or filename in public/
|
||||
* --lyrics Path to LRC format lyrics file
|
||||
* --lyrics-json Path to JSON lyrics file [{start, end, text}]
|
||||
* --title Main title (default: "Music Video")
|
||||
* --subtitle Subtitle (default: "")
|
||||
* --credit Bottom credit text (default: "")
|
||||
* --duration Audio duration in seconds (auto-detected if omitted)
|
||||
* --offset Lyric timing offset in seconds (default: -0.5)
|
||||
* --output Output file path (default: out/video.mp4)
|
||||
* --codec Video codec: h264, h265, vp8, vp9 (default: h264)
|
||||
*/
|
||||
|
||||
import {execSync} from 'child_process';
|
||||
import {readFileSync, readdirSync, existsSync, copyFileSync, mkdirSync} from 'fs';
|
||||
import {resolve, basename, isAbsolute, join} from 'path';
|
||||
import {homedir} from 'os';
|
||||
|
||||
/**
|
||||
* Resolve a file path that may be a MSYS2/Cygwin-style path on Windows.
|
||||
* Converts paths like /e/foo/bar to E:/foo/bar for Node.js compatibility.
|
||||
*/
|
||||
function resolveFilePath(p) {
|
||||
if (process.platform === 'win32' && /^\/[a-zA-Z]\//.test(p)) {
|
||||
// Convert MSYS2 path /x/... to X:/...
|
||||
return p[1].toUpperCase() + ':' + p.slice(2);
|
||||
}
|
||||
return resolve(p);
|
||||
}
|
||||
|
||||
/**
|
||||
* Find a usable browser executable for Remotion rendering.
|
||||
*
|
||||
* Search priority:
|
||||
* 1. Environment variable BROWSER_EXECUTABLE
|
||||
* 2. CLI argument --browser
|
||||
* 3. Remotion cache (chrome-headless-shell)
|
||||
* 4. System Chrome (requires --chrome-mode=chrome-for-testing)
|
||||
* 5. System Edge (requires --chrome-mode=chrome-for-testing)
|
||||
* 6. System Chromium (requires --chrome-mode=chrome-for-testing)
|
||||
*
|
||||
* Returns {path, chromeMode} or {path: null, chromeMode: 'headless-shell'} if not found.
|
||||
*
|
||||
* chromeMode:
|
||||
* - 'headless-shell': for chrome-headless-shell binary (uses --headless=old)
|
||||
* - 'chrome-for-testing': for regular Chrome/Edge/Chromium (uses --headless=new)
|
||||
*/
|
||||
function findBrowserExecutable(cliOverride) {
|
||||
// 1. Environment variable — highest priority
|
||||
const envExe = process.env.BROWSER_EXECUTABLE;
|
||||
if (envExe && existsSync(envExe)) {
|
||||
const mode = isHeadlessShell(envExe) ? 'headless-shell' : 'chrome-for-testing';
|
||||
return {path: envExe, chromeMode: mode};
|
||||
}
|
||||
|
||||
// 2. CLI argument
|
||||
if (cliOverride && existsSync(cliOverride)) {
|
||||
const mode = isHeadlessShell(cliOverride) ? 'headless-shell' : 'chrome-for-testing';
|
||||
return {path: cliOverride, chromeMode: mode};
|
||||
}
|
||||
|
||||
const platform = process.platform;
|
||||
const home = homedir();
|
||||
|
||||
// 3. Local node_modules/.remotion (chrome-headless-shell) — uses --headless=old
|
||||
const localCacheDir = join(process.cwd(), 'node_modules', '.remotion', 'chrome-headless-shell');
|
||||
if (existsSync(localCacheDir)) {
|
||||
try {
|
||||
// Structure: chrome-headless-shell/linux64/chrome-headless-shell-linux64/chrome-headless-shell
|
||||
const platformDir = platform === 'win32' ? 'win64' : platform === 'darwin' ? 'mac-arm64' : 'linux64';
|
||||
const exeName = platform === 'win32' ? 'chrome-headless-shell.exe' : 'chrome-headless-shell';
|
||||
const platformPath = join(localCacheDir, platformDir);
|
||||
|
||||
if (existsSync(platformPath)) {
|
||||
const subdirs = readdirSync(platformPath);
|
||||
for (const subdir of subdirs) {
|
||||
const exe = join(platformPath, subdir, exeName);
|
||||
if (existsSync(exe)) return {path: exe, chromeMode: 'headless-shell'};
|
||||
}
|
||||
}
|
||||
} catch {}
|
||||
}
|
||||
|
||||
// 4. User home Remotion cache (chrome-headless-shell) — uses --headless=old
|
||||
let cacheDir;
|
||||
if (platform === 'win32') {
|
||||
cacheDir = join(home, 'AppData', 'Local', 'remotion', 'chrome-headless-shell');
|
||||
} else if (platform === 'darwin') {
|
||||
cacheDir = join(home, 'Library', 'Caches', 'remotion', 'chrome-headless-shell');
|
||||
} else {
|
||||
cacheDir = join(home, '.cache', 'remotion', 'chrome-headless-shell');
|
||||
}
|
||||
|
||||
if (existsSync(cacheDir)) {
|
||||
try {
|
||||
const versions = readdirSync(cacheDir).sort().reverse();
|
||||
const exeName = platform === 'win32' ? 'chrome-headless-shell.exe' : 'chrome-headless-shell';
|
||||
for (const ver of versions) {
|
||||
const exe = join(cacheDir, ver, exeName);
|
||||
if (existsSync(exe)) return {path: exe, chromeMode: 'headless-shell'};
|
||||
}
|
||||
} catch {}
|
||||
}
|
||||
|
||||
// 4-6. System browsers: Chrome, Edge, Chromium — require --chrome-mode=chrome-for-testing
|
||||
const browserPaths = platform === 'win32' ? [
|
||||
// Chrome
|
||||
'C:\\Program Files\\Google\\Chrome\\Application\\chrome.exe',
|
||||
'C:\\Program Files (x86)\\Google\\Chrome\\Application\\chrome.exe',
|
||||
// Edge (pre-installed on Windows 10/11)
|
||||
'C:\\Program Files (x86)\\Microsoft\\Edge\\Application\\msedge.exe',
|
||||
'C:\\Program Files\\Microsoft\\Edge\\Application\\msedge.exe',
|
||||
] : platform === 'darwin' ? [
|
||||
'/Applications/Google Chrome.app/Contents/MacOS/Google Chrome',
|
||||
'/Applications/Microsoft Edge.app/Contents/MacOS/Microsoft Edge',
|
||||
'/Applications/Chromium.app/Contents/MacOS/Chromium',
|
||||
] : [
|
||||
'/usr/bin/google-chrome',
|
||||
'/usr/bin/google-chrome-stable',
|
||||
'/usr/bin/chromium',
|
||||
'/usr/bin/chromium-browser',
|
||||
'/usr/bin/microsoft-edge',
|
||||
'/usr/bin/microsoft-edge-stable',
|
||||
];
|
||||
|
||||
for (const p of browserPaths) {
|
||||
if (existsSync(p)) return {path: p, chromeMode: 'chrome-for-testing'};
|
||||
}
|
||||
|
||||
return {path: null, chromeMode: 'headless-shell'};
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if the given executable path is a chrome-headless-shell binary.
|
||||
*/
|
||||
function isHeadlessShell(exePath) {
|
||||
const name = exePath.toLowerCase().replace(/\\/g, '/');
|
||||
return name.includes('chrome-headless-shell');
|
||||
}
|
||||
|
||||
function parseLrc(content) {
|
||||
const lines = content.split(/\r?\n/).filter(l => l.trim());
|
||||
const parsed = [];
|
||||
for (const line of lines) {
|
||||
const match = line.match(/^\[(\d{2}):(\d{2})(?:\.(\d{2,3}))?\]\s*(.*)$/);
|
||||
if (match) {
|
||||
const minutes = parseInt(match[1], 10);
|
||||
const seconds = parseInt(match[2], 10);
|
||||
const cs = match[3] ? parseInt(match[3].padEnd(3, '0'), 10) / 1000 : 0;
|
||||
const time = minutes * 60 + seconds + cs;
|
||||
const text = match[4].trim();
|
||||
parsed.push({time, text});
|
||||
}
|
||||
}
|
||||
const result = [];
|
||||
for (let i = 0; i < parsed.length; i++) {
|
||||
const start = parsed[i].time;
|
||||
const end = i < parsed.length - 1 ? parsed[i + 1].time : start + 5;
|
||||
if (parsed[i].text) {
|
||||
result.push({start, end, text: parsed[i].text});
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
function getAudioDuration(filePath) {
|
||||
try {
|
||||
const result = execSync(
|
||||
`ffprobe -v error -show_entries format=duration -of csv=p=0 "${filePath}"`,
|
||||
{encoding: 'utf-8'}
|
||||
).trim();
|
||||
return parseFloat(result);
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
function parseArgs(argv) {
|
||||
const args = {};
|
||||
for (let i = 2; i < argv.length; i++) {
|
||||
const key = argv[i];
|
||||
if (key.startsWith('--') && i + 1 < argv.length) {
|
||||
const name = key.slice(2);
|
||||
args[name] = argv[i + 1];
|
||||
i++;
|
||||
}
|
||||
}
|
||||
return args;
|
||||
}
|
||||
|
||||
const args = parseArgs(process.argv);
|
||||
|
||||
// Validate required args
|
||||
if (!args.audio) {
|
||||
console.error('Error: --audio is required');
|
||||
console.error('Usage: node render.mjs --audio music.mp3 --lyrics lyrics.lrc --title "Song"');
|
||||
process.exit(1);
|
||||
}
|
||||
|
||||
// If audio is an absolute path, copy it into public/ and use the filename
|
||||
let audioFileName = args.audio;
|
||||
const resolvedAudio = resolveFilePath(args.audio);
|
||||
if (isAbsolute(resolvedAudio)) {
|
||||
if (!existsSync(resolvedAudio)) {
|
||||
console.error(`Error: Audio file not found: ${resolvedAudio}`);
|
||||
process.exit(1);
|
||||
}
|
||||
const pubDir = resolve('public');
|
||||
mkdirSync(pubDir, {recursive: true});
|
||||
const fname = basename(resolvedAudio);
|
||||
const dest = resolve(pubDir, fname);
|
||||
if (resolve(resolvedAudio) !== dest) {
|
||||
copyFileSync(resolvedAudio, dest);
|
||||
console.log(`Copied audio to public/${fname}`);
|
||||
}
|
||||
audioFileName = fname;
|
||||
} else {
|
||||
// Relative name — must exist in public/
|
||||
const audioPath = resolve('public', args.audio);
|
||||
if (!existsSync(audioPath)) {
|
||||
console.error(`Error: Audio file not found in public/: ${args.audio}`);
|
||||
process.exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
// Parse lyrics
|
||||
let lyrics = [];
|
||||
if (args.lyrics) {
|
||||
const lrcPath = resolveFilePath(args.lyrics);
|
||||
if (!existsSync(lrcPath)) {
|
||||
console.error(`Error: LRC file not found: ${lrcPath}`);
|
||||
process.exit(1);
|
||||
}
|
||||
lyrics = parseLrc(readFileSync(lrcPath, 'utf-8'));
|
||||
console.log(`Parsed ${lyrics.length} lyric lines from LRC file`);
|
||||
} else if (args['lyrics-json']) {
|
||||
const jsonPath = resolveFilePath(args['lyrics-json']);
|
||||
if (!existsSync(jsonPath)) {
|
||||
console.error(`Error: JSON lyrics file not found: ${jsonPath}`);
|
||||
process.exit(1);
|
||||
}
|
||||
lyrics = JSON.parse(readFileSync(jsonPath, 'utf-8'));
|
||||
console.log(`Loaded ${lyrics.length} lyric lines from JSON file`);
|
||||
}
|
||||
|
||||
// Determine audio duration
|
||||
let duration = args.duration ? parseFloat(args.duration) : null;
|
||||
if (!duration) {
|
||||
const audioPath = resolve('public', audioFileName);
|
||||
if (existsSync(audioPath)) {
|
||||
duration = getAudioDuration(audioPath);
|
||||
if (duration) {
|
||||
console.log(`Auto-detected audio duration: ${duration.toFixed(2)}s`);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!duration) {
|
||||
console.error('Error: Could not detect audio duration. Please provide --duration');
|
||||
process.exit(1);
|
||||
}
|
||||
|
||||
// Build input props
|
||||
// Sanitize title: single-line, max 50 chars
|
||||
const rawTitle = (args.title || 'Music Video').replace(/[\r\n]+/g, ' ').trim();
|
||||
const title = rawTitle.length > 50 ? rawTitle.slice(0, 47) + '...' : rawTitle;
|
||||
|
||||
const inputProps = {
|
||||
audioFileName: audioFileName,
|
||||
lyrics,
|
||||
title,
|
||||
subtitle: (args.subtitle || '').replace(/[\r\n]+/g, ' ').trim(),
|
||||
creditText: args.credit || '',
|
||||
durationInSeconds: duration,
|
||||
lyricOffset: args.offset ? parseFloat(args.offset) : -0.5,
|
||||
};
|
||||
|
||||
const output = args.output ? resolveFilePath(args.output) : 'out/video.mp4';
|
||||
const codec = args.codec || 'h264';
|
||||
|
||||
// Write props to temp file to avoid shell escaping issues
|
||||
const propsFile = resolve('.render-props.json');
|
||||
const {writeFileSync} = await import('fs');
|
||||
writeFileSync(propsFile, JSON.stringify(inputProps));
|
||||
|
||||
// Find browser executable to avoid re-downloading
|
||||
const {path: browserExe, chromeMode} = findBrowserExecutable(args.browser);
|
||||
|
||||
if (!browserExe) {
|
||||
console.warn('⚠️ No browser found. Remotion will attempt to download chrome-headless-shell from Google servers.');
|
||||
console.warn(' If download fails (e.g. Google servers inaccessible), try one of these:');
|
||||
console.warn(' 1. Set environment variable: BROWSER_EXECUTABLE=/path/to/chrome-or-edge');
|
||||
console.warn(' 2. Pass CLI argument: --browser /path/to/chrome-or-edge');
|
||||
console.warn(' 3. Enable proxy and retry');
|
||||
console.warn('');
|
||||
}
|
||||
|
||||
const cmd = [
|
||||
'npx remotion render',
|
||||
'MusicVideo',
|
||||
`"${output}"`,
|
||||
`--props="${propsFile}"`,
|
||||
`--codec=${codec}`,
|
||||
'--log=error',
|
||||
browserExe ? `--browser-executable="${browserExe}"` : '',
|
||||
chromeMode !== 'headless-shell' ? `--chrome-mode=${chromeMode}` : '',
|
||||
].filter(Boolean).join(' ');
|
||||
|
||||
console.log(`\nRendering video...`);
|
||||
console.log(` Audio: ${args.audio}`);
|
||||
console.log(` Title: ${inputProps.title}`);
|
||||
console.log(` Duration: ${duration.toFixed(1)}s`);
|
||||
console.log(` Lyrics: ${lyrics.length} lines`);
|
||||
console.log(` Output: ${output}`);
|
||||
console.log(` Codec: ${codec}`);
|
||||
if (browserExe) console.log(` Browser: ${browserExe}`);
|
||||
if (chromeMode !== 'headless-shell') console.log(` Chrome mode: ${chromeMode}`);
|
||||
console.log('');
|
||||
|
||||
try {
|
||||
const result = execSync(cmd, {encoding: 'utf-8', stdio: ['pipe', 'pipe', 'pipe']});
|
||||
// Only show the final output file line (starts with '+') and size info
|
||||
const outputLines = result.split(/\r?\n/).filter(l => l.includes(output) || /^\+/.test(l.replace(/\x1b\[[0-9;]*m/g, '').trim()));
|
||||
if (outputLines.length) console.log(outputLines.join('\n'));
|
||||
console.log(`\n✅ Video rendered successfully: ${output}`);
|
||||
} catch (e) {
|
||||
// Show stderr on failure for debugging
|
||||
if (e.stderr) console.error(e.stderr.toString());
|
||||
console.error('\n❌ Render failed');
|
||||
process.exit(1);
|
||||
} finally {
|
||||
// Clean up temp props file
|
||||
try {
|
||||
const {unlinkSync} = await import('fs');
|
||||
unlinkSync(propsFile);
|
||||
} catch {}
|
||||
}
|
||||
12
.claude/skills/acestep-simplemv/scripts/render.sh
Normal file
12
.claude/skills/acestep-simplemv/scripts/render.sh
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
#!/bin/bash
|
||||
# render.sh - Convenience wrapper for rendering music videos
|
||||
#
|
||||
# Usage:
|
||||
# ./render.sh --audio music.mp3 --lyrics lyrics.lrc --title "Song Name"
|
||||
# ./render.sh --audio music.mp3 --lyrics-json lyrics.json --title "Song" --output out/mv.mp4
|
||||
#
|
||||
# All options are passed through to render.mjs. See render.mjs for full options list.
|
||||
|
||||
set -e
|
||||
cd "$(dirname "$0")"
|
||||
node render.mjs "$@"
|
||||
|
|
@ -0,0 +1,314 @@
|
|||
import React from 'react';
|
||||
import {
|
||||
AbsoluteFill,
|
||||
Audio,
|
||||
useCurrentFrame,
|
||||
useVideoConfig,
|
||||
interpolate,
|
||||
Easing,
|
||||
staticFile,
|
||||
} from 'remotion';
|
||||
import {useAudioData, visualizeAudio} from '@remotion/media-utils';
|
||||
import {MVInputProps} from './types';
|
||||
|
||||
export const AudioVisualization: React.FC<MVInputProps> = ({
|
||||
audioFileName,
|
||||
lyrics,
|
||||
title,
|
||||
subtitle,
|
||||
creditText,
|
||||
lyricOffset,
|
||||
}) => {
|
||||
const frame = useCurrentFrame();
|
||||
const {fps, durationInFrames} = useVideoConfig();
|
||||
|
||||
const audioSrc = audioFileName.startsWith('http')
|
||||
? audioFileName
|
||||
: staticFile(audioFileName);
|
||||
|
||||
const audioData = useAudioData(audioSrc);
|
||||
|
||||
if (!audioData) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const visualization = visualizeAudio({
|
||||
fps,
|
||||
frame,
|
||||
audioData,
|
||||
numberOfSamples: 128,
|
||||
optimizeFor: 'speed',
|
||||
});
|
||||
|
||||
const currentTime = frame / fps + lyricOffset;
|
||||
|
||||
const currentLyric = lyrics.find(
|
||||
(lyric) => currentTime >= lyric.start && currentTime < lyric.end
|
||||
);
|
||||
|
||||
const lyricProgress = currentLyric
|
||||
? interpolate(
|
||||
currentTime,
|
||||
[currentLyric.start, currentLyric.start + 0.3],
|
||||
[0, 1],
|
||||
{extrapolateRight: 'clamp'}
|
||||
)
|
||||
: 0;
|
||||
|
||||
const titleOpacity = interpolate(frame, [0, 30], [0, 1], {
|
||||
extrapolateRight: 'clamp',
|
||||
});
|
||||
|
||||
const titleY = interpolate(frame, [0, 30], [-50, 0], {
|
||||
extrapolateRight: 'clamp',
|
||||
easing: Easing.out(Easing.ease),
|
||||
});
|
||||
|
||||
const hue = interpolate(frame, [0, durationInFrames], [200, 320], {
|
||||
extrapolateRight: 'wrap',
|
||||
});
|
||||
|
||||
const avgAmplitude =
|
||||
visualization.reduce((sum, val) => sum + val, 0) / visualization.length;
|
||||
|
||||
return (
|
||||
<AbsoluteFill>
|
||||
{/* Animated gradient background */}
|
||||
<AbsoluteFill
|
||||
style={{
|
||||
background: `linear-gradient(135deg, hsl(${hue}, 80%, 12%) 0%, hsl(${hue + 80}, 70%, 8%) 100%)`,
|
||||
}}
|
||||
/>
|
||||
|
||||
{/* Radial glow effect */}
|
||||
<AbsoluteFill
|
||||
style={{
|
||||
background: `radial-gradient(circle at 50% 50%, hsla(${hue}, 100%, 50%, ${avgAmplitude * 0.3}) 0%, transparent 50%)`,
|
||||
}}
|
||||
/>
|
||||
|
||||
{/* Audio source */}
|
||||
<Audio src={audioSrc} />
|
||||
|
||||
{/* Bottom frequency bars */}
|
||||
<AbsoluteFill
|
||||
style={{
|
||||
justifyContent: 'flex-end',
|
||||
alignItems: 'center',
|
||||
}}
|
||||
>
|
||||
<div
|
||||
style={{
|
||||
display: 'flex',
|
||||
alignItems: 'flex-end',
|
||||
justifyContent: 'center',
|
||||
gap: 4,
|
||||
height: 350,
|
||||
width: '90%',
|
||||
marginBottom: 180,
|
||||
}}
|
||||
>
|
||||
{visualization.map((value, index) => {
|
||||
const scaledValue = Math.pow(value, 0.6);
|
||||
const barHeight = Math.max(scaledValue * 800, 20);
|
||||
const colorIndex = (index / visualization.length) * 360;
|
||||
|
||||
return (
|
||||
<div
|
||||
key={index}
|
||||
style={{
|
||||
width: `${100 / visualization.length}%`,
|
||||
height: barHeight,
|
||||
background: `linear-gradient(to top,
|
||||
hsl(${(colorIndex + hue) % 360}, 90%, 60%),
|
||||
hsl(${(colorIndex + hue + 40) % 360}, 90%, 70%))`,
|
||||
borderRadius: '4px 4px 0 0',
|
||||
boxShadow: `0 0 ${10 + scaledValue * 30}px hsla(${(colorIndex + hue) % 360}, 100%, 60%, ${scaledValue})`,
|
||||
transition: 'height 0.05s ease-out',
|
||||
}}
|
||||
/>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
</AbsoluteFill>
|
||||
|
||||
{/* Symmetrical side bars */}
|
||||
<AbsoluteFill
|
||||
style={{
|
||||
justifyContent: 'center',
|
||||
alignItems: 'center',
|
||||
}}
|
||||
>
|
||||
{/* Left bars */}
|
||||
<div
|
||||
style={{
|
||||
position: 'absolute',
|
||||
left: 40,
|
||||
display: 'flex',
|
||||
flexDirection: 'column',
|
||||
gap: 8,
|
||||
height: '80%',
|
||||
justifyContent: 'space-around',
|
||||
}}
|
||||
>
|
||||
{visualization.slice(0, 20).map((value, index) => {
|
||||
const scaledValue = Math.pow(value, 0.6);
|
||||
const barWidth = Math.max(scaledValue * 300, 10);
|
||||
const colorIndex = (index / 20) * 360;
|
||||
return (
|
||||
<div
|
||||
key={index}
|
||||
style={{
|
||||
width: barWidth,
|
||||
height: 12,
|
||||
background: `linear-gradient(to right,
|
||||
hsl(${(colorIndex + hue) % 360}, 90%, 60%),
|
||||
hsl(${(colorIndex + hue + 40) % 360}, 90%, 70%))`,
|
||||
borderRadius: '0 6px 6px 0',
|
||||
boxShadow: `0 0 ${10 + scaledValue * 20}px hsla(${(colorIndex + hue) % 360}, 100%, 60%, ${scaledValue})`,
|
||||
}}
|
||||
/>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
|
||||
{/* Right bars */}
|
||||
<div
|
||||
style={{
|
||||
position: 'absolute',
|
||||
right: 40,
|
||||
display: 'flex',
|
||||
flexDirection: 'column',
|
||||
gap: 8,
|
||||
height: '80%',
|
||||
justifyContent: 'space-around',
|
||||
alignItems: 'flex-end',
|
||||
}}
|
||||
>
|
||||
{visualization.slice(0, 20).map((value, index) => {
|
||||
const scaledValue = Math.pow(value, 0.6);
|
||||
const barWidth = Math.max(scaledValue * 300, 10);
|
||||
const colorIndex = (index / 20) * 360;
|
||||
return (
|
||||
<div
|
||||
key={index}
|
||||
style={{
|
||||
width: barWidth,
|
||||
height: 12,
|
||||
background: `linear-gradient(to left,
|
||||
hsl(${(colorIndex + hue + 180) % 360}, 90%, 60%),
|
||||
hsl(${(colorIndex + hue + 220) % 360}, 90%, 70%))`,
|
||||
borderRadius: '6px 0 0 6px',
|
||||
boxShadow: `0 0 ${10 + scaledValue * 20}px hsla(${(colorIndex + hue + 180) % 360}, 100%, 60%, ${scaledValue})`,
|
||||
}}
|
||||
/>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
</AbsoluteFill>
|
||||
|
||||
{/* Center title area */}
|
||||
<AbsoluteFill
|
||||
style={{
|
||||
justifyContent: 'flex-start',
|
||||
alignItems: 'center',
|
||||
paddingTop: 60,
|
||||
}}
|
||||
>
|
||||
<div
|
||||
style={{
|
||||
textAlign: 'center',
|
||||
transform: `scale(${1 + avgAmplitude * 0.1})`,
|
||||
transition: 'transform 0.1s ease-out',
|
||||
}}
|
||||
>
|
||||
<div
|
||||
style={{
|
||||
fontSize: 96,
|
||||
fontWeight: 'bold',
|
||||
color: 'white',
|
||||
opacity: titleOpacity,
|
||||
transform: `translateY(${titleY}px)`,
|
||||
textShadow: `0 0 40px hsla(${hue}, 100%, 70%, 0.8), 0 4px 20px rgba(0,0,0,0.5)`,
|
||||
fontFamily: '"Noto Sans CJK JP", "Noto Sans CJK SC", Arial, sans-serif',
|
||||
marginBottom: 10,
|
||||
}}
|
||||
>
|
||||
{title}
|
||||
</div>
|
||||
<div
|
||||
style={{
|
||||
fontSize: 56,
|
||||
fontWeight: '600',
|
||||
color: 'rgba(255,255,255,0.95)',
|
||||
opacity: titleOpacity,
|
||||
transform: `translateY(${titleY}px)`,
|
||||
textShadow: `0 0 30px hsla(${hue + 60}, 100%, 70%, 0.6), 0 2px 10px rgba(0,0,0,0.5)`,
|
||||
fontFamily: '"Noto Sans CJK JP", "Noto Sans CJK SC", Arial, sans-serif',
|
||||
letterSpacing: '4px',
|
||||
}}
|
||||
>
|
||||
{subtitle}
|
||||
</div>
|
||||
</div>
|
||||
</AbsoluteFill>
|
||||
|
||||
{/* Lyrics display */}
|
||||
{currentLyric && currentLyric.text && (
|
||||
<AbsoluteFill
|
||||
style={{
|
||||
justifyContent: 'center',
|
||||
alignItems: 'center',
|
||||
paddingTop: 100,
|
||||
}}
|
||||
>
|
||||
<div
|
||||
style={{
|
||||
fontSize: 48,
|
||||
fontWeight: '600',
|
||||
color: 'white',
|
||||
textAlign: 'center',
|
||||
maxWidth: '85%',
|
||||
opacity: lyricProgress,
|
||||
transform: `translateY(${(1 - lyricProgress) * 30}px)`,
|
||||
textShadow: `0 0 40px hsla(${hue}, 100%, 70%, 0.8), 0 4px 30px rgba(0,0,0,0.9)`,
|
||||
fontFamily: '"Noto Sans CJK JP", "Noto Sans CJK SC", Arial, sans-serif',
|
||||
lineHeight: 1.5,
|
||||
padding: '25px 50px',
|
||||
background: `linear-gradient(135deg, rgba(0,0,0,0.4), rgba(0,0,0,0.2))`,
|
||||
backdropFilter: 'blur(15px)',
|
||||
borderRadius: '20px',
|
||||
border: `2px solid hsla(${hue}, 80%, 60%, 0.3)`,
|
||||
boxShadow: `0 8px 32px rgba(0,0,0,0.5), inset 0 0 40px hsla(${hue}, 100%, 50%, 0.1)`,
|
||||
}}
|
||||
>
|
||||
{currentLyric.text}
|
||||
</div>
|
||||
</AbsoluteFill>
|
||||
)}
|
||||
|
||||
{/* Bottom credit text */}
|
||||
<AbsoluteFill
|
||||
style={{
|
||||
justifyContent: 'flex-end',
|
||||
alignItems: 'center',
|
||||
padding: 50,
|
||||
}}
|
||||
>
|
||||
<div
|
||||
style={{
|
||||
fontSize: 32,
|
||||
fontWeight: '500',
|
||||
color: 'white',
|
||||
opacity: 0.8,
|
||||
textAlign: 'center',
|
||||
textShadow: `0 0 20px hsla(${hue}, 100%, 70%, 0.6), 0 2px 10px rgba(0,0,0,0.7)`,
|
||||
fontFamily: '"Noto Sans CJK JP", "Noto Sans CJK SC", Arial, sans-serif',
|
||||
}}
|
||||
>
|
||||
{creditText}
|
||||
</div>
|
||||
</AbsoluteFill>
|
||||
</AbsoluteFill>
|
||||
);
|
||||
};
|
||||
31
.claude/skills/acestep-simplemv/scripts/src/Root.tsx
Normal file
31
.claude/skills/acestep-simplemv/scripts/src/Root.tsx
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
import React from 'react';
|
||||
import {Composition, CalculateMetadataFunction} from 'remotion';
|
||||
import {AudioVisualization} from './AudioVisualization';
|
||||
import {MVInputProps, defaultProps} from './types';
|
||||
|
||||
const calculateMetadata: CalculateMetadataFunction<MVInputProps> = ({props}) => {
|
||||
const fps = 30;
|
||||
const durationInFrames = Math.ceil(props.durationInSeconds * fps);
|
||||
return {
|
||||
durationInFrames,
|
||||
fps,
|
||||
width: 1920,
|
||||
height: 1080,
|
||||
};
|
||||
};
|
||||
|
||||
export const RemotionRoot: React.FC = () => {
|
||||
return (
|
||||
<>
|
||||
<Composition
|
||||
id="MusicVideo"
|
||||
component={AudioVisualization}
|
||||
fps={30}
|
||||
width={1920}
|
||||
height={1080}
|
||||
defaultProps={defaultProps}
|
||||
calculateMetadata={calculateMetadata}
|
||||
/>
|
||||
</>
|
||||
);
|
||||
};
|
||||
4
.claude/skills/acestep-simplemv/scripts/src/index.ts
Normal file
4
.claude/skills/acestep-simplemv/scripts/src/index.ts
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
import {registerRoot} from 'remotion';
|
||||
import {RemotionRoot} from './Root';
|
||||
|
||||
registerRoot(RemotionRoot);
|
||||
40
.claude/skills/acestep-simplemv/scripts/src/parseLrc.ts
Normal file
40
.claude/skills/acestep-simplemv/scripts/src/parseLrc.ts
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
import {LyricLine} from './types';
|
||||
|
||||
/**
|
||||
* Parse LRC format lyrics into LyricLine array.
|
||||
* LRC format: [mm:ss.xx] lyrics text
|
||||
*
|
||||
* Example:
|
||||
* [00:02.99] Version one point five is here today
|
||||
* [00:07.00] ACE-Step's rising, leading the way
|
||||
*/
|
||||
export function parseLrc(lrcContent: string): LyricLine[] {
|
||||
const lines = lrcContent.split('\n').filter((line) => line.trim());
|
||||
const parsed: {time: number; text: string}[] = [];
|
||||
|
||||
for (const line of lines) {
|
||||
// Match [mm:ss.xx] or [mm:ss] format
|
||||
const match = line.match(/^\[(\d{2}):(\d{2})(?:\.(\d{2,3}))?\]\s*(.*)$/);
|
||||
if (match) {
|
||||
const minutes = parseInt(match[1], 10);
|
||||
const seconds = parseInt(match[2], 10);
|
||||
const centiseconds = match[3] ? parseInt(match[3].padEnd(3, '0'), 10) / 1000 : 0;
|
||||
const time = minutes * 60 + seconds + centiseconds;
|
||||
const text = match[4].trim();
|
||||
parsed.push({time, text});
|
||||
}
|
||||
}
|
||||
|
||||
// Convert to LyricLine with start/end
|
||||
const result: LyricLine[] = [];
|
||||
for (let i = 0; i < parsed.length; i++) {
|
||||
const start = parsed[i].time;
|
||||
const end = i < parsed.length - 1 ? parsed[i + 1].time : start + 5;
|
||||
const text = parsed[i].text;
|
||||
if (text) {
|
||||
result.push({start, end, text});
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
32
.claude/skills/acestep-simplemv/scripts/src/types.ts
Normal file
32
.claude/skills/acestep-simplemv/scripts/src/types.ts
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
export interface LyricLine {
|
||||
start: number;
|
||||
end: number;
|
||||
text: string;
|
||||
}
|
||||
|
||||
export interface MVInputProps extends Record<string, unknown> {
|
||||
/** Path to audio file (relative to public/ or absolute URL) */
|
||||
audioFileName: string;
|
||||
/** Lyrics as JSON array [{start, end, text}] */
|
||||
lyrics: LyricLine[];
|
||||
/** Main title displayed at top */
|
||||
title: string;
|
||||
/** Subtitle displayed below title */
|
||||
subtitle: string;
|
||||
/** Bottom credit text */
|
||||
creditText: string;
|
||||
/** Audio duration in seconds (used to calculate total frames) */
|
||||
durationInSeconds: number;
|
||||
/** Lyric timing offset in seconds (positive = delay, negative = advance) */
|
||||
lyricOffset: number;
|
||||
}
|
||||
|
||||
export const defaultProps: MVInputProps = {
|
||||
audioFileName: 'celebration.mp3',
|
||||
lyrics: [],
|
||||
title: 'ACE-Step',
|
||||
subtitle: 'v1.5',
|
||||
creditText: 'Powered by Claude Code + ACE-Step',
|
||||
durationInSeconds: 150,
|
||||
lyricOffset: -0.5,
|
||||
};
|
||||
18
.claude/skills/acestep-simplemv/scripts/tsconfig.json
Normal file
18
.claude/skills/acestep-simplemv/scripts/tsconfig.json
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
{
|
||||
"compilerOptions": {
|
||||
"target": "ES2022",
|
||||
"module": "ES2022",
|
||||
"moduleResolution": "Bundler",
|
||||
"lib": ["DOM", "ES2022"],
|
||||
"jsx": "react-jsx",
|
||||
"skipLibCheck": true,
|
||||
"strict": true,
|
||||
"esModuleInterop": true,
|
||||
"allowSyntheticDefaultImports": true,
|
||||
"forceConsistentCasingInFileNames": true,
|
||||
"resolveJsonModule": true,
|
||||
"isolatedModules": true,
|
||||
"noEmit": true
|
||||
},
|
||||
"include": ["src/**/*"]
|
||||
}
|
||||
194
.claude/skills/acestep-songwriting/SKILL.md
Normal file
194
.claude/skills/acestep-songwriting/SKILL.md
Normal file
|
|
@ -0,0 +1,194 @@
|
|||
---
|
||||
name: acestep-songwriting
|
||||
description: Music songwriting guide for ACE-Step. Provides professional knowledge on writing captions, lyrics, choosing BPM/key/duration, and structuring songs. Use this skill when users want to create, write, or plan a song before generating it with ACE-Step.
|
||||
allowed-tools: Read
|
||||
---
|
||||
|
||||
# ACE-Step Songwriting Guide
|
||||
|
||||
Professional music creation knowledge for writing captions, lyrics, and choosing music parameters for ACE-Step.
|
||||
|
||||
## Output Format
|
||||
|
||||
After using this guide, produce two things for the acestep skill:
|
||||
1. **Caption** (`-c`): Style/genre/instruments/emotion description
|
||||
2. **Lyrics** (`-l`): Complete structured lyrics with tags
|
||||
3. **Parameters**: `--duration`, `--bpm`, `--key`, `--time-signature`, `--language`
|
||||
|
||||
---
|
||||
|
||||
## Caption: The Most Important Input
|
||||
|
||||
**Caption is the most important factor affecting generated music.**
|
||||
|
||||
Supports multiple formats: simple style words, comma-separated tags, complex natural language descriptions.
|
||||
|
||||
### Common Dimensions
|
||||
|
||||
| Dimension | Examples |
|
||||
|-----------|----------|
|
||||
| **Style/Genre** | pop, rock, jazz, electronic, hip-hop, R&B, folk, classical, lo-fi, synthwave |
|
||||
| **Emotion/Atmosphere** | melancholic, uplifting, energetic, dreamy, dark, nostalgic, euphoric, intimate |
|
||||
| **Instruments** | acoustic guitar, piano, synth pads, 808 drums, strings, brass, electric bass |
|
||||
| **Timbre Texture** | warm, bright, crisp, muddy, airy, punchy, lush, raw, polished |
|
||||
| **Era Reference** | 80s synth-pop, 90s grunge, 2010s EDM, vintage soul, modern trap |
|
||||
| **Production Style** | lo-fi, high-fidelity, live recording, studio-polished, bedroom pop |
|
||||
| **Vocal Characteristics** | female vocal, male vocal, breathy, powerful, falsetto, raspy, choir |
|
||||
| **Speed/Rhythm** | slow tempo, mid-tempo, fast-paced, groovy, driving, laid-back |
|
||||
| **Structure Hints** | building intro, catchy chorus, dramatic bridge, fade-out ending |
|
||||
|
||||
### Caption Writing Principles
|
||||
|
||||
1. **Specific beats vague** — "sad piano ballad with female breathy vocal" > "a sad song"
|
||||
2. **Combine multiple dimensions** — style+emotion+instruments+timbre anchors direction precisely
|
||||
3. **Use references well** — "in the style of 80s synthwave" conveys complex aesthetic quickly
|
||||
4. **Texture words are useful** — warm, crisp, airy, punchy influence mixing and timbre
|
||||
5. **Don't pursue perfection** — Caption is a starting point, iterate based on results
|
||||
6. **Granularity determines freedom** — Less detail = more model creativity; more detail = more control
|
||||
7. **Avoid conflicting words** — "classical strings" + "hardcore metal" degrades output
|
||||
- **Fix: Repetition reinforcement** — Repeat the elements you want more
|
||||
- **Fix: Conflict to evolution** — "Start with soft strings, middle becomes metal rock, end turns to hip-hop"
|
||||
8. **Don't put BPM/key/tempo in Caption** — Use dedicated parameters instead
|
||||
|
||||
---
|
||||
|
||||
## Lyrics: The Temporal Script
|
||||
|
||||
Lyrics controls how music unfolds over time. It carries:
|
||||
- Lyric text itself
|
||||
- **Structure tags** ([Verse], [Chorus], [Bridge]...)
|
||||
- **Vocal style hints** ([raspy vocal], [whispered]...)
|
||||
- **Instrumental sections** ([guitar solo], [drum break]...)
|
||||
- **Energy changes** ([building energy], [explosive drop]...)
|
||||
|
||||
### Structure Tags
|
||||
|
||||
| Category | Tag | Description |
|
||||
|----------|-----|-------------|
|
||||
| **Basic Structure** | `[Intro]` | Opening, establish atmosphere |
|
||||
| | `[Verse]` / `[Verse 1]` | Verse, narrative progression |
|
||||
| | `[Pre-Chorus]` | Pre-chorus, build energy |
|
||||
| | `[Chorus]` | Chorus, emotional climax |
|
||||
| | `[Bridge]` | Bridge, transition or elevation |
|
||||
| | `[Outro]` | Ending, conclusion |
|
||||
| **Dynamic Sections** | `[Build]` | Energy gradually rising |
|
||||
| | `[Drop]` | Electronic music energy release |
|
||||
| | `[Breakdown]` | Reduced instrumentation, space |
|
||||
| **Instrumental** | `[Instrumental]` | Pure instrumental, no vocals |
|
||||
| | `[Guitar Solo]` | Guitar solo |
|
||||
| | `[Piano Interlude]` | Piano interlude |
|
||||
| **Special** | `[Fade Out]` | Fade out ending |
|
||||
| | `[Silence]` | Silence |
|
||||
|
||||
### Combining Tags
|
||||
|
||||
Use `-` for finer control, but keep it concise:
|
||||
|
||||
```
|
||||
✅ [Chorus - anthemic]
|
||||
❌ [Chorus - anthemic - stacked harmonies - high energy - powerful - epic]
|
||||
```
|
||||
|
||||
Put complex style descriptions in Caption, not in tags.
|
||||
|
||||
### Caption-Lyrics Consistency
|
||||
|
||||
**Models are not good at resolving conflicts.** Checklist:
|
||||
- Instruments in Caption ↔ Instrumental section tags in Lyrics
|
||||
- Emotion in Caption ↔ Energy tags in Lyrics
|
||||
- Vocal description in Caption ↔ Vocal control tags in Lyrics
|
||||
|
||||
### Vocal Control Tags
|
||||
|
||||
| Tag | Effect |
|
||||
|-----|--------|
|
||||
| `[raspy vocal]` | Raspy, textured vocals |
|
||||
| `[whispered]` | Whispered |
|
||||
| `[falsetto]` | Falsetto |
|
||||
| `[powerful belting]` | Powerful, high-pitched singing |
|
||||
| `[spoken word]` | Rap/recitation |
|
||||
| `[harmonies]` | Layered harmonies |
|
||||
| `[call and response]` | Call and response |
|
||||
| `[ad-lib]` | Improvised embellishments |
|
||||
|
||||
### Energy and Emotion Tags
|
||||
|
||||
| Tag | Effect |
|
||||
|-----|--------|
|
||||
| `[high energy]` | High energy, passionate |
|
||||
| `[low energy]` | Low energy, restrained |
|
||||
| `[building energy]` | Increasing energy |
|
||||
| `[explosive]` | Explosive energy |
|
||||
| `[melancholic]` | Melancholic |
|
||||
| `[euphoric]` | Euphoric |
|
||||
| `[dreamy]` | Dreamy |
|
||||
| `[aggressive]` | Aggressive |
|
||||
|
||||
### Lyric Writing Tips
|
||||
|
||||
1. **6-10 syllables per line** — Model aligns syllables to beats; keep similar counts for lines in same position (±1-2)
|
||||
2. **Uppercase = stronger intensity** — `WE ARE THE CHAMPIONS!` (shouting) vs `walking through the streets` (normal)
|
||||
3. **Parentheses = background vocals** — `We rise together (together)`
|
||||
4. **Extend vowels** — `Feeeling so aliiive` (use cautiously, effects unstable)
|
||||
5. **Clear section separation** — Blank lines between sections
|
||||
|
||||
### Avoiding "AI-flavored" Lyrics
|
||||
|
||||
| Red Flag | Description |
|
||||
|----------|-------------|
|
||||
| **Adjective stacking** | "neon skies, electric hearts, endless dreams" — vague imagery filler |
|
||||
| **Rhyme chaos** | Inconsistent patterns or forced rhymes breaking meaning |
|
||||
| **Blurred boundaries** | Lyric content crosses structure tags |
|
||||
| **No breathing room** | Lines too long to sing in one breath |
|
||||
| **Mixed metaphors** | Water → fire → flying — listeners can't anchor |
|
||||
|
||||
**Metaphor discipline**: One core metaphor per song, explore its multiple aspects.
|
||||
|
||||
---
|
||||
|
||||
## Music Metadata
|
||||
|
||||
**Most of the time, let LM auto-infer.** Only set manually when you have clear requirements.
|
||||
|
||||
| Parameter | Range | Description |
|
||||
|-----------|-------|-------------|
|
||||
| `bpm` | 30–300 | Slow 60–80, mid 90–120, fast 130–180 |
|
||||
| `keyscale` | Key | e.g. `C Major`, `Am`. Common keys (C, G, D, Am, Em) most stable |
|
||||
| `timesignature` | Time sig | `4/4` (most common), `3/4` (waltz), `6/8` (swing) |
|
||||
| `vocal_language` | Language | Usually auto-detected from lyrics |
|
||||
| `duration` | Seconds | See duration calculation below |
|
||||
|
||||
### When to Set Manually
|
||||
|
||||
| Scenario | Set |
|
||||
|----------|-----|
|
||||
| Daily generation | Let LM auto-infer |
|
||||
| Clear tempo requirement | `bpm` |
|
||||
| Specific style (waltz) | `timesignature=3/4` |
|
||||
| Match other material | `bpm` + `duration` |
|
||||
| Specific key color | `keyscale` |
|
||||
|
||||
---
|
||||
|
||||
## Duration Calculation
|
||||
|
||||
### Estimation Method
|
||||
|
||||
- **Intro/Outro**: 5-10 seconds each
|
||||
- **Instrumental sections**: 5-15 seconds each
|
||||
- **Typical structures**:
|
||||
- 2 verses + 2 choruses: 120-150s minimum
|
||||
- 2 verses + 2 choruses + bridge: 180-240s minimum
|
||||
- Full song with intro/outro: 210-270s (3.5-4.5 min)
|
||||
|
||||
### BPM and Duration Relationship
|
||||
|
||||
- **Slower BPM (60-80)**: Need MORE duration for same lyrics
|
||||
- **Medium BPM (100-130)**: Standard duration
|
||||
- **Faster BPM (150-180)**: Can fit more lyrics, but still need breathing room
|
||||
|
||||
**Rule of thumb**: When in doubt, estimate longer. A song too short feels rushed.
|
||||
|
||||
---
|
||||
|
||||
Note: Lyrics tags (piano, powerful, whispered) are consistent with Caption (piano ballad, building to powerful chorus, intimate).
|
||||
|
|
@ -6,154 +6,74 @@ allowed-tools: Read, Write, Bash, Skill
|
|||
|
||||
# ACE-Step Music Generation Skill
|
||||
|
||||
Use ACE-Step V1.5 API for music generation. Script: `scripts/acestep.sh` (requires curl + jq).
|
||||
Use ACE-Step V1.5 API for music generation. **Always use `scripts/acestep.sh` script** — do NOT call API endpoints directly.
|
||||
|
||||
## Prerequisites - ACE-Step API Service
|
||||
|
||||
**IMPORTANT**: This skill requires the ACE-Step API server to be running.
|
||||
|
||||
### Required Dependencies
|
||||
|
||||
The `scripts/acestep.sh` script requires the following tools:
|
||||
|
||||
**1. curl** - For making HTTP requests to the API
|
||||
**2. jq** - For parsing JSON responses
|
||||
|
||||
#### Check Dependencies
|
||||
|
||||
Before using this skill, verify that the required tools are installed:
|
||||
## Quick Start
|
||||
|
||||
```bash
|
||||
# Check curl
|
||||
curl --version
|
||||
# 1. cd to this skill's directory
|
||||
cd {project_root}/{.claude or .codex}/skills/acestep/
|
||||
|
||||
# Check jq
|
||||
jq --version
|
||||
# 2. Check API service health
|
||||
./scripts/acestep.sh health
|
||||
|
||||
# 3. Generate with lyrics (recommended)
|
||||
./scripts/acestep.sh generate -c "pop, female vocal, piano" -l "[Verse] Your lyrics here..." --duration 120 --language zh
|
||||
|
||||
# 4. Output saved to: {project_root}/acestep_output/
|
||||
```
|
||||
|
||||
#### Installing jq
|
||||
## Workflow
|
||||
|
||||
If jq is not installed, the script will attempt to install it automatically. If automatic installation fails, install manually:
|
||||
|
||||
**Windows:**
|
||||
```bash
|
||||
# Using Chocolatey
|
||||
choco install jq
|
||||
|
||||
# Or download from: https://jqlang.github.io/jq/download/
|
||||
# Extract jq.exe and add to PATH
|
||||
```
|
||||
|
||||
**macOS:**
|
||||
```bash
|
||||
# Using Homebrew
|
||||
brew install jq
|
||||
|
||||
# Using MacPorts
|
||||
port install jq
|
||||
```
|
||||
|
||||
**Linux:**
|
||||
```bash
|
||||
# Debian/Ubuntu
|
||||
sudo apt-get install jq
|
||||
|
||||
# Fedora/RHEL/CentOS
|
||||
sudo yum install jq
|
||||
# or
|
||||
sudo dnf install jq
|
||||
|
||||
# Arch Linux
|
||||
sudo pacman -S jq
|
||||
```
|
||||
|
||||
**Verification:**
|
||||
```bash
|
||||
jq --version
|
||||
# Should output: jq-1.x
|
||||
```
|
||||
|
||||
If user reports jq installation issues, guide them through manual installation for their platform.
|
||||
|
||||
### Before First Use
|
||||
|
||||
**Ask the user about their setup:**
|
||||
|
||||
1. **"Do you have ACE-Step API service configured and running?"**
|
||||
|
||||
If **YES**:
|
||||
- Verify the API endpoint: `curl -s http://127.0.0.1:8001/health`
|
||||
- If using remote service, ask for the API URL and update `scripts/config.json`
|
||||
- Proceed with music generation
|
||||
|
||||
If **NO** or **NOT SURE**:
|
||||
- Ask: "Do you have ACE-Step installed?"
|
||||
|
||||
**If installed but not running**:
|
||||
- Use the acestep-docs skill to help them start the service
|
||||
- Guide them through startup process
|
||||
|
||||
**If not installed**:
|
||||
- Offer to help download and install ACE-Step
|
||||
- Ask: "Would you like to use the Windows portable package or install from source?"
|
||||
- Use acestep-docs skill to guide through installation
|
||||
|
||||
### Service Configuration
|
||||
|
||||
**Local Service (Default):**
|
||||
```json
|
||||
{
|
||||
"api_url": "http://127.0.0.1:8001",
|
||||
"api_key": ""
|
||||
}
|
||||
```
|
||||
|
||||
**Remote Service:**
|
||||
```json
|
||||
{
|
||||
"api_url": "http://your-server-ip:8001",
|
||||
"api_key": "your-api-key-if-needed"
|
||||
}
|
||||
```
|
||||
|
||||
To configure remote service, update `scripts/config.json` or use:
|
||||
```bash
|
||||
cd {skill_directory}/scripts/
|
||||
./acestep.sh config --set api_url "http://remote-server:8001"
|
||||
./acestep.sh config --set api_key "your-key"
|
||||
```
|
||||
|
||||
### Using acestep-docs Skill for Setup Help
|
||||
|
||||
**IMPORTANT**: For installation and startup, always use the acestep-docs skill to get complete and accurate guidance.
|
||||
|
||||
When user needs help with installation or startup, invoke the acestep-docs skill:
|
||||
|
||||
```
|
||||
Use the Skill tool to invoke: acestep-docs
|
||||
```
|
||||
|
||||
**DO NOT provide simplified startup commands** - each user's environment may be different. Always guide them to use acestep-docs for proper setup.
|
||||
|
||||
### Health Check
|
||||
|
||||
**To verify if service is running:**
|
||||
```bash
|
||||
curl http://127.0.0.1:8001/health
|
||||
# Should return: {"status":"ok",...}
|
||||
```
|
||||
|
||||
If health check fails, use acestep-docs skill to help user start the service properly.
|
||||
|
||||
---
|
||||
|
||||
**WORKFLOW**: For user requests requiring vocals, you should:
|
||||
1. Consult [Music Creation Guide](./music-creation-guide.md) for lyrics writing, caption creation, duration/BPM/key selection
|
||||
2. Write complete, well-structured lyrics yourself based on the guide
|
||||
For user requests requiring vocals:
|
||||
1. Use the **acestep-songwriting** skill for lyrics writing, caption creation, duration/BPM/key selection
|
||||
2. Write complete, well-structured lyrics yourself based on the songwriting guide
|
||||
3. Generate using Caption mode with `-c` and `-l` parameters
|
||||
|
||||
Only use Simple/Random mode (`-d` or `random`) for quick inspiration or instrumental exploration.
|
||||
|
||||
If the user needs a simple music video, use the **acestep-simplemv** skill to render one with waveform visualization and synced lyrics.
|
||||
|
||||
**MV Production Requirements**: Making a simple MV requires three additional skills to be installed:
|
||||
- **acestep-songwriting** — for writing lyrics and planning song structure
|
||||
- **acestep-lyrics-transcription** — for transcribing audio to timestamped lyrics (LRC)
|
||||
- **acestep-simplemv** — for rendering the final music video
|
||||
|
||||
## Script Commands
|
||||
|
||||
**CRITICAL - Complete Lyrics Input**: When providing lyrics via the `-l` parameter, you MUST pass ALL lyrics content WITHOUT any omission:
|
||||
- If user provides lyrics, pass the ENTIRE text they give you
|
||||
- If you generate lyrics yourself, pass the COMPLETE lyrics you created
|
||||
- NEVER truncate, shorten, or pass only partial lyrics
|
||||
- Missing lyrics will result in incomplete or incoherent songs
|
||||
|
||||
**Music Parameters**: Use the **acestep-songwriting** skill for guidance on duration, BPM, key scale, and time signature.
|
||||
|
||||
```bash
|
||||
# need to cd to this skill's directory first
|
||||
cd {project_root}/{.claude or .codex}/skills/acestep/
|
||||
|
||||
# Caption mode - RECOMMENDED: Write lyrics first, then generate
|
||||
./scripts/acestep.sh generate -c "Electronic pop, energetic synths" -l "[Verse] Your complete lyrics
|
||||
[Chorus] Full chorus here..." --duration 120 --bpm 128
|
||||
|
||||
# Instrumental only
|
||||
./scripts/acestep.sh generate "Jazz with saxophone"
|
||||
|
||||
# Quick exploration (Simple/Random mode)
|
||||
./scripts/acestep.sh generate -d "A cheerful song about spring"
|
||||
./scripts/acestep.sh random
|
||||
|
||||
# Options
|
||||
./scripts/acestep.sh generate "Rock" --duration 60 --batch 2
|
||||
./scripts/acestep.sh generate "EDM" --no-thinking # Faster
|
||||
|
||||
# Other commands
|
||||
./scripts/acestep.sh status <job_id>
|
||||
./scripts/acestep.sh health
|
||||
./scripts/acestep.sh models
|
||||
```
|
||||
|
||||
## Output Files
|
||||
|
||||
After generation, the script automatically saves results to the `acestep_output` folder in the project root (same level as `.claude`):
|
||||
|
|
@ -190,41 +110,6 @@ project_root/
|
|||
|
||||
To get the actual synthesized lyrics, parse the JSON and read the top-level `lyrics` field, not `metas.lyrics`.
|
||||
|
||||
## Script Commands
|
||||
|
||||
**CRITICAL - Complete Lyrics Input**: When providing lyrics via the `-l` parameter, you MUST pass ALL lyrics content WITHOUT any omission:
|
||||
- If user provides lyrics, pass the ENTIRE text they give you
|
||||
- If you generate lyrics yourself, pass the COMPLETE lyrics you created
|
||||
- NEVER truncate, shorten, or pass only partial lyrics
|
||||
- Missing lyrics will result in incomplete or incoherent songs
|
||||
|
||||
**Music Parameters**: Refer to [Music Creation Guide](./music-creation-guide.md) for how to calculate duration, choose BPM, key scale, and time signature.
|
||||
|
||||
```bash
|
||||
# need to cd skills path
|
||||
cd {project_root}/{.claude or .codex}/skills/acestep/
|
||||
|
||||
# Caption mode - RECOMMENDED: Write lyrics first, then generate
|
||||
./scripts/acestep.sh generate -c "Electronic pop, energetic synths" -l "[Verse] Your complete lyrics
|
||||
[Chorus] Full chorus here..." --duration 120 --bpm 128
|
||||
|
||||
# Instrumental only
|
||||
./scripts/acestep.sh generate "Jazz with saxophone"
|
||||
|
||||
# Quick exploration (Simple/Random mode)
|
||||
./scripts/acestep.sh generate -d "A cheerful song about spring"
|
||||
./scripts/acestep.sh random
|
||||
|
||||
# Options
|
||||
./scripts/acestep.sh generate "Rock" --duration 60 --batch 2
|
||||
./scripts/acestep.sh generate "EDM" --no-thinking # Faster
|
||||
|
||||
# Other commands
|
||||
./scripts/acestep.sh status <job_id>
|
||||
./scripts/acestep.sh health
|
||||
./scripts/acestep.sh models
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
**Important**: Configuration follows this priority (high to low):
|
||||
|
|
@ -239,6 +124,7 @@ cd {project_root}/{.claude or .codex}/skills/acestep/
|
|||
{
|
||||
"api_url": "http://127.0.0.1:8001",
|
||||
"api_key": "",
|
||||
"api_mode": "completion",
|
||||
"generation": {
|
||||
"thinking": true,
|
||||
"use_format": false,
|
||||
|
|
@ -255,102 +141,113 @@ cd {project_root}/{.claude or .codex}/skills/acestep/
|
|||
|--------|---------|-------------|
|
||||
| `api_url` | `http://127.0.0.1:8001` | API server address |
|
||||
| `api_key` | `""` | API authentication key (optional) |
|
||||
| `api_mode` | `completion` | API mode: `completion` (OpenRouter, default) or `native` (polling) |
|
||||
| `generation.thinking` | `true` | Enable 5Hz LM (higher quality, slower) |
|
||||
| `generation.audio_format` | `mp3` | Output format (mp3/wav/flac) |
|
||||
| `generation.vocal_language` | `en` | Vocal language |
|
||||
|
||||
## API Reference
|
||||
## Prerequisites - ACE-Step API Service
|
||||
|
||||
All responses wrapped: `{"data": <payload>, "code": 200, "error": null, "timestamp": ...}`
|
||||
**IMPORTANT**: This skill requires the ACE-Step API server to be running.
|
||||
|
||||
| Endpoint | Method | Description |
|
||||
|----------|--------|-------------|
|
||||
| `/health` | GET | Health check |
|
||||
| `/release_task` | POST | Create generation task |
|
||||
| `/query_result` | POST | Query task status, body: `{"task_id_list": ["id"]}` |
|
||||
| `/v1/models` | GET | List available models |
|
||||
| `/v1/audio?path={path}` | GET | Download audio file |
|
||||
### Required Dependencies
|
||||
|
||||
### Query Result Response
|
||||
The `scripts/acestep.sh` script requires: **curl** and **jq**.
|
||||
|
||||
```json
|
||||
{
|
||||
"data": [{
|
||||
"task_id": "xxx",
|
||||
"status": 1,
|
||||
"result": "[{\"file\":\"/v1/audio?path=...\",\"metas\":{\"bpm\":120,\"duration\":60,\"keyscale\":\"C Major\"}}]"
|
||||
}]
|
||||
}
|
||||
```bash
|
||||
# Check dependencies
|
||||
curl --version
|
||||
jq --version
|
||||
```
|
||||
|
||||
Status codes: `0` = processing, `1` = success, `2` = failed
|
||||
If jq is not installed, the script will attempt to install it automatically. If automatic installation fails:
|
||||
- **Windows**: `choco install jq` or download from https://jqlang.github.io/jq/download/
|
||||
- **macOS**: `brew install jq`
|
||||
- **Linux**: `sudo apt-get install jq` (Debian/Ubuntu) or `sudo dnf install jq` (Fedora)
|
||||
|
||||
## Request Parameters (`/release_task`)
|
||||
### Before First Use
|
||||
|
||||
Parameters can be placed in `param_obj` object.
|
||||
**You MUST check the API key and URL status before proceeding.** Run:
|
||||
|
||||
### Generation Modes
|
||||
|
||||
| Mode | Usage | When to Use |
|
||||
|------|-------|-------------|
|
||||
| **Caption** (Recommended) | `generate -c "style" -l "lyrics"` | For vocal songs - write lyrics yourself first |
|
||||
| **Simple** | `generate -d "description"` | Quick exploration, LM generates everything |
|
||||
| **Random** | `random` | Random generation for inspiration |
|
||||
|
||||
### Core Parameters
|
||||
|
||||
| Parameter | Type | Default | Description |
|
||||
|-----------|------|---------|-------------|
|
||||
| `prompt` | string | "" | Music style description (Caption mode) |
|
||||
| `lyrics` | string | "" | **Full lyrics content** - Pass ALL lyrics without omission. Use `[inst]` for instrumental. Partial/truncated lyrics = incomplete songs |
|
||||
| `sample_mode` | bool | false | Enable Simple/Random mode |
|
||||
| `sample_query` | string | "" | Description for Simple mode |
|
||||
| `thinking` | bool | false | Enable 5Hz LM for audio code generation |
|
||||
| `use_format` | bool | false | Use LM to enhance caption/lyrics |
|
||||
| `model` | string | - | DiT model name |
|
||||
| `batch_size` | int | 1 | Number of audio files to generate |
|
||||
|
||||
### Music Attributes
|
||||
|
||||
| Parameter | Type | Default | Description |
|
||||
|-----------|------|---------|-------------|
|
||||
| `audio_duration` | float | - | Duration in seconds |
|
||||
| `bpm` | int | - | Tempo (beats per minute) |
|
||||
| `key_scale` | string | "" | Key (e.g. "C Major") |
|
||||
| `time_signature` | string | "" | Time signature (e.g. "4/4") |
|
||||
| `vocal_language` | string | "en" | Language code (en, zh, ja, etc.) |
|
||||
| `audio_format` | string | "mp3" | Output format (mp3/wav/flac) |
|
||||
|
||||
### Generation Parameters
|
||||
|
||||
| Parameter | Type | Default | Description |
|
||||
|-----------|------|---------|-------------|
|
||||
| `inference_steps` | int | 8 | Diffusion steps |
|
||||
| `guidance_scale` | float | 7.0 | CFG scale |
|
||||
| `seed` | int | -1 | Random seed (-1 for random) |
|
||||
| `infer_method` | string | "ode" | Diffusion method (ode/sde) |
|
||||
|
||||
### Audio Task Parameters
|
||||
|
||||
| Parameter | Type | Default | Description |
|
||||
|-----------|------|---------|-------------|
|
||||
| `task_type` | string | "text2music" | text2music / continuation / repainting |
|
||||
| `src_audio_path` | string | - | Source audio for continuation |
|
||||
| `repainting_start` | float | 0.0 | Repainting start position (seconds) |
|
||||
| `repainting_end` | float | - | Repainting end position (seconds) |
|
||||
|
||||
### Example Request (Simple Mode)
|
||||
|
||||
```json
|
||||
{
|
||||
"sample_mode": true,
|
||||
"sample_query": "A cheerful pop song about spring",
|
||||
"thinking": true,
|
||||
"param_obj": {
|
||||
"duration": 60,
|
||||
"bpm": 120,
|
||||
"language": "en"
|
||||
},
|
||||
"batch_size": 2
|
||||
}
|
||||
```bash
|
||||
cd "{project_root}/{.claude or .codex}/skills/acestep/" && bash ./scripts/acestep.sh config --check-key
|
||||
cd "{project_root}/{.claude or .codex}/skills/acestep/" && bash ./scripts/acestep.sh config --get api_url
|
||||
```
|
||||
|
||||
#### Case 1: Using Official Cloud API (`https://api.acemusic.ai`) without API key
|
||||
|
||||
If `api_url` is `https://api.acemusic.ai` and `api_key` is `empty`, you MUST stop and guide the user to configure their key:
|
||||
|
||||
1. Tell the user: "You're using the ACE-Step official cloud API, but no API key is configured. An API key is required to use this service."
|
||||
2. Explain how to get a key: API keys are currently available through the official ACE-Step Discord community (https://discord.gg/bGVxwUyD). Additional distribution methods will be added in the future.
|
||||
3. Use `AskUserQuestion` to ask the user to provide their API key.
|
||||
4. Once provided, configure it:
|
||||
```bash
|
||||
cd "{project_root}/{.claude or .codex}/skills/acestep/" && bash ./scripts/acestep.sh config --set api_key <KEY>
|
||||
```
|
||||
5. Additionally, inform the user: "If you also want to render music videos (MV), it's recommended to configure a lyrics transcription API key as well (OpenAI Whisper or ElevenLabs Scribe), so that lyrics can be automatically transcribed with accurate timestamps. You can configure it later via the `acestep-lyrics-transcription` skill."
|
||||
|
||||
#### Case 2: API key is configured
|
||||
|
||||
Verify the API endpoint: `./scripts/acestep.sh health` and proceed with music generation.
|
||||
|
||||
#### Case 3: Using local/custom API without key
|
||||
|
||||
Local services (`http://127.0.0.1:*`) typically don't require a key. Verify with `./scripts/acestep.sh health` and proceed.
|
||||
|
||||
If health check fails:
|
||||
- Ask: "Do you have ACE-Step installed?"
|
||||
- **If installed but not running**: Use the acestep-docs skill to help them start the service
|
||||
- **If not installed**: Use acestep-docs skill to guide through installation
|
||||
|
||||
### Service Configuration
|
||||
|
||||
**Official Cloud API:** ACE-Step provides an official API endpoint at `https://api.acemusic.ai`. To use it:
|
||||
```bash
|
||||
./scripts/acestep.sh config --set api_url "https://api.acemusic.ai"
|
||||
./scripts/acestep.sh config --set api_key "your-key"
|
||||
./scripts/acestep.sh config --set api_mode completion
|
||||
```
|
||||
API keys are currently available through the official ACE-Step Discord community. Additional distribution methods will be added in the future.
|
||||
|
||||
**Local Service (Default):** No configuration needed — connects to `http://127.0.0.1:8001`.
|
||||
|
||||
**Custom Remote Service:** Update `scripts/config.json` or use:
|
||||
```bash
|
||||
./scripts/acestep.sh config --set api_url "http://remote-server:8001"
|
||||
./scripts/acestep.sh config --set api_key "your-key"
|
||||
```
|
||||
|
||||
**API Key Handling**: When checking whether an API key is configured, use `config --check-key` which only reports `configured` or `empty` without printing the actual key. **NEVER use `config --get api_key`** or read `config.json` directly — these would expose the user's API key. The `config --list` command is safe — it automatically masks API keys as `***` in output.
|
||||
|
||||
### API Mode
|
||||
|
||||
The skill supports two API modes. Switch via `api_mode` in `scripts/config.json`:
|
||||
|
||||
| Mode | Endpoint | Description |
|
||||
|------|----------|-------------|
|
||||
| `completion` (default) | `/v1/chat/completions` | OpenRouter-compatible, sync request, audio returned as base64 |
|
||||
| `native` | `/release_task` + `/query_result` | Async polling mode, supports all parameters |
|
||||
|
||||
**Switch mode:**
|
||||
```bash
|
||||
./scripts/acestep.sh config --set api_mode completion
|
||||
./scripts/acestep.sh config --set api_mode native
|
||||
```
|
||||
|
||||
**Completion mode notes:**
|
||||
- No polling needed — single request returns result directly
|
||||
- Audio is base64-encoded inline in the response (auto-decoded and saved)
|
||||
- `inference_steps`, `infer_method`, `shift` are not configurable (server defaults)
|
||||
- `--no-wait` and `status` commands are not applicable in completion mode
|
||||
- Requires `model` field — auto-detected from `/v1/models` if not specified
|
||||
|
||||
### Using acestep-docs Skill for Setup Help
|
||||
|
||||
**IMPORTANT**: For installation and startup, always use the acestep-docs skill to get complete and accurate guidance.
|
||||
|
||||
**DO NOT provide simplified startup commands** - each user's environment may be different. Always guide them to use acestep-docs for proper setup.
|
||||
|
||||
---
|
||||
|
||||
For API debugging, see [API Reference](./api-reference.md).
|
||||
|
|
|
|||
149
.claude/skills/acestep/api-reference.md
Normal file
149
.claude/skills/acestep/api-reference.md
Normal file
|
|
@ -0,0 +1,149 @@
|
|||
# ACE-Step API Reference
|
||||
|
||||
> For debugging and advanced usage only. Normal operations should use `scripts/acestep.sh`.
|
||||
|
||||
## Native Mode Endpoints
|
||||
|
||||
All responses wrapped: `{"data": <payload>, "code": 200, "error": null, "timestamp": ...}`
|
||||
|
||||
| Endpoint | Method | Description |
|
||||
|----------|--------|-------------|
|
||||
| `/health` | GET | Health check |
|
||||
| `/release_task` | POST | Create generation task |
|
||||
| `/query_result` | POST | Query task status, body: `{"task_id_list": ["id"]}` |
|
||||
| `/v1/models` | GET | List available models |
|
||||
| `/v1/audio?path={path}` | GET | Download audio file |
|
||||
|
||||
## Completion Mode Endpoints
|
||||
|
||||
| Endpoint | Method | Description |
|
||||
|----------|--------|-------------|
|
||||
| `/v1/chat/completions` | POST | Generate music (OpenRouter-compatible) |
|
||||
| `/v1/models` | GET | List available models (OpenRouter format) |
|
||||
|
||||
## Query Result Response
|
||||
|
||||
```json
|
||||
{
|
||||
"data": [{
|
||||
"task_id": "xxx",
|
||||
"status": 1,
|
||||
"result": "[{\"file\":\"/v1/audio?path=...\",\"metas\":{\"bpm\":120,\"duration\":60,\"keyscale\":\"C Major\"}}]"
|
||||
}]
|
||||
}
|
||||
```
|
||||
|
||||
Status codes: `0` = processing, `1` = success, `2` = failed
|
||||
|
||||
## Completion Mode Request (`/v1/chat/completions`)
|
||||
|
||||
**Caption mode** — prompt and lyrics wrapped in XML tags inside message content:
|
||||
```json
|
||||
{
|
||||
"model": "acestep/ACE-Step-v1.5",
|
||||
"messages": [{"role": "user", "content": "<prompt>Jazz with saxophone</prompt><lyrics>[Verse] Hello...</lyrics>"}],
|
||||
"stream": false,
|
||||
"thinking": true,
|
||||
"use_format": false,
|
||||
"audio_config": {"duration": 90, "bpm": 110, "format": "mp3", "vocal_language": "en"}
|
||||
}
|
||||
```
|
||||
|
||||
**Simple mode** — plain text message, set `sample_mode: true`:
|
||||
```json
|
||||
{
|
||||
"model": "acestep/ACE-Step-v1.5",
|
||||
"messages": [{"role": "user", "content": "A cheerful pop song about spring"}],
|
||||
"stream": false,
|
||||
"sample_mode": true,
|
||||
"thinking": true
|
||||
}
|
||||
```
|
||||
|
||||
## Completion Mode Response
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "chatcmpl-abc123",
|
||||
"choices": [{
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "## Metadata\n**Caption:** ...\n**BPM:** 128\n\n## Lyrics\n...",
|
||||
"audio": [{"type": "audio_url", "audio_url": {"url": "data:audio/mpeg;base64,..."}}]
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}]
|
||||
}
|
||||
```
|
||||
|
||||
Audio is base64-encoded inline — the script auto-decodes and saves to `acestep_output/`.
|
||||
|
||||
## Request Parameters (`/release_task`)
|
||||
|
||||
Parameters can be placed in `param_obj` object.
|
||||
|
||||
### Generation Modes
|
||||
|
||||
| Mode | Usage | When to Use |
|
||||
|------|-------|-------------|
|
||||
| **Caption** (Recommended) | `generate -c "style" -l "lyrics"` | For vocal songs - write lyrics yourself first |
|
||||
| **Simple** | `generate -d "description"` | Quick exploration, LM generates everything |
|
||||
| **Random** | `random` | Random generation for inspiration |
|
||||
|
||||
### Core Parameters
|
||||
|
||||
| Parameter | Type | Default | Description |
|
||||
|-----------|------|---------|-------------|
|
||||
| `prompt` | string | "" | Music style description (Caption mode) |
|
||||
| `lyrics` | string | "" | **Full lyrics content** - Pass ALL lyrics without omission. Use `[inst]` for instrumental. Partial/truncated lyrics = incomplete songs |
|
||||
| `sample_mode` | bool | false | Enable Simple/Random mode |
|
||||
| `sample_query` | string | "" | Description for Simple mode |
|
||||
| `thinking` | bool | false | Enable 5Hz LM for audio code generation |
|
||||
| `use_format` | bool | false | Use LM to enhance caption/lyrics |
|
||||
| `model` | string | - | DiT model name |
|
||||
| `batch_size` | int | 1 | Number of audio files to generate |
|
||||
|
||||
### Music Attributes
|
||||
|
||||
| Parameter | Type | Default | Description |
|
||||
|-----------|------|---------|-------------|
|
||||
| `audio_duration` | float | - | Duration in seconds |
|
||||
| `bpm` | int | - | Tempo (beats per minute) |
|
||||
| `key_scale` | string | "" | Key (e.g. "C Major") |
|
||||
| `time_signature` | string | "" | Time signature (e.g. "4/4") |
|
||||
| `vocal_language` | string | "en" | Language code (en, zh, ja, etc.) |
|
||||
| `audio_format` | string | "mp3" | Output format (mp3/wav/flac) |
|
||||
|
||||
### Generation Parameters
|
||||
|
||||
| Parameter | Type | Default | Description |
|
||||
|-----------|------|---------|-------------|
|
||||
| `inference_steps` | int | 8 | Diffusion steps |
|
||||
| `guidance_scale` | float | 7.0 | CFG scale |
|
||||
| `seed` | int | -1 | Random seed (-1 for random) |
|
||||
| `infer_method` | string | "ode" | Diffusion method (ode/sde) |
|
||||
|
||||
### Audio Task Parameters
|
||||
|
||||
| Parameter | Type | Default | Description |
|
||||
|-----------|------|---------|-------------|
|
||||
| `task_type` | string | "text2music" | text2music / continuation / repainting |
|
||||
| `src_audio_path` | string | - | Source audio for continuation |
|
||||
| `repainting_start` | float | 0.0 | Repainting start position (seconds) |
|
||||
| `repainting_end` | float | - | Repainting end position (seconds) |
|
||||
|
||||
### Example Request (Simple Mode)
|
||||
|
||||
```json
|
||||
{
|
||||
"sample_mode": true,
|
||||
"sample_query": "A cheerful pop song about spring",
|
||||
"thinking": true,
|
||||
"param_obj": {
|
||||
"duration": 60,
|
||||
"bpm": 120,
|
||||
"language": "en"
|
||||
},
|
||||
"batch_size": 2
|
||||
}
|
||||
```
|
||||
|
|
@ -1,350 +0,0 @@
|
|||
# ACE-Step Music Creation Guide
|
||||
|
||||
> This guide contains professional music creation knowledge extracted from ACE-Step Tutorial. Use this as reference when creating music with ACE-Step.
|
||||
|
||||
---
|
||||
|
||||
## Input Control: What Do You Want?
|
||||
|
||||
This is the part where you communicate "creative intent" with the model—what kind of music you want to generate.
|
||||
|
||||
| Category | Parameter | Function |
|
||||
|----------|-----------|----------|
|
||||
| **Task Type** | `task_type` | Determines generation mode: text2music, cover, repaint, lego, extract, complete |
|
||||
| **Text Input** | `caption` | Description of overall music elements: style, instruments, emotion, atmosphere, timbre, vocal gender, progression, etc. |
|
||||
| | `lyrics` | Temporal element description: lyric content, music structure evolution, vocal changes, vocal/instrument performance style, start/end style, articulation, etc. (use `[Instrumental]` for instrumental music) |
|
||||
| **Music Metadata** | `bpm` | Tempo (30–300) |
|
||||
| | `keyscale` | Key (e.g., C Major, Am) |
|
||||
| | `timesignature` | Time signature (4/4, 3/4, 6/8) |
|
||||
| | `vocal_language` | Vocal language |
|
||||
| | `duration` | Target duration (seconds) |
|
||||
| **Audio Reference** | `reference_audio` | Global reference for timbre or style (for cover, style transfer) |
|
||||
| | `src_audio` | Source audio for non-text2music tasks (text2music defaults to silence, no input needed) |
|
||||
| | `audio_codes` | Semantic codes input to model in Cover mode (advanced: reuse codes for variants, convert songs to codes for extension, combine like DJ mixing) |
|
||||
| **Interval Control** | `repainting_start/end` | Time interval for operations (repaint redraw area / lego new track area) |
|
||||
|
||||
---
|
||||
|
||||
## About Caption: The Most Important Input
|
||||
|
||||
**Caption is the most important factor affecting generated music.**
|
||||
|
||||
It supports multiple input formats: simple style words, comma-separated tags, complex natural language descriptions. We've trained to be compatible with various formats, ensuring text format doesn't significantly affect model performance.
|
||||
|
||||
### Common Dimensions for Caption Writing
|
||||
|
||||
| Dimension | Examples |
|
||||
|-----------|----------|
|
||||
| **Style/Genre** | pop, rock, jazz, electronic, hip-hop, R&B, folk, classical, lo-fi, synthwave |
|
||||
| **Emotion/Atmosphere** | melancholic, uplifting, energetic, dreamy, dark, nostalgic, euphoric, intimate |
|
||||
| **Instruments** | acoustic guitar, piano, synth pads, 808 drums, strings, brass, electric bass |
|
||||
| **Timbre Texture** | warm, bright, crisp, muddy, airy, punchy, lush, raw, polished |
|
||||
| **Era Reference** | 80s synth-pop, 90s grunge, 2010s EDM, vintage soul, modern trap |
|
||||
| **Production Style** | lo-fi, high-fidelity, live recording, studio-polished, bedroom pop |
|
||||
| **Vocal Characteristics** | female vocal, male vocal, breathy, powerful, falsetto, raspy, choir |
|
||||
| **Speed/Rhythm** | slow tempo, mid-tempo, fast-paced, groovy, driving, laid-back |
|
||||
| **Structure Hints** | building intro, catchy chorus, dramatic bridge, fade-out ending |
|
||||
|
||||
### Practical Principles for Caption Writing
|
||||
|
||||
1. **Specific beats vague** — "sad piano ballad with female breathy vocal" works better than "a sad song."
|
||||
|
||||
2. **Combine multiple dimensions** — Single-dimension descriptions give the model too much room to play; combining style+emotion+instruments+timbre can more precisely anchor your desired direction.
|
||||
|
||||
3. **Use references well** — "in the style of 80s synthwave" or "reminiscent of Bon Iver" can quickly convey complex aesthetic preferences.
|
||||
|
||||
4. **Texture words are useful** — Adjectives like warm, crisp, airy, punchy can influence mixing and timbre tendencies.
|
||||
|
||||
5. **Don't pursue perfect descriptions** — Caption is a starting point, not an endpoint. Write a general direction first, then iterate based on results.
|
||||
|
||||
6. **Description granularity determines freedom** — More omitted descriptions give the model more room to play, more random factor influence; more detailed descriptions constrain the model more. Decide specificity based on your needs—want surprises? Write less. Want control? Write more details.
|
||||
|
||||
7. **Avoid conflicting words** — Conflicting style combinations easily lead to degraded output. For example, wanting both "classical strings" and "hardcore metal" simultaneously—the model will try to fuse but usually not ideal.
|
||||
|
||||
**Ways to resolve conflicts:**
|
||||
- **Repetition reinforcement** — Strengthen the elements you want more in mixed styles by repeating certain words
|
||||
- **Conflict to evolution** — Transform style conflicts into temporal style evolution. For example: "Start with soft strings, middle becomes noisy dynamic metal rock, end turns to hip-hop"—this gives the model clear guidance on how to handle different styles, rather than mixing them into a mess
|
||||
|
||||
---
|
||||
|
||||
## About Lyrics: The Temporal Script
|
||||
|
||||
If Caption describes the music's "overall portrait"—style, atmosphere, timbre—then **Lyrics is the music's "temporal script"**, controlling how music unfolds over time.
|
||||
|
||||
Lyrics is not just lyric content. It carries:
|
||||
- The lyric text itself
|
||||
- **Structure tags** ([Verse], [Chorus], [Bridge]...)
|
||||
- **Vocal style hints** ([raspy vocal], [whispered]...)
|
||||
- **Instrumental sections** ([guitar solo], [drum break]...)
|
||||
- **Energy changes** ([building energy], [explosive drop]...)
|
||||
|
||||
### Common Structure Tags
|
||||
|
||||
| Category | Tag | Description |
|
||||
|----------|-----|-------------|
|
||||
| **Basic Structure** | `[Intro]` | Opening, establish atmosphere |
|
||||
| | `[Verse]` / `[Verse 1]` | Verse, narrative progression |
|
||||
| | `[Pre-Chorus]` | Pre-chorus, build energy |
|
||||
| | `[Chorus]` | Chorus, emotional climax |
|
||||
| | `[Bridge]` | Bridge, transition or elevation |
|
||||
| | `[Outro]` | Ending, conclusion |
|
||||
| **Dynamic Sections** | `[Build]` | Energy gradually rising |
|
||||
| | `[Drop]` | Electronic music energy release |
|
||||
| | `[Breakdown]` | Reduced instrumentation, space |
|
||||
| **Instrumental Sections** | `[Instrumental]` | Pure instrumental, no vocals |
|
||||
| | `[Guitar Solo]` | Guitar solo |
|
||||
| | `[Piano Interlude]` | Piano interlude |
|
||||
| **Special Tags** | `[Fade Out]` | Fade out ending |
|
||||
| | `[Silence]` | Silence |
|
||||
|
||||
### Combining Tags: Use Moderately
|
||||
|
||||
Structure tags can be combined with `-` for finer control:
|
||||
|
||||
```
|
||||
[Chorus - anthemic]
|
||||
This is the chorus lyrics
|
||||
Dreams are burning
|
||||
|
||||
[Bridge - whispered]
|
||||
Whisper those words softly
|
||||
```
|
||||
|
||||
⚠️ **Note: Don't stack too many tags.**
|
||||
|
||||
```
|
||||
❌ Not recommended:
|
||||
[Chorus - anthemic - stacked harmonies - high energy - powerful - epic]
|
||||
|
||||
✅ Recommended:
|
||||
[Chorus - anthemic]
|
||||
```
|
||||
|
||||
**Principle**: Keep structure tags concise; put complex style descriptions in Caption.
|
||||
|
||||
### ⚠️ Key: Maintain Consistency Between Caption and Lyrics
|
||||
|
||||
**Models are not good at resolving conflicts.** If descriptions in Caption and Lyrics contradict, the model gets confused and output quality decreases.
|
||||
|
||||
**Checklist:**
|
||||
- Instruments in Caption ↔ Instrumental section tags in Lyrics
|
||||
- Emotion in Caption ↔ Energy tags in Lyrics
|
||||
- Vocal description in Caption ↔ Vocal control tags in Lyrics
|
||||
|
||||
Think of Caption as "overall setting" and Lyrics as "shot script"—they should tell the same story.
|
||||
|
||||
### Vocal Control Tags
|
||||
|
||||
| Tag | Effect |
|
||||
|-----|--------|
|
||||
| `[raspy vocal]` | Raspy, textured vocals |
|
||||
| `[whispered]` | Whispered |
|
||||
| `[falsetto]` | Falsetto |
|
||||
| `[powerful belting]` | Powerful, high-pitched singing |
|
||||
| `[spoken word]` | Rap/recitation |
|
||||
| `[harmonies]` | Layered harmonies |
|
||||
| `[call and response]` | Call and response |
|
||||
| `[ad-lib]` | Improvised embellishments |
|
||||
|
||||
### Energy and Emotion Tags
|
||||
|
||||
| Tag | Effect |
|
||||
|-----|--------|
|
||||
| `[high energy]` | High energy, passionate |
|
||||
| `[low energy]` | Low energy, restrained |
|
||||
| `[building energy]` | Increasing energy |
|
||||
| `[explosive]` | Explosive energy |
|
||||
| `[melancholic]` | Melancholic |
|
||||
| `[euphoric]` | Euphoric |
|
||||
| `[dreamy]` | Dreamy |
|
||||
| `[aggressive]` | Aggressive |
|
||||
|
||||
### Lyric Text Writing Tips
|
||||
|
||||
**1. Control Syllable Count**
|
||||
|
||||
**6-10 syllables per line** usually works best. The model aligns syllables to beats—if one line has 6 syllables and the next has 14, rhythm becomes strange.
|
||||
|
||||
**Tip**: Keep similar syllable counts for lines in the same position (e.g., first line of each verse) (±1-2 deviation).
|
||||
|
||||
**2. Use Case to Control Intensity**
|
||||
|
||||
Uppercase indicates stronger vocal intensity:
|
||||
|
||||
```
|
||||
[Verse]
|
||||
walking through the empty streets (normal intensity)
|
||||
|
||||
[Chorus]
|
||||
WE ARE THE CHAMPIONS! (high intensity, shouting)
|
||||
```
|
||||
|
||||
**3. Use Parentheses for Background Vocals**
|
||||
|
||||
```
|
||||
[Chorus]
|
||||
We rise together (together)
|
||||
Into the light (into the light)
|
||||
```
|
||||
|
||||
Content in parentheses is processed as background vocals or harmonies.
|
||||
|
||||
**4. Extend Vowels**
|
||||
|
||||
You can extend sounds by repeating vowels:
|
||||
|
||||
```
|
||||
Feeeling so aliiive
|
||||
```
|
||||
|
||||
But use cautiously—effects are unstable, sometimes ignored or mispronounced.
|
||||
|
||||
**5. Clear Section Separation**
|
||||
|
||||
Separate each section with blank lines:
|
||||
|
||||
```
|
||||
[Verse 1]
|
||||
First verse lyrics
|
||||
Continue first verse
|
||||
|
||||
[Chorus]
|
||||
Chorus lyrics
|
||||
Chorus continues
|
||||
```
|
||||
|
||||
### Avoiding "AI-flavored" Lyrics
|
||||
|
||||
These characteristics make lyrics seem mechanical and lack human touch:
|
||||
|
||||
| Red Flag 🚩 | Description |
|
||||
|-------------|-------------|
|
||||
| **Adjective stacking** | "neon skies, electric hearts, endless dreams"—filling a section with vague imagery |
|
||||
| **Rhyme chaos** | Inconsistent rhyme patterns, or forced rhymes causing semantic breaks |
|
||||
| **Blurred section boundaries** | Lyric content crosses structure tags, Verse content "flows" into Chorus |
|
||||
| **No breathing room** | Each line too long, can't sing in one breath |
|
||||
| **Mixed metaphors** | First verse uses water imagery, second suddenly becomes fire, third is flying—listeners can't anchor |
|
||||
|
||||
**Metaphor discipline**: Stick to one core metaphor per song, exploring its multiple aspects.
|
||||
|
||||
---
|
||||
|
||||
## About Music Metadata: Optional Fine Control
|
||||
|
||||
**Most of the time, you don't need to manually set metadata.**
|
||||
|
||||
When you enable `thinking` mode (or enable `use_cot_metas`), LM automatically infers appropriate BPM, key, time signature, etc. based on your Caption and Lyrics. This is usually good enough.
|
||||
|
||||
But if you have clear ideas, you can also manually control them:
|
||||
|
||||
| Parameter | Control Range | Description |
|
||||
|-----------|--------------|-------------|
|
||||
| `bpm` | 30–300 | Tempo. Common distribution: slow songs 60–80, mid-tempo 90–120, fast songs 130–180 |
|
||||
| `keyscale` | Key | e.g., `C Major`, `Am`, `F# Minor`. Affects overall pitch and emotional color |
|
||||
| `timesignature` | Time signature | `4/4` (most common), `3/4` (waltz), `6/8` (swing feel) |
|
||||
| `vocal_language` | Language | Vocal language. LM usually auto-detects from lyrics |
|
||||
| `duration` | Seconds | Target duration. Actual generation may vary slightly |
|
||||
|
||||
### Understanding Control Boundaries
|
||||
|
||||
These parameters are **guidance** rather than **precise commands**:
|
||||
|
||||
- **BPM**: Common range (60–180) works well; extreme values (like 30 or 280) have less training data, may be unstable
|
||||
- **Key**: Common keys (C, G, D, Am, Em) are stable; rare keys may be ignored or shifted
|
||||
- **Time signature**: `4/4` is most reliable; `3/4`, `6/8` usually OK; complex signatures (like `5/4`, `7/8`) are advanced, effects vary by style
|
||||
- **Duration**: Short songs (30–60s) and medium length (2–4min) are stable; very long generation may have repetition or structure issues
|
||||
|
||||
### When Do You Need Manual Settings?
|
||||
|
||||
| Scenario | Suggestion |
|
||||
|----------|------------|
|
||||
| Daily generation | Don't worry, let LM auto-infer |
|
||||
| Clear tempo requirement | Manually set `bpm` |
|
||||
| Specific style (e.g., waltz) | Manually set `timesignature=3/4` |
|
||||
| Need to match other material | Manually set `bpm` and `duration` |
|
||||
| Pursue specific key color | Manually set `keyscale` |
|
||||
|
||||
**Tip**: If you manually set metadata but generation results clearly don't match—check if there's conflict with Caption/Lyrics. For example, Caption says "slow ballad" but `bpm=160`, the model gets confused.
|
||||
|
||||
**Recommended Practice**: Don't write tempo, BPM, key, and other metadata information in Caption. These should be set through dedicated metadata parameters (`bpm`, `keyscale`, `timesignature`, etc.), not described in Caption. Caption should focus on style, emotion, instruments, timbre, and other musical characteristics, while metadata information is handled by corresponding parameters.
|
||||
|
||||
---
|
||||
|
||||
## Duration Calculation Guidelines
|
||||
|
||||
When creating music, you MUST calculate appropriate duration based on lyrics content and song structure:
|
||||
|
||||
### Estimation Method
|
||||
|
||||
- **Per line of lyrics**: 3-5 seconds
|
||||
- **Intro/Outro**: 5-10 seconds each
|
||||
- **Instrumental sections**: 5-15 seconds each
|
||||
- **Typical song structures**:
|
||||
- 2 verses + 2 choruses: 120-150 seconds minimum
|
||||
- 2 verses + 2 choruses + bridge: 180-240 seconds minimum
|
||||
- Full song with intro/outro: 210-270 seconds (3.5-4.5 minutes)
|
||||
|
||||
### Common Pitfalls
|
||||
|
||||
❌ **DON'T**: Set duration too short for the lyrics amount
|
||||
- Example: 10 lines of lyrics with 120 seconds → rushed, compressed
|
||||
|
||||
✅ **DO**: Calculate realistic duration
|
||||
- Example: 10 lines of lyrics → ~40 seconds of vocals + 20 seconds intro/outro = 60 seconds minimum
|
||||
|
||||
### BPM and Duration Relationship
|
||||
|
||||
The BPM affects how quickly lyrics are sung:
|
||||
- **Slower BPM (60-80)**: Need MORE duration for same lyrics
|
||||
- **Medium BPM (100-130)**: Standard duration
|
||||
- **Faster BPM (150-180)**: Can fit more lyrics in less time, but still need breathing room
|
||||
|
||||
**Rule of thumb**: When in doubt, estimate longer rather than shorter. A song that's too short will feel rushed and incomplete.
|
||||
|
||||
---
|
||||
|
||||
## Complete Example
|
||||
|
||||
Assuming Caption is: `female vocal, piano ballad, emotional, intimate atmosphere, strings, building to powerful chorus`
|
||||
|
||||
```
|
||||
[Intro - piano]
|
||||
|
||||
[Verse 1]
|
||||
月光洒在窗台上
|
||||
我听见你的呼吸
|
||||
城市在远处沉睡
|
||||
只有我们还醒着
|
||||
|
||||
[Pre-Chorus]
|
||||
这一刻如此安静
|
||||
却藏着汹涌的心
|
||||
|
||||
[Chorus - powerful]
|
||||
让我们燃烧吧
|
||||
像夜空中的烟火
|
||||
短暂却绚烂
|
||||
这就是我们的时刻
|
||||
|
||||
[Verse 2]
|
||||
时间在指尖流过
|
||||
我们抓不住什么
|
||||
但至少此刻拥有
|
||||
彼此眼中的火焰
|
||||
|
||||
[Bridge - whispered]
|
||||
如果明天一切消散
|
||||
至少我们曾经闪耀
|
||||
|
||||
[Final Chorus]
|
||||
让我们燃烧吧
|
||||
像夜空中的烟火
|
||||
短暂却绚烂
|
||||
THIS IS OUR MOMENT!
|
||||
|
||||
[Outro - fade out]
|
||||
```
|
||||
|
||||
Note: In this example, Lyrics tags (piano, powerful, whispered) are consistent with Caption descriptions (piano ballad, building to powerful chorus, intimate), with no conflicts.
|
||||
|
||||
---
|
||||
|
|
@ -72,6 +72,7 @@ ensure_output_dir() {
|
|||
DEFAULT_CONFIG='{
|
||||
"api_url": "http://127.0.0.1:8001",
|
||||
"api_key": "",
|
||||
"api_mode": "native",
|
||||
"generation": {
|
||||
"thinking": true,
|
||||
"use_format": true,
|
||||
|
|
@ -85,7 +86,15 @@ DEFAULT_CONFIG='{
|
|||
# Ensure config file exists
|
||||
ensure_config() {
|
||||
if [ ! -f "$CONFIG_FILE" ]; then
|
||||
echo "$DEFAULT_CONFIG" > "$CONFIG_FILE"
|
||||
local example="${SCRIPT_DIR}/config.example.json"
|
||||
if [ -f "$example" ]; then
|
||||
cp "$example" "$CONFIG_FILE"
|
||||
echo -e "${YELLOW}Config file created from config.example.json. Please configure your settings:${NC}"
|
||||
echo -e " ${CYAN}./scripts/acestep.sh config --set api_url <url>${NC}"
|
||||
echo -e " ${CYAN}./scripts/acestep.sh config --set api_key <key>${NC}"
|
||||
else
|
||||
echo "$DEFAULT_CONFIG" > "$CONFIG_FILE"
|
||||
fi
|
||||
fi
|
||||
}
|
||||
|
||||
|
|
@ -244,11 +253,20 @@ cmd_config() {
|
|||
--set) action="set"; key="$2"; value="$3"; shift 3 ;;
|
||||
--reset) action="reset"; shift ;;
|
||||
--list) action="list"; shift ;;
|
||||
--check-key) action="check-key"; shift ;;
|
||||
*) shift ;;
|
||||
esac
|
||||
done
|
||||
|
||||
case "$action" in
|
||||
"check-key")
|
||||
local api_key=$(get_config "api_key")
|
||||
if [ -n "$api_key" ]; then
|
||||
echo "api_key: configured"
|
||||
else
|
||||
echo "api_key: empty"
|
||||
fi
|
||||
;;
|
||||
"get")
|
||||
[ -z "$key" ] && { echo -e "${RED}Error: --get requires KEY${NC}"; exit 1; }
|
||||
local result=$(get_config "$key")
|
||||
|
|
@ -261,11 +279,11 @@ cmd_config() {
|
|||
"reset")
|
||||
echo "$DEFAULT_CONFIG" > "$CONFIG_FILE"
|
||||
echo -e "${GREEN}Configuration reset to defaults.${NC}"
|
||||
cat "$CONFIG_FILE"
|
||||
jq 'walk(if type == "object" and has("api_key") and (.api_key | length) > 0 then .api_key = "***" else . end)' "$CONFIG_FILE"
|
||||
;;
|
||||
"list")
|
||||
echo "Current configuration:"
|
||||
cat "$CONFIG_FILE"
|
||||
jq 'walk(if type == "object" and has("api_key") and (.api_key | length) > 0 then .api_key = "***" else . end)' "$CONFIG_FILE"
|
||||
;;
|
||||
*)
|
||||
echo "Config file: $CONFIG_FILE"
|
||||
|
|
@ -488,6 +506,197 @@ download_audios() {
|
|||
done <<< "$audio_paths"
|
||||
}
|
||||
|
||||
# =============================================================================
|
||||
# Completion Mode (OpenRouter /v1/chat/completions)
|
||||
# =============================================================================
|
||||
|
||||
# Load api_mode from config (default: native)
|
||||
load_api_mode() {
|
||||
local mode=$(get_config "api_mode")
|
||||
echo "${mode:-native}"
|
||||
}
|
||||
|
||||
# Get model ID from /v1/models endpoint for completion mode
|
||||
get_completion_model() {
|
||||
local api_url="$1"
|
||||
local user_model="$2"
|
||||
local api_key=$(load_api_key)
|
||||
|
||||
# If user specified a model, prefix with acemusic/ if needed
|
||||
if [ -n "$user_model" ]; then
|
||||
if [[ "$user_model" == */* ]]; then
|
||||
echo "$user_model"
|
||||
else
|
||||
echo "acemusic/${user_model}"
|
||||
fi
|
||||
return
|
||||
fi
|
||||
|
||||
# Query /v1/models for the first available model
|
||||
local response
|
||||
if [ -n "$api_key" ]; then
|
||||
response=$(curl -s -H "Authorization: Bearer ${api_key}" "${api_url}/v1/models" 2>/dev/null)
|
||||
else
|
||||
response=$(curl -s "${api_url}/v1/models" 2>/dev/null)
|
||||
fi
|
||||
|
||||
local model_id
|
||||
model_id=$(echo "$response" | jq -r '.data[0].id // empty' 2>/dev/null)
|
||||
echo "${model_id:-acemusic/acestep-v15-turbo}"
|
||||
}
|
||||
|
||||
# Decode base64 audio data URL and save to file
|
||||
# Handles cross-platform compatibility (Linux/macOS/Windows MSYS)
|
||||
decode_base64_audio() {
|
||||
local data_url="$1"
|
||||
local output_file="$2"
|
||||
|
||||
# Strip data URL prefix: data:audio/mpeg;base64,...
|
||||
local b64_data="${data_url#data:*;base64,}"
|
||||
|
||||
local tmp_b64=$(mktemp)
|
||||
printf '%s' "$b64_data" > "$tmp_b64"
|
||||
|
||||
if command -v base64 &> /dev/null; then
|
||||
# Linux / macOS / MSYS2
|
||||
base64 -d < "$tmp_b64" > "$output_file" 2>/dev/null || \
|
||||
base64 -D < "$tmp_b64" > "$output_file" 2>/dev/null || \
|
||||
python3 -c "import base64,sys; sys.stdout.buffer.write(base64.b64decode(sys.stdin.read()))" < "$tmp_b64" > "$output_file" 2>/dev/null || \
|
||||
python -c "import base64,sys; sys.stdout.buffer.write(base64.b64decode(sys.stdin.read()))" < "$tmp_b64" > "$output_file" 2>/dev/null
|
||||
else
|
||||
# Fallback to python
|
||||
python3 -c "import base64,sys; sys.stdout.buffer.write(base64.b64decode(sys.stdin.read()))" < "$tmp_b64" > "$output_file" 2>/dev/null || \
|
||||
python -c "import base64,sys; sys.stdout.buffer.write(base64.b64decode(sys.stdin.read()))" < "$tmp_b64" > "$output_file" 2>/dev/null
|
||||
fi
|
||||
|
||||
local decode_ok=$?
|
||||
rm -f "$tmp_b64"
|
||||
return $decode_ok
|
||||
}
|
||||
|
||||
# Parse completion response: extract metadata, save audio files
|
||||
# Usage: parse_completion_response <response_file> <job_id>
|
||||
parse_completion_response() {
|
||||
local resp_file="$1"
|
||||
local job_id="$2"
|
||||
|
||||
ensure_output_dir
|
||||
|
||||
local audio_format=$(get_config "generation.audio_format")
|
||||
[ -z "$audio_format" ] && audio_format="mp3"
|
||||
|
||||
# Check for error
|
||||
local finish_reason
|
||||
finish_reason=$(jq -r '.choices[0].finish_reason // "stop"' "$resp_file" 2>/dev/null)
|
||||
if [ "$finish_reason" = "error" ]; then
|
||||
local err_content
|
||||
err_content=$(jq -r '.choices[0].message.content // "Unknown error"' "$resp_file" 2>/dev/null)
|
||||
echo -e "${RED}Generation failed: $err_content${NC}"
|
||||
return 1
|
||||
fi
|
||||
|
||||
# Extract and display text content (metadata + lyrics)
|
||||
local content
|
||||
content=$(jq -r '.choices[0].message.content // empty' "$resp_file" 2>/dev/null)
|
||||
if [ -n "$content" ]; then
|
||||
echo "$content"
|
||||
echo ""
|
||||
fi
|
||||
|
||||
# Extract and save audio files
|
||||
local audio_count
|
||||
audio_count=$(jq -r '.choices[0].message.audio | length // 0' "$resp_file" 2>/dev/null)
|
||||
|
||||
if [ "$audio_count" -gt 0 ] 2>/dev/null; then
|
||||
local i=0
|
||||
while [ "$i" -lt "$audio_count" ]; do
|
||||
local audio_url
|
||||
audio_url=$(jq -r ".choices[0].message.audio[$i].audio_url.url // empty" "$resp_file" 2>/dev/null)
|
||||
|
||||
if [ -n "$audio_url" ]; then
|
||||
local output_file="${OUTPUT_DIR}/${job_id}_$((i+1)).${audio_format}"
|
||||
echo -e " ${CYAN}Decoding audio $((i+1))...${NC}"
|
||||
|
||||
if decode_base64_audio "$audio_url" "$output_file"; then
|
||||
if [ -f "$output_file" ] && [ -s "$output_file" ]; then
|
||||
echo -e " ${GREEN}Saved: $output_file${NC}"
|
||||
else
|
||||
echo -e " ${RED}Failed to decode audio $((i+1))${NC}"
|
||||
rm -f "$output_file" 2>/dev/null
|
||||
fi
|
||||
else
|
||||
echo -e " ${RED}Failed to decode audio $((i+1))${NC}"
|
||||
rm -f "$output_file" 2>/dev/null
|
||||
fi
|
||||
fi
|
||||
i=$((i+1))
|
||||
done
|
||||
else
|
||||
echo -e " ${YELLOW}No audio files in response${NC}"
|
||||
fi
|
||||
|
||||
# Save full response JSON (strip base64 audio to keep file small)
|
||||
local clean_resp
|
||||
clean_resp=$(jq 'del(.choices[].message.audio[].audio_url.url)' "$resp_file" 2>/dev/null)
|
||||
if [ -n "$clean_resp" ]; then
|
||||
save_result "$job_id" "$clean_resp"
|
||||
else
|
||||
save_result "$job_id" "$(cat "$resp_file")"
|
||||
fi
|
||||
}
|
||||
|
||||
# Send request to /v1/chat/completions and handle response
|
||||
# Usage: send_completion_request <api_url> <payload_file> <job_id_var>
|
||||
send_completion_request() {
|
||||
local api_url="$1"
|
||||
local payload_file="$2"
|
||||
local api_key=$(load_api_key)
|
||||
|
||||
local resp_file=$(mktemp)
|
||||
|
||||
local http_code
|
||||
if [ -n "$api_key" ]; then
|
||||
http_code=$(curl -s -w "%{http_code}" --connect-timeout 10 --max-time 660 \
|
||||
-o "$resp_file" \
|
||||
-X POST "${api_url}/v1/chat/completions" \
|
||||
-H "Content-Type: application/json; charset=utf-8" \
|
||||
-H "Authorization: Bearer ${api_key}" \
|
||||
--data-binary "@${payload_file}")
|
||||
else
|
||||
http_code=$(curl -s -w "%{http_code}" --connect-timeout 10 --max-time 660 \
|
||||
-o "$resp_file" \
|
||||
-X POST "${api_url}/v1/chat/completions" \
|
||||
-H "Content-Type: application/json; charset=utf-8" \
|
||||
--data-binary "@${payload_file}")
|
||||
fi
|
||||
|
||||
rm -f "$payload_file"
|
||||
|
||||
if [ "$http_code" != "200" ]; then
|
||||
local err_detail
|
||||
err_detail=$(jq -r '.detail // .error.message // empty' "$resp_file" 2>/dev/null)
|
||||
echo -e "${RED}Error: HTTP $http_code${NC}"
|
||||
[ -n "$err_detail" ] && echo -e "${RED}$err_detail${NC}"
|
||||
rm -f "$resp_file"
|
||||
return 1
|
||||
fi
|
||||
|
||||
# Generate a job_id from the completion id
|
||||
local job_id
|
||||
job_id=$(jq -r '.id // empty' "$resp_file" 2>/dev/null)
|
||||
[ -z "$job_id" ] && job_id="completion-$(date +%s)"
|
||||
|
||||
echo ""
|
||||
echo -e "${GREEN}Generation completed!${NC}"
|
||||
echo ""
|
||||
|
||||
parse_completion_response "$resp_file" "$job_id"
|
||||
rm -f "$resp_file"
|
||||
|
||||
echo ""
|
||||
echo -e "${GREEN}Done! Files saved to: $OUTPUT_DIR${NC}"
|
||||
}
|
||||
|
||||
# Wait for job and download results
|
||||
wait_for_job() {
|
||||
local api_url="$1"
|
||||
|
|
@ -646,6 +855,8 @@ cmd_generate() {
|
|||
[ -n "$bpm" ] && payload=$(echo "$payload" | jq --argjson v "$bpm" '. + {bpm: $v}')
|
||||
[ -n "$batch" ] && payload=$(echo "$payload" | jq --argjson v "$batch" '. + {batch_size: $v}')
|
||||
|
||||
local api_mode=$(load_api_mode)
|
||||
|
||||
echo "Generating music..."
|
||||
if [ -n "$description" ]; then
|
||||
echo " Mode: Simple (description)"
|
||||
|
|
@ -655,38 +866,91 @@ cmd_generate() {
|
|||
echo " Caption: ${caption:0:50}..."
|
||||
fi
|
||||
echo " Thinking: $thinking, Format: $use_format"
|
||||
echo " API: $api_mode"
|
||||
echo " Output: $OUTPUT_DIR"
|
||||
echo ""
|
||||
|
||||
# Write payload to temp file to ensure proper UTF-8 encoding
|
||||
local temp_payload=$(mktemp)
|
||||
printf '%s' "$payload" > "$temp_payload"
|
||||
if [ "$api_mode" = "completion" ]; then
|
||||
# --- Completion mode: /v1/chat/completions ---
|
||||
local model_id=$(get_completion_model "$api_url" "$model")
|
||||
|
||||
local api_key=$(load_api_key)
|
||||
local response
|
||||
if [ -n "$api_key" ]; then
|
||||
response=$(curl -s -X POST "${api_url}/release_task" \
|
||||
-H "Content-Type: application/json; charset=utf-8" \
|
||||
-H "Authorization: Bearer ${api_key}" \
|
||||
--data-binary "@${temp_payload}")
|
||||
# Build message content
|
||||
local message_content=""
|
||||
local sample_mode=false
|
||||
if [ -n "$description" ]; then
|
||||
message_content="$description"
|
||||
sample_mode=true
|
||||
else
|
||||
message_content="<prompt>${caption}</prompt>"
|
||||
[ -n "$lyrics" ] && message_content="${message_content}<lyrics>${lyrics}</lyrics>"
|
||||
fi
|
||||
|
||||
# Build completion payload
|
||||
local payload_c=$(jq -n \
|
||||
--arg model "$model_id" \
|
||||
--arg content "$message_content" \
|
||||
--argjson thinking "$thinking" \
|
||||
--argjson use_format "$use_format" \
|
||||
--argjson sample_mode "$sample_mode" \
|
||||
--argjson use_cot_caption "$cot_caption" \
|
||||
--argjson use_cot_language "$cot_language" \
|
||||
--arg vocal_language "$language" \
|
||||
--arg format "${def_audio_format:-mp3}" \
|
||||
'{
|
||||
model: $model,
|
||||
messages: [{"role": "user", "content": $content}],
|
||||
stream: false,
|
||||
thinking: $thinking,
|
||||
use_format: $use_format,
|
||||
sample_mode: $sample_mode,
|
||||
use_cot_caption: $use_cot_caption,
|
||||
use_cot_language: $use_cot_language,
|
||||
audio_config: {
|
||||
format: $format,
|
||||
vocal_language: $vocal_language
|
||||
}
|
||||
}')
|
||||
|
||||
# Add optional parameters to completion payload
|
||||
[ -n "$guidance" ] && payload_c=$(echo "$payload_c" | jq --argjson v "$guidance" '. + {guidance_scale: $v}')
|
||||
[ -n "$seed" ] && payload_c=$(echo "$payload_c" | jq --argjson v "$seed" '. + {seed: $v}')
|
||||
[ -n "$batch" ] && payload_c=$(echo "$payload_c" | jq --argjson v "$batch" '. + {batch_size: $v}')
|
||||
[ -n "$duration" ] && payload_c=$(echo "$payload_c" | jq --argjson v "$duration" '.audio_config.duration = $v')
|
||||
[ -n "$bpm" ] && payload_c=$(echo "$payload_c" | jq --argjson v "$bpm" '.audio_config.bpm = $v')
|
||||
|
||||
local temp_payload=$(mktemp)
|
||||
printf '%s' "$payload_c" > "$temp_payload"
|
||||
|
||||
send_completion_request "$api_url" "$temp_payload"
|
||||
else
|
||||
response=$(curl -s -X POST "${api_url}/release_task" \
|
||||
-H "Content-Type: application/json; charset=utf-8" \
|
||||
--data-binary "@${temp_payload}")
|
||||
fi
|
||||
# --- Native mode: /release_task + polling ---
|
||||
local temp_payload=$(mktemp)
|
||||
printf '%s' "$payload" > "$temp_payload"
|
||||
|
||||
rm -f "$temp_payload"
|
||||
local api_key=$(load_api_key)
|
||||
local response
|
||||
if [ -n "$api_key" ]; then
|
||||
response=$(curl -s -X POST "${api_url}/release_task" \
|
||||
-H "Content-Type: application/json; charset=utf-8" \
|
||||
-H "Authorization: Bearer ${api_key}" \
|
||||
--data-binary "@${temp_payload}")
|
||||
else
|
||||
response=$(curl -s -X POST "${api_url}/release_task" \
|
||||
-H "Content-Type: application/json; charset=utf-8" \
|
||||
--data-binary "@${temp_payload}")
|
||||
fi
|
||||
|
||||
# Response is wrapped: {"data": {"task_id": ...}, "code": 200, ...}
|
||||
local job_id=$(echo "$response" | jq -r '.data.task_id // .task_id // empty')
|
||||
rm -f "$temp_payload"
|
||||
|
||||
[ -z "$job_id" ] && { echo -e "${RED}Error: Failed to create job${NC}"; echo "$response"; exit 1; }
|
||||
local job_id=$(echo "$response" | jq -r '.data.task_id // .task_id // empty')
|
||||
[ -z "$job_id" ] && { echo -e "${RED}Error: Failed to create job${NC}"; echo "$response"; exit 1; }
|
||||
|
||||
if [ "$no_wait" = true ]; then
|
||||
echo "Job ID: $job_id"
|
||||
echo "Use '$0 status $job_id' to check progress and download"
|
||||
else
|
||||
wait_for_job "$api_url" "$job_id"
|
||||
if [ "$no_wait" = true ]; then
|
||||
echo "Job ID: $job_id"
|
||||
echo "Use '$0 status $job_id' to check progress and download"
|
||||
else
|
||||
wait_for_job "$api_url" "$job_id"
|
||||
fi
|
||||
fi
|
||||
}
|
||||
|
||||
|
|
@ -715,42 +979,67 @@ cmd_random() {
|
|||
# Normalize boolean for jq --argjson
|
||||
thinking=$(normalize_bool "$thinking" "true")
|
||||
|
||||
local api_mode=$(load_api_mode)
|
||||
|
||||
echo "Generating random music..."
|
||||
echo " Thinking: $thinking"
|
||||
echo " API: $api_mode"
|
||||
echo " Output: $OUTPUT_DIR"
|
||||
echo ""
|
||||
|
||||
local payload=$(jq -n --argjson thinking "$thinking" '{sample_mode: true, thinking: $thinking}')
|
||||
if [ "$api_mode" = "completion" ]; then
|
||||
# --- Completion mode ---
|
||||
local model_id=$(get_completion_model "$api_url" "")
|
||||
local def_audio_format=$(get_config "generation.audio_format")
|
||||
|
||||
# Write payload to temp file
|
||||
local temp_payload=$(mktemp)
|
||||
printf '%s' "$payload" > "$temp_payload"
|
||||
local payload_c=$(jq -n \
|
||||
--arg model "$model_id" \
|
||||
--argjson thinking "$thinking" \
|
||||
--arg format "${def_audio_format:-mp3}" \
|
||||
'{
|
||||
model: $model,
|
||||
messages: [{"role": "user", "content": "Generate a random song"}],
|
||||
stream: false,
|
||||
sample_mode: true,
|
||||
thinking: $thinking,
|
||||
audio_config: { format: $format }
|
||||
}')
|
||||
|
||||
local api_key=$(load_api_key)
|
||||
local response
|
||||
if [ -n "$api_key" ]; then
|
||||
response=$(curl -s -X POST "${api_url}/release_task" \
|
||||
-H "Content-Type: application/json; charset=utf-8" \
|
||||
-H "Authorization: Bearer ${api_key}" \
|
||||
--data-binary "@${temp_payload}")
|
||||
local temp_payload=$(mktemp)
|
||||
printf '%s' "$payload_c" > "$temp_payload"
|
||||
|
||||
send_completion_request "$api_url" "$temp_payload"
|
||||
else
|
||||
response=$(curl -s -X POST "${api_url}/release_task" \
|
||||
-H "Content-Type: application/json; charset=utf-8" \
|
||||
--data-binary "@${temp_payload}")
|
||||
fi
|
||||
# --- Native mode ---
|
||||
local payload=$(jq -n --argjson thinking "$thinking" '{sample_mode: true, thinking: $thinking}')
|
||||
|
||||
rm -f "$temp_payload"
|
||||
local temp_payload=$(mktemp)
|
||||
printf '%s' "$payload" > "$temp_payload"
|
||||
|
||||
# Response is wrapped: {"data": {"task_id": ...}, "code": 200, ...}
|
||||
local job_id=$(echo "$response" | jq -r '.data.task_id // .task_id // empty')
|
||||
local api_key=$(load_api_key)
|
||||
local response
|
||||
if [ -n "$api_key" ]; then
|
||||
response=$(curl -s -X POST "${api_url}/release_task" \
|
||||
-H "Content-Type: application/json; charset=utf-8" \
|
||||
-H "Authorization: Bearer ${api_key}" \
|
||||
--data-binary "@${temp_payload}")
|
||||
else
|
||||
response=$(curl -s -X POST "${api_url}/release_task" \
|
||||
-H "Content-Type: application/json; charset=utf-8" \
|
||||
--data-binary "@${temp_payload}")
|
||||
fi
|
||||
|
||||
[ -z "$job_id" ] && { echo -e "${RED}Error: Failed to create job${NC}"; echo "$response"; exit 1; }
|
||||
rm -f "$temp_payload"
|
||||
|
||||
if [ "$no_wait" = true ]; then
|
||||
echo "Job ID: $job_id"
|
||||
echo "Use '$0 status $job_id' to check progress and download"
|
||||
else
|
||||
wait_for_job "$api_url" "$job_id"
|
||||
local job_id=$(echo "$response" | jq -r '.data.task_id // .task_id // empty')
|
||||
[ -z "$job_id" ] && { echo -e "${RED}Error: Failed to create job${NC}"; echo "$response"; exit 1; }
|
||||
|
||||
if [ "$no_wait" = true ]; then
|
||||
echo "Job ID: $job_id"
|
||||
echo "Use '$0 status $job_id' to check progress and download"
|
||||
else
|
||||
wait_for_job "$api_url" "$job_id"
|
||||
fi
|
||||
fi
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
{
|
||||
"api_url": "http://127.0.0.1:8001",
|
||||
"api_url": "https://api.acemusic.ai",
|
||||
"api_key": "",
|
||||
"api_mode": "completion",
|
||||
"generation": {
|
||||
"thinking": true,
|
||||
"use_format": false,
|
||||
67
.github/copilot-instructions.md
vendored
Normal file
67
.github/copilot-instructions.md
vendored
Normal file
|
|
@ -0,0 +1,67 @@
|
|||
# ACE-Step 1.5 - GitHub Copilot Instructions
|
||||
|
||||
## Project Overview
|
||||
|
||||
ACE-Step 1.5 is an open-source music foundation model combining a Language Model (LM) as a planner with a Diffusion Transformer (DiT) for audio synthesis. It generates commercial-grade music on consumer hardware (< 4GB VRAM).
|
||||
|
||||
## Tech Stack
|
||||
|
||||
- **Python 3.11** (strict requirement)
|
||||
- **PyTorch 2.7+** with CUDA 12.8 (Windows/Linux), MPS (macOS ARM64)
|
||||
- **Transformers 4.51.0-4.57.x** for LLM inference
|
||||
- **Diffusers** for diffusion models
|
||||
- **Gradio 6.2.0** for web UI
|
||||
- **FastAPI + Uvicorn** for REST API server
|
||||
- **uv** for dependency management
|
||||
- **MLX** (Apple Silicon native acceleration, macOS ARM64)
|
||||
- **nano-vllm** (optimized LLM inference, non-macOS ARM64)
|
||||
|
||||
## Multi-Platform Support
|
||||
|
||||
**CRITICAL**: Supports CUDA, ROCm, Intel XPU, MPS, MLX, and CPU. When fixing bugs or adding features:
|
||||
- **DO NOT alter non-target platform paths** unless explicitly required
|
||||
- Changes to CUDA code should not affect MPS/XPU/CPU paths
|
||||
- Use `gpu_config.py` for hardware detection and configuration
|
||||
|
||||
## Code Organization
|
||||
|
||||
### Main Entry Points
|
||||
- `acestep/acestep_v15_pipeline.py` - Gradio UI pipeline
|
||||
- `acestep/api_server.py` - REST API server
|
||||
- `cli.py` - Command-line interface
|
||||
- `acestep/model_downloader.py` - Model downloader
|
||||
|
||||
### Core Modules
|
||||
- `acestep/handler.py` - Audio generation handler (AceStepHandler)
|
||||
- `acestep/llm_inference.py` - LLM handler for text processing
|
||||
- `acestep/inference.py` - Generation logic and parameters
|
||||
- `acestep/gpu_config.py` - Hardware detection and GPU configuration
|
||||
- `acestep/audio_utils.py` - Audio processing utilities
|
||||
- `acestep/constants.py` - Global constants
|
||||
|
||||
### UI & Internationalization
|
||||
- `acestep/gradio_ui/` - Gradio interface components
|
||||
- `acestep/gradio_ui/i18n.py` - i18n system (50+ languages)
|
||||
- All user-facing strings must use i18n translation keys
|
||||
|
||||
### Training
|
||||
- `acestep/training/` - LoRA training pipeline
|
||||
- `acestep/dataset/` - Dataset handling
|
||||
|
||||
## Key Conventions
|
||||
|
||||
- **Python style**: PEP 8, 4 spaces, double quotes for strings
|
||||
- **Naming**: `snake_case` functions/variables, `PascalCase` classes, `UPPER_SNAKE_CASE` constants
|
||||
- **Logging**: Use `loguru` logger (not `print()` except CLI output)
|
||||
- **Dependencies**: Use `uv add <package>` to add to `pyproject.toml`
|
||||
|
||||
## Performance
|
||||
|
||||
- Target: 4GB VRAM - minimize memory allocations
|
||||
- Lazy load models when needed
|
||||
- Batch operations supported (up to 8 songs)
|
||||
|
||||
## Additional Resources
|
||||
|
||||
- **AGENTS.md**: Detailed guidance for AI coding agents
|
||||
- **CONTRIBUTING.md**: Contribution workflow and guidelines
|
||||
99
.github/workflows/codeql.yml
vendored
Normal file
99
.github/workflows/codeql.yml
vendored
Normal file
|
|
@ -0,0 +1,99 @@
|
|||
# For most projects, this workflow file will not need changing; you simply need
|
||||
# to commit it to your repository.
|
||||
#
|
||||
# You may wish to alter this file to override the set of languages analyzed,
|
||||
# or to provide custom queries or build logic.
|
||||
#
|
||||
# ******** NOTE ********
|
||||
# We have attempted to detect the languages in your repository. Please check
|
||||
# the `language` matrix defined below to confirm you have the correct set of
|
||||
# supported CodeQL languages.
|
||||
#
|
||||
name: "CodeQL Advanced"
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ "main" ]
|
||||
pull_request:
|
||||
branches: [ "main" ]
|
||||
schedule:
|
||||
- cron: '26 2 * * 5'
|
||||
|
||||
jobs:
|
||||
analyze:
|
||||
name: Analyze (${{ matrix.language }})
|
||||
# Runner size impacts CodeQL analysis time. To learn more, please see:
|
||||
# - https://gh.io/recommended-hardware-resources-for-running-codeql
|
||||
# - https://gh.io/supported-runners-and-hardware-resources
|
||||
# - https://gh.io/using-larger-runners (GitHub.com only)
|
||||
# Consider using larger runners or machines with greater resources for possible analysis time improvements.
|
||||
runs-on: ${{ (matrix.language == 'swift' && 'macos-latest') || 'ubuntu-latest' }}
|
||||
permissions:
|
||||
# required for all workflows
|
||||
security-events: write
|
||||
|
||||
# required to fetch internal or private CodeQL packs
|
||||
packages: read
|
||||
|
||||
# only required for workflows in private repositories
|
||||
actions: read
|
||||
contents: read
|
||||
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- language: python
|
||||
build-mode: none
|
||||
# CodeQL supports the following values keywords for 'language': 'actions', 'c-cpp', 'csharp', 'go', 'java-kotlin', 'javascript-typescript', 'python', 'ruby', 'rust', 'swift'
|
||||
# Use `c-cpp` to analyze code written in C, C++ or both
|
||||
# Use 'java-kotlin' to analyze code written in Java, Kotlin or both
|
||||
# Use 'javascript-typescript' to analyze code written in JavaScript, TypeScript or both
|
||||
# To learn more about changing the languages that are analyzed or customizing the build mode for your analysis,
|
||||
# see https://docs.github.com/en/code-security/code-scanning/creating-an-advanced-setup-for-code-scanning/customizing-your-advanced-setup-for-code-scanning.
|
||||
# If you are analyzing a compiled language, you can modify the 'build-mode' for that language to customize how
|
||||
# your codebase is analyzed, see https://docs.github.com/en/code-security/code-scanning/creating-an-advanced-setup-for-code-scanning/codeql-code-scanning-for-compiled-languages
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
# Add any setup steps before running the `github/codeql-action/init` action.
|
||||
# This includes steps like installing compilers or runtimes (`actions/setup-node`
|
||||
# or others). This is typically only required for manual builds.
|
||||
# - name: Setup runtime (example)
|
||||
# uses: actions/setup-example@v1
|
||||
|
||||
# Initializes the CodeQL tools for scanning.
|
||||
- name: Initialize CodeQL
|
||||
uses: github/codeql-action/init@v4
|
||||
with:
|
||||
languages: ${{ matrix.language }}
|
||||
build-mode: ${{ matrix.build-mode }}
|
||||
# If you wish to specify custom queries, you can do so here or in a config file.
|
||||
# By default, queries listed here will override any specified in a config file.
|
||||
# Prefix the list here with "+" to use these queries and those in the config file.
|
||||
|
||||
# For more details on CodeQL's query packs, refer to: https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs
|
||||
# queries: security-extended,security-and-quality
|
||||
|
||||
# If the analyze step fails for one of the languages you are analyzing with
|
||||
# "We were unable to automatically build your code", modify the matrix above
|
||||
# to set the build mode to "manual" for that language. Then modify this step
|
||||
# to build your code.
|
||||
# ℹ️ Command-line programs to run using the OS shell.
|
||||
# 📚 See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun
|
||||
- name: Run manual build steps
|
||||
if: matrix.build-mode == 'manual'
|
||||
shell: bash
|
||||
run: |
|
||||
echo 'If you are using a "manual" build mode for one or more of the' \
|
||||
'languages you are analyzing, replace this with the commands to build' \
|
||||
'your code, for example:'
|
||||
echo ' make bootstrap'
|
||||
echo ' make release'
|
||||
exit 1
|
||||
|
||||
- name: Perform CodeQL Analysis
|
||||
uses: github/codeql-action/analyze@v4
|
||||
with:
|
||||
category: "/language:${{matrix.language}}"
|
||||
96
AGENTS.md
Normal file
96
AGENTS.md
Normal file
|
|
@ -0,0 +1,96 @@
|
|||
# AGENTS.md
|
||||
|
||||
Guidance for AI coding agents working in `ace-step/ACE-Step-1.5`.
|
||||
|
||||
This document is aligned with the intent from:
|
||||
- Discussion #408: functional decomposition to reduce risk from large mixed-responsibility files.
|
||||
- Discussion #365: low-risk contribution workflow, minimal scope, and review rigor.
|
||||
|
||||
## Primary Objectives
|
||||
|
||||
1. Keep changes safe and reviewable.
|
||||
2. Prefer small, maintainable, decomposed modules.
|
||||
3. Preserve behavior outside the target fix.
|
||||
4. Validate with focused Python unit tests.
|
||||
|
||||
## Scope and Change Control (Required)
|
||||
|
||||
- Solve one problem per task/PR.
|
||||
- Keep edits minimal: touch only files/functions required for the requested change.
|
||||
- Do not make drive-by refactors, formatting sweeps, or opportunistic cleanups.
|
||||
- Do not alter non-target hardware/runtime paths (CPU/CUDA/MPS/XPU) unless required by the task.
|
||||
- If any cross-path change is necessary, isolate it and justify it in the PR notes.
|
||||
- Preserve existing public interfaces unless the task explicitly requires an interface change.
|
||||
|
||||
## Decomposition and Module Size Policy
|
||||
|
||||
- Prefer single-responsibility modules with clear boundaries.
|
||||
- Target module size:
|
||||
- Optimal: `<= 150` LOC @coderabbit: Exceeding this limit is NIGGLE, not hard cap or major issue.
|
||||
- Hard cap: `200` LOC
|
||||
- Function decomposition rules:
|
||||
- Do one thing at a time; if a function description naturally contains "and", split it.
|
||||
- Split by responsibility, not by convenience.
|
||||
- Keep data flow explicit (`data in, data out`); side effects must be obvious and deliberate.
|
||||
- Push decisions up and push work down (orchestration at higher layers, execution details in lower layers).
|
||||
- The call graph should read clearly from top-level orchestration to leaf operations.
|
||||
- If a module would exceed `200` LOC:
|
||||
- Split by responsibility before merging, or
|
||||
- Add a short justification in PR notes and include a concrete follow-up split plan.
|
||||
- Keep orchestrator/facade modules thin. Move logic into focused helpers/services.
|
||||
- Preserve stable facade imports when splitting large files so external callers are not broken.
|
||||
|
||||
## Python Unit Testing Expectations
|
||||
|
||||
- Add or update tests for every behavior change and bug fix.
|
||||
- Match repository conventions:
|
||||
- Use `unittest`-style tests.
|
||||
- Name test files as `*_test.py` or `test_*.py`.
|
||||
- Keep tests deterministic, fast, and scoped to changed behavior.
|
||||
- Use mocks/fakes for GPU, filesystem, network, and external services where possible.
|
||||
- If a change requires mocking a large portion of the system to test one unit, treat that as a decomposition smell and refactor boundaries.
|
||||
- Include at least:
|
||||
- One success-path test.
|
||||
- One regression/edge-case test for the bug being fixed.
|
||||
- One non-target behavior check when relevant.
|
||||
- Run targeted tests locally before submitting.
|
||||
|
||||
## Feature Gating and WIP Safety
|
||||
|
||||
- Do not expose unfinished or non-functional user-facing flows by default.
|
||||
- Gate WIP or unstable UI/API paths behind explicit feature/release flags.
|
||||
- Keep default behavior stable; "coming soon" paths must not appear as usable functionality unless they are operational and tested.
|
||||
|
||||
## Python Coding Best Practices
|
||||
|
||||
- Use explicit, readable code over clever shortcuts.
|
||||
- Docstrings are mandatory for all new or modified Python modules, classes, and functions.
|
||||
- Docstrings must be concise and include purpose plus key inputs/outputs (and raised exceptions when relevant).
|
||||
- Add type hints for new/modified functions when practical.
|
||||
- Keep functions focused and short; extract helpers instead of nesting complexity.
|
||||
- Use clear names that describe behavior, not implementation trivia.
|
||||
- Prefer pure functions for logic-heavy paths where possible.
|
||||
- Avoid duplicated logic, but do not introduce broad abstractions too early; prefer simple local duplication over unstable premature abstraction.
|
||||
- Handle errors explicitly; avoid bare `except`.
|
||||
- Keep logging actionable; avoid noisy logs and `print` debugging in committed code.
|
||||
- Avoid hidden state and unintended side effects.
|
||||
- Write comments only where intent is non-obvious; keep comments concise and technical.
|
||||
|
||||
## AI-Agent Workflow (Recommended)
|
||||
|
||||
1. Understand the task and define explicit in-scope/out-of-scope boundaries.
|
||||
2. Propose a minimal patch plan before editing.
|
||||
3. Implement the smallest viable change.
|
||||
4. Add/update focused tests.
|
||||
5. Self-review only changed hunks for regressions and scope creep.
|
||||
6. Summarize risk, validation, and non-target impact in PR notes.
|
||||
|
||||
## PR Readiness Checklist
|
||||
|
||||
- [ ] Change is tightly scoped to one problem.
|
||||
- [ ] Non-target paths are unchanged, or changes are explicitly justified.
|
||||
- [ ] New/updated tests cover changed behavior and edge cases.
|
||||
- [ ] No unrelated refactor/formatting churn.
|
||||
- [ ] Required docstrings are present for all new/modified modules, classes, and functions.
|
||||
- [ ] WIP/unstable functionality is feature-flagged and not exposed as default-ready behavior.
|
||||
- [ ] Module LOC policy is met (`<=150` target, `<=200` hard cap or justified exception).
|
||||
|
|
@ -36,7 +36,7 @@ try:
|
|||
from .llm_inference import LLMHandler
|
||||
from .dataset_handler import DatasetHandler
|
||||
from .gradio_ui import create_gradio_interface
|
||||
from .gpu_config import get_gpu_config, get_gpu_memory_gb, print_gpu_config_info, set_global_gpu_config, VRAM_16GB_MIN_GB, VRAM_AUTO_OFFLOAD_THRESHOLD_GB
|
||||
from .gpu_config import get_gpu_config, get_gpu_memory_gb, print_gpu_config_info, set_global_gpu_config, VRAM_16GB_MIN_GB, VRAM_AUTO_OFFLOAD_THRESHOLD_GB, is_mps_platform
|
||||
from .model_downloader import ensure_lm_model
|
||||
except ImportError:
|
||||
# When executed as a script: `python acestep/acestep_v15_pipeline.py`
|
||||
|
|
@ -47,7 +47,7 @@ except ImportError:
|
|||
from acestep.llm_inference import LLMHandler
|
||||
from acestep.dataset_handler import DatasetHandler
|
||||
from acestep.gradio_ui import create_gradio_interface
|
||||
from acestep.gpu_config import get_gpu_config, get_gpu_memory_gb, print_gpu_config_info, set_global_gpu_config, VRAM_16GB_MIN_GB, VRAM_AUTO_OFFLOAD_THRESHOLD_GB
|
||||
from acestep.gpu_config import get_gpu_config, get_gpu_memory_gb, print_gpu_config_info, set_global_gpu_config, VRAM_16GB_MIN_GB, VRAM_AUTO_OFFLOAD_THRESHOLD_GB, is_mps_platform
|
||||
from acestep.model_downloader import ensure_lm_model
|
||||
|
||||
|
||||
|
|
@ -93,11 +93,14 @@ def main():
|
|||
set_global_gpu_config(gpu_config) # Set global config for use across modules
|
||||
|
||||
gpu_memory_gb = gpu_config.gpu_memory_gb
|
||||
_is_mac = is_mps_platform()
|
||||
# Enable auto-offload for GPUs below 20 GB. 16 GB GPUs cannot hold all
|
||||
# models simultaneously (DiT ~4.7 + VAE ~0.3 + text_enc ~1.2 + LM ≥1.2 +
|
||||
# activations) so they *must* offload. The old threshold of 16 GB caused
|
||||
# 16 GB GPUs to never offload, leading to OOM.
|
||||
auto_offload = gpu_memory_gb > 0 and gpu_memory_gb < VRAM_AUTO_OFFLOAD_THRESHOLD_GB
|
||||
# Mac (Apple Silicon) uses unified memory — offloading provides no benefit.
|
||||
auto_offload = (not _is_mac) and gpu_memory_gb > 0 and gpu_memory_gb < VRAM_AUTO_OFFLOAD_THRESHOLD_GB
|
||||
_default_backend = "mlx" if _is_mac else "vllm"
|
||||
|
||||
# Print GPU configuration info
|
||||
print(f"\n{'='*60}")
|
||||
|
|
@ -113,7 +116,9 @@ def main():
|
|||
print(f" Available LM Models: {gpu_config.available_lm_models or 'None'}")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
if auto_offload:
|
||||
if _is_mac:
|
||||
print(f"Apple Silicon (MPS) detected — unified memory {gpu_memory_gb:.1f}GB, no CPU offload needed, backend={_default_backend}")
|
||||
elif auto_offload:
|
||||
print(f"Auto-enabling CPU offload (GPU {gpu_memory_gb:.1f}GB < {VRAM_AUTO_OFFLOAD_THRESHOLD_GB}GB threshold)")
|
||||
elif gpu_memory_gb > 0:
|
||||
print(f"CPU offload disabled by default (GPU {gpu_memory_gb:.1f}GB >= {VRAM_AUTO_OFFLOAD_THRESHOLD_GB}GB threshold)")
|
||||
|
|
@ -152,7 +157,7 @@ def main():
|
|||
parser.add_argument("--device", type=str, default="auto", choices=["auto", "cuda", "mps", "xpu", "cpu"], help="Processing device (default: auto)")
|
||||
parser.add_argument("--init_llm", type=lambda x: x.lower() in ['true', '1', 'yes'], default=None, help="Initialize 5Hz LM (default: auto based on GPU memory)")
|
||||
parser.add_argument("--lm_model_path", type=str, default=None, help="5Hz LM model path (e.g., 'acestep-5Hz-lm-0.6B')")
|
||||
parser.add_argument("--backend", type=str, default="vllm", choices=["vllm", "pt", "mlx"], help="5Hz LM backend (default: vllm, use 'mlx' for native Apple Silicon acceleration)")
|
||||
parser.add_argument("--backend", type=str, default=_default_backend, choices=["vllm", "pt", "mlx"], help=f"5Hz LM backend (default: {_default_backend}, use 'mlx' for native Apple Silicon acceleration)")
|
||||
parser.add_argument("--use_flash_attention", type=lambda x: x.lower() in ['true', '1', 'yes'], default=None, help="Use flash attention (default: auto-detect)")
|
||||
parser.add_argument("--offload_to_cpu", type=lambda x: x.lower() in ['true', '1', 'yes'], default=auto_offload, help=f"Offload models to CPU (default: {'True' if auto_offload else 'False'}, auto-detected based on GPU VRAM)")
|
||||
parser.add_argument("--offload_dit_to_cpu", type=lambda x: x.lower() in ['true', '1', 'yes'], default=False, help="Offload DiT to CPU (default: False)")
|
||||
|
|
|
|||
|
|
@ -41,6 +41,7 @@ except ImportError: # Optional dependency
|
|||
load_dotenv = None # type: ignore
|
||||
|
||||
from fastapi import FastAPI, HTTPException, Request, Depends, Header
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel, Field
|
||||
from starlette.datastructures import UploadFile as StarletteUploadFile
|
||||
|
||||
|
|
@ -374,6 +375,8 @@ PARAM_ALIASES = {
|
|||
"use_cot_language": ["use_cot_language", "cot_language", "cot-language"],
|
||||
"is_format_caption": ["is_format_caption", "isFormatCaption"],
|
||||
"allow_lm_batch": ["allow_lm_batch", "allowLmBatch", "parallel_thinking"],
|
||||
"track_name": ["track_name", "trackName"],
|
||||
"track_classes": ["track_classes", "trackClasses", "instruments"],
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -519,6 +522,8 @@ class GenerateMusicRequest(BaseModel):
|
|||
use_cot_language: bool = True
|
||||
is_format_caption: bool = False
|
||||
allow_lm_batch: bool = True
|
||||
track_name: Optional[str] = None
|
||||
track_classes: Optional[List[str]] = None
|
||||
|
||||
lm_temperature: float = 0.85
|
||||
lm_cfg_scale: float = 2.5
|
||||
|
|
@ -893,13 +898,33 @@ class RequestParser:
|
|||
def _validate_audio_path(path: Optional[str]) -> Optional[str]:
|
||||
"""Validate a user-supplied audio file path to prevent path traversal attacks.
|
||||
|
||||
Rejects absolute paths and paths containing '..' traversal sequences.
|
||||
Returns the validated path or None if the input is None/empty.
|
||||
Accepts absolute paths strictly only if they are within the system temporary directory.
|
||||
Otherwise, rejects absolute paths and paths containing '..' traversal sequences.
|
||||
|
||||
Returns the validated, normalized path or None if the input is None/empty.
|
||||
Raises HTTPException 400 if the path is unsafe.
|
||||
"""
|
||||
if not path:
|
||||
return None
|
||||
# Reject absolute paths (Unix and Windows)
|
||||
|
||||
# Resolve requested path and system temp path to normalized absolute forms
|
||||
import tempfile
|
||||
system_temp = os.path.realpath(tempfile.gettempdir())
|
||||
requested_path = os.path.realpath(path)
|
||||
|
||||
# SECURE CHECK: Use os.path.commonpath to verify directory boundary integrity.
|
||||
# This prevents prefix bypasses (e.g., /tmp_evil when /tmp is allowed).
|
||||
try:
|
||||
is_in_temp = os.path.commonpath([system_temp, requested_path]) == system_temp
|
||||
except ValueError:
|
||||
# Occurs on Windows if paths are on different drives
|
||||
is_in_temp = False
|
||||
|
||||
if is_in_temp:
|
||||
# Accept server-generated files in temp
|
||||
return requested_path
|
||||
|
||||
# Reject manual absolute paths outside of temp
|
||||
if os.path.isabs(path):
|
||||
raise HTTPException(status_code=400, detail="absolute audio file paths are not allowed")
|
||||
# Reject path traversal via '..' components
|
||||
|
|
@ -1492,7 +1517,25 @@ def create_app() -> FastAPI:
|
|||
# This matches gradio behavior which uses TASK_INSTRUCTIONS for each task type
|
||||
instruction_to_use = req.instruction
|
||||
if instruction_to_use == DEFAULT_DIT_INSTRUCTION and req.task_type in TASK_INSTRUCTIONS:
|
||||
instruction_to_use = TASK_INSTRUCTIONS[req.task_type]
|
||||
raw_instruction = TASK_INSTRUCTIONS[req.task_type]
|
||||
|
||||
if req.task_type == "complete":
|
||||
# Use track_classes joined by pipes
|
||||
if req.track_classes:
|
||||
# Join list items: ["Drums", "Bass"] -> "DRUMS | BASS"
|
||||
classes_str = " | ".join([str(t).upper() for t in req.track_classes])
|
||||
# Use the raw instruction template from constants
|
||||
# Format: "Complete the track with {TRACK_CLASSES}:"
|
||||
instruction_to_use = raw_instruction.format(TRACK_CLASSES=classes_str)
|
||||
else:
|
||||
# Fallback if no classes provided
|
||||
instruction_to_use = TASK_INSTRUCTIONS.get("complete_default", raw_instruction)
|
||||
|
||||
elif "{TRACK_NAME}" in raw_instruction and req.track_name:
|
||||
# Logic for extract/lego
|
||||
instruction_to_use = raw_instruction.format(TRACK_NAME=req.track_name.upper())
|
||||
else:
|
||||
instruction_to_use = raw_instruction
|
||||
|
||||
# Build GenerationParams using unified interface
|
||||
# Note: thinking controls LM code generation, sample_mode only affects CoT metas
|
||||
|
|
@ -1542,11 +1585,30 @@ def create_app() -> FastAPI:
|
|||
|
||||
# Build GenerationConfig - default to 2 audios like gradio_ui
|
||||
batch_size = req.batch_size if req.batch_size is not None else 2
|
||||
|
||||
# Resolve seed(s) from req.seed into List[int] for GenerationConfig.seeds
|
||||
resolved_seeds = None
|
||||
if not req.use_random_seed and req.seed is not None:
|
||||
if isinstance(req.seed, int):
|
||||
if req.seed >= 0:
|
||||
resolved_seeds = [req.seed]
|
||||
elif isinstance(req.seed, str):
|
||||
resolved_seeds = []
|
||||
for s in req.seed.split(","):
|
||||
s = s.strip()
|
||||
if s and s != "-1":
|
||||
try:
|
||||
resolved_seeds.append(int(float(s)))
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
if not resolved_seeds:
|
||||
resolved_seeds = None
|
||||
|
||||
config = GenerationConfig(
|
||||
batch_size=batch_size,
|
||||
allow_lm_batch=req.allow_lm_batch,
|
||||
use_random_seed=req.use_random_seed,
|
||||
seeds=None, # Let unified logic handle seed generation
|
||||
seeds=resolved_seeds,
|
||||
audio_format=req.audio_format,
|
||||
constrained_decoding_debug=req.constrained_decoding_debug,
|
||||
)
|
||||
|
|
@ -2155,6 +2217,16 @@ def create_app() -> FastAPI:
|
|||
|
||||
app = FastAPI(title="ACE-Step API", version="1.0", lifespan=lifespan)
|
||||
|
||||
# Enable CORS for browser-based frontends (e.g. studio.html opened via file://)
|
||||
# Restricted to localhost origins and the "null" origin (file:// protocol)
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["null", "http://localhost", "http://127.0.0.1"],
|
||||
allow_origin_regex=r"^https?://(localhost|127\.0\.0\.1)(:\d+)?$",
|
||||
allow_methods=["GET", "POST", "OPTIONS"],
|
||||
allow_headers=["Content-Type", "Authorization"],
|
||||
)
|
||||
|
||||
# Mount OpenRouter-compatible endpoints (/v1/chat/completions, /v1/models)
|
||||
from acestep.openrouter_adapter import create_openrouter_router
|
||||
openrouter_router = create_openrouter_router(lambda: app.state)
|
||||
|
|
@ -2185,6 +2257,11 @@ def create_app() -> FastAPI:
|
|||
# when callers (multipart/form, url-encoded, raw body) pass them explicitly.
|
||||
ref_audio = kwargs.pop("reference_audio_path", None) or p.str("reference_audio_path") or None
|
||||
src_audio = kwargs.pop("src_audio_path", None) or p.str("src_audio_path") or None
|
||||
|
||||
t_classes = p.get("track_classes")
|
||||
if t_classes is not None and isinstance(t_classes, str):
|
||||
t_classes = [t_classes]
|
||||
|
||||
return GenerateMusicRequest(
|
||||
prompt=p.str("prompt"),
|
||||
lyrics=p.str("lyrics"),
|
||||
|
|
@ -2209,8 +2286,8 @@ def create_app() -> FastAPI:
|
|||
repainting_end=p.float("repainting_end"),
|
||||
instruction=p.str("instruction", DEFAULT_DIT_INSTRUCTION),
|
||||
audio_cover_strength=p.float("audio_cover_strength", 1.0),
|
||||
reference_audio_path=_validate_audio_path(ref_audio),
|
||||
src_audio_path=_validate_audio_path(src_audio),
|
||||
reference_audio_path=ref_audio,
|
||||
src_audio_path=src_audio,
|
||||
task_type=p.str("task_type", "text2music"),
|
||||
use_adg=p.bool("use_adg"),
|
||||
cfg_interval_start=p.float("cfg_interval_start", 0.0),
|
||||
|
|
@ -2233,6 +2310,8 @@ def create_app() -> FastAPI:
|
|||
use_cot_language=p.bool("use_cot_language", True),
|
||||
is_format_caption=p.bool("is_format_caption"),
|
||||
allow_lm_batch=p.bool("allow_lm_batch", True),
|
||||
track_name=p.str("track_name"),
|
||||
track_classes=t_classes,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
|
@ -2241,18 +2320,40 @@ def create_app() -> FastAPI:
|
|||
if not isinstance(body, dict):
|
||||
raise HTTPException(status_code=400, detail="JSON payload must be an object")
|
||||
verify_token_from_request(body, authorization)
|
||||
req = _build_request(RequestParser(body))
|
||||
|
||||
# Explicitly validate manual string paths from JSON input
|
||||
p = RequestParser(body)
|
||||
req = _build_request(
|
||||
p,
|
||||
reference_audio_path=_validate_audio_path(p.str("reference_audio_path") or None),
|
||||
src_audio_path=_validate_audio_path(p.str("src_audio_path") or None)
|
||||
)
|
||||
|
||||
elif content_type.endswith("+json"):
|
||||
body = await request.json()
|
||||
if not isinstance(body, dict):
|
||||
raise HTTPException(status_code=400, detail="JSON payload must be an object")
|
||||
verify_token_from_request(body, authorization)
|
||||
req = _build_request(RequestParser(body))
|
||||
|
||||
p = RequestParser(body)
|
||||
req = _build_request(
|
||||
p,
|
||||
reference_audio_path=_validate_audio_path(p.str("reference_audio_path") or None),
|
||||
src_audio_path=_validate_audio_path(p.str("src_audio_path") or None)
|
||||
)
|
||||
|
||||
elif content_type.startswith("multipart/form-data"):
|
||||
form = await request.form()
|
||||
form_dict = {k: v for k, v in form.items() if not hasattr(v, 'read')}
|
||||
|
||||
# Parse form data correctly to support lists ---
|
||||
form_dict = {}
|
||||
for k in form.keys():
|
||||
vals = [v for v in form.getlist(k) if not hasattr(v, 'read')]
|
||||
if len(vals) == 1:
|
||||
form_dict[k] = vals[0]
|
||||
elif len(vals) > 1:
|
||||
form_dict[k] = vals
|
||||
|
||||
verify_token_from_request(form_dict, authorization)
|
||||
|
||||
# Support both naming conventions: ref_audio/reference_audio, ctx_audio/src_audio
|
||||
|
|
@ -2275,7 +2376,7 @@ def create_app() -> FastAPI:
|
|||
src_audio_path = _validate_audio_path(str(form.get("ctx_audio_path") or form.get("src_audio_path") or "").strip() or None)
|
||||
|
||||
req = _build_request(
|
||||
RequestParser(dict(form)),
|
||||
RequestParser(dict(form_dict)),
|
||||
reference_audio_path=reference_audio_path,
|
||||
src_audio_path=src_audio_path,
|
||||
)
|
||||
|
|
@ -2301,7 +2402,12 @@ def create_app() -> FastAPI:
|
|||
body = json.loads(raw.decode("utf-8"))
|
||||
if isinstance(body, dict):
|
||||
verify_token_from_request(body, authorization)
|
||||
req = _build_request(RequestParser(body))
|
||||
p = RequestParser(body)
|
||||
req = _build_request(
|
||||
p,
|
||||
reference_audio_path=_validate_audio_path(p.str("reference_audio_path") or None),
|
||||
src_audio_path=_validate_audio_path(p.str("src_audio_path") or None)
|
||||
)
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="JSON payload must be an object")
|
||||
except HTTPException:
|
||||
|
|
|
|||
|
|
@ -1,354 +1,406 @@
|
|||
"""
|
||||
Audio saving and transcoding utility module
|
||||
|
||||
Independent audio file operations outside of handler, supporting:
|
||||
- Save audio tensor/numpy to files (default FLAC format, fast)
|
||||
- Format conversion (FLAC/WAV/MP3)
|
||||
- Batch processing
|
||||
"""
|
||||
|
||||
import os
|
||||
import hashlib
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Union, Optional, List, Tuple
|
||||
import torch
|
||||
import numpy as np
|
||||
import torchaudio
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class AudioSaver:
|
||||
"""Audio saving and transcoding utility class"""
|
||||
|
||||
def __init__(self, default_format: str = "flac"):
|
||||
"""
|
||||
Initialize audio saver
|
||||
|
||||
Args:
|
||||
default_format: Default save format ('flac', 'wav', 'mp3')
|
||||
"""
|
||||
self.default_format = default_format.lower()
|
||||
if self.default_format not in ["flac", "wav", "mp3"]:
|
||||
logger.warning(f"Unsupported format {default_format}, using 'flac'")
|
||||
self.default_format = "flac"
|
||||
|
||||
def save_audio(
|
||||
self,
|
||||
audio_data: Union[torch.Tensor, np.ndarray],
|
||||
output_path: Union[str, Path],
|
||||
sample_rate: int = 48000,
|
||||
format: Optional[str] = None,
|
||||
channels_first: bool = True,
|
||||
) -> str:
|
||||
"""
|
||||
Save audio data to file
|
||||
|
||||
Args:
|
||||
audio_data: Audio data, torch.Tensor [channels, samples] or numpy.ndarray
|
||||
output_path: Output file path (extension can be omitted)
|
||||
sample_rate: Sample rate
|
||||
format: Audio format ('flac', 'wav', 'mp3'), defaults to default_format
|
||||
channels_first: If True, tensor format is [channels, samples], else [samples, channels]
|
||||
|
||||
Returns:
|
||||
Actual saved file path
|
||||
"""
|
||||
format = (format or self.default_format).lower()
|
||||
if format not in ["flac", "wav", "mp3"]:
|
||||
logger.warning(f"Unsupported format {format}, using {self.default_format}")
|
||||
format = self.default_format
|
||||
|
||||
# Ensure output path has correct extension
|
||||
output_path = Path(output_path)
|
||||
if output_path.suffix.lower() not in ['.flac', '.wav', '.mp3']:
|
||||
output_path = output_path.with_suffix(f'.{format}')
|
||||
|
||||
# Convert to torch tensor
|
||||
if isinstance(audio_data, np.ndarray):
|
||||
if channels_first:
|
||||
# numpy [samples, channels] -> tensor [channels, samples]
|
||||
audio_tensor = torch.from_numpy(audio_data.T).float()
|
||||
else:
|
||||
# numpy [samples, channels] -> tensor [samples, channels] -> [channels, samples]
|
||||
audio_tensor = torch.from_numpy(audio_data).float()
|
||||
if audio_tensor.dim() == 2 and audio_tensor.shape[0] < audio_tensor.shape[1]:
|
||||
audio_tensor = audio_tensor.T
|
||||
else:
|
||||
# torch tensor
|
||||
audio_tensor = audio_data.cpu().float()
|
||||
if not channels_first and audio_tensor.dim() == 2:
|
||||
# [samples, channels] -> [channels, samples]
|
||||
if audio_tensor.shape[0] > audio_tensor.shape[1]:
|
||||
audio_tensor = audio_tensor.T
|
||||
|
||||
# Ensure memory is contiguous
|
||||
audio_tensor = audio_tensor.contiguous()
|
||||
|
||||
# Select backend and save
|
||||
try:
|
||||
if format == "mp3":
|
||||
# MP3 uses ffmpeg backend
|
||||
torchaudio.save(
|
||||
str(output_path),
|
||||
audio_tensor,
|
||||
sample_rate,
|
||||
channels_first=True,
|
||||
backend='ffmpeg',
|
||||
)
|
||||
elif format in ["flac", "wav"]:
|
||||
# FLAC and WAV use soundfile backend (fastest)
|
||||
torchaudio.save(
|
||||
str(output_path),
|
||||
audio_tensor,
|
||||
sample_rate,
|
||||
channels_first=True,
|
||||
backend='soundfile',
|
||||
)
|
||||
else:
|
||||
# Other formats use default backend
|
||||
torchaudio.save(
|
||||
str(output_path),
|
||||
audio_tensor,
|
||||
sample_rate,
|
||||
channels_first=True,
|
||||
)
|
||||
|
||||
logger.debug(f"[AudioSaver] Saved audio to {output_path} ({format}, {sample_rate}Hz)")
|
||||
return str(output_path)
|
||||
|
||||
except Exception as e:
|
||||
try:
|
||||
import soundfile as sf
|
||||
audio_np = audio_tensor.transpose(0, 1).numpy() # -> [samples, channels]
|
||||
sf.write(str(output_path), audio_np, sample_rate, format=format.upper())
|
||||
logger.debug(f"[AudioSaver] Fallback soundfile Saved audio to {output_path} ({format}, {sample_rate}Hz)")
|
||||
return str(output_path)
|
||||
except Exception as e:
|
||||
logger.error(f"[AudioSaver] Failed to save audio: {e}")
|
||||
raise
|
||||
|
||||
def convert_audio(
|
||||
self,
|
||||
input_path: Union[str, Path],
|
||||
output_path: Union[str, Path],
|
||||
output_format: str,
|
||||
remove_input: bool = False,
|
||||
) -> str:
|
||||
"""
|
||||
Convert audio format
|
||||
|
||||
Args:
|
||||
input_path: Input audio file path
|
||||
output_path: Output audio file path
|
||||
output_format: Target format ('flac', 'wav', 'mp3')
|
||||
remove_input: Whether to delete input file
|
||||
|
||||
Returns:
|
||||
Output file path
|
||||
"""
|
||||
input_path = Path(input_path)
|
||||
output_path = Path(output_path)
|
||||
|
||||
if not input_path.exists():
|
||||
raise FileNotFoundError(f"Input file not found: {input_path}")
|
||||
|
||||
# Load audio
|
||||
audio_tensor, sample_rate = torchaudio.load(str(input_path))
|
||||
|
||||
# Save as new format
|
||||
output_path = self.save_audio(
|
||||
audio_tensor,
|
||||
output_path,
|
||||
sample_rate=sample_rate,
|
||||
format=output_format,
|
||||
channels_first=True
|
||||
)
|
||||
|
||||
# Delete input file if needed
|
||||
if remove_input:
|
||||
input_path.unlink()
|
||||
logger.debug(f"[AudioSaver] Removed input file: {input_path}")
|
||||
|
||||
return output_path
|
||||
|
||||
def save_batch(
|
||||
self,
|
||||
audio_batch: Union[List[torch.Tensor], torch.Tensor],
|
||||
output_dir: Union[str, Path],
|
||||
file_prefix: str = "audio",
|
||||
sample_rate: int = 48000,
|
||||
format: Optional[str] = None,
|
||||
channels_first: bool = True,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Save audio batch
|
||||
|
||||
Args:
|
||||
audio_batch: Audio batch, List[tensor] or tensor [batch, channels, samples]
|
||||
output_dir: Output directory
|
||||
file_prefix: File prefix
|
||||
sample_rate: Sample rate
|
||||
format: Audio format
|
||||
channels_first: Tensor format flag
|
||||
|
||||
Returns:
|
||||
List of saved file paths
|
||||
"""
|
||||
output_dir = Path(output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Process batch
|
||||
if isinstance(audio_batch, torch.Tensor) and audio_batch.dim() == 3:
|
||||
# [batch, channels, samples]
|
||||
audio_list = [audio_batch[i] for i in range(audio_batch.shape[0])]
|
||||
elif isinstance(audio_batch, list):
|
||||
audio_list = audio_batch
|
||||
else:
|
||||
audio_list = [audio_batch]
|
||||
|
||||
saved_paths = []
|
||||
for i, audio in enumerate(audio_list):
|
||||
output_path = output_dir / f"{file_prefix}_{i:04d}"
|
||||
saved_path = self.save_audio(
|
||||
audio,
|
||||
output_path,
|
||||
sample_rate=sample_rate,
|
||||
format=format,
|
||||
channels_first=channels_first
|
||||
)
|
||||
saved_paths.append(saved_path)
|
||||
|
||||
return saved_paths
|
||||
|
||||
|
||||
def get_audio_file_hash(audio_file) -> str:
|
||||
"""
|
||||
Get hash identifier for an audio file.
|
||||
|
||||
Args:
|
||||
audio_file: Path to audio file (str) or file-like object
|
||||
|
||||
Returns:
|
||||
Hash string or empty string
|
||||
"""
|
||||
if audio_file is None:
|
||||
return ""
|
||||
|
||||
try:
|
||||
if isinstance(audio_file, str):
|
||||
if os.path.exists(audio_file):
|
||||
with open(audio_file, 'rb') as f:
|
||||
return hashlib.md5(f.read()).hexdigest()
|
||||
return hashlib.md5(audio_file.encode('utf-8')).hexdigest()
|
||||
elif hasattr(audio_file, 'name'):
|
||||
return hashlib.md5(str(audio_file.name).encode('utf-8')).hexdigest()
|
||||
return hashlib.md5(str(audio_file).encode('utf-8')).hexdigest()
|
||||
except Exception:
|
||||
return hashlib.md5(str(audio_file).encode('utf-8')).hexdigest()
|
||||
|
||||
|
||||
def generate_uuid_from_params(params_dict) -> str:
|
||||
"""
|
||||
Generate deterministic UUID from generation parameters.
|
||||
Same parameters will always generate the same UUID.
|
||||
|
||||
Args:
|
||||
params_dict: Dictionary of parameters
|
||||
|
||||
Returns:
|
||||
UUID string
|
||||
"""
|
||||
|
||||
params_json = json.dumps(params_dict, sort_keys=True, ensure_ascii=False)
|
||||
hash_obj = hashlib.sha256(params_json.encode('utf-8'))
|
||||
hash_hex = hash_obj.hexdigest()
|
||||
uuid_str = f"{hash_hex[0:8]}-{hash_hex[8:12]}-{hash_hex[12:16]}-{hash_hex[16:20]}-{hash_hex[20:32]}"
|
||||
return uuid_str
|
||||
|
||||
|
||||
def generate_uuid_from_audio_data(
|
||||
audio_data: Union[torch.Tensor, np.ndarray],
|
||||
seed: Optional[int] = None
|
||||
) -> str:
|
||||
"""
|
||||
Generate UUID from audio data (for caching/deduplication)
|
||||
|
||||
Args:
|
||||
audio_data: Audio data
|
||||
seed: Optional seed value
|
||||
|
||||
Returns:
|
||||
UUID string
|
||||
"""
|
||||
if isinstance(audio_data, torch.Tensor):
|
||||
# Convert to numpy and calculate hash
|
||||
audio_np = audio_data.cpu().numpy()
|
||||
else:
|
||||
audio_np = audio_data
|
||||
|
||||
# Calculate data hash
|
||||
data_hash = hashlib.md5(audio_np.tobytes()).hexdigest()
|
||||
|
||||
if seed is not None:
|
||||
combined = f"{data_hash}_{seed}"
|
||||
return hashlib.md5(combined.encode()).hexdigest()
|
||||
|
||||
return data_hash
|
||||
|
||||
|
||||
# Global default instance
|
||||
_default_saver = AudioSaver(default_format="flac")
|
||||
|
||||
SILENT_RMS_THRESHOLD = 1e-5
|
||||
SILENT_PEAK_THRESHOLD = 1e-5
|
||||
|
||||
|
||||
def is_audio_silent(
|
||||
audio_data: Union[torch.Tensor, np.ndarray],
|
||||
rms_threshold: float = SILENT_RMS_THRESHOLD,
|
||||
peak_threshold: float = SILENT_PEAK_THRESHOLD,
|
||||
channels_first: bool = True,
|
||||
) -> Tuple[bool, float, float]:
|
||||
"""
|
||||
Check if audio is silent or near-silent (e.g. zeroed conditioning output).
|
||||
Returns (is_silent, rms, peak) where rms/peak are computed over the full signal.
|
||||
"""
|
||||
if audio_data is None:
|
||||
return True, 0.0, 0.0
|
||||
if isinstance(audio_data, np.ndarray):
|
||||
x = np.asarray(audio_data, dtype=np.float64).ravel()
|
||||
else:
|
||||
x = audio_data.cpu().float().numpy().ravel()
|
||||
if x.size == 0:
|
||||
return True, 0.0, 0.0
|
||||
rms = float(np.sqrt(np.mean(x * x)))
|
||||
peak = float(np.max(np.abs(x)))
|
||||
is_silent = rms <= rms_threshold and peak <= peak_threshold
|
||||
return is_silent, rms, peak
|
||||
|
||||
|
||||
def save_audio(
|
||||
audio_data: Union[torch.Tensor, np.ndarray],
|
||||
output_path: Union[str, Path],
|
||||
sample_rate: int = 48000,
|
||||
format: Optional[str] = None,
|
||||
channels_first: bool = True,
|
||||
) -> str:
|
||||
"""
|
||||
Convenience function: save audio (using default configuration)
|
||||
|
||||
Args:
|
||||
audio_data: Audio data
|
||||
output_path: Output path
|
||||
sample_rate: Sample rate
|
||||
format: Format (default flac)
|
||||
channels_first: Tensor format flag
|
||||
|
||||
Returns:
|
||||
Saved file path
|
||||
"""
|
||||
return _default_saver.save_audio(
|
||||
audio_data, output_path, sample_rate, format, channels_first
|
||||
)
|
||||
|
||||
"""
|
||||
Audio saving and transcoding utility module
|
||||
|
||||
Independent audio file operations outside of handler, supporting:
|
||||
- Save audio tensor/numpy to files (default FLAC format, fast)
|
||||
- Format conversion (FLAC/WAV/MP3)
|
||||
- Batch processing
|
||||
"""
|
||||
|
||||
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
from typing import Union, Optional, List, Tuple
|
||||
import torch
|
||||
import numpy as np
|
||||
import torchaudio
|
||||
from loguru import logger
|
||||
|
||||
|
||||
def normalize_audio(audio_data: Union[torch.Tensor, np.ndarray], target_db: float = -1.0) -> Union[torch.Tensor, np.ndarray]:
|
||||
"""
|
||||
Apply peak normalization to audio data.
|
||||
|
||||
Args:
|
||||
audio_data: Audio data as torch.Tensor or numpy.ndarray
|
||||
target_db: Target peak level in dB (default: -1.0)
|
||||
|
||||
Returns:
|
||||
Normalized audio data in the same format as input
|
||||
"""
|
||||
# Create a copy to avoid modifying original in-place
|
||||
if isinstance(audio_data, torch.Tensor):
|
||||
audio = audio_data.clone()
|
||||
is_tensor = True
|
||||
else:
|
||||
audio = audio_data.copy()
|
||||
is_tensor = False
|
||||
|
||||
# Calculate current peak
|
||||
if is_tensor:
|
||||
peak = torch.max(torch.abs(audio))
|
||||
else:
|
||||
peak = np.max(np.abs(audio))
|
||||
|
||||
# Handle silence/near-silence to avoid division by zero or extreme gain
|
||||
if peak < 1e-6:
|
||||
return audio_data
|
||||
|
||||
# Convert target dB to linear amplitude
|
||||
target_amp = 10 ** (target_db / 20.0)
|
||||
|
||||
# Calculate needed gain
|
||||
gain = target_amp / peak
|
||||
|
||||
# Apply gain
|
||||
audio = audio * gain
|
||||
|
||||
return audio
|
||||
|
||||
|
||||
|
||||
class AudioSaver:
|
||||
"""Audio saving and transcoding utility class"""
|
||||
|
||||
def __init__(self, default_format: str = "flac"):
|
||||
"""
|
||||
Initialize audio saver
|
||||
|
||||
Args:
|
||||
default_format: Default save format ('flac', 'wav', 'mp3', 'wav32')
|
||||
"""
|
||||
self.default_format = default_format.lower()
|
||||
if self.default_format not in ["flac", "wav", "mp3", "wav32"]:
|
||||
logger.warning(f"Unsupported format {default_format}, using 'flac'")
|
||||
self.default_format = "flac"
|
||||
|
||||
def save_audio(
|
||||
self,
|
||||
audio_data: Union[torch.Tensor, np.ndarray],
|
||||
output_path: Union[str, Path],
|
||||
sample_rate: int = 48000,
|
||||
format: Optional[str] = None,
|
||||
channels_first: bool = True,
|
||||
) -> str:
|
||||
"""
|
||||
Save audio data to file
|
||||
|
||||
Args:
|
||||
audio_data: Audio data, torch.Tensor [channels, samples] or numpy.ndarray
|
||||
output_path: Output file path (extension can be omitted)
|
||||
sample_rate: Sample rate
|
||||
format: Audio format ('flac', 'wav', 'mp3', 'wav32'), defaults to default_format
|
||||
channels_first: If True, tensor format is [channels, samples], else [samples, channels]
|
||||
|
||||
Returns:
|
||||
Actual saved file path
|
||||
"""
|
||||
format = (format or self.default_format).lower()
|
||||
if format not in ["flac", "wav", "mp3", "wav32"]:
|
||||
logger.warning(f"Unsupported format {format}, using {self.default_format}")
|
||||
format = self.default_format
|
||||
|
||||
# Ensure output path has correct extension
|
||||
output_path = Path(output_path)
|
||||
|
||||
# Determine extension based on format
|
||||
ext = ".wav" if format == "wav32" else f".{format}"
|
||||
|
||||
if output_path.suffix.lower() not in ['.flac', '.wav', '.mp3']:
|
||||
output_path = output_path.with_suffix(ext)
|
||||
elif format == "wav32" and output_path.suffix.lower() == ".wav32":
|
||||
# Explicitly fix .wav32 extension if present
|
||||
output_path = output_path.with_suffix(".wav")
|
||||
|
||||
# Convert to torch tensor
|
||||
if isinstance(audio_data, np.ndarray):
|
||||
if channels_first:
|
||||
# numpy already [channels, samples]
|
||||
audio_tensor = torch.from_numpy(audio_data).float()
|
||||
else:
|
||||
# numpy [samples, channels] -> tensor [samples, channels] -> [channels, samples] (if transposed)
|
||||
audio_tensor = torch.from_numpy(audio_data).float()
|
||||
if audio_tensor.dim() == 2 and audio_tensor.shape[0] > audio_tensor.shape[1]:
|
||||
# Assume [samples, channels] if dim0 > dim1 (heuristic)
|
||||
audio_tensor = audio_tensor.T
|
||||
else:
|
||||
# torch tensor
|
||||
audio_tensor = audio_data.cpu().float()
|
||||
if not channels_first and audio_tensor.dim() == 2:
|
||||
# [samples, channels] -> [channels, samples]
|
||||
if audio_tensor.shape[0] > audio_tensor.shape[1]:
|
||||
audio_tensor = audio_tensor.T
|
||||
|
||||
# Ensure memory is contiguous
|
||||
audio_tensor = audio_tensor.contiguous()
|
||||
|
||||
# Select backend and save
|
||||
try:
|
||||
if format == "mp3":
|
||||
# MP3 uses ffmpeg backend
|
||||
torchaudio.save(
|
||||
str(output_path),
|
||||
audio_tensor,
|
||||
sample_rate,
|
||||
channels_first=True,
|
||||
backend='ffmpeg',
|
||||
)
|
||||
elif format in ["flac", "wav", "wav32"]:
|
||||
# FLAC and WAV use soundfile backend (fastest)
|
||||
# handle 32-bit float wav
|
||||
if format == "wav32":
|
||||
try:
|
||||
import soundfile as sf
|
||||
|
||||
# Use soundfile directly for 32-bit float
|
||||
audio_np = audio_tensor.transpose(0, 1).numpy() # [channels, samples] -> [samples, channels]
|
||||
|
||||
# Explicitly specify format as WAV to avoid issues with extension detection or custom extensions
|
||||
sf.write(str(output_path), audio_np, sample_rate, subtype='FLOAT', format='WAV')
|
||||
logger.debug(f"[AudioSaver] Saved audio to {output_path} (wav32, {sample_rate}Hz)")
|
||||
return str(output_path)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save wav32: {e}, falling back to standard wav")
|
||||
format = "wav"
|
||||
# Fallthrough to standard wav saving
|
||||
|
||||
torchaudio.save(
|
||||
str(output_path),
|
||||
audio_tensor,
|
||||
sample_rate,
|
||||
channels_first=True,
|
||||
backend='soundfile',
|
||||
)
|
||||
else:
|
||||
# Other formats use default backend
|
||||
torchaudio.save(
|
||||
str(output_path),
|
||||
audio_tensor,
|
||||
sample_rate,
|
||||
channels_first=True,
|
||||
)
|
||||
|
||||
logger.debug(f"[AudioSaver] Saved audio to {output_path} ({format}, {sample_rate}Hz)")
|
||||
return str(output_path)
|
||||
|
||||
except Exception as e:
|
||||
try:
|
||||
import soundfile as sf
|
||||
audio_np = audio_tensor.transpose(0, 1).numpy() # -> [samples, channels]
|
||||
|
||||
# Handle wav32 fallback formatting
|
||||
if format == "wav32":
|
||||
sf_format = "WAV"
|
||||
subtype = "FLOAT"
|
||||
else:
|
||||
sf_format = format.upper()
|
||||
subtype = None
|
||||
|
||||
sf.write(str(output_path), audio_np, sample_rate, format=sf_format, subtype=subtype)
|
||||
logger.debug(f"[AudioSaver] Fallback soundfile Saved audio to {output_path} ({format}, {sample_rate}Hz)")
|
||||
return str(output_path)
|
||||
except Exception as inner_e:
|
||||
logger.error(f"[AudioSaver] Failed to save audio: {e} -> Fallback failed: {inner_e}")
|
||||
raise
|
||||
|
||||
def convert_audio(
|
||||
self,
|
||||
input_path: Union[str, Path],
|
||||
output_path: Union[str, Path],
|
||||
output_format: str,
|
||||
remove_input: bool = False,
|
||||
) -> str:
|
||||
"""
|
||||
Convert audio format
|
||||
|
||||
Args:
|
||||
input_path: Input audio file path
|
||||
output_path: Output audio file path
|
||||
output_format: Target format ('flac', 'wav', 'mp3')
|
||||
remove_input: Whether to delete input file
|
||||
|
||||
Returns:
|
||||
Output file path
|
||||
"""
|
||||
input_path = Path(input_path)
|
||||
output_path = Path(output_path)
|
||||
|
||||
if not input_path.exists():
|
||||
raise FileNotFoundError(f"Input file not found: {input_path}")
|
||||
|
||||
# Load audio
|
||||
audio_tensor, sample_rate = torchaudio.load(str(input_path))
|
||||
|
||||
# Save as new format
|
||||
output_path = self.save_audio(
|
||||
audio_tensor,
|
||||
output_path,
|
||||
sample_rate=sample_rate,
|
||||
format=output_format,
|
||||
channels_first=True
|
||||
)
|
||||
|
||||
# Delete input file if needed
|
||||
if remove_input:
|
||||
input_path.unlink()
|
||||
logger.debug(f"[AudioSaver] Removed input file: {input_path}")
|
||||
|
||||
return output_path
|
||||
|
||||
def save_batch(
|
||||
self,
|
||||
audio_batch: Union[List[torch.Tensor], torch.Tensor],
|
||||
output_dir: Union[str, Path],
|
||||
file_prefix: str = "audio",
|
||||
sample_rate: int = 48000,
|
||||
format: Optional[str] = None,
|
||||
channels_first: bool = True,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Save audio batch
|
||||
|
||||
Args:
|
||||
audio_batch: Audio batch, List[tensor] or tensor [batch, channels, samples]
|
||||
output_dir: Output directory
|
||||
file_prefix: File prefix
|
||||
sample_rate: Sample rate
|
||||
format: Audio format
|
||||
channels_first: Tensor format flag
|
||||
|
||||
Returns:
|
||||
List of saved file paths
|
||||
"""
|
||||
output_dir = Path(output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Process batch
|
||||
if isinstance(audio_batch, torch.Tensor) and audio_batch.dim() == 3:
|
||||
# [batch, channels, samples]
|
||||
audio_list = [audio_batch[i] for i in range(audio_batch.shape[0])]
|
||||
elif isinstance(audio_batch, list):
|
||||
audio_list = audio_batch
|
||||
else:
|
||||
audio_list = [audio_batch]
|
||||
|
||||
saved_paths = []
|
||||
for i, audio in enumerate(audio_list):
|
||||
output_path = output_dir / f"{file_prefix}_{i:04d}"
|
||||
saved_path = self.save_audio(
|
||||
audio,
|
||||
output_path,
|
||||
sample_rate=sample_rate,
|
||||
format=format,
|
||||
channels_first=channels_first
|
||||
)
|
||||
saved_paths.append(saved_path)
|
||||
|
||||
return saved_paths
|
||||
|
||||
|
||||
def get_audio_file_hash(audio_file) -> str:
|
||||
"""
|
||||
Get hash identifier for an audio file.
|
||||
|
||||
Args:
|
||||
audio_file: Path to audio file (str) or file-like object
|
||||
|
||||
Returns:
|
||||
Hash string or empty string
|
||||
"""
|
||||
if audio_file is None:
|
||||
return ""
|
||||
|
||||
try:
|
||||
if isinstance(audio_file, str):
|
||||
if os.path.exists(audio_file):
|
||||
with open(audio_file, 'rb') as f:
|
||||
return hashlib.sha256(f.read()).hexdigest()
|
||||
return hashlib.sha256(audio_file.encode('utf-8')).hexdigest()
|
||||
elif hasattr(audio_file, 'name'):
|
||||
return hashlib.sha256(str(audio_file.name).encode('utf-8')).hexdigest()
|
||||
return hashlib.sha256(str(audio_file).encode('utf-8')).hexdigest()
|
||||
except Exception:
|
||||
return hashlib.sha256(str(audio_file).encode('utf-8')).hexdigest()
|
||||
|
||||
|
||||
def generate_uuid_from_params(params_dict) -> str:
|
||||
"""
|
||||
Generate deterministic UUID from generation parameters.
|
||||
Same parameters will always generate the same UUID.
|
||||
|
||||
Args:
|
||||
params_dict: Dictionary of parameters
|
||||
|
||||
Returns:
|
||||
UUID string
|
||||
"""
|
||||
|
||||
params_json = json.dumps(params_dict, sort_keys=True, ensure_ascii=False)
|
||||
hash_obj = hashlib.sha256(params_json.encode('utf-8'))
|
||||
hash_hex = hash_obj.hexdigest()
|
||||
uuid_str = f"{hash_hex[0:8]}-{hash_hex[8:12]}-{hash_hex[12:16]}-{hash_hex[16:20]}-{hash_hex[20:32]}"
|
||||
return uuid_str
|
||||
|
||||
|
||||
def generate_uuid_from_audio_data(
|
||||
audio_data: Union[torch.Tensor, np.ndarray],
|
||||
seed: Optional[int] = None
|
||||
) -> str:
|
||||
"""
|
||||
Generate UUID from audio data (for caching/deduplication)
|
||||
|
||||
Args:
|
||||
audio_data: Audio data
|
||||
seed: Optional seed value
|
||||
|
||||
Returns:
|
||||
UUID string
|
||||
"""
|
||||
if isinstance(audio_data, torch.Tensor):
|
||||
# Convert to numpy and calculate hash
|
||||
audio_np = audio_data.cpu().numpy()
|
||||
else:
|
||||
audio_np = audio_data
|
||||
|
||||
# Calculate data hash
|
||||
data_hash = hashlib.sha256(audio_np.tobytes()).hexdigest()
|
||||
|
||||
if seed is not None:
|
||||
combined = f"{data_hash}_{seed}"
|
||||
return hashlib.sha256(combined.encode()).hexdigest()
|
||||
|
||||
return data_hash
|
||||
|
||||
|
||||
# Global default instance
|
||||
_default_saver = AudioSaver(default_format="flac")
|
||||
|
||||
|
||||
def save_audio(
|
||||
audio_data: Union[torch.Tensor, np.ndarray],
|
||||
output_path: Union[str, Path],
|
||||
sample_rate: int = 48000,
|
||||
format: Optional[str] = None,
|
||||
channels_first: bool = True,
|
||||
) -> str:
|
||||
"""
|
||||
Convenience function: save audio (using default configuration)
|
||||
|
||||
Args:
|
||||
audio_data: Audio data
|
||||
output_path: Output path
|
||||
sample_rate: Sample rate
|
||||
format: Format (default flac)
|
||||
channels_first: Tensor format flag
|
||||
|
||||
Returns:
|
||||
Saved file path
|
||||
"""
|
||||
return _default_saver.save_audio(
|
||||
audio_data, output_path, sample_rate, format, channels_first
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -87,6 +87,28 @@ TASK_TYPES_TURBO = ["text2music", "repaint", "cover"]
|
|||
TASK_TYPES_BASE = ["text2music", "repaint", "cover", "extract", "lego", "complete"]
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Generation Mode Constants (UI-level modes that map to task types)
|
||||
# ==============================================================================
|
||||
|
||||
# Default modes for turbo and SFT models (restricted set)
|
||||
GENERATION_MODES_TURBO = ["Simple", "Custom", "Remix", "Repaint"]
|
||||
|
||||
# Extended modes for pure base models only — adds Extract/Lego/Complete
|
||||
GENERATION_MODES_BASE = ["Simple", "Custom", "Remix", "Repaint", "Extract", "Lego", "Complete"]
|
||||
|
||||
# Mapping from generation mode to task_type value
|
||||
MODE_TO_TASK_TYPE = {
|
||||
"Simple": "text2music",
|
||||
"Custom": "text2music",
|
||||
"Remix": "cover",
|
||||
"Repaint": "repaint",
|
||||
"Extract": "extract",
|
||||
"Lego": "lego",
|
||||
"Complete": "complete",
|
||||
}
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Instruction Constants
|
||||
# ==============================================================================
|
||||
|
|
|
|||
|
|
@ -1,6 +1,29 @@
|
|||
"""Handler decomposition components."""
|
||||
|
||||
from .audio_codes import AudioCodesMixin
|
||||
from .batch_prep import BatchPrepMixin
|
||||
from .diffusion import DiffusionMixin
|
||||
from .init_service import InitServiceMixin
|
||||
from .io_audio import IoAudioMixin
|
||||
from .lora_manager import LoraManagerMixin
|
||||
from .memory_utils import MemoryUtilsMixin
|
||||
from .metadata_utils import MetadataMixin
|
||||
from .padding_utils import PaddingMixin
|
||||
from .prompt_utils import PromptMixin
|
||||
from .progress import ProgressMixin
|
||||
from .task_utils import TaskUtilsMixin
|
||||
|
||||
__all__ = ["LoraManagerMixin", "ProgressMixin"]
|
||||
__all__ = [
|
||||
"AudioCodesMixin",
|
||||
"BatchPrepMixin",
|
||||
"DiffusionMixin",
|
||||
"InitServiceMixin",
|
||||
"IoAudioMixin",
|
||||
"LoraManagerMixin",
|
||||
"MemoryUtilsMixin",
|
||||
"MetadataMixin",
|
||||
"PaddingMixin",
|
||||
"PromptMixin",
|
||||
"ProgressMixin",
|
||||
"TaskUtilsMixin",
|
||||
]
|
||||
|
|
|
|||
99
acestep/core/generation/handler/audio_codes.py
Normal file
99
acestep/core/generation/handler/audio_codes.py
Normal file
|
|
@ -0,0 +1,99 @@
|
|||
"""Audio-code parsing and conversion helpers for handler decomposition."""
|
||||
|
||||
import re
|
||||
import traceback
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class AudioCodesMixin:
|
||||
"""Mixin containing audio-code parsing and latent conversion helpers.
|
||||
|
||||
Depends on host members:
|
||||
- Attributes: ``model``, ``vae``, ``device``, ``dtype``, ``silence_latent``.
|
||||
- Methods: ``_load_model_context``, ``process_src_audio``, ``is_silence``,
|
||||
``_encode_audio_to_latents``.
|
||||
"""
|
||||
|
||||
def _parse_audio_code_string(self, code_str: str) -> List[int]:
|
||||
"""Extract integer audio codes from tokens like ``<|audio_code_123|>``."""
|
||||
if not code_str:
|
||||
return []
|
||||
try:
|
||||
max_audio_code = 63999
|
||||
codes = []
|
||||
clamped_count = 0
|
||||
for x in re.findall(r"<\|audio_code_(\d+)\|>", code_str):
|
||||
code_value = int(x)
|
||||
clamped_value = max(0, min(code_value, max_audio_code))
|
||||
if clamped_value != code_value:
|
||||
clamped_count += 1
|
||||
logger.warning(
|
||||
f"[_parse_audio_code_string] Clamped audio code value from {code_value} to {clamped_value}"
|
||||
)
|
||||
codes.append(clamped_value)
|
||||
if clamped_count > 0:
|
||||
logger.warning(
|
||||
f"[_parse_audio_code_string] Clamped {clamped_count} audio code value(s) "
|
||||
f"to valid range [0, {max_audio_code}]"
|
||||
)
|
||||
return codes
|
||||
except Exception as e:
|
||||
logger.debug(f"[_parse_audio_code_string] Failed to parse audio code string: {e}")
|
||||
return []
|
||||
|
||||
def _decode_audio_codes_to_latents(self, code_str: str) -> Optional[torch.Tensor]:
|
||||
"""Convert serialized audio-code string into 25Hz latents."""
|
||||
if self.model is None or not hasattr(self.model, "tokenizer") or not hasattr(self.model, "detokenizer"):
|
||||
return None
|
||||
|
||||
code_ids = self._parse_audio_code_string(code_str)
|
||||
if len(code_ids) == 0:
|
||||
return None
|
||||
|
||||
with self._load_model_context("model"):
|
||||
quantizer = self.model.tokenizer.quantizer
|
||||
detokenizer = self.model.detokenizer
|
||||
indices = torch.tensor(code_ids, device=self.device, dtype=torch.long)
|
||||
indices = indices.unsqueeze(0).unsqueeze(-1)
|
||||
|
||||
quantized = quantizer.get_output_from_indices(indices)
|
||||
if quantized.dtype != self.dtype:
|
||||
quantized = quantized.to(self.dtype)
|
||||
lm_hints_25hz = detokenizer(quantized)
|
||||
return lm_hints_25hz
|
||||
|
||||
def convert_src_audio_to_codes(self, audio_file) -> str:
|
||||
"""Convert uploaded source audio into serialized audio code tokens."""
|
||||
if audio_file is None:
|
||||
return "❌ Please upload source audio first"
|
||||
if self.model is None or self.vae is None:
|
||||
return "❌ Model not initialized. Please initialize the service first."
|
||||
|
||||
try:
|
||||
processed_audio = self.process_src_audio(audio_file)
|
||||
if processed_audio is None:
|
||||
return "❌ Failed to process audio file"
|
||||
|
||||
with torch.inference_mode():
|
||||
with self._load_model_context("vae"):
|
||||
if self.is_silence(processed_audio.unsqueeze(0)):
|
||||
return "❌ Audio file appears to be silent"
|
||||
latents = self._encode_audio_to_latents(processed_audio)
|
||||
|
||||
attention_mask = torch.ones(latents.shape[0], dtype=torch.bool, device=self.device)
|
||||
with self._load_model_context("model"):
|
||||
hidden_states = latents.unsqueeze(0)
|
||||
_, indices, _ = self.model.tokenize(
|
||||
hidden_states, self.silence_latent, attention_mask.unsqueeze(0)
|
||||
)
|
||||
indices_flat = indices.flatten().cpu().tolist()
|
||||
codes_string = "".join([f"<|audio_code_{idx}|>" for idx in indices_flat])
|
||||
logger.info(f"[convert_src_audio_to_codes] Generated {len(indices_flat)} audio codes")
|
||||
return codes_string
|
||||
except Exception as e:
|
||||
error_msg = f"❌ Error converting audio to codes: {str(e)}\n{traceback.format_exc()}"
|
||||
logger.exception("[convert_src_audio_to_codes] Error converting audio to codes")
|
||||
return error_msg
|
||||
104
acestep/core/generation/handler/batch_prep.py
Normal file
104
acestep/core/generation/handler/batch_prep.py
Normal file
|
|
@ -0,0 +1,104 @@
|
|||
"""Batch preparation helpers for handler decomposition."""
|
||||
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from acestep.constants import DEFAULT_DIT_INSTRUCTION
|
||||
|
||||
|
||||
class BatchPrepMixin:
|
||||
"""Mixin containing batch and input normalization helpers.
|
||||
|
||||
Depends on host members:
|
||||
- Attributes: ``device``, ``dtype``.
|
||||
- Methods: ``tiled_encode``, ``extract_caption_from_sft_format``,
|
||||
``_build_metadata_dict``.
|
||||
"""
|
||||
|
||||
def _normalize_audio_code_hints(
|
||||
self, audio_code_hints: Optional[Union[str, List[str]]], batch_size: int
|
||||
) -> List[Optional[str]]:
|
||||
"""Normalize ``audio_code_hints`` into a batch-length list."""
|
||||
if audio_code_hints is None:
|
||||
normalized: List[Optional[str]] = [None] * batch_size
|
||||
elif isinstance(audio_code_hints, str):
|
||||
normalized = [audio_code_hints] * batch_size
|
||||
elif len(audio_code_hints) == 1 and batch_size > 1:
|
||||
normalized = audio_code_hints * batch_size
|
||||
elif len(audio_code_hints) != batch_size:
|
||||
normalized = list(audio_code_hints[:batch_size])
|
||||
while len(normalized) < batch_size:
|
||||
normalized.append(None)
|
||||
else:
|
||||
normalized = list(audio_code_hints)
|
||||
return [hint if isinstance(hint, str) and hint.strip() else None for hint in normalized]
|
||||
|
||||
def _normalize_instructions(
|
||||
self,
|
||||
instructions: Optional[Union[str, List[str]]],
|
||||
batch_size: int,
|
||||
default: Optional[str] = None,
|
||||
) -> List[str]:
|
||||
"""Normalize instructions into a batch-length list."""
|
||||
if instructions is None:
|
||||
default_instruction = default or DEFAULT_DIT_INSTRUCTION
|
||||
return [default_instruction] * batch_size
|
||||
if isinstance(instructions, str):
|
||||
return [instructions] * batch_size
|
||||
if len(instructions) == 1:
|
||||
return instructions * batch_size
|
||||
if len(instructions) != batch_size:
|
||||
normalized = list(instructions[:batch_size])
|
||||
default_instruction = default or DEFAULT_DIT_INSTRUCTION
|
||||
while len(normalized) < batch_size:
|
||||
normalized.append(default_instruction)
|
||||
return normalized
|
||||
return list(instructions)
|
||||
|
||||
def _encode_audio_to_latents(self, audio: torch.Tensor) -> torch.Tensor:
|
||||
"""Encode audio to latents using tiled VAE encode path."""
|
||||
input_was_2d = audio.dim() == 2
|
||||
if input_was_2d:
|
||||
audio = audio.unsqueeze(0)
|
||||
|
||||
with torch.inference_mode():
|
||||
latents = self.tiled_encode(audio, offload_latent_to_cpu=True)
|
||||
|
||||
latents = latents.to(self.device).to(self.dtype)
|
||||
latents = latents.transpose(1, 2)
|
||||
if input_was_2d:
|
||||
latents = latents.squeeze(0)
|
||||
return latents
|
||||
|
||||
def prepare_batch_data(
|
||||
self,
|
||||
actual_batch_size,
|
||||
processed_src_audio,
|
||||
audio_duration,
|
||||
captions,
|
||||
lyrics,
|
||||
vocal_language,
|
||||
instruction,
|
||||
bpm,
|
||||
key_scale,
|
||||
time_signature,
|
||||
):
|
||||
"""Prepare repeated batch-level caption/instruction/metadata values."""
|
||||
pure_caption = self.extract_caption_from_sft_format(captions)
|
||||
captions_batch = [pure_caption] * actual_batch_size
|
||||
instructions_batch = [instruction] * actual_batch_size
|
||||
lyrics_batch = [lyrics] * actual_batch_size
|
||||
vocal_languages_batch = [vocal_language] * actual_batch_size
|
||||
|
||||
calculated_duration = None
|
||||
if processed_src_audio is not None:
|
||||
calculated_duration = processed_src_audio.shape[-1] / 48000.0
|
||||
elif audio_duration is not None and float(audio_duration) > 0:
|
||||
calculated_duration = float(audio_duration)
|
||||
|
||||
metadata_dict: Dict[str, Union[str, int]] = self._build_metadata_dict(
|
||||
bpm, key_scale, time_signature, calculated_duration
|
||||
)
|
||||
metas_batch = [metadata_dict.copy() for _ in range(actual_batch_size)]
|
||||
return captions_batch, instructions_batch, lyrics_batch, vocal_languages_batch, metas_batch
|
||||
136
acestep/core/generation/handler/diffusion.py
Normal file
136
acestep/core/generation/handler/diffusion.py
Normal file
|
|
@ -0,0 +1,136 @@
|
|||
"""Diffusion-related handler helpers."""
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
from acestep.mlx_dit.generate import mlx_generate_diffusion
|
||||
|
||||
|
||||
class DiffusionMixin:
|
||||
"""Mixin containing diffusion execution helpers.
|
||||
|
||||
Required host attributes:
|
||||
- ``mlx_decoder``: MLX decoder object passed to ``mlx_generate_diffusion``.
|
||||
- ``device``: torch device string used for output tensor placement.
|
||||
- ``dtype``: torch dtype used for output tensor conversion.
|
||||
"""
|
||||
|
||||
def _mlx_run_diffusion(
|
||||
self,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
context_latents,
|
||||
src_latents,
|
||||
seed,
|
||||
infer_method: str = "ode",
|
||||
shift: float = 3.0,
|
||||
timesteps=None,
|
||||
audio_cover_strength: float = 1.0,
|
||||
encoder_hidden_states_non_cover=None,
|
||||
encoder_attention_mask_non_cover=None,
|
||||
context_latents_non_cover=None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run the MLX diffusion loop and return generated latents.
|
||||
|
||||
This method accepts the same signature as the handler diffusion path for
|
||||
API compatibility. Attention-mask parameters are intentionally accepted
|
||||
but unused because the MLX generator consumes hidden states/latents only.
|
||||
|
||||
Args:
|
||||
encoder_hidden_states: Prompt conditioning tensor.
|
||||
encoder_attention_mask: Unused; accepted for API compatibility.
|
||||
context_latents: Context/reference latent tensor.
|
||||
src_latents: Source latent tensor used for shape and initialization.
|
||||
seed: Random seed used by MLX diffusion.
|
||||
infer_method: Diffusion method, one of ``"ode"`` or ``"sde"``.
|
||||
shift: Timestep shift value.
|
||||
timesteps: Optional iterable or tensor-like custom timesteps.
|
||||
audio_cover_strength: Blend factor for cover conditioning.
|
||||
encoder_hidden_states_non_cover: Optional non-cover conditioning tensor.
|
||||
encoder_attention_mask_non_cover: Unused; accepted for API compatibility.
|
||||
context_latents_non_cover: Optional non-cover context latent tensor.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: ``{"target_latents": torch.Tensor, "time_costs": dict}``.
|
||||
|
||||
Raises:
|
||||
AttributeError: If required host attributes are missing.
|
||||
ValueError: If infer method is unsupported or batch dimensions mismatch.
|
||||
TypeError: If ``timesteps`` is neither iterable nor tensor-like.
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
# Kept for API compatibility with non-MLX diffusion path.
|
||||
_ = encoder_attention_mask, encoder_attention_mask_non_cover
|
||||
|
||||
for required_attr in ("mlx_decoder", "device", "dtype"):
|
||||
if not hasattr(self, required_attr):
|
||||
raise AttributeError(f"DiffusionMixin host is missing required attribute '{required_attr}'")
|
||||
|
||||
if infer_method not in {"ode", "sde"}:
|
||||
raise ValueError(f"Unsupported infer_method '{infer_method}'. Expected 'ode' or 'sde'.")
|
||||
|
||||
if timesteps is not None and not (hasattr(timesteps, "__iter__") or hasattr(timesteps, "tolist")):
|
||||
raise TypeError("timesteps must be iterable, tensor-like, or None")
|
||||
|
||||
if encoder_hidden_states.shape[0] != context_latents.shape[0]:
|
||||
raise ValueError(
|
||||
"Batch dimension mismatch: encoder_hidden_states and context_latents must share dim 0"
|
||||
)
|
||||
if encoder_hidden_states.shape[0] != src_latents.shape[0]:
|
||||
raise ValueError(
|
||||
"Batch dimension mismatch: encoder_hidden_states and src_latents must share dim 0"
|
||||
)
|
||||
if encoder_hidden_states_non_cover is not None and encoder_hidden_states_non_cover.shape[0] != encoder_hidden_states.shape[0]:
|
||||
raise ValueError(
|
||||
"Batch dimension mismatch: encoder_hidden_states_non_cover must share dim 0 with encoder_hidden_states"
|
||||
)
|
||||
if context_latents_non_cover is not None and context_latents_non_cover.shape[0] != context_latents.shape[0]:
|
||||
raise ValueError(
|
||||
"Batch dimension mismatch: context_latents_non_cover must share dim 0 with context_latents"
|
||||
)
|
||||
|
||||
# Convert inputs to numpy (float32)
|
||||
enc_np = encoder_hidden_states.detach().cpu().float().numpy()
|
||||
ctx_np = context_latents.detach().cpu().float().numpy()
|
||||
src_shape = (src_latents.shape[0], src_latents.shape[1], src_latents.shape[2])
|
||||
|
||||
enc_nc_np = (
|
||||
encoder_hidden_states_non_cover.detach().cpu().float().numpy()
|
||||
if encoder_hidden_states_non_cover is not None else None
|
||||
)
|
||||
ctx_nc_np = (
|
||||
context_latents_non_cover.detach().cpu().float().numpy()
|
||||
if context_latents_non_cover is not None else None
|
||||
)
|
||||
|
||||
# Convert timesteps tensor if present
|
||||
ts_list = None
|
||||
if timesteps is not None:
|
||||
if hasattr(timesteps, "tolist"):
|
||||
ts_list = timesteps.tolist()
|
||||
else:
|
||||
ts_list = list(timesteps)
|
||||
|
||||
result = mlx_generate_diffusion(
|
||||
mlx_decoder=self.mlx_decoder,
|
||||
encoder_hidden_states_np=enc_np,
|
||||
context_latents_np=ctx_np,
|
||||
src_latents_shape=src_shape,
|
||||
seed=seed,
|
||||
infer_method=infer_method,
|
||||
shift=shift,
|
||||
timesteps=ts_list,
|
||||
audio_cover_strength=audio_cover_strength,
|
||||
encoder_hidden_states_non_cover_np=enc_nc_np,
|
||||
context_latents_non_cover_np=ctx_nc_np,
|
||||
)
|
||||
|
||||
# Convert result latents back to PyTorch tensor on the correct device
|
||||
target_np = result["target_latents"]
|
||||
target_tensor = torch.from_numpy(target_np).to(device=self.device, dtype=self.dtype)
|
||||
|
||||
return {
|
||||
"target_latents": target_tensor,
|
||||
"time_costs": result["time_costs"],
|
||||
}
|
||||
155
acestep/core/generation/handler/diffusion_test.py
Normal file
155
acestep/core/generation/handler/diffusion_test.py
Normal file
|
|
@ -0,0 +1,155 @@
|
|||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from acestep.core.generation.handler.diffusion import DiffusionMixin
|
||||
|
||||
|
||||
class _Host(DiffusionMixin):
|
||||
def __init__(self, device: str = "cpu", dtype: torch.dtype = torch.float32):
|
||||
self.mlx_decoder = object()
|
||||
self.device = device
|
||||
self.dtype = dtype
|
||||
|
||||
|
||||
class _IterableTimesteps:
|
||||
def __init__(self, values):
|
||||
self._values = values
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._values)
|
||||
|
||||
|
||||
class DiffusionMixinTests(unittest.TestCase):
|
||||
def test_mlx_run_diffusion_converts_inputs_and_outputs_tensor(self):
|
||||
host = _Host(dtype=torch.float16)
|
||||
encoder_hidden_states = torch.randn(2, 4, 8, dtype=torch.float64)
|
||||
encoder_attention_mask = torch.ones(2, 4, dtype=torch.int64)
|
||||
context_latents = torch.randn(2, 16, 8, dtype=torch.float64)
|
||||
src_latents = torch.zeros(2, 3, 5, dtype=torch.float32)
|
||||
timesteps = torch.tensor([1.0, 0.5], dtype=torch.float32)
|
||||
non_cover_hidden = torch.randn(2, 4, 8, dtype=torch.float64)
|
||||
non_cover_mask = torch.ones(2, 4, dtype=torch.int64)
|
||||
non_cover_context = torch.randn(2, 16, 8, dtype=torch.float64)
|
||||
fake_target = np.ones((2, 3, 5), dtype=np.float32)
|
||||
|
||||
def _fake_generate(**kwargs):
|
||||
self.assertIs(kwargs["mlx_decoder"], host.mlx_decoder)
|
||||
self.assertEqual(kwargs["src_latents_shape"], (2, 3, 5))
|
||||
self.assertEqual(kwargs["timesteps"], [1.0, 0.5])
|
||||
self.assertEqual(kwargs["infer_method"], "sde")
|
||||
self.assertEqual(kwargs["shift"], 2.0)
|
||||
self.assertEqual(kwargs["audio_cover_strength"], 0.6)
|
||||
self.assertEqual(kwargs["encoder_hidden_states_np"].dtype, np.float32)
|
||||
self.assertEqual(kwargs["context_latents_np"].dtype, np.float32)
|
||||
self.assertEqual(kwargs["encoder_hidden_states_non_cover_np"].dtype, np.float32)
|
||||
self.assertEqual(kwargs["context_latents_non_cover_np"].dtype, np.float32)
|
||||
return {"target_latents": fake_target, "time_costs": {"diffusion_time_cost": 1.2}}
|
||||
|
||||
with patch("acestep.core.generation.handler.diffusion.mlx_generate_diffusion", side_effect=_fake_generate):
|
||||
result = host._mlx_run_diffusion(
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
context_latents=context_latents,
|
||||
src_latents=src_latents,
|
||||
seed=123,
|
||||
infer_method="sde",
|
||||
shift=2.0,
|
||||
timesteps=timesteps,
|
||||
audio_cover_strength=0.6,
|
||||
encoder_hidden_states_non_cover=non_cover_hidden,
|
||||
encoder_attention_mask_non_cover=non_cover_mask,
|
||||
context_latents_non_cover=non_cover_context,
|
||||
)
|
||||
|
||||
self.assertIn("target_latents", result)
|
||||
self.assertIn("time_costs", result)
|
||||
self.assertEqual(result["time_costs"]["diffusion_time_cost"], 1.2)
|
||||
self.assertEqual(result["target_latents"].dtype, torch.float16)
|
||||
self.assertEqual(result["target_latents"].device.type, "cpu")
|
||||
self.assertTrue(torch.allclose(result["target_latents"], torch.ones_like(result["target_latents"])))
|
||||
|
||||
def test_mlx_run_diffusion_handles_optional_and_iterable_timesteps(self):
|
||||
host = _Host(dtype=torch.float32)
|
||||
encoder_hidden_states = torch.randn(1, 2, 3, dtype=torch.float32)
|
||||
encoder_attention_mask = torch.ones(1, 2, dtype=torch.int64)
|
||||
context_latents = torch.randn(1, 4, 3, dtype=torch.float32)
|
||||
src_latents = torch.zeros(1, 2, 3, dtype=torch.float32)
|
||||
timesteps = _IterableTimesteps([0.9, 0.8, 0.7])
|
||||
|
||||
def _fake_generate(**kwargs):
|
||||
self.assertEqual(kwargs["timesteps"], [0.9, 0.8, 0.7])
|
||||
self.assertIsNone(kwargs["encoder_hidden_states_non_cover_np"])
|
||||
self.assertIsNone(kwargs["context_latents_non_cover_np"])
|
||||
return {"target_latents": np.zeros((1, 2, 3), dtype=np.float32), "time_costs": {}}
|
||||
|
||||
with patch("acestep.core.generation.handler.diffusion.mlx_generate_diffusion", side_effect=_fake_generate):
|
||||
result = host._mlx_run_diffusion(
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
context_latents=context_latents,
|
||||
src_latents=src_latents,
|
||||
seed=1,
|
||||
timesteps=timesteps,
|
||||
)
|
||||
|
||||
self.assertEqual(tuple(result["target_latents"].shape), (1, 2, 3))
|
||||
self.assertEqual(result["target_latents"].dtype, torch.float32)
|
||||
|
||||
def test_mlx_run_diffusion_rejects_invalid_infer_method(self):
|
||||
host = _Host()
|
||||
x = torch.randn(1, 2, 3)
|
||||
with self.assertRaises(ValueError):
|
||||
host._mlx_run_diffusion(
|
||||
encoder_hidden_states=x,
|
||||
encoder_attention_mask=torch.ones(1, 2, dtype=torch.int64),
|
||||
context_latents=torch.randn(1, 4, 3),
|
||||
src_latents=torch.randn(1, 2, 3),
|
||||
seed=1,
|
||||
infer_method="bad",
|
||||
)
|
||||
|
||||
def test_mlx_run_diffusion_rejects_non_iterable_timesteps(self):
|
||||
host = _Host()
|
||||
x = torch.randn(1, 2, 3)
|
||||
with self.assertRaises(TypeError):
|
||||
host._mlx_run_diffusion(
|
||||
encoder_hidden_states=x,
|
||||
encoder_attention_mask=torch.ones(1, 2, dtype=torch.int64),
|
||||
context_latents=torch.randn(1, 4, 3),
|
||||
src_latents=torch.randn(1, 2, 3),
|
||||
seed=1,
|
||||
timesteps=123,
|
||||
)
|
||||
|
||||
def test_mlx_run_diffusion_rejects_batch_mismatch(self):
|
||||
host = _Host()
|
||||
with self.assertRaises(ValueError):
|
||||
host._mlx_run_diffusion(
|
||||
encoder_hidden_states=torch.randn(2, 2, 3),
|
||||
encoder_attention_mask=torch.ones(2, 2, dtype=torch.int64),
|
||||
context_latents=torch.randn(1, 4, 3),
|
||||
src_latents=torch.randn(2, 2, 3),
|
||||
seed=1,
|
||||
)
|
||||
|
||||
def test_mlx_run_diffusion_requires_host_attributes(self):
|
||||
class _BrokenHost(DiffusionMixin):
|
||||
pass
|
||||
|
||||
host = _BrokenHost()
|
||||
x = torch.randn(1, 2, 3)
|
||||
with self.assertRaises(AttributeError):
|
||||
host._mlx_run_diffusion(
|
||||
encoder_hidden_states=x,
|
||||
encoder_attention_mask=torch.ones(1, 2, dtype=torch.int64),
|
||||
context_latents=torch.randn(1, 4, 3),
|
||||
src_latents=torch.randn(1, 2, 3),
|
||||
seed=1,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
414
acestep/core/generation/handler/init_service.py
Normal file
414
acestep/core/generation/handler/init_service.py
Normal file
|
|
@ -0,0 +1,414 @@
|
|||
"""Initialization-adjacent utility mixin for AceStepHandler."""
|
||||
|
||||
import os
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class InitServiceMixin:
|
||||
def _device_type(self) -> str:
|
||||
"""Normalize the host device value to a backend type string."""
|
||||
if isinstance(self.device, str):
|
||||
return self.device.split(":", 1)[0]
|
||||
return self.device.type
|
||||
|
||||
def get_available_checkpoints(self) -> List[str]:
|
||||
"""Return available checkpoint directory paths under the project root.
|
||||
|
||||
Uses ``self._get_project_root()`` to resolve the checkpoints directory and
|
||||
returns a single-item list when present, otherwise an empty list.
|
||||
"""
|
||||
# Get project root (handler.py is in acestep/, so go up two levels to project root)
|
||||
project_root = self._get_project_root()
|
||||
# default checkpoints
|
||||
checkpoint_dir = os.path.join(project_root, "checkpoints")
|
||||
if os.path.exists(checkpoint_dir):
|
||||
return [checkpoint_dir]
|
||||
else:
|
||||
return []
|
||||
|
||||
def get_available_acestep_v15_models(self) -> List[str]:
|
||||
"""Scan and return all model directory names starting with 'acestep-v15-'"""
|
||||
# Get project root
|
||||
project_root = self._get_project_root()
|
||||
checkpoint_dir = os.path.join(project_root, "checkpoints")
|
||||
|
||||
models = []
|
||||
if os.path.exists(checkpoint_dir):
|
||||
# Scan all directories starting with 'acestep-v15-' in checkpoints folder
|
||||
for item in os.listdir(checkpoint_dir):
|
||||
item_path = os.path.join(checkpoint_dir, item)
|
||||
if os.path.isdir(item_path) and item.startswith("acestep-v15-"):
|
||||
models.append(item)
|
||||
|
||||
# Sort by name
|
||||
models.sort()
|
||||
return models
|
||||
|
||||
def is_flash_attention_available(self, device: Optional[str] = None) -> bool:
|
||||
"""Check whether flash attention can be used on the target device."""
|
||||
target_device = str(device or self.device or "auto").split(":", 1)[0]
|
||||
if target_device == "auto":
|
||||
if not torch.cuda.is_available():
|
||||
return False
|
||||
else:
|
||||
if target_device != "cuda" or not torch.cuda.is_available():
|
||||
return False
|
||||
# FlashAttention requires Ampere (compute capability >= 8.0) or newer
|
||||
try:
|
||||
major, _ = torch.cuda.get_device_capability()
|
||||
if major < 8:
|
||||
logger.info(
|
||||
f"[is_flash_attention_available] GPU compute capability {major}.x < 8.0 "
|
||||
f"(pre-Ampere) — FlashAttention not supported, will use SDPA instead."
|
||||
)
|
||||
return False
|
||||
except Exception:
|
||||
return False
|
||||
try:
|
||||
import flash_attn
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
def is_turbo_model(self) -> bool:
|
||||
"""Check if the currently loaded model is a turbo model"""
|
||||
if self.config is None:
|
||||
return False
|
||||
return getattr(self.config, "is_turbo", False)
|
||||
|
||||
def _empty_cache(self):
|
||||
"""Clear accelerator memory cache (CUDA, XPU, or MPS)."""
|
||||
device_type = self._device_type()
|
||||
if device_type == "cuda" and torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
elif device_type == "xpu" and hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
torch.xpu.empty_cache()
|
||||
elif device_type == "mps" and hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||||
torch.mps.empty_cache()
|
||||
|
||||
def _synchronize(self):
|
||||
"""Synchronize accelerator operations (CUDA, XPU, or MPS)."""
|
||||
device_type = self._device_type()
|
||||
if device_type == "cuda" and torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
elif device_type == "xpu" and hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
torch.xpu.synchronize()
|
||||
elif device_type == "mps" and hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||||
torch.mps.synchronize()
|
||||
|
||||
def _memory_allocated(self):
|
||||
"""Get current accelerator memory usage in bytes, or 0 for unsupported backends."""
|
||||
device_type = self._device_type()
|
||||
if device_type == "cuda" and torch.cuda.is_available():
|
||||
return torch.cuda.memory_allocated()
|
||||
# MPS and XPU don't expose per-tensor memory tracking
|
||||
return 0
|
||||
|
||||
def _max_memory_allocated(self):
|
||||
"""Get peak accelerator memory usage in bytes, or 0 for unsupported backends."""
|
||||
device_type = self._device_type()
|
||||
if device_type == "cuda" and torch.cuda.is_available():
|
||||
return torch.cuda.max_memory_allocated()
|
||||
return 0
|
||||
|
||||
def _is_on_target_device(self, tensor, target_device):
|
||||
"""Check if tensor is on the target device (handles cuda vs cuda:0 comparison)."""
|
||||
if tensor is None:
|
||||
return True
|
||||
try:
|
||||
if isinstance(target_device, torch.device):
|
||||
target_type = target_device.type
|
||||
else:
|
||||
target_type = torch.device(str(target_device)).type
|
||||
except Exception:
|
||||
# Keep fallback conservative: derive backend token instead of assuming CUDA.
|
||||
target_type = str(target_device).strip().lower().split(":", 1)[0]
|
||||
if not target_type:
|
||||
logger.warning(
|
||||
"[_is_on_target_device] Malformed target device value: {!r}",
|
||||
target_device,
|
||||
)
|
||||
return False
|
||||
return tensor.device.type == target_type
|
||||
|
||||
@staticmethod
|
||||
def _get_affine_quantized_tensor_class():
|
||||
"""Return the AffineQuantizedTensor class from torchao, or None if unavailable.
|
||||
|
||||
Supports both old (torchao.quantization.affine_quantized) and new
|
||||
(torchao.dtypes.affine_quantized_tensor) import paths across torchao versions.
|
||||
"""
|
||||
try:
|
||||
from torchao.dtypes.affine_quantized_tensor import AffineQuantizedTensor
|
||||
return AffineQuantizedTensor
|
||||
except ImportError:
|
||||
pass
|
||||
try:
|
||||
from torchao.quantization.affine_quantized import AffineQuantizedTensor
|
||||
return AffineQuantizedTensor
|
||||
except ImportError:
|
||||
pass
|
||||
return None
|
||||
|
||||
def _is_quantized_tensor(self, t):
|
||||
"""True if t is a torchao AffineQuantizedTensor (calling .to() on it can raise NotImplementedError)."""
|
||||
if t is None:
|
||||
return False
|
||||
cls = self._get_affine_quantized_tensor_class()
|
||||
if cls is None:
|
||||
return False
|
||||
return isinstance(t, cls)
|
||||
|
||||
def _has_quantized_params(self, module):
|
||||
"""True if module (or any submodule) has at least one AffineQuantizedTensor parameter."""
|
||||
cls = self._get_affine_quantized_tensor_class()
|
||||
if cls is None:
|
||||
return False
|
||||
for _, param in module.named_parameters():
|
||||
if param is not None and isinstance(param, cls):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _ensure_silence_latent_on_device(self):
|
||||
"""Ensure silence_latent is on the correct device (self.device)."""
|
||||
if hasattr(self, "silence_latent") and self.silence_latent is not None:
|
||||
if not self._is_on_target_device(self.silence_latent, self.device):
|
||||
self.silence_latent = self.silence_latent.to(self.device).to(self.dtype)
|
||||
|
||||
def _move_module_recursive(self, module, target_device, dtype=None, visited=None):
|
||||
"""
|
||||
Recursively move a module and all its submodules to the target device.
|
||||
This handles modules that may not be properly registered.
|
||||
"""
|
||||
if visited is None:
|
||||
visited = set()
|
||||
|
||||
module_id = id(module)
|
||||
if module_id in visited:
|
||||
return
|
||||
visited.add(module_id)
|
||||
|
||||
# Move the module itself
|
||||
module.to(target_device)
|
||||
if dtype is not None:
|
||||
module.to(dtype)
|
||||
|
||||
# Move all direct parameters
|
||||
for param_name, param in module._parameters.items():
|
||||
if param is not None and not self._is_on_target_device(param, target_device):
|
||||
if self._is_quantized_tensor(param):
|
||||
moved_param = self._move_quantized_param(param, target_device)
|
||||
else:
|
||||
moved_param = torch.nn.Parameter(
|
||||
param.data.to(target_device), requires_grad=param.requires_grad
|
||||
)
|
||||
if dtype is not None and moved_param.is_floating_point():
|
||||
moved_param = torch.nn.Parameter(
|
||||
moved_param.data.to(dtype), requires_grad=param.requires_grad
|
||||
)
|
||||
module._parameters[param_name] = moved_param
|
||||
|
||||
# Move all direct buffers
|
||||
for buf_name, buf in module._buffers.items():
|
||||
if buf is not None and not self._is_on_target_device(buf, target_device):
|
||||
module._buffers[buf_name] = buf.to(target_device)
|
||||
|
||||
# Recursively process all submodules (registered and unregistered)
|
||||
for name, child in module._modules.items():
|
||||
if child is not None:
|
||||
self._move_module_recursive(child, target_device, dtype, visited)
|
||||
|
||||
# Also check for any nn.Module attributes that might not be in _modules
|
||||
for attr_name in dir(module):
|
||||
if attr_name.startswith('_'):
|
||||
continue
|
||||
try:
|
||||
attr = getattr(module, attr_name, None)
|
||||
if isinstance(attr, torch.nn.Module) and id(attr) not in visited:
|
||||
self._move_module_recursive(attr, target_device, dtype, visited)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _move_quantized_param(self, param, target_device):
|
||||
"""Move an AffineQuantizedTensor to target_device using _apply_fn_to_data.
|
||||
|
||||
This is the safe fallback for older torch versions where model.to(device) raises
|
||||
NotImplementedError on AffineQuantizedTensor (because aten._has_compatible_shallow_copy_type
|
||||
is not implemented). _apply_fn_to_data recursively applies a function to all inner
|
||||
tensors (int_data, scale, zero_point, etc.) without going through Module._apply.
|
||||
"""
|
||||
if hasattr(param, '_apply_fn_to_data'):
|
||||
return torch.nn.Parameter(
|
||||
param._apply_fn_to_data(lambda x: x.to(target_device)),
|
||||
requires_grad=param.requires_grad,
|
||||
)
|
||||
# Last resort: try direct .to() (may raise), but preserve Parameter registration.
|
||||
moved = param.to(target_device)
|
||||
return torch.nn.Parameter(moved, requires_grad=param.requires_grad)
|
||||
|
||||
def _recursive_to_device(self, model, device, dtype=None):
|
||||
"""
|
||||
Recursively move all parameters and buffers of a model to the specified device.
|
||||
This is more thorough than model.to() for some custom HuggingFace models.
|
||||
|
||||
Handles torchao AffineQuantizedTensor parameters that may raise NotImplementedError
|
||||
on model.to(device) in older torch versions (where Module._apply calls
|
||||
_has_compatible_shallow_copy_type, which is not implemented for AffineQuantizedTensor).
|
||||
In that case, falls back to moving quantized parameters individually via _apply_fn_to_data.
|
||||
"""
|
||||
target_device = torch.device(device) if isinstance(device, str) else device
|
||||
|
||||
# Method 1: Standard .to() call — works on newer torch where _apply uses swap_tensors
|
||||
try:
|
||||
model.to(target_device)
|
||||
if dtype is not None:
|
||||
model.to(dtype)
|
||||
except NotImplementedError:
|
||||
# Older torch: Module._apply calls _has_compatible_shallow_copy_type which is
|
||||
# not implemented for AffineQuantizedTensor. Move parameters manually.
|
||||
logger.info(
|
||||
"[_recursive_to_device] model.to() raised NotImplementedError "
|
||||
"(AffineQuantizedTensor on older torch). Moving parameters individually."
|
||||
)
|
||||
for module in model.modules():
|
||||
# Move non-quantized parameters and buffers directly
|
||||
for param_name, param in module._parameters.items():
|
||||
if param is None:
|
||||
continue
|
||||
if self._is_on_target_device(param, target_device):
|
||||
continue
|
||||
if self._is_quantized_tensor(param):
|
||||
module._parameters[param_name] = self._move_quantized_param(param, target_device)
|
||||
else:
|
||||
module._parameters[param_name] = torch.nn.Parameter(
|
||||
param.data.to(target_device), requires_grad=param.requires_grad
|
||||
)
|
||||
if dtype is not None:
|
||||
module._parameters[param_name] = torch.nn.Parameter(
|
||||
module._parameters[param_name].data.to(dtype),
|
||||
requires_grad=param.requires_grad,
|
||||
)
|
||||
for buf_name, buf in module._buffers.items():
|
||||
if buf is not None and not self._is_on_target_device(buf, target_device):
|
||||
module._buffers[buf_name] = buf.to(target_device)
|
||||
|
||||
# Method 2: Use our thorough recursive moving for any missed modules
|
||||
# (skip if model.to() failed — we already moved everything above)
|
||||
try:
|
||||
self._move_module_recursive(model, target_device, dtype)
|
||||
except NotImplementedError:
|
||||
pass # Already handled above
|
||||
|
||||
# Method 3: Force move via state_dict if there are still parameters on wrong device
|
||||
wrong_device_params = []
|
||||
for name, param in model.named_parameters():
|
||||
if not self._is_on_target_device(param, device):
|
||||
wrong_device_params.append(name)
|
||||
|
||||
if wrong_device_params and device != "cpu":
|
||||
logger.warning(f"[_recursive_to_device] {len(wrong_device_params)} parameters on wrong device after initial move, retrying individually")
|
||||
for module in model.modules():
|
||||
for param_name, param in module._parameters.items():
|
||||
if param is None or self._is_on_target_device(param, target_device):
|
||||
continue
|
||||
if self._is_quantized_tensor(param):
|
||||
module._parameters[param_name] = self._move_quantized_param(param, target_device)
|
||||
else:
|
||||
module._parameters[param_name] = torch.nn.Parameter(
|
||||
param.data.to(target_device), requires_grad=param.requires_grad
|
||||
)
|
||||
if dtype is not None and module._parameters[param_name].is_floating_point():
|
||||
module._parameters[param_name] = torch.nn.Parameter(
|
||||
module._parameters[param_name].data.to(dtype),
|
||||
requires_grad=param.requires_grad,
|
||||
)
|
||||
|
||||
# Synchronize accelerator to ensure all transfers are complete
|
||||
if device != "cpu":
|
||||
self._synchronize()
|
||||
|
||||
# Final verification
|
||||
if device != "cpu":
|
||||
still_wrong = []
|
||||
for name, param in model.named_parameters():
|
||||
if not self._is_on_target_device(param, device):
|
||||
still_wrong.append(f"{name} on {param.device}")
|
||||
if still_wrong:
|
||||
logger.error(f"[_recursive_to_device] CRITICAL: {len(still_wrong)} parameters still on wrong device: {still_wrong[:10]}")
|
||||
|
||||
@contextmanager
|
||||
def _load_model_context(self, model_name: str):
|
||||
"""
|
||||
Context manager to load a model to GPU and offload it back to CPU after use.
|
||||
|
||||
Args:
|
||||
model_name: Name of the model to load ("text_encoder", "vae", "model")
|
||||
"""
|
||||
if not self.offload_to_cpu:
|
||||
yield
|
||||
return
|
||||
|
||||
# If model is DiT ("model") and offload_dit_to_cpu is False, do not offload
|
||||
if model_name == "model" and not self.offload_dit_to_cpu:
|
||||
# Ensure it's on device if not already (should be handled by init, but safe to check)
|
||||
model = getattr(self, model_name, None)
|
||||
if model is not None:
|
||||
# Check if model is on CPU, if so move to device (one-time move if it was somehow on CPU)
|
||||
# We check the first parameter's device
|
||||
try:
|
||||
param = next(model.parameters())
|
||||
if param.device.type == "cpu":
|
||||
logger.info(f"[_load_model_context] Moving {model_name} to {self.device} (persistent)")
|
||||
self._recursive_to_device(model, self.device, self.dtype)
|
||||
if hasattr(self, "silence_latent"):
|
||||
self.silence_latent = self.silence_latent.to(self.device).to(self.dtype)
|
||||
except StopIteration:
|
||||
pass
|
||||
yield
|
||||
return
|
||||
|
||||
model = getattr(self, model_name, None)
|
||||
if model is None:
|
||||
yield
|
||||
return
|
||||
|
||||
# Load to GPU
|
||||
logger.info(f"[_load_model_context] Loading {model_name} to {self.device}")
|
||||
start_time = time.time()
|
||||
if model_name == "vae":
|
||||
vae_dtype = self._get_vae_dtype()
|
||||
self._recursive_to_device(model, self.device, vae_dtype)
|
||||
else:
|
||||
self._recursive_to_device(model, self.device, self.dtype)
|
||||
|
||||
if model_name == "model" and hasattr(self, "silence_latent"):
|
||||
self.silence_latent = self.silence_latent.to(self.device).to(self.dtype)
|
||||
|
||||
load_time = time.time() - start_time
|
||||
self.current_offload_cost += load_time
|
||||
logger.info(f"[_load_model_context] Loaded {model_name} to {self.device} in {load_time:.4f}s")
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
# Offload to CPU
|
||||
logger.info(f"[_load_model_context] Offloading {model_name} to CPU")
|
||||
start_time = time.time()
|
||||
if model_name == "vae":
|
||||
self._recursive_to_device(model, "cpu", self._get_vae_dtype("cpu"))
|
||||
else:
|
||||
self._recursive_to_device(model, "cpu")
|
||||
|
||||
# NOTE: Do NOT offload silence_latent to CPU here!
|
||||
# silence_latent is used in many places outside of model context,
|
||||
# so it should stay on GPU to avoid device mismatch errors.
|
||||
|
||||
self._empty_cache()
|
||||
offload_time = time.time() - start_time
|
||||
self.current_offload_cost += offload_time
|
||||
logger.info(f"[_load_model_context] Offloaded {model_name} to CPU in {offload_time:.4f}s")
|
||||
179
acestep/core/generation/handler/init_service_test.py
Normal file
179
acestep/core/generation/handler/init_service_test.py
Normal file
|
|
@ -0,0 +1,179 @@
|
|||
import os
|
||||
import builtins
|
||||
import types
|
||||
import unittest
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import torch
|
||||
|
||||
from acestep.core.generation.handler.init_service import InitServiceMixin
|
||||
|
||||
|
||||
class _Host(InitServiceMixin):
|
||||
def __init__(self, project_root: str, device: str = "cpu", config=None):
|
||||
self._project_root = project_root
|
||||
self.device = device
|
||||
self.config = config
|
||||
|
||||
def _get_project_root(self):
|
||||
return self._project_root
|
||||
|
||||
|
||||
class InitServiceMixinTests(unittest.TestCase):
|
||||
def test_device_type_normalizes_device(self):
|
||||
host = _Host(project_root="K:/fake_root", device="cuda:0")
|
||||
self.assertEqual(host._device_type(), "cuda")
|
||||
|
||||
def test_is_on_target_device_handles_device_alias(self):
|
||||
host = _Host(project_root="K:/fake_root", device="cpu")
|
||||
t = types.SimpleNamespace(device=types.SimpleNamespace(type="cuda"))
|
||||
self.assertTrue(host._is_on_target_device(t, "cuda:0"))
|
||||
self.assertFalse(host._is_on_target_device(t, "cpu"))
|
||||
|
||||
def test_is_on_target_device_fallback_does_not_assume_cuda(self):
|
||||
host = _Host(project_root="K:/fake_root", device="cpu")
|
||||
t = types.SimpleNamespace(device=types.SimpleNamespace(type="cuda"))
|
||||
self.assertFalse(host._is_on_target_device(t, "mps:0"))
|
||||
|
||||
def test_is_on_target_device_malformed_target_logs_and_returns_false(self):
|
||||
host = _Host(project_root="K:/fake_root", device="cpu")
|
||||
t = types.SimpleNamespace(device=types.SimpleNamespace(type="cuda"))
|
||||
with patch("acestep.core.generation.handler.init_service.logger.warning") as warning:
|
||||
self.assertFalse(host._is_on_target_device(t, ":0"))
|
||||
warning.assert_called_once()
|
||||
|
||||
def test_move_module_recursive_preserves_parameter_type(self):
|
||||
host = _Host(project_root="K:/fake_root", device="cpu")
|
||||
module = torch.nn.Linear(2, 2)
|
||||
with patch.object(host, "_is_on_target_device", return_value=False):
|
||||
host._move_module_recursive(module, "cpu")
|
||||
self.assertIsInstance(module.weight, torch.nn.Parameter)
|
||||
self.assertIsInstance(module.bias, torch.nn.Parameter)
|
||||
|
||||
def test_move_quantized_param_fallback_wraps_parameter(self):
|
||||
host = _Host(project_root="K:/fake_root", device="cpu")
|
||||
param = torch.nn.Parameter(torch.randn(2), requires_grad=True)
|
||||
moved = host._move_quantized_param(param, "cpu")
|
||||
self.assertIsInstance(moved, torch.nn.Parameter)
|
||||
self.assertTrue(moved.requires_grad)
|
||||
|
||||
def test_get_available_checkpoints_returns_expected_list(self):
|
||||
host = _Host(project_root="K:/fake_root")
|
||||
with patch("os.path.exists", return_value=False):
|
||||
self.assertEqual(host.get_available_checkpoints(), [])
|
||||
|
||||
with patch("os.path.exists", return_value=True):
|
||||
self.assertEqual(host.get_available_checkpoints(), [os.path.join("K:/fake_root", "checkpoints")])
|
||||
|
||||
def test_get_available_acestep_v15_models_filters_and_sorts(self):
|
||||
host = _Host(project_root="K:/fake_root")
|
||||
with patch("os.path.exists", return_value=True), patch(
|
||||
"os.listdir",
|
||||
return_value=["acestep-v15-zeta", "acestep-v15-alpha", "not-a-model", "acestep-v15-file"],
|
||||
), patch(
|
||||
"os.path.isdir",
|
||||
side_effect=lambda p: p.endswith("acestep-v15-zeta")
|
||||
or p.endswith("acestep-v15-alpha")
|
||||
or p.endswith("not-a-model"),
|
||||
):
|
||||
self.assertEqual(
|
||||
host.get_available_acestep_v15_models(),
|
||||
["acestep-v15-alpha", "acestep-v15-zeta"],
|
||||
)
|
||||
|
||||
def test_is_turbo_model_uses_config_flag(self):
|
||||
host = _Host(project_root="K:/fake_root", config=None)
|
||||
self.assertFalse(host.is_turbo_model())
|
||||
|
||||
host.config = types.SimpleNamespace(is_turbo=True)
|
||||
self.assertTrue(host.is_turbo_model())
|
||||
|
||||
def test_is_flash_attention_available_rejects_non_cuda(self):
|
||||
host = _Host(project_root="K:/fake_root", device="cpu")
|
||||
self.assertFalse(host.is_flash_attention_available())
|
||||
self.assertFalse(host.is_flash_attention_available(device="mps"))
|
||||
|
||||
def test_is_flash_attention_available_true_when_cuda_and_module_present(self):
|
||||
host = _Host(project_root="K:/fake_root", device="cuda")
|
||||
with patch("torch.cuda.is_available", return_value=True):
|
||||
with patch("torch.cuda.get_device_capability", return_value=(8, 0)):
|
||||
with patch.dict("sys.modules", {"flash_attn": types.SimpleNamespace()}):
|
||||
self.assertTrue(host.is_flash_attention_available())
|
||||
|
||||
def test_is_flash_attention_available_false_when_pre_ampere_gpu(self):
|
||||
host = _Host(project_root="K:/fake_root", device="cuda")
|
||||
with patch("torch.cuda.is_available", return_value=True):
|
||||
with patch("torch.cuda.get_device_capability", return_value=(7, 5)):
|
||||
with patch.dict("sys.modules", {"flash_attn": types.SimpleNamespace()}):
|
||||
self.assertFalse(host.is_flash_attention_available())
|
||||
|
||||
def test_is_flash_attention_available_false_when_module_missing(self):
|
||||
host = _Host(project_root="K:/fake_root", device="cuda")
|
||||
real_import = builtins.__import__
|
||||
|
||||
def fake_import(name, globals=None, locals=None, fromlist=(), level=0):
|
||||
if name == "flash_attn":
|
||||
raise ImportError("flash_attn missing")
|
||||
return real_import(name, globals, locals, fromlist, level)
|
||||
|
||||
with patch("torch.cuda.is_available", return_value=True):
|
||||
with patch("torch.cuda.get_device_capability", return_value=(8, 0)):
|
||||
with patch("builtins.__import__", side_effect=fake_import):
|
||||
self.assertFalse(host.is_flash_attention_available())
|
||||
|
||||
def test_empty_cache_routes_to_cuda(self):
|
||||
host = _Host(project_root="K:/fake_root", device="cuda")
|
||||
with patch("torch.cuda.is_available", return_value=True), patch("torch.cuda.empty_cache") as empty_cache:
|
||||
host._empty_cache()
|
||||
empty_cache.assert_called_once()
|
||||
|
||||
def test_empty_cache_routes_to_xpu(self):
|
||||
host = _Host(project_root="K:/fake_root", device="xpu")
|
||||
empty_cache = Mock()
|
||||
xpu_stub = types.SimpleNamespace(is_available=lambda: True, empty_cache=empty_cache)
|
||||
with patch("torch.xpu", new=xpu_stub, create=True):
|
||||
host._empty_cache()
|
||||
empty_cache.assert_called_once()
|
||||
|
||||
def test_empty_cache_routes_to_mps(self):
|
||||
host = _Host(project_root="K:/fake_root", device="mps")
|
||||
with patch("torch.backends.mps.is_available", return_value=True), patch("torch.mps.empty_cache") as empty_cache:
|
||||
host._empty_cache()
|
||||
empty_cache.assert_called_once()
|
||||
|
||||
def test_synchronize_routes_to_cuda(self):
|
||||
host = _Host(project_root="K:/fake_root", device="cuda")
|
||||
with patch("torch.cuda.is_available", return_value=True), patch("torch.cuda.synchronize") as sync:
|
||||
host._synchronize()
|
||||
sync.assert_called_once()
|
||||
|
||||
def test_synchronize_routes_to_xpu(self):
|
||||
host = _Host(project_root="K:/fake_root", device="xpu")
|
||||
sync = Mock()
|
||||
xpu_stub = types.SimpleNamespace(is_available=lambda: True, synchronize=sync)
|
||||
with patch("torch.xpu", new=xpu_stub, create=True):
|
||||
host._synchronize()
|
||||
sync.assert_called_once()
|
||||
|
||||
def test_synchronize_routes_to_mps(self):
|
||||
host = _Host(project_root="K:/fake_root", device="mps")
|
||||
with patch("torch.backends.mps.is_available", return_value=True), patch("torch.mps.synchronize") as sync:
|
||||
host._synchronize()
|
||||
sync.assert_called_once()
|
||||
|
||||
def test_memory_queries_use_cuda_only(self):
|
||||
host = _Host(project_root="K:/fake_root", device="cpu")
|
||||
self.assertEqual(host._memory_allocated(), 0)
|
||||
self.assertEqual(host._max_memory_allocated(), 0)
|
||||
|
||||
host.device = "cuda"
|
||||
with patch("torch.cuda.is_available", return_value=True):
|
||||
with patch("torch.cuda.memory_allocated", return_value=123), patch(
|
||||
"torch.cuda.max_memory_allocated", return_value=456
|
||||
):
|
||||
self.assertEqual(host._memory_allocated(), 123)
|
||||
self.assertEqual(host._max_memory_allocated(), 456)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
133
acestep/core/generation/handler/io_audio.py
Normal file
133
acestep/core/generation/handler/io_audio.py
Normal file
|
|
@ -0,0 +1,133 @@
|
|||
"""Audio IO and normalization helpers for handler decomposition."""
|
||||
|
||||
import math
|
||||
import random
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class IoAudioMixin:
|
||||
"""Mixin containing audio file loading and normalization helpers.
|
||||
|
||||
Depends on host members:
|
||||
- Method: ``is_silence`` (provided by ``MemoryUtilsMixin`` in this decomposition).
|
||||
"""
|
||||
|
||||
def _normalize_audio_to_stereo_48k(self, audio: torch.Tensor, sr: int) -> torch.Tensor:
|
||||
"""Normalize audio tensor to stereo at 48kHz.
|
||||
|
||||
Args:
|
||||
audio: Tensor in [channels, samples] or [samples] format.
|
||||
sr: Source sample rate.
|
||||
|
||||
Returns:
|
||||
Tensor in [2, samples] at 48kHz, clamped to [-1.0, 1.0].
|
||||
"""
|
||||
if audio.shape[0] == 1:
|
||||
audio = torch.cat([audio, audio], dim=0)
|
||||
|
||||
audio = audio[:2]
|
||||
|
||||
if sr != 48000:
|
||||
import torchaudio
|
||||
audio = torchaudio.transforms.Resample(sr, 48000)(audio)
|
||||
|
||||
return torch.clamp(audio, -1.0, 1.0)
|
||||
|
||||
def process_target_audio(self, audio_file: Optional[str]) -> Optional[torch.Tensor]:
|
||||
"""Load and normalize target audio file.
|
||||
|
||||
Args:
|
||||
audio_file: Path to target audio file.
|
||||
|
||||
Returns:
|
||||
Normalized stereo 48kHz tensor, or ``None`` on error/empty input.
|
||||
"""
|
||||
if audio_file is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
import soundfile as sf
|
||||
audio_np, sr = sf.read(audio_file, dtype="float32")
|
||||
if audio_np.ndim == 1:
|
||||
audio = torch.from_numpy(audio_np).unsqueeze(0)
|
||||
else:
|
||||
audio = torch.from_numpy(audio_np.T)
|
||||
return self._normalize_audio_to_stereo_48k(audio, sr)
|
||||
except (OSError, RuntimeError, ValueError):
|
||||
logger.exception("[process_target_audio] Error processing target audio")
|
||||
return None
|
||||
|
||||
def process_reference_audio(self, audio_file: Optional[str]) -> Optional[torch.Tensor]:
|
||||
"""Load and normalize reference audio, then sample 3x10s segments.
|
||||
|
||||
Args:
|
||||
audio_file: Path to reference audio file.
|
||||
|
||||
Returns:
|
||||
30-second stereo tensor from sampled front/middle/back segments,
|
||||
or ``None`` for empty/silent/error cases.
|
||||
"""
|
||||
if audio_file is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
import torchaudio
|
||||
audio, sr = torchaudio.load(audio_file)
|
||||
logger.debug(f"[process_reference_audio] Reference audio shape: {audio.shape}")
|
||||
logger.debug(f"[process_reference_audio] Reference audio sample rate: {sr}")
|
||||
logger.debug(
|
||||
f"[process_reference_audio] Reference audio duration: {audio.shape[-1] / sr:.6f} seconds"
|
||||
)
|
||||
|
||||
audio = self._normalize_audio_to_stereo_48k(audio, sr)
|
||||
if self.is_silence(audio):
|
||||
return None
|
||||
|
||||
target_frames = 30 * 48000
|
||||
segment_frames = 10 * 48000
|
||||
|
||||
if audio.shape[-1] < target_frames:
|
||||
repeat_times = math.ceil(target_frames / audio.shape[-1])
|
||||
audio = audio.repeat(1, repeat_times)
|
||||
|
||||
total_frames = audio.shape[-1]
|
||||
segment_size = total_frames // 3
|
||||
|
||||
front_start = random.randint(0, max(0, segment_size - segment_frames))
|
||||
front_audio = audio[:, front_start : front_start + segment_frames]
|
||||
|
||||
middle_start = segment_size + random.randint(0, max(0, segment_size - segment_frames))
|
||||
middle_audio = audio[:, middle_start : middle_start + segment_frames]
|
||||
|
||||
back_start = 2 * segment_size + random.randint(
|
||||
0, max(0, (total_frames - 2 * segment_size) - segment_frames)
|
||||
)
|
||||
back_audio = audio[:, back_start : back_start + segment_frames]
|
||||
|
||||
return torch.cat([front_audio, middle_audio, back_audio], dim=-1)
|
||||
except (OSError, RuntimeError, ValueError):
|
||||
logger.exception("[process_reference_audio] Error processing reference audio")
|
||||
return None
|
||||
|
||||
def process_src_audio(self, audio_file: Optional[str]) -> Optional[torch.Tensor]:
|
||||
"""Load and normalize source audio for remix/extract flows.
|
||||
|
||||
Args:
|
||||
audio_file: Path to source audio file.
|
||||
|
||||
Returns:
|
||||
Normalized stereo 48kHz tensor, or ``None`` on error/empty input.
|
||||
"""
|
||||
if audio_file is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
import torchaudio
|
||||
audio, sr = torchaudio.load(audio_file)
|
||||
return self._normalize_audio_to_stereo_48k(audio, sr)
|
||||
except (OSError, RuntimeError, ValueError):
|
||||
logger.exception("[process_src_audio] Error processing source audio")
|
||||
return None
|
||||
76
acestep/core/generation/handler/io_audio_test.py
Normal file
76
acestep/core/generation/handler/io_audio_test.py
Normal file
|
|
@ -0,0 +1,76 @@
|
|||
"""Unit tests for audio IO mixin extraction."""
|
||||
|
||||
import sys
|
||||
import types
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from acestep.core.generation.handler.io_audio import IoAudioMixin
|
||||
|
||||
|
||||
class _Host(IoAudioMixin):
|
||||
"""Minimal host implementing methods used by ``IoAudioMixin``."""
|
||||
|
||||
def is_silence(self, audio: torch.Tensor) -> bool:
|
||||
"""Treat near-zero tensors as silence."""
|
||||
return torch.all(audio.abs() < 1e-6).item()
|
||||
|
||||
|
||||
def _fake_torchaudio_module(load_fn):
|
||||
"""Create fake ``torchaudio`` module with minimal API used by tests."""
|
||||
module = types.ModuleType("torchaudio")
|
||||
module.load = load_fn
|
||||
module.transforms = types.SimpleNamespace(Resample=lambda *_args, **_kwargs: (lambda x: x))
|
||||
return module
|
||||
|
||||
|
||||
class IoAudioMixinTests(unittest.TestCase):
|
||||
"""Tests for normalization and audio loading helpers."""
|
||||
|
||||
def test_normalize_audio_to_stereo_48k_duplicates_mono_and_clamps(self):
|
||||
"""Mono input should duplicate to stereo and clamp values."""
|
||||
host = _Host()
|
||||
audio = torch.tensor([[2.0, -2.0, 0.5]], dtype=torch.float32)
|
||||
result = host._normalize_audio_to_stereo_48k(audio, 48000)
|
||||
|
||||
self.assertEqual(tuple(result.shape), (2, 3))
|
||||
self.assertTrue(torch.all(result <= 1.0))
|
||||
self.assertTrue(torch.all(result >= -1.0))
|
||||
|
||||
def test_process_target_audio_loads_and_normalizes(self):
|
||||
"""Target audio should be loaded and normalized through helper."""
|
||||
host = _Host()
|
||||
fake_np = np.array([0.1, -0.1, 0.2], dtype=np.float32)
|
||||
fake_sf = types.ModuleType("soundfile")
|
||||
fake_sf.read = lambda *_args, **_kwargs: (fake_np, 32000)
|
||||
|
||||
with patch.dict(sys.modules, {"soundfile": fake_sf}):
|
||||
with patch.object(host, "_normalize_audio_to_stereo_48k", return_value=torch.zeros(2, 3)) as norm:
|
||||
result = host.process_target_audio("fake.wav")
|
||||
|
||||
self.assertIsNotNone(result)
|
||||
norm.assert_called_once()
|
||||
|
||||
def test_process_src_audio_handles_load_error(self):
|
||||
"""Source audio processing should return None on load failure."""
|
||||
host = _Host()
|
||||
fake_ta = _fake_torchaudio_module(lambda *_args, **_kwargs: (_ for _ in ()).throw(RuntimeError("bad")))
|
||||
with patch.dict(sys.modules, {"torchaudio": fake_ta}):
|
||||
result = host.process_src_audio("bad.wav")
|
||||
self.assertIsNone(result)
|
||||
|
||||
def test_process_reference_audio_returns_none_for_silence(self):
|
||||
"""Reference audio should short-circuit for silent input."""
|
||||
host = _Host()
|
||||
silent = torch.zeros(2, 16, dtype=torch.float32)
|
||||
fake_ta = _fake_torchaudio_module(lambda *_args, **_kwargs: (silent, 48000))
|
||||
with patch.dict(sys.modules, {"torchaudio": fake_ta}):
|
||||
result = host.process_reference_audio("silent.wav")
|
||||
self.assertIsNone(result)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
@ -37,20 +37,46 @@ def load_lora(self, lora_path: str) -> str:
|
|||
return "❌ PEFT library not installed. Please install with: pip install peft"
|
||||
|
||||
try:
|
||||
import copy
|
||||
|
||||
# Memory-efficient state_dict backup instead of deepcopy
|
||||
if self._base_decoder is None:
|
||||
self._base_decoder = copy.deepcopy(self.model.decoder)
|
||||
logger.info("Base decoder backed up")
|
||||
# Log memory before backup
|
||||
if hasattr(self, '_memory_allocated'):
|
||||
mem_before = self._memory_allocated() / (1024**3)
|
||||
logger.info(f"VRAM before LoRA load: {mem_before:.2f}GB")
|
||||
|
||||
# Save only the base model weights as state_dict (CPU to save VRAM)
|
||||
try:
|
||||
state_dict = self.model.decoder.state_dict()
|
||||
if not state_dict:
|
||||
raise ValueError("state_dict is empty - cannot backup decoder")
|
||||
self._base_decoder = {k: v.detach().cpu().clone() for k, v in state_dict.items()}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create state_dict backup: {e}")
|
||||
raise
|
||||
|
||||
# Calculate backup size in MB
|
||||
backup_size_mb = sum(v.numel() * v.element_size() for v in self._base_decoder.values()) / (1024**2)
|
||||
logger.info(f"Base decoder state_dict backed up to CPU ({backup_size_mb:.1f}MB)")
|
||||
else:
|
||||
self.model.decoder = copy.deepcopy(self._base_decoder)
|
||||
logger.info("Restored base decoder before loading new LoRA")
|
||||
# Restore base decoder from state_dict backup
|
||||
logger.info("Restoring base decoder from state_dict backup")
|
||||
load_result = self.model.decoder.load_state_dict(self._base_decoder, strict=False)
|
||||
if load_result.missing_keys:
|
||||
logger.warning(f"Missing keys when restoring decoder: {load_result.missing_keys[:5]}")
|
||||
if load_result.unexpected_keys:
|
||||
logger.warning(f"Unexpected keys when restoring decoder: {load_result.unexpected_keys[:5]}")
|
||||
self.model.decoder = self.model.decoder.to(self.device).to(self.dtype)
|
||||
|
||||
logger.info(f"Loading LoRA adapter from {lora_path}")
|
||||
self.model.decoder = PeftModel.from_pretrained(self.model.decoder, lora_path, is_trainable=False)
|
||||
self.model.decoder = self.model.decoder.to(self.device).to(self.dtype)
|
||||
self.model.decoder.eval()
|
||||
|
||||
# Log memory after LoRA load
|
||||
if hasattr(self, '_memory_allocated'):
|
||||
mem_after = self._memory_allocated() / (1024**3)
|
||||
logger.info(f"VRAM after LoRA load: {mem_after:.2f}GB")
|
||||
|
||||
self.lora_loaded = True
|
||||
self.use_lora = True
|
||||
self._ensure_lora_registry()
|
||||
|
|
@ -81,9 +107,35 @@ def unload_lora(self) -> str:
|
|||
return "❌ Base decoder backup not found. Cannot restore."
|
||||
|
||||
try:
|
||||
import copy
|
||||
|
||||
self.model.decoder = copy.deepcopy(self._base_decoder)
|
||||
# Log memory before unload (track before any operations)
|
||||
mem_before = None
|
||||
if hasattr(self, '_memory_allocated'):
|
||||
mem_before = self._memory_allocated() / (1024**3)
|
||||
logger.info(f"VRAM before LoRA unload: {mem_before:.2f}GB")
|
||||
|
||||
# Get the base model from the PEFT wrapper if it exists
|
||||
# This is more memory-efficient than recreating from state_dict
|
||||
from peft import PeftModel
|
||||
|
||||
if isinstance(self.model.decoder, PeftModel):
|
||||
logger.info("Extracting base model from PEFT wrapper")
|
||||
# PEFT's get_base_model() returns the underlying base model without copying
|
||||
self.model.decoder = self.model.decoder.get_base_model()
|
||||
# Restore state_dict from backup to ensure clean state
|
||||
load_result = self.model.decoder.load_state_dict(self._base_decoder, strict=False)
|
||||
if load_result.missing_keys:
|
||||
logger.warning(f"Missing keys when restoring decoder: {load_result.missing_keys[:5]}")
|
||||
if load_result.unexpected_keys:
|
||||
logger.warning(f"Unexpected keys when restoring decoder: {load_result.unexpected_keys[:5]}")
|
||||
else:
|
||||
# Fallback: restore from state_dict backup
|
||||
logger.info("Restoring base decoder from state_dict backup")
|
||||
load_result = self.model.decoder.load_state_dict(self._base_decoder, strict=False)
|
||||
if load_result.missing_keys:
|
||||
logger.warning(f"Missing keys when restoring decoder: {load_result.missing_keys[:5]}")
|
||||
if load_result.unexpected_keys:
|
||||
logger.warning(f"Unexpected keys when restoring decoder: {load_result.unexpected_keys[:5]}")
|
||||
|
||||
self.model.decoder = self.model.decoder.to(self.device).to(self.dtype)
|
||||
self.model.decoder.eval()
|
||||
|
||||
|
|
@ -99,6 +151,12 @@ def unload_lora(self) -> str:
|
|||
self._lora_active_adapter = None
|
||||
self._lora_scale_state = {}
|
||||
|
||||
# Log memory after unload
|
||||
if mem_before is not None and hasattr(self, '_memory_allocated'):
|
||||
mem_after = self._memory_allocated() / (1024**3)
|
||||
freed = mem_before - mem_after
|
||||
logger.info(f"VRAM after LoRA unload: {mem_after:.2f}GB (freed: {freed:.2f}GB)")
|
||||
|
||||
logger.info("LoRA unloaded, base decoder restored")
|
||||
return "✅ LoRA unloaded, using base model"
|
||||
except Exception as e:
|
||||
|
|
|
|||
166
acestep/core/generation/handler/memory_utils.py
Normal file
166
acestep/core/generation/handler/memory_utils.py
Normal file
|
|
@ -0,0 +1,166 @@
|
|||
"""Memory and VRAM helper methods for handler decomposition."""
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
from acestep.gpu_config import get_effective_free_vram_gb, get_global_gpu_config
|
||||
|
||||
|
||||
class MemoryUtilsMixin:
|
||||
"""Mixin containing memory sizing and VRAM guard helpers.
|
||||
|
||||
Depends on host members:
|
||||
- Attribute: ``device``.
|
||||
"""
|
||||
|
||||
def is_silence(self, audio: torch.Tensor) -> bool:
|
||||
"""Return True when audio is effectively silent."""
|
||||
return bool(torch.all(audio.abs() < 1e-6))
|
||||
|
||||
def _get_system_memory_gb(self) -> Optional[float]:
|
||||
"""Return total system RAM in GB when available."""
|
||||
try:
|
||||
page_size = os.sysconf("SC_PAGE_SIZE")
|
||||
page_count = os.sysconf("SC_PHYS_PAGES")
|
||||
if page_size and page_count:
|
||||
return (page_size * page_count) / (1024**3)
|
||||
except (ValueError, OSError, AttributeError):
|
||||
return None
|
||||
return None
|
||||
|
||||
def _get_effective_mps_memory_gb(self) -> Optional[float]:
|
||||
"""Best-effort MPS memory estimate (recommended max or system RAM)."""
|
||||
if hasattr(torch, "mps") and hasattr(torch.mps, "recommended_max_memory"):
|
||||
try:
|
||||
return torch.mps.recommended_max_memory() / (1024**3)
|
||||
except Exception:
|
||||
pass
|
||||
system_gb = self._get_system_memory_gb()
|
||||
if system_gb is None:
|
||||
return None
|
||||
return system_gb * 0.75
|
||||
|
||||
VAE_DECODE_MAX_CHUNK_SIZE = 512
|
||||
|
||||
def _get_auto_decode_chunk_size(self) -> int:
|
||||
"""Choose a conservative VAE decode chunk size based on available memory."""
|
||||
override = os.environ.get("ACESTEP_VAE_DECODE_CHUNK_SIZE")
|
||||
if override:
|
||||
try:
|
||||
value = int(override)
|
||||
if value > 0:
|
||||
return value
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
max_chunk = self.VAE_DECODE_MAX_CHUNK_SIZE
|
||||
|
||||
if self.device == "mps":
|
||||
mem_gb = self._get_effective_mps_memory_gb()
|
||||
if mem_gb is not None:
|
||||
if mem_gb >= 48:
|
||||
return min(1536, max_chunk)
|
||||
if mem_gb >= 24:
|
||||
return min(1024, max_chunk)
|
||||
return min(512, max_chunk)
|
||||
|
||||
if self.device == "cuda" or (isinstance(self.device, str) and self.device.startswith("cuda")):
|
||||
try:
|
||||
free_gb = get_effective_free_vram_gb()
|
||||
except Exception:
|
||||
free_gb = 0
|
||||
logger.debug(f"[_get_auto_decode_chunk_size] Effective free VRAM: {free_gb:.2f} GB")
|
||||
if free_gb >= 24.0:
|
||||
return min(512, max_chunk)
|
||||
if free_gb >= 16.0:
|
||||
return min(384, max_chunk)
|
||||
if free_gb >= 12.0:
|
||||
return min(256, max_chunk)
|
||||
return min(128, max_chunk)
|
||||
return min(256, max_chunk)
|
||||
|
||||
def _should_offload_wav_to_cpu(self) -> bool:
|
||||
"""Decide whether to offload decoded wavs to CPU for memory safety."""
|
||||
override = os.environ.get("ACESTEP_MPS_DECODE_OFFLOAD")
|
||||
if override:
|
||||
return override.lower() in ("1", "true", "yes")
|
||||
if self.device == "mps":
|
||||
mem_gb = self._get_effective_mps_memory_gb()
|
||||
if mem_gb is not None and mem_gb >= 32:
|
||||
return False
|
||||
return True
|
||||
if self.device == "cuda" or (isinstance(self.device, str) and self.device.startswith("cuda")):
|
||||
try:
|
||||
free_gb = get_effective_free_vram_gb()
|
||||
logger.debug(f"[_should_offload_wav_to_cpu] Effective free VRAM: {free_gb:.2f} GB")
|
||||
if free_gb >= 24.0:
|
||||
return False
|
||||
except Exception:
|
||||
pass
|
||||
return True
|
||||
|
||||
def _vram_guard_reduce_batch(
|
||||
self,
|
||||
batch_size: int,
|
||||
audio_duration: Optional[float] = None,
|
||||
use_lm: bool = False,
|
||||
) -> int:
|
||||
"""Auto-reduce batch_size when free VRAM is too tight."""
|
||||
if batch_size <= 1:
|
||||
return batch_size
|
||||
|
||||
device = self.device
|
||||
if device == "cpu" or device == "mps":
|
||||
return batch_size
|
||||
|
||||
if self.offload_to_cpu:
|
||||
gpu_config = get_global_gpu_config()
|
||||
if gpu_config is not None:
|
||||
tier_max = gpu_config.max_batch_size_with_lm
|
||||
if batch_size <= tier_max:
|
||||
logger.debug(
|
||||
f"[VRAM guard] offload_to_cpu=True, batch_size={batch_size} <= "
|
||||
f"tier limit {tier_max} — skipping dynamic VRAM check"
|
||||
)
|
||||
return batch_size
|
||||
|
||||
try:
|
||||
free_gb = get_effective_free_vram_gb()
|
||||
except Exception:
|
||||
return batch_size
|
||||
|
||||
duration_sec = float(audio_duration) if audio_duration and float(audio_duration) > 0 else 60.0
|
||||
per_sample_gb = 0.5 + max(0.0, 0.15 * (duration_sec - 60.0) / 60.0)
|
||||
if hasattr(self, "model") and self.model is not None:
|
||||
model_name = getattr(self, "config_path", "") or ""
|
||||
if "base" in model_name.lower():
|
||||
per_sample_gb *= 2.0
|
||||
|
||||
safety_margin_gb = 1.5
|
||||
available_for_batch = free_gb - safety_margin_gb
|
||||
if available_for_batch <= 0:
|
||||
logger.warning(f"[VRAM guard] Only {free_gb:.1f} GB free — reducing batch_size to 1")
|
||||
return 1
|
||||
|
||||
max_safe_batch = max(1, int(available_for_batch / per_sample_gb))
|
||||
if max_safe_batch < batch_size:
|
||||
logger.warning(
|
||||
f"[VRAM guard] Free VRAM {free_gb:.1f} GB can safely fit ~{max_safe_batch} samples "
|
||||
f"(requested {batch_size}). Reducing batch_size to {max_safe_batch}."
|
||||
)
|
||||
return max_safe_batch
|
||||
return batch_size
|
||||
|
||||
def _get_vae_dtype(self, device: Optional[str] = None) -> torch.dtype:
|
||||
"""Get VAE dtype based on target device and GPU tier."""
|
||||
target_device = device or self.device
|
||||
if target_device in ["cuda", "xpu"]:
|
||||
return torch.bfloat16
|
||||
if target_device == "mps":
|
||||
return torch.float16
|
||||
if target_device == "cpu":
|
||||
return torch.float32
|
||||
return self.dtype
|
||||
79
acestep/core/generation/handler/metadata_utils.py
Normal file
79
acestep/core/generation/handler/metadata_utils.py
Normal file
|
|
@ -0,0 +1,79 @@
|
|||
"""Metadata formatting helpers for handler decomposition."""
|
||||
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
|
||||
class MetadataMixin:
|
||||
"""Mixin containing metadata parsing and formatting helpers.
|
||||
|
||||
Depends on host members:
|
||||
- No cross-mixin runtime dependencies; pure formatting/parsing helpers.
|
||||
"""
|
||||
|
||||
def _create_default_meta(self) -> str:
|
||||
"""Create default metadata string."""
|
||||
return (
|
||||
"- bpm: N/A\n"
|
||||
"- timesignature: N/A\n"
|
||||
"- keyscale: N/A\n"
|
||||
"- duration: 30 seconds\n"
|
||||
)
|
||||
|
||||
def _dict_to_meta_string(self, meta_dict: Dict[str, Any]) -> str:
|
||||
"""Convert metadata dict to formatted string."""
|
||||
bpm = meta_dict.get("bpm", meta_dict.get("tempo", "N/A"))
|
||||
timesignature = meta_dict.get("timesignature", meta_dict.get("time_signature", "N/A"))
|
||||
keyscale = meta_dict.get("keyscale", meta_dict.get("key", meta_dict.get("scale", "N/A")))
|
||||
duration = meta_dict.get("duration", meta_dict.get("length", 30))
|
||||
|
||||
if isinstance(duration, (int, float)):
|
||||
duration = f"{int(duration)} seconds"
|
||||
elif not isinstance(duration, str):
|
||||
duration = "30 seconds"
|
||||
|
||||
return (
|
||||
f"- bpm: {bpm}\n"
|
||||
f"- timesignature: {timesignature}\n"
|
||||
f"- keyscale: {keyscale}\n"
|
||||
f"- duration: {duration}\n"
|
||||
)
|
||||
|
||||
def _parse_metas(self, metas: List[Union[str, Dict[str, Any]]]) -> List[str]:
|
||||
"""Parse and normalize metadata values with safe fallbacks."""
|
||||
parsed_metas = []
|
||||
for meta in metas:
|
||||
if meta is None:
|
||||
parsed_meta = self._create_default_meta()
|
||||
elif isinstance(meta, str):
|
||||
parsed_meta = meta
|
||||
elif isinstance(meta, dict):
|
||||
parsed_meta = self._dict_to_meta_string(meta)
|
||||
else:
|
||||
parsed_meta = self._create_default_meta()
|
||||
parsed_metas.append(parsed_meta)
|
||||
return parsed_metas
|
||||
|
||||
def prepare_metadata(
|
||||
self, bpm: Optional[Union[int, str]], key_scale: str, time_signature: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Build metadata dict for generation."""
|
||||
return self._build_metadata_dict(bpm, key_scale, time_signature)
|
||||
|
||||
def _build_metadata_dict(
|
||||
self,
|
||||
bpm: Optional[Union[int, str]],
|
||||
key_scale: str,
|
||||
time_signature: str,
|
||||
duration: Optional[float] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Build metadata dictionary with defaults for missing fields."""
|
||||
metadata_dict: Dict[str, Any] = {}
|
||||
metadata_dict["bpm"] = bpm if bpm else "N/A"
|
||||
metadata_dict["keyscale"] = key_scale if key_scale.strip() else "N/A"
|
||||
if time_signature.strip() and time_signature != "N/A" and time_signature:
|
||||
metadata_dict["timesignature"] = time_signature
|
||||
else:
|
||||
metadata_dict["timesignature"] = "N/A"
|
||||
if duration is not None:
|
||||
metadata_dict["duration"] = f"{int(duration)} seconds"
|
||||
return metadata_dict
|
||||
151
acestep/core/generation/handler/padding_utils.py
Normal file
151
acestep/core/generation/handler/padding_utils.py
Normal file
|
|
@ -0,0 +1,151 @@
|
|||
"""Padding helpers for handler batch preparation."""
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class PaddingMixin:
|
||||
"""Mixin containing repaint/lego padding helpers.
|
||||
|
||||
Depends on host members:
|
||||
- Method: ``create_target_wavs`` (provided by ``TaskUtilsMixin`` in this decomposition).
|
||||
"""
|
||||
|
||||
def prepare_padding_info(
|
||||
self,
|
||||
actual_batch_size,
|
||||
processed_src_audio,
|
||||
audio_duration,
|
||||
repainting_start,
|
||||
repainting_end,
|
||||
is_repaint_task,
|
||||
is_lego_task,
|
||||
is_cover_task,
|
||||
can_use_repainting,
|
||||
):
|
||||
"""Prepare padded target wavs and repaint coordinates for each batch item."""
|
||||
try:
|
||||
target_wavs_batch = []
|
||||
# Store padding info for each batch item to adjust repainting coordinates
|
||||
padding_info_batch = []
|
||||
for i in range(actual_batch_size):
|
||||
if processed_src_audio is not None:
|
||||
if is_cover_task:
|
||||
# Cover task: Use src_audio directly without padding
|
||||
batch_target_wavs = processed_src_audio
|
||||
padding_info_batch.append({"left_padding_duration": 0.0, "right_padding_duration": 0.0})
|
||||
elif is_repaint_task or is_lego_task:
|
||||
# Repaint/lego task: May need padding for outpainting
|
||||
src_audio_duration = processed_src_audio.shape[-1] / 48000.0
|
||||
|
||||
# Determine actual end time
|
||||
if repainting_end is None or repainting_end < 0:
|
||||
actual_end = src_audio_duration
|
||||
else:
|
||||
actual_end = repainting_end
|
||||
|
||||
left_padding_duration = max(0, -repainting_start) if repainting_start is not None else 0
|
||||
right_padding_duration = max(0, actual_end - src_audio_duration)
|
||||
|
||||
# Create padded audio
|
||||
left_padding_frames = int(left_padding_duration * 48000)
|
||||
right_padding_frames = int(right_padding_duration * 48000)
|
||||
|
||||
if left_padding_frames > 0 or right_padding_frames > 0:
|
||||
# Pad the src audio
|
||||
batch_target_wavs = torch.nn.functional.pad(
|
||||
processed_src_audio, (left_padding_frames, right_padding_frames), "constant", 0
|
||||
)
|
||||
else:
|
||||
batch_target_wavs = processed_src_audio
|
||||
|
||||
# Store padding info for coordinate adjustment
|
||||
padding_info_batch.append(
|
||||
{
|
||||
"left_padding_duration": left_padding_duration,
|
||||
"right_padding_duration": right_padding_duration,
|
||||
}
|
||||
)
|
||||
else:
|
||||
# Other tasks: Use src_audio directly without padding
|
||||
batch_target_wavs = processed_src_audio
|
||||
padding_info_batch.append({"left_padding_duration": 0.0, "right_padding_duration": 0.0})
|
||||
else:
|
||||
padding_info_batch.append({"left_padding_duration": 0.0, "right_padding_duration": 0.0})
|
||||
if audio_duration is not None and float(audio_duration) > 0:
|
||||
batch_target_wavs = self.create_target_wavs(float(audio_duration))
|
||||
else:
|
||||
import random
|
||||
|
||||
random_duration = random.uniform(10.0, 120.0)
|
||||
batch_target_wavs = self.create_target_wavs(random_duration)
|
||||
target_wavs_batch.append(batch_target_wavs)
|
||||
|
||||
# Stack target_wavs into batch tensor
|
||||
# Ensure all tensors have the same shape by padding to max length
|
||||
max_frames = max(wav.shape[-1] for wav in target_wavs_batch)
|
||||
padded_target_wavs = []
|
||||
for wav in target_wavs_batch:
|
||||
if wav.shape[-1] < max_frames:
|
||||
pad_frames = max_frames - wav.shape[-1]
|
||||
padded_wav = torch.nn.functional.pad(wav, (0, pad_frames), "constant", 0)
|
||||
padded_target_wavs.append(padded_wav)
|
||||
else:
|
||||
padded_target_wavs.append(wav)
|
||||
|
||||
target_wavs_tensor = torch.stack(padded_target_wavs, dim=0) # [batch_size, 2, frames]
|
||||
|
||||
if can_use_repainting:
|
||||
# Repaint task: Set repainting parameters
|
||||
if repainting_start is None:
|
||||
repainting_start_batch = None
|
||||
elif isinstance(repainting_start, (int, float)):
|
||||
if processed_src_audio is not None:
|
||||
adjusted_start = repainting_start + padding_info_batch[0]["left_padding_duration"]
|
||||
repainting_start_batch = [adjusted_start] * actual_batch_size
|
||||
else:
|
||||
repainting_start_batch = [repainting_start] * actual_batch_size
|
||||
else:
|
||||
# List input - adjust each item
|
||||
repainting_start_batch = []
|
||||
for i in range(actual_batch_size):
|
||||
if processed_src_audio is not None:
|
||||
adjusted_start = repainting_start[i] + padding_info_batch[i]["left_padding_duration"]
|
||||
repainting_start_batch.append(adjusted_start)
|
||||
else:
|
||||
repainting_start_batch.append(repainting_start[i])
|
||||
|
||||
# Handle repainting_end - use src audio duration if not specified or negative
|
||||
if processed_src_audio is not None:
|
||||
# If src audio is provided, use its duration as default end
|
||||
src_audio_duration = processed_src_audio.shape[-1] / 48000.0
|
||||
if repainting_end is None or repainting_end < 0:
|
||||
# Use src audio duration (before padding), then adjust for padding
|
||||
adjusted_end = src_audio_duration + padding_info_batch[0]["left_padding_duration"]
|
||||
repainting_end_batch = [adjusted_end] * actual_batch_size
|
||||
else:
|
||||
# Adjust repainting_end to be relative to padded audio
|
||||
adjusted_end = repainting_end + padding_info_batch[0]["left_padding_duration"]
|
||||
repainting_end_batch = [adjusted_end] * actual_batch_size
|
||||
else:
|
||||
# No src audio - repainting doesn't make sense without it
|
||||
if repainting_end is None or repainting_end < 0:
|
||||
repainting_end_batch = None
|
||||
elif isinstance(repainting_end, (int, float)):
|
||||
repainting_end_batch = [repainting_end] * actual_batch_size
|
||||
else:
|
||||
# List input - adjust each item
|
||||
repainting_end_batch = []
|
||||
for i in range(actual_batch_size):
|
||||
repainting_end_batch.append(repainting_end[i])
|
||||
else:
|
||||
# All other tasks (cover, text2music, extract, complete): No repainting
|
||||
# Only repaint and lego tasks should have repainting parameters
|
||||
repainting_start_batch = None
|
||||
repainting_end_batch = None
|
||||
|
||||
return repainting_start_batch, repainting_end_batch, target_wavs_tensor
|
||||
except (TypeError, ValueError, RuntimeError, IndexError):
|
||||
logger.exception("[prepare_padding_info] Error preparing padding information")
|
||||
fallback = torch.stack([self.create_target_wavs(30.0) for _ in range(actual_batch_size)], dim=0)
|
||||
return None, None, fallback
|
||||
162
acestep/core/generation/handler/prompt_utils.py
Normal file
162
acestep/core/generation/handler/prompt_utils.py
Normal file
|
|
@ -0,0 +1,162 @@
|
|||
"""Prompt and text-input helpers for handler decomposition."""
|
||||
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
from acestep.constants import DEFAULT_DIT_INSTRUCTION, SFT_GEN_PROMPT
|
||||
|
||||
|
||||
class PromptMixin:
|
||||
"""Mixin containing prompt formatting and text-encoder helpers.
|
||||
|
||||
Depends on host members:
|
||||
- Attributes: ``text_tokenizer``, ``text_encoder``, ``device``, ``dtype``.
|
||||
- Methods: ``_parse_metas`` (from ``MetadataMixin``), ``_load_model_context``
|
||||
(from ``InitServiceMixin``).
|
||||
"""
|
||||
|
||||
def _format_instruction(self, instruction: str) -> str:
|
||||
"""Ensure instruction ends with a colon."""
|
||||
if not instruction.endswith(":"):
|
||||
instruction = instruction + ":"
|
||||
return instruction
|
||||
|
||||
def _format_lyrics(self, lyrics: str, language: str) -> str:
|
||||
"""Format lyrics text with language header."""
|
||||
return f"# Languages\n{language}\n\n# Lyric\n{lyrics}<|endoftext|>"
|
||||
|
||||
def _pad_sequences(
|
||||
self, sequences: List[torch.Tensor], max_length: int, pad_value: int = 0
|
||||
) -> torch.Tensor:
|
||||
"""Pad sequence tensors to the same length."""
|
||||
return torch.stack(
|
||||
[
|
||||
torch.nn.functional.pad(seq, (0, max_length - len(seq)), "constant", pad_value)
|
||||
for seq in sequences
|
||||
]
|
||||
)
|
||||
|
||||
def extract_caption_from_sft_format(self, caption: str) -> str:
|
||||
"""Extract caption body from SFT-formatted prompt when present."""
|
||||
try:
|
||||
if "# Instruction" in caption and "# Caption" in caption:
|
||||
pattern = r"#\s*Caption\s*\n(.*?)(?:\n\s*#\s*Metas|$)"
|
||||
match = re.search(pattern, caption, re.DOTALL)
|
||||
if match:
|
||||
return match.group(1).strip()
|
||||
return caption
|
||||
except (AttributeError, TypeError, re.error):
|
||||
logger.exception("[extract_caption_from_sft_format] Error extracting caption")
|
||||
return caption
|
||||
|
||||
def build_dit_inputs(
|
||||
self,
|
||||
task: str,
|
||||
instruction: Optional[str],
|
||||
caption: str,
|
||||
lyrics: str,
|
||||
metas: Optional[Union[str, Dict[str, Any]]] = None,
|
||||
vocal_language: str = "en",
|
||||
) -> Tuple[str, str]:
|
||||
"""Build caption and lyric input text for DiT branches.
|
||||
|
||||
Args:
|
||||
task: Task name (currently informational; reserved for task-specific formatting).
|
||||
instruction: Instruction text; falls back to default when empty.
|
||||
caption: Caption fallback value.
|
||||
lyrics: Raw lyric text.
|
||||
metas: Optional metadata (string or dict) that may include caption/language.
|
||||
vocal_language: Fallback lyric language when not present in metadata.
|
||||
|
||||
Returns:
|
||||
Tuple of ``(caption_input, lyrics_input)`` for caption and lyric encoder branches.
|
||||
"""
|
||||
final_instruction = self._format_instruction(instruction or DEFAULT_DIT_INSTRUCTION)
|
||||
actual_caption = caption
|
||||
actual_language = vocal_language
|
||||
|
||||
if metas is not None:
|
||||
try:
|
||||
if isinstance(metas, str):
|
||||
parsed_metas = self._parse_metas([metas])
|
||||
meta_dict = parsed_metas[0] if parsed_metas and isinstance(parsed_metas[0], dict) else {}
|
||||
elif isinstance(metas, dict):
|
||||
meta_dict = metas
|
||||
else:
|
||||
meta_dict = {}
|
||||
except (TypeError, ValueError, KeyError, IndexError):
|
||||
logger.exception("[build_dit_inputs] Error parsing metas")
|
||||
meta_dict = {}
|
||||
if "caption" in meta_dict and meta_dict["caption"]:
|
||||
actual_caption = str(meta_dict["caption"])
|
||||
if "language" in meta_dict and meta_dict["language"]:
|
||||
actual_language = str(meta_dict["language"])
|
||||
|
||||
parsed_meta = self._parse_metas([metas])[0]
|
||||
caption_input = SFT_GEN_PROMPT.format(final_instruction, actual_caption, parsed_meta)
|
||||
lyrics_input = self._format_lyrics(lyrics, actual_language)
|
||||
return caption_input, lyrics_input
|
||||
|
||||
def _get_text_hidden_states(self, text_prompt: str) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Get hidden states and attention mask from text encoder."""
|
||||
if self.text_tokenizer is None or self.text_encoder is None:
|
||||
raise ValueError("Text encoder not initialized")
|
||||
|
||||
try:
|
||||
with self._load_model_context("text_encoder"):
|
||||
text_inputs = self.text_tokenizer(
|
||||
text_prompt,
|
||||
padding="longest",
|
||||
truncation=True,
|
||||
max_length=256,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids.to(self.device)
|
||||
text_attention_mask = text_inputs.attention_mask.to(self.device).bool()
|
||||
|
||||
with torch.inference_mode():
|
||||
text_outputs = self.text_encoder(text_input_ids)
|
||||
if hasattr(text_outputs, "last_hidden_state"):
|
||||
text_hidden_states = text_outputs.last_hidden_state
|
||||
elif isinstance(text_outputs, tuple):
|
||||
text_hidden_states = text_outputs[0]
|
||||
else:
|
||||
text_hidden_states = text_outputs
|
||||
|
||||
text_hidden_states = text_hidden_states.to(self.dtype)
|
||||
return text_hidden_states, text_attention_mask
|
||||
except (AttributeError, RuntimeError, TypeError, ValueError):
|
||||
logger.exception("[_get_text_hidden_states] Failed to encode text prompt")
|
||||
raise
|
||||
|
||||
def _extract_caption_and_language(
|
||||
self,
|
||||
metas: List[Union[str, Dict[str, Any]]],
|
||||
captions: List[str],
|
||||
vocal_languages: List[str],
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
"""Extract caption/language values from metas with fallback values."""
|
||||
actual_captions = list(captions)
|
||||
actual_languages = list(vocal_languages)
|
||||
|
||||
for i, meta in enumerate(metas):
|
||||
if i >= len(actual_captions):
|
||||
break
|
||||
|
||||
meta_dict = None
|
||||
if isinstance(meta, str):
|
||||
parsed = self._parse_metas([meta])
|
||||
if parsed and isinstance(parsed[0], dict):
|
||||
meta_dict = parsed[0]
|
||||
elif isinstance(meta, dict):
|
||||
meta_dict = meta
|
||||
|
||||
if meta_dict:
|
||||
if "caption" in meta_dict and meta_dict["caption"]:
|
||||
actual_captions[i] = str(meta_dict["caption"])
|
||||
if "language" in meta_dict and meta_dict["language"]:
|
||||
actual_languages[i] = str(meta_dict["language"])
|
||||
return actual_captions, actual_languages
|
||||
123
acestep/core/generation/handler/task_utils.py
Normal file
123
acestep/core/generation/handler/task_utils.py
Normal file
|
|
@ -0,0 +1,123 @@
|
|||
"""Task and seed helpers for handler decomposition."""
|
||||
|
||||
import random
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
from acestep.constants import TASK_INSTRUCTIONS
|
||||
|
||||
|
||||
class TaskUtilsMixin:
|
||||
"""Mixin containing generation task and seed utility helpers.
|
||||
|
||||
Depends on host members:
|
||||
- No required cross-mixin attributes for seed/instruction helpers.
|
||||
"""
|
||||
|
||||
def prepare_seeds(
|
||||
self, actual_batch_size: int, seed, use_random_seed: bool
|
||||
) -> Tuple[List[int], str]:
|
||||
"""Prepare per-item seeds and UI seed string."""
|
||||
actual_seed_list: List[int] = []
|
||||
seed_value_for_ui = ""
|
||||
try:
|
||||
if use_random_seed:
|
||||
actual_seed_list = [random.randint(0, 2**32 - 1) for _ in range(actual_batch_size)]
|
||||
seed_value_for_ui = ", ".join(str(s) for s in actual_seed_list)
|
||||
else:
|
||||
seed_list: List[int] = []
|
||||
if isinstance(seed, str):
|
||||
for s in [s.strip() for s in seed.split(",")]:
|
||||
if s == "-1" or s == "":
|
||||
seed_list.append(-1)
|
||||
else:
|
||||
try:
|
||||
seed_list.append(int(float(s)))
|
||||
except (ValueError, TypeError) as exc:
|
||||
logger.debug(f"[prepare_seeds] Failed to parse seed value '{s}': {exc}")
|
||||
seed_list.append(-1)
|
||||
elif seed is None or (isinstance(seed, (int, float)) and seed < 0):
|
||||
seed_list = [-1] * actual_batch_size
|
||||
elif isinstance(seed, (int, float)):
|
||||
seed_list = [int(seed)]
|
||||
else:
|
||||
seed_list = [-1] * actual_batch_size
|
||||
|
||||
has_single_non_negative_seed = len(seed_list) == 1 and seed_list[0] != -1
|
||||
for i in range(actual_batch_size):
|
||||
seed_val = seed_list[i] if i < len(seed_list) else -1
|
||||
if has_single_non_negative_seed and actual_batch_size > 1 and i > 0:
|
||||
actual_seed_list.append(random.randint(0, 2**32 - 1))
|
||||
elif seed_val == -1:
|
||||
actual_seed_list.append(random.randint(0, 2**32 - 1))
|
||||
else:
|
||||
actual_seed_list.append(int(seed_val))
|
||||
seed_value_for_ui = ", ".join(str(s) for s in actual_seed_list)
|
||||
except (TypeError, ValueError, OverflowError):
|
||||
logger.exception("[prepare_seeds] Failed to prepare seeds")
|
||||
actual_seed_list = [random.randint(0, 2**32 - 1) for _ in range(actual_batch_size)]
|
||||
seed_value_for_ui = ", ".join(str(s) for s in actual_seed_list)
|
||||
|
||||
return actual_seed_list, seed_value_for_ui
|
||||
|
||||
def generate_instruction(
|
||||
self,
|
||||
task_type: str,
|
||||
track_name: Optional[str] = None,
|
||||
complete_track_classes: Optional[List[str]] = None,
|
||||
) -> str:
|
||||
"""Generate task instruction text from task type and track context."""
|
||||
if task_type == "text2music":
|
||||
return TASK_INSTRUCTIONS["text2music"]
|
||||
if task_type == "repaint":
|
||||
return TASK_INSTRUCTIONS["repaint"]
|
||||
if task_type == "cover":
|
||||
return TASK_INSTRUCTIONS["cover"]
|
||||
if task_type == "extract":
|
||||
return (
|
||||
TASK_INSTRUCTIONS["extract"].format(TRACK_NAME=track_name.upper())
|
||||
if track_name
|
||||
else TASK_INSTRUCTIONS["extract_default"]
|
||||
)
|
||||
if task_type == "lego":
|
||||
return (
|
||||
TASK_INSTRUCTIONS["lego"].format(TRACK_NAME=track_name.upper())
|
||||
if track_name
|
||||
else TASK_INSTRUCTIONS["lego_default"]
|
||||
)
|
||||
if task_type == "complete":
|
||||
if complete_track_classes and len(complete_track_classes) > 0:
|
||||
track_classes_upper = [t.upper() for t in complete_track_classes]
|
||||
return TASK_INSTRUCTIONS["complete"].format(
|
||||
TRACK_CLASSES=" | ".join(track_classes_upper)
|
||||
)
|
||||
return TASK_INSTRUCTIONS["complete_default"]
|
||||
return TASK_INSTRUCTIONS["text2music"]
|
||||
|
||||
def determine_task_type(self, task_type, audio_code_string):
|
||||
"""Compute task-mode booleans for downstream generation logic."""
|
||||
is_repaint_task = task_type == "repaint"
|
||||
is_lego_task = task_type == "lego"
|
||||
is_cover_task = task_type == "cover"
|
||||
|
||||
if isinstance(audio_code_string, list):
|
||||
has_codes = any((c or "").strip() for c in audio_code_string)
|
||||
else:
|
||||
has_codes = bool(audio_code_string and str(audio_code_string).strip())
|
||||
|
||||
if has_codes:
|
||||
is_cover_task = True
|
||||
can_use_repainting = is_repaint_task or is_lego_task
|
||||
return is_repaint_task, is_lego_task, is_cover_task, can_use_repainting
|
||||
|
||||
def create_target_wavs(self, duration_seconds: float) -> torch.Tensor:
|
||||
"""Create silent stereo target audio with safe duration handling."""
|
||||
try:
|
||||
duration_seconds = max(0.1, round(duration_seconds, 1))
|
||||
frames = int(duration_seconds * 48000)
|
||||
return torch.zeros(2, frames)
|
||||
except (TypeError, ValueError, OverflowError):
|
||||
logger.exception("[create_target_wavs] Error creating target audio")
|
||||
return torch.zeros(2, 30 * 48000)
|
||||
|
|
@ -19,6 +19,93 @@ from acestep.constants import (
|
|||
DEBUG_GPU,
|
||||
)
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
# CPU thread configuration
|
||||
# ----------------------------------------------------------------------
|
||||
# When running on CPU we want to make use of most of the available cores
|
||||
# but leave a couple free for the OS / other processes. The logic is:
|
||||
# * If the system has ≤ 2 logical CPUs, use all of them.
|
||||
# * Otherwise, use (cpu_count - 2) threads.
|
||||
# This mirrors the common “all‑but‑two” heuristic while guaranteeing at
|
||||
# least one thread.
|
||||
# The function is executed at import time so that any subsequent
|
||||
# torch operations respect the setting.
|
||||
import os
|
||||
import torch
|
||||
|
||||
def _configure_cpu_threads() -> None:
|
||||
"""Set torch's intra-op and inter-op thread counts based on available CPUs.
|
||||
|
||||
This function configures PyTorch to use most available CPU cores while
|
||||
leaving a couple free for the OS and other processes. The logic is:
|
||||
* If the system has ≤ 2 logical CPUs, use all of them.
|
||||
* Otherwise, use (cpu_count - 2) threads.
|
||||
|
||||
This mirrors the common "all-but-two" heuristic while guaranteeing at
|
||||
least one thread.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If torch.set_num_threads or torch.set_num_interop_threads
|
||||
fails (e.g., if called after threads have already been used).
|
||||
"""
|
||||
cpu_cnt = os.cpu_count() or 1
|
||||
# Ensure we never set a non-positive number of threads.
|
||||
threads = cpu_cnt - 2 if cpu_cnt > 2 else cpu_cnt
|
||||
threads = max(threads, 1)
|
||||
|
||||
try:
|
||||
torch.set_num_threads(threads)
|
||||
except RuntimeError as exc:
|
||||
raise RuntimeError(
|
||||
f"Failed to set torch intra-op threads to {threads}: {exc}"
|
||||
) from exc
|
||||
|
||||
try:
|
||||
torch.set_num_interop_threads(threads)
|
||||
except RuntimeError as exc:
|
||||
raise RuntimeError(
|
||||
f"Failed to set torch inter-op threads to {threads}: {exc}"
|
||||
) from exc
|
||||
|
||||
# Track whether CPU threads have been configured to avoid redundant calls.
|
||||
_cpu_threads_configured = False
|
||||
|
||||
|
||||
def configure_cpu_threads_if_needed() -> bool:
|
||||
"""Configure CPU threads if enabled via environment variable.
|
||||
|
||||
This function provides an opt-in mechanism for configuring PyTorch's
|
||||
thread counts. It only takes effect if the environment variable
|
||||
``ACESTEP_CONFIGURE_THREADS`` is set to a truthy value (e.g., "1", "true", "yes").
|
||||
|
||||
The configuration is applied at most once per process; subsequent calls
|
||||
are no-ops.
|
||||
|
||||
Returns:
|
||||
True if configuration was applied, False if skipped (either because
|
||||
the environment variable is not set or configuration was already done).
|
||||
|
||||
Raises:
|
||||
RuntimeError: If thread configuration fails (propagated from
|
||||
``_configure_cpu_threads``).
|
||||
"""
|
||||
global _cpu_threads_configured
|
||||
|
||||
if _cpu_threads_configured:
|
||||
return False
|
||||
|
||||
env_value = os.environ.get("ACESTEP_CONFIGURE_THREADS", "").strip().lower()
|
||||
if env_value not in ("1", "true", "yes", "on"):
|
||||
return False
|
||||
|
||||
_configure_cpu_threads()
|
||||
_cpu_threads_configured = True
|
||||
return True
|
||||
|
||||
|
||||
# Apply opt-in CPU thread configuration early so torch respects it.
|
||||
configure_cpu_threads_if_needed()
|
||||
|
||||
|
||||
def _normalize_mode(mode: str) -> str:
|
||||
return (mode or "").strip().upper()
|
||||
|
|
|
|||
|
|
@ -37,6 +37,22 @@ PYTORCH_CUDA_INSTALL_URL = "https://download.pytorch.org/whl/cu121"
|
|||
PYTORCH_ROCM_INSTALL_URL = "https://download.pytorch.org/whl/rocm6.0"
|
||||
|
||||
|
||||
def is_mps_platform() -> bool:
|
||||
"""Check if running on macOS with MPS (Apple Silicon) available.
|
||||
|
||||
This is the canonical check used across the codebase to apply
|
||||
Mac-specific configuration overrides (no compile, no quantization,
|
||||
mlx backend, no offload, etc.).
|
||||
"""
|
||||
if sys.platform != "darwin":
|
||||
return False
|
||||
try:
|
||||
import torch
|
||||
return hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Empirical VRAM measurements (GB) -- model weights only, bf16 precision
|
||||
# These values should be calibrated using scripts/profile_vram.py
|
||||
|
|
@ -95,7 +111,7 @@ class GPUConfig:
|
|||
recommended_lm_model: str # Recommended default LM model path (empty if LM not available)
|
||||
|
||||
# LM backend restriction
|
||||
# "all" = any backend, "pt_mlx_only" = only pt/mlx (no vllm), used for very low VRAM
|
||||
# "all" = any backend, "pt_mlx_only" = only pt/mlx (no vllm), used for MPS (vllm requires CUDA)
|
||||
lm_backend_restriction: str # "all" or "pt_mlx_only"
|
||||
recommended_backend: str # Recommended default backend: "vllm", "pt", or "mlx"
|
||||
|
||||
|
|
@ -126,8 +142,8 @@ GPU_TIER_CONFIGS = {
|
|||
"init_lm_default": False,
|
||||
"available_lm_models": [],
|
||||
"recommended_lm_model": "",
|
||||
"lm_backend_restriction": "pt_mlx_only", # vllm KV cache won't fit
|
||||
"recommended_backend": "pt",
|
||||
"lm_backend_restriction": "all",
|
||||
"recommended_backend": "vllm",
|
||||
"offload_to_cpu_default": True,
|
||||
"offload_dit_to_cpu_default": True,
|
||||
"quantization_default": True, # INT8 essential to fit DiT in ~4GB
|
||||
|
|
@ -145,8 +161,8 @@ GPU_TIER_CONFIGS = {
|
|||
"init_lm_default": False,
|
||||
"available_lm_models": [],
|
||||
"recommended_lm_model": "",
|
||||
"lm_backend_restriction": "pt_mlx_only",
|
||||
"recommended_backend": "pt",
|
||||
"lm_backend_restriction": "all",
|
||||
"recommended_backend": "vllm",
|
||||
"offload_to_cpu_default": True,
|
||||
"offload_dit_to_cpu_default": True,
|
||||
"quantization_default": True,
|
||||
|
|
@ -156,7 +172,7 @@ GPU_TIER_CONFIGS = {
|
|||
"tier3": { # 6-8GB
|
||||
# Offload mode. DiT(4.46) + context(0.5) ≈ 5.0GB.
|
||||
# ~1.5-3GB headroom allows LM 0.6B (1.2+0.6=1.8GB) and batch=2.
|
||||
# vllm KV cache is tight; pt backend is safer for 0.6B on this tier.
|
||||
# With CPU offload, DiT is offloaded before LM runs → vllm can use freed VRAM.
|
||||
"max_duration_with_lm": 480, # 8 minutes
|
||||
"max_duration_without_lm": 600, # 10 minutes (max supported)
|
||||
"max_batch_size_with_lm": 2,
|
||||
|
|
@ -164,8 +180,8 @@ GPU_TIER_CONFIGS = {
|
|||
"init_lm_default": True,
|
||||
"available_lm_models": ["acestep-5Hz-lm-0.6B"],
|
||||
"recommended_lm_model": "acestep-5Hz-lm-0.6B",
|
||||
"lm_backend_restriction": "pt_mlx_only", # vllm KV cache too greedy for <8GB
|
||||
"recommended_backend": "pt",
|
||||
"lm_backend_restriction": "all",
|
||||
"recommended_backend": "vllm",
|
||||
"offload_to_cpu_default": True,
|
||||
"offload_dit_to_cpu_default": True,
|
||||
"quantization_default": True,
|
||||
|
|
@ -511,6 +527,20 @@ def get_gpu_config(gpu_memory_gb: Optional[float] = None) -> GPUConfig:
|
|||
"""
|
||||
Get GPU configuration based on detected or provided GPU memory.
|
||||
|
||||
On macOS with MPS (Apple Silicon), several overrides are applied
|
||||
automatically regardless of the tier selected by memory size:
|
||||
|
||||
- ``compile_model_default = False`` — ``torch.compile`` is not supported
|
||||
on MPS and would error or silently fall back to eager mode.
|
||||
- ``quantization_default = False`` — torchao INT8 quantization is
|
||||
incompatible with MPS / macOS.
|
||||
- ``recommended_backend = "mlx"`` — MLX provides native Apple Silicon
|
||||
acceleration for the 5Hz LM; vllm requires CUDA.
|
||||
- ``lm_backend_restriction = "pt_mlx_only"`` — vllm cannot run on MPS.
|
||||
- ``offload_to_cpu_default = False`` — Apple Silicon uses unified memory;
|
||||
offloading to CPU provides no benefit and adds overhead.
|
||||
- ``offload_dit_to_cpu_default = False`` — same reason.
|
||||
|
||||
Args:
|
||||
gpu_memory_gb: GPU memory in GB. If None, will be auto-detected.
|
||||
|
||||
|
|
@ -523,6 +553,15 @@ def get_gpu_config(gpu_memory_gb: Optional[float] = None) -> GPUConfig:
|
|||
tier = get_gpu_tier(gpu_memory_gb)
|
||||
config = GPU_TIER_CONFIGS[tier]
|
||||
|
||||
# --- MPS (Apple Silicon) overrides ---
|
||||
_mps = is_mps_platform()
|
||||
if _mps:
|
||||
logger.info(
|
||||
f"macOS MPS detected ({gpu_memory_gb:.1f} GB unified memory, tier={tier}). "
|
||||
"Applying Apple Silicon optimizations: no compile, no quantization, "
|
||||
"mlx backend, no CPU offload."
|
||||
)
|
||||
|
||||
return GPUConfig(
|
||||
tier=tier,
|
||||
gpu_memory_gb=gpu_memory_gb,
|
||||
|
|
@ -533,12 +572,15 @@ def get_gpu_config(gpu_memory_gb: Optional[float] = None) -> GPUConfig:
|
|||
init_lm_default=config["init_lm_default"],
|
||||
available_lm_models=config["available_lm_models"],
|
||||
recommended_lm_model=config.get("recommended_lm_model", ""),
|
||||
lm_backend_restriction=config.get("lm_backend_restriction", "all"),
|
||||
recommended_backend=config.get("recommended_backend", "vllm"),
|
||||
offload_to_cpu_default=config.get("offload_to_cpu_default", True),
|
||||
offload_dit_to_cpu_default=config.get("offload_dit_to_cpu_default", True),
|
||||
quantization_default=config.get("quantization_default", True),
|
||||
compile_model_default=config.get("compile_model_default", True),
|
||||
# MPS: vllm requires CUDA, restrict to pt/mlx; prefer mlx for native acceleration
|
||||
lm_backend_restriction="pt_mlx_only" if _mps else config.get("lm_backend_restriction", "all"),
|
||||
recommended_backend="mlx" if _mps else config.get("recommended_backend", "vllm"),
|
||||
# MPS: unified memory — offloading to CPU is pointless overhead
|
||||
offload_to_cpu_default=False if _mps else config.get("offload_to_cpu_default", True),
|
||||
offload_dit_to_cpu_default=False if _mps else config.get("offload_dit_to_cpu_default", True),
|
||||
# MPS: torch.compile and torchao quantization are not supported
|
||||
quantization_default=False if _mps else config.get("quantization_default", True),
|
||||
compile_model_default=False if _mps else config.get("compile_model_default", True),
|
||||
lm_memory_gb=config["lm_memory_gb"],
|
||||
)
|
||||
|
||||
|
|
@ -1056,6 +1098,97 @@ def print_gpu_config_info(gpu_config: GPUConfig):
|
|||
logger.info(f" - Available LM Models: {gpu_config.available_lm_models or 'None'}")
|
||||
|
||||
|
||||
# Human-readable tier labels for UI display
|
||||
GPU_TIER_LABELS = {
|
||||
"tier1": "tier1 (≤4GB)",
|
||||
"tier2": "tier2 (4-6GB)",
|
||||
"tier3": "tier3 (6-8GB)",
|
||||
"tier4": "tier4 (8-12GB)",
|
||||
"tier5": "tier5 (12-16GB)",
|
||||
"tier6a": "tier6a (16-20GB)",
|
||||
"tier6b": "tier6b (20-24GB)",
|
||||
"unlimited": "unlimited (≥24GB)",
|
||||
}
|
||||
|
||||
# Ordered list of tier keys for dropdown
|
||||
GPU_TIER_CHOICES = list(GPU_TIER_LABELS.items()) # [(value, label), ...]
|
||||
|
||||
|
||||
def get_gpu_device_name() -> str:
|
||||
"""
|
||||
Get the GPU device name string.
|
||||
|
||||
Returns:
|
||||
Human-readable GPU name, e.g. "NVIDIA GeForce RTX 4060 Ti",
|
||||
"Apple M2 Pro (MPS)", "CPU only", etc.
|
||||
"""
|
||||
try:
|
||||
import torch
|
||||
if torch.cuda.is_available():
|
||||
return torch.cuda.get_device_name(0)
|
||||
elif hasattr(torch, 'xpu') and torch.xpu.is_available():
|
||||
props = torch.xpu.get_device_properties(0)
|
||||
return getattr(props, 'name', 'Intel XPU')
|
||||
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||||
# MPS doesn't expose a device name; use platform info
|
||||
try:
|
||||
import platform
|
||||
chip = platform.processor() or "Apple Silicon"
|
||||
return f"{chip} (MPS)"
|
||||
except Exception:
|
||||
return "Apple Silicon (MPS)"
|
||||
else:
|
||||
return "CPU only"
|
||||
except ImportError:
|
||||
return "Unknown (PyTorch not available)"
|
||||
|
||||
|
||||
def get_gpu_config_for_tier(tier: str) -> GPUConfig:
|
||||
"""
|
||||
Create a GPUConfig for a specific tier, applying platform overrides.
|
||||
|
||||
This is used when the user manually selects a different tier in the UI.
|
||||
The actual gpu_memory_gb is preserved from the real hardware detection,
|
||||
but all tier-based settings come from the selected tier's config.
|
||||
|
||||
Args:
|
||||
tier: Tier key, e.g. "tier3", "tier6a", "unlimited"
|
||||
|
||||
Returns:
|
||||
GPUConfig with the selected tier's settings
|
||||
"""
|
||||
if tier not in GPU_TIER_CONFIGS:
|
||||
logger.warning(f"Unknown tier '{tier}', falling back to auto-detected config")
|
||||
return get_gpu_config()
|
||||
|
||||
# Keep the real GPU memory for informational purposes
|
||||
real_gpu_memory = get_gpu_memory_gb()
|
||||
config = GPU_TIER_CONFIGS[tier]
|
||||
|
||||
_mps = is_mps_platform()
|
||||
if _mps:
|
||||
logger.info(f"Manual tier override to {tier} on macOS MPS — applying Apple Silicon overrides")
|
||||
|
||||
return GPUConfig(
|
||||
tier=tier,
|
||||
gpu_memory_gb=real_gpu_memory,
|
||||
max_duration_with_lm=config["max_duration_with_lm"],
|
||||
max_duration_without_lm=config["max_duration_without_lm"],
|
||||
max_batch_size_with_lm=config["max_batch_size_with_lm"],
|
||||
max_batch_size_without_lm=config["max_batch_size_without_lm"],
|
||||
init_lm_default=config["init_lm_default"],
|
||||
available_lm_models=config["available_lm_models"],
|
||||
recommended_lm_model=config.get("recommended_lm_model", ""),
|
||||
lm_backend_restriction="pt_mlx_only" if _mps else config.get("lm_backend_restriction", "all"),
|
||||
recommended_backend="mlx" if _mps else config.get("recommended_backend", "vllm"),
|
||||
offload_to_cpu_default=False if _mps else config.get("offload_to_cpu_default", True),
|
||||
offload_dit_to_cpu_default=False if _mps else config.get("offload_dit_to_cpu_default", True),
|
||||
quantization_default=False if _mps else config.get("quantization_default", True),
|
||||
compile_model_default=False if _mps else config.get("compile_model_default", True),
|
||||
lm_memory_gb=config["lm_memory_gb"],
|
||||
)
|
||||
|
||||
|
||||
# Global GPU config instance (initialized lazily)
|
||||
_global_gpu_config: Optional[GPUConfig] = None
|
||||
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from typing import Any, Dict, List, Optional
|
|||
from uuid import uuid4
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Request, Depends, Header
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import FileResponse
|
||||
|
||||
# Global results directory inside project root
|
||||
|
|
@ -467,10 +468,30 @@ async def release_task(request: Request, authorization: Optional[str] = Header(N
|
|||
lm_negative_prompt=get_param("lm_negative_prompt", default="NO USER INPUT") or "NO USER INPUT",
|
||||
)
|
||||
|
||||
# Resolve seed(s) into List[int] for GenerationConfig.seeds
|
||||
use_random_seed = get_param("use_random_seed", default=True)
|
||||
resolved_seeds = None
|
||||
if not use_random_seed:
|
||||
raw_seed = get_param("seed", default=-1)
|
||||
if isinstance(raw_seed, str) and raw_seed.strip():
|
||||
resolved_seeds = []
|
||||
for s in raw_seed.split(","):
|
||||
s = s.strip()
|
||||
if s and s != "-1":
|
||||
try:
|
||||
resolved_seeds.append(int(float(s)))
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
if not resolved_seeds:
|
||||
resolved_seeds = None
|
||||
elif isinstance(raw_seed, (int, float)) and int(raw_seed) >= 0:
|
||||
resolved_seeds = [int(raw_seed)]
|
||||
|
||||
config = GenerationConfig(
|
||||
batch_size=get_param("batch_size", default=2),
|
||||
use_random_seed=get_param("use_random_seed", default=True),
|
||||
audio_format=get_param("audio_format", default="mp3"),
|
||||
use_random_seed=use_random_seed,
|
||||
seeds=resolved_seeds,
|
||||
audio_format=get_param("audio_format", default="flac"),
|
||||
)
|
||||
|
||||
# Get output directory
|
||||
|
|
@ -512,6 +533,38 @@ async def release_task(request: Request, authorization: Optional[str] = Header(N
|
|||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# Origins that are expected to call the API:
|
||||
# - "null" → studio.html opened via file:// protocol
|
||||
# - http://localhost:* → local dev servers / Gradio UI
|
||||
# - http://127.0.0.1:* → same, numeric form
|
||||
_CORS_KWARGS = dict(
|
||||
allow_origins=["null", "http://localhost", "http://127.0.0.1"],
|
||||
allow_origin_regex=r"^https?://(localhost|127\.0\.0\.1)(:\d+)?$",
|
||||
allow_methods=["GET", "POST", "OPTIONS"],
|
||||
allow_headers=["Content-Type", "Authorization"],
|
||||
)
|
||||
|
||||
|
||||
def _add_cors_middleware(app):
|
||||
"""Add CORS middleware so browser-based frontends (e.g. studio.html via file://) can call the API."""
|
||||
app.add_middleware(CORSMiddleware, **_CORS_KWARGS)
|
||||
|
||||
|
||||
def _add_cors_middleware_post_launch(app):
|
||||
"""Wrap an already-started app's middleware stack with CORS.
|
||||
|
||||
``add_middleware`` raises after Starlette has started, so we patch the
|
||||
compiled middleware stack directly instead.
|
||||
"""
|
||||
from starlette.middleware.cors import CORSMiddleware as _CORSImpl
|
||||
|
||||
if app.middleware_stack is not None:
|
||||
app.middleware_stack = _CORSImpl(app=app.middleware_stack, **_CORS_KWARGS)
|
||||
else:
|
||||
# App hasn't built its stack yet – safe to use the normal path
|
||||
_add_cors_middleware(app)
|
||||
|
||||
|
||||
def setup_api_routes_to_app(app, dit_handler, llm_handler, api_key: Optional[str] = None):
|
||||
"""
|
||||
Mount API routes to a FastAPI application (for use with gr.mount_gradio_app)
|
||||
|
|
@ -523,6 +576,7 @@ def setup_api_routes_to_app(app, dit_handler, llm_handler, api_key: Optional[str
|
|||
api_key: Optional API key for authentication
|
||||
"""
|
||||
set_api_key(api_key)
|
||||
_add_cors_middleware(app)
|
||||
app.state.dit_handler = dit_handler
|
||||
app.state.llm_handler = llm_handler
|
||||
app.include_router(router)
|
||||
|
|
@ -540,6 +594,7 @@ def setup_api_routes(demo, dit_handler, llm_handler, api_key: Optional[str] = No
|
|||
"""
|
||||
set_api_key(api_key)
|
||||
app = demo.app
|
||||
_add_cors_middleware_post_launch(app)
|
||||
app.state.dit_handler = dit_handler
|
||||
app.state.llm_handler = llm_handler
|
||||
app.include_router(router)
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|||
|
||||
generation_section["config_path"].change(
|
||||
fn=gen_h.update_model_type_settings,
|
||||
inputs=[generation_section["config_path"]],
|
||||
inputs=[generation_section["config_path"], generation_section["generation_mode"]],
|
||||
outputs=[
|
||||
generation_section["inference_steps"],
|
||||
generation_section["guidance_scale"],
|
||||
|
|
@ -40,6 +40,25 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|||
generation_section["cfg_interval_start"],
|
||||
generation_section["cfg_interval_end"],
|
||||
generation_section["task_type"],
|
||||
generation_section["generation_mode"],
|
||||
]
|
||||
)
|
||||
|
||||
# ========== Tier Override ==========
|
||||
generation_section["tier_dropdown"].change(
|
||||
fn=lambda tier: gen_h.on_tier_change(tier, llm_handler),
|
||||
inputs=[generation_section["tier_dropdown"]],
|
||||
outputs=[
|
||||
generation_section["offload_to_cpu_checkbox"],
|
||||
generation_section["offload_dit_to_cpu_checkbox"],
|
||||
generation_section["compile_model_checkbox"],
|
||||
generation_section["quantization_checkbox"],
|
||||
generation_section["backend_dropdown"],
|
||||
generation_section["lm_model_path"],
|
||||
generation_section["init_llm_checkbox"],
|
||||
generation_section["batch_size_input"],
|
||||
generation_section["audio_duration"],
|
||||
generation_section["gpu_info_display"],
|
||||
]
|
||||
)
|
||||
|
||||
|
|
@ -57,6 +76,8 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|||
generation_section["offload_dit_to_cpu_checkbox"],
|
||||
generation_section["compile_model_checkbox"],
|
||||
generation_section["quantization_checkbox"],
|
||||
generation_section["mlx_dit_checkbox"],
|
||||
generation_section["generation_mode"], # preserve current mode across init
|
||||
],
|
||||
outputs=[
|
||||
generation_section["init_status"],
|
||||
|
|
@ -70,9 +91,12 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|||
generation_section["cfg_interval_start"],
|
||||
generation_section["cfg_interval_end"],
|
||||
generation_section["task_type"],
|
||||
generation_section["generation_mode"],
|
||||
# GPU-config-aware limits (updated after initialization)
|
||||
generation_section["audio_duration"],
|
||||
generation_section["batch_size_input"],
|
||||
# Think checkbox: enable if LLM initialized
|
||||
generation_section["think_checkbox"],
|
||||
]
|
||||
)
|
||||
|
||||
|
|
@ -115,24 +139,6 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|||
outputs=[generation_section["lm_negative_prompt"]]
|
||||
)
|
||||
|
||||
generation_section["init_llm_checkbox"].change(
|
||||
fn=gen_h.update_audio_cover_strength_visibility,
|
||||
inputs=[generation_section["task_type"], generation_section["init_llm_checkbox"], generation_section["reference_audio"]],
|
||||
outputs=[generation_section["audio_cover_strength"]]
|
||||
)
|
||||
|
||||
generation_section["task_type"].change(
|
||||
fn=gen_h.update_audio_cover_strength_visibility,
|
||||
inputs=[generation_section["task_type"], generation_section["init_llm_checkbox"], generation_section["reference_audio"]],
|
||||
outputs=[generation_section["audio_cover_strength"]]
|
||||
)
|
||||
|
||||
generation_section["reference_audio"].change(
|
||||
fn=gen_h.update_audio_cover_strength_visibility,
|
||||
inputs=[generation_section["task_type"], generation_section["init_llm_checkbox"], generation_section["reference_audio"]],
|
||||
outputs=[generation_section["audio_cover_strength"]]
|
||||
)
|
||||
|
||||
generation_section["batch_size_input"].change(
|
||||
fn=gen_h.update_audio_components_visibility,
|
||||
inputs=[generation_section["batch_size_input"]],
|
||||
|
|
@ -149,13 +155,34 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|||
]
|
||||
)
|
||||
|
||||
# ========== Audio Conversion ==========
|
||||
# ========== Audio Conversion (LM Codes Hints accordion in Custom mode) ==========
|
||||
generation_section["convert_src_to_codes_btn"].click(
|
||||
fn=lambda src: gen_h.convert_src_audio_to_codes_wrapper(dit_handler, src),
|
||||
inputs=[generation_section["src_audio"]],
|
||||
inputs=[generation_section["lm_codes_audio_upload"]],
|
||||
outputs=[generation_section["text2music_audio_code_string"]]
|
||||
)
|
||||
|
||||
# ========== Analyze Source Audio (Remix/Repaint: convert to codes + transcribe) ==========
|
||||
generation_section["analyze_btn"].click(
|
||||
fn=lambda src, debug: gen_h.analyze_src_audio(dit_handler, llm_handler, src, debug),
|
||||
inputs=[
|
||||
generation_section["src_audio"],
|
||||
generation_section["constrained_decoding_debug"],
|
||||
],
|
||||
outputs=[
|
||||
generation_section["text2music_audio_code_string"],
|
||||
results_section["status_output"],
|
||||
generation_section["captions"],
|
||||
generation_section["lyrics"],
|
||||
generation_section["bpm"],
|
||||
generation_section["audio_duration"],
|
||||
generation_section["key_scale"],
|
||||
generation_section["vocal_language"],
|
||||
generation_section["time_signature"],
|
||||
results_section["is_format_caption_state"],
|
||||
]
|
||||
)
|
||||
|
||||
# ========== Instruction UI Updates ==========
|
||||
for trigger in [generation_section["task_type"], generation_section["track_name"], generation_section["complete_track_classes"], generation_section["reference_audio"]]:
|
||||
trigger.change(
|
||||
|
|
@ -164,7 +191,6 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|||
generation_section["task_type"],
|
||||
generation_section["track_name"],
|
||||
generation_section["complete_track_classes"],
|
||||
generation_section["text2music_audio_code_string"],
|
||||
generation_section["init_llm_checkbox"],
|
||||
generation_section["reference_audio"],
|
||||
],
|
||||
|
|
@ -172,9 +198,7 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|||
generation_section["instruction_display_gen"],
|
||||
generation_section["track_name"],
|
||||
generation_section["complete_track_classes"],
|
||||
generation_section["audio_cover_strength"],
|
||||
generation_section["repainting_group"],
|
||||
generation_section["text2music_audio_codes_group"],
|
||||
]
|
||||
)
|
||||
|
||||
|
|
@ -233,25 +257,23 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|||
outputs=[results_section["is_format_caption_state"]]
|
||||
)
|
||||
|
||||
# ========== Audio Uploads Accordion ==========
|
||||
for trigger in [generation_section["reference_audio"], generation_section["src_audio"]]:
|
||||
trigger.change(
|
||||
fn=gen_h.update_audio_uploads_accordion,
|
||||
inputs=[generation_section["reference_audio"], generation_section["src_audio"]],
|
||||
outputs=[generation_section["audio_uploads_accordion"]]
|
||||
)
|
||||
|
||||
# ========== Instrumental Checkbox ==========
|
||||
generation_section["instrumental_checkbox"].change(
|
||||
fn=gen_h.handle_instrumental_checkbox,
|
||||
inputs=[generation_section["instrumental_checkbox"], generation_section["lyrics"]],
|
||||
outputs=[generation_section["lyrics"]]
|
||||
inputs=[
|
||||
generation_section["instrumental_checkbox"],
|
||||
generation_section["lyrics"],
|
||||
generation_section["lyrics_before_instrumental"],
|
||||
],
|
||||
outputs=[
|
||||
generation_section["lyrics"],
|
||||
generation_section["lyrics_before_instrumental"],
|
||||
]
|
||||
)
|
||||
|
||||
# ========== Format Button ==========
|
||||
# Note: cfg_scale and negative_prompt are not supported in format mode
|
||||
generation_section["format_btn"].click(
|
||||
fn=lambda caption, lyrics, bpm, duration, key_scale, time_sig, temp, top_k, top_p, debug: gen_h.handle_format_sample(
|
||||
# ========== Format Caption Button ==========
|
||||
generation_section["format_caption_btn"].click(
|
||||
fn=lambda caption, lyrics, bpm, duration, key_scale, time_sig, temp, top_k, top_p, debug: gen_h.handle_format_caption(
|
||||
llm_handler, caption, lyrics, bpm, duration, key_scale, time_sig, temp, top_k, top_p, debug
|
||||
),
|
||||
inputs=[
|
||||
|
|
@ -268,6 +290,34 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|||
],
|
||||
outputs=[
|
||||
generation_section["captions"],
|
||||
generation_section["bpm"],
|
||||
generation_section["audio_duration"],
|
||||
generation_section["key_scale"],
|
||||
generation_section["vocal_language"],
|
||||
generation_section["time_signature"],
|
||||
results_section["is_format_caption_state"],
|
||||
results_section["status_output"],
|
||||
]
|
||||
)
|
||||
|
||||
# ========== Format Lyrics Button ==========
|
||||
generation_section["format_lyrics_btn"].click(
|
||||
fn=lambda caption, lyrics, bpm, duration, key_scale, time_sig, temp, top_k, top_p, debug: gen_h.handle_format_lyrics(
|
||||
llm_handler, caption, lyrics, bpm, duration, key_scale, time_sig, temp, top_k, top_p, debug
|
||||
),
|
||||
inputs=[
|
||||
generation_section["captions"],
|
||||
generation_section["lyrics"],
|
||||
generation_section["bpm"],
|
||||
generation_section["audio_duration"],
|
||||
generation_section["key_scale"],
|
||||
generation_section["time_signature"],
|
||||
generation_section["lm_temperature"],
|
||||
generation_section["lm_top_k"],
|
||||
generation_section["lm_top_p"],
|
||||
generation_section["constrained_decoding_debug"],
|
||||
],
|
||||
outputs=[
|
||||
generation_section["lyrics"],
|
||||
generation_section["bpm"],
|
||||
generation_section["audio_duration"],
|
||||
|
|
@ -279,17 +329,30 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|||
]
|
||||
)
|
||||
|
||||
# ========== Simple/Custom Mode Toggle ==========
|
||||
# ========== Generation Mode Change ==========
|
||||
generation_section["generation_mode"].change(
|
||||
fn=gen_h.handle_generation_mode_change,
|
||||
fn=lambda mode: gen_h.handle_generation_mode_change(mode, llm_handler),
|
||||
inputs=[generation_section["generation_mode"]],
|
||||
outputs=[
|
||||
generation_section["simple_mode_group"],
|
||||
generation_section["caption_accordion"],
|
||||
generation_section["lyrics_accordion"],
|
||||
generation_section["custom_mode_group"],
|
||||
generation_section["generate_btn"],
|
||||
generation_section["simple_sample_created"],
|
||||
generation_section["optional_params_accordion"],
|
||||
generation_section["task_type"],
|
||||
generation_section["src_audio_row"],
|
||||
generation_section["repainting_group"],
|
||||
generation_section["text2music_audio_codes_group"],
|
||||
generation_section["track_name"],
|
||||
generation_section["complete_track_classes"],
|
||||
generation_section["generate_btn_row"],
|
||||
generation_section["generation_mode"],
|
||||
generation_section["results_wrapper"],
|
||||
generation_section["think_checkbox"],
|
||||
generation_section["load_file_col"],
|
||||
generation_section["load_file"],
|
||||
generation_section["audio_cover_strength"],
|
||||
generation_section["cover_noise_strength"],
|
||||
]
|
||||
)
|
||||
|
||||
|
|
@ -337,13 +400,12 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|||
generation_section["simple_vocal_language"],
|
||||
generation_section["time_signature"],
|
||||
generation_section["instrumental_checkbox"],
|
||||
generation_section["caption_accordion"],
|
||||
generation_section["lyrics_accordion"],
|
||||
generation_section["generate_btn"],
|
||||
generation_section["simple_sample_created"],
|
||||
generation_section["think_checkbox"],
|
||||
results_section["is_format_caption_state"],
|
||||
results_section["status_output"],
|
||||
generation_section["generation_mode"],
|
||||
]
|
||||
)
|
||||
|
||||
|
|
@ -381,6 +443,7 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|||
generation_section["use_cot_caption"],
|
||||
generation_section["use_cot_language"],
|
||||
generation_section["audio_cover_strength"],
|
||||
generation_section["cover_noise_strength"],
|
||||
generation_section["think_checkbox"],
|
||||
generation_section["text2music_audio_code_string"],
|
||||
generation_section["repainting_start"],
|
||||
|
|
@ -470,26 +533,62 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|||
],
|
||||
js=download_existing_js # Run the above JS
|
||||
)
|
||||
# ========== Send to SRC Handlers ==========
|
||||
# ========== Send to Remix / Repaint Handlers ==========
|
||||
# Mode-UI outputs shared with generation_mode.change — applied atomically
|
||||
# so we don't rely on a chained .change() event for visibility/label updates.
|
||||
_mode_ui_outputs = [
|
||||
generation_section["simple_mode_group"],
|
||||
generation_section["custom_mode_group"],
|
||||
generation_section["generate_btn"],
|
||||
generation_section["simple_sample_created"],
|
||||
generation_section["optional_params_accordion"],
|
||||
generation_section["task_type"],
|
||||
generation_section["src_audio_row"],
|
||||
generation_section["repainting_group"],
|
||||
generation_section["text2music_audio_codes_group"],
|
||||
generation_section["track_name"],
|
||||
generation_section["complete_track_classes"],
|
||||
generation_section["generate_btn_row"],
|
||||
generation_section["generation_mode"],
|
||||
generation_section["results_wrapper"],
|
||||
generation_section["think_checkbox"],
|
||||
generation_section["load_file_col"],
|
||||
generation_section["load_file"],
|
||||
generation_section["audio_cover_strength"],
|
||||
generation_section["cover_noise_strength"],
|
||||
]
|
||||
for btn_idx in range(1, 9):
|
||||
results_section[f"send_to_src_btn_{btn_idx}"].click(
|
||||
fn=res_h.send_audio_to_src_with_metadata,
|
||||
results_section[f"send_to_remix_btn_{btn_idx}"].click(
|
||||
fn=lambda audio, lm, ly, cap: res_h.send_audio_to_remix(
|
||||
audio, lm, ly, cap, llm_handler),
|
||||
inputs=[
|
||||
results_section[f"generated_audio_{btn_idx}"],
|
||||
results_section["lm_metadata_state"]
|
||||
results_section["lm_metadata_state"],
|
||||
generation_section["lyrics"],
|
||||
generation_section["captions"],
|
||||
],
|
||||
outputs=[
|
||||
generation_section["src_audio"],
|
||||
generation_section["bpm"],
|
||||
generation_section["captions"],
|
||||
generation_section["generation_mode"],
|
||||
generation_section["lyrics"],
|
||||
generation_section["audio_duration"],
|
||||
generation_section["key_scale"],
|
||||
generation_section["vocal_language"],
|
||||
generation_section["time_signature"],
|
||||
results_section["is_format_caption_state"],
|
||||
generation_section["audio_uploads_accordion"]
|
||||
]
|
||||
generation_section["captions"],
|
||||
] + _mode_ui_outputs,
|
||||
)
|
||||
results_section[f"send_to_repaint_btn_{btn_idx}"].click(
|
||||
fn=lambda audio, lm, ly, cap: res_h.send_audio_to_repaint(
|
||||
audio, lm, ly, cap, llm_handler),
|
||||
inputs=[
|
||||
results_section[f"generated_audio_{btn_idx}"],
|
||||
results_section["lm_metadata_state"],
|
||||
generation_section["lyrics"],
|
||||
generation_section["captions"],
|
||||
],
|
||||
outputs=[
|
||||
generation_section["src_audio"],
|
||||
generation_section["generation_mode"],
|
||||
generation_section["lyrics"],
|
||||
generation_section["captions"],
|
||||
] + _mode_ui_outputs,
|
||||
)
|
||||
|
||||
# ========== Score Calculation Handlers ==========
|
||||
|
|
@ -539,6 +638,25 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|||
]
|
||||
)
|
||||
|
||||
# ========== Convert To Codes Handlers ==========
|
||||
for btn_idx in range(1, 9):
|
||||
results_section[f"convert_to_codes_btn_{btn_idx}"].click(
|
||||
fn=lambda audio: res_h.convert_result_audio_to_codes(dit_handler, audio),
|
||||
inputs=[results_section[f"generated_audio_{btn_idx}"]],
|
||||
outputs=[
|
||||
results_section[f"codes_display_{btn_idx}"],
|
||||
results_section[f"details_accordion_{btn_idx}"],
|
||||
]
|
||||
)
|
||||
|
||||
# ========== Save LRC Handlers ==========
|
||||
for btn_idx in range(1, 9):
|
||||
results_section[f"save_lrc_btn_{btn_idx}"].click(
|
||||
fn=res_h.save_lrc_to_file,
|
||||
inputs=[results_section[f"lrc_display_{btn_idx}"]],
|
||||
outputs=[results_section[f"lrc_download_file_{btn_idx}"]]
|
||||
)
|
||||
|
||||
def generation_wrapper(*args):
|
||||
yield from res_h.generate_with_batch_management(dit_handler, llm_handler, *args)
|
||||
# ========== Generation Handler ==========
|
||||
|
|
@ -577,6 +695,7 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|||
generation_section["repainting_end"],
|
||||
generation_section["instruction_display_gen"],
|
||||
generation_section["audio_cover_strength"],
|
||||
generation_section["cover_noise_strength"],
|
||||
generation_section["task_type"],
|
||||
generation_section["use_adg"],
|
||||
generation_section["cfg_interval_start"],
|
||||
|
|
@ -603,6 +722,10 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|||
generation_section["lm_batch_chunk_size"],
|
||||
generation_section["track_name"],
|
||||
generation_section["complete_track_classes"],
|
||||
generation_section["enable_normalization"],
|
||||
generation_section["normalization_db"],
|
||||
generation_section["latent_shift"],
|
||||
generation_section["latent_rescale"],
|
||||
generation_section["autogen_checkbox"],
|
||||
results_section["current_batch_index"],
|
||||
results_section["total_batches"],
|
||||
|
|
@ -765,6 +888,7 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|||
generation_section["repainting_end"],
|
||||
generation_section["instruction_display_gen"],
|
||||
generation_section["audio_cover_strength"],
|
||||
generation_section["cover_noise_strength"],
|
||||
generation_section["task_type"],
|
||||
generation_section["use_adg"],
|
||||
generation_section["cfg_interval_start"],
|
||||
|
|
@ -790,6 +914,10 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|||
generation_section["lm_batch_chunk_size"],
|
||||
generation_section["track_name"],
|
||||
generation_section["complete_track_classes"],
|
||||
generation_section["enable_normalization"],
|
||||
generation_section["normalization_db"],
|
||||
generation_section["latent_shift"],
|
||||
generation_section["latent_rescale"],
|
||||
],
|
||||
outputs=[results_section["generation_params_state"]]
|
||||
).then(
|
||||
|
|
@ -897,6 +1025,10 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|||
generation_section["allow_lm_batch"],
|
||||
generation_section["track_name"],
|
||||
generation_section["complete_track_classes"],
|
||||
generation_section["enable_normalization"],
|
||||
generation_section["normalization_db"],
|
||||
generation_section["latent_shift"],
|
||||
generation_section["latent_rescale"],
|
||||
]
|
||||
)
|
||||
|
||||
|
|
@ -1178,11 +1310,12 @@ def setup_training_event_handlers(demo, dit_handler, llm_handler, training_secti
|
|||
|
||||
# Preprocess dataset to tensor files
|
||||
training_section["preprocess_btn"].click(
|
||||
fn=lambda output_dir, state: train_h.preprocess_dataset(
|
||||
output_dir, dit_handler, state
|
||||
fn=lambda output_dir, mode, state: train_h.preprocess_dataset(
|
||||
output_dir, mode, dit_handler, state
|
||||
),
|
||||
inputs=[
|
||||
training_section["preprocess_output_dir"],
|
||||
training_section["preprocess_mode"],
|
||||
training_section["dataset_builder_state"],
|
||||
],
|
||||
outputs=[training_section["preprocess_progress"]]
|
||||
|
|
@ -1256,3 +1389,92 @@ def setup_training_event_handlers(demo, dit_handler, llm_handler, training_secti
|
|||
],
|
||||
outputs=[training_section["export_status"]]
|
||||
)
|
||||
|
||||
# ========== LoKr Training Tab Handlers ==========
|
||||
|
||||
# Load preprocessed tensor dataset for LoKr
|
||||
training_section["lokr_load_dataset_btn"].click(
|
||||
fn=train_h.load_training_dataset,
|
||||
inputs=[training_section["lokr_training_tensor_dir"]],
|
||||
outputs=[training_section["lokr_training_dataset_info"]]
|
||||
)
|
||||
|
||||
# Start LoKr training from preprocessed tensors
|
||||
def lokr_training_wrapper(
|
||||
tensor_dir, ldim, lalpha, factor, decompose_both, use_tucker,
|
||||
use_scalar, weight_decompose, lr, ep, bs, ga, se, sh, sd, od, ts,
|
||||
):
|
||||
from loguru import logger
|
||||
if not isinstance(ts, dict):
|
||||
ts = {"is_training": False, "should_stop": False}
|
||||
try:
|
||||
for progress, log_msg, plot, state in train_h.start_lokr_training(
|
||||
tensor_dir, dit_handler,
|
||||
ldim, lalpha, factor, decompose_both, use_tucker,
|
||||
use_scalar, weight_decompose,
|
||||
lr, ep, bs, ga, se, sh, sd, od, ts,
|
||||
):
|
||||
yield progress, log_msg, plot, state
|
||||
except Exception as e:
|
||||
logger.exception("LoKr training wrapper error")
|
||||
yield f"❌ Error: {str(e)}", str(e), None, ts
|
||||
|
||||
training_section["start_lokr_training_btn"].click(
|
||||
fn=lokr_training_wrapper,
|
||||
inputs=[
|
||||
training_section["lokr_training_tensor_dir"],
|
||||
training_section["lokr_linear_dim"],
|
||||
training_section["lokr_linear_alpha"],
|
||||
training_section["lokr_factor"],
|
||||
training_section["lokr_decompose_both"],
|
||||
training_section["lokr_use_tucker"],
|
||||
training_section["lokr_use_scalar"],
|
||||
training_section["lokr_weight_decompose"],
|
||||
training_section["lokr_learning_rate"],
|
||||
training_section["lokr_train_epochs"],
|
||||
training_section["lokr_train_batch_size"],
|
||||
training_section["lokr_gradient_accumulation"],
|
||||
training_section["lokr_save_every_n_epochs"],
|
||||
training_section["lokr_training_shift"],
|
||||
training_section["lokr_training_seed"],
|
||||
training_section["lokr_output_dir"],
|
||||
training_section["training_state"],
|
||||
],
|
||||
outputs=[
|
||||
training_section["lokr_training_progress"],
|
||||
training_section["lokr_training_log"],
|
||||
training_section["lokr_training_loss_plot"],
|
||||
training_section["training_state"],
|
||||
]
|
||||
)
|
||||
|
||||
# Stop LoKr training (reuses same stop mechanism)
|
||||
training_section["stop_lokr_training_btn"].click(
|
||||
fn=train_h.stop_training,
|
||||
inputs=[training_section["training_state"]],
|
||||
outputs=[
|
||||
training_section["lokr_training_progress"],
|
||||
training_section["training_state"],
|
||||
]
|
||||
)
|
||||
|
||||
# Refresh LoKr export epochs
|
||||
training_section["refresh_lokr_export_epochs_btn"].click(
|
||||
fn=train_h.list_lokr_export_epochs,
|
||||
inputs=[training_section["lokr_output_dir"]],
|
||||
outputs=[
|
||||
training_section["lokr_export_epoch"],
|
||||
training_section["lokr_export_status"],
|
||||
]
|
||||
)
|
||||
|
||||
# Export LoKr
|
||||
training_section["export_lokr_btn"].click(
|
||||
fn=train_h.export_lokr,
|
||||
inputs=[
|
||||
training_section["lokr_export_path"],
|
||||
training_section["lokr_output_dir"],
|
||||
training_section["lokr_export_epoch"],
|
||||
],
|
||||
outputs=[training_section["lokr_export_status"]]
|
||||
)
|
||||
File diff suppressed because it is too large
Load diff
88
acestep/gradio_ui/events/generation_handlers_test.py
Normal file
88
acestep/gradio_ui/events/generation_handlers_test.py
Normal file
|
|
@ -0,0 +1,88 @@
|
|||
"""Unit tests for generation input event handlers."""
|
||||
|
||||
import unittest
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
try:
|
||||
from acestep.gradio_ui.events import generation_handlers
|
||||
_IMPORT_ERROR = None
|
||||
except Exception as exc: # pragma: no cover - environment dependency guard
|
||||
generation_handlers = None
|
||||
_IMPORT_ERROR = exc
|
||||
|
||||
|
||||
class _FakeDitHandler:
|
||||
"""Minimal DiT handler stub for analyze-src-audio tests."""
|
||||
|
||||
def __init__(self, convert_result):
|
||||
self._convert_result = convert_result
|
||||
|
||||
def convert_src_audio_to_codes(self, _src_audio):
|
||||
"""Return configured conversion output."""
|
||||
return self._convert_result
|
||||
|
||||
|
||||
@unittest.skipIf(generation_handlers is None, f"generation_handlers import unavailable: {_IMPORT_ERROR}")
|
||||
class GenerationHandlersTests(unittest.TestCase):
|
||||
"""Tests for source-audio analysis validation behavior."""
|
||||
|
||||
@patch("acestep.gradio_ui.events.generation_handlers.gr.Warning")
|
||||
@patch("acestep.gradio_ui.events.generation_handlers.understand_music")
|
||||
def test_analyze_src_audio_rejects_non_audio_code_output(
|
||||
self,
|
||||
understand_music_mock,
|
||||
warning_mock,
|
||||
):
|
||||
"""Reject conversion output that has no serialized audio-code tokens."""
|
||||
dit_handler = _FakeDitHandler("ERROR: not an audio file")
|
||||
llm_handler = SimpleNamespace(llm_initialized=True)
|
||||
|
||||
result = generation_handlers.analyze_src_audio(
|
||||
dit_handler=dit_handler,
|
||||
llm_handler=llm_handler,
|
||||
src_audio="fake.mp3",
|
||||
constrained_decoding_debug=False,
|
||||
)
|
||||
|
||||
self.assertEqual(result, ("", "", "", "", None, None, "", "", "", False))
|
||||
understand_music_mock.assert_not_called()
|
||||
warning_mock.assert_called_once()
|
||||
|
||||
@patch("acestep.gradio_ui.events.generation_handlers.gr.Warning")
|
||||
@patch("acestep.gradio_ui.events.generation_handlers.understand_music")
|
||||
def test_analyze_src_audio_allows_valid_audio_code_output(
|
||||
self,
|
||||
understand_music_mock,
|
||||
warning_mock,
|
||||
):
|
||||
"""Pass valid audio codes through to LM understanding."""
|
||||
dit_handler = _FakeDitHandler("<|audio_code_123|><|audio_code_456|>")
|
||||
llm_handler = SimpleNamespace(llm_initialized=True)
|
||||
understand_music_mock.return_value = SimpleNamespace(
|
||||
success=True,
|
||||
status_message="ok",
|
||||
caption="caption",
|
||||
lyrics="lyrics",
|
||||
bpm=120,
|
||||
duration=30.0,
|
||||
keyscale="C major",
|
||||
language="en",
|
||||
timesignature="4",
|
||||
)
|
||||
|
||||
result = generation_handlers.analyze_src_audio(
|
||||
dit_handler=dit_handler,
|
||||
llm_handler=llm_handler,
|
||||
src_audio="real.mp3",
|
||||
constrained_decoding_debug=False,
|
||||
)
|
||||
|
||||
self.assertEqual(result[0], "<|audio_code_123|><|audio_code_456|>")
|
||||
self.assertEqual(result[1], "ok")
|
||||
understand_music_mock.assert_called_once()
|
||||
warning_mock.assert_not_called()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
@ -5,14 +5,21 @@ Contains event handlers and helper functions related to result display, scoring,
|
|||
import os
|
||||
import json
|
||||
import datetime
|
||||
import math
|
||||
import re
|
||||
import tempfile
|
||||
import shutil
|
||||
import zipfile
|
||||
import time as time_module
|
||||
import sys
|
||||
from typing import Dict, Any, Optional, List
|
||||
import gradio as gr
|
||||
from loguru import logger
|
||||
from acestep.gradio_ui.i18n import t
|
||||
from acestep.gradio_ui.events.generation_handlers import parse_and_validate_timesteps
|
||||
from acestep.gradio_ui.events.generation_handlers import (
|
||||
parse_and_validate_timesteps,
|
||||
compute_mode_ui_updates,
|
||||
)
|
||||
from acestep.inference import generate_music, GenerationParams, GenerationConfig
|
||||
from acestep.audio_utils import save_audio
|
||||
from acestep.gpu_config import (
|
||||
|
|
@ -21,7 +28,7 @@ from acestep.gpu_config import (
|
|||
check_batch_size_limit,
|
||||
)
|
||||
|
||||
# Platform detection for Windows-specific fixess
|
||||
# Platform detection for Windows-specific fixes
|
||||
IS_WINDOWS = sys.platform == "win32"
|
||||
|
||||
# Global results directory inside project root
|
||||
|
|
@ -272,122 +279,59 @@ def _build_generation_info(
|
|||
seed_value: str,
|
||||
inference_steps: int,
|
||||
num_audios: int,
|
||||
audio_format: str = "flac",
|
||||
) -> str:
|
||||
"""Build generation info string from result data.
|
||||
"""Build a compact generation timing summary.
|
||||
|
||||
Args:
|
||||
lm_metadata: LM-generated metadata dictionary
|
||||
lm_metadata: LM-generated metadata dictionary (unused, kept for API compat)
|
||||
time_costs: Unified time costs dictionary
|
||||
seed_value: Seed value string
|
||||
inference_steps: Number of inference steps
|
||||
seed_value: Seed value string (unused, kept for API compat)
|
||||
inference_steps: Number of inference steps (unused, kept for API compat)
|
||||
num_audios: Number of generated audios
|
||||
audio_format: Output audio format name (e.g. "flac", "mp3", "wav32")
|
||||
|
||||
Returns:
|
||||
Formatted generation info string
|
||||
"""
|
||||
if not time_costs or num_audios <= 0:
|
||||
return ""
|
||||
|
||||
songs_label = f"({num_audios} song{'s' if num_audios > 1 else ''})"
|
||||
info_parts = []
|
||||
|
||||
# Part 1: Per-track average time (prominently displayed at the top)
|
||||
# Only count model time (LM + DiT), not post-processing like audio conversion
|
||||
if time_costs and num_audios > 0:
|
||||
lm_total = time_costs.get('lm_total_time', 0.0)
|
||||
dit_total = time_costs.get('dit_total_time_cost', 0.0)
|
||||
model_total = lm_total + dit_total
|
||||
if model_total > 0:
|
||||
avg_time_per_track = model_total / num_audios
|
||||
avg_section = f"**🎯 Average Time per Track: {avg_time_per_track:.2f}s** ({num_audios} track(s))"
|
||||
info_parts.append(avg_section)
|
||||
|
||||
# Part 2: LM-generated metadata (if available)
|
||||
if lm_metadata:
|
||||
metadata_lines = []
|
||||
if lm_metadata.get('bpm'):
|
||||
metadata_lines.append(f"- **BPM:** {lm_metadata['bpm']}")
|
||||
if lm_metadata.get('caption'):
|
||||
metadata_lines.append(f"- **Refined Caption:** {lm_metadata['caption']}")
|
||||
if lm_metadata.get('lyrics'):
|
||||
metadata_lines.append(f"- **Refined Lyrics:** {lm_metadata['lyrics']}")
|
||||
if lm_metadata.get('duration'):
|
||||
metadata_lines.append(f"- **Duration:** {lm_metadata['duration']} seconds")
|
||||
if lm_metadata.get('keyscale'):
|
||||
metadata_lines.append(f"- **Key Scale:** {lm_metadata['keyscale']}")
|
||||
if lm_metadata.get('language'):
|
||||
metadata_lines.append(f"- **Language:** {lm_metadata['language']}")
|
||||
if lm_metadata.get('timesignature'):
|
||||
metadata_lines.append(f"- **Time Signature:** {lm_metadata['timesignature']}")
|
||||
|
||||
if metadata_lines:
|
||||
metadata_section = "**🤖 LM-Generated Metadata:**\n" + "\n".join(metadata_lines)
|
||||
info_parts.append(metadata_section)
|
||||
|
||||
# Part 3: Time costs breakdown (formatted and beautified)
|
||||
if time_costs:
|
||||
time_lines = []
|
||||
|
||||
# LM time costs
|
||||
lm_phase1 = time_costs.get('lm_phase1_time', 0.0)
|
||||
lm_phase2 = time_costs.get('lm_phase2_time', 0.0)
|
||||
lm_total = time_costs.get('lm_total_time', 0.0)
|
||||
|
||||
|
||||
# --- Block 1: Generation time (LM + DiT) ---
|
||||
lm_total = time_costs.get('lm_total_time', 0.0)
|
||||
dit_total = time_costs.get('dit_total_time_cost', 0.0)
|
||||
gen_total = lm_total + dit_total
|
||||
|
||||
if gen_total > 0:
|
||||
avg = gen_total / num_audios
|
||||
lines = [f"**🎵 Total generation time {songs_label}: {gen_total:.2f}s**"]
|
||||
lines.append(f"- {avg:.2f}s per song")
|
||||
if lm_total > 0:
|
||||
time_lines.append("**🧠 LM Time:**")
|
||||
if lm_phase1 > 0:
|
||||
time_lines.append(f" - Phase 1 (CoT): {lm_phase1:.2f}s")
|
||||
if lm_phase2 > 0:
|
||||
time_lines.append(f" - Phase 2 (Codes): {lm_phase2:.2f}s")
|
||||
time_lines.append(f" - Total: {lm_total:.2f}s")
|
||||
|
||||
# DiT time costs
|
||||
dit_encoder = time_costs.get('dit_encoder_time_cost', 0.0)
|
||||
dit_model = time_costs.get('dit_model_time_cost', 0.0)
|
||||
dit_vae_decode = time_costs.get('dit_vae_decode_time_cost', 0.0)
|
||||
dit_offload = time_costs.get('dit_offload_time_cost', 0.0)
|
||||
dit_total = time_costs.get('dit_total_time_cost', 0.0)
|
||||
lines.append(f"- LM phase {songs_label}: {lm_total:.2f}s")
|
||||
if dit_total > 0:
|
||||
time_lines.append("\n**🎵 DiT Time:**")
|
||||
if dit_encoder > 0:
|
||||
time_lines.append(f" - Encoder: {dit_encoder:.2f}s")
|
||||
if dit_model > 0:
|
||||
time_lines.append(f" - Model: {dit_model:.2f}s")
|
||||
if dit_vae_decode > 0:
|
||||
time_lines.append(f" - VAE Decode: {dit_vae_decode:.2f}s")
|
||||
if dit_offload > 0:
|
||||
time_lines.append(f" - Offload: {dit_offload:.2f}s")
|
||||
time_lines.append(f" - Total: {dit_total:.2f}s")
|
||||
|
||||
# Post-processing time costs
|
||||
audio_conversion_time = time_costs.get('audio_conversion_time', 0.0)
|
||||
auto_score_time = time_costs.get('auto_score_time', 0.0)
|
||||
auto_lrc_time = time_costs.get('auto_lrc_time', 0.0)
|
||||
|
||||
if audio_conversion_time > 0 or auto_score_time > 0 or auto_lrc_time > 0:
|
||||
time_lines.append("\n**🔧 Post-processing Time:**")
|
||||
if audio_conversion_time > 0:
|
||||
time_lines.append(f" - Audio Conversion: {audio_conversion_time:.2f}s")
|
||||
if auto_score_time > 0:
|
||||
time_lines.append(f" - Auto Score: {auto_score_time:.2f}s")
|
||||
if auto_lrc_time > 0:
|
||||
time_lines.append(f" - Auto LRC: {auto_lrc_time:.2f}s")
|
||||
|
||||
if time_lines:
|
||||
time_section = "\n".join(time_lines)
|
||||
info_parts.append(time_section)
|
||||
|
||||
# Part 4: Generation summary
|
||||
summary_lines = [
|
||||
"**🎵 Generation Complete**",
|
||||
f" - **Seeds:** {seed_value}",
|
||||
f" - **Steps:** {inference_steps}",
|
||||
f" - **Audio Count:** {num_audios} audio(s)",
|
||||
]
|
||||
info_parts.append("\n".join(summary_lines))
|
||||
|
||||
# Part 5: Pipeline total time (at the end)
|
||||
pipeline_total = time_costs.get('pipeline_total_time', 0.0) if time_costs else 0.0
|
||||
if pipeline_total > 0:
|
||||
info_parts.append(f"**⏱️ Total Time: {pipeline_total:.2f}s**")
|
||||
|
||||
# Combine all parts
|
||||
lines.append(f"- DiT phase {songs_label}: {dit_total:.2f}s")
|
||||
info_parts.append("\n".join(lines))
|
||||
|
||||
# --- Block 2: Processing time (conversion + scoring + LRC) ---
|
||||
audio_conversion_time = time_costs.get('audio_conversion_time', 0.0)
|
||||
auto_score_time = time_costs.get('auto_score_time', 0.0)
|
||||
auto_lrc_time = time_costs.get('auto_lrc_time', 0.0)
|
||||
proc_total = audio_conversion_time + auto_score_time + auto_lrc_time
|
||||
|
||||
if proc_total > 0:
|
||||
fmt_label = audio_format.upper() if audio_format != "wav32" else "WAV 32-bit"
|
||||
lines = [f"**🔧 Total processing time {songs_label}: {proc_total:.2f}s**"]
|
||||
if audio_conversion_time > 0:
|
||||
lines.append(f"- to {fmt_label} {songs_label}: {audio_conversion_time:.2f}s")
|
||||
if auto_score_time > 0:
|
||||
lines.append(f"- scoring {songs_label}: {auto_score_time:.2f}s")
|
||||
if auto_lrc_time > 0:
|
||||
lines.append(f"- LRC detection {songs_label}: {auto_lrc_time:.2f}s")
|
||||
info_parts.append("\n".join(lines))
|
||||
|
||||
return "\n\n".join(info_parts)
|
||||
|
||||
|
||||
|
|
@ -450,7 +394,6 @@ def send_audio_to_src_with_metadata(audio_file, lm_metadata):
|
|||
|
||||
This function ONLY sets the src_audio field. All other metadata fields (caption, lyrics, etc.)
|
||||
are preserved by returning gr.skip() to avoid overwriting user's existing inputs.
|
||||
Also opens the audio_uploads_accordion so the user can see the uploaded audio.
|
||||
|
||||
Args:
|
||||
audio_file: Audio file path
|
||||
|
|
@ -458,7 +401,7 @@ def send_audio_to_src_with_metadata(audio_file, lm_metadata):
|
|||
|
||||
Returns:
|
||||
Tuple of (audio_file, bpm, caption, lyrics, duration, key_scale, language, time_signature, is_format_caption, audio_uploads_accordion)
|
||||
All values except audio_file and audio_uploads_accordion are gr.skip() to preserve existing UI values
|
||||
All values except audio_file and accordion are gr.skip() to preserve existing UI values
|
||||
"""
|
||||
if audio_file is None:
|
||||
# Return all skip to not modify anything
|
||||
|
|
@ -476,7 +419,106 @@ def send_audio_to_src_with_metadata(audio_file, lm_metadata):
|
|||
gr.skip(), # language - preserve existing value
|
||||
gr.skip(), # time_signature - preserve existing value
|
||||
gr.skip(), # is_format_caption - preserve existing value
|
||||
gr.Accordion(open=True), # audio_uploads_accordion - expand to show uploaded audio
|
||||
gr.Accordion(open=True), # audio_uploads_accordion - open to show the src audio
|
||||
)
|
||||
|
||||
|
||||
def _extract_metadata_for_editing(lm_metadata, current_lyrics="", current_caption=""):
|
||||
"""Extract lyrics and caption from lm_metadata for repaint/remix operations.
|
||||
|
||||
Falls back to current UI values when lm_metadata is missing or incomplete,
|
||||
so that existing user input is not overwritten with empty strings.
|
||||
|
||||
Args:
|
||||
lm_metadata: Metadata dictionary from LM generation (or None)
|
||||
current_lyrics: Current lyrics value from the UI (fallback)
|
||||
current_caption: Current caption value from the UI (fallback)
|
||||
|
||||
Returns:
|
||||
Tuple of (lyrics, caption) as strings
|
||||
"""
|
||||
lyrics = current_lyrics or ""
|
||||
caption = current_caption or ""
|
||||
if lm_metadata and isinstance(lm_metadata, dict):
|
||||
lyrics = lm_metadata.get("lyrics", lyrics)
|
||||
caption = lm_metadata.get("caption", caption)
|
||||
return lyrics, caption
|
||||
|
||||
|
||||
def send_audio_to_remix(audio_file, lm_metadata, current_lyrics, current_caption,
|
||||
llm_handler=None):
|
||||
"""Send generated audio to src_audio and switch mode to Remix.
|
||||
|
||||
Also populate lyrics and caption fields from the generated audio,
|
||||
and apply all Remix-mode UI updates atomically (visibility, labels)
|
||||
so the UI is correct without relying on a chained .change() event.
|
||||
|
||||
Args:
|
||||
audio_file: Generated audio file path
|
||||
lm_metadata: LM metadata dict (may be None)
|
||||
current_lyrics: Current lyrics text in the UI
|
||||
current_caption: Current caption text in the UI
|
||||
llm_handler: Optional LLM handler (for think-checkbox state)
|
||||
|
||||
Returns:
|
||||
Tuple of (src_audio, generation_mode, lyrics, caption,
|
||||
*mode_ui_updates) — 4 + 19 = 23 values.
|
||||
"""
|
||||
# 4 data outputs + 19 mode-UI outputs
|
||||
n_outputs = 23
|
||||
if audio_file is None:
|
||||
return (gr.skip(),) * n_outputs
|
||||
|
||||
lyrics, caption = _extract_metadata_for_editing(
|
||||
lm_metadata, current_lyrics, current_caption
|
||||
)
|
||||
|
||||
mode_updates = compute_mode_ui_updates("Remix", llm_handler)
|
||||
|
||||
return (
|
||||
audio_file, # src_audio
|
||||
gr.update(value="Remix"), # generation_mode -> Remix
|
||||
lyrics, # lyrics
|
||||
caption, # caption
|
||||
*mode_updates, # 19 mode-UI updates
|
||||
)
|
||||
|
||||
|
||||
def send_audio_to_repaint(audio_file, lm_metadata, current_lyrics, current_caption,
|
||||
llm_handler=None):
|
||||
"""Send generated audio to src_audio and switch mode to Repaint.
|
||||
|
||||
Also populate lyrics and caption fields from the generated audio,
|
||||
and apply all Repaint-mode UI updates atomically (visibility, labels)
|
||||
so the UI is correct without relying on a chained .change() event.
|
||||
|
||||
Args:
|
||||
audio_file: Generated audio file path
|
||||
lm_metadata: LM metadata dict (may be None)
|
||||
current_lyrics: Current lyrics text in the UI
|
||||
current_caption: Current caption text in the UI
|
||||
llm_handler: Optional LLM handler (for think-checkbox state)
|
||||
|
||||
Returns:
|
||||
Tuple of (src_audio, generation_mode, lyrics, caption,
|
||||
*mode_ui_updates) — 4 + 19 = 23 values.
|
||||
"""
|
||||
n_outputs = 23
|
||||
if audio_file is None:
|
||||
return (gr.skip(),) * n_outputs
|
||||
|
||||
lyrics, caption = _extract_metadata_for_editing(
|
||||
lm_metadata, current_lyrics, current_caption
|
||||
)
|
||||
|
||||
mode_updates = compute_mode_ui_updates("Repaint", llm_handler)
|
||||
|
||||
return (
|
||||
audio_file, # src_audio
|
||||
gr.update(value="Repaint"), # generation_mode -> Repaint
|
||||
lyrics, # lyrics
|
||||
caption, # caption
|
||||
*mode_updates, # 19 mode-UI updates
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -486,7 +528,7 @@ def generate_with_progress(
|
|||
inference_steps, guidance_scale, random_seed_checkbox, seed,
|
||||
reference_audio, audio_duration, batch_size_input, src_audio,
|
||||
text2music_audio_code_string, repainting_start, repainting_end,
|
||||
instruction_display_gen, audio_cover_strength, task_type,
|
||||
instruction_display_gen, audio_cover_strength, cover_noise_strength, task_type,
|
||||
use_adg, cfg_interval_start, cfg_interval_end, shift, infer_method, custom_timesteps, audio_format, lm_temperature,
|
||||
think_checkbox, lm_cfg_scale, lm_top_k, lm_top_p, lm_negative_prompt,
|
||||
use_cot_metas, use_cot_caption, use_cot_language, is_format_caption,
|
||||
|
|
@ -495,9 +537,15 @@ def generate_with_progress(
|
|||
auto_score,
|
||||
auto_lrc,
|
||||
score_scale,
|
||||
|
||||
lm_batch_chunk_size,
|
||||
enable_normalization,
|
||||
normalization_db,
|
||||
latent_shift,
|
||||
latent_rescale,
|
||||
progress=gr.Progress(track_tqdm=True),
|
||||
):
|
||||
|
||||
"""Generate audio with progress tracking"""
|
||||
|
||||
# ========== GPU Memory Validation ==========
|
||||
|
|
@ -533,11 +581,14 @@ def generate_with_progress(
|
|||
logger.info("[generate_with_progress] Skipping Phase 1 metas COT: sample is already formatted (is_format_caption=True)")
|
||||
gr.Info(t("messages.skipping_metas_cot"))
|
||||
|
||||
# Parse and validate custom timesteps
|
||||
parsed_timesteps, has_timesteps_warning, _ = parse_and_validate_timesteps(custom_timesteps, inference_steps)
|
||||
actual_inference_steps = int(inference_steps) if inference_steps is not None else 8
|
||||
|
||||
# Update inference_steps if custom timesteps provided (to match UI display)
|
||||
actual_inference_steps = inference_steps
|
||||
if parsed_timesteps is not None:
|
||||
actual_inference_steps = len(parsed_timesteps) - 1
|
||||
|
||||
|
||||
# step 1: prepare inputs
|
||||
# generate_music, GenerationParams, GenerationConfig
|
||||
gen_params = GenerationParams(
|
||||
|
|
@ -565,6 +616,7 @@ def generate_with_progress(
|
|||
repainting_start=repainting_start,
|
||||
repainting_end=repainting_end,
|
||||
audio_cover_strength=audio_cover_strength,
|
||||
cover_noise_strength=cover_noise_strength,
|
||||
thinking=think_checkbox,
|
||||
lm_temperature=lm_temperature,
|
||||
lm_cfg_scale=lm_cfg_scale,
|
||||
|
|
@ -575,20 +627,18 @@ def generate_with_progress(
|
|||
use_cot_caption=use_cot_caption,
|
||||
use_cot_language=use_cot_language,
|
||||
use_constrained_decoding=True,
|
||||
enable_normalization=enable_normalization,
|
||||
normalization_db=normalization_db,
|
||||
latent_shift=latent_shift,
|
||||
latent_rescale=latent_rescale,
|
||||
)
|
||||
if isinstance(seed, (int, float)):
|
||||
seed_list = [int(seed)] if seed >= 0 else None
|
||||
elif isinstance(seed, str) and seed.strip():
|
||||
|
||||
# seed string to list
|
||||
if isinstance(seed, str) and seed.strip():
|
||||
if "," in seed:
|
||||
try:
|
||||
seed_list = [int(s.strip()) for s in seed.split(",")]
|
||||
except (ValueError, TypeError):
|
||||
seed_list = None
|
||||
seed_list = [int(s.strip()) for s in seed.split(",")]
|
||||
else:
|
||||
try:
|
||||
seed_list = [int(seed.strip())]
|
||||
except (ValueError, TypeError):
|
||||
seed_list = None
|
||||
seed_list = [int(seed.strip())]
|
||||
else:
|
||||
seed_list = None
|
||||
gen_config = GenerationConfig(
|
||||
|
|
@ -637,6 +687,7 @@ def generate_with_progress(
|
|||
seed_value=seed_value_for_ui,
|
||||
inference_steps=inference_steps,
|
||||
num_audios=len(result.audios) if result.success else 0,
|
||||
audio_format=audio_format,
|
||||
)
|
||||
|
||||
if not result.success:
|
||||
|
|
@ -669,7 +720,7 @@ def generate_with_progress(
|
|||
# Clear lrc_display with empty string - this triggers .change() to clear subtitles
|
||||
clear_lrcs = [gr.update(value="", visible=True) for _ in range(8)]
|
||||
clear_accordions = [gr.skip() for _ in range(8)] # Don't change accordion visibility
|
||||
dump_audio = [gr.update(value=None, subtitles=None) for _ in range(8)]
|
||||
dump_audio = [gr.update(value=None, subtitles=None, playback_position=0) for _ in range(8)]
|
||||
yield (
|
||||
# Audio outputs - just skip, value will be updated in loop
|
||||
# Subtitles will be cleared via lrc_display.change()
|
||||
|
|
@ -697,29 +748,33 @@ def generate_with_progress(
|
|||
)
|
||||
time_module.sleep(0.1)
|
||||
|
||||
final_codes_display_updates = [gr.skip() for _ in range(8)]
|
||||
for i in range(8):
|
||||
if i < len(audios):
|
||||
key = audios[i]["key"]
|
||||
audio_tensor = audios[i]["tensor"]
|
||||
sample_rate = audios[i]["sample_rate"]
|
||||
audio_params = audios[i]["params"]
|
||||
is_silent = audios[i].get("silent", False)
|
||||
# Use local output directory instead of system temp
|
||||
timestamp = int(time_module.time())
|
||||
temp_dir = os.path.join(DEFAULT_RESULTS_DIR, f"batch_{timestamp}")
|
||||
temp_dir = os.path.abspath(temp_dir).replace("\\", "/")
|
||||
os.makedirs(temp_dir, exist_ok=True)
|
||||
json_path = os.path.join(temp_dir, f"{key}.json").replace("\\", "/")
|
||||
audio_path = os.path.join(temp_dir, f"{key}.{audio_format}").replace("\\", "/")
|
||||
if not is_silent and audio_tensor is not None:
|
||||
save_audio(audio_data=audio_tensor, output_path=audio_path, sample_rate=sample_rate, format=audio_format, channels_first=True)
|
||||
with open(json_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(audio_params, f, indent=2, ensure_ascii=False)
|
||||
audio_outputs[i] = audio_path
|
||||
all_audio_paths.append(audio_path)
|
||||
all_audio_paths.append(json_path)
|
||||
else:
|
||||
audio_outputs[i] = None
|
||||
|
||||
# Handle wav32 extension
|
||||
ext = "wav" if audio_format == "wav32" else audio_format
|
||||
audio_path = os.path.join(temp_dir, f"{key}.{ext}").replace("\\", "/")
|
||||
|
||||
# Save audio and capture actual saved path
|
||||
saved_path = save_audio(audio_data=audio_tensor, output_path=audio_path, sample_rate=sample_rate, format=audio_format, channels_first=True)
|
||||
if saved_path:
|
||||
audio_path = saved_path.replace("\\", "/")
|
||||
|
||||
with open(json_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(audio_params, f, indent=2, ensure_ascii=False)
|
||||
audio_outputs[i] = audio_path
|
||||
all_audio_paths.append(audio_path)
|
||||
all_audio_paths.append(json_path)
|
||||
|
||||
code_str = audio_params.get("audio_codes", "")
|
||||
final_codes_list[i] = code_str
|
||||
|
|
@ -832,7 +887,6 @@ def generate_with_progress(
|
|||
|
||||
codes_display_updates = [gr.skip() for _ in range(8)]
|
||||
codes_display_updates[i] = gr.update(value=code_str, visible=True) # Keep visible=True
|
||||
final_codes_display_updates[i] = gr.update(value=code_str, visible=True) # Keep visible=True
|
||||
|
||||
details_accordion_updates = [gr.skip() for _ in range(8)]
|
||||
# Don't change accordion visibility - keep it always expandable
|
||||
|
|
@ -927,18 +981,28 @@ def generate_with_progress(
|
|||
seed_value=seed_value_for_ui,
|
||||
inference_steps=inference_steps,
|
||||
num_audios=len(result.audios),
|
||||
audio_format=audio_format,
|
||||
)
|
||||
|
||||
# Build final codes display, LRC display, accordion visibility updates
|
||||
final_codes_display_updates = [gr.skip() for _ in range(8)]
|
||||
# final_lrc_display_updates = [gr.skip() for _ in range(8)]
|
||||
final_accordion_updates = [gr.skip() for _ in range(8)]
|
||||
|
||||
# Audio was already sent in loop yields, just reset playback position to 0
|
||||
# This resets the playback cursor to the beginning without reloading the audio
|
||||
audio_playback_updates = [gr.update(playback_position=0) for _ in range(8)]
|
||||
# On Windows, progressive yields are disabled, so we must return actual audio paths
|
||||
# On other platforms, audio was already sent in loop yields, just reset playback position
|
||||
# Use gr.update() to force Gradio to update the audio component (Issue #113)
|
||||
audio_playback_updates = []
|
||||
for idx in range(8):
|
||||
path = audio_outputs[idx]
|
||||
if path:
|
||||
audio_playback_updates.append(gr.update(value=path, label=f"Sample {idx+1} (Ready)", interactive=False))
|
||||
logger.info(f"[generate_with_progress] Audio {idx+1} path: {path}")
|
||||
else:
|
||||
audio_playback_updates.append(gr.update(value=None, label="None", interactive=False))
|
||||
|
||||
yield (
|
||||
# Audio outputs - reset playback position to beginning
|
||||
# Audio outputs - use gr.update() to force component refresh
|
||||
audio_playback_updates[0], audio_playback_updates[1], audio_playback_updates[2], audio_playback_updates[3],
|
||||
audio_playback_updates[4], audio_playback_updates[5], audio_playback_updates[6], audio_playback_updates[7],
|
||||
all_audio_paths,
|
||||
|
|
@ -1292,7 +1356,8 @@ def generate_lrc_handler(dit_handler, sample_idx, current_batch_index, batch_que
|
|||
Tuple of (lrc_display_update, details_accordion_update, batch_queue)
|
||||
Note: No audio_update - subtitles updated via lrc_display.change()
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
if current_batch_index not in batch_queue:
|
||||
return gr.skip(), gr.skip(), batch_queue
|
||||
|
||||
|
|
@ -1420,13 +1485,17 @@ def capture_current_params(
|
|||
inference_steps, guidance_scale, random_seed_checkbox, seed,
|
||||
reference_audio, audio_duration, batch_size_input, src_audio,
|
||||
text2music_audio_code_string, repainting_start, repainting_end,
|
||||
instruction_display_gen, audio_cover_strength, task_type,
|
||||
instruction_display_gen, audio_cover_strength, cover_noise_strength, task_type,
|
||||
use_adg, cfg_interval_start, cfg_interval_end, shift, infer_method, custom_timesteps, audio_format, lm_temperature,
|
||||
think_checkbox, lm_cfg_scale, lm_top_k, lm_top_p, lm_negative_prompt,
|
||||
use_cot_metas, use_cot_caption, use_cot_language,
|
||||
constrained_decoding_debug, allow_lm_batch, auto_score, auto_lrc, score_scale, lm_batch_chunk_size,
|
||||
track_name, complete_track_classes
|
||||
|
||||
track_name, complete_track_classes,
|
||||
enable_normalization, normalization_db,
|
||||
latent_shift, latent_rescale
|
||||
):
|
||||
|
||||
"""Capture current UI parameters for next batch generation
|
||||
|
||||
IMPORTANT: For AutoGen batches, we clear audio codes to ensure:
|
||||
|
|
@ -1453,6 +1522,7 @@ def capture_current_params(
|
|||
"repainting_end": repainting_end,
|
||||
"instruction_display_gen": instruction_display_gen,
|
||||
"audio_cover_strength": audio_cover_strength,
|
||||
"cover_noise_strength": cover_noise_strength,
|
||||
"task_type": task_type,
|
||||
"use_adg": use_adg,
|
||||
"cfg_interval_start": cfg_interval_start,
|
||||
|
|
@ -1478,16 +1548,22 @@ def capture_current_params(
|
|||
"lm_batch_chunk_size": lm_batch_chunk_size,
|
||||
"track_name": track_name,
|
||||
"complete_track_classes": complete_track_classes,
|
||||
|
||||
"enable_normalization": enable_normalization,
|
||||
"normalization_db": normalization_db,
|
||||
"latent_shift": latent_shift,
|
||||
"latent_rescale": latent_rescale,
|
||||
}
|
||||
|
||||
|
||||
|
||||
def generate_with_batch_management(
|
||||
dit_handler, llm_handler,
|
||||
captions, lyrics, bpm, key_scale, time_signature, vocal_language,
|
||||
inference_steps, guidance_scale, random_seed_checkbox, seed,
|
||||
reference_audio, audio_duration, batch_size_input, src_audio,
|
||||
text2music_audio_code_string, repainting_start, repainting_end,
|
||||
instruction_display_gen, audio_cover_strength, task_type,
|
||||
instruction_display_gen, audio_cover_strength, cover_noise_strength, task_type,
|
||||
use_adg, cfg_interval_start, cfg_interval_end, shift, infer_method, custom_timesteps, audio_format, lm_temperature,
|
||||
think_checkbox, lm_cfg_scale, lm_top_k, lm_top_p, lm_negative_prompt,
|
||||
use_cot_metas, use_cot_caption, use_cot_language, is_format_caption,
|
||||
|
|
@ -1499,7 +1575,13 @@ def generate_with_batch_management(
|
|||
lm_batch_chunk_size,
|
||||
track_name,
|
||||
complete_track_classes,
|
||||
|
||||
enable_normalization,
|
||||
normalization_db,
|
||||
latent_shift,
|
||||
latent_rescale,
|
||||
autogen_checkbox,
|
||||
|
||||
current_batch_index,
|
||||
total_batches,
|
||||
batch_queue,
|
||||
|
|
@ -1516,7 +1598,7 @@ def generate_with_batch_management(
|
|||
inference_steps, guidance_scale, random_seed_checkbox, seed,
|
||||
reference_audio, audio_duration, batch_size_input, src_audio,
|
||||
text2music_audio_code_string, repainting_start, repainting_end,
|
||||
instruction_display_gen, audio_cover_strength, task_type,
|
||||
instruction_display_gen, audio_cover_strength, cover_noise_strength, task_type,
|
||||
use_adg, cfg_interval_start, cfg_interval_end, shift, infer_method, custom_timesteps, audio_format, lm_temperature,
|
||||
think_checkbox, lm_cfg_scale, lm_top_k, lm_top_p, lm_negative_prompt,
|
||||
use_cot_metas, use_cot_caption, use_cot_language, is_format_caption,
|
||||
|
|
@ -1526,18 +1608,28 @@ def generate_with_batch_management(
|
|||
auto_lrc,
|
||||
score_scale,
|
||||
lm_batch_chunk_size,
|
||||
|
||||
enable_normalization,
|
||||
normalization_db,
|
||||
latent_shift,
|
||||
latent_rescale,
|
||||
progress
|
||||
)
|
||||
|
||||
final_result_from_inner = None
|
||||
for partial_result in generator:
|
||||
final_result_from_inner = partial_result
|
||||
# Forward intermediate yields to UI for progressive streaming updates
|
||||
# (audio outputs appear one-by-one as they are ready)
|
||||
# Pad with gr.skip() for the extra batch management outputs
|
||||
yield tuple(partial_result[:46]) + (
|
||||
gr.skip(), gr.skip(), gr.skip(), gr.skip(),
|
||||
gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(),
|
||||
)
|
||||
# Progressive yields disabled on Windows to prevent UI freeze
|
||||
# On other platforms, yield progress updates normally
|
||||
if not IS_WINDOWS:
|
||||
# current_batch_index, total_batches, batch_queue, next_params,
|
||||
# batch_indicator_text, prev_btn, next_btn, next_status, restore_btn
|
||||
# Slice off extra_outputs and raw_codes_list (last 2 items) before re-yielding to UI
|
||||
ui_result = partial_result[:-2] if len(partial_result) > 47 else (partial_result[:-1] if len(partial_result) > 46 else partial_result)
|
||||
yield ui_result + (
|
||||
gr.skip(), gr.skip(), gr.skip(), gr.skip(),
|
||||
gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip()
|
||||
)
|
||||
result = final_result_from_inner
|
||||
all_audio_paths = result[8]
|
||||
|
||||
|
|
@ -1591,6 +1683,7 @@ def generate_with_batch_management(
|
|||
"repainting_end": repainting_end,
|
||||
"instruction_display_gen": instruction_display_gen,
|
||||
"audio_cover_strength": audio_cover_strength,
|
||||
"cover_noise_strength": cover_noise_strength,
|
||||
"task_type": task_type,
|
||||
"use_adg": use_adg,
|
||||
"cfg_interval_start": cfg_interval_start,
|
||||
|
|
@ -1615,7 +1708,13 @@ def generate_with_batch_management(
|
|||
"lm_batch_chunk_size": lm_batch_chunk_size,
|
||||
"track_name": track_name,
|
||||
"complete_track_classes": complete_track_classes,
|
||||
|
||||
"enable_normalization": enable_normalization,
|
||||
"normalization_db": normalization_db,
|
||||
"latent_shift": latent_shift,
|
||||
"latent_rescale": latent_rescale,
|
||||
}
|
||||
|
||||
|
||||
# Next batch parameters (with cleared codes & random seed)
|
||||
# Next batch parameters
|
||||
|
|
@ -1771,6 +1870,7 @@ def generate_next_batch_background(
|
|||
params.setdefault("repainting_end", -1)
|
||||
params.setdefault("instruction_display_gen", "")
|
||||
params.setdefault("audio_cover_strength", 1.0)
|
||||
params.setdefault("cover_noise_strength", 0.0)
|
||||
params.setdefault("task_type", "text2music")
|
||||
params.setdefault("use_adg", False)
|
||||
params.setdefault("cfg_interval_start", 0.0)
|
||||
|
|
@ -1778,7 +1878,7 @@ def generate_next_batch_background(
|
|||
params.setdefault("shift", 1.0)
|
||||
params.setdefault("infer_method", "ode")
|
||||
params.setdefault("custom_timesteps", "")
|
||||
params.setdefault("audio_format", "mp3")
|
||||
params.setdefault("audio_format", "flac")
|
||||
params.setdefault("lm_temperature", 0.85)
|
||||
params.setdefault("think_checkbox", True)
|
||||
params.setdefault("lm_cfg_scale", 2.0)
|
||||
|
|
@ -1796,6 +1896,10 @@ def generate_next_batch_background(
|
|||
params.setdefault("lm_batch_chunk_size", 8)
|
||||
params.setdefault("track_name", None)
|
||||
params.setdefault("complete_track_classes", [])
|
||||
params.setdefault("enable_normalization", True)
|
||||
params.setdefault("normalization_db", -1.0)
|
||||
params.setdefault("latent_shift", 0.0)
|
||||
params.setdefault("latent_rescale", 1.0)
|
||||
|
||||
# Call generate_with_progress with the saved parameters
|
||||
# Note: generate_with_progress is a generator, need to iterate through it
|
||||
|
|
@ -1823,6 +1927,7 @@ def generate_next_batch_background(
|
|||
repainting_end=params.get("repainting_end"),
|
||||
instruction_display_gen=params.get("instruction_display_gen"),
|
||||
audio_cover_strength=params.get("audio_cover_strength"),
|
||||
cover_noise_strength=params.get("cover_noise_strength", 0.0),
|
||||
task_type=params.get("task_type"),
|
||||
use_adg=params.get("use_adg"),
|
||||
cfg_interval_start=params.get("cfg_interval_start"),
|
||||
|
|
@ -1847,6 +1952,10 @@ def generate_next_batch_background(
|
|||
auto_lrc=params.get("auto_lrc"),
|
||||
score_scale=params.get("score_scale"),
|
||||
lm_batch_chunk_size=params.get("lm_batch_chunk_size"),
|
||||
enable_normalization=params.get("enable_normalization"),
|
||||
normalization_db=params.get("normalization_db"),
|
||||
latent_shift=params.get("latent_shift", 0.0),
|
||||
latent_rescale=params.get("latent_rescale", 1.0),
|
||||
progress=progress
|
||||
)
|
||||
|
||||
|
|
@ -2005,7 +2114,6 @@ def navigate_to_previous_batch(current_batch_index, batch_queue):
|
|||
for idx in range(8):
|
||||
if idx < len(real_audio_paths):
|
||||
audio_path = real_audio_paths[idx].replace("\\", "/") # Normalize path
|
||||
# Pass path directly; Gradio Audio component with type="filepath" expects a string path
|
||||
audio_updates.append(gr.update(value=audio_path))
|
||||
else:
|
||||
audio_updates.append(gr.update(value=None))
|
||||
|
|
@ -2129,7 +2237,6 @@ def navigate_to_next_batch(autogen_enabled, current_batch_index, total_batches,
|
|||
for idx in range(8):
|
||||
if idx < len(real_audio_paths):
|
||||
audio_path = real_audio_paths[idx].replace("\\", "/") # Normalize path
|
||||
# Pass path directly; Gradio Audio component with type="filepath" expects a string path
|
||||
audio_updates.append(gr.update(value=audio_path))
|
||||
else:
|
||||
audio_updates.append(gr.update(value=None))
|
||||
|
|
@ -2257,6 +2364,15 @@ def restore_batch_parameters(current_batch_index, batch_queue):
|
|||
track_name = params.get("track_name", None)
|
||||
complete_track_classes = params.get("complete_track_classes", [])
|
||||
|
||||
# Audio Normalization
|
||||
enable_normalization = params.get("enable_normalization", True)
|
||||
normalization_db = params.get("normalization_db", -1.0)
|
||||
|
||||
# Latent Shift / Rescale
|
||||
latent_shift = params.get("latent_shift", 0.0)
|
||||
latent_rescale = params.get("latent_rescale", 1.0)
|
||||
|
||||
|
||||
# Extract codes - only restore to single input
|
||||
stored_codes = batch_data.get("codes", "")
|
||||
if stored_codes:
|
||||
|
|
@ -2276,5 +2392,65 @@ def restore_batch_parameters(current_batch_index, batch_queue):
|
|||
vocal_language, audio_duration, batch_size_input, inference_steps,
|
||||
lm_temperature, lm_cfg_scale, lm_top_k, lm_top_p, think_checkbox,
|
||||
use_cot_caption, use_cot_language, allow_lm_batch,
|
||||
track_name, complete_track_classes
|
||||
)
|
||||
track_name, complete_track_classes,
|
||||
enable_normalization, normalization_db,
|
||||
latent_shift, latent_rescale
|
||||
)
|
||||
|
||||
|
||||
def convert_result_audio_to_codes(dit_handler, generated_audio):
|
||||
"""Convert a generated audio sample to LM audio codes.
|
||||
|
||||
Args:
|
||||
dit_handler: DiT handler instance with convert_src_audio_to_codes method
|
||||
generated_audio: File path to the generated audio
|
||||
|
||||
Returns:
|
||||
Tuple of (codes_display_update, details_accordion_update)
|
||||
"""
|
||||
if not generated_audio:
|
||||
gr.Warning("No audio to convert.")
|
||||
return gr.skip(), gr.skip()
|
||||
|
||||
if not dit_handler or dit_handler.model is None:
|
||||
gr.Warning(t("messages.service_not_initialized"))
|
||||
return gr.skip(), gr.skip()
|
||||
|
||||
try:
|
||||
codes_string = dit_handler.convert_src_audio_to_codes(generated_audio)
|
||||
if not codes_string or codes_string.startswith("❌"):
|
||||
gr.Warning(f"Failed to convert audio to codes: {codes_string}")
|
||||
return gr.skip(), gr.skip()
|
||||
|
||||
gr.Info("Audio converted to codes successfully.")
|
||||
return gr.update(value=codes_string), gr.update(open=True)
|
||||
except Exception as e:
|
||||
gr.Warning(f"Error converting audio to codes: {e}")
|
||||
return gr.skip(), gr.skip()
|
||||
|
||||
|
||||
def save_lrc_to_file(lrc_text):
|
||||
"""Save LRC text to a downloadable .lrc file.
|
||||
|
||||
Args:
|
||||
lrc_text: The LRC text content to save
|
||||
|
||||
Returns:
|
||||
gr.update for the File component with the .lrc file path
|
||||
"""
|
||||
if not lrc_text or not lrc_text.strip():
|
||||
gr.Warning("No LRC content to save.")
|
||||
return gr.skip()
|
||||
|
||||
try:
|
||||
# Create a temporary file with .lrc extension
|
||||
tmp_dir = tempfile.mkdtemp()
|
||||
lrc_path = os.path.join(tmp_dir, "lyrics.lrc")
|
||||
with open(lrc_path, "w", encoding="utf-8") as f:
|
||||
f.write(lrc_text)
|
||||
gr.Info("LRC file ready for download.")
|
||||
return gr.update(value=lrc_path, visible=True)
|
||||
except Exception as e:
|
||||
gr.Warning(f"Error saving LRC file: {e}")
|
||||
return gr.skip()
|
||||
|
||||
|
|
|
|||
|
|
@ -9,11 +9,19 @@ import json
|
|||
from typing import Any, Dict, List, Tuple, Optional
|
||||
from loguru import logger
|
||||
import gradio as gr
|
||||
import matplotlib
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from acestep.training.dataset_builder import DatasetBuilder, AudioSample
|
||||
from acestep.debug_utils import debug_log_for, debug_start_for, debug_end_for
|
||||
from acestep.gpu_config import get_global_gpu_config
|
||||
|
||||
# Define a safe root directory for all training-related filesystem operations.
|
||||
# This limits user-supplied paths (for checkpoints, exports, etc.) to stay
|
||||
# within the server's working tree, preventing directory traversal outside it.
|
||||
SAFE_TRAINING_ROOT = os.path.abspath(os.getcwd())
|
||||
|
||||
|
||||
def create_dataset_builder() -> DatasetBuilder:
|
||||
"""Create a new DatasetBuilder instance."""
|
||||
|
|
@ -42,7 +50,7 @@ def scan_directory(
|
|||
Tuple of (table_data, status, slider_update, builder_state)
|
||||
"""
|
||||
if not audio_dir or not audio_dir.strip():
|
||||
return [], "<EFBFBD> Please enter a directory path", _safe_slider(0, value=0, visible=False), builder_state
|
||||
return [], "❌ Please enter a directory path", _safe_slider(0, value=0, visible=False), builder_state
|
||||
|
||||
# Create or use existing builder
|
||||
builder = builder_state if builder_state else DatasetBuilder()
|
||||
|
|
@ -99,17 +107,17 @@ def auto_label_all(
|
|||
Tuple of (table_data, status, builder_state)
|
||||
"""
|
||||
if builder_state is None:
|
||||
return [], "<EFBFBD> Please scan a directory first", builder_state
|
||||
return [], "❌ Please scan a directory first", builder_state
|
||||
|
||||
if not builder_state.samples:
|
||||
return [], "<EFBFBD> No samples to label. Please scan a directory first.", builder_state
|
||||
return [], "❌ No samples to label. Please scan a directory first.", builder_state
|
||||
|
||||
# Check if handlers are initialized
|
||||
if dit_handler is None or dit_handler.model is None:
|
||||
return builder_state.get_samples_dataframe_data(), "<EFBFBD> Model not initialized. Please initialize the service first.", builder_state
|
||||
return builder_state.get_samples_dataframe_data(), "❌ Model not initialized. Please initialize the service first.", builder_state
|
||||
|
||||
if llm_handler is None or not llm_handler.llm_initialized:
|
||||
return builder_state.get_samples_dataframe_data(), "<EFBFBD> LLM not initialized. Please initialize the service with LLM enabled.", builder_state
|
||||
return builder_state.get_samples_dataframe_data(), "❌ LLM not initialized. Please initialize the service with LLM enabled.", builder_state
|
||||
|
||||
def progress_callback(msg):
|
||||
if progress:
|
||||
|
|
@ -210,7 +218,7 @@ def save_sample_edit(
|
|||
Tuple of (table_data, status, builder_state)
|
||||
"""
|
||||
if builder_state is None:
|
||||
return [], "<EFBFBD> No dataset loaded", builder_state
|
||||
return [], "❌ No dataset loaded", builder_state
|
||||
|
||||
idx = int(sample_idx)
|
||||
|
||||
|
|
@ -281,13 +289,13 @@ def save_dataset(
|
|||
Status message
|
||||
"""
|
||||
if builder_state is None:
|
||||
return "<EFBFBD> No dataset to save. Please scan a directory first.", gr.update()
|
||||
return "❌ No dataset to save. Please scan a directory first.", gr.update()
|
||||
|
||||
if not builder_state.samples:
|
||||
return "<EFBFBD> No samples in dataset.", gr.update()
|
||||
return "❌ No samples in dataset.", gr.update()
|
||||
|
||||
if not save_path or not save_path.strip():
|
||||
return "<EFBFBD> Please enter a save path.", gr.update()
|
||||
return "❌ Please enter a save path.", gr.update()
|
||||
|
||||
save_path = save_path.strip()
|
||||
if not save_path.lower().endswith(".json"):
|
||||
|
|
@ -296,7 +304,7 @@ def save_dataset(
|
|||
# Check if any samples are labeled
|
||||
labeled_count = builder_state.get_labeled_count()
|
||||
if labeled_count == 0:
|
||||
return "<EFBFBD>️ Warning: No samples have been labeled. Consider auto-labeling first.\nSaving anyway...", gr.update(value=save_path)
|
||||
return "⚠️ Warning: No samples have been labeled. Consider auto-labeling first.\nSaving anyway...", gr.update(value=save_path)
|
||||
|
||||
return builder_state.save_dataset(save_path, dataset_name), gr.update(value=save_path)
|
||||
|
||||
|
|
@ -321,14 +329,14 @@ def load_existing_dataset_for_preprocess(
|
|||
|
||||
if not dataset_path or not dataset_path.strip():
|
||||
updates = (gr.update(), gr.update(), gr.update(), gr.update(), gr.update())
|
||||
return ("<EFBFBD> Please enter a dataset path", [], _safe_slider(0, value=0, visible=False), builder_state) + empty_preview + updates
|
||||
return ("❌ Please enter a dataset path", [], _safe_slider(0, value=0, visible=False), builder_state) + empty_preview + updates
|
||||
|
||||
dataset_path = dataset_path.strip()
|
||||
debug_log_for("dataset", f"UI load_existing_dataset_for_preprocess: path='{dataset_path}'")
|
||||
|
||||
if not os.path.exists(dataset_path):
|
||||
updates = (gr.update(), gr.update(), gr.update(), gr.update(), gr.update())
|
||||
return (f"<EFBFBD> Dataset not found: {dataset_path}", [], _safe_slider(0, value=0, visible=False), builder_state) + empty_preview + updates
|
||||
return (f"❌ Dataset not found: {dataset_path}", [], _safe_slider(0, value=0, visible=False), builder_state) + empty_preview + updates
|
||||
|
||||
# Create new builder (don't reuse old state when loading a file)
|
||||
builder = DatasetBuilder()
|
||||
|
|
@ -350,12 +358,12 @@ def load_existing_dataset_for_preprocess(
|
|||
|
||||
# Create info text
|
||||
labeled_count = builder.get_labeled_count()
|
||||
info = f"<EFBFBD> Loaded dataset: {builder.metadata.name}\n"
|
||||
info += f"<EFBFBD> Samples: {len(samples)} ({labeled_count} labeled)\n"
|
||||
info += f"<EFBFBD><EFBFBD><EFBFBD>️ Custom Tag: {builder.metadata.custom_tag or '(none)'}\n"
|
||||
info += "<EFBFBD> Ready for preprocessing! You can also edit samples below."
|
||||
info = f"📂 Loaded dataset: {builder.metadata.name}\n"
|
||||
info += f"🔢 Samples: {len(samples)} ({labeled_count} labeled)\n"
|
||||
info += f"🏷️ Custom Tag: {builder.metadata.custom_tag or '(none)'}\n"
|
||||
info += "✅ Ready for preprocessing! You can also edit samples below."
|
||||
if any((s.formatted_lyrics and not s.lyrics) for s in builder.samples):
|
||||
info += "\n<EFBFBD>️ Showing formatted lyrics where lyrics are empty."
|
||||
info += "\nℹ️ Showing formatted lyrics where lyrics are empty."
|
||||
|
||||
# Get first sample preview
|
||||
first_sample = builder.samples[0]
|
||||
|
|
@ -401,6 +409,7 @@ def load_existing_dataset_for_preprocess(
|
|||
|
||||
def preprocess_dataset(
|
||||
output_dir: str,
|
||||
preprocess_mode: str,
|
||||
dit_handler,
|
||||
builder_state: Optional[DatasetBuilder],
|
||||
progress=None,
|
||||
|
|
@ -413,20 +422,20 @@ def preprocess_dataset(
|
|||
Status message
|
||||
"""
|
||||
if builder_state is None:
|
||||
return "<EFBFBD> No dataset loaded. Please scan a directory first."
|
||||
return "❌ No dataset loaded. Please scan a directory first."
|
||||
|
||||
if not builder_state.samples:
|
||||
return "<EFBFBD> No samples in dataset."
|
||||
return "❌ No samples in dataset."
|
||||
|
||||
labeled_count = builder_state.get_labeled_count()
|
||||
if labeled_count == 0:
|
||||
return "<EFBFBD> No labeled samples. Please auto-label or manually label samples first."
|
||||
return "❌ No labeled samples. Please auto-label or manually label samples first."
|
||||
|
||||
if not output_dir or not output_dir.strip():
|
||||
return "<EFBFBD> Please enter an output directory."
|
||||
return "❌ Please enter an output directory."
|
||||
|
||||
if dit_handler is None or dit_handler.model is None:
|
||||
return "<EFBFBD> Model not initialized. Please initialize the service first."
|
||||
return "❌ Model not initialized. Please initialize the service first."
|
||||
|
||||
def progress_callback(msg):
|
||||
if progress:
|
||||
|
|
@ -436,10 +445,15 @@ def preprocess_dataset(
|
|||
pass
|
||||
|
||||
# Run preprocessing
|
||||
mode = str(preprocess_mode or "lora").strip().lower()
|
||||
if mode not in {"lora", "lokr"}:
|
||||
mode = "lora"
|
||||
|
||||
t0 = debug_start_for("dataset", "preprocess_to_tensors")
|
||||
output_paths, status = builder_state.preprocess_to_tensors(
|
||||
dit_handler=dit_handler,
|
||||
output_dir=output_dir.strip(),
|
||||
preprocess_mode=mode,
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
debug_end_for("dataset", "preprocess_to_tensors", t0)
|
||||
|
|
@ -456,15 +470,15 @@ def load_training_dataset(
|
|||
Info text about the dataset
|
||||
"""
|
||||
if not tensor_dir or not tensor_dir.strip():
|
||||
return "<EFBFBD> Please enter a tensor directory path"
|
||||
return "❌ Please enter a tensor directory path"
|
||||
|
||||
tensor_dir = tensor_dir.strip()
|
||||
|
||||
if not os.path.exists(tensor_dir):
|
||||
return f"<EFBFBD> Directory not found: {tensor_dir}"
|
||||
return f"❌ Directory not found: {tensor_dir}"
|
||||
|
||||
if not os.path.isdir(tensor_dir):
|
||||
return f"<EFBFBD> Not a directory: {tensor_dir}"
|
||||
return f"❌ Not a directory: {tensor_dir}"
|
||||
|
||||
# Check for manifest
|
||||
manifest_path = os.path.join(tensor_dir, "manifest.json")
|
||||
|
|
@ -478,9 +492,9 @@ def load_training_dataset(
|
|||
name = metadata.get("name", "Unknown")
|
||||
custom_tag = metadata.get("custom_tag", "")
|
||||
|
||||
info = f"<EFBFBD> Loaded preprocessed dataset: {name}\n"
|
||||
info += f"<EFBFBD> Samples: {num_samples} preprocessed tensors\n"
|
||||
info += f"<EFBFBD><EFBFBD><EFBFBD>️ Custom Tag: {custom_tag or '(none)'}"
|
||||
info = f"📂 Loaded preprocessed dataset: {name}\n"
|
||||
info += f"🔢 Samples: {num_samples} preprocessed tensors\n"
|
||||
info += f"🏷️ Custom Tag: {custom_tag or '(none)'}"
|
||||
|
||||
return info
|
||||
except Exception as e:
|
||||
|
|
@ -490,10 +504,10 @@ def load_training_dataset(
|
|||
pt_files = [f for f in os.listdir(tensor_dir) if f.endswith('.pt')]
|
||||
|
||||
if not pt_files:
|
||||
return f"<EFBFBD> No .pt tensor files found in {tensor_dir}"
|
||||
return f"❌ No .pt tensor files found in {tensor_dir}"
|
||||
|
||||
info = f"<EFBFBD> Found {len(pt_files)} tensor files in {tensor_dir}\n"
|
||||
info += "<EFBFBD>️ No manifest.json found - using all .pt files"
|
||||
info = f"📂 Found {len(pt_files)} tensor files in {tensor_dir}\n"
|
||||
info += "ℹ️ No manifest.json found - using all .pt files"
|
||||
|
||||
return info
|
||||
|
||||
|
|
@ -515,6 +529,43 @@ def _format_duration(seconds):
|
|||
return f"{seconds // 3600}h {(seconds % 3600) // 60}m"
|
||||
|
||||
|
||||
def _training_loss_figure(
|
||||
training_state: Dict,
|
||||
step_list: List[int],
|
||||
loss_list: List[float],
|
||||
) -> Optional[Any]:
|
||||
"""Build a training/validation loss plot (matplotlib Figure) for gr.Plot."""
|
||||
steps = training_state.get("plot_steps") or step_list
|
||||
loss = training_state.get("plot_loss") or loss_list
|
||||
if not steps or not loss:
|
||||
fig, ax = plt.subplots(figsize=(6, 3))
|
||||
ax.set_xlabel("Step")
|
||||
ax.set_ylabel("Loss")
|
||||
ax.set_title("Training loss")
|
||||
fig.tight_layout()
|
||||
return fig
|
||||
ema = training_state.get("plot_ema")
|
||||
val_steps = training_state.get("plot_val_steps") or []
|
||||
val_loss = training_state.get("plot_val_loss") or []
|
||||
best_step = training_state.get("plot_best_step")
|
||||
|
||||
fig, ax = plt.subplots(figsize=(6, 3))
|
||||
ax.plot(steps, loss, color="tab:blue", alpha=0.35, label="Loss (raw)", linewidth=1)
|
||||
if ema and len(ema) == len(steps):
|
||||
ax.plot(steps, ema, color="tab:blue", alpha=1.0, label="Loss (smoothed)", linewidth=1.5)
|
||||
if val_steps and val_loss:
|
||||
ax.scatter(val_steps, val_loss, color="tab:orange", s=24, zorder=5, label="Validation")
|
||||
if best_step is not None:
|
||||
ax.axvline(x=best_step, color="tab:green", linestyle="--", alpha=0.8, label="Best checkpoint")
|
||||
ax.set_xlabel("Step")
|
||||
ax.set_ylabel("Loss")
|
||||
ax.set_title("Training loss")
|
||||
ax.legend(loc="upper right", fontsize=8)
|
||||
ax.grid(True, alpha=0.3)
|
||||
fig.tight_layout()
|
||||
return fig
|
||||
|
||||
|
||||
def start_training(
|
||||
tensor_dir: str,
|
||||
dit_handler,
|
||||
|
|
@ -538,17 +589,17 @@ def start_training(
|
|||
This is a generator function that yields progress updates.
|
||||
"""
|
||||
if not tensor_dir or not tensor_dir.strip():
|
||||
yield "<EFBFBD> Please enter a tensor directory path", "", None, training_state
|
||||
yield "❌ Please enter a tensor directory path", "", None, training_state
|
||||
return
|
||||
|
||||
tensor_dir = tensor_dir.strip()
|
||||
|
||||
if not os.path.exists(tensor_dir):
|
||||
yield f"<EFBFBD> Tensor directory not found: {tensor_dir}", "", None, training_state
|
||||
yield f"❌ Tensor directory not found: {tensor_dir}", "", None, training_state
|
||||
return
|
||||
|
||||
if dit_handler is None or dit_handler.model is None:
|
||||
yield "<EFBFBD> Model not initialized. Please initialize the service first.", "", None, training_state
|
||||
yield "❌ Model not initialized. Please initialize the service first.", "", None, training_state
|
||||
return
|
||||
|
||||
# Training preset: LoRA training must run on non-quantized DiT.
|
||||
|
|
@ -575,11 +626,11 @@ def start_training(
|
|||
if hasattr(dit_handler, "switch_to_training_preset"):
|
||||
switch_status, switched = dit_handler.switch_to_training_preset()
|
||||
if not switched:
|
||||
yield f"� {switch_status}", "", None, training_state
|
||||
yield f"❌ {switch_status}", "", None, training_state
|
||||
return
|
||||
yield f"� {switch_status}", "", None, training_state
|
||||
yield f"✅ {switch_status}", "", None, training_state
|
||||
else:
|
||||
yield "� Training requires non-quantized DiT, and auto-switch is unavailable in this build.", "", None, training_state
|
||||
yield "❌ Training requires non-quantized DiT, and auto-switch is unavailable in this build.", "", None, training_state
|
||||
return
|
||||
|
||||
# Check for required training dependencies
|
||||
|
|
@ -587,7 +638,7 @@ def start_training(
|
|||
from lightning.fabric import Fabric
|
||||
from peft import get_peft_model, LoraConfig
|
||||
except ImportError as e:
|
||||
yield f"<EFBFBD> Missing required packages: {e}\nPlease install: pip install peft lightning", "", None, training_state
|
||||
yield f"❌ Missing required packages: {e}\nPlease install: pip install peft lightning", "", None, training_state
|
||||
return
|
||||
|
||||
training_state["is_training"] = True
|
||||
|
|
@ -623,14 +674,14 @@ def start_training(
|
|||
pin_memory = True
|
||||
prefetch_factor = 2
|
||||
persistent_workers = True
|
||||
pin_memory_device = None
|
||||
pin_memory_device = ""
|
||||
mixed_precision = "bf16"
|
||||
elif device_type == "mps":
|
||||
num_workers = 0
|
||||
pin_memory = False
|
||||
prefetch_factor = 2
|
||||
persistent_workers = False
|
||||
pin_memory_device = None
|
||||
pin_memory_device = ""
|
||||
mixed_precision = "fp16"
|
||||
else:
|
||||
cpu_count = os.cpu_count() or 2
|
||||
|
|
@ -638,7 +689,7 @@ def start_training(
|
|||
pin_memory = False
|
||||
prefetch_factor = 2
|
||||
persistent_workers = num_workers > 0
|
||||
pin_memory_device = None
|
||||
pin_memory_device = ""
|
||||
mixed_precision = "fp32"
|
||||
|
||||
logger.info(
|
||||
|
|
@ -663,16 +714,16 @@ def start_training(
|
|||
mixed_precision=mixed_precision,
|
||||
)
|
||||
|
||||
import pandas as pd
|
||||
|
||||
# Initialize training log and loss history
|
||||
log_lines = []
|
||||
loss_data = pd.DataFrame({"step": [0], "loss": [0.0]})
|
||||
step_list = []
|
||||
loss_list = []
|
||||
initial_plot = _training_loss_figure(training_state, step_list, loss_list)
|
||||
|
||||
# Start timer
|
||||
start_time = time.time()
|
||||
|
||||
yield f"<EFBFBD> Starting training from {tensor_dir}...", "", loss_data, training_state
|
||||
yield f"🚀 Starting training from {tensor_dir}...", "", initial_plot, training_state
|
||||
|
||||
# Create trainer
|
||||
trainer = LoRATrainer(
|
||||
|
|
@ -681,9 +732,6 @@ def start_training(
|
|||
training_config=training_config,
|
||||
)
|
||||
|
||||
# Collect loss history
|
||||
step_list = []
|
||||
loss_list = []
|
||||
training_failed = False
|
||||
failure_message = ""
|
||||
|
||||
|
|
@ -731,37 +779,64 @@ def start_training(
|
|||
if step > 0 and loss is not None and loss == loss: # Check for NaN
|
||||
step_list.append(step)
|
||||
loss_list.append(float(loss))
|
||||
loss_data = pd.DataFrame({"step": step_list, "loss": loss_list})
|
||||
|
||||
yield display_status, log_text, loss_data, training_state
|
||||
plot_figure = _training_loss_figure(training_state, step_list, loss_list)
|
||||
yield display_status, log_text, plot_figure, training_state
|
||||
|
||||
if training_state.get("should_stop", False):
|
||||
logger.info("⏹️ Training stopped by user")
|
||||
log_lines.append("⏹️ Training stopped by user")
|
||||
yield f"⏹️ Stopped ({time_info})", "\n".join(log_lines[-15:]), loss_data, training_state
|
||||
yield f"⏹️ Stopped ({time_info})", "\n".join(log_lines[-15:]), plot_figure, training_state
|
||||
break
|
||||
|
||||
total_time = time.time() - start_time
|
||||
training_state["is_training"] = False
|
||||
final_plot = _training_loss_figure(training_state, step_list, loss_list)
|
||||
if training_failed:
|
||||
final_msg = f"{failure_message}\nElapsed: {_format_duration(total_time)}"
|
||||
logger.warning(final_msg)
|
||||
log_lines.append(failure_message)
|
||||
yield final_msg, "\n".join(log_lines[-15:]), loss_data, training_state
|
||||
yield final_msg, "\n".join(log_lines[-15:]), final_plot, training_state
|
||||
return
|
||||
completion_msg = f"<EFBFBD> Training completed! Total time: {_format_duration(total_time)}"
|
||||
completion_msg = f"✅ Training completed! Total time: {_format_duration(total_time)}"
|
||||
|
||||
logger.info(completion_msg)
|
||||
log_lines.append(completion_msg)
|
||||
|
||||
yield completion_msg, "\n".join(log_lines[-15:]), loss_data, training_state
|
||||
yield completion_msg, "\n".join(log_lines[-15:]), final_plot, training_state
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Training error")
|
||||
training_state["is_training"] = False
|
||||
import pandas as pd
|
||||
empty_df = pd.DataFrame({"step": [], "loss": []})
|
||||
yield f"<EFBFBD> Error: {str(e)}", str(e), empty_df, training_state
|
||||
yield f"❌ Error: {str(e)}", str(e), _training_loss_figure({}, [], []), training_state
|
||||
|
||||
def _safe_join(base_root: str, user_path: str) -> Optional[str]:
|
||||
"""Safely join a user-supplied path to a base root, preventing escapes.
|
||||
|
||||
Returns an absolute path within base_root, or None if the path is invalid.
|
||||
"""
|
||||
if not user_path:
|
||||
return None
|
||||
# Normalize whitespace
|
||||
candidate = user_path.strip()
|
||||
if not candidate:
|
||||
return None
|
||||
# Disallow absolute paths outright; force everything under base_root
|
||||
if os.path.isabs(candidate):
|
||||
return None
|
||||
# Join with base_root and normalize to an absolute path
|
||||
abs_root = os.path.abspath(base_root)
|
||||
joined = os.path.abspath(os.path.join(abs_root, candidate))
|
||||
try:
|
||||
common = os.path.commonpath([abs_root, joined])
|
||||
except ValueError:
|
||||
# Different drives on Windows or other path issues
|
||||
return None
|
||||
if common != abs_root:
|
||||
# Attempted to escape the allowed root
|
||||
return None
|
||||
return joined
|
||||
|
||||
|
||||
|
||||
def stop_training(training_state: Dict) -> Tuple[str, Dict]:
|
||||
|
|
@ -771,7 +846,7 @@ def stop_training(training_state: Dict) -> Tuple[str, Dict]:
|
|||
Tuple of (status, training_state)
|
||||
"""
|
||||
if not training_state.get("is_training", False):
|
||||
return "<EFBFBD>️ No training in progress", training_state
|
||||
return "ℹ️ No training in progress", training_state
|
||||
|
||||
training_state["should_stop"] = True
|
||||
return "⏹️ Stopping training...", training_state
|
||||
|
|
@ -787,11 +862,16 @@ def export_lora(
|
|||
Status message
|
||||
"""
|
||||
if not export_path or not export_path.strip():
|
||||
return "<EFBFBD> Please enter an export path"
|
||||
return "❌ Please enter an export path"
|
||||
# Resolve and validate the base LoRA output directory within the safe root
|
||||
safe_lora_dir = _safe_join(SAFE_TRAINING_ROOT, lora_output_dir)
|
||||
if safe_lora_dir is None:
|
||||
return "❌ Invalid LoRA output directory"
|
||||
|
||||
|
||||
# Check if there's a trained model to export
|
||||
final_dir = os.path.join(lora_output_dir, "final")
|
||||
checkpoint_dir = os.path.join(lora_output_dir, "checkpoints")
|
||||
final_dir = os.path.join(safe_lora_dir, "final")
|
||||
checkpoint_dir = os.path.join(safe_lora_dir, "checkpoints")
|
||||
|
||||
# Prefer final, fallback to checkpoints
|
||||
if os.path.exists(final_dir):
|
||||
|
|
@ -800,30 +880,337 @@ def export_lora(
|
|||
# Find the latest checkpoint
|
||||
checkpoints = [d for d in os.listdir(checkpoint_dir) if d.startswith("epoch_")]
|
||||
if not checkpoints:
|
||||
return "<EFBFBD> No checkpoints found"
|
||||
return "❌ No checkpoints found"
|
||||
|
||||
checkpoints.sort(key=lambda x: int(x.split("_")[1]))
|
||||
latest = checkpoints[-1]
|
||||
source_path = os.path.join(checkpoint_dir, latest)
|
||||
else:
|
||||
return f"<EFBFBD> No trained model found in {lora_output_dir}"
|
||||
return f"❌ No trained model found in {lora_output_dir}"
|
||||
|
||||
# Resolve and validate the export destination within the safe root
|
||||
safe_export_path = _safe_join(SAFE_TRAINING_ROOT, export_path)
|
||||
if safe_export_path is None:
|
||||
return "❌ Invalid export path"
|
||||
|
||||
try:
|
||||
import shutil
|
||||
|
||||
export_path = export_path.strip()
|
||||
os.makedirs(os.path.dirname(export_path) if os.path.dirname(export_path) else ".", exist_ok=True)
|
||||
# Ensure parent directory exists
|
||||
parent_dir = os.path.dirname(safe_export_path) or "."
|
||||
os.makedirs(parent_dir, exist_ok=True)
|
||||
|
||||
if os.path.exists(export_path):
|
||||
shutil.rmtree(export_path)
|
||||
if os.path.exists(safe_export_path):
|
||||
shutil.rmtree(safe_export_path)
|
||||
|
||||
shutil.copytree(source_path, export_path)
|
||||
shutil.copytree(source_path, safe_export_path)
|
||||
|
||||
return f"<EFBFBD> LoRA exported to {export_path}"
|
||||
return f"✅ LoRA exported to {safe_export_path}"
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Export error")
|
||||
return f"<EFBFBD> Export failed: {str(e)}"
|
||||
return f"❌ Export failed: {str(e)}"
|
||||
|
||||
|
||||
def start_lokr_training(
|
||||
tensor_dir: str,
|
||||
dit_handler,
|
||||
lokr_linear_dim: int,
|
||||
lokr_linear_alpha: int,
|
||||
lokr_factor: int,
|
||||
lokr_decompose_both: bool,
|
||||
lokr_use_tucker: bool,
|
||||
lokr_use_scalar: bool,
|
||||
lokr_weight_decompose: bool,
|
||||
learning_rate: float,
|
||||
train_epochs: int,
|
||||
train_batch_size: int,
|
||||
gradient_accumulation: int,
|
||||
save_every_n_epochs: int,
|
||||
training_shift: float,
|
||||
training_seed: int,
|
||||
lokr_output_dir: str,
|
||||
training_state: Dict,
|
||||
progress=None,
|
||||
):
|
||||
"""Start LoKr training from preprocessed tensors."""
|
||||
if not tensor_dir or not tensor_dir.strip():
|
||||
yield "❌ Please enter a tensor directory path", "", None, training_state
|
||||
return
|
||||
|
||||
tensor_dir = tensor_dir.strip()
|
||||
if not os.path.exists(tensor_dir):
|
||||
yield f"❌ Tensor directory not found: {tensor_dir}", "", None, training_state
|
||||
return
|
||||
|
||||
if dit_handler is None or dit_handler.model is None:
|
||||
yield "❌ Model not initialized. Please initialize the service first.", "", None, training_state
|
||||
return
|
||||
|
||||
if getattr(dit_handler, "quantization", None) is not None:
|
||||
yield "Switching model to training preset (disable quantization)...", "", None, training_state
|
||||
if hasattr(dit_handler, "switch_to_training_preset"):
|
||||
switch_status, switched = dit_handler.switch_to_training_preset()
|
||||
if not switched:
|
||||
yield f"❌ {switch_status}", "", None, training_state
|
||||
return
|
||||
yield f"✅ {switch_status}", "", None, training_state
|
||||
else:
|
||||
yield "❌ Training requires non-quantized DiT, and auto-switch is unavailable in this build.", "", None, training_state
|
||||
return
|
||||
|
||||
try:
|
||||
from lightning.fabric import Fabric # noqa: F401
|
||||
except ImportError as e:
|
||||
yield f"❌ Missing required packages: {e}\nPlease install: pip install lightning lycoris-lora", "", None, training_state
|
||||
return
|
||||
|
||||
training_state["is_training"] = True
|
||||
training_state["should_stop"] = False
|
||||
training_state["adapter_type"] = "lokr"
|
||||
|
||||
try:
|
||||
from acestep.training.configs import LoKRConfig as LoKRConfigClass, TrainingConfig
|
||||
from acestep.training.trainer import LoKRTrainer
|
||||
|
||||
device_attr = getattr(dit_handler, "device", "")
|
||||
if hasattr(device_attr, "type"):
|
||||
device_type = str(device_attr.type).lower()
|
||||
else:
|
||||
device_type = str(device_attr).split(":", 1)[0].lower()
|
||||
|
||||
if device_type == "cuda":
|
||||
num_workers = 4
|
||||
pin_memory = True
|
||||
prefetch_factor = 2
|
||||
persistent_workers = True
|
||||
pin_memory_device = "cuda"
|
||||
mixed_precision = "bf16"
|
||||
elif device_type == "xpu":
|
||||
num_workers = 4
|
||||
pin_memory = True
|
||||
prefetch_factor = 2
|
||||
persistent_workers = True
|
||||
pin_memory_device = ""
|
||||
mixed_precision = "bf16"
|
||||
elif device_type == "mps":
|
||||
num_workers = 0
|
||||
pin_memory = False
|
||||
prefetch_factor = 2
|
||||
persistent_workers = False
|
||||
pin_memory_device = ""
|
||||
mixed_precision = "fp16"
|
||||
else:
|
||||
num_workers = 0
|
||||
pin_memory = False
|
||||
prefetch_factor = 2
|
||||
persistent_workers = False
|
||||
pin_memory_device = ""
|
||||
mixed_precision = "fp32"
|
||||
|
||||
lokr_config = LoKRConfigClass(
|
||||
linear_dim=lokr_linear_dim,
|
||||
linear_alpha=lokr_linear_alpha,
|
||||
factor=lokr_factor,
|
||||
decompose_both=lokr_decompose_both,
|
||||
use_tucker=lokr_use_tucker,
|
||||
use_scalar=lokr_use_scalar,
|
||||
weight_decompose=lokr_weight_decompose,
|
||||
)
|
||||
training_config = TrainingConfig(
|
||||
shift=training_shift,
|
||||
learning_rate=learning_rate,
|
||||
batch_size=train_batch_size,
|
||||
gradient_accumulation_steps=gradient_accumulation,
|
||||
max_epochs=train_epochs,
|
||||
save_every_n_epochs=save_every_n_epochs,
|
||||
seed=training_seed,
|
||||
output_dir=lokr_output_dir,
|
||||
num_workers=num_workers,
|
||||
pin_memory=pin_memory,
|
||||
prefetch_factor=prefetch_factor,
|
||||
persistent_workers=persistent_workers,
|
||||
pin_memory_device=pin_memory_device,
|
||||
mixed_precision=mixed_precision,
|
||||
)
|
||||
|
||||
log_lines = []
|
||||
step_list = []
|
||||
loss_list = []
|
||||
initial_plot = _training_loss_figure(training_state, step_list, loss_list)
|
||||
start_time = time.time()
|
||||
yield f"🚀 Starting LoKr training from {tensor_dir}...", "", initial_plot, training_state
|
||||
|
||||
trainer = LoKRTrainer(
|
||||
dit_handler=dit_handler,
|
||||
lokr_config=lokr_config,
|
||||
training_config=training_config,
|
||||
)
|
||||
|
||||
training_failed = False
|
||||
failure_message = ""
|
||||
|
||||
for step, loss, status in trainer.train_from_preprocessed(tensor_dir, training_state):
|
||||
status_text = str(status)
|
||||
status_lower = status_text.lower()
|
||||
if (
|
||||
status_text.startswith("❌")
|
||||
or "training failed" in status_lower
|
||||
or "error:" in status_lower
|
||||
or "module not found" in status_lower
|
||||
):
|
||||
training_failed = True
|
||||
failure_message = status_text
|
||||
|
||||
elapsed_seconds = time.time() - start_time
|
||||
time_info = f"⏱️ Elapsed: {_format_duration(elapsed_seconds)}"
|
||||
match = re.search(r"Epoch\s+(\d+)/(\d+)", status_text)
|
||||
if match:
|
||||
current_ep = int(match.group(1))
|
||||
total_ep = int(match.group(2))
|
||||
if current_ep > 0:
|
||||
eta_seconds = (elapsed_seconds / current_ep) * (total_ep - current_ep)
|
||||
time_info += f" | ETA: ~{_format_duration(eta_seconds)}"
|
||||
|
||||
display_status = f"{status_text}\n{time_info}"
|
||||
log_lines.append(status_text)
|
||||
if len(log_lines) > 15:
|
||||
log_lines = log_lines[-15:]
|
||||
log_text = "\n".join(log_lines)
|
||||
|
||||
if step > 0 and loss is not None and loss == loss:
|
||||
step_list.append(step)
|
||||
loss_list.append(float(loss))
|
||||
|
||||
plot_figure = _training_loss_figure(training_state, step_list, loss_list)
|
||||
yield display_status, log_text, plot_figure, training_state
|
||||
|
||||
if training_state.get("should_stop", False):
|
||||
log_lines.append("⏹️ Training stopped by user")
|
||||
yield f"⏹️ Stopped ({time_info})", "\n".join(log_lines[-15:]), plot_figure, training_state
|
||||
break
|
||||
|
||||
total_time = time.time() - start_time
|
||||
training_state["is_training"] = False
|
||||
final_plot = _training_loss_figure(training_state, step_list, loss_list)
|
||||
if training_failed:
|
||||
final_msg = f"{failure_message}\nElapsed: {_format_duration(total_time)}"
|
||||
log_lines.append(failure_message)
|
||||
yield final_msg, "\n".join(log_lines[-15:]), final_plot, training_state
|
||||
return
|
||||
|
||||
completion_msg = f"✅ LoKr training completed! Total time: {_format_duration(total_time)}"
|
||||
log_lines.append(completion_msg)
|
||||
yield completion_msg, "\n".join(log_lines[-15:]), final_plot, training_state
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("LoKr training error")
|
||||
training_state["is_training"] = False
|
||||
yield f"❌ Error: {str(e)}", str(e), _training_loss_figure({}, [], []), training_state
|
||||
|
||||
|
||||
def list_lokr_export_epochs(lokr_output_dir: str) -> Tuple[Any, str]:
|
||||
"""List available LoKr checkpoint epochs for export dropdown."""
|
||||
default_choice = "Latest (auto)"
|
||||
if not lokr_output_dir or not lokr_output_dir.strip():
|
||||
return gr.update(choices=[default_choice], value=default_choice), "⚠️ Enter LoKr output directory first"
|
||||
|
||||
checkpoint_dir = os.path.join(lokr_output_dir.strip(), "checkpoints")
|
||||
if not os.path.isdir(checkpoint_dir):
|
||||
return gr.update(choices=[default_choice], value=default_choice), "ℹ️ No checkpoints found; export will use latest available weights"
|
||||
|
||||
checkpoints = []
|
||||
for d in os.listdir(checkpoint_dir):
|
||||
if not d.startswith("epoch_"):
|
||||
continue
|
||||
weight_file = os.path.join(checkpoint_dir, d, "lokr_weights.safetensors")
|
||||
if not os.path.exists(weight_file):
|
||||
continue
|
||||
try:
|
||||
epoch_num = int(d.split("_")[1])
|
||||
except Exception:
|
||||
continue
|
||||
checkpoints.append((epoch_num, d))
|
||||
|
||||
if not checkpoints:
|
||||
return gr.update(choices=[default_choice], value=default_choice), "ℹ️ No exportable epoch checkpoints found"
|
||||
|
||||
checkpoints.sort(key=lambda x: x[0], reverse=True)
|
||||
choices = [default_choice] + [d for _, d in checkpoints]
|
||||
return gr.update(choices=choices, value=default_choice), f"✅ Found {len(checkpoints)} LoKr checkpoints"
|
||||
|
||||
|
||||
def export_lokr(
|
||||
export_path: str,
|
||||
lokr_output_dir: str,
|
||||
selected_epoch: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Export trained LoKr weights.
|
||||
|
||||
Returns:
|
||||
Status message
|
||||
"""
|
||||
if not export_path or not export_path.strip():
|
||||
return "❌ Please enter an export path"
|
||||
|
||||
final_dir = os.path.join(lokr_output_dir, "final")
|
||||
checkpoint_dir = os.path.join(lokr_output_dir, "checkpoints")
|
||||
default_epoch_choice = "Latest (auto)"
|
||||
|
||||
chosen_epoch = (selected_epoch or "").strip()
|
||||
if not chosen_epoch:
|
||||
chosen_epoch = default_epoch_choice
|
||||
|
||||
checkpoint_names: List[str] = []
|
||||
if os.path.isdir(checkpoint_dir):
|
||||
for d in os.listdir(checkpoint_dir):
|
||||
if not d.startswith("epoch_"):
|
||||
continue
|
||||
try:
|
||||
int(d.split("_")[1])
|
||||
except Exception:
|
||||
continue
|
||||
checkpoint_names.append(d)
|
||||
checkpoint_names.sort(key=lambda x: int(x.split("_")[1]))
|
||||
|
||||
# Determine source
|
||||
explicit_epoch = chosen_epoch not in {default_epoch_choice, "latest", "Latest", "auto", "Auto"}
|
||||
if explicit_epoch:
|
||||
requested = chosen_epoch
|
||||
if requested.isdigit():
|
||||
requested = f"epoch_{requested}"
|
||||
if requested not in checkpoint_names:
|
||||
return (
|
||||
f"❌ Selected epoch not found: {chosen_epoch}. "
|
||||
f"Available: {', '.join(checkpoint_names) if checkpoint_names else '(none)'}"
|
||||
)
|
||||
source_file = os.path.join(checkpoint_dir, requested, "lokr_weights.safetensors")
|
||||
if not os.path.exists(source_file):
|
||||
return f"❌ No LoKr weights found for selected epoch: {requested}"
|
||||
elif os.path.exists(os.path.join(final_dir, "lokr_weights.safetensors")):
|
||||
source_file = os.path.join(final_dir, "lokr_weights.safetensors")
|
||||
elif checkpoint_names:
|
||||
latest_checkpoint = checkpoint_names[-1]
|
||||
source_file = os.path.join(checkpoint_dir, latest_checkpoint, "lokr_weights.safetensors")
|
||||
if not os.path.exists(source_file):
|
||||
return f"❌ No LoKr weights found in latest checkpoint: {latest_checkpoint}"
|
||||
else:
|
||||
return f"❌ No trained LoKr weights found in {lokr_output_dir}"
|
||||
|
||||
try:
|
||||
import shutil
|
||||
|
||||
export_path = export_path.strip()
|
||||
if export_path.lower().endswith(".safetensors"):
|
||||
os.makedirs(os.path.dirname(export_path) if os.path.dirname(export_path) else ".", exist_ok=True)
|
||||
shutil.copy2(source_file, export_path)
|
||||
return f"✅ LoKr exported to {export_path}"
|
||||
|
||||
os.makedirs(export_path, exist_ok=True)
|
||||
dst_file = os.path.join(export_path, "lokr_weights.safetensors")
|
||||
shutil.copy2(source_file, dst_file)
|
||||
return f"✅ LoKr exported to {dst_file}"
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("LoKr export error")
|
||||
return f"❌ Export failed: {str(e)}"
|
||||
|
|
@ -53,12 +53,19 @@
|
|||
"compile_model_info": "Use torch.compile to optimize model (required for quantization)",
|
||||
"quantization_label": "INT8 Quantization",
|
||||
"quantization_info": "Enable INT8 weight-only quantization to reduce VRAM usage (requires Compile Model)",
|
||||
"mlx_dit_label": "MLX DiT (Apple Silicon)",
|
||||
"mlx_dit_info_enabled": "Use native MLX for DiT diffusion on Apple Silicon (faster than MPS)",
|
||||
"mlx_dit_info_disabled": "MLX not available (requires macOS + Apple Silicon + mlx package)",
|
||||
"init_btn": "Initialize Service",
|
||||
"status_label": "Status",
|
||||
"language_label": "UI Language",
|
||||
"language_info": "Select interface language"
|
||||
"language_info": "Select interface language",
|
||||
"gpu_auto_tier": "Auto-detected Tier",
|
||||
"tier_label": "GPU Tier Override",
|
||||
"tier_info": "Manually select GPU tier to adjust optimization defaults (offload, quantization, backend, etc.)"
|
||||
},
|
||||
"generation": {
|
||||
"tab_title": "🎵 Generation",
|
||||
"required_inputs": "📝 Required Inputs",
|
||||
"task_type_label": "Task Type",
|
||||
"task_type_info": "Select the task type for generation",
|
||||
|
|
@ -71,8 +78,11 @@
|
|||
"track_classes_info": "Select multiple track classes for complete task",
|
||||
"audio_uploads": "🎵 Audio Uploads",
|
||||
"reference_audio": "Reference Audio (optional)",
|
||||
"source_audio": "Source Audio (optional)",
|
||||
"source_audio": "Source Audio",
|
||||
"convert_codes_btn": "Convert to Codes",
|
||||
"analyze_btn": "🔍 Analyze",
|
||||
"sample_btn": "🎲 Click Me",
|
||||
"load_btn": "📂 Load",
|
||||
"lm_codes_hints": "🎼 LM Codes Hints",
|
||||
"lm_codes_label": "LM Codes Hints",
|
||||
"lm_codes_placeholder": "<|audio_code_10695|><|audio_code_54246|>...",
|
||||
|
|
@ -84,13 +94,20 @@
|
|||
"repainting_start": "Repainting Start",
|
||||
"repainting_end": "Repainting End",
|
||||
"mode_label": "Generation Mode",
|
||||
"mode_info": "Simple: describe music in natural language. Custom: full control over caption and lyrics.",
|
||||
"mode_info": "Select a generation mode to get started.",
|
||||
"mode_info_simple": "Describe your music in natural language. AI will generate caption, lyrics and metadata for you.",
|
||||
"mode_info_custom": "Full control over caption, lyrics and all parameters.",
|
||||
"mode_info_remix": "Upload source audio to create a remix version with your caption and lyrics.",
|
||||
"mode_info_repaint": "Upload source audio and repaint a specific time range.",
|
||||
"mode_info_extract": "Extract a specific track (vocals, drums, etc.) from source audio.",
|
||||
"mode_info_lego": "Reassemble tracks: replace a specific track in the source audio.",
|
||||
"mode_info_complete": "Complete missing tracks in the source audio.",
|
||||
"mode_simple": "Simple",
|
||||
"mode_custom": "Custom",
|
||||
"simple_query_label": "Song Description",
|
||||
"simple_query_placeholder": "Describe the music you want to create, e.g., 'a soft Bengali love song for a quiet evening'. Leave empty for a random sample.",
|
||||
"simple_query_info": "Enter a natural language description of the music you want to generate",
|
||||
"simple_vocal_language_label": "Vocal Language (optional)",
|
||||
"simple_vocal_language_label": "Vocal Language",
|
||||
"simple_vocal_language_info": "Select preferred language(s) for lyrics. Use 'unknown' for any language.",
|
||||
"create_sample_btn": "Create Sample",
|
||||
"caption_title": "📝 Music Caption",
|
||||
|
|
@ -103,12 +120,20 @@
|
|||
"lyrics_info": "Song lyrics with structure",
|
||||
"instrumental_label": "Instrumental",
|
||||
"format_btn": "Format",
|
||||
"format_caption_btn": "Enhance Caption",
|
||||
"format_lyrics_btn": "Enhance Lyrics",
|
||||
"optional_params": "⚙️ Optional Parameters",
|
||||
"optional_music_props": "🎵 Music Properties",
|
||||
"optional_gen_settings": "📐 Generation Settings",
|
||||
"advanced_dit_section": "🎛️ DiT Diffusion",
|
||||
"advanced_lm_section": "🤖 LM Generation",
|
||||
"advanced_output_section": "🔊 Audio Output & Post-processing",
|
||||
"advanced_automation_section": "⚡ Automation & Batch",
|
||||
"vocal_language_label": "Vocal Language (optional)",
|
||||
"vocal_language_info": "use `unknown` for inst",
|
||||
"vocal_language_info": "'unknown' = instrumental / auto",
|
||||
"bpm_label": "BPM (optional)",
|
||||
"bpm_info": "leave empty for N/A",
|
||||
"keyscale_label": "KeyScale (optional)",
|
||||
"keyscale_label": "Key (optional)",
|
||||
"keyscale_placeholder": "Leave empty for N/A",
|
||||
"keyscale_info": "A-G, #/♭, major/minor",
|
||||
"timesig_label": "Time Signature (optional)",
|
||||
|
|
@ -117,7 +142,7 @@
|
|||
"duration_info": "Use -1 for random",
|
||||
"batch_size_label": "Batch Size",
|
||||
"batch_size_info": "Number of audio to generate (max 8)",
|
||||
"advanced_settings": "🔧 Advanced Settings",
|
||||
"advanced_settings": "⚙️ Settings",
|
||||
"inference_steps_label": "DiT Inference Steps",
|
||||
"inference_steps_info": "Turbo: max 8, Base: max 200",
|
||||
"guidance_scale_label": "DiT Guidance Scale (Only support for base model)",
|
||||
|
|
@ -150,6 +175,7 @@
|
|||
"lm_negative_prompt_label": "LM Negative Prompt",
|
||||
"lm_negative_prompt_placeholder": "Enter negative prompt for CFG (default: NO USER INPUT)",
|
||||
"lm_negative_prompt_info": "Negative prompt (use when LM CFG Scale > 1.0)",
|
||||
"advanced_dit_params": "Advanced DiT Parameters",
|
||||
"cot_metas_label": "CoT Metas",
|
||||
"cot_metas_info": "Use LM to generate CoT metadata (uncheck to skip LM CoT generation)",
|
||||
"cot_language_label": "CoT Language",
|
||||
|
|
@ -164,25 +190,32 @@
|
|||
"lm_batch_chunk_info": "Max items per LM batch chunk (default: 8, limited by GPU memory)",
|
||||
"codes_strength_label": "LM Codes Strength",
|
||||
"codes_strength_info": "Control how many denoising steps use LM-generated codes",
|
||||
"similarity_denoise_label": "Similarity / Denoise",
|
||||
"similarity_denoise_info": "Controls how closely the output follows the reference audio. Higher values preserve more structure.",
|
||||
"cover_strength_label": "Audio Cover Strength",
|
||||
"cover_strength_info": "Control how many denoising steps use cover mode",
|
||||
"remix_strength_label": "Remix Strength",
|
||||
"remix_strength_info": "Control how much the remix deviates from the source audio (lower = closer to original)",
|
||||
"cover_noise_strength_label": "Cover Strength",
|
||||
"cover_noise_strength_info": "Controls melody retention in Remix mode. Recommended: use the SFT model with a value of 0.1–0.25. A small increase restores the melody, but style transfer may require additional prompt tuning. (0 = pure noise/no cover, 1 = closest to original audio)",
|
||||
"score_sensitivity_label": "Quality Score Sensitivity",
|
||||
"score_sensitivity_info": "Lower = more sensitive (default: 1.0). Adjusts how PMI maps to [0,1]",
|
||||
"think_label": "Think",
|
||||
"parallel_thinking_label": "ParallelThinking",
|
||||
"parallel_thinking_info": "Process batch samples in parallel for faster generation",
|
||||
"generate_btn": "🎵 Generate Music",
|
||||
"autogen_label": "AutoGen",
|
||||
"caption_rewrite_label": "CaptionRewrite"
|
||||
"caption_rewrite_label": "CaptionRewrite",
|
||||
"caption_rewrite_info": "Use LM to rewrite caption before generation"
|
||||
},
|
||||
"results": {
|
||||
"title": "🎵 Results",
|
||||
"generated_music": "🎵 Generated Music (Sample {n})",
|
||||
"send_to_src_btn": "🔗 Send To Src Audio",
|
||||
"send_to_remix_btn": "🔗 Send To Remix",
|
||||
"send_to_repaint_btn": "🔗 Send To Repaint",
|
||||
"save_btn": "💾 Save",
|
||||
"score_btn": "📊 Score",
|
||||
"lrc_btn": "🎵 LRC",
|
||||
"score_btn": "📊 Get Score",
|
||||
"lrc_btn": "🎵 Get LRC",
|
||||
"save_lrc_btn": "💾 Save LRC",
|
||||
"convert_to_codes_btn": "🔄 Convert To Codes",
|
||||
"quality_score_label": "Quality Score (Sample {n})",
|
||||
"quality_score_placeholder": "Click 'Score' to calculate perplexity-based quality score",
|
||||
"codes_label": "LM Codes (Sample {n})",
|
||||
|
|
@ -196,7 +229,7 @@
|
|||
"prev_btn": "◀ Previous",
|
||||
"next_btn": "Next ▶",
|
||||
"restore_params_btn": "↙️ Apply These Settings to UI (Restore Batch Parameters)",
|
||||
"batch_results_title": "👇 Click here to view batch results & generation details",
|
||||
"batch_results_title": "📁 Batch Results & Generation Details",
|
||||
"all_files_label": "📁 All Generated Files (Download)",
|
||||
"generation_details": "Generation Details"
|
||||
},
|
||||
|
|
@ -214,6 +247,7 @@
|
|||
"lm_generated": "🤖 Generated example using LM",
|
||||
"lm_fallback": "Failed to generate example using LM, falling back to examples directory",
|
||||
"lm_not_initialized": "❌ 5Hz LM not initialized. Please initialize it first.",
|
||||
"think_requires_lm": "⚠️ 'Think' requires 5Hz LM to be initialized. Think has been disabled — generation will proceed without LM thinking.",
|
||||
"autogen_enabled": "🔄 AutoGen enabled - next batch will generate after this",
|
||||
"batch_ready": "✅ Batch {n} ready! Click 'Next' to view.",
|
||||
"batch_generating": "🔄 Starting background generation for Batch {n}...",
|
||||
|
|
@ -263,57 +297,57 @@
|
|||
"dataset_name": "Dataset Name",
|
||||
"dataset_name_placeholder": "Enter dataset name",
|
||||
"dataset_settings_header": "Dataset Settings",
|
||||
"tag_prepend": "Prepend (tag, caption)",
|
||||
"tag_append": "Append (caption, tag)",
|
||||
"tag_replace": "Replace caption",
|
||||
"step2_title": "Step 2: Auto-Label with AI",
|
||||
"step2_instruction": "Click the button below to automatically generate metadata for all audio files using AI:\n- **Caption**: Music style, genre, mood description\n- **BPM**: Beats per minute\n- **Key**: Musical key (e.g., C Major, Am)\n- **Time Signature**: 4/4, 3/4, etc.",
|
||||
"tag_prepend": "Prepend (Tag, Caption)",
|
||||
"tag_append": "Append (Caption, Tag)",
|
||||
"tag_replace": "Replace Caption",
|
||||
"step2_title": "Step 2: AI Auto-Labeling",
|
||||
"step2_instruction": "Click the button below to use AI to automatically generate metadata for all audio files:\n- **Caption**: Music style, genre, mood description\n- **BPM**: Beats per minute\n- **Key**: Music key (e.g. C Major, Am)\n- **Time Signature**: 4/4, 3/4 etc.",
|
||||
"step3_title": "Step 3: Preview & Edit",
|
||||
"step4_title": "Step 4: Save Dataset",
|
||||
"step5_title": "Step 5: Preprocess to Tensors",
|
||||
"step5_intro": "**Preprocessing converts your dataset to pre-computed tensors for fast training.**\n\nYou can either:\n- Use the dataset from Steps 1-4 above, **OR**\n- Load an existing dataset JSON file (if you've already saved one)",
|
||||
"step5_details": "This step:\n- Encodes audio to VAE latents\n- Encodes captions and lyrics to text embeddings\n- Runs the condition encoder\n- Saves all tensors to `.pt` files\n\n⚠️ **This requires the model to be loaded and may take a few minutes.**",
|
||||
"train_tensor_selection_desc": "Select the directory containing preprocessed tensor files (`.pt` files).\nThese are created in the \"Dataset Builder\" tab using the \"Preprocess\" button.",
|
||||
"step5_intro": "**Preprocessing converts your dataset into pre-computed tensors for fast training.**\n\nYou can:\n- Use the dataset from steps 1-4 above, **OR**\n- Load an existing dataset JSON file (if you have one saved)",
|
||||
"step5_details": "This step will:\n- Encode audio to VAE latents\n- Encode captions and lyrics to text embeddings\n- Run condition encoders\n- Save all tensors to `.pt` files\n\n⚠️ **This requires loading models and may take a few minutes.**",
|
||||
"train_tensor_selection_desc": "Select the directory containing preprocessed tensor files (`.pt` files).\nThese are created using the 'Preprocess' button in the 'Dataset Builder' tab.",
|
||||
"all_instrumental": "All Instrumental",
|
||||
"all_instrumental_info": "Check if all tracks are instrumental (no vocals)",
|
||||
"custom_tag": "Custom Activation Tag",
|
||||
"custom_tag_info": "Unique tag to activate this LoRA's style",
|
||||
"custom_tag": "Custom Trigger Tag",
|
||||
"custom_tag_info": "Unique tag to activate this LoRA style",
|
||||
"tag_position": "Tag Position",
|
||||
"tag_position_info": "Where to place the custom tag in the caption",
|
||||
"genre_ratio": "Genre Ratio (%)",
|
||||
"genre_ratio_info": "0%=all Caption, 100%=all Genre. Per-sample override takes priority.",
|
||||
"skip_metas": "Skip BPM/Key/Time Signature",
|
||||
"skip_metas_info": "Skip BPM/Key/Time Signature generation. Caption and Genre are still generated by LLM.",
|
||||
"genre_ratio_info": "0%=All Caption, 100%=All Genre. Single sample override takes precedence.",
|
||||
"skip_metas": "Skip BPM/Key/TimeSig",
|
||||
"skip_metas_info": "Skip BPM/Key/TimeSig generation. Captions and genres are still generated by LM.",
|
||||
"only_unlabeled": "Only Unlabeled",
|
||||
"only_unlabeled_info": "Only label samples without caption (useful for resuming failed labeling)",
|
||||
"only_unlabeled_info": "Only label samples with no caption (for continuing failed labeling)",
|
||||
"auto_label_btn": "🏷️ Auto-Label All",
|
||||
"label_progress": "Labeling Progress",
|
||||
"select_sample": "Select Sample #",
|
||||
"select_sample_info": "Choose a sample to preview and edit",
|
||||
"select_sample_info": "Select sample to preview and edit",
|
||||
"audio_preview": "Audio Preview",
|
||||
"filename": "Filename",
|
||||
"caption": "Caption",
|
||||
"genre": "Genre",
|
||||
"prompt_override_label": "Prompt Override (this sample)",
|
||||
"prompt_override_label": "Prompt Override (This Sample)",
|
||||
"prompt_override_info": "Override global ratio for this sample",
|
||||
"lyrics_editable_label": "Lyrics (editable, used for training)",
|
||||
"lyrics_editable_label": "Lyrics (Editable for Training)",
|
||||
"raw_lyrics_label": "Raw Lyrics (from .txt file)",
|
||||
"no_lyrics_placeholder": "(no .txt lyrics file)",
|
||||
"no_lyrics_placeholder": "(No .txt lyrics file)",
|
||||
"bpm": "BPM",
|
||||
"key_label": "Key",
|
||||
"key_placeholder": "C Major",
|
||||
"time_sig": "Time Signature",
|
||||
"time_sig": "Time Sig",
|
||||
"duration_s": "Duration (s)",
|
||||
"language": "Language",
|
||||
"instrumental": "Instrumental",
|
||||
"save_changes_btn": "💾 Save Changes",
|
||||
"edit_status": "Edit Status",
|
||||
"save_path": "Save Path",
|
||||
"save_path_info": "Path where the dataset JSON will be saved",
|
||||
"save_path_info": "Path to save dataset JSON",
|
||||
"save_dataset_btn": "💾 Save Dataset",
|
||||
"save_status": "Save Status",
|
||||
"load_existing_label": "Load Existing Dataset (Optional)",
|
||||
"load_existing_info": "Path to a previously saved dataset JSON file",
|
||||
"load_existing_info": "Path to previously saved dataset JSON file",
|
||||
"load_dataset_btn": "📂 Load Dataset",
|
||||
"tensor_output_dir": "Tensor Output Directory",
|
||||
"tensor_output_info": "Directory to save preprocessed tensor files",
|
||||
|
|
@ -321,25 +355,22 @@
|
|||
"preprocess_progress": "Preprocessing Progress",
|
||||
"preprocessed_tensors_dir": "Preprocessed Tensors Directory",
|
||||
"preprocessed_tensors_info": "Directory containing preprocessed .pt tensor files",
|
||||
"train_section_tensors": "Preprocessed Dataset Selection",
|
||||
"train_section_lora": "LoRA Settings",
|
||||
"train_section_params": "Training Parameters",
|
||||
"dataset_info": "Dataset Info",
|
||||
"lora_rank": "LoRA Rank (r)",
|
||||
"lora_rank_info": "Higher = more capacity, more memory",
|
||||
"lora_rank_info": "Higher capacity but more VRAM",
|
||||
"lora_alpha": "LoRA Alpha",
|
||||
"lora_alpha_info": "Scaling factor (typically 2x rank)",
|
||||
"lora_alpha_info": "Scaling factor (usually 2x Rank)",
|
||||
"lora_dropout": "LoRA Dropout",
|
||||
"learning_rate": "Learning Rate",
|
||||
"learning_rate_info": "Start with 3e-4, adjust if needed",
|
||||
"learning_rate_info": "Start with 3e-4, adjust as needed",
|
||||
"max_epochs": "Max Epochs",
|
||||
"batch_size": "Batch Size",
|
||||
"batch_size_info": "Increase if you have enough VRAM",
|
||||
"batch_size_info": "Increase if VRAM allows",
|
||||
"gradient_accumulation": "Gradient Accumulation",
|
||||
"gradient_accumulation_info": "Effective batch = batch_size × accumulation",
|
||||
"gradient_accumulation_info": "Effective Batch Size = batch_size * accum_steps",
|
||||
"save_every_n_epochs": "Save Every N Epochs",
|
||||
"shift": "Shift",
|
||||
"shift_info": "Timestep shift for turbo model",
|
||||
"shift_info": "Timestep shift for Turbo models",
|
||||
"seed": "Seed",
|
||||
"output_dir": "Output Directory",
|
||||
"output_dir_info": "Directory to save trained LoRA weights",
|
||||
|
|
@ -354,5 +385,15 @@
|
|||
"export_path": "Export Path",
|
||||
"export_lora_btn": "📦 Export LoRA",
|
||||
"export_status": "Export Status"
|
||||
},
|
||||
"gen": {
|
||||
"enable_normalization": "Enable Normalization",
|
||||
"enable_normalization_info": "Normalize audio volume to target peak level to prevent clipping or ensure consistent loudness.",
|
||||
"normalization_db": "Target Peak (dB)",
|
||||
"normalization_db_info": "Target peak level in decibels. -1.0 dB is standard safe peak. -0.1 dB is max.",
|
||||
"latent_shift": "Latent Shift",
|
||||
"latent_shift_info": "Shift applied to DiT latents before VAE decode. Default 0 (no shift). Negative values (e.g. -0.04) can reduce clipping.",
|
||||
"latent_rescale": "Latent Rescale",
|
||||
"latent_rescale_info": "Rescale factor for DiT latents before VAE decode. Default 1.0 (no rescale). Values < 1.0 (e.g. 0.91) can reduce clipping."
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1,356 +1,391 @@
|
|||
{
|
||||
"app": {
|
||||
"title": "🎛️ סביבת העבודה ACE-Step V1.5 Playground💡",
|
||||
"subtitle": "פורצים את גבולות יצירת המוזיקה בקוד פתוח"
|
||||
},
|
||||
"dataset": {
|
||||
"title": "📊 סייר מערכי נתונים (Dataset Explorer)",
|
||||
"dataset_label": "מערך נתונים",
|
||||
"dataset_info": "בחר מערך נתונים לחקירה",
|
||||
"import_btn": "📥 ייבוא מערך נתונים",
|
||||
"search_type_label": "סוג חיפוש",
|
||||
"search_type_info": "כיצד למצוא פריטים",
|
||||
"search_value_label": "ערך חיפוש",
|
||||
"search_value_placeholder": "הזן מפתחות או אינדקס (השאר ריק לבחירה אקראית)",
|
||||
"search_value_info": "מפתחות: התאמה מדויקת, אינדקס: 0 עד גודל המערך פחות 1",
|
||||
"instruction_label": "📝 הנחיה (Instruction)",
|
||||
"instruction_placeholder": "אין הנחיה זמינה",
|
||||
"metadata_title": "📋 מטא-דאטה של הפריט (JSON)",
|
||||
"metadata_label": "מידע מלא על הפריט",
|
||||
"source_audio": "אודיו מקור",
|
||||
"target_audio": "אודיו יעד",
|
||||
"reference_audio": "אודיו ייחוס",
|
||||
"get_item_btn": "🔍 קבל פריט",
|
||||
"use_src_checkbox": "השתמש באודיו מקור ממערך הנתונים",
|
||||
"use_src_info": "סמן כדי להשתמש באודיו המקור מתוך מערך הנתונים",
|
||||
"data_status_label": "📊 מצב נתונים",
|
||||
"data_status_default": "❌ לא יובא מערך נתונים",
|
||||
"autofill_btn": "📋 מילוי אוטומטי של טופס היצירה"
|
||||
},
|
||||
"service": {
|
||||
"title": "🔧 הגדרות שירות",
|
||||
"checkpoint_label": "קובץ נקודת ביקורת (Checkpoint)",
|
||||
"checkpoint_info": "בחר קובץ נקודת ביקורת של מודל מאומן (נתיב מלא או שם קובץ)",
|
||||
"refresh_btn": "🔄 רענון",
|
||||
"model_path_label": "נתיב מודל ראשי",
|
||||
"model_path_info": "בחר את ספריית הגדרות המודל (נסרק אוטומטית מנקודות הביקורת)",
|
||||
"device_label": "מכשיר (Device)",
|
||||
"device_info": "מכשיר עיבוד (מומלץ זיהוי אוטומטי)",
|
||||
"lm_model_path_label": "נתיב מודל 5Hz LM",
|
||||
"lm_model_path_info": "בחר את קובץ נקודת הביקורת של מודל ה-5Hz LM",
|
||||
"backend_label": "מנוע (Backend) 5Hz LM",
|
||||
"backend_info": "בחר מנוע עבור 5Hz LM: vllm (מהיר יותר) או pt (PyTorch, תואם יותר)",
|
||||
"init_llm_label": "אתחול 5Hz LM",
|
||||
"init_llm_info": "סמן כדי לאתחל את ה-5Hz LM במהלך אתחול השירות",
|
||||
"flash_attention_label": "השתמש ב-Flash Attention",
|
||||
"flash_attention_info_enabled": "הפעל Flash Attention להסקה מהירה יותר (דורש חבילת flash_attn)",
|
||||
"flash_attention_info_disabled": "Flash Attention אינו זמין (חבילת flash_attn לא מותקנת)",
|
||||
"offload_cpu_label": "העברה ל-CPU (Offload)",
|
||||
"offload_cpu_info": "העבר מודלים ל-CPU כשאינם בשימוש כדי לחסוך בזיכרון גרפי (VRAM)",
|
||||
"offload_dit_cpu_label": "העברת DiT ל-CPU",
|
||||
"offload_dit_cpu_info": "העבר DiT ל-CPU (דורש 'העברה ל-CPU')",
|
||||
"compile_model_label": "הידור מודל (Compile)",
|
||||
"compile_model_info": "השתמש ב-torch.compile לאופטימיזציה של המודל (נדרש עבור קוונטיזציה)",
|
||||
"quantization_label": "קוונטיזציה INT8",
|
||||
"quantization_info": "הפעל קוונטיזציה של משקולות בלבד (INT8) להפחתת שימוש ב-VRAM (דורש הידור מודל)",
|
||||
"init_btn": "אתחול שירות",
|
||||
"status_label": "מצב",
|
||||
"language_label": "שפת ממשק",
|
||||
"language_info": "בחר את שפת הממשק"
|
||||
},
|
||||
"generation": {
|
||||
"required_inputs": "📝 קלטים נדרשים",
|
||||
"task_type_label": "סוג משימה",
|
||||
"task_type_info": "בחר את סוג המשימה ליצירה",
|
||||
"instruction_label": "הנחיה",
|
||||
"instruction_info": "ההנחיה נוצרת אוטומטית בהתאם לסוג המשימה",
|
||||
"load_btn": "טעינה",
|
||||
"track_name_label": "שם רצועה",
|
||||
"track_name_info": "בחר שם רצועה עבור משימות lego/extract",
|
||||
"track_classes_label": "שמות רצועות",
|
||||
"track_classes_info": "בחר מספר מחלקות רצועה עבור משימה מלאה",
|
||||
"audio_uploads": "🎵 העלאות אודיו",
|
||||
"reference_audio": "אודיו ייחוס (אופציונלי)",
|
||||
"source_audio": "אודיו מקור (אופציונלי)",
|
||||
"convert_codes_btn": "המר לקודים",
|
||||
"lm_codes_hints": "🎼 רמזי קודי LM",
|
||||
"lm_codes_label": "רמזי קודי LM",
|
||||
"lm_codes_placeholder": "<|audio_code_10695|><|audio_code_54246|>...",
|
||||
"lm_codes_info": "הדבק רמזי קודי LM עבור יצירת טקסט למוזיקה (text2music)",
|
||||
"lm_codes_sample": "רמזי קודי LM (דגימה {n})",
|
||||
"lm_codes_sample_info": "קודים עבור דגימה {n}",
|
||||
"transcribe_btn": "תמלול",
|
||||
"repainting_controls": "🎨 בקרת צביעה מחדש (בשניות)",
|
||||
"repainting_start": "תחילת צביעה מחדש",
|
||||
"repainting_end": "סיום צביעה מחדש",
|
||||
"mode_label": "מצב יצירה",
|
||||
"mode_info": "פשוט: תאר מוזיקה בשפה טבעית. מותאם אישית: שליטה מלאה בתיאור ומילים.",
|
||||
"mode_simple": "פשוט",
|
||||
"mode_custom": "מותאם אישית",
|
||||
"simple_query_label": "תיאור השיר",
|
||||
"simple_query_placeholder": "תאר את המוזיקה שברצונך ליצור, למשל: 'שיר אהבה אקוסטי שקט לערב רגוע'. השאר ריק לדגימה אקראית.",
|
||||
"simple_query_info": "הזן תיאור בשפה טבעית של המוזיקה שברצונך ליצור",
|
||||
"simple_vocal_language_label": "שפת שירה (אופציונלי)",
|
||||
"simple_vocal_language_info": "בחר שפות מועדפות למילים. השתמש ב-'unknown' לכל שפה.",
|
||||
"create_sample_btn": "צור דגימה",
|
||||
"caption_title": "📝 תיאור מוזיקלי (Caption)",
|
||||
"caption_label": "תיאור מוזיקלי (אופציונלי)",
|
||||
"caption_placeholder": "מנגינת גיטרה אקוסטית שלווה עם שירה רכה...",
|
||||
"caption_info": "תאר את הסגנון, הז'אנר, הכלים והאווירה",
|
||||
"lyrics_title": "📝 מילים",
|
||||
"lyrics_label": "מילים (אופציונלי)",
|
||||
"lyrics_placeholder": "[בית 1]\\nתחת שמי הלילה...\\nאני מרגיש חי...",
|
||||
"lyrics_info": "מילות השיר עם מבנה",
|
||||
"instrumental_label": "אינסטרומנטלי (ללא שירה)",
|
||||
"format_btn": "פרמוט",
|
||||
"optional_params": "⚙️ פרמטרים אופציונליים",
|
||||
"vocal_language_label": "שפת שירה (אופציונלי)",
|
||||
"vocal_language_info": "השתמש ב-`unknown` לקטעים כליים",
|
||||
"bpm_label": "קצב (BPM) (אופציונלי)",
|
||||
"bpm_info": "השאר ריק אם לא ידוע",
|
||||
"keyscale_label": "סולם (KeyScale) (אופציונלי)",
|
||||
"keyscale_placeholder": "השאר ריק אם לא ידוע",
|
||||
"keyscale_info": "A-G, #/♭, מז'ור/מינור",
|
||||
"timesig_label": "משקל מוזיקלי (אופציונלי)",
|
||||
"timesig_info": "2/4, 3/4, 4/4...",
|
||||
"duration_label": "אורך אודיו (שניות)",
|
||||
"duration_info": "השתמש ב-1- לאקראי",
|
||||
"batch_size_label": "גודל מנה (Batch Size)",
|
||||
"batch_size_info": "מספר קטעי אודיו ליצירה (מקסימום 8)",
|
||||
"advanced_settings": "🔧 הגדרות מתקדמות",
|
||||
"inference_steps_label": "צעדי הסקה של DiT",
|
||||
"inference_steps_info": "Turbo: מקסימום 8, Base: מקסימום 200",
|
||||
"guidance_scale_label": "קנה מידה להנחיה (רק למודל base)",
|
||||
"guidance_scale_info": "ערכים גבוהים יותר נצמדים יותר לטקסט",
|
||||
"seed_label": "גרעין (Seed)",
|
||||
"seed_info": "השתמש בערכים מופרדים בפסיקים עבור מנות",
|
||||
"random_seed_label": "גרעין אקראי",
|
||||
"random_seed_info": "אפשר ליצירה אוטומטית של גרעינים",
|
||||
"audio_format_label": "פורמט אודיו",
|
||||
"audio_format_info": "פורמט האודיו עבור הקבצים שיישמרו",
|
||||
"use_adg_label": "השתמש ב-ADG",
|
||||
"use_adg_info": "הפעל Angle Domain Guidance",
|
||||
"shift_label": "Shift",
|
||||
"shift_info": "פקטור הסטת צעדי זמן למודלי base (טווח 1.0~5.0, ברירת מחדל 3.0). לא משפיע על מודלי turbo.",
|
||||
"infer_method_label": "שיטת הסקה",
|
||||
"infer_method_info": "שיטת הסקת הדיפוזיה. ODE (Euler) מהירה יותר, SDE (stochastic) עשויה להפיק תוצאות שונות.",
|
||||
"custom_timesteps_label": "צעדי זמן מותאמים אישית",
|
||||
"custom_timesteps_info": "אופציונלי: ערכים מופרדים בפסיקים מ-1.0 עד 0.0. דורס את צעדי ההסקה וה-shift.",
|
||||
"cfg_interval_start": "תחילת מרווח CFG",
|
||||
"cfg_interval_end": "סיום מרווח CFG",
|
||||
"lm_params_title": "🤖 פרמטרי יצירת LM",
|
||||
"lm_temperature_label": "טמפרטורת LM",
|
||||
"lm_temperature_info": "טמפרטורת 5Hz LM (גבוה יותר = אקראי יותר)",
|
||||
"lm_cfg_scale_label": "קנה מידה LM CFG",
|
||||
"lm_cfg_scale_info": "5Hz LM CFG (1.0 = ללא CFG)",
|
||||
"lm_top_k_label": "LM Top-K",
|
||||
"lm_top_k_info": "Top-K (0 = מושבת)",
|
||||
"lm_top_p_label": "LM Top-P",
|
||||
"lm_top_p_info": "Top-P (1.0 = מושבת)",
|
||||
"lm_negative_prompt_label": "הנחיה שלילית ל-LM",
|
||||
"lm_negative_prompt_placeholder": "הזן הנחיה שלילית עבור CFG",
|
||||
"lm_negative_prompt_info": "הנחיה שלילית (בשימוש כאשר LM CFG Scale > 1.0)",
|
||||
"cot_metas_label": "CoT Metas",
|
||||
"cot_metas_info": "השתמש ב-LM ליצירת מטא-דאטה CoT (בטל סימון כדי לדלג)",
|
||||
"cot_language_label": "שפת CoT",
|
||||
"cot_language_info": "יצירת שפה ב-CoT (שרשרת מחשבה)",
|
||||
"constrained_debug_label": "ניקוי באגים של פענוח מוגבל",
|
||||
"constrained_debug_info": "הפעל לוגים של ניקוי באגים עבור פענוח מוגבל",
|
||||
"auto_score_label": "דירוג אוטומטי",
|
||||
"auto_score_info": "חשב אוטומטית ציוני איכות לכל קטעי האודיו שנוצרו",
|
||||
"auto_lrc_label": "LRC אוטומטי",
|
||||
"auto_lrc_info": "צור אוטומטית חותמות זמן למילים (LRC) לכל קטעי האודיו",
|
||||
"lm_batch_chunk_label": "גודל מקטע מנת LM",
|
||||
"lm_batch_chunk_info": "מקסימום פריטים למקטע מנת LM (ברירת מחדל: 8, מוגבל ע\"י זיכרון ה-GPU)",
|
||||
"codes_strength_label": "חוזק קודי LM",
|
||||
"codes_strength_info": "שליטה בכמות צעדי הניקוי מרעש המשתמשים בקודים שנוצרו ע\"י ה-LM",
|
||||
"cover_strength_label": "חוזק כיסוי אודיו (Audio Cover)",
|
||||
"cover_strength_info": "שליטה בכמות צעדי הניקוי מרעש המשתמשים במצב כיסוי",
|
||||
"score_sensitivity_label": "רגישות ציון איכות",
|
||||
"score_sensitivity_info": "נמוך יותר = רגיש יותר (ברירת מחדל: 1.0)",
|
||||
"think_label": "חשיבה (Think)",
|
||||
"parallel_thinking_label": "חשיבה מקבילית",
|
||||
"generate_btn": "🎵 צור מוזיקה",
|
||||
"autogen_label": "יצירה אוטומטית",
|
||||
"caption_rewrite_label": "שכתוב תיאור"
|
||||
},
|
||||
"results": {
|
||||
"title": "🎵 תוצאות",
|
||||
"generated_music": "🎵 מוזיקה שנוצרה (דגימה {n})",
|
||||
"send_to_src_btn": "🔗 שלח לאודיו מקור",
|
||||
"save_btn": "💾 שמירה",
|
||||
"score_btn": "📊 דירוג",
|
||||
"lrc_btn": "🎵 LRC",
|
||||
"quality_score_label": "ציון איכות (דגימה {n})",
|
||||
"quality_score_placeholder": "לחץ על 'דירוג' לחישוב ציון איכות מבוסס מורכבות (Perplexity)",
|
||||
"codes_label": "קודי LM (דגימה {n})",
|
||||
"lrc_label": "חותמות זמן למילים (דגימה {n})",
|
||||
"lrc_placeholder": "לחץ על 'LRC' ליצירת חותמות זמן",
|
||||
"details_accordion": "📊 דירוג, LRC וקודי LM",
|
||||
"generation_status": "מצב יצירה",
|
||||
"current_batch": "מנה נוכחית",
|
||||
"batch_indicator": "מנה {current} / {total}",
|
||||
"next_batch_status": "מצב המנה הבאה",
|
||||
"prev_btn": "◀ הקודם",
|
||||
"next_btn": "הבא ▶",
|
||||
"restore_params_btn": "↙️ החל הגדרות אלו על הממשק (שחזור פרמטרי מנה)",
|
||||
"batch_results_title": "📁 תוצאות המנה ופרטי יצירה",
|
||||
"all_files_label": "📁 כל הקבצים שנוצרו (הורדה)",
|
||||
"generation_details": "פרטי יצירה"
|
||||
},
|
||||
"messages": {
|
||||
"no_audio_to_save": "❌ אין אודיו לשמירה",
|
||||
"save_success": "✅ האודיו והמטא-דאטה נשמרו ב-{filename}",
|
||||
"save_failed": "❌ השמירה נכשלה: {error}",
|
||||
"no_file_selected": "⚠️ לא נבחר קובץ",
|
||||
"params_loaded": "✅ הפרמטרים נטענו מ-{filename}",
|
||||
"invalid_json": "❌ קובץ JSON לא תקין: {error}",
|
||||
"load_error": "❌ שגיאה בטעינת הקובץ: {error}",
|
||||
"example_loaded": "📁 נטען דגם מ-{filename}",
|
||||
"example_failed": "נכשל ניתוח קובץ ה-JSON ב-{filename}: {error}",
|
||||
"example_error": "שגיאה בטעינת הדגם: {error}",
|
||||
"lm_generated": "🤖 נוצר דגם באמצעות ה-LM",
|
||||
"lm_fallback": "יצירת דגם באמצעות ה-LM נכשלה, חוזר לשימוש בספריית הדגמים",
|
||||
"lm_not_initialized": "❌ 5Hz LM לא מאותחל. נא לאתחל אותו תחילה.",
|
||||
"autogen_enabled": "🔄 יצירה אוטומטית הופעלה - המנה הבאה תיווצר לאחר מכן",
|
||||
"batch_ready": "✅ מנה {n} מוכנה! לחץ על 'הבא' לצפייה.",
|
||||
"batch_generating": "🔄 מתחיל יצירת רקע עבור מנה {n}...",
|
||||
"batch_failed": "❌ יצירת הרקע נכשלה: {error}",
|
||||
"viewing_batch": "✅ צופה במנה {n}",
|
||||
"at_first_batch": "נמצא כבר במנה הראשונה",
|
||||
"at_last_batch": "אין מנה באה זמינה",
|
||||
"batch_not_found": "מנה {n} לא נמצאה בתור",
|
||||
"no_batch_data": "לא נמצאו נתוני מנה לשחזור.",
|
||||
"params_restored": "✅ פרמטרי הממשק שוחזרו ממנה {n}",
|
||||
"scoring_failed": "❌ שגיאה: נתוני המנה לא נמצאו",
|
||||
"no_codes": "❌ אין קודי אודיו זמינים. נא ליצור מוזיקה תחילה.",
|
||||
"score_failed": "❌ הדירוג נכשל: {error}",
|
||||
"score_error": "❌ שגיאה בחישוב הציון: {error}",
|
||||
"lrc_no_batch_data": "❌ לא נמצאו נתוני מנה. נא ליצור מוזיקה תחילה.",
|
||||
"lrc_no_extra_outputs": "❌ לא נמצאו פלטים נוספים. טנזורי התניה אינם זמינים.",
|
||||
"lrc_missing_tensors": "❌ חסרים טנזורים נדרשים ליצירת LRC.",
|
||||
"lrc_sample_not_exist": "❌ הדגימה אינה קיימת במנה הנוכחית.",
|
||||
"lrc_empty_result": "⚠️ יצירת ה-LRC הפיקה תוצאה ריקה.",
|
||||
"empty_query": "⚠️ נא להזין תיאור מוזיקלי.",
|
||||
"sample_creation_failed": "❌ יצירת הדגימה נכשלה. נא לנסות שוב.",
|
||||
"sample_created": "✅ הדגימה נוצרה! בדוק את התיאור והמילים, ולאחר מכן לחץ על 'צור מוזיקה'.",
|
||||
"simple_examples_not_found": "⚠️ ספריית הדגמים של המצב הפשוט לא נמצאה.",
|
||||
"simple_examples_empty": "⚠️ לא נמצאו קבצי דוגמה במצב פשוט.",
|
||||
"simple_example_loaded": "🎲 נטענה דוגמה אקראית מ-{filename}",
|
||||
"format_success": "✅ התיאור והמילים פורמטו בהצלחה",
|
||||
"format_failed": "❌ הפירמוט נכשל: {error}",
|
||||
"skipping_metas_cot": "⚡ מדלג על שלב 1 של מטא-דאטה COT (הדגימה כבר מפורמטת)",
|
||||
"invalid_timesteps_format": "⚠️ פורמט צעדי זמן לא תקין. משתמש בלוח זמנים כברירת מחדל.",
|
||||
"timesteps_out_of_range": "⚠️ צעדי הזמן חייבים להיות בטווח [0, 1]. משתמש בלוח זמנים כברירת מחדל.",
|
||||
"timesteps_count_mismatch": "⚠️ מספר צעדי הזמן ({actual}) שונה מצעדי ההסקה ({expected}). משתמש במספר צעדי הזמן."
|
||||
},
|
||||
"training": {
|
||||
"tab_title": "🎓 אימון LoRA",
|
||||
"tab_dataset_builder": "📁 בונה מערך נתונים",
|
||||
"tab_train_lora": "🚀 אימון LoRA",
|
||||
"quick_start_title": "🚀 התחלה מהירה",
|
||||
"load_dataset_label": "נתיב קובץ JSON של מערך הנתונים",
|
||||
"load_dataset_info": "טעינת מערך נתונים שנשמר בעבר",
|
||||
"load_btn": "📂 טעינה",
|
||||
"load_status": "מצב טעינה",
|
||||
"scan_label": "נתיב ספריית אודיו",
|
||||
"scan_info": "סריקה אחר קבצי אודיו (wav, mp3, flac, ogg, opus)",
|
||||
"scan_btn": "🔍 סריקה",
|
||||
"scan_status": "מצב סריקה",
|
||||
"found_audio_files": "קבצי אודיו שנמצאו",
|
||||
"dataset_name": "שם מערך הנתונים",
|
||||
"dataset_name_placeholder": "הזן שם למערך הנתונים",
|
||||
"dataset_settings_header": "הגדרות מערך נתונים",
|
||||
"tag_prepend": "הוספה בהתחלה (תגית, תיאור)",
|
||||
"tag_append": "הוספה בסוף (תיאור, תגית)",
|
||||
"tag_replace": "החלפת התיאור",
|
||||
"step2_title": "שלב 2: תיוג אוטומטי באמצעות AI",
|
||||
"step2_instruction": "לחץ על הכפתור למטה כדי ליצור אוטומטית מטא-נתונים עבור כל קבצי האודיו באמצעות AI:\n• **תיאור**: סגנון מוזיקה, ז'אנר, תיאור מצב רוח\n• **BPM**: פעימות לדקה\n• **מפתח**: מפתח מוזיקלי (לדוגמה, C Major, Am)\n• **חתימת זמן**: 4/4, 3/4, וכו'",
|
||||
"step3_title": "שלב 3: תצוגה מקדימה ועריכה",
|
||||
"step4_title": "שלב 4: שמירת מערך הנתונים",
|
||||
"step5_title": "שלב 5: עיבוד מקדים לטנזורים (Tensors)",
|
||||
"step5_intro": "**העיבוד המקדים ממיר את מערך הנתונים שלך לטנזורים מחושבים מראש לאימון מהיר.**\n\nאתה יכול:\n• להשתמש במערך הנתונים משלבים 1-4 למעלה, **או**\n• לטעון קובץ JSON של מערך נתונים קיים (אם כבר שמרת אחד)",
|
||||
"step5_details": "שלב זה:\n• מקודד אודיו ל-latents של VAE\n• מקודד תיאורים ומילים ל-embeddings טקסט\n• מפעיל את מקודד התנאי\n• שומר את כל הטנזורים לקבצי `.pt`\n\n⚠️ **זה דורש טעינת המודל ועשוי לקחת מספר דקות.**",
|
||||
"train_tensor_selection_desc": "בחר את הספרייה המכילה קבצי טנזור מעובדים מראש (קבצי `.pt`).\nאלה נוצרים בכרטיסייה \"בונה מערך נתונים\" באמצעות כפתור \"עיבוד מקדים\".",
|
||||
"all_instrumental": "הכל אינסטרומנטלי",
|
||||
"all_instrumental_info": "סמן אם כל הרצועות הן כליות (ללא שירה)",
|
||||
"custom_tag": "תגית הפעלה מותאמת אישית",
|
||||
"custom_tag_info": "תגית ייחודית להפעלת הסגנון של LoRA זו",
|
||||
"tag_position": "מיקום התגית",
|
||||
"tag_position_info": "היכן למקם את התגית המותאמת אישית בתוך התיאור",
|
||||
"genre_ratio": "יחס ז'אנר (%)",
|
||||
"genre_ratio_info": "0% = הכל תיאור, 100% = הכל ז'אנר. הגדרה פר-דגימה קודמת להגדרת הכלל.",
|
||||
"skip_metas": "דלג על BPM/סולם/משקל",
|
||||
"skip_metas_info": "דלג על יצירת BPM/סולם/משקל. התיאור והז'אנר עדיין ייווצרו על ידי ה-LLM.",
|
||||
"only_unlabeled": "רק כאלו ללא תיוג",
|
||||
"only_unlabeled_info": "תייג רק דגימות ללא תיאור (שימושי להמשך תיוג שנכשל)",
|
||||
"auto_label_btn": "🏷️ תיוג אוטומטי של הכל",
|
||||
"label_progress": "התקדמות התיוג",
|
||||
"select_sample": "בחר דגימה #",
|
||||
"select_sample_info": "בחר דגימה לצפייה ועריכה",
|
||||
"audio_preview": "תצוגה מקדימה של אודיו",
|
||||
"filename": "שם קובץ",
|
||||
"caption": "תיאור",
|
||||
"genre": "ז'אנר",
|
||||
"prompt_override_label": "דריסת פרומפט (לדגימה זו)",
|
||||
"prompt_override_info": "דריסת היחס הכללי עבור דגימה זו",
|
||||
"lyrics_editable_label": "מילים (ניתן לעריכה, משמש לאימון)",
|
||||
"raw_lyrics_label": "מילים גולמיות (מתוך קובץ .txt)",
|
||||
"no_lyrics_placeholder": "(אין קובץ מילים .txt)",
|
||||
"bpm": "BPM",
|
||||
"key_label": "סולם (Key)",
|
||||
"key_placeholder": "C Major",
|
||||
"time_sig": "משקל מוזיקלי",
|
||||
"duration_s": "משך (שניות)",
|
||||
"language": "שפה",
|
||||
"instrumental": "אינסטרומנטלי",
|
||||
"save_changes_btn": "💾 שמירת שינויים",
|
||||
"edit_status": "מצב עריכה",
|
||||
"save_path": "נתיב שמירה",
|
||||
"save_path_info": "הנתיב שבו יישמר קובץ ה-JSON של מערך הנתונים",
|
||||
"save_dataset_btn": "💾 שמירת מערך נתונים",
|
||||
"save_status": "מצב שמירה",
|
||||
"load_existing_label": "טעינת מערך נתונים קיים (אופציונלי)",
|
||||
"load_existing_info": "נתיב לקובץ JSON של מערך נתונים שנשמר בעבר",
|
||||
"load_dataset_btn": "📂 טעינת מערך נתונים",
|
||||
"tensor_output_dir": "ספריית פלט של טנזורים",
|
||||
"tensor_output_info": "הספרייה לשמירת קבצי טנזור שעברו עיבוד מקדים",
|
||||
"preprocess_btn": "⚡ עיבוד מקדים",
|
||||
"preprocess_progress": "התקדמות עיבוד מקדים",
|
||||
"preprocessed_tensors_dir": "ספריית טנזורים מעובדים",
|
||||
"preprocessed_tensors_info": "ספרייה המכילה קבצי .pt של טנזורים מעובדים",
|
||||
"train_section_tensors": "בחירת מערך נתונים מעובד",
|
||||
"train_section_lora": "הגדרות LoRA",
|
||||
"train_section_params": "פרמטרי אימון",
|
||||
"dataset_info": "מידע על מערך הנתונים",
|
||||
"lora_rank": "דרגת LoRA (Rank)",
|
||||
"lora_rank_info": "גבוה יותר = יותר קיבולת, יותר זיכרון",
|
||||
"lora_alpha": "LoRA Alpha",
|
||||
"lora_alpha_info": "פקטור קנה מידה (בדרך כלל פי 2 מה-Rank)",
|
||||
"lora_dropout": "LoRA Dropout",
|
||||
"learning_rate": "קצב למידה (Learning Rate)",
|
||||
"learning_rate_info": "התחל עם 3e-4, שנה במידת הצורך",
|
||||
"max_epochs": "מקסימום תקופות (Epochs)",
|
||||
"batch_size": "גודל מנה (Batch Size)",
|
||||
"batch_size_info": "הגדל אם יש לך מספיק זיכרון גרפי (VRAM)",
|
||||
"gradient_accumulation": "צבירת גרדיאנטים (Accumulation)",
|
||||
"gradient_accumulation_info": "גודל מנה אפקטיבי = גודל מנה × צבירה",
|
||||
"save_every_n_epochs": "שמור כל N תקופות (Epochs)",
|
||||
"shift": "Shift (הסטה)",
|
||||
"shift_info": "הסטת צעדי זמן עבור מודל turbo",
|
||||
"seed": "גרעין (Seed)",
|
||||
"output_dir": "ספריית פלט",
|
||||
"output_dir_info": "ספרייה לשמירת משקולות ה-LoRA המאומנות",
|
||||
"start_training_btn": "🚀 התחלת אימון",
|
||||
"stop_training_btn": "⏹️ עצירת אימון",
|
||||
"training_progress": "התקדמות האימון",
|
||||
"training_log": "יומן אימון",
|
||||
"training_loss_title": "הפסד אימון (Training Loss)",
|
||||
"step": "צעד",
|
||||
"loss": "הפסד (Loss)",
|
||||
"export_header": "ייצוא LoRA",
|
||||
"export_path": "נתיב ייצוא",
|
||||
"export_lora_btn": "📦 ייצוא LoRA",
|
||||
"export_status": "מצב ייצוא"
|
||||
}
|
||||
}
|
||||
{
|
||||
"app": {
|
||||
"title": "🎛️ סביבת העבודה ACE-Step V1.5 Playground💡",
|
||||
"subtitle": "פורצים את גבולות יצירת המוזיקה בקוד פתוח"
|
||||
},
|
||||
"dataset": {
|
||||
"title": "📊 סייר מערכי נתונים (Dataset Explorer)",
|
||||
"dataset_label": "מערך נתונים",
|
||||
"dataset_info": "בחר מערך נתונים לחקירה",
|
||||
"import_btn": "📥 ייבוא מערך נתונים",
|
||||
"search_type_label": "סוג חיפוש",
|
||||
"search_type_info": "כיצד למצוא פריטים",
|
||||
"search_value_label": "ערך חיפוש",
|
||||
"search_value_placeholder": "הזן מפתחות או אינדקס (השאר ריק לבחירה אקראית)",
|
||||
"search_value_info": "מפתחות: התאמה מדויקת, אינדקס: 0 עד גודל המערך פחות 1",
|
||||
"instruction_label": "📝 הנחיה (Instruction)",
|
||||
"instruction_placeholder": "אין הנחיה זמינה",
|
||||
"metadata_title": "📋 מטא-דאטה של הפריט (JSON)",
|
||||
"metadata_label": "מידע מלא על הפריט",
|
||||
"source_audio": "אודיו מקור",
|
||||
"target_audio": "אודיו יעד",
|
||||
"reference_audio": "אודיו ייחוס",
|
||||
"get_item_btn": "🔍 קבל פריט",
|
||||
"use_src_checkbox": "השתמש באודיו מקור ממערך הנתונים",
|
||||
"use_src_info": "סמן כדי להשתמש באודיו המקור מתוך מערך הנתונים",
|
||||
"data_status_label": "📊 מצב נתונים",
|
||||
"data_status_default": "❌ לא יובא מערך נתונים",
|
||||
"autofill_btn": "📋 מילוי אוטומטי של טופס היצירה"
|
||||
},
|
||||
"service": {
|
||||
"title": "🔧 הגדרות שירות",
|
||||
"checkpoint_label": "קובץ נקודת ביקורת (Checkpoint)",
|
||||
"checkpoint_info": "בחר קובץ נקודת ביקורת של מודל מאומן (נתיב מלא או שם קובץ)",
|
||||
"refresh_btn": "🔄 רענון",
|
||||
"model_path_label": "נתיב מודל ראשי",
|
||||
"model_path_info": "בחר את ספריית הגדרות המודל (נסרק אוטומטית מנקודות הביקורת)",
|
||||
"device_label": "מכשיר (Device)",
|
||||
"device_info": "מכשיר עיבוד (מומלץ זיהוי אוטומטי)",
|
||||
"lm_model_path_label": "נתיב מודל 5Hz LM",
|
||||
"lm_model_path_info": "בחר את קובץ נקודת הביקורת של מודל ה-5Hz LM",
|
||||
"backend_label": "מנוע (Backend) 5Hz LM",
|
||||
"backend_info": "בחר מנוע עבור 5Hz LM: vllm (מהיר יותר) או pt (PyTorch, תואם יותר)",
|
||||
"init_llm_label": "אתחול 5Hz LM",
|
||||
"init_llm_info": "סמן כדי לאתחל את ה-5Hz LM במהלך אתחול השירות",
|
||||
"flash_attention_label": "השתמש ב-Flash Attention",
|
||||
"flash_attention_info_enabled": "הפעל Flash Attention להסקה מהירה יותר (דורש חבילת flash_attn)",
|
||||
"flash_attention_info_disabled": "Flash Attention אינו זמין (חבילת flash_attn לא מותקנת)",
|
||||
"offload_cpu_label": "העברה ל-CPU (Offload)",
|
||||
"offload_cpu_info": "העבר מודלים ל-CPU כשאינם בשימוש כדי לחסוך בזיכרון גרפי (VRAM)",
|
||||
"offload_dit_cpu_label": "העברת DiT ל-CPU",
|
||||
"offload_dit_cpu_info": "העבר DiT ל-CPU (דורש 'העברה ל-CPU')",
|
||||
"compile_model_label": "הידור מודל (Compile)",
|
||||
"compile_model_info": "השתמש ב-torch.compile לאופטימיזציה של המודל (נדרש עבור קוונטיזציה)",
|
||||
"quantization_label": "קוונטיזציה INT8",
|
||||
"quantization_info": "הפעל קוונטיזציה של משקולות בלבד (INT8) להפחתת שימוש ב-VRAM (דורש הידור מודל)",
|
||||
"mlx_dit_label": "MLX DiT (Apple Silicon)",
|
||||
"mlx_dit_info_enabled": "השתמש ב-MLX מקורי להפצת DiT על Apple Silicon (מהיר יותר מ-MPS)",
|
||||
"mlx_dit_info_disabled": "MLX לא זמין (דורש macOS + Apple Silicon + חבילת mlx)",
|
||||
"init_btn": "אתחול שירות",
|
||||
"status_label": "מצב",
|
||||
"language_label": "שפת ממשק",
|
||||
"language_info": "בחר את שפת הממשק",
|
||||
"gpu_auto_tier": "שכבה שזוהתה אוטומטית",
|
||||
"tier_label": "דריסת שכבת GPU",
|
||||
"tier_info": "בחר שכבת GPU באופן ידני כדי להתאים ברירות מחדל של אופטימיזציה (העברה, קוונטיזציה, מנוע וכו')"
|
||||
},
|
||||
"generation": {
|
||||
"tab_title": "🎵 יצירה",
|
||||
"required_inputs": "📝 קלטים נדרשים",
|
||||
"task_type_label": "סוג משימה",
|
||||
"task_type_info": "בחר את סוג המשימה ליצירה",
|
||||
"instruction_label": "הנחיה",
|
||||
"instruction_info": "ההנחיה נוצרת אוטומטית בהתאם לסוג המשימה",
|
||||
"load_btn": "טעינה",
|
||||
"track_name_label": "שם רצועה",
|
||||
"track_name_info": "בחר שם רצועה עבור משימות lego/extract",
|
||||
"track_classes_label": "שמות רצועות",
|
||||
"track_classes_info": "בחר מספר מחלקות רצועה עבור משימה מלאה",
|
||||
"audio_uploads": "🎵 העלאות אודיו",
|
||||
"reference_audio": "אודיו ייחוס (אופציונלי)",
|
||||
"source_audio": "אודיו מקור",
|
||||
"convert_codes_btn": "המר לקודים",
|
||||
"analyze_btn": "🔍 ניתוח",
|
||||
"sample_btn": "🎲 לחץ כאן",
|
||||
"load_btn": "📂 טעינה",
|
||||
"lm_codes_hints": "🎼 רמזי קודי LM",
|
||||
"lm_codes_label": "רמזי קודי LM",
|
||||
"lm_codes_placeholder": "<|audio_code_10695|><|audio_code_54246|>...",
|
||||
"lm_codes_info": "הדבק רמזי קודי LM עבור יצירת טקסט למוזיקה (text2music)",
|
||||
"lm_codes_sample": "רמזי קודי LM (דגימה {n})",
|
||||
"lm_codes_sample_info": "קודים עבור דגימה {n}",
|
||||
"transcribe_btn": "תמלול",
|
||||
"repainting_controls": "🎨 בקרת צביעה מחדש (בשניות)",
|
||||
"repainting_start": "תחילת צביעה מחדש",
|
||||
"repainting_end": "סיום צביעה מחדש",
|
||||
"mode_label": "מצב יצירה",
|
||||
"mode_info": "בחר מצב יצירה כדי להתחיל.",
|
||||
"mode_info_simple": "תאר את המוזיקה שלך בשפה טבעית. הבינה המלאכותית תיצור תיאור, מילים ומטא-נתונים עבורך.",
|
||||
"mode_info_custom": "שליטה מלאה בתיאור, מילים וכל הפרמטרים.",
|
||||
"mode_info_remix": "העלה שמע מקור כדי ליצור גרסת רמיקס עם התיאור והמילים שלך.",
|
||||
"mode_info_repaint": "העלה שמע מקור וצבע מחדש טווח זמן מסוים.",
|
||||
"mode_info_extract": "חלץ רצועה מסוימת (שירה, תופים וכו') משמע המקור.",
|
||||
"mode_info_lego": "הרכב מחדש רצועות: החלף רצועה מסוימת בשמע המקור.",
|
||||
"mode_info_complete": "השלם רצועות חסרות בשמע המקור.",
|
||||
"mode_simple": "פשוט",
|
||||
"mode_custom": "מותאם אישית",
|
||||
"simple_query_label": "תיאור השיר",
|
||||
"simple_query_placeholder": "תאר את המוזיקה שברצונך ליצור, למשל: 'שיר אהבה אקוסטי שקט לערב רגוע'. השאר ריק לדגימה אקראית.",
|
||||
"simple_query_info": "הזן תיאור בשפה טבעית של המוזיקה שברצונך ליצור",
|
||||
"simple_vocal_language_label": "שפת שירה",
|
||||
"simple_vocal_language_info": "בחר שפות מועדפות למילים. השתמש ב-'unknown' לכל שפה.",
|
||||
"create_sample_btn": "צור דגימה",
|
||||
"caption_title": "📝 תיאור מוזיקלי (Caption)",
|
||||
"caption_label": "תיאור מוזיקלי (אופציונלי)",
|
||||
"caption_placeholder": "מנגינת גיטרה אקוסטית שלווה עם שירה רכה...",
|
||||
"caption_info": "תאר את הסגנון, הז'אנר, הכלים והאווירה",
|
||||
"lyrics_title": "📝 מילים",
|
||||
"lyrics_label": "מילים (אופציונלי)",
|
||||
"lyrics_placeholder": "[בית 1]\\nתחת שמי הלילה...\\nאני מרגיש חי...",
|
||||
"lyrics_info": "מילות השיר עם מבנה",
|
||||
"instrumental_label": "אינסטרומנטלי (ללא שירה)",
|
||||
"format_btn": "פרמוט",
|
||||
"format_caption_btn": "שפר תיאור",
|
||||
"format_lyrics_btn": "שפר מילים",
|
||||
"optional_params": "⚙️ פרמטרים אופציונליים",
|
||||
"optional_music_props": "🎵 מאפייני מוזיקה",
|
||||
"optional_gen_settings": "📐 הגדרות יצירה",
|
||||
"advanced_dit_section": "🎛️ פרמטרי DiT",
|
||||
"advanced_lm_section": "🤖 פרמטרי LM",
|
||||
"advanced_output_section": "🔊 פלט שמע ועיבוד",
|
||||
"advanced_automation_section": "⚡ אוטומציה ואצוות",
|
||||
"vocal_language_label": "שפת שירה (אופציונלי)",
|
||||
"vocal_language_info": "'unknown' = כלי/אוטומטי",
|
||||
"bpm_label": "קצב (BPM) (אופציונלי)",
|
||||
"bpm_info": "השאר ריק אם לא ידוע",
|
||||
"keyscale_label": "מפתח (אופציונלי)",
|
||||
"keyscale_placeholder": "השאר ריק אם לא ידוע",
|
||||
"keyscale_info": "A-G, #/♭, מז'ור/מינור",
|
||||
"timesig_label": "משקל מוזיקלי (אופציונלי)",
|
||||
"timesig_info": "2/4, 3/4, 4/4...",
|
||||
"duration_label": "אורך אודיו (שניות)",
|
||||
"duration_info": "השתמש ב-1- לאקראי",
|
||||
"batch_size_label": "גודל מנה (Batch Size)",
|
||||
"batch_size_info": "מספר קטעי אודיו ליצירה (מקסימום 8)",
|
||||
"advanced_settings": "⚙️ הגדרות",
|
||||
"inference_steps_label": "צעדי הסקה של DiT",
|
||||
"inference_steps_info": "Turbo: מקסימום 8, Base: מקסימום 200",
|
||||
"guidance_scale_label": "קנה מידה להנחיה (רק למודל base)",
|
||||
"guidance_scale_info": "ערכים גבוהים יותר נצמדים יותר לטקסט",
|
||||
"seed_label": "גרעין (Seed)",
|
||||
"seed_info": "השתמש בערכים מופרדים בפסיקים עבור מנות",
|
||||
"random_seed_label": "גרעין אקראי",
|
||||
"random_seed_info": "אפשר ליצירה אוטומטית של גרעינים",
|
||||
"audio_format_label": "פורמט אודיו",
|
||||
"audio_format_info": "פורמט האודיו עבור הקבצים שיישמרו",
|
||||
"use_adg_label": "השתמש ב-ADG",
|
||||
"use_adg_info": "הפעל Angle Domain Guidance",
|
||||
"shift_label": "Shift",
|
||||
"shift_info": "פקטור הסטת צעדי זמן למודלי base (טווח 1.0~5.0, ברירת מחדל 3.0). לא משפיע על מודלי turbo.",
|
||||
"infer_method_label": "שיטת הסקה",
|
||||
"infer_method_info": "שיטת הסקת הדיפוזיה. ODE (Euler) מהירה יותר, SDE (stochastic) עשויה להפיק תוצאות שונות.",
|
||||
"custom_timesteps_label": "צעדי זמן מותאמים אישית",
|
||||
"custom_timesteps_info": "אופציונלי: ערכים מופרדים בפסיקים מ-1.0 עד 0.0. דורס את צעדי ההסקה וה-shift.",
|
||||
"cfg_interval_start": "תחילת מרווח CFG",
|
||||
"cfg_interval_end": "סיום מרווח CFG",
|
||||
"lm_params_title": "🤖 פרמטרי יצירת LM",
|
||||
"lm_temperature_label": "טמפרטורת LM",
|
||||
"lm_temperature_info": "טמפרטורת 5Hz LM (גבוה יותר = אקראי יותר)",
|
||||
"lm_cfg_scale_label": "קנה מידה LM CFG",
|
||||
"lm_cfg_scale_info": "5Hz LM CFG (1.0 = ללא CFG)",
|
||||
"lm_top_k_label": "LM Top-K",
|
||||
"lm_top_k_info": "Top-K (0 = מושבת)",
|
||||
"lm_top_p_label": "LM Top-P",
|
||||
"lm_top_p_info": "Top-P (1.0 = מושבת)",
|
||||
"lm_negative_prompt_label": "הנחיה שלילית ל-LM",
|
||||
"lm_negative_prompt_placeholder": "הזן הנחיה שלילית עבור CFG",
|
||||
"lm_negative_prompt_info": "הנחיה שלילית (בשימוש כאשר LM CFG Scale > 1.0)",
|
||||
"cot_metas_label": "CoT Metas",
|
||||
"cot_metas_info": "השתמש ב-LM ליצירת מטא-דאטה CoT (בטל סימון כדי לדלג)",
|
||||
"cot_language_label": "שפת CoT",
|
||||
"cot_language_info": "יצירת שפה ב-CoT (שרשרת מחשבה)",
|
||||
"constrained_debug_label": "ניקוי באגים של פענוח מוגבל",
|
||||
"constrained_debug_info": "הפעל לוגים של ניקוי באגים עבור פענוח מוגבל",
|
||||
"auto_score_label": "דירוג אוטומטי",
|
||||
"auto_score_info": "חשב אוטומטית ציוני איכות לכל קטעי האודיו שנוצרו",
|
||||
"auto_lrc_label": "LRC אוטומטי",
|
||||
"auto_lrc_info": "צור אוטומטית חותמות זמן למילים (LRC) לכל קטעי האודיו",
|
||||
"lm_batch_chunk_label": "גודל מקטע מנת LM",
|
||||
"lm_batch_chunk_info": "מקסימום פריטים למקטע מנת LM (ברירת מחדל: 8, מוגבל ע\"י זיכרון ה-GPU)",
|
||||
"codes_strength_label": "חוזק קודי LM",
|
||||
"codes_strength_info": "שליטה בכמות צעדי הניקוי מרעש המשתמשים בקודים שנוצרו ע\"י ה-LM",
|
||||
"cover_strength_label": "חוזק כיסוי אודיו (Audio Cover)",
|
||||
"cover_strength_info": "שליטה בכמות צעדי הניקוי מרעש המשתמשים במצב כיסוי",
|
||||
"remix_strength_label": "חוזק רמיקס",
|
||||
"remix_strength_info": "שליטה בכמה הרמיקס סוטה מהאודיו המקורי (נמוך = קרוב יותר למקור)",
|
||||
"cover_noise_strength_label": "חוזק כיסוי",
|
||||
"cover_noise_strength_info": "שליטה בשחזור המלודיה במצב Remix. מומלץ: השתמשו במודל SFT עם ערך של 0.1–0.25. העלאה קלה משחזרת את המלודיה, אך העברת סגנון עשויה לדרוש כוונון פרומפט נוסף. (0 = רעש טהור/ללא כיסוי, 1 = הכי קרוב לאודיו המקורי)",
|
||||
"score_sensitivity_label": "רגישות ציון איכות",
|
||||
"score_sensitivity_info": "נמוך יותר = רגיש יותר (ברירת מחדל: 1.0)",
|
||||
"think_label": "חשיבה (Think)",
|
||||
"parallel_thinking_label": "חשיבה מקבילית",
|
||||
"parallel_thinking_info": "עיבוד מקבילי של דגימות אצווה ליצירה מהירה",
|
||||
"generate_btn": "🎵 צור מוזיקה",
|
||||
"autogen_label": "יצירה אוטומטית",
|
||||
"caption_rewrite_label": "שכתוב תיאור",
|
||||
"caption_rewrite_info": "שימוש ב-LM לשכתוב תיאור לפני יצירה"
|
||||
},
|
||||
"results": {
|
||||
"title": "🎵 תוצאות",
|
||||
"generated_music": "🎵 מוזיקה שנוצרה (דגימה {n})",
|
||||
"send_to_remix_btn": "🔗 שלח לרמיקס",
|
||||
"send_to_repaint_btn": "🔗 שלח לצביעה מחדש",
|
||||
"save_btn": "💾 שמירה",
|
||||
"score_btn": "📊 קבל דירוג",
|
||||
"lrc_btn": "🎵 קבל LRC",
|
||||
"save_lrc_btn": "💾 שמור LRC",
|
||||
"convert_to_codes_btn": "🔄 המר לקודים",
|
||||
"quality_score_label": "ציון איכות (דגימה {n})",
|
||||
"quality_score_placeholder": "לחץ על 'דירוג' לחישוב ציון איכות מבוסס מורכבות (Perplexity)",
|
||||
"codes_label": "קודי LM (דגימה {n})",
|
||||
"lrc_label": "חותמות זמן למילים (דגימה {n})",
|
||||
"lrc_placeholder": "לחץ על 'LRC' ליצירת חותמות זמן",
|
||||
"details_accordion": "📊 דירוג, LRC וקודי LM",
|
||||
"generation_status": "מצב יצירה",
|
||||
"current_batch": "מנה נוכחית",
|
||||
"batch_indicator": "מנה {current} / {total}",
|
||||
"next_batch_status": "מצב המנה הבאה",
|
||||
"prev_btn": "◀ הקודם",
|
||||
"next_btn": "הבא ▶",
|
||||
"restore_params_btn": "↙️ החל הגדרות אלו על הממשק (שחזור פרמטרי מנה)",
|
||||
"batch_results_title": "📁 תוצאות המנה ופרטי יצירה",
|
||||
"all_files_label": "📁 כל הקבצים שנוצרו (הורדה)",
|
||||
"generation_details": "פרטי יצירה"
|
||||
},
|
||||
"messages": {
|
||||
"no_audio_to_save": "❌ אין אודיו לשמירה",
|
||||
"save_success": "✅ האודיו והמטא-דאטה נשמרו ב-{filename}",
|
||||
"save_failed": "❌ השמירה נכשלה: {error}",
|
||||
"no_file_selected": "⚠️ לא נבחר קובץ",
|
||||
"params_loaded": "✅ הפרמטרים נטענו מ-{filename}",
|
||||
"invalid_json": "❌ קובץ JSON לא תקין: {error}",
|
||||
"load_error": "❌ שגיאה בטעינת הקובץ: {error}",
|
||||
"example_loaded": "📁 נטען דגם מ-{filename}",
|
||||
"example_failed": "נכשל ניתוח קובץ ה-JSON ב-{filename}: {error}",
|
||||
"example_error": "שגיאה בטעינת הדגם: {error}",
|
||||
"lm_generated": "🤖 נוצר דגם באמצעות ה-LM",
|
||||
"lm_fallback": "יצירת דגם באמצעות ה-LM נכשלה, חוזר לשימוש בספריית הדגמים",
|
||||
"lm_not_initialized": "❌ 5Hz LM לא מאותחל. נא לאתחל אותו תחילה.",
|
||||
"think_requires_lm": "⚠️ 'Think' דורש שה-5Hz LM יהיה מאותחל. Think הושבת — היצירה תמשיך ללא חשיבת LM.",
|
||||
"autogen_enabled": "🔄 יצירה אוטומטית הופעלה - המנה הבאה תיווצר לאחר מכן",
|
||||
"batch_ready": "✅ מנה {n} מוכנה! לחץ על 'הבא' לצפייה.",
|
||||
"batch_generating": "🔄 מתחיל יצירת רקע עבור מנה {n}...",
|
||||
"batch_failed": "❌ יצירת הרקע נכשלה: {error}",
|
||||
"viewing_batch": "✅ צופה במנה {n}",
|
||||
"at_first_batch": "נמצא כבר במנה הראשונה",
|
||||
"at_last_batch": "אין מנה באה זמינה",
|
||||
"batch_not_found": "מנה {n} לא נמצאה בתור",
|
||||
"no_batch_data": "לא נמצאו נתוני מנה לשחזור.",
|
||||
"params_restored": "✅ פרמטרי הממשק שוחזרו ממנה {n}",
|
||||
"scoring_failed": "❌ שגיאה: נתוני המנה לא נמצאו",
|
||||
"no_codes": "❌ אין קודי אודיו זמינים. נא ליצור מוזיקה תחילה.",
|
||||
"score_failed": "❌ הדירוג נכשל: {error}",
|
||||
"score_error": "❌ שגיאה בחישוב הציון: {error}",
|
||||
"lrc_no_batch_data": "❌ לא נמצאו נתוני מנה. נא ליצור מוזיקה תחילה.",
|
||||
"lrc_no_extra_outputs": "❌ לא נמצאו פלטים נוספים. טנזורי התניה אינם זמינים.",
|
||||
"lrc_missing_tensors": "❌ חסרים טנזורים נדרשים ליצירת LRC.",
|
||||
"lrc_sample_not_exist": "❌ הדגימה אינה קיימת במנה הנוכחית.",
|
||||
"lrc_empty_result": "⚠️ יצירת ה-LRC הפיקה תוצאה ריקה.",
|
||||
"empty_query": "⚠️ נא להזין תיאור מוזיקלי.",
|
||||
"sample_creation_failed": "❌ יצירת הדגימה נכשלה. נא לנסות שוב.",
|
||||
"sample_created": "✅ הדגימה נוצרה! בדוק את התיאור והמילים, ולאחר מכן לחץ על 'צור מוזיקה'.",
|
||||
"simple_examples_not_found": "⚠️ ספריית הדגמים של המצב הפשוט לא נמצאה.",
|
||||
"simple_examples_empty": "⚠️ לא נמצאו קבצי דוגמה במצב פשוט.",
|
||||
"simple_example_loaded": "🎲 נטענה דוגמה אקראית מ-{filename}",
|
||||
"format_success": "✅ התיאור והמילים פורמטו בהצלחה",
|
||||
"format_failed": "❌ הפירמוט נכשל: {error}",
|
||||
"skipping_metas_cot": "⚡ מדלג על שלב 1 של מטא-דאטה COT (הדגימה כבר מפורמטת)",
|
||||
"invalid_timesteps_format": "⚠️ פורמט צעדי זמן לא תקין. משתמש בלוח זמנים כברירת מחדל.",
|
||||
"timesteps_out_of_range": "⚠️ צעדי הזמן חייבים להיות בטווח [0, 1]. משתמש בלוח זמנים כברירת מחדל.",
|
||||
"timesteps_count_mismatch": "⚠️ מספר צעדי הזמן ({actual}) שונה מצעדי ההסקה ({expected}). משתמש במספר צעדי הזמן."
|
||||
},
|
||||
"training": {
|
||||
"tab_title": "🎓 אימון LoRA",
|
||||
"tab_dataset_builder": "📁 בונה מערך נתונים",
|
||||
"tab_train_lora": "🚀 אימון LoRA",
|
||||
"quick_start_title": "🚀 התחלה מהירה",
|
||||
"load_dataset_label": "נתיב קובץ JSON של מערך הנתונים",
|
||||
"load_dataset_info": "טעינת מערך נתונים שנשמר בעבר",
|
||||
"load_btn": "📂 טעינה",
|
||||
"load_status": "מצב טעינה",
|
||||
"scan_label": "נתיב ספריית אודיו",
|
||||
"scan_info": "סריקה אחר קבצי אודיו (wav, mp3, flac, ogg, opus)",
|
||||
"scan_btn": "🔍 סריקה",
|
||||
"scan_status": "מצב סריקה",
|
||||
"found_audio_files": "קבצי אודיו שנמצאו",
|
||||
"dataset_name": "שם מערך הנתונים",
|
||||
"dataset_name_placeholder": "הזן שם למערך הנתונים",
|
||||
"dataset_settings_header": "הגדרות מערך נתונים",
|
||||
"tag_prepend": "הוספה בהתחלה (תגית, תיאור)",
|
||||
"tag_append": "הוספה בסוף (תיאור, תגית)",
|
||||
"tag_replace": "החלפת התיאור",
|
||||
"step2_title": "שלב 2: תיוג אוטומטי באמצעות AI",
|
||||
"step2_instruction": "לחץ על הכפתור למטה כדי ליצור אוטומטית מטא-נתונים עבור כל קבצי האודיו באמצעות AI:\n• **תיאור**: סגנון מוזיקה, ז'אנר, תיאור מצב רוח\n• **BPM**: פעימות לדקה\n• **מפתח**: מפתח מוזיקלי (לדוגמה, C Major, Am)\n• **חתימת זמן**: 4/4, 3/4, וכו'",
|
||||
"step3_title": "שלב 3: תצוגה מקדימה ועריכה",
|
||||
"step4_title": "שלב 4: שמירת מערך הנתונים",
|
||||
"step5_title": "שלב 5: עיבוד מקדים לטנזורים (Tensors)",
|
||||
"step5_intro": "**העיבוד המקדים ממיר את מערך הנתונים שלך לטנזורים מחושבים מראש לאימון מהיר.**\n\nאתה יכול:\n• להשתמש במערך הנתונים משלבים 1-4 למעלה, **או**\n• לטעון קובץ JSON של מערך נתונים קיים (אם כבר שמרת אחד)",
|
||||
"step5_details": "שלב זה:\n• מקודד אודיו ל-latents של VAE\n• מקודד תיאורים ומילים ל-embeddings טקסט\n• מפעיל את מקודד התנאי\n• שומר את כל הטנזורים לקבצי `.pt`\n\n⚠️ **זה דורש טעינת המודל ועשוי לקחת מספר דקות.**",
|
||||
"train_tensor_selection_desc": "בחר את הספרייה המכילה קבצי טנזור מעובדים מראש (קבצי `.pt`).\nאלה נוצרים בכרטיסייה \"בונה מערך נתונים\" באמצעות כפתור \"עיבוד מקדים\".",
|
||||
"all_instrumental": "הכל אינסטרומנטלי",
|
||||
"all_instrumental_info": "סמן אם כל הרצועות הן כליות (ללא שירה)",
|
||||
"custom_tag": "תגית הפעלה מותאמת אישית",
|
||||
"custom_tag_info": "תגית ייחודית להפעלת הסגנון של LoRA זו",
|
||||
"tag_position": "מיקום התגית",
|
||||
"tag_position_info": "היכן למקם את התגית המותאמת אישית בתוך התיאור",
|
||||
"genre_ratio": "יחס ז'אנר (%)",
|
||||
"genre_ratio_info": "0% = הכל תיאור, 100% = הכל ז'אנר. הגדרה פר-דגימה קודמת להגדרת הכלל.",
|
||||
"skip_metas": "דלג על BPM/סולם/משקל",
|
||||
"skip_metas_info": "דלג על יצירת BPM/סולם/משקל. התיאור והז'אנר עדיין ייווצרו על ידי ה-LLM.",
|
||||
"only_unlabeled": "רק כאלו ללא תיוג",
|
||||
"only_unlabeled_info": "תייג רק דגימות ללא תיאור (שימושי להמשך תיוג שנכשל)",
|
||||
"auto_label_btn": "🏷️ תיוג אוטומטי של הכל",
|
||||
"label_progress": "התקדמות התיוג",
|
||||
"select_sample": "בחר דגימה #",
|
||||
"select_sample_info": "בחר דגימה לצפייה ועריכה",
|
||||
"audio_preview": "תצוגה מקדימה של אודיו",
|
||||
"filename": "שם קובץ",
|
||||
"caption": "תיאור",
|
||||
"genre": "ז'אנר",
|
||||
"prompt_override_label": "דריסת פרומפט (לדגימה זו)",
|
||||
"prompt_override_info": "דריסת היחס הכללי עבור דגימה זו",
|
||||
"lyrics_editable_label": "מילים (ניתן לעריכה, משמש לאימון)",
|
||||
"raw_lyrics_label": "מילים גולמיות (מתוך קובץ .txt)",
|
||||
"no_lyrics_placeholder": "(אין קובץ מילים .txt)",
|
||||
"bpm": "BPM",
|
||||
"key_label": "סולם (Key)",
|
||||
"key_placeholder": "C Major",
|
||||
"time_sig": "משקל מוזיקלי",
|
||||
"duration_s": "משך (שניות)",
|
||||
"language": "שפה",
|
||||
"instrumental": "אינסטרומנטלי",
|
||||
"save_changes_btn": "💾 שמירת שינויים",
|
||||
"edit_status": "מצב עריכה",
|
||||
"save_path": "נתיב שמירה",
|
||||
"save_path_info": "הנתיב שבו יישמר קובץ ה-JSON של מערך הנתונים",
|
||||
"save_dataset_btn": "💾 שמירת מערך נתונים",
|
||||
"save_status": "מצב שמירה",
|
||||
"load_existing_label": "טעינת מערך נתונים קיים (אופציונלי)",
|
||||
"load_existing_info": "נתיב לקובץ JSON של מערך נתונים שנשמר בעבר",
|
||||
"load_dataset_btn": "📂 טעינת מערך נתונים",
|
||||
"tensor_output_dir": "ספריית פלט של טנזורים",
|
||||
"tensor_output_info": "הספרייה לשמירת קבצי טנזור שעברו עיבוד מקדים",
|
||||
"preprocess_btn": "⚡ עיבוד מקדים",
|
||||
"preprocess_progress": "התקדמות עיבוד מקדים",
|
||||
"preprocessed_tensors_dir": "ספריית טנזורים מעובדים",
|
||||
"preprocessed_tensors_info": "ספרייה המכילה קבצי .pt של טנזורים מעובדים",
|
||||
"train_section_tensors": "בחירת מערך נתונים מעובד",
|
||||
"train_section_lora": "הגדרות LoRA",
|
||||
"train_section_params": "פרמטרי אימון",
|
||||
"dataset_info": "מידע על מערך הנתונים",
|
||||
"lora_rank": "דרגת LoRA (Rank)",
|
||||
"lora_rank_info": "גבוה יותר = יותר קיבולת, יותר זיכרון",
|
||||
"lora_alpha": "LoRA Alpha",
|
||||
"lora_alpha_info": "פקטור קנה מידה (בדרך כלל פי 2 מה-Rank)",
|
||||
"lora_dropout": "LoRA Dropout",
|
||||
"learning_rate": "קצב למידה (Learning Rate)",
|
||||
"learning_rate_info": "התחל עם 3e-4, שנה במידת הצורך",
|
||||
"max_epochs": "מקסימום תקופות (Epochs)",
|
||||
"batch_size": "גודל מנה (Batch Size)",
|
||||
"batch_size_info": "הגדל אם יש לך מספיק זיכרון גרפי (VRAM)",
|
||||
"gradient_accumulation": "צבירת גרדיאנטים (Accumulation)",
|
||||
"gradient_accumulation_info": "גודל מנה אפקטיבי = גודל מנה × צבירה",
|
||||
"save_every_n_epochs": "שמור כל N תקופות (Epochs)",
|
||||
"shift": "Shift (הסטה)",
|
||||
"shift_info": "הסטת צעדי זמן עבור מודל turbo",
|
||||
"seed": "גרעין (Seed)",
|
||||
"output_dir": "ספריית פלט",
|
||||
"output_dir_info": "ספרייה לשמירת משקולות ה-LoRA המאומנות",
|
||||
"start_training_btn": "🚀 התחלת אימון",
|
||||
"stop_training_btn": "⏹️ עצירת אימון",
|
||||
"training_progress": "התקדמות האימון",
|
||||
"training_log": "יומן אימון",
|
||||
"training_loss_title": "הפסד אימון (Training Loss)",
|
||||
"step": "צעד",
|
||||
"loss": "הפסד (Loss)",
|
||||
"export_header": "ייצוא LoRA",
|
||||
"export_path": "נתיב ייצוא",
|
||||
"export_lora_btn": "📦 ייצוא LoRA",
|
||||
"export_status": "מצב ייצוא"
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -53,12 +53,19 @@
|
|||
"compile_model_info": "torch.compileでモデルを最適化(量子化に必要)",
|
||||
"quantization_label": "INT8 量子化",
|
||||
"quantization_info": "INT8重み量子化を有効にしてVRAMを節約(モデルのコンパイルが必要)",
|
||||
"mlx_dit_label": "MLX DiT (Apple Silicon)",
|
||||
"mlx_dit_info_enabled": "Apple SiliconでMLXネイティブDiT拡散を使用(MPSより高速)",
|
||||
"mlx_dit_info_disabled": "MLXは利用不可(macOS + Apple Silicon + mlxパッケージが必要)",
|
||||
"init_btn": "サービスを初期化",
|
||||
"status_label": "ステータス",
|
||||
"language_label": "UI言語",
|
||||
"language_info": "インターフェース言語を選択"
|
||||
"language_info": "インターフェース言語を選択",
|
||||
"gpu_auto_tier": "自動検出ティア",
|
||||
"tier_label": "GPU ティアの手動選択",
|
||||
"tier_info": "GPUティアを手動で選択して最適化のデフォルト(オフロード、量子化、バックエンドなど)を調整します"
|
||||
},
|
||||
"generation": {
|
||||
"tab_title": "🎵 生成",
|
||||
"required_inputs": "📝 必須入力",
|
||||
"task_type_label": "タスクタイプ",
|
||||
"task_type_info": "生成のタスクタイプを選択",
|
||||
|
|
@ -71,8 +78,11 @@
|
|||
"track_classes_info": "completeタスクの複数のトラッククラスを選択",
|
||||
"audio_uploads": "🎵 オーディオアップロード",
|
||||
"reference_audio": "リファレンスオーディオ(オプション)",
|
||||
"source_audio": "ソースオーディオ(オプション)",
|
||||
"source_audio": "ソースオーディオ",
|
||||
"convert_codes_btn": "コードに変換",
|
||||
"analyze_btn": "🔍 分析",
|
||||
"sample_btn": "🎲 お試し",
|
||||
"load_btn": "📂 読込",
|
||||
"lm_codes_hints": "🎼 LM コードヒント",
|
||||
"lm_codes_label": "LM コードヒント",
|
||||
"lm_codes_placeholder": "<|audio_code_10695|><|audio_code_54246|>...",
|
||||
|
|
@ -84,13 +94,20 @@
|
|||
"repainting_start": "再描画開始",
|
||||
"repainting_end": "再描画終了",
|
||||
"mode_label": "生成モード",
|
||||
"mode_info": "シンプル:自然言語で音楽を説明。カスタム:キャプションと歌詞を完全にコントロール。",
|
||||
"mode_info": "生成モードを選択して開始します。",
|
||||
"mode_info_simple": "自然言語で音楽を説明すると、AIがキャプション、歌詞、メタデータを自動生成します。",
|
||||
"mode_info_custom": "キャプション、歌詞、すべてのパラメータを完全にコントロール。",
|
||||
"mode_info_remix": "ソースオーディオをアップロードして、キャプションと歌詞でリミックスバージョンを作成。",
|
||||
"mode_info_repaint": "ソースオーディオをアップロードして、指定した時間範囲を再描画。",
|
||||
"mode_info_extract": "ソースオーディオから特定のトラック(ボーカル、ドラムなど)を抽出。",
|
||||
"mode_info_lego": "トラックの再構成:ソースオーディオの特定のトラックを置き換え。",
|
||||
"mode_info_complete": "ソースオーディオの欠落したトラックを補完。",
|
||||
"mode_simple": "シンプル",
|
||||
"mode_custom": "カスタム",
|
||||
"simple_query_label": "曲の説明",
|
||||
"simple_query_placeholder": "作成したい音楽を説明してください。例:'静かな夜のための優しいベンガルのラブソング'。空欄の場合はランダムなサンプルが生成されます。",
|
||||
"simple_query_info": "生成したい音楽の自然言語の説明を入力",
|
||||
"simple_vocal_language_label": "ボーカル言語(オプション)",
|
||||
"simple_vocal_language_label": "ボーカル言語",
|
||||
"simple_vocal_language_info": "歌詞の希望言語を選択。任意の言語の場合は'unknown'を使用。",
|
||||
"create_sample_btn": "サンプル作成",
|
||||
"caption_title": "📝 音楽キャプション",
|
||||
|
|
@ -103,12 +120,20 @@
|
|||
"lyrics_info": "構造を持つ曲の歌詞",
|
||||
"instrumental_label": "インストゥルメンタル",
|
||||
"format_btn": "フォーマット",
|
||||
"format_caption_btn": "キャプションを強化",
|
||||
"format_lyrics_btn": "歌詞を強化",
|
||||
"optional_params": "⚙️ オプションパラメータ",
|
||||
"optional_music_props": "🎵 音楽プロパティ",
|
||||
"optional_gen_settings": "📐 生成設定",
|
||||
"advanced_dit_section": "🎛️ DiT 拡散パラメータ",
|
||||
"advanced_lm_section": "🤖 LM 生成パラメータ",
|
||||
"advanced_output_section": "🔊 オーディオ出力と後処理",
|
||||
"advanced_automation_section": "⚡ 自動化とバッチ",
|
||||
"vocal_language_label": "ボーカル言語(オプション)",
|
||||
"vocal_language_info": "インストには`unknown`を使用",
|
||||
"vocal_language_info": "'unknown' = インスト/自動",
|
||||
"bpm_label": "BPM(オプション)",
|
||||
"bpm_info": "空白の場合はN/A",
|
||||
"keyscale_label": "キースケール(オプション)",
|
||||
"keyscale_label": "キー(オプション)",
|
||||
"keyscale_placeholder": "空白の場合はN/A",
|
||||
"keyscale_info": "A-G, #/♭, メジャー/マイナー",
|
||||
"timesig_label": "拍子記号(オプション)",
|
||||
|
|
@ -117,7 +142,7 @@
|
|||
"duration_info": "ランダムの場合は-1を使用",
|
||||
"batch_size_label": "バッチサイズ",
|
||||
"batch_size_info": "生成するオーディオの数(最大8)",
|
||||
"advanced_settings": "🔧 詳細設定",
|
||||
"advanced_settings": "⚙️ 設定",
|
||||
"inference_steps_label": "DiT 推論ステップ",
|
||||
"inference_steps_info": "Turbo: 最大8、Base: 最大200",
|
||||
"guidance_scale_label": "DiT ガイダンススケール(baseモデルのみサポート)",
|
||||
|
|
@ -168,21 +193,31 @@
|
|||
"similarity_denoise_info": "出力が参照オーディオにどれだけ忠実かを制御します。高い値ほど構造を保持します。",
|
||||
"cover_strength_label": "オーディオカバー強度",
|
||||
"cover_strength_info": "カバーモードを使用するデノイジングステップ数を制御",
|
||||
"remix_strength_label": "リミックス強度",
|
||||
"remix_strength_info": "リミックスがソースオーディオからどれだけ逸脱するかを制御(低い = オリジナルに近い)",
|
||||
"cover_noise_strength_label": "カバー強度",
|
||||
"cover_noise_strength_info": "Remixモードでのメロディ復元を制御します。推奨:SFTモデルを使用し、値を0.1〜0.25に設定。わずかに上げるとメロディが復元されますが、スタイル変換には追加のプロンプト調整が必要な場合があります。(0 = 純粋なノイズ/カバーなし、1 = オリジナルに最も近い)",
|
||||
"score_sensitivity_label": "品質スコア感度",
|
||||
"score_sensitivity_info": "低い = より敏感(デフォルト: 1.0)。PMIが[0,1]にマッピングする方法を調整",
|
||||
"think_label": "思考",
|
||||
"parallel_thinking_label": "並列思考",
|
||||
"parallel_thinking_info": "バッチサンプルを並列処理して高速生成",
|
||||
"generate_btn": "🎵 音楽を生成",
|
||||
"autogen_label": "自動生成",
|
||||
"caption_rewrite_label": "キャプション書き換え"
|
||||
"caption_rewrite_label": "キャプション書き換え",
|
||||
"caption_rewrite_info": "生成前にLMでキャプションを書き換え",
|
||||
"advanced_dit_params": "DiT 詳細パラメータ"
|
||||
},
|
||||
"results": {
|
||||
"title": "🎵 結果",
|
||||
"generated_music": "🎵 生成された音楽(サンプル {n})",
|
||||
"send_to_src_btn": "🔗 ソースオーディオに送信",
|
||||
"send_to_remix_btn": "🔗 リミックスに送信",
|
||||
"send_to_repaint_btn": "🔗 リペイントに送信",
|
||||
"save_btn": "💾 保存",
|
||||
"score_btn": "📊 スコア",
|
||||
"lrc_btn": "🎵 LRC",
|
||||
"score_btn": "📊 スコア取得",
|
||||
"lrc_btn": "🎵 LRC 取得",
|
||||
"save_lrc_btn": "💾 LRC 保存",
|
||||
"convert_to_codes_btn": "🔄 コードに変換",
|
||||
"quality_score_label": "品質スコア(サンプル {n})",
|
||||
"quality_score_placeholder": "'スコア'をクリックしてパープレキシティベースの品質スコアを計算",
|
||||
"codes_label": "LM コード(サンプル {n})",
|
||||
|
|
@ -214,6 +249,7 @@
|
|||
"lm_generated": "🤖 LMを使用してサンプルを生成しました",
|
||||
"lm_fallback": "LMを使用したサンプル生成に失敗、サンプルディレクトリにフォールバック",
|
||||
"lm_not_initialized": "❌ 5Hz LMが初期化されていません。最初に初期化してください。",
|
||||
"think_requires_lm": "⚠️ 「Think」機能には5Hz LMの初期化が必要です。Thinkは無効化されました — LM思考なしで生成を続行します。",
|
||||
"autogen_enabled": "🔄 自動生成が有効 - このあと次のバッチを生成します",
|
||||
"batch_ready": "✅ バッチ {n} の準備完了!'次へ'をクリックして表示。",
|
||||
"batch_generating": "🔄 バッチ {n} のバックグラウンド生成を開始...",
|
||||
|
|
@ -354,5 +390,16 @@
|
|||
"export_path": "エクスポートパス",
|
||||
"export_lora_btn": "📦 LoRA をエクスポート",
|
||||
"export_status": "エクスポート状態"
|
||||
},
|
||||
"gen": {
|
||||
"enable_normalization": "ノーマライズを有効にする",
|
||||
"enable_normalization_info": "クリッピングを防ぎ、ラウドネスを一定にするために、オーディオ音量をターゲットピークレベルに正規化します。",
|
||||
"normalization_db": "ターゲットピーク (dB)",
|
||||
"normalization_db_info": "デシベル単位のターゲットピークレベル。-1.0 dBが標準的な安全ピークです。-0.1 dBが最大です。",
|
||||
"advanced_dit_params": "DiT 詳細パラメータ",
|
||||
"latent_shift": "Latent シフト",
|
||||
"latent_shift_info": "VAEデコード前にDiT latentに適用するシフト量。デフォルト0(シフトなし)。負の値(例: -0.04)でクリッピングを軽減できます。",
|
||||
"latent_rescale": "Latent リスケール",
|
||||
"latent_rescale_info": "VAEデコード前にDiT latentに適用するスケール係数。デフォルト1.0(スケーリングなし)。1.0未満の値(例: 0.91)でクリッピングを軽減できます。"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -53,12 +53,19 @@
|
|||
"compile_model_info": "使用 torch.compile 优化模型(量化必需)",
|
||||
"quantization_label": "INT8 量化",
|
||||
"quantization_info": "启用 INT8 仅权重量化以减少显存占用(需要启用编译模型)",
|
||||
"mlx_dit_label": "MLX DiT (Apple Silicon)",
|
||||
"mlx_dit_info_enabled": "使用原生 MLX 加速 DiT 扩散推理(比 MPS 更快)",
|
||||
"mlx_dit_info_disabled": "MLX 不可用(需要 macOS + Apple Silicon + mlx 包)",
|
||||
"init_btn": "初始化服务",
|
||||
"status_label": "状态",
|
||||
"language_label": "界面语言",
|
||||
"language_info": "选择界面语言"
|
||||
"language_info": "选择界面语言",
|
||||
"gpu_auto_tier": "自动检测层级",
|
||||
"tier_label": "GPU 层级覆盖",
|
||||
"tier_info": "手动选择 GPU 层级以调整优化默认值(卸载、量化、后端等)"
|
||||
},
|
||||
"generation": {
|
||||
"tab_title": "🎵 生成",
|
||||
"required_inputs": "📝 必需输入",
|
||||
"task_type_label": "任务类型",
|
||||
"task_type_info": "选择生成的任务类型",
|
||||
|
|
@ -71,8 +78,11 @@
|
|||
"track_classes_info": "为complete任务选择多个音轨类别",
|
||||
"audio_uploads": "🎵 音频上传",
|
||||
"reference_audio": "参考音频(可选)",
|
||||
"source_audio": "源音频(可选)",
|
||||
"source_audio": "源音频",
|
||||
"convert_codes_btn": "转换为代码",
|
||||
"analyze_btn": "🔍 分析",
|
||||
"sample_btn": "🎲 试试看",
|
||||
"load_btn": "📂 加载",
|
||||
"lm_codes_hints": "🎼 LM 代码提示",
|
||||
"lm_codes_label": "LM 代码提示",
|
||||
"lm_codes_placeholder": "<|audio_code_10695|><|audio_code_54246|>...",
|
||||
|
|
@ -84,13 +94,20 @@
|
|||
"repainting_start": "重绘开始",
|
||||
"repainting_end": "重绘结束",
|
||||
"mode_label": "生成模式",
|
||||
"mode_info": "简单模式:用自然语言描述音乐。自定义模式:完全控制描述和歌词。",
|
||||
"mode_info": "选择一种生成模式开始创作。",
|
||||
"mode_info_simple": "用自然语言描述你想要的音乐,AI 将自动生成描述、歌词和元数据。",
|
||||
"mode_info_custom": "完全控制描述、歌词和所有参数。",
|
||||
"mode_info_remix": "上传源音频,用你的描述和歌词创建混音版本。",
|
||||
"mode_info_repaint": "上传源音频,重绘指定时间范围的内容。",
|
||||
"mode_info_extract": "从源音频中提取特定音轨(人声、鼓等)。",
|
||||
"mode_info_lego": "重新组装音轨:替换源音频中的特定音轨。",
|
||||
"mode_info_complete": "补全源音频中缺失的音轨。",
|
||||
"mode_simple": "简单",
|
||||
"mode_custom": "自定义",
|
||||
"simple_query_label": "歌曲描述",
|
||||
"simple_query_placeholder": "描述你想创作的音乐,例如:'给我生成一首暗黑的戏剧古风,歌词要华丽'。留空则随机生成样本。",
|
||||
"simple_query_info": "输入你想生成的音乐的自然语言描述",
|
||||
"simple_vocal_language_label": "人声语言(可选)",
|
||||
"simple_vocal_language_label": "人声语言",
|
||||
"simple_vocal_language_info": "选择歌词的首选语言。使用 'unknown' 表示任意语言。",
|
||||
"create_sample_btn": "创建样本",
|
||||
"caption_title": "📝 音乐描述",
|
||||
|
|
@ -103,12 +120,20 @@
|
|||
"lyrics_info": "带有结构的歌曲歌词",
|
||||
"instrumental_label": "纯音乐",
|
||||
"format_btn": "格式化",
|
||||
"format_caption_btn": "增强描述",
|
||||
"format_lyrics_btn": "增强歌词",
|
||||
"optional_params": "⚙️ 可选参数",
|
||||
"optional_music_props": "🎵 音乐属性",
|
||||
"optional_gen_settings": "📐 生成设置",
|
||||
"advanced_dit_section": "🎛️ DiT 扩散参数",
|
||||
"advanced_lm_section": "🤖 LM 生成参数",
|
||||
"advanced_output_section": "🔊 音频输出与后处理",
|
||||
"advanced_automation_section": "⚡ 自动化与批量",
|
||||
"vocal_language_label": "人声语言(可选)",
|
||||
"vocal_language_info": "纯音乐使用 `unknown`",
|
||||
"vocal_language_info": "'unknown' = 纯音乐/自动",
|
||||
"bpm_label": "BPM(可选)",
|
||||
"bpm_info": "留空表示N/A",
|
||||
"keyscale_label": "调性(可选)",
|
||||
"keyscale_label": "调 (可选)",
|
||||
"keyscale_placeholder": "留空表示N/A",
|
||||
"keyscale_info": "A-G, #/♭, 大调/小调",
|
||||
"timesig_label": "拍号(可选)",
|
||||
|
|
@ -117,7 +142,7 @@
|
|||
"duration_info": "使用-1表示随机",
|
||||
"batch_size_label": "批量大小",
|
||||
"batch_size_info": "要生成的音频数量(最多8个)",
|
||||
"advanced_settings": "🔧 高级设置",
|
||||
"advanced_settings": "⚙️ 设置",
|
||||
"inference_steps_label": "DiT 推理步数",
|
||||
"inference_steps_info": "Turbo: 最多8, Base: 最多200",
|
||||
"guidance_scale_label": "DiT 引导比例(仅支持base模型)",
|
||||
|
|
@ -150,8 +175,8 @@
|
|||
"lm_negative_prompt_label": "LM 负面提示",
|
||||
"lm_negative_prompt_placeholder": "输入CFG的负面提示(默认: NO USER INPUT)",
|
||||
"lm_negative_prompt_info": "负面提示(当LM CFG比例 > 1.0时使用)",
|
||||
"cot_metas_label": "CoT 元数据",
|
||||
"cot_metas_info": "使用LM生成CoT元数据(取消勾选以跳过LM CoT生成)",
|
||||
"cot_metas_label": "CoT メタデータ",
|
||||
"cot_metas_info": "LMを使用してCoTメタデータを生成(チェックを外すとLM CoT生成をスキップ)",
|
||||
"cot_language_label": "CoT 语言",
|
||||
"cot_language_info": "在CoT中生成语言(思维链)",
|
||||
"constrained_debug_label": "约束解码调试",
|
||||
|
|
@ -168,21 +193,31 @@
|
|||
"similarity_denoise_info": "控制输出与参考音频的贴合程度。数值越高保留越多结构。",
|
||||
"cover_strength_label": "音频覆盖强度",
|
||||
"cover_strength_info": "控制使用覆盖模式的去噪步骤数量",
|
||||
"remix_strength_label": "混音强度",
|
||||
"remix_strength_info": "控制混音与原始音频的偏离程度(越低越接近原始音频)",
|
||||
"cover_noise_strength_label": "翻唱强度",
|
||||
"cover_noise_strength_info": "控制 Remix 模式下的旋律还原程度。推荐:使用 SFT 模型,值设为 0.1–0.25。稍微提升即可还原旋律,但风格迁移可能需要额外的 prompt 调参。(0 = 纯噪声/无翻唱效果,1 = 最接近原始音频)",
|
||||
"score_sensitivity_label": "质量评分敏感度",
|
||||
"score_sensitivity_info": "更低 = 更敏感(默认: 1.0). 调整PMI如何映射到[0,1]",
|
||||
"think_label": "思考",
|
||||
"parallel_thinking_label": "并行思考",
|
||||
"parallel_thinking_info": "并行处理批量样本以加速生成",
|
||||
"generate_btn": "🎵 生成音乐",
|
||||
"autogen_label": "自动生成",
|
||||
"caption_rewrite_label": "描述重写"
|
||||
"caption_rewrite_label": "描述重写",
|
||||
"caption_rewrite_info": "生成前使用LM重写描述",
|
||||
"advanced_dit_params": "DiT 高级参数"
|
||||
},
|
||||
"results": {
|
||||
"title": "🎵 结果",
|
||||
"generated_music": "🎵 生成的音乐(样本 {n})",
|
||||
"send_to_src_btn": "🔗 发送到源音频",
|
||||
"send_to_remix_btn": "🔗 发送到混音",
|
||||
"send_to_repaint_btn": "🔗 发送到重绘",
|
||||
"save_btn": "💾 保存",
|
||||
"score_btn": "📊 评分",
|
||||
"lrc_btn": "🎵 LRC",
|
||||
"score_btn": "📊 获取评分",
|
||||
"lrc_btn": "🎵 获取 LRC",
|
||||
"save_lrc_btn": "💾 保存 LRC",
|
||||
"convert_to_codes_btn": "🔄 转换为代码",
|
||||
"quality_score_label": "质量分数(样本 {n})",
|
||||
"quality_score_placeholder": "点击'评分'以计算基于困惑度的质量分数",
|
||||
"codes_label": "LM 代码(样本 {n})",
|
||||
|
|
@ -214,6 +249,7 @@
|
|||
"lm_generated": "🤖 使用LM生成的示例",
|
||||
"lm_fallback": "使用LM生成示例失败,回退到示例目录",
|
||||
"lm_not_initialized": "❌ 5Hz LM未初始化。请先初始化它。",
|
||||
"think_requires_lm": "⚠️ \"Think\"功能需要先初始化 5Hz LM。Think 已自动关闭,生成将不使用 LM 思考。",
|
||||
"autogen_enabled": "🔄 已启用自动生成 - 下一批次将在此之后生成",
|
||||
"batch_ready": "✅ 批次 {n} 就绪!点击'下一个'查看。",
|
||||
"batch_generating": "🔄 开始为批次 {n} 进行后台生成...",
|
||||
|
|
@ -252,6 +288,7 @@
|
|||
"tab_train_lora": "🚀 训练 LoRA",
|
||||
"quick_start_title": "🚀 快速开始",
|
||||
"load_dataset_label": "数据集 JSON 路径",
|
||||
"load_dataset_info": "加载之前保存的数据集",
|
||||
"load_btn": "📂 加载",
|
||||
"load_status": "加载状态",
|
||||
"scan_label": "音频目录路径",
|
||||
|
|
@ -272,6 +309,9 @@
|
|||
"step5_title": "步骤 5:预处理为张量",
|
||||
"step5_intro": "**预处理将您的数据集转换为预计算的张量,以便快速训练。**\n\n您可以:\n- 使用上述步骤 1-4 的数据集,**或者**\n- 加载现有的数据集 JSON 文件(如果您已经保存了一个)",
|
||||
"step5_details": "此步骤:\n- 将音频编码为 VAE 潜变量\n- 将描述和歌词编码为文本嵌入\n- 运行条件编码器\n- 将所有张量保存到 `.pt` 文件\n\n⚠️ **这需要加载模型,可能需要几分钟。**",
|
||||
"train_section_tensors": "预处理数据集选择",
|
||||
"train_section_lora": "LoRA 设置",
|
||||
"train_section_params": "训练参数",
|
||||
"train_tensor_selection_desc": "选择包含预处理张量文件(`.pt` 文件)的目录。\n这些文件在「数据集构建」标签页中使用「预处理」按钮创建。",
|
||||
"all_instrumental": "全部为纯音乐",
|
||||
"all_instrumental_info": "勾选表示所有曲目均为纯音乐(无人声)",
|
||||
|
|
@ -350,5 +390,15 @@
|
|||
"export_path": "导出路径",
|
||||
"export_lora_btn": "📦 导出 LoRA",
|
||||
"export_status": "导出状态"
|
||||
},
|
||||
"gen": {
|
||||
"enable_normalization": "启用归一化",
|
||||
"enable_normalization_info": "将音频音量归一化到目标峰值电平,以防止削波或确保一致的响度。",
|
||||
"normalization_db": "目标峰值 (dB)",
|
||||
"normalization_db_info": "分贝单位的目标峰值电平。-1.0 dB 是标准安全峰值。-0.1 dB 是最大值。",
|
||||
"latent_shift": "Latent 偏移",
|
||||
"latent_shift_info": "VAE 解码前对 DiT latent 施加的偏移量。默认 0(不偏移)。负值(如 -0.04)可减少爆音。",
|
||||
"latent_rescale": "Latent 缩放",
|
||||
"latent_rescale_info": "VAE 解码前对 DiT latent 施加的缩放因子。默认 1.0(不缩放)。小于 1.0 的值(如 0.91)可减少爆音。"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1,11 +1,33 @@
|
|||
"""
|
||||
Gradio UI Components Module
|
||||
Contains all Gradio interface component definitions and layouts
|
||||
|
||||
Layout:
|
||||
┌──────────────────────────────────────┐
|
||||
│ Header │
|
||||
├──────────────────────────────────────┤
|
||||
│ Dataset Explorer (hidden accordion) │
|
||||
├──────────────────────────────────────┤
|
||||
│ Settings (accordion, collapsed) │
|
||||
│ ├─ Service Configuration │
|
||||
│ ├─ DiT Parameters │
|
||||
│ ├─ LM Parameters │
|
||||
│ └─ Output / Automation │
|
||||
├──────────────────────────────────────┤
|
||||
│ ┌─ Generation ─┬─ Training ──────┐ │
|
||||
│ │ Mode Radio │ Dataset/LoRA │ │
|
||||
│ │ Inputs │ │ │
|
||||
│ │ Results │ │ │
|
||||
│ └───────────────┴────────────────┘ │
|
||||
└──────────────────────────────────────┘
|
||||
"""
|
||||
import gradio as gr
|
||||
from acestep.gradio_ui.i18n import get_i18n, t
|
||||
from acestep.gradio_ui.interfaces.dataset import create_dataset_section
|
||||
from acestep.gradio_ui.interfaces.generation import create_generation_section
|
||||
from acestep.gradio_ui.interfaces.generation import (
|
||||
create_advanced_settings_section,
|
||||
create_generation_tab_section,
|
||||
)
|
||||
from acestep.gradio_ui.interfaces.result import create_results_section
|
||||
from acestep.gradio_ui.interfaces.training import create_training_section
|
||||
from acestep.gradio_ui.events import setup_event_handlers, setup_training_event_handlers
|
||||
|
|
@ -29,6 +51,9 @@ def create_gradio_interface(dit_handler, llm_handler, dataset_handler, init_para
|
|||
# Initialize i18n with selected language
|
||||
i18n = get_i18n(language)
|
||||
|
||||
# Check if running in service mode (hide training tab)
|
||||
service_mode = init_params is not None and init_params.get('service_mode', False)
|
||||
|
||||
with gr.Blocks(
|
||||
title=t("app.title"),
|
||||
theme=gr.themes.Soft(),
|
||||
|
|
@ -62,6 +87,34 @@ def create_gradio_interface(dit_handler, llm_handler, dataset_handler, init_para
|
|||
.component-wrapper > .timestamps {
|
||||
transform: translateY(15px);
|
||||
}
|
||||
/* Equal-height row for instrumental checkbox + enhance lyrics button */
|
||||
.instrumental-row {
|
||||
align-items: stretch !important;
|
||||
}
|
||||
.instrumental-row > div {
|
||||
display: flex !important;
|
||||
align-items: stretch !important;
|
||||
}
|
||||
.instrumental-row > div > div {
|
||||
flex: 1;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
}
|
||||
.instrumental-row button {
|
||||
height: 100% !important;
|
||||
min-height: 42px;
|
||||
}
|
||||
/* Ensure buttons in instrumental-row fill height */
|
||||
.instrumental-row > div > button {
|
||||
height: 100% !important;
|
||||
min-height: 42px;
|
||||
}
|
||||
/* Two-line icon buttons: emoji on top, text below */
|
||||
.icon-btn-wrap button, .icon-btn-wrap > button {
|
||||
word-spacing: 100vw;
|
||||
text-align: center;
|
||||
line-height: 1.4;
|
||||
}
|
||||
""",
|
||||
) as demo:
|
||||
|
||||
|
|
@ -72,21 +125,52 @@ def create_gradio_interface(dit_handler, llm_handler, dataset_handler, init_para
|
|||
</div>
|
||||
""")
|
||||
|
||||
# Dataset Explorer Section
|
||||
# Dataset Explorer Section (hidden)
|
||||
dataset_section = create_dataset_section(dataset_handler)
|
||||
|
||||
# Generation Section (pass init_params and language to support pre-initialization)
|
||||
generation_section = create_generation_section(dit_handler, llm_handler, init_params=init_params, language=language)
|
||||
# ═══════════════════════════════════════════
|
||||
# Top-level: Settings (contains Service Config + Advanced Settings)
|
||||
# ═══════════════════════════════════════════
|
||||
settings_section = create_advanced_settings_section(
|
||||
dit_handler, llm_handler, init_params=init_params, language=language
|
||||
)
|
||||
|
||||
# Results Section
|
||||
results_section = create_results_section(dit_handler)
|
||||
# ═══════════════════════════════════════════
|
||||
# Tabs: Generation | Training
|
||||
# ═══════════════════════════════════════════
|
||||
with gr.Tabs():
|
||||
# --- Generation Tab ---
|
||||
with gr.Tab(t("generation.tab_title")):
|
||||
gen_section = create_generation_tab_section(
|
||||
dit_handler, llm_handler, init_params=init_params, language=language
|
||||
)
|
||||
|
||||
# Results Section (inside the Generation tab, wrapped for visibility control)
|
||||
with gr.Column(visible=True) as results_wrapper:
|
||||
results_section = create_results_section(dit_handler)
|
||||
# Store the wrapper in gen_section so event handlers can toggle it
|
||||
gen_section["results_wrapper"] = results_wrapper
|
||||
|
||||
# --- Training Tab ---
|
||||
with gr.Tab(t("training.tab_title"), visible=not service_mode):
|
||||
training_section = create_training_section(
|
||||
dit_handler, llm_handler, init_params=init_params
|
||||
)
|
||||
|
||||
# Training Section (LoRA training and dataset builder)
|
||||
# Pass init_params to support hiding in service mode
|
||||
training_section = create_training_section(dit_handler, llm_handler, init_params=init_params)
|
||||
# ═══════════════════════════════════════════
|
||||
# Merge all generation-related component dicts for event wiring
|
||||
# ═══════════════════════════════════════════
|
||||
# The event handlers expect a single "generation_section" dict with all
|
||||
# components from settings (service config + advanced) and generation tab.
|
||||
generation_section = {}
|
||||
generation_section.update(settings_section)
|
||||
generation_section.update(gen_section)
|
||||
|
||||
# Connect event handlers
|
||||
setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, dataset_section, generation_section, results_section)
|
||||
setup_event_handlers(
|
||||
demo, dit_handler, llm_handler, dataset_handler,
|
||||
dataset_section, generation_section, results_section
|
||||
)
|
||||
|
||||
# Connect training event handlers
|
||||
setup_training_event_handlers(demo, dit_handler, llm_handler, training_section)
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -6,6 +6,93 @@ import gradio as gr
|
|||
from acestep.gradio_ui.i18n import t
|
||||
|
||||
|
||||
def _create_audio_column(n, visible=True):
|
||||
"""Create a single audio sample column with all its sub-components.
|
||||
|
||||
Layout:
|
||||
Audio player
|
||||
Row: [Send To Cover] [Send To Repaint] [Save]
|
||||
Accordion (Score & LRC & LM Codes):
|
||||
codes_display
|
||||
Row: score_display + score_btn
|
||||
Row: lrc_display + lrc_btn
|
||||
"""
|
||||
with gr.Column(visible=visible) as audio_col:
|
||||
generated_audio = gr.Audio(
|
||||
label=t("results.generated_music", n=n),
|
||||
type="filepath",
|
||||
interactive=False,
|
||||
buttons=[]
|
||||
)
|
||||
with gr.Row(equal_height=True):
|
||||
send_to_remix_btn = gr.Button(
|
||||
t("results.send_to_remix_btn"),
|
||||
variant="secondary", size="sm", scale=1
|
||||
)
|
||||
send_to_repaint_btn = gr.Button(
|
||||
t("results.send_to_repaint_btn"),
|
||||
variant="secondary", size="sm", scale=1
|
||||
)
|
||||
save_btn = gr.Button(
|
||||
t("results.save_btn"),
|
||||
variant="primary", size="sm", scale=1
|
||||
)
|
||||
with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion:
|
||||
codes_display = gr.Textbox(
|
||||
label=t("results.codes_label", n=n),
|
||||
interactive=False, buttons=["copy"],
|
||||
lines=4, max_lines=4, visible=True
|
||||
)
|
||||
convert_to_codes_btn = gr.Button(
|
||||
t("results.convert_to_codes_btn"),
|
||||
variant="secondary", size="sm"
|
||||
)
|
||||
score_display = gr.Textbox(
|
||||
label=t("results.quality_score_label", n=n),
|
||||
interactive=False, buttons=["copy"],
|
||||
lines=6, max_lines=6, visible=True
|
||||
)
|
||||
score_btn = gr.Button(
|
||||
t("results.score_btn"),
|
||||
variant="secondary", size="sm"
|
||||
)
|
||||
lrc_display = gr.Textbox(
|
||||
label=t("results.lrc_label", n=n),
|
||||
interactive=True, buttons=["copy"],
|
||||
lines=8, max_lines=8, visible=True
|
||||
)
|
||||
with gr.Row(equal_height=True):
|
||||
lrc_btn = gr.Button(
|
||||
t("results.lrc_btn"),
|
||||
variant="secondary", size="sm"
|
||||
)
|
||||
save_lrc_btn = gr.Button(
|
||||
t("results.save_lrc_btn"),
|
||||
variant="secondary", size="sm"
|
||||
)
|
||||
lrc_download_file = gr.File(
|
||||
label="LRC Download",
|
||||
visible=False,
|
||||
interactive=False,
|
||||
)
|
||||
return {
|
||||
"audio_col": audio_col,
|
||||
"generated_audio": generated_audio,
|
||||
"send_to_remix_btn": send_to_remix_btn,
|
||||
"send_to_repaint_btn": send_to_repaint_btn,
|
||||
"save_btn": save_btn,
|
||||
"details_accordion": details_accordion,
|
||||
"codes_display": codes_display,
|
||||
"convert_to_codes_btn": convert_to_codes_btn,
|
||||
"score_display": score_display,
|
||||
"score_btn": score_btn,
|
||||
"lrc_display": lrc_display,
|
||||
"lrc_btn": lrc_btn,
|
||||
"save_lrc_btn": save_lrc_btn,
|
||||
"lrc_download_file": lrc_download_file,
|
||||
}
|
||||
|
||||
|
||||
def create_results_section(dit_handler) -> dict:
|
||||
"""Create results display section"""
|
||||
with gr.Accordion(t("results.title"), open=True):
|
||||
|
|
@ -16,393 +103,25 @@ def create_results_section(dit_handler) -> dict:
|
|||
is_format_caption_state = gr.State(value=False)
|
||||
|
||||
# Batch management states
|
||||
current_batch_index = gr.State(value=0) # Currently displayed batch index
|
||||
total_batches = gr.State(value=1) # Total number of batches generated
|
||||
batch_queue = gr.State(value={}) # Dictionary storing all batch data
|
||||
generation_params_state = gr.State(value={}) # Store generation parameters for next batches
|
||||
is_generating_background = gr.State(value=False) # Background generation flag
|
||||
current_batch_index = gr.State(value=0)
|
||||
total_batches = gr.State(value=1)
|
||||
batch_queue = gr.State(value={})
|
||||
generation_params_state = gr.State(value={})
|
||||
is_generating_background = gr.State(value=False)
|
||||
|
||||
# All audio components in one row with dynamic visibility
|
||||
# Row 1: samples 1-4
|
||||
with gr.Row():
|
||||
with gr.Column(visible=True) as audio_col_1:
|
||||
generated_audio_1 = gr.Audio(
|
||||
label=t("results.generated_music", n=1),
|
||||
type="filepath",
|
||||
interactive=False,
|
||||
buttons=[]
|
||||
)
|
||||
with gr.Row(equal_height=True):
|
||||
send_to_src_btn_1 = gr.Button(
|
||||
t("results.send_to_src_btn"),
|
||||
variant="secondary",
|
||||
size="sm",
|
||||
scale=1
|
||||
)
|
||||
save_btn_1 = gr.Button(
|
||||
t("results.save_btn"),
|
||||
variant="primary",
|
||||
size="sm",
|
||||
scale=1
|
||||
)
|
||||
score_btn_1 = gr.Button(
|
||||
t("results.score_btn"),
|
||||
variant="secondary",
|
||||
size="sm",
|
||||
scale=1
|
||||
)
|
||||
lrc_btn_1 = gr.Button(
|
||||
t("results.lrc_btn"),
|
||||
variant="secondary",
|
||||
size="sm",
|
||||
scale=1
|
||||
)
|
||||
with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_1:
|
||||
codes_display_1 = gr.Textbox(
|
||||
label=t("results.codes_label", n=1),
|
||||
interactive=False,
|
||||
buttons=["copy"],
|
||||
lines=4,
|
||||
max_lines=4,
|
||||
visible=True
|
||||
)
|
||||
score_display_1 = gr.Textbox(
|
||||
label=t("results.quality_score_label", n=1),
|
||||
interactive=False,
|
||||
buttons=["copy"],
|
||||
lines=6,
|
||||
max_lines=6,
|
||||
visible=True
|
||||
)
|
||||
lrc_display_1 = gr.Textbox(
|
||||
label=t("results.lrc_label", n=1),
|
||||
interactive=True,
|
||||
buttons=["copy"],
|
||||
lines=8,
|
||||
max_lines=8,
|
||||
visible=True
|
||||
)
|
||||
with gr.Column(visible=True) as audio_col_2:
|
||||
generated_audio_2 = gr.Audio(
|
||||
label=t("results.generated_music", n=2),
|
||||
type="filepath",
|
||||
interactive=False,
|
||||
buttons=[]
|
||||
)
|
||||
with gr.Row(equal_height=True):
|
||||
send_to_src_btn_2 = gr.Button(
|
||||
t("results.send_to_src_btn"),
|
||||
variant="secondary",
|
||||
size="sm",
|
||||
scale=1
|
||||
)
|
||||
save_btn_2 = gr.Button(
|
||||
t("results.save_btn"),
|
||||
variant="primary",
|
||||
size="sm",
|
||||
scale=1
|
||||
)
|
||||
score_btn_2 = gr.Button(
|
||||
t("results.score_btn"),
|
||||
variant="secondary",
|
||||
size="sm",
|
||||
scale=1
|
||||
)
|
||||
lrc_btn_2 = gr.Button(
|
||||
t("results.lrc_btn"),
|
||||
variant="secondary",
|
||||
size="sm",
|
||||
scale=1
|
||||
)
|
||||
with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_2:
|
||||
codes_display_2 = gr.Textbox(
|
||||
label=t("results.codes_label", n=2),
|
||||
interactive=False,
|
||||
buttons=["copy"],
|
||||
lines=4,
|
||||
max_lines=4,
|
||||
visible=True
|
||||
)
|
||||
score_display_2 = gr.Textbox(
|
||||
label=t("results.quality_score_label", n=2),
|
||||
interactive=False,
|
||||
buttons=["copy"],
|
||||
lines=6,
|
||||
max_lines=6,
|
||||
visible=True
|
||||
)
|
||||
lrc_display_2 = gr.Textbox(
|
||||
label=t("results.lrc_label", n=2),
|
||||
interactive=True,
|
||||
buttons=["copy"],
|
||||
lines=8,
|
||||
max_lines=8,
|
||||
visible=True
|
||||
)
|
||||
with gr.Column(visible=False) as audio_col_3:
|
||||
generated_audio_3 = gr.Audio(
|
||||
label=t("results.generated_music", n=3),
|
||||
type="filepath",
|
||||
interactive=False,
|
||||
buttons=[]
|
||||
)
|
||||
with gr.Row(equal_height=True):
|
||||
send_to_src_btn_3 = gr.Button(
|
||||
t("results.send_to_src_btn"),
|
||||
variant="secondary",
|
||||
size="sm",
|
||||
scale=1
|
||||
)
|
||||
save_btn_3 = gr.Button(
|
||||
t("results.save_btn"),
|
||||
variant="primary",
|
||||
size="sm",
|
||||
scale=1
|
||||
)
|
||||
score_btn_3 = gr.Button(
|
||||
t("results.score_btn"),
|
||||
variant="secondary",
|
||||
size="sm",
|
||||
scale=1
|
||||
)
|
||||
lrc_btn_3 = gr.Button(
|
||||
t("results.lrc_btn"),
|
||||
variant="secondary",
|
||||
size="sm",
|
||||
scale=1
|
||||
)
|
||||
with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_3:
|
||||
codes_display_3 = gr.Textbox(
|
||||
label=t("results.codes_label", n=3),
|
||||
interactive=False,
|
||||
buttons=["copy"],
|
||||
lines=4,
|
||||
max_lines=4,
|
||||
visible=True
|
||||
)
|
||||
score_display_3 = gr.Textbox(
|
||||
label=t("results.quality_score_label", n=3),
|
||||
interactive=False,
|
||||
buttons=["copy"],
|
||||
lines=6,
|
||||
max_lines=6,
|
||||
visible=True
|
||||
)
|
||||
lrc_display_3 = gr.Textbox(
|
||||
label=t("results.lrc_label", n=3),
|
||||
interactive=True,
|
||||
buttons=["copy"],
|
||||
lines=8,
|
||||
max_lines=8,
|
||||
visible=True
|
||||
)
|
||||
with gr.Column(visible=False) as audio_col_4:
|
||||
generated_audio_4 = gr.Audio(
|
||||
label=t("results.generated_music", n=4),
|
||||
type="filepath",
|
||||
interactive=False,
|
||||
buttons=[]
|
||||
)
|
||||
with gr.Row(equal_height=True):
|
||||
send_to_src_btn_4 = gr.Button(
|
||||
t("results.send_to_src_btn"),
|
||||
variant="secondary",
|
||||
size="sm",
|
||||
scale=1
|
||||
)
|
||||
save_btn_4 = gr.Button(
|
||||
t("results.save_btn"),
|
||||
variant="primary",
|
||||
size="sm",
|
||||
scale=1
|
||||
)
|
||||
score_btn_4 = gr.Button(
|
||||
t("results.score_btn"),
|
||||
variant="secondary",
|
||||
size="sm",
|
||||
scale=1
|
||||
)
|
||||
lrc_btn_4 = gr.Button(
|
||||
t("results.lrc_btn"),
|
||||
variant="secondary",
|
||||
size="sm",
|
||||
scale=1
|
||||
)
|
||||
with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_4:
|
||||
codes_display_4 = gr.Textbox(
|
||||
label=t("results.codes_label", n=4),
|
||||
interactive=False,
|
||||
buttons=["copy"],
|
||||
lines=4,
|
||||
max_lines=4,
|
||||
visible=True
|
||||
)
|
||||
score_display_4 = gr.Textbox(
|
||||
label=t("results.quality_score_label", n=4),
|
||||
interactive=False,
|
||||
buttons=["copy"],
|
||||
lines=6,
|
||||
max_lines=6,
|
||||
visible=True
|
||||
)
|
||||
lrc_display_4 = gr.Textbox(
|
||||
label=t("results.lrc_label", n=4),
|
||||
interactive=True,
|
||||
buttons=["copy"],
|
||||
lines=8,
|
||||
max_lines=8,
|
||||
visible=True
|
||||
)
|
||||
cols_1_4 = []
|
||||
for i in range(1, 5):
|
||||
cols_1_4.append(_create_audio_column(i, visible=(i <= 2)))
|
||||
|
||||
# Second row for batch size 5-8 (initially hidden)
|
||||
# Row 2: samples 5-8 (initially hidden)
|
||||
with gr.Row(visible=False) as audio_row_5_8:
|
||||
with gr.Column() as audio_col_5:
|
||||
generated_audio_5 = gr.Audio(
|
||||
label=t("results.generated_music", n=5),
|
||||
type="filepath",
|
||||
interactive=False,
|
||||
buttons=[]
|
||||
)
|
||||
with gr.Row(equal_height=True):
|
||||
send_to_src_btn_5 = gr.Button(t("results.send_to_src_btn"), variant="secondary", size="sm", scale=1)
|
||||
save_btn_5 = gr.Button(t("results.save_btn"), variant="primary", size="sm", scale=1)
|
||||
score_btn_5 = gr.Button(t("results.score_btn"), variant="secondary", size="sm", scale=1)
|
||||
lrc_btn_5 = gr.Button(t("results.lrc_btn"), variant="secondary", size="sm", scale=1)
|
||||
with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_5:
|
||||
codes_display_5 = gr.Textbox(
|
||||
label=t("results.codes_label", n=5),
|
||||
interactive=False,
|
||||
buttons=["copy"],
|
||||
lines=4,
|
||||
max_lines=4,
|
||||
visible=True
|
||||
)
|
||||
score_display_5 = gr.Textbox(
|
||||
label=t("results.quality_score_label", n=5),
|
||||
interactive=False,
|
||||
buttons=["copy"],
|
||||
lines=6,
|
||||
max_lines=6,
|
||||
visible=True
|
||||
)
|
||||
lrc_display_5 = gr.Textbox(
|
||||
label=t("results.lrc_label", n=5),
|
||||
interactive=True,
|
||||
buttons=["copy"],
|
||||
lines=8,
|
||||
max_lines=8,
|
||||
visible=True
|
||||
)
|
||||
with gr.Column() as audio_col_6:
|
||||
generated_audio_6 = gr.Audio(
|
||||
label=t("results.generated_music", n=6),
|
||||
type="filepath",
|
||||
interactive=False,
|
||||
buttons=[]
|
||||
)
|
||||
with gr.Row(equal_height=True):
|
||||
send_to_src_btn_6 = gr.Button(t("results.send_to_src_btn"), variant="secondary", size="sm", scale=1)
|
||||
save_btn_6 = gr.Button(t("results.save_btn"), variant="primary", size="sm", scale=1)
|
||||
score_btn_6 = gr.Button(t("results.score_btn"), variant="secondary", size="sm", scale=1)
|
||||
lrc_btn_6 = gr.Button(t("results.lrc_btn"), variant="secondary", size="sm", scale=1)
|
||||
with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_6:
|
||||
codes_display_6 = gr.Textbox(
|
||||
label=t("results.codes_label", n=6),
|
||||
interactive=False,
|
||||
buttons=["copy"],
|
||||
lines=4,
|
||||
max_lines=4,
|
||||
visible=True
|
||||
)
|
||||
score_display_6 = gr.Textbox(
|
||||
label=t("results.quality_score_label", n=6),
|
||||
interactive=False,
|
||||
buttons=["copy"],
|
||||
lines=6,
|
||||
max_lines=6,
|
||||
visible=True
|
||||
)
|
||||
lrc_display_6 = gr.Textbox(
|
||||
label=t("results.lrc_label", n=6),
|
||||
interactive=True,
|
||||
buttons=["copy"],
|
||||
lines=8,
|
||||
max_lines=8,
|
||||
visible=True
|
||||
)
|
||||
with gr.Column() as audio_col_7:
|
||||
generated_audio_7 = gr.Audio(
|
||||
label=t("results.generated_music", n=7),
|
||||
type="filepath",
|
||||
interactive=False,
|
||||
buttons=[]
|
||||
)
|
||||
with gr.Row(equal_height=True):
|
||||
send_to_src_btn_7 = gr.Button(t("results.send_to_src_btn"), variant="secondary", size="sm", scale=1)
|
||||
save_btn_7 = gr.Button(t("results.save_btn"), variant="primary", size="sm", scale=1)
|
||||
score_btn_7 = gr.Button(t("results.score_btn"), variant="secondary", size="sm", scale=1)
|
||||
lrc_btn_7 = gr.Button(t("results.lrc_btn"), variant="secondary", size="sm", scale=1)
|
||||
with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_7:
|
||||
codes_display_7 = gr.Textbox(
|
||||
label=t("results.codes_label", n=7),
|
||||
interactive=False,
|
||||
buttons=["copy"],
|
||||
lines=4,
|
||||
max_lines=4,
|
||||
visible=True
|
||||
)
|
||||
score_display_7 = gr.Textbox(
|
||||
label=t("results.quality_score_label", n=7),
|
||||
interactive=False,
|
||||
buttons=["copy"],
|
||||
lines=6,
|
||||
max_lines=6,
|
||||
visible=True
|
||||
)
|
||||
lrc_display_7 = gr.Textbox(
|
||||
label=t("results.lrc_label", n=7),
|
||||
interactive=True,
|
||||
buttons=["copy"],
|
||||
lines=8,
|
||||
max_lines=8,
|
||||
visible=True
|
||||
)
|
||||
with gr.Column() as audio_col_8:
|
||||
generated_audio_8 = gr.Audio(
|
||||
label=t("results.generated_music", n=8),
|
||||
type="filepath",
|
||||
interactive=False,
|
||||
buttons=[]
|
||||
)
|
||||
with gr.Row(equal_height=True):
|
||||
send_to_src_btn_8 = gr.Button(t("results.send_to_src_btn"), variant="secondary", size="sm", scale=1)
|
||||
save_btn_8 = gr.Button(t("results.save_btn"), variant="primary", size="sm", scale=1)
|
||||
score_btn_8 = gr.Button(t("results.score_btn"), variant="secondary", size="sm", scale=1)
|
||||
lrc_btn_8 = gr.Button(t("results.lrc_btn"), variant="secondary", size="sm", scale=1)
|
||||
with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_8:
|
||||
codes_display_8 = gr.Textbox(
|
||||
label=t("results.codes_label", n=8),
|
||||
interactive=False,
|
||||
buttons=["copy"],
|
||||
lines=4,
|
||||
max_lines=4,
|
||||
visible=True
|
||||
)
|
||||
score_display_8 = gr.Textbox(
|
||||
label=t("results.quality_score_label", n=8),
|
||||
interactive=False,
|
||||
buttons=["copy"],
|
||||
lines=6,
|
||||
max_lines=6,
|
||||
visible=True
|
||||
)
|
||||
lrc_display_8 = gr.Textbox(
|
||||
label=t("results.lrc_label", n=8),
|
||||
interactive=True,
|
||||
buttons=["copy"],
|
||||
lines=8,
|
||||
max_lines=8,
|
||||
visible=True
|
||||
)
|
||||
cols_5_8 = []
|
||||
for i in range(5, 9):
|
||||
cols_5_8.append(_create_audio_column(i, visible=True))
|
||||
|
||||
all_cols = cols_1_4 + cols_5_8
|
||||
|
||||
status_output = gr.Textbox(label=t("results.generation_status"), interactive=False)
|
||||
|
||||
|
|
@ -410,48 +129,37 @@ def create_results_section(dit_handler) -> dict:
|
|||
with gr.Row(equal_height=True):
|
||||
prev_batch_btn = gr.Button(
|
||||
t("results.prev_btn"),
|
||||
variant="secondary",
|
||||
interactive=False,
|
||||
scale=1,
|
||||
size="sm"
|
||||
variant="secondary", interactive=False, scale=1, size="sm"
|
||||
)
|
||||
batch_indicator = gr.Textbox(
|
||||
label=t("results.current_batch"),
|
||||
value=t("results.batch_indicator", current=1, total=1),
|
||||
interactive=False,
|
||||
scale=3
|
||||
interactive=False, scale=3
|
||||
)
|
||||
next_batch_status = gr.Textbox(
|
||||
label=t("results.next_batch_status"),
|
||||
value="",
|
||||
interactive=False,
|
||||
scale=3
|
||||
value="", interactive=False, scale=3
|
||||
)
|
||||
next_batch_btn = gr.Button(
|
||||
t("results.next_btn"),
|
||||
variant="primary",
|
||||
interactive=False,
|
||||
scale=1,
|
||||
size="sm"
|
||||
variant="primary", interactive=False, scale=1, size="sm"
|
||||
)
|
||||
|
||||
# One-click restore parameters button
|
||||
restore_params_btn = gr.Button(
|
||||
t("results.restore_params_btn"),
|
||||
variant="secondary",
|
||||
interactive=False, # Initially disabled, enabled after generation
|
||||
size="sm"
|
||||
variant="secondary", interactive=False, size="sm"
|
||||
)
|
||||
|
||||
with gr.Accordion(t("results.batch_results_title"), open=True):
|
||||
generated_audio_batch = gr.File(
|
||||
label=t("results.all_files_label"),
|
||||
file_count="multiple",
|
||||
interactive=False
|
||||
file_count="multiple", interactive=False
|
||||
)
|
||||
generation_info = gr.Markdown(label=t("results.generation_details"))
|
||||
|
||||
return {
|
||||
# Build return dict from all_cols
|
||||
result = {
|
||||
"lm_metadata_state": lm_metadata_state,
|
||||
"is_format_caption_state": is_format_caption_state,
|
||||
"current_batch_index": current_batch_index,
|
||||
|
|
@ -465,88 +173,25 @@ def create_results_section(dit_handler) -> dict:
|
|||
"next_batch_btn": next_batch_btn,
|
||||
"next_batch_status": next_batch_status,
|
||||
"restore_params_btn": restore_params_btn,
|
||||
"generated_audio_1": generated_audio_1,
|
||||
"generated_audio_2": generated_audio_2,
|
||||
"generated_audio_3": generated_audio_3,
|
||||
"generated_audio_4": generated_audio_4,
|
||||
"generated_audio_5": generated_audio_5,
|
||||
"generated_audio_6": generated_audio_6,
|
||||
"generated_audio_7": generated_audio_7,
|
||||
"generated_audio_8": generated_audio_8,
|
||||
"audio_row_5_8": audio_row_5_8,
|
||||
"audio_col_1": audio_col_1,
|
||||
"audio_col_2": audio_col_2,
|
||||
"audio_col_3": audio_col_3,
|
||||
"audio_col_4": audio_col_4,
|
||||
"audio_col_5": audio_col_5,
|
||||
"audio_col_6": audio_col_6,
|
||||
"audio_col_7": audio_col_7,
|
||||
"audio_col_8": audio_col_8,
|
||||
"send_to_src_btn_1": send_to_src_btn_1,
|
||||
"send_to_src_btn_2": send_to_src_btn_2,
|
||||
"send_to_src_btn_3": send_to_src_btn_3,
|
||||
"send_to_src_btn_4": send_to_src_btn_4,
|
||||
"send_to_src_btn_5": send_to_src_btn_5,
|
||||
"send_to_src_btn_6": send_to_src_btn_6,
|
||||
"send_to_src_btn_7": send_to_src_btn_7,
|
||||
"send_to_src_btn_8": send_to_src_btn_8,
|
||||
"save_btn_1": save_btn_1,
|
||||
"save_btn_2": save_btn_2,
|
||||
"save_btn_3": save_btn_3,
|
||||
"save_btn_4": save_btn_4,
|
||||
"save_btn_5": save_btn_5,
|
||||
"save_btn_6": save_btn_6,
|
||||
"save_btn_7": save_btn_7,
|
||||
"save_btn_8": save_btn_8,
|
||||
"score_btn_1": score_btn_1,
|
||||
"score_btn_2": score_btn_2,
|
||||
"score_btn_3": score_btn_3,
|
||||
"score_btn_4": score_btn_4,
|
||||
"score_btn_5": score_btn_5,
|
||||
"score_btn_6": score_btn_6,
|
||||
"score_btn_7": score_btn_7,
|
||||
"score_btn_8": score_btn_8,
|
||||
"score_display_1": score_display_1,
|
||||
"score_display_2": score_display_2,
|
||||
"score_display_3": score_display_3,
|
||||
"score_display_4": score_display_4,
|
||||
"score_display_5": score_display_5,
|
||||
"score_display_6": score_display_6,
|
||||
"score_display_7": score_display_7,
|
||||
"score_display_8": score_display_8,
|
||||
"codes_display_1": codes_display_1,
|
||||
"codes_display_2": codes_display_2,
|
||||
"codes_display_3": codes_display_3,
|
||||
"codes_display_4": codes_display_4,
|
||||
"codes_display_5": codes_display_5,
|
||||
"codes_display_6": codes_display_6,
|
||||
"codes_display_7": codes_display_7,
|
||||
"codes_display_8": codes_display_8,
|
||||
"lrc_btn_1": lrc_btn_1,
|
||||
"lrc_btn_2": lrc_btn_2,
|
||||
"lrc_btn_3": lrc_btn_3,
|
||||
"lrc_btn_4": lrc_btn_4,
|
||||
"lrc_btn_5": lrc_btn_5,
|
||||
"lrc_btn_6": lrc_btn_6,
|
||||
"lrc_btn_7": lrc_btn_7,
|
||||
"lrc_btn_8": lrc_btn_8,
|
||||
"lrc_display_1": lrc_display_1,
|
||||
"lrc_display_2": lrc_display_2,
|
||||
"lrc_display_3": lrc_display_3,
|
||||
"lrc_display_4": lrc_display_4,
|
||||
"lrc_display_5": lrc_display_5,
|
||||
"lrc_display_6": lrc_display_6,
|
||||
"lrc_display_7": lrc_display_7,
|
||||
"lrc_display_8": lrc_display_8,
|
||||
"details_accordion_1": details_accordion_1,
|
||||
"details_accordion_2": details_accordion_2,
|
||||
"details_accordion_3": details_accordion_3,
|
||||
"details_accordion_4": details_accordion_4,
|
||||
"details_accordion_5": details_accordion_5,
|
||||
"details_accordion_6": details_accordion_6,
|
||||
"details_accordion_7": details_accordion_7,
|
||||
"details_accordion_8": details_accordion_8,
|
||||
"generated_audio_batch": generated_audio_batch,
|
||||
"generation_info": generation_info,
|
||||
}
|
||||
|
||||
|
||||
for idx, col_data in enumerate(all_cols, start=1):
|
||||
result[f"generated_audio_{idx}"] = col_data["generated_audio"]
|
||||
result[f"audio_col_{idx}"] = col_data["audio_col"]
|
||||
result[f"send_to_remix_btn_{idx}"] = col_data["send_to_remix_btn"]
|
||||
result[f"send_to_repaint_btn_{idx}"] = col_data["send_to_repaint_btn"]
|
||||
result[f"save_btn_{idx}"] = col_data["save_btn"]
|
||||
result[f"score_btn_{idx}"] = col_data["score_btn"]
|
||||
result[f"score_display_{idx}"] = col_data["score_display"]
|
||||
result[f"codes_display_{idx}"] = col_data["codes_display"]
|
||||
result[f"convert_to_codes_btn_{idx}"] = col_data["convert_to_codes_btn"]
|
||||
result[f"lrc_btn_{idx}"] = col_data["lrc_btn"]
|
||||
result[f"lrc_display_{idx}"] = col_data["lrc_display"]
|
||||
result[f"save_lrc_btn_{idx}"] = col_data["save_lrc_btn"]
|
||||
result[f"lrc_download_file_{idx}"] = col_data["lrc_download_file"]
|
||||
result[f"details_accordion_{idx}"] = col_data["details_accordion"]
|
||||
|
||||
return result
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
2137
acestep/handler.py
2137
acestep/handler.py
File diff suppressed because it is too large
Load diff
2580
acestep/inference.py
2580
acestep/inference.py
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
33
acestep/mlx_dit/__init__.py
Normal file
33
acestep/mlx_dit/__init__.py
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
# Native MLX implementation of the AceStep DiT decoder for Apple Silicon.
|
||||
# Provides pure MLX inference with graceful fallback to PyTorch.
|
||||
|
||||
import logging
|
||||
import platform
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def is_mlx_available() -> bool:
|
||||
"""Check if MLX is available on this platform (macOS + Apple Silicon)."""
|
||||
if platform.system() != "Darwin":
|
||||
return False
|
||||
try:
|
||||
import mlx.core as mx
|
||||
import mlx.nn
|
||||
# Verify we can actually create arrays (Metal backend works)
|
||||
_ = mx.array([1.0])
|
||||
mx.eval(_)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
_MLX_AVAILABLE = None
|
||||
|
||||
|
||||
def mlx_available() -> bool:
|
||||
"""Cached check for MLX availability."""
|
||||
global _MLX_AVAILABLE
|
||||
if _MLX_AVAILABLE is None:
|
||||
_MLX_AVAILABLE = is_mlx_available()
|
||||
return _MLX_AVAILABLE
|
||||
84
acestep/mlx_dit/convert.py
Normal file
84
acestep/mlx_dit/convert.py
Normal file
|
|
@ -0,0 +1,84 @@
|
|||
# Weight conversion from PyTorch AceStep DiT decoder to native MLX format.
|
||||
|
||||
import logging
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def convert_decoder_weights(
|
||||
pytorch_model,
|
||||
) -> List[Tuple[str, "mx.array"]]:
|
||||
"""Convert PyTorch decoder weights to a list of (name, mx.array) pairs
|
||||
suitable for ``mlx_decoder.load_weights()``.
|
||||
|
||||
The function extracts weights from
|
||||
``pytorch_model.decoder`` (``AceStepDiTModel``) and converts them to MLX
|
||||
format, handling:
|
||||
- Conv1d weight layout: PT ``[out, in, K]`` -> MLX ``[out, K, in]``
|
||||
- ConvTranspose1d layout: PT ``[in, out, K]`` -> MLX ``[out, K, in]``
|
||||
- nn.Sequential index remapping (Lambda wrappers removed in MLX)
|
||||
- All other weights are transferred as-is
|
||||
|
||||
Args:
|
||||
pytorch_model: The full ``AceStepConditionGenerationModel`` (PyTorch).
|
||||
|
||||
Returns:
|
||||
List of (param_name, mx.array) pairs ready for ``model.load_weights()``.
|
||||
"""
|
||||
import mlx.core as mx
|
||||
|
||||
decoder = pytorch_model.decoder
|
||||
state_dict = decoder.state_dict()
|
||||
|
||||
weights: List[Tuple[str, "mx.array"]] = []
|
||||
|
||||
for key, value in state_dict.items():
|
||||
np_val = value.detach().cpu().float().numpy()
|
||||
new_key = key
|
||||
|
||||
# PyTorch proj_in is Sequential(Lambda, Conv1d, Lambda)
|
||||
# The Conv1d is at index 1. In MLX we use a bare Conv1d.
|
||||
if key.startswith("proj_in.1."):
|
||||
new_key = key.replace("proj_in.1.", "proj_in.")
|
||||
if new_key.endswith(".weight"):
|
||||
# PT Conv1d weight: [out, in, K] -> MLX: [out, K, in]
|
||||
np_val = np_val.swapaxes(1, 2)
|
||||
|
||||
# PyTorch proj_out is Sequential(Lambda, ConvTranspose1d, Lambda)
|
||||
elif key.startswith("proj_out.1."):
|
||||
new_key = key.replace("proj_out.1.", "proj_out.")
|
||||
if new_key.endswith(".weight"):
|
||||
# PT ConvTranspose1d weight: [in, out, K] -> MLX: [out, K, in]
|
||||
np_val = np_val.transpose(1, 2, 0)
|
||||
|
||||
# Skip rotary embedding buffers (recomputed in MLX)
|
||||
elif "rotary_emb" in key:
|
||||
continue
|
||||
|
||||
weights.append((new_key, mx.array(np_val)))
|
||||
|
||||
logger.info(
|
||||
"[MLX-DiT] Converted %d decoder parameters to MLX format.", len(weights)
|
||||
)
|
||||
return weights
|
||||
|
||||
|
||||
def convert_and_load(
|
||||
pytorch_model,
|
||||
mlx_decoder: "MLXDiTDecoder",
|
||||
) -> None:
|
||||
"""Convert PyTorch decoder weights and load them into an MLX decoder.
|
||||
|
||||
Args:
|
||||
pytorch_model: The full AceStepConditionGenerationModel (PyTorch).
|
||||
mlx_decoder: An instance of ``MLXDiTDecoder`` (already constructed).
|
||||
"""
|
||||
import mlx.core as mx
|
||||
|
||||
weights = convert_decoder_weights(pytorch_model)
|
||||
mlx_decoder.load_weights(weights)
|
||||
mx.eval(mlx_decoder.parameters())
|
||||
logger.info("[MLX-DiT] Weights loaded and evaluated successfully.")
|
||||
213
acestep/mlx_dit/generate.py
Normal file
213
acestep/mlx_dit/generate.py
Normal file
|
|
@ -0,0 +1,213 @@
|
|||
# MLX diffusion generation loop for AceStep DiT decoder.
|
||||
#
|
||||
# Replicates the timestep scheduling and ODE/SDE stepping from
|
||||
# ``AceStepConditionGenerationModel.generate_audio`` using pure MLX arrays.
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Pre-defined timestep schedules (from modeling_acestep_v15_turbo.py)
|
||||
VALID_SHIFTS = [1.0, 2.0, 3.0]
|
||||
|
||||
VALID_TIMESTEPS = [
|
||||
1.0, 0.9545454545454546, 0.9333333333333333, 0.9, 0.875,
|
||||
0.8571428571428571, 0.8333333333333334, 0.7692307692307693, 0.75,
|
||||
0.6666666666666666, 0.6428571428571429, 0.625, 0.5454545454545454,
|
||||
0.5, 0.4, 0.375, 0.3, 0.25, 0.2222222222222222, 0.125,
|
||||
]
|
||||
|
||||
SHIFT_TIMESTEPS = {
|
||||
1.0: [1.0, 0.875, 0.75, 0.625, 0.5, 0.375, 0.25, 0.125],
|
||||
2.0: [1.0, 0.9333333333333333, 0.8571428571428571, 0.7692307692307693,
|
||||
0.6666666666666666, 0.5454545454545454, 0.4, 0.2222222222222222],
|
||||
3.0: [1.0, 0.9545454545454546, 0.9, 0.8333333333333334, 0.75,
|
||||
0.6428571428571429, 0.5, 0.3],
|
||||
}
|
||||
|
||||
|
||||
def get_timestep_schedule(
|
||||
shift: float = 3.0,
|
||||
timesteps: Optional[list] = None,
|
||||
) -> List[float]:
|
||||
"""Compute the timestep schedule for diffusion sampling.
|
||||
|
||||
Replicates the logic from the turbo model's ``generate_audio`` method.
|
||||
|
||||
Args:
|
||||
shift: Diffusion timestep shift (1, 2, or 3).
|
||||
timesteps: Optional custom list of timesteps.
|
||||
|
||||
Returns:
|
||||
List of timestep values (descending, without trailing 0).
|
||||
"""
|
||||
t_schedule_list = None
|
||||
|
||||
if timesteps is not None:
|
||||
ts_list = list(timesteps)
|
||||
# Remove trailing zeros
|
||||
while ts_list and ts_list[-1] == 0:
|
||||
ts_list.pop()
|
||||
if len(ts_list) < 1:
|
||||
logger.warning("timesteps empty after removing zeros; using default shift=%s", shift)
|
||||
else:
|
||||
if len(ts_list) > 20:
|
||||
logger.warning("timesteps length=%d > 20; truncating", len(ts_list))
|
||||
ts_list = ts_list[:20]
|
||||
# Map each timestep to the nearest valid value
|
||||
mapped = [min(VALID_TIMESTEPS, key=lambda x, t=t: abs(x - t)) for t in ts_list]
|
||||
t_schedule_list = mapped
|
||||
|
||||
if t_schedule_list is None:
|
||||
original_shift = shift
|
||||
shift = min(VALID_SHIFTS, key=lambda x: abs(x - shift))
|
||||
if original_shift != shift:
|
||||
logger.warning("shift=%.2f rounded to nearest valid shift=%.1f", original_shift, shift)
|
||||
t_schedule_list = SHIFT_TIMESTEPS[shift]
|
||||
|
||||
return t_schedule_list
|
||||
|
||||
|
||||
def mlx_generate_diffusion(
|
||||
mlx_decoder,
|
||||
encoder_hidden_states_np: np.ndarray,
|
||||
context_latents_np: np.ndarray,
|
||||
src_latents_shape: Tuple[int, ...],
|
||||
seed: Optional[Union[int, List[int]]] = None,
|
||||
infer_method: str = "ode",
|
||||
shift: float = 3.0,
|
||||
timesteps: Optional[list] = None,
|
||||
audio_cover_strength: float = 1.0,
|
||||
encoder_hidden_states_non_cover_np: Optional[np.ndarray] = None,
|
||||
context_latents_non_cover_np: Optional[np.ndarray] = None,
|
||||
) -> Dict[str, object]:
|
||||
"""Run the complete MLX diffusion loop.
|
||||
|
||||
This is the core generation function. It accepts numpy arrays (converted
|
||||
from PyTorch tensors by the handler) and returns numpy arrays that the
|
||||
handler converts back to PyTorch.
|
||||
|
||||
Args:
|
||||
mlx_decoder: ``MLXDiTDecoder`` instance with loaded weights.
|
||||
encoder_hidden_states_np: [B, enc_L, D] from prepare_condition (numpy).
|
||||
context_latents_np: [B, T, C] from prepare_condition (numpy).
|
||||
src_latents_shape: shape tuple [B, T, 64] for noise generation.
|
||||
seed: random seed (int, list[int], or None).
|
||||
infer_method: "ode" or "sde".
|
||||
shift: timestep shift factor.
|
||||
timesteps: optional custom timestep list.
|
||||
audio_cover_strength: cover strength (0-1).
|
||||
encoder_hidden_states_non_cover_np: optional [B, enc_L, D] for non-cover.
|
||||
context_latents_non_cover_np: optional [B, T, C] for non-cover.
|
||||
|
||||
Returns:
|
||||
Dict with ``"target_latents"`` (numpy) and ``"time_costs"`` dict.
|
||||
"""
|
||||
import mlx.core as mx
|
||||
from .model import MLXCrossAttentionCache
|
||||
|
||||
time_costs = {}
|
||||
total_start = time.time()
|
||||
|
||||
# Convert numpy arrays to MLX
|
||||
enc_hs = mx.array(encoder_hidden_states_np)
|
||||
ctx = mx.array(context_latents_np)
|
||||
|
||||
enc_hs_nc = mx.array(encoder_hidden_states_non_cover_np) if encoder_hidden_states_non_cover_np is not None else None
|
||||
ctx_nc = mx.array(context_latents_non_cover_np) if context_latents_non_cover_np is not None else None
|
||||
|
||||
bsz = src_latents_shape[0]
|
||||
T = src_latents_shape[1]
|
||||
C = src_latents_shape[2]
|
||||
|
||||
# ---- Noise preparation ----
|
||||
if seed is None:
|
||||
noise = mx.random.normal((bsz, T, C))
|
||||
elif isinstance(seed, list):
|
||||
parts = []
|
||||
for s in seed:
|
||||
if s is None or s < 0:
|
||||
parts.append(mx.random.normal((1, T, C)))
|
||||
else:
|
||||
key = mx.random.key(int(s))
|
||||
parts.append(mx.random.normal((1, T, C), key=key))
|
||||
noise = mx.concatenate(parts, axis=0)
|
||||
else:
|
||||
key = mx.random.key(int(seed))
|
||||
noise = mx.random.normal((bsz, T, C), key=key)
|
||||
|
||||
# ---- Timestep schedule ----
|
||||
t_schedule_list = get_timestep_schedule(shift, timesteps)
|
||||
num_steps = len(t_schedule_list)
|
||||
|
||||
cover_steps = int(num_steps * audio_cover_strength)
|
||||
|
||||
# ---- Diffusion loop ----
|
||||
cache = MLXCrossAttentionCache()
|
||||
xt = noise
|
||||
|
||||
diff_start = time.time()
|
||||
|
||||
for step_idx in range(num_steps):
|
||||
current_t = t_schedule_list[step_idx]
|
||||
t_curr = mx.full((bsz,), current_t)
|
||||
|
||||
# Switch to non-cover conditions when appropriate
|
||||
if step_idx >= cover_steps and enc_hs_nc is not None:
|
||||
enc_hs = enc_hs_nc
|
||||
ctx = ctx_nc
|
||||
cache = MLXCrossAttentionCache()
|
||||
|
||||
vt, cache = mlx_decoder(
|
||||
hidden_states=xt,
|
||||
timestep=t_curr,
|
||||
timestep_r=t_curr,
|
||||
encoder_hidden_states=enc_hs,
|
||||
context_latents=ctx,
|
||||
cache=cache,
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
# Evaluate to ensure computation is complete before next step
|
||||
mx.eval(vt)
|
||||
|
||||
# Final step: compute x0
|
||||
if step_idx == num_steps - 1:
|
||||
t_unsq = mx.expand_dims(mx.expand_dims(t_curr, axis=-1), axis=-1)
|
||||
xt = xt - vt * t_unsq
|
||||
mx.eval(xt)
|
||||
break
|
||||
|
||||
# ODE / SDE update
|
||||
next_t = t_schedule_list[step_idx + 1]
|
||||
if infer_method == "sde":
|
||||
t_unsq = mx.expand_dims(mx.expand_dims(t_curr, axis=-1), axis=-1)
|
||||
pred_clean = xt - vt * t_unsq
|
||||
# Re-noise with next timestep
|
||||
new_noise = mx.random.normal(xt.shape)
|
||||
xt = next_t * new_noise + (1.0 - next_t) * pred_clean
|
||||
else:
|
||||
# ODE Euler step: x_{t+1} = x_t - v_t * dt
|
||||
dt = current_t - next_t
|
||||
dt_arr = mx.full((bsz, 1, 1), dt)
|
||||
xt = xt - vt * dt_arr
|
||||
|
||||
mx.eval(xt)
|
||||
|
||||
diff_end = time.time()
|
||||
total_end = time.time()
|
||||
|
||||
time_costs["diffusion_time_cost"] = diff_end - diff_start
|
||||
time_costs["diffusion_per_step_time_cost"] = time_costs["diffusion_time_cost"] / max(num_steps, 1)
|
||||
time_costs["total_time_cost"] = total_end - total_start
|
||||
|
||||
# Convert result back to numpy
|
||||
result_np = np.array(xt)
|
||||
return {
|
||||
"target_latents": result_np,
|
||||
"time_costs": time_costs,
|
||||
}
|
||||
629
acestep/mlx_dit/model.py
Normal file
629
acestep/mlx_dit/model.py
Normal file
|
|
@ -0,0 +1,629 @@
|
|||
# This module re-implements the diffusion transformer decoder from
|
||||
# modeling_acestep_v15_turbo.py using pure MLX operations for optimal
|
||||
# performance on Apple Silicon.
|
||||
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Utility helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _rotate_half(x: mx.array) -> mx.array:
|
||||
"""Rotate the last dimension by splitting in half and swapping with negation."""
|
||||
half = x.shape[-1] // 2
|
||||
x1 = x[..., :half]
|
||||
x2 = x[..., half:]
|
||||
return mx.concatenate([-x2, x1], axis=-1)
|
||||
|
||||
|
||||
def _apply_rotary_pos_emb(
|
||||
q: mx.array, k: mx.array, cos: mx.array, sin: mx.array
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
"""Apply rotary position embeddings to query and key tensors.
|
||||
|
||||
Args:
|
||||
q, k: [B, n_heads, L, head_dim]
|
||||
cos, sin: [1, 1, L, head_dim]
|
||||
"""
|
||||
q_embed = (q * cos) + (_rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (_rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
def _create_sliding_window_mask(
|
||||
seq_len: int, window_size: int, dtype: mx.Dtype = mx.float32
|
||||
) -> mx.array:
|
||||
"""Create a bidirectional sliding-window additive attention mask.
|
||||
|
||||
Positions within ``window_size`` of each other get ``0``; all others
|
||||
receive a large negative value (``-1e9``).
|
||||
|
||||
Returns:
|
||||
[1, 1, seq_len, seq_len]
|
||||
"""
|
||||
indices = mx.arange(seq_len)
|
||||
# diff[i, j] = |i - j|
|
||||
diff = mx.abs(indices[:, None] - indices[None, :])
|
||||
zeros = mx.zeros(diff.shape, dtype=dtype)
|
||||
neginf = mx.full(diff.shape, -1e9, dtype=dtype)
|
||||
mask = mx.where(diff <= window_size, zeros, neginf)
|
||||
return mask[None, None, :, :] # [1, 1, L, L]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Rotary Position Embedding
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class MLXRotaryEmbedding(nn.Module):
|
||||
"""Pre-computes and caches cos/sin tables for rotary position embeddings."""
|
||||
|
||||
def __init__(self, head_dim: int, max_len: int = 32768, base: float = 1_000_000.0):
|
||||
super().__init__()
|
||||
self.head_dim = head_dim
|
||||
self.max_len = max_len
|
||||
self.base = base
|
||||
|
||||
inv_freq = 1.0 / (
|
||||
base ** (mx.arange(0, head_dim, 2).astype(mx.float32) / head_dim)
|
||||
)
|
||||
positions = mx.arange(max_len).astype(mx.float32)
|
||||
freqs = positions[:, None] * inv_freq[None, :] # [max_len, head_dim//2]
|
||||
freqs = mx.concatenate([freqs, freqs], axis=-1) # [max_len, head_dim]
|
||||
self._cos = mx.cos(freqs) # [max_len, head_dim]
|
||||
self._sin = mx.sin(freqs) # [max_len, head_dim]
|
||||
|
||||
def __call__(self, seq_len: int) -> Tuple[mx.array, mx.array]:
|
||||
"""Return (cos, sin) each shaped [1, 1, seq_len, head_dim]."""
|
||||
cos = self._cos[:seq_len][None, None, :, :]
|
||||
sin = self._sin[:seq_len][None, None, :, :]
|
||||
return cos, sin
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Cross-Attention KV Cache
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class MLXCrossAttentionCache:
|
||||
"""Simple KV cache for cross-attention layers.
|
||||
|
||||
Cross-attention K/V are computed from encoder hidden states once on the
|
||||
first diffusion step and re-used for all subsequent steps.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._keys: dict[int, mx.array] = {}
|
||||
self._values: dict[int, mx.array] = {}
|
||||
self._updated: set[int] = set()
|
||||
|
||||
def update(self, key: mx.array, value: mx.array, layer_idx: int):
|
||||
self._keys[layer_idx] = key
|
||||
self._values[layer_idx] = value
|
||||
self._updated.add(layer_idx)
|
||||
|
||||
def is_updated(self, layer_idx: int) -> bool:
|
||||
return layer_idx in self._updated
|
||||
|
||||
def get(self, layer_idx: int) -> Tuple[mx.array, mx.array]:
|
||||
return self._keys[layer_idx], self._values[layer_idx]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Core Layers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class MLXSwiGLUMLP(nn.Module):
|
||||
"""SwiGLU MLP (equivalent to Qwen3MLP): gate * silu(gate_proj) * up_proj."""
|
||||
|
||||
def __init__(self, hidden_size: int, intermediate_size: int):
|
||||
super().__init__()
|
||||
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
|
||||
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
|
||||
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
|
||||
class MLXAttention(nn.Module):
|
||||
"""Multi-head attention with QK-RMSNorm for the AceStep DiT.
|
||||
|
||||
Supports both self-attention (with RoPE) and cross-attention (with
|
||||
optional KV caching).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_attention_heads: int,
|
||||
num_key_value_heads: int,
|
||||
head_dim: int,
|
||||
rms_norm_eps: float,
|
||||
attention_bias: bool,
|
||||
layer_idx: int,
|
||||
is_cross_attention: bool = False,
|
||||
sliding_window: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.num_heads = num_attention_heads
|
||||
self.num_kv_heads = num_key_value_heads
|
||||
self.head_dim = head_dim
|
||||
self.n_rep = num_attention_heads // num_key_value_heads
|
||||
self.scale = head_dim ** -0.5
|
||||
self.layer_idx = layer_idx
|
||||
self.is_cross_attention = is_cross_attention
|
||||
self.sliding_window = sliding_window
|
||||
|
||||
self.q_proj = nn.Linear(hidden_size, num_attention_heads * head_dim, bias=attention_bias)
|
||||
self.k_proj = nn.Linear(hidden_size, num_key_value_heads * head_dim, bias=attention_bias)
|
||||
self.v_proj = nn.Linear(hidden_size, num_key_value_heads * head_dim, bias=attention_bias)
|
||||
self.o_proj = nn.Linear(num_attention_heads * head_dim, hidden_size, bias=attention_bias)
|
||||
|
||||
self.q_norm = nn.RMSNorm(head_dim, eps=rms_norm_eps)
|
||||
self.k_norm = nn.RMSNorm(head_dim, eps=rms_norm_eps)
|
||||
|
||||
@staticmethod
|
||||
def _repeat_kv(x: mx.array, n_rep: int) -> mx.array:
|
||||
"""Repeat KV heads for GQA: [B, n_kv, L, D] -> [B, n_kv*n_rep, L, D]."""
|
||||
if n_rep == 1:
|
||||
return x
|
||||
B, n_kv, L, D = x.shape
|
||||
x = mx.expand_dims(x, axis=2) # [B, n_kv, 1, L, D]
|
||||
x = mx.broadcast_to(x, (B, n_kv, n_rep, L, D))
|
||||
return x.reshape(B, n_kv * n_rep, L, D)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
position_cos_sin: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
attention_mask: Optional[mx.array] = None,
|
||||
encoder_hidden_states: Optional[mx.array] = None,
|
||||
cache: Optional[MLXCrossAttentionCache] = None,
|
||||
use_cache: bool = False,
|
||||
) -> mx.array:
|
||||
B, L, _ = hidden_states.shape
|
||||
|
||||
# Project queries (always from hidden_states)
|
||||
q = self.q_proj(hidden_states)
|
||||
q = self.q_norm(q.reshape(B, L, self.num_heads, self.head_dim))
|
||||
q = q.transpose(0, 2, 1, 3) # [B, n_heads, L, D]
|
||||
|
||||
if self.is_cross_attention and encoder_hidden_states is not None:
|
||||
# Cross-attention: K,V come from encoder
|
||||
if cache is not None and cache.is_updated(self.layer_idx):
|
||||
k, v = cache.get(self.layer_idx)
|
||||
else:
|
||||
enc_L = encoder_hidden_states.shape[1]
|
||||
k = self.k_proj(encoder_hidden_states)
|
||||
k = self.k_norm(k.reshape(B, enc_L, self.num_kv_heads, self.head_dim))
|
||||
k = k.transpose(0, 2, 1, 3)
|
||||
v = self.v_proj(encoder_hidden_states).reshape(
|
||||
B, enc_L, self.num_kv_heads, self.head_dim
|
||||
).transpose(0, 2, 1, 3)
|
||||
if cache is not None and use_cache:
|
||||
cache.update(k, v, self.layer_idx)
|
||||
else:
|
||||
# Self-attention: K,V come from hidden_states
|
||||
k = self.k_proj(hidden_states)
|
||||
k = self.k_norm(k.reshape(B, L, self.num_kv_heads, self.head_dim))
|
||||
k = k.transpose(0, 2, 1, 3)
|
||||
v = self.v_proj(hidden_states).reshape(
|
||||
B, L, self.num_kv_heads, self.head_dim
|
||||
).transpose(0, 2, 1, 3)
|
||||
|
||||
# Apply RoPE to self-attention Q,K
|
||||
if position_cos_sin is not None:
|
||||
cos, sin = position_cos_sin
|
||||
q, k = _apply_rotary_pos_emb(q, k, cos, sin)
|
||||
|
||||
# GQA: repeat KV heads to match Q heads
|
||||
k = self._repeat_kv(k, self.n_rep)
|
||||
v = self._repeat_kv(v, self.n_rep)
|
||||
|
||||
# Scaled dot-product attention
|
||||
attn_out = mx.fast.scaled_dot_product_attention(
|
||||
q, k, v, scale=self.scale, mask=attention_mask
|
||||
)
|
||||
|
||||
# Merge heads and project output: [B, n_heads, L, D] -> [B, L, hidden]
|
||||
attn_out = attn_out.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.o_proj(attn_out)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DiT Layer
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class MLXDiTLayer(nn.Module):
|
||||
"""A single DiT transformer layer with AdaLN modulation.
|
||||
|
||||
Implements:
|
||||
1. Self-attention with adaptive layer norm (AdaLN)
|
||||
2. Cross-attention to encoder hidden states
|
||||
3. Feed-forward MLP with adaptive layer norm
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
num_attention_heads: int,
|
||||
num_key_value_heads: int,
|
||||
head_dim: int,
|
||||
rms_norm_eps: float,
|
||||
attention_bias: bool,
|
||||
layer_idx: int,
|
||||
layer_type: str,
|
||||
sliding_window: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.layer_type = layer_type
|
||||
sw = sliding_window if layer_type == "sliding_attention" else None
|
||||
|
||||
# 1. Self-attention
|
||||
self.self_attn_norm = nn.RMSNorm(hidden_size, eps=rms_norm_eps)
|
||||
self.self_attn = MLXAttention(
|
||||
hidden_size=hidden_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
num_key_value_heads=num_key_value_heads,
|
||||
head_dim=head_dim,
|
||||
rms_norm_eps=rms_norm_eps,
|
||||
attention_bias=attention_bias,
|
||||
layer_idx=layer_idx,
|
||||
is_cross_attention=False,
|
||||
sliding_window=sw,
|
||||
)
|
||||
|
||||
# 2. Cross-attention
|
||||
self.cross_attn_norm = nn.RMSNorm(hidden_size, eps=rms_norm_eps)
|
||||
self.cross_attn = MLXAttention(
|
||||
hidden_size=hidden_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
num_key_value_heads=num_key_value_heads,
|
||||
head_dim=head_dim,
|
||||
rms_norm_eps=rms_norm_eps,
|
||||
attention_bias=attention_bias,
|
||||
layer_idx=layer_idx,
|
||||
is_cross_attention=True,
|
||||
)
|
||||
|
||||
# 3. MLP
|
||||
self.mlp_norm = nn.RMSNorm(hidden_size, eps=rms_norm_eps)
|
||||
self.mlp = MLXSwiGLUMLP(hidden_size, intermediate_size)
|
||||
|
||||
# AdaLN modulation table (6 values: shift/scale/gate for self-attn & MLP)
|
||||
self.scale_shift_table = mx.zeros((1, 6, hidden_size))
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
position_cos_sin: Tuple[mx.array, mx.array],
|
||||
temb: mx.array,
|
||||
self_attn_mask: Optional[mx.array],
|
||||
encoder_hidden_states: Optional[mx.array],
|
||||
encoder_attention_mask: Optional[mx.array],
|
||||
cache: Optional[MLXCrossAttentionCache] = None,
|
||||
use_cache: bool = False,
|
||||
) -> mx.array:
|
||||
# AdaLN modulation from timestep embeddings
|
||||
# scale_shift_table: [1, 6, D], temb: [B, 6, D]
|
||||
modulation = self.scale_shift_table + temb # [B, 6, D]
|
||||
parts = mx.split(modulation, 6, axis=1)
|
||||
# Each part: [B, 1, D]
|
||||
shift_msa, scale_msa, gate_msa = parts[0], parts[1], parts[2]
|
||||
c_shift_msa, c_scale_msa, c_gate_msa = parts[3], parts[4], parts[5]
|
||||
|
||||
# Step 1: Self-attention with AdaLN
|
||||
normed = self.self_attn_norm(hidden_states)
|
||||
normed = normed * (1.0 + scale_msa) + shift_msa
|
||||
attn_out = self.self_attn(
|
||||
normed,
|
||||
position_cos_sin=position_cos_sin,
|
||||
attention_mask=self_attn_mask,
|
||||
)
|
||||
hidden_states = hidden_states + attn_out * gate_msa
|
||||
|
||||
# Step 2: Cross-attention
|
||||
normed = self.cross_attn_norm(hidden_states)
|
||||
cross_out = self.cross_attn(
|
||||
normed,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=encoder_attention_mask,
|
||||
cache=cache,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
hidden_states = hidden_states + cross_out
|
||||
|
||||
# Step 3: MLP with AdaLN
|
||||
normed = self.mlp_norm(hidden_states)
|
||||
normed = normed * (1.0 + c_scale_msa) + c_shift_msa
|
||||
ff_out = self.mlp(normed)
|
||||
hidden_states = hidden_states + ff_out * c_gate_msa
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Timestep Embedding
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class MLXTimestepEmbedding(nn.Module):
|
||||
"""Sinusoidal timestep embedding followed by MLP."""
|
||||
|
||||
def __init__(self, in_channels: int = 256, time_embed_dim: int = 2048, scale: float = 1000.0):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.scale = scale
|
||||
|
||||
self.linear_1 = nn.Linear(in_channels, time_embed_dim, bias=True)
|
||||
self.act1 = nn.SiLU()
|
||||
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim, bias=True)
|
||||
self.act2 = nn.SiLU()
|
||||
self.time_proj = nn.Linear(time_embed_dim, time_embed_dim * 6, bias=True)
|
||||
|
||||
def _sinusoidal_embedding(self, t: mx.array, dim: int, max_period: int = 10000) -> mx.array:
|
||||
"""Create sinusoidal timestep embeddings.
|
||||
|
||||
Args:
|
||||
t: 1-D array of shape [N]
|
||||
dim: embedding dimension
|
||||
Returns:
|
||||
[N, dim]
|
||||
"""
|
||||
t = t * self.scale
|
||||
half = dim // 2
|
||||
freqs = mx.exp(
|
||||
-math.log(max_period)
|
||||
* mx.arange(half).astype(mx.float32) / half
|
||||
)
|
||||
args = t[:, None].astype(mx.float32) * freqs[None, :]
|
||||
embedding = mx.concatenate([mx.cos(args), mx.sin(args)], axis=-1)
|
||||
if dim % 2:
|
||||
embedding = mx.concatenate(
|
||||
[embedding, mx.zeros_like(embedding[:, :1])], axis=-1
|
||||
)
|
||||
return embedding
|
||||
|
||||
def __call__(self, t: mx.array) -> Tuple[mx.array, mx.array]:
|
||||
"""
|
||||
Args:
|
||||
t: [B] timestep values
|
||||
Returns:
|
||||
temb: [B, D]
|
||||
timestep_proj: [B, 6, D]
|
||||
"""
|
||||
t_freq = self._sinusoidal_embedding(t, self.in_channels)
|
||||
temb = self.linear_1(t_freq.astype(t.dtype))
|
||||
temb = self.act1(temb)
|
||||
temb = self.linear_2(temb)
|
||||
proj = self.time_proj(self.act2(temb)) # [B, D*6]
|
||||
timestep_proj = proj.reshape(proj.shape[0], 6, -1) # [B, 6, D]
|
||||
return temb, timestep_proj
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Full DiT Decoder
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class MLXDiTDecoder(nn.Module):
|
||||
"""Native MLX implementation of AceStepDiTModel (the diffusion transformer decoder).
|
||||
|
||||
Mirrors the PyTorch ``AceStepDiTModel`` class exactly:
|
||||
- Patch-based input projection (Conv1d)
|
||||
- Timestep conditioning via dual TimestepEmbedding
|
||||
- N DiT transformer layers with self/cross-attention and AdaLN
|
||||
- Patch-based output projection (ConvTranspose1d)
|
||||
- Adaptive output layer norm
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int = 2048,
|
||||
intermediate_size: int = 6144,
|
||||
num_hidden_layers: int = 24,
|
||||
num_attention_heads: int = 16,
|
||||
num_key_value_heads: int = 8,
|
||||
head_dim: int = 128,
|
||||
rms_norm_eps: float = 1e-6,
|
||||
attention_bias: bool = False,
|
||||
in_channels: int = 192,
|
||||
audio_acoustic_hidden_dim: int = 64,
|
||||
patch_size: int = 2,
|
||||
sliding_window: int = 128,
|
||||
layer_types: Optional[list] = None,
|
||||
rope_theta: float = 1_000_000.0,
|
||||
max_position_embeddings: int = 32768,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.patch_size = patch_size
|
||||
inner_dim = hidden_size
|
||||
|
||||
if layer_types is None:
|
||||
layer_types = [
|
||||
"sliding_attention" if bool((i + 1) % 2) else "full_attention"
|
||||
for i in range(num_hidden_layers)
|
||||
]
|
||||
|
||||
# Rotary position embeddings
|
||||
self.rotary_emb = MLXRotaryEmbedding(
|
||||
head_dim, max_len=max_position_embeddings, base=rope_theta
|
||||
)
|
||||
|
||||
# Input projection: Conv1d patch embedding
|
||||
# MLX Conv1d uses channels-last: [B, L, C] -> [B, L//stride, out_C]
|
||||
self.proj_in = nn.Conv1d(
|
||||
in_channels=in_channels,
|
||||
out_channels=inner_dim,
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size,
|
||||
padding=0,
|
||||
)
|
||||
|
||||
# Timestep embeddings (two: t and t-r)
|
||||
self.time_embed = MLXTimestepEmbedding(in_channels=256, time_embed_dim=inner_dim)
|
||||
self.time_embed_r = MLXTimestepEmbedding(in_channels=256, time_embed_dim=inner_dim)
|
||||
|
||||
# Condition embedder
|
||||
self.condition_embedder = nn.Linear(inner_dim, inner_dim, bias=True)
|
||||
|
||||
# Transformer layers
|
||||
self.layers = [
|
||||
MLXDiTLayer(
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
num_key_value_heads=num_key_value_heads,
|
||||
head_dim=head_dim,
|
||||
rms_norm_eps=rms_norm_eps,
|
||||
attention_bias=attention_bias,
|
||||
layer_idx=i,
|
||||
layer_type=layer_types[i],
|
||||
sliding_window=sliding_window,
|
||||
)
|
||||
for i in range(num_hidden_layers)
|
||||
]
|
||||
|
||||
# Output
|
||||
self.norm_out = nn.RMSNorm(inner_dim, eps=rms_norm_eps)
|
||||
self.proj_out = nn.ConvTranspose1d(
|
||||
in_channels=inner_dim,
|
||||
out_channels=audio_acoustic_hidden_dim,
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size,
|
||||
padding=0,
|
||||
)
|
||||
|
||||
# Output adaptive layer norm modulation (2 values: shift, scale)
|
||||
self.scale_shift_table = mx.zeros((1, 2, inner_dim))
|
||||
|
||||
# Pre-compute sliding window mask (will be set on first forward)
|
||||
self._sliding_masks: dict[int, mx.array] = {}
|
||||
self._sliding_window = sliding_window
|
||||
self._layer_types = layer_types
|
||||
|
||||
def _get_sliding_mask(self, seq_len: int, dtype: mx.Dtype) -> mx.array:
|
||||
if seq_len not in self._sliding_masks:
|
||||
self._sliding_masks[seq_len] = _create_sliding_window_mask(
|
||||
seq_len, self._sliding_window, dtype
|
||||
)
|
||||
return self._sliding_masks[seq_len]
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
timestep: mx.array,
|
||||
timestep_r: mx.array,
|
||||
encoder_hidden_states: mx.array,
|
||||
context_latents: mx.array,
|
||||
cache: Optional[MLXCrossAttentionCache] = None,
|
||||
use_cache: bool = True,
|
||||
) -> Tuple[mx.array, Optional[MLXCrossAttentionCache]]:
|
||||
"""
|
||||
Args:
|
||||
hidden_states: noisy latents [B, T, 64]
|
||||
timestep: [B] current timestep
|
||||
timestep_r: [B] reference timestep
|
||||
encoder_hidden_states: [B, enc_L, D] from condition encoder
|
||||
context_latents: [B, T, C_ctx] (src_latents + chunk_masks)
|
||||
cache: cross-attention KV cache
|
||||
use_cache: whether to cache cross-attention KV
|
||||
|
||||
Returns:
|
||||
(output_hidden_states, cache)
|
||||
"""
|
||||
# Timestep embeddings
|
||||
temb_t, proj_t = self.time_embed(timestep)
|
||||
temb_r, proj_r = self.time_embed_r(timestep - timestep_r)
|
||||
temb = temb_t + temb_r # [B, D]
|
||||
timestep_proj = proj_t + proj_r # [B, 6, D]
|
||||
|
||||
# Concatenate context with hidden states: [B, T, C_ctx + 64] -> [B, T, in_channels]
|
||||
hidden_states = mx.concatenate([context_latents, hidden_states], axis=-1)
|
||||
|
||||
original_seq_len = hidden_states.shape[1]
|
||||
|
||||
# Pad to multiple of patch_size
|
||||
pad_length = 0
|
||||
if hidden_states.shape[1] % self.patch_size != 0:
|
||||
pad_length = self.patch_size - (hidden_states.shape[1] % self.patch_size)
|
||||
# Pad along time dimension
|
||||
padding = mx.zeros(
|
||||
(hidden_states.shape[0], pad_length, hidden_states.shape[2]),
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
hidden_states = mx.concatenate([hidden_states, padding], axis=1)
|
||||
|
||||
# Patch embedding: [B, T, in_ch] -> [B, T//patch, D]
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
|
||||
# Project encoder states
|
||||
encoder_hidden_states = self.condition_embedder(encoder_hidden_states)
|
||||
|
||||
seq_len = hidden_states.shape[1]
|
||||
dtype = hidden_states.dtype
|
||||
|
||||
# Position embeddings (RoPE)
|
||||
cos, sin = self.rotary_emb(seq_len)
|
||||
|
||||
# Attention masks
|
||||
# Self-attention: full layers get None; sliding layers get windowed mask
|
||||
# Cross-attention: always None (no masking)
|
||||
sliding_mask = None
|
||||
has_sliding = any(lt == "sliding_attention" for lt in self._layer_types)
|
||||
if has_sliding:
|
||||
sliding_mask = self._get_sliding_mask(seq_len, dtype)
|
||||
|
||||
# Process through transformer layers
|
||||
for layer in self.layers:
|
||||
self_attn_mask = sliding_mask if layer.layer_type == "sliding_attention" else None
|
||||
hidden_states = layer(
|
||||
hidden_states,
|
||||
position_cos_sin=(cos, sin),
|
||||
temb=timestep_proj,
|
||||
self_attn_mask=self_attn_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=None,
|
||||
cache=cache,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
|
||||
# Output adaptive layer norm
|
||||
shift, scale = mx.split(
|
||||
self.scale_shift_table + mx.expand_dims(temb, axis=1), 2, axis=1
|
||||
)
|
||||
hidden_states = self.norm_out(hidden_states) * (1.0 + scale) + shift
|
||||
|
||||
# De-patchify: [B, T//patch, D] -> [B, T, out_channels]
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
|
||||
# Crop back to original sequence length
|
||||
hidden_states = hidden_states[:, :original_seq_len, :]
|
||||
|
||||
return hidden_states, cache
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config) -> "MLXDiTDecoder":
|
||||
"""Construct from an AceStepConfig (transformers PretrainedConfig)."""
|
||||
return cls(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
num_hidden_layers=config.num_hidden_layers,
|
||||
num_attention_heads=config.num_attention_heads,
|
||||
num_key_value_heads=config.num_key_value_heads,
|
||||
head_dim=getattr(config, "head_dim", config.hidden_size // config.num_attention_heads),
|
||||
rms_norm_eps=config.rms_norm_eps,
|
||||
attention_bias=config.attention_bias,
|
||||
in_channels=config.in_channels,
|
||||
audio_acoustic_hidden_dim=config.audio_acoustic_hidden_dim,
|
||||
patch_size=config.patch_size,
|
||||
sliding_window=config.sliding_window if config.sliding_window else 128,
|
||||
layer_types=config.layer_types,
|
||||
rope_theta=config.rope_theta,
|
||||
max_position_embeddings=config.max_position_embeddings,
|
||||
)
|
||||
33
acestep/mlx_vae/__init__.py
Normal file
33
acestep/mlx_vae/__init__.py
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
# Native MLX implementation of the Oobleck VAE (AutoencoderOobleck) for Apple Silicon.
|
||||
# Provides pure MLX encode/decode with graceful fallback to PyTorch.
|
||||
|
||||
import logging
|
||||
import platform
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def is_mlx_available() -> bool:
|
||||
"""Check if MLX is available on this platform (macOS + Apple Silicon)."""
|
||||
if platform.system() != "Darwin":
|
||||
return False
|
||||
try:
|
||||
import mlx.core as mx
|
||||
import mlx.nn
|
||||
# Verify we can actually create arrays (Metal backend works)
|
||||
_ = mx.array([1.0])
|
||||
mx.eval(_)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
_MLX_AVAILABLE = None
|
||||
|
||||
|
||||
def mlx_available() -> bool:
|
||||
"""Cached check for MLX availability."""
|
||||
global _MLX_AVAILABLE
|
||||
if _MLX_AVAILABLE is None:
|
||||
_MLX_AVAILABLE = is_mlx_available()
|
||||
return _MLX_AVAILABLE
|
||||
150
acestep/mlx_vae/convert.py
Normal file
150
acestep/mlx_vae/convert.py
Normal file
|
|
@ -0,0 +1,150 @@
|
|||
# Weight conversion from PyTorch AutoencoderOobleck to native MLX format.
|
||||
#
|
||||
# Handles:
|
||||
# - weight_norm fusion: weight_g + weight_v → fused weight
|
||||
# - Conv1d axis swap: PT [out, in, K] → MLX [out, K, in]
|
||||
# - ConvTranspose1d: PT [in, out, K] → MLX [out, K, in]
|
||||
# - Snake1d parameters: PT [1, C, 1] → MLX [C]
|
||||
# - Bias: no change
|
||||
|
||||
import logging
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _fuse_weight_norm(
|
||||
weight_g: np.ndarray,
|
||||
weight_v: np.ndarray,
|
||||
eps: float = 1e-9,
|
||||
) -> np.ndarray:
|
||||
"""Fuse weight_norm parameters into a single weight tensor.
|
||||
|
||||
weight_norm decomposes ``w = g * v / ||v||`` where:
|
||||
weight_g: per-output-channel scale [out, 1, 1] (Conv1d)
|
||||
or [in, 1, 1] for ConvTranspose1d
|
||||
weight_v: direction tensor with same shape as the original weight
|
||||
|
||||
Returns the fused weight in the *original PyTorch shape* (before axis swap).
|
||||
"""
|
||||
v_flat = weight_v.reshape(weight_v.shape[0], -1)
|
||||
norm = np.linalg.norm(v_flat, axis=1).reshape(weight_g.shape)
|
||||
return weight_g * weight_v / (norm + eps)
|
||||
|
||||
|
||||
def convert_vae_weights(
|
||||
pytorch_vae,
|
||||
) -> List[Tuple[str, "mx.array"]]:
|
||||
"""Convert PyTorch AutoencoderOobleck weights to MLX format.
|
||||
|
||||
The function extracts the state dict from ``pytorch_vae`` and converts
|
||||
each parameter to the format expected by ``MLXAutoEncoderOobleck``,
|
||||
handling weight_norm fusion, Conv axis reordering, and Snake1d
|
||||
parameter reshaping.
|
||||
|
||||
Args:
|
||||
pytorch_vae: A ``diffusers.AutoencoderOobleck`` instance.
|
||||
|
||||
Returns:
|
||||
List of ``(param_name, mx.array)`` pairs for ``model.load_weights()``.
|
||||
"""
|
||||
import mlx.core as mx
|
||||
|
||||
state_dict = pytorch_vae.state_dict()
|
||||
weights: List[Tuple[str, "mx.array"]] = []
|
||||
processed: set = set()
|
||||
|
||||
# Sort keys so we process weight_v / weight_g pairs together
|
||||
all_keys = sorted(state_dict.keys())
|
||||
|
||||
for key in all_keys:
|
||||
if key in processed:
|
||||
continue
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 1) weight_norm fusion: *_weight_g + *_weight_v → *.weight
|
||||
# ------------------------------------------------------------------
|
||||
if key.endswith(".weight_g"):
|
||||
# The companion weight_v key
|
||||
base = key[: -len(".weight_g")]
|
||||
v_key = base + ".weight_v"
|
||||
|
||||
if v_key not in state_dict:
|
||||
logger.warning(
|
||||
"[MLX-VAE] weight_g without weight_v: %s — skipping", key
|
||||
)
|
||||
processed.add(key)
|
||||
continue
|
||||
|
||||
g = state_dict[key].detach().cpu().float().numpy()
|
||||
v = state_dict[v_key].detach().cpu().float().numpy()
|
||||
w = _fuse_weight_norm(g, v)
|
||||
|
||||
# Determine layer type and swap axes
|
||||
if "conv_t1" in base:
|
||||
# ConvTranspose1d: PT [in, out, K] → MLX [out, K, in]
|
||||
w = w.transpose(1, 2, 0)
|
||||
else:
|
||||
# Conv1d: PT [out, in, K] → MLX [out, K, in]
|
||||
w = w.swapaxes(1, 2)
|
||||
|
||||
new_key = base + ".weight"
|
||||
weights.append((new_key, mx.array(w)))
|
||||
processed.add(key)
|
||||
processed.add(v_key)
|
||||
continue
|
||||
|
||||
if key.endswith(".weight_v"):
|
||||
# Will be / was handled together with weight_g above
|
||||
continue
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 2) Snake1d parameters: PT [1, C, 1] → MLX [C]
|
||||
# ------------------------------------------------------------------
|
||||
if key.endswith(".alpha") or key.endswith(".beta"):
|
||||
val = state_dict[key].detach().cpu().float().numpy().squeeze()
|
||||
weights.append((key, mx.array(val)))
|
||||
processed.add(key)
|
||||
continue
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 3) Bias: shape [C] — no transformation needed
|
||||
# ------------------------------------------------------------------
|
||||
if key.endswith(".bias"):
|
||||
val = state_dict[key].detach().cpu().float().numpy()
|
||||
weights.append((key, mx.array(val)))
|
||||
processed.add(key)
|
||||
continue
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 4) Catch-all for any remaining parameters (unlikely in this model)
|
||||
# ------------------------------------------------------------------
|
||||
val = state_dict[key].detach().cpu().float().numpy()
|
||||
logger.debug("[MLX-VAE] Pass-through key: %s shape=%s", key, val.shape)
|
||||
weights.append((key, mx.array(val)))
|
||||
processed.add(key)
|
||||
|
||||
logger.info(
|
||||
"[MLX-VAE] Converted %d parameters to MLX format.", len(weights)
|
||||
)
|
||||
return weights
|
||||
|
||||
|
||||
def convert_and_load(
|
||||
pytorch_vae,
|
||||
mlx_vae: "MLXAutoEncoderOobleck",
|
||||
) -> None:
|
||||
"""Convert PyTorch VAE weights and load them into an MLX VAE.
|
||||
|
||||
Args:
|
||||
pytorch_vae: ``diffusers.AutoencoderOobleck`` instance (PyTorch).
|
||||
mlx_vae: An ``MLXAutoEncoderOobleck`` instance (already constructed).
|
||||
"""
|
||||
import mlx.core as mx
|
||||
|
||||
weights = convert_vae_weights(pytorch_vae)
|
||||
mlx_vae.load_weights(weights)
|
||||
mx.eval(mlx_vae.parameters())
|
||||
logger.info("[MLX-VAE] Weights loaded and evaluated successfully.")
|
||||
336
acestep/mlx_vae/model.py
Normal file
336
acestep/mlx_vae/model.py
Normal file
|
|
@ -0,0 +1,336 @@
|
|||
# Pure MLX re-implementation of diffusers' AutoencoderOobleck for Apple Silicon.
|
||||
#
|
||||
# Architecture mirrors the PyTorch version exactly:
|
||||
# Snake1d -> OobleckResidualUnit -> EncoderBlock / DecoderBlock
|
||||
# -> OobleckEncoder / OobleckDecoder -> MLXAutoEncoderOobleck
|
||||
#
|
||||
# All operations use MLX channels-last (NLC) convention internally.
|
||||
# The public encode/decode API accepts and returns NLC arrays.
|
||||
|
||||
import math
|
||||
import logging
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Snake1d Activation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class MLXSnake1d(nn.Module):
|
||||
"""Snake activation: x + (1/beta) * sin(alpha * x)^2.
|
||||
|
||||
Parameters ``alpha`` and ``beta`` are stored as 1-D vectors of shape [C]
|
||||
and broadcast over (B, L) automatically. When ``logscale=True`` (default)
|
||||
the actual scale is ``exp(alpha)`` / ``exp(beta)``.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_dim: int, logscale: bool = True):
|
||||
super().__init__()
|
||||
self.alpha = mx.zeros(hidden_dim)
|
||||
self.beta = mx.zeros(hidden_dim)
|
||||
self.logscale = logscale
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
# x: [B, L, C] (NLC)
|
||||
# NOTE: Upcast to float32 for exp/sin/power to prevent overflow with float16
|
||||
# weights (exp overflows float16 at alpha > ~11). This is only a problem
|
||||
# if the weights are in float16. The surrounding
|
||||
# Conv1d layers still run in the caller's dtype (float16) for speed.
|
||||
|
||||
# This is the original code that works with float16 weights, if we end up needing to
|
||||
# use float16 weights. please use this code instead
|
||||
# alpha = mx.exp(self.alpha.astype(mx.float32)) if self.logscale else self.alpha
|
||||
# beta = mx.exp(self.beta.astype(mx.float32)) if self.logscale else self.beta
|
||||
# x_f32 = x.astype(mx.float32)
|
||||
# result = x_f32 + mx.reciprocal(beta + 1e-9) * mx.power(mx.sin(alpha * x_f32), 2)
|
||||
# return result.astype(x.dtype)
|
||||
alpha = mx.exp(self.alpha) if self.logscale else self.alpha
|
||||
beta = mx.exp(self.beta) if self.logscale else self.beta
|
||||
# All ops broadcast [C] over [B, L, C]
|
||||
return x + mx.reciprocal(beta + 1e-9) * mx.power(mx.sin(alpha * x), 2)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Residual Unit
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class MLXOobleckResidualUnit(nn.Module):
|
||||
"""Two weight-normalised Conv1d layers (k=7 dilated + k=1) wrapped with
|
||||
Snake1d activations and a residual skip connection."""
|
||||
|
||||
def __init__(self, dimension: int = 16, dilation: int = 1):
|
||||
super().__init__()
|
||||
pad = ((7 - 1) * dilation) // 2
|
||||
|
||||
self.snake1 = MLXSnake1d(dimension)
|
||||
self.conv1 = nn.Conv1d(
|
||||
dimension, dimension, kernel_size=7, dilation=dilation, padding=pad
|
||||
)
|
||||
self.snake2 = MLXSnake1d(dimension)
|
||||
self.conv2 = nn.Conv1d(dimension, dimension, kernel_size=1)
|
||||
|
||||
def __call__(self, hidden_state: mx.array) -> mx.array:
|
||||
# hidden_state: [B, L, C]
|
||||
output = self.conv1(self.snake1(hidden_state))
|
||||
output = self.conv2(self.snake2(output))
|
||||
|
||||
# Safety trim (should be no-op with correct padding)
|
||||
padding = (hidden_state.shape[1] - output.shape[1]) // 2
|
||||
if padding > 0:
|
||||
hidden_state = hidden_state[:, padding:-padding, :]
|
||||
|
||||
return hidden_state + output
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Encoder / Decoder Blocks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class MLXOobleckEncoderBlock(nn.Module):
|
||||
"""3 residual units (dilations 1, 3, 9) -> Snake -> strided Conv1d down."""
|
||||
|
||||
def __init__(self, input_dim: int, output_dim: int, stride: int = 1):
|
||||
super().__init__()
|
||||
self.res_unit1 = MLXOobleckResidualUnit(input_dim, dilation=1)
|
||||
self.res_unit2 = MLXOobleckResidualUnit(input_dim, dilation=3)
|
||||
self.res_unit3 = MLXOobleckResidualUnit(input_dim, dilation=9)
|
||||
self.snake1 = MLXSnake1d(input_dim)
|
||||
self.conv1 = nn.Conv1d(
|
||||
input_dim,
|
||||
output_dim,
|
||||
kernel_size=2 * stride,
|
||||
stride=stride,
|
||||
padding=math.ceil(stride / 2),
|
||||
)
|
||||
|
||||
def __call__(self, hidden_state: mx.array) -> mx.array:
|
||||
hidden_state = self.res_unit1(hidden_state)
|
||||
hidden_state = self.res_unit2(hidden_state)
|
||||
hidden_state = self.snake1(self.res_unit3(hidden_state))
|
||||
hidden_state = self.conv1(hidden_state)
|
||||
return hidden_state
|
||||
|
||||
|
||||
class MLXOobleckDecoderBlock(nn.Module):
|
||||
"""Snake -> strided ConvTranspose1d up -> 3 residual units (dilations 1, 3, 9)."""
|
||||
|
||||
def __init__(self, input_dim: int, output_dim: int, stride: int = 1):
|
||||
super().__init__()
|
||||
self.snake1 = MLXSnake1d(input_dim)
|
||||
self.conv_t1 = nn.ConvTranspose1d(
|
||||
input_dim,
|
||||
output_dim,
|
||||
kernel_size=2 * stride,
|
||||
stride=stride,
|
||||
padding=math.ceil(stride / 2),
|
||||
)
|
||||
self.res_unit1 = MLXOobleckResidualUnit(output_dim, dilation=1)
|
||||
self.res_unit2 = MLXOobleckResidualUnit(output_dim, dilation=3)
|
||||
self.res_unit3 = MLXOobleckResidualUnit(output_dim, dilation=9)
|
||||
|
||||
def __call__(self, hidden_state: mx.array) -> mx.array:
|
||||
hidden_state = self.snake1(hidden_state)
|
||||
hidden_state = self.conv_t1(hidden_state)
|
||||
hidden_state = self.res_unit1(hidden_state)
|
||||
hidden_state = self.res_unit2(hidden_state)
|
||||
hidden_state = self.res_unit3(hidden_state)
|
||||
return hidden_state
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Encoder / Decoder
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class MLXOobleckEncoder(nn.Module):
|
||||
"""Oobleck Encoder: Conv1d -> N encoder blocks -> Snake -> Conv1d."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
encoder_hidden_size: int,
|
||||
audio_channels: int,
|
||||
downsampling_ratios: List[int],
|
||||
channel_multiples: List[int],
|
||||
):
|
||||
super().__init__()
|
||||
strides = downsampling_ratios
|
||||
cm = [1] + list(channel_multiples)
|
||||
|
||||
self.conv1 = nn.Conv1d(
|
||||
audio_channels, encoder_hidden_size, kernel_size=7, padding=3
|
||||
)
|
||||
|
||||
self.block = []
|
||||
for i, stride in enumerate(strides):
|
||||
self.block.append(
|
||||
MLXOobleckEncoderBlock(
|
||||
input_dim=encoder_hidden_size * cm[i],
|
||||
output_dim=encoder_hidden_size * cm[i + 1],
|
||||
stride=stride,
|
||||
)
|
||||
)
|
||||
|
||||
d_model = encoder_hidden_size * cm[-1]
|
||||
self.snake1 = MLXSnake1d(d_model)
|
||||
self.conv2 = nn.Conv1d(d_model, encoder_hidden_size, kernel_size=3, padding=1)
|
||||
|
||||
def __call__(self, hidden_state: mx.array) -> mx.array:
|
||||
hidden_state = self.conv1(hidden_state)
|
||||
for module in self.block:
|
||||
hidden_state = module(hidden_state)
|
||||
hidden_state = self.snake1(hidden_state)
|
||||
hidden_state = self.conv2(hidden_state)
|
||||
return hidden_state
|
||||
|
||||
|
||||
class MLXOobleckDecoder(nn.Module):
|
||||
"""Oobleck Decoder: Conv1d -> N decoder blocks -> Snake -> Conv1d."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
input_channels: int,
|
||||
audio_channels: int,
|
||||
upsampling_ratios: List[int],
|
||||
channel_multiples: List[int],
|
||||
):
|
||||
super().__init__()
|
||||
strides = upsampling_ratios
|
||||
cm = [1] + list(channel_multiples)
|
||||
|
||||
self.conv1 = nn.Conv1d(
|
||||
input_channels, channels * cm[-1], kernel_size=7, padding=3
|
||||
)
|
||||
|
||||
self.block = []
|
||||
for i, stride in enumerate(strides):
|
||||
self.block.append(
|
||||
MLXOobleckDecoderBlock(
|
||||
input_dim=channels * cm[len(strides) - i],
|
||||
output_dim=channels * cm[len(strides) - i - 1],
|
||||
stride=stride,
|
||||
)
|
||||
)
|
||||
|
||||
self.snake1 = MLXSnake1d(channels)
|
||||
self.conv2 = nn.Conv1d(
|
||||
channels, audio_channels, kernel_size=7, padding=3, bias=False
|
||||
)
|
||||
|
||||
def __call__(self, hidden_state: mx.array) -> mx.array:
|
||||
hidden_state = self.conv1(hidden_state)
|
||||
for layer in self.block:
|
||||
hidden_state = layer(hidden_state)
|
||||
hidden_state = self.snake1(hidden_state)
|
||||
hidden_state = self.conv2(hidden_state)
|
||||
return hidden_state
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Full VAE
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class MLXAutoEncoderOobleck(nn.Module):
|
||||
"""Pure-MLX re-implementation of ``diffusers.AutoencoderOobleck``.
|
||||
|
||||
Default configuration matches the Stable Audio / ACE-Step VAE:
|
||||
encoder_hidden_size = 128
|
||||
downsampling_ratios = [2, 4, 4, 8, 8] (hop_length = 2048)
|
||||
channel_multiples = [1, 2, 4, 8, 16]
|
||||
decoder_channels = 128
|
||||
decoder_input_channels = 64 (latent dim)
|
||||
audio_channels = 2 (stereo)
|
||||
|
||||
Data flows in NLC (batch, length, channels) format throughout.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
encoder_hidden_size: int = 128,
|
||||
downsampling_ratios: Optional[List[int]] = None,
|
||||
channel_multiples: Optional[List[int]] = None,
|
||||
decoder_channels: int = 128,
|
||||
decoder_input_channels: int = 64,
|
||||
audio_channels: int = 2,
|
||||
):
|
||||
super().__init__()
|
||||
if downsampling_ratios is None:
|
||||
downsampling_ratios = [2, 4, 4, 8, 8]
|
||||
if channel_multiples is None:
|
||||
channel_multiples = [1, 2, 4, 8, 16]
|
||||
|
||||
self.encoder_hidden_size = encoder_hidden_size
|
||||
self.decoder_input_channels = decoder_input_channels
|
||||
|
||||
self.encoder = MLXOobleckEncoder(
|
||||
encoder_hidden_size=encoder_hidden_size,
|
||||
audio_channels=audio_channels,
|
||||
downsampling_ratios=downsampling_ratios,
|
||||
channel_multiples=channel_multiples,
|
||||
)
|
||||
self.decoder = MLXOobleckDecoder(
|
||||
channels=decoder_channels,
|
||||
input_channels=decoder_input_channels,
|
||||
audio_channels=audio_channels,
|
||||
upsampling_ratios=downsampling_ratios[::-1],
|
||||
channel_multiples=channel_multiples,
|
||||
)
|
||||
|
||||
# -- public API ---------------------------------------------------------
|
||||
|
||||
def encode_and_sample(self, audio_nlc: mx.array) -> mx.array:
|
||||
"""Encode audio -> sample latent.
|
||||
|
||||
Args:
|
||||
audio_nlc: [B, L_audio, C_audio] in NLC format.
|
||||
|
||||
Returns:
|
||||
z: [B, L_latent, C_latent] sampled latent.
|
||||
"""
|
||||
h = self.encoder(audio_nlc) # [B, L', encoder_hidden_size]
|
||||
|
||||
# Diagonal Gaussian: split into mean + log-scale
|
||||
mean, scale = mx.split(h, 2, axis=-1)
|
||||
|
||||
# softplus(scale) + epsilon (numerically stable)
|
||||
std = mx.where(scale > 20.0, scale, mx.log(1.0 + mx.exp(scale))) + 1e-4
|
||||
|
||||
noise = mx.random.normal(mean.shape)
|
||||
z = mean + std * noise
|
||||
return z
|
||||
|
||||
def encode_mean(self, audio_nlc: mx.array) -> mx.array:
|
||||
"""Encode audio -> return mean (no sampling noise)."""
|
||||
h = self.encoder(audio_nlc)
|
||||
mean, _scale = mx.split(h, 2, axis=-1)
|
||||
return mean
|
||||
|
||||
def decode(self, latents_nlc: mx.array) -> mx.array:
|
||||
"""Decode latents -> audio.
|
||||
|
||||
Args:
|
||||
latents_nlc: [B, L_latent, C_latent] in NLC format.
|
||||
|
||||
Returns:
|
||||
audio: [B, L_audio, C_audio] in NLC format.
|
||||
"""
|
||||
return self.decoder(latents_nlc)
|
||||
|
||||
# -- construction helpers -----------------------------------------------
|
||||
|
||||
@classmethod
|
||||
def from_pytorch_config(cls, pt_vae) -> "MLXAutoEncoderOobleck":
|
||||
"""Construct from a PyTorch ``AutoencoderOobleck`` instance's config."""
|
||||
cfg = pt_vae.config
|
||||
return cls(
|
||||
encoder_hidden_size=cfg.encoder_hidden_size,
|
||||
downsampling_ratios=list(cfg.downsampling_ratios),
|
||||
channel_multiples=list(cfg.channel_multiples),
|
||||
decoder_channels=cfg.decoder_channels,
|
||||
decoder_input_channels=cfg.decoder_input_channels,
|
||||
audio_channels=cfg.audio_channels,
|
||||
)
|
||||
|
|
@ -8,6 +8,8 @@ with intelligent fallback between download sources.
|
|||
|
||||
import os
|
||||
import sys
|
||||
import hashlib
|
||||
import shutil
|
||||
import argparse
|
||||
from typing import Optional, List, Dict, Tuple
|
||||
from pathlib import Path
|
||||
|
|
@ -15,6 +17,118 @@ from pathlib import Path
|
|||
from loguru import logger
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Model Code File Sync (GitHub repo -> checkpoint directories)
|
||||
# =============================================================================
|
||||
|
||||
# Mapping from checkpoint directory name to source model variant in acestep/models/
|
||||
_CHECKPOINT_TO_VARIANT: Dict[str, str] = {
|
||||
"acestep-v15-turbo": "turbo",
|
||||
"acestep-v15-sft": "sft",
|
||||
"acestep-v15-base": "base",
|
||||
# SFT variants (base-SFT uses the same model code as SFT)
|
||||
"acestep-v15-base-sft-fix-inst": "sft",
|
||||
# Turbo variants all share the turbo model code
|
||||
"acestep-v15-turbo-shift1": "turbo",
|
||||
"acestep-v15-turbo-shift3": "turbo",
|
||||
"acestep-v15-turbo-continuous": "turbo",
|
||||
"acestep-v15-turbo-fix-inst-shift3": "turbo",
|
||||
"acestep-v15-turbo-fix-inst-shift-continous": "turbo",
|
||||
"acestep-v15-turbo-fix-inst-shift-dynamic": "turbo",
|
||||
"acestep-v15-turbo-rl": "turbo",
|
||||
}
|
||||
|
||||
|
||||
def _get_models_source_dir() -> Path:
|
||||
"""Get the acestep/models/ directory (authoritative source for model code)."""
|
||||
return Path(__file__).resolve().parent / "models"
|
||||
|
||||
|
||||
def _file_hash(filepath: Path) -> str:
|
||||
"""Compute SHA-256 hash of a file's contents."""
|
||||
h = hashlib.sha256()
|
||||
with open(filepath, "rb") as f:
|
||||
for chunk in iter(lambda: f.read(8192), b""):
|
||||
h.update(chunk)
|
||||
return h.hexdigest()
|
||||
|
||||
|
||||
def _check_code_mismatch(model_name: str, checkpoints_dir) -> List[str]:
|
||||
"""
|
||||
Compare .py files in acestep/models/{variant}/ with those in the checkpoint directory.
|
||||
|
||||
Args:
|
||||
model_name: Checkpoint directory name (e.g. "acestep-v15-turbo")
|
||||
checkpoints_dir: Path to the checkpoints root directory
|
||||
|
||||
Returns:
|
||||
List of filenames that differ (empty list if all match or model_name is unknown)
|
||||
"""
|
||||
variant = _CHECKPOINT_TO_VARIANT.get(model_name)
|
||||
if variant is None:
|
||||
return []
|
||||
|
||||
source_dir = _get_models_source_dir() / variant
|
||||
if not source_dir.exists():
|
||||
return []
|
||||
|
||||
if isinstance(checkpoints_dir, str):
|
||||
checkpoints_dir = Path(checkpoints_dir)
|
||||
target_dir = checkpoints_dir / model_name
|
||||
|
||||
mismatched = []
|
||||
for src_file in source_dir.glob("*.py"):
|
||||
if src_file.name == "__init__.py":
|
||||
continue
|
||||
dst_file = target_dir / src_file.name
|
||||
if not dst_file.exists():
|
||||
mismatched.append(src_file.name)
|
||||
elif _file_hash(src_file) != _file_hash(dst_file):
|
||||
mismatched.append(src_file.name)
|
||||
|
||||
return mismatched
|
||||
|
||||
|
||||
def _sync_model_code_files(model_name: str, checkpoints_dir) -> List[str]:
|
||||
"""
|
||||
Copy .py files from acestep/models/{variant}/ into the checkpoint directory,
|
||||
overwriting the HuggingFace-downloaded versions.
|
||||
|
||||
Args:
|
||||
model_name: Checkpoint directory name (e.g. "acestep-v15-turbo")
|
||||
checkpoints_dir: Path to the checkpoints root directory
|
||||
|
||||
Returns:
|
||||
List of filenames that were synced (empty if model_name is unknown or no source)
|
||||
"""
|
||||
variant = _CHECKPOINT_TO_VARIANT.get(model_name)
|
||||
if variant is None:
|
||||
return []
|
||||
|
||||
source_dir = _get_models_source_dir() / variant
|
||||
if not source_dir.exists():
|
||||
logger.warning(f"[Model Sync] Source directory not found: {source_dir}")
|
||||
return []
|
||||
|
||||
if isinstance(checkpoints_dir, str):
|
||||
checkpoints_dir = Path(checkpoints_dir)
|
||||
target_dir = checkpoints_dir / model_name
|
||||
if not target_dir.exists():
|
||||
logger.warning(f"[Model Sync] Target directory not found: {target_dir}")
|
||||
return []
|
||||
|
||||
synced = []
|
||||
for src_file in source_dir.glob("*.py"):
|
||||
if src_file.name == "__init__.py":
|
||||
continue
|
||||
dst_file = target_dir / src_file.name
|
||||
shutil.copy2(src_file, dst_file)
|
||||
synced.append(src_file.name)
|
||||
logger.debug(f"[Model Sync] Synced {src_file.name} -> {dst_file}")
|
||||
|
||||
return synced
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Network Detection & Smart Download
|
||||
# =============================================================================
|
||||
|
|
@ -302,7 +416,15 @@ def download_main_model(
|
|||
print("This may take a while depending on your internet connection...")
|
||||
|
||||
# Use smart download with automatic fallback
|
||||
return _smart_download(MAIN_MODEL_REPO, checkpoints_dir, token, prefer_source)
|
||||
success, msg = _smart_download(MAIN_MODEL_REPO, checkpoints_dir, token, prefer_source)
|
||||
if success:
|
||||
# Sync model code files for all DiT components in the main model
|
||||
for component in MAIN_MODEL_COMPONENTS:
|
||||
if component in _CHECKPOINT_TO_VARIANT:
|
||||
synced = _sync_model_code_files(component, checkpoints_dir)
|
||||
if synced:
|
||||
logger.info(f"[Model Download] Synced code files for {component}: {synced}")
|
||||
return success, msg
|
||||
|
||||
|
||||
def download_submodel(
|
||||
|
|
@ -348,7 +470,13 @@ def download_submodel(
|
|||
print(f"Destination: {model_path}")
|
||||
|
||||
# Use smart download with automatic fallback
|
||||
return _smart_download(repo_id, model_path, token, prefer_source)
|
||||
success, msg = _smart_download(repo_id, model_path, token, prefer_source)
|
||||
if success and model_name in _CHECKPOINT_TO_VARIANT:
|
||||
# Sync model code files after successful download
|
||||
synced = _sync_model_code_files(model_name, checkpoints_dir)
|
||||
if synced:
|
||||
logger.info(f"[Model Download] Synced code files for {model_name}: {synced}")
|
||||
return success, msg
|
||||
|
||||
|
||||
def download_all_models(
|
||||
|
|
|
|||
3
acestep/models/__init__.py
Normal file
3
acestep/models/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
# ACE-Step model definitions
|
||||
# These files are the authoritative source for model code.
|
||||
# They are auto-synced to checkpoint directories on startup.
|
||||
0
acestep/models/base/__init__.py
Normal file
0
acestep/models/base/__init__.py
Normal file
220
acestep/models/base/apg_guidance.py
Normal file
220
acestep/models/base/apg_guidance.py
Normal file
|
|
@ -0,0 +1,220 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class MomentumBuffer:
|
||||
|
||||
def __init__(self, momentum: float = -0.75):
|
||||
self.momentum = momentum
|
||||
self.running_average = 0
|
||||
|
||||
def update(self, update_value: torch.Tensor):
|
||||
new_average = self.momentum * self.running_average
|
||||
self.running_average = update_value + new_average
|
||||
|
||||
|
||||
def project(
|
||||
v0: torch.Tensor, # [B, C, T]
|
||||
v1: torch.Tensor, # [B, C, T]
|
||||
dims=[-1],
|
||||
):
|
||||
dtype = v0.dtype
|
||||
device_type = v0.device.type
|
||||
if device_type == "mps":
|
||||
v0, v1 = v0.cpu(), v1.cpu()
|
||||
|
||||
v0, v1 = v0.double(), v1.double()
|
||||
v1 = torch.nn.functional.normalize(v1, dim=dims)
|
||||
v0_parallel = (v0 * v1).sum(dim=dims, keepdim=True) * v1
|
||||
v0_orthogonal = v0 - v0_parallel
|
||||
return v0_parallel.to(dtype).to(device_type), v0_orthogonal.to(dtype).to(device_type)
|
||||
|
||||
|
||||
def apg_forward(
|
||||
pred_cond: torch.Tensor, # [B, C, T]
|
||||
pred_uncond: torch.Tensor, # [B, C, T]
|
||||
guidance_scale: float,
|
||||
momentum_buffer: MomentumBuffer = None,
|
||||
eta: float = 0.0,
|
||||
norm_threshold: float = 2.5,
|
||||
dims=[-1],
|
||||
):
|
||||
diff = pred_cond - pred_uncond
|
||||
if momentum_buffer is not None:
|
||||
momentum_buffer.update(diff)
|
||||
diff = momentum_buffer.running_average
|
||||
|
||||
if norm_threshold > 0:
|
||||
ones = torch.ones_like(diff)
|
||||
diff_norm = diff.norm(p=2, dim=dims, keepdim=True)
|
||||
scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
|
||||
diff = diff * scale_factor
|
||||
|
||||
diff_parallel, diff_orthogonal = project(diff, pred_cond, dims)
|
||||
normalized_update = diff_orthogonal + eta * diff_parallel
|
||||
pred_guided = pred_cond + (guidance_scale - 1) * normalized_update
|
||||
return pred_guided
|
||||
|
||||
|
||||
def cfg_forward(cond_output, uncond_output, cfg_strength):
|
||||
return uncond_output + cfg_strength * (cond_output - uncond_output)
|
||||
|
||||
|
||||
def call_cos_tensor(tensor1, tensor2):
|
||||
"""
|
||||
Calculate cosine similarity between two normalized tensors.
|
||||
|
||||
Args:
|
||||
tensor1: First tensor [B, ...]
|
||||
tensor2: Second tensor [B, ...]
|
||||
|
||||
Returns:
|
||||
Cosine similarity value [B, 1]
|
||||
"""
|
||||
tensor1 = tensor1 / torch.linalg.norm(tensor1, dim=1, keepdim=True)
|
||||
tensor2 = tensor2 / torch.linalg.norm(tensor2, dim=1, keepdim=True)
|
||||
cosvalue = torch.sum(tensor1 * tensor2, dim=1, keepdim=True)
|
||||
return cosvalue
|
||||
|
||||
|
||||
def compute_perpendicular_component(latent_diff, latent_hat_uncond):
|
||||
"""
|
||||
Decompose latent_diff into parallel and perpendicular components relative to latent_hat_uncond.
|
||||
|
||||
Args:
|
||||
latent_diff: Difference tensor [B, C, ...]
|
||||
latent_hat_uncond: Unconditional prediction tensor [B, C, ...]
|
||||
|
||||
Returns:
|
||||
projection: Parallel component
|
||||
perpendicular_component: Perpendicular component
|
||||
"""
|
||||
n, t, c = latent_diff.shape
|
||||
latent_diff = latent_diff.view(n * t, c).float()
|
||||
latent_hat_uncond = latent_hat_uncond.view(n * t, c).float()
|
||||
|
||||
if latent_diff.size() != latent_hat_uncond.size():
|
||||
raise ValueError("latent_diff and latent_hat_uncond must have the same shape [n, d].")
|
||||
|
||||
dot_product = torch.sum(latent_diff * latent_hat_uncond, dim=1, keepdim=True) # [n, 1]
|
||||
norm_square = torch.sum(latent_hat_uncond * latent_hat_uncond, dim=1, keepdim=True) # [n, 1]
|
||||
projection = (dot_product / (norm_square + 1e-8)) * latent_hat_uncond
|
||||
perpendicular_component = latent_diff - projection
|
||||
|
||||
return projection.view(n, t, c), perpendicular_component.reshape(n, t, c)
|
||||
|
||||
|
||||
def adg_forward(
|
||||
latents: torch.Tensor,
|
||||
noise_pred_cond: torch.Tensor,
|
||||
noise_pred_uncond: torch.Tensor,
|
||||
sigma: torch.Tensor,
|
||||
guidance_scale: float,
|
||||
angle_clip: float = 3.14 / 6, # pi/6 by default
|
||||
apply_norm: bool = False,
|
||||
apply_clip: bool = True,
|
||||
):
|
||||
"""
|
||||
ADG (Angle-based Dynamic Guidance) forward pass for Flow Matching.
|
||||
|
||||
In flow matching (including SD3), sigma represents the current timestep t_curr.
|
||||
The predictions are velocity fields v(x_t, t).
|
||||
|
||||
Args:
|
||||
latents: Current state x_t [N, T, d] where d=64
|
||||
noise_pred_cond: Conditional velocity prediction v_cond [N, T, d]
|
||||
noise_pred_uncond: Unconditional velocity prediction v_uncond [N, T, d]
|
||||
sigma: Current timestep t_curr (not t_prev!)
|
||||
guidance_scale: Guidance strength
|
||||
angle_clip: Maximum angle for clipping (default: pi/6)
|
||||
apply_norm: Whether to normalize the result (ADG_w_norm variant)
|
||||
apply_clip: Whether to clip the angle (ADG_wo_clip when False)
|
||||
|
||||
Returns:
|
||||
Guided velocity prediction [N, T, d]
|
||||
"""
|
||||
# Get batch size
|
||||
n = noise_pred_cond.shape[0]
|
||||
noise_pred_text = noise_pred_cond
|
||||
n, t, c = noise_pred_text.shape
|
||||
|
||||
# Ensure sigma/t has the right shape for broadcasting [N, 1, 1]
|
||||
if isinstance(sigma, (int, float)):
|
||||
sigma = torch.tensor(sigma, device=latents.device, dtype=latents.dtype)
|
||||
sigma = sigma.view(1, 1, 1).expand(n, 1, 1)
|
||||
elif torch.is_tensor(sigma):
|
||||
if sigma.numel() == 1:
|
||||
sigma = sigma.view(1, 1, 1).expand(n, 1, 1)
|
||||
elif sigma.numel() == n:
|
||||
sigma = sigma.view(n, 1, 1)
|
||||
else:
|
||||
raise ValueError(f"sigma has incompatible shape. Expected scalar or size {n}, got {sigma.shape}")
|
||||
else:
|
||||
raise TypeError(f"sigma must be a number or tensor, got {type(sigma)}")
|
||||
|
||||
# Adjust guidance weight
|
||||
weight = guidance_scale - 1
|
||||
weight = weight * (weight > 0) + 1e-3
|
||||
|
||||
latent_hat_text = latents - sigma * noise_pred_text
|
||||
latent_hat_uncond = latents - sigma * noise_pred_uncond
|
||||
latent_diff = latent_hat_text - latent_hat_uncond
|
||||
|
||||
# Calculate angle between conditional and unconditional predicted data
|
||||
latent_theta = torch.acos(
|
||||
call_cos_tensor(latent_hat_text.view(-1, c).to(float),
|
||||
latent_hat_uncond.reshape(-1, c).contiguous().to(float)))
|
||||
latent_theta_new = torch.clip(weight * latent_theta, -angle_clip, angle_clip) if apply_clip else weight * latent_theta
|
||||
proj, perp = compute_perpendicular_component(latent_diff, latent_hat_uncond)
|
||||
latent_v_new = torch.cos(latent_theta_new) * latent_hat_text
|
||||
|
||||
latent_p_new = perp * torch.sin(latent_theta_new) / torch.sin(latent_theta) * (
|
||||
torch.sin(latent_theta) > 1e-3) + perp * weight * (torch.sin(latent_theta) <= 1e-3)
|
||||
latent_new = latent_v_new + latent_p_new
|
||||
if apply_norm:
|
||||
latent_new = latent_new * torch.linalg.norm(latent_hat_text, dim=1, keepdim=True) / torch.linalg.norm(
|
||||
latent_new, dim=1, keepdim=True)
|
||||
|
||||
noise_pred = (latents - latent_new) / sigma
|
||||
noise_pred = noise_pred.reshape(n, t, c).to(latents.dtype)
|
||||
return noise_pred
|
||||
|
||||
|
||||
def adg_w_norm_forward(
|
||||
latents: torch.Tensor,
|
||||
noise_pred_cond: torch.Tensor,
|
||||
noise_pred_uncond: torch.Tensor,
|
||||
sigma: float,
|
||||
guidance_scale: float,
|
||||
angle_clip: float = 3.14 / 3,
|
||||
):
|
||||
"""
|
||||
ADG with normalization - preserves the magnitude of latent predictions.
|
||||
|
||||
This variant normalizes the final latent to maintain the same norm as the
|
||||
conditional prediction, which can help preserve image quality.
|
||||
"""
|
||||
return adg_forward(latents,
|
||||
noise_pred_cond,
|
||||
noise_pred_uncond,
|
||||
sigma,
|
||||
guidance_scale,
|
||||
angle_clip=angle_clip,
|
||||
apply_norm=True,
|
||||
apply_clip=True)
|
||||
|
||||
|
||||
def adg_wo_clip_forward(
|
||||
latents: torch.Tensor,
|
||||
noise_pred_cond: torch.Tensor,
|
||||
noise_pred_uncond: torch.Tensor,
|
||||
sigma: float,
|
||||
guidance_scale: float,
|
||||
):
|
||||
"""
|
||||
ADG without angle clipping - allows unbounded angle adjustments.
|
||||
|
||||
This variant doesn't clip the angle, which may result in more aggressive
|
||||
guidance but could be less stable.
|
||||
"""
|
||||
return adg_forward(latents, noise_pred_cond, noise_pred_uncond, sigma, guidance_scale, apply_norm=False, apply_clip=False)
|
||||
263
acestep/models/base/configuration_acestep_v15.py
Normal file
263
acestep/models/base/configuration_acestep_v15.py
Normal file
|
|
@ -0,0 +1,263 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""AceStep model configuration"""
|
||||
|
||||
from transformers.configuration_utils import PretrainedConfig, layer_type_validation
|
||||
from transformers.modeling_rope_utils import rope_config_validation
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class AceStepConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`AceStepModel`]. It is used to instantiate an
|
||||
AceStep model according to the specified arguments, defining the model architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 64003):
|
||||
Vocabulary size of the AceStep model. Defines the number of different tokens that can be represented by the
|
||||
`inputs_ids` passed when calling the model.
|
||||
hidden_size (`int`, *optional*, defaults to 4096):
|
||||
Dimension of the hidden representations.
|
||||
intermediate_size (`int`, *optional*, defaults to 22016):
|
||||
Dimension of the MLP representations.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 32):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 32):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
num_key_value_heads (`int`, *optional*, defaults to 32):
|
||||
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||
by meanpooling all the original heads within that group. For more details, check out [this
|
||||
paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `32`.
|
||||
head_dim (`int`, *optional*, defaults to 128):
|
||||
The attention head dimension.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
||||
The non-linear activation function (function or string) in the decoder.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 32768):
|
||||
The maximum sequence length that this model might ever be used with.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||
The epsilon used by the rms normalization layers.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||
relevant if `config.is_decoder=True`.
|
||||
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
||||
Whether the model's input and output word embeddings should be tied.
|
||||
rope_theta (`float`, *optional*, defaults to 10000.0):
|
||||
The base period of the RoPE embeddings.
|
||||
rope_scaling (`Dict`, *optional*):
|
||||
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
||||
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
||||
accordingly.
|
||||
Expected contents:
|
||||
`rope_type` (`str`):
|
||||
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
||||
'llama3'], with 'default' being the original RoPE implementation.
|
||||
`factor` (`float`, *optional*):
|
||||
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
||||
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
||||
original maximum pre-trained length.
|
||||
`original_max_position_embeddings` (`int`, *optional*):
|
||||
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
|
||||
pretraining.
|
||||
`attention_factor` (`float`, *optional*):
|
||||
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
||||
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
||||
`factor` field to infer the suggested value.
|
||||
`beta_fast` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 32.
|
||||
`beta_slow` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 1.
|
||||
`short_factor` (`list[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`long_factor` (`list[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`low_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
|
||||
`high_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
|
||||
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
|
||||
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
||||
use_sliding_window (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use sliding window attention.
|
||||
sliding_window (`int`, *optional*, defaults to 4096):
|
||||
Sliding window attention (SWA) window size. If not specified, will default to `4096`.
|
||||
layer_types (`list`, *optional*):
|
||||
Attention pattern for each layer.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
|
||||
```python
|
||||
>>> from acestep.models import AceStepConfig
|
||||
|
||||
>>> # Initializing an AceStep configuration
|
||||
>>> configuration = AceStepConfig()
|
||||
|
||||
>>> # Initializing a model from the configuration
|
||||
>>> model = AceStepConditionGenerationModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "acestep"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
# Default tensor parallel plan for the base model
|
||||
base_model_tp_plan = {
|
||||
"layers.*.self_attn.q_proj": "colwise",
|
||||
"layers.*.self_attn.k_proj": "colwise",
|
||||
"layers.*.self_attn.v_proj": "colwise",
|
||||
"layers.*.self_attn.o_proj": "rowwise",
|
||||
"layers.*.mlp.gate_proj": "colwise",
|
||||
"layers.*.mlp.up_proj": "colwise",
|
||||
"layers.*.mlp.down_proj": "rowwise",
|
||||
}
|
||||
base_model_pp_plan = {
|
||||
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||
"norm": (["hidden_states"], ["hidden_states"]),
|
||||
}
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=64003,
|
||||
fsq_dim=2048,
|
||||
fsq_input_levels=[8, 8, 8, 5, 5, 5],
|
||||
fsq_input_num_quantizers=1,
|
||||
hidden_size=2048,
|
||||
intermediate_size=6144,
|
||||
num_hidden_layers=24,
|
||||
num_attention_heads=16,
|
||||
num_key_value_heads=8,
|
||||
head_dim=128,
|
||||
hidden_act="silu",
|
||||
max_position_embeddings=32768,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-6,
|
||||
use_cache=True,
|
||||
tie_word_embeddings=True,
|
||||
rope_theta=1000000,
|
||||
rope_scaling=None,
|
||||
attention_bias=False,
|
||||
use_sliding_window=True,
|
||||
sliding_window=128,
|
||||
layer_types=None,
|
||||
attention_dropout=0.0,
|
||||
num_lyric_encoder_hidden_layers=8,
|
||||
audio_acoustic_hidden_dim=64,
|
||||
pool_window_size=5,
|
||||
text_hidden_dim=1024,
|
||||
in_channels=192,
|
||||
data_proportion=0.5,
|
||||
timestep_mu=-0.4,
|
||||
timestep_sigma=1.0,
|
||||
timbre_hidden_dim=64,
|
||||
num_timbre_encoder_hidden_layers=4,
|
||||
timbre_fix_frame=750,
|
||||
patch_size=2,
|
||||
num_attention_pooler_hidden_layers=2,
|
||||
num_audio_decoder_hidden_layers=24,
|
||||
model_version="turbo",
|
||||
**kwargs,
|
||||
):
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.use_sliding_window = use_sliding_window
|
||||
self.sliding_window = sliding_window if self.use_sliding_window else None
|
||||
|
||||
# Text encoder configuration
|
||||
self.text_hidden_dim = text_hidden_dim
|
||||
|
||||
# Lyric encoder configuration
|
||||
self.num_lyric_encoder_hidden_layers = num_lyric_encoder_hidden_layers
|
||||
self.patch_size = patch_size
|
||||
|
||||
# Audio semantic token generation configuration
|
||||
self.audio_acoustic_hidden_dim = audio_acoustic_hidden_dim
|
||||
self.pool_window_size = pool_window_size
|
||||
self.in_channels = in_channels
|
||||
self.data_proportion = data_proportion
|
||||
self.timestep_mu = timestep_mu
|
||||
self.timestep_sigma = timestep_sigma
|
||||
|
||||
# FSQ (Finite Scalar Quantization) configuration
|
||||
self.fsq_dim = fsq_dim
|
||||
self.fsq_input_levels = fsq_input_levels
|
||||
self.fsq_input_num_quantizers = fsq_input_num_quantizers
|
||||
|
||||
# Timbre encoder configuration
|
||||
self.timbre_hidden_dim = timbre_hidden_dim
|
||||
self.num_timbre_encoder_hidden_layers = num_timbre_encoder_hidden_layers
|
||||
self.timbre_fix_frame = timbre_fix_frame
|
||||
self.num_attention_pooler_hidden_layers = num_attention_pooler_hidden_layers
|
||||
self.num_audio_decoder_hidden_layers = num_audio_decoder_hidden_layers
|
||||
self.vocab_size = vocab_size
|
||||
|
||||
# Backward compatibility: ensure num_key_value_heads is set
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = num_attention_heads
|
||||
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.head_dim = head_dim
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
self.attention_bias = attention_bias
|
||||
self.attention_dropout = attention_dropout
|
||||
self.model_version = model_version
|
||||
|
||||
# Validate rotary position embeddings parameters
|
||||
# Backward compatibility: if there is a 'type' field, move it to 'rope_type'
|
||||
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
||||
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
||||
rope_config_validation(self)
|
||||
|
||||
self.layer_types = layer_types
|
||||
|
||||
# Set default layer types if not specified
|
||||
if self.layer_types is None:
|
||||
self.layer_types = [
|
||||
"sliding_attention" if bool((i + 1) % 2) else "full_attention" for i in range(self.num_hidden_layers)
|
||||
]
|
||||
layer_type_validation(self.layer_types)
|
||||
|
||||
super().__init__(
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["AceStepConfig"]
|
||||
2144
acestep/models/base/modeling_acestep_v15_base.py
Normal file
2144
acestep/models/base/modeling_acestep_v15_base.py
Normal file
File diff suppressed because it is too large
Load diff
0
acestep/models/sft/__init__.py
Normal file
0
acestep/models/sft/__init__.py
Normal file
220
acestep/models/sft/apg_guidance.py
Normal file
220
acestep/models/sft/apg_guidance.py
Normal file
|
|
@ -0,0 +1,220 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class MomentumBuffer:
|
||||
|
||||
def __init__(self, momentum: float = -0.75):
|
||||
self.momentum = momentum
|
||||
self.running_average = 0
|
||||
|
||||
def update(self, update_value: torch.Tensor):
|
||||
new_average = self.momentum * self.running_average
|
||||
self.running_average = update_value + new_average
|
||||
|
||||
|
||||
def project(
|
||||
v0: torch.Tensor, # [B, C, T]
|
||||
v1: torch.Tensor, # [B, C, T]
|
||||
dims=[-1],
|
||||
):
|
||||
dtype = v0.dtype
|
||||
device_type = v0.device.type
|
||||
if device_type == "mps":
|
||||
v0, v1 = v0.cpu(), v1.cpu()
|
||||
|
||||
v0, v1 = v0.double(), v1.double()
|
||||
v1 = torch.nn.functional.normalize(v1, dim=dims)
|
||||
v0_parallel = (v0 * v1).sum(dim=dims, keepdim=True) * v1
|
||||
v0_orthogonal = v0 - v0_parallel
|
||||
return v0_parallel.to(dtype).to(device_type), v0_orthogonal.to(dtype).to(device_type)
|
||||
|
||||
|
||||
def apg_forward(
|
||||
pred_cond: torch.Tensor, # [B, C, T]
|
||||
pred_uncond: torch.Tensor, # [B, C, T]
|
||||
guidance_scale: float,
|
||||
momentum_buffer: MomentumBuffer = None,
|
||||
eta: float = 0.0,
|
||||
norm_threshold: float = 2.5,
|
||||
dims=[-1],
|
||||
):
|
||||
diff = pred_cond - pred_uncond
|
||||
if momentum_buffer is not None:
|
||||
momentum_buffer.update(diff)
|
||||
diff = momentum_buffer.running_average
|
||||
|
||||
if norm_threshold > 0:
|
||||
ones = torch.ones_like(diff)
|
||||
diff_norm = diff.norm(p=2, dim=dims, keepdim=True)
|
||||
scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
|
||||
diff = diff * scale_factor
|
||||
|
||||
diff_parallel, diff_orthogonal = project(diff, pred_cond, dims)
|
||||
normalized_update = diff_orthogonal + eta * diff_parallel
|
||||
pred_guided = pred_cond + (guidance_scale - 1) * normalized_update
|
||||
return pred_guided
|
||||
|
||||
|
||||
def cfg_forward(cond_output, uncond_output, cfg_strength):
|
||||
return uncond_output + cfg_strength * (cond_output - uncond_output)
|
||||
|
||||
|
||||
def call_cos_tensor(tensor1, tensor2):
|
||||
"""
|
||||
Calculate cosine similarity between two normalized tensors.
|
||||
|
||||
Args:
|
||||
tensor1: First tensor [B, ...]
|
||||
tensor2: Second tensor [B, ...]
|
||||
|
||||
Returns:
|
||||
Cosine similarity value [B, 1]
|
||||
"""
|
||||
tensor1 = tensor1 / torch.linalg.norm(tensor1, dim=1, keepdim=True)
|
||||
tensor2 = tensor2 / torch.linalg.norm(tensor2, dim=1, keepdim=True)
|
||||
cosvalue = torch.sum(tensor1 * tensor2, dim=1, keepdim=True)
|
||||
return cosvalue
|
||||
|
||||
|
||||
def compute_perpendicular_component(latent_diff, latent_hat_uncond):
|
||||
"""
|
||||
Decompose latent_diff into parallel and perpendicular components relative to latent_hat_uncond.
|
||||
|
||||
Args:
|
||||
latent_diff: Difference tensor [B, C, ...]
|
||||
latent_hat_uncond: Unconditional prediction tensor [B, C, ...]
|
||||
|
||||
Returns:
|
||||
projection: Parallel component
|
||||
perpendicular_component: Perpendicular component
|
||||
"""
|
||||
n, t, c = latent_diff.shape
|
||||
latent_diff = latent_diff.view(n * t, c).float()
|
||||
latent_hat_uncond = latent_hat_uncond.view(n * t, c).float()
|
||||
|
||||
if latent_diff.size() != latent_hat_uncond.size():
|
||||
raise ValueError("latent_diff and latent_hat_uncond must have the same shape [n, d].")
|
||||
|
||||
dot_product = torch.sum(latent_diff * latent_hat_uncond, dim=1, keepdim=True) # [n, 1]
|
||||
norm_square = torch.sum(latent_hat_uncond * latent_hat_uncond, dim=1, keepdim=True) # [n, 1]
|
||||
projection = (dot_product / (norm_square + 1e-8)) * latent_hat_uncond
|
||||
perpendicular_component = latent_diff - projection
|
||||
|
||||
return projection.view(n, t, c), perpendicular_component.reshape(n, t, c)
|
||||
|
||||
|
||||
def adg_forward(
|
||||
latents: torch.Tensor,
|
||||
noise_pred_cond: torch.Tensor,
|
||||
noise_pred_uncond: torch.Tensor,
|
||||
sigma: torch.Tensor,
|
||||
guidance_scale: float,
|
||||
angle_clip: float = 3.14 / 6, # pi/6 by default
|
||||
apply_norm: bool = False,
|
||||
apply_clip: bool = True,
|
||||
):
|
||||
"""
|
||||
ADG (Angle-based Dynamic Guidance) forward pass for Flow Matching.
|
||||
|
||||
In flow matching (including SD3), sigma represents the current timestep t_curr.
|
||||
The predictions are velocity fields v(x_t, t).
|
||||
|
||||
Args:
|
||||
latents: Current state x_t [N, T, d] where d=64
|
||||
noise_pred_cond: Conditional velocity prediction v_cond [N, T, d]
|
||||
noise_pred_uncond: Unconditional velocity prediction v_uncond [N, T, d]
|
||||
sigma: Current timestep t_curr (not t_prev!)
|
||||
guidance_scale: Guidance strength
|
||||
angle_clip: Maximum angle for clipping (default: pi/6)
|
||||
apply_norm: Whether to normalize the result (ADG_w_norm variant)
|
||||
apply_clip: Whether to clip the angle (ADG_wo_clip when False)
|
||||
|
||||
Returns:
|
||||
Guided velocity prediction [N, T, d]
|
||||
"""
|
||||
# Get batch size
|
||||
n = noise_pred_cond.shape[0]
|
||||
noise_pred_text = noise_pred_cond
|
||||
n, t, c = noise_pred_text.shape
|
||||
|
||||
# Ensure sigma/t has the right shape for broadcasting [N, 1, 1]
|
||||
if isinstance(sigma, (int, float)):
|
||||
sigma = torch.tensor(sigma, device=latents.device, dtype=latents.dtype)
|
||||
sigma = sigma.view(1, 1, 1).expand(n, 1, 1)
|
||||
elif torch.is_tensor(sigma):
|
||||
if sigma.numel() == 1:
|
||||
sigma = sigma.view(1, 1, 1).expand(n, 1, 1)
|
||||
elif sigma.numel() == n:
|
||||
sigma = sigma.view(n, 1, 1)
|
||||
else:
|
||||
raise ValueError(f"sigma has incompatible shape. Expected scalar or size {n}, got {sigma.shape}")
|
||||
else:
|
||||
raise TypeError(f"sigma must be a number or tensor, got {type(sigma)}")
|
||||
|
||||
# Adjust guidance weight
|
||||
weight = guidance_scale - 1
|
||||
weight = weight * (weight > 0) + 1e-3
|
||||
|
||||
latent_hat_text = latents - sigma * noise_pred_text
|
||||
latent_hat_uncond = latents - sigma * noise_pred_uncond
|
||||
latent_diff = latent_hat_text - latent_hat_uncond
|
||||
|
||||
# Calculate angle between conditional and unconditional predicted data
|
||||
latent_theta = torch.acos(
|
||||
call_cos_tensor(latent_hat_text.view(-1, c).to(float),
|
||||
latent_hat_uncond.reshape(-1, c).contiguous().to(float)))
|
||||
latent_theta_new = torch.clip(weight * latent_theta, -angle_clip, angle_clip) if apply_clip else weight * latent_theta
|
||||
proj, perp = compute_perpendicular_component(latent_diff, latent_hat_uncond)
|
||||
latent_v_new = torch.cos(latent_theta_new) * latent_hat_text
|
||||
|
||||
latent_p_new = perp * torch.sin(latent_theta_new) / torch.sin(latent_theta) * (
|
||||
torch.sin(latent_theta) > 1e-3) + perp * weight * (torch.sin(latent_theta) <= 1e-3)
|
||||
latent_new = latent_v_new + latent_p_new
|
||||
if apply_norm:
|
||||
latent_new = latent_new * torch.linalg.norm(latent_hat_text, dim=1, keepdim=True) / torch.linalg.norm(
|
||||
latent_new, dim=1, keepdim=True)
|
||||
|
||||
noise_pred = (latents - latent_new) / sigma
|
||||
noise_pred = noise_pred.reshape(n, t, c).to(latents.dtype)
|
||||
return noise_pred
|
||||
|
||||
|
||||
def adg_w_norm_forward(
|
||||
latents: torch.Tensor,
|
||||
noise_pred_cond: torch.Tensor,
|
||||
noise_pred_uncond: torch.Tensor,
|
||||
sigma: float,
|
||||
guidance_scale: float,
|
||||
angle_clip: float = 3.14 / 3,
|
||||
):
|
||||
"""
|
||||
ADG with normalization - preserves the magnitude of latent predictions.
|
||||
|
||||
This variant normalizes the final latent to maintain the same norm as the
|
||||
conditional prediction, which can help preserve image quality.
|
||||
"""
|
||||
return adg_forward(latents,
|
||||
noise_pred_cond,
|
||||
noise_pred_uncond,
|
||||
sigma,
|
||||
guidance_scale,
|
||||
angle_clip=angle_clip,
|
||||
apply_norm=True,
|
||||
apply_clip=True)
|
||||
|
||||
|
||||
def adg_wo_clip_forward(
|
||||
latents: torch.Tensor,
|
||||
noise_pred_cond: torch.Tensor,
|
||||
noise_pred_uncond: torch.Tensor,
|
||||
sigma: float,
|
||||
guidance_scale: float,
|
||||
):
|
||||
"""
|
||||
ADG without angle clipping - allows unbounded angle adjustments.
|
||||
|
||||
This variant doesn't clip the angle, which may result in more aggressive
|
||||
guidance but could be less stable.
|
||||
"""
|
||||
return adg_forward(latents, noise_pred_cond, noise_pred_uncond, sigma, guidance_scale, apply_norm=False, apply_clip=False)
|
||||
263
acestep/models/sft/configuration_acestep_v15.py
Normal file
263
acestep/models/sft/configuration_acestep_v15.py
Normal file
|
|
@ -0,0 +1,263 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""AceStep model configuration"""
|
||||
|
||||
from transformers.configuration_utils import PretrainedConfig, layer_type_validation
|
||||
from transformers.modeling_rope_utils import rope_config_validation
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class AceStepConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`AceStepModel`]. It is used to instantiate an
|
||||
AceStep model according to the specified arguments, defining the model architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 64003):
|
||||
Vocabulary size of the AceStep model. Defines the number of different tokens that can be represented by the
|
||||
`inputs_ids` passed when calling the model.
|
||||
hidden_size (`int`, *optional*, defaults to 4096):
|
||||
Dimension of the hidden representations.
|
||||
intermediate_size (`int`, *optional*, defaults to 22016):
|
||||
Dimension of the MLP representations.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 32):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 32):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
num_key_value_heads (`int`, *optional*, defaults to 32):
|
||||
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||
by meanpooling all the original heads within that group. For more details, check out [this
|
||||
paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `32`.
|
||||
head_dim (`int`, *optional*, defaults to 128):
|
||||
The attention head dimension.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
||||
The non-linear activation function (function or string) in the decoder.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 32768):
|
||||
The maximum sequence length that this model might ever be used with.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||
The epsilon used by the rms normalization layers.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||
relevant if `config.is_decoder=True`.
|
||||
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
||||
Whether the model's input and output word embeddings should be tied.
|
||||
rope_theta (`float`, *optional*, defaults to 10000.0):
|
||||
The base period of the RoPE embeddings.
|
||||
rope_scaling (`Dict`, *optional*):
|
||||
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
||||
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
||||
accordingly.
|
||||
Expected contents:
|
||||
`rope_type` (`str`):
|
||||
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
||||
'llama3'], with 'default' being the original RoPE implementation.
|
||||
`factor` (`float`, *optional*):
|
||||
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
||||
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
||||
original maximum pre-trained length.
|
||||
`original_max_position_embeddings` (`int`, *optional*):
|
||||
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
|
||||
pretraining.
|
||||
`attention_factor` (`float`, *optional*):
|
||||
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
||||
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
||||
`factor` field to infer the suggested value.
|
||||
`beta_fast` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 32.
|
||||
`beta_slow` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 1.
|
||||
`short_factor` (`list[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`long_factor` (`list[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`low_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
|
||||
`high_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
|
||||
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
|
||||
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
||||
use_sliding_window (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use sliding window attention.
|
||||
sliding_window (`int`, *optional*, defaults to 4096):
|
||||
Sliding window attention (SWA) window size. If not specified, will default to `4096`.
|
||||
layer_types (`list`, *optional*):
|
||||
Attention pattern for each layer.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
|
||||
```python
|
||||
>>> from acestep.models import AceStepConfig
|
||||
|
||||
>>> # Initializing an AceStep configuration
|
||||
>>> configuration = AceStepConfig()
|
||||
|
||||
>>> # Initializing a model from the configuration
|
||||
>>> model = AceStepConditionGenerationModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "acestep"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
# Default tensor parallel plan for the base model
|
||||
base_model_tp_plan = {
|
||||
"layers.*.self_attn.q_proj": "colwise",
|
||||
"layers.*.self_attn.k_proj": "colwise",
|
||||
"layers.*.self_attn.v_proj": "colwise",
|
||||
"layers.*.self_attn.o_proj": "rowwise",
|
||||
"layers.*.mlp.gate_proj": "colwise",
|
||||
"layers.*.mlp.up_proj": "colwise",
|
||||
"layers.*.mlp.down_proj": "rowwise",
|
||||
}
|
||||
base_model_pp_plan = {
|
||||
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||
"norm": (["hidden_states"], ["hidden_states"]),
|
||||
}
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=64003,
|
||||
fsq_dim=2048,
|
||||
fsq_input_levels=[8, 8, 8, 5, 5, 5],
|
||||
fsq_input_num_quantizers=1,
|
||||
hidden_size=2048,
|
||||
intermediate_size=6144,
|
||||
num_hidden_layers=24,
|
||||
num_attention_heads=16,
|
||||
num_key_value_heads=8,
|
||||
head_dim=128,
|
||||
hidden_act="silu",
|
||||
max_position_embeddings=32768,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-6,
|
||||
use_cache=True,
|
||||
tie_word_embeddings=True,
|
||||
rope_theta=1000000,
|
||||
rope_scaling=None,
|
||||
attention_bias=False,
|
||||
use_sliding_window=True,
|
||||
sliding_window=128,
|
||||
layer_types=None,
|
||||
attention_dropout=0.0,
|
||||
num_lyric_encoder_hidden_layers=8,
|
||||
audio_acoustic_hidden_dim=64,
|
||||
pool_window_size=5,
|
||||
text_hidden_dim=1024,
|
||||
in_channels=192,
|
||||
data_proportion=0.5,
|
||||
timestep_mu=-0.4,
|
||||
timestep_sigma=1.0,
|
||||
timbre_hidden_dim=64,
|
||||
num_timbre_encoder_hidden_layers=4,
|
||||
timbre_fix_frame=750,
|
||||
patch_size=2,
|
||||
num_attention_pooler_hidden_layers=2,
|
||||
num_audio_decoder_hidden_layers=24,
|
||||
model_version="turbo",
|
||||
**kwargs,
|
||||
):
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.use_sliding_window = use_sliding_window
|
||||
self.sliding_window = sliding_window if self.use_sliding_window else None
|
||||
|
||||
# Text encoder configuration
|
||||
self.text_hidden_dim = text_hidden_dim
|
||||
|
||||
# Lyric encoder configuration
|
||||
self.num_lyric_encoder_hidden_layers = num_lyric_encoder_hidden_layers
|
||||
self.patch_size = patch_size
|
||||
|
||||
# Audio semantic token generation configuration
|
||||
self.audio_acoustic_hidden_dim = audio_acoustic_hidden_dim
|
||||
self.pool_window_size = pool_window_size
|
||||
self.in_channels = in_channels
|
||||
self.data_proportion = data_proportion
|
||||
self.timestep_mu = timestep_mu
|
||||
self.timestep_sigma = timestep_sigma
|
||||
|
||||
# FSQ (Finite Scalar Quantization) configuration
|
||||
self.fsq_dim = fsq_dim
|
||||
self.fsq_input_levels = fsq_input_levels
|
||||
self.fsq_input_num_quantizers = fsq_input_num_quantizers
|
||||
|
||||
# Timbre encoder configuration
|
||||
self.timbre_hidden_dim = timbre_hidden_dim
|
||||
self.num_timbre_encoder_hidden_layers = num_timbre_encoder_hidden_layers
|
||||
self.timbre_fix_frame = timbre_fix_frame
|
||||
self.num_attention_pooler_hidden_layers = num_attention_pooler_hidden_layers
|
||||
self.num_audio_decoder_hidden_layers = num_audio_decoder_hidden_layers
|
||||
self.vocab_size = vocab_size
|
||||
|
||||
# Backward compatibility: ensure num_key_value_heads is set
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = num_attention_heads
|
||||
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.head_dim = head_dim
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
self.attention_bias = attention_bias
|
||||
self.attention_dropout = attention_dropout
|
||||
self.model_version = model_version
|
||||
|
||||
# Validate rotary position embeddings parameters
|
||||
# Backward compatibility: if there is a 'type' field, move it to 'rope_type'
|
||||
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
||||
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
||||
rope_config_validation(self)
|
||||
|
||||
self.layer_types = layer_types
|
||||
|
||||
# Set default layer types if not specified
|
||||
if self.layer_types is None:
|
||||
self.layer_types = [
|
||||
"sliding_attention" if bool((i + 1) % 2) else "full_attention" for i in range(self.num_hidden_layers)
|
||||
]
|
||||
layer_type_validation(self.layer_types)
|
||||
|
||||
super().__init__(
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["AceStepConfig"]
|
||||
2152
acestep/models/sft/modeling_acestep_v15_base.py
Normal file
2152
acestep/models/sft/modeling_acestep_v15_base.py
Normal file
File diff suppressed because it is too large
Load diff
0
acestep/models/turbo/__init__.py
Normal file
0
acestep/models/turbo/__init__.py
Normal file
263
acestep/models/turbo/configuration_acestep_v15.py
Normal file
263
acestep/models/turbo/configuration_acestep_v15.py
Normal file
|
|
@ -0,0 +1,263 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""AceStep model configuration"""
|
||||
|
||||
from transformers.configuration_utils import PretrainedConfig, layer_type_validation
|
||||
from transformers.modeling_rope_utils import rope_config_validation
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class AceStepConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`AceStepModel`]. It is used to instantiate an
|
||||
AceStep model according to the specified arguments, defining the model architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 64003):
|
||||
Vocabulary size of the AceStep model. Defines the number of different tokens that can be represented by the
|
||||
`inputs_ids` passed when calling the model.
|
||||
hidden_size (`int`, *optional*, defaults to 4096):
|
||||
Dimension of the hidden representations.
|
||||
intermediate_size (`int`, *optional*, defaults to 22016):
|
||||
Dimension of the MLP representations.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 32):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 32):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
num_key_value_heads (`int`, *optional*, defaults to 32):
|
||||
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||
by meanpooling all the original heads within that group. For more details, check out [this
|
||||
paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `32`.
|
||||
head_dim (`int`, *optional*, defaults to 128):
|
||||
The attention head dimension.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
||||
The non-linear activation function (function or string) in the decoder.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 32768):
|
||||
The maximum sequence length that this model might ever be used with.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||
The epsilon used by the rms normalization layers.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||
relevant if `config.is_decoder=True`.
|
||||
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
||||
Whether the model's input and output word embeddings should be tied.
|
||||
rope_theta (`float`, *optional*, defaults to 10000.0):
|
||||
The base period of the RoPE embeddings.
|
||||
rope_scaling (`Dict`, *optional*):
|
||||
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
||||
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
||||
accordingly.
|
||||
Expected contents:
|
||||
`rope_type` (`str`):
|
||||
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
||||
'llama3'], with 'default' being the original RoPE implementation.
|
||||
`factor` (`float`, *optional*):
|
||||
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
||||
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
||||
original maximum pre-trained length.
|
||||
`original_max_position_embeddings` (`int`, *optional*):
|
||||
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
|
||||
pretraining.
|
||||
`attention_factor` (`float`, *optional*):
|
||||
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
||||
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
||||
`factor` field to infer the suggested value.
|
||||
`beta_fast` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 32.
|
||||
`beta_slow` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 1.
|
||||
`short_factor` (`list[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`long_factor` (`list[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`low_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
|
||||
`high_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
|
||||
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
|
||||
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
||||
use_sliding_window (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use sliding window attention.
|
||||
sliding_window (`int`, *optional*, defaults to 4096):
|
||||
Sliding window attention (SWA) window size. If not specified, will default to `4096`.
|
||||
layer_types (`list`, *optional*):
|
||||
Attention pattern for each layer.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
|
||||
```python
|
||||
>>> from acestep.models import AceStepConfig
|
||||
|
||||
>>> # Initializing an AceStep configuration
|
||||
>>> configuration = AceStepConfig()
|
||||
|
||||
>>> # Initializing a model from the configuration
|
||||
>>> model = AceStepConditionGenerationModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "acestep"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
# Default tensor parallel plan for the base model
|
||||
base_model_tp_plan = {
|
||||
"layers.*.self_attn.q_proj": "colwise",
|
||||
"layers.*.self_attn.k_proj": "colwise",
|
||||
"layers.*.self_attn.v_proj": "colwise",
|
||||
"layers.*.self_attn.o_proj": "rowwise",
|
||||
"layers.*.mlp.gate_proj": "colwise",
|
||||
"layers.*.mlp.up_proj": "colwise",
|
||||
"layers.*.mlp.down_proj": "rowwise",
|
||||
}
|
||||
base_model_pp_plan = {
|
||||
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||
"norm": (["hidden_states"], ["hidden_states"]),
|
||||
}
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=64003,
|
||||
fsq_dim=2048,
|
||||
fsq_input_levels=[8, 8, 8, 5, 5, 5],
|
||||
fsq_input_num_quantizers=1,
|
||||
hidden_size=2048,
|
||||
intermediate_size=6144,
|
||||
num_hidden_layers=24,
|
||||
num_attention_heads=16,
|
||||
num_key_value_heads=8,
|
||||
head_dim=128,
|
||||
hidden_act="silu",
|
||||
max_position_embeddings=32768,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-6,
|
||||
use_cache=True,
|
||||
tie_word_embeddings=True,
|
||||
rope_theta=1000000,
|
||||
rope_scaling=None,
|
||||
attention_bias=False,
|
||||
use_sliding_window=True,
|
||||
sliding_window=128,
|
||||
layer_types=None,
|
||||
attention_dropout=0.0,
|
||||
num_lyric_encoder_hidden_layers=8,
|
||||
audio_acoustic_hidden_dim=64,
|
||||
pool_window_size=5,
|
||||
text_hidden_dim=1024,
|
||||
in_channels=192,
|
||||
data_proportion=0.5,
|
||||
timestep_mu=-0.4,
|
||||
timestep_sigma=1.0,
|
||||
timbre_hidden_dim=64,
|
||||
num_timbre_encoder_hidden_layers=4,
|
||||
timbre_fix_frame=750,
|
||||
patch_size=2,
|
||||
num_attention_pooler_hidden_layers=2,
|
||||
num_audio_decoder_hidden_layers=24,
|
||||
model_version="turbo",
|
||||
**kwargs,
|
||||
):
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.use_sliding_window = use_sliding_window
|
||||
self.sliding_window = sliding_window if self.use_sliding_window else None
|
||||
|
||||
# Text encoder configuration
|
||||
self.text_hidden_dim = text_hidden_dim
|
||||
|
||||
# Lyric encoder configuration
|
||||
self.num_lyric_encoder_hidden_layers = num_lyric_encoder_hidden_layers
|
||||
self.patch_size = patch_size
|
||||
|
||||
# Audio semantic token generation configuration
|
||||
self.audio_acoustic_hidden_dim = audio_acoustic_hidden_dim
|
||||
self.pool_window_size = pool_window_size
|
||||
self.in_channels = in_channels
|
||||
self.data_proportion = data_proportion
|
||||
self.timestep_mu = timestep_mu
|
||||
self.timestep_sigma = timestep_sigma
|
||||
|
||||
# FSQ (Finite Scalar Quantization) configuration
|
||||
self.fsq_dim = fsq_dim
|
||||
self.fsq_input_levels = fsq_input_levels
|
||||
self.fsq_input_num_quantizers = fsq_input_num_quantizers
|
||||
|
||||
# Timbre encoder configuration
|
||||
self.timbre_hidden_dim = timbre_hidden_dim
|
||||
self.num_timbre_encoder_hidden_layers = num_timbre_encoder_hidden_layers
|
||||
self.timbre_fix_frame = timbre_fix_frame
|
||||
self.num_attention_pooler_hidden_layers = num_attention_pooler_hidden_layers
|
||||
self.num_audio_decoder_hidden_layers = num_audio_decoder_hidden_layers
|
||||
self.vocab_size = vocab_size
|
||||
|
||||
# Backward compatibility: ensure num_key_value_heads is set
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = num_attention_heads
|
||||
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.head_dim = head_dim
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
self.attention_bias = attention_bias
|
||||
self.attention_dropout = attention_dropout
|
||||
self.model_version = model_version
|
||||
|
||||
# Validate rotary position embeddings parameters
|
||||
# Backward compatibility: if there is a 'type' field, move it to 'rope_type'
|
||||
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
||||
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
||||
rope_config_validation(self)
|
||||
|
||||
self.layer_types = layer_types
|
||||
|
||||
# Set default layer types if not specified
|
||||
if self.layer_types is None:
|
||||
self.layer_types = [
|
||||
"sliding_attention" if bool((i + 1) % 2) else "full_attention" for i in range(self.num_hidden_layers)
|
||||
]
|
||||
layer_type_validation(self.layer_types)
|
||||
|
||||
super().__init__(
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["AceStepConfig"]
|
||||
2156
acestep/models/turbo/modeling_acestep_v15_turbo.py
Normal file
2156
acestep/models/turbo/modeling_acestep_v15_turbo.py
Normal file
File diff suppressed because it is too large
Load diff
99
acestep/third_parts/nano-vllm/nanovllm/distributed.py
Normal file
99
acestep/third_parts/nano-vllm/nanovllm/distributed.py
Normal file
|
|
@ -0,0 +1,99 @@
|
|||
"""
|
||||
Distributed utilities for nano-vllm.
|
||||
|
||||
This module provides wrapper functions for torch.distributed that gracefully
|
||||
handle single-GPU mode (world_size == 1) without requiring distributed initialization.
|
||||
"""
|
||||
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
# Global flag to track if distributed is actually initialized
|
||||
_distributed_initialized = False
|
||||
|
||||
|
||||
def initialize_distributed(backend: str, init_method: str, world_size: int, rank: int) -> bool:
|
||||
"""
|
||||
Initialize distributed process group only if world_size > 1.
|
||||
|
||||
Args:
|
||||
backend: Distributed backend (e.g., "nccl" or "gloo")
|
||||
init_method: Initialization method (e.g., "tcp://127.0.0.1:2333")
|
||||
world_size: Total number of processes
|
||||
rank: Rank of current process
|
||||
|
||||
Returns:
|
||||
True if distributed was initialized, False otherwise
|
||||
"""
|
||||
global _distributed_initialized
|
||||
|
||||
if world_size == 1:
|
||||
# Single GPU mode - no distributed needed
|
||||
_distributed_initialized = False
|
||||
return False
|
||||
|
||||
# Multi-GPU mode - initialize distributed
|
||||
dist.init_process_group(backend, init_method, world_size=world_size, rank=rank)
|
||||
_distributed_initialized = True
|
||||
return True
|
||||
|
||||
|
||||
def is_initialized() -> bool:
|
||||
"""Check if distributed is initialized."""
|
||||
return _distributed_initialized
|
||||
|
||||
|
||||
def get_rank() -> int:
|
||||
"""Get current process rank. Returns 0 if distributed is not initialized."""
|
||||
if _distributed_initialized:
|
||||
return dist.get_rank()
|
||||
return 0
|
||||
|
||||
|
||||
def get_world_size() -> int:
|
||||
"""Get world size. Returns 1 if distributed is not initialized."""
|
||||
if _distributed_initialized:
|
||||
return dist.get_world_size()
|
||||
return 1
|
||||
|
||||
|
||||
def barrier():
|
||||
"""Synchronize all processes. No-op if distributed is not initialized."""
|
||||
if _distributed_initialized:
|
||||
dist.barrier()
|
||||
|
||||
|
||||
def all_reduce(tensor, op=None):
|
||||
"""
|
||||
All-reduce operation. No-op if distributed is not initialized.
|
||||
|
||||
Args:
|
||||
tensor: Tensor to reduce
|
||||
op: Reduce operation (default: SUM)
|
||||
"""
|
||||
if _distributed_initialized:
|
||||
if op is None:
|
||||
op = dist.ReduceOp.SUM
|
||||
dist.all_reduce(tensor, op)
|
||||
|
||||
|
||||
def gather(tensor, gather_list=None, dst=0):
|
||||
"""
|
||||
Gather tensors from all processes. No-op if distributed is not initialized.
|
||||
|
||||
Args:
|
||||
tensor: Tensor to gather
|
||||
gather_list: List to gather into (only used on dst rank)
|
||||
dst: Destination rank
|
||||
"""
|
||||
if _distributed_initialized:
|
||||
dist.gather(tensor, gather_list, dst)
|
||||
|
||||
|
||||
def destroy_process_group():
|
||||
"""Destroy process group. No-op if distributed is not initialized."""
|
||||
global _distributed_initialized
|
||||
|
||||
if _distributed_initialized:
|
||||
dist.destroy_process_group()
|
||||
_distributed_initialized = False
|
||||
|
|
@ -8,6 +8,7 @@ import sys
|
|||
|
||||
from nanovllm.config import Config
|
||||
from acestep.debug_utils import debug_start, debug_end
|
||||
from nanovllm import distributed as dist_utils
|
||||
|
||||
# Debug logging - enable with NANOVLLM_DEBUG=1
|
||||
_DEBUG = os.environ.get("NANOVLLM_DEBUG", "0") == "1"
|
||||
|
|
@ -64,11 +65,15 @@ class ModelRunner:
|
|||
self.world_size = config.tensor_parallel_size
|
||||
self.rank = rank
|
||||
self.event = event
|
||||
dist_port = find_available_port()
|
||||
print(f"[debug]dist_port: {dist_port}")
|
||||
# Use gloo backend on Windows, nccl on Linux/other platforms
|
||||
backend = "gloo" if sys.platform == "win32" else "nccl"
|
||||
dist.init_process_group(backend, f"tcp://127.0.0.1:{dist_port}", world_size=self.world_size, rank=rank)
|
||||
|
||||
# Only initialize distributed if world_size > 1
|
||||
if self.world_size > 1:
|
||||
dist_port = find_available_port()
|
||||
print(f"[debug]dist_port: {dist_port}")
|
||||
# Use gloo backend on Windows, nccl on Linux/other platforms
|
||||
backend = "gloo" if sys.platform == "win32" else "nccl"
|
||||
dist_utils.initialize_distributed(backend, f"tcp://127.0.0.1:{dist_port}", world_size=self.world_size, rank=rank)
|
||||
|
||||
torch.cuda.set_device(rank)
|
||||
default_dtype = torch.get_default_dtype()
|
||||
# Use dtype instead of deprecated torch_dtype
|
||||
|
|
@ -119,9 +124,9 @@ class ModelRunner:
|
|||
if self.world_size > 1:
|
||||
if rank == 0:
|
||||
self.shm = SharedMemory(name="nanovllm", create=True, size=2**20)
|
||||
dist.barrier()
|
||||
dist_utils.barrier()
|
||||
else:
|
||||
dist.barrier()
|
||||
dist_utils.barrier()
|
||||
self.shm = SharedMemory(name="nanovllm")
|
||||
self.loop()
|
||||
|
||||
|
|
@ -163,13 +168,13 @@ class ModelRunner:
|
|||
def exit(self):
|
||||
if self.world_size > 1:
|
||||
self.shm.close()
|
||||
dist.barrier()
|
||||
dist_utils.barrier()
|
||||
if self.rank == 0:
|
||||
self.shm.unlink()
|
||||
if not self.enforce_eager:
|
||||
del self.graphs, self.graph_pool
|
||||
torch.cuda.synchronize()
|
||||
dist.destroy_process_group()
|
||||
dist_utils.destroy_process_group()
|
||||
|
||||
def loop(self):
|
||||
while True:
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import torch.nn.functional as F
|
|||
import torch.distributed as dist
|
||||
|
||||
from nanovllm.utils.context import get_context
|
||||
from nanovllm import distributed as dist_utils
|
||||
|
||||
|
||||
class VocabParallelEmbedding(nn.Module):
|
||||
|
|
@ -14,8 +15,8 @@ class VocabParallelEmbedding(nn.Module):
|
|||
embedding_dim: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.tp_rank = dist.get_rank()
|
||||
self.tp_size = dist.get_world_size()
|
||||
self.tp_rank = dist_utils.get_rank()
|
||||
self.tp_size = dist_utils.get_world_size()
|
||||
assert num_embeddings % self.tp_size == 0
|
||||
self.num_embeddings = num_embeddings
|
||||
self.num_embeddings_per_partition = self.num_embeddings // self.tp_size
|
||||
|
|
@ -38,7 +39,7 @@ class VocabParallelEmbedding(nn.Module):
|
|||
y = F.embedding(x, self.weight)
|
||||
if self.tp_size > 1:
|
||||
y = mask.unsqueeze(1) * y
|
||||
dist.all_reduce(y)
|
||||
dist_utils.all_reduce(y)
|
||||
return y
|
||||
|
||||
|
||||
|
|
@ -59,8 +60,10 @@ class ParallelLMHead(VocabParallelEmbedding):
|
|||
last_indices = context.cu_seqlens_q[1:] - 1
|
||||
x = x[last_indices].contiguous()
|
||||
logits = F.linear(x, self.weight)
|
||||
# In multi-GPU mode, gather logits from all ranks and concatenate
|
||||
# In single-GPU mode (tp_size=1), skip gathering and return logits directly
|
||||
if self.tp_size > 1:
|
||||
all_logits = [torch.empty_like(logits) for _ in range(self.tp_size)] if self.tp_rank == 0 else None
|
||||
dist.gather(logits, all_logits, 0)
|
||||
dist_utils.gather(logits, all_logits, 0)
|
||||
logits = torch.cat(all_logits, -1) if self.tp_rank == 0 else None
|
||||
return logits
|
||||
|
|
|
|||
|
|
@ -3,6 +3,8 @@ from torch import nn
|
|||
import torch.nn.functional as F
|
||||
import torch.distributed as dist
|
||||
|
||||
from nanovllm import distributed as dist_utils
|
||||
|
||||
|
||||
def divide(numerator, denominator):
|
||||
assert numerator % denominator == 0
|
||||
|
|
@ -20,8 +22,8 @@ class LinearBase(nn.Module):
|
|||
):
|
||||
super().__init__()
|
||||
self.tp_dim = tp_dim
|
||||
self.tp_rank = dist.get_rank()
|
||||
self.tp_size = dist.get_world_size()
|
||||
self.tp_rank = dist_utils.get_rank()
|
||||
self.tp_size = dist_utils.get_world_size()
|
||||
self.weight = nn.Parameter(torch.empty(output_size, input_size))
|
||||
self.weight.weight_loader = self.weight_loader
|
||||
if bias:
|
||||
|
|
@ -59,7 +61,7 @@ class ColumnParallelLinear(LinearBase):
|
|||
output_size: int,
|
||||
bias: bool = False,
|
||||
):
|
||||
tp_size = dist.get_world_size()
|
||||
tp_size = dist_utils.get_world_size()
|
||||
super().__init__(input_size, divide(output_size, tp_size), bias, 0)
|
||||
|
||||
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
|
||||
|
|
@ -103,7 +105,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|||
total_num_kv_heads: int | None = None,
|
||||
bias: bool = False,
|
||||
):
|
||||
tp_size = dist.get_world_size()
|
||||
tp_size = dist_utils.get_world_size()
|
||||
total_num_kv_heads = total_num_kv_heads or total_num_heads
|
||||
self.head_size = head_size
|
||||
self.num_heads = divide(total_num_heads, tp_size)
|
||||
|
|
@ -136,7 +138,7 @@ class RowParallelLinear(LinearBase):
|
|||
output_size: int,
|
||||
bias: bool = False,
|
||||
):
|
||||
tp_size = dist.get_world_size()
|
||||
tp_size = dist_utils.get_world_size()
|
||||
super().__init__(divide(input_size, tp_size), output_size, bias, 1)
|
||||
|
||||
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
|
||||
|
|
@ -149,5 +151,5 @@ class RowParallelLinear(LinearBase):
|
|||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
y = F.linear(x, self.weight, self.bias if self.tp_rank == 0 else None)
|
||||
if self.tp_size > 1:
|
||||
dist.all_reduce(y)
|
||||
dist_utils.all_reduce(y)
|
||||
return y
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from nanovllm.layers.layernorm import RMSNorm
|
|||
from nanovllm.layers.linear import QKVParallelLinear, MergedColumnParallelLinear, RowParallelLinear
|
||||
from nanovllm.layers.rotary_embedding import get_rope
|
||||
from nanovllm.layers.embed_head import VocabParallelEmbedding, ParallelLMHead
|
||||
from nanovllm import distributed as dist_utils
|
||||
|
||||
|
||||
class Qwen3Attention(nn.Module):
|
||||
|
|
@ -26,7 +27,7 @@ class Qwen3Attention(nn.Module):
|
|||
rope_scaling: tuple | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
tp_size = dist.get_world_size()
|
||||
tp_size = dist_utils.get_world_size()
|
||||
self.total_num_heads = num_heads
|
||||
assert self.total_num_heads % tp_size == 0
|
||||
self.num_heads = self.total_num_heads // tp_size
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ including dataset building, audio labeling, and training utilities.
|
|||
"""
|
||||
|
||||
from acestep.training.dataset_builder import DatasetBuilder, AudioSample
|
||||
from acestep.training.configs import LoRAConfig, TrainingConfig
|
||||
from acestep.training.configs import LoRAConfig, LoKRConfig, TrainingConfig
|
||||
from acestep.training.lora_utils import (
|
||||
inject_lora_into_dit,
|
||||
save_lora_weights,
|
||||
|
|
@ -14,6 +14,12 @@ from acestep.training.lora_utils import (
|
|||
merge_lora_weights,
|
||||
check_peft_available,
|
||||
)
|
||||
from acestep.training.lokr_utils import (
|
||||
inject_lokr_into_dit,
|
||||
save_lokr_weights,
|
||||
load_lokr_weights,
|
||||
check_lycoris_available,
|
||||
)
|
||||
from acestep.training.data_module import (
|
||||
# Preprocessed (recommended)
|
||||
PreprocessedTensorDataset,
|
||||
|
|
@ -25,7 +31,13 @@ from acestep.training.data_module import (
|
|||
collate_training_batch,
|
||||
load_dataset_from_json,
|
||||
)
|
||||
from acestep.training.trainer import LoRATrainer, PreprocessedLoRAModule, LIGHTNING_AVAILABLE
|
||||
from acestep.training.trainer import (
|
||||
LoRATrainer,
|
||||
LoKRTrainer,
|
||||
PreprocessedLoRAModule,
|
||||
PreprocessedLoKRModule,
|
||||
LIGHTNING_AVAILABLE,
|
||||
)
|
||||
|
||||
def check_lightning_available():
|
||||
"""Check if Lightning Fabric is available."""
|
||||
|
|
@ -37,6 +49,7 @@ __all__ = [
|
|||
"AudioSample",
|
||||
# Configs
|
||||
"LoRAConfig",
|
||||
"LoKRConfig",
|
||||
"TrainingConfig",
|
||||
# LoRA Utils
|
||||
"inject_lora_into_dit",
|
||||
|
|
@ -44,6 +57,11 @@ __all__ = [
|
|||
"load_lora_weights",
|
||||
"merge_lora_weights",
|
||||
"check_peft_available",
|
||||
# LoKr Utils
|
||||
"inject_lokr_into_dit",
|
||||
"save_lokr_weights",
|
||||
"load_lokr_weights",
|
||||
"check_lycoris_available",
|
||||
# Data Module (Preprocessed - Recommended)
|
||||
"PreprocessedTensorDataset",
|
||||
"PreprocessedDataModule",
|
||||
|
|
@ -55,7 +73,9 @@ __all__ = [
|
|||
"load_dataset_from_json",
|
||||
# Trainer
|
||||
"LoRATrainer",
|
||||
"LoKRTrainer",
|
||||
"PreprocessedLoRAModule",
|
||||
"PreprocessedLoKRModule",
|
||||
"check_lightning_available",
|
||||
"LIGHTNING_AVAILABLE",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ Contains dataclasses for LoRA and training configurations.
|
|||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional
|
||||
from typing import List
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -38,6 +38,43 @@ class LoRAConfig:
|
|||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoKRConfig:
|
||||
"""Configuration for LoKr (Low-Rank Kronecker) training."""
|
||||
|
||||
linear_dim: int = 64
|
||||
linear_alpha: int = 128
|
||||
factor: int = -1
|
||||
decompose_both: bool = False
|
||||
use_tucker: bool = False
|
||||
use_scalar: bool = False
|
||||
weight_decompose: bool = False
|
||||
target_modules: List[str] = field(default_factory=lambda: [
|
||||
"q_proj", "k_proj", "v_proj", "o_proj"
|
||||
])
|
||||
full_matrix: bool = False
|
||||
bypass_mode: bool = False
|
||||
rs_lora: bool = False
|
||||
unbalanced_factorization: bool = False
|
||||
|
||||
def to_dict(self):
|
||||
"""Convert to dictionary for LyCORIS config."""
|
||||
return {
|
||||
"linear_dim": self.linear_dim,
|
||||
"linear_alpha": self.linear_alpha,
|
||||
"factor": self.factor,
|
||||
"decompose_both": self.decompose_both,
|
||||
"use_tucker": self.use_tucker,
|
||||
"use_scalar": self.use_scalar,
|
||||
"weight_decompose": self.weight_decompose,
|
||||
"target_modules": self.target_modules,
|
||||
"full_matrix": self.full_matrix,
|
||||
"bypass_mode": self.bypass_mode,
|
||||
"rs_lora": self.rs_lora,
|
||||
"unbalanced_factorization": self.unbalanced_factorization,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingConfig:
|
||||
"""Configuration for LoRA training process.
|
||||
|
|
@ -83,11 +120,18 @@ class TrainingConfig:
|
|||
pin_memory: bool = True
|
||||
prefetch_factor: int = 2
|
||||
persistent_workers: bool = True
|
||||
pin_memory_device: Optional[str] = None
|
||||
pin_memory_device: str = ""
|
||||
|
||||
# Logging
|
||||
log_every_n_steps: int = 10
|
||||
|
||||
|
||||
# Validation (for loss curve and best-checkpoint tracking)
|
||||
val_split: float = 0.0
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if not 0.0 <= self.val_split < 1.0:
|
||||
raise ValueError("val_split must be in [0.0, 1.0).")
|
||||
|
||||
def to_dict(self):
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
|
|
@ -110,4 +154,5 @@ class TrainingConfig:
|
|||
"persistent_workers": self.persistent_workers,
|
||||
"pin_memory_device": self.pin_memory_device,
|
||||
"log_every_n_steps": self.log_every_n_steps,
|
||||
"val_split": self.val_split,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -177,7 +177,7 @@ class PreprocessedDataModule(LightningDataModule if LIGHTNING_AVAILABLE else obj
|
|||
pin_memory: bool = True,
|
||||
prefetch_factor: int = 2,
|
||||
persistent_workers: bool = True,
|
||||
pin_memory_device: Optional[str] = None,
|
||||
pin_memory_device: str = "",
|
||||
val_split: float = 0.0,
|
||||
):
|
||||
"""Initialize the data module.
|
||||
|
|
@ -226,7 +226,7 @@ class PreprocessedDataModule(LightningDataModule if LIGHTNING_AVAILABLE else obj
|
|||
"""Create training dataloader."""
|
||||
prefetch_factor = None if self.num_workers == 0 else self.prefetch_factor
|
||||
persistent_workers = False if self.num_workers == 0 else self.persistent_workers
|
||||
pin_memory_device = self.pin_memory_device if self.pin_memory else None
|
||||
pin_memory_device = self.pin_memory_device if self.pin_memory else ""
|
||||
return DataLoader(
|
||||
self.train_dataset,
|
||||
batch_size=self.batch_size,
|
||||
|
|
@ -246,7 +246,7 @@ class PreprocessedDataModule(LightningDataModule if LIGHTNING_AVAILABLE else obj
|
|||
return None
|
||||
prefetch_factor = None if self.num_workers == 0 else self.prefetch_factor
|
||||
persistent_workers = False if self.num_workers == 0 else self.persistent_workers
|
||||
pin_memory_device = self.pin_memory_device if self.pin_memory else None
|
||||
pin_memory_device = self.pin_memory_device if self.pin_memory else ""
|
||||
return DataLoader(
|
||||
self.val_dataset,
|
||||
batch_size=self.batch_size,
|
||||
|
|
|
|||
|
|
@ -1,39 +1,62 @@
|
|||
import os
|
||||
from typing import Tuple
|
||||
|
||||
import torchaudio
|
||||
from loguru import logger
|
||||
|
||||
|
||||
def load_lyrics_file(audio_path: str) -> Tuple[str, bool]:
|
||||
"""Load lyrics from a .txt file with the same name as the audio file."""
|
||||
def _read_text_file(path: str) -> Tuple[str, bool]:
|
||||
"""Read a text file; return (content.strip(), True) if present and non-empty."""
|
||||
if not os.path.exists(path):
|
||||
return "", False
|
||||
try:
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
content = f.read().strip()
|
||||
if content:
|
||||
return content, True
|
||||
return "", False
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to read {path}: {e}")
|
||||
return "", False
|
||||
|
||||
|
||||
def load_caption_file(audio_path: str) -> Tuple[str, bool]:
|
||||
"""Load caption from <basename>.caption.txt (explicit convention)."""
|
||||
base_path = os.path.splitext(audio_path)[0]
|
||||
lyrics_path = base_path + ".txt"
|
||||
caption_path = base_path + ".caption.txt"
|
||||
content, ok = _read_text_file(caption_path)
|
||||
if ok:
|
||||
logger.debug(f"Loaded caption from {caption_path}")
|
||||
return content, ok
|
||||
|
||||
if os.path.exists(lyrics_path):
|
||||
try:
|
||||
with open(lyrics_path, "r", encoding="utf-8") as f:
|
||||
lyrics_content = f.read().strip()
|
||||
|
||||
if lyrics_content:
|
||||
logger.info(f"Loaded lyrics from {lyrics_path}")
|
||||
return lyrics_content, True
|
||||
logger.warning(f"Lyrics file is empty: {lyrics_path}")
|
||||
return "", False
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to read lyrics file {lyrics_path}: {e}")
|
||||
return "", False
|
||||
|
||||
def load_lyrics_file(audio_path: str) -> Tuple[str, bool]:
|
||||
"""Load lyrics from <basename>.lyrics.txt, then fallback to <basename>.txt for backward compat."""
|
||||
base_path = os.path.splitext(audio_path)[0]
|
||||
for suffix in (".lyrics.txt", ".txt"):
|
||||
path = base_path + suffix
|
||||
content, ok = _read_text_file(path)
|
||||
if ok:
|
||||
if suffix == ".lyrics.txt":
|
||||
logger.debug(f"Loaded lyrics from {path}")
|
||||
else:
|
||||
logger.debug(f"Loaded lyrics from {path} (legacy .txt)")
|
||||
return content, True
|
||||
return "", False
|
||||
|
||||
|
||||
def get_audio_duration(audio_path: str) -> int:
|
||||
"""Get the duration of an audio file in seconds."""
|
||||
# Primary: torchcodec (ships with torchaudio >=2.9, supports all ffmpeg formats)
|
||||
# Note: torchcodec is optional on ROCM/Intel platforms due to CUDA dependencies
|
||||
try:
|
||||
info = torchaudio.info(audio_path)
|
||||
return int(info.num_frames / info.sample_rate)
|
||||
from torchcodec.decoders import AudioDecoder
|
||||
decoder = AudioDecoder(audio_path)
|
||||
return int(decoder.metadata.duration_seconds)
|
||||
except ImportError:
|
||||
logger.debug("torchcodec not available (expected on ROCM/Intel platforms), using soundfile fallback")
|
||||
except Exception as e:
|
||||
logger.warning(f"torchaudio failed for {audio_path}: {e}, trying soundfile")
|
||||
logger.debug(f"torchcodec failed for {audio_path}: {e}, trying soundfile")
|
||||
# Fallback: soundfile (fast for wav/flac/ogg, works on all platforms)
|
||||
try:
|
||||
import soundfile as sf
|
||||
info = sf.info(audio_path)
|
||||
|
|
|
|||
|
|
@ -29,10 +29,18 @@ class PreprocessMixin:
|
|||
dit_handler,
|
||||
output_dir: str,
|
||||
max_duration: float = 240.0,
|
||||
preprocess_mode: str = "lora",
|
||||
progress_callback=None,
|
||||
) -> Tuple[List[str], str]:
|
||||
"""Preprocess all labeled samples to tensor files for efficient training."""
|
||||
debug_log_for("dataset", f"preprocess_to_tensors: output_dir='{output_dir}', max_duration={max_duration}")
|
||||
mode = str(preprocess_mode or "lora").strip().lower()
|
||||
if mode not in {"lora", "lokr"}:
|
||||
mode = "lora"
|
||||
|
||||
debug_log_for(
|
||||
"dataset",
|
||||
f"preprocess_to_tensors: output_dir='{output_dir}', max_duration={max_duration}, mode={mode}",
|
||||
)
|
||||
if not self.samples:
|
||||
return [], "❌ No samples to preprocess"
|
||||
|
||||
|
|
@ -145,6 +153,17 @@ class PreprocessMixin:
|
|||
if lyric_hidden_states.dtype != model_dtype:
|
||||
lyric_hidden_states = lyric_hidden_states.to(model_dtype)
|
||||
|
||||
refer_audio_hidden = None
|
||||
refer_audio_order_mask_val = None
|
||||
if mode == "lokr":
|
||||
# LoKr mode uses per-sample audio latents as reference-audio conditioning.
|
||||
refer_audio_hidden = target_latents.to(device=model_device, dtype=model_dtype)
|
||||
refer_audio_order_mask_val = torch.zeros(
|
||||
refer_audio_hidden.shape[0],
|
||||
device=model_device,
|
||||
dtype=torch.long,
|
||||
)
|
||||
|
||||
encoder_hidden_states, encoder_attention_mask = run_encoder(
|
||||
model,
|
||||
text_hidden_states=text_hidden_states,
|
||||
|
|
@ -153,6 +172,8 @@ class PreprocessMixin:
|
|||
lyric_attention_mask=lyric_attention_mask,
|
||||
device=model_device,
|
||||
dtype=model_dtype,
|
||||
refer_audio_hidden_states_packed=refer_audio_hidden,
|
||||
refer_audio_order_mask=refer_audio_order_mask_val,
|
||||
)
|
||||
debug_end_verbose_for("dataset", f"run_encoder[{i}]", t0)
|
||||
debug_log_verbose_for(
|
||||
|
|
@ -162,7 +183,14 @@ class PreprocessMixin:
|
|||
)
|
||||
|
||||
t0 = debug_start_verbose_for("dataset", f"build_context_latents[{i}]")
|
||||
context_latents = build_context_latents(silence_latent, latent_length, device, dtype)
|
||||
context_src = target_latents if mode == "lokr" else None
|
||||
context_latents = build_context_latents(
|
||||
silence_latent,
|
||||
latent_length,
|
||||
device,
|
||||
dtype,
|
||||
src_latents=context_src,
|
||||
)
|
||||
debug_end_verbose_for("dataset", f"build_context_latents[{i}]", t0)
|
||||
|
||||
output_data = {
|
||||
|
|
@ -182,6 +210,7 @@ class PreprocessMixin:
|
|||
"timesignature": sample.timesignature,
|
||||
"language": sample.language,
|
||||
"is_instrumental": sample.is_instrumental,
|
||||
"preprocess_mode": mode,
|
||||
},
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,20 @@
|
|||
import torch
|
||||
|
||||
|
||||
def build_context_latents(silence_latent, latent_length: int, device, dtype):
|
||||
def build_context_latents(
|
||||
silence_latent,
|
||||
latent_length: int,
|
||||
device,
|
||||
dtype,
|
||||
src_latents: torch.Tensor = None,
|
||||
):
|
||||
"""Build context latents for text2music."""
|
||||
src_latents = silence_latent[:, :latent_length, :].to(dtype)
|
||||
if src_latents is None:
|
||||
src_latents = silence_latent[:, :latent_length, :].to(dtype)
|
||||
else:
|
||||
if src_latents.dim() == 2:
|
||||
src_latents = src_latents.unsqueeze(0)
|
||||
src_latents = src_latents.to(device=device, dtype=dtype)
|
||||
if src_latents.shape[0] < 1:
|
||||
src_latents = src_latents.expand(1, -1, -1)
|
||||
|
||||
|
|
|
|||
|
|
@ -9,10 +9,26 @@ def run_encoder(
|
|||
lyric_attention_mask,
|
||||
device,
|
||||
dtype,
|
||||
refer_audio_hidden_states_packed=None,
|
||||
refer_audio_order_mask=None,
|
||||
):
|
||||
"""Run model encoder to get hidden states and attention mask."""
|
||||
refer_audio_hidden = torch.zeros(1, 1, 64, device=device, dtype=dtype)
|
||||
refer_audio_order_mask = torch.zeros(1, device=device, dtype=torch.long)
|
||||
if refer_audio_hidden_states_packed is None:
|
||||
refer_audio_hidden = torch.zeros(1, 1, 64, device=device, dtype=dtype)
|
||||
else:
|
||||
refer_audio_hidden = refer_audio_hidden_states_packed
|
||||
if refer_audio_hidden.dim() == 2:
|
||||
refer_audio_hidden = refer_audio_hidden.unsqueeze(0)
|
||||
refer_audio_hidden = refer_audio_hidden.to(device=device, dtype=dtype)
|
||||
|
||||
if refer_audio_order_mask is None:
|
||||
refer_audio_order_mask = torch.zeros(
|
||||
refer_audio_hidden.shape[0],
|
||||
device=device,
|
||||
dtype=torch.long,
|
||||
)
|
||||
else:
|
||||
refer_audio_order_mask = refer_audio_order_mask.to(device=device, dtype=torch.long)
|
||||
|
||||
with torch.no_grad():
|
||||
encoder_hidden_states, encoder_attention_mask = model.encoder(
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from typing import List, Tuple
|
|||
|
||||
from loguru import logger
|
||||
|
||||
from .audio_io import get_audio_duration, load_lyrics_file
|
||||
from .audio_io import get_audio_duration, load_caption_file, load_lyrics_file
|
||||
from .csv_metadata import load_csv_metadata
|
||||
from .models import AudioSample, SUPPORTED_AUDIO_FORMATS
|
||||
|
||||
|
|
@ -39,17 +39,23 @@ class ScanMixin:
|
|||
|
||||
csv_metadata = load_csv_metadata(directory)
|
||||
csv_count = 0
|
||||
caption_count = 0
|
||||
lyrics_count = 0
|
||||
|
||||
for audio_path in audio_files:
|
||||
try:
|
||||
duration = get_audio_duration(audio_path)
|
||||
caption_content, has_caption_file = load_caption_file(audio_path)
|
||||
lyrics_content, has_lyrics_file = load_lyrics_file(audio_path)
|
||||
|
||||
if has_caption_file:
|
||||
caption_count += 1
|
||||
if has_lyrics_file:
|
||||
lyrics_count += 1
|
||||
|
||||
is_instrumental = self.metadata.all_instrumental
|
||||
if has_lyrics_file:
|
||||
is_instrumental = False
|
||||
lyrics_count += 1
|
||||
|
||||
sample = AudioSample(
|
||||
audio_path=audio_path,
|
||||
|
|
@ -57,9 +63,12 @@ class ScanMixin:
|
|||
duration=duration,
|
||||
is_instrumental=is_instrumental,
|
||||
custom_tag=self.metadata.custom_tag,
|
||||
caption=caption_content if has_caption_file else "",
|
||||
lyrics=lyrics_content if has_lyrics_file else "[Instrumental]",
|
||||
raw_lyrics=lyrics_content if has_lyrics_file else "",
|
||||
)
|
||||
if has_caption_file:
|
||||
sample.labeled = True
|
||||
|
||||
if csv_metadata and sample.filename in csv_metadata:
|
||||
meta = csv_metadata[sample.filename]
|
||||
|
|
@ -79,8 +88,10 @@ class ScanMixin:
|
|||
self.metadata.num_samples = len(self.samples)
|
||||
|
||||
status = f"✅ Found {len(self.samples)} audio files in {directory}"
|
||||
if caption_count > 0:
|
||||
status += f"\n 📋 Detected {caption_count} captions (.caption.txt)"
|
||||
if lyrics_count > 0:
|
||||
status += f"\n 📝 {lyrics_count} files have accompanying lyrics (.txt)"
|
||||
status += f"\n 📝 Detected {lyrics_count} lyrics (.lyrics.txt / .txt)"
|
||||
if csv_count > 0:
|
||||
status += f"\n 📊 {csv_count} files have metadata from CSV"
|
||||
|
||||
|
|
|
|||
253
acestep/training/lokr_utils.py
Normal file
253
acestep/training/lokr_utils.py
Normal file
|
|
@ -0,0 +1,253 @@
|
|||
"""
|
||||
LoKr utilities for ACE-Step training and inference.
|
||||
|
||||
This module integrates LyCORIS LoKr adapters with the ACE-Step decoder.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
from acestep.training.configs import LoKRConfig
|
||||
|
||||
try:
|
||||
from lycoris import LycorisNetwork, create_lycoris
|
||||
|
||||
LYCORIS_AVAILABLE = True
|
||||
except ImportError:
|
||||
LYCORIS_AVAILABLE = False
|
||||
LycorisNetwork = Any # type: ignore[assignment,misc]
|
||||
logger.warning(
|
||||
"LyCORIS library not installed. LoKr training/inference unavailable. "
|
||||
"Install with: pip install lycoris-lora"
|
||||
)
|
||||
|
||||
|
||||
def check_lycoris_available() -> bool:
|
||||
"""Check if LyCORIS is importable."""
|
||||
return LYCORIS_AVAILABLE
|
||||
|
||||
|
||||
def _matches_target_module_name(module_name: str, target_modules) -> bool:
|
||||
"""Return True if a LyCORIS module name maps to one of target module suffixes."""
|
||||
if not module_name:
|
||||
return False
|
||||
name = module_name.lower()
|
||||
for target in target_modules or []:
|
||||
t = str(target).strip().lower()
|
||||
if not t:
|
||||
continue
|
||||
if name.endswith(t) or f"_{t}" in name or f".{t}" in name:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def inject_lokr_into_dit(
|
||||
model,
|
||||
lokr_config: LoKRConfig,
|
||||
multiplier: float = 1.0,
|
||||
) -> Tuple[Any, "LycorisNetwork", Dict[str, Any]]:
|
||||
"""
|
||||
Inject LoKr adapters into the decoder.
|
||||
|
||||
Returns:
|
||||
Tuple: (model, lycoris_network, info_dict)
|
||||
"""
|
||||
if not LYCORIS_AVAILABLE:
|
||||
raise ImportError(
|
||||
"LyCORIS library is required for LoKr training. "
|
||||
"Install with: pip install lycoris-lora"
|
||||
)
|
||||
|
||||
decoder = model.decoder
|
||||
|
||||
# Freeze all existing params before creating adapter params.
|
||||
for _, param in model.named_parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
LycorisNetwork.apply_preset(
|
||||
{
|
||||
"unet_target_name": lokr_config.target_modules,
|
||||
"target_name": lokr_config.target_modules,
|
||||
}
|
||||
)
|
||||
|
||||
lycoris_net = create_lycoris(
|
||||
decoder,
|
||||
multiplier,
|
||||
linear_dim=lokr_config.linear_dim,
|
||||
linear_alpha=lokr_config.linear_alpha,
|
||||
algo="lokr",
|
||||
factor=lokr_config.factor,
|
||||
decompose_both=lokr_config.decompose_both,
|
||||
use_tucker=lokr_config.use_tucker,
|
||||
use_scalar=lokr_config.use_scalar,
|
||||
full_matrix=lokr_config.full_matrix,
|
||||
bypass_mode=lokr_config.bypass_mode,
|
||||
rs_lora=lokr_config.rs_lora,
|
||||
unbalanced_factorization=lokr_config.unbalanced_factorization,
|
||||
)
|
||||
|
||||
if lokr_config.weight_decompose:
|
||||
try:
|
||||
lycoris_net = create_lycoris(
|
||||
decoder,
|
||||
multiplier,
|
||||
linear_dim=lokr_config.linear_dim,
|
||||
linear_alpha=lokr_config.linear_alpha,
|
||||
algo="lokr",
|
||||
factor=lokr_config.factor,
|
||||
decompose_both=lokr_config.decompose_both,
|
||||
use_tucker=lokr_config.use_tucker,
|
||||
use_scalar=lokr_config.use_scalar,
|
||||
full_matrix=lokr_config.full_matrix,
|
||||
bypass_mode=lokr_config.bypass_mode,
|
||||
rs_lora=lokr_config.rs_lora,
|
||||
unbalanced_factorization=lokr_config.unbalanced_factorization,
|
||||
dora_wd=True,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(f"DoRA mode not supported in current LyCORIS build: {exc}")
|
||||
|
||||
lycoris_net.apply_to()
|
||||
|
||||
# Keep a reference on decoder so it stays discoverable after wrappers.
|
||||
# Always refresh this reference to avoid stale nets from earlier runs.
|
||||
decoder._lycoris_net = lycoris_net
|
||||
|
||||
lokr_param_list = []
|
||||
enabled_module_count = 0
|
||||
disabled_module_count = 0
|
||||
disabled_examples = []
|
||||
|
||||
for idx, module in enumerate(getattr(lycoris_net, "loras", []) or []):
|
||||
module_name = (
|
||||
getattr(module, "lora_name", None)
|
||||
or getattr(module, "name", None)
|
||||
or f"{module.__class__.__name__}#{idx}"
|
||||
)
|
||||
enabled = _matches_target_module_name(module_name, lokr_config.target_modules)
|
||||
|
||||
if enabled:
|
||||
enabled_module_count += 1
|
||||
else:
|
||||
disabled_module_count += 1
|
||||
if len(disabled_examples) < 8:
|
||||
disabled_examples.append(module_name)
|
||||
|
||||
for param in module.parameters():
|
||||
param.requires_grad = enabled
|
||||
if enabled:
|
||||
lokr_param_list.append(param)
|
||||
|
||||
logger.info(
|
||||
f"LoKr target filter: enabled {enabled_module_count} LyCORIS modules "
|
||||
f"(disabled {disabled_module_count}) for targets={lokr_config.target_modules}"
|
||||
)
|
||||
if disabled_examples:
|
||||
logger.info("LoKr disabled non-target modules (sample): " + ", ".join(disabled_examples))
|
||||
|
||||
if not lokr_param_list:
|
||||
for param in lycoris_net.parameters():
|
||||
param.requires_grad = True
|
||||
lokr_param_list.append(param)
|
||||
|
||||
# De-duplicate possible shared params.
|
||||
unique_params = {id(p): p for p in lokr_param_list}
|
||||
total_params = sum(p.numel() for p in model.parameters())
|
||||
lokr_params = sum(p.numel() for p in unique_params.values())
|
||||
trainable_params = sum(p.numel() for p in unique_params.values() if p.requires_grad)
|
||||
|
||||
info = {
|
||||
"total_params": total_params,
|
||||
"lokr_params": lokr_params,
|
||||
"trainable_params": trainable_params,
|
||||
"trainable_ratio": trainable_params / total_params if total_params > 0 else 0.0,
|
||||
"linear_dim": lokr_config.linear_dim,
|
||||
"linear_alpha": lokr_config.linear_alpha,
|
||||
"factor": lokr_config.factor,
|
||||
"algo": "lokr",
|
||||
"target_modules": lokr_config.target_modules,
|
||||
}
|
||||
|
||||
logger.info("LoKr injected into decoder")
|
||||
logger.info(
|
||||
f"LoKr trainable params: {trainable_params:,}/{total_params:,} "
|
||||
f"({info['trainable_ratio']:.2%})"
|
||||
)
|
||||
return model, lycoris_net, info
|
||||
|
||||
|
||||
def save_lokr_weights(
|
||||
lycoris_net: "LycorisNetwork",
|
||||
output_dir: str,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
metadata: Optional[Dict[str, str]] = None,
|
||||
) -> str:
|
||||
"""Save LoKr weights to safetensors."""
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
weights_path = os.path.join(output_dir, "lokr_weights.safetensors")
|
||||
|
||||
save_metadata: Dict[str, str] = {"algo": "lokr", "format": "lycoris"}
|
||||
if metadata:
|
||||
for key, value in metadata.items():
|
||||
if value is None:
|
||||
continue
|
||||
if isinstance(value, str):
|
||||
save_metadata[key] = value
|
||||
else:
|
||||
save_metadata[key] = json.dumps(value, ensure_ascii=True)
|
||||
|
||||
lycoris_net.save_weights(weights_path, dtype=dtype, metadata=save_metadata)
|
||||
logger.info(f"LoKr weights saved to {weights_path}")
|
||||
return weights_path
|
||||
|
||||
|
||||
def load_lokr_weights(lycoris_net: "LycorisNetwork", weights_path: str) -> Dict[str, Any]:
|
||||
"""Load LoKr weights into an injected LyCORIS network."""
|
||||
if not os.path.exists(weights_path):
|
||||
raise FileNotFoundError(f"LoKr weights not found: {weights_path}")
|
||||
result = lycoris_net.load_weights(weights_path)
|
||||
logger.info(f"LoKr weights loaded from {weights_path}")
|
||||
return result
|
||||
|
||||
|
||||
def save_lokr_training_checkpoint(
|
||||
lycoris_net: "LycorisNetwork",
|
||||
optimizer,
|
||||
scheduler,
|
||||
epoch: int,
|
||||
global_step: int,
|
||||
output_dir: str,
|
||||
lokr_config: Optional[LoKRConfig] = None,
|
||||
run_metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> str:
|
||||
"""Save LoKr weights plus optimizer/scheduler state."""
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
metadata: Dict[str, Any] = {}
|
||||
if lokr_config is not None:
|
||||
metadata["lokr_config"] = lokr_config.to_dict()
|
||||
if run_metadata is not None:
|
||||
metadata["run_metadata"] = run_metadata
|
||||
metadata = metadata or None
|
||||
save_lokr_weights(lycoris_net, output_dir, metadata=metadata)
|
||||
|
||||
state = {
|
||||
"epoch": epoch,
|
||||
"global_step": global_step,
|
||||
"optimizer_state_dict": optimizer.state_dict(),
|
||||
"scheduler_state_dict": scheduler.state_dict(),
|
||||
}
|
||||
if lokr_config is not None:
|
||||
state["lokr_config"] = lokr_config.to_dict()
|
||||
if run_metadata is not None:
|
||||
state["run_metadata"] = run_metadata
|
||||
|
||||
state_path = os.path.join(output_dir, "training_state.pt")
|
||||
torch.save(state, state_path)
|
||||
logger.info(f"LoKr checkpoint saved to {output_dir} (epoch={epoch}, step={global_step})")
|
||||
return output_dir
|
||||
|
|
@ -8,6 +8,7 @@ Uses PEFT (Parameter-Efficient Fine-Tuning) library for LoRA implementation.
|
|||
import os
|
||||
from typing import Optional, List, Dict, Any, Tuple
|
||||
from loguru import logger
|
||||
import types
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
|
@ -95,8 +96,61 @@ def inject_lora_into_dit(
|
|||
if not PEFT_AVAILABLE:
|
||||
raise ImportError("PEFT library is required for LoRA training. Install with: pip install peft")
|
||||
|
||||
# Get the decoder (DiT model)
|
||||
# Get the decoder (DiT model). Previous failed training runs may leave
|
||||
# Fabric/PEFT wrappers attached; unwrap to a clean base module first.
|
||||
decoder = model.decoder
|
||||
while hasattr(decoder, "_forward_module"):
|
||||
decoder = decoder._forward_module
|
||||
if hasattr(decoder, "base_model"):
|
||||
base_model = decoder.base_model
|
||||
if hasattr(base_model, "model"):
|
||||
decoder = base_model.model
|
||||
else:
|
||||
decoder = base_model
|
||||
if hasattr(decoder, "model") and isinstance(decoder.model, nn.Module):
|
||||
decoder = decoder.model
|
||||
model.decoder = decoder
|
||||
|
||||
# PEFT may call enable_input_require_grads() when is_gradient_checkpointing
|
||||
# is true. AceStepDiTModel doesn't implement get_input_embeddings, so the
|
||||
# default implementation raises NotImplementedError. Guard this path.
|
||||
if hasattr(decoder, "enable_input_require_grads"):
|
||||
orig_enable_input_require_grads = decoder.enable_input_require_grads
|
||||
|
||||
def _safe_enable_input_require_grads(self):
|
||||
try:
|
||||
result = orig_enable_input_require_grads()
|
||||
try:
|
||||
self._acestep_input_grads_hook_enabled = True
|
||||
except Exception:
|
||||
pass
|
||||
return result
|
||||
except NotImplementedError:
|
||||
try:
|
||||
self._acestep_input_grads_hook_enabled = False
|
||||
except Exception:
|
||||
pass
|
||||
if not getattr(self, "_acestep_input_grads_warning_emitted", False):
|
||||
logger.info(
|
||||
"Skipping enable_input_require_grads for decoder: "
|
||||
"get_input_embeddings is not implemented (expected for DiT)"
|
||||
)
|
||||
try:
|
||||
self._acestep_input_grads_warning_emitted = True
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
decoder.enable_input_require_grads = types.MethodType(
|
||||
_safe_enable_input_require_grads, decoder
|
||||
)
|
||||
|
||||
# Avoid PEFT auto-prep path on non-embedding diffusion decoder.
|
||||
if hasattr(decoder, "is_gradient_checkpointing"):
|
||||
try:
|
||||
decoder.is_gradient_checkpointing = False
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Create PEFT LoRA config
|
||||
peft_lora_config = LoraConfig(
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
15
acestep/training_v2/__init__.py
Normal file
15
acestep/training_v2/__init__.py
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
"""
|
||||
ACE-Step Training V2 -- Corrected LoRA Fine-Tuning CLI
|
||||
|
||||
Non-destructive parallel module providing corrected training procedures
|
||||
that match each model variant's own forward() training logic.
|
||||
|
||||
Subcommands:
|
||||
vanilla -- Reproduce existing (bugged) training for backward compatibility
|
||||
fixed -- Corrected training: continuous timesteps + CFG dropout
|
||||
selective -- Fixed + dataset-specific layer/module selection
|
||||
estimate -- Gradient sensitivity analysis (no training)
|
||||
compare-configs -- Compare module configs across datasets
|
||||
"""
|
||||
|
||||
__version__ = "2.0.0"
|
||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Reference in a new issue