diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index 1ad53f31ca..ae351d4476 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -4,6 +4,7 @@ from logging import Logger import os from invokeai.app.services.image_record_storage import SqliteImageRecordStorage from invokeai.app.services.images import ImageService +from invokeai.app.services.metadata import CoreMetadataService from invokeai.app.services.urls import LocalUrlService from invokeai.backend.util.logging import InvokeAILogger @@ -18,7 +19,6 @@ from ..services.invocation_services import InvocationServices from ..services.invoker import Invoker from ..services.processor import DefaultInvocationProcessor from ..services.sqlite import SqliteItemStorage -from ..services.metadata import PngMetadataService from .events import FastAPIEventService @@ -59,7 +59,7 @@ class ApiDependencies: DiskLatentsStorage(f"{output_folder}/latents") ) - metadata = PngMetadataService() + metadata = CoreMetadataService() urls = LocalUrlService() @@ -80,6 +80,7 @@ class ApiDependencies: metadata=metadata, url=urls, logger=logger, + graph_execution_manager=graph_execution_manager, ) services = InvocationServices( diff --git a/invokeai/app/api/models/images.py b/invokeai/app/api/models/images.py index 866e181561..fa04702326 100644 --- a/invokeai/app/api/models/images.py +++ b/invokeai/app/api/models/images.py @@ -2,7 +2,6 @@ from typing import Optional from pydantic import BaseModel, Field from invokeai.app.models.image import ImageType -from invokeai.app.services.metadata import InvokeAIMetadata class ImageResponseMetadata(BaseModel): @@ -11,9 +10,9 @@ class ImageResponseMetadata(BaseModel): created: int = Field(description="The creation timestamp of the image") width: int = Field(description="The width of the image in pixels") height: int = Field(description="The height of the image in pixels") - invokeai: Optional[InvokeAIMetadata] = Field( - description="The image's InvokeAI-specific metadata" - ) + # invokeai: Optional[InvokeAIMetadata] = Field( + # description="The image's InvokeAI-specific metadata" + # ) class ImageResponse(BaseModel): diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 64993e011a..7259beb1a8 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -7,6 +7,7 @@ from pydantic import BaseModel, Field import torch from invokeai.app.invocations.util.choose_model import choose_model +from invokeai.app.models.image import ImageCategory from invokeai.app.util.misc import SEED_MAX, get_random_seed from invokeai.app.util.step_callback import stable_diffusion_step_callback @@ -356,20 +357,30 @@ class LatentsToImageInvocation(BaseInvocation): np_image = model.decode_latents(latents) image = model.numpy_to_pil(np_image)[0] - image_type = ImageType.RESULT - image_name = context.services.images.create_name( - context.graph_execution_state_id, self.id + # image_type = ImageType.RESULT + # image_name = context.services.images.create_name( + # context.graph_execution_state_id, self.id + # ) + + # metadata = context.services.metadata.build_metadata( + # session_id=context.graph_execution_state_id, node=self + # ) + + # torch.cuda.empty_cache() + + # context.services.images.save(image_type, image_name, image, metadata) + image_dto = context.services.images_new.create( + image=image, + image_type=ImageType.RESULT, + image_category=ImageCategory.IMAGE, + session_id=context.graph_execution_state_id, + node_id=self.id, ) - metadata = context.services.metadata.build_metadata( - session_id=context.graph_execution_state_id, node=self - ) - - torch.cuda.empty_cache() - - context.services.images.save(image_type, image_name, image, metadata) return build_image_output( - image_type=image_type, image_name=image_name, image=image + image_type=image_dto.image_type, + image_name=image_dto.image_name, + image=image, ) diff --git a/invokeai/app/models/metadata.py b/invokeai/app/models/metadata.py index 35998fa27e..481f2c1ff6 100644 --- a/invokeai/app/models/metadata.py +++ b/invokeai/app/models/metadata.py @@ -1,5 +1,5 @@ from typing import Optional -from pydantic import BaseModel, Field, StrictFloat, StrictInt, StrictStr +from pydantic import BaseModel, Extra, Field, StrictFloat, StrictInt, StrictStr class ImageMetadata(BaseModel): @@ -8,11 +8,24 @@ class ImageMetadata(BaseModel): Also includes any metadata from the image's PNG tEXt chunks. - Generated by traversing the execution graph, collecting the parameters of the nearest ancestors of a given node. + Generated by traversing the execution graph, collecting the parameters of the nearest ancestors + of a given node. Full metadata may be accessed by querying for the session in the `graph_executions` table. """ + class Config: + extra = Extra.allow + """ + This lets the ImageMetadata class accept arbitrary additional fields. The CoreMetadataService + won't add any fields that are not already defined, but other a different metadata service + implementation might. + """ + + type: Optional[StrictStr] = Field( + default=None, + description="The type of the ancestor node of the image output node.", + ) positive_conditioning: Optional[StrictStr] = Field( default=None, description="The positive conditioning." ) @@ -20,10 +33,10 @@ class ImageMetadata(BaseModel): default=None, description="The negative conditioning." ) width: Optional[StrictInt] = Field( - default=None, description="Width of the image/tensor in pixels." + default=None, description="Width of the image/latents in pixels." ) height: Optional[StrictInt] = Field( - default=None, description="Height of the image/tensor in pixels." + default=None, description="Height of the image/latents in pixels." ) seed: Optional[StrictInt] = Field( default=None, description="The seed used for noise generation." @@ -42,18 +55,21 @@ class ImageMetadata(BaseModel): ) strength: Optional[StrictFloat] = Field( default=None, - description="The strength used for image-to-image/tensor-to-tensor.", + description="The strength used for image-to-image/latents-to-latents.", ) - image: Optional[StrictStr] = Field( - default=None, description="The ID of the initial image." + latents: Optional[StrictStr] = Field( + default=None, description="The ID of the initial latents." ) - tensor: Optional[StrictStr] = Field( - default=None, description="The ID of the initial tensor." + vae: Optional[StrictStr] = Field( + default=None, description="The VAE used for decoding." + ) + unet: Optional[StrictStr] = Field( + default=None, description="The UNet used dor inference." + ) + clip: Optional[StrictStr] = Field( + default=None, description="The CLIP Encoder used for conditioning." ) - # Pending model refactor: - # vae: Optional[str] = Field(default=None,description="The VAE used for decoding.") - # unet: Optional[str] = Field(default=None,description="The UNet used dor inference.") - # clip: Optional[str] = Field(default=None,description="The CLIP Encoder used for conditioning.") extra: Optional[StrictStr] = Field( - default=None, description="Extra metadata, extracted from the PNG tEXt chunk." + default=None, + description="Uploaded image metadata, extracted from the PNG tEXt chunk.", ) diff --git a/invokeai/app/services/graph.py b/invokeai/app/services/graph.py index ab6e4ed49d..44688ada0a 100644 --- a/invokeai/app/services/graph.py +++ b/invokeai/app/services/graph.py @@ -713,6 +713,13 @@ class Graph(BaseModel): g.add_edges_from(set([(e.source.node_id, e.destination.node_id) for e in self.edges])) return g + def nx_graph_with_data(self) -> nx.DiGraph: + """Returns a NetworkX DiGraph representing the data and layout of this graph""" + g = nx.DiGraph() + g.add_nodes_from([n for n in self.nodes.items()]) + g.add_edges_from(set([(e.source.node_id, e.destination.node_id) for e in self.edges])) + return g + def nx_graph_flat( self, nx_graph: Optional[nx.DiGraph] = None, prefix: Optional[str] = None ) -> nx.DiGraph: diff --git a/invokeai/app/services/image_file_storage.py b/invokeai/app/services/image_file_storage.py index 3a99940068..dadb9584d5 100644 --- a/invokeai/app/services/image_file_storage.py +++ b/invokeai/app/services/image_file_storage.py @@ -6,11 +6,11 @@ from queue import Queue from typing import Dict, Optional from PIL.Image import Image as PILImageType -from PIL import Image -from PIL.PngImagePlugin import PngInfo +from PIL import Image, PngImagePlugin from send2trash import send2trash from invokeai.app.models.image import ImageType +from invokeai.app.models.metadata import ImageMetadata from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail @@ -54,7 +54,7 @@ class ImageFileStorageBase(ABC): image: PILImageType, image_type: ImageType, image_name: str, - pnginfo: Optional[PngInfo] = None, + metadata: Optional[ImageMetadata] = None, thumbnail_size: int = 256, ) -> None: """Saves an image and a 256x256 WEBP thumbnail. Returns a tuple of the image name, thumbnail name, and created timestamp.""" @@ -109,12 +109,18 @@ class DiskImageFileStorage(ImageFileStorageBase): image: PILImageType, image_type: ImageType, image_name: str, - pnginfo: Optional[PngInfo] = None, + metadata: Optional[ImageMetadata] = None, thumbnail_size: int = 256, ) -> None: try: image_path = self.get_path(image_type, image_name) - image.save(image_path, "PNG", pnginfo=pnginfo) + + if metadata is not None: + pnginfo = PngImagePlugin.PngInfo() + pnginfo.add_text("invokeai", metadata.json()) + image.save(image_path, "PNG", pnginfo=pnginfo) + else: + image.save(image_path, "PNG") thumbnail_name = get_thumbnail_name(image_name) thumbnail_path = self.get_path(image_type, thumbnail_name, thumbnail=True) diff --git a/invokeai/app/services/images.py b/invokeai/app/services/images.py index 9b46ebcc09..fc4c85fbdf 100644 --- a/invokeai/app/services/images.py +++ b/invokeai/app/services/images.py @@ -1,10 +1,9 @@ from abc import ABC, abstractmethod import json from logging import Logger -from typing import Optional, Union +from typing import Optional, TYPE_CHECKING, Union import uuid from PIL.Image import Image as PILImageType -from PIL import PngImagePlugin from invokeai.app.models.image import ImageCategory, ImageType from invokeai.app.models.metadata import ImageMetadata @@ -17,12 +16,16 @@ from invokeai.app.services.models.image_record import ( image_record_to_dto, ) from invokeai.app.services.image_file_storage import ImageFileStorageBase -from invokeai.app.services.item_storage import PaginatedResults +from invokeai.app.services.item_storage import ItemStorageABC, PaginatedResults from invokeai.app.services.metadata import MetadataServiceBase from invokeai.app.services.urls import UrlServiceBase from invokeai.app.util.misc import get_iso_timestamp +if TYPE_CHECKING: + from invokeai.app.services.graph import GraphExecutionState + + class ImageServiceABC(ABC): """ High-level service for image management. @@ -59,7 +62,9 @@ class ImageServiceABC(ABC): pass @abstractmethod - def get_url(self, image_type: ImageType, image_name: str, thumbnail: bool = False) -> str: + def get_url( + self, image_type: ImageType, image_name: str, thumbnail: bool = False + ) -> str: """Gets an image's or thumbnail's URL""" pass @@ -113,6 +118,7 @@ class ImageServiceDependencies: metadata: MetadataServiceBase urls: UrlServiceBase logger: Logger + graph_execution_manager: ItemStorageABC["GraphExecutionState"] def __init__( self, @@ -121,12 +127,14 @@ class ImageServiceDependencies: metadata: MetadataServiceBase, url: UrlServiceBase, logger: Logger, + graph_execution_manager: ItemStorageABC["GraphExecutionState"], ): self.records = image_record_storage self.files = image_file_storage self.metadata = metadata self.urls = url self.logger = logger + self.graph_execution_manager = graph_execution_manager class ImageService(ImageServiceABC): @@ -139,6 +147,7 @@ class ImageService(ImageServiceABC): metadata: MetadataServiceBase, url: UrlServiceBase, logger: Logger, + graph_execution_manager: ItemStorageABC["GraphExecutionState"], ): self._services = ImageServiceDependencies( image_record_storage=image_record_storage, @@ -146,6 +155,7 @@ class ImageService(ImageServiceABC): metadata=metadata, url=url, logger=logger, + graph_execution_manager=graph_execution_manager, ) def create( @@ -155,7 +165,6 @@ class ImageService(ImageServiceABC): image_category: ImageCategory, node_id: Optional[str] = None, session_id: Optional[str] = None, - metadata: Optional[ImageMetadata] = None, ) -> ImageDTO: image_name = self._create_image_name( image_type=image_type, @@ -165,12 +174,7 @@ class ImageService(ImageServiceABC): ) timestamp = get_iso_timestamp() - - if metadata is not None: - pnginfo = PngImagePlugin.PngInfo() - pnginfo.add_text("invokeai", json.dumps(metadata)) - else: - pnginfo = None + metadata = self._get_metadata(session_id, node_id) try: # TODO: Consider using a transaction here to ensure consistency between storage and database @@ -178,7 +182,7 @@ class ImageService(ImageServiceABC): image_type=image_type, image_name=image_name, image=image, - pnginfo=pnginfo, + metadata=metadata, ) self._services.records.save( @@ -237,24 +241,6 @@ class ImageService(ImageServiceABC): self._services.logger.error("Problem getting image record") raise e - def get_path( - self, image_type: ImageType, image_name: str, thumbnail: bool = False - ) -> str: - try: - return self._services.files.get_path(image_type, image_name, thumbnail) - except Exception as e: - self._services.logger.error("Problem getting image path") - raise e - - def get_url( - self, image_type: ImageType, image_name: str, thumbnail: bool = False - ) -> str: - try: - return self._services.urls.get_image_url(image_type, image_name, thumbnail) - except Exception as e: - self._services.logger.error("Problem getting image path") - raise e - def get_dto(self, image_type: ImageType, image_name: str) -> ImageDTO: try: image_record = self._services.records.get(image_type, image_name) @@ -273,6 +259,24 @@ class ImageService(ImageServiceABC): self._services.logger.error("Problem getting image DTO") raise e + def get_path( + self, image_type: ImageType, image_name: str, thumbnail: bool = False + ) -> str: + try: + return self._services.files.get_path(image_type, image_name, thumbnail) + except Exception as e: + self._services.logger.error("Problem getting image path") + raise e + + def get_url( + self, image_type: ImageType, image_name: str, thumbnail: bool = False + ) -> str: + try: + return self._services.urls.get_image_url(image_type, image_name, thumbnail) + except Exception as e: + self._services.logger.error("Problem getting image path") + raise e + def get_many( self, image_type: ImageType, @@ -353,3 +357,15 @@ class ImageService(ImageServiceABC): return f"{image_type.value}_{image_category.value}_{session_id}_{node_id}_{uuid_str}.png" return f"{image_type.value}_{image_category.value}_{uuid_str}.png" + + def _get_metadata( + self, session_id: Optional[str] = None, node_id: Optional[str] = None + ) -> Union[ImageMetadata, None]: + """Get the metadata for a node.""" + metadata = None + + if node_id is not None and session_id is not None: + session = self._services.graph_execution_manager.get(session_id) + metadata = self._services.metadata.create_image_metadata(session, node_id) + + return metadata diff --git a/invokeai/app/services/metadata.py b/invokeai/app/services/metadata.py index 40ec189cd0..07509f4e3c 100644 --- a/invokeai/app/services/metadata.py +++ b/invokeai/app/services/metadata.py @@ -1,295 +1,142 @@ import json from abc import ABC, abstractmethod -from typing import Any, Dict, Optional, TypedDict -from PIL import Image, PngImagePlugin -from pydantic import BaseModel +from typing import Any, Union +import networkx as nx -from invokeai.app.models.image import ImageType, is_image_type - - -class MetadataImageField(TypedDict): - """Pydantic-less ImageField, used for metadata parsing.""" - - image_type: ImageType - image_name: str - - -class MetadataLatentsField(TypedDict): - """Pydantic-less LatentsField, used for metadata parsing.""" - - latents_name: str - - -class MetadataColorField(TypedDict): - """Pydantic-less ColorField, used for metadata parsing""" - - r: int - g: int - b: int - a: int - - -# TODO: This is a placeholder for `InvocationsUnion` pending resolution of circular imports -NodeMetadata = Dict[ - str, - None - | str - | int - | float - | bool - | MetadataImageField - | MetadataLatentsField - | MetadataColorField, -] - - -class InvokeAIMetadata(TypedDict, total=False): - """InvokeAI-specific metadata format.""" - - session_id: Optional[str] - node: Optional[NodeMetadata] - - -def build_invokeai_metadata_pnginfo( - metadata: InvokeAIMetadata | None, -) -> PngImagePlugin.PngInfo: - """Builds a PngInfo object with key `"invokeai"` and value `metadata`""" - pnginfo = PngImagePlugin.PngInfo() - - if metadata is not None: - pnginfo.add_text("invokeai", json.dumps(metadata)) - - return pnginfo +from invokeai.app.models.metadata import ImageMetadata +from invokeai.app.services.graph import Edge, Graph, GraphExecutionState class MetadataServiceBase(ABC): - @abstractmethod - def get_metadata(self, image: Image.Image) -> InvokeAIMetadata | None: - """Gets the InvokeAI metadata from a PIL Image, skipping invalid values""" - pass + """Handles building metadata for nodes, images, and outputs.""" @abstractmethod - def build_metadata( - self, session_id: str, node: BaseModel - ) -> InvokeAIMetadata | None: - """Builds an InvokeAIMetadata object""" + def create_image_metadata( + self, session: GraphExecutionState, node_id: str + ) -> ImageMetadata: + """Builds an ImageMetadata object for a node.""" pass - # @abstractmethod - # def create_metadata(self, session_id: str, node_id: str) -> dict: - # """Creates metadata for a result""" - # pass -class PngMetadataService(MetadataServiceBase): - """Handles loading and building metadata for images.""" +class CoreMetadataService(MetadataServiceBase): + _ANCESTOR_TYPES = ["t2l", "l2l"] + """The ancestor types that contain the core metadata""" - # TODO: Use `InvocationsUnion` to **validate** metadata as representing a fully-functioning node - def _load_metadata(self, image: Image.Image) -> dict | None: - """Loads a specific info entry from a PIL Image.""" + _ANCESTOR_PARAMS = ["type", "steps", "model", "cfg_scale", "scheduler", "strength"] + """The core metadata parameters in the ancestor types""" - try: - info = image.info.get("invokeai") + _NOISE_FIELDS = ["seed", "width", "height"] + """The core metadata parameters in the noise node""" - if type(info) is not str: - return None - - loaded_metadata = json.loads(info) - - if type(loaded_metadata) is not dict: - return None - - if len(loaded_metadata.items()) == 0: - return None - - return loaded_metadata - except: - return None - - def get_metadata(self, image: Image.Image) -> dict | None: - """Retrieves an image's metadata as a dict""" - loaded_metadata = self._load_metadata(image) - - return loaded_metadata - - def build_metadata(self, session_id: str, node: BaseModel) -> InvokeAIMetadata: - metadata = InvokeAIMetadata(session_id=session_id, node=node.dict()) + def create_image_metadata( + self, session: GraphExecutionState, node_id: str + ) -> ImageMetadata: + metadata = self._build_metadata_from_graph(session, node_id) return metadata + def _find_nearest_ancestor(self, G: nx.DiGraph, node_id: str) -> Union[str, None]: + """ + Finds the id of the nearest ancestor (of a valid type) of a given node. -from enum import Enum + Parameters: + G (nx.DiGraph): The execution graph, converted in to a networkx DiGraph. Its nodes must + have the same data as the execution graph. + node_id (str): The ID of the node. -from abc import ABC, abstractmethod -import json -import sqlite3 -from threading import Lock -from typing import Any, Union + Returns: + str | None: The ID of the nearest ancestor, or None if there are no valid ancestors. + """ -import networkx as nx + # Retrieve the node from the graph + node = G.nodes[node_id] -from pydantic import BaseModel, Field, parse_obj_as, parse_raw_as -from invokeai.app.invocations.image import ImageOutput -from invokeai.app.services.graph import Edge, GraphExecutionState -from invokeai.app.invocations.latent import LatentsOutput -from invokeai.app.services.item_storage import PaginatedResults -from invokeai.app.util.misc import get_timestamp + # If the node type is one of the core metadata node types, return its id + if node.get("type") in self._ANCESTOR_TYPES: + return node.get("id") + # Else, look for the ancestor in the predecessor nodes + for predecessor in G.predecessors(node_id): + result = self._find_nearest_ancestor(G, predecessor) + if result: + return result -class ResultType(str, Enum): - image_output = "image_output" - latents_output = "latents_output" + # If there are no valid ancestors, return None + return None + def _get_additional_metadata( + self, graph: Graph, node_id: str + ) -> Union[dict[str, Any], None]: + """ + Returns additional metadata for a given node. -class Result(BaseModel): - """A session result""" + Parameters: + graph (Graph): The execution graph. + node_id (str): The ID of the node. - id: str = Field(description="Result ID") - session_id: str = Field(description="Session ID") - node_id: str = Field(description="Node ID") - data: Union[LatentsOutput, ImageOutput] = Field(description="The result data") + Returns: + dict[str, Any] | None: A dictionary of additional metadata. + """ + metadata = {} -class ResultWithSession(BaseModel): - """A result with its session""" + # Iterate over all edges in the graph + for edge in graph.edges: + dest_node_id = edge.destination.node_id + dest_field = edge.destination.field + source_node_dict = graph.nodes[edge.source.node_id].dict() - result: Result = Field(description="The result") - session: GraphExecutionState = Field(description="The session") + # If the destination node ID matches the given node ID, gather necessary metadata + if dest_node_id == node_id: + # If the destination field is 'positive_conditioning', add the 'prompt' from the source node + if dest_field == "positive_conditioning": + metadata["positive_conditioning"] = source_node_dict.get("prompt") + # If the destination field is 'negative_conditioning', add the 'prompt' from the source node + if dest_field == "negative_conditioning": + metadata["negative_conditioning"] = source_node_dict.get("prompt") + # If the destination field is 'noise', add the core noise fields from the source node + if dest_field == "noise": + for field in self._NOISE_FIELDS: + metadata[field] = source_node_dict.get(field) + return metadata + def _build_metadata_from_graph( + self, session: GraphExecutionState, node_id: str + ) -> ImageMetadata: + """ + Builds an ImageMetadata object for a node. -# # Create a directed graph -# from typing import Any, TypedDict, Union -# from networkx import DiGraph -# import networkx as nx -# import json + Parameters: + session (GraphExecutionState): The session. + node_id (str): The ID of the node. + Returns: + ImageMetadata: The metadata for the node. + """ -# # We need to use a loose class for nodes to allow for graceful parsing - we cannot use the stricter -# # model used by the system, because we may be a graph in an old format. We can, however, use the -# # Edge model, because the edge format does not change. -# class LooseGraph(BaseModel): -# id: str -# nodes: dict[str, dict[str, Any]] -# edges: list[Edge] + graph = session.execution_graph + # Find the nearest ancestor of the given node + ancestor_id = self._find_nearest_ancestor(graph.nx_graph_with_data(), node_id) -# # An intermediate type used during parsing -# class NearestAncestor(TypedDict): -# node_id: str -# metadata: dict[str, Any] + # If no ancestor was found, return an empty ImageMetadata object + if ancestor_id is None: + return ImageMetadata() + ancestor_node = graph.get_node(ancestor_id) -# # The ancestor types that contain the core metadata -# ANCESTOR_TYPES = ['t2l', 'l2l'] + ancestor_metadata = { + param: val + for param, val in ancestor_node.dict().items() + if param in self._ANCESTOR_PARAMS + } -# # The core metadata parameters in the ancestor types -# ANCESTOR_PARAMS = ['steps', 'model', 'cfg_scale', 'scheduler', 'strength'] + # Get additional metadata related to the ancestor + addl_metadata = self._get_additional_metadata(graph, ancestor_id) -# # The core metadata parameters in the noise node -# NOISE_FIELDS = ['seed', 'width', 'height'] + # If additional metadata was found, add it to the main metadata + if addl_metadata is not None: + ancestor_metadata.update(addl_metadata) -# # Find nearest t2l or l2l ancestor from a given l2i node -# def find_nearest_ancestor(G: DiGraph, node_id: str) -> Union[NearestAncestor, None]: -# """Returns metadata for the nearest ancestor of a given node. - -# Parameters: -# G (DiGraph): A directed graph. -# node_id (str): The ID of the starting node. - -# Returns: -# NearestAncestor | None: An object with the ID and metadata of the nearest ancestor. -# """ - -# # Retrieve the node from the graph -# node = G.nodes[node_id] - -# # If the node type is one of the core metadata node types, gather necessary metadata and return -# if node.get('type') in ANCESTOR_TYPES: -# parsed_metadata = {param: val for param, val in node.items() if param in ANCESTOR_PARAMS} -# return NearestAncestor(node_id=node_id, metadata=parsed_metadata) - - -# # Else, look for the ancestor in the predecessor nodes -# for predecessor in G.predecessors(node_id): -# result = find_nearest_ancestor(G, predecessor) -# if result: -# return result - -# # If there are no valid ancestors, return None -# return None - - -# def get_additional_metadata(graph: LooseGraph, node_id: str) -> Union[dict[str, Any], None]: -# """Collects additional metadata from nodes connected to a given node. - -# Parameters: -# graph (LooseGraph): The graph. -# node_id (str): The ID of the node. - -# Returns: -# dict | None: A dictionary containing additional metadata. -# """ - -# metadata = {} - -# # Iterate over all edges in the graph -# for edge in graph.edges: -# dest_node_id = edge.destination.node_id -# dest_field = edge.destination.field -# source_node = graph.nodes[edge.source.node_id] - -# # If the destination node ID matches the given node ID, gather necessary metadata -# if dest_node_id == node_id: -# # If the destination field is 'positive_conditioning', add the 'prompt' from the source node -# if dest_field == 'positive_conditioning': -# metadata['positive_conditioning'] = source_node.get('prompt') -# # If the destination field is 'negative_conditioning', add the 'prompt' from the source node -# if dest_field == 'negative_conditioning': -# metadata['negative_conditioning'] = source_node.get('prompt') -# # If the destination field is 'noise', add the core noise fields from the source node -# if dest_field == 'noise': -# for field in NOISE_FIELDS: -# metadata[field] = source_node.get(field) -# return metadata - -# def build_core_metadata(graph_raw: str, node_id: str) -> Union[dict, None]: -# """Builds the core metadata for a given node. - -# Parameters: -# graph_raw (str): The graph structure as a raw string. -# node_id (str): The ID of the node. - -# Returns: -# dict | None: A dictionary containing core metadata. -# """ - -# # Create a directed graph to facilitate traversal -# G = nx.DiGraph() - -# # Convert the raw graph string into a JSON object -# graph = parse_obj_as(LooseGraph, graph_raw) - -# # Add nodes and edges to the graph -# for node_id, node_data in graph.nodes.items(): -# G.add_node(node_id, **node_data) -# for edge in graph.edges: -# G.add_edge(edge.source.node_id, edge.destination.node_id) - -# # Find the nearest ancestor of the given node -# ancestor = find_nearest_ancestor(G, node_id) - -# # If no ancestor was found, return None -# if ancestor is None: -# return None - -# metadata = ancestor['metadata'] -# ancestor_id = ancestor['node_id'] - -# # Get additional metadata related to the ancestor -# addl_metadata = get_additional_metadata(graph, ancestor_id) - -# # If additional metadata was found, add it to the main metadata -# if addl_metadata is not None: -# metadata.update(addl_metadata) - -# return metadata + return ImageMetadata(**ancestor_metadata) diff --git a/invokeai/app/services/models/image_record.py b/invokeai/app/services/models/image_record.py index 6e15574eb9..29a2d71232 100644 --- a/invokeai/app/services/models/image_record.py +++ b/invokeai/app/services/models/image_record.py @@ -62,9 +62,12 @@ def deserialize_image_record(image_row: sqlite3.Row) -> ImageRecord: image_type = ImageType(image_dict.get("image_type", ImageType.RESULT.value)) - raw_metadata = image_dict.get("metadata", "{}") + raw_metadata = image_dict.get("metadata") - metadata = ImageMetadata.parse_raw(raw_metadata) + if raw_metadata is not None: + metadata = ImageMetadata.parse_raw(raw_metadata) + else: + metadata = None return ImageRecord( image_name=image_dict.get("id", "unknown"),