feat(db,nodes,api): refactor metadata

Metadata for the Linear UI is now sneakily provided via a `MetadataAccumulator` node, which the client populates / hooks up while building the graph.

Additionally, we provide the unexpanded graph with the metadata API response.

Both of these are embedded into the PNGs.

- Remove `metadata` from `ImageDTO`
- Split up the `images/` routes to accomodate this; metadata is only retrieved per-image
- `images/{image_name}` now gets the DTO
- `images/{image_name}/metadata` gets the new metadata
- `images/{image_name}/full` gets the full-sized image file
- Remove old metadata service
- Add `MetadataAccumulator` node, `CoreMetadataField`, hook up to `LatentsToImage` node
- Add `get_raw()` method to `ItemStorage`, retrieves the row from DB as a string, no pydantic parsing
- Update `images`related services to handle storing and retrieving the new metadata
- Add `get_metadata_graph_from_raw_session` which extracts the `graph` from `session` without needing to hydrate the session in pydantic, in preparation for providing it as metadata; also removes all references to the `MetadataAccumulator` node
This commit is contained in:
psychedelicious 2023-07-13 01:14:22 +10:00
parent eb0d55263b
commit 50bef87da7
15 changed files with 384 additions and 348 deletions

View File

@ -13,7 +13,6 @@ from invokeai.app.services.board_record_storage import SqliteBoardRecordStorage
from invokeai.app.services.boards import BoardService, BoardServiceDependencies from invokeai.app.services.boards import BoardService, BoardServiceDependencies
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
from invokeai.app.services.images import ImageService, ImageServiceDependencies from invokeai.app.services.images import ImageService, ImageServiceDependencies
from invokeai.app.services.metadata import CoreMetadataService
from invokeai.app.services.resource_name import SimpleNameService from invokeai.app.services.resource_name import SimpleNameService
from invokeai.app.services.urls import LocalUrlService from invokeai.app.services.urls import LocalUrlService
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
@ -75,7 +74,6 @@ class ApiDependencies:
) )
urls = LocalUrlService() urls = LocalUrlService()
metadata = CoreMetadataService()
image_record_storage = SqliteImageRecordStorage(db_location) image_record_storage = SqliteImageRecordStorage(db_location)
image_file_storage = DiskImageFileStorage(f"{output_folder}/images") image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
names = SimpleNameService() names = SimpleNameService()
@ -111,7 +109,6 @@ class ApiDependencies:
board_image_record_storage=board_image_record_storage, board_image_record_storage=board_image_record_storage,
image_record_storage=image_record_storage, image_record_storage=image_record_storage,
image_file_storage=image_file_storage, image_file_storage=image_file_storage,
metadata=metadata,
url=urls, url=urls,
logger=logger, logger=logger,
names=names, names=names,

View File

