From 5bf9891553eb423c0d39539bc66c2ad551f37323 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 21 May 2023 22:15:44 +1000 Subject: [PATCH] feat(nodes): it works --- invokeai/app/api/dependencies.py | 4 +- invokeai/app/api/routers/image_files.py | 164 ++---------- invokeai/app/api_app.py | 2 +- invokeai/app/invocations/generate.py | 50 ++-- invokeai/app/models/metadata.py | 21 +- invokeai/app/services/image_file_storage.py | 233 ++++-------------- invokeai/app/services/image_record_storage.py | 25 +- invokeai/app/services/images.py | 208 ++++++++++++---- invokeai/app/services/models/image_record.py | 37 ++- invokeai/app/services/urls.py | 6 +- .../services/util/deserialize_image_record.py | 33 --- 11 files changed, 302 insertions(+), 481 deletions(-) delete mode 100644 invokeai/app/services/util/deserialize_image_record.py diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index 09be2daecc..1ad53f31ca 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -63,9 +63,7 @@ class ApiDependencies: urls = LocalUrlService() - image_file_storage = DiskImageFileStorage( - f"{output_folder}/images", metadata_service=metadata - ) + image_file_storage = DiskImageFileStorage(f"{output_folder}/images") # TODO: build a file/path manager? db_location = os.path.join(output_folder, "invokeai.db") diff --git a/invokeai/app/api/routers/image_files.py b/invokeai/app/api/routers/image_files.py index a42b2a1e63..2694df5b19 100644 --- a/invokeai/app/api/routers/image_files.py +++ b/invokeai/app/api/routers/image_files.py @@ -1,165 +1,47 @@ -# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) -import io -from datetime import datetime, timezone -import json -import os -from typing import Any -import uuid - -from fastapi import Body, HTTPException, Path, Query, Request, UploadFile -from fastapi.responses import FileResponse, Response +# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team +from fastapi import HTTPException, Path +from fastapi.responses import FileResponse from fastapi.routing import APIRouter -from PIL import Image -from invokeai.app.api.models.images import ( - ImageResponse, - ImageResponseMetadata, -) from invokeai.app.models.image import ImageType -from invokeai.app.services.item_storage import PaginatedResults from ..dependencies import ApiDependencies -images_router = APIRouter(prefix="/v1/files/images", tags=["images", "files"]) +image_files_router = APIRouter(prefix="/v1/files/images", tags=["images", "files"]) -# @images_router.get("/{image_type}/{image_name}", operation_id="get_image") -# async def get_image( -# image_type: ImageType = Path(description="The type of image to get"), -# image_name: str = Path(description="The name of the image to get"), -# ) -> FileResponse: -# """Gets an image""" - -# path = ApiDependencies.invoker.services.images.get_path( -# image_type=image_type, image_name=image_name -# ) - -# if ApiDependencies.invoker.services.images.validate_path(path): -# return FileResponse(path) -# else: -# raise HTTPException(status_code=404) - - -@images_router.get("/{image_type}/{image_name}", operation_id="get_image") +@image_files_router.get("/{image_type}/{image_name}", operation_id="get_image") async def get_image( image_type: ImageType = Path(description="The type of the image to get"), image_name: str = Path(description="The id of the image to get"), ) -> FileResponse: """Gets an image""" - path = ApiDependencies.invoker.services.images.get_path( - image_type=image_type, image_name=image_name - ) + try: + path = ApiDependencies.invoker.services.images_new.get_path( + image_type=image_type, image_name=image_name + ) - if ApiDependencies.invoker.services.images.validate_path(path): return FileResponse(path) - else: + except Exception as e: raise HTTPException(status_code=404) -@images_router.delete("/{image_type}/{image_name}", operation_id="delete_image") -async def delete_image( - image_type: ImageType = Path(description="The type of the image to delete"), - image_name: str = Path(description="The name of the image to delete"), -) -> None: - """Deletes an image and its thumbnail""" - - ApiDependencies.invoker.services.images.delete( - image_type=image_type, image_name=image_name - ) - - -@images_router.get( - "/{image_type}/thumbnails/{thumbnail_id}", operation_id="get_thumbnail" +@image_files_router.get( + "/{image_type}/{image_name}/thumbnail", operation_id="get_thumbnail" ) async def get_thumbnail( - image_type: ImageType = Path(description="The type of the thumbnail to get"), - thumbnail_id: str = Path(description="The id of the thumbnail to get"), -) -> FileResponse | Response: + image_type: ImageType = Path( + description="The type of the image whose thumbnail to get" + ), + image_name: str = Path(description="The id of the image whose thumbnail to get"), +) -> FileResponse: """Gets a thumbnail""" - path = ApiDependencies.invoker.services.images.get_path( - image_type=image_type, image_name=thumbnail_id, is_thumbnail=True - ) - - if ApiDependencies.invoker.services.images.validate_path(path): - return FileResponse(path) - else: - raise HTTPException(status_code=404) - - -@images_router.post( - "/uploads/", - operation_id="upload_image", - responses={ - 201: { - "description": "The image was uploaded successfully", - "model": ImageResponse, - }, - 415: {"description": "Image upload failed"}, - }, - status_code=201, -) -async def upload_image( - file: UploadFile, image_type: ImageType, request: Request, response: Response -) -> ImageResponse: - if not file.content_type.startswith("image"): - raise HTTPException(status_code=415, detail="Not an image") - - contents = await file.read() - try: - img = Image.open(io.BytesIO(contents)) - except: - # Error opening the image - raise HTTPException(status_code=415, detail="Failed to read image") + path = ApiDependencies.invoker.services.images_new.get_path( + image_type=image_type, image_name=image_name, thumbnail=True + ) - filename = f"{uuid.uuid4()}_{str(int(datetime.now(timezone.utc).timestamp()))}.png" - - saved_image = ApiDependencies.invoker.services.images.save( - image_type, filename, img - ) - - invokeai_metadata = ApiDependencies.invoker.services.metadata.get_metadata(img) - - image_url = ApiDependencies.invoker.services.images.get_uri( - image_type, saved_image.image_name - ) - - thumbnail_url = ApiDependencies.invoker.services.images.get_uri( - image_type, saved_image.image_name, True - ) - - res = ImageResponse( - image_type=image_type, - image_name=saved_image.image_name, - image_url=image_url, - thumbnail_url=thumbnail_url, - metadata=ImageResponseMetadata( - created=saved_image.created, - width=img.width, - height=img.height, - invokeai=invokeai_metadata, - ), - ) - - response.status_code = 201 - response.headers["Location"] = image_url - - return res - - -@images_router.get( - "/", - operation_id="list_images", - responses={200: {"model": PaginatedResults[ImageResponse]}}, -) -async def list_images( - image_type: ImageType = Query( - default=ImageType.RESULT, description="The type of images to get" - ), - page: int = Query(default=0, description="The page of images to get"), - per_page: int = Query(default=10, description="The number of images per page"), -) -> PaginatedResults[ImageResponse]: - """Gets a list of images""" - result = ApiDependencies.invoker.services.images.list(image_type, page, per_page) - return result + return FileResponse(path) + except Exception as e: + raise HTTPException(status_code=404) diff --git a/invokeai/app/api_app.py b/invokeai/app/api_app.py index 9720474109..dffb2ec139 100644 --- a/invokeai/app/api_app.py +++ b/invokeai/app/api_app.py @@ -71,7 +71,7 @@ async def shutdown_event(): app.include_router(sessions.session_router, prefix="/api") -app.include_router(image_files.images_router, prefix="/api") +app.include_router(image_files.image_files_router, prefix="/api") app.include_router(models.models_router, prefix="/api") diff --git a/invokeai/app/invocations/generate.py b/invokeai/app/invocations/generate.py index 525be128e4..a27027dfe4 100644 --- a/invokeai/app/invocations/generate.py +++ b/invokeai/app/invocations/generate.py @@ -93,34 +93,42 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation): # each time it is called. We only need the first one. generate_output = next(outputs) - # Results are image and seed, unwrap for now and ignore the seed - # TODO: pre-seed? - # TODO: can this return multiple results? Should it? - image_type = ImageType.RESULT - image_name = context.services.images.create_name( - context.graph_execution_state_id, self.id - ) - - metadata = context.services.metadata.build_metadata( - session_id=context.graph_execution_state_id, node=self - ) - - context.services.images.save( - image_type, image_name, generate_output.image, metadata - ) - - context.services.images_db.set( - id=image_name, + image_dto = context.services.images_new.create( + image=generate_output.image, image_type=ImageType.RESULT, image_category=ImageCategory.IMAGE, session_id=context.graph_execution_state_id, node_id=self.id, - metadata=GeneratedImageOrLatentsMetadata(), ) + # Results are image and seed, unwrap for now and ignore the seed + # TODO: pre-seed? + # TODO: can this return multiple results? Should it? + # image_type = ImageType.RESULT + # image_name = context.services.images.create_name( + # context.graph_execution_state_id, self.id + # ) + + # metadata = context.services.metadata.build_metadata( + # session_id=context.graph_execution_state_id, node=self + # ) + + # context.services.images.save( + # image_type, image_name, generate_output.image, metadata + # ) + + # context.services.images_db.set( + # id=image_name, + # image_type=ImageType.RESULT, + # image_category=ImageCategory.IMAGE, + # session_id=context.graph_execution_state_id, + # node_id=self.id, + # metadata=GeneratedImageOrLatentsMetadata(), + # ) + return build_image_output( - image_type=image_type, - image_name=image_name, + image_type=image_dto.image_type, + image_name=image_dto.image_name, image=generate_output.image, ) diff --git a/invokeai/app/models/metadata.py b/invokeai/app/models/metadata.py index aae3337266..35998fa27e 100644 --- a/invokeai/app/models/metadata.py +++ b/invokeai/app/models/metadata.py @@ -2,8 +2,11 @@ from typing import Optional from pydantic import BaseModel, Field, StrictFloat, StrictInt, StrictStr -class GeneratedImageOrLatentsMetadata(BaseModel): - """Core generation metadata for an image/tensor generated in InvokeAI. +class ImageMetadata(BaseModel): + """ + Core generation metadata for an image/tensor generated in InvokeAI. + + Also includes any metadata from the image's PNG tEXt chunks. Generated by traversing the execution graph, collecting the parameters of the nearest ancestors of a given node. @@ -51,20 +54,6 @@ class GeneratedImageOrLatentsMetadata(BaseModel): # vae: Optional[str] = Field(default=None,description="The VAE used for decoding.") # unet: Optional[str] = Field(default=None,description="The UNet used dor inference.") # clip: Optional[str] = Field(default=None,description="The CLIP Encoder used for conditioning.") - - -class UploadedImageOrLatentsMetadata(BaseModel): - """Limited metadata for an uploaded image/tensor.""" - - width: Optional[StrictInt] = Field( - default=None, description="Width of the image/tensor in pixels." - ) - height: Optional[StrictInt] = Field( - default=None, description="Height of the image/tensor in pixels." - ) - # The extra field will be the contents of the PNG file's tEXt chunk. It may have come - # from another SD application or InvokeAI, so it needs to be flexible. - # If the upload is a not an image or `image_latents` tensor, this will be omitted. extra: Optional[StrictStr] = Field( default=None, description="Extra metadata, extracted from the PNG tEXt chunk." ) diff --git a/invokeai/app/services/image_file_storage.py b/invokeai/app/services/image_file_storage.py index ff3640011a..3a99940068 100644 --- a/invokeai/app/services/image_file_storage.py +++ b/invokeai/app/services/image_file_storage.py @@ -1,28 +1,16 @@ -# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) - +# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team import os -from glob import glob from abc import ABC, abstractmethod from pathlib import Path from queue import Queue -from typing import Dict, List +from typing import Dict, Optional -from PIL.Image import Image -import PIL.Image as PILImage +from PIL.Image import Image as PILImageType +from PIL import Image +from PIL.PngImagePlugin import PngInfo from send2trash import send2trash -from invokeai.app.api.models.images import ( - ImageResponse, - ImageResponseMetadata, - SavedImage, -) + from invokeai.app.models.image import ImageType -from invokeai.app.services.metadata import ( - InvokeAIMetadata, - MetadataServiceBase, - build_invokeai_metadata_pnginfo, -) -from invokeai.app.services.item_storage import PaginatedResults -from invokeai.app.util.misc import get_timestamp from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail @@ -48,61 +36,27 @@ class ImageFileStorageBase(ABC): super().__init__(message) @abstractmethod - def get(self, image_type: ImageType, image_name: str) -> Image: + def get(self, image_type: ImageType, image_name: str) -> PILImageType: """Retrieves an image as PIL Image.""" pass - @abstractmethod - def list( - self, image_type: ImageType, page: int = 0, per_page: int = 10 - ) -> PaginatedResults[ImageResponse]: - """Gets a paginated list of images.""" - pass - - # TODO: make this a bit more flexible for e.g. cloud storage + # # TODO: make this a bit more flexible for e.g. cloud storage @abstractmethod def get_path( - self, image_type: ImageType, image_name: str, is_thumbnail: bool = False + self, image_type: ImageType, image_name: str, thumbnail: bool = False ) -> str: - """Gets the internal path to an image or its thumbnail.""" - pass - - # TODO: make this a bit more flexible for e.g. cloud storage - @abstractmethod - def get_uri( - self, image_type: ImageType, image_name: str, is_thumbnail: bool = False - ) -> str: - """Gets the external URI to an image or its thumbnail.""" - pass - - # @abstractmethod - # def get_image_location( - # self, image_type: ImageType, image_name: str - # ) -> str: - # """Gets the location of an image.""" - # pass - - # @abstractmethod - # def get_thumbnail_location( - # self, image_type: ImageType, image_name: str - # ) -> str: - # """Gets the location of an image's thumbnail.""" - # pass - - # TODO: make this a bit more flexible for e.g. cloud storage - @abstractmethod - def validate_path(self, path: str) -> bool: - """Validates an image path.""" + """Gets the internal path to an image or thumbnail.""" pass @abstractmethod def save( self, + image: PILImageType, image_type: ImageType, image_name: str, - image: Image, - metadata: InvokeAIMetadata | None = None, - ) -> SavedImage: + pnginfo: Optional[PngInfo] = None, + thumbnail_size: int = 256, + ) -> None: """Saves an image and a 256x256 WEBP thumbnail. Returns a tuple of the image name, thumbnail name, and created timestamp.""" pass @@ -111,26 +65,20 @@ class ImageFileStorageBase(ABC): """Deletes an image and its thumbnail (if one exists).""" pass - def create_name(self, context_id: str, node_id: str) -> str: - """Creates a unique contextual image filename.""" - return f"{context_id}_{node_id}_{str(get_timestamp())}.png" - class DiskImageFileStorage(ImageFileStorageBase): """Stores images on disk""" __output_folder: str __cache_ids: Queue # TODO: this is an incredibly naive cache - __cache: Dict[str, Image] + __cache: Dict[str, PILImageType] __max_cache_size: int - __metadata_service: MetadataServiceBase - def __init__(self, output_folder: str, metadata_service: MetadataServiceBase): + def __init__(self, output_folder: str): self.__output_folder = output_folder self.__cache = dict() self.__cache_ids = Queue() self.__max_cache_size = 10 # TODO: get this from config - self.__metadata_service = metadata_service Path(output_folder).mkdir(parents=True, exist_ok=True) @@ -143,144 +91,38 @@ class DiskImageFileStorage(ImageFileStorageBase): parents=True, exist_ok=True ) - def list( - self, image_type: ImageType, page: int = 0, per_page: int = 10 - ) -> PaginatedResults[ImageResponse]: - dir_path = os.path.join(self.__output_folder, image_type) - image_paths = glob(f"{dir_path}/*.png") - count = len(image_paths) - - sorted_image_paths = sorted( - glob(f"{dir_path}/*.png"), key=os.path.getctime, reverse=True - ) - - page_of_image_paths = sorted_image_paths[ - page * per_page : (page + 1) * per_page - ] - - page_of_images: List[ImageResponse] = [] - - for path in page_of_image_paths: - filename = os.path.basename(path) - img = PILImage.open(path) - - invokeai_metadata = self.__metadata_service.get_metadata(img) - - page_of_images.append( - ImageResponse( - image_type=image_type, - image_name=filename, - # TODO: DiskImageStorage should not be building URLs...? - image_url=self.get_uri(image_type, filename), - thumbnail_url=self.get_uri(image_type, filename, True), - # TODO: Creation of this object should happen elsewhere (?), just making it fit here so it works - metadata=ImageResponseMetadata( - created=int(os.path.getctime(path)), - width=img.width, - height=img.height, - invokeai=invokeai_metadata, - ), - ) - ) - - page_count_trunc = int(count / per_page) - page_count_mod = count % per_page - page_count = page_count_trunc if page_count_mod == 0 else page_count_trunc + 1 - - return PaginatedResults[ImageResponse]( - items=page_of_images, - page=page, - pages=page_count, - per_page=per_page, - total=count, - ) - - def get(self, image_type: ImageType, image_name: str) -> Image: + def get(self, image_type: ImageType, image_name: str) -> PILImageType: try: image_path = self.get_path(image_type, image_name) cache_item = self.__get_cache(image_path) if cache_item: return cache_item - image = PILImage.open(image_path) + image = Image.open(image_path) self.__set_cache(image_path, image) return image - except Exception as e: + except FileNotFoundError as e: raise ImageFileStorageBase.ImageFileNotFoundException from e - # TODO: make this a bit more flexible for e.g. cloud storage - def get_path( - self, image_type: ImageType, image_name: str, is_thumbnail: bool = False - ) -> str: - # strip out any relative path shenanigans - basename = os.path.basename(image_name) - - if is_thumbnail: - path = os.path.join( - self.__output_folder, image_type, "thumbnails", basename - ) - else: - path = os.path.join(self.__output_folder, image_type, basename) - - abspath = os.path.abspath(path) - - return abspath - - def get_uri( - self, image_type: ImageType, image_name: str, is_thumbnail: bool = False - ) -> str: - # strip out any relative path shenanigans - basename = os.path.basename(image_name) - - if is_thumbnail: - thumbnail_basename = get_thumbnail_name(basename) - uri = f"api/v1/images/{image_type.value}/thumbnails/{thumbnail_basename}" - else: - uri = f"api/v1/images/{image_type.value}/{basename}" - - return uri - - def validate_path(self, path: str) -> bool: - try: - os.stat(path) - return True - except FileNotFoundError: - return False - except Exception as e: - raise e - def save( self, + image: PILImageType, image_type: ImageType, image_name: str, - image: Image, - metadata: InvokeAIMetadata | None = None, - ) -> SavedImage: + pnginfo: Optional[PngInfo] = None, + thumbnail_size: int = 256, + ) -> None: try: image_path = self.get_path(image_type, image_name) - - # TODO: Reading the image and then saving it strips the metadata... - if metadata: - pnginfo = build_invokeai_metadata_pnginfo(metadata=metadata) - image.save(image_path, "PNG", pnginfo=pnginfo) - else: - image.save(image_path) # this saved image has an empty info + image.save(image_path, "PNG", pnginfo=pnginfo) thumbnail_name = get_thumbnail_name(image_name) - thumbnail_path = self.get_path( - image_type, thumbnail_name, is_thumbnail=True - ) - thumbnail_image = make_thumbnail(image) + thumbnail_path = self.get_path(image_type, thumbnail_name, thumbnail=True) + thumbnail_image = make_thumbnail(image, thumbnail_size) thumbnail_image.save(thumbnail_path) self.__set_cache(image_path, image) self.__set_cache(thumbnail_path, thumbnail_image) - - return SavedImage( - image_name=image_name, - thumbnail_name=thumbnail_name, - created=int(os.path.getctime(image_path)), - ) except Exception as e: raise ImageFileStorageBase.ImageFileSaveException from e @@ -304,10 +146,29 @@ class DiskImageFileStorage(ImageFileStorageBase): except Exception as e: raise ImageFileStorageBase.ImageFileDeleteException from e - def __get_cache(self, image_name: str) -> Image | None: + # TODO: make this a bit more flexible for e.g. cloud storage + def get_path( + self, image_type: ImageType, image_name: str, thumbnail: bool = False + ) -> str: + # strip out any relative path shenanigans + basename = os.path.basename(image_name) + + if thumbnail: + thumbnail_name = get_thumbnail_name(basename) + path = os.path.join( + self.__output_folder, image_type, "thumbnails", thumbnail_name + ) + else: + path = os.path.join(self.__output_folder, image_type, basename) + + abspath = os.path.abspath(path) + + return abspath + + def __get_cache(self, image_name: str) -> PILImageType | None: return None if image_name not in self.__cache else self.__cache[image_name] - def __set_cache(self, image_name: str, image: Image): + def __set_cache(self, image_name: str, image: PILImageType): if not image_name in self.__cache: self.__cache[image_name] = image self.__cache_ids.put( diff --git a/invokeai/app/services/image_record_storage.py b/invokeai/app/services/image_record_storage.py index 6d2d9dab68..7c79cf7a34 100644 --- a/invokeai/app/services/image_record_storage.py +++ b/invokeai/app/services/image_record_storage.py @@ -1,25 +1,18 @@ from abc import ABC, abstractmethod import datetime from typing import Optional -from invokeai.app.models.metadata import ( - GeneratedImageOrLatentsMetadata, - UploadedImageOrLatentsMetadata, -) - import sqlite3 import threading from typing import Optional, Union -from invokeai.app.models.metadata import ( - GeneratedImageOrLatentsMetadata, - UploadedImageOrLatentsMetadata, -) + +from invokeai.app.models.metadata import ImageMetadata from invokeai.app.models.image import ( ImageCategory, ImageType, ) from invokeai.app.services.util.create_enum_table import create_enum_table -from invokeai.app.services.models.image_record import ImageRecord -from invokeai.app.services.util.deserialize_image_record import ( +from invokeai.app.services.models.image_record import ( + ImageRecord, deserialize_image_record, ) @@ -76,9 +69,7 @@ class ImageRecordStorageBase(ABC): image_category: ImageCategory, session_id: Optional[str], node_id: Optional[str], - metadata: Optional[ - GeneratedImageOrLatentsMetadata | UploadedImageOrLatentsMetadata - ], + metadata: Optional[ImageMetadata], created_at: str = datetime.datetime.utcnow().isoformat(), ) -> None: """Saves an image record.""" @@ -288,9 +279,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): image_category: ImageCategory, session_id: Optional[str], node_id: Optional[str], - metadata: Union[ - GeneratedImageOrLatentsMetadata, UploadedImageOrLatentsMetadata, None - ], + metadata: Optional[ImageMetadata], created_at: str, ) -> None: try: @@ -306,7 +295,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): image_category, node_id, session_id, - metadata + metadata, created_at ) VALUES (?, ?, ?, ?, ?, ?, ?); diff --git a/invokeai/app/services/images.py b/invokeai/app/services/images.py index 1559e518b4..53f6a756d6 100644 --- a/invokeai/app/services/images.py +++ b/invokeai/app/services/images.py @@ -1,12 +1,13 @@ +from abc import ABC, abstractmethod +import json from logging import Logger from typing import Optional, Union import uuid from PIL.Image import Image as PILImageType +from PIL import PngImagePlugin + from invokeai.app.models.image import ImageCategory, ImageType -from invokeai.app.models.metadata import ( - GeneratedImageOrLatentsMetadata, - UploadedImageOrLatentsMetadata, -) +from invokeai.app.models.metadata import ImageMetadata from invokeai.app.services.image_record_storage import ( ImageRecordStorageBase, ) @@ -22,8 +23,95 @@ from invokeai.app.services.urls import UrlServiceBase from invokeai.app.util.misc import get_iso_timestamp +class ImageServiceABC(ABC): + """ + High-level service for image management. + + Provides methods for creating, retrieving, and deleting images. + """ + + @abstractmethod + def create( + self, + image: PILImageType, + image_type: ImageType, + image_category: ImageCategory, + node_id: Optional[str] = None, + session_id: Optional[str] = None, + metadata: Optional[ImageMetadata] = None, + ) -> ImageDTO: + """Creates an image, storing the file and its metadata.""" + pass + + @abstractmethod + def get_pil_image(self, image_type: ImageType, image_name: str) -> PILImageType: + """Gets an image as a PIL image.""" + pass + + @abstractmethod + def get_record(self, image_type: ImageType, image_name: str) -> ImageRecord: + """Gets an image record.""" + pass + + @abstractmethod + def get_path(self, image_type: ImageType, image_name: str) -> str: + """Gets an image's path""" + pass + + @abstractmethod + def get_image_url(self, image_type: ImageType, image_name: str) -> str: + """Gets an image's URL""" + pass + + @abstractmethod + def get_thumbnail_url(self, image_type: ImageType, image_name: str) -> str: + """Gets an image's URL""" + pass + + @abstractmethod + def get_dto(self, image_type: ImageType, image_name: str) -> ImageDTO: + """Gets an image DTO.""" + pass + + @abstractmethod + def get_many( + self, + image_type: ImageType, + image_category: ImageCategory, + page: int = 0, + per_page: int = 10, + ) -> PaginatedResults[ImageDTO]: + """Gets a paginated list of image DTOs.""" + pass + + @abstractmethod + def delete(self, image_type: ImageType, image_name: str): + """Deletes an image.""" + pass + + @abstractmethod + def add_tag(self, image_type: ImageType, image_id: str, tag: str) -> None: + """Adds a tag to an image.""" + pass + + @abstractmethod + def remove_tag(self, image_type: ImageType, image_id: str, tag: str) -> None: + """Removes a tag from an image.""" + pass + + @abstractmethod + def favorite(self, image_type: ImageType, image_id: str) -> None: + """Favorites an image.""" + pass + + @abstractmethod + def unfavorite(self, image_type: ImageType, image_id: str) -> None: + """Unfavorites an image.""" + pass + + class ImageServiceDependencies: - """Service dependencies for the ImageManagementService.""" + """Service dependencies for the ImageService.""" records: ImageRecordStorageBase files: ImageFileStorageBase @@ -46,9 +134,7 @@ class ImageServiceDependencies: self.logger = logger -class ImageService: - """High-level service for image management.""" - +class ImageService(ImageServiceABC): _services: ImageServiceDependencies def __init__( @@ -67,21 +153,6 @@ class ImageService: logger=logger, ) - def _create_image_name( - self, - image_type: ImageType, - image_category: ImageCategory, - node_id: Optional[str] = None, - session_id: Optional[str] = None, - ) -> str: - """Creates an image name.""" - uuid_str = str(uuid.uuid4()) - - if node_id is not None and session_id is not None: - return f"{image_type.value}_{image_category.value}_{session_id}_{node_id}_{uuid_str}.png" - - return f"{image_type.value}_{image_category.value}_{uuid_str}.png" - def create( self, image: PILImageType, @@ -89,11 +160,8 @@ class ImageService: image_category: ImageCategory, node_id: Optional[str] = None, session_id: Optional[str] = None, - metadata: Optional[ - Union[GeneratedImageOrLatentsMetadata, UploadedImageOrLatentsMetadata] - ] = None, + metadata: Optional[ImageMetadata] = None, ) -> ImageDTO: - """Creates an image, storing the file and its metadata.""" image_name = self._create_image_name( image_type=image_type, image_category=image_category, @@ -103,13 +171,19 @@ class ImageService: timestamp = get_iso_timestamp() + if metadata is not None: + pnginfo = PngImagePlugin.PngInfo() + pnginfo.add_text("invokeai", json.dumps(metadata)) + else: + pnginfo = None + try: # TODO: Consider using a transaction here to ensure consistency between storage and database self._services.files.save( image_type=image_type, image_name=image_name, image=image, - metadata=metadata, + pnginfo=pnginfo, ) self._services.records.save( @@ -144,25 +218,40 @@ class ImageService: except ImageFileStorageBase.ImageFileSaveException: self._services.logger.error("Failed to save image file") raise + except Exception as e: + self._services.logger.error("Problem saving image record and file") + raise e def get_pil_image(self, image_type: ImageType, image_name: str) -> PILImageType: - """Gets an image as a PIL image.""" try: return self._services.files.get(image_type, image_name) except ImageFileStorageBase.ImageFileNotFoundException: self._services.logger.error("Failed to get image file") raise + except Exception as e: + self._services.logger.error("Problem getting image file") + raise e def get_record(self, image_type: ImageType, image_name: str) -> ImageRecord: - """Gets an image record.""" try: return self._services.records.get(image_type, image_name) except ImageRecordStorageBase.ImageRecordNotFoundException: - self._services.logger.error("Failed to get image record") + self._services.logger.error("Image record not found") raise + except Exception as e: + self._services.logger.error("Problem getting image record") + raise e + + def get_path( + self, image_type: ImageType, image_name: str, thumbnail: bool = False + ) -> str: + try: + return self._services.files.get_path(image_type, image_name, thumbnail) + except Exception as e: + self._services.logger.error("Problem getting image path") + raise e def get_dto(self, image_type: ImageType, image_name: str) -> ImageDTO: - """Gets an image DTO.""" try: image_record = self._services.records.get(image_type, image_name) @@ -174,21 +263,11 @@ class ImageService: return image_dto except ImageRecordStorageBase.ImageRecordNotFoundException: - self._services.logger.error("Failed to get image DTO") - raise - - def delete(self, image_type: ImageType, image_name: str): - """Deletes an image.""" - # TODO: Consider using a transaction here to ensure consistency between storage and database - try: - self._services.files.delete(image_type, image_name) - self._services.records.delete(image_type, image_name) - except ImageRecordStorageBase.ImageRecordDeleteException: - self._services.logger.error(f"Failed to delete image record") - raise - except ImageFileStorageBase.ImageFileDeleteException: - self._services.logger.error(f"Failed to delete image file") + self._services.logger.error("Image record not found") raise + except Exception as e: + self._services.logger.error("Problem getting image DTO") + raise e def get_many( self, @@ -197,7 +276,6 @@ class ImageService: page: int = 0, per_page: int = 10, ) -> PaginatedResults[ImageDTO]: - """Gets a paginated list of image DTOs.""" try: results = self._services.records.get_many( image_type, @@ -225,21 +303,47 @@ class ImageService: total=results.total, ) except Exception as e: - self._services.logger.error("Failed to get paginated image DTOs") + self._services.logger.error("Problem getting paginated image DTOs") + raise e + + def delete(self, image_type: ImageType, image_name: str): + # TODO: Consider using a transaction here to ensure consistency between storage and database + try: + self._services.files.delete(image_type, image_name) + self._services.records.delete(image_type, image_name) + except ImageRecordStorageBase.ImageRecordDeleteException: + self._services.logger.error(f"Failed to delete image record") + raise + except ImageFileStorageBase.ImageFileDeleteException: + self._services.logger.error(f"Failed to delete image file") + raise + except Exception as e: + self._services.logger.error("Problem deleting image record and file") raise e def add_tag(self, image_type: ImageType, image_id: str, tag: str) -> None: - """Adds a tag to an image.""" raise NotImplementedError("The 'add_tag' method is not implemented yet.") def remove_tag(self, image_type: ImageType, image_id: str, tag: str) -> None: - """Removes a tag from an image.""" raise NotImplementedError("The 'remove_tag' method is not implemented yet.") def favorite(self, image_type: ImageType, image_id: str) -> None: - """Favorites an image.""" raise NotImplementedError("The 'favorite' method is not implemented yet.") def unfavorite(self, image_type: ImageType, image_id: str) -> None: - """Unfavorites an image.""" raise NotImplementedError("The 'unfavorite' method is not implemented yet.") + + def _create_image_name( + self, + image_type: ImageType, + image_category: ImageCategory, + node_id: Optional[str] = None, + session_id: Optional[str] = None, + ) -> str: + """Create a unique image name.""" + uuid_str = str(uuid.uuid4()) + + if node_id is not None and session_id is not None: + return f"{image_type.value}_{image_category.value}_{session_id}_{node_id}_{uuid_str}.png" + + return f"{image_type.value}_{image_category.value}_{uuid_str}.png" diff --git a/invokeai/app/services/models/image_record.py b/invokeai/app/services/models/image_record.py index 731e42e132..cd2f3aacbc 100644 --- a/invokeai/app/services/models/image_record.py +++ b/invokeai/app/services/models/image_record.py @@ -1,11 +1,10 @@ import datetime +import sqlite3 from typing import Optional, Union from pydantic import BaseModel, Field -from invokeai.app.models.metadata import ( - GeneratedImageOrLatentsMetadata, - UploadedImageOrLatentsMetadata, -) from invokeai.app.models.image import ImageCategory, ImageType +from invokeai.app.models.metadata import ImageMetadata +from invokeai.app.util.misc import get_iso_timestamp class ImageRecord(BaseModel): @@ -19,9 +18,9 @@ class ImageRecord(BaseModel): ) session_id: Optional[str] = Field(default=None, description="The session ID.") node_id: Optional[str] = Field(default=None, description="The node ID.") - metadata: Optional[ - Union[GeneratedImageOrLatentsMetadata, UploadedImageOrLatentsMetadata] - ] = Field(default=None, description="The image's metadata.") + metadata: Optional[ImageMetadata] = Field( + default=None, description="The image's metadata." + ) class ImageDTO(ImageRecord): @@ -46,3 +45,27 @@ def image_record_to_dto( image_url=image_url, thumbnail_url=thumbnail_url, ) + + +def deserialize_image_record(image_row: sqlite3.Row) -> ImageRecord: + """Deserializes an image record.""" + + image_dict = dict(image_row) + + image_type = ImageType(image_dict.get("image_type", ImageType.RESULT.value)) + + raw_metadata = image_dict.get("metadata", "{}") + + metadata = ImageMetadata.parse_raw(raw_metadata) + + return ImageRecord( + image_name=image_dict.get("id", "unknown"), + session_id=image_dict.get("session_id", None), + node_id=image_dict.get("node_id", None), + metadata=metadata, + image_type=image_type, + image_category=ImageCategory( + image_dict.get("image_category", ImageCategory.IMAGE.value) + ), + created_at=image_dict.get("created_at", get_iso_timestamp()), + ) diff --git a/invokeai/app/services/urls.py b/invokeai/app/services/urls.py index 16f8fc7494..8b2d53d2af 100644 --- a/invokeai/app/services/urls.py +++ b/invokeai/app/services/urls.py @@ -25,8 +25,8 @@ class LocalUrlService(UrlServiceBase): def get_image_url(self, image_type: ImageType, image_name: str) -> str: image_basename = os.path.basename(image_name) - return f"{self._base_url}/images/{image_type.value}/{image_basename}" + return f"{self._base_url}/files/images/{image_type.value}/{image_basename}" def get_thumbnail_url(self, image_type: ImageType, image_name: str) -> str: - thumbnail_basename = get_thumbnail_name(os.path.basename(image_name)) - return f"{self._base_url}/images/{image_type.value}/thumbnails/{thumbnail_basename}" + image_basename = os.path.basename(image_name) + return f"{self._base_url}/files/images/{image_type.value}/{image_basename}/thumbnail" diff --git a/invokeai/app/services/util/deserialize_image_record.py b/invokeai/app/services/util/deserialize_image_record.py deleted file mode 100644 index 52014b78c5..0000000000 --- a/invokeai/app/services/util/deserialize_image_record.py +++ /dev/null @@ -1,33 +0,0 @@ -from invokeai.app.models.metadata import ( - GeneratedImageOrLatentsMetadata, - UploadedImageOrLatentsMetadata, -) -from invokeai.app.models.image import ImageCategory, ImageType -from invokeai.app.services.models.image_record import ImageRecord -from invokeai.app.util.misc import get_iso_timestamp - - -def deserialize_image_record(image: dict) -> ImageRecord: - """Deserializes an image record.""" - - # All values *should* be present, except `session_id` and `node_id`, but provide some defaults just in case - - image_type = ImageType(image.get("image_type", ImageType.RESULT.value)) - raw_metadata = image.get("metadata", {}) - - if image_type == ImageType.UPLOAD: - metadata = UploadedImageOrLatentsMetadata.parse_obj(raw_metadata) - else: - metadata = GeneratedImageOrLatentsMetadata.parse_obj(raw_metadata) - - return ImageRecord( - image_name=image.get("id", "unknown"), - session_id=image.get("session_id", None), - node_id=image.get("node_id", None), - metadata=metadata, - image_type=image_type, - image_category=ImageCategory( - image.get("image_category", ImageCategory.IMAGE.value) - ), - created_at=image.get("created_at", get_iso_timestamp()), - )