From 23d9d58c08974749593ca150b8fe0b5e037c9d27 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 23 May 2023 22:57:29 +1000 Subject: [PATCH] fix(nodes): fix bugs with serving images When returning a `FileResponse`, we must provide a valid path, else an exception is raised outside the route handler. Add the `validate_path` method back to the service so we can validate paths before returning the file. I don't like this but apparently this is just how `starlette` and `fastapi` works with `FileResponse`. --- invokeai/app/api/routers/images.py | 18 +++++++++++++++--- invokeai/app/api_app.py | 4 ++-- invokeai/app/services/image_file_storage.py | 16 +++++++++++++++- invokeai/app/services/images.py | 16 ++++++++++++++-- invokeai/app/services/urls.py | 10 +++++++--- 5 files changed, 53 insertions(+), 11 deletions(-) diff --git a/invokeai/app/api/routers/images.py b/invokeai/app/api/routers/images.py index 123774b721..602b539da1 100644 --- a/invokeai/app/api/routers/images.py +++ b/invokeai/app/api/routers/images.py @@ -93,7 +93,7 @@ async def get_image_metadata( @images_router.get( - "/{image_type}/{image_name}/full", + "/{image_type}/{image_name}", operation_id="get_image_full", response_class=Response, responses={ @@ -117,7 +117,15 @@ async def get_image_full( image_type, image_name ) - return FileResponse(path, media_type="image/png") + if not ApiDependencies.invoker.services.images_new.validate_path(path): + raise HTTPException(status_code=404) + + return FileResponse( + path, + media_type="image/png", + filename=image_name, + content_disposition_type="inline", + ) except Exception as e: raise HTTPException(status_code=404) @@ -144,8 +152,12 @@ async def get_image_thumbnail( path = ApiDependencies.invoker.services.images_new.get_path( image_type, image_name, thumbnail=True ) + if not ApiDependencies.invoker.services.images_new.validate_path(path): + raise HTTPException(status_code=404) - return FileResponse(path, media_type="image/webp") + return FileResponse( + path, media_type="image/webp", content_disposition_type="inline" + ) except Exception as e: raise HTTPException(status_code=404) diff --git a/invokeai/app/api_app.py b/invokeai/app/api_app.py index 964202786a..69d322578d 100644 --- a/invokeai/app/api_app.py +++ b/invokeai/app/api_app.py @@ -3,8 +3,7 @@ import asyncio from inspect import signature import uvicorn -from invokeai.app.models import resources -import invokeai.backend.util.logging as logger +from invokeai.backend.util.logging import InvokeAILogger from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html @@ -20,6 +19,7 @@ from .api.sockets import SocketIO from .invocations.baseinvocation import BaseInvocation from .services.config import InvokeAIAppConfig +logger = InvokeAILogger.getLogger() # Create the app # TODO: create this all in a method so configuration/etc. can be passed in? diff --git a/invokeai/app/services/image_file_storage.py b/invokeai/app/services/image_file_storage.py index 1b4466e06e..46070b3bf2 100644 --- a/invokeai/app/services/image_file_storage.py +++ b/invokeai/app/services/image_file_storage.py @@ -44,7 +44,6 @@ class ImageFileStorageBase(ABC): """Retrieves an image as PIL Image.""" pass - # # TODO: make this a bit more flexible for e.g. cloud storage @abstractmethod def get_path( self, image_type: ImageType, image_name: str, thumbnail: bool = False @@ -52,6 +51,13 @@ class ImageFileStorageBase(ABC): """Gets the internal path to an image or thumbnail.""" pass + # TODO: We need to validate paths before starlette makes the FileResponse, else we get a + # 500 internal server error. I don't like having this method on the service. + @abstractmethod + def validate_path(self, path: str) -> bool: + """Validates the path given for an image or thumbnail.""" + pass + @abstractmethod def save( self, @@ -175,6 +181,14 @@ class DiskImageFileStorage(ImageFileStorageBase): return abspath + def validate_path(self, path: str) -> bool: + """Validates the path given for an image or thumbnail.""" + try: + os.stat(path) + return True + except: + return False + def __get_cache(self, image_name: str) -> PILImageType | None: return None if image_name not in self.__cache else self.__cache[image_name] diff --git a/invokeai/app/services/images.py b/invokeai/app/services/images.py index 2b2322085d..914dd3b6d3 100644 --- a/invokeai/app/services/images.py +++ b/invokeai/app/services/images.py @@ -70,14 +70,19 @@ class ImageServiceABC(ABC): @abstractmethod def get_path(self, image_type: ImageType, image_name: str) -> str: - """Gets an image's path""" + """Gets an image's path.""" + pass + + @abstractmethod + def validate_path(self, path: str) -> bool: + """Validates an image's path.""" pass @abstractmethod def get_url( self, image_type: ImageType, image_name: str, thumbnail: bool = False ) -> str: - """Gets an image's or thumbnail's URL""" + """Gets an image's or thumbnail's URL.""" pass @abstractmethod @@ -273,6 +278,13 @@ class ImageService(ImageServiceABC): self._services.logger.error("Problem getting image path") raise e + def validate_path(self, path: str) -> bool: + try: + return self._services.files.validate_path(path) + except Exception as e: + self._services.logger.error("Problem validating image path") + raise e + def get_url( self, image_type: ImageType, image_name: str, thumbnail: bool = False ) -> str: diff --git a/invokeai/app/services/urls.py b/invokeai/app/services/urls.py index cfc4b34012..2716da60ad 100644 --- a/invokeai/app/services/urls.py +++ b/invokeai/app/services/urls.py @@ -24,7 +24,11 @@ class LocalUrlService(UrlServiceBase): self, image_type: ImageType, image_name: str, thumbnail: bool = False ) -> str: image_basename = os.path.basename(image_name) - if thumbnail: - return f"{self._base_url}/images/{image_type.value}/{image_basename}/thumbnail" - return f"{self._base_url}/images/{image_type.value}/{image_basename}/full" + # These paths are determined by the routes in invokeai/app/api/routers/images.py + if thumbnail: + return ( + f"{self._base_url}/images/{image_type.value}/{image_basename}/thumbnail" + ) + + return f"{self._base_url}/images/{image_type.value}/{image_basename}"