feat(nodes): address feedback

- Address database feedback:
  - Remove all the extraneous tables. Only an `images` table now:
  - `image_type` and `image_category` are unrestricted strings. When creating images, the provided values are checked to ensure they are a valid type and category.
  - Add `updated_at` and `deleted_at` columns. `deleted_at` is currently unused.
  - Use SQLite's built-in timestamp features to populate these. Add a trigger to update `updated_at` when the row is updated. Currently no way to update a row.
  - Rename the `id` column in `images` to `image_name`
- Rename `ImageCategory.IMAGE` to `ImageCategory.GENERAL`
- Move all exceptions outside their base classes to make them more portable.
- Add `width` and `height` columns to the database. These store the actual dimensions of the image file, whereas the metadata's `width` and `height` refer to the respective generation parameters and are nullable.
- Make `deserialize_image_record` take a `dict` instead of `sqlite3.Row`
- Improve comments throughout
- Tidy up unused code/files and some minor organisation
This commit is contained in:
psychedelicious
2023-05-23 18:59:43 +10:00
committed by Kent Keirsey
parent 021e5a2aa3
commit 035425ef24
12 changed files with 273 additions and 846 deletions

View File

@ -30,7 +30,7 @@ async def upload_image(
image_type: ImageType, image_type: ImageType,
request: Request, request: Request,
response: Response, response: Response,
image_category: ImageCategory = ImageCategory.IMAGE, image_category: ImageCategory = ImageCategory.GENERAL,
) -> ImageDTO: ) -> ImageDTO:
"""Uploads an image""" """Uploads an image"""
if not file.content_type.startswith("image"): if not file.content_type.startswith("image"):

View File

@ -95,7 +95,7 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
image_dto = context.services.images_new.create( image_dto = context.services.images_new.create(
image=generate_output.image, image=generate_output.image,
image_type=ImageType.RESULT, image_type=ImageType.RESULT,
image_category=ImageCategory.IMAGE, image_category=ImageCategory.GENERAL,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
node_id=self.id, node_id=self.id,
) )
@ -119,7 +119,7 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
# context.services.images_db.set( # context.services.images_db.set(
# id=image_name, # id=image_name,
# image_type=ImageType.RESULT, # image_type=ImageType.RESULT,
# image_category=ImageCategory.IMAGE, # image_category=ImageCategory.GENERAL,
# session_id=context.graph_execution_state_id, # session_id=context.graph_execution_state_id,
# node_id=self.id, # node_id=self.id,
# metadata=GeneratedImageOrLatentsMetadata(), # metadata=GeneratedImageOrLatentsMetadata(),

View File

@ -372,7 +372,7 @@ class LatentsToImageInvocation(BaseInvocation):
image_dto = context.services.images_new.create( image_dto = context.services.images_new.create(
image=image, image=image,
image_type=ImageType.RESULT, image_type=ImageType.RESULT,
image_category=ImageCategory.IMAGE, image_category=ImageCategory.GENERAL,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
node_id=self.id, node_id=self.id,
) )

View File

@ -2,7 +2,7 @@ from enum import Enum
from typing import Optional, Tuple from typing import Optional, Tuple
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from invokeai.app.util.enum import MetaEnum from invokeai.app.util.metaenum import MetaEnum
class ImageType(str, Enum, metaclass=MetaEnum): class ImageType(str, Enum, metaclass=MetaEnum):
@ -13,20 +13,32 @@ class ImageType(str, Enum, metaclass=MetaEnum):
INTERMEDIATE = "intermediates" INTERMEDIATE = "intermediates"
class InvalidImageTypeException(ValueError):
"""Raised when a provided value is not a valid ImageType.
Subclasses `ValueError`.
"""
def __init__(self, message="Invalid image type."):
super().__init__(message)
class ImageCategory(str, Enum, metaclass=MetaEnum): class ImageCategory(str, Enum, metaclass=MetaEnum):
"""The category of an image. Use ImageCategory.OTHER for non-default categories.""" """The category of an image. Use ImageCategory.OTHER for non-default categories."""
IMAGE = "image" GENERAL = "general"
CONTROL_IMAGE = "control_image" CONTROL = "control"
OTHER = "other" OTHER = "other"
def is_image_type(obj): class InvalidImageCategoryException(ValueError):
try: """Raised when a provided value is not a valid ImageCategory.
ImageType(obj)
except ValueError: Subclasses `ValueError`.
return False """
return True
def __init__(self, message="Invalid image category."):
super().__init__(message)
class ImageField(BaseModel): class ImageField(BaseModel):

View File

@ -26,50 +26,66 @@ class ImageMetadata(BaseModel):
default=None, default=None,
description="The type of the ancestor node of the image output node.", description="The type of the ancestor node of the image output node.",
) )
"""The type of the ancestor node of the image output node."""
positive_conditioning: Optional[StrictStr] = Field( positive_conditioning: Optional[StrictStr] = Field(
default=None, description="The positive conditioning." default=None, description="The positive conditioning."
) )
"""The positive conditioning"""
negative_conditioning: Optional[StrictStr] = Field( negative_conditioning: Optional[StrictStr] = Field(
default=None, description="The negative conditioning." default=None, description="The negative conditioning."
) )
"""The negative conditioning"""
width: Optional[StrictInt] = Field( width: Optional[StrictInt] = Field(
default=None, description="Width of the image/latents in pixels." default=None, description="Width of the image/latents in pixels."
) )
"""Width of the image/latents in pixels"""
height: Optional[StrictInt] = Field( height: Optional[StrictInt] = Field(
default=None, description="Height of the image/latents in pixels." default=None, description="Height of the image/latents in pixels."
) )
"""Height of the image/latents in pixels"""
seed: Optional[StrictInt] = Field( seed: Optional[StrictInt] = Field(
default=None, description="The seed used for noise generation." default=None, description="The seed used for noise generation."
) )
"""The seed used for noise generation"""
cfg_scale: Optional[StrictFloat] = Field( cfg_scale: Optional[StrictFloat] = Field(
default=None, description="The classifier-free guidance scale." default=None, description="The classifier-free guidance scale."
) )
"""The classifier-free guidance scale"""
steps: Optional[StrictInt] = Field( steps: Optional[StrictInt] = Field(
default=None, description="The number of steps used for inference." default=None, description="The number of steps used for inference."
) )
"""The number of steps used for inference"""
scheduler: Optional[StrictStr] = Field( scheduler: Optional[StrictStr] = Field(
default=None, description="The scheduler used for inference." default=None, description="The scheduler used for inference."
) )
"""The scheduler used for inference"""
model: Optional[StrictStr] = Field( model: Optional[StrictStr] = Field(
default=None, description="The model used for inference." default=None, description="The model used for inference."
) )
"""The model used for inference"""
strength: Optional[StrictFloat] = Field( strength: Optional[StrictFloat] = Field(
default=None, default=None,
description="The strength used for image-to-image/latents-to-latents.", description="The strength used for image-to-image/latents-to-latents.",
) )
"""The strength used for image-to-image/latents-to-latents."""
latents: Optional[StrictStr] = Field( latents: Optional[StrictStr] = Field(
default=None, description="The ID of the initial latents." default=None, description="The ID of the initial latents."
) )
"""The ID of the initial latents"""
vae: Optional[StrictStr] = Field( vae: Optional[StrictStr] = Field(
default=None, description="The VAE used for decoding." default=None, description="The VAE used for decoding."
) )
"""The VAE used for decoding"""
unet: Optional[StrictStr] = Field( unet: Optional[StrictStr] = Field(
default=None, description="The UNet used dor inference." default=None, description="The UNet used dor inference."
) )
"""The UNet used dor inference"""
clip: Optional[StrictStr] = Field( clip: Optional[StrictStr] = Field(
default=None, description="The CLIP Encoder used for conditioning." default=None, description="The CLIP Encoder used for conditioning."
) )
"""The CLIP Encoder used for conditioning"""
extra: Optional[StrictStr] = Field( extra: Optional[StrictStr] = Field(
default=None, default=None,
description="Uploaded image metadata, extracted from the PNG tEXt chunk.", description="Uploaded image metadata, extracted from the PNG tEXt chunk.",
) )
"""Uploaded image metadata, extracted from the PNG tEXt chunk."""

