fix(nodes): fix bugs with serving images

When returning a `FileResponse`, we must provide a valid path, else an exception is raised outside the route handler.

Add the `validate_path` method back to the service so we can validate paths before returning the file.

I don't like this but apparently this is just how `starlette` and `fastapi` works with `FileResponse`.
This commit is contained in:
psychedelicious 2023-05-23 22:57:29 +10:00 committed by Kent Keirsey
parent 4c331a5d7e
commit 23d9d58c08
5 changed files with 53 additions and 11 deletions

View File

@ -93,7 +93,7 @@ async def get_image_metadata(
@images_router.get(
"/{image_type}/{image_name}/full",
"/{image_type}/{image_name}",
operation_id="get_image_full",
response_class=Response,
responses={
@ -117,7 +117,15 @@ async def get_image_full(
image_type, image_name
)
return FileResponse(path, media_type="image/png")
if not ApiDependencies.invoker.services.images_new.validate_path(path):
raise HTTPException(status_code=404)
return FileResponse(
path,
media_type="image/png",
filename=image_name,
content_disposition_type="inline",
)
except Exception as e:
raise HTTPException(status_code=404)
@ -144,8 +152,12 @@ async def get_image_thumbnail(
path = ApiDependencies.invoker.services.images_new.get_path(
image_type, image_name, thumbnail=True
)
if not ApiDependencies.invoker.services.images_new.validate_path(path):
raise HTTPException(status_code=404)
return FileResponse(path, media_type="image/webp")
return FileResponse(
path, media_type="image/webp", content_disposition_type="inline"
)
except Exception as e:
raise HTTPException(status_code=404)

View File

@ -3,8 +3,7 @@ import asyncio
from inspect import signature
import uvicorn
from invokeai.app.models import resources
import invokeai.backend.util.logging as logger
from invokeai.backend.util.logging import InvokeAILogger
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
@ -20,6 +19,7 @@ from .api.sockets import SocketIO
from .invocations.baseinvocation import BaseInvocation
from .services.config import InvokeAIAppConfig
logger = InvokeAILogger.getLogger()
# Create the app
# TODO: create this all in a method so configuration/etc. can be passed in?

View File

@ -44,7 +44,6 @@ class ImageFileStorageBase(ABC):
"""Retrieves an image as PIL Image."""
pass
# # TODO: make this a bit more flexible for e.g. cloud storage
@abstractmethod
def get_path(
self, image_type: ImageType, image_name: str, thumbnail: bool = False
@ -52,6 +51,13 @@ class ImageFileStorageBase(ABC):
"""Gets the internal path to an image or thumbnail."""
pass
# TODO: We need to validate paths before starlette makes the FileResponse, else we get a
# 500 internal server error. I don't like having this method on the service.
@abstractmethod
def validate_path(self, path: str) -> bool:
"""Validates the path given for an image or thumbnail."""
pass
@abstractmethod
def save(
self,
@ -175,6 +181,14 @@ class DiskImageFileStorage(ImageFileStorageBase):
return abspath
def validate_path(self, path: str) -> bool:
"""Validates the path given for an image or thumbnail."""
try:
os.stat(path)
return True
except:
return False
def __get_cache(self, image_name: str) -> PILImageType | None:
return None if image_name not in self.__cache else self.__cache[image_name]

View File

@ -70,14 +70,19 @@ class ImageServiceABC(ABC):
@abstractmethod
def get_path(self, image_type: ImageType, image_name: str) -> str:
"""Gets an image's path"""
"""Gets an image's path."""
pass
@abstractmethod
def validate_path(self, path: str) -> bool:
"""Validates an image's path."""
pass
@abstractmethod
def get_url(
self, image_type: ImageType, image_name: str, thumbnail: bool = False
) -> str:
"""Gets an image's or thumbnail's URL"""
"""Gets an image's or thumbnail's URL."""
pass
@abstractmethod
@ -273,6 +278,13 @@ class ImageService(ImageServiceABC):
self._services.logger.error("Problem getting image path")
raise e
def validate_path(self, path: str) -> bool:
try:
return self._services.files.validate_path(path)
except Exception as e:
self._services.logger.error("Problem validating image path")
raise e
def get_url(
self, image_type: ImageType, image_name: str, thumbnail: bool = False
) -> str:

View File

@ -24,7 +24,11 @@ class LocalUrlService(UrlServiceBase):
self, image_type: ImageType, image_name: str, thumbnail: bool = False
) -> str:
image_basename = os.path.basename(image_name)
if thumbnail:
return f"{self._base_url}/images/{image_type.value}/{image_basename}/thumbnail"
return f"{self._base_url}/images/{image_type.value}/{image_basename}/full"
# These paths are determined by the routes in invokeai/app/api/routers/images.py
if thumbnail:
return (
f"{self._base_url}/images/{image_type.value}/{image_basename}/thumbnail"
)
return f"{self._base_url}/images/{image_type.value}/{image_basename}"