Tighten launcher and CUDA compat guards

This commit is contained in:
Rich 2026-03-13 19:16:59 +00:00
parent de512c0477
commit 88b90f9192
6 changed files with 29 additions and 4 deletions

View file

@ -22,9 +22,11 @@ class InitServiceLoaderMixin:
_ = torch.tensor([True, False], device=target_device).argsort()
return True
except RuntimeError as exc:
if "bool dtype" in str(exc):
return False
return True
logger.debug(
"[_cuda_supports_bool_argsort] Treating CUDA bool argsort probe failure as unsupported: {}",
exc,
)
return False
def _apply_cuda_bool_argsort_workaround(self) -> None:
"""Patch dynamic model helpers when bool argsort is unsupported on CUDA."""

View file

@ -512,6 +512,16 @@ class InitServiceMixinTests(unittest.TestCase):
self.assertIsNot(dummy_module.pack_sequences, _pack_sequences)
self.assertTrue(getattr(dummy_module.pack_sequences, "__acestep_bool_argsort_patched__", False))
def test_cuda_supports_bool_argsort_returns_false_for_unexpected_runtime_error(self):
"""It treats any CUDA bool argsort RuntimeError as unsupported."""
host = _Host(project_root="K:/fake_root", device="cuda")
with patch("torch.cuda.is_available", return_value=True), patch(
"torch.tensor",
side_effect=RuntimeError("unexpected argsort failure"),
):
self.assertFalse(host._cuda_supports_bool_argsort())
def test_validate_quantization_setup_raises_import_error_when_torchao_missing(self):
"""It raises ImportError with guidance when torchao is unavailable."""
host = _Host(project_root="K:/fake_root", device="cpu")

View file

@ -28,6 +28,8 @@ class LauncherLegacyTorchFixTests(unittest.TestCase):
content = self._read("start_api_server.sh")
self.assertIn("ACESTEP_SKIP_LEGACY_TORCH_FIX", content)
self.assertIn("legacy_torch_fix_probe_exit_code", content)
self.assertIn("legacy NVIDIA compatibility probe failed with exit code $compat_status", content)
self.assertIn("return 1", content)
self.assertIn("torch==2.5.1+cu121", content)
def test_windows_gradio_launcher_calls_shared_probe(self) -> None:
@ -36,6 +38,8 @@ class LauncherLegacyTorchFixTests(unittest.TestCase):
self.assertIn('if /i "%ACESTEP_SKIP_LEGACY_TORCH_FIX%"=="true"', content)
self.assertIn("legacy_torch_fix_probe_exit_code", content)
self.assertIn("torch==2.5.1+cu121", content)
self.assertEqual(content.count("call :EnsureLegacyNvidiaTorchCompat"), 1)
self.assertEqual(content.count("if !ERRORLEVEL! NEQ 0 exit /b !ERRORLEVEL!"), 1)
def test_windows_api_launcher_calls_shared_probe(self) -> None:
"""Windows API launcher should call shared Python compatibility probe."""
@ -43,6 +47,8 @@ class LauncherLegacyTorchFixTests(unittest.TestCase):
self.assertIn('if /i "%ACESTEP_SKIP_LEGACY_TORCH_FIX%"=="true"', content)
self.assertIn("legacy_torch_fix_probe_exit_code", content)
self.assertIn("torch==2.5.1+cu121", content)
self.assertEqual(content.count("call :EnsureLegacyNvidiaTorchCompat"), 1)
self.assertEqual(content.count("if !ERRORLEVEL! NEQ 0 exit /b !ERRORLEVEL!"), 1)
if __name__ == "__main__":

View file

@ -271,6 +271,7 @@ if exist "%~dp0python_embedded\python.exe" (
)
call :EnsureLegacyNvidiaTorchCompat
if !ERRORLEVEL! NEQ 0 exit /b !ERRORLEVEL!
echo Starting ACE-Step API Server...
echo.

View file

@ -180,9 +180,13 @@ _ensure_legacy_nvidia_torch_compat() {
compat_status=$?
fi
if [[ "$compat_status" -ne 42 ]]; then
if [[ "$compat_status" -eq 0 ]]; then
return 0
fi
if [[ "$compat_status" -ne 42 ]]; then
echo "[Compatibility] Error: legacy NVIDIA compatibility probe failed with exit code $compat_status."
return "$compat_status"
fi
echo "[Compatibility] Applying legacy NVIDIA torch build (CUDA 12.1, supports sm_61)..."
if (cd "$SCRIPT_DIR" && uv pip install --python .venv/bin/python --force-reinstall \
@ -200,6 +204,7 @@ _ensure_legacy_nvidia_torch_compat() {
echo "[Compatibility] Warning: failed to install legacy torch automatically."
echo "[Compatibility] Run manually:"
echo " uv pip install --python .venv/bin/python --force-reinstall --index-url https://download.pytorch.org/whl/cu121 torch==2.5.1+cu121 torchvision==0.20.1+cu121 torchaudio==2.5.1+cu121"
return 1
fi
}

View file

@ -297,6 +297,7 @@ if exist "%~dp0python_embedded\python.exe" (
)
call :EnsureLegacyNvidiaTorchCompat
if !ERRORLEVEL! NEQ 0 exit /b !ERRORLEVEL!
echo Starting ACE-Step Gradio UI...
echo.