feat(nodes): it works

This commit is contained in:
psychedelicious 2023-05-21 22:15:44 +10:00 committed by Kent Keirsey
parent 22c34c343a
commit 5bf9891553
11 changed files with 302 additions and 481 deletions

View File

@ -63,9 +63,7 @@ class ApiDependencies:
urls = LocalUrlService() urls = LocalUrlService()
image_file_storage = DiskImageFileStorage( image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
f"{output_folder}/images", metadata_service=metadata
)
# TODO: build a file/path manager? # TODO: build a file/path manager?
db_location = os.path.join(output_folder, "invokeai.db") db_location = os.path.join(output_folder, "invokeai.db")

View File

@ -1,165 +1,47 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
import io from fastapi import HTTPException, Path
from datetime import datetime, timezone from fastapi.responses import FileResponse
import json
import os
from typing import Any
import uuid
from fastapi import Body, HTTPException, Path, Query, Request, UploadFile
from fastapi.responses import FileResponse, Response
from fastapi.routing import APIRouter from fastapi.routing import APIRouter
from PIL import Image
from invokeai.app.api.models.images import (
ImageResponse,
ImageResponseMetadata,
)
from invokeai.app.models.image import ImageType from invokeai.app.models.image import ImageType
from invokeai.app.services.item_storage import PaginatedResults
from ..dependencies import ApiDependencies from ..dependencies import ApiDependencies
images_router = APIRouter(prefix="/v1/files/images", tags=["images", "files"]) image_files_router = APIRouter(prefix="/v1/files/images", tags=["images", "files"])
# @images_router.get("/{image_type}/{image_name}", operation_id="get_image") @image_files_router.get("/{image_type}/{image_name}", operation_id="get_image")
# async def get_image(
# image_type: ImageType = Path(description="The type of image to get"),
# image_name: str = Path(description="The name of the image to get"),
# ) -> FileResponse:
# """Gets an image"""
# path = ApiDependencies.invoker.services.images.get_path(
# image_type=image_type, image_name=image_name
# )
# if ApiDependencies.invoker.services.images.validate_path(path):
# return FileResponse(path)
# else:
# raise HTTPException(status_code=404)
@images_router.get("/{image_type}/{image_name}", operation_id="get_image")
async def get_image( async def get_image(
image_type: ImageType = Path(description="The type of the image to get"), image_type: ImageType = Path(description="The type of the image to get"),
image_name: str = Path(description="The id of the image to get"), image_name: str = Path(description="The id of the image to get"),
) -> FileResponse: ) -> FileResponse:
"""Gets an image""" """Gets an image"""
path = ApiDependencies.invoker.services.images.get_path( try:
path = ApiDependencies.invoker.services.images_new.get_path(
image_type=image_type, image_name=image_name image_type=image_type, image_name=image_name
) )
if ApiDependencies.invoker.services.images.validate_path(path):
return FileResponse(path) return FileResponse(path)
else: except Exception as e:
raise HTTPException(status_code=404) raise HTTPException(status_code=404)
@images_router.delete("/{image_type}/{image_name}", operation_id="delete_image") @image_files_router.get(
async def delete_image( "/{image_type}/{image_name}/thumbnail", operation_id="get_thumbnail"
image_type: ImageType = Path(description="The type of the image to delete"),
image_name: str = Path(description="The name of the image to delete"),
) -> None:
"""Deletes an image and its thumbnail"""
ApiDependencies.invoker.services.images.delete(
image_type=image_type, image_name=image_name
)
@images_router.get(
"/{image_type}/thumbnails/{thumbnail_id}", operation_id="get_thumbnail"
) )
async def get_thumbnail( async def get_thumbnail(
image_type: ImageType = Path(description="The type of the thumbnail to get"), image_type: ImageType = Path(
thumbnail_id: str = Path(description="The id of the thumbnail to get"), description="The type of the image whose thumbnail to get"
) -> FileResponse | Response: ),
image_name: str = Path(description="The id of the image whose thumbnail to get"),
) -> FileResponse:
"""Gets a thumbnail""" """Gets a thumbnail"""
path = ApiDependencies.invoker.services.images.get_path(
image_type=image_type, image_name=thumbnail_id, is_thumbnail=True
)
if ApiDependencies.invoker.services.images.validate_path(path):
return FileResponse(path)
else:
raise HTTPException(status_code=404)
@images_router.post(
"/uploads/",
operation_id="upload_image",
responses={
201: {
"description": "The image was uploaded successfully",
"model": ImageResponse,
},
415: {"description": "Image upload failed"},
},
status_code=201,
)
async def upload_image(
file: UploadFile, image_type: ImageType, request: Request, response: Response
) -> ImageResponse:
if not file.content_type.startswith("image"):
raise HTTPException(status_code=415, detail="Not an image")
contents = await file.read()
try: try:
img = Image.open(io.BytesIO(contents)) path = ApiDependencies.invoker.services.images_new.get_path(
except: image_type=image_type, image_name=image_name, thumbnail=True
# Error opening the image
raise HTTPException(status_code=415, detail="Failed to read image")
filename = f"{uuid.uuid4()}_{str(int(datetime.now(timezone.utc).timestamp()))}.png"
saved_image = ApiDependencies.invoker.services.images.save(
image_type, filename, img
) )
invokeai_metadata = ApiDependencies.invoker.services.metadata.get_metadata(img) return FileResponse(path)
except Exception as e:
image_url = ApiDependencies.invoker.services.images.get_uri( raise HTTPException(status_code=404)
image_type, saved_image.image_name
)
thumbnail_url = ApiDependencies.invoker.services.images.get_uri(
image_type, saved_image.image_name, True
)
res = ImageResponse(
image_type=image_type,
image_name=saved_image.image_name,
image_url=image_url,
thumbnail_url=thumbnail_url,
metadata=ImageResponseMetadata(
created=saved_image.created,
width=img.width,
height=img.height,
invokeai=invokeai_metadata,
),
)
response.status_code = 201
response.headers["Location"] = image_url
return res
@images_router.get(
"/",
operation_id="list_images",
responses={200: {"model": PaginatedResults[ImageResponse]}},
)
async def list_images(
image_type: ImageType = Query(
default=ImageType.RESULT, description="The type of images to get"
),
page: int = Query(default=0, description="The page of images to get"),
per_page: int = Query(default=10, description="The number of images per page"),
) -> PaginatedResults[ImageResponse]:
"""Gets a list of images"""
result = ApiDependencies.invoker.services.images.list(image_type, page, per_page)
return result

