mirror of
https://github.com/ace-step/ACE-Step-1.5.git
synced 2026-07-02 16:37:04 +00:00
Tighten launcher and CUDA compat guards
This commit is contained in:
parent
de512c0477
commit
88b90f9192
6 changed files with 29 additions and 4 deletions
|
|
@ -22,9 +22,11 @@ class InitServiceLoaderMixin:
|
|||
_ = torch.tensor([True, False], device=target_device).argsort()
|
||||
return True
|
||||
except RuntimeError as exc:
|
||||
if "bool dtype" in str(exc):
|
||||
return False
|
||||
return True
|
||||
logger.debug(
|
||||
"[_cuda_supports_bool_argsort] Treating CUDA bool argsort probe failure as unsupported: {}",
|
||||
exc,
|
||||
)
|
||||
return False
|
||||
|
||||
def _apply_cuda_bool_argsort_workaround(self) -> None:
|
||||
"""Patch dynamic model helpers when bool argsort is unsupported on CUDA."""
|
||||
|
|
|
|||
|
|
@ -512,6 +512,16 @@ class InitServiceMixinTests(unittest.TestCase):
|
|||
self.assertIsNot(dummy_module.pack_sequences, _pack_sequences)
|
||||
self.assertTrue(getattr(dummy_module.pack_sequences, "__acestep_bool_argsort_patched__", False))
|
||||
|
||||
def test_cuda_supports_bool_argsort_returns_false_for_unexpected_runtime_error(self):
|
||||
"""It treats any CUDA bool argsort RuntimeError as unsupported."""
|
||||
host = _Host(project_root="K:/fake_root", device="cuda")
|
||||
|
||||
with patch("torch.cuda.is_available", return_value=True), patch(
|
||||
"torch.tensor",
|
||||
side_effect=RuntimeError("unexpected argsort failure"),
|
||||
):
|
||||
self.assertFalse(host._cuda_supports_bool_argsort())
|
||||
|
||||
def test_validate_quantization_setup_raises_import_error_when_torchao_missing(self):
|
||||
"""It raises ImportError with guidance when torchao is unavailable."""
|
||||
host = _Host(project_root="K:/fake_root", device="cpu")
|
||||
|
|
|
|||
|
|
@ -28,6 +28,8 @@ class LauncherLegacyTorchFixTests(unittest.TestCase):
|
|||
content = self._read("start_api_server.sh")
|
||||
self.assertIn("ACESTEP_SKIP_LEGACY_TORCH_FIX", content)
|
||||
self.assertIn("legacy_torch_fix_probe_exit_code", content)
|
||||
self.assertIn("legacy NVIDIA compatibility probe failed with exit code $compat_status", content)
|
||||
self.assertIn("return 1", content)
|
||||
self.assertIn("torch==2.5.1+cu121", content)
|
||||
|
||||
def test_windows_gradio_launcher_calls_shared_probe(self) -> None:
|
||||
|
|
@ -36,6 +38,8 @@ class LauncherLegacyTorchFixTests(unittest.TestCase):
|
|||
self.assertIn('if /i "%ACESTEP_SKIP_LEGACY_TORCH_FIX%"=="true"', content)
|
||||
self.assertIn("legacy_torch_fix_probe_exit_code", content)
|
||||
self.assertIn("torch==2.5.1+cu121", content)
|
||||
self.assertEqual(content.count("call :EnsureLegacyNvidiaTorchCompat"), 1)
|
||||
self.assertEqual(content.count("if !ERRORLEVEL! NEQ 0 exit /b !ERRORLEVEL!"), 1)
|
||||
|
||||
def test_windows_api_launcher_calls_shared_probe(self) -> None:
|
||||
"""Windows API launcher should call shared Python compatibility probe."""
|
||||
|
|
@ -43,6 +47,8 @@ class LauncherLegacyTorchFixTests(unittest.TestCase):
|
|||
self.assertIn('if /i "%ACESTEP_SKIP_LEGACY_TORCH_FIX%"=="true"', content)
|
||||
self.assertIn("legacy_torch_fix_probe_exit_code", content)
|
||||
self.assertIn("torch==2.5.1+cu121", content)
|
||||
self.assertEqual(content.count("call :EnsureLegacyNvidiaTorchCompat"), 1)
|
||||
self.assertEqual(content.count("if !ERRORLEVEL! NEQ 0 exit /b !ERRORLEVEL!"), 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -271,6 +271,7 @@ if exist "%~dp0python_embedded\python.exe" (
|
|||
)
|
||||
|
||||
call :EnsureLegacyNvidiaTorchCompat
|
||||
if !ERRORLEVEL! NEQ 0 exit /b !ERRORLEVEL!
|
||||
|
||||
echo Starting ACE-Step API Server...
|
||||
echo.
|
||||
|
|
|
|||
|
|
@ -180,9 +180,13 @@ _ensure_legacy_nvidia_torch_compat() {
|
|||
compat_status=$?
|
||||
fi
|
||||
|
||||
if [[ "$compat_status" -ne 42 ]]; then
|
||||
if [[ "$compat_status" -eq 0 ]]; then
|
||||
return 0
|
||||
fi
|
||||
if [[ "$compat_status" -ne 42 ]]; then
|
||||
echo "[Compatibility] Error: legacy NVIDIA compatibility probe failed with exit code $compat_status."
|
||||
return "$compat_status"
|
||||
fi
|
||||
|
||||
echo "[Compatibility] Applying legacy NVIDIA torch build (CUDA 12.1, supports sm_61)..."
|
||||
if (cd "$SCRIPT_DIR" && uv pip install --python .venv/bin/python --force-reinstall \
|
||||
|
|
@ -200,6 +204,7 @@ _ensure_legacy_nvidia_torch_compat() {
|
|||
echo "[Compatibility] Warning: failed to install legacy torch automatically."
|
||||
echo "[Compatibility] Run manually:"
|
||||
echo " uv pip install --python .venv/bin/python --force-reinstall --index-url https://download.pytorch.org/whl/cu121 torch==2.5.1+cu121 torchvision==0.20.1+cu121 torchaudio==2.5.1+cu121"
|
||||
return 1
|
||||
fi
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -297,6 +297,7 @@ if exist "%~dp0python_embedded\python.exe" (
|
|||
)
|
||||
|
||||
call :EnsureLegacyNvidiaTorchCompat
|
||||
if !ERRORLEVEL! NEQ 0 exit /b !ERRORLEVEL!
|
||||
|
||||
echo Starting ACE-Step Gradio UI...
|
||||
echo.
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue