mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(nodes): add list_images endpoint
- add `list_images` endpoint at `GET api/v1/images` - extend `ImageStorageBase` with `list()` method, implemented it for `DiskImageStorage` - add `ImageReponse` class to for image responses, which includes urls, metadata - add `ImageMetadata` class (basically a stub at the moment) - uploaded images now named `"{uuid}_{timestamp}.png"` - add `models` modules. besides separating concerns more clearly, this helps to mitigate circular dependencies - improve thumbnail handling
This commit is contained in:
parent
54d9833db0
commit
34402cc46a
14
invokeai/app/api/models/images.py
Normal file
14
invokeai/app/api/models/images.py
Normal file
@ -0,0 +1,14 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.models.image import ImageType
|
||||
from invokeai.app.models.metadata import ImageMetadata
|
||||
|
||||
|
||||
class ImageResponse(BaseModel):
|
||||
"""The response type for images"""
|
||||
|
||||
image_type: ImageType = Field(description="The type of the image")
|
||||
image_name: str = Field(description="The name of the image")
|
||||
image_url: str = Field(description="The url of the image")
|
||||
thumbnail_url: str = Field(description="The url of the image's thumbnail")
|
||||
metadata: ImageMetadata = Field(description="The image's metadata")
|
@ -1,18 +1,20 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from datetime import datetime, timezone
|
||||
import uuid
|
||||
|
||||
from fastapi import Path, Request, UploadFile
|
||||
from fastapi import Path, Query, Request, UploadFile
|
||||
from fastapi.responses import FileResponse, Response
|
||||
from fastapi.routing import APIRouter
|
||||
from PIL import Image
|
||||
from invokeai.app.api.models.images import ImageResponse
|
||||
from invokeai.app.services.item_storage import PaginatedResults
|
||||
|
||||
from ...services.image_storage import ImageType
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
images_router = APIRouter(prefix="/v1/images", tags=["images"])
|
||||
|
||||
|
||||
@images_router.get("/{image_type}/{image_name}", operation_id="get_image")
|
||||
async def get_image(
|
||||
image_type: ImageType = Path(description="The type of image to get"),
|
||||
@ -53,14 +55,30 @@ async def upload_image(file: UploadFile, request: Request):
|
||||
# Error opening the image
|
||||
return Response(status_code=415)
|
||||
|
||||
filename = f"{str(int(datetime.now(timezone.utc).timestamp()))}.png"
|
||||
filename = f"{uuid.uuid4()}_{str(int(datetime.now(timezone.utc).timestamp()))}.png"
|
||||
ApiDependencies.invoker.services.images.save(ImageType.UPLOAD, filename, im)
|
||||
|
||||
return Response(
|
||||
status_code=201,
|
||||
headers={
|
||||
"Location": request.url_for(
|
||||
"get_image", image_type=ImageType.UPLOAD, image_name=filename
|
||||
"get_image", image_type=ImageType.UPLOAD.value, image_name=filename
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@images_router.get(
|
||||
"/",
|
||||
operation_id="list_images",
|
||||
responses={200: {"model": PaginatedResults[ImageResponse]}},
|
||||
)
|
||||
async def list_images(
|
||||
image_type: ImageType = Query(default=ImageType.RESULT, description="The type of images to get"),
|
||||
page: int = Query(default=0, description="The page of images to get"),
|
||||
per_page: int = Query(default=10, description="The number of images per page"),
|
||||
) -> PaginatedResults[ImageResponse]:
|
||||
"""Gets a list of images"""
|
||||
result = ApiDependencies.invoker.services.images.list(
|
||||
image_type, page, per_page
|
||||
)
|
||||
return result
|
||||
|
@ -6,7 +6,8 @@ from typing import Any, Callable, Iterable, Literal, get_args, get_origin, get_t
|
||||
from pydantic import BaseModel, Field
|
||||
import networkx as nx
|
||||
import matplotlib.pyplot as plt
|
||||
from ..invocations.image import ImageField
|
||||
|
||||
from ..models.image import ImageField
|
||||
from ..services.graph import GraphExecutionState
|
||||
from ..services.invoker import Invoker
|
||||
|
||||
|
@ -7,9 +7,9 @@ import numpy
|
||||
from PIL import Image, ImageOps
|
||||
from pydantic import Field
|
||||
|
||||
from ..services.image_storage import ImageType
|
||||
from invokeai.app.models.image import ImageField, ImageType
|
||||
from .baseinvocation import BaseInvocation, InvocationContext
|
||||
from .image import ImageField, ImageOutput
|
||||
from .image import ImageOutput
|
||||
|
||||
|
||||
class CvInpaintInvocation(BaseInvocation):
|
||||
|
@ -8,12 +8,13 @@ from torch import Tensor
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from ..services.image_storage import ImageType
|
||||
from invokeai.app.models.image import ImageField, ImageType
|
||||
from .baseinvocation import BaseInvocation, InvocationContext
|
||||
from .image import ImageField, ImageOutput
|
||||
from .image import ImageOutput
|
||||
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
from ..util.util import diffusers_step_callback_adapter, CanceledException
|
||||
from ..models.exceptions import CanceledException
|
||||
from ..util.step_callback import diffusers_step_callback_adapter
|
||||
|
||||
SAMPLER_NAME_VALUES = Literal[
|
||||
tuple(InvokeAIGenerator.schedulers())
|
||||
|
@ -7,20 +7,10 @@ import numpy
|
||||
from PIL import Image, ImageFilter, ImageOps
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..services.image_storage import ImageType
|
||||
from ..models.image import ImageField, ImageType
|
||||
from ..services.invocation_services import InvocationServices
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
|
||||
|
||||
|
||||
class ImageField(BaseModel):
|
||||
"""An image field used for passing image objects between invocations"""
|
||||
|
||||
image_type: str = Field(
|
||||
default=ImageType.RESULT, description="The type of the image"
|
||||
)
|
||||
image_name: Optional[str] = Field(default=None, description="The name of the image")
|
||||
|
||||
|
||||
class ImageOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output an image"""
|
||||
#fmt: off
|
||||
|
@ -3,10 +3,10 @@ from typing import Literal, Union
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from ..services.image_storage import ImageType
|
||||
from invokeai.app.models.image import ImageField, ImageType
|
||||
from ..services.invocation_services import InvocationServices
|
||||
from .baseinvocation import BaseInvocation, InvocationContext
|
||||
from .image import ImageField, ImageOutput
|
||||
from .image import ImageOutput
|
||||
|
||||
class RestoreFaceInvocation(BaseInvocation):
|
||||
"""Restores faces in an image."""
|
||||
|
@ -5,10 +5,10 @@ from typing import Literal, Union
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from ..services.image_storage import ImageType
|
||||
from invokeai.app.models.image import ImageField, ImageType
|
||||
from ..services.invocation_services import InvocationServices
|
||||
from .baseinvocation import BaseInvocation, InvocationContext
|
||||
from .image import ImageField, ImageOutput
|
||||
from .image import ImageOutput
|
||||
|
||||
|
||||
class UpscaleInvocation(BaseInvocation):
|
||||
|
0
invokeai/app/models/__init__.py
Normal file
0
invokeai/app/models/__init__.py
Normal file
3
invokeai/app/models/exceptions.py
Normal file
3
invokeai/app/models/exceptions.py
Normal file
@ -0,0 +1,3 @@
|
||||
class CanceledException(Exception):
|
||||
"""Execution canceled by user."""
|
||||
pass
|
26
invokeai/app/models/image.py
Normal file
26
invokeai/app/models/image.py
Normal file
@ -0,0 +1,26 @@
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ImageType(str, Enum):
|
||||
RESULT = "results"
|
||||
INTERMEDIATE = "intermediates"
|
||||
UPLOAD = "uploads"
|
||||
|
||||
|
||||
class ImageField(BaseModel):
|
||||
"""An image field used for passing image objects between invocations"""
|
||||
|
||||
image_type: str = Field(
|
||||
default=ImageType.RESULT, description="The type of the image"
|
||||
)
|
||||
image_name: Optional[str] = Field(default=None, description="The name of the image")
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"required": [
|
||||
"image_type",
|
||||
"image_name",
|
||||
]
|
||||
}
|
11
invokeai/app/models/metadata.py
Normal file
11
invokeai/app/models/metadata.py
Normal file
@ -0,0 +1,11 @@
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
class ImageMetadata(BaseModel):
|
||||
"""An image's metadata"""
|
||||
|
||||
timestamp: float = Field(description="The creation timestamp of the image")
|
||||
width: int = Field(description="The width of the image in pixels")
|
||||
height: int = Field(description="The height of the image in pixels")
|
||||
# TODO: figure out metadata
|
||||
sd_metadata: Optional[dict] = Field(default={}, description="The image's SD-specific metadata")
|
@ -2,24 +2,25 @@
|
||||
|
||||
import datetime
|
||||
import os
|
||||
from glob import glob
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from queue import Queue
|
||||
from typing import Dict
|
||||
from typing import Callable, Dict, List
|
||||
|
||||
from PIL.Image import Image
|
||||
import PIL.Image as PILImage
|
||||
from pydantic import BaseModel
|
||||
from invokeai.app.api.models.images import ImageResponse
|
||||
from invokeai.app.models.image import ImageField, ImageType
|
||||
from invokeai.app.models.metadata import ImageMetadata
|
||||
from invokeai.app.services.item_storage import PaginatedResults
|
||||
from invokeai.app.util.save_thumbnail import save_thumbnail
|
||||
|
||||
from invokeai.backend.image_util import PngWriter
|
||||
|
||||
|
||||
class ImageType(str, Enum):
|
||||
RESULT = "results"
|
||||
INTERMEDIATE = "intermediates"
|
||||
UPLOAD = "uploads"
|
||||
|
||||
|
||||
class ImageStorageBase(ABC):
|
||||
"""Responsible for storing and retrieving images."""
|
||||
|
||||
@ -27,9 +28,17 @@ class ImageStorageBase(ABC):
|
||||
def get(self, image_type: ImageType, image_name: str) -> Image:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list(
|
||||
self, image_type: ImageType, page: int = 0, per_page: int = 10
|
||||
) -> PaginatedResults[ImageResponse]:
|
||||
pass
|
||||
|
||||
# TODO: make this a bit more flexible for e.g. cloud storage
|
||||
@abstractmethod
|
||||
def get_path(self, image_type: ImageType, image_name: str) -> str:
|
||||
def get_path(
|
||||
self, image_type: ImageType, image_name: str, is_thumbnail: bool = False
|
||||
) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@ -71,19 +80,74 @@ class DiskImageStorage(ImageStorageBase):
|
||||
parents=True, exist_ok=True
|
||||
)
|
||||
|
||||
def list(
|
||||
self, image_type: ImageType, page: int = 0, per_page: int = 10
|
||||
) -> PaginatedResults[ImageResponse]:
|
||||
dir_path = os.path.join(self.__output_folder, image_type)
|
||||
image_paths = glob(f"{dir_path}/*.png")
|
||||
count = len(image_paths)
|
||||
|
||||
sorted_image_paths = sorted(
|
||||
glob(f"{dir_path}/*.png"), key=os.path.getctime, reverse=True
|
||||
)
|
||||
|
||||
page_of_image_paths = sorted_image_paths[
|
||||
page * per_page : (page + 1) * per_page
|
||||
]
|
||||
|
||||
page_of_images: List[ImageResponse] = []
|
||||
|
||||
for path in page_of_image_paths:
|
||||
filename = os.path.basename(path)
|
||||
img = PILImage.open(path)
|
||||
page_of_images.append(
|
||||
ImageResponse(
|
||||
image_type=image_type.value,
|
||||
image_name=filename,
|
||||
# TODO: DiskImageStorage should not be building URLs...?
|
||||
image_url=f"api/v1/images/{image_type.value}/{filename}",
|
||||
thumbnail_url=f"api/v1/images/{image_type.value}/thumbnails/{os.path.splitext(filename)[0]}.webp",
|
||||
# TODO: Creation of this object should happen elsewhere, just making it fit here so it works
|
||||
metadata=ImageMetadata(
|
||||
timestamp=os.path.getctime(path),
|
||||
width=img.width,
|
||||
height=img.height,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
page_count_trunc = int(count / per_page)
|
||||
page_count_mod = count % per_page
|
||||
page_count = page_count_trunc if page_count_mod == 0 else page_count_trunc + 1
|
||||
|
||||
return PaginatedResults[ImageResponse](
|
||||
items=page_of_images,
|
||||
page=page,
|
||||
pages=page_count,
|
||||
per_page=per_page,
|
||||
total=count,
|
||||
)
|
||||
|
||||
def get(self, image_type: ImageType, image_name: str) -> Image:
|
||||
image_path = self.get_path(image_type, image_name)
|
||||
cache_item = self.__get_cache(image_path)
|
||||
if cache_item:
|
||||
return cache_item
|
||||
|
||||
image = Image.open(image_path)
|
||||
image = PILImage.open(image_path)
|
||||
self.__set_cache(image_path, image)
|
||||
return image
|
||||
|
||||
# TODO: make this a bit more flexible for e.g. cloud storage
|
||||
def get_path(self, image_type: ImageType, image_name: str) -> str:
|
||||
path = os.path.join(self.__output_folder, image_type, image_name)
|
||||
def get_path(
|
||||
self, image_type: ImageType, image_name: str, is_thumbnail: bool = False
|
||||
) -> str:
|
||||
if is_thumbnail:
|
||||
path = os.path.join(
|
||||
self.__output_folder, image_type, "thumbnails", image_name
|
||||
)
|
||||
else:
|
||||
path = os.path.join(self.__output_folder, image_type, image_name)
|
||||
return path
|
||||
|
||||
def save(self, image_type: ImageType, image_name: str, image: Image) -> None:
|
||||
@ -101,12 +165,19 @@ class DiskImageStorage(ImageStorageBase):
|
||||
|
||||
def delete(self, image_type: ImageType, image_name: str) -> None:
|
||||
image_path = self.get_path(image_type, image_name)
|
||||
thumbnail_path = self.get_path(image_type, image_name, True)
|
||||
if os.path.exists(image_path):
|
||||
os.remove(image_path)
|
||||
|
||||
if image_path in self.__cache:
|
||||
del self.__cache[image_path]
|
||||
|
||||
if os.path.exists(thumbnail_path):
|
||||
os.remove(thumbnail_path)
|
||||
|
||||
if thumbnail_path in self.__cache:
|
||||
del self.__cache[thumbnail_path]
|
||||
|
||||
def __get_cache(self, image_name: str) -> Image:
|
||||
return None if image_name not in self.__cache else self.__cache[image_name]
|
||||
|
||||
|
@ -4,7 +4,7 @@ from threading import Event, Thread
|
||||
from ..invocations.baseinvocation import InvocationContext
|
||||
from .invocation_queue import InvocationQueueItem
|
||||
from .invoker import InvocationProcessorABC, Invoker
|
||||
from ..util.util import CanceledException
|
||||
from ..models.exceptions import CanceledException
|
||||
|
||||
class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
__invoker_thread: Thread
|
||||
|
0
invokeai/app/util/__init__.py
Normal file
0
invokeai/app/util/__init__.py
Normal file
@ -1,14 +1,16 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from ..invocations.baseinvocation import InvocationContext
|
||||
from ...backend.util.util import image_to_dataURL
|
||||
from ...backend.generator.base import Generator
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
|
||||
class CanceledException(Exception):
|
||||
pass
|
||||
|
||||
def fast_latents_step_callback(sample: torch.Tensor, step: int, steps: int, id: str, context: InvocationContext, ):
|
||||
def fast_latents_step_callback(
|
||||
sample: torch.Tensor,
|
||||
step: int,
|
||||
steps: int,
|
||||
id: str,
|
||||
context: InvocationContext,
|
||||
):
|
||||
# TODO: only output a preview image when requested
|
||||
image = Generator.sample_to_lowres_estimated_image(sample)
|
||||
|
||||
@ -21,15 +23,12 @@ def fast_latents_step_callback(sample: torch.Tensor, step: int, steps: int, id:
|
||||
context.services.events.emit_generator_progress(
|
||||
context.graph_execution_state_id,
|
||||
id,
|
||||
{
|
||||
"width": width,
|
||||
"height": height,
|
||||
"dataURL": dataURL
|
||||
},
|
||||
{"width": width, "height": height, "dataURL": dataURL},
|
||||
step,
|
||||
steps,
|
||||
)
|
||||
|
||||
|
||||
def diffusers_step_callback_adapter(*cb_args, **kwargs):
|
||||
"""
|
||||
txt2img gives us a Tensor in the step_callbak, while img2img gives us a PipelineIntermediateState.
|
||||
@ -37,6 +36,8 @@ def diffusers_step_callback_adapter(*cb_args, **kwargs):
|
||||
"""
|
||||
if isinstance(cb_args[0], PipelineIntermediateState):
|
||||
progress_state: PipelineIntermediateState = cb_args[0]
|
||||
return fast_latents_step_callback(progress_state.latents, progress_state.step, **kwargs)
|
||||
return fast_latents_step_callback(
|
||||
progress_state.latents, progress_state.step, **kwargs
|
||||
)
|
||||
else:
|
||||
return fast_latents_step_callback(*cb_args, **kwargs)
|
Loading…
Reference in New Issue
Block a user