View File

@ -71,7 +71,7 @@ async def shutdown_event():
app.include_router(sessions.session_router, prefix="/api") app.include_router(sessions.session_router, prefix="/api")
app.include_router(image_files.images_router, prefix="/api") app.include_router(image_files.image_files_router, prefix="/api")
app.include_router(models.models_router, prefix="/api") app.include_router(models.models_router, prefix="/api")

View File

@ -93,34 +93,42 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
# each time it is called. We only need the first one. # each time it is called. We only need the first one.
generate_output = next(outputs) generate_output = next(outputs)
# Results are image and seed, unwrap for now and ignore the seed image_dto = context.services.images_new.create(
# TODO: pre-seed? image=generate_output.image,
# TODO: can this return multiple results? Should it?
image_type = ImageType.RESULT
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(
image_type, image_name, generate_output.image, metadata
)
context.services.images_db.set(
id=image_name,
image_type=ImageType.RESULT, image_type=ImageType.RESULT,
image_category=ImageCategory.IMAGE, image_category=ImageCategory.IMAGE,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
node_id=self.id, node_id=self.id,
metadata=GeneratedImageOrLatentsMetadata(),
) )
# Results are image and seed, unwrap for now and ignore the seed
# TODO: pre-seed?
# TODO: can this return multiple results? Should it?
# image_type = ImageType.RESULT
# image_name = context.services.images.create_name(
# context.graph_execution_state_id, self.id
# )
# metadata = context.services.metadata.build_metadata(
# session_id=context.graph_execution_state_id, node=self
# )
# context.services.images.save(
# image_type, image_name, generate_output.image, metadata
# )
# context.services.images_db.set(
# id=image_name,
# image_type=ImageType.RESULT,
# image_category=ImageCategory.IMAGE,
# session_id=context.graph_execution_state_id,
# node_id=self.id,
# metadata=GeneratedImageOrLatentsMetadata(),
# )
return build_image_output( return build_image_output(
image_type=image_type, image_type=image_dto.image_type,
image_name=image_name, image_name=image_dto.image_name,
image=generate_output.image, image=generate_output.image,
) )

View File

@ -2,8 +2,11 @@ from typing import Optional
from pydantic import BaseModel, Field, StrictFloat, StrictInt, StrictStr from pydantic import BaseModel, Field, StrictFloat, StrictInt, StrictStr
class GeneratedImageOrLatentsMetadata(BaseModel): class ImageMetadata(BaseModel):
"""Core generation metadata for an image/tensor generated in InvokeAI. """
Core generation metadata for an image/tensor generated in InvokeAI.
Also includes any metadata from the image's PNG tEXt chunks.
Generated by traversing the execution graph, collecting the parameters of the nearest ancestors of a given node. Generated by traversing the execution graph, collecting the parameters of the nearest ancestors of a given node.
@ -51,20 +54,6 @@ class GeneratedImageOrLatentsMetadata(BaseModel):
# vae: Optional[str] = Field(default=None,description="The VAE used for decoding.") # vae: Optional[str] = Field(default=None,description="The VAE used for decoding.")
# unet: Optional[str] = Field(default=None,description="The UNet used dor inference.") # unet: Optional[str] = Field(default=None,description="The UNet used dor inference.")
# clip: Optional[str] = Field(default=None,description="The CLIP Encoder used for conditioning.") # clip: Optional[str] = Field(default=None,description="The CLIP Encoder used for conditioning.")
class UploadedImageOrLatentsMetadata(BaseModel):
"""Limited metadata for an uploaded image/tensor."""
width: Optional[StrictInt] = Field(
default=None, description="Width of the image/tensor in pixels."
)
height: Optional[StrictInt] = Field(
default=None, description="Height of the image/tensor in pixels."
)
# The extra field will be the contents of the PNG file's tEXt chunk. It may have come
# from another SD application or InvokeAI, so it needs to be flexible.
# If the upload is a not an image or `image_latents` tensor, this will be omitted.
extra: Optional[StrictStr] = Field( extra: Optional[StrictStr] = Field(
default=None, description="Extra metadata, extracted from the PNG tEXt chunk." default=None, description="Extra metadata, extracted from the PNG tEXt chunk."
) )

View File

@ -1,28 +1,16 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
import os import os
from glob import glob
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from queue import Queue from queue import Queue
from typing import Dict, List from typing import Dict, Optional
from PIL.Image import Image from PIL.Image import Image as PILImageType
import PIL.Image as PILImage from PIL import Image
from PIL.PngImagePlugin import PngInfo
from send2trash import send2trash from send2trash import send2trash
from invokeai.app.api.models.images import (
ImageResponse,
ImageResponseMetadata,
SavedImage,
)
from invokeai.app.models.image import ImageType from invokeai.app.models.image import ImageType
from invokeai.app.services.metadata import (
InvokeAIMetadata,
MetadataServiceBase,
build_invokeai_metadata_pnginfo,
)
from invokeai.app.services.item_storage import PaginatedResults
from invokeai.app.util.misc import get_timestamp
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
@ -48,61 +36,27 @@ class ImageFileStorageBase(ABC):
super().__init__(message) super().__init__(message)
@abstractmethod @abstractmethod
def get(self, image_type: ImageType, image_name: str) -> Image: def get(self, image_type: ImageType, image_name: str) -> PILImageType:
"""Retrieves an image as PIL Image.""" """Retrieves an image as PIL Image."""
pass pass
@abstractmethod # # TODO: make this a bit more flexible for e.g. cloud storage
def list(
self, image_type: ImageType, page: int = 0, per_page: int = 10
) -> PaginatedResults[ImageResponse]:
"""Gets a paginated list of images."""
pass
# TODO: make this a bit more flexible for e.g. cloud storage
@abstractmethod @abstractmethod
def get_path( def get_path(
self, image_type: ImageType, image_name: str, is_thumbnail: bool = False self, image_type: ImageType, image_name: str, thumbnail: bool = False
) -> str: ) -> str:
"""Gets the internal path to an image or its thumbnail.""" """Gets the internal path to an image or thumbnail."""
pass
# TODO: make this a bit more flexible for e.g. cloud storage
@abstractmethod
def get_uri(
self, image_type: ImageType, image_name: str, is_thumbnail: bool = False
) -> str:
"""Gets the external URI to an image or its thumbnail."""
pass
# @abstractmethod
# def get_image_location(
# self, image_type: ImageType, image_name: str
# ) -> str:
# """Gets the location of an image."""
# pass
# @abstractmethod
# def get_thumbnail_location(
# self, image_type: ImageType, image_name: str
# ) -> str:
# """Gets the location of an image's thumbnail."""
# pass
# TODO: make this a bit more flexible for e.g. cloud storage
@abstractmethod
def validate_path(self, path: str) -> bool:
"""Validates an image path."""
pass pass
@abstractmethod @abstractmethod
def save( def save(
self, self,
image: PILImageType,
image_type: ImageType, image_type: ImageType,
image_name: str, image_name: str,
image: Image, pnginfo: Optional[PngInfo] = None,
metadata: InvokeAIMetadata | None = None, thumbnail_size: int = 256,
) -> SavedImage: ) -> None:
"""Saves an image and a 256x256 WEBP thumbnail. Returns a tuple of the image name, thumbnail name, and created timestamp.""" """Saves an image and a 256x256 WEBP thumbnail. Returns a tuple of the image name, thumbnail name, and created timestamp."""
pass pass
@ -111,26 +65,20 @@ class ImageFileStorageBase(ABC):
"""Deletes an image and its thumbnail (if one exists).""" """Deletes an image and its thumbnail (if one exists)."""
pass pass
def create_name(self, context_id: str, node_id: str) -> str:
"""Creates a unique contextual image filename."""
return f"{context_id}_{node_id}_{str(get_timestamp())}.png"
class DiskImageFileStorage(ImageFileStorageBase): class DiskImageFileStorage(ImageFileStorageBase):
"""Stores images on disk""" """Stores images on disk"""
__output_folder: str __output_folder: str
__cache_ids: Queue # TODO: this is an incredibly naive cache __cache_ids: Queue # TODO: this is an incredibly naive cache
__cache: Dict[str, Image] __cache: Dict[str, PILImageType]
__max_cache_size: int __max_cache_size: int
__metadata_service: MetadataServiceBase
def __init__(self, output_folder: str, metadata_service: MetadataServiceBase): def __init__(self, output_folder: str):
self.__output_folder = output_folder self.__output_folder = output_folder
self.__cache = dict() self.__cache = dict()
self.__cache_ids = Queue() self.__cache_ids = Queue()
self.__max_cache_size = 10 # TODO: get this from config self.__max_cache_size = 10 # TODO: get this from config
self.__metadata_service = metadata_service
Path(output_folder).mkdir(parents=True, exist_ok=True) Path(output_folder).mkdir(parents=True, exist_ok=True)
@ -143,144 +91,38 @@ class DiskImageFileStorage(ImageFileStorageBase):
parents=True, exist_ok=True parents=True, exist_ok=True
) )
def list( def get(self, image_type: ImageType, image_name: str) -> PILImageType:
self, image_type: ImageType, page: int = 0, per_page: int = 10
) -> PaginatedResults[ImageResponse]:
dir_path = os.path.join(self.__output_folder, image_type)
image_paths = glob(f"{dir_path}/*.png")
count = len(image_paths)
sorted_image_paths = sorted(
glob(f"{dir_path}/*.png"), key=os.path.getctime, reverse=True
)
page_of_image_paths = sorted_image_paths[
page * per_page : (page + 1) * per_page
]
page_of_images: List[ImageResponse] = []
for path in page_of_image_paths:
filename = os.path.basename(path)
img = PILImage.open(path)
invokeai_metadata = self.__metadata_service.get_metadata(img)
page_of_images.append(
ImageResponse(
image_type=image_type,
image_name=filename,
# TODO: DiskImageStorage should not be building URLs...?
image_url=self.get_uri(image_type, filename),
thumbnail_url=self.get_uri(image_type, filename, True),
# TODO: Creation of this object should happen elsewhere (?), just making it fit here so it works
metadata=ImageResponseMetadata(
created=int(os.path.getctime(path)),
width=img.width,
height=img.height,
invokeai=invokeai_metadata,
),
)
)
page_count_trunc = int(count / per_page)
page_count_mod = count % per_page
page_count = page_count_trunc if page_count_mod == 0 else page_count_trunc + 1
return PaginatedResults[ImageResponse](
items=page_of_images,
page=page,
pages=page_count,
per_page=per_page,
total=count,
)
def get(self, image_type: ImageType, image_name: str) -> Image:
try: try:
image_path = self.get_path(image_type, image_name) image_path = self.get_path(image_type, image_name)
cache_item = self.__get_cache(image_path) cache_item = self.__get_cache(image_path)
if cache_item: if cache_item:
return cache_item return cache_item
image = PILImage.open(image_path) image = Image.open(image_path)
self.__set_cache(image_path, image) self.__set_cache(image_path, image)
return image return image
except Exception as e: except FileNotFoundError as e:
raise ImageFileStorageBase.ImageFileNotFoundException from e raise ImageFileStorageBase.ImageFileNotFoundException from e
# TODO: make this a bit more flexible for e.g. cloud storage
def get_path(
self, image_type: ImageType, image_name: str, is_thumbnail: bool = False
) -> str:
# strip out any relative path shenanigans
basename = os.path.basename(image_name)
if is_thumbnail:
path = os.path.join(
self.__output_folder, image_type, "thumbnails", basename
)
else:
path = os.path.join(self.__output_folder, image_type, basename)
abspath = os.path.abspath(path)
return abspath
def get_uri(
self, image_type: ImageType, image_name: str, is_thumbnail: bool = False
) -> str:
# strip out any relative path shenanigans
basename = os.path.basename(image_name)
if is_thumbnail:
thumbnail_basename = get_thumbnail_name(basename)
uri = f"api/v1/images/{image_type.value}/thumbnails/{thumbnail_basename}"
else:
uri = f"api/v1/images/{image_type.value}/{basename}"
return uri
def validate_path(self, path: str) -> bool:
try:
os.stat(path)
return True
except FileNotFoundError:
return False
except Exception as e:
raise e
def save( def save(
self, self,
image: PILImageType,
image_type: ImageType, image_type: ImageType,
image_name: str, image_name: str,
image: Image, pnginfo: Optional[PngInfo] = None,
metadata: InvokeAIMetadata | None = None, thumbnail_size: int = 256,
) -> SavedImage: ) -> None:
try: try:
image_path = self.get_path(image_type, image_name) image_path = self.get_path(image_type, image_name)
# TODO: Reading the image and then saving it strips the metadata...
if metadata:
pnginfo = build_invokeai_metadata_pnginfo(metadata=metadata)
image.save(image_path, "PNG", pnginfo=pnginfo) image.save(image_path, "PNG", pnginfo=pnginfo)
else:
image.save(image_path) # this saved image has an empty info
thumbnail_name = get_thumbnail_name(image_name) thumbnail_name = get_thumbnail_name(image_name)
thumbnail_path = self.get_path( thumbnail_path = self.get_path(image_type, thumbnail_name, thumbnail=True)
image_type, thumbnail_name, is_thumbnail=True thumbnail_image = make_thumbnail(image, thumbnail_size)
)
thumbnail_image = make_thumbnail(image)
thumbnail_image.save(thumbnail_path) thumbnail_image.save(thumbnail_path)
self.__set_cache(image_path, image) self.__set_cache(image_path, image)
self.__set_cache(thumbnail_path, thumbnail_image) self.__set_cache(thumbnail_path, thumbnail_image)
return SavedImage(
image_name=image_name,
thumbnail_name=thumbnail_name,
created=int(os.path.getctime(image_path)),
)
except Exception as e: except Exception as e:
raise ImageFileStorageBase.ImageFileSaveException from e raise ImageFileStorageBase.ImageFileSaveException from e
@ -304,10 +146,29 @@ class DiskImageFileStorage(ImageFileStorageBase):
except Exception as e: except Exception as e:
raise ImageFileStorageBase.ImageFileDeleteException from e raise ImageFileStorageBase.ImageFileDeleteException from e
def __get_cache(self, image_name: str) -> Image | None: # TODO: make this a bit more flexible for e.g. cloud storage
def get_path(
self, image_type: ImageType, image_name: str, thumbnail: bool = False
) -> str:
# strip out any relative path shenanigans
basename = os.path.basename(image_name)
if thumbnail:
thumbnail_name = get_thumbnail_name(basename)
path = os.path.join(
self.__output_folder, image_type, "thumbnails", thumbnail_name
)
else:
path = os.path.join(self.__output_folder, image_type, basename)
abspath = os.path.abspath(path)
return abspath
def __get_cache(self, image_name: str) -> PILImageType | None:
return None if image_name not in self.__cache else self.__cache[image_name] return None if image_name not in self.__cache else self.__cache[image_name]
def __set_cache(self, image_name: str, image: Image): def __set_cache(self, image_name: str, image: PILImageType):
if not image_name in self.__cache: if not image_name in self.__cache:
self.__cache[image_name] = image self.__cache[image_name] = image
self.__cache_ids.put( self.__cache_ids.put(

View File

@ -1,25 +1,18 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import datetime import datetime
from typing import Optional from typing import Optional
from invokeai.app.models.metadata import (
GeneratedImageOrLatentsMetadata,
UploadedImageOrLatentsMetadata,
)
import sqlite3 import sqlite3
import threading import threading
from typing import Optional, Union from typing import Optional, Union
from invokeai.app.models.metadata import (
GeneratedImageOrLatentsMetadata, from invokeai.app.models.metadata import ImageMetadata
UploadedImageOrLatentsMetadata,
)
from invokeai.app.models.image import ( from invokeai.app.models.image import (
ImageCategory, ImageCategory,
ImageType, ImageType,
) )
from invokeai.app.services.util.create_enum_table import create_enum_table from invokeai.app.services.util.create_enum_table import create_enum_table
from invokeai.app.services.models.image_record import ImageRecord from invokeai.app.services.models.image_record import (
from invokeai.app.services.util.deserialize_image_record import ( ImageRecord,
deserialize_image_record, deserialize_image_record,
) )
@ -76,9 +69,7 @@ class ImageRecordStorageBase(ABC):
image_category: ImageCategory, image_category: ImageCategory,
session_id: Optional[str], session_id: Optional[str],
node_id: Optional[str], node_id: Optional[str],
metadata: Optional[ metadata: Optional[ImageMetadata],
GeneratedImageOrLatentsMetadata | UploadedImageOrLatentsMetadata
],
created_at: str = datetime.datetime.utcnow().isoformat(), created_at: str = datetime.datetime.utcnow().isoformat(),
) -> None: ) -> None:
"""Saves an image record.""" """Saves an image record."""
@ -288,9 +279,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
image_category: ImageCategory, image_category: ImageCategory,
session_id: Optional[str], session_id: Optional[str],
node_id: Optional[str], node_id: Optional[str],
metadata: Union[ metadata: Optional[ImageMetadata],
GeneratedImageOrLatentsMetadata, UploadedImageOrLatentsMetadata, None
],
created_at: str, created_at: str,
) -> None: ) -> None:
try: try:
@ -306,7 +295,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
image_category, image_category,
node_id, node_id,
session_id, session_id,
metadata metadata,
created_at created_at
) )
VALUES (?, ?, ?, ?, ?, ?, ?); VALUES (?, ?, ?, ?, ?, ?, ?);