@ -1,20 +1,19 @@
import io import io
from typing import Optional from typing import Optional
from fastapi import Body, HTTPException, Path, Query, Request, Response, UploadFile
from fastapi.routing import APIRouter from fastapi import (Body, HTTPException, Path, Query, Request, Response,
UploadFile)
from fastapi.responses import FileResponse from fastapi.responses import FileResponse
from fastapi.routing import APIRouter
from PIL import Image from PIL import Image
from invokeai.app.models.image import (
ImageCategory, from invokeai.app.invocations.metadata import ImageMetadata
ResourceOrigin, from invokeai.app.models.image import ImageCategory, ResourceOrigin
)
from invokeai.app.services.image_record_storage import OffsetPaginatedResults from invokeai.app.services.image_record_storage import OffsetPaginatedResults
from invokeai.app.services.models.image_record import (
ImageDTO,
ImageRecordChanges,
ImageUrlsDTO,
)
from invokeai.app.services.item_storage import PaginatedResults from invokeai.app.services.item_storage import PaginatedResults
from invokeai.app.services.models.image_record import (ImageDTO,
ImageRecordChanges,
ImageUrlsDTO)
from ..dependencies import ApiDependencies from ..dependencies import ApiDependencies
@ -103,23 +102,38 @@ async def update_image(
@images_router.get( @images_router.get(
"/{image_name}/metadata", "/{image_name}",
operation_id="get_image_metadata", operation_id="get_image_dto",
response_model=ImageDTO, response_model=ImageDTO,
) )
async def get_image_metadata( async def get_image_dto(
image_name: str = Path(description="The name of image to get"), image_name: str = Path(description="The name of image to get"),
) -> ImageDTO: ) -> ImageDTO:
"""Gets an image's metadata""" """Gets an image's DTO"""
try: try:
return ApiDependencies.invoker.services.images.get_dto(image_name) return ApiDependencies.invoker.services.images.get_dto(image_name)
except Exception as e: except Exception as e:
raise HTTPException(status_code=404) raise HTTPException(status_code=404)
@images_router.get(
"/{image_name}/metadata",
operation_id="get_image_metadata",
response_model=ImageMetadata,
)
async def get_image_metadata(
image_name: str = Path(description="The name of image to get"),
) -> ImageMetadata:
"""Gets an image's metadata"""
try:
return ApiDependencies.invoker.services.images.get_metadata(image_name)
except Exception as e:
raise HTTPException(status_code=404)
@images_router.get( @images_router.get(
"/{image_name}", "/{image_name}/full",
operation_id="get_image_full", operation_id="get_image_full",
response_class=Response, response_class=Response,
responses={ responses={
@ -208,10 +222,10 @@ async def get_image_urls(
@images_router.get( @images_router.get(
"/", "/",
operation_id="list_images_with_metadata", operation_id="list_image_dtos",
response_model=OffsetPaginatedResults[ImageDTO], response_model=OffsetPaginatedResults[ImageDTO],
) )
async def list_images_with_metadata( async def list_image_dtos(
image_origin: Optional[ResourceOrigin] = Query( image_origin: Optional[ResourceOrigin] = Query(
default=None, description="The origin of images to list" default=None, description="The origin of images to list"
), ),
@ -227,7 +241,7 @@ async def list_images_with_metadata(
offset: int = Query(default=0, description="The page offset"), offset: int = Query(default=0, description="The page offset"),
limit: int = Query(default=10, description="The number of images per page"), limit: int = Query(default=10, description="The number of images per page"),
) -> OffsetPaginatedResults[ImageDTO]: ) -> OffsetPaginatedResults[ImageDTO]:
"""Gets a list of images""" """Gets a list of image DTOs"""
image_dtos = ApiDependencies.invoker.services.images.get_many( image_dtos = ApiDependencies.invoker.services.images.get_many(
offset, offset,

View File

@ -34,7 +34,6 @@ from invokeai.app.services.board_record_storage import SqliteBoardRecordStorage
from invokeai.app.services.boards import BoardService, BoardServiceDependencies from invokeai.app.services.boards import BoardService, BoardServiceDependencies
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
from invokeai.app.services.images import ImageService, ImageServiceDependencies from invokeai.app.services.images import ImageService, ImageServiceDependencies
from invokeai.app.services.metadata import CoreMetadataService
from invokeai.app.services.resource_name import SimpleNameService from invokeai.app.services.resource_name import SimpleNameService
from invokeai.app.services.urls import LocalUrlService from invokeai.app.services.urls import LocalUrlService
from .services.default_graphs import (default_text_to_image_graph_id, from .services.default_graphs import (default_text_to_image_graph_id,
@ -244,7 +243,6 @@ def invoke_cli():
) )
urls = LocalUrlService() urls = LocalUrlService()
metadata = CoreMetadataService()
image_record_storage = SqliteImageRecordStorage(db_location) image_record_storage = SqliteImageRecordStorage(db_location)
image_file_storage = DiskImageFileStorage(f"{output_folder}/images") image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
names = SimpleNameService() names = SimpleNameService()
@ -277,7 +275,6 @@ def invoke_cli():
board_image_record_storage=board_image_record_storage, board_image_record_storage=board_image_record_storage,
image_record_storage=image_record_storage, image_record_storage=image_record_storage,
image_file_storage=image_file_storage, image_file_storage=image_file_storage,
metadata=metadata,
url=urls, url=urls,
logger=logger, logger=logger,
names=names, names=names,

View File

@ -9,9 +9,9 @@ from diffusers.image_processor import VaeImageProcessor
from diffusers.schedulers import SchedulerMixin as Scheduler from diffusers.schedulers import SchedulerMixin as Scheduler
from pydantic import BaseModel, Field, validator from pydantic import BaseModel, Field, validator
from invokeai.app.invocations.metadata import CoreMetadata
from invokeai.app.util.step_callback import stable_diffusion_step_callback from invokeai.app.util.step_callback import stable_diffusion_step_callback
from ..models.image import ImageCategory, ImageField, ResourceOrigin
from ...backend.model_management.lora import ModelPatcher from ...backend.model_management.lora import ModelPatcher
from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.stable_diffusion.diffusers_pipeline import ( from ...backend.stable_diffusion.diffusers_pipeline import (
@ -21,6 +21,7 @@ from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import \
PostprocessingSettings PostprocessingSettings
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
from ...backend.util.devices import torch_dtype from ...backend.util.devices import torch_dtype
from ..models.image import ImageCategory, ImageField, ResourceOrigin
from .baseinvocation import (BaseInvocation, BaseInvocationOutput, from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
InvocationConfig, InvocationContext) InvocationConfig, InvocationContext)
from .compel import ConditioningField from .compel import ConditioningField
@ -449,6 +450,8 @@ class LatentsToImageInvocation(BaseInvocation):
tiled: bool = Field( tiled: bool = Field(
default=False, default=False,
description="Decode latents by overlaping tiles(less memory consumption)") description="Decode latents by overlaping tiles(less memory consumption)")
metadata: Optional[CoreMetadata] = Field(default=None, description="Optional core metadata to be written to the image")
# Schema customisation # Schema customisation
class Config(InvocationConfig): class Config(InvocationConfig):
@ -493,7 +496,8 @@ class LatentsToImageInvocation(BaseInvocation):
image_category=ImageCategory.GENERAL, image_category=ImageCategory.GENERAL,
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate is_intermediate=self.is_intermediate,
metadata=self.metadata.dict() if self.metadata else None,
) )
return ImageOutput( return ImageOutput(

View File

@ -0,0 +1,124 @@
from typing import Literal, Optional, Union
from pydantic import BaseModel, Field
from invokeai.app.invocations.baseinvocation import (BaseInvocation,
BaseInvocationOutput,
InvocationContext)
from invokeai.app.invocations.controlnet_image_processors import ControlField
from invokeai.app.invocations.model import (LoRAModelField, MainModelField,
VAEModelField)
class LoRAMetadataField(BaseModel):
"""LoRA metadata for an image generated in InvokeAI."""
lora: LoRAModelField = Field(description="The LoRA model")
weight: float = Field(description="The weight of the LoRA model")
class CoreMetadata(BaseModel):
"""Core generation metadata for an image generated in InvokeAI."""
generation_mode: str = Field(description="The generation mode that output this image",)
positive_prompt: str = Field(description="The positive prompt parameter")
negative_prompt: str = Field(description="The negative prompt parameter")
width: int = Field(description="The width parameter")
height: int = Field(description="The height parameter")
seed: int = Field(description="The seed used for noise generation")
rand_device: str = Field(description="The device used for random number generation")
cfg_scale: float = Field(description="The classifier-free guidance scale parameter")
steps: int = Field(description="The number of steps used for inference")
scheduler: str = Field(description="The scheduler used for inference")
clip_skip: int = Field(description="The number of skipped CLIP layers",)
model: MainModelField = Field(description="The main model used for inference")
controlnets: list[ControlField]= Field(description="The ControlNets used for inference")
loras: list[LoRAMetadataField] = Field(description="The LoRAs used for inference")
strength: Union[float, None] = Field(
default=None,
description="The strength used for latents-to-latents",
)
init_image: Union[str, None] = Field(
default=None, description="The name of the initial image"
)
vae: Union[VAEModelField, None] = Field(
default=None,
description="The VAE used for decoding, if the main model's default was not used",
)
class ImageMetadata(BaseModel):
"""An image's generation metadata"""
metadata: Optional[dict] = Field(
default=None,
description="The image's core metadata, if it was created in the Linear or Canvas UI",
)
graph: Optional[dict] = Field(
default=None, description="The graph that created the image"
)
class MetadataAccumulatorOutput(BaseInvocationOutput):
"""The output of the MetadataAccumulator node"""
type: Literal["metadata_accumulator_output"] = "metadata_accumulator_output"
metadata: CoreMetadata = Field(description="The core metadata for the image")
class MetadataAccumulatorInvocation(BaseInvocation):
"""Outputs a Core Metadata Object"""
type: Literal["metadata_accumulator"] = "metadata_accumulator"
generation_mode: str = Field(description="The generation mode that output this image",)
positive_prompt: str = Field(description="The positive prompt parameter")
negative_prompt: str = Field(description="The negative prompt parameter")
width: int = Field(description="The width parameter")
height: int = Field(description="The height parameter")
seed: int = Field(description="The seed used for noise generation")
rand_device: str = Field(description="The device used for random number generation")
cfg_scale: float = Field(description="The classifier-free guidance scale parameter")
steps: int = Field(description="The number of steps used for inference")
scheduler: str = Field(description="The scheduler used for inference")
clip_skip: int = Field(description="The number of skipped CLIP layers",)
model: MainModelField = Field(description="The main model used for inference")
controlnets: list[ControlField]= Field(description="The ControlNets used for inference")
loras: list[LoRAMetadataField] = Field(description="The LoRAs used for inference")
strength: Union[float, None] = Field(
default=None,
description="The strength used for latents-to-latents",
)
init_image: Union[str, None] = Field(
default=None, description="The name of the initial image"
)
vae: Union[VAEModelField, None] = Field(
default=None,
description="The VAE used for decoding, if the main model's default was not used",
)
def invoke(self, context: InvocationContext) -> MetadataAccumulatorOutput:
"""Collects and outputs a CoreMetadata object"""
return MetadataAccumulatorOutput(
metadata=CoreMetadata(
generation_mode=self.generation_mode,
positive_prompt=self.positive_prompt,
negative_prompt=self.negative_prompt,
width=self.width,
height=self.height,
seed=self.seed,
rand_device=self.rand_device,
cfg_scale=self.cfg_scale,
steps=self.steps,
scheduler=self.scheduler,
model=self.model,
strength=self.strength,
init_image=self.init_image,
vae=self.vae,
controlnets=self.controlnets,
loras=self.loras,
clip_skip=self.clip_skip,
)
)

View File

@ -1,93 +0,0 @@
from typing import Optional, Union, List
from pydantic import BaseModel, Extra, Field, StrictFloat, StrictInt, StrictStr
class ImageMetadata(BaseModel):
"""
Core generation metadata for an image/tensor generated in InvokeAI.
Also includes any metadata from the image's PNG tEXt chunks.
Generated by traversing the execution graph, collecting the parameters of the nearest ancestors
of a given node.
Full metadata may be accessed by querying for the session in the `graph_executions` table.
"""
class Config:
extra = Extra.allow
"""
This lets the ImageMetadata class accept arbitrary additional fields. The CoreMetadataService
won't add any fields that are not already defined, but other a different metadata service
implementation might.
"""
type: Optional[StrictStr] = Field(
default=None,
description="The type of the ancestor node of the image output node.",
)
"""The type of the ancestor node of the image output node."""
positive_conditioning: Optional[StrictStr] = Field(
default=None, description="The positive conditioning."
)
"""The positive conditioning"""
negative_conditioning: Optional[StrictStr] = Field(
default=None, description="The negative conditioning."
)
"""The negative conditioning"""
width: Optional[StrictInt] = Field(
default=None, description="Width of the image/latents in pixels."
)
"""Width of the image/latents in pixels"""
height: Optional[StrictInt] = Field(
default=None, description="Height of the image/latents in pixels."
)
"""Height of the image/latents in pixels"""
seed: Optional[StrictInt] = Field(
default=None, description="The seed used for noise generation."
)
"""The seed used for noise generation"""
# cfg_scale: Optional[StrictFloat] = Field(
# cfg_scale: Union[float, list[float]] = Field(
cfg_scale: Union[StrictFloat, List[StrictFloat]] = Field(
default=None, description="The classifier-free guidance scale."
)
"""The classifier-free guidance scale"""
steps: Optional[StrictInt] = Field(
default=None, description="The number of steps used for inference."
)
"""The number of steps used for inference"""
scheduler: Optional[StrictStr] = Field(
default=None, description="The scheduler used for inference."
)
"""The scheduler used for inference"""
model: Optional[StrictStr] = Field(
default=None, description="The model used for inference."
)
"""The model used for inference"""
strength: Optional[StrictFloat] = Field(
default=None,
description="The strength used for image-to-image/latents-to-latents.",
)
"""The strength used for image-to-image/latents-to-latents."""
latents: Optional[StrictStr] = Field(
default=None, description="The ID of the initial latents."
)
"""The ID of the initial latents"""
vae: Optional[StrictStr] = Field(
default=None, description="The VAE used for decoding."
)
"""The VAE used for decoding"""
unet: Optional[StrictStr] = Field(
default=None, description="The UNet used dor inference."
)
"""The UNet used dor inference"""
clip: Optional[StrictStr] = Field(
default=None, description="The CLIP Encoder used for conditioning."
)
"""The CLIP Encoder used for conditioning"""
extra: Optional[StrictStr] = Field(
default=None,
description="Uploaded image metadata, extracted from the PNG tEXt chunk.",
)
"""Uploaded image metadata, extracted from the PNG tEXt chunk."""

View File

@ -1,14 +1,14 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
import json
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from queue import Queue from queue import Queue
from typing import Dict, Optional, Union from typing import Dict, Optional, Union
from PIL.Image import Image as PILImageType
from PIL import Image, PngImagePlugin from PIL import Image, PngImagePlugin
from PIL.Image import Image as PILImageType
from send2trash import send2trash from send2trash import send2trash
from invokeai.app.models.metadata import ImageMetadata
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
@ -59,7 +59,8 @@ class ImageFileStorageBase(ABC):
self, self,
image: PILImageType, image: PILImageType,
image_name: str, image_name: str,
metadata: Optional[ImageMetadata] = None, metadata: Optional[dict] = None,
graph: Optional[dict] = None,
thumbnail_size: int = 256, thumbnail_size: int = 256,
) -> None: ) -> None:
"""Saves an image and a 256x256 WEBP thumbnail. Returns a tuple of the image name, thumbnail name, and created timestamp.""" """Saves an image and a 256x256 WEBP thumbnail. Returns a tuple of the image name, thumbnail name, and created timestamp."""
@ -110,20 +111,22 @@ class DiskImageFileStorage(ImageFileStorageBase):
self, self,
image: PILImageType, image: PILImageType,
image_name: str, image_name: str,
metadata: Optional[ImageMetadata] = None, metadata: Optional[dict] = None,
graph: Optional[dict] = None,
thumbnail_size: int = 256, thumbnail_size: int = 256,
) -> None: ) -> None:
try: try:
self.__validate_storage_folders() self.__validate_storage_folders()
image_path = self.get_path(image_name) image_path = self.get_path(image_name)
pnginfo = PngImagePlugin.PngInfo()
if metadata is not None: if metadata is not None:
pnginfo = PngImagePlugin.PngInfo() pnginfo.add_text("metadata", json.dumps(metadata))
pnginfo.add_text("invokeai", metadata.json()) if graph is not None:
image.save(image_path, "PNG", pnginfo=pnginfo) pnginfo.add_text("graph", json.dumps(graph))
else:
image.save(image_path, "PNG")
image.save(image_path, "PNG", pnginfo=pnginfo)
thumbnail_name = get_thumbnail_name(image_name) thumbnail_name = get_thumbnail_name(image_name)
thumbnail_path = self.get_path(thumbnail_name, thumbnail=True) thumbnail_path = self.get_path(thumbnail_name, thumbnail=True)
thumbnail_image = make_thumbnail(image, thumbnail_size) thumbnail_image = make_thumbnail(image, thumbnail_size)

View File

@ -1,3 +1,4 @@
import json
import sqlite3 import sqlite3
import threading import threading
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
@ -8,7 +9,6 @@ from pydantic import BaseModel, Field
from pydantic.generics import GenericModel from pydantic.generics import GenericModel
from invokeai.app.models.image import ImageCategory, ResourceOrigin from invokeai.app.models.image import ImageCategory, ResourceOrigin
from invokeai.app.models.metadata import ImageMetadata
from invokeai.app.services.models.image_record import ( from invokeai.app.services.models.image_record import (
ImageRecord, ImageRecordChanges, deserialize_image_record) ImageRecord, ImageRecordChanges, deserialize_image_record)
@ -48,6 +48,28 @@ class ImageRecordDeleteException(Exception):
super().__init__(message) super().__init__(message)
IMAGE_DTO_COLS = ", ".join(
list(
map(
lambda c: "images." + c,
[
"image_name",
"image_origin",
"image_category",
"width",
"height",
"session_id",
"node_id",
"is_intermediate",
"created_at",
"updated_at",
"deleted_at",
],
)
)
)
class ImageRecordStorageBase(ABC): class ImageRecordStorageBase(ABC):
"""Low-level service responsible for interfacing with the image record store.""" """Low-level service responsible for interfacing with the image record store."""
@ -58,6 +80,11 @@ class ImageRecordStorageBase(ABC):
"""Gets an image record.""" """Gets an image record."""
pass pass
@abstractmethod
def get_metadata(self, image_name: str) -> Optional[dict]:
"""Gets an image's metadata'."""
pass
@abstractmethod @abstractmethod
def update( def update(
self, self,
@ -102,7 +129,7 @@ class ImageRecordStorageBase(ABC):
height: int, height: int,
session_id: Optional[str], session_id: Optional[str],
node_id: Optional[str], node_id: Optional[str],
metadata: Optional[ImageMetadata], metadata: Optional[dict],
is_intermediate: bool = False, is_intermediate: bool = False,
) -> datetime: ) -> datetime:
"""Saves an image record.""" """Saves an image record."""
@ -206,7 +233,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
self._cursor.execute( self._cursor.execute(
f"""--sql f"""--sql
SELECT * FROM images SELECT {IMAGE_DTO_COLS} FROM images
WHERE image_name = ?; WHERE image_name = ?;
""", """,
(image_name,), (image_name,),
@ -224,6 +251,28 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
return deserialize_image_record(dict(result)) return deserialize_image_record(dict(result))
def get_metadata(self, image_name: str) -> Optional[dict]:
try:
self._lock.acquire()
self._cursor.execute(
f"""--sql
SELECT images.metadata FROM images
WHERE image_name = ?;
""",
(image_name,),
)
result = cast(Optional[sqlite3.Row], self._cursor.fetchone())
if not result or not result[0]:
return None
return json.loads(result[0])
except sqlite3.Error as e:
self._conn.rollback()
raise ImageRecordNotFoundException from e
finally:
self._lock.release()
def update( def update(
self, self,
image_name: str, image_name: str,
@ -291,8 +340,8 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
WHERE 1=1 WHERE 1=1
""" """
images_query = """--sql images_query = f"""--sql
SELECT images.* SELECT {IMAGE_DTO_COLS}
FROM images FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1 WHERE 1=1
@ -410,12 +459,12 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
width: int, width: int,
height: int, height: int,
node_id: Optional[str], node_id: Optional[str],
metadata: Optional[ImageMetadata], metadata: Optional[dict],
is_intermediate: bool = False, is_intermediate: bool = False,
) -> datetime: ) -> datetime:
try: try:
metadata_json = ( metadata_json = (
None if metadata is None else metadata.json(exclude_none=True) None if metadata is None else json.dumps(metadata)
) )
self._lock.acquire() self._lock.acquire()
self._cursor.execute( self._cursor.execute(
@ -465,9 +514,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
finally: finally:
self._lock.release() self._lock.release()
def get_most_recent_image_for_board( def get_most_recent_image_for_board(self, board_id: str) -> Optional[ImageRecord]:
self, board_id: str
) -> Optional[ImageRecord]:
try: try:
self._lock.acquire() self._lock.acquire()
self._cursor.execute( self._cursor.execute(

View File

@ -1,39 +1,30 @@
import json
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from logging import Logger from logging import Logger
from typing import Optional, TYPE_CHECKING, Union from typing import TYPE_CHECKING, Optional
from PIL.Image import Image as PILImageType from PIL.Image import Image as PILImageType
from invokeai.app.models.image import ( from invokeai.app.invocations.metadata import ImageMetadata
ImageCategory, from invokeai.app.models.image import (ImageCategory,
ResourceOrigin, InvalidImageCategoryException,
InvalidImageCategoryException, InvalidOriginException, ResourceOrigin)
InvalidOriginException, from invokeai.app.services.board_image_record_storage import \
) BoardImageRecordStorageBase
from invokeai.app.models.metadata import ImageMetadata from invokeai.app.services.graph import Graph
from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase
from invokeai.app.services.image_record_storage import (
ImageRecordDeleteException,
ImageRecordNotFoundException,
ImageRecordSaveException,
ImageRecordStorageBase,
OffsetPaginatedResults,
)
from invokeai.app.services.models.image_record import (
ImageRecord,
ImageDTO,
ImageRecordChanges,
image_record_to_dto,
)
from invokeai.app.services.image_file_storage import ( from invokeai.app.services.image_file_storage import (
ImageFileDeleteException, ImageFileDeleteException, ImageFileNotFoundException,
ImageFileNotFoundException, ImageFileSaveException, ImageFileStorageBase)
ImageFileSaveException, from invokeai.app.services.image_record_storage import (
ImageFileStorageBase, ImageRecordDeleteException, ImageRecordNotFoundException,
) ImageRecordSaveException, ImageRecordStorageBase, OffsetPaginatedResults)
from invokeai.app.services.item_storage import ItemStorageABC, PaginatedResults from invokeai.app.services.item_storage import ItemStorageABC
from invokeai.app.services.metadata import MetadataServiceBase from invokeai.app.services.models.image_record import (ImageDTO, ImageRecord,
ImageRecordChanges,
image_record_to_dto)
from invokeai.app.services.resource_name import NameServiceBase from invokeai.app.services.resource_name import NameServiceBase
from invokeai.app.services.urls import UrlServiceBase from invokeai.app.services.urls import UrlServiceBase
from invokeai.app.util.metadata import get_metadata_graph_from_raw_session
if TYPE_CHECKING: if TYPE_CHECKING:
from invokeai.app.services.graph import GraphExecutionState from invokeai.app.services.graph import GraphExecutionState
@ -51,6 +42,7 @@ class ImageServiceABC(ABC):
node_id: Optional[str] = None, node_id: Optional[str] = None,
session_id: Optional[str] = None, session_id: Optional[str] = None,
is_intermediate: bool = False, is_intermediate: bool = False,
metadata: Optional[dict] = None,
) -> ImageDTO: ) -> ImageDTO:
"""Creates an image, storing the file and its metadata.""" """Creates an image, storing the file and its metadata."""
pass pass
@ -79,6 +71,11 @@ class ImageServiceABC(ABC):
"""Gets an image DTO.""" """Gets an image DTO."""
pass pass
@abstractmethod
def get_metadata(self, image_name: str) -> ImageMetadata:
"""Gets an image's metadata."""
pass
@abstractmethod @abstractmethod
def get_path(self, image_name: str, thumbnail: bool = False) -> str: def get_path(self, image_name: str, thumbnail: bool = False) -> str:
"""Gets an image's path.""" """Gets an image's path."""
@ -124,7 +121,6 @@ class ImageServiceDependencies:
image_records: ImageRecordStorageBase image_records: ImageRecordStorageBase
image_files: ImageFileStorageBase image_files: ImageFileStorageBase
board_image_records: BoardImageRecordStorageBase board_image_records: BoardImageRecordStorageBase
metadata: MetadataServiceBase
urls: UrlServiceBase urls: UrlServiceBase
logger: Logger logger: Logger
names: NameServiceBase names: NameServiceBase
@ -135,7 +131,6 @@ class ImageServiceDependencies:
image_record_storage: ImageRecordStorageBase, image_record_storage: ImageRecordStorageBase,
image_file_storage: ImageFileStorageBase, image_file_storage: ImageFileStorageBase,
board_image_record_storage: BoardImageRecordStorageBase, board_image_record_storage: BoardImageRecordStorageBase,
metadata: MetadataServiceBase,
url: UrlServiceBase, url: UrlServiceBase,
logger: Logger, logger: Logger,
names: NameServiceBase, names: NameServiceBase,
@ -144,7 +139,6 @@ class ImageServiceDependencies:
self.image_records = image_record_storage self.image_records = image_record_storage
self.image_files = image_file_storage self.image_files = image_file_storage
self.board_image_records = board_image_record_storage self.board_image_records = board_image_record_storage
self.metadata = metadata
self.urls = url self.urls = url
self.logger = logger self.logger = logger
self.names = names self.names = names
@ -165,6 +159,7 @@ class ImageService(ImageServiceABC):
node_id: Optional[str] = None, node_id: Optional[str] = None,
session_id: Optional[str] = None, session_id: Optional[str] = None,
is_intermediate: bool = False, is_intermediate: bool = False,
metadata: Optional[dict] = None,
) -> ImageDTO: ) -> ImageDTO:
if image_origin not in ResourceOrigin: if image_origin not in ResourceOrigin:
raise InvalidOriginException raise InvalidOriginException
@ -174,7 +169,16 @@ class ImageService(ImageServiceABC):
image_name = self._services.names.create_image_name() image_name = self._services.names.create_image_name()
metadata = self._get_metadata(session_id, node_id) graph = None
if session_id is not None:
session_raw = self._services.graph_execution_manager.get_raw(session_id)
if session_raw is not None:
try:
graph = get_metadata_graph_from_raw_session(session_raw)
except Exception as e:
self._services.logger.warn(f"Failed to parse session graph: {e}")
graph = None
(width, height) = image.size (width, height) = image.size
@ -191,14 +195,12 @@ class ImageService(ImageServiceABC):
is_intermediate=is_intermediate, is_intermediate=is_intermediate,
# Nullable fields # Nullable fields
node_id=node_id, node_id=node_id,
session_id=session_id,
metadata=metadata, metadata=metadata,
session_id=session_id,
) )
self._services.image_files.save( self._services.image_files.save(
image_name=image_name, image_name=image_name, image=image, metadata=metadata, graph=graph
image=image,
metadata=metadata,
) )
image_dto = self.get_dto(image_name) image_dto = self.get_dto(image_name)
@ -268,6 +270,34 @@ class ImageService(ImageServiceABC):
self._services.logger.error("Problem getting image DTO") self._services.logger.error("Problem getting image DTO")
raise e raise e
def get_metadata(self, image_name: str) -> Optional[ImageMetadata]:
try:
image_record = self._services.image_records.get(image_name)
if not image_record.session_id:
return ImageMetadata()
session_raw = self._services.graph_execution_manager.get_raw(
image_record.session_id
)
graph = None
if session_raw:
try:
graph = get_metadata_graph_from_raw_session(session_raw)
except Exception as e:
self._services.logger.warn(f"Failed to parse session graph: {e}")
graph = None
metadata = self._services.image_records.get_metadata(image_name)
return ImageMetadata(graph=graph, metadata=metadata)
except ImageRecordNotFoundException:
self._services.logger.error("Image record not found")
raise
except Exception as e:
self._services.logger.error("Problem getting image DTO")
raise e
def get_path(self, image_name: str, thumbnail: bool = False) -> str: def get_path(self, image_name: str, thumbnail: bool = False) -> str:
try: try:
return self._services.image_files.get_path(image_name, thumbnail) return self._services.image_files.get_path(image_name, thumbnail)
@ -367,15 +397,3 @@ class ImageService(ImageServiceABC):
except Exception as e: except Exception as e:
self._services.logger.error("Problem deleting image records and files") self._services.logger.error("Problem deleting image records and files")
raise e raise e
def _get_metadata(
self, session_id: Optional[str] = None, node_id: Optional[str] = None
) -> Optional[ImageMetadata]:
"""Get the metadata for a node."""
metadata = None
if node_id is not None and session_id is not None:
session = self._services.graph_execution_manager.get(session_id)
metadata = self._services.metadata.create_image_metadata(session, node_id)
return metadata

View File

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Callable, Generic, TypeVar from typing import Callable, Generic, Optional, TypeVar
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from pydantic.generics import GenericModel from pydantic.generics import GenericModel
@ -29,14 +29,22 @@ class ItemStorageABC(ABC, Generic[T]):
@abstractmethod @abstractmethod
def get(self, item_id: str) -> T: def get(self, item_id: str) -> T:
"""Gets the item, parsing it into a Pydantic model"""
pass
@abstractmethod
def get_raw(self, item_id: str) -> Optional[str]:
"""Gets the raw item as a string, skipping Pydantic parsing"""
pass pass
@abstractmethod @abstractmethod
def set(self, item: T) -> None: def set(self, item: T) -> None:
"""Sets the item"""
pass pass
@abstractmethod @abstractmethod
def list(self, page: int = 0, per_page: int = 10) -> PaginatedResults[T]: def list(self, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
"""Gets a paginated list of items"""
pass pass
@abstractmethod @abstractmethod

View File

@ -1,142 +0,0 @@
from abc import ABC, abstractmethod
from typing import Any, Optional
import networkx as nx
from invokeai.app.models.metadata import ImageMetadata
from invokeai.app.services.graph import Graph, GraphExecutionState
class MetadataServiceBase(ABC):
"""Handles building metadata for nodes, images, and outputs."""
@abstractmethod
def create_image_metadata(
self, session: GraphExecutionState, node_id: str
) -> ImageMetadata:
"""Builds an ImageMetadata object for a node."""
pass
class CoreMetadataService(MetadataServiceBase):
_ANCESTOR_TYPES = ["t2l", "l2l"]
"""The ancestor types that contain the core metadata"""
_ANCESTOR_PARAMS = ["type", "steps", "model", "cfg_scale", "scheduler", "strength"]
"""The core metadata parameters in the ancestor types"""
_NOISE_FIELDS = ["seed", "width", "height"]
"""The core metadata parameters in the noise node"""
def create_image_metadata(
self, session: GraphExecutionState, node_id: str
) -> ImageMetadata:
metadata = self._build_metadata_from_graph(session, node_id)
return metadata
def _find_nearest_ancestor(self, G: nx.DiGraph, node_id: str) -> Optional[str]:
"""
Finds the id of the nearest ancestor (of a valid type) of a given node.
Parameters:
G (nx.DiGraph): The execution graph, converted in to a networkx DiGraph. Its nodes must
have the same data as the execution graph.
node_id (str): The ID of the node.
Returns:
str | None: The ID of the nearest ancestor, or None if there are no valid ancestors.
"""
# Retrieve the node from the graph
node = G.nodes[node_id]
# If the node type is one of the core metadata node types, return its id
if node.get("type") in self._ANCESTOR_TYPES:
return node.get("id")
# Else, look for the ancestor in the predecessor nodes
for predecessor in G.predecessors(node_id):
result = self._find_nearest_ancestor(G, predecessor)
if result:
return result
# If there are no valid ancestors, return None
return None
def _get_additional_metadata(
self, graph: Graph, node_id: str
) -> Optional[dict[str, Any]]:
"""
Returns additional metadata for a given node.
Parameters:
graph (Graph): The execution graph.
node_id (str): The ID of the node.
Returns:
dict[str, Any] | None: A dictionary of additional metadata.
"""
metadata = {}
# Iterate over all edges in the graph
for edge in graph.edges:
dest_node_id = edge.destination.node_id
dest_field = edge.destination.field
source_node_dict = graph.nodes[edge.source.node_id].dict()
# If the destination node ID matches the given node ID, gather necessary metadata
if dest_node_id == node_id:
# Prompt
if dest_field == "positive_conditioning":
metadata["positive_conditioning"] = source_node_dict.get("prompt")
# Negative prompt
if dest_field == "negative_conditioning":
metadata["negative_conditioning"] = source_node_dict.get("prompt")
# Seed, width and height
if dest_field == "noise":
for field in self._NOISE_FIELDS:
metadata[field] = source_node_dict.get(field)
return metadata
def _build_metadata_from_graph(
self, session: GraphExecutionState, node_id: str
) -> ImageMetadata:
"""
Builds an ImageMetadata object for a node.
Parameters:
session (GraphExecutionState): The session.
node_id (str): The ID of the node.
Returns:
ImageMetadata: The metadata for the node.
"""
# We need to do all the traversal on the execution graph
graph = session.execution_graph
# Find the nearest `t2l`/`l2l` ancestor of the given node
ancestor_id = self._find_nearest_ancestor(graph.nx_graph_with_data(), node_id)
# If no ancestor was found, return an empty ImageMetadata object
if ancestor_id is None:
return ImageMetadata()
ancestor_node = graph.get_node(ancestor_id)
# Grab all the core metadata from the ancestor node
ancestor_metadata = {
param: val
for param, val in ancestor_node.dict().items()
if param in self._ANCESTOR_PARAMS
}
# Get this image's prompts and noise parameters
addl_metadata = self._get_additional_metadata(graph, ancestor_id)
# If additional metadata was found, add it to the main metadata
if addl_metadata is not None:
ancestor_metadata.update(addl_metadata)
return ImageMetadata(**ancestor_metadata)

View File

@ -1,13 +1,14 @@
import datetime import datetime
from typing import Optional, Union from typing import Optional, Union
from pydantic import BaseModel, Extra, Field, StrictBool, StrictStr from pydantic import BaseModel, Extra, Field, StrictBool, StrictStr
from invokeai.app.models.image import ImageCategory, ResourceOrigin from invokeai.app.models.image import ImageCategory, ResourceOrigin
from invokeai.app.models.metadata import ImageMetadata
from invokeai.app.util.misc import get_iso_timestamp from invokeai.app.util.misc import get_iso_timestamp
class ImageRecord(BaseModel): class ImageRecord(BaseModel):
"""Deserialized image record.""" """Deserialized image record without metadata."""
image_name: str = Field(description="The unique name of the image.") image_name: str = Field(description="The unique name of the image.")
"""The unique name of the image.""" """The unique name of the image."""
@ -43,11 +44,6 @@ class ImageRecord(BaseModel):
description="The node ID that generated this image, if it is a generated image.", description="The node ID that generated this image, if it is a generated image.",
) )
"""The node ID that generated this image, if it is a generated image.""" """The node ID that generated this image, if it is a generated image."""
metadata: Optional[ImageMetadata] = Field(
default=None,
description="A limited subset of the image's generation metadata. Retrieve the image's session for full metadata.",
)
"""A limited subset of the image's generation metadata. Retrieve the image's session for full metadata."""
class ImageRecordChanges(BaseModel, extra=Extra.forbid): class ImageRecordChanges(BaseModel, extra=Extra.forbid):
@ -112,6 +108,7 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
# Retrieve all the values, setting "reasonable" defaults if they are not present. # Retrieve all the values, setting "reasonable" defaults if they are not present.
# TODO: do we really need to handle default values here? ideally the data is the correct shape...
image_name = image_dict.get("image_name", "unknown") image_name = image_dict.get("image_name", "unknown")
image_origin = ResourceOrigin( image_origin = ResourceOrigin(
image_dict.get("image_origin", ResourceOrigin.INTERNAL.value) image_dict.get("image_origin", ResourceOrigin.INTERNAL.value)
@ -128,13 +125,6 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
deleted_at = image_dict.get("deleted_at", get_iso_timestamp()) deleted_at = image_dict.get("deleted_at", get_iso_timestamp())
is_intermediate = image_dict.get("is_intermediate", False) is_intermediate = image_dict.get("is_intermediate", False)
raw_metadata = image_dict.get("metadata")
if raw_metadata is not None:
metadata = ImageMetadata.parse_raw(raw_metadata)
else:
metadata = None
return ImageRecord( return ImageRecord(
image_name=image_name, image_name=image_name,
image_origin=image_origin, image_origin=image_origin,
@ -143,7 +133,6 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
height=height, height=height,
session_id=session_id, session_id=session_id,
node_id=node_id, node_id=node_id,
metadata=metadata,
created_at=created_at, created_at=created_at,
updated_at=updated_at, updated_at=updated_at,
deleted_at=deleted_at, deleted_at=deleted_at,

View File

@ -1,6 +1,6 @@
import sqlite3 import sqlite3
from threading import Lock from threading import Lock
from typing import Generic, TypeVar, Optional, Union, get_args from typing import Generic, Optional, TypeVar, get_args
from pydantic import BaseModel, parse_raw_as from pydantic import BaseModel, parse_raw_as
@ -78,6 +78,21 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
return self._parse_item(result[0]) return self._parse_item(result[0])
def get_raw(self, id: str) -> Optional[str]:
try:
self._lock.acquire()
self._cursor.execute(
f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),)
)
result = self._cursor.fetchone()
finally:
self._lock.release()
if not result:
return None
return result[0]
def delete(self, id: str): def delete(self, id: str):
try: try:
self._lock.acquire() self._lock.acquire()

View File

@ -22,4 +22,4 @@ class LocalUrlService(UrlServiceBase):
if thumbnail: if thumbnail:
return f"{self._base_url}/images/{image_basename}/thumbnail" return f"{self._base_url}/images/{image_basename}/thumbnail"
return f"{self._base_url}/images/{image_basename}" return f"{self._base_url}/images/{image_basename}/full"

View File

@ -0,0 +1,55 @@
import json
from typing import Optional
from pydantic import ValidationError
from invokeai.app.services.graph import Edge
def get_metadata_graph_from_raw_session(session_raw: str) -> Optional[dict]:
"""
Parses raw session string, returning a dict of the graph.
Only the general graph shape is validated; none of the fields are validated.
Any `metadata_accumulator` nodes and edges are removed.
Any validation failure will return None.
"""
graph = json.loads(session_raw).get("graph", None)
# sanity check make sure the graph is at least reasonably shaped
if (
type(graph) is not dict
or "nodes" not in graph
or type(graph["nodes"]) is not dict
or "edges" not in graph
or type(graph["edges"]) is not list
):
# something has gone terribly awry, return an empty dict
return None
try:
# delete the `metadata_accumulator` node
del graph["nodes"]["metadata_accumulator"]
except KeyError:
# no accumulator node, all good
pass
# delete any edges to or from it
for i, edge in enumerate(graph["edges"]):
try:
# try to parse the edge
Edge(**edge)
except ValidationError:
# something has gone terribly awry, return an empty dict
return None
if (
edge["source"]["node_id"] == "metadata_accumulator"
or edge["destination"]["node_id"] == "metadata_accumulator"
):
del graph["edges"][i]
return graph