fix: fix indextts2 model resource checks (#707)
Some checks failed
CI / Lightweight Python checks (push) Has been cancelled

* fork/PR_CLI: 修复CLI误将pinyin.vocab作为模型文件进行检查而报错的问题

* fork/PR_CLI: indextts2 check 指令增加 hf_cache 相关模型的检查
This commit is contained in:
From_Abyss 2026-06-23 15:58:20 +08:00 committed by GitHub
parent b154a1b21e
commit 7264ce2a9a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 144 additions and 16 deletions

View file

@ -15,13 +15,21 @@ REQUIRED_MODEL_FILES = [
"gpt.pth",
"s2mel.pth",
"wav2vec2bert_stats.pt",
"pinyin.vocab",
"feat1.pt",
"feat2.pt",
]
REQUIRED_MODEL_DIRS = [
"qwen0.6bemo4-merge",
]
AUX_MODEL_FILES = [
"hf_cache/semantic_codec_model.safetensors",
"hf_cache/campplus_cn_common.bin",
"hf_cache/bigvgan/config.json",
"hf_cache/bigvgan/bigvgan_generator.pt",
]
AUX_MODEL_DIRS = [
"hf_cache/w2v-bert-2.0",
]
def make_model_dir(base_dir):
@ -31,6 +39,14 @@ def make_model_dir(base_dir):
(model_dir / filename).write_text("placeholder", encoding="utf-8")
for dirname in REQUIRED_MODEL_DIRS:
(model_dir / dirname).mkdir()
for filename in AUX_MODEL_FILES:
target = model_dir / filename
target.parent.mkdir(parents=True, exist_ok=True)
target.write_text("placeholder", encoding="utf-8")
for dirname in AUX_MODEL_DIRS:
target = model_dir / dirname
target.mkdir(parents=True, exist_ok=True)
(target / "config.json").write_text("placeholder", encoding="utf-8")
return model_dir

View file

@ -16,25 +16,46 @@ REQUIRED_MODEL_FILES = [
"gpt.pth",
"s2mel.pth",
"wav2vec2bert_stats.pt",
"pinyin.vocab",
"feat1.pt",
"feat2.pt",
]
REQUIRED_MODEL_DIRS = [
"qwen0.6bemo4-merge",
]
AUX_MODEL_FILES = [
"hf_cache/semantic_codec_model.safetensors",
"hf_cache/campplus_cn_common.bin",
"hf_cache/bigvgan/config.json",
"hf_cache/bigvgan/bigvgan_generator.pt",
]
AUX_MODEL_DIRS = [
"hf_cache/w2v-bert-2.0",
]
def make_model_dir(base_dir):
def make_model_dir(base_dir, include_aux=True):
model_dir = base_dir / "checkpoints"
model_dir.mkdir()
for filename in REQUIRED_MODEL_FILES:
(model_dir / filename).write_text("placeholder", encoding="utf-8")
for dirname in REQUIRED_MODEL_DIRS:
(model_dir / dirname).mkdir()
if include_aux:
make_aux_model_cache(model_dir)
return model_dir
def make_aux_model_cache(model_dir):
for filename in AUX_MODEL_FILES:
target = model_dir / filename
target.parent.mkdir(parents=True, exist_ok=True)
target.write_text("placeholder", encoding="utf-8")
for dirname in AUX_MODEL_DIRS:
target = model_dir / dirname
target.mkdir(parents=True, exist_ok=True)
(target / "config.json").write_text("placeholder", encoding="utf-8")
def assert_model_resource_help(test_case, stderr, model_dir):
test_case.assertIn(f"Model directory: {model_dir}", stderr)
test_case.assertIn("Missing resources:", stderr)
@ -189,11 +210,30 @@ class CheckCommandTests(unittest.TestCase):
self.assertEqual(exit_code, 2)
self.assertEqual(stdout.getvalue(), "")
self.assertIn("pinyin.vocab", stderr.getvalue())
self.assertIn("feat1.pt", stderr.getvalue())
self.assertIn("feat2.pt", stderr.getvalue())
self.assertIn("qwen0.6bemo4-merge", stderr.getvalue())
def test_check_requires_the_auxiliary_model_cache_resources(self):
with tempfile.TemporaryDirectory() as temp_dir:
model_dir = make_model_dir(Path(temp_dir), include_aux=False)
from indextts.cli_v2 import main
stdout = io.StringIO()
stderr = io.StringIO()
with contextlib.redirect_stdout(stdout), contextlib.redirect_stderr(stderr):
exit_code = main(["check", "--model-dir", str(model_dir)])
self.assertEqual(exit_code, 2)
self.assertEqual(stdout.getvalue(), "")
self.assertIn("ERROR: missing required model files", stderr.getvalue())
self.assertIn("hf_cache/w2v-bert-2.0", stderr.getvalue())
self.assertIn("hf_cache/semantic_codec_model.safetensors", stderr.getvalue())
self.assertIn("hf_cache/campplus_cn_common.bin", stderr.getvalue())
self.assertIn("hf_cache/bigvgan/config.json", stderr.getvalue())
self.assertIn("hf_cache/bigvgan/bigvgan_generator.pt", stderr.getvalue())
def test_check_requires_file_resources_and_directory_resources(self):
with tempfile.TemporaryDirectory() as temp_dir:
model_dir = Path(temp_dir) / "checkpoints"
@ -1114,14 +1154,9 @@ class SynthCommandTests(unittest.TestCase):
def test_synth_maps_runtime_options_to_indextts2(self):
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = Path(temp_dir)
model_dir = temp_path / "models"
model_dir = make_model_dir(temp_path)
voice_path = temp_path / "voice.wav"
output_path = temp_path / "out.wav"
model_dir.mkdir()
for filename in REQUIRED_MODEL_FILES:
(model_dir / filename).write_text("placeholder", encoding="utf-8")
for dirname in REQUIRED_MODEL_DIRS:
(model_dir / dirname).mkdir()
voice_path.write_bytes(b"voice")
exit_code, stdout, stderr, calls = self.run_synth(

View file

@ -16,13 +16,21 @@ REQUIRED_MODEL_FILES = [
"gpt.pth",
"s2mel.pth",
"wav2vec2bert_stats.pt",
"pinyin.vocab",
"feat1.pt",
"feat2.pt",
]
REQUIRED_MODEL_DIRS = [
"qwen0.6bemo4-merge",
]
AUX_MODEL_FILES = [
"hf_cache/semantic_codec_model.safetensors",
"hf_cache/campplus_cn_common.bin",
"hf_cache/bigvgan/config.json",
"hf_cache/bigvgan/bigvgan_generator.pt",
]
AUX_MODEL_DIRS = [
"hf_cache/w2v-bert-2.0",
]
def make_model_dir(path):
@ -31,6 +39,14 @@ def make_model_dir(path):
(path / filename).write_text("placeholder", encoding="utf-8")
for dirname in REQUIRED_MODEL_DIRS:
(path / dirname).mkdir()
for filename in AUX_MODEL_FILES:
target = path / filename
target.parent.mkdir(parents=True, exist_ok=True)
target.write_text("placeholder", encoding="utf-8")
for dirname in AUX_MODEL_DIRS:
target = path / dirname
target.mkdir(parents=True, exist_ok=True)
(target / "config.json").write_text("placeholder", encoding="utf-8")
def fake_torch():

View file

@ -14,21 +14,42 @@ REQUIRED_MODEL_FILES = [
"gpt.pth",
"s2mel.pth",
"wav2vec2bert_stats.pt",
"pinyin.vocab",
"feat1.pt",
"feat2.pt",
]
REQUIRED_MODEL_DIRS = [
"qwen0.6bemo4-merge",
]
AUX_MODEL_FILES = [
"hf_cache/semantic_codec_model.safetensors",
"hf_cache/campplus_cn_common.bin",
"hf_cache/bigvgan/config.json",
"hf_cache/bigvgan/bigvgan_generator.pt",
]
AUX_MODEL_DIRS = [
"hf_cache/w2v-bert-2.0",
]
def make_model_dir(path):
def make_model_dir(path, include_aux=True):
path.mkdir(parents=True, exist_ok=True)
for filename in REQUIRED_MODEL_FILES:
(path / filename).write_text("placeholder", encoding="utf-8")
for dirname in REQUIRED_MODEL_DIRS:
(path / dirname).mkdir(exist_ok=True)
if include_aux:
make_aux_model_cache(path)
def make_aux_model_cache(path):
for filename in AUX_MODEL_FILES:
target = path / filename
target.parent.mkdir(parents=True, exist_ok=True)
target.write_text("placeholder", encoding="utf-8")
for dirname in AUX_MODEL_DIRS:
target = path / dirname
target.mkdir(parents=True, exist_ok=True)
(target / "config.json").write_text("placeholder", encoding="utf-8")
def user_state_paths(temp_path):
@ -72,19 +93,29 @@ class DownloadCommandTests(unittest.TestCase):
with tempfile.TemporaryDirectory() as temp_dir:
state = user_state_paths(Path(temp_dir).resolve())
calls = []
aux_calls = []
def fake_snapshot_download(repo_id, local_dir, **kwargs):
calls.append((repo_id, Path(local_dir)))
make_model_dir(Path(local_dir))
make_model_dir(Path(local_dir), include_aux=False)
return str(local_dir)
def fake_ensure_models_available(model_dir):
aux_calls.append(Path(model_dir))
make_aux_model_cache(Path(model_dir))
with mock.patch.dict(os.environ, state["env"], clear=False):
with mock.patch("indextts.utils.model_download.snapshot_download", side_effect=fake_snapshot_download):
exit_code, stdout, stderr = self.run_cli(["download"])
with mock.patch(
"indextts.utils.model_download.ensure_models_available",
side_effect=fake_ensure_models_available,
):
exit_code, stdout, stderr = self.run_cli(["download"])
config_exists = state["config_path"].exists()
self.assertEqual(exit_code, 0)
self.assertEqual(calls, [("IndexTeam/IndexTTS-2", state["model_dir"])])
self.assertEqual(aux_calls, [state["model_dir"]])
self.assertIn(f"Downloaded model resources to: {state['model_dir']}", stdout)
self.assertEqual(stderr, "")
self.assertFalse(config_exists)

View file

@ -25,13 +25,21 @@ REQUIRED_MODEL_FILES = (
"gpt.pth",
"s2mel.pth",
"wav2vec2bert_stats.pt",
"pinyin.vocab",
"feat1.pt",
"feat2.pt",
)
REQUIRED_MODEL_DIRS = (
"qwen0.6bemo4-merge",
)
REQUIRED_AUX_MODEL_FILES = (
"hf_cache/semantic_codec_model.safetensors",
"hf_cache/campplus_cn_common.bin",
"hf_cache/bigvgan/config.json",
"hf_cache/bigvgan/bigvgan_generator.pt",
)
REQUIRED_AUX_MODEL_DIRS = (
"hf_cache/w2v-bert-2.0",
)
MODEL_REPO_ID = "IndexTeam/IndexTTS-2"
REQUIRED_PACKAGES = ("torch", "torchaudio", "indextts")
PERSISTED_CONFIG_KEYS = (
@ -322,6 +330,13 @@ def _download_model_resources(source, model_dir):
from huggingface_hub import snapshot_download
snapshot_download(repo_id=MODEL_REPO_ID, local_dir=str(model_dir))
if _missing_primary_model_resources(model_dir):
return
from indextts.utils.model_download import ensure_models_available
ensure_models_available(str(model_dir))
def _download_support_package(source):
if source == "auto":
@ -1517,11 +1532,26 @@ def _print_model_resource_help(model_dir, missing_summary):
def _missing_model_files(model_dir):
if not model_dir.is_dir():
return None
missing_files = _missing_primary_model_resources(model_dir)
missing_aux_files = [
filename for filename in REQUIRED_AUX_MODEL_FILES if not _model_resource_path(model_dir, filename).is_file()
]
missing_aux_dirs = [
dirname for dirname in REQUIRED_AUX_MODEL_DIRS if not _model_resource_path(model_dir, dirname).is_dir()
]
return missing_files + missing_aux_files + missing_aux_dirs
def _missing_primary_model_resources(model_dir):
missing_files = [filename for filename in REQUIRED_MODEL_FILES if not (model_dir / filename).is_file()]
missing_dirs = [dirname for dirname in REQUIRED_MODEL_DIRS if not (model_dir / dirname).is_dir()]
return missing_files + missing_dirs
def _model_resource_path(model_dir, relative_path):
return model_dir.joinpath(*relative_path.split("/"))
def _import_required_packages():
missing = []
imported = {}