View File

@ -1,12 +1,13 @@
from abc import ABC, abstractmethod
import json
from logging import Logger from logging import Logger
from typing import Optional, Union from typing import Optional, Union
import uuid import uuid
from PIL.Image import Image as PILImageType from PIL.Image import Image as PILImageType
from PIL import PngImagePlugin
from invokeai.app.models.image import ImageCategory, ImageType from invokeai.app.models.image import ImageCategory, ImageType
from invokeai.app.models.metadata import ( from invokeai.app.models.metadata import ImageMetadata
GeneratedImageOrLatentsMetadata,
UploadedImageOrLatentsMetadata,
)
from invokeai.app.services.image_record_storage import ( from invokeai.app.services.image_record_storage import (
ImageRecordStorageBase, ImageRecordStorageBase,
) )
@ -22,8 +23,95 @@ from invokeai.app.services.urls import UrlServiceBase
from invokeai.app.util.misc import get_iso_timestamp from invokeai.app.util.misc import get_iso_timestamp
class ImageServiceABC(ABC):
"""
High-level service for image management.
Provides methods for creating, retrieving, and deleting images.
"""
@abstractmethod
def create(
self,
image: PILImageType,
image_type: ImageType,
image_category: ImageCategory,
node_id: Optional[str] = None,
session_id: Optional[str] = None,
metadata: Optional[ImageMetadata] = None,
) -> ImageDTO:
"""Creates an image, storing the file and its metadata."""
pass
@abstractmethod
def get_pil_image(self, image_type: ImageType, image_name: str) -> PILImageType:
"""Gets an image as a PIL image."""
pass
@abstractmethod
def get_record(self, image_type: ImageType, image_name: str) -> ImageRecord:
"""Gets an image record."""
pass
@abstractmethod
def get_path(self, image_type: ImageType, image_name: str) -> str:
"""Gets an image's path"""
pass
@abstractmethod
def get_image_url(self, image_type: ImageType, image_name: str) -> str:
"""Gets an image's URL"""
pass
@abstractmethod
def get_thumbnail_url(self, image_type: ImageType, image_name: str) -> str:
"""Gets an image's URL"""
pass
@abstractmethod
def get_dto(self, image_type: ImageType, image_name: str) -> ImageDTO:
"""Gets an image DTO."""
pass
@abstractmethod
def get_many(
self,
image_type: ImageType,
image_category: ImageCategory,
page: int = 0,
per_page: int = 10,
) -> PaginatedResults[ImageDTO]:
"""Gets a paginated list of image DTOs."""
pass
@abstractmethod
def delete(self, image_type: ImageType, image_name: str):
"""Deletes an image."""
pass
@abstractmethod
def add_tag(self, image_type: ImageType, image_id: str, tag: str) -> None:
"""Adds a tag to an image."""
pass
@abstractmethod
def remove_tag(self, image_type: ImageType, image_id: str, tag: str) -> None:
"""Removes a tag from an image."""
pass
@abstractmethod
def favorite(self, image_type: ImageType, image_id: str) -> None:
"""Favorites an image."""
pass
@abstractmethod
def unfavorite(self, image_type: ImageType, image_id: str) -> None:
"""Unfavorites an image."""
pass
class ImageServiceDependencies: class ImageServiceDependencies:
"""Service dependencies for the ImageManagementService.""" """Service dependencies for the ImageService."""
records: ImageRecordStorageBase records: ImageRecordStorageBase
files: ImageFileStorageBase files: ImageFileStorageBase
@ -46,9 +134,7 @@ class ImageServiceDependencies:
self.logger = logger self.logger = logger
class ImageService: class ImageService(ImageServiceABC):
"""High-level service for image management."""
_services: ImageServiceDependencies _services: ImageServiceDependencies
def __init__( def __init__(
@ -67,21 +153,6 @@ class ImageService:
logger=logger, logger=logger,
) )
def _create_image_name(
self,
image_type: ImageType,
image_category: ImageCategory,
node_id: Optional[str] = None,
session_id: Optional[str] = None,
) -> str:
"""Creates an image name."""
uuid_str = str(uuid.uuid4())
if node_id is not None and session_id is not None:
return f"{image_type.value}_{image_category.value}_{session_id}_{node_id}_{uuid_str}.png"
return f"{image_type.value}_{image_category.value}_{uuid_str}.png"
def create( def create(
self, self,
image: PILImageType, image: PILImageType,
@ -89,11 +160,8 @@ class ImageService:
image_category: ImageCategory, image_category: ImageCategory,
node_id: Optional[str] = None, node_id: Optional[str] = None,
session_id: Optional[str] = None, session_id: Optional[str] = None,
metadata: Optional[ metadata: Optional[ImageMetadata] = None,
Union[GeneratedImageOrLatentsMetadata, UploadedImageOrLatentsMetadata]
] = None,
) -> ImageDTO: ) -> ImageDTO:
"""Creates an image, storing the file and its metadata."""
image_name = self._create_image_name( image_name = self._create_image_name(
image_type=image_type, image_type=image_type,
image_category=image_category, image_category=image_category,
@ -103,13 +171,19 @@ class ImageService:
timestamp = get_iso_timestamp() timestamp = get_iso_timestamp()
if metadata is not None:
pnginfo = PngImagePlugin.PngInfo()
pnginfo.add_text("invokeai", json.dumps(metadata))
else:
pnginfo = None
try: try:
# TODO: Consider using a transaction here to ensure consistency between storage and database # TODO: Consider using a transaction here to ensure consistency between storage and database
self._services.files.save( self._services.files.save(
image_type=image_type, image_type=image_type,
image_name=image_name, image_name=image_name,
image=image, image=image,
metadata=metadata, pnginfo=pnginfo,
) )
self._services.records.save( self._services.records.save(
@ -144,25 +218,40 @@ class ImageService:
except ImageFileStorageBase.ImageFileSaveException: except ImageFileStorageBase.ImageFileSaveException:
self._services.logger.error("Failed to save image file") self._services.logger.error("Failed to save image file")
raise raise
except Exception as e:
self._services.logger.error("Problem saving image record and file")
raise e
def get_pil_image(self, image_type: ImageType, image_name: str) -> PILImageType: def get_pil_image(self, image_type: ImageType, image_name: str) -> PILImageType:
"""Gets an image as a PIL image."""
try: try:
return self._services.files.get(image_type, image_name) return self._services.files.get(image_type, image_name)
except ImageFileStorageBase.ImageFileNotFoundException: except ImageFileStorageBase.ImageFileNotFoundException:
self._services.logger.error("Failed to get image file") self._services.logger.error("Failed to get image file")
raise raise
except Exception as e:
self._services.logger.error("Problem getting image file")
raise e
def get_record(self, image_type: ImageType, image_name: str) -> ImageRecord: def get_record(self, image_type: ImageType, image_name: str) -> ImageRecord:
"""Gets an image record."""
try: try:
return self._services.records.get(image_type, image_name) return self._services.records.get(image_type, image_name)
except ImageRecordStorageBase.ImageRecordNotFoundException: except ImageRecordStorageBase.ImageRecordNotFoundException:
self._services.logger.error("Failed to get image record") self._services.logger.error("Image record not found")
raise raise
except Exception as e:
self._services.logger.error("Problem getting image record")
raise e
def get_path(
self, image_type: ImageType, image_name: str, thumbnail: bool = False
) -> str:
try:
return self._services.files.get_path(image_type, image_name, thumbnail)
except Exception as e:
self._services.logger.error("Problem getting image path")
raise e
def get_dto(self, image_type: ImageType, image_name: str) -> ImageDTO: def get_dto(self, image_type: ImageType, image_name: str) -> ImageDTO:
"""Gets an image DTO."""
try: try:
image_record = self._services.records.get(image_type, image_name) image_record = self._services.records.get(image_type, image_name)
@ -174,21 +263,11 @@ class ImageService:
return image_dto return image_dto
except ImageRecordStorageBase.ImageRecordNotFoundException: except ImageRecordStorageBase.ImageRecordNotFoundException:
self._services.logger.error("Failed to get image DTO") self._services.logger.error("Image record not found")
raise
def delete(self, image_type: ImageType, image_name: str):
"""Deletes an image."""
# TODO: Consider using a transaction here to ensure consistency between storage and database
try:
self._services.files.delete(image_type, image_name)
self._services.records.delete(image_type, image_name)
except ImageRecordStorageBase.ImageRecordDeleteException:
self._services.logger.error(f"Failed to delete image record")
raise
except ImageFileStorageBase.ImageFileDeleteException:
self._services.logger.error(f"Failed to delete image file")
raise raise
except Exception as e:
self._services.logger.error("Problem getting image DTO")
raise e
def get_many( def get_many(
self, self,
@ -197,7 +276,6 @@ class ImageService:
page: int = 0, page: int = 0,
per_page: int = 10, per_page: int = 10,
) -> PaginatedResults[ImageDTO]: ) -> PaginatedResults[ImageDTO]:
"""Gets a paginated list of image DTOs."""
try: try:
results = self._services.records.get_many( results = self._services.records.get_many(
image_type, image_type,
@ -225,21 +303,47 @@ class ImageService:
total=results.total, total=results.total,
) )
except Exception as e: except Exception as e:
self._services.logger.error("Failed to get paginated image DTOs") self._services.logger.error("Problem getting paginated image DTOs")
raise e
def delete(self, image_type: ImageType, image_name: str):
# TODO: Consider using a transaction here to ensure consistency between storage and database
try:
self._services.files.delete(image_type, image_name)
self._services.records.delete(image_type, image_name)
except ImageRecordStorageBase.ImageRecordDeleteException:
self._services.logger.error(f"Failed to delete image record")
raise
except ImageFileStorageBase.ImageFileDeleteException:
self._services.logger.error(f"Failed to delete image file")
raise
except Exception as e:
self._services.logger.error("Problem deleting image record and file")
raise e raise e
def add_tag(self, image_type: ImageType, image_id: str, tag: str) -> None: def add_tag(self, image_type: ImageType, image_id: str, tag: str) -> None:
"""Adds a tag to an image."""
raise NotImplementedError("The 'add_tag' method is not implemented yet.") raise NotImplementedError("The 'add_tag' method is not implemented yet.")
def remove_tag(self, image_type: ImageType, image_id: str, tag: str) -> None: def remove_tag(self, image_type: ImageType, image_id: str, tag: str) -> None:
"""Removes a tag from an image."""
raise NotImplementedError("The 'remove_tag' method is not implemented yet.") raise NotImplementedError("The 'remove_tag' method is not implemented yet.")
def favorite(self, image_type: ImageType, image_id: str) -> None: def favorite(self, image_type: ImageType, image_id: str) -> None:
"""Favorites an image."""
raise NotImplementedError("The 'favorite' method is not implemented yet.") raise NotImplementedError("The 'favorite' method is not implemented yet.")
def unfavorite(self, image_type: ImageType, image_id: str) -> None: def unfavorite(self, image_type: ImageType, image_id: str) -> None:
"""Unfavorites an image."""
raise NotImplementedError("The 'unfavorite' method is not implemented yet.") raise NotImplementedError("The 'unfavorite' method is not implemented yet.")
def _create_image_name(
self,
image_type: ImageType,
image_category: ImageCategory,
node_id: Optional[str] = None,
session_id: Optional[str] = None,
) -> str:
"""Create a unique image name."""
uuid_str = str(uuid.uuid4())
if node_id is not None and session_id is not None:
return f"{image_type.value}_{image_category.value}_{session_id}_{node_id}_{uuid_str}.png"
return f"{image_type.value}_{image_category.value}_{uuid_str}.png"

View File

@ -1,11 +1,10 @@
import datetime import datetime
import sqlite3
from typing import Optional, Union from typing import Optional, Union
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from invokeai.app.models.metadata import (
GeneratedImageOrLatentsMetadata,
UploadedImageOrLatentsMetadata,
)
from invokeai.app.models.image import ImageCategory, ImageType from invokeai.app.models.image import ImageCategory, ImageType
from invokeai.app.models.metadata import ImageMetadata
from invokeai.app.util.misc import get_iso_timestamp
class ImageRecord(BaseModel): class ImageRecord(BaseModel):
@ -19,9 +18,9 @@ class ImageRecord(BaseModel):
) )
session_id: Optional[str] = Field(default=None, description="The session ID.") session_id: Optional[str] = Field(default=None, description="The session ID.")
node_id: Optional[str] = Field(default=None, description="The node ID.") node_id: Optional[str] = Field(default=None, description="The node ID.")
metadata: Optional[ metadata: Optional[ImageMetadata] = Field(
Union[GeneratedImageOrLatentsMetadata, UploadedImageOrLatentsMetadata] default=None, description="The image's metadata."
] = Field(default=None, description="The image's metadata.") )
class ImageDTO(ImageRecord): class ImageDTO(ImageRecord):
@ -46,3 +45,27 @@ def image_record_to_dto(
image_url=image_url, image_url=image_url,
thumbnail_url=thumbnail_url, thumbnail_url=thumbnail_url,
) )
def deserialize_image_record(image_row: sqlite3.Row) -> ImageRecord:
"""Deserializes an image record."""
image_dict = dict(image_row)
image_type = ImageType(image_dict.get("image_type", ImageType.RESULT.value))
raw_metadata = image_dict.get("metadata", "{}")
metadata = ImageMetadata.parse_raw(raw_metadata)
return ImageRecord(
image_name=image_dict.get("id", "unknown"),
session_id=image_dict.get("session_id", None),
node_id=image_dict.get("node_id", None),
metadata=metadata,
image_type=image_type,
image_category=ImageCategory(
image_dict.get("image_category", ImageCategory.IMAGE.value)
),
created_at=image_dict.get("created_at", get_iso_timestamp()),
)

View File

@ -25,8 +25,8 @@ class LocalUrlService(UrlServiceBase):
def get_image_url(self, image_type: ImageType, image_name: str) -> str: def get_image_url(self, image_type: ImageType, image_name: str) -> str:
image_basename = os.path.basename(image_name) image_basename = os.path.basename(image_name)
return f"{self._base_url}/images/{image_type.value}/{image_basename}" return f"{self._base_url}/files/images/{image_type.value}/{image_basename}"
def get_thumbnail_url(self, image_type: ImageType, image_name: str) -> str: def get_thumbnail_url(self, image_type: ImageType, image_name: str) -> str:
thumbnail_basename = get_thumbnail_name(os.path.basename(image_name)) image_basename = os.path.basename(image_name)
return f"{self._base_url}/images/{image_type.value}/thumbnails/{thumbnail_basename}" return f"{self._base_url}/files/images/{image_type.value}/{image_basename}/thumbnail"

View File

@ -1,33 +0,0 @@
from invokeai.app.models.metadata import (
GeneratedImageOrLatentsMetadata,
UploadedImageOrLatentsMetadata,
)
from invokeai.app.models.image import ImageCategory, ImageType
from invokeai.app.services.models.image_record import ImageRecord
from invokeai.app.util.misc import get_iso_timestamp
def deserialize_image_record(image: dict) -> ImageRecord:
"""Deserializes an image record."""
# All values *should* be present, except `session_id` and `node_id`, but provide some defaults just in case
image_type = ImageType(image.get("image_type", ImageType.RESULT.value))
raw_metadata = image.get("metadata", {})
if image_type == ImageType.UPLOAD:
metadata = UploadedImageOrLatentsMetadata.parse_obj(raw_metadata)
else:
metadata = GeneratedImageOrLatentsMetadata.parse_obj(raw_metadata)
return ImageRecord(
image_name=image.get("id", "unknown"),
session_id=image.get("session_id", None),
node_id=image.get("node_id", None),
metadata=metadata,
image_type=image_type,
image_category=ImageCategory(
image.get("image_category", ImageCategory.IMAGE.value)
),
created_at=image.get("created_at", get_iso_timestamp()),
)