diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index 517e174b68..7494d24324 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -1,9 +1,13 @@ # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) import os +from types import ModuleType +from invokeai.app.services.database.images.sqlite_images_db_service import ( + SqliteImageDb, +) +from invokeai.app.services.urls import LocalUrlService import invokeai.backend.util.logging as logger -from typing import types from ..services.default_graphs import create_system_graphs from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage @@ -17,6 +21,7 @@ from ..services.invoker import Invoker from ..services.processor import DefaultInvocationProcessor from ..services.sqlite import SqliteItemStorage from ..services.metadata import PngMetadataService +from ..services.results import SqliteResultsService from .events import FastAPIEventService @@ -50,28 +55,41 @@ class ApiDependencies: os.path.join(os.path.dirname(__file__), "../../../../outputs") ) - latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f'{output_folder}/latents')) + latents = ForwardCacheLatentsStorage( + DiskLatentsStorage(f"{output_folder}/latents") + ) metadata = PngMetadataService() - images = DiskImageStorage(f'{output_folder}/images', metadata_service=metadata) + urls = LocalUrlService() + + images = DiskImageStorage(f"{output_folder}/images", metadata_service=metadata) # TODO: build a file/path manager? db_location = os.path.join(output_folder, "invokeai.db") + graph_execution_manager = SqliteItemStorage[GraphExecutionState]( + filename=db_location, table_name="graph_executions" + ) + + images_db = SqliteImageDb(filename=db_location) + + # register event handler to update the `results` table when a graph execution state is inserted or updated + # graph_execution_manager.on_changed(results.handle_graph_execution_state_change) + services = InvocationServices( - model_manager=get_model_manager(config,logger), + model_manager=get_model_manager(config, logger), events=events, latents=latents, images=images, metadata=metadata, + images_db=images_db, + urls=urls, queue=MemoryInvocationQueue(), graph_library=SqliteItemStorage[LibraryGraph]( filename=db_location, table_name="graphs" ), - graph_execution_manager=SqliteItemStorage[GraphExecutionState]( - filename=db_location, table_name="graph_executions" - ), + graph_execution_manager=graph_execution_manager, processor=DefaultInvocationProcessor(), restoration=RestorationServices(config,logger), configuration=config, diff --git a/invokeai/app/api/routers/image_resources.py b/invokeai/app/api/routers/image_resources.py new file mode 100644 index 0000000000..56fcdcb2d1 --- /dev/null +++ b/invokeai/app/api/routers/image_resources.py @@ -0,0 +1,74 @@ +from fastapi import HTTPException, Path, Query +from fastapi.routing import APIRouter +from invokeai.app.models.image import ( + ImageCategory, + ImageType, +) +from invokeai.app.services.image_db import ImageRecordServiceBase +from invokeai.app.services.image_storage import ImageStorageBase +from invokeai.app.services.models.image_record import ImageRecord +from invokeai.app.services.item_storage import PaginatedResults + +from ..dependencies import ApiDependencies + +image_records_router = APIRouter(prefix="/v1/records/images", tags=["records"]) + + +@image_records_router.get("/{image_type}/{image_name}", operation_id="get_image_record") +async def get_image_record( + image_type: ImageType = Path(description="The type of the image record to get"), + image_name: str = Path(description="The id of the image record to get"), +) -> ImageRecord: + """Gets an image record by id""" + + try: + return ApiDependencies.invoker.services.images_new.get_record( + image_type=image_type, image_name=image_name + ) + except ImageRecordServiceBase.ImageRecordNotFoundException: + raise HTTPException(status_code=404) + + +@image_records_router.get( + "/", + operation_id="list_image_records", +) +async def list_image_records( + image_type: ImageType = Query(description="The type of image records to get"), + image_category: ImageCategory = Query( + description="The kind of image records to get" + ), + page: int = Query(default=0, description="The page of image records to get"), + per_page: int = Query( + default=10, description="The number of image records per page" + ), +) -> PaginatedResults[ImageRecord]: + """Gets a list of image records by type and category""" + + images = ApiDependencies.invoker.services.images_new.get_many( + image_type=image_type, + image_category=image_category, + page=page, + per_page=per_page, + ) + + return images + + +@image_records_router.delete("/{image_type}/{image_name}", operation_id="delete_image") +async def delete_image_record( + image_type: ImageType = Query(description="The type of image records to get"), + image_name: str = Path(description="The name of the image to delete"), +) -> None: + """Deletes an image record""" + + try: + ApiDependencies.invoker.services.images_new.delete( + image_type=image_type, image_name=image_name + ) + except ImageStorageBase.ImageFileDeleteException: + # TODO: log this + pass + except ImageRecordServiceBase.ImageRecordDeleteException: + # TODO: log this + pass diff --git a/invokeai/app/api/routers/images.py b/invokeai/app/api/routers/images.py index 0b7891e0f2..41ba00ef7a 100644 --- a/invokeai/app/api/routers/images.py +++ b/invokeai/app/api/routers/images.py @@ -14,23 +14,39 @@ from invokeai.app.api.models.images import ( ImageResponse, ImageResponseMetadata, ) +from invokeai.app.models.image import ImageType from invokeai.app.services.item_storage import PaginatedResults -from ...services.image_storage import ImageType from ..dependencies import ApiDependencies images_router = APIRouter(prefix="/v1/images", tags=["images"]) -@images_router.get("/{image_type}/{image_name}", operation_id="get_image") +# @images_router.get("/{image_type}/{image_name}", operation_id="get_image") +# async def get_image( +# image_type: ImageType = Path(description="The type of image to get"), +# image_name: str = Path(description="The name of the image to get"), +# ) -> FileResponse: +# """Gets an image""" + +# path = ApiDependencies.invoker.services.images.get_path( +# image_type=image_type, image_name=image_name +# ) + +# if ApiDependencies.invoker.services.images.validate_path(path): +# return FileResponse(path) +# else: +# raise HTTPException(status_code=404) + +@images_router.get("/{image_type}/{image_id}", operation_id="get_image") async def get_image( - image_type: ImageType = Path(description="The type of image to get"), - image_name: str = Path(description="The name of the image to get"), + image_type: ImageType = Path(description="The type of the image to get"), + image_id: str = Path(description="The id of the image to get"), ) -> FileResponse: """Gets an image""" path = ApiDependencies.invoker.services.images.get_path( - image_type=image_type, image_name=image_name + image_type=image_type, image_id=image_id ) if ApiDependencies.invoker.services.images.validate_path(path): @@ -41,7 +57,7 @@ async def get_image( @images_router.delete("/{image_type}/{image_name}", operation_id="delete_image") async def delete_image( - image_type: ImageType = Path(description="The type of image to delete"), + image_type: ImageType = Path(description="The type of the image to delete"), image_name: str = Path(description="The name of the image to delete"), ) -> None: """Deletes an image and its thumbnail""" @@ -52,16 +68,16 @@ async def delete_image( @images_router.get( - "/{thumbnail_type}/thumbnails/{thumbnail_name}", operation_id="get_thumbnail" + "/{image_type}/thumbnails/{thumbnail_id}", operation_id="get_thumbnail" ) async def get_thumbnail( - thumbnail_type: ImageType = Path(description="The type of thumbnail to get"), - thumbnail_name: str = Path(description="The name of the thumbnail to get"), + image_type: ImageType = Path(description="The type of the thumbnail to get"), + thumbnail_id: str = Path(description="The id of the thumbnail to get"), ) -> FileResponse | Response: """Gets a thumbnail""" path = ApiDependencies.invoker.services.images.get_path( - image_type=thumbnail_type, image_name=thumbnail_name, is_thumbnail=True + image_type=image_type, image_id=thumbnail_id, is_thumbnail=True ) if ApiDependencies.invoker.services.images.validate_path(path): diff --git a/invokeai/app/api/routers/results.py b/invokeai/app/api/routers/results.py new file mode 100644 index 0000000000..4190e5bd27 --- /dev/null +++ b/invokeai/app/api/routers/results.py @@ -0,0 +1,42 @@ +from fastapi import HTTPException, Path, Query +from fastapi.routing import APIRouter +from invokeai.app.services.results import ResultType, ResultWithSession +from invokeai.app.services.item_storage import PaginatedResults + +from ..dependencies import ApiDependencies + +results_router = APIRouter(prefix="/v1/results", tags=["results"]) + + +@results_router.get("/{result_type}/{result_name}", operation_id="get_result") +async def get_result( + result_type: ResultType = Path(description="The type of result to get"), + result_name: str = Path(description="The name of the result to get"), +) -> ResultWithSession: + """Gets a result""" + + result = ApiDependencies.invoker.services.results.get( + result_id=result_name, result_type=result_type + ) + + if result is not None: + return result + else: + raise HTTPException(status_code=404) + + +@results_router.get( + "/", + operation_id="list_results", + responses={200: {"model": PaginatedResults[ResultWithSession]}}, +) +async def list_results( + result_type: ResultType = Query(description="The type of results to get"), + page: int = Query(default=0, description="The page of results to get"), + per_page: int = Query(default=10, description="The number of results per page"), +) -> PaginatedResults[ResultWithSession]: + """Gets a list of results""" + results = ApiDependencies.invoker.services.results.get_many( + result_type=result_type, page=page, per_page=per_page + ) + return results diff --git a/invokeai/app/api_app.py b/invokeai/app/api_app.py index 33714f1057..a67f36edd3 100644 --- a/invokeai/app/api_app.py +++ b/invokeai/app/api_app.py @@ -3,6 +3,7 @@ import asyncio from inspect import signature import uvicorn +from invokeai.app.models import resources import invokeai.backend.util.logging as logger from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware @@ -14,11 +15,12 @@ from fastapi_events.middleware import EventHandlerASGIMiddleware from pydantic.schema import schema from .api.dependencies import ApiDependencies -from .api.routers import images, sessions, models +from .api.routers import image_resources, images, sessions, models from .api.sockets import SocketIO from .invocations.baseinvocation import BaseInvocation from .services.config import InvokeAIAppConfig + # Create the app # TODO: create this all in a method so configuration/etc. can be passed in? app = FastAPI(title="Invoke AI", docs_url=None, redoc_url=None) @@ -73,6 +75,8 @@ app.include_router(images.images_router, prefix="/api") app.include_router(models.models_router, prefix="/api") +app.include_router(image_resources.image_resources_router, prefix="/api") + # Build a custom OpenAPI to include all outputs # TODO: can outputs be included on metadata of invocation schemas somehow? @@ -121,6 +125,7 @@ app.openapi = custom_openapi # Override API doc favicons app.mount("/static", StaticFiles(directory="static/dream_web"), name="static") + @app.get("/docs", include_in_schema=False) def overridden_swagger(): return get_swagger_ui_html( @@ -138,8 +143,12 @@ def overridden_redoc(): redoc_favicon_url="/static/favicon.ico", ) + # Must mount *after* the other routes else it borks em -app.mount("/", StaticFiles(directory="invokeai/frontend/web/dist", html=True), name="ui") +app.mount( + "/", StaticFiles(directory="invokeai/frontend/web/dist", html=True), name="ui" +) + def invoke_api(): # Start our own event loop for eventing usage diff --git a/invokeai/app/invocations/baseinvocation.py b/invokeai/app/invocations/baseinvocation.py index 7daaa588b1..da61641105 100644 --- a/invokeai/app/invocations/baseinvocation.py +++ b/invokeai/app/invocations/baseinvocation.py @@ -1,12 +1,15 @@ # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) +from __future__ import annotations + from abc import ABC, abstractmethod from inspect import signature -from typing import get_args, get_type_hints, Dict, List, Literal, TypedDict +from typing import get_args, get_type_hints, Dict, List, Literal, TypedDict, TYPE_CHECKING from pydantic import BaseModel, Field -from ..services.invocation_services import InvocationServices +if TYPE_CHECKING: + from ..services.invocation_services import InvocationServices class InvocationContext: diff --git a/invokeai/app/invocations/generate.py b/invokeai/app/invocations/generate.py index bc72bbe2b3..525be128e4 100644 --- a/invokeai/app/invocations/generate.py +++ b/invokeai/app/invocations/generate.py @@ -10,6 +10,8 @@ from pydantic import BaseModel, Field from invokeai.app.models.image import ColorField, ImageField, ImageType from invokeai.app.invocations.util.choose_model import choose_model +from invokeai.app.models.metadata import GeneratedImageOrLatentsMetadata +from invokeai.app.models.image import ImageCategory, ImageType from invokeai.app.util.misc import SEED_MAX, get_random_seed from invokeai.backend.generator.inpaint import infill_methods from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig @@ -106,6 +108,16 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation): context.services.images.save( image_type, image_name, generate_output.image, metadata ) + + context.services.images_db.set( + id=image_name, + image_type=ImageType.RESULT, + image_category=ImageCategory.IMAGE, + session_id=context.graph_execution_state_id, + node_id=self.id, + metadata=GeneratedImageOrLatentsMetadata(), + ) + return build_image_output( image_type=image_type, image_name=image_name, diff --git a/invokeai/app/invocations/image.py b/invokeai/app/invocations/image.py index 8b4163c4c6..56141cbb0e 100644 --- a/invokeai/app/invocations/image.py +++ b/invokeai/app/invocations/image.py @@ -31,7 +31,7 @@ class ImageOutput(BaseInvocationOutput): """Base class for invocations that output an image""" # fmt: off - type: Literal["image"] = "image" + type: Literal["image_output"] = "image_output" image: ImageField = Field(default=None, description="The output image") width: int = Field(description="The width of the image in pixels") height: int = Field(description="The height of the image in pixels") diff --git a/invokeai/app/models/image.py b/invokeai/app/models/image.py index f6813c6d96..f364abdb71 100644 --- a/invokeai/app/models/image.py +++ b/invokeai/app/models/image.py @@ -2,11 +2,23 @@ from enum import Enum from typing import Optional, Tuple from pydantic import BaseModel, Field +from invokeai.app.util.enum import MetaEnum + + +class ImageType(str, Enum, metaclass=MetaEnum): + """The type of an image.""" -class ImageType(str, Enum): RESULT = "results" - INTERMEDIATE = "intermediates" UPLOAD = "uploads" + INTERMEDIATE = "intermediates" + + +class ImageCategory(str, Enum, metaclass=MetaEnum): + """The category of an image. Use ImageCategory.OTHER for non-default categories.""" + + IMAGE = "image" + CONTROL_IMAGE = "control_image" + OTHER = "other" def is_image_type(obj): diff --git a/invokeai/app/models/metadata.py b/invokeai/app/models/metadata.py new file mode 100644 index 0000000000..aae3337266 --- /dev/null +++ b/invokeai/app/models/metadata.py @@ -0,0 +1,70 @@ +from typing import Optional +from pydantic import BaseModel, Field, StrictFloat, StrictInt, StrictStr + + +class GeneratedImageOrLatentsMetadata(BaseModel): + """Core generation metadata for an image/tensor generated in InvokeAI. + + Generated by traversing the execution graph, collecting the parameters of the nearest ancestors of a given node. + + Full metadata may be accessed by querying for the session in the `graph_executions` table. + """ + + positive_conditioning: Optional[StrictStr] = Field( + default=None, description="The positive conditioning." + ) + negative_conditioning: Optional[StrictStr] = Field( + default=None, description="The negative conditioning." + ) + width: Optional[StrictInt] = Field( + default=None, description="Width of the image/tensor in pixels." + ) + height: Optional[StrictInt] = Field( + default=None, description="Height of the image/tensor in pixels." + ) + seed: Optional[StrictInt] = Field( + default=None, description="The seed used for noise generation." + ) + cfg_scale: Optional[StrictFloat] = Field( + default=None, description="The classifier-free guidance scale." + ) + steps: Optional[StrictInt] = Field( + default=None, description="The number of steps used for inference." + ) + scheduler: Optional[StrictStr] = Field( + default=None, description="The scheduler used for inference." + ) + model: Optional[StrictStr] = Field( + default=None, description="The model used for inference." + ) + strength: Optional[StrictFloat] = Field( + default=None, + description="The strength used for image-to-image/tensor-to-tensor.", + ) + image: Optional[StrictStr] = Field( + default=None, description="The ID of the initial image." + ) + tensor: Optional[StrictStr] = Field( + default=None, description="The ID of the initial tensor." + ) + # Pending model refactor: + # vae: Optional[str] = Field(default=None,description="The VAE used for decoding.") + # unet: Optional[str] = Field(default=None,description="The UNet used dor inference.") + # clip: Optional[str] = Field(default=None,description="The CLIP Encoder used for conditioning.") + + +class UploadedImageOrLatentsMetadata(BaseModel): + """Limited metadata for an uploaded image/tensor.""" + + width: Optional[StrictInt] = Field( + default=None, description="Width of the image/tensor in pixels." + ) + height: Optional[StrictInt] = Field( + default=None, description="Height of the image/tensor in pixels." + ) + # The extra field will be the contents of the PNG file's tEXt chunk. It may have come + # from another SD application or InvokeAI, so it needs to be flexible. + # If the upload is a not an image or `image_latents` tensor, this will be omitted. + extra: Optional[StrictStr] = Field( + default=None, description="Extra metadata, extracted from the PNG tEXt chunk." + ) diff --git a/invokeai/app/models/resources.py b/invokeai/app/models/resources.py new file mode 100644 index 0000000000..1cd22e4550 --- /dev/null +++ b/invokeai/app/models/resources.py @@ -0,0 +1,28 @@ +# TODO: Make a new model for this +from enum import Enum + +from invokeai.app.util.enum import MetaEnum + + +class ResourceType(str, Enum, metaclass=MetaEnum): + """The type of a resource.""" + + IMAGES = "images" + TENSORS = "tensors" + + +# class ResourceOrigin(str, Enum, metaclass=MetaEnum): +# """The origin of a resource (eg image or tensor).""" + +# RESULTS = "results" +# UPLOADS = "uploads" +# INTERMEDIATES = "intermediates" + + + +class TensorKind(str, Enum, metaclass=MetaEnum): + """The kind of a tensor. Use TensorKind.OTHER for non-default kinds.""" + + IMAGE_LATENTS = "image_latents" + CONDITIONING = "conditioning" + OTHER = "other" diff --git a/invokeai/app/services/db.ipynb b/invokeai/app/services/db.ipynb new file mode 100644 index 0000000000..67dfe22128 --- /dev/null +++ b/invokeai/app/services/db.ipynb @@ -0,0 +1,578 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 40, + "metadata": {}, + "outputs": [], + "source": [ + "from abc import ABC, abstractmethod\n", + "from enum import Enum\n", + "import enum\n", + "import sqlite3\n", + "import threading\n", + "from typing import Optional, Type, TypeVar, Union\n", + "from PIL.Image import Image as PILImage\n", + "from pydantic import BaseModel, Field\n", + "from torch import Tensor" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "class ResourceOrigin(str, Enum):\n", + " \"\"\"The origin of a resource (eg image or tensor).\"\"\"\n", + "\n", + " RESULTS = \"results\"\n", + " UPLOADS = \"uploads\"\n", + " INTERMEDIATES = \"intermediates\"\n", + "\n", + "\n", + "class ImageKind(str, Enum):\n", + " \"\"\"The kind of an image. Use ImageKind.OTHER for non-default kinds.\"\"\"\n", + "\n", + " IMAGE = \"image\"\n", + " CONTROL_IMAGE = \"control_image\"\n", + " OTHER = \"other\"\n", + "\n", + "\n", + "class TensorKind(str, Enum):\n", + " \"\"\"The kind of a tensor. Use TensorKind.OTHER for non-default kinds.\"\"\"\n", + "\n", + " IMAGE_LATENTS = \"image_latents\"\n", + " CONDITIONING = \"conditioning\"\n", + " OTHER = \"other\"\n" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "def create_sql_values_string_from_string_enum(enum: Type[Enum]):\n", + " \"\"\"\n", + " Creates a string of the form \"('value1'), ('value2'), ..., ('valueN')\" from a StrEnum.\n", + " \"\"\"\n", + "\n", + " delimiter = \", \"\n", + " values = [f\"('{e.value}')\" for e in enum]\n", + " return delimiter.join(values)\n", + "\n", + "\n", + "def create_sql_table_from_enum(\n", + " enum: Type[Enum],\n", + " table_name: str,\n", + " primary_key_name: str,\n", + " conn: sqlite3.Connection,\n", + " cursor: sqlite3.Cursor,\n", + " lock: threading.Lock,\n", + "):\n", + " \"\"\"\n", + " Creates and populates a table to be used as a functional enum.\n", + " \"\"\"\n", + "\n", + " try:\n", + " lock.acquire()\n", + "\n", + " values_string = create_sql_values_string_from_string_enum(enum)\n", + "\n", + " cursor.execute(\n", + " f\"\"\"--sql\n", + " CREATE TABLE IF NOT EXISTS {table_name} (\n", + " {primary_key_name} TEXT PRIMARY KEY\n", + " );\n", + " \"\"\"\n", + " )\n", + " cursor.execute(\n", + " f\"\"\"--sql\n", + " INSERT OR IGNORE INTO {table_name} ({primary_key_name}) VALUES {values_string};\n", + " \"\"\"\n", + " )\n", + " conn.commit()\n", + " finally:\n", + " lock.release()\n", + "\n", + "\n", + "\"\"\"\n", + "`resource_origins` functions as an enum for the ResourceOrigin model.\n", + "\"\"\"\n", + "\n", + "\n", + "# def create_resource_origins_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n", + "# create_sql_table_from_enum(\n", + "# enum=ResourceOrigin,\n", + "# table_name=\"resource_origins\",\n", + "# primary_key_name=\"origin_name\",\n", + "# conn=conn,\n", + "# cursor=cursor,\n", + "# lock=lock,\n", + "# )\n", + "\n", + "\n", + "\"\"\"\n", + "`image_kinds` functions as an enum for the ImageType model.\n", + "\"\"\"\n", + "\n", + "\n", + "# def create_image_kinds_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n", + " # create_sql_table_from_enum(\n", + " # enum=ImageKind,\n", + " # table_name=\"image_kinds\",\n", + " # primary_key_name=\"kind_name\",\n", + " # conn=conn,\n", + " # cursor=cursor,\n", + " # lock=lock,\n", + " # )\n", + "\n", + "\n", + "\"\"\"\n", + "`tensor_kinds` functions as an enum for the TensorType model.\n", + "\"\"\"\n", + "\n", + "\n", + "# def create_tensor_kinds_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n", + " # create_sql_table_from_enum(\n", + " # enum=TensorKind,\n", + " # table_name=\"tensor_kinds\",\n", + " # primary_key_name=\"kind_name\",\n", + " # conn=conn,\n", + " # cursor=cursor,\n", + " # lock=lock,\n", + " # )\n", + "\n", + "\n", + "\"\"\"\n", + "`images` stores all images, regardless of type\n", + "\"\"\"\n", + "\n", + "\n", + "def create_images_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n", + " try:\n", + " lock.acquire()\n", + "\n", + " cursor.execute(\n", + " \"\"\"--sql\n", + " CREATE TABLE IF NOT EXISTS images (\n", + " id TEXT PRIMARY KEY,\n", + " origin TEXT,\n", + " image_kind TEXT,\n", + " created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,\n", + " FOREIGN KEY(origin) REFERENCES resource_origins(origin_name),\n", + " FOREIGN KEY(image_kind) REFERENCES image_kinds(kind_name)\n", + " );\n", + " \"\"\"\n", + " )\n", + " cursor.execute(\n", + " \"\"\"--sql\n", + " CREATE UNIQUE INDEX IF NOT EXISTS idx_images_id ON images(id);\n", + " \"\"\"\n", + " )\n", + " cursor.execute(\n", + " \"\"\"--sql\n", + " CREATE INDEX IF NOT EXISTS idx_images_origin ON images(origin);\n", + " \"\"\"\n", + " )\n", + " cursor.execute(\n", + " \"\"\"--sql\n", + " CREATE INDEX IF NOT EXISTS idx_images_image_kind ON images(image_kind);\n", + " \"\"\"\n", + " )\n", + " conn.commit()\n", + " finally:\n", + " lock.release()\n", + "\n", + "\n", + "\"\"\"\n", + "`images_results` stores additional data specific to `results` images.\n", + "\"\"\"\n", + "\n", + "\n", + "def create_images_results_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n", + " try:\n", + " lock.acquire()\n", + "\n", + " cursor.execute(\n", + " \"\"\"--sql\n", + " CREATE TABLE IF NOT EXISTS images_results (\n", + " images_id TEXT PRIMARY KEY,\n", + " session_id TEXT NOT NULL,\n", + " node_id TEXT NOT NULL,\n", + " FOREIGN KEY(images_id) REFERENCES images(id) ON DELETE CASCADE\n", + " );\n", + " \"\"\"\n", + " )\n", + " cursor.execute(\n", + " \"\"\"--sql\n", + " CREATE UNIQUE INDEX IF NOT EXISTS idx_images_results_images_id ON images_results(images_id);\n", + " \"\"\"\n", + " )\n", + " conn.commit()\n", + " finally:\n", + " lock.release()\n", + "\n", + "\n", + "\"\"\"\n", + "`images_intermediates` stores additional data specific to `intermediates` images\n", + "\"\"\"\n", + "\n", + "\n", + "def create_images_intermediates_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n", + " try:\n", + " lock.acquire()\n", + "\n", + " cursor.execute(\n", + " \"\"\"--sql\n", + " CREATE TABLE IF NOT EXISTS images_intermediates (\n", + " images_id TEXT PRIMARY KEY,\n", + " session_id TEXT NOT NULL,\n", + " node_id TEXT NOT NULL,\n", + " FOREIGN KEY(images_id) REFERENCES images(id) ON DELETE CASCADE\n", + " );\n", + " \"\"\"\n", + " )\n", + " cursor.execute(\n", + " \"\"\"--sql\n", + " CREATE UNIQUE INDEX IF NOT EXISTS idx_images_intermediates_images_id ON images_intermediates(images_id);\n", + " \"\"\"\n", + " )\n", + " conn.commit()\n", + " finally:\n", + " lock.release()\n", + "\n", + "\n", + "\"\"\"\n", + "`images_metadata` stores basic metadata for any image type\n", + "\"\"\"\n", + "\n", + "\n", + "def create_images_metadata_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n", + " try:\n", + " lock.acquire()\n", + "\n", + " cursor.execute(\n", + " \"\"\"--sql\n", + " CREATE TABLE IF NOT EXISTS images_metadata (\n", + " images_id TEXT PRIMARY KEY,\n", + " metadata TEXT,\n", + " FOREIGN KEY(images_id) REFERENCES images(id) ON DELETE CASCADE\n", + " );\n", + " \"\"\"\n", + " )\n", + " cursor.execute(\n", + " \"\"\"--sql\n", + " CREATE UNIQUE INDEX IF NOT EXISTS idx_images_metadata_images_id ON images_metadata(images_id);\n", + " \"\"\"\n", + " )\n", + " conn.commit()\n", + " finally:\n", + " lock.release()\n", + "\n", + "\n", + "# `tensors` table: stores references to tensor\n", + "\n", + "\n", + "def create_tensors_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n", + " try:\n", + " lock.acquire()\n", + "\n", + " cursor.execute(\n", + " \"\"\"--sql\n", + " CREATE TABLE IF NOT EXISTS tensors (\n", + " id TEXT PRIMARY KEY,\n", + " origin TEXT,\n", + " tensor_kind TEXT,\n", + " created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,\n", + " FOREIGN KEY(origin) REFERENCES resource_origins(origin_name),\n", + " FOREIGN KEY(tensor_kind) REFERENCES tensor_kinds(kind_name)\n", + " );\n", + " \"\"\"\n", + " )\n", + " cursor.execute(\n", + " \"\"\"--sql\n", + " CREATE UNIQUE INDEX IF NOT EXISTS idx_tensors_id ON tensors(id);\n", + " \"\"\"\n", + " )\n", + " cursor.execute(\n", + " \"\"\"--sql\n", + " CREATE INDEX IF NOT EXISTS idx_tensors_origin ON tensors(origin);\n", + " \"\"\"\n", + " )\n", + " cursor.execute(\n", + " \"\"\"--sql\n", + " CREATE INDEX IF NOT EXISTS idx_tensors_tensor_kind ON tensors(tensor_kind);\n", + " \"\"\"\n", + " )\n", + " conn.commit()\n", + " finally:\n", + " lock.release()\n", + "\n", + "\n", + "# `tensors_results` stores additional data specific to `result` tensor\n", + "\n", + "\n", + "def create_tensors_results_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n", + " try:\n", + " lock.acquire()\n", + "\n", + " cursor.execute(\n", + " \"\"\"--sql\n", + " CREATE TABLE IF NOT EXISTS tensors_results (\n", + " tensors_id TEXT PRIMARY KEY,\n", + " session_id TEXT NOT NULL,\n", + " node_id TEXT NOT NULL,\n", + " FOREIGN KEY(tensors_id) REFERENCES tensors(id) ON DELETE CASCADE\n", + " );\n", + " \"\"\"\n", + " )\n", + " cursor.execute(\n", + " \"\"\"--sql\n", + " CREATE UNIQUE INDEX IF NOT EXISTS idx_tensors_results_tensors_id ON tensors_results(tensors_id);\n", + " \"\"\"\n", + " )\n", + " conn.commit()\n", + " finally:\n", + " lock.release()\n", + "\n", + "\n", + "# `tensors_intermediates` stores additional data specific to `intermediate` tensor\n", + "\n", + "\n", + "def create_tensors_intermediates_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n", + " try:\n", + " lock.acquire()\n", + "\n", + " cursor.execute(\n", + " \"\"\"--sql\n", + " CREATE TABLE IF NOT EXISTS tensors_intermediates (\n", + " tensors_id TEXT PRIMARY KEY,\n", + " session_id TEXT NOT NULL,\n", + " node_id TEXT NOT NULL,\n", + " FOREIGN KEY(tensors_id) REFERENCES tensors(id) ON DELETE CASCADE\n", + " );\n", + " \"\"\"\n", + " )\n", + " cursor.execute(\n", + " \"\"\"--sql\n", + " CREATE UNIQUE INDEX IF NOT EXISTS idx_tensors_intermediates_tensors_id ON tensors_intermediates(tensors_id);\n", + " \"\"\"\n", + " )\n", + " conn.commit()\n", + " finally:\n", + " lock.release()\n", + "\n", + "\n", + "# `tensors_metadata` table: stores generated/transformed metadata for tensor\n", + "\n", + "\n", + "def create_tensors_metadata_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n", + " try:\n", + " lock.acquire()\n", + "\n", + " cursor.execute(\n", + " \"\"\"--sql\n", + " CREATE TABLE IF NOT EXISTS tensors_metadata (\n", + " tensors_id TEXT PRIMARY KEY,\n", + " metadata TEXT,\n", + " FOREIGN KEY(tensors_id) REFERENCES tensors(id) ON DELETE CASCADE\n", + " );\n", + " \"\"\"\n", + " )\n", + " cursor.execute(\n", + " \"\"\"--sql\n", + " CREATE UNIQUE INDEX IF NOT EXISTS idx_tensors_metadata_tensors_id ON tensors_metadata(tensors_id);\n", + " \"\"\"\n", + " )\n", + " conn.commit()\n", + " finally:\n", + " lock.release()\n" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "db_path = '/home/bat/Documents/Code/outputs/test.db'\n", + "if (os.path.exists(db_path)):\n", + " os.remove(db_path)\n", + "\n", + "conn = sqlite3.connect(\n", + " db_path, check_same_thread=False\n", + ")\n", + "cursor = conn.cursor()\n", + "lock = threading.Lock()" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "metadata": {}, + "outputs": [], + "source": [ + "create_sql_table_from_enum(\n", + " enum=ResourceOrigin,\n", + " table_name=\"resource_origins\",\n", + " primary_key_name=\"origin_name\",\n", + " conn=conn,\n", + " cursor=cursor,\n", + " lock=lock,\n", + ")\n", + "\n", + "create_sql_table_from_enum(\n", + " enum=ImageKind,\n", + " table_name=\"image_kinds\",\n", + " primary_key_name=\"kind_name\",\n", + " conn=conn,\n", + " cursor=cursor,\n", + " lock=lock,\n", + ")\n", + "\n", + "create_sql_table_from_enum(\n", + " enum=TensorKind,\n", + " table_name=\"tensor_kinds\",\n", + " primary_key_name=\"kind_name\",\n", + " conn=conn,\n", + " cursor=cursor,\n", + " lock=lock,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "metadata": {}, + "outputs": [], + "source": [ + "create_images_table(conn, cursor, lock)\n", + "create_images_results_table(conn, cursor, lock)\n", + "create_images_intermediates_table(conn, cursor, lock)\n", + "create_images_metadata_table(conn, cursor, lock)" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "metadata": {}, + "outputs": [], + "source": [ + "create_tensors_table(conn, cursor, lock)\n", + "create_tensors_results_table(conn, cursor, lock)\n", + "create_tensors_intermediates_table(conn, cursor, lock)\n", + "create_tensors_metadata_table(conn, cursor, lock)" + ] + }, + { + "cell_type": "code", + "execution_count": 59, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "from pydantic import StrictStr\n", + "\n", + "\n", + "class GeneratedImageOrLatentsMetadata(BaseModel):\n", + " \"\"\"Core generation metadata for an image/tensor generated in InvokeAI.\n", + "\n", + " Generated by traversing the execution graph, collecting the parameters of the nearest ancestors of a given node.\n", + "\n", + " Full metadata may be accessed by querying for the session in the `graph_executions` table.\n", + " \"\"\"\n", + "\n", + " positive_conditioning: Optional[StrictStr] = Field(\n", + " default=None, description=\"The positive conditioning.\"\n", + " )\n", + " negative_conditioning: Optional[str] = Field(\n", + " default=None, description=\"The negative conditioning.\"\n", + " )\n", + " width: Optional[int] = Field(\n", + " default=None, description=\"Width of the image/tensor in pixels.\"\n", + " )\n", + " height: Optional[int] = Field(\n", + " default=None, description=\"Height of the image/tensor in pixels.\"\n", + " )\n", + " seed: Optional[int] = Field(\n", + " default=None, description=\"The seed used for noise generation.\"\n", + " )\n", + " cfg_scale: Optional[float] = Field(\n", + " default=None, description=\"The classifier-free guidance scale.\"\n", + " )\n", + " steps: Optional[int] = Field(\n", + " default=None, description=\"The number of steps used for inference.\"\n", + " )\n", + " scheduler: Optional[str] = Field(\n", + " default=None, description=\"The scheduler used for inference.\"\n", + " )\n", + " model: Optional[str] = Field(\n", + " default=None, description=\"The model used for inference.\"\n", + " )\n", + " strength: Optional[float] = Field(\n", + " default=None,\n", + " description=\"The strength used for image-to-image/tensor-to-tensor.\",\n", + " )\n", + " image: Optional[str] = Field(\n", + " default=None, description=\"The ID of the initial image.\"\n", + " )\n", + " tensor: Optional[str] = Field(\n", + " default=None, description=\"The ID of the initial tensor.\"\n", + " )\n", + " # Pending model refactor:\n", + " # vae: Optional[str] = Field(default=None,description=\"The VAE used for decoding.\")\n", + " # unet: Optional[str] = Field(default=None,description=\"The UNet used dor inference.\")\n", + " # clip: Optional[str] = Field(default=None,description=\"The CLIP Encoder used for conditioning.\")\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 61, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "GeneratedImageOrLatentsMetadata(positive_conditioning='123', negative_conditioning=None, width=None, height=None, seed=None, cfg_scale=None, steps=None, scheduler=None, model=None, strength=None, image=None, tensor=None)" + ] + }, + "execution_count": 61, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "GeneratedImageOrLatentsMetadata(positive_conditioning='123')" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.6" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/invokeai/app/services/image_db.py b/invokeai/app/services/image_db.py new file mode 100644 index 0000000000..73984c6685 --- /dev/null +++ b/invokeai/app/services/image_db.py @@ -0,0 +1,329 @@ +from abc import ABC, abstractmethod +import datetime +from typing import Optional +from invokeai.app.models.metadata import ( + GeneratedImageOrLatentsMetadata, + UploadedImageOrLatentsMetadata, +) + +import sqlite3 +import threading +from typing import Optional, Union +from invokeai.app.models.metadata import ( + GeneratedImageOrLatentsMetadata, + UploadedImageOrLatentsMetadata, +) +from invokeai.app.models.image import ( + ImageCategory, + ImageType, +) +from invokeai.app.services.util.create_enum_table import create_enum_table +from invokeai.app.services.models.image_record import ImageRecord +from invokeai.app.services.util.deserialize_image_record import ( + deserialize_image_record, +) + +from invokeai.app.services.item_storage import PaginatedResults + + +class ImageRecordServiceBase(ABC): + """Low-level service responsible for interfacing with the image record store.""" + + class ImageRecordNotFoundException(Exception): + """Raised when an image record is not found.""" + + def __init__(self, message="Image record not found"): + super().__init__(message) + + class ImageRecordSaveException(Exception): + """Raised when an image record cannot be saved.""" + + def __init__(self, message="Image record not saved"): + super().__init__(message) + + class ImageRecordDeleteException(Exception): + """Raised when an image record cannot be deleted.""" + + def __init__(self, message="Image record not deleted"): + super().__init__(message) + + @abstractmethod + def get(self, image_type: ImageType, image_name: str) -> ImageRecord: + """Gets an image record.""" + pass + + @abstractmethod + def get_many( + self, + image_type: ImageType, + image_category: ImageCategory, + page: int = 0, + per_page: int = 10, + ) -> PaginatedResults[ImageRecord]: + """Gets a page of image records.""" + pass + + @abstractmethod + def delete(self, image_type: ImageType, image_name: str) -> None: + """Deletes an image record.""" + pass + + @abstractmethod + def save( + self, + image_name: str, + image_type: ImageType, + image_category: ImageCategory, + session_id: Optional[str], + node_id: Optional[str], + metadata: Optional[ + GeneratedImageOrLatentsMetadata | UploadedImageOrLatentsMetadata + ], + created_at: str = datetime.datetime.utcnow().isoformat(), + ) -> None: + """Saves an image record.""" + pass + + +class SqliteImageRecordService(ImageRecordServiceBase): + _filename: str + _conn: sqlite3.Connection + _cursor: sqlite3.Cursor + _lock: threading.Lock + + def __init__(self, filename: str) -> None: + super().__init__() + + self._filename = filename + self._conn = sqlite3.connect(filename, check_same_thread=False) + # Enable row factory to get rows as dictionaries (must be done before making the cursor!) + self._conn.row_factory = sqlite3.Row + self._cursor = self._conn.cursor() + self._lock = threading.Lock() + + try: + self._lock.acquire() + # Enable foreign keys + self._conn.execute("PRAGMA foreign_keys = ON;") + self._create_tables() + self._conn.commit() + finally: + self._lock.release() + + def _create_tables(self) -> None: + """Creates the tables for the `images` database.""" + + # Create the `images` table. + self._cursor.execute( + f"""--sql + CREATE TABLE IF NOT EXISTS images ( + id TEXT PRIMARY KEY, + image_type TEXT, -- non-nullable via foreign key constraint + image_category TEXT, -- non-nullable via foreign key constraint + session_id TEXT, -- nullable + node_id TEXT, -- nullable + metadata TEXT, -- nullable + created_at TEXT NOT NULL, + FOREIGN KEY(image_type) REFERENCES image_types(type_name), + FOREIGN KEY(image_category) REFERENCES image_categories(category_name) + ); + """ + ) + + # Create the `images` table indices. + self._cursor.execute( + """--sql + CREATE UNIQUE INDEX IF NOT EXISTS idx_images_id ON images(id); + """ + ) + self._cursor.execute( + """--sql + CREATE INDEX IF NOT EXISTS idx_images_image_type ON images(image_type); + """ + ) + self._cursor.execute( + """--sql + CREATE INDEX IF NOT EXISTS idx_images_image_category ON images(image_category); + """ + ) + + # Create the tables for image-related enums + create_enum_table( + enum=ImageType, + table_name="image_types", + primary_key_name="type_name", + cursor=self._cursor, + ) + + create_enum_table( + enum=ImageCategory, + table_name="image_categories", + primary_key_name="category_name", + cursor=self._cursor, + ) + + # Create the `tags` table. TODO: do this elsewhere, shouldn't be in images db service + self._cursor.execute( + """--sql + CREATE TABLE IF NOT EXISTS tags ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + tag_name TEXT UNIQUE NOT NULL + ); + """ + ) + + # Create the `images_tags` junction table. + self._cursor.execute( + """--sql + CREATE TABLE IF NOT EXISTS images_tags ( + image_id TEXT, + tag_id INTEGER, + PRIMARY KEY (image_id, tag_id), + FOREIGN KEY(image_id) REFERENCES images(id) ON DELETE CASCADE, + FOREIGN KEY(tag_id) REFERENCES tags(id) ON DELETE CASCADE + ); + """ + ) + + # Create the `images_favorites` table. + self._cursor.execute( + """--sql + CREATE TABLE IF NOT EXISTS images_favorites ( + image_id TEXT PRIMARY KEY, + favorited_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY(image_id) REFERENCES images(id) ON DELETE CASCADE + ); + """ + ) + + def get(self, image_type: ImageType, image_name: str) -> Union[ImageRecord, None]: + try: + self._lock.acquire() + + self._cursor.execute( + f"""--sql + SELECT * FROM images + WHERE id = ?; + """, + (image_name,), + ) + + result = self._cursor.fetchone() + except sqlite3.Error as e: + self._conn.rollback() + raise self.ImageRecordNotFoundException from e + finally: + self._lock.release() + + if not result: + raise self.ImageRecordNotFoundException + + return deserialize_image_record(result) + + def get_many( + self, + image_type: ImageType, + image_category: ImageCategory, + page: int = 0, + per_page: int = 10, + ) -> PaginatedResults[ImageRecord]: + try: + self._lock.acquire() + + self._cursor.execute( + f"""--sql + SELECT * FROM images + WHERE image_type = ? AND image_category = ? + LIMIT ? OFFSET ?; + """, + (image_type.value, image_category.value, per_page, page * per_page), + ) + + result = self._cursor.fetchall() + + images = list(map(lambda r: deserialize_image_record(r), result)) + + self._cursor.execute( + """--sql + SELECT count(*) FROM images + WHERE image_type = ? AND image_category = ? + """, + (image_type.value, image_category.value), + ) + + count = self._cursor.fetchone()[0] + except sqlite3.Error as e: + self._conn.rollback() + raise e + finally: + self._lock.release() + + pageCount = int(count / per_page) + 1 + + return PaginatedResults( + items=images, page=page, pages=pageCount, per_page=per_page, total=count + ) + + def delete(self, image_type: ImageType, image_name: str) -> None: + try: + self._lock.acquire() + self._cursor.execute( + """--sql + DELETE FROM images + WHERE id = ?; + """, + (image_name,), + ) + self._conn.commit() + except sqlite3.Error as e: + self._conn.rollback() + raise ImageRecordServiceBase.ImageRecordDeleteException from e + finally: + self._lock.release() + + def save( + self, + image_name: str, + image_type: ImageType, + image_category: ImageCategory, + session_id: Optional[str], + node_id: Optional[str], + metadata: Union[ + GeneratedImageOrLatentsMetadata, UploadedImageOrLatentsMetadata, None + ], + created_at: str, + ) -> None: + try: + metadata_json = ( + None if metadata is None else metadata.json(exclude_none=True) + ) + self._lock.acquire() + self._cursor.execute( + """--sql + INSERT OR IGNORE INTO images ( + id, + image_type, + image_category, + node_id, + session_id, + metadata + created_at + ) + VALUES (?, ?, ?, ?, ?, ?, ?); + """, + ( + image_name, + image_type.value, + image_category.value, + node_id, + session_id, + metadata_json, + created_at, + ), + ) + self._conn.commit() + except sqlite3.Error as e: + self._conn.rollback() + raise ImageRecordServiceBase.ImageRecordNotFoundException from e + finally: + self._lock.release() diff --git a/invokeai/app/services/image_storage.py b/invokeai/app/services/image_storage.py index e2593dd473..7610ac62bf 100644 --- a/invokeai/app/services/image_storage.py +++ b/invokeai/app/services/image_storage.py @@ -27,7 +27,25 @@ from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail class ImageStorageBase(ABC): - """Responsible for storing and retrieving images.""" + """Low-level service responsible for storing and retrieving images.""" + + class ImageFileNotFoundException(Exception): + """Raised when an image file is not found in storage.""" + + def __init__(self, message="Image file not found"): + super().__init__(message) + + class ImageFileSaveException(Exception): + """Raised when an image cannot be saved.""" + + def __init__(self, message="Image file not saved"): + super().__init__(message) + + class ImageFileDeleteException(Exception): + """Raised when an image cannot be deleted.""" + + def __init__(self, message="Image file not deleted"): + super().__init__(message) @abstractmethod def get(self, image_type: ImageType, image_name: str) -> Image: @@ -136,7 +154,7 @@ class DiskImageStorage(ImageStorageBase): page_of_images.append( ImageResponse( - image_type=image_type.value, + image_type=image_type, image_name=filename, # TODO: DiskImageStorage should not be building URLs...? image_url=self.get_uri(image_type, filename), @@ -164,14 +182,17 @@ class DiskImageStorage(ImageStorageBase): ) def get(self, image_type: ImageType, image_name: str) -> Image: - image_path = self.get_path(image_type, image_name) - cache_item = self.__get_cache(image_path) - if cache_item: - return cache_item + try: + image_path = self.get_path(image_type, image_name) + cache_item = self.__get_cache(image_path) + if cache_item: + return cache_item - image = PILImage.open(image_path) - self.__set_cache(image_path, image) - return image + image = PILImage.open(image_path) + self.__set_cache(image_path, image) + return image + except Exception as e: + raise ImageStorageBase.ImageFileNotFoundException from e # TODO: make this a bit more flexible for e.g. cloud storage def get_path( @@ -209,8 +230,10 @@ class DiskImageStorage(ImageStorageBase): try: os.stat(path) return True - except Exception: + except FileNotFoundError: return False + except Exception as e: + raise e def save( self, @@ -219,45 +242,53 @@ class DiskImageStorage(ImageStorageBase): image: Image, metadata: InvokeAIMetadata | None = None, ) -> SavedImage: - image_path = self.get_path(image_type, image_name) + try: + image_path = self.get_path(image_type, image_name) - # TODO: Reading the image and then saving it strips the metadata... - if metadata: - pnginfo = build_invokeai_metadata_pnginfo(metadata=metadata) - image.save(image_path, "PNG", pnginfo=pnginfo) - else: - image.save(image_path) # this saved image has an empty info + # TODO: Reading the image and then saving it strips the metadata... + if metadata: + pnginfo = build_invokeai_metadata_pnginfo(metadata=metadata) + image.save(image_path, "PNG", pnginfo=pnginfo) + else: + image.save(image_path) # this saved image has an empty info - thumbnail_name = get_thumbnail_name(image_name) - thumbnail_path = self.get_path(image_type, thumbnail_name, is_thumbnail=True) - thumbnail_image = make_thumbnail(image) - thumbnail_image.save(thumbnail_path) + thumbnail_name = get_thumbnail_name(image_name) + thumbnail_path = self.get_path( + image_type, thumbnail_name, is_thumbnail=True + ) + thumbnail_image = make_thumbnail(image) + thumbnail_image.save(thumbnail_path) - self.__set_cache(image_path, image) - self.__set_cache(thumbnail_path, thumbnail_image) + self.__set_cache(image_path, image) + self.__set_cache(thumbnail_path, thumbnail_image) - return SavedImage( - image_name=image_name, - thumbnail_name=thumbnail_name, - created=int(os.path.getctime(image_path)), - ) + return SavedImage( + image_name=image_name, + thumbnail_name=thumbnail_name, + created=int(os.path.getctime(image_path)), + ) + except Exception as e: + raise ImageStorageBase.ImageFileSaveException from e def delete(self, image_type: ImageType, image_name: str) -> None: - basename = os.path.basename(image_name) - image_path = self.get_path(image_type, basename) + try: + basename = os.path.basename(image_name) + image_path = self.get_path(image_type, basename) - if os.path.exists(image_path): - send2trash(image_path) - if image_path in self.__cache: - del self.__cache[image_path] + if os.path.exists(image_path): + send2trash(image_path) + if image_path in self.__cache: + del self.__cache[image_path] - thumbnail_name = get_thumbnail_name(image_name) - thumbnail_path = self.get_path(image_type, thumbnail_name, True) + thumbnail_name = get_thumbnail_name(image_name) + thumbnail_path = self.get_path(image_type, thumbnail_name, True) - if os.path.exists(thumbnail_path): - send2trash(thumbnail_path) - if thumbnail_path in self.__cache: - del self.__cache[thumbnail_path] + if os.path.exists(thumbnail_path): + send2trash(thumbnail_path) + if thumbnail_path in self.__cache: + del self.__cache[thumbnail_path] + except Exception as e: + raise ImageStorageBase.ImageFileDeleteException from e def __get_cache(self, image_name: str) -> Image | None: return None if image_name not in self.__cache else self.__cache[image_name] diff --git a/invokeai/app/services/images.py b/invokeai/app/services/images.py new file mode 100644 index 0000000000..190ddaa8d6 --- /dev/null +++ b/invokeai/app/services/images.py @@ -0,0 +1,219 @@ +from typing import Union +import uuid +from PIL.Image import Image as PILImageType +from invokeai.app.models.image import ImageCategory, ImageType +from invokeai.app.models.metadata import ( + GeneratedImageOrLatentsMetadata, + UploadedImageOrLatentsMetadata, +) +from invokeai.app.services.image_db import ( + ImageRecordServiceBase, +) +from invokeai.app.services.models.image_record import ImageRecord +from invokeai.app.services.image_storage import ImageStorageBase +from invokeai.app.services.item_storage import PaginatedResults +from invokeai.app.services.metadata import MetadataServiceBase +from invokeai.app.services.urls import UrlServiceBase +from invokeai.app.util.misc import get_iso_timestamp + + +class ImageServiceDependencies: + """Service dependencies for the ImageManagementService.""" + + db: ImageRecordServiceBase + storage: ImageStorageBase + metadata: MetadataServiceBase + urls: UrlServiceBase + + def __init__( + self, + image_db_service: ImageRecordServiceBase, + image_storage_service: ImageStorageBase, + image_metadata_service: MetadataServiceBase, + url_service: UrlServiceBase, + ): + self.db = image_db_service + self.storage = image_storage_service + self.metadata = image_metadata_service + self.url = url_service + + +class ImageService: + """High-level service for image management.""" + + _services: ImageServiceDependencies + + def __init__( + self, + image_db_service: ImageRecordServiceBase, + image_storage_service: ImageStorageBase, + image_metadata_service: MetadataServiceBase, + url_service: UrlServiceBase, + ): + self._services = ImageServiceDependencies( + image_db_service=image_db_service, + image_storage_service=image_storage_service, + image_metadata_service=image_metadata_service, + url_service=url_service, + ) + + def _create_image_name( + self, + image_type: ImageType, + image_category: ImageCategory, + node_id: Union[str, None], + session_id: Union[str, None], + ) -> str: + """Creates an image name.""" + uuid_str = str(uuid.uuid4()) + + if node_id is not None and session_id is not None: + return f"{image_type.value}_{image_category.value}_{session_id}_{node_id}_{uuid_str}.png" + + return f"{image_type.value}_{image_category.value}_{uuid_str}.png" + + def create( + self, + image: PILImageType, + image_type: ImageType, + image_category: ImageCategory, + node_id: Union[str, None], + session_id: Union[str, None], + metadata: Union[ + GeneratedImageOrLatentsMetadata, UploadedImageOrLatentsMetadata, None + ], + ) -> ImageRecord: + """Creates an image, storing the file and its metadata.""" + image_name = self._create_image_name( + image_type=image_type, + image_category=image_category, + node_id=node_id, + session_id=session_id, + ) + + timestamp = get_iso_timestamp() + + try: + # TODO: Consider using a transaction here to ensure consistency between storage and database + self._services.storage.save( + image_type=image_type, + image_name=image_name, + image=image, + metadata=metadata, + ) + + self._services.db.save( + image_name=image_name, + image_type=image_type, + image_category=image_category, + node_id=node_id, + session_id=session_id, + metadata=metadata, + created_at=timestamp, + ) + + image_url = self._services.url.get_image_url( + image_type=image_type, image_name=image_name + ) + + thumbnail_url = self._services.url.get_thumbnail_url( + image_type=image_type, image_name=image_name + ) + + return ImageRecord( + image_name=image_name, + image_type=image_type, + image_category=image_category, + node_id=node_id, + session_id=session_id, + metadata=metadata, + created_at=timestamp, + image_url=image_url, + thumbnail_url=thumbnail_url, + ) + except ImageRecordServiceBase.ImageRecordSaveException: + # TODO: log this + raise + except ImageStorageBase.ImageFileSaveException: + # TODO: log this + raise + + def get_pil_image(self, image_type: ImageType, image_name: str) -> PILImageType: + """Gets an image as a PIL image.""" + try: + pil_image = self._services.storage.get( + image_type=image_type, image_name=image_name + ) + return pil_image + except ImageStorageBase.ImageFileNotFoundException: + # TODO: log this + raise + + def get_record(self, image_type: ImageType, image_name: str) -> ImageRecord: + """Gets an image record.""" + try: + image_record = self._services.db.get( + image_type=image_type, image_name=image_name + ) + return image_record + except ImageRecordServiceBase.ImageRecordNotFoundException: + # TODO: log this + raise + + def delete(self, image_type: ImageType, image_name: str): + """Deletes an image.""" + # TODO: Consider using a transaction here to ensure consistency between storage and database + try: + self._services.storage.delete(image_type=image_type, image_name=image_name) + self._services.db.delete(image_type=image_type, image_name=image_name) + except ImageRecordServiceBase.ImageRecordDeleteException: + # TODO: log this + raise + except ImageStorageBase.ImageFileDeleteException: + # TODO: log this + raise + + def get_many( + self, + image_type: ImageType, + image_category: ImageCategory, + page: int = 0, + per_page: int = 10, + ) -> PaginatedResults[ImageRecord]: + """Gets a paginated list of image records.""" + try: + results = self._services.db.get_many( + image_type=image_type, + image_category=image_category, + page=page, + per_page=per_page, + ) + + for r in results.items: + r.image_url = self._services.url.get_image_url( + image_type=image_type, image_name=r.image_name + ) + + r.thumbnail_url = self._services.url.get_thumbnail_url( + image_type=image_type, image_name=r.image_name + ) + + return results + except Exception as e: + raise e + + def add_tag(self, image_type: ImageType, image_id: str, tag: str) -> None: + """Adds a tag to an image.""" + raise NotImplementedError("The 'add_tag' method is not implemented yet.") + + def remove_tag(self, image_type: ImageType, image_id: str, tag: str) -> None: + """Removes a tag from an image.""" + raise NotImplementedError("The 'remove_tag' method is not implemented yet.") + + def favorite(self, image_type: ImageType, image_id: str) -> None: + """Favorites an image.""" + raise NotImplementedError("The 'favorite' method is not implemented yet.") + + def unfavorite(self, image_type: ImageType, image_id: str) -> None: + """Unfavorites an image.""" + raise NotImplementedError("The 'unfavorite' method is not implemented yet.") diff --git a/invokeai/app/services/invocation_services.py b/invokeai/app/services/invocation_services.py index d4c0c06b65..74fb7accff 100644 --- a/invokeai/app/services/invocation_services.py +++ b/invokeai/app/services/invocation_services.py @@ -1,7 +1,12 @@ # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team -from typing import types +from types import ModuleType +from invokeai.app.services.image_db import ( + ImageRecordServiceBase, +) +from invokeai.app.services.images import ImageService from invokeai.app.services.metadata import MetadataServiceBase +from invokeai.app.services.urls import UrlServiceBase from invokeai.backend import ModelManager from .events import EventServiceBase @@ -12,6 +17,7 @@ from .invocation_queue import InvocationQueueABC from .item_storage import ItemStorageABC from .config import InvokeAISettings + class InvocationServices: """Services that can be used by invocations""" @@ -23,26 +29,32 @@ class InvocationServices: model_manager: ModelManager restoration: RestorationServices configuration: InvokeAISettings - + images_db: ImageRecordServiceBase + urls: UrlServiceBase + images_new: ImageService + # NOTE: we must forward-declare any types that include invocations, since invocations can use services graph_library: ItemStorageABC["LibraryGraph"] graph_execution_manager: ItemStorageABC["GraphExecutionState"] processor: "InvocationProcessorABC" def __init__( - self, - model_manager: ModelManager, - events: EventServiceBase, - logger: types.ModuleType, - latents: LatentsStorageBase, - images: ImageStorageBase, - metadata: MetadataServiceBase, - queue: InvocationQueueABC, - graph_library: ItemStorageABC["LibraryGraph"], - graph_execution_manager: ItemStorageABC["GraphExecutionState"], - processor: "InvocationProcessorABC", - restoration: RestorationServices, - configuration: InvokeAISettings=None, + self, + model_manager: ModelManager, + events: EventServiceBase, + logger: ModuleType, + latents: LatentsStorageBase, + images: ImageStorageBase, + metadata: MetadataServiceBase, + queue: InvocationQueueABC, + images_db: ImageRecordServiceBase, + images_new: ImageService, + urls: UrlServiceBase, + graph_library: ItemStorageABC["LibraryGraph"], + graph_execution_manager: ItemStorageABC["GraphExecutionState"], + processor: "InvocationProcessorABC", + restoration: RestorationServices, + configuration: InvokeAISettings=None, ): self.model_manager = model_manager self.events = events @@ -51,8 +63,13 @@ class InvocationServices: self.images = images self.metadata = metadata self.queue = queue + self.images_db = images_db + self.images_new = images_new + self.urls = urls self.graph_library = graph_library self.graph_execution_manager = graph_execution_manager self.processor = processor self.restoration = restoration self.configuration = configuration + + diff --git a/invokeai/app/services/metadata.py b/invokeai/app/services/metadata.py index a7f2378ab1..bc1cfdb063 100644 --- a/invokeai/app/services/metadata.py +++ b/invokeai/app/services/metadata.py @@ -22,16 +22,24 @@ class MetadataLatentsField(TypedDict): class MetadataColorField(TypedDict): """Pydantic-less ColorField, used for metadata parsing""" + r: int g: int b: int a: int - # TODO: This is a placeholder for `InvocationsUnion` pending resolution of circular imports NodeMetadata = Dict[ - str, None | str | int | float | bool | MetadataImageField | MetadataLatentsField | MetadataColorField + str, + None + | str + | int + | float + | bool + | MetadataImageField + | MetadataLatentsField + | MetadataColorField, ] @@ -67,6 +75,11 @@ class MetadataServiceBase(ABC): """Builds an InvokeAIMetadata object""" pass + @abstractmethod + def create_metadata(self, session_id: str, node_id: str) -> dict: + """Creates metadata for a result""" + pass + class PngMetadataService(MetadataServiceBase): """Handles loading and building metadata for images.""" diff --git a/invokeai/app/services/models/image_record.py b/invokeai/app/services/models/image_record.py new file mode 100644 index 0000000000..600508e57f --- /dev/null +++ b/invokeai/app/services/models/image_record.py @@ -0,0 +1,29 @@ +import datetime +from typing import Literal, Optional, Union +from pydantic import BaseModel, Field +from invokeai.app.models.metadata import ( + GeneratedImageOrLatentsMetadata, + UploadedImageOrLatentsMetadata, +) +from invokeai.app.models.image import ImageCategory, ImageType +from invokeai.app.models.resources import ResourceType + + +class ImageRecord(BaseModel): + """Deserialized image record.""" + + image_name: str = Field(description="The name of the image.") + image_type: ImageType = Field(description="The type of the image.") + image_category: ImageCategory = Field(description="The category of the image.") + created_at: Union[datetime.datetime, str] = Field( + description="The created timestamp of the image." + ) + session_id: Optional[str] = Field(default=None, description="The session ID.") + node_id: Optional[str] = Field(default=None, description="The node ID.") + metadata: Optional[ + Union[GeneratedImageOrLatentsMetadata, UploadedImageOrLatentsMetadata] + ] = Field(default=None, description="The image's metadata.") + image_url: Optional[str] = Field(default=None, description="The URL of the image.") + thumbnail_url: Optional[str] = Field( + default=None, description="The thumbnail URL of the image." + ) diff --git a/invokeai/app/services/proposeddesign.py b/invokeai/app/services/proposeddesign.py new file mode 100644 index 0000000000..712d7224e9 --- /dev/null +++ b/invokeai/app/services/proposeddesign.py @@ -0,0 +1,657 @@ +from abc import ABC, abstractmethod +from enum import Enum +import enum +import sqlite3 +import threading +from typing import Optional, Type, TypeVar, Union +from PIL.Image import Image as PILImage +from pydantic import BaseModel, Field +from torch import Tensor + +from invokeai.app.services.item_storage import PaginatedResults + + +""" +Substantial proposed changes to the management of images and tensor. + +tl;dr: +With the upcoming move to latents-only nodes, we need to handle metadata differently. After struggling with this unsuccessfully - trying to smoosh it in to the existing setup - I believe we need to expand the scope of the refactor to include the management of images and latents - and make `latents` a special case of `tensor`. + +full story: +The consensus for tensor-only nodes' metadata was to traverse the execution graph and grab the core parameters to write to the image. This was straightforward, and I've written functions to find the nearest t2l/l2l, noise, and compel nodes and build the metadata from those. + +But struggling to integrate this and the associated edge cases this brought up a number of issues deeper in the system (some of which I had previously implemented). The ImageStorageService is doing way too much, and we have a need to be able to retrieve sessions the session given image/latents id, which is not currently feasible due to SQLite's JSON parsing performance. + +I made a new ResultsService and `results` table in the db to facilitate this. This first attempt failed because it doesn't handle uploads and leaves the codebase messy. + +So I've spent the day trying to figure out to handle this in a sane way and think I've got something decent. I've described some changes to service bases and the database below. + +The gist of it is to store the core parameters for an image in its metadata when the image is saved, but never to read from it. Instead, the same metadata is stored in the database, which will be set up for efficient access. So when a page of images is requested, the metadata comes from the db instead of a filesystem operation. + +The URL generation responsibilities have been split off the image storage service in to a URL service. New database services/tables for images and tensor are added. These services will provide paginated images/tensors for the API to serve. This also paves the way for handling tensors as first-class outputs. +""" + + +# TODO: Make a new model for this +class ResourceOrigin(str, Enum): + """The origin of a resource (eg image or tensor).""" + + RESULTS = "results" + UPLOADS = "uploads" + INTERMEDIATES = "intermediates" + + +class ImageKind(str, Enum): + """The kind of an image.""" + + IMAGE = "image" + CONTROL_IMAGE = "control_image" + + +class TensorKind(str, Enum): + """The kind of a tensor.""" + + IMAGE_TENSOR = "tensor" + CONDITIONING = "conditioning" + + +""" +Core Generation Metadata Pydantic Model + +I've already implemented the code to traverse a session to build this object. +""" + + +class CoreGenerationMetadata(BaseModel): + """Core generation metadata for an image/tensor generated in InvokeAI. + + Generated by traversing the execution graph, collecting the parameters of the nearest ancestors of a given node. + + Full metadata may be accessed by querying for the session in the `graph_executions` table. + """ + + positive_conditioning: Optional[str] = Field( + description="The positive conditioning." + ) + negative_conditioning: Optional[str] = Field( + description="The negative conditioning." + ) + width: Optional[int] = Field(description="Width of the image/tensor in pixels.") + height: Optional[int] = Field(description="Height of the image/tensor in pixels.") + seed: Optional[int] = Field(description="The seed used for noise generation.") + cfg_scale: Optional[float] = Field( + description="The classifier-free guidance scale." + ) + steps: Optional[int] = Field(description="The number of steps used for inference.") + scheduler: Optional[str] = Field(description="The scheduler used for inference.") + model: Optional[str] = Field(description="The model used for inference.") + strength: Optional[float] = Field( + description="The strength used for image-to-image/tensor-to-tensor." + ) + image: Optional[str] = Field(description="The ID of the initial image.") + tensor: Optional[str] = Field(description="The ID of the initial tensor.") + # Pending model refactor: + # vae: Optional[str] = Field(description="The VAE used for decoding.") + # unet: Optional[str] = Field(description="The UNet used dor inference.") + # clip: Optional[str] = Field(description="The CLIP Encoder used for conditioning.") + + +""" +Minimal Uploads Metadata Model +""" + + +class UploadsMetadata(BaseModel): + """Limited metadata for an uploaded image/tensor.""" + + width: Optional[int] = Field(description="Width of the image/tensor in pixels.") + height: Optional[int] = Field(description="Height of the image/tensor in pixels.") + # The extra field will be the contents of the PNG file's tEXt chunk. It may have come + # from another SD application or InvokeAI, so we need to make it very flexible. I think it's + # best to just store it as a string and let the frontend parse it. + # If the upload is a tensor type, this will be omitted. + extra: Optional[str] = Field( + description="Extra metadata, extracted from the PNG tEXt chunk." + ) + + +""" +Slimmed-down Image Storage Service Base + - No longer lists images or generates URLs - only stores and retrieves images. + - OSS implementation for disk storage +""" + + +class ImageStorageBase(ABC): + """Responsible for storing and retrieving images.""" + + @abstractmethod + def save( + self, + image: PILImage, + image_kind: ImageKind, + origin: ResourceOrigin, + context_id: str, + node_id: str, + metadata: CoreGenerationMetadata, + ) -> str: + """Saves an image and its thumbnail, returning its unique identifier.""" + pass + + @abstractmethod + def get(self, id: str, thumbnail: bool = False) -> Union[PILImage, None]: + """Retrieves an image as a PIL Image.""" + pass + + @abstractmethod + def delete(self, id: str) -> None: + """Deletes an image.""" + pass + + +class TensorStorageBase(ABC): + """Responsible for storing and retrieving tensors.""" + + @abstractmethod + def save( + self, + tensor: Tensor, + tensor_kind: TensorKind, + origin: ResourceOrigin, + context_id: str, + node_id: str, + metadata: CoreGenerationMetadata, + ) -> str: + """Saves a tensor, returning its unique identifier.""" + pass + + @abstractmethod + def get(self, id: str, thumbnail: bool = False) -> Union[Tensor, None]: + """Retrieves a tensor as a torch Tensor.""" + pass + + @abstractmethod + def delete(self, id: str) -> None: + """Deletes a tensor.""" + pass + + +""" +New Url Service Base + - Abstracts the logic for generating URLs out of the storage service + - OSS implementation for locally-hosted URLs + - Also provides a method to get the internal path to a resource (for OSS, the FS path) +""" + + +class ResourceLocationServiceBase(ABC): + """Responsible for locating resources (eg images or tensors).""" + + @abstractmethod + def get_url(self, id: str) -> str: + """Gets the URL for a resource.""" + pass + + @abstractmethod + def get_path(self, id: str) -> str: + """Gets the path for a resource.""" + pass + + +""" +New Images Database Service Base + +This is a new service that will be responsible for the new `images` table(s): + - Storing images in the table + - Retrieving individual images and pages of images + - Deleting individual images + +Operations will typically use joins with the various `images` tables. +""" + + +class ImagesDbServiceBase(ABC): + """Responsible for interfacing with `images` table.""" + + class GeneratedImageEntity(BaseModel): + id: str = Field(description="The unique identifier for the image.") + session_id: str = Field(description="The session ID.") + node_id: str = Field(description="The node ID.") + metadata: CoreGenerationMetadata = Field( + description="The metadata for the image." + ) + + class UploadedImageEntity(BaseModel): + id: str = Field(description="The unique identifier for the image.") + metadata: UploadsMetadata = Field(description="The metadata for the image.") + + @abstractmethod + def get(self, id: str) -> Union[GeneratedImageEntity, UploadedImageEntity, None]: + """Gets an image from the `images` table.""" + pass + + @abstractmethod + def get_many( + self, image_kind: ImageKind, page: int = 0, per_page: int = 10 + ) -> PaginatedResults[Union[GeneratedImageEntity, UploadedImageEntity]]: + """Gets a page of images from the `images` table.""" + pass + + @abstractmethod + def delete(self, id: str) -> None: + """Deletes an image from the `images` table.""" + pass + + @abstractmethod + def set( + self, + id: str, + image_kind: ImageKind, + session_id: Optional[str], + node_id: Optional[str], + metadata: CoreGenerationMetadata | UploadsMetadata, + ) -> None: + """Sets an image in the `images` table.""" + pass + + +""" +New Tensor Database Service Base + +This is a new service that will be responsible for the new `tensor` table: + - Storing tensor in the table + - Retrieving individual tensor and pages of tensor + - Deleting individual tensor + +Operations will always use joins with the `tensor_metadata` table. +""" + + +class TensorDbServiceBase(ABC): + """Responsible for interfacing with `tensor` table.""" + + class GeneratedTensorEntity(BaseModel): + id: str = Field(description="The unique identifier for the tensor.") + session_id: str = Field(description="The session ID.") + node_id: str = Field(description="The node ID.") + metadata: CoreGenerationMetadata = Field( + description="The metadata for the tensor." + ) + + class UploadedTensorEntity(BaseModel): + id: str = Field(description="The unique identifier for the tensor.") + metadata: UploadsMetadata = Field(description="The metadata for the tensor.") + + @abstractmethod + def get(self, id: str) -> Union[GeneratedTensorEntity, UploadedTensorEntity, None]: + """Gets a tensor from the `tensor` table.""" + pass + + @abstractmethod + def get_many( + self, tensor_kind: TensorKind, page: int = 0, per_page: int = 10 + ) -> PaginatedResults[Union[GeneratedTensorEntity, UploadedTensorEntity]]: + """Gets a page of tensor from the `tensor` table.""" + pass + + @abstractmethod + def delete(self, id: str) -> None: + """Deletes a tensor from the `tensor` table.""" + pass + + @abstractmethod + def set( + self, + id: str, + tensor_kind: TensorKind, + session_id: Optional[str], + node_id: Optional[str], + metadata: CoreGenerationMetadata | UploadsMetadata, + ) -> None: + """Sets a tensor in the `tensor` table.""" + pass + + +""" +Database Changes + +The existing tables will remain as-is, new tables will be added. + +Tensor now also have the same types as images - `results`, `intermediates`, `uploads`. Storage, retrieval, and operations may diverge from images in the future, so they are managed separately. + +A few `images` tables are created to store all images: + - `results` and `intermediates` images have additional data: `session_id` and `node_id`, and may be further differentiated in the future. For this reason, they each get their own table. + - `uploads` do not get their own table, as they are never going to have more than an `id`, `image_kind` and `timestamp`. + - `images_metadata` holds the same image metadata that is written to the image. This table, along with the URL service, allow us to more efficiently serve images without having to read the image from storage. + +The same tables are made for `tensor` and for the moment, implementation is expected to be identical. + +Schemas for each table below. + +Insertions and updates of ancillary tables (e.g. `results_images`, `images_metadata`, etc) will need to be done manually in the services, but should be straightforward. Deletion via cascading will be handled by the database. +""" + + +def create_sql_values_string_from_string_enum(enum: Type[Enum]): + """ + Creates a string of the form "('value1'), ('value2'), ..., ('valueN')" from a StrEnum. + """ + + delimiter = ", " + values = [f"('{e.value}')" for e in enum] + return delimiter.join(values) + + +def create_sql_table_from_enum( + enum: Type[Enum], + table_name: str, + primary_key_name: str, + cursor: sqlite3.Cursor, + lock: threading.Lock, +): + """ + Creates and populates a table to be used as a functional enum. + """ + + try: + lock.acquire() + + values_string = create_sql_values_string_from_string_enum(enum) + + cursor.execute( + f"""--sql + CREATE TABLE IF NOT EXISTS {table_name} ( + {primary_key_name} TEXT PRIMARY KEY + ); + """ + ) + cursor.execute( + f"""--sql + INSERT OR IGNORE INTO {table_name} ({primary_key_name}) VALUES {values_string}; + """ + ) + finally: + lock.release() + + +""" +`resource_origins` functions as an enum for the ResourceOrigin model. +""" + + +def create_resource_origins_table(cursor: sqlite3.Cursor, lock: threading.Lock): + create_sql_table_from_enum( + enum=ResourceOrigin, + table_name="resource_origins", + primary_key_name="origin_name", + cursor=cursor, + lock=lock, + ) + + +""" +`image_kinds` functions as an enum for the ImageType model. +""" + + +def create_image_kinds_table(cursor: sqlite3.Cursor, lock: threading.Lock): + create_sql_table_from_enum( + enum=ImageKind, + table_name="image_kinds", + primary_key_name="kind_name", + cursor=cursor, + lock=lock, + ) + + +""" +`tensor_kinds` functions as an enum for the TensorType model. +""" + + +def create_tensor_kinds_table(cursor: sqlite3.Cursor, lock: threading.Lock): + create_sql_table_from_enum( + enum=TensorKind, + table_name="tensor_kinds", + primary_key_name="kind_name", + cursor=cursor, + lock=lock, + ) + + +""" +`images` stores all images, regardless of type +""" + + +def create_images_table(cursor: sqlite3.Cursor, lock: threading.Lock): + try: + lock.acquire() + + cursor.execute( + """--sql + CREATE TABLE IF NOT EXISTS images ( + id TEXT PRIMARY KEY, + origin TEXT, + image_kind TEXT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY(origin) REFERENCES resource_origins(origin_name), + FOREIGN KEY(image_kind) REFERENCES image_kinds(kind_name) + ); + """ + ) + cursor.execute( + """--sql + CREATE UNIQUE INDEX IF NOT EXISTS idx_images_id ON images(id); + """ + ) + cursor.execute( + """--sql + CREATE INDEX IF NOT EXISTS idx_images_origin ON images(origin); + """ + ) + cursor.execute( + """--sql + CREATE INDEX IF NOT EXISTS idx_images_image_kind ON images(image_kind); + """ + ) + finally: + lock.release() + + +""" +`image_results` stores additional data specific to `results` images. +""" + + +def create_image_results_table(cursor: sqlite3.Cursor, lock: threading.Lock): + try: + lock.acquire() + + cursor.execute( + """--sql + CREATE TABLE IF NOT EXISTS image_results ( + images_id TEXT PRIMARY KEY, + session_id TEXT NOT NULL, + node_id TEXT NOT NULL, + FOREIGN KEY(images_id) REFERENCES images(id) ON DELETE CASCADE + ); + """ + ) + cursor.execute( + """--sql + CREATE UNIQUE INDEX IF NOT EXISTS idx_image_results_images_id ON image_results(id); + """ + ) + finally: + lock.release() + + +""" +`image_intermediates` stores additional data specific to `intermediates` images +""" + + +def create_image_intermediates_table(cursor: sqlite3.Cursor, lock: threading.Lock): + try: + lock.acquire() + + cursor.execute( + """--sql + CREATE TABLE IF NOT EXISTS image_intermediates ( + images_id TEXT PRIMARY KEY, + session_id TEXT NOT NULL, + node_id TEXT NOT NULL, + FOREIGN KEY(images_id) REFERENCES images(id) ON DELETE CASCADE + ); + """ + ) + cursor.execute( + """--sql + CREATE UNIQUE INDEX IF NOT EXISTS idx_image_intermediates_images_id ON image_intermediates(id); + """ + ) + finally: + lock.release() + + +""" +`images_metadata` stores basic metadata for any image type +""" + + +def create_images_metadata_table(cursor: sqlite3.Cursor, lock: threading.Lock): + try: + lock.acquire() + + cursor.execute( + """--sql + CREATE TABLE IF NOT EXISTS images_metadata ( + images_id TEXT PRIMARY KEY, + metadata TEXT, + FOREIGN KEY(images_id) REFERENCES images(id) ON DELETE CASCADE + ); + """ + ) + cursor.execute( + """--sql + CREATE UNIQUE INDEX IF NOT EXISTS idx_images_metadata_images_id ON images_metadata(images_id); + """ + ) + finally: + lock.release() + + +# `tensor` table: stores references to tensor + + +def create_tensors_table(cursor: sqlite3.Cursor, lock: threading.Lock): + try: + lock.acquire() + + cursor.execute( + """--sql + CREATE TABLE IF NOT EXISTS tensors ( + id TEXT PRIMARY KEY, + origin TEXT, + tensor_kind TEXT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY(origin) REFERENCES resource_origins(origin_name), + FOREIGN KEY(tensor_kind) REFERENCES tensor_kinds(kind_name), + ); + """ + ) + cursor.execute( + """--sql + CREATE UNIQUE INDEX IF NOT EXISTS idx_tensors_id ON tensors(id); + """ + ) + cursor.execute( + """--sql + CREATE INDEX IF NOT EXISTS idx_tensors_origin ON tensors(origin); + """ + ) + cursor.execute( + """--sql + CREATE INDEX IF NOT EXISTS idx_tensors_tensor_kind ON tensors(tensor_kind); + """ + ) + finally: + lock.release() + + +# `results_tensor` stores additional data specific to `result` tensor + + +def create_tensor_results_table(cursor: sqlite3.Cursor, lock: threading.Lock): + try: + lock.acquire() + + cursor.execute( + """--sql + CREATE TABLE IF NOT EXISTS tensor_results ( + tensor_id TEXT PRIMARY KEY, + session_id TEXT NOT NULL, + node_id TEXT NOT NULL, + FOREIGN KEY(tensor_id) REFERENCES tensors(id) ON DELETE CASCADE + ); + """ + ) + cursor.execute( + """--sql + CREATE UNIQUE INDEX IF NOT EXISTS idx_tensor_results_tensor_id ON tensor_results(tensor_id); + """ + ) + finally: + lock.release() + + +# `tensor_intermediates` stores additional data specific to `intermediate` tensor + + +def create_tensor_intermediates_table(cursor: sqlite3.Cursor, lock: threading.Lock): + try: + lock.acquire() + + cursor.execute( + """--sql + CREATE TABLE IF NOT EXISTS tensor_intermediates ( + tensor_id TEXT PRIMARY KEY, + session_id TEXT NOT NULL, + node_id TEXT NOT NULL, + FOREIGN KEY(tensor_id) REFERENCES tensors(id) ON DELETE CASCADE + ); + """ + ) + cursor.execute( + """--sql + CREATE UNIQUE INDEX IF NOT EXISTS idx_tensor_intermediates_tensor_id ON tensor_intermediates(tensor_id); + """ + ) + finally: + lock.release() + + +# `tensors_metadata` table: stores generated/transformed metadata for tensor + + +def create_tensors_metadata_table(cursor: sqlite3.Cursor, lock: threading.Lock): + try: + lock.acquire() + + cursor.execute( + """--sql + CREATE TABLE IF NOT EXISTS tensors_metadata ( + tensor_id TEXT PRIMARY KEY, + metadata TEXT, + FOREIGN KEY(tensor_id) REFERENCES tensors(id) ON DELETE CASCADE + ); + """ + ) + cursor.execute( + """--sql + CREATE UNIQUE INDEX IF NOT EXISTS idx_tensors_metadata_tensor_id ON tensors_metadata(tensor_id); + """ + ) + finally: + lock.release() diff --git a/invokeai/app/services/results.py b/invokeai/app/services/results.py new file mode 100644 index 0000000000..df7bf7bc6b --- /dev/null +++ b/invokeai/app/services/results.py @@ -0,0 +1,466 @@ +from enum import Enum + +from abc import ABC, abstractmethod +import json +import sqlite3 +from threading import Lock +from typing import Any, Union + +import networkx as nx + +from pydantic import BaseModel, Field, parse_obj_as, parse_raw_as +from invokeai.app.invocations.image import ImageOutput +from invokeai.app.services.graph import Edge, GraphExecutionState +from invokeai.app.invocations.latent import LatentsOutput +from invokeai.app.services.item_storage import PaginatedResults +from invokeai.app.util.misc import get_timestamp + + +class ResultType(str, Enum): + image_output = "image_output" + latents_output = "latents_output" + + +class Result(BaseModel): + """A session result""" + + id: str = Field(description="Result ID") + session_id: str = Field(description="Session ID") + node_id: str = Field(description="Node ID") + data: Union[LatentsOutput, ImageOutput] = Field(description="The result data") + + +class ResultWithSession(BaseModel): + """A result with its session""" + + result: Result = Field(description="The result") + session: GraphExecutionState = Field(description="The session") + + +# Create a directed graph +from typing import Any, TypedDict, Union +from networkx import DiGraph +import networkx as nx +import json + + +# We need to use a loose class for nodes to allow for graceful parsing - we cannot use the stricter +# model used by the system, because we may be a graph in an old format. We can, however, use the +# Edge model, because the edge format does not change. +class LooseGraph(BaseModel): + id: str + nodes: dict[str, dict[str, Any]] + edges: list[Edge] + + +# An intermediate type used during parsing +class NearestAncestor(TypedDict): + node_id: str + metadata: dict[str, Any] + + +# The ancestor types that contain the core metadata +ANCESTOR_TYPES = ['t2l', 'l2l'] + +# The core metadata parameters in the ancestor types +ANCESTOR_PARAMS = ['steps', 'model', 'cfg_scale', 'scheduler', 'strength'] + +# The core metadata parameters in the noise node +NOISE_FIELDS = ['seed', 'width', 'height'] + +# Find nearest t2l or l2l ancestor from a given l2i node +def find_nearest_ancestor(G: DiGraph, node_id: str) -> Union[NearestAncestor, None]: + """Returns metadata for the nearest ancestor of a given node. + + Parameters: + G (DiGraph): A directed graph. + node_id (str): The ID of the starting node. + + Returns: + NearestAncestor | None: An object with the ID and metadata of the nearest ancestor. + """ + + # Retrieve the node from the graph + node = G.nodes[node_id] + + # If the node type is one of the core metadata node types, gather necessary metadata and return + if node.get('type') in ANCESTOR_TYPES: + parsed_metadata = {param: val for param, val in node.items() if param in ANCESTOR_PARAMS} + return NearestAncestor(node_id=node_id, metadata=parsed_metadata) + + + # Else, look for the ancestor in the predecessor nodes + for predecessor in G.predecessors(node_id): + result = find_nearest_ancestor(G, predecessor) + if result: + return result + + # If there are no valid ancestors, return None + return None + + +def get_additional_metadata(graph: LooseGraph, node_id: str) -> Union[dict[str, Any], None]: + """Collects additional metadata from nodes connected to a given node. + + Parameters: + graph (LooseGraph): The graph. + node_id (str): The ID of the node. + + Returns: + dict | None: A dictionary containing additional metadata. + """ + + metadata = {} + + # Iterate over all edges in the graph + for edge in graph.edges: + dest_node_id = edge.destination.node_id + dest_field = edge.destination.field + source_node = graph.nodes[edge.source.node_id] + + # If the destination node ID matches the given node ID, gather necessary metadata + if dest_node_id == node_id: + # If the destination field is 'positive_conditioning', add the 'prompt' from the source node + if dest_field == 'positive_conditioning': + metadata['positive_conditioning'] = source_node.get('prompt') + # If the destination field is 'negative_conditioning', add the 'prompt' from the source node + if dest_field == 'negative_conditioning': + metadata['negative_conditioning'] = source_node.get('prompt') + # If the destination field is 'noise', add the core noise fields from the source node + if dest_field == 'noise': + for field in NOISE_FIELDS: + metadata[field] = source_node.get(field) + return metadata + +def build_core_metadata(graph_raw: str, node_id: str) -> Union[dict, None]: + """Builds the core metadata for a given node. + + Parameters: + graph_raw (str): The graph structure as a raw string. + node_id (str): The ID of the node. + + Returns: + dict | None: A dictionary containing core metadata. + """ + + # Create a directed graph to facilitate traversal + G = nx.DiGraph() + + # Convert the raw graph string into a JSON object + graph = parse_obj_as(LooseGraph, graph_raw) + + # Add nodes and edges to the graph + for node_id, node_data in graph.nodes.items(): + G.add_node(node_id, **node_data) + for edge in graph.edges: + G.add_edge(edge.source.node_id, edge.destination.node_id) + + # Find the nearest ancestor of the given node + ancestor = find_nearest_ancestor(G, node_id) + + # If no ancestor was found, return None + if ancestor is None: + return None + + metadata = ancestor['metadata'] + ancestor_id = ancestor['node_id'] + + # Get additional metadata related to the ancestor + addl_metadata = get_additional_metadata(graph, ancestor_id) + + # If additional metadata was found, add it to the main metadata + if addl_metadata is not None: + metadata.update(addl_metadata) + + return metadata + + + +class ResultsServiceABC(ABC): + """The Results service is responsible for retrieving results.""" + + @abstractmethod + def get( + self, result_id: str, result_type: ResultType + ) -> Union[ResultWithSession, None]: + pass + + @abstractmethod + def get_many( + self, result_type: ResultType, page: int = 0, per_page: int = 10 + ) -> PaginatedResults[ResultWithSession]: + pass + + @abstractmethod + def search( + self, query: str, page: int = 0, per_page: int = 10 + ) -> PaginatedResults[ResultWithSession]: + pass + + @abstractmethod + def handle_graph_execution_state_change(self, session: GraphExecutionState) -> None: + pass + + +class SqliteResultsService(ResultsServiceABC): + """SQLite implementation of the Results service.""" + + _filename: str + _conn: sqlite3.Connection + _cursor: sqlite3.Cursor + _lock: Lock + + def __init__(self, filename: str): + super().__init__() + + self._filename = filename + self._lock = Lock() + + self._conn = sqlite3.connect( + self._filename, check_same_thread=False + ) # TODO: figure out a better threading solution + self._cursor = self._conn.cursor() + + self._create_table() + + def _create_table(self): + try: + self._lock.acquire() + self._cursor.execute( + """--sql + CREATE TABLE IF NOT EXISTS results ( + id TEXT PRIMARY KEY, -- the result's name + result_type TEXT, -- `image_output` | `latents_output` + node_id TEXT, -- the node that produced this result + session_id TEXT, -- the session that produced this result + created_at INTEGER, -- the time at which this result was created + data TEXT -- the result itself + ); + """ + ) + self._cursor.execute( + """--sql + CREATE UNIQUE INDEX IF NOT EXISTS idx_result_id ON results(id); + """ + ) + finally: + self._lock.release() + + def _parse_joined_result(self, result_row: Any, column_names: list[str]): + result_raw = {} + session_raw = {} + + for idx, name in enumerate(column_names): + if name == "session": + session_raw = json.loads(result_row[idx]) + elif name == "data": + result_raw[name] = json.loads(result_row[idx]) + else: + result_raw[name] = result_row[idx] + + graph_raw = session_raw['execution_graph'] + + result = parse_obj_as(Result, result_raw) + session = parse_obj_as(GraphExecutionState, session_raw) + + m = build_core_metadata(graph_raw, result.node_id) + print(m) + + # g = session.execution_graph.nx_graph() + # ancestors = nx.dag.ancestors(g, result.node_id) + + # nodes = [session.execution_graph.get_node(result.node_id)] + # for ancestor in ancestors: + # nodes.append(session.execution_graph.get_node(ancestor)) + + # filtered_nodes = filter(lambda n: n.type in NODE_TYPE_ALLOWLIST, nodes) + # print(list(map(lambda n: n.dict(), filtered_nodes))) + # metadata = {} + # for node in nodes: + # if (node.type in ['txt2img', 'img2img',]) + # for field, value in node.dict().items(): + # if field not in ['type', 'id']: + # if field not in metadata: + # metadata[field] = value + + # print(ancestors) + # print(nodes) + # print(metadata) + + # for node in nodes: + # print(node.dict()) + + # print(nodes) + + return ResultWithSession( + result=result, + session=session, + ) + + def get( + self, result_id: str, result_type: ResultType + ) -> Union[ResultWithSession, None]: + """Retrieves a result by ID and type.""" + try: + self._lock.acquire() + self._cursor.execute( + """--sql + SELECT + results.id AS id, + results.result_type AS result_type, + results.node_id AS node_id, + results.session_id AS session_id, + results.data AS data, + graph_executions.item AS session + FROM results + JOIN graph_executions ON results.session_id = graph_executions.id + WHERE results.id = ? AND results.result_type = ? + """, + (result_id, result_type), + ) + + result_row = self._cursor.fetchone() + + if result_row is None: + return None + + column_names = list(map(lambda x: x[0], self._cursor.description)) + result_parsed = self._parse_joined_result(result_row, column_names) + finally: + self._lock.release() + + if not result_parsed: + return None + + return result_parsed + + def get_many( + self, + result_type: ResultType, + page: int = 0, + per_page: int = 10, + ) -> PaginatedResults[ResultWithSession]: + """Lists results of a given type.""" + try: + self._lock.acquire() + + self._cursor.execute( + f"""--sql + SELECT + results.id AS id, + results.result_type AS result_type, + results.node_id AS node_id, + results.session_id AS session_id, + results.data AS data, + graph_executions.item AS session + FROM results + JOIN graph_executions ON results.session_id = graph_executions.id + WHERE results.result_type = ? + LIMIT ? OFFSET ?; + """, + (result_type.value, per_page, page * per_page), + ) + + result_rows = self._cursor.fetchall() + column_names = list(map(lambda c: c[0], self._cursor.description)) + + result_parsed = [] + + for result_row in result_rows: + result_parsed.append( + self._parse_joined_result(result_row, column_names) + ) + + self._cursor.execute("""SELECT count(*) FROM results;""") + count = self._cursor.fetchone()[0] + finally: + self._lock.release() + + pageCount = int(count / per_page) + 1 + + return PaginatedResults[ResultWithSession]( + items=result_parsed, + page=page, + pages=pageCount, + per_page=per_page, + total=count, + ) + + def search( + self, + query: str, + page: int = 0, + per_page: int = 10, + ) -> PaginatedResults[ResultWithSession]: + """Finds results by query.""" + try: + self._lock.acquire() + self._cursor.execute( + """--sql + SELECT results.data, graph_executions.item + FROM results + JOIN graph_executions ON results.session_id = graph_executions.id + WHERE item LIKE ? + LIMIT ? OFFSET ?; + """, + (f"%{query}%", per_page, page * per_page), + ) + + result_rows = self._cursor.fetchall() + + items = list( + map( + lambda r: ResultWithSession( + result=parse_raw_as(Result, r[0]), + session=parse_raw_as(GraphExecutionState, r[1]), + ), + result_rows, + ) + ) + self._cursor.execute( + """--sql + SELECT count(*) FROM results WHERE item LIKE ?; + """, + (f"%{query}%",), + ) + count = self._cursor.fetchone()[0] + finally: + self._lock.release() + + pageCount = int(count / per_page) + 1 + + return PaginatedResults[ResultWithSession]( + items=items, page=page, pages=pageCount, per_page=per_page, total=count + ) + + def handle_graph_execution_state_change(self, session: GraphExecutionState) -> None: + """Updates the results table with the results from the session.""" + with self._conn as conn: + for node_id, result in session.results.items(): + # We'll only process 'image_output' or 'latents_output' + if result.type not in ["image_output", "latents_output"]: + continue + + # The id depends on the result type + if result.type == "image_output": + id = result.image.image_name + result_type = "image_output" + else: + id = result.latents.latents_name + result_type = "latents_output" + + # Insert the result into the results table, ignoring if it already exists + conn.execute( + """--sql + INSERT OR IGNORE INTO results (id, result_type, node_id, session_id, created_at, data) + VALUES (?, ?, ?, ?, ?, ?) + """, + ( + id, + result_type, + node_id, + session.id, + get_timestamp(), + result.json(), + ), + ) diff --git a/invokeai/app/services/urls.py b/invokeai/app/services/urls.py new file mode 100644 index 0000000000..16f8fc7494 --- /dev/null +++ b/invokeai/app/services/urls.py @@ -0,0 +1,32 @@ +import os +from abc import ABC, abstractmethod + +from invokeai.app.models.image import ImageType +from invokeai.app.util.thumbnails import get_thumbnail_name + + +class UrlServiceBase(ABC): + """Responsible for building URLs for resources (eg images or tensors)""" + + @abstractmethod + def get_image_url(self, image_type: ImageType, image_name: str) -> str: + """Gets the URL for an image""" + pass + + @abstractmethod + def get_thumbnail_url(self, image_type: ImageType, image_name: str) -> str: + """Gets the URL for an image's thumbnail""" + pass + + +class LocalUrlService(UrlServiceBase): + def __init__(self, base_url: str = "api/v1"): + self._base_url = base_url + + def get_image_url(self, image_type: ImageType, image_name: str) -> str: + image_basename = os.path.basename(image_name) + return f"{self._base_url}/images/{image_type.value}/{image_basename}" + + def get_thumbnail_url(self, image_type: ImageType, image_name: str) -> str: + thumbnail_basename = get_thumbnail_name(os.path.basename(image_name)) + return f"{self._base_url}/images/{image_type.value}/thumbnails/{thumbnail_basename}" diff --git a/invokeai/app/services/util/create_enum_table.py b/invokeai/app/services/util/create_enum_table.py new file mode 100644 index 0000000000..03cbfd6e90 --- /dev/null +++ b/invokeai/app/services/util/create_enum_table.py @@ -0,0 +1,39 @@ +from enum import Enum +import sqlite3 +from typing import Type + + +def create_sql_values_string_from_string_enum(enum: Type[Enum]): + """ + Creates a string of the form "('value1'), ('value2'), ..., ('valueN')" from a StrEnum. + """ + + delimiter = ", " + values = [f"('{e.value}')" for e in enum] + return delimiter.join(values) + + +def create_enum_table( + enum: Type[Enum], + table_name: str, + primary_key_name: str, + cursor: sqlite3.Cursor, +): + """ + Creates and populates a table to be used as a functional enum. + """ + + values_string = create_sql_values_string_from_string_enum(enum) + + cursor.execute( + f"""--sql + CREATE TABLE IF NOT EXISTS {table_name} ( + {primary_key_name} TEXT PRIMARY KEY + ); + """ + ) + cursor.execute( + f"""--sql + INSERT OR IGNORE INTO {table_name} ({primary_key_name}) VALUES {values_string}; + """ + ) diff --git a/invokeai/app/services/util/deserialize_image_record.py b/invokeai/app/services/util/deserialize_image_record.py new file mode 100644 index 0000000000..52014b78c5 --- /dev/null +++ b/invokeai/app/services/util/deserialize_image_record.py @@ -0,0 +1,33 @@ +from invokeai.app.models.metadata import ( + GeneratedImageOrLatentsMetadata, + UploadedImageOrLatentsMetadata, +) +from invokeai.app.models.image import ImageCategory, ImageType +from invokeai.app.services.models.image_record import ImageRecord +from invokeai.app.util.misc import get_iso_timestamp + + +def deserialize_image_record(image: dict) -> ImageRecord: + """Deserializes an image record.""" + + # All values *should* be present, except `session_id` and `node_id`, but provide some defaults just in case + + image_type = ImageType(image.get("image_type", ImageType.RESULT.value)) + raw_metadata = image.get("metadata", {}) + + if image_type == ImageType.UPLOAD: + metadata = UploadedImageOrLatentsMetadata.parse_obj(raw_metadata) + else: + metadata = GeneratedImageOrLatentsMetadata.parse_obj(raw_metadata) + + return ImageRecord( + image_name=image.get("id", "unknown"), + session_id=image.get("session_id", None), + node_id=image.get("node_id", None), + metadata=metadata, + image_type=image_type, + image_category=ImageCategory( + image.get("image_category", ImageCategory.IMAGE.value) + ), + created_at=image.get("created_at", get_iso_timestamp()), + ) diff --git a/invokeai/app/util/enum.py b/invokeai/app/util/enum.py new file mode 100644 index 0000000000..5bba5712c5 --- /dev/null +++ b/invokeai/app/util/enum.py @@ -0,0 +1,12 @@ +from enum import EnumMeta + + +class MetaEnum(EnumMeta): + """Metaclass to support `in` syntax value checking in String Enums""" + + def __contains__(cls, item): + try: + cls(item) + except ValueError: + return False + return True diff --git a/invokeai/app/util/misc.py b/invokeai/app/util/misc.py index c3d091b653..7c674674e2 100644 --- a/invokeai/app/util/misc.py +++ b/invokeai/app/util/misc.py @@ -6,6 +6,14 @@ def get_timestamp(): return int(datetime.datetime.now(datetime.timezone.utc).timestamp()) +def get_iso_timestamp() -> str: + return datetime.datetime.utcnow().isoformat() + + +def get_datetime_from_iso_timestamp(iso_timestamp: str) -> datetime.datetime: + return datetime.datetime.fromisoformat(iso_timestamp) + + SEED_MAX = np.iinfo(np.int32).max diff --git a/invokeai/frontend/web/src/services/api/models/ImageOutput.ts b/invokeai/frontend/web/src/services/api/models/ImageOutput.ts index 09b842de26..d7db0c11de 100644 --- a/invokeai/frontend/web/src/services/api/models/ImageOutput.ts +++ b/invokeai/frontend/web/src/services/api/models/ImageOutput.ts @@ -8,7 +8,7 @@ import type { ImageField } from './ImageField'; * Base class for invocations that output an image */ export type ImageOutput = { - type: 'image'; + type: 'image_output'; /** * The output image */ diff --git a/invokeai/frontend/web/src/services/api/models/RandomIntInvocation.ts b/invokeai/frontend/web/src/services/api/models/RandomIntInvocation.ts index af7cf85666..0a5220c31d 100644 --- a/invokeai/frontend/web/src/services/api/models/RandomIntInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/RandomIntInvocation.ts @@ -11,5 +11,13 @@ export type RandomIntInvocation = { */ id: string; type?: 'rand_int'; + /** + * The inclusive low value + */ + low?: number; + /** + * The exclusive high value + */ + high?: number; }; diff --git a/invokeai/frontend/web/src/services/api/schemas/$RandomIntInvocation.ts b/invokeai/frontend/web/src/services/api/schemas/$RandomIntInvocation.ts index e5b0387d5a..e70192fae5 100644 --- a/invokeai/frontend/web/src/services/api/schemas/$RandomIntInvocation.ts +++ b/invokeai/frontend/web/src/services/api/schemas/$RandomIntInvocation.ts @@ -12,5 +12,13 @@ export const $RandomIntInvocation = { type: { type: 'Enum', }, + low: { + type: 'number', + description: `The inclusive low value`, + }, + high: { + type: 'number', + description: `The exclusive high value`, + }, }, } as const; diff --git a/invokeai/frontend/web/src/services/types/guards.ts b/invokeai/frontend/web/src/services/types/guards.ts index 72cf1108fb..5065290220 100644 --- a/invokeai/frontend/web/src/services/types/guards.ts +++ b/invokeai/frontend/web/src/services/types/guards.ts @@ -10,11 +10,16 @@ import { CollectInvocationOutput, ImageType, ImageField, + LatentsOutput, } from 'services/api'; export const isImageOutput = ( output: GraphExecutionState['results'][string] -): output is ImageOutput => output.type === 'image'; +): output is ImageOutput => output.type === 'image_output'; + +export const isLatentsOutput = ( + output: GraphExecutionState['results'][string] +): output is LatentsOutput => output.type === 'latents_output'; export const isMaskOutput = ( output: GraphExecutionState['results'][string]