diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index 9c08013fef..cb867354a5 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -13,7 +13,6 @@ from invokeai.app.services.board_record_storage import SqliteBoardRecordStorage from invokeai.app.services.boards import BoardService, BoardServiceDependencies from invokeai.app.services.image_record_storage import SqliteImageRecordStorage from invokeai.app.services.images import ImageService, ImageServiceDependencies -from invokeai.app.services.metadata import CoreMetadataService from invokeai.app.services.resource_name import SimpleNameService from invokeai.app.services.urls import LocalUrlService from invokeai.backend.util.logging import InvokeAILogger @@ -75,7 +74,6 @@ class ApiDependencies: ) urls = LocalUrlService() - metadata = CoreMetadataService() image_record_storage = SqliteImageRecordStorage(db_location) image_file_storage = DiskImageFileStorage(f"{output_folder}/images") names = SimpleNameService() @@ -111,7 +109,6 @@ class ApiDependencies: board_image_record_storage=board_image_record_storage, image_record_storage=image_record_storage, image_file_storage=image_file_storage, - metadata=metadata, url=urls, logger=logger, names=names, diff --git a/invokeai/app/api/routers/images.py b/invokeai/app/api/routers/images.py index a8c84b81b9..a0428e772e 100644 --- a/invokeai/app/api/routers/images.py +++ b/invokeai/app/api/routers/images.py @@ -1,20 +1,19 @@ import io from typing import Optional -from fastapi import Body, HTTPException, Path, Query, Request, Response, UploadFile -from fastapi.routing import APIRouter + +from fastapi import (Body, HTTPException, Path, Query, Request, Response, + UploadFile) from fastapi.responses import FileResponse +from fastapi.routing import APIRouter from PIL import Image -from invokeai.app.models.image import ( - ImageCategory, - ResourceOrigin, -) + +from invokeai.app.invocations.metadata import ImageMetadata +from invokeai.app.models.image import ImageCategory, ResourceOrigin from invokeai.app.services.image_record_storage import OffsetPaginatedResults -from invokeai.app.services.models.image_record import ( - ImageDTO, - ImageRecordChanges, - ImageUrlsDTO, -) from invokeai.app.services.item_storage import PaginatedResults +from invokeai.app.services.models.image_record import (ImageDTO, + ImageRecordChanges, + ImageUrlsDTO) from ..dependencies import ApiDependencies @@ -103,23 +102,38 @@ async def update_image( @images_router.get( - "/{image_name}/metadata", - operation_id="get_image_metadata", + "/{image_name}", + operation_id="get_image_dto", response_model=ImageDTO, ) -async def get_image_metadata( +async def get_image_dto( image_name: str = Path(description="The name of image to get"), ) -> ImageDTO: - """Gets an image's metadata""" + """Gets an image's DTO""" try: return ApiDependencies.invoker.services.images.get_dto(image_name) except Exception as e: raise HTTPException(status_code=404) +@images_router.get( + "/{image_name}/metadata", + operation_id="get_image_metadata", + response_model=ImageMetadata, +) +async def get_image_metadata( + image_name: str = Path(description="The name of image to get"), +) -> ImageMetadata: + """Gets an image's metadata""" + + try: + return ApiDependencies.invoker.services.images.get_metadata(image_name) + except Exception as e: + raise HTTPException(status_code=404) + @images_router.get( - "/{image_name}", + "/{image_name}/full", operation_id="get_image_full", response_class=Response, responses={ @@ -208,10 +222,10 @@ async def get_image_urls( @images_router.get( "/", - operation_id="list_images_with_metadata", + operation_id="list_image_dtos", response_model=OffsetPaginatedResults[ImageDTO], ) -async def list_images_with_metadata( +async def list_image_dtos( image_origin: Optional[ResourceOrigin] = Query( default=None, description="The origin of images to list" ), @@ -227,7 +241,7 @@ async def list_images_with_metadata( offset: int = Query(default=0, description="The page offset"), limit: int = Query(default=10, description="The number of images per page"), ) -> OffsetPaginatedResults[ImageDTO]: - """Gets a list of images""" + """Gets a list of image DTOs""" image_dtos = ApiDependencies.invoker.services.images.get_many( offset, diff --git a/invokeai/app/cli_app.py b/invokeai/app/cli_app.py index e3c6280ccb..888d36c4bf 100644 --- a/invokeai/app/cli_app.py +++ b/invokeai/app/cli_app.py @@ -34,7 +34,6 @@ from invokeai.app.services.board_record_storage import SqliteBoardRecordStorage from invokeai.app.services.boards import BoardService, BoardServiceDependencies from invokeai.app.services.image_record_storage import SqliteImageRecordStorage from invokeai.app.services.images import ImageService, ImageServiceDependencies -from invokeai.app.services.metadata import CoreMetadataService from invokeai.app.services.resource_name import SimpleNameService from invokeai.app.services.urls import LocalUrlService from .services.default_graphs import (default_text_to_image_graph_id, @@ -244,7 +243,6 @@ def invoke_cli(): ) urls = LocalUrlService() - metadata = CoreMetadataService() image_record_storage = SqliteImageRecordStorage(db_location) image_file_storage = DiskImageFileStorage(f"{output_folder}/images") names = SimpleNameService() @@ -277,7 +275,6 @@ def invoke_cli(): board_image_record_storage=board_image_record_storage, image_record_storage=image_record_storage, image_file_storage=image_file_storage, - metadata=metadata, url=urls, logger=logger, names=names, diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 5bdeaa5a9c..da99afc50f 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -9,9 +9,9 @@ from diffusers.image_processor import VaeImageProcessor from diffusers.schedulers import SchedulerMixin as Scheduler from pydantic import BaseModel, Field, validator +from invokeai.app.invocations.metadata import CoreMetadata from invokeai.app.util.step_callback import stable_diffusion_step_callback -from ..models.image import ImageCategory, ImageField, ResourceOrigin from ...backend.model_management.lora import ModelPatcher from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.stable_diffusion.diffusers_pipeline import ( @@ -21,6 +21,7 @@ from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import \ PostprocessingSettings from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP from ...backend.util.devices import torch_dtype +from ..models.image import ImageCategory, ImageField, ResourceOrigin from .baseinvocation import (BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext) from .compel import ConditioningField @@ -449,6 +450,8 @@ class LatentsToImageInvocation(BaseInvocation): tiled: bool = Field( default=False, description="Decode latents by overlaping tiles(less memory consumption)") + metadata: Optional[CoreMetadata] = Field(default=None, description="Optional core metadata to be written to the image") + # Schema customisation class Config(InvocationConfig): @@ -493,7 +496,8 @@ class LatentsToImageInvocation(BaseInvocation): image_category=ImageCategory.GENERAL, node_id=self.id, session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate + is_intermediate=self.is_intermediate, + metadata=self.metadata.dict() if self.metadata else None, ) return ImageOutput( diff --git a/invokeai/app/invocations/metadata.py b/invokeai/app/invocations/metadata.py new file mode 100644 index 0000000000..b7639c56c7 --- /dev/null +++ b/invokeai/app/invocations/metadata.py @@ -0,0 +1,124 @@ +from typing import Literal, Optional, Union + +from pydantic import BaseModel, Field + +from invokeai.app.invocations.baseinvocation import (BaseInvocation, + BaseInvocationOutput, + InvocationContext) +from invokeai.app.invocations.controlnet_image_processors import ControlField +from invokeai.app.invocations.model import (LoRAModelField, MainModelField, + VAEModelField) + + +class LoRAMetadataField(BaseModel): + """LoRA metadata for an image generated in InvokeAI.""" + lora: LoRAModelField = Field(description="The LoRA model") + weight: float = Field(description="The weight of the LoRA model") + + +class CoreMetadata(BaseModel): + """Core generation metadata for an image generated in InvokeAI.""" + + generation_mode: str = Field(description="The generation mode that output this image",) + positive_prompt: str = Field(description="The positive prompt parameter") + negative_prompt: str = Field(description="The negative prompt parameter") + width: int = Field(description="The width parameter") + height: int = Field(description="The height parameter") + seed: int = Field(description="The seed used for noise generation") + rand_device: str = Field(description="The device used for random number generation") + cfg_scale: float = Field(description="The classifier-free guidance scale parameter") + steps: int = Field(description="The number of steps used for inference") + scheduler: str = Field(description="The scheduler used for inference") + clip_skip: int = Field(description="The number of skipped CLIP layers",) + model: MainModelField = Field(description="The main model used for inference") + controlnets: list[ControlField]= Field(description="The ControlNets used for inference") + loras: list[LoRAMetadataField] = Field(description="The LoRAs used for inference") + strength: Union[float, None] = Field( + default=None, + description="The strength used for latents-to-latents", + ) + init_image: Union[str, None] = Field( + default=None, description="The name of the initial image" + ) + vae: Union[VAEModelField, None] = Field( + default=None, + description="The VAE used for decoding, if the main model's default was not used", + ) + + +class ImageMetadata(BaseModel): + """An image's generation metadata""" + + metadata: Optional[dict] = Field( + default=None, + description="The image's core metadata, if it was created in the Linear or Canvas UI", + ) + graph: Optional[dict] = Field( + default=None, description="The graph that created the image" + ) + + +class MetadataAccumulatorOutput(BaseInvocationOutput): + """The output of the MetadataAccumulator node""" + + type: Literal["metadata_accumulator_output"] = "metadata_accumulator_output" + + metadata: CoreMetadata = Field(description="The core metadata for the image") + + +class MetadataAccumulatorInvocation(BaseInvocation): + """Outputs a Core Metadata Object""" + + type: Literal["metadata_accumulator"] = "metadata_accumulator" + + generation_mode: str = Field(description="The generation mode that output this image",) + positive_prompt: str = Field(description="The positive prompt parameter") + negative_prompt: str = Field(description="The negative prompt parameter") + width: int = Field(description="The width parameter") + height: int = Field(description="The height parameter") + seed: int = Field(description="The seed used for noise generation") + rand_device: str = Field(description="The device used for random number generation") + cfg_scale: float = Field(description="The classifier-free guidance scale parameter") + steps: int = Field(description="The number of steps used for inference") + scheduler: str = Field(description="The scheduler used for inference") + clip_skip: int = Field(description="The number of skipped CLIP layers",) + model: MainModelField = Field(description="The main model used for inference") + controlnets: list[ControlField]= Field(description="The ControlNets used for inference") + loras: list[LoRAMetadataField] = Field(description="The LoRAs used for inference") + strength: Union[float, None] = Field( + default=None, + description="The strength used for latents-to-latents", + ) + init_image: Union[str, None] = Field( + default=None, description="The name of the initial image" + ) + vae: Union[VAEModelField, None] = Field( + default=None, + description="The VAE used for decoding, if the main model's default was not used", + ) + + + def invoke(self, context: InvocationContext) -> MetadataAccumulatorOutput: + """Collects and outputs a CoreMetadata object""" + + return MetadataAccumulatorOutput( + metadata=CoreMetadata( + generation_mode=self.generation_mode, + positive_prompt=self.positive_prompt, + negative_prompt=self.negative_prompt, + width=self.width, + height=self.height, + seed=self.seed, + rand_device=self.rand_device, + cfg_scale=self.cfg_scale, + steps=self.steps, + scheduler=self.scheduler, + model=self.model, + strength=self.strength, + init_image=self.init_image, + vae=self.vae, + controlnets=self.controlnets, + loras=self.loras, + clip_skip=self.clip_skip, + ) + ) diff --git a/invokeai/app/models/metadata.py b/invokeai/app/models/metadata.py deleted file mode 100644 index 8d90ca0bc8..0000000000 --- a/invokeai/app/models/metadata.py +++ /dev/null @@ -1,93 +0,0 @@ -from typing import Optional, Union, List -from pydantic import BaseModel, Extra, Field, StrictFloat, StrictInt, StrictStr - - -class ImageMetadata(BaseModel): - """ - Core generation metadata for an image/tensor generated in InvokeAI. - - Also includes any metadata from the image's PNG tEXt chunks. - - Generated by traversing the execution graph, collecting the parameters of the nearest ancestors - of a given node. - - Full metadata may be accessed by querying for the session in the `graph_executions` table. - """ - - class Config: - extra = Extra.allow - """ - This lets the ImageMetadata class accept arbitrary additional fields. The CoreMetadataService - won't add any fields that are not already defined, but other a different metadata service - implementation might. - """ - - type: Optional[StrictStr] = Field( - default=None, - description="The type of the ancestor node of the image output node.", - ) - """The type of the ancestor node of the image output node.""" - positive_conditioning: Optional[StrictStr] = Field( - default=None, description="The positive conditioning." - ) - """The positive conditioning""" - negative_conditioning: Optional[StrictStr] = Field( - default=None, description="The negative conditioning." - ) - """The negative conditioning""" - width: Optional[StrictInt] = Field( - default=None, description="Width of the image/latents in pixels." - ) - """Width of the image/latents in pixels""" - height: Optional[StrictInt] = Field( - default=None, description="Height of the image/latents in pixels." - ) - """Height of the image/latents in pixels""" - seed: Optional[StrictInt] = Field( - default=None, description="The seed used for noise generation." - ) - """The seed used for noise generation""" - # cfg_scale: Optional[StrictFloat] = Field( - # cfg_scale: Union[float, list[float]] = Field( - cfg_scale: Union[StrictFloat, List[StrictFloat]] = Field( - default=None, description="The classifier-free guidance scale." - ) - """The classifier-free guidance scale""" - steps: Optional[StrictInt] = Field( - default=None, description="The number of steps used for inference." - ) - """The number of steps used for inference""" - scheduler: Optional[StrictStr] = Field( - default=None, description="The scheduler used for inference." - ) - """The scheduler used for inference""" - model: Optional[StrictStr] = Field( - default=None, description="The model used for inference." - ) - """The model used for inference""" - strength: Optional[StrictFloat] = Field( - default=None, - description="The strength used for image-to-image/latents-to-latents.", - ) - """The strength used for image-to-image/latents-to-latents.""" - latents: Optional[StrictStr] = Field( - default=None, description="The ID of the initial latents." - ) - """The ID of the initial latents""" - vae: Optional[StrictStr] = Field( - default=None, description="The VAE used for decoding." - ) - """The VAE used for decoding""" - unet: Optional[StrictStr] = Field( - default=None, description="The UNet used dor inference." - ) - """The UNet used dor inference""" - clip: Optional[StrictStr] = Field( - default=None, description="The CLIP Encoder used for conditioning." - ) - """The CLIP Encoder used for conditioning""" - extra: Optional[StrictStr] = Field( - default=None, - description="Uploaded image metadata, extracted from the PNG tEXt chunk.", - ) - """Uploaded image metadata, extracted from the PNG tEXt chunk.""" diff --git a/invokeai/app/services/image_file_storage.py b/invokeai/app/services/image_file_storage.py index 136964afb5..c558d2469c 100644 --- a/invokeai/app/services/image_file_storage.py +++ b/invokeai/app/services/image_file_storage.py @@ -1,14 +1,14 @@ # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team +import json from abc import ABC, abstractmethod from pathlib import Path from queue import Queue from typing import Dict, Optional, Union -from PIL.Image import Image as PILImageType from PIL import Image, PngImagePlugin +from PIL.Image import Image as PILImageType from send2trash import send2trash -from invokeai.app.models.metadata import ImageMetadata from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail @@ -59,7 +59,8 @@ class ImageFileStorageBase(ABC): self, image: PILImageType, image_name: str, - metadata: Optional[ImageMetadata] = None, + metadata: Optional[dict] = None, + graph: Optional[dict] = None, thumbnail_size: int = 256, ) -> None: """Saves an image and a 256x256 WEBP thumbnail. Returns a tuple of the image name, thumbnail name, and created timestamp.""" @@ -110,20 +111,22 @@ class DiskImageFileStorage(ImageFileStorageBase): self, image: PILImageType, image_name: str, - metadata: Optional[ImageMetadata] = None, + metadata: Optional[dict] = None, + graph: Optional[dict] = None, thumbnail_size: int = 256, ) -> None: try: self.__validate_storage_folders() image_path = self.get_path(image_name) + pnginfo = PngImagePlugin.PngInfo() + if metadata is not None: - pnginfo = PngImagePlugin.PngInfo() - pnginfo.add_text("invokeai", metadata.json()) - image.save(image_path, "PNG", pnginfo=pnginfo) - else: - image.save(image_path, "PNG") + pnginfo.add_text("metadata", json.dumps(metadata)) + if graph is not None: + pnginfo.add_text("graph", json.dumps(graph)) + image.save(image_path, "PNG", pnginfo=pnginfo) thumbnail_name = get_thumbnail_name(image_name) thumbnail_path = self.get_path(thumbnail_name, thumbnail=True) thumbnail_image = make_thumbnail(image, thumbnail_size) diff --git a/invokeai/app/services/image_record_storage.py b/invokeai/app/services/image_record_storage.py index 014006eb7a..7b37307ce8 100644 --- a/invokeai/app/services/image_record_storage.py +++ b/invokeai/app/services/image_record_storage.py @@ -1,3 +1,4 @@ +import json import sqlite3 import threading from abc import ABC, abstractmethod @@ -8,7 +9,6 @@ from pydantic import BaseModel, Field from pydantic.generics import GenericModel from invokeai.app.models.image import ImageCategory, ResourceOrigin -from invokeai.app.models.metadata import ImageMetadata from invokeai.app.services.models.image_record import ( ImageRecord, ImageRecordChanges, deserialize_image_record) @@ -48,6 +48,28 @@ class ImageRecordDeleteException(Exception): super().__init__(message) +IMAGE_DTO_COLS = ", ".join( + list( + map( + lambda c: "images." + c, + [ + "image_name", + "image_origin", + "image_category", + "width", + "height", + "session_id", + "node_id", + "is_intermediate", + "created_at", + "updated_at", + "deleted_at", + ], + ) + ) +) + + class ImageRecordStorageBase(ABC): """Low-level service responsible for interfacing with the image record store.""" @@ -58,6 +80,11 @@ class ImageRecordStorageBase(ABC): """Gets an image record.""" pass + @abstractmethod + def get_metadata(self, image_name: str) -> Optional[dict]: + """Gets an image's metadata'.""" + pass + @abstractmethod def update( self, @@ -102,7 +129,7 @@ class ImageRecordStorageBase(ABC): height: int, session_id: Optional[str], node_id: Optional[str], - metadata: Optional[ImageMetadata], + metadata: Optional[dict], is_intermediate: bool = False, ) -> datetime: """Saves an image record.""" @@ -206,7 +233,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): self._cursor.execute( f"""--sql - SELECT * FROM images + SELECT {IMAGE_DTO_COLS} FROM images WHERE image_name = ?; """, (image_name,), @@ -224,6 +251,28 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): return deserialize_image_record(dict(result)) + def get_metadata(self, image_name: str) -> Optional[dict]: + try: + self._lock.acquire() + + self._cursor.execute( + f"""--sql + SELECT images.metadata FROM images + WHERE image_name = ?; + """, + (image_name,), + ) + + result = cast(Optional[sqlite3.Row], self._cursor.fetchone()) + if not result or not result[0]: + return None + return json.loads(result[0]) + except sqlite3.Error as e: + self._conn.rollback() + raise ImageRecordNotFoundException from e + finally: + self._lock.release() + def update( self, image_name: str, @@ -291,8 +340,8 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): WHERE 1=1 """ - images_query = """--sql - SELECT images.* + images_query = f"""--sql + SELECT {IMAGE_DTO_COLS} FROM images LEFT JOIN board_images ON board_images.image_name = images.image_name WHERE 1=1 @@ -410,12 +459,12 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): width: int, height: int, node_id: Optional[str], - metadata: Optional[ImageMetadata], + metadata: Optional[dict], is_intermediate: bool = False, ) -> datetime: try: metadata_json = ( - None if metadata is None else metadata.json(exclude_none=True) + None if metadata is None else json.dumps(metadata) ) self._lock.acquire() self._cursor.execute( @@ -465,9 +514,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): finally: self._lock.release() - def get_most_recent_image_for_board( - self, board_id: str - ) -> Optional[ImageRecord]: + def get_most_recent_image_for_board(self, board_id: str) -> Optional[ImageRecord]: try: self._lock.acquire() self._cursor.execute( diff --git a/invokeai/app/services/images.py b/invokeai/app/services/images.py index 4fbea1aa2d..a7d0b6ddee 100644 --- a/invokeai/app/services/images.py +++ b/invokeai/app/services/images.py @@ -1,39 +1,30 @@ +import json from abc import ABC, abstractmethod from logging import Logger -from typing import Optional, TYPE_CHECKING, Union +from typing import TYPE_CHECKING, Optional + from PIL.Image import Image as PILImageType -from invokeai.app.models.image import ( - ImageCategory, - ResourceOrigin, - InvalidImageCategoryException, - InvalidOriginException, -) -from invokeai.app.models.metadata import ImageMetadata -from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase -from invokeai.app.services.image_record_storage import ( - ImageRecordDeleteException, - ImageRecordNotFoundException, - ImageRecordSaveException, - ImageRecordStorageBase, - OffsetPaginatedResults, -) -from invokeai.app.services.models.image_record import ( - ImageRecord, - ImageDTO, - ImageRecordChanges, - image_record_to_dto, -) +from invokeai.app.invocations.metadata import ImageMetadata +from invokeai.app.models.image import (ImageCategory, + InvalidImageCategoryException, + InvalidOriginException, ResourceOrigin) +from invokeai.app.services.board_image_record_storage import \ + BoardImageRecordStorageBase +from invokeai.app.services.graph import Graph from invokeai.app.services.image_file_storage import ( - ImageFileDeleteException, - ImageFileNotFoundException, - ImageFileSaveException, - ImageFileStorageBase, -) -from invokeai.app.services.item_storage import ItemStorageABC, PaginatedResults -from invokeai.app.services.metadata import MetadataServiceBase + ImageFileDeleteException, ImageFileNotFoundException, + ImageFileSaveException, ImageFileStorageBase) +from invokeai.app.services.image_record_storage import ( + ImageRecordDeleteException, ImageRecordNotFoundException, + ImageRecordSaveException, ImageRecordStorageBase, OffsetPaginatedResults) +from invokeai.app.services.item_storage import ItemStorageABC +from invokeai.app.services.models.image_record import (ImageDTO, ImageRecord, + ImageRecordChanges, + image_record_to_dto) from invokeai.app.services.resource_name import NameServiceBase from invokeai.app.services.urls import UrlServiceBase +from invokeai.app.util.metadata import get_metadata_graph_from_raw_session if TYPE_CHECKING: from invokeai.app.services.graph import GraphExecutionState @@ -51,6 +42,7 @@ class ImageServiceABC(ABC): node_id: Optional[str] = None, session_id: Optional[str] = None, is_intermediate: bool = False, + metadata: Optional[dict] = None, ) -> ImageDTO: """Creates an image, storing the file and its metadata.""" pass @@ -79,6 +71,11 @@ class ImageServiceABC(ABC): """Gets an image DTO.""" pass + @abstractmethod + def get_metadata(self, image_name: str) -> ImageMetadata: + """Gets an image's metadata.""" + pass + @abstractmethod def get_path(self, image_name: str, thumbnail: bool = False) -> str: """Gets an image's path.""" @@ -124,7 +121,6 @@ class ImageServiceDependencies: image_records: ImageRecordStorageBase image_files: ImageFileStorageBase board_image_records: BoardImageRecordStorageBase - metadata: MetadataServiceBase urls: UrlServiceBase logger: Logger names: NameServiceBase @@ -135,7 +131,6 @@ class ImageServiceDependencies: image_record_storage: ImageRecordStorageBase, image_file_storage: ImageFileStorageBase, board_image_record_storage: BoardImageRecordStorageBase, - metadata: MetadataServiceBase, url: UrlServiceBase, logger: Logger, names: NameServiceBase, @@ -144,7 +139,6 @@ class ImageServiceDependencies: self.image_records = image_record_storage self.image_files = image_file_storage self.board_image_records = board_image_record_storage - self.metadata = metadata self.urls = url self.logger = logger self.names = names @@ -165,6 +159,7 @@ class ImageService(ImageServiceABC): node_id: Optional[str] = None, session_id: Optional[str] = None, is_intermediate: bool = False, + metadata: Optional[dict] = None, ) -> ImageDTO: if image_origin not in ResourceOrigin: raise InvalidOriginException @@ -174,7 +169,16 @@ class ImageService(ImageServiceABC): image_name = self._services.names.create_image_name() - metadata = self._get_metadata(session_id, node_id) + graph = None + + if session_id is not None: + session_raw = self._services.graph_execution_manager.get_raw(session_id) + if session_raw is not None: + try: + graph = get_metadata_graph_from_raw_session(session_raw) + except Exception as e: + self._services.logger.warn(f"Failed to parse session graph: {e}") + graph = None (width, height) = image.size @@ -191,14 +195,12 @@ class ImageService(ImageServiceABC): is_intermediate=is_intermediate, # Nullable fields node_id=node_id, - session_id=session_id, metadata=metadata, + session_id=session_id, ) self._services.image_files.save( - image_name=image_name, - image=image, - metadata=metadata, + image_name=image_name, image=image, metadata=metadata, graph=graph ) image_dto = self.get_dto(image_name) @@ -268,6 +270,34 @@ class ImageService(ImageServiceABC): self._services.logger.error("Problem getting image DTO") raise e + def get_metadata(self, image_name: str) -> Optional[ImageMetadata]: + try: + image_record = self._services.image_records.get(image_name) + + if not image_record.session_id: + return ImageMetadata() + + session_raw = self._services.graph_execution_manager.get_raw( + image_record.session_id + ) + graph = None + + if session_raw: + try: + graph = get_metadata_graph_from_raw_session(session_raw) + except Exception as e: + self._services.logger.warn(f"Failed to parse session graph: {e}") + graph = None + + metadata = self._services.image_records.get_metadata(image_name) + return ImageMetadata(graph=graph, metadata=metadata) + except ImageRecordNotFoundException: + self._services.logger.error("Image record not found") + raise + except Exception as e: + self._services.logger.error("Problem getting image DTO") + raise e + def get_path(self, image_name: str, thumbnail: bool = False) -> str: try: return self._services.image_files.get_path(image_name, thumbnail) @@ -367,15 +397,3 @@ class ImageService(ImageServiceABC): except Exception as e: self._services.logger.error("Problem deleting image records and files") raise e - - def _get_metadata( - self, session_id: Optional[str] = None, node_id: Optional[str] = None - ) -> Optional[ImageMetadata]: - """Get the metadata for a node.""" - metadata = None - - if node_id is not None and session_id is not None: - session = self._services.graph_execution_manager.get(session_id) - metadata = self._services.metadata.create_image_metadata(session, node_id) - - return metadata diff --git a/invokeai/app/services/item_storage.py b/invokeai/app/services/item_storage.py index 394f67797d..709d88bf97 100644 --- a/invokeai/app/services/item_storage.py +++ b/invokeai/app/services/item_storage.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Callable, Generic, TypeVar +from typing import Callable, Generic, Optional, TypeVar from pydantic import BaseModel, Field from pydantic.generics import GenericModel @@ -29,14 +29,22 @@ class ItemStorageABC(ABC, Generic[T]): @abstractmethod def get(self, item_id: str) -> T: + """Gets the item, parsing it into a Pydantic model""" + pass + + @abstractmethod + def get_raw(self, item_id: str) -> Optional[str]: + """Gets the raw item as a string, skipping Pydantic parsing""" pass @abstractmethod def set(self, item: T) -> None: + """Sets the item""" pass @abstractmethod def list(self, page: int = 0, per_page: int = 10) -> PaginatedResults[T]: + """Gets a paginated list of items""" pass @abstractmethod diff --git a/invokeai/app/services/metadata.py b/invokeai/app/services/metadata.py deleted file mode 100644 index cc169db3ce..0000000000 --- a/invokeai/app/services/metadata.py +++ /dev/null @@ -1,142 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Any, Optional -import networkx as nx - -from invokeai.app.models.metadata import ImageMetadata -from invokeai.app.services.graph import Graph, GraphExecutionState - - -class MetadataServiceBase(ABC): - """Handles building metadata for nodes, images, and outputs.""" - - @abstractmethod - def create_image_metadata( - self, session: GraphExecutionState, node_id: str - ) -> ImageMetadata: - """Builds an ImageMetadata object for a node.""" - pass - - -class CoreMetadataService(MetadataServiceBase): - _ANCESTOR_TYPES = ["t2l", "l2l"] - """The ancestor types that contain the core metadata""" - - _ANCESTOR_PARAMS = ["type", "steps", "model", "cfg_scale", "scheduler", "strength"] - """The core metadata parameters in the ancestor types""" - - _NOISE_FIELDS = ["seed", "width", "height"] - """The core metadata parameters in the noise node""" - - def create_image_metadata( - self, session: GraphExecutionState, node_id: str - ) -> ImageMetadata: - metadata = self._build_metadata_from_graph(session, node_id) - - return metadata - - def _find_nearest_ancestor(self, G: nx.DiGraph, node_id: str) -> Optional[str]: - """ - Finds the id of the nearest ancestor (of a valid type) of a given node. - - Parameters: - G (nx.DiGraph): The execution graph, converted in to a networkx DiGraph. Its nodes must - have the same data as the execution graph. - node_id (str): The ID of the node. - - Returns: - str | None: The ID of the nearest ancestor, or None if there are no valid ancestors. - """ - - # Retrieve the node from the graph - node = G.nodes[node_id] - - # If the node type is one of the core metadata node types, return its id - if node.get("type") in self._ANCESTOR_TYPES: - return node.get("id") - - # Else, look for the ancestor in the predecessor nodes - for predecessor in G.predecessors(node_id): - result = self._find_nearest_ancestor(G, predecessor) - if result: - return result - - # If there are no valid ancestors, return None - return None - - def _get_additional_metadata( - self, graph: Graph, node_id: str - ) -> Optional[dict[str, Any]]: - """ - Returns additional metadata for a given node. - - Parameters: - graph (Graph): The execution graph. - node_id (str): The ID of the node. - - Returns: - dict[str, Any] | None: A dictionary of additional metadata. - """ - - metadata = {} - - # Iterate over all edges in the graph - for edge in graph.edges: - dest_node_id = edge.destination.node_id - dest_field = edge.destination.field - source_node_dict = graph.nodes[edge.source.node_id].dict() - - # If the destination node ID matches the given node ID, gather necessary metadata - if dest_node_id == node_id: - # Prompt - if dest_field == "positive_conditioning": - metadata["positive_conditioning"] = source_node_dict.get("prompt") - # Negative prompt - if dest_field == "negative_conditioning": - metadata["negative_conditioning"] = source_node_dict.get("prompt") - # Seed, width and height - if dest_field == "noise": - for field in self._NOISE_FIELDS: - metadata[field] = source_node_dict.get(field) - return metadata - - def _build_metadata_from_graph( - self, session: GraphExecutionState, node_id: str - ) -> ImageMetadata: - """ - Builds an ImageMetadata object for a node. - - Parameters: - session (GraphExecutionState): The session. - node_id (str): The ID of the node. - - Returns: - ImageMetadata: The metadata for the node. - """ - - # We need to do all the traversal on the execution graph - graph = session.execution_graph - - # Find the nearest `t2l`/`l2l` ancestor of the given node - ancestor_id = self._find_nearest_ancestor(graph.nx_graph_with_data(), node_id) - - # If no ancestor was found, return an empty ImageMetadata object - if ancestor_id is None: - return ImageMetadata() - - ancestor_node = graph.get_node(ancestor_id) - - # Grab all the core metadata from the ancestor node - ancestor_metadata = { - param: val - for param, val in ancestor_node.dict().items() - if param in self._ANCESTOR_PARAMS - } - - # Get this image's prompts and noise parameters - addl_metadata = self._get_additional_metadata(graph, ancestor_id) - - # If additional metadata was found, add it to the main metadata - if addl_metadata is not None: - ancestor_metadata.update(addl_metadata) - - return ImageMetadata(**ancestor_metadata) diff --git a/invokeai/app/services/models/image_record.py b/invokeai/app/services/models/image_record.py index c40d2138f8..cf10f6e8b2 100644 --- a/invokeai/app/services/models/image_record.py +++ b/invokeai/app/services/models/image_record.py @@ -1,13 +1,14 @@ import datetime from typing import Optional, Union + from pydantic import BaseModel, Extra, Field, StrictBool, StrictStr + from invokeai.app.models.image import ImageCategory, ResourceOrigin -from invokeai.app.models.metadata import ImageMetadata from invokeai.app.util.misc import get_iso_timestamp class ImageRecord(BaseModel): - """Deserialized image record.""" + """Deserialized image record without metadata.""" image_name: str = Field(description="The unique name of the image.") """The unique name of the image.""" @@ -43,11 +44,6 @@ class ImageRecord(BaseModel): description="The node ID that generated this image, if it is a generated image.", ) """The node ID that generated this image, if it is a generated image.""" - metadata: Optional[ImageMetadata] = Field( - default=None, - description="A limited subset of the image's generation metadata. Retrieve the image's session for full metadata.", - ) - """A limited subset of the image's generation metadata. Retrieve the image's session for full metadata.""" class ImageRecordChanges(BaseModel, extra=Extra.forbid): @@ -112,6 +108,7 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord: # Retrieve all the values, setting "reasonable" defaults if they are not present. + # TODO: do we really need to handle default values here? ideally the data is the correct shape... image_name = image_dict.get("image_name", "unknown") image_origin = ResourceOrigin( image_dict.get("image_origin", ResourceOrigin.INTERNAL.value) @@ -128,13 +125,6 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord: deleted_at = image_dict.get("deleted_at", get_iso_timestamp()) is_intermediate = image_dict.get("is_intermediate", False) - raw_metadata = image_dict.get("metadata") - - if raw_metadata is not None: - metadata = ImageMetadata.parse_raw(raw_metadata) - else: - metadata = None - return ImageRecord( image_name=image_name, image_origin=image_origin, @@ -143,7 +133,6 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord: height=height, session_id=session_id, node_id=node_id, - metadata=metadata, created_at=created_at, updated_at=updated_at, deleted_at=deleted_at, diff --git a/invokeai/app/services/sqlite.py b/invokeai/app/services/sqlite.py index e71f039bcc..8902415096 100644 --- a/invokeai/app/services/sqlite.py +++ b/invokeai/app/services/sqlite.py @@ -1,6 +1,6 @@ import sqlite3 from threading import Lock -from typing import Generic, TypeVar, Optional, Union, get_args +from typing import Generic, Optional, TypeVar, get_args from pydantic import BaseModel, parse_raw_as @@ -78,6 +78,21 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]): return self._parse_item(result[0]) + def get_raw(self, id: str) -> Optional[str]: + try: + self._lock.acquire() + self._cursor.execute( + f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),) + ) + result = self._cursor.fetchone() + finally: + self._lock.release() + + if not result: + return None + + return result[0] + def delete(self, id: str): try: self._lock.acquire() diff --git a/invokeai/app/services/urls.py b/invokeai/app/services/urls.py index 5920e9e6c1..73d8ddadf4 100644 --- a/invokeai/app/services/urls.py +++ b/invokeai/app/services/urls.py @@ -22,4 +22,4 @@ class LocalUrlService(UrlServiceBase): if thumbnail: return f"{self._base_url}/images/{image_basename}/thumbnail" - return f"{self._base_url}/images/{image_basename}" + return f"{self._base_url}/images/{image_basename}/full" diff --git a/invokeai/app/util/metadata.py b/invokeai/app/util/metadata.py new file mode 100644 index 0000000000..5ca5f14e12 --- /dev/null +++ b/invokeai/app/util/metadata.py @@ -0,0 +1,55 @@ +import json +from typing import Optional + +from pydantic import ValidationError + +from invokeai.app.services.graph import Edge + + +def get_metadata_graph_from_raw_session(session_raw: str) -> Optional[dict]: + """ + Parses raw session string, returning a dict of the graph. + + Only the general graph shape is validated; none of the fields are validated. + + Any `metadata_accumulator` nodes and edges are removed. + + Any validation failure will return None. + """ + + graph = json.loads(session_raw).get("graph", None) + + # sanity check make sure the graph is at least reasonably shaped + if ( + type(graph) is not dict + or "nodes" not in graph + or type(graph["nodes"]) is not dict + or "edges" not in graph + or type(graph["edges"]) is not list + ): + # something has gone terribly awry, return an empty dict + return None + + try: + # delete the `metadata_accumulator` node + del graph["nodes"]["metadata_accumulator"] + except KeyError: + # no accumulator node, all good + pass + + # delete any edges to or from it + for i, edge in enumerate(graph["edges"]): + try: + # try to parse the edge + Edge(**edge) + except ValidationError: + # something has gone terribly awry, return an empty dict + return None + + if ( + edge["source"]["node_id"] == "metadata_accumulator" + or edge["destination"]["node_id"] == "metadata_accumulator" + ): + del graph["edges"][i] + + return graph