mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
cleaning up bulk download zip after the response is complete
This commit is contained in:
parent
ca1c96e8f5
commit
4dfa1e3d03
@ -409,10 +409,10 @@ async def download_images_from_list(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
async def get_bulk_download_item(
|
async def get_bulk_download_item(
|
||||||
|
background_tasks: BackgroundTasks,
|
||||||
bulk_download_item_name: str = Path(description="The bulk_download_item_id of the bulk download item to get"),
|
bulk_download_item_name: str = Path(description="The bulk_download_item_id of the bulk download item to get"),
|
||||||
) -> FileResponse:
|
) -> FileResponse:
|
||||||
"""Gets a bulk download zip file"""
|
"""Gets a bulk download zip file"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
path = ApiDependencies.invoker.services.bulk_download.get_path(bulk_download_item_name)
|
path = ApiDependencies.invoker.services.bulk_download.get_path(bulk_download_item_name)
|
||||||
|
|
||||||
@ -423,6 +423,7 @@ async def get_bulk_download_item(
|
|||||||
content_disposition_type="inline",
|
content_disposition_type="inline",
|
||||||
)
|
)
|
||||||
response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}"
|
response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}"
|
||||||
|
background_tasks.add_task(ApiDependencies.invoker.services.bulk_download.delete, bulk_download_item_name)
|
||||||
return response
|
return response
|
||||||
except Exception:
|
except Exception:
|
||||||
raise HTTPException(status_code=404)
|
raise HTTPException(status_code=404)
|
||||||
|
@ -59,3 +59,11 @@ class BulkDownloadBase(ABC):
|
|||||||
Returns:
|
Returns:
|
||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def delete(self, bulk_download_item_name: str) -> None:
|
||||||
|
"""
|
||||||
|
Delete the bulk download file.
|
||||||
|
|
||||||
|
:param bulk_download_item_name: The name of the bulk download item.
|
||||||
|
"""
|
||||||
|
@ -38,23 +38,6 @@ class BulkDownloadService(BulkDownloadBase):
|
|||||||
self.__bulk_downloads_folder = self.__output_folder / "bulk_downloads"
|
self.__bulk_downloads_folder = self.__output_folder / "bulk_downloads"
|
||||||
self.__bulk_downloads_folder.mkdir(parents=True, exist_ok=True)
|
self.__bulk_downloads_folder.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
def get_path(self, bulk_download_item_name: str) -> str:
|
|
||||||
"""
|
|
||||||
Get the path to the bulk download file.
|
|
||||||
|
|
||||||
:param bulk_download_item_name: The name of the bulk download item.
|
|
||||||
:return: The path to the bulk download file.
|
|
||||||
"""
|
|
||||||
path = str(self.__bulk_downloads_folder / bulk_download_item_name)
|
|
||||||
if not self.validate_path(path):
|
|
||||||
raise BulkDownloadTargetException()
|
|
||||||
return path
|
|
||||||
|
|
||||||
def validate_path(self, path: Union[str, Path]) -> bool:
|
|
||||||
"""Validates the path given for a bulk download."""
|
|
||||||
path = path if isinstance(path, Path) else Path(path)
|
|
||||||
return path.exists()
|
|
||||||
|
|
||||||
def handler(self, image_names: list[str], board_id: Optional[str]) -> None:
|
def handler(self, image_names: list[str], board_id: Optional[str]) -> None:
|
||||||
"""
|
"""
|
||||||
Create a zip file containing the images specified by the given image names or board id.
|
Create a zip file containing the images specified by the given image names or board id.
|
||||||
@ -166,3 +149,29 @@ class BulkDownloadService(BulkDownloadBase):
|
|||||||
# Delete all the files
|
# Delete all the files
|
||||||
for file in files:
|
for file in files:
|
||||||
file.unlink()
|
file.unlink()
|
||||||
|
|
||||||
|
def delete(self, bulk_download_item_name: str) -> None:
|
||||||
|
"""
|
||||||
|
Delete the bulk download file.
|
||||||
|
|
||||||
|
:param bulk_download_item_name: The name of the bulk download item.
|
||||||
|
"""
|
||||||
|
path = self.get_path(bulk_download_item_name)
|
||||||
|
Path(path).unlink()
|
||||||
|
|
||||||
|
def get_path(self, bulk_download_item_name: str) -> str:
|
||||||
|
"""
|
||||||
|
Get the path to the bulk download file.
|
||||||
|
|
||||||
|
:param bulk_download_item_name: The name of the bulk download item.
|
||||||
|
:return: The path to the bulk download file.
|
||||||
|
"""
|
||||||
|
path = str(self.__bulk_downloads_folder / bulk_download_item_name)
|
||||||
|
if not self.validate_path(path):
|
||||||
|
raise BulkDownloadTargetException()
|
||||||
|
return path
|
||||||
|
|
||||||
|
def validate_path(self, path: Union[str, Path]) -> bool:
|
||||||
|
"""Validates the path given for a bulk download."""
|
||||||
|
path = path if isinstance(path, Path) else Path(path)
|
||||||
|
return path.exists()
|
||||||
|
@ -293,7 +293,7 @@ def test_handler_on_generic_exception(
|
|||||||
|
|
||||||
monkeypatch.setattr(mock_invoker.services.images, "get_dto", mock_get_board_name)
|
monkeypatch.setattr(mock_invoker.services.images, "get_dto", mock_get_board_name)
|
||||||
|
|
||||||
with pytest.raises(Exception):
|
with pytest.raises(Exception): # noqa: B017
|
||||||
execute_handler_test_on_error(tmp_path, monkeypatch, mock_image_dto, mock_invoker, exception)
|
execute_handler_test_on_error(tmp_path, monkeypatch, mock_image_dto, mock_invoker, exception)
|
||||||
|
|
||||||
event_bus: DummyEventService = mock_invoker.services.events
|
event_bus: DummyEventService = mock_invoker.services.events
|
||||||
@ -317,3 +317,17 @@ def execute_handler_test_on_error(
|
|||||||
assert event_bus.events[0].event_name == "bulk_download_started"
|
assert event_bus.events[0].event_name == "bulk_download_started"
|
||||||
assert event_bus.events[1].event_name == "bulk_download_failed"
|
assert event_bus.events[1].event_name == "bulk_download_failed"
|
||||||
assert event_bus.events[1].payload["error"] == error.__str__()
|
assert event_bus.events[1].payload["error"] == error.__str__()
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete(tmp_path: Path, monkeypatch: Any, mock_image_dto: ImageDTO, mock_invoker: Invoker):
|
||||||
|
"""Test that the delete method removes the bulk download file."""
|
||||||
|
|
||||||
|
bulk_download_service = BulkDownloadService(tmp_path)
|
||||||
|
|
||||||
|
mock_file: Path = tmp_path / "bulk_downloads" / "test.zip"
|
||||||
|
mock_file.write_text("contents")
|
||||||
|
|
||||||
|
bulk_download_service.delete("test.zip")
|
||||||
|
|
||||||
|
assert (tmp_path / "bulk_downloads").exists()
|
||||||
|
assert len(os.listdir(tmp_path / "bulk_downloads")) == 0
|
||||||
|
Loading…
Reference in New Issue
Block a user