mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
cleaning up bulk download zip after the response is complete
This commit is contained in:
parent
7544b350f3
commit
79eb871683
@ -409,10 +409,10 @@ async def download_images_from_list(
|
||||
},
|
||||
)
|
||||
async def get_bulk_download_item(
|
||||
background_tasks: BackgroundTasks,
|
||||
bulk_download_item_name: str = Path(description="The bulk_download_item_id of the bulk download item to get"),
|
||||
) -> FileResponse:
|
||||
"""Gets a bulk download zip file"""
|
||||
|
||||
try:
|
||||
path = ApiDependencies.invoker.services.bulk_download.get_path(bulk_download_item_name)
|
||||
|
||||
@ -423,6 +423,7 @@ async def get_bulk_download_item(
|
||||
content_disposition_type="inline",
|
||||
)
|
||||
response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}"
|
||||
background_tasks.add_task(ApiDependencies.invoker.services.bulk_download.delete, bulk_download_item_name)
|
||||
return response
|
||||
except Exception:
|
||||
raise HTTPException(status_code=404)
|
||||
|
@ -59,3 +59,11 @@ class BulkDownloadBase(ABC):
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, bulk_download_item_name: str) -> None:
|
||||
"""
|
||||
Delete the bulk download file.
|
||||
|
||||
:param bulk_download_item_name: The name of the bulk download item.
|
||||
"""
|
||||
|
@ -38,23 +38,6 @@ class BulkDownloadService(BulkDownloadBase):
|
||||
self.__bulk_downloads_folder = self.__output_folder / "bulk_downloads"
|
||||
self.__bulk_downloads_folder.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def get_path(self, bulk_download_item_name: str) -> str:
|
||||
"""
|
||||
Get the path to the bulk download file.
|
||||
|
||||
:param bulk_download_item_name: The name of the bulk download item.
|
||||
:return: The path to the bulk download file.
|
||||
"""
|
||||
path = str(self.__bulk_downloads_folder / bulk_download_item_name)
|
||||
if not self.validate_path(path):
|
||||
raise BulkDownloadTargetException()
|
||||
return path
|
||||
|
||||
def validate_path(self, path: Union[str, Path]) -> bool:
|
||||
"""Validates the path given for a bulk download."""
|
||||
path = path if isinstance(path, Path) else Path(path)
|
||||
return path.exists()
|
||||
|
||||
def handler(self, image_names: list[str], board_id: Optional[str]) -> None:
|
||||
"""
|
||||
Create a zip file containing the images specified by the given image names or board id.
|
||||
@ -166,3 +149,29 @@ class BulkDownloadService(BulkDownloadBase):
|
||||
# Delete all the files
|
||||
for file in files:
|
||||
file.unlink()
|
||||
|
||||
def delete(self, bulk_download_item_name: str) -> None:
|
||||
"""
|
||||
Delete the bulk download file.
|
||||
|
||||
:param bulk_download_item_name: The name of the bulk download item.
|
||||
"""
|
||||
path = self.get_path(bulk_download_item_name)
|
||||
Path(path).unlink()
|
||||
|
||||
def get_path(self, bulk_download_item_name: str) -> str:
|
||||
"""
|
||||
Get the path to the bulk download file.
|
||||
|
||||
:param bulk_download_item_name: The name of the bulk download item.
|
||||
:return: The path to the bulk download file.
|
||||
"""
|
||||
path = str(self.__bulk_downloads_folder / bulk_download_item_name)
|
||||
if not self.validate_path(path):
|
||||
raise BulkDownloadTargetException()
|
||||
return path
|
||||
|
||||
def validate_path(self, path: Union[str, Path]) -> bool:
|
||||
"""Validates the path given for a bulk download."""
|
||||
path = path if isinstance(path, Path) else Path(path)
|
||||
return path.exists()
|
||||
|
@ -293,7 +293,7 @@ def test_handler_on_generic_exception(
|
||||
|
||||
monkeypatch.setattr(mock_invoker.services.images, "get_dto", mock_get_board_name)
|
||||
|
||||
with pytest.raises(Exception):
|
||||
with pytest.raises(Exception): # noqa: B017
|
||||
execute_handler_test_on_error(tmp_path, monkeypatch, mock_image_dto, mock_invoker, exception)
|
||||
|
||||
event_bus: DummyEventService = mock_invoker.services.events
|
||||
@ -317,3 +317,17 @@ def execute_handler_test_on_error(
|
||||
assert event_bus.events[0].event_name == "bulk_download_started"
|
||||
assert event_bus.events[1].event_name == "bulk_download_failed"
|
||||
assert event_bus.events[1].payload["error"] == error.__str__()
|
||||
|
||||
|
||||
def test_delete(tmp_path: Path, monkeypatch: Any, mock_image_dto: ImageDTO, mock_invoker: Invoker):
|
||||
"""Test that the delete method removes the bulk download file."""
|
||||
|
||||
bulk_download_service = BulkDownloadService(tmp_path)
|
||||
|
||||
mock_file: Path = tmp_path / "bulk_downloads" / "test.zip"
|
||||
mock_file.write_text("contents")
|
||||
|
||||
bulk_download_service.delete("test.zip")
|
||||
|
||||
assert (tmp_path / "bulk_downloads").exists()
|
||||
assert len(os.listdir(tmp_path / "bulk_downloads")) == 0
|
||||
|
Loading…
Reference in New Issue
Block a user