View File

@ -1,28 +0,0 @@
# TODO: Make a new model for this
from enum import Enum
from invokeai.app.util.enum import MetaEnum
class ResourceType(str, Enum, metaclass=MetaEnum):
"""The type of a resource."""
IMAGES = "images"
TENSORS = "tensors"
# class ResourceOrigin(str, Enum, metaclass=MetaEnum):
# """The origin of a resource (eg image or tensor)."""
# RESULTS = "results"
# UPLOADS = "uploads"
# INTERMEDIATES = "intermediates"
class TensorKind(str, Enum, metaclass=MetaEnum):
"""The kind of a tensor. Use TensorKind.OTHER for non-default kinds."""
IMAGE_LATENTS = "image_latents"
CONDITIONING = "conditioning"
OTHER = "other"

View File

@ -1,578 +0,0 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [],
"source": [
"from abc import ABC, abstractmethod\n",
"from enum import Enum\n",
"import enum\n",
"import sqlite3\n",
"import threading\n",
"from typing import Optional, Type, TypeVar, Union\n",
"from PIL.Image import Image as PILImage\n",
"from pydantic import BaseModel, Field\n",
"from torch import Tensor"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
"outputs": [],
"source": [
"\n",
"class ResourceOrigin(str, Enum):\n",
" \"\"\"The origin of a resource (eg image or tensor).\"\"\"\n",
"\n",
" RESULTS = \"results\"\n",
" UPLOADS = \"uploads\"\n",
" INTERMEDIATES = \"intermediates\"\n",
"\n",
"\n",
"class ImageKind(str, Enum):\n",
" \"\"\"The kind of an image. Use ImageKind.OTHER for non-default kinds.\"\"\"\n",
"\n",
" IMAGE = \"image\"\n",
" CONTROL_IMAGE = \"control_image\"\n",
" OTHER = \"other\"\n",
"\n",
"\n",
"class TensorKind(str, Enum):\n",
" \"\"\"The kind of a tensor. Use TensorKind.OTHER for non-default kinds.\"\"\"\n",
"\n",
" IMAGE_LATENTS = \"image_latents\"\n",
" CONDITIONING = \"conditioning\"\n",
" OTHER = \"other\"\n"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {},
"outputs": [],
"source": [
"\n",
"def create_sql_values_string_from_string_enum(enum: Type[Enum]):\n",
" \"\"\"\n",
" Creates a string of the form \"('value1'), ('value2'), ..., ('valueN')\" from a StrEnum.\n",
" \"\"\"\n",
"\n",
" delimiter = \", \"\n",
" values = [f\"('{e.value}')\" for e in enum]\n",
" return delimiter.join(values)\n",
"\n",
"\n",
"def create_sql_table_from_enum(\n",
" enum: Type[Enum],\n",
" table_name: str,\n",
" primary_key_name: str,\n",
" conn: sqlite3.Connection,\n",
" cursor: sqlite3.Cursor,\n",
" lock: threading.Lock,\n",
"):\n",
" \"\"\"\n",
" Creates and populates a table to be used as a functional enum.\n",
" \"\"\"\n",
"\n",
" try:\n",
" lock.acquire()\n",
"\n",
" values_string = create_sql_values_string_from_string_enum(enum)\n",
"\n",
" cursor.execute(\n",
" f\"\"\"--sql\n",
" CREATE TABLE IF NOT EXISTS {table_name} (\n",
" {primary_key_name} TEXT PRIMARY KEY\n",
" );\n",
" \"\"\"\n",
" )\n",
" cursor.execute(\n",
" f\"\"\"--sql\n",
" INSERT OR IGNORE INTO {table_name} ({primary_key_name}) VALUES {values_string};\n",
" \"\"\"\n",
" )\n",
" conn.commit()\n",
" finally:\n",
" lock.release()\n",
"\n",
"\n",
"\"\"\"\n",
"`resource_origins` functions as an enum for the ResourceOrigin model.\n",
"\"\"\"\n",
"\n",
"\n",
"# def create_resource_origins_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n",
"# create_sql_table_from_enum(\n",
"# enum=ResourceOrigin,\n",
"# table_name=\"resource_origins\",\n",
"# primary_key_name=\"origin_name\",\n",
"# conn=conn,\n",
"# cursor=cursor,\n",
"# lock=lock,\n",
"# )\n",
"\n",
"\n",
"\"\"\"\n",
"`image_kinds` functions as an enum for the ImageType model.\n",
"\"\"\"\n",
"\n",
"\n",
"# def create_image_kinds_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n",
" # create_sql_table_from_enum(\n",
" # enum=ImageKind,\n",
" # table_name=\"image_kinds\",\n",
" # primary_key_name=\"kind_name\",\n",
" # conn=conn,\n",
" # cursor=cursor,\n",
" # lock=lock,\n",
" # )\n",
"\n",
"\n",
"\"\"\"\n",
"`tensor_kinds` functions as an enum for the TensorType model.\n",
"\"\"\"\n",
"\n",
"\n",
"# def create_tensor_kinds_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n",
" # create_sql_table_from_enum(\n",
" # enum=TensorKind,\n",
" # table_name=\"tensor_kinds\",\n",
" # primary_key_name=\"kind_name\",\n",
" # conn=conn,\n",
" # cursor=cursor,\n",
" # lock=lock,\n",
" # )\n",
"\n",
"\n",
"\"\"\"\n",
"`images` stores all images, regardless of type\n",
"\"\"\"\n",
"\n",
"\n",
"def create_images_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n",
" try:\n",
" lock.acquire()\n",
"\n",
" cursor.execute(\n",
" \"\"\"--sql\n",
" CREATE TABLE IF NOT EXISTS images (\n",
" id TEXT PRIMARY KEY,\n",
" origin TEXT,\n",
" image_kind TEXT,\n",
" created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,\n",
" FOREIGN KEY(origin) REFERENCES resource_origins(origin_name),\n",
" FOREIGN KEY(image_kind) REFERENCES image_kinds(kind_name)\n",
" );\n",
" \"\"\"\n",
" )\n",
" cursor.execute(\n",
" \"\"\"--sql\n",
" CREATE UNIQUE INDEX IF NOT EXISTS idx_images_id ON images(id);\n",
" \"\"\"\n",
" )\n",
" cursor.execute(\n",
" \"\"\"--sql\n",
" CREATE INDEX IF NOT EXISTS idx_images_origin ON images(origin);\n",
" \"\"\"\n",
" )\n",
" cursor.execute(\n",
" \"\"\"--sql\n",
" CREATE INDEX IF NOT EXISTS idx_images_image_kind ON images(image_kind);\n",
" \"\"\"\n",
" )\n",
" conn.commit()\n",
" finally:\n",
" lock.release()\n",
"\n",
"\n",
"\"\"\"\n",
"`images_results` stores additional data specific to `results` images.\n",
"\"\"\"\n",
"\n",
"\n",
"def create_images_results_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n",
" try:\n",
" lock.acquire()\n",
"\n",
" cursor.execute(\n",
" \"\"\"--sql\n",
" CREATE TABLE IF NOT EXISTS images_results (\n",
" images_id TEXT PRIMARY KEY,\n",
" session_id TEXT NOT NULL,\n",
" node_id TEXT NOT NULL,\n",
" FOREIGN KEY(images_id) REFERENCES images(id) ON DELETE CASCADE\n",
" );\n",
" \"\"\"\n",
" )\n",
" cursor.execute(\n",
" \"\"\"--sql\n",
" CREATE UNIQUE INDEX IF NOT EXISTS idx_images_results_images_id ON images_results(images_id);\n",
" \"\"\"\n",
" )\n",
" conn.commit()\n",
" finally:\n",
" lock.release()\n",
"\n",
"\n",
"\"\"\"\n",
"`images_intermediates` stores additional data specific to `intermediates` images\n",
"\"\"\"\n",
"\n",
"\n",
"def create_images_intermediates_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n",
" try:\n",
" lock.acquire()\n",
"\n",
" cursor.execute(\n",
" \"\"\"--sql\n",
" CREATE TABLE IF NOT EXISTS images_intermediates (\n",
" images_id TEXT PRIMARY KEY,\n",
" session_id TEXT NOT NULL,\n",
" node_id TEXT NOT NULL,\n",
" FOREIGN KEY(images_id) REFERENCES images(id) ON DELETE CASCADE\n",
" );\n",
" \"\"\"\n",
" )\n",
" cursor.execute(\n",
" \"\"\"--sql\n",
" CREATE UNIQUE INDEX IF NOT EXISTS idx_images_intermediates_images_id ON images_intermediates(images_id);\n",
" \"\"\"\n",
" )\n",
" conn.commit()\n",
" finally:\n",
" lock.release()\n",
"\n",
"\n",
"\"\"\"\n",
"`images_metadata` stores basic metadata for any image type\n",
"\"\"\"\n",
"\n",
"\n",
"def create_images_metadata_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n",
" try:\n",
" lock.acquire()\n",
"\n",
" cursor.execute(\n",
" \"\"\"--sql\n",
" CREATE TABLE IF NOT EXISTS images_metadata (\n",
" images_id TEXT PRIMARY KEY,\n",
" metadata TEXT,\n",
" FOREIGN KEY(images_id) REFERENCES images(id) ON DELETE CASCADE\n",
" );\n",
" \"\"\"\n",
" )\n",
" cursor.execute(\n",
" \"\"\"--sql\n",
" CREATE UNIQUE INDEX IF NOT EXISTS idx_images_metadata_images_id ON images_metadata(images_id);\n",
" \"\"\"\n",
" )\n",
" conn.commit()\n",
" finally:\n",
" lock.release()\n",
"\n",
"\n",
"# `tensors` table: stores references to tensor\n",
"\n",
"\n",
"def create_tensors_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n",
" try:\n",
" lock.acquire()\n",
"\n",
" cursor.execute(\n",
" \"\"\"--sql\n",
" CREATE TABLE IF NOT EXISTS tensors (\n",
" id TEXT PRIMARY KEY,\n",
" origin TEXT,\n",
" tensor_kind TEXT,\n",
" created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,\n",
" FOREIGN KEY(origin) REFERENCES resource_origins(origin_name),\n",
" FOREIGN KEY(tensor_kind) REFERENCES tensor_kinds(kind_name)\n",
" );\n",
" \"\"\"\n",
" )\n",
" cursor.execute(\n",
" \"\"\"--sql\n",
" CREATE UNIQUE INDEX IF NOT EXISTS idx_tensors_id ON tensors(id);\n",
" \"\"\"\n",
" )\n",
" cursor.execute(\n",
" \"\"\"--sql\n",
" CREATE INDEX IF NOT EXISTS idx_tensors_origin ON tensors(origin);\n",
" \"\"\"\n",
" )\n",
" cursor.execute(\n",
" \"\"\"--sql\n",
" CREATE INDEX IF NOT EXISTS idx_tensors_tensor_kind ON tensors(tensor_kind);\n",
" \"\"\"\n",
" )\n",
" conn.commit()\n",
" finally:\n",
" lock.release()\n",
"\n",
"\n",
"# `tensors_results` stores additional data specific to `result` tensor\n",
"\n",
"\n",
"def create_tensors_results_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n",
" try:\n",
" lock.acquire()\n",
"\n",
" cursor.execute(\n",
" \"\"\"--sql\n",
" CREATE TABLE IF NOT EXISTS tensors_results (\n",
" tensors_id TEXT PRIMARY KEY,\n",
" session_id TEXT NOT NULL,\n",
" node_id TEXT NOT NULL,\n",
" FOREIGN KEY(tensors_id) REFERENCES tensors(id) ON DELETE CASCADE\n",
" );\n",
" \"\"\"\n",
" )\n",
" cursor.execute(\n",
" \"\"\"--sql\n",
" CREATE UNIQUE INDEX IF NOT EXISTS idx_tensors_results_tensors_id ON tensors_results(tensors_id);\n",
" \"\"\"\n",
" )\n",
" conn.commit()\n",
" finally:\n",
" lock.release()\n",
"\n",
"\n",
"# `tensors_intermediates` stores additional data specific to `intermediate` tensor\n",
"\n",
"\n",
"def create_tensors_intermediates_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n",
" try:\n",
" lock.acquire()\n",
"\n",
" cursor.execute(\n",
" \"\"\"--sql\n",
" CREATE TABLE IF NOT EXISTS tensors_intermediates (\n",
" tensors_id TEXT PRIMARY KEY,\n",
" session_id TEXT NOT NULL,\n",
" node_id TEXT NOT NULL,\n",
" FOREIGN KEY(tensors_id) REFERENCES tensors(id) ON DELETE CASCADE\n",
" );\n",
" \"\"\"\n",
" )\n",
" cursor.execute(\n",
" \"\"\"--sql\n",
" CREATE UNIQUE INDEX IF NOT EXISTS idx_tensors_intermediates_tensors_id ON tensors_intermediates(tensors_id);\n",
" \"\"\"\n",
" )\n",
" conn.commit()\n",
" finally:\n",
" lock.release()\n",
"\n",
"\n",
"# `tensors_metadata` table: stores generated/transformed metadata for tensor\n",
"\n",
"\n",
"def create_tensors_metadata_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n",
" try:\n",
" lock.acquire()\n",
"\n",
" cursor.execute(\n",
" \"\"\"--sql\n",
" CREATE TABLE IF NOT EXISTS tensors_metadata (\n",
" tensors_id TEXT PRIMARY KEY,\n",
" metadata TEXT,\n",
" FOREIGN KEY(tensors_id) REFERENCES tensors(id) ON DELETE CASCADE\n",
" );\n",
" \"\"\"\n",
" )\n",
" cursor.execute(\n",
" \"\"\"--sql\n",
" CREATE UNIQUE INDEX IF NOT EXISTS idx_tensors_metadata_tensors_id ON tensors_metadata(tensors_id);\n",
" \"\"\"\n",
" )\n",
" conn.commit()\n",
" finally:\n",
" lock.release()\n"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"db_path = '/home/bat/Documents/Code/outputs/test.db'\n",
"if (os.path.exists(db_path)):\n",
" os.remove(db_path)\n",
"\n",
"conn = sqlite3.connect(\n",
" db_path, check_same_thread=False\n",
")\n",
"cursor = conn.cursor()\n",
"lock = threading.Lock()"
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {},
"outputs": [],
"source": [
"create_sql_table_from_enum(\n",
" enum=ResourceOrigin,\n",
" table_name=\"resource_origins\",\n",
" primary_key_name=\"origin_name\",\n",
" conn=conn,\n",
" cursor=cursor,\n",
" lock=lock,\n",
")\n",
"\n",
"create_sql_table_from_enum(\n",
" enum=ImageKind,\n",
" table_name=\"image_kinds\",\n",
" primary_key_name=\"kind_name\",\n",
" conn=conn,\n",
" cursor=cursor,\n",
" lock=lock,\n",
")\n",
"\n",
"create_sql_table_from_enum(\n",
" enum=TensorKind,\n",
" table_name=\"tensor_kinds\",\n",
" primary_key_name=\"kind_name\",\n",
" conn=conn,\n",
" cursor=cursor,\n",
" lock=lock,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {},
"outputs": [],
"source": [
"create_images_table(conn, cursor, lock)\n",
"create_images_results_table(conn, cursor, lock)\n",
"create_images_intermediates_table(conn, cursor, lock)\n",
"create_images_metadata_table(conn, cursor, lock)"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {},
"outputs": [],
"source": [
"create_tensors_table(conn, cursor, lock)\n",
"create_tensors_results_table(conn, cursor, lock)\n",
"create_tensors_intermediates_table(conn, cursor, lock)\n",
"create_tensors_metadata_table(conn, cursor, lock)"
]
},
{
"cell_type": "code",
"execution_count": 59,
"metadata": {},
"outputs": [],
"source": [
"\n",
"from pydantic import StrictStr\n",
"\n",
"\n",
"class GeneratedImageOrLatentsMetadata(BaseModel):\n",
" \"\"\"Core generation metadata for an image/tensor generated in InvokeAI.\n",
"\n",
" Generated by traversing the execution graph, collecting the parameters of the nearest ancestors of a given node.\n",
"\n",
" Full metadata may be accessed by querying for the session in the `graph_executions` table.\n",
" \"\"\"\n",
"\n",
" positive_conditioning: Optional[StrictStr] = Field(\n",
" default=None, description=\"The positive conditioning.\"\n",
" )\n",
" negative_conditioning: Optional[str] = Field(\n",
" default=None, description=\"The negative conditioning.\"\n",
" )\n",
" width: Optional[int] = Field(\n",
" default=None, description=\"Width of the image/tensor in pixels.\"\n",
" )\n",
" height: Optional[int] = Field(\n",
" default=None, description=\"Height of the image/tensor in pixels.\"\n",
" )\n",
" seed: Optional[int] = Field(\n",
" default=None, description=\"The seed used for noise generation.\"\n",
" )\n",
" cfg_scale: Optional[float] = Field(\n",
" default=None, description=\"The classifier-free guidance scale.\"\n",
" )\n",
" steps: Optional[int] = Field(\n",
" default=None, description=\"The number of steps used for inference.\"\n",
" )\n",
" scheduler: Optional[str] = Field(\n",
" default=None, description=\"The scheduler used for inference.\"\n",
" )\n",
" model: Optional[str] = Field(\n",
" default=None, description=\"The model used for inference.\"\n",
" )\n",
" strength: Optional[float] = Field(\n",
" default=None,\n",
" description=\"The strength used for image-to-image/tensor-to-tensor.\",\n",
" )\n",
" image: Optional[str] = Field(\n",
" default=None, description=\"The ID of the initial image.\"\n",
" )\n",
" tensor: Optional[str] = Field(\n",
" default=None, description=\"The ID of the initial tensor.\"\n",
" )\n",
" # Pending model refactor:\n",
" # vae: Optional[str] = Field(default=None,description=\"The VAE used for decoding.\")\n",
" # unet: Optional[str] = Field(default=None,description=\"The UNet used dor inference.\")\n",
" # clip: Optional[str] = Field(default=None,description=\"The CLIP Encoder used for conditioning.\")\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 61,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"GeneratedImageOrLatentsMetadata(positive_conditioning='123', negative_conditioning=None, width=None, height=None, seed=None, cfg_scale=None, steps=None, scheduler=None, model=None, strength=None, image=None, tensor=None)"
]
},
"execution_count": 61,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"GeneratedImageOrLatentsMetadata(positive_conditioning='123')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -14,27 +14,31 @@ from invokeai.app.models.metadata import ImageMetadata
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
# TODO: Should these excpetions subclass existing python exceptions?
class ImageFileNotFoundException(Exception):
"""Raised when an image file is not found in storage."""
def __init__(self, message="Image file not found"):
super().__init__(message)
class ImageFileSaveException(Exception):
"""Raised when an image cannot be saved."""
def __init__(self, message="Image file not saved"):
super().__init__(message)
class ImageFileDeleteException(Exception):
"""Raised when an image cannot be deleted."""
def __init__(self, message="Image file not deleted"):
super().__init__(message)
class ImageFileStorageBase(ABC): class ImageFileStorageBase(ABC):
"""Low-level service responsible for storing and retrieving image files.""" """Low-level service responsible for storing and retrieving image files."""
class ImageFileNotFoundException(Exception):
"""Raised when an image file is not found in storage."""
def __init__(self, message="Image file not found"):
super().__init__(message)
class ImageFileSaveException(Exception):
"""Raised when an image cannot be saved."""
def __init__(self, message="Image file not saved"):
super().__init__(message)
class ImageFileDeleteException(Exception):
"""Raised when an image cannot be deleted."""
def __init__(self, message="Image file not deleted"):
super().__init__(message)
@abstractmethod @abstractmethod
def get(self, image_type: ImageType, image_name: str) -> PILImageType: def get(self, image_type: ImageType, image_name: str) -> PILImageType:
"""Retrieves an image as PIL Image.""" """Retrieves an image as PIL Image."""
@ -102,7 +106,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
self.__set_cache(image_path, image) self.__set_cache(image_path, image)
return image return image
except FileNotFoundError as e: except FileNotFoundError as e:
raise ImageFileStorageBase.ImageFileNotFoundException from e raise ImageFileNotFoundException from e
def save( def save(
self, self,
@ -130,7 +134,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
self.__set_cache(image_path, image) self.__set_cache(image_path, image)
self.__set_cache(thumbnail_path, thumbnail_image) self.__set_cache(thumbnail_path, thumbnail_image)
except Exception as e: except Exception as e:
raise ImageFileStorageBase.ImageFileSaveException from e raise ImageFileSaveException from e
def delete(self, image_type: ImageType, image_name: str) -> None: def delete(self, image_type: ImageType, image_name: str) -> None:
try: try:
@ -150,7 +154,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
if thumbnail_path in self.__cache: if thumbnail_path in self.__cache:
del self.__cache[thumbnail_path] del self.__cache[thumbnail_path]
except Exception as e: except Exception as e:
raise ImageFileStorageBase.ImageFileDeleteException from e raise ImageFileDeleteException from e
# TODO: make this a bit more flexible for e.g. cloud storage # TODO: make this a bit more flexible for e.g. cloud storage
def get_path( def get_path(

View File

@ -1,7 +1,6 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import datetime from datetime import datetime
from enum import Enum from typing import Optional, cast
from typing import Optional, Type
import sqlite3 import sqlite3
import threading import threading
from typing import Optional, Union from typing import Optional, Union
@ -18,62 +17,32 @@ from invokeai.app.services.models.image_record import (
from invokeai.app.services.item_storage import PaginatedResults from invokeai.app.services.item_storage import PaginatedResults
def create_sql_values_string_from_string_enum(enum: Type[Enum]): # TODO: Should these excpetions subclass existing python exceptions?
""" class ImageRecordNotFoundException(Exception):
Creates a string of the form "('value1'), ('value2'), ..., ('valueN')" from a StrEnum. """Raised when an image record is not found."""
"""
delimiter = ", " def __init__(self, message="Image record not found"):
values = [f"('{e.value}')" for e in enum] super().__init__(message)
return delimiter.join(values)
def create_enum_table( class ImageRecordSaveException(Exception):
enum: Type[Enum], """Raised when an image record cannot be saved."""
table_name: str,
primary_key_name: str,
cursor: sqlite3.Cursor,
):
"""
Creates and populates a table to be used as a functional enum.
"""
values_string = create_sql_values_string_from_string_enum(enum) def __init__(self, message="Image record not saved"):
super().__init__(message)
cursor.execute(
f"""--sql class ImageRecordDeleteException(Exception):
CREATE TABLE IF NOT EXISTS {table_name} ( """Raised when an image record cannot be deleted."""
{primary_key_name} TEXT PRIMARY KEY
); def __init__(self, message="Image record not deleted"):
""" super().__init__(message)
)
cursor.execute(
f"""--sql
INSERT OR IGNORE INTO {table_name} ({primary_key_name}) VALUES {values_string};
"""
)
class ImageRecordStorageBase(ABC): class ImageRecordStorageBase(ABC):
"""Low-level service responsible for interfacing with the image record store.""" """Low-level service responsible for interfacing with the image record store."""
class ImageRecordNotFoundException(Exception): # TODO: Implement an `update()` method
"""Raised when an image record is not found."""
def __init__(self, message="Image record not found"):
super().__init__(message)
class ImageRecordSaveException(Exception):
"""Raised when an image record cannot be saved."""
def __init__(self, message="Image record not saved"):
super().__init__(message)
class ImageRecordDeleteException(Exception):
"""Raised when an image record cannot be deleted."""
def __init__(self, message="Image record not deleted"):
super().__init__(message)
@abstractmethod @abstractmethod
def get(self, image_type: ImageType, image_name: str) -> ImageRecord: def get(self, image_type: ImageType, image_name: str) -> ImageRecord:
@ -91,6 +60,8 @@ class ImageRecordStorageBase(ABC):
"""Gets a page of image records.""" """Gets a page of image records."""
pass pass
# TODO: The database has a nullable `deleted_at` column, currently unused.
# Should we implement soft deletes? Would need coordination with ImageFileStorage.
@abstractmethod @abstractmethod
def delete(self, image_type: ImageType, image_name: str) -> None: def delete(self, image_type: ImageType, image_name: str) -> None:
"""Deletes an image record.""" """Deletes an image record."""
@ -102,11 +73,12 @@ class ImageRecordStorageBase(ABC):
image_name: str, image_name: str,
image_type: ImageType, image_type: ImageType,
image_category: ImageCategory, image_category: ImageCategory,
width: int,
height: int,
session_id: Optional[str], session_id: Optional[str],
node_id: Optional[str], node_id: Optional[str],
metadata: Optional[ImageMetadata], metadata: Optional[ImageMetadata],
created_at: str = datetime.datetime.utcnow().isoformat(), ) -> datetime:
) -> None:
"""Saves an image record.""" """Saves an image record."""
pass pass
@ -141,17 +113,23 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
# Create the `images` table. # Create the `images` table.
self._cursor.execute( self._cursor.execute(
f"""--sql """--sql
CREATE TABLE IF NOT EXISTS images ( CREATE TABLE IF NOT EXISTS images (
id TEXT PRIMARY KEY, image_name TEXT NOT NULL PRIMARY KEY,
image_type TEXT, -- non-nullable via foreign key constraint -- This is an enum in python, unrestricted string here for flexibility
image_category TEXT, -- non-nullable via foreign key constraint image_type TEXT NOT NULL,
session_id TEXT, -- nullable -- This is an enum in python, unrestricted string here for flexibility
node_id TEXT, -- nullable image_category TEXT NOT NULL,
metadata TEXT, -- nullable width INTEGER NOT NULL,
created_at TEXT NOT NULL, height INTEGER NOT NULL,
FOREIGN KEY(image_type) REFERENCES image_types(type_name), session_id TEXT,
FOREIGN KEY(image_category) REFERENCES image_categories(category_name) node_id TEXT,
metadata TEXT,
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
-- Updated via trigger
updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
-- Soft delete, currently unused
deleted_at DATETIME
); );
""" """
) )
@ -159,7 +137,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
# Create the `images` table indices. # Create the `images` table indices.
self._cursor.execute( self._cursor.execute(
"""--sql """--sql
CREATE UNIQUE INDEX IF NOT EXISTS idx_images_id ON images(id); CREATE UNIQUE INDEX IF NOT EXISTS idx_images_image_name ON images(image_name);
""" """
) )
self._cursor.execute( self._cursor.execute(
@ -172,53 +150,22 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
CREATE INDEX IF NOT EXISTS idx_images_image_category ON images(image_category); CREATE INDEX IF NOT EXISTS idx_images_image_category ON images(image_category);
""" """
) )
# Create the tables for image-related enums
create_enum_table(
enum=ImageType,
table_name="image_types",
primary_key_name="type_name",
cursor=self._cursor,
)
create_enum_table(
enum=ImageCategory,
table_name="image_categories",
primary_key_name="category_name",
cursor=self._cursor,
)
# Create the `tags` table. TODO: do this elsewhere, shouldn't be in images db service
self._cursor.execute( self._cursor.execute(
"""--sql """--sql
CREATE TABLE IF NOT EXISTS tags ( CREATE INDEX IF NOT EXISTS idx_images_created_at ON images(created_at);
id INTEGER PRIMARY KEY AUTOINCREMENT,
tag_name TEXT UNIQUE NOT NULL
);
""" """
) )
# Create the `images_tags` junction table. # Add trigger for `updated_at`.
self._cursor.execute( self._cursor.execute(
"""--sql """--sql
CREATE TABLE IF NOT EXISTS images_tags ( CREATE TRIGGER IF NOT EXISTS tg_images_updated_at
image_id TEXT, AFTER UPDATE
tag_id INTEGER, ON images FOR EACH ROW
PRIMARY KEY (image_id, tag_id), BEGIN
FOREIGN KEY(image_id) REFERENCES images(id) ON DELETE CASCADE, UPDATE images SET updated_at = current_timestamp
FOREIGN KEY(tag_id) REFERENCES tags(id) ON DELETE CASCADE WHERE image_name = old.image_name;
); END;
"""
)
# Create the `images_favorites` table.
self._cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS images_favorites (
image_id TEXT PRIMARY KEY,
favorited_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY(image_id) REFERENCES images(id) ON DELETE CASCADE
);
""" """
) )
@ -229,22 +176,22 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
self._cursor.execute( self._cursor.execute(
f"""--sql f"""--sql
SELECT * FROM images SELECT * FROM images
WHERE id = ?; WHERE image_name = ?;
""", """,
(image_name,), (image_name,),
) )
result = self._cursor.fetchone() result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
except sqlite3.Error as e: except sqlite3.Error as e:
self._conn.rollback() self._conn.rollback()
raise self.ImageRecordNotFoundException from e raise ImageRecordNotFoundException from e
finally: finally:
self._lock.release() self._lock.release()
if not result: if not result:
raise self.ImageRecordNotFoundException raise ImageRecordNotFoundException
return deserialize_image_record(result) return deserialize_image_record(dict(result))
def get_many( def get_many(
self, self,
@ -260,14 +207,15 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
f"""--sql f"""--sql
SELECT * FROM images SELECT * FROM images
WHERE image_type = ? AND image_category = ? WHERE image_type = ? AND image_category = ?
ORDER BY created_at DESC
LIMIT ? OFFSET ?; LIMIT ? OFFSET ?;
""", """,
(image_type.value, image_category.value, per_page, page * per_page), (image_type.value, image_category.value, per_page, page * per_page),
) )
result = self._cursor.fetchall() result = cast(list[sqlite3.Row], self._cursor.fetchall())
images = list(map(lambda r: deserialize_image_record(r), result)) images = list(map(lambda r: deserialize_image_record(dict(r)), result))
self._cursor.execute( self._cursor.execute(
"""--sql """--sql
@ -296,14 +244,14 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
self._cursor.execute( self._cursor.execute(
"""--sql """--sql
DELETE FROM images DELETE FROM images
WHERE id = ?; WHERE image_name = ?;
""", """,
(image_name,), (image_name,),
) )
self._conn.commit() self._conn.commit()
except sqlite3.Error as e: except sqlite3.Error as e:
self._conn.rollback() self._conn.rollback()
raise ImageRecordStorageBase.ImageRecordDeleteException from e raise ImageRecordDeleteException from e
finally: finally:
self._lock.release() self._lock.release()
@ -313,10 +261,11 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
image_type: ImageType, image_type: ImageType,
image_category: ImageCategory, image_category: ImageCategory,
session_id: Optional[str], session_id: Optional[str],
width: int,
height: int,
node_id: Optional[str], node_id: Optional[str],
metadata: Optional[ImageMetadata], metadata: Optional[ImageMetadata],
created_at: str, ) -> datetime:
) -> None:
try: try:
metadata_json = ( metadata_json = (
None if metadata is None else metadata.json(exclude_none=True) None if metadata is None else metadata.json(exclude_none=True)
@ -325,29 +274,44 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
self._cursor.execute( self._cursor.execute(
"""--sql """--sql
INSERT OR IGNORE INTO images ( INSERT OR IGNORE INTO images (
id, image_name,
image_type, image_type,
image_category, image_category,
width,
height,
node_id, node_id,
session_id, session_id,
metadata, metadata
created_at
) )
VALUES (?, ?, ?, ?, ?, ?, ?); VALUES (?, ?, ?, ?, ?, ?, ?, ?);
""", """,
( (
image_name, image_name,
image_type.value, image_type.value,
image_category.value, image_category.value,
width,
height,
node_id, node_id,
session_id, session_id,
metadata_json, metadata_json,
created_at,
), ),
) )
self._conn.commit() self._conn.commit()
self._cursor.execute(
"""--sql
SELECT created_at
FROM images
WHERE image_name = ?;
""",
(image_name,),
)
created_at = datetime.fromisoformat(self._cursor.fetchone()[0])
return created_at
except sqlite3.Error as e: except sqlite3.Error as e:
self._conn.rollback() self._conn.rollback()
raise ImageRecordStorageBase.ImageRecordNotFoundException from e raise ImageRecordSaveException from e
finally: finally:
self._lock.release() self._lock.release()

View File

@ -4,9 +4,17 @@ from typing import Optional, TYPE_CHECKING, Union
import uuid import uuid
from PIL.Image import Image as PILImageType from PIL.Image import Image as PILImageType
from invokeai.app.models.image import ImageCategory, ImageType from invokeai.app.models.image import (
ImageCategory,
ImageType,
InvalidImageCategoryException,
InvalidImageTypeException,
)
from invokeai.app.models.metadata import ImageMetadata from invokeai.app.models.metadata import ImageMetadata
from invokeai.app.services.image_record_storage import ( from invokeai.app.services.image_record_storage import (
ImageRecordDeleteException,
ImageRecordNotFoundException,
ImageRecordSaveException,
ImageRecordStorageBase, ImageRecordStorageBase,
) )
from invokeai.app.services.models.image_record import ( from invokeai.app.services.models.image_record import (
@ -14,7 +22,12 @@ from invokeai.app.services.models.image_record import (
ImageDTO, ImageDTO,
image_record_to_dto, image_record_to_dto,
) )
from invokeai.app.services.image_file_storage import ImageFileStorageBase from invokeai.app.services.image_file_storage import (
ImageFileDeleteException,
ImageFileNotFoundException,
ImageFileSaveException,
ImageFileStorageBase,
)
from invokeai.app.services.item_storage import ItemStorageABC, PaginatedResults from invokeai.app.services.item_storage import ItemStorageABC, PaginatedResults
from invokeai.app.services.metadata import MetadataServiceBase from invokeai.app.services.metadata import MetadataServiceBase
from invokeai.app.services.urls import UrlServiceBase from invokeai.app.services.urls import UrlServiceBase
@ -50,6 +63,11 @@ class ImageServiceABC(ABC):
"""Gets an image record.""" """Gets an image record."""
pass pass
@abstractmethod
def get_dto(self, image_type: ImageType, image_name: str) -> ImageDTO:
"""Gets an image DTO."""
pass
@abstractmethod @abstractmethod
def get_path(self, image_type: ImageType, image_name: str) -> str: def get_path(self, image_type: ImageType, image_name: str) -> str:
"""Gets an image's path""" """Gets an image's path"""
@ -62,11 +80,6 @@ class ImageServiceABC(ABC):
"""Gets an image's or thumbnail's URL""" """Gets an image's or thumbnail's URL"""
pass pass
@abstractmethod
def get_dto(self, image_type: ImageType, image_name: str) -> ImageDTO:
"""Gets an image DTO."""
pass
@abstractmethod @abstractmethod
def get_many( def get_many(
self, self,
@ -83,26 +96,6 @@ class ImageServiceABC(ABC):
"""Deletes an image.""" """Deletes an image."""
pass pass
@abstractmethod
def add_tag(self, image_type: ImageType, image_id: str, tag: str) -> None:
"""Adds a tag to an image."""
pass
@abstractmethod
def remove_tag(self, image_type: ImageType, image_id: str, tag: str) -> None:
"""Removes a tag from an image."""
pass
@abstractmethod
def favorite(self, image_type: ImageType, image_id: str) -> None:
"""Favorites an image."""
pass
@abstractmethod
def unfavorite(self, image_type: ImageType, image_id: str) -> None:
"""Unfavorites an image."""
pass
class ImageServiceDependencies: class ImageServiceDependencies:
"""Service dependencies for the ImageService.""" """Service dependencies for the ImageService."""
@ -160,6 +153,12 @@ class ImageService(ImageServiceABC):
node_id: Optional[str] = None, node_id: Optional[str] = None,
session_id: Optional[str] = None, session_id: Optional[str] = None,
) -> ImageDTO: ) -> ImageDTO:
if image_type not in ImageType:
raise InvalidImageTypeException
if image_category not in ImageCategory:
raise InvalidImageCategoryException
image_name = self._create_image_name( image_name = self._create_image_name(
image_type=image_type, image_type=image_type,
image_category=image_category, image_category=image_category,
@ -167,11 +166,25 @@ class ImageService(ImageServiceABC):
session_id=session_id, session_id=session_id,
) )
timestamp = get_iso_timestamp()
metadata = self._get_metadata(session_id, node_id) metadata = self._get_metadata(session_id, node_id)
(width, height) = image.size
try: try:
# TODO: Consider using a transaction here to ensure consistency between storage and database # TODO: Consider using a transaction here to ensure consistency between storage and database
created_at = self._services.records.save(
# Non-nullable fields
image_name=image_name,
image_type=image_type,
image_category=image_category,
width=width,
height=height,
# Nullable fields
node_id=node_id,
session_id=session_id,
metadata=metadata,
)
self._services.files.save( self._services.files.save(
image_type=image_type, image_type=image_type,
image_name=image_name, image_name=image_name,
@ -179,36 +192,34 @@ class ImageService(ImageServiceABC):
metadata=metadata, metadata=metadata,
) )
self._services.records.save(
image_name=image_name,
image_type=image_type,
image_category=image_category,
node_id=node_id,
session_id=session_id,
metadata=metadata,
created_at=timestamp,
)
image_url = self._services.urls.get_image_url(image_type, image_name) image_url = self._services.urls.get_image_url(image_type, image_name)
thumbnail_url = self._services.urls.get_image_url( thumbnail_url = self._services.urls.get_image_url(
image_type, image_name, True image_type, image_name, True
) )
return ImageDTO( return ImageDTO(
# Non-nullable fields
image_name=image_name, image_name=image_name,
image_type=image_type, image_type=image_type,
image_category=image_category, image_category=image_category,
width=width,
height=height,
# Nullable fields
node_id=node_id, node_id=node_id,
session_id=session_id, session_id=session_id,
metadata=metadata, metadata=metadata,
created_at=timestamp, # Meta fields
created_at=created_at,
updated_at=created_at, # this is always the same as the created_at at this time
deleted_at=None,
# Extra non-nullable fields for DTO
image_url=image_url, image_url=image_url,
thumbnail_url=thumbnail_url, thumbnail_url=thumbnail_url,
) )
except ImageRecordStorageBase.ImageRecordSaveException: except ImageRecordSaveException:
self._services.logger.error("Failed to save image record") self._services.logger.error("Failed to save image record")
raise raise
except ImageFileStorageBase.ImageFileSaveException: except ImageFileSaveException:
self._services.logger.error("Failed to save image file") self._services.logger.error("Failed to save image file")
raise raise
except Exception as e: except Exception as e:
@ -218,7 +229,7 @@ class ImageService(ImageServiceABC):
def get_pil_image(self, image_type: ImageType, image_name: str) -> PILImageType: def get_pil_image(self, image_type: ImageType, image_name: str) -> PILImageType:
try: try:
return self._services.files.get(image_type, image_name) return self._services.files.get(image_type, image_name)
except ImageFileStorageBase.ImageFileNotFoundException: except ImageFileNotFoundException:
self._services.logger.error("Failed to get image file") self._services.logger.error("Failed to get image file")
raise raise
except Exception as e: except Exception as e:
@ -228,7 +239,7 @@ class ImageService(ImageServiceABC):
def get_record(self, image_type: ImageType, image_name: str) -> ImageRecord: def get_record(self, image_type: ImageType, image_name: str) -> ImageRecord:
try: try:
return self._services.records.get(image_type, image_name) return self._services.records.get(image_type, image_name)
except ImageRecordStorageBase.ImageRecordNotFoundException: except ImageRecordNotFoundException:
self._services.logger.error("Image record not found") self._services.logger.error("Image record not found")
raise raise
except Exception as e: except Exception as e:
@ -246,7 +257,7 @@ class ImageService(ImageServiceABC):
) )
return image_dto return image_dto
except ImageRecordStorageBase.ImageRecordNotFoundException: except ImageRecordNotFoundException:
self._services.logger.error("Image record not found") self._services.logger.error("Image record not found")
raise raise
except Exception as e: except Exception as e:
@ -311,32 +322,19 @@ class ImageService(ImageServiceABC):
raise e raise e
def delete(self, image_type: ImageType, image_name: str): def delete(self, image_type: ImageType, image_name: str):
# TODO: Consider using a transaction here to ensure consistency between storage and database
try: try:
self._services.files.delete(image_type, image_name) self._services.files.delete(image_type, image_name)
self._services.records.delete(image_type, image_name) self._services.records.delete(image_type, image_name)
except ImageRecordStorageBase.ImageRecordDeleteException: except ImageRecordDeleteException:
self._services.logger.error(f"Failed to delete image record") self._services.logger.error(f"Failed to delete image record")
raise raise
except ImageFileStorageBase.ImageFileDeleteException: except ImageFileDeleteException:
self._services.logger.error(f"Failed to delete image file") self._services.logger.error(f"Failed to delete image file")
raise raise
except Exception as e: except Exception as e:
self._services.logger.error("Problem deleting image record and file") self._services.logger.error("Problem deleting image record and file")
raise e raise e
def add_tag(self, image_type: ImageType, image_id: str, tag: str) -> None:
raise NotImplementedError("The 'add_tag' method is not implemented yet.")
def remove_tag(self, image_type: ImageType, image_id: str, tag: str) -> None:
raise NotImplementedError("The 'remove_tag' method is not implemented yet.")
def favorite(self, image_type: ImageType, image_id: str) -> None:
raise NotImplementedError("The 'favorite' method is not implemented yet.")
def unfavorite(self, image_type: ImageType, image_id: str) -> None:
raise NotImplementedError("The 'unfavorite' method is not implemented yet.")
def _create_image_name( def _create_image_name(
self, self,
image_type: ImageType, image_type: ImageType,

View File

@ -1,5 +1,4 @@
import datetime import datetime
import sqlite3
from typing import Optional, Union from typing import Optional, Union
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from invokeai.app.models.image import ImageCategory, ImageType from invokeai.app.models.image import ImageCategory, ImageType
@ -10,30 +9,60 @@ from invokeai.app.util.misc import get_iso_timestamp
class ImageRecord(BaseModel): class ImageRecord(BaseModel):
"""Deserialized image record.""" """Deserialized image record."""
image_name: str = Field(description="The name of the image.") image_name: str = Field(description="The unique name of the image.")
"""The unique name of the image."""
image_type: ImageType = Field(description="The type of the image.") image_type: ImageType = Field(description="The type of the image.")
"""The type of the image."""
image_category: ImageCategory = Field(description="The category of the image.") image_category: ImageCategory = Field(description="The category of the image.")
"""The category of the image."""
width: int = Field(description="The width of the image in px.")
"""The actual width of the image in px. This may be different from the width in metadata."""
height: int = Field(description="The height of the image in px.")
"""The actual height of the image in px. This may be different from the height in metadata."""
created_at: Union[datetime.datetime, str] = Field( created_at: Union[datetime.datetime, str] = Field(
description="The created timestamp of the image." description="The created timestamp of the image."
) )
session_id: Optional[str] = Field(default=None, description="The session ID.") """The created timestamp of the image."""
node_id: Optional[str] = Field(default=None, description="The node ID.") updated_at: Union[datetime.datetime, str] = Field(
metadata: Optional[ImageMetadata] = Field( description="The updated timestamp of the image."
default=None, description="The image's metadata."
) )
"""The updated timestamp of the image."""
deleted_at: Union[datetime.datetime, str, None] = Field(
description="The deleted timestamp of the image."
)
"""The deleted timestamp of the image."""
session_id: Optional[str] = Field(
default=None,
description="The session ID that generated this image, if it is a generated image.",
)
"""The session ID that generated this image, if it is a generated image."""
node_id: Optional[str] = Field(
default=None,
description="The node ID that generated this image, if it is a generated image.",
)
"""The node ID that generated this image, if it is a generated image."""
metadata: Optional[ImageMetadata] = Field(
default=None,
description="A limited subset of the image's generation metadata. Retrieve the image's session for full metadata.",
)
"""A limited subset of the image's generation metadata. Retrieve the image's session for full metadata."""
class ImageUrlsDTO(BaseModel): class ImageUrlsDTO(BaseModel):
"""The URLs for an image and its thumbnaill""" """The URLs for an image and its thumbnail."""
image_name: str = Field(description="The name of the image.") image_name: str = Field(description="The unique name of the image.")
"""The unique name of the image."""
image_type: ImageType = Field(description="The type of the image.") image_type: ImageType = Field(description="The type of the image.")
"""The type of the image."""
image_url: str = Field(description="The URL of the image.") image_url: str = Field(description="The URL of the image.")
thumbnail_url: str = Field(description="The thumbnail URL of the image.") """The URL of the image."""
thumbnail_url: str = Field(description="The URL of the image's thumbnail.")
"""The URL of the image's thumbnail."""
class ImageDTO(ImageRecord, ImageUrlsDTO): class ImageDTO(ImageRecord, ImageUrlsDTO):
"""Deserialized image record with URLs.""" """Deserialized image record, enriched for the frontend with URLs."""
pass pass
@ -43,24 +72,29 @@ def image_record_to_dto(
) -> ImageDTO: ) -> ImageDTO:
"""Converts an image record to an image DTO.""" """Converts an image record to an image DTO."""
return ImageDTO( return ImageDTO(
image_name=image_record.image_name, **image_record.dict(),
image_type=image_record.image_type,
image_category=image_record.image_category,
created_at=image_record.created_at,
session_id=image_record.session_id,
node_id=image_record.node_id,
metadata=image_record.metadata,
image_url=image_url, image_url=image_url,
thumbnail_url=thumbnail_url, thumbnail_url=thumbnail_url,
) )
def deserialize_image_record(image_row: sqlite3.Row) -> ImageRecord: def deserialize_image_record(image_dict: dict) -> ImageRecord:
"""Deserializes an image record.""" """Deserializes an image record."""
image_dict = dict(image_row) # Retrieve all the values, setting "reasonable" defaults if they are not present.
image_name = image_dict.get("image_name", "unknown")
image_type = ImageType(image_dict.get("image_type", ImageType.RESULT.value)) image_type = ImageType(image_dict.get("image_type", ImageType.RESULT.value))
image_category = ImageCategory(
image_dict.get("image_category", ImageCategory.GENERAL.value)
)
width = image_dict.get("width", 0)
height = image_dict.get("height", 0)
session_id = image_dict.get("session_id", None)
node_id = image_dict.get("node_id", None)
created_at = image_dict.get("created_at", get_iso_timestamp())
updated_at = image_dict.get("updated_at", get_iso_timestamp())
deleted_at = image_dict.get("deleted_at", get_iso_timestamp())
raw_metadata = image_dict.get("metadata") raw_metadata = image_dict.get("metadata")
@ -70,13 +104,15 @@ def deserialize_image_record(image_row: sqlite3.Row) -> ImageRecord:
metadata = None metadata = None
return ImageRecord( return ImageRecord(
image_name=image_dict.get("id", "unknown"), image_name=image_name,
session_id=image_dict.get("session_id", None),
node_id=image_dict.get("node_id", None),
metadata=metadata,
image_type=image_type, image_type=image_type,
image_category=ImageCategory( image_category=image_category,
image_dict.get("image_category", ImageCategory.IMAGE.value) width=width,
), height=height,
created_at=image_dict.get("created_at", get_iso_timestamp()), session_id=session_id,
node_id=node_id,
metadata=metadata,
created_at=created_at,
updated_at=updated_at,
deleted_at=deleted_at,
) )

View File

@ -2,7 +2,10 @@ from enum import EnumMeta
class MetaEnum(EnumMeta): class MetaEnum(EnumMeta):
"""Metaclass to support `in` syntax value checking in String Enums""" """Metaclass to support additional features in Enums.
- `in` operator support: `'value' in MyEnum -> bool`
"""
def __contains__(cls, item): def __contains__(cls, item):
try: try: