fix(nodes): fix bugs with serving images

When returning a `FileResponse`, we must provide a valid path, else an exception is raised outside the route handler.

Add the `validate_path` method back to the service so we can validate paths before returning the file.

I don't like this but apparently this is just how `starlette` and `fastapi` works with `FileResponse`.
This commit is contained in:
psychedelicious 2023-05-23 22:57:29 +10:00 committed by Kent Keirsey
parent 4c331a5d7e
commit 23d9d58c08
5 changed files with 53 additions and 11 deletions

View File

@ -93,7 +93,7 @@ async def get_image_metadata(
@images_router.get( @images_router.get(
"/{image_type}/{image_name}/full", "/{image_type}/{image_name}",
operation_id="get_image_full", operation_id="get_image_full",
response_class=Response, response_class=Response,
responses={ responses={
@ -117,7 +117,15 @@ async def get_image_full(
image_type, image_name image_type, image_name
) )
return FileResponse(path, media_type="image/png") if not ApiDependencies.invoker.services.images_new.validate_path(path):
raise HTTPException(status_code=404)
return FileResponse(
path,
media_type="image/png",
filename=image_name,
content_disposition_type="inline",
)
except Exception as e: except Exception as e:
raise HTTPException(status_code=404) raise HTTPException(status_code=404)
@ -144,8 +152,12 @@ async def get_image_thumbnail(
path = ApiDependencies.invoker.services.images_new.get_path( path = ApiDependencies.invoker.services.images_new.get_path(
image_type, image_name, thumbnail=True image_type, image_name, thumbnail=True
) )
if not ApiDependencies.invoker.services.images_new.validate_path(path):
raise HTTPException(status_code=404)
return FileResponse(path, media_type="image/webp") return FileResponse(
path, media_type="image/webp", content_disposition_type="inline"
)
except Exception as e: except Exception as e:
raise HTTPException(status_code=404) raise HTTPException(status_code=404)

View File

@ -3,8 +3,7 @@ import asyncio
from inspect import signature from inspect import signature
import uvicorn import uvicorn
from invokeai.app.models import resources from invokeai.backend.util.logging import InvokeAILogger
import invokeai.backend.util.logging as logger
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
@ -20,6 +19,7 @@ from .api.sockets import SocketIO
from .invocations.baseinvocation import BaseInvocation from .invocations.baseinvocation import BaseInvocation
from .services.config import InvokeAIAppConfig from .services.config import InvokeAIAppConfig
logger = InvokeAILogger.getLogger()
# Create the app # Create the app
# TODO: create this all in a method so configuration/etc. can be passed in? # TODO: create this all in a method so configuration/etc. can be passed in?

View File

@ -44,7 +44,6 @@ class ImageFileStorageBase(ABC):
"""Retrieves an image as PIL Image.""" """Retrieves an image as PIL Image."""
pass pass
# # TODO: make this a bit more flexible for e.g. cloud storage
@abstractmethod @abstractmethod
def get_path( def get_path(
self, image_type: ImageType, image_name: str, thumbnail: bool = False self, image_type: ImageType, image_name: str, thumbnail: bool = False
@ -52,6 +51,13 @@ class ImageFileStorageBase(ABC):
"""Gets the internal path to an image or thumbnail.""" """Gets the internal path to an image or thumbnail."""
pass pass
# TODO: We need to validate paths before starlette makes the FileResponse, else we get a
# 500 internal server error. I don't like having this method on the service.
@abstractmethod
def validate_path(self, path: str) -> bool:
"""Validates the path given for an image or thumbnail."""
pass
@abstractmethod @abstractmethod
def save( def save(
self, self,
@ -175,6 +181,14 @@ class DiskImageFileStorage(ImageFileStorageBase):
return abspath return abspath
def validate_path(self, path: str) -> bool:
"""Validates the path given for an image or thumbnail."""
try:
os.stat(path)
return True
except:
return False
def __get_cache(self, image_name: str) -> PILImageType | None: def __get_cache(self, image_name: str) -> PILImageType | None:
return None if image_name not in self.__cache else self.__cache[image_name] return None if image_name not in self.__cache else self.__cache[image_name]

View File

@ -70,14 +70,19 @@ class ImageServiceABC(ABC):
@abstractmethod @abstractmethod
def get_path(self, image_type: ImageType, image_name: str) -> str: def get_path(self, image_type: ImageType, image_name: str) -> str:
"""Gets an image's path""" """Gets an image's path."""
pass
@abstractmethod
def validate_path(self, path: str) -> bool:
"""Validates an image's path."""
pass pass
@abstractmethod @abstractmethod
def get_url( def get_url(
self, image_type: ImageType, image_name: str, thumbnail: bool = False self, image_type: ImageType, image_name: str, thumbnail: bool = False
) -> str: ) -> str:
"""Gets an image's or thumbnail's URL""" """Gets an image's or thumbnail's URL."""
pass pass
@abstractmethod @abstractmethod
@ -273,6 +278,13 @@ class ImageService(ImageServiceABC):
self._services.logger.error("Problem getting image path") self._services.logger.error("Problem getting image path")
raise e raise e
def validate_path(self, path: str) -> bool:
try:
return self._services.files.validate_path(path)
except Exception as e:
self._services.logger.error("Problem validating image path")
raise e
def get_url( def get_url(
self, image_type: ImageType, image_name: str, thumbnail: bool = False self, image_type: ImageType, image_name: str, thumbnail: bool = False
) -> str: ) -> str:

View File

@ -24,7 +24,11 @@ class LocalUrlService(UrlServiceBase):
self, image_type: ImageType, image_name: str, thumbnail: bool = False self, image_type: ImageType, image_name: str, thumbnail: bool = False
) -> str: ) -> str:
image_basename = os.path.basename(image_name) image_basename = os.path.basename(image_name)
if thumbnail:
return f"{self._base_url}/images/{image_type.value}/{image_basename}/thumbnail"
return f"{self._base_url}/images/{image_type.value}/{image_basename}/full" # These paths are determined by the routes in invokeai/app/api/routers/images.py
if thumbnail:
return (
f"{self._base_url}/images/{image_type.value}/{image_basename}/thumbnail"
)
return f"{self._base_url}/images/{image_type.value}/{image_basename}"