diff --git a/tests/app/services/bulk_download/test_bulk_download.py b/tests/app/services/bulk_download/test_bulk_download.py new file mode 100644 index 0000000000..4f476c21be --- /dev/null +++ b/tests/app/services/bulk_download/test_bulk_download.py @@ -0,0 +1,319 @@ +import os +from pathlib import Path +from typing import Any +from zipfile import ZipFile + +import pytest + +from invokeai.app.services.board_records.board_records_common import BoardRecord, BoardRecordNotFoundException +from invokeai.app.services.board_records.board_records_sqlite import SqliteBoardRecordStorage +from invokeai.app.services.bulk_download.bulk_download_common import BulkDownloadTargetException +from invokeai.app.services.bulk_download.bulk_download_default import BulkDownloadService +from invokeai.app.services.config.config_default import InvokeAIAppConfig +from invokeai.app.services.image_records.image_records_common import ( + ImageCategory, + ImageRecordNotFoundException, + ResourceOrigin, +) +from invokeai.app.services.images.images_common import ImageDTO +from invokeai.app.services.images.images_default import ImageService +from invokeai.app.services.invocation_services import InvocationServices +from invokeai.app.services.invoker import Invoker +from invokeai.app.services.shared.pagination import OffsetPaginatedResults +from invokeai.backend.util.logging import InvokeAILogger +from tests.fixtures.event_service import DummyEventService, mock_event_service # noqa: F401,F811 +from tests.fixtures.sqlite_database import create_mock_sqlite_database + + +@pytest.fixture +def mock_image_dto() -> ImageDTO: + """Create a mock ImageDTO.""" + return ImageDTO( + image_name="mock_image.png", + board_id="12345", + image_url="None", + width=100, + height=100, + thumbnail_url="None", + image_origin=ResourceOrigin.INTERNAL, + image_category=ImageCategory.GENERAL, + created_at="None", + updated_at="None", + starred=False, + has_workflow=False, + is_intermediate=False, + ) + + +@pytest.fixture +def mock_services(mock_event_service: DummyEventService) -> InvocationServices: + configuration = InvokeAIAppConfig(use_memory_db=True, node_cache_size=0) + logger = InvokeAILogger.get_logger() + db = create_mock_sqlite_database(configuration, logger) + + return InvocationServices( + board_image_records=None, # type: ignore + board_images=None, # type: ignore + board_records=SqliteBoardRecordStorage(db=db), + boards=None, # type: ignore + bulk_download=None, # type: ignore + configuration=None, # type: ignore + events=mock_event_service, + graph_execution_manager=None, # type: ignore + image_files=None, # type: ignore + image_records=None, # type: ignore + images=ImageService(), + invocation_cache=None, # type: ignore + latents=None, # type: ignore + logger=logger, + model_manager=None, # type: ignore + model_records=None, # type: ignore + download_queue=None, # type: ignore + model_install=None, # type: ignore + names=None, # type: ignore + performance_statistics=None, # type: ignore + processor=None, # type: ignore + queue=None, # type: ignore + session_processor=None, # type: ignore + session_queue=None, # type: ignore + urls=None, # type: ignore + workflow_records=None, # type: ignore + ) + + +@pytest.fixture() +def mock_invoker(mock_services: InvocationServices) -> Invoker: + return Invoker(services=mock_services) + + +def test_get_path_when_file_exists(tmp_path: Path) -> None: + """Test get_path when the file exists.""" + + # Create a directory at tmp_path/bulk_downloads + test_bulk_downloads_dir: Path = tmp_path / "bulk_downloads" + test_bulk_downloads_dir.mkdir(parents=True, exist_ok=True) + + # Create a file at tmp_path/bulk_downloads/test.zip + test_file_path: Path = test_bulk_downloads_dir / "test.zip" + test_file_path.touch() + + bulk_download_service = BulkDownloadService(tmp_path) + assert bulk_download_service.get_path("test.zip") == str(test_file_path) + + +def test_get_path_when_file_does_not_exist(tmp_path: Path) -> None: + """Test get_path when the file does not exist.""" + + bulk_download_service = BulkDownloadService(tmp_path) + with pytest.raises(BulkDownloadTargetException): + bulk_download_service.get_path("test") + + +def test_bulk_downloads_dir_created_at_start(tmp_path: Path) -> None: + """Test that the bulk_downloads directory is created at start.""" + + BulkDownloadService(tmp_path) + assert (tmp_path / "bulk_downloads").exists() + + +def test_handler_image_names(tmp_path: Path, monkeypatch: Any, mock_image_dto: ImageDTO, mock_invoker: Invoker): + """Test that the handler creates the zip file correctly when given a list of image names.""" + + expected_zip_path, expected_image_path, mock_image_contents = prepare_handler_test( + tmp_path, monkeypatch, mock_image_dto, mock_invoker + ) + + bulk_download_service = BulkDownloadService(tmp_path) + bulk_download_service.start(mock_invoker) + bulk_download_service.handler([mock_image_dto.image_name], None) + + assert_handler_success( + expected_zip_path, expected_image_path, mock_image_contents, tmp_path, mock_invoker.services.events + ) + + +def test_handler_board_id(tmp_path: Path, monkeypatch: Any, mock_image_dto: ImageDTO, mock_invoker: Invoker): + """Test that the handler creates the zip file correctly when given a board id.""" + + expected_zip_path, expected_image_path, mock_image_contents = prepare_handler_test( + tmp_path, monkeypatch, mock_image_dto, mock_invoker + ) + + def mock_board_get(*args, **kwargs): + return BoardRecord(board_id="12345", board_name="test", created_at="None", updated_at="None") + + monkeypatch.setattr(mock_invoker.services.board_records, "get", mock_board_get) + + def mock_get_many(*args, **kwargs): + return OffsetPaginatedResults(limit=-1, total=1, offset=0, items=[mock_image_dto]) + + monkeypatch.setattr(mock_invoker.services.images, "get_many", mock_get_many) + + bulk_download_service = BulkDownloadService(tmp_path) + bulk_download_service.start(mock_invoker) + bulk_download_service.handler([], "test") + + assert_handler_success( + expected_zip_path, expected_image_path, mock_image_contents, tmp_path, mock_invoker.services.events + ) + + +def test_handler_board_id_default(tmp_path: Path, monkeypatch: Any, mock_image_dto: ImageDTO, mock_invoker: Invoker): + """Test that the handler creates the zip file correctly when given a board id.""" + + _, expected_image_path, mock_image_contents = prepare_handler_test( + tmp_path, monkeypatch, mock_image_dto, mock_invoker + ) + expected_zip_path: Path = tmp_path / "bulk_downloads" / "Uncategorized.zip" + + def mock_get_many(*args, **kwargs): + return OffsetPaginatedResults(limit=-1, total=1, offset=0, items=[mock_image_dto]) + + monkeypatch.setattr(mock_invoker.services.images, "get_many", mock_get_many) + + bulk_download_service = BulkDownloadService(tmp_path) + bulk_download_service.start(mock_invoker) + bulk_download_service.handler([], "none") + + assert_handler_success( + expected_zip_path, expected_image_path, mock_image_contents, tmp_path, mock_invoker.services.events + ) + + +def prepare_handler_test(tmp_path: Path, monkeypatch: Any, mock_image_dto: ImageDTO, mock_invoker: Invoker): + """Prepare the test for the handler tests.""" + + def mock_uuid_string(): + return "test" + + # You have to patch the function within the module it's being imported into. This is strange, but it works. + # See http://www.gregreda.com/2021/06/28/mocking-imported-module-function-python/ + monkeypatch.setattr("invokeai.app.services.bulk_download.bulk_download_default.uuid_string", mock_uuid_string) + + expected_zip_path: Path = tmp_path / "bulk_downloads" / "test.zip" + expected_image_path: Path = ( + tmp_path / "bulk_downloads" / mock_image_dto.image_category.value / mock_image_dto.image_name + ) + + # Mock the get_dto method so that when the image dto needs to be retrieved it is returned + def mock_get_dto(*args, **kwargs): + return mock_image_dto + + monkeypatch.setattr(mock_invoker.services.images, "get_dto", mock_get_dto) + + # Create a mock image file so that the contents of the zip file are not empty + mock_image_path: Path = tmp_path / mock_image_dto.image_name + mock_image_contents: str = "Totally an image" + mock_image_path.write_text(mock_image_contents) + + def mock_get_path(*args, **kwargs): + return str(mock_image_path) + + monkeypatch.setattr(mock_invoker.services.images, "get_path", mock_get_path) + + return expected_zip_path, expected_image_path, mock_image_contents + + +def assert_handler_success( + expected_zip_path: Path, + expected_image_path: Path, + mock_image_contents: str, + tmp_path: Path, + event_bus: DummyEventService, +): + """Assert that the handler was successful.""" + # Check that the zip file was created + assert expected_zip_path.exists() + assert expected_zip_path.is_file() + assert expected_zip_path.stat().st_size > 0 + + # Check that the zip contents are expected + with ZipFile(expected_zip_path, "r") as zip_file: + zip_file.extractall(tmp_path / "bulk_downloads") + assert expected_image_path.exists() + assert expected_image_path.is_file() + assert expected_image_path.stat().st_size > 0 + assert expected_image_path.read_text() == mock_image_contents + + # Check that the correct events were emitted + assert len(event_bus.events) == 2 + assert event_bus.events[0].event_name == "bulk_download_started" + assert event_bus.events[1].event_name == "bulk_download_completed" + assert event_bus.events[1].payload["bulk_download_item_name"] == os.path.basename(expected_zip_path) + + +def test_stop(tmp_path: Path) -> None: + """Test that the stop method removes the bulk_downloads directory.""" + + bulk_download_service = BulkDownloadService(tmp_path) + + mock_file: Path = tmp_path / "bulk_downloads" / "test.zip" + mock_file.write_text("contents") + + bulk_download_service.stop() + + assert (tmp_path / "bulk_downloads").exists() + assert len(os.listdir(tmp_path / "bulk_downloads")) == 0 + + +def test_handler_on_image_not_found(tmp_path: Path, monkeypatch: Any, mock_image_dto: ImageDTO, mock_invoker: Invoker): + """Test that the handler emits an error event when the image is not found.""" + exception: Exception = ImageRecordNotFoundException("Image not found") + + def mock_get_dto(*args, **kwargs): + raise exception + + monkeypatch.setattr(mock_invoker.services.images, "get_dto", mock_get_dto) + + execute_handler_test_on_error(tmp_path, monkeypatch, mock_image_dto, mock_invoker, exception) + + +def test_handler_on_board_not_found(tmp_path: Path, monkeypatch: Any, mock_image_dto: ImageDTO, mock_invoker: Invoker): + """Test that the handler emits an error event when the image is not found.""" + + exception: Exception = BoardRecordNotFoundException("Image not found") + + def mock_get_board_name(*args, **kwargs): + raise exception + + monkeypatch.setattr(mock_invoker.services.images, "get_dto", mock_get_board_name) + + execute_handler_test_on_error(tmp_path, monkeypatch, mock_image_dto, mock_invoker, exception) + + +def test_handler_on_generic_exception( + tmp_path: Path, monkeypatch: Any, mock_image_dto: ImageDTO, mock_invoker: Invoker +): + """Test that the handler emits an error event when the image is not found.""" + + exception: Exception = Exception("Generic exception") + + def mock_get_board_name(*args, **kwargs): + raise exception + + monkeypatch.setattr(mock_invoker.services.images, "get_dto", mock_get_board_name) + + with pytest.raises(Exception): + execute_handler_test_on_error(tmp_path, monkeypatch, mock_image_dto, mock_invoker, exception) + + event_bus: DummyEventService = mock_invoker.services.events + + assert len(event_bus.events) == 2 + assert event_bus.events[0].event_name == "bulk_download_started" + assert event_bus.events[1].event_name == "bulk_download_failed" + assert event_bus.events[1].payload["error"] == exception.__str__() + + +def execute_handler_test_on_error( + tmp_path: Path, monkeypatch: Any, mock_image_dto: ImageDTO, mock_invoker: Invoker, error: Exception +): + bulk_download_service = BulkDownloadService(tmp_path) + bulk_download_service.start(mock_invoker) + bulk_download_service.handler([mock_image_dto.image_name], None) + + event_bus: DummyEventService = mock_invoker.services.events + + assert len(event_bus.events) == 2 + assert event_bus.events[0].event_name == "bulk_download_started" + assert event_bus.events[1].event_name == "bulk_download_failed" + assert event_bus.events[1].payload["error"] == error.__str__() diff --git a/tests/fixtures/event_service.py b/tests/fixtures/event_service.py index 71262be3f9..0a09fa0d64 100644 --- a/tests/fixtures/event_service.py +++ b/tests/fixtures/event_service.py @@ -29,6 +29,6 @@ class DummyEventService(EventServiceBase): @pytest.fixture -def mock_event_service() -> EventServiceBase: +def mock_event_service() -> DummyEventService: """Create a dummy event service.""" return DummyEventService()