mirror of
https://github.com/index-tts/index-tts.git
synced 2026-07-02 16:37:05 +00:00
fix: fix indextts2 model resource checks (#707)
Some checks failed
CI / Lightweight Python checks (push) Has been cancelled
Some checks failed
CI / Lightweight Python checks (push) Has been cancelled
* fork/PR_CLI: 修复CLI误将pinyin.vocab作为模型文件进行检查而报错的问题 * fork/PR_CLI: indextts2 check 指令增加 hf_cache 相关模型的检查
This commit is contained in:
parent
b154a1b21e
commit
7264ce2a9a
5 changed files with 144 additions and 16 deletions
|
|
@ -15,13 +15,21 @@ REQUIRED_MODEL_FILES = [
|
|||
"gpt.pth",
|
||||
"s2mel.pth",
|
||||
"wav2vec2bert_stats.pt",
|
||||
"pinyin.vocab",
|
||||
"feat1.pt",
|
||||
"feat2.pt",
|
||||
]
|
||||
REQUIRED_MODEL_DIRS = [
|
||||
"qwen0.6bemo4-merge",
|
||||
]
|
||||
AUX_MODEL_FILES = [
|
||||
"hf_cache/semantic_codec_model.safetensors",
|
||||
"hf_cache/campplus_cn_common.bin",
|
||||
"hf_cache/bigvgan/config.json",
|
||||
"hf_cache/bigvgan/bigvgan_generator.pt",
|
||||
]
|
||||
AUX_MODEL_DIRS = [
|
||||
"hf_cache/w2v-bert-2.0",
|
||||
]
|
||||
|
||||
|
||||
def make_model_dir(base_dir):
|
||||
|
|
@ -31,6 +39,14 @@ def make_model_dir(base_dir):
|
|||
(model_dir / filename).write_text("placeholder", encoding="utf-8")
|
||||
for dirname in REQUIRED_MODEL_DIRS:
|
||||
(model_dir / dirname).mkdir()
|
||||
for filename in AUX_MODEL_FILES:
|
||||
target = model_dir / filename
|
||||
target.parent.mkdir(parents=True, exist_ok=True)
|
||||
target.write_text("placeholder", encoding="utf-8")
|
||||
for dirname in AUX_MODEL_DIRS:
|
||||
target = model_dir / dirname
|
||||
target.mkdir(parents=True, exist_ok=True)
|
||||
(target / "config.json").write_text("placeholder", encoding="utf-8")
|
||||
return model_dir
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -16,25 +16,46 @@ REQUIRED_MODEL_FILES = [
|
|||
"gpt.pth",
|
||||
"s2mel.pth",
|
||||
"wav2vec2bert_stats.pt",
|
||||
"pinyin.vocab",
|
||||
"feat1.pt",
|
||||
"feat2.pt",
|
||||
]
|
||||
REQUIRED_MODEL_DIRS = [
|
||||
"qwen0.6bemo4-merge",
|
||||
]
|
||||
AUX_MODEL_FILES = [
|
||||
"hf_cache/semantic_codec_model.safetensors",
|
||||
"hf_cache/campplus_cn_common.bin",
|
||||
"hf_cache/bigvgan/config.json",
|
||||
"hf_cache/bigvgan/bigvgan_generator.pt",
|
||||
]
|
||||
AUX_MODEL_DIRS = [
|
||||
"hf_cache/w2v-bert-2.0",
|
||||
]
|
||||
|
||||
|
||||
def make_model_dir(base_dir):
|
||||
def make_model_dir(base_dir, include_aux=True):
|
||||
model_dir = base_dir / "checkpoints"
|
||||
model_dir.mkdir()
|
||||
for filename in REQUIRED_MODEL_FILES:
|
||||
(model_dir / filename).write_text("placeholder", encoding="utf-8")
|
||||
for dirname in REQUIRED_MODEL_DIRS:
|
||||
(model_dir / dirname).mkdir()
|
||||
if include_aux:
|
||||
make_aux_model_cache(model_dir)
|
||||
return model_dir
|
||||
|
||||
|
||||
def make_aux_model_cache(model_dir):
|
||||
for filename in AUX_MODEL_FILES:
|
||||
target = model_dir / filename
|
||||
target.parent.mkdir(parents=True, exist_ok=True)
|
||||
target.write_text("placeholder", encoding="utf-8")
|
||||
for dirname in AUX_MODEL_DIRS:
|
||||
target = model_dir / dirname
|
||||
target.mkdir(parents=True, exist_ok=True)
|
||||
(target / "config.json").write_text("placeholder", encoding="utf-8")
|
||||
|
||||
|
||||
def assert_model_resource_help(test_case, stderr, model_dir):
|
||||
test_case.assertIn(f"Model directory: {model_dir}", stderr)
|
||||
test_case.assertIn("Missing resources:", stderr)
|
||||
|
|
@ -189,11 +210,30 @@ class CheckCommandTests(unittest.TestCase):
|
|||
|
||||
self.assertEqual(exit_code, 2)
|
||||
self.assertEqual(stdout.getvalue(), "")
|
||||
self.assertIn("pinyin.vocab", stderr.getvalue())
|
||||
self.assertIn("feat1.pt", stderr.getvalue())
|
||||
self.assertIn("feat2.pt", stderr.getvalue())
|
||||
self.assertIn("qwen0.6bemo4-merge", stderr.getvalue())
|
||||
|
||||
def test_check_requires_the_auxiliary_model_cache_resources(self):
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
model_dir = make_model_dir(Path(temp_dir), include_aux=False)
|
||||
|
||||
from indextts.cli_v2 import main
|
||||
|
||||
stdout = io.StringIO()
|
||||
stderr = io.StringIO()
|
||||
with contextlib.redirect_stdout(stdout), contextlib.redirect_stderr(stderr):
|
||||
exit_code = main(["check", "--model-dir", str(model_dir)])
|
||||
|
||||
self.assertEqual(exit_code, 2)
|
||||
self.assertEqual(stdout.getvalue(), "")
|
||||
self.assertIn("ERROR: missing required model files", stderr.getvalue())
|
||||
self.assertIn("hf_cache/w2v-bert-2.0", stderr.getvalue())
|
||||
self.assertIn("hf_cache/semantic_codec_model.safetensors", stderr.getvalue())
|
||||
self.assertIn("hf_cache/campplus_cn_common.bin", stderr.getvalue())
|
||||
self.assertIn("hf_cache/bigvgan/config.json", stderr.getvalue())
|
||||
self.assertIn("hf_cache/bigvgan/bigvgan_generator.pt", stderr.getvalue())
|
||||
|
||||
def test_check_requires_file_resources_and_directory_resources(self):
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
model_dir = Path(temp_dir) / "checkpoints"
|
||||
|
|
@ -1114,14 +1154,9 @@ class SynthCommandTests(unittest.TestCase):
|
|||
def test_synth_maps_runtime_options_to_indextts2(self):
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
temp_path = Path(temp_dir)
|
||||
model_dir = temp_path / "models"
|
||||
model_dir = make_model_dir(temp_path)
|
||||
voice_path = temp_path / "voice.wav"
|
||||
output_path = temp_path / "out.wav"
|
||||
model_dir.mkdir()
|
||||
for filename in REQUIRED_MODEL_FILES:
|
||||
(model_dir / filename).write_text("placeholder", encoding="utf-8")
|
||||
for dirname in REQUIRED_MODEL_DIRS:
|
||||
(model_dir / dirname).mkdir()
|
||||
voice_path.write_bytes(b"voice")
|
||||
|
||||
exit_code, stdout, stderr, calls = self.run_synth(
|
||||
|
|
|
|||
|
|
@ -16,13 +16,21 @@ REQUIRED_MODEL_FILES = [
|
|||
"gpt.pth",
|
||||
"s2mel.pth",
|
||||
"wav2vec2bert_stats.pt",
|
||||
"pinyin.vocab",
|
||||
"feat1.pt",
|
||||
"feat2.pt",
|
||||
]
|
||||
REQUIRED_MODEL_DIRS = [
|
||||
"qwen0.6bemo4-merge",
|
||||
]
|
||||
AUX_MODEL_FILES = [
|
||||
"hf_cache/semantic_codec_model.safetensors",
|
||||
"hf_cache/campplus_cn_common.bin",
|
||||
"hf_cache/bigvgan/config.json",
|
||||
"hf_cache/bigvgan/bigvgan_generator.pt",
|
||||
]
|
||||
AUX_MODEL_DIRS = [
|
||||
"hf_cache/w2v-bert-2.0",
|
||||
]
|
||||
|
||||
|
||||
def make_model_dir(path):
|
||||
|
|
@ -31,6 +39,14 @@ def make_model_dir(path):
|
|||
(path / filename).write_text("placeholder", encoding="utf-8")
|
||||
for dirname in REQUIRED_MODEL_DIRS:
|
||||
(path / dirname).mkdir()
|
||||
for filename in AUX_MODEL_FILES:
|
||||
target = path / filename
|
||||
target.parent.mkdir(parents=True, exist_ok=True)
|
||||
target.write_text("placeholder", encoding="utf-8")
|
||||
for dirname in AUX_MODEL_DIRS:
|
||||
target = path / dirname
|
||||
target.mkdir(parents=True, exist_ok=True)
|
||||
(target / "config.json").write_text("placeholder", encoding="utf-8")
|
||||
|
||||
|
||||
def fake_torch():
|
||||
|
|
|
|||
|
|
@ -14,21 +14,42 @@ REQUIRED_MODEL_FILES = [
|
|||
"gpt.pth",
|
||||
"s2mel.pth",
|
||||
"wav2vec2bert_stats.pt",
|
||||
"pinyin.vocab",
|
||||
"feat1.pt",
|
||||
"feat2.pt",
|
||||
]
|
||||
REQUIRED_MODEL_DIRS = [
|
||||
"qwen0.6bemo4-merge",
|
||||
]
|
||||
AUX_MODEL_FILES = [
|
||||
"hf_cache/semantic_codec_model.safetensors",
|
||||
"hf_cache/campplus_cn_common.bin",
|
||||
"hf_cache/bigvgan/config.json",
|
||||
"hf_cache/bigvgan/bigvgan_generator.pt",
|
||||
]
|
||||
AUX_MODEL_DIRS = [
|
||||
"hf_cache/w2v-bert-2.0",
|
||||
]
|
||||
|
||||
|
||||
def make_model_dir(path):
|
||||
def make_model_dir(path, include_aux=True):
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
for filename in REQUIRED_MODEL_FILES:
|
||||
(path / filename).write_text("placeholder", encoding="utf-8")
|
||||
for dirname in REQUIRED_MODEL_DIRS:
|
||||
(path / dirname).mkdir(exist_ok=True)
|
||||
if include_aux:
|
||||
make_aux_model_cache(path)
|
||||
|
||||
|
||||
def make_aux_model_cache(path):
|
||||
for filename in AUX_MODEL_FILES:
|
||||
target = path / filename
|
||||
target.parent.mkdir(parents=True, exist_ok=True)
|
||||
target.write_text("placeholder", encoding="utf-8")
|
||||
for dirname in AUX_MODEL_DIRS:
|
||||
target = path / dirname
|
||||
target.mkdir(parents=True, exist_ok=True)
|
||||
(target / "config.json").write_text("placeholder", encoding="utf-8")
|
||||
|
||||
|
||||
def user_state_paths(temp_path):
|
||||
|
|
@ -72,19 +93,29 @@ class DownloadCommandTests(unittest.TestCase):
|
|||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
state = user_state_paths(Path(temp_dir).resolve())
|
||||
calls = []
|
||||
aux_calls = []
|
||||
|
||||
def fake_snapshot_download(repo_id, local_dir, **kwargs):
|
||||
calls.append((repo_id, Path(local_dir)))
|
||||
make_model_dir(Path(local_dir))
|
||||
make_model_dir(Path(local_dir), include_aux=False)
|
||||
return str(local_dir)
|
||||
|
||||
def fake_ensure_models_available(model_dir):
|
||||
aux_calls.append(Path(model_dir))
|
||||
make_aux_model_cache(Path(model_dir))
|
||||
|
||||
with mock.patch.dict(os.environ, state["env"], clear=False):
|
||||
with mock.patch("indextts.utils.model_download.snapshot_download", side_effect=fake_snapshot_download):
|
||||
exit_code, stdout, stderr = self.run_cli(["download"])
|
||||
with mock.patch(
|
||||
"indextts.utils.model_download.ensure_models_available",
|
||||
side_effect=fake_ensure_models_available,
|
||||
):
|
||||
exit_code, stdout, stderr = self.run_cli(["download"])
|
||||
config_exists = state["config_path"].exists()
|
||||
|
||||
self.assertEqual(exit_code, 0)
|
||||
self.assertEqual(calls, [("IndexTeam/IndexTTS-2", state["model_dir"])])
|
||||
self.assertEqual(aux_calls, [state["model_dir"]])
|
||||
self.assertIn(f"Downloaded model resources to: {state['model_dir']}", stdout)
|
||||
self.assertEqual(stderr, "")
|
||||
self.assertFalse(config_exists)
|
||||
|
|
|
|||
|
|
@ -25,13 +25,21 @@ REQUIRED_MODEL_FILES = (
|
|||
"gpt.pth",
|
||||
"s2mel.pth",
|
||||
"wav2vec2bert_stats.pt",
|
||||
"pinyin.vocab",
|
||||
"feat1.pt",
|
||||
"feat2.pt",
|
||||
)
|
||||
REQUIRED_MODEL_DIRS = (
|
||||
"qwen0.6bemo4-merge",
|
||||
)
|
||||
REQUIRED_AUX_MODEL_FILES = (
|
||||
"hf_cache/semantic_codec_model.safetensors",
|
||||
"hf_cache/campplus_cn_common.bin",
|
||||
"hf_cache/bigvgan/config.json",
|
||||
"hf_cache/bigvgan/bigvgan_generator.pt",
|
||||
)
|
||||
REQUIRED_AUX_MODEL_DIRS = (
|
||||
"hf_cache/w2v-bert-2.0",
|
||||
)
|
||||
MODEL_REPO_ID = "IndexTeam/IndexTTS-2"
|
||||
REQUIRED_PACKAGES = ("torch", "torchaudio", "indextts")
|
||||
PERSISTED_CONFIG_KEYS = (
|
||||
|
|
@ -322,6 +330,13 @@ def _download_model_resources(source, model_dir):
|
|||
from huggingface_hub import snapshot_download
|
||||
snapshot_download(repo_id=MODEL_REPO_ID, local_dir=str(model_dir))
|
||||
|
||||
if _missing_primary_model_resources(model_dir):
|
||||
return
|
||||
|
||||
from indextts.utils.model_download import ensure_models_available
|
||||
|
||||
ensure_models_available(str(model_dir))
|
||||
|
||||
|
||||
def _download_support_package(source):
|
||||
if source == "auto":
|
||||
|
|
@ -1517,11 +1532,26 @@ def _print_model_resource_help(model_dir, missing_summary):
|
|||
def _missing_model_files(model_dir):
|
||||
if not model_dir.is_dir():
|
||||
return None
|
||||
missing_files = _missing_primary_model_resources(model_dir)
|
||||
missing_aux_files = [
|
||||
filename for filename in REQUIRED_AUX_MODEL_FILES if not _model_resource_path(model_dir, filename).is_file()
|
||||
]
|
||||
missing_aux_dirs = [
|
||||
dirname for dirname in REQUIRED_AUX_MODEL_DIRS if not _model_resource_path(model_dir, dirname).is_dir()
|
||||
]
|
||||
return missing_files + missing_aux_files + missing_aux_dirs
|
||||
|
||||
|
||||
def _missing_primary_model_resources(model_dir):
|
||||
missing_files = [filename for filename in REQUIRED_MODEL_FILES if not (model_dir / filename).is_file()]
|
||||
missing_dirs = [dirname for dirname in REQUIRED_MODEL_DIRS if not (model_dir / dirname).is_dir()]
|
||||
return missing_files + missing_dirs
|
||||
|
||||
|
||||
def _model_resource_path(model_dir, relative_path):
|
||||
return model_dir.joinpath(*relative_path.split("/"))
|
||||
|
||||
|
||||
def _import_required_packages():
|
||||
missing = []
|
||||
imported = {}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue