feat(nodes): add high-level images service

feat(nodes): add ResultsServiceABC & SqliteResultsService

**Doesn't actually work bc of circular imports. Can't even test it.**

- add a base class for ResultsService and SQLite implementation
- use `graph_execution_manager` `on_changed` callback to keep `results` table in sync

fix(nodes): fix results service bugs

chore(ui): regen api

fix(ui): fix type guards

feat(nodes): add `result_type` to results table, fix types

fix(nodes): do not shadow `list` builtin

feat(nodes): add results router

It doesn't work due to circular imports still

fix(nodes): Result class should use outputs classes, not fields

feat(ui): crude results router

fix(ui): send to canvas in currentimagebuttons not working

feat(nodes): add core metadata builder

feat(nodes): add design doc

feat(nodes): wip latents db stuff

feat(nodes): images_db_service and resources router

feat(nodes): wip images db & router

feat(nodes): update image related names

feat(nodes): update urlservice

feat(nodes): add high-level images service
This commit is contained in:
psychedelicious 2023-05-17 19:13:53 +10:00 committed by Kent Keirsey
parent fb0b63c580
commit 9c89d3452c
29 changed files with 2851 additions and 83 deletions

View File

@ -1,9 +1,13 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import os
from types import ModuleType
from invokeai.app.services.database.images.sqlite_images_db_service import (
SqliteImageDb,
)
from invokeai.app.services.urls import LocalUrlService
import invokeai.backend.util.logging as logger
from typing import types
from ..services.default_graphs import create_system_graphs
from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
@ -17,6 +21,7 @@ from ..services.invoker import Invoker
from ..services.processor import DefaultInvocationProcessor
from ..services.sqlite import SqliteItemStorage
from ..services.metadata import PngMetadataService
from ..services.results import SqliteResultsService
from .events import FastAPIEventService
@ -50,28 +55,41 @@ class ApiDependencies:
os.path.join(os.path.dirname(__file__), "../../../../outputs")
)
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f'{output_folder}/latents'))
latents = ForwardCacheLatentsStorage(
DiskLatentsStorage(f"{output_folder}/latents")
)
metadata = PngMetadataService()
images = DiskImageStorage(f'{output_folder}/images', metadata_service=metadata)
urls = LocalUrlService()
images = DiskImageStorage(f"{output_folder}/images", metadata_service=metadata)
# TODO: build a file/path manager?
db_location = os.path.join(output_folder, "invokeai.db")
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
filename=db_location, table_name="graph_executions"
)
images_db = SqliteImageDb(filename=db_location)
# register event handler to update the `results` table when a graph execution state is inserted or updated
# graph_execution_manager.on_changed(results.handle_graph_execution_state_change)
services = InvocationServices(
model_manager=get_model_manager(config,logger),
model_manager=get_model_manager(config, logger),
events=events,
latents=latents,
images=images,
metadata=metadata,
images_db=images_db,
urls=urls,
queue=MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph](
filename=db_location, table_name="graphs"
),
graph_execution_manager=SqliteItemStorage[GraphExecutionState](
filename=db_location, table_name="graph_executions"
),
graph_execution_manager=graph_execution_manager,
processor=DefaultInvocationProcessor(),
restoration=RestorationServices(config,logger),
configuration=config,

View File

@ -0,0 +1,74 @@
from fastapi import HTTPException, Path, Query
from fastapi.routing import APIRouter
from invokeai.app.models.image import (
ImageCategory,
ImageType,
)
from invokeai.app.services.image_db import ImageRecordServiceBase
from invokeai.app.services.image_storage import ImageStorageBase
from invokeai.app.services.models.image_record import ImageRecord
from invokeai.app.services.item_storage import PaginatedResults
from ..dependencies import ApiDependencies
image_records_router = APIRouter(prefix="/v1/records/images", tags=["records"])
@image_records_router.get("/{image_type}/{image_name}", operation_id="get_image_record")
async def get_image_record(
image_type: ImageType = Path(description="The type of the image record to get"),
image_name: str = Path(description="The id of the image record to get"),
) -> ImageRecord:
"""Gets an image record by id"""
try:
return ApiDependencies.invoker.services.images_new.get_record(
image_type=image_type, image_name=image_name
)
except ImageRecordServiceBase.ImageRecordNotFoundException:
raise HTTPException(status_code=404)
@image_records_router.get(
"/",
operation_id="list_image_records",
)
async def list_image_records(
image_type: ImageType = Query(description="The type of image records to get"),
image_category: ImageCategory = Query(
description="The kind of image records to get"
),
page: int = Query(default=0, description="The page of image records to get"),
per_page: int = Query(
default=10, description="The number of image records per page"
),
) -> PaginatedResults[ImageRecord]:
"""Gets a list of image records by type and category"""
images = ApiDependencies.invoker.services.images_new.get_many(
image_type=image_type,
image_category=image_category,
page=page,
per_page=per_page,
)
return images
@image_records_router.delete("/{image_type}/{image_name}", operation_id="delete_image")
async def delete_image_record(
image_type: ImageType = Query(description="The type of image records to get"),
image_name: str = Path(description="The name of the image to delete"),
) -> None:
"""Deletes an image record"""
try:
ApiDependencies.invoker.services.images_new.delete(
image_type=image_type, image_name=image_name
)
except ImageStorageBase.ImageFileDeleteException:
# TODO: log this
pass
except ImageRecordServiceBase.ImageRecordDeleteException:
# TODO: log this
pass

View File

@ -14,23 +14,39 @@ from invokeai.app.api.models.images import (
ImageResponse,
ImageResponseMetadata,
)
from invokeai.app.models.image import ImageType
from invokeai.app.services.item_storage import PaginatedResults
from ...services.image_storage import ImageType
from ..dependencies import ApiDependencies
images_router = APIRouter(prefix="/v1/images", tags=["images"])
@images_router.get("/{image_type}/{image_name}", operation_id="get_image")
# @images_router.get("/{image_type}/{image_name}", operation_id="get_image")
# async def get_image(
# image_type: ImageType = Path(description="The type of image to get"),
# image_name: str = Path(description="The name of the image to get"),
# ) -> FileResponse:
# """Gets an image"""
# path = ApiDependencies.invoker.services.images.get_path(
# image_type=image_type, image_name=image_name
# )
# if ApiDependencies.invoker.services.images.validate_path(path):
# return FileResponse(path)
# else:
# raise HTTPException(status_code=404)
@images_router.get("/{image_type}/{image_id}", operation_id="get_image")
async def get_image(
image_type: ImageType = Path(description="The type of image to get"),
image_name: str = Path(description="The name of the image to get"),
image_type: ImageType = Path(description="The type of the image to get"),
image_id: str = Path(description="The id of the image to get"),
) -> FileResponse:
"""Gets an image"""
path = ApiDependencies.invoker.services.images.get_path(
image_type=image_type, image_name=image_name
image_type=image_type, image_id=image_id
)
if ApiDependencies.invoker.services.images.validate_path(path):
@ -41,7 +57,7 @@ async def get_image(
@images_router.delete("/{image_type}/{image_name}", operation_id="delete_image")
async def delete_image(
image_type: ImageType = Path(description="The type of image to delete"),
image_type: ImageType = Path(description="The type of the image to delete"),
image_name: str = Path(description="The name of the image to delete"),
) -> None:
"""Deletes an image and its thumbnail"""
@ -52,16 +68,16 @@ async def delete_image(
@images_router.get(
"/{thumbnail_type}/thumbnails/{thumbnail_name}", operation_id="get_thumbnail"
"/{image_type}/thumbnails/{thumbnail_id}", operation_id="get_thumbnail"
)
async def get_thumbnail(
thumbnail_type: ImageType = Path(description="The type of thumbnail to get"),
thumbnail_name: str = Path(description="The name of the thumbnail to get"),
image_type: ImageType = Path(description="The type of the thumbnail to get"),
thumbnail_id: str = Path(description="The id of the thumbnail to get"),
) -> FileResponse | Response:
"""Gets a thumbnail"""
path = ApiDependencies.invoker.services.images.get_path(
image_type=thumbnail_type, image_name=thumbnail_name, is_thumbnail=True
image_type=image_type, image_id=thumbnail_id, is_thumbnail=True
)
if ApiDependencies.invoker.services.images.validate_path(path):

View File

@ -0,0 +1,42 @@
from fastapi import HTTPException, Path, Query
from fastapi.routing import APIRouter
from invokeai.app.services.results import ResultType, ResultWithSession
from invokeai.app.services.item_storage import PaginatedResults
from ..dependencies import ApiDependencies
results_router = APIRouter(prefix="/v1/results", tags=["results"])
@results_router.get("/{result_type}/{result_name}", operation_id="get_result")
async def get_result(
result_type: ResultType = Path(description="The type of result to get"),
result_name: str = Path(description="The name of the result to get"),
) -> ResultWithSession:
"""Gets a result"""
result = ApiDependencies.invoker.services.results.get(
result_id=result_name, result_type=result_type
)
if result is not None:
return result
else:
raise HTTPException(status_code=404)
@results_router.get(
"/",
operation_id="list_results",
responses={200: {"model": PaginatedResults[ResultWithSession]}},
)
async def list_results(
result_type: ResultType = Query(description="The type of results to get"),
page: int = Query(default=0, description="The page of results to get"),
per_page: int = Query(default=10, description="The number of results per page"),
) -> PaginatedResults[ResultWithSession]:
"""Gets a list of results"""
results = ApiDependencies.invoker.services.results.get_many(
result_type=result_type, page=page, per_page=per_page
)
return results

View File

@ -3,6 +3,7 @@ import asyncio
from inspect import signature
import uvicorn
from invokeai.app.models import resources
import invokeai.backend.util.logging as logger
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
@ -14,11 +15,12 @@ from fastapi_events.middleware import EventHandlerASGIMiddleware
from pydantic.schema import schema
from .api.dependencies import ApiDependencies
from .api.routers import images, sessions, models
from .api.routers import image_resources, images, sessions, models
from .api.sockets import SocketIO
from .invocations.baseinvocation import BaseInvocation
from .services.config import InvokeAIAppConfig
# Create the app
# TODO: create this all in a method so configuration/etc. can be passed in?
app = FastAPI(title="Invoke AI", docs_url=None, redoc_url=None)
@ -73,6 +75,8 @@ app.include_router(images.images_router, prefix="/api")
app.include_router(models.models_router, prefix="/api")
app.include_router(image_resources.image_resources_router, prefix="/api")
# Build a custom OpenAPI to include all outputs
# TODO: can outputs be included on metadata of invocation schemas somehow?
@ -121,6 +125,7 @@ app.openapi = custom_openapi
# Override API doc favicons
app.mount("/static", StaticFiles(directory="static/dream_web"), name="static")
@app.get("/docs", include_in_schema=False)
def overridden_swagger():
return get_swagger_ui_html(
@ -138,8 +143,12 @@ def overridden_redoc():
redoc_favicon_url="/static/favicon.ico",
)
# Must mount *after* the other routes else it borks em
app.mount("/", StaticFiles(directory="invokeai/frontend/web/dist", html=True), name="ui")
app.mount(
"/", StaticFiles(directory="invokeai/frontend/web/dist", html=True), name="ui"
)
def invoke_api():
# Start our own event loop for eventing usage

View File

@ -1,12 +1,15 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from __future__ import annotations
from abc import ABC, abstractmethod
from inspect import signature
from typing import get_args, get_type_hints, Dict, List, Literal, TypedDict
from typing import get_args, get_type_hints, Dict, List, Literal, TypedDict, TYPE_CHECKING
from pydantic import BaseModel, Field
from ..services.invocation_services import InvocationServices
if TYPE_CHECKING:
from ..services.invocation_services import InvocationServices
class InvocationContext:

View File

@ -10,6 +10,8 @@ from pydantic import BaseModel, Field
from invokeai.app.models.image import ColorField, ImageField, ImageType
from invokeai.app.invocations.util.choose_model import choose_model
from invokeai.app.models.metadata import GeneratedImageOrLatentsMetadata
from invokeai.app.models.image import ImageCategory, ImageType
from invokeai.app.util.misc import SEED_MAX, get_random_seed
from invokeai.backend.generator.inpaint import infill_methods
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
@ -106,6 +108,16 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
context.services.images.save(
image_type, image_name, generate_output.image, metadata
)
context.services.images_db.set(
id=image_name,
image_type=ImageType.RESULT,
image_category=ImageCategory.IMAGE,
session_id=context.graph_execution_state_id,
node_id=self.id,
metadata=GeneratedImageOrLatentsMetadata(),
)
return build_image_output(
image_type=image_type,
image_name=image_name,

View File

@ -31,7 +31,7 @@ class ImageOutput(BaseInvocationOutput):
"""Base class for invocations that output an image"""
# fmt: off
type: Literal["image"] = "image"
type: Literal["image_output"] = "image_output"
image: ImageField = Field(default=None, description="The output image")
width: int = Field(description="The width of the image in pixels")
height: int = Field(description="The height of the image in pixels")

View File

@ -2,11 +2,23 @@ from enum import Enum
from typing import Optional, Tuple
from pydantic import BaseModel, Field
from invokeai.app.util.enum import MetaEnum
class ImageType(str, Enum, metaclass=MetaEnum):
"""The type of an image."""
class ImageType(str, Enum):
RESULT = "results"
INTERMEDIATE = "intermediates"
UPLOAD = "uploads"
INTERMEDIATE = "intermediates"
class ImageCategory(str, Enum, metaclass=MetaEnum):
"""The category of an image. Use ImageCategory.OTHER for non-default categories."""
IMAGE = "image"
CONTROL_IMAGE = "control_image"
OTHER = "other"
def is_image_type(obj):

View File

@ -0,0 +1,70 @@
from typing import Optional
from pydantic import BaseModel, Field, StrictFloat, StrictInt, StrictStr
class GeneratedImageOrLatentsMetadata(BaseModel):
"""Core generation metadata for an image/tensor generated in InvokeAI.
Generated by traversing the execution graph, collecting the parameters of the nearest ancestors of a given node.
Full metadata may be accessed by querying for the session in the `graph_executions` table.
"""
positive_conditioning: Optional[StrictStr] = Field(
default=None, description="The positive conditioning."
)
negative_conditioning: Optional[StrictStr] = Field(
default=None, description="The negative conditioning."
)
width: Optional[StrictInt] = Field(
default=None, description="Width of the image/tensor in pixels."
)
height: Optional[StrictInt] = Field(
default=None, description="Height of the image/tensor in pixels."
)
seed: Optional[StrictInt] = Field(
default=None, description="The seed used for noise generation."
)
cfg_scale: Optional[StrictFloat] = Field(
default=None, description="The classifier-free guidance scale."
)
steps: Optional[StrictInt] = Field(
default=None, description="The number of steps used for inference."
)
scheduler: Optional[StrictStr] = Field(
default=None, description="The scheduler used for inference."
)
model: Optional[StrictStr] = Field(
default=None, description="The model used for inference."
)
strength: Optional[StrictFloat] = Field(
default=None,
description="The strength used for image-to-image/tensor-to-tensor.",
)
image: Optional[StrictStr] = Field(
default=None, description="The ID of the initial image."
)
tensor: Optional[StrictStr] = Field(
default=None, description="The ID of the initial tensor."
)
# Pending model refactor:
# vae: Optional[str] = Field(default=None,description="The VAE used for decoding.")
# unet: Optional[str] = Field(default=None,description="The UNet used dor inference.")
# clip: Optional[str] = Field(default=None,description="The CLIP Encoder used for conditioning.")
class UploadedImageOrLatentsMetadata(BaseModel):
"""Limited metadata for an uploaded image/tensor."""
width: Optional[StrictInt] = Field(
default=None, description="Width of the image/tensor in pixels."
)
height: Optional[StrictInt] = Field(
default=None, description="Height of the image/tensor in pixels."
)
# The extra field will be the contents of the PNG file's tEXt chunk. It may have come
# from another SD application or InvokeAI, so it needs to be flexible.
# If the upload is a not an image or `image_latents` tensor, this will be omitted.
extra: Optional[StrictStr] = Field(
default=None, description="Extra metadata, extracted from the PNG tEXt chunk."
)

View File

@ -0,0 +1,28 @@
# TODO: Make a new model for this
from enum import Enum
from invokeai.app.util.enum import MetaEnum
class ResourceType(str, Enum, metaclass=MetaEnum):
"""The type of a resource."""
IMAGES = "images"
TENSORS = "tensors"
# class ResourceOrigin(str, Enum, metaclass=MetaEnum):
# """The origin of a resource (eg image or tensor)."""
# RESULTS = "results"
# UPLOADS = "uploads"
# INTERMEDIATES = "intermediates"
class TensorKind(str, Enum, metaclass=MetaEnum):
"""The kind of a tensor. Use TensorKind.OTHER for non-default kinds."""
IMAGE_LATENTS = "image_latents"
CONDITIONING = "conditioning"
OTHER = "other"

View File

@ -0,0 +1,578 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [],
"source": [
"from abc import ABC, abstractmethod\n",
"from enum import Enum\n",
"import enum\n",
"import sqlite3\n",
"import threading\n",
"from typing import Optional, Type, TypeVar, Union\n",
"from PIL.Image import Image as PILImage\n",
"from pydantic import BaseModel, Field\n",
"from torch import Tensor"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
"outputs": [],
"source": [
"\n",
"class ResourceOrigin(str, Enum):\n",
" \"\"\"The origin of a resource (eg image or tensor).\"\"\"\n",
"\n",
" RESULTS = \"results\"\n",
" UPLOADS = \"uploads\"\n",
" INTERMEDIATES = \"intermediates\"\n",
"\n",
"\n",
"class ImageKind(str, Enum):\n",
" \"\"\"The kind of an image. Use ImageKind.OTHER for non-default kinds.\"\"\"\n",
"\n",
" IMAGE = \"image\"\n",
" CONTROL_IMAGE = \"control_image\"\n",
" OTHER = \"other\"\n",
"\n",
"\n",
"class TensorKind(str, Enum):\n",
" \"\"\"The kind of a tensor. Use TensorKind.OTHER for non-default kinds.\"\"\"\n",
"\n",
" IMAGE_LATENTS = \"image_latents\"\n",
" CONDITIONING = \"conditioning\"\n",
" OTHER = \"other\"\n"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {},
"outputs": [],
"source": [
"\n",
"def create_sql_values_string_from_string_enum(enum: Type[Enum]):\n",
" \"\"\"\n",
" Creates a string of the form \"('value1'), ('value2'), ..., ('valueN')\" from a StrEnum.\n",
" \"\"\"\n",
"\n",
" delimiter = \", \"\n",
" values = [f\"('{e.value}')\" for e in enum]\n",
" return delimiter.join(values)\n",
"\n",
"\n",
"def create_sql_table_from_enum(\n",
" enum: Type[Enum],\n",
" table_name: str,\n",
" primary_key_name: str,\n",
" conn: sqlite3.Connection,\n",
" cursor: sqlite3.Cursor,\n",
" lock: threading.Lock,\n",
"):\n",
" \"\"\"\n",
" Creates and populates a table to be used as a functional enum.\n",
" \"\"\"\n",
"\n",
" try:\n",
" lock.acquire()\n",
"\n",
" values_string = create_sql_values_string_from_string_enum(enum)\n",
"\n",
" cursor.execute(\n",
" f\"\"\"--sql\n",
" CREATE TABLE IF NOT EXISTS {table_name} (\n",
" {primary_key_name} TEXT PRIMARY KEY\n",
" );\n",
" \"\"\"\n",
" )\n",
" cursor.execute(\n",
" f\"\"\"--sql\n",
" INSERT OR IGNORE INTO {table_name} ({primary_key_name}) VALUES {values_string};\n",
" \"\"\"\n",
" )\n",
" conn.commit()\n",
" finally:\n",
" lock.release()\n",
"\n",
"\n",
"\"\"\"\n",
"`resource_origins` functions as an enum for the ResourceOrigin model.\n",
"\"\"\"\n",
"\n",
"\n",
"# def create_resource_origins_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n",
"# create_sql_table_from_enum(\n",
"# enum=ResourceOrigin,\n",
"# table_name=\"resource_origins\",\n",
"# primary_key_name=\"origin_name\",\n",
"# conn=conn,\n",
"# cursor=cursor,\n",
"# lock=lock,\n",
"# )\n",
"\n",
"\n",
"\"\"\"\n",
"`image_kinds` functions as an enum for the ImageType model.\n",
"\"\"\"\n",
"\n",
"\n",
"# def create_image_kinds_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n",
" # create_sql_table_from_enum(\n",
" # enum=ImageKind,\n",
" # table_name=\"image_kinds\",\n",
" # primary_key_name=\"kind_name\",\n",
" # conn=conn,\n",
" # cursor=cursor,\n",
" # lock=lock,\n",
" # )\n",
"\n",
"\n",
"\"\"\"\n",
"`tensor_kinds` functions as an enum for the TensorType model.\n",
"\"\"\"\n",
"\n",
"\n",
"# def create_tensor_kinds_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n",
" # create_sql_table_from_enum(\n",
" # enum=TensorKind,\n",
" # table_name=\"tensor_kinds\",\n",
" # primary_key_name=\"kind_name\",\n",
" # conn=conn,\n",
" # cursor=cursor,\n",
" # lock=lock,\n",
" # )\n",
"\n",
"\n",
"\"\"\"\n",
"`images` stores all images, regardless of type\n",
"\"\"\"\n",
"\n",
"\n",
"def create_images_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n",
" try:\n",
" lock.acquire()\n",
"\n",
" cursor.execute(\n",
" \"\"\"--sql\n",
" CREATE TABLE IF NOT EXISTS images (\n",
" id TEXT PRIMARY KEY,\n",
" origin TEXT,\n",
" image_kind TEXT,\n",
" created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,\n",
" FOREIGN KEY(origin) REFERENCES resource_origins(origin_name),\n",
" FOREIGN KEY(image_kind) REFERENCES image_kinds(kind_name)\n",
" );\n",
" \"\"\"\n",
" )\n",
" cursor.execute(\n",
" \"\"\"--sql\n",
" CREATE UNIQUE INDEX IF NOT EXISTS idx_images_id ON images(id);\n",
" \"\"\"\n",
" )\n",
" cursor.execute(\n",
" \"\"\"--sql\n",
" CREATE INDEX IF NOT EXISTS idx_images_origin ON images(origin);\n",
" \"\"\"\n",
" )\n",
" cursor.execute(\n",
" \"\"\"--sql\n",
" CREATE INDEX IF NOT EXISTS idx_images_image_kind ON images(image_kind);\n",
" \"\"\"\n",
" )\n",
" conn.commit()\n",
" finally:\n",
" lock.release()\n",
"\n",
"\n",
"\"\"\"\n",
"`images_results` stores additional data specific to `results` images.\n",
"\"\"\"\n",
"\n",
"\n",
"def create_images_results_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n",
" try:\n",
" lock.acquire()\n",
"\n",
" cursor.execute(\n",
" \"\"\"--sql\n",
" CREATE TABLE IF NOT EXISTS images_results (\n",
" images_id TEXT PRIMARY KEY,\n",
" session_id TEXT NOT NULL,\n",
" node_id TEXT NOT NULL,\n",
" FOREIGN KEY(images_id) REFERENCES images(id) ON DELETE CASCADE\n",
" );\n",
" \"\"\"\n",
" )\n",
" cursor.execute(\n",
" \"\"\"--sql\n",
" CREATE UNIQUE INDEX IF NOT EXISTS idx_images_results_images_id ON images_results(images_id);\n",
" \"\"\"\n",
" )\n",
" conn.commit()\n",
" finally:\n",
" lock.release()\n",
"\n",
"\n",
"\"\"\"\n",
"`images_intermediates` stores additional data specific to `intermediates` images\n",
"\"\"\"\n",
"\n",
"\n",
"def create_images_intermediates_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n",
" try:\n",
" lock.acquire()\n",
"\n",
" cursor.execute(\n",
" \"\"\"--sql\n",
" CREATE TABLE IF NOT EXISTS images_intermediates (\n",
" images_id TEXT PRIMARY KEY,\n",
" session_id TEXT NOT NULL,\n",
" node_id TEXT NOT NULL,\n",
" FOREIGN KEY(images_id) REFERENCES images(id) ON DELETE CASCADE\n",
" );\n",
" \"\"\"\n",
" )\n",
" cursor.execute(\n",
" \"\"\"--sql\n",
" CREATE UNIQUE INDEX IF NOT EXISTS idx_images_intermediates_images_id ON images_intermediates(images_id);\n",
" \"\"\"\n",
" )\n",
" conn.commit()\n",
" finally:\n",
" lock.release()\n",
"\n",
"\n",
"\"\"\"\n",
"`images_metadata` stores basic metadata for any image type\n",
"\"\"\"\n",
"\n",
"\n",
"def create_images_metadata_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n",
" try:\n",
" lock.acquire()\n",
"\n",
" cursor.execute(\n",
" \"\"\"--sql\n",
" CREATE TABLE IF NOT EXISTS images_metadata (\n",
" images_id TEXT PRIMARY KEY,\n",
" metadata TEXT,\n",
" FOREIGN KEY(images_id) REFERENCES images(id) ON DELETE CASCADE\n",
" );\n",
" \"\"\"\n",
" )\n",
" cursor.execute(\n",
" \"\"\"--sql\n",
" CREATE UNIQUE INDEX IF NOT EXISTS idx_images_metadata_images_id ON images_metadata(images_id);\n",
" \"\"\"\n",
" )\n",
" conn.commit()\n",
" finally:\n",
" lock.release()\n",
"\n",
"\n",
"# `tensors` table: stores references to tensor\n",
"\n",
"\n",
"def create_tensors_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n",
" try:\n",
" lock.acquire()\n",
"\n",
" cursor.execute(\n",
" \"\"\"--sql\n",
" CREATE TABLE IF NOT EXISTS tensors (\n",
" id TEXT PRIMARY KEY,\n",
" origin TEXT,\n",
" tensor_kind TEXT,\n",
" created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,\n",
" FOREIGN KEY(origin) REFERENCES resource_origins(origin_name),\n",
" FOREIGN KEY(tensor_kind) REFERENCES tensor_kinds(kind_name)\n",
" );\n",
" \"\"\"\n",
" )\n",
" cursor.execute(\n",
" \"\"\"--sql\n",
" CREATE UNIQUE INDEX IF NOT EXISTS idx_tensors_id ON tensors(id);\n",
" \"\"\"\n",
" )\n",
" cursor.execute(\n",
" \"\"\"--sql\n",
" CREATE INDEX IF NOT EXISTS idx_tensors_origin ON tensors(origin);\n",
" \"\"\"\n",
" )\n",
" cursor.execute(\n",
" \"\"\"--sql\n",
" CREATE INDEX IF NOT EXISTS idx_tensors_tensor_kind ON tensors(tensor_kind);\n",
" \"\"\"\n",
" )\n",
" conn.commit()\n",
" finally:\n",
" lock.release()\n",
"\n",
"\n",
"# `tensors_results` stores additional data specific to `result` tensor\n",
"\n",
"\n",
"def create_tensors_results_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n",
" try:\n",
" lock.acquire()\n",
"\n",
" cursor.execute(\n",
" \"\"\"--sql\n",
" CREATE TABLE IF NOT EXISTS tensors_results (\n",
" tensors_id TEXT PRIMARY KEY,\n",
" session_id TEXT NOT NULL,\n",
" node_id TEXT NOT NULL,\n",
" FOREIGN KEY(tensors_id) REFERENCES tensors(id) ON DELETE CASCADE\n",
" );\n",
" \"\"\"\n",
" )\n",
" cursor.execute(\n",
" \"\"\"--sql\n",
" CREATE UNIQUE INDEX IF NOT EXISTS idx_tensors_results_tensors_id ON tensors_results(tensors_id);\n",
" \"\"\"\n",
" )\n",
" conn.commit()\n",
" finally:\n",
" lock.release()\n",
"\n",
"\n",
"# `tensors_intermediates` stores additional data specific to `intermediate` tensor\n",
"\n",
"\n",
"def create_tensors_intermediates_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n",
" try:\n",
" lock.acquire()\n",
"\n",
" cursor.execute(\n",
" \"\"\"--sql\n",
" CREATE TABLE IF NOT EXISTS tensors_intermediates (\n",
" tensors_id TEXT PRIMARY KEY,\n",
" session_id TEXT NOT NULL,\n",
" node_id TEXT NOT NULL,\n",
" FOREIGN KEY(tensors_id) REFERENCES tensors(id) ON DELETE CASCADE\n",
" );\n",
" \"\"\"\n",
" )\n",
" cursor.execute(\n",
" \"\"\"--sql\n",
" CREATE UNIQUE INDEX IF NOT EXISTS idx_tensors_intermediates_tensors_id ON tensors_intermediates(tensors_id);\n",
" \"\"\"\n",
" )\n",
" conn.commit()\n",
" finally:\n",
" lock.release()\n",
"\n",
"\n",
"# `tensors_metadata` table: stores generated/transformed metadata for tensor\n",
"\n",
"\n",
"def create_tensors_metadata_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n",
" try:\n",
" lock.acquire()\n",
"\n",
" cursor.execute(\n",
" \"\"\"--sql\n",
" CREATE TABLE IF NOT EXISTS tensors_metadata (\n",
" tensors_id TEXT PRIMARY KEY,\n",
" metadata TEXT,\n",
" FOREIGN KEY(tensors_id) REFERENCES tensors(id) ON DELETE CASCADE\n",
" );\n",
" \"\"\"\n",
" )\n",
" cursor.execute(\n",
" \"\"\"--sql\n",
" CREATE UNIQUE INDEX IF NOT EXISTS idx_tensors_metadata_tensors_id ON tensors_metadata(tensors_id);\n",
" \"\"\"\n",
" )\n",
" conn.commit()\n",
" finally:\n",
" lock.release()\n"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"db_path = '/home/bat/Documents/Code/outputs/test.db'\n",
"if (os.path.exists(db_path)):\n",
" os.remove(db_path)\n",
"\n",
"conn = sqlite3.connect(\n",
" db_path, check_same_thread=False\n",
")\n",
"cursor = conn.cursor()\n",
"lock = threading.Lock()"
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {},
"outputs": [],
"source": [
"create_sql_table_from_enum(\n",
" enum=ResourceOrigin,\n",
" table_name=\"resource_origins\",\n",
" primary_key_name=\"origin_name\",\n",
" conn=conn,\n",
" cursor=cursor,\n",
" lock=lock,\n",
")\n",
"\n",
"create_sql_table_from_enum(\n",
" enum=ImageKind,\n",
" table_name=\"image_kinds\",\n",
" primary_key_name=\"kind_name\",\n",
" conn=conn,\n",
" cursor=cursor,\n",
" lock=lock,\n",
")\n",
"\n",
"create_sql_table_from_enum(\n",
" enum=TensorKind,\n",
" table_name=\"tensor_kinds\",\n",
" primary_key_name=\"kind_name\",\n",
" conn=conn,\n",
" cursor=cursor,\n",
" lock=lock,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {},
"outputs": [],
"source": [
"create_images_table(conn, cursor, lock)\n",
"create_images_results_table(conn, cursor, lock)\n",
"create_images_intermediates_table(conn, cursor, lock)\n",
"create_images_metadata_table(conn, cursor, lock)"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {},
"outputs": [],
"source": [
"create_tensors_table(conn, cursor, lock)\n",
"create_tensors_results_table(conn, cursor, lock)\n",
"create_tensors_intermediates_table(conn, cursor, lock)\n",
"create_tensors_metadata_table(conn, cursor, lock)"
]
},
{
"cell_type": "code",
"execution_count": 59,
"metadata": {},
"outputs": [],
"source": [
"\n",
"from pydantic import StrictStr\n",
"\n",
"\n",
"class GeneratedImageOrLatentsMetadata(BaseModel):\n",
" \"\"\"Core generation metadata for an image/tensor generated in InvokeAI.\n",
"\n",
" Generated by traversing the execution graph, collecting the parameters of the nearest ancestors of a given node.\n",
"\n",
" Full metadata may be accessed by querying for the session in the `graph_executions` table.\n",
" \"\"\"\n",
"\n",
" positive_conditioning: Optional[StrictStr] = Field(\n",
" default=None, description=\"The positive conditioning.\"\n",
" )\n",
" negative_conditioning: Optional[str] = Field(\n",
" default=None, description=\"The negative conditioning.\"\n",
" )\n",
" width: Optional[int] = Field(\n",
" default=None, description=\"Width of the image/tensor in pixels.\"\n",
" )\n",
" height: Optional[int] = Field(\n",
" default=None, description=\"Height of the image/tensor in pixels.\"\n",
" )\n",
" seed: Optional[int] = Field(\n",
" default=None, description=\"The seed used for noise generation.\"\n",
" )\n",
" cfg_scale: Optional[float] = Field(\n",
" default=None, description=\"The classifier-free guidance scale.\"\n",
" )\n",
" steps: Optional[int] = Field(\n",
" default=None, description=\"The number of steps used for inference.\"\n",
" )\n",
" scheduler: Optional[str] = Field(\n",
" default=None, description=\"The scheduler used for inference.\"\n",
" )\n",
" model: Optional[str] = Field(\n",
" default=None, description=\"The model used for inference.\"\n",
" )\n",
" strength: Optional[float] = Field(\n",
" default=None,\n",
" description=\"The strength used for image-to-image/tensor-to-tensor.\",\n",
" )\n",
" image: Optional[str] = Field(\n",
" default=None, description=\"The ID of the initial image.\"\n",
" )\n",
" tensor: Optional[str] = Field(\n",
" default=None, description=\"The ID of the initial tensor.\"\n",
" )\n",
" # Pending model refactor:\n",
" # vae: Optional[str] = Field(default=None,description=\"The VAE used for decoding.\")\n",
" # unet: Optional[str] = Field(default=None,description=\"The UNet used dor inference.\")\n",
" # clip: Optional[str] = Field(default=None,description=\"The CLIP Encoder used for conditioning.\")\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 61,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"GeneratedImageOrLatentsMetadata(positive_conditioning='123', negative_conditioning=None, width=None, height=None, seed=None, cfg_scale=None, steps=None, scheduler=None, model=None, strength=None, image=None, tensor=None)"
]
},
"execution_count": 61,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"GeneratedImageOrLatentsMetadata(positive_conditioning='123')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -0,0 +1,329 @@
from abc import ABC, abstractmethod
import datetime
from typing import Optional
from invokeai.app.models.metadata import (
GeneratedImageOrLatentsMetadata,
UploadedImageOrLatentsMetadata,
)
import sqlite3
import threading
from typing import Optional, Union
from invokeai.app.models.metadata import (
GeneratedImageOrLatentsMetadata,
UploadedImageOrLatentsMetadata,
)
from invokeai.app.models.image import (
ImageCategory,
ImageType,
)
from invokeai.app.services.util.create_enum_table import create_enum_table
from invokeai.app.services.models.image_record import ImageRecord
from invokeai.app.services.util.deserialize_image_record import (
deserialize_image_record,
)
from invokeai.app.services.item_storage import PaginatedResults
class ImageRecordServiceBase(ABC):
"""Low-level service responsible for interfacing with the image record store."""
class ImageRecordNotFoundException(Exception):
"""Raised when an image record is not found."""
def __init__(self, message="Image record not found"):
super().__init__(message)
class ImageRecordSaveException(Exception):
"""Raised when an image record cannot be saved."""
def __init__(self, message="Image record not saved"):
super().__init__(message)
class ImageRecordDeleteException(Exception):
"""Raised when an image record cannot be deleted."""
def __init__(self, message="Image record not deleted"):
super().__init__(message)
@abstractmethod
def get(self, image_type: ImageType, image_name: str) -> ImageRecord:
"""Gets an image record."""
pass
@abstractmethod
def get_many(
self,
image_type: ImageType,
image_category: ImageCategory,
page: int = 0,
per_page: int = 10,
) -> PaginatedResults[ImageRecord]:
"""Gets a page of image records."""
pass
@abstractmethod
def delete(self, image_type: ImageType, image_name: str) -> None:
"""Deletes an image record."""
pass
@abstractmethod
def save(
self,
image_name: str,
image_type: ImageType,
image_category: ImageCategory,
session_id: Optional[str],
node_id: Optional[str],
metadata: Optional[
GeneratedImageOrLatentsMetadata | UploadedImageOrLatentsMetadata
],
created_at: str = datetime.datetime.utcnow().isoformat(),
) -> None:
"""Saves an image record."""
pass
class SqliteImageRecordService(ImageRecordServiceBase):
_filename: str
_conn: sqlite3.Connection
_cursor: sqlite3.Cursor
_lock: threading.Lock
def __init__(self, filename: str) -> None:
super().__init__()
self._filename = filename
self._conn = sqlite3.connect(filename, check_same_thread=False)
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
self._conn.row_factory = sqlite3.Row
self._cursor = self._conn.cursor()
self._lock = threading.Lock()
try:
self._lock.acquire()
# Enable foreign keys
self._conn.execute("PRAGMA foreign_keys = ON;")
self._create_tables()
self._conn.commit()
finally:
self._lock.release()
def _create_tables(self) -> None:
"""Creates the tables for the `images` database."""
# Create the `images` table.
self._cursor.execute(
f"""--sql
CREATE TABLE IF NOT EXISTS images (
id TEXT PRIMARY KEY,
image_type TEXT, -- non-nullable via foreign key constraint
image_category TEXT, -- non-nullable via foreign key constraint
session_id TEXT, -- nullable
node_id TEXT, -- nullable
metadata TEXT, -- nullable
created_at TEXT NOT NULL,
FOREIGN KEY(image_type) REFERENCES image_types(type_name),
FOREIGN KEY(image_category) REFERENCES image_categories(category_name)
);
"""
)
# Create the `images` table indices.
self._cursor.execute(
"""--sql
CREATE UNIQUE INDEX IF NOT EXISTS idx_images_id ON images(id);
"""
)
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_images_image_type ON images(image_type);
"""
)
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_images_image_category ON images(image_category);
"""
)
# Create the tables for image-related enums
create_enum_table(
enum=ImageType,
table_name="image_types",
primary_key_name="type_name",
cursor=self._cursor,
)
create_enum_table(
enum=ImageCategory,
table_name="image_categories",
primary_key_name="category_name",
cursor=self._cursor,
)
# Create the `tags` table. TODO: do this elsewhere, shouldn't be in images db service
self._cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS tags (
id INTEGER PRIMARY KEY AUTOINCREMENT,
tag_name TEXT UNIQUE NOT NULL
);
"""
)
# Create the `images_tags` junction table.
self._cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS images_tags (
image_id TEXT,
tag_id INTEGER,
PRIMARY KEY (image_id, tag_id),
FOREIGN KEY(image_id) REFERENCES images(id) ON DELETE CASCADE,
FOREIGN KEY(tag_id) REFERENCES tags(id) ON DELETE CASCADE
);
"""
)
# Create the `images_favorites` table.
self._cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS images_favorites (
image_id TEXT PRIMARY KEY,
favorited_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY(image_id) REFERENCES images(id) ON DELETE CASCADE
);
"""
)
def get(self, image_type: ImageType, image_name: str) -> Union[ImageRecord, None]:
try:
self._lock.acquire()
self._cursor.execute(
f"""--sql
SELECT * FROM images
WHERE id = ?;
""",
(image_name,),
)
result = self._cursor.fetchone()
except sqlite3.Error as e:
self._conn.rollback()
raise self.ImageRecordNotFoundException from e
finally:
self._lock.release()
if not result:
raise self.ImageRecordNotFoundException
return deserialize_image_record(result)
def get_many(
self,
image_type: ImageType,
image_category: ImageCategory,
page: int = 0,
per_page: int = 10,
) -> PaginatedResults[ImageRecord]:
try:
self._lock.acquire()
self._cursor.execute(
f"""--sql
SELECT * FROM images
WHERE image_type = ? AND image_category = ?
LIMIT ? OFFSET ?;
""",
(image_type.value, image_category.value, per_page, page * per_page),
)
result = self._cursor.fetchall()
images = list(map(lambda r: deserialize_image_record(r), result))
self._cursor.execute(
"""--sql
SELECT count(*) FROM images
WHERE image_type = ? AND image_category = ?
""",
(image_type.value, image_category.value),
)
count = self._cursor.fetchone()[0]
except sqlite3.Error as e:
self._conn.rollback()
raise e
finally:
self._lock.release()
pageCount = int(count / per_page) + 1
return PaginatedResults(
items=images, page=page, pages=pageCount, per_page=per_page, total=count
)
def delete(self, image_type: ImageType, image_name: str) -> None:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
DELETE FROM images
WHERE id = ?;
""",
(image_name,),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise ImageRecordServiceBase.ImageRecordDeleteException from e
finally:
self._lock.release()
def save(
self,
image_name: str,
image_type: ImageType,
image_category: ImageCategory,
session_id: Optional[str],
node_id: Optional[str],
metadata: Union[
GeneratedImageOrLatentsMetadata, UploadedImageOrLatentsMetadata, None
],
created_at: str,
) -> None:
try:
metadata_json = (
None if metadata is None else metadata.json(exclude_none=True)
)
self._lock.acquire()
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO images (
id,
image_type,
image_category,
node_id,
session_id,
metadata
created_at
)
VALUES (?, ?, ?, ?, ?, ?, ?);
""",
(
image_name,
image_type.value,
image_category.value,
node_id,
session_id,
metadata_json,
created_at,
),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise ImageRecordServiceBase.ImageRecordNotFoundException from e
finally:
self._lock.release()

View File

@ -27,7 +27,25 @@ from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
class ImageStorageBase(ABC):
"""Responsible for storing and retrieving images."""
"""Low-level service responsible for storing and retrieving images."""
class ImageFileNotFoundException(Exception):
"""Raised when an image file is not found in storage."""
def __init__(self, message="Image file not found"):
super().__init__(message)
class ImageFileSaveException(Exception):
"""Raised when an image cannot be saved."""
def __init__(self, message="Image file not saved"):
super().__init__(message)
class ImageFileDeleteException(Exception):
"""Raised when an image cannot be deleted."""
def __init__(self, message="Image file not deleted"):
super().__init__(message)
@abstractmethod
def get(self, image_type: ImageType, image_name: str) -> Image:
@ -136,7 +154,7 @@ class DiskImageStorage(ImageStorageBase):
page_of_images.append(
ImageResponse(
image_type=image_type.value,
image_type=image_type,
image_name=filename,
# TODO: DiskImageStorage should not be building URLs...?
image_url=self.get_uri(image_type, filename),
@ -164,14 +182,17 @@ class DiskImageStorage(ImageStorageBase):
)
def get(self, image_type: ImageType, image_name: str) -> Image:
image_path = self.get_path(image_type, image_name)
cache_item = self.__get_cache(image_path)
if cache_item:
return cache_item
try:
image_path = self.get_path(image_type, image_name)
cache_item = self.__get_cache(image_path)
if cache_item:
return cache_item
image = PILImage.open(image_path)
self.__set_cache(image_path, image)
return image
image = PILImage.open(image_path)
self.__set_cache(image_path, image)
return image
except Exception as e:
raise ImageStorageBase.ImageFileNotFoundException from e
# TODO: make this a bit more flexible for e.g. cloud storage
def get_path(
@ -209,8 +230,10 @@ class DiskImageStorage(ImageStorageBase):
try:
os.stat(path)
return True
except Exception:
except FileNotFoundError:
return False
except Exception as e:
raise e
def save(
self,
@ -219,45 +242,53 @@ class DiskImageStorage(ImageStorageBase):
image: Image,
metadata: InvokeAIMetadata | None = None,
) -> SavedImage:
image_path = self.get_path(image_type, image_name)
try:
image_path = self.get_path(image_type, image_name)
# TODO: Reading the image and then saving it strips the metadata...
if metadata:
pnginfo = build_invokeai_metadata_pnginfo(metadata=metadata)
image.save(image_path, "PNG", pnginfo=pnginfo)
else:
image.save(image_path) # this saved image has an empty info
# TODO: Reading the image and then saving it strips the metadata...
if metadata:
pnginfo = build_invokeai_metadata_pnginfo(metadata=metadata)
image.save(image_path, "PNG", pnginfo=pnginfo)
else:
image.save(image_path) # this saved image has an empty info
thumbnail_name = get_thumbnail_name(image_name)
thumbnail_path = self.get_path(image_type, thumbnail_name, is_thumbnail=True)
thumbnail_image = make_thumbnail(image)
thumbnail_image.save(thumbnail_path)
thumbnail_name = get_thumbnail_name(image_name)
thumbnail_path = self.get_path(
image_type, thumbnail_name, is_thumbnail=True
)
thumbnail_image = make_thumbnail(image)
thumbnail_image.save(thumbnail_path)
self.__set_cache(image_path, image)
self.__set_cache(thumbnail_path, thumbnail_image)
self.__set_cache(image_path, image)
self.__set_cache(thumbnail_path, thumbnail_image)
return SavedImage(
image_name=image_name,
thumbnail_name=thumbnail_name,
created=int(os.path.getctime(image_path)),
)
return SavedImage(
image_name=image_name,
thumbnail_name=thumbnail_name,
created=int(os.path.getctime(image_path)),
)
except Exception as e:
raise ImageStorageBase.ImageFileSaveException from e
def delete(self, image_type: ImageType, image_name: str) -> None:
basename = os.path.basename(image_name)
image_path = self.get_path(image_type, basename)
try:
basename = os.path.basename(image_name)
image_path = self.get_path(image_type, basename)
if os.path.exists(image_path):
send2trash(image_path)
if image_path in self.__cache:
del self.__cache[image_path]
if os.path.exists(image_path):
send2trash(image_path)
if image_path in self.__cache:
del self.__cache[image_path]
thumbnail_name = get_thumbnail_name(image_name)
thumbnail_path = self.get_path(image_type, thumbnail_name, True)
thumbnail_name = get_thumbnail_name(image_name)
thumbnail_path = self.get_path(image_type, thumbnail_name, True)
if os.path.exists(thumbnail_path):
send2trash(thumbnail_path)
if thumbnail_path in self.__cache:
del self.__cache[thumbnail_path]
if os.path.exists(thumbnail_path):
send2trash(thumbnail_path)
if thumbnail_path in self.__cache:
del self.__cache[thumbnail_path]
except Exception as e:
raise ImageStorageBase.ImageFileDeleteException from e
def __get_cache(self, image_name: str) -> Image | None:
return None if image_name not in self.__cache else self.__cache[image_name]

View File

@ -0,0 +1,219 @@
from typing import Union
import uuid
from PIL.Image import Image as PILImageType
from invokeai.app.models.image import ImageCategory, ImageType
from invokeai.app.models.metadata import (
GeneratedImageOrLatentsMetadata,
UploadedImageOrLatentsMetadata,
)
from invokeai.app.services.image_db import (
ImageRecordServiceBase,
)
from invokeai.app.services.models.image_record import ImageRecord
from invokeai.app.services.image_storage import ImageStorageBase
from invokeai.app.services.item_storage import PaginatedResults
from invokeai.app.services.metadata import MetadataServiceBase
from invokeai.app.services.urls import UrlServiceBase
from invokeai.app.util.misc import get_iso_timestamp
class ImageServiceDependencies:
"""Service dependencies for the ImageManagementService."""
db: ImageRecordServiceBase
storage: ImageStorageBase
metadata: MetadataServiceBase
urls: UrlServiceBase
def __init__(
self,
image_db_service: ImageRecordServiceBase,
image_storage_service: ImageStorageBase,
image_metadata_service: MetadataServiceBase,
url_service: UrlServiceBase,
):
self.db = image_db_service
self.storage = image_storage_service
self.metadata = image_metadata_service
self.url = url_service
class ImageService:
"""High-level service for image management."""
_services: ImageServiceDependencies
def __init__(
self,
image_db_service: ImageRecordServiceBase,
image_storage_service: ImageStorageBase,
image_metadata_service: MetadataServiceBase,
url_service: UrlServiceBase,
):
self._services = ImageServiceDependencies(
image_db_service=image_db_service,
image_storage_service=image_storage_service,
image_metadata_service=image_metadata_service,
url_service=url_service,
)
def _create_image_name(
self,
image_type: ImageType,
image_category: ImageCategory,
node_id: Union[str, None],
session_id: Union[str, None],
) -> str:
"""Creates an image name."""
uuid_str = str(uuid.uuid4())
if node_id is not None and session_id is not None:
return f"{image_type.value}_{image_category.value}_{session_id}_{node_id}_{uuid_str}.png"
return f"{image_type.value}_{image_category.value}_{uuid_str}.png"
def create(
self,
image: PILImageType,
image_type: ImageType,
image_category: ImageCategory,
node_id: Union[str, None],
session_id: Union[str, None],
metadata: Union[
GeneratedImageOrLatentsMetadata, UploadedImageOrLatentsMetadata, None
],
) -> ImageRecord:
"""Creates an image, storing the file and its metadata."""
image_name = self._create_image_name(
image_type=image_type,
image_category=image_category,
node_id=node_id,
session_id=session_id,
)
timestamp = get_iso_timestamp()
try:
# TODO: Consider using a transaction here to ensure consistency between storage and database
self._services.storage.save(
image_type=image_type,
image_name=image_name,
image=image,
metadata=metadata,
)
self._services.db.save(
image_name=image_name,
image_type=image_type,
image_category=image_category,
node_id=node_id,
session_id=session_id,
metadata=metadata,
created_at=timestamp,
)
image_url = self._services.url.get_image_url(
image_type=image_type, image_name=image_name
)
thumbnail_url = self._services.url.get_thumbnail_url(
image_type=image_type, image_name=image_name
)
return ImageRecord(
image_name=image_name,
image_type=image_type,
image_category=image_category,
node_id=node_id,
session_id=session_id,
metadata=metadata,
created_at=timestamp,
image_url=image_url,
thumbnail_url=thumbnail_url,
)
except ImageRecordServiceBase.ImageRecordSaveException:
# TODO: log this
raise
except ImageStorageBase.ImageFileSaveException:
# TODO: log this
raise
def get_pil_image(self, image_type: ImageType, image_name: str) -> PILImageType:
"""Gets an image as a PIL image."""
try:
pil_image = self._services.storage.get(
image_type=image_type, image_name=image_name
)
return pil_image
except ImageStorageBase.ImageFileNotFoundException:
# TODO: log this
raise
def get_record(self, image_type: ImageType, image_name: str) -> ImageRecord:
"""Gets an image record."""
try:
image_record = self._services.db.get(
image_type=image_type, image_name=image_name
)
return image_record
except ImageRecordServiceBase.ImageRecordNotFoundException:
# TODO: log this
raise
def delete(self, image_type: ImageType, image_name: str):
"""Deletes an image."""
# TODO: Consider using a transaction here to ensure consistency between storage and database
try:
self._services.storage.delete(image_type=image_type, image_name=image_name)
self._services.db.delete(image_type=image_type, image_name=image_name)
except ImageRecordServiceBase.ImageRecordDeleteException:
# TODO: log this
raise
except ImageStorageBase.ImageFileDeleteException:
# TODO: log this
raise
def get_many(
self,
image_type: ImageType,
image_category: ImageCategory,
page: int = 0,
per_page: int = 10,
) -> PaginatedResults[ImageRecord]:
"""Gets a paginated list of image records."""
try:
results = self._services.db.get_many(
image_type=image_type,
image_category=image_category,
page=page,
per_page=per_page,
)
for r in results.items:
r.image_url = self._services.url.get_image_url(
image_type=image_type, image_name=r.image_name
)
r.thumbnail_url = self._services.url.get_thumbnail_url(
image_type=image_type, image_name=r.image_name
)
return results
except Exception as e:
raise e
def add_tag(self, image_type: ImageType, image_id: str, tag: str) -> None:
"""Adds a tag to an image."""
raise NotImplementedError("The 'add_tag' method is not implemented yet.")
def remove_tag(self, image_type: ImageType, image_id: str, tag: str) -> None:
"""Removes a tag from an image."""
raise NotImplementedError("The 'remove_tag' method is not implemented yet.")
def favorite(self, image_type: ImageType, image_id: str) -> None:
"""Favorites an image."""
raise NotImplementedError("The 'favorite' method is not implemented yet.")
def unfavorite(self, image_type: ImageType, image_id: str) -> None:
"""Unfavorites an image."""
raise NotImplementedError("The 'unfavorite' method is not implemented yet.")

View File

@ -1,7 +1,12 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
from typing import types
from types import ModuleType
from invokeai.app.services.image_db import (
ImageRecordServiceBase,
)
from invokeai.app.services.images import ImageService
from invokeai.app.services.metadata import MetadataServiceBase
from invokeai.app.services.urls import UrlServiceBase
from invokeai.backend import ModelManager
from .events import EventServiceBase
@ -12,6 +17,7 @@ from .invocation_queue import InvocationQueueABC
from .item_storage import ItemStorageABC
from .config import InvokeAISettings
class InvocationServices:
"""Services that can be used by invocations"""
@ -23,26 +29,32 @@ class InvocationServices:
model_manager: ModelManager
restoration: RestorationServices
configuration: InvokeAISettings
images_db: ImageRecordServiceBase
urls: UrlServiceBase
images_new: ImageService
# NOTE: we must forward-declare any types that include invocations, since invocations can use services
graph_library: ItemStorageABC["LibraryGraph"]
graph_execution_manager: ItemStorageABC["GraphExecutionState"]
processor: "InvocationProcessorABC"
def __init__(
self,
model_manager: ModelManager,
events: EventServiceBase,
logger: types.ModuleType,
latents: LatentsStorageBase,
images: ImageStorageBase,
metadata: MetadataServiceBase,
queue: InvocationQueueABC,
graph_library: ItemStorageABC["LibraryGraph"],
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
processor: "InvocationProcessorABC",
restoration: RestorationServices,
configuration: InvokeAISettings=None,
self,
model_manager: ModelManager,
events: EventServiceBase,
logger: ModuleType,
latents: LatentsStorageBase,
images: ImageStorageBase,
metadata: MetadataServiceBase,
queue: InvocationQueueABC,
images_db: ImageRecordServiceBase,
images_new: ImageService,
urls: UrlServiceBase,
graph_library: ItemStorageABC["LibraryGraph"],
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
processor: "InvocationProcessorABC",
restoration: RestorationServices,
configuration: InvokeAISettings=None,
):
self.model_manager = model_manager
self.events = events
@ -51,8 +63,13 @@ class InvocationServices:
self.images = images
self.metadata = metadata
self.queue = queue
self.images_db = images_db
self.images_new = images_new
self.urls = urls
self.graph_library = graph_library
self.graph_execution_manager = graph_execution_manager
self.processor = processor
self.restoration = restoration
self.configuration = configuration

View File

@ -22,16 +22,24 @@ class MetadataLatentsField(TypedDict):
class MetadataColorField(TypedDict):
"""Pydantic-less ColorField, used for metadata parsing"""
r: int
g: int
b: int
a: int
# TODO: This is a placeholder for `InvocationsUnion` pending resolution of circular imports
NodeMetadata = Dict[
str, None | str | int | float | bool | MetadataImageField | MetadataLatentsField | MetadataColorField
str,
None
| str
| int
| float
| bool
| MetadataImageField
| MetadataLatentsField
| MetadataColorField,
]
@ -67,6 +75,11 @@ class MetadataServiceBase(ABC):
"""Builds an InvokeAIMetadata object"""
pass
@abstractmethod
def create_metadata(self, session_id: str, node_id: str) -> dict:
"""Creates metadata for a result"""
pass
class PngMetadataService(MetadataServiceBase):
"""Handles loading and building metadata for images."""

View File

@ -0,0 +1,29 @@
import datetime
from typing import Literal, Optional, Union
from pydantic import BaseModel, Field
from invokeai.app.models.metadata import (
GeneratedImageOrLatentsMetadata,
UploadedImageOrLatentsMetadata,
)
from invokeai.app.models.image import ImageCategory, ImageType
from invokeai.app.models.resources import ResourceType
class ImageRecord(BaseModel):
"""Deserialized image record."""
image_name: str = Field(description="The name of the image.")
image_type: ImageType = Field(description="The type of the image.")
image_category: ImageCategory = Field(description="The category of the image.")
created_at: Union[datetime.datetime, str] = Field(
description="The created timestamp of the image."
)
session_id: Optional[str] = Field(default=None, description="The session ID.")
node_id: Optional[str] = Field(default=None, description="The node ID.")
metadata: Optional[
Union[GeneratedImageOrLatentsMetadata, UploadedImageOrLatentsMetadata]
] = Field(default=None, description="The image's metadata.")
image_url: Optional[str] = Field(default=None, description="The URL of the image.")
thumbnail_url: Optional[str] = Field(
default=None, description="The thumbnail URL of the image."
)

View File

@ -0,0 +1,657 @@
from abc import ABC, abstractmethod
from enum import Enum
import enum
import sqlite3
import threading
from typing import Optional, Type, TypeVar, Union
from PIL.Image import Image as PILImage
from pydantic import BaseModel, Field
from torch import Tensor
from invokeai.app.services.item_storage import PaginatedResults
"""
Substantial proposed changes to the management of images and tensor.
tl;dr:
With the upcoming move to latents-only nodes, we need to handle metadata differently. After struggling with this unsuccessfully - trying to smoosh it in to the existing setup - I believe we need to expand the scope of the refactor to include the management of images and latents - and make `latents` a special case of `tensor`.
full story:
The consensus for tensor-only nodes' metadata was to traverse the execution graph and grab the core parameters to write to the image. This was straightforward, and I've written functions to find the nearest t2l/l2l, noise, and compel nodes and build the metadata from those.
But struggling to integrate this and the associated edge cases this brought up a number of issues deeper in the system (some of which I had previously implemented). The ImageStorageService is doing way too much, and we have a need to be able to retrieve sessions the session given image/latents id, which is not currently feasible due to SQLite's JSON parsing performance.
I made a new ResultsService and `results` table in the db to facilitate this. This first attempt failed because it doesn't handle uploads and leaves the codebase messy.
So I've spent the day trying to figure out to handle this in a sane way and think I've got something decent. I've described some changes to service bases and the database below.
The gist of it is to store the core parameters for an image in its metadata when the image is saved, but never to read from it. Instead, the same metadata is stored in the database, which will be set up for efficient access. So when a page of images is requested, the metadata comes from the db instead of a filesystem operation.
The URL generation responsibilities have been split off the image storage service in to a URL service. New database services/tables for images and tensor are added. These services will provide paginated images/tensors for the API to serve. This also paves the way for handling tensors as first-class outputs.
"""
# TODO: Make a new model for this
class ResourceOrigin(str, Enum):
"""The origin of a resource (eg image or tensor)."""
RESULTS = "results"
UPLOADS = "uploads"
INTERMEDIATES = "intermediates"
class ImageKind(str, Enum):
"""The kind of an image."""
IMAGE = "image"
CONTROL_IMAGE = "control_image"
class TensorKind(str, Enum):
"""The kind of a tensor."""
IMAGE_TENSOR = "tensor"
CONDITIONING = "conditioning"
"""
Core Generation Metadata Pydantic Model
I've already implemented the code to traverse a session to build this object.
"""
class CoreGenerationMetadata(BaseModel):
"""Core generation metadata for an image/tensor generated in InvokeAI.
Generated by traversing the execution graph, collecting the parameters of the nearest ancestors of a given node.
Full metadata may be accessed by querying for the session in the `graph_executions` table.
"""
positive_conditioning: Optional[str] = Field(
description="The positive conditioning."
)
negative_conditioning: Optional[str] = Field(
description="The negative conditioning."
)
width: Optional[int] = Field(description="Width of the image/tensor in pixels.")
height: Optional[int] = Field(description="Height of the image/tensor in pixels.")
seed: Optional[int] = Field(description="The seed used for noise generation.")
cfg_scale: Optional[float] = Field(
description="The classifier-free guidance scale."
)
steps: Optional[int] = Field(description="The number of steps used for inference.")
scheduler: Optional[str] = Field(description="The scheduler used for inference.")
model: Optional[str] = Field(description="The model used for inference.")
strength: Optional[float] = Field(
description="The strength used for image-to-image/tensor-to-tensor."
)
image: Optional[str] = Field(description="The ID of the initial image.")
tensor: Optional[str] = Field(description="The ID of the initial tensor.")
# Pending model refactor:
# vae: Optional[str] = Field(description="The VAE used for decoding.")
# unet: Optional[str] = Field(description="The UNet used dor inference.")
# clip: Optional[str] = Field(description="The CLIP Encoder used for conditioning.")
"""
Minimal Uploads Metadata Model
"""
class UploadsMetadata(BaseModel):
"""Limited metadata for an uploaded image/tensor."""
width: Optional[int] = Field(description="Width of the image/tensor in pixels.")
height: Optional[int] = Field(description="Height of the image/tensor in pixels.")
# The extra field will be the contents of the PNG file's tEXt chunk. It may have come
# from another SD application or InvokeAI, so we need to make it very flexible. I think it's
# best to just store it as a string and let the frontend parse it.
# If the upload is a tensor type, this will be omitted.
extra: Optional[str] = Field(
description="Extra metadata, extracted from the PNG tEXt chunk."
)
"""
Slimmed-down Image Storage Service Base
- No longer lists images or generates URLs - only stores and retrieves images.
- OSS implementation for disk storage
"""
class ImageStorageBase(ABC):
"""Responsible for storing and retrieving images."""
@abstractmethod
def save(
self,
image: PILImage,
image_kind: ImageKind,
origin: ResourceOrigin,
context_id: str,
node_id: str,
metadata: CoreGenerationMetadata,
) -> str:
"""Saves an image and its thumbnail, returning its unique identifier."""
pass
@abstractmethod
def get(self, id: str, thumbnail: bool = False) -> Union[PILImage, None]:
"""Retrieves an image as a PIL Image."""
pass
@abstractmethod
def delete(self, id: str) -> None:
"""Deletes an image."""
pass
class TensorStorageBase(ABC):
"""Responsible for storing and retrieving tensors."""
@abstractmethod
def save(
self,
tensor: Tensor,
tensor_kind: TensorKind,
origin: ResourceOrigin,
context_id: str,
node_id: str,
metadata: CoreGenerationMetadata,
) -> str:
"""Saves a tensor, returning its unique identifier."""
pass
@abstractmethod
def get(self, id: str, thumbnail: bool = False) -> Union[Tensor, None]:
"""Retrieves a tensor as a torch Tensor."""
pass
@abstractmethod
def delete(self, id: str) -> None:
"""Deletes a tensor."""
pass
"""
New Url Service Base
- Abstracts the logic for generating URLs out of the storage service
- OSS implementation for locally-hosted URLs
- Also provides a method to get the internal path to a resource (for OSS, the FS path)
"""
class ResourceLocationServiceBase(ABC):
"""Responsible for locating resources (eg images or tensors)."""
@abstractmethod
def get_url(self, id: str) -> str:
"""Gets the URL for a resource."""
pass
@abstractmethod
def get_path(self, id: str) -> str:
"""Gets the path for a resource."""
pass
"""
New Images Database Service Base
This is a new service that will be responsible for the new `images` table(s):
- Storing images in the table
- Retrieving individual images and pages of images
- Deleting individual images
Operations will typically use joins with the various `images` tables.
"""
class ImagesDbServiceBase(ABC):
"""Responsible for interfacing with `images` table."""
class GeneratedImageEntity(BaseModel):
id: str = Field(description="The unique identifier for the image.")
session_id: str = Field(description="The session ID.")
node_id: str = Field(description="The node ID.")
metadata: CoreGenerationMetadata = Field(
description="The metadata for the image."
)
class UploadedImageEntity(BaseModel):
id: str = Field(description="The unique identifier for the image.")
metadata: UploadsMetadata = Field(description="The metadata for the image.")
@abstractmethod
def get(self, id: str) -> Union[GeneratedImageEntity, UploadedImageEntity, None]:
"""Gets an image from the `images` table."""
pass
@abstractmethod
def get_many(
self, image_kind: ImageKind, page: int = 0, per_page: int = 10
) -> PaginatedResults[Union[GeneratedImageEntity, UploadedImageEntity]]:
"""Gets a page of images from the `images` table."""
pass
@abstractmethod
def delete(self, id: str) -> None:
"""Deletes an image from the `images` table."""
pass
@abstractmethod
def set(
self,
id: str,
image_kind: ImageKind,
session_id: Optional[str],
node_id: Optional[str],
metadata: CoreGenerationMetadata | UploadsMetadata,
) -> None:
"""Sets an image in the `images` table."""
pass
"""
New Tensor Database Service Base
This is a new service that will be responsible for the new `tensor` table:
- Storing tensor in the table
- Retrieving individual tensor and pages of tensor
- Deleting individual tensor
Operations will always use joins with the `tensor_metadata` table.
"""
class TensorDbServiceBase(ABC):
"""Responsible for interfacing with `tensor` table."""
class GeneratedTensorEntity(BaseModel):
id: str = Field(description="The unique identifier for the tensor.")
session_id: str = Field(description="The session ID.")
node_id: str = Field(description="The node ID.")
metadata: CoreGenerationMetadata = Field(
description="The metadata for the tensor."
)
class UploadedTensorEntity(BaseModel):
id: str = Field(description="The unique identifier for the tensor.")
metadata: UploadsMetadata = Field(description="The metadata for the tensor.")
@abstractmethod
def get(self, id: str) -> Union[GeneratedTensorEntity, UploadedTensorEntity, None]:
"""Gets a tensor from the `tensor` table."""
pass
@abstractmethod
def get_many(
self, tensor_kind: TensorKind, page: int = 0, per_page: int = 10
) -> PaginatedResults[Union[GeneratedTensorEntity, UploadedTensorEntity]]:
"""Gets a page of tensor from the `tensor` table."""
pass
@abstractmethod
def delete(self, id: str) -> None:
"""Deletes a tensor from the `tensor` table."""
pass
@abstractmethod
def set(
self,
id: str,
tensor_kind: TensorKind,
session_id: Optional[str],
node_id: Optional[str],
metadata: CoreGenerationMetadata | UploadsMetadata,
) -> None:
"""Sets a tensor in the `tensor` table."""
pass
"""
Database Changes
The existing tables will remain as-is, new tables will be added.
Tensor now also have the same types as images - `results`, `intermediates`, `uploads`. Storage, retrieval, and operations may diverge from images in the future, so they are managed separately.
A few `images` tables are created to store all images:
- `results` and `intermediates` images have additional data: `session_id` and `node_id`, and may be further differentiated in the future. For this reason, they each get their own table.
- `uploads` do not get their own table, as they are never going to have more than an `id`, `image_kind` and `timestamp`.
- `images_metadata` holds the same image metadata that is written to the image. This table, along with the URL service, allow us to more efficiently serve images without having to read the image from storage.
The same tables are made for `tensor` and for the moment, implementation is expected to be identical.
Schemas for each table below.
Insertions and updates of ancillary tables (e.g. `results_images`, `images_metadata`, etc) will need to be done manually in the services, but should be straightforward. Deletion via cascading will be handled by the database.
"""
def create_sql_values_string_from_string_enum(enum: Type[Enum]):
"""
Creates a string of the form "('value1'), ('value2'), ..., ('valueN')" from a StrEnum.
"""
delimiter = ", "
values = [f"('{e.value}')" for e in enum]
return delimiter.join(values)
def create_sql_table_from_enum(
enum: Type[Enum],
table_name: str,
primary_key_name: str,
cursor: sqlite3.Cursor,
lock: threading.Lock,
):
"""
Creates and populates a table to be used as a functional enum.
"""
try:
lock.acquire()
values_string = create_sql_values_string_from_string_enum(enum)
cursor.execute(
f"""--sql
CREATE TABLE IF NOT EXISTS {table_name} (
{primary_key_name} TEXT PRIMARY KEY
);
"""
)
cursor.execute(
f"""--sql
INSERT OR IGNORE INTO {table_name} ({primary_key_name}) VALUES {values_string};
"""
)
finally:
lock.release()
"""
`resource_origins` functions as an enum for the ResourceOrigin model.
"""
def create_resource_origins_table(cursor: sqlite3.Cursor, lock: threading.Lock):
create_sql_table_from_enum(
enum=ResourceOrigin,
table_name="resource_origins",
primary_key_name="origin_name",
cursor=cursor,
lock=lock,
)
"""
`image_kinds` functions as an enum for the ImageType model.
"""
def create_image_kinds_table(cursor: sqlite3.Cursor, lock: threading.Lock):
create_sql_table_from_enum(
enum=ImageKind,
table_name="image_kinds",
primary_key_name="kind_name",
cursor=cursor,
lock=lock,
)
"""
`tensor_kinds` functions as an enum for the TensorType model.
"""
def create_tensor_kinds_table(cursor: sqlite3.Cursor, lock: threading.Lock):
create_sql_table_from_enum(
enum=TensorKind,
table_name="tensor_kinds",
primary_key_name="kind_name",
cursor=cursor,
lock=lock,
)
"""
`images` stores all images, regardless of type
"""
def create_images_table(cursor: sqlite3.Cursor, lock: threading.Lock):
try:
lock.acquire()
cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS images (
id TEXT PRIMARY KEY,
origin TEXT,
image_kind TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY(origin) REFERENCES resource_origins(origin_name),
FOREIGN KEY(image_kind) REFERENCES image_kinds(kind_name)
);
"""
)
cursor.execute(
"""--sql
CREATE UNIQUE INDEX IF NOT EXISTS idx_images_id ON images(id);
"""
)
cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_images_origin ON images(origin);
"""
)
cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_images_image_kind ON images(image_kind);
"""
)
finally:
lock.release()
"""
`image_results` stores additional data specific to `results` images.
"""
def create_image_results_table(cursor: sqlite3.Cursor, lock: threading.Lock):
try:
lock.acquire()
cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS image_results (
images_id TEXT PRIMARY KEY,
session_id TEXT NOT NULL,
node_id TEXT NOT NULL,
FOREIGN KEY(images_id) REFERENCES images(id) ON DELETE CASCADE
);
"""
)
cursor.execute(
"""--sql
CREATE UNIQUE INDEX IF NOT EXISTS idx_image_results_images_id ON image_results(id);
"""
)
finally:
lock.release()
"""
`image_intermediates` stores additional data specific to `intermediates` images
"""
def create_image_intermediates_table(cursor: sqlite3.Cursor, lock: threading.Lock):
try:
lock.acquire()
cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS image_intermediates (
images_id TEXT PRIMARY KEY,
session_id TEXT NOT NULL,
node_id TEXT NOT NULL,
FOREIGN KEY(images_id) REFERENCES images(id) ON DELETE CASCADE
);
"""
)
cursor.execute(
"""--sql
CREATE UNIQUE INDEX IF NOT EXISTS idx_image_intermediates_images_id ON image_intermediates(id);
"""
)
finally:
lock.release()
"""
`images_metadata` stores basic metadata for any image type
"""
def create_images_metadata_table(cursor: sqlite3.Cursor, lock: threading.Lock):
try:
lock.acquire()
cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS images_metadata (
images_id TEXT PRIMARY KEY,
metadata TEXT,
FOREIGN KEY(images_id) REFERENCES images(id) ON DELETE CASCADE
);
"""
)
cursor.execute(
"""--sql
CREATE UNIQUE INDEX IF NOT EXISTS idx_images_metadata_images_id ON images_metadata(images_id);
"""
)
finally:
lock.release()
# `tensor` table: stores references to tensor
def create_tensors_table(cursor: sqlite3.Cursor, lock: threading.Lock):
try:
lock.acquire()
cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS tensors (
id TEXT PRIMARY KEY,
origin TEXT,
tensor_kind TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY(origin) REFERENCES resource_origins(origin_name),
FOREIGN KEY(tensor_kind) REFERENCES tensor_kinds(kind_name),
);
"""
)
cursor.execute(
"""--sql
CREATE UNIQUE INDEX IF NOT EXISTS idx_tensors_id ON tensors(id);
"""
)
cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_tensors_origin ON tensors(origin);
"""
)
cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_tensors_tensor_kind ON tensors(tensor_kind);
"""
)
finally:
lock.release()
# `results_tensor` stores additional data specific to `result` tensor
def create_tensor_results_table(cursor: sqlite3.Cursor, lock: threading.Lock):
try:
lock.acquire()
cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS tensor_results (
tensor_id TEXT PRIMARY KEY,
session_id TEXT NOT NULL,
node_id TEXT NOT NULL,
FOREIGN KEY(tensor_id) REFERENCES tensors(id) ON DELETE CASCADE
);
"""
)
cursor.execute(
"""--sql
CREATE UNIQUE INDEX IF NOT EXISTS idx_tensor_results_tensor_id ON tensor_results(tensor_id);
"""
)
finally:
lock.release()
# `tensor_intermediates` stores additional data specific to `intermediate` tensor
def create_tensor_intermediates_table(cursor: sqlite3.Cursor, lock: threading.Lock):
try:
lock.acquire()
cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS tensor_intermediates (
tensor_id TEXT PRIMARY KEY,
session_id TEXT NOT NULL,
node_id TEXT NOT NULL,
FOREIGN KEY(tensor_id) REFERENCES tensors(id) ON DELETE CASCADE
);
"""
)
cursor.execute(
"""--sql
CREATE UNIQUE INDEX IF NOT EXISTS idx_tensor_intermediates_tensor_id ON tensor_intermediates(tensor_id);
"""
)
finally:
lock.release()
# `tensors_metadata` table: stores generated/transformed metadata for tensor
def create_tensors_metadata_table(cursor: sqlite3.Cursor, lock: threading.Lock):
try:
lock.acquire()
cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS tensors_metadata (
tensor_id TEXT PRIMARY KEY,
metadata TEXT,
FOREIGN KEY(tensor_id) REFERENCES tensors(id) ON DELETE CASCADE
);
"""
)
cursor.execute(
"""--sql
CREATE UNIQUE INDEX IF NOT EXISTS idx_tensors_metadata_tensor_id ON tensors_metadata(tensor_id);
"""
)
finally:
lock.release()

View File

@ -0,0 +1,466 @@
from enum import Enum
from abc import ABC, abstractmethod
import json
import sqlite3
from threading import Lock
from typing import Any, Union
import networkx as nx
from pydantic import BaseModel, Field, parse_obj_as, parse_raw_as
from invokeai.app.invocations.image import ImageOutput
from invokeai.app.services.graph import Edge, GraphExecutionState
from invokeai.app.invocations.latent import LatentsOutput
from invokeai.app.services.item_storage import PaginatedResults
from invokeai.app.util.misc import get_timestamp
class ResultType(str, Enum):
image_output = "image_output"
latents_output = "latents_output"
class Result(BaseModel):
"""A session result"""
id: str = Field(description="Result ID")
session_id: str = Field(description="Session ID")
node_id: str = Field(description="Node ID")
data: Union[LatentsOutput, ImageOutput] = Field(description="The result data")
class ResultWithSession(BaseModel):
"""A result with its session"""
result: Result = Field(description="The result")
session: GraphExecutionState = Field(description="The session")
# Create a directed graph
from typing import Any, TypedDict, Union
from networkx import DiGraph
import networkx as nx
import json
# We need to use a loose class for nodes to allow for graceful parsing - we cannot use the stricter
# model used by the system, because we may be a graph in an old format. We can, however, use the
# Edge model, because the edge format does not change.
class LooseGraph(BaseModel):
id: str
nodes: dict[str, dict[str, Any]]
edges: list[Edge]
# An intermediate type used during parsing
class NearestAncestor(TypedDict):
node_id: str
metadata: dict[str, Any]
# The ancestor types that contain the core metadata
ANCESTOR_TYPES = ['t2l', 'l2l']
# The core metadata parameters in the ancestor types
ANCESTOR_PARAMS = ['steps', 'model', 'cfg_scale', 'scheduler', 'strength']
# The core metadata parameters in the noise node
NOISE_FIELDS = ['seed', 'width', 'height']
# Find nearest t2l or l2l ancestor from a given l2i node
def find_nearest_ancestor(G: DiGraph, node_id: str) -> Union[NearestAncestor, None]:
"""Returns metadata for the nearest ancestor of a given node.
Parameters:
G (DiGraph): A directed graph.
node_id (str): The ID of the starting node.
Returns:
NearestAncestor | None: An object with the ID and metadata of the nearest ancestor.
"""
# Retrieve the node from the graph
node = G.nodes[node_id]
# If the node type is one of the core metadata node types, gather necessary metadata and return
if node.get('type') in ANCESTOR_TYPES:
parsed_metadata = {param: val for param, val in node.items() if param in ANCESTOR_PARAMS}
return NearestAncestor(node_id=node_id, metadata=parsed_metadata)
# Else, look for the ancestor in the predecessor nodes
for predecessor in G.predecessors(node_id):
result = find_nearest_ancestor(G, predecessor)
if result:
return result
# If there are no valid ancestors, return None
return None
def get_additional_metadata(graph: LooseGraph, node_id: str) -> Union[dict[str, Any], None]:
"""Collects additional metadata from nodes connected to a given node.
Parameters:
graph (LooseGraph): The graph.
node_id (str): The ID of the node.
Returns:
dict | None: A dictionary containing additional metadata.
"""
metadata = {}
# Iterate over all edges in the graph
for edge in graph.edges:
dest_node_id = edge.destination.node_id
dest_field = edge.destination.field
source_node = graph.nodes[edge.source.node_id]
# If the destination node ID matches the given node ID, gather necessary metadata
if dest_node_id == node_id:
# If the destination field is 'positive_conditioning', add the 'prompt' from the source node
if dest_field == 'positive_conditioning':
metadata['positive_conditioning'] = source_node.get('prompt')
# If the destination field is 'negative_conditioning', add the 'prompt' from the source node
if dest_field == 'negative_conditioning':
metadata['negative_conditioning'] = source_node.get('prompt')
# If the destination field is 'noise', add the core noise fields from the source node
if dest_field == 'noise':
for field in NOISE_FIELDS:
metadata[field] = source_node.get(field)
return metadata
def build_core_metadata(graph_raw: str, node_id: str) -> Union[dict, None]:
"""Builds the core metadata for a given node.
Parameters:
graph_raw (str): The graph structure as a raw string.
node_id (str): The ID of the node.
Returns:
dict | None: A dictionary containing core metadata.
"""
# Create a directed graph to facilitate traversal
G = nx.DiGraph()
# Convert the raw graph string into a JSON object
graph = parse_obj_as(LooseGraph, graph_raw)
# Add nodes and edges to the graph
for node_id, node_data in graph.nodes.items():
G.add_node(node_id, **node_data)
for edge in graph.edges:
G.add_edge(edge.source.node_id, edge.destination.node_id)
# Find the nearest ancestor of the given node
ancestor = find_nearest_ancestor(G, node_id)
# If no ancestor was found, return None
if ancestor is None:
return None
metadata = ancestor['metadata']
ancestor_id = ancestor['node_id']
# Get additional metadata related to the ancestor
addl_metadata = get_additional_metadata(graph, ancestor_id)
# If additional metadata was found, add it to the main metadata
if addl_metadata is not None:
metadata.update(addl_metadata)
return metadata
class ResultsServiceABC(ABC):
"""The Results service is responsible for retrieving results."""
@abstractmethod
def get(
self, result_id: str, result_type: ResultType
) -> Union[ResultWithSession, None]:
pass
@abstractmethod
def get_many(
self, result_type: ResultType, page: int = 0, per_page: int = 10
) -> PaginatedResults[ResultWithSession]:
pass
@abstractmethod
def search(
self, query: str, page: int = 0, per_page: int = 10
) -> PaginatedResults[ResultWithSession]:
pass
@abstractmethod
def handle_graph_execution_state_change(self, session: GraphExecutionState) -> None:
pass
class SqliteResultsService(ResultsServiceABC):
"""SQLite implementation of the Results service."""
_filename: str
_conn: sqlite3.Connection
_cursor: sqlite3.Cursor
_lock: Lock
def __init__(self, filename: str):
super().__init__()
self._filename = filename
self._lock = Lock()
self._conn = sqlite3.connect(
self._filename, check_same_thread=False
) # TODO: figure out a better threading solution
self._cursor = self._conn.cursor()
self._create_table()
def _create_table(self):
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS results (
id TEXT PRIMARY KEY, -- the result's name
result_type TEXT, -- `image_output` | `latents_output`
node_id TEXT, -- the node that produced this result
session_id TEXT, -- the session that produced this result
created_at INTEGER, -- the time at which this result was created
data TEXT -- the result itself
);
"""
)
self._cursor.execute(
"""--sql
CREATE UNIQUE INDEX IF NOT EXISTS idx_result_id ON results(id);
"""
)
finally:
self._lock.release()
def _parse_joined_result(self, result_row: Any, column_names: list[str]):
result_raw = {}
session_raw = {}
for idx, name in enumerate(column_names):
if name == "session":
session_raw = json.loads(result_row[idx])
elif name == "data":
result_raw[name] = json.loads(result_row[idx])
else:
result_raw[name] = result_row[idx]
graph_raw = session_raw['execution_graph']
result = parse_obj_as(Result, result_raw)
session = parse_obj_as(GraphExecutionState, session_raw)
m = build_core_metadata(graph_raw, result.node_id)
print(m)
# g = session.execution_graph.nx_graph()
# ancestors = nx.dag.ancestors(g, result.node_id)
# nodes = [session.execution_graph.get_node(result.node_id)]
# for ancestor in ancestors:
# nodes.append(session.execution_graph.get_node(ancestor))
# filtered_nodes = filter(lambda n: n.type in NODE_TYPE_ALLOWLIST, nodes)
# print(list(map(lambda n: n.dict(), filtered_nodes)))
# metadata = {}
# for node in nodes:
# if (node.type in ['txt2img', 'img2img',])
# for field, value in node.dict().items():
# if field not in ['type', 'id']:
# if field not in metadata:
# metadata[field] = value
# print(ancestors)
# print(nodes)
# print(metadata)
# for node in nodes:
# print(node.dict())
# print(nodes)
return ResultWithSession(
result=result,
session=session,
)
def get(
self, result_id: str, result_type: ResultType
) -> Union[ResultWithSession, None]:
"""Retrieves a result by ID and type."""
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT
results.id AS id,
results.result_type AS result_type,
results.node_id AS node_id,
results.session_id AS session_id,
results.data AS data,
graph_executions.item AS session
FROM results
JOIN graph_executions ON results.session_id = graph_executions.id
WHERE results.id = ? AND results.result_type = ?
""",
(result_id, result_type),
)
result_row = self._cursor.fetchone()
if result_row is None:
return None
column_names = list(map(lambda x: x[0], self._cursor.description))
result_parsed = self._parse_joined_result(result_row, column_names)
finally:
self._lock.release()
if not result_parsed:
return None
return result_parsed
def get_many(
self,
result_type: ResultType,
page: int = 0,
per_page: int = 10,
) -> PaginatedResults[ResultWithSession]:
"""Lists results of a given type."""
try:
self._lock.acquire()
self._cursor.execute(
f"""--sql
SELECT
results.id AS id,
results.result_type AS result_type,
results.node_id AS node_id,
results.session_id AS session_id,
results.data AS data,
graph_executions.item AS session
FROM results
JOIN graph_executions ON results.session_id = graph_executions.id
WHERE results.result_type = ?
LIMIT ? OFFSET ?;
""",
(result_type.value, per_page, page * per_page),
)
result_rows = self._cursor.fetchall()
column_names = list(map(lambda c: c[0], self._cursor.description))
result_parsed = []
for result_row in result_rows:
result_parsed.append(
self._parse_joined_result(result_row, column_names)
)
self._cursor.execute("""SELECT count(*) FROM results;""")
count = self._cursor.fetchone()[0]
finally:
self._lock.release()
pageCount = int(count / per_page) + 1
return PaginatedResults[ResultWithSession](
items=result_parsed,
page=page,
pages=pageCount,
per_page=per_page,
total=count,
)
def search(
self,
query: str,
page: int = 0,
per_page: int = 10,
) -> PaginatedResults[ResultWithSession]:
"""Finds results by query."""
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT results.data, graph_executions.item
FROM results
JOIN graph_executions ON results.session_id = graph_executions.id
WHERE item LIKE ?
LIMIT ? OFFSET ?;
""",
(f"%{query}%", per_page, page * per_page),
)
result_rows = self._cursor.fetchall()
items = list(
map(
lambda r: ResultWithSession(
result=parse_raw_as(Result, r[0]),
session=parse_raw_as(GraphExecutionState, r[1]),
),
result_rows,
)
)
self._cursor.execute(
"""--sql
SELECT count(*) FROM results WHERE item LIKE ?;
""",
(f"%{query}%",),
)
count = self._cursor.fetchone()[0]
finally:
self._lock.release()
pageCount = int(count / per_page) + 1
return PaginatedResults[ResultWithSession](
items=items, page=page, pages=pageCount, per_page=per_page, total=count
)
def handle_graph_execution_state_change(self, session: GraphExecutionState) -> None:
"""Updates the results table with the results from the session."""
with self._conn as conn:
for node_id, result in session.results.items():
# We'll only process 'image_output' or 'latents_output'
if result.type not in ["image_output", "latents_output"]:
continue
# The id depends on the result type
if result.type == "image_output":
id = result.image.image_name
result_type = "image_output"
else:
id = result.latents.latents_name
result_type = "latents_output"
# Insert the result into the results table, ignoring if it already exists
conn.execute(
"""--sql
INSERT OR IGNORE INTO results (id, result_type, node_id, session_id, created_at, data)
VALUES (?, ?, ?, ?, ?, ?)
""",
(
id,
result_type,
node_id,
session.id,
get_timestamp(),
result.json(),
),
)

View File

@ -0,0 +1,32 @@
import os
from abc import ABC, abstractmethod
from invokeai.app.models.image import ImageType
from invokeai.app.util.thumbnails import get_thumbnail_name
class UrlServiceBase(ABC):
"""Responsible for building URLs for resources (eg images or tensors)"""
@abstractmethod
def get_image_url(self, image_type: ImageType, image_name: str) -> str:
"""Gets the URL for an image"""
pass
@abstractmethod
def get_thumbnail_url(self, image_type: ImageType, image_name: str) -> str:
"""Gets the URL for an image's thumbnail"""
pass
class LocalUrlService(UrlServiceBase):
def __init__(self, base_url: str = "api/v1"):
self._base_url = base_url
def get_image_url(self, image_type: ImageType, image_name: str) -> str:
image_basename = os.path.basename(image_name)
return f"{self._base_url}/images/{image_type.value}/{image_basename}"
def get_thumbnail_url(self, image_type: ImageType, image_name: str) -> str:
thumbnail_basename = get_thumbnail_name(os.path.basename(image_name))
return f"{self._base_url}/images/{image_type.value}/thumbnails/{thumbnail_basename}"

View File

@ -0,0 +1,39 @@
from enum import Enum
import sqlite3
from typing import Type
def create_sql_values_string_from_string_enum(enum: Type[Enum]):
"""
Creates a string of the form "('value1'), ('value2'), ..., ('valueN')" from a StrEnum.
"""
delimiter = ", "
values = [f"('{e.value}')" for e in enum]
return delimiter.join(values)
def create_enum_table(
enum: Type[Enum],
table_name: str,
primary_key_name: str,
cursor: sqlite3.Cursor,
):
"""
Creates and populates a table to be used as a functional enum.
"""
values_string = create_sql_values_string_from_string_enum(enum)
cursor.execute(
f"""--sql
CREATE TABLE IF NOT EXISTS {table_name} (
{primary_key_name} TEXT PRIMARY KEY
);
"""
)
cursor.execute(
f"""--sql
INSERT OR IGNORE INTO {table_name} ({primary_key_name}) VALUES {values_string};
"""
)

View File

@ -0,0 +1,33 @@
from invokeai.app.models.metadata import (
GeneratedImageOrLatentsMetadata,
UploadedImageOrLatentsMetadata,
)
from invokeai.app.models.image import ImageCategory, ImageType
from invokeai.app.services.models.image_record import ImageRecord
from invokeai.app.util.misc import get_iso_timestamp
def deserialize_image_record(image: dict) -> ImageRecord:
"""Deserializes an image record."""
# All values *should* be present, except `session_id` and `node_id`, but provide some defaults just in case
image_type = ImageType(image.get("image_type", ImageType.RESULT.value))
raw_metadata = image.get("metadata", {})
if image_type == ImageType.UPLOAD:
metadata = UploadedImageOrLatentsMetadata.parse_obj(raw_metadata)
else:
metadata = GeneratedImageOrLatentsMetadata.parse_obj(raw_metadata)
return ImageRecord(
image_name=image.get("id", "unknown"),
session_id=image.get("session_id", None),
node_id=image.get("node_id", None),
metadata=metadata,
image_type=image_type,
image_category=ImageCategory(
image.get("image_category", ImageCategory.IMAGE.value)
),
created_at=image.get("created_at", get_iso_timestamp()),
)

12
invokeai/app/util/enum.py Normal file
View File

@ -0,0 +1,12 @@
from enum import EnumMeta
class MetaEnum(EnumMeta):
"""Metaclass to support `in` syntax value checking in String Enums"""
def __contains__(cls, item):
try:
cls(item)
except ValueError:
return False
return True

View File

@ -6,6 +6,14 @@ def get_timestamp():
return int(datetime.datetime.now(datetime.timezone.utc).timestamp())
def get_iso_timestamp() -> str:
return datetime.datetime.utcnow().isoformat()
def get_datetime_from_iso_timestamp(iso_timestamp: str) -> datetime.datetime:
return datetime.datetime.fromisoformat(iso_timestamp)
SEED_MAX = np.iinfo(np.int32).max

View File

@ -8,7 +8,7 @@ import type { ImageField } from './ImageField';
* Base class for invocations that output an image
*/
export type ImageOutput = {
type: 'image';
type: 'image_output';
/**
* The output image
*/

View File

@ -11,5 +11,13 @@ export type RandomIntInvocation = {
*/
id: string;
type?: 'rand_int';
/**
* The inclusive low value
*/
low?: number;
/**
* The exclusive high value
*/
high?: number;
};

View File

@ -12,5 +12,13 @@ export const $RandomIntInvocation = {
type: {
type: 'Enum',
},
low: {
type: 'number',
description: `The inclusive low value`,
},
high: {
type: 'number',
description: `The exclusive high value`,
},
},
} as const;

View File

@ -10,11 +10,16 @@ import {
CollectInvocationOutput,
ImageType,
ImageField,
LatentsOutput,
} from 'services/api';
export const isImageOutput = (
output: GraphExecutionState['results'][string]
): output is ImageOutput => output.type === 'image';
): output is ImageOutput => output.type === 'image_output';
export const isLatentsOutput = (
output: GraphExecutionState['results'][string]
): output is LatentsOutput => output.type === 'latents_output';
export const isMaskOutput = (
output: GraphExecutionState['results'][string]