mirror of
https://github.com/index-tts/index-tts.git
synced 2026-07-02 16:37:05 +00:00
fix: fix model download redirect issue, dependency placement, and test efficiency (#706)
Some checks failed
CI / Lightweight Python checks (push) Has been cancelled
Some checks failed
CI / Lightweight Python checks (push) Has been cancelled
* fix: fix model download redirect issue and add test * fix: address review comments on requests dep, connection leak, and test download cap * fix: fix inconsistency of testcase and code * fix: address remaining PR review feedback * fix: cli download uses auto network detection, matching webui behavior * fix: use ModelScope local_dir directly, remove redundant tmp download and file shuffling * fix: update cli download tests to match new auto/sdk-based download paths * Fix ModelScope single-file local path handling --------- Co-authored-by: nanaoto <10inspiral@gmail.com> Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
This commit is contained in:
parent
326f751e04
commit
5ad3403636
9 changed files with 183 additions and 154 deletions
|
|
@ -1,12 +1,10 @@
|
|||
import contextlib
|
||||
import importlib
|
||||
import io
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from unittest import mock
|
||||
|
||||
|
||||
|
|
@ -72,21 +70,16 @@ class DownloadCommandTests(unittest.TestCase):
|
|||
|
||||
def test_download_defaults_to_huggingface_source_and_checks_downloaded_resources(self):
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
state = user_state_paths(Path(temp_dir))
|
||||
state = user_state_paths(Path(temp_dir).resolve())
|
||||
calls = []
|
||||
real_import_module = importlib.import_module
|
||||
|
||||
def snapshot_download(*, repo_id, local_dir):
|
||||
def fake_snapshot_download(repo_id, local_dir, **kwargs):
|
||||
calls.append((repo_id, Path(local_dir)))
|
||||
make_model_dir(Path(local_dir))
|
||||
|
||||
def import_module(name, package=None):
|
||||
if name == "huggingface_hub":
|
||||
return SimpleNamespace(snapshot_download=snapshot_download)
|
||||
return real_import_module(name, package)
|
||||
return str(local_dir)
|
||||
|
||||
with mock.patch.dict(os.environ, state["env"], clear=False):
|
||||
with mock.patch("importlib.import_module", side_effect=import_module):
|
||||
with mock.patch("indextts.utils.model_download.snapshot_download", side_effect=fake_snapshot_download):
|
||||
exit_code, stdout, stderr = self.run_cli(["download"])
|
||||
config_exists = state["config_path"].exists()
|
||||
|
||||
|
|
@ -98,23 +91,18 @@ class DownloadCommandTests(unittest.TestCase):
|
|||
|
||||
def test_download_from_modelscope_to_model_dir_persists_successful_target_directory(self):
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
temp_path = Path(temp_dir)
|
||||
temp_path = Path(temp_dir).resolve()
|
||||
state = user_state_paths(temp_path)
|
||||
model_dir = temp_path / "custom-models"
|
||||
calls = []
|
||||
real_import_module = importlib.import_module
|
||||
|
||||
def snapshot_download(model_id, *, local_dir):
|
||||
def fake_snapshot(model_id, local_dir, **kwargs):
|
||||
calls.append((model_id, Path(local_dir)))
|
||||
make_model_dir(Path(local_dir))
|
||||
|
||||
def import_module(name, package=None):
|
||||
if name == "modelscope":
|
||||
return SimpleNamespace(snapshot_download=snapshot_download)
|
||||
return real_import_module(name, package)
|
||||
return str(local_dir)
|
||||
|
||||
with mock.patch.dict(os.environ, state["env"], clear=False):
|
||||
with mock.patch("importlib.import_module", side_effect=import_module):
|
||||
with mock.patch("indextts.utils.model_download._snapshot_from_modelscope", side_effect=fake_snapshot):
|
||||
exit_code, stdout, stderr = self.run_cli(
|
||||
["download", "--source", "modelscope", "--model-dir", str(model_dir)]
|
||||
)
|
||||
|
|
@ -128,27 +116,21 @@ class DownloadCommandTests(unittest.TestCase):
|
|||
|
||||
def test_download_from_huggingface_preserves_existing_files_in_model_dir(self):
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
temp_path = Path(temp_dir)
|
||||
temp_path = Path(temp_dir).resolve()
|
||||
state = user_state_paths(temp_path)
|
||||
model_dir = temp_path / "models"
|
||||
sentinel = model_dir / "keep.txt"
|
||||
calls = []
|
||||
model_dir.mkdir()
|
||||
sentinel.write_text("keep", encoding="utf-8")
|
||||
real_import_module = importlib.import_module
|
||||
|
||||
def snapshot_download(*, repo_id, local_dir):
|
||||
def fake_snapshot_download(*, repo_id, local_dir):
|
||||
target = Path(local_dir)
|
||||
calls.append((repo_id, target, sentinel.exists()))
|
||||
make_model_dir(target)
|
||||
|
||||
def import_module(name, package=None):
|
||||
if name == "huggingface_hub":
|
||||
return SimpleNamespace(snapshot_download=snapshot_download)
|
||||
return real_import_module(name, package)
|
||||
|
||||
with mock.patch.dict(os.environ, state["env"], clear=False):
|
||||
with mock.patch("importlib.import_module", side_effect=import_module):
|
||||
with mock.patch("huggingface_hub.snapshot_download", side_effect=fake_snapshot_download):
|
||||
exit_code, stdout, stderr = self.run_cli(
|
||||
["download", "--source", "huggingface", "--model-dir", str(model_dir)]
|
||||
)
|
||||
|
|
@ -162,21 +144,16 @@ class DownloadCommandTests(unittest.TestCase):
|
|||
|
||||
def test_download_with_no_save_does_not_persist_model_dir_override(self):
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
temp_path = Path(temp_dir)
|
||||
temp_path = Path(temp_dir).resolve()
|
||||
state = user_state_paths(temp_path)
|
||||
model_dir = temp_path / "temporary-models"
|
||||
real_import_module = importlib.import_module
|
||||
|
||||
def snapshot_download(*, repo_id, local_dir):
|
||||
def fake_snapshot_download(repo_id, local_dir, **kwargs):
|
||||
make_model_dir(Path(local_dir))
|
||||
|
||||
def import_module(name, package=None):
|
||||
if name == "huggingface_hub":
|
||||
return SimpleNamespace(snapshot_download=snapshot_download)
|
||||
return real_import_module(name, package)
|
||||
return str(local_dir)
|
||||
|
||||
with mock.patch.dict(os.environ, state["env"], clear=False):
|
||||
with mock.patch("importlib.import_module", side_effect=import_module):
|
||||
with mock.patch("indextts.utils.model_download.snapshot_download", side_effect=fake_snapshot_download):
|
||||
exit_code, stdout, stderr = self.run_cli(
|
||||
["download", "--model-dir", str(model_dir), "--no-save"]
|
||||
)
|
||||
|
|
@ -189,38 +166,31 @@ class DownloadCommandTests(unittest.TestCase):
|
|||
|
||||
def test_download_returns_runtime_unavailable_when_source_package_is_missing(self):
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
state = user_state_paths(Path(temp_dir))
|
||||
real_import_module = importlib.import_module
|
||||
state = user_state_paths(Path(temp_dir).resolve())
|
||||
|
||||
def import_module(name, package=None):
|
||||
if name == "huggingface_hub":
|
||||
raise ImportError("No module named huggingface_hub")
|
||||
return real_import_module(name, package)
|
||||
def raise_import(*args, **kwargs):
|
||||
raise ImportError("No module named huggingface_hub")
|
||||
|
||||
with mock.patch.dict(os.environ, state["env"], clear=False):
|
||||
with mock.patch("importlib.import_module", side_effect=import_module):
|
||||
with mock.patch("indextts.utils.model_download.snapshot_download", side_effect=raise_import):
|
||||
exit_code, stdout, stderr = self.run_cli(["download"])
|
||||
config_exists = state["config_path"].exists()
|
||||
|
||||
self.assertEqual(exit_code, 3)
|
||||
self.assertEqual(stdout, "")
|
||||
self.assertIn("ERROR: runtime unavailable for huggingface download source", stderr)
|
||||
self.assertIn("huggingface_hub", stderr)
|
||||
self.assertIn("pip install huggingface_hub", stderr)
|
||||
self.assertIn("ERROR: runtime unavailable for auto download source", stderr)
|
||||
self.assertIn("pip install huggingface_hub modelscope", stderr)
|
||||
self.assertFalse(config_exists)
|
||||
|
||||
def test_download_from_modelscope_returns_runtime_unavailable_when_source_package_is_missing(self):
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
state = user_state_paths(Path(temp_dir))
|
||||
real_import_module = importlib.import_module
|
||||
state = user_state_paths(Path(temp_dir).resolve())
|
||||
|
||||
def import_module(name, package=None):
|
||||
if name == "modelscope":
|
||||
raise ImportError("No module named modelscope")
|
||||
return real_import_module(name, package)
|
||||
def raise_import(*args, **kwargs):
|
||||
raise ImportError("No module named modelscope")
|
||||
|
||||
with mock.patch.dict(os.environ, state["env"], clear=False):
|
||||
with mock.patch("importlib.import_module", side_effect=import_module):
|
||||
with mock.patch("indextts.utils.model_download._snapshot_from_modelscope", side_effect=raise_import):
|
||||
exit_code, stdout, stderr = self.run_cli(["download", "--source", "modelscope"])
|
||||
config_exists = state["config_path"].exists()
|
||||
|
||||
|
|
@ -228,28 +198,22 @@ class DownloadCommandTests(unittest.TestCase):
|
|||
self.assertEqual(stdout, "")
|
||||
self.assertIn("ERROR: runtime unavailable for modelscope download source", stderr)
|
||||
self.assertIn("modelscope", stderr)
|
||||
self.assertIn("pip install modelscope", stderr)
|
||||
self.assertFalse(config_exists)
|
||||
|
||||
def test_download_validates_downloaded_resources_before_persisting_model_dir(self):
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
temp_path = Path(temp_dir)
|
||||
temp_path = Path(temp_dir).resolve()
|
||||
state = user_state_paths(temp_path)
|
||||
model_dir = temp_path / "incomplete-models"
|
||||
real_import_module = importlib.import_module
|
||||
|
||||
def snapshot_download(*, repo_id, local_dir):
|
||||
def fake_snapshot_download(repo_id, local_dir, **kwargs):
|
||||
target = Path(local_dir)
|
||||
target.mkdir(parents=True, exist_ok=True)
|
||||
(target / "config.yaml").write_text("placeholder", encoding="utf-8")
|
||||
|
||||
def import_module(name, package=None):
|
||||
if name == "huggingface_hub":
|
||||
return SimpleNamespace(snapshot_download=snapshot_download)
|
||||
return real_import_module(name, package)
|
||||
return str(local_dir)
|
||||
|
||||
with mock.patch.dict(os.environ, state["env"], clear=False):
|
||||
with mock.patch("importlib.import_module", side_effect=import_module):
|
||||
with mock.patch("indextts.utils.model_download.snapshot_download", side_effect=fake_snapshot_download):
|
||||
exit_code, stdout, stderr = self.run_cli(["download", "--model-dir", str(model_dir)])
|
||||
config_exists = state["config_path"].exists()
|
||||
|
||||
|
|
|
|||
|
|
@ -118,9 +118,9 @@ def _build_parser():
|
|||
)
|
||||
download.add_argument(
|
||||
"--source",
|
||||
choices=("huggingface", "modelscope"),
|
||||
default="huggingface",
|
||||
help="Model download source",
|
||||
choices=("huggingface", "modelscope", "auto"),
|
||||
default="auto",
|
||||
help="Model download source (default: auto-detect based on network)",
|
||||
)
|
||||
download.add_argument(
|
||||
"--model-dir",
|
||||
|
|
@ -312,15 +312,20 @@ def _run_download(args):
|
|||
|
||||
|
||||
def _download_model_resources(source, model_dir):
|
||||
if source == "huggingface":
|
||||
huggingface_hub = importlib.import_module("huggingface_hub")
|
||||
huggingface_hub.snapshot_download(repo_id=MODEL_REPO_ID, local_dir=str(model_dir))
|
||||
return
|
||||
modelscope = importlib.import_module("modelscope")
|
||||
modelscope.snapshot_download(MODEL_REPO_ID, local_dir=str(model_dir))
|
||||
if source == "auto":
|
||||
from indextts.utils.model_download import snapshot_download
|
||||
snapshot_download(MODEL_REPO_ID, local_dir=str(model_dir))
|
||||
elif source == "modelscope":
|
||||
from indextts.utils.model_download import _snapshot_from_modelscope
|
||||
_snapshot_from_modelscope(MODEL_REPO_ID, str(model_dir))
|
||||
else:
|
||||
from huggingface_hub import snapshot_download
|
||||
snapshot_download(repo_id=MODEL_REPO_ID, local_dir=str(model_dir))
|
||||
|
||||
|
||||
def _download_support_package(source):
|
||||
if source == "auto":
|
||||
return "huggingface_hub modelscope"
|
||||
if source == "huggingface":
|
||||
return "huggingface_hub"
|
||||
return "modelscope"
|
||||
|
|
|
|||
|
|
@ -12,11 +12,8 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import datetime
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from collections import OrderedDict
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import yaml
|
||||
|
|
@ -26,7 +23,7 @@ def load_checkpoint(model: torch.nn.Module, model_pth: str) -> dict:
|
|||
checkpoint = torch.load(model_pth, map_location='cpu')
|
||||
checkpoint = checkpoint['model'] if 'model' in checkpoint else checkpoint
|
||||
model.load_state_dict(checkpoint, strict=True)
|
||||
info_path = re.sub('.pth$', '.yaml', model_pth)
|
||||
info_path = str(Path(model_pth).with_suffix('.yaml'))
|
||||
configs = {}
|
||||
if os.path.exists(info_path):
|
||||
with open(info_path, 'r') as fin:
|
||||
|
|
|
|||
|
|
@ -13,7 +13,8 @@ import json
|
|||
import logging
|
||||
import os
|
||||
from typing import List, Set
|
||||
from urllib.request import urlopen, Request
|
||||
|
||||
import requests
|
||||
|
||||
from indextts.utils.network_detection import need_proxy
|
||||
|
||||
|
|
@ -32,19 +33,29 @@ _EXTRA_FILES = [
|
|||
"voice_01.wav", # used in infer.py and infer_v2.py __main__ blocks
|
||||
]
|
||||
|
||||
_SESSION = requests.Session()
|
||||
_SESSION.headers.update({"User-Agent": "IndexTTS/2.0"})
|
||||
|
||||
def _download_file(url: str, local_path: str, timeout: int = 60, min_size: int = 0) -> None:
|
||||
|
||||
def _download_file(
|
||||
url: str,
|
||||
local_path: str,
|
||||
timeout: int = 60,
|
||||
min_size: int = 0,
|
||||
max_bytes: int = 0,
|
||||
) -> None:
|
||||
"""
|
||||
Download a file from a URL to a local path with validation.
|
||||
|
||||
Raises RuntimeError if the server returns an error or non-binary content.
|
||||
When *max_bytes* > 0 the download stops after that many bytes (useful for
|
||||
reachability checks that don't need the full file).
|
||||
"""
|
||||
req = Request(url, headers={"User-Agent": "IndexTTS/2.0"})
|
||||
with urlopen(req, timeout=timeout) as response:
|
||||
status = response.status
|
||||
if status < 200 or status >= 300:
|
||||
raise RuntimeError(f"Server returned HTTP {status} for {url}")
|
||||
content_type = response.headers.get("Content-Type", "")
|
||||
resp = _SESSION.get(url, timeout=timeout, stream=True)
|
||||
try:
|
||||
if resp.status_code < 200 or resp.status_code >= 300:
|
||||
raise RuntimeError(f"Server returned HTTP {resp.status_code} for {url}")
|
||||
content_type = resp.headers.get("Content-Type", "")
|
||||
if "text/html" in content_type:
|
||||
raise RuntimeError(
|
||||
f"Server returned HTML instead of binary file for {url} "
|
||||
|
|
@ -55,11 +66,12 @@ def _download_file(url: str, local_path: str, timeout: int = 60, min_size: int =
|
|||
tmp_path = local_path + ".tmp"
|
||||
try:
|
||||
with open(tmp_path, "wb") as f:
|
||||
while True:
|
||||
chunk = response.read(8192)
|
||||
if not chunk:
|
||||
break
|
||||
received = 0
|
||||
for chunk in resp.iter_content(chunk_size=8192):
|
||||
f.write(chunk)
|
||||
received += len(chunk)
|
||||
if max_bytes and received >= max_bytes:
|
||||
break
|
||||
if min_size and os.path.getsize(tmp_path) < min_size:
|
||||
raise RuntimeError(
|
||||
f"Downloaded file is suspiciously small "
|
||||
|
|
@ -69,6 +81,8 @@ def _download_file(url: str, local_path: str, timeout: int = 60, min_size: int =
|
|||
finally:
|
||||
if os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
finally:
|
||||
resp.close()
|
||||
|
||||
|
||||
def get_required_files() -> List[str]:
|
||||
|
|
@ -118,7 +132,7 @@ def ensure_examples_available(force: bool = False) -> None:
|
|||
continue
|
||||
url = f"{base_url}/examples/{filename}"
|
||||
try:
|
||||
_download_file(url, local_path, min_size=100)
|
||||
_download_file(url, local_path, min_size=100, timeout=120)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to download {filename}: {e}")
|
||||
|
||||
|
|
@ -139,7 +153,7 @@ def download_test_sample(force: bool = False) -> str:
|
|||
base_url = _MS_RAW_URL if need_proxy() else _HF_RAW_URL
|
||||
url = f"{base_url}/examples/voice_01.wav"
|
||||
|
||||
_download_file(url, local_path, min_size=100)
|
||||
_download_file(url, local_path, min_size=100, timeout=120)
|
||||
return local_path
|
||||
|
||||
# Alias for backward compatibility (used by tests)
|
||||
|
|
|
|||
|
|
@ -9,7 +9,6 @@ via ``ensure_models_available()``, so no downloads happen during inference.
|
|||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -36,30 +35,48 @@ _BIGVGAN_REPO = "nvidia/bigvgan_v2_22khz_80band_256x"
|
|||
|
||||
def _download_single_file(repo_id: str, filename: str, local_path: str) -> str:
|
||||
"""Download a single file from a HF/ModelScope repo to a specific local path."""
|
||||
from indextts.utils.examples_downloader import _download_file
|
||||
|
||||
os.makedirs(os.path.dirname(local_path), exist_ok=True)
|
||||
local_dir = os.path.dirname(local_path)
|
||||
os.makedirs(local_dir, exist_ok=True)
|
||||
|
||||
if _get_using_modelscope():
|
||||
ms_model_id = HF_TO_MODELSCOPE_REPO_MAP.get(repo_id, repo_id)
|
||||
# Try ModelScope file_download first
|
||||
# Try ModelScope SDK first
|
||||
try:
|
||||
from modelscope.hub.file_download import model_file_download
|
||||
tmp = model_file_download(model_id=ms_model_id, file_path=filename)
|
||||
shutil.copy2(tmp, local_path)
|
||||
downloaded_path = model_file_download(
|
||||
model_id=ms_model_id, file_path=filename, local_dir=local_dir,
|
||||
)
|
||||
if not downloaded_path or not os.path.isfile(downloaded_path):
|
||||
downloaded_path = os.path.join(local_dir, filename)
|
||||
if os.path.abspath(downloaded_path) != os.path.abspath(local_path):
|
||||
shutil.copy2(downloaded_path, local_path)
|
||||
if not os.path.isfile(local_path):
|
||||
raise RuntimeError(f"Downloaded file not found at expected path: {local_path}")
|
||||
return local_path
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"ModelScope download failed for {ms_model_id}/{filename}: {e}. Falling back to hf-mirror.",
|
||||
exc_info=True,
|
||||
)
|
||||
# Fallback to hf-mirror.com
|
||||
# Fallback to hf-mirror.com (only path that needs manual download)
|
||||
from indextts.utils.examples_downloader import _download_file
|
||||
url = f"https://hf-mirror.com/{repo_id}/resolve/main/{filename}"
|
||||
logger.info(f"Downloading {repo_id}/{filename} from hf-mirror -> {local_path}")
|
||||
_download_file(url, local_path, timeout=300)
|
||||
else:
|
||||
url = f"https://huggingface.co/{repo_id}/resolve/main/{filename}"
|
||||
# Use HuggingFace Hub SDK
|
||||
from huggingface_hub import hf_hub_download
|
||||
logger.info(f"Downloading {repo_id}/{filename} -> {local_path}")
|
||||
downloaded_path = hf_hub_download(repo_id=repo_id, filename=filename, local_dir=local_dir)
|
||||
if downloaded_path and os.path.abspath(downloaded_path) != os.path.abspath(local_path):
|
||||
shutil.copy2(downloaded_path, local_path)
|
||||
elif not os.path.isfile(local_path):
|
||||
fallback_path = os.path.join(local_dir, filename)
|
||||
if os.path.isfile(fallback_path):
|
||||
shutil.copy2(fallback_path, local_path)
|
||||
if not os.path.isfile(local_path):
|
||||
raise RuntimeError(f"Downloaded file not found at expected path: {local_path}")
|
||||
|
||||
logger.info(f"Downloading {repo_id}/{filename} -> {local_path}")
|
||||
_download_file(url, local_path, timeout=300)
|
||||
return local_path
|
||||
|
||||
|
||||
|
|
@ -230,29 +247,6 @@ def _snapshot_from_modelscope(model_id: str, local_dir: str, revision=None) -> s
|
|||
from modelscope.hub.snapshot_download import snapshot_download as _ms_snapshot
|
||||
logger.info(f"Downloading repo from ModelScope: {ms_model_id}")
|
||||
|
||||
# Check if files exist in a subdirectory from a previous download
|
||||
existing_subdir = os.path.join(local_dir, ms_model_id)
|
||||
if os.path.isdir(existing_subdir) and os.listdir(existing_subdir):
|
||||
for item in os.listdir(existing_subdir):
|
||||
src = os.path.join(existing_subdir, item)
|
||||
dst = os.path.join(local_dir, item)
|
||||
if not os.path.exists(dst):
|
||||
shutil.move(src, dst)
|
||||
shutil.rmtree(existing_subdir, ignore_errors=True)
|
||||
return local_dir
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
_ms_snapshot(model_id=ms_model_id, cache_dir=tmpdir, revision=revision)
|
||||
downloaded = os.path.join(tmpdir, ms_model_id)
|
||||
if not os.path.isdir(downloaded):
|
||||
for root, dirs, files in os.walk(tmpdir):
|
||||
if files and root != tmpdir:
|
||||
downloaded = root
|
||||
break
|
||||
os.makedirs(local_dir, exist_ok=True)
|
||||
for item in os.listdir(downloaded):
|
||||
src = os.path.join(downloaded, item)
|
||||
dst = os.path.join(local_dir, item)
|
||||
if not os.path.exists(dst):
|
||||
shutil.move(src, dst)
|
||||
os.makedirs(local_dir, exist_ok=True)
|
||||
_ms_snapshot(model_id=ms_model_id, local_dir=local_dir, revision=revision)
|
||||
return local_dir
|
||||
|
|
|
|||
|
|
@ -54,6 +54,7 @@ dependencies = [
|
|||
"tokenizers==0.21.0",
|
||||
"torch==2.8.*",
|
||||
"torchaudio==2.8.*",
|
||||
"requests>=2.28",
|
||||
"tqdm>=4.67.1",
|
||||
"transformers==4.52.1",
|
||||
|
||||
|
|
@ -73,7 +74,6 @@ deepspeed = [
|
|||
]
|
||||
test = [
|
||||
"pytest>=7.0",
|
||||
"requests>=2.28",
|
||||
]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
|
|
|
|||
|
|
@ -1,2 +1,15 @@
|
|||
# padding_test.py is a manual script, not a pytest test module.
|
||||
collect_ignore = ["padding_test.py"]
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
parser.addoption(
|
||||
"--use-modelscope", action="store_true", default=False,
|
||||
help="Force download tests to use ModelScope path (for CN network testing)",
|
||||
)
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
if config.getoption("--use-modelscope"):
|
||||
import indextts.utils.model_download as md
|
||||
md._USING_MODELSCOPE = True
|
||||
|
|
|
|||
|
|
@ -8,6 +8,8 @@ CI only (no GPU):
|
|||
uv run --extra test pytest tests/test_v2.py -v -m "not gpu"
|
||||
"""
|
||||
import importlib
|
||||
import sys
|
||||
import types
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
|
@ -41,35 +43,39 @@ def prompt_wav():
|
|||
|
||||
# -- Download URL checks (no GPU) ---------------------------------------------
|
||||
|
||||
def _get_download_urls():
|
||||
"""Return URLs that match the actual network environment (CN vs intl)."""
|
||||
from indextts.utils.network_detection import need_proxy
|
||||
if need_proxy():
|
||||
# China: test ModelScope endpoints
|
||||
return [
|
||||
"https://modelscope.cn/studio/IndexTeam/IndexTTS-2-Demo/resolve/master/examples/voice_01.wav",
|
||||
"https://modelscope.cn/models/AI-ModelScope/w2v-bert-2.0/resolve/master/config.json",
|
||||
]
|
||||
else:
|
||||
# International: test HuggingFace endpoints
|
||||
return [
|
||||
"https://huggingface.co/spaces/IndexTeam/IndexTTS-2-Demo/resolve/main/examples/voice_01.wav",
|
||||
"https://huggingface.co/nvidia/bigvgan_v2_22khz_80band_256x/resolve/main/config.json",
|
||||
]
|
||||
# Each auxiliary model: (test_id, repo_id, probe_file)
|
||||
# probe_file: a small file in the repo to verify download works end-to-end.
|
||||
_MODEL_PROBES = [
|
||||
("bigvgan", "nvidia/bigvgan_v2_22khz_80band_256x", "config.json"),
|
||||
("w2v-bert-2.0", "facebook/w2v-bert-2.0", "config.json"),
|
||||
("campplus", "funasr/campplus", "campplus_cn_common.bin"),
|
||||
("MaskGCT", "amphion/MaskGCT", "README.md"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("url", _get_download_urls())
|
||||
def test_download_url_reachable(url, tmp_path):
|
||||
"""Use the actual _download_file function to verify URLs are reachable."""
|
||||
@pytest.mark.parametrize("name,repo_id,filename", _MODEL_PROBES, ids=[m[0] for m in _MODEL_PROBES])
|
||||
def test_model_download_reachable(name, repo_id, filename, tmp_path):
|
||||
"""Each auxiliary model must be downloadable via the real download path."""
|
||||
from indextts.utils.examples_downloader import _download_file
|
||||
from indextts.utils.network_detection import need_proxy
|
||||
|
||||
dest = tmp_path / "downloaded"
|
||||
# Real download, same code path as production. No min_size so we don't
|
||||
# need to pull a full large file — just verify it doesn't raise.
|
||||
_download_file(url, str(dest), timeout=30)
|
||||
base_url = "https://hf-mirror.com" if need_proxy() else "https://huggingface.co"
|
||||
url = f"{base_url}/{repo_id}/resolve/main/{filename}"
|
||||
dest = tmp_path / filename
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
_download_file(url, str(dest), max_bytes=8192)
|
||||
assert dest.exists() and dest.stat().st_size > 0
|
||||
|
||||
|
||||
def test_example_download_reachable(tmp_path, monkeypatch):
|
||||
"""Example audio must be downloadable via the real download path."""
|
||||
from indextts.utils import examples_downloader
|
||||
|
||||
monkeypatch.setattr(examples_downloader, "_TESTS_DIR", str(tmp_path))
|
||||
path = examples_downloader.download_test_sample(force=True)
|
||||
assert Path(path).exists() and Path(path).stat().st_size > 0
|
||||
|
||||
|
||||
# -- Model download logic (no GPU) --------------------------------------------
|
||||
|
||||
def test_legacy_cache_compatibility(tmp_path, monkeypatch):
|
||||
|
|
@ -117,6 +123,42 @@ def test_legacy_cache_compatibility(tmp_path, monkeypatch):
|
|||
assert len(download_calls) == 0, f"Unexpected downloads: {download_calls}"
|
||||
|
||||
|
||||
def test_modelscope_single_file_download_matches_local_path(tmp_path, monkeypatch):
|
||||
"""ModelScope single-file download must produce the exact requested local_path."""
|
||||
from indextts.utils import model_download
|
||||
|
||||
local_path = tmp_path / "hf_cache" / "semantic_codec_model.safetensors"
|
||||
expected_bytes = b"fake_semantic"
|
||||
|
||||
def fake_model_file_download(model_id, file_path, local_dir):
|
||||
downloaded = Path(local_dir) / file_path
|
||||
downloaded.parent.mkdir(parents=True, exist_ok=True)
|
||||
downloaded.write_bytes(expected_bytes)
|
||||
return str(downloaded)
|
||||
|
||||
fake_modelscope = types.ModuleType("modelscope")
|
||||
fake_hub = types.ModuleType("modelscope.hub")
|
||||
fake_file_download = types.ModuleType("modelscope.hub.file_download")
|
||||
fake_file_download.model_file_download = fake_model_file_download
|
||||
fake_hub.file_download = fake_file_download
|
||||
fake_modelscope.hub = fake_hub
|
||||
|
||||
monkeypatch.setitem(sys.modules, "modelscope", fake_modelscope)
|
||||
monkeypatch.setitem(sys.modules, "modelscope.hub", fake_hub)
|
||||
monkeypatch.setitem(sys.modules, "modelscope.hub.file_download", fake_file_download)
|
||||
monkeypatch.setattr(model_download, "_get_using_modelscope", lambda: True)
|
||||
|
||||
got = model_download._download_single_file(
|
||||
repo_id="amphion/MaskGCT",
|
||||
filename="semantic_codec/model.safetensors",
|
||||
local_path=str(local_path),
|
||||
)
|
||||
|
||||
assert got == str(local_path)
|
||||
assert local_path.exists()
|
||||
assert local_path.read_bytes() == expected_bytes
|
||||
|
||||
|
||||
# -- Inference (GPU required) --------------------------------------------------
|
||||
|
||||
INFER_TEXTS = [
|
||||
|
|
|
|||
4
uv.lock
generated
4
uv.lock
generated
|
|
@ -1183,6 +1183,7 @@ dependencies = [
|
|||
{ name = "omegaconf" },
|
||||
{ name = "opencv-python" },
|
||||
{ name = "pandas" },
|
||||
{ name = "requests" },
|
||||
{ name = "safetensors" },
|
||||
{ name = "sentencepiece" },
|
||||
{ name = "tensorboard" },
|
||||
|
|
@ -1204,7 +1205,6 @@ deepspeed = [
|
|||
]
|
||||
test = [
|
||||
{ name = "pytest" },
|
||||
{ name = "requests" },
|
||||
]
|
||||
webui = [
|
||||
{ name = "gradio" },
|
||||
|
|
@ -1234,7 +1234,7 @@ requires-dist = [
|
|||
{ name = "opencv-python", specifier = "==4.9.0.80" },
|
||||
{ name = "pandas", specifier = "==2.3.2" },
|
||||
{ name = "pytest", marker = "extra == 'test'", specifier = ">=7.0" },
|
||||
{ name = "requests", marker = "extra == 'test'", specifier = ">=2.28" },
|
||||
{ name = "requests", specifier = ">=2.28" },
|
||||
{ name = "safetensors", specifier = "==0.5.2" },
|
||||
{ name = "sentencepiece", specifier = ">=0.2.1" },
|
||||
{ name = "tensorboard", specifier = "==2.9.1" },
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue