diff --git a/tests/app/services/model_load/test_load_api.py b/tests/app/services/model_load/test_load_api.py index 9671c8c6c3..8dd948692d 100644 --- a/tests/app/services/model_load/test_load_api.py +++ b/tests/app/services/model_load/test_load_api.py @@ -57,11 +57,13 @@ def test_load_from_path(mock_context: InvocationContext, embedding_file: Path) - assert isinstance(loaded_model_3.model, dict) assert torch.equal(loaded_model_1.model["emb_params"], loaded_model_3.model["emb_params"]) + def test_load_from_dir(mock_context: InvocationContext, vae_directory: Path) -> None: loaded_model = mock_context.models.load_and_cache_model(vae_directory) assert isinstance(loaded_model, LoadedModelWithoutConfig) assert isinstance(loaded_model.model, AutoencoderTiny) + def test_download_and_load(mock_context: InvocationContext) -> None: loaded_model_1 = mock_context.models.load_and_cache_model( "https://www.test.foo/download/test_embedding.safetensors" diff --git a/tests/backend/model_manager/model_manager_fixtures.py b/tests/backend/model_manager/model_manager_fixtures.py index ee66c459b8..9ce272fc42 100644 --- a/tests/backend/model_manager/model_manager_fixtures.py +++ b/tests/backend/model_manager/model_manager_fixtures.py @@ -60,6 +60,7 @@ def mm2_model_files(tmp_path_factory) -> Path: def embedding_file(mm2_model_files: Path) -> Path: return mm2_model_files / "test_embedding.safetensors" + @pytest.fixture def vae_directory(mm2_model_files: Path) -> Path: return mm2_model_files / "taesdxl"