feat(nodes): add metadata handling

This commit is contained in:
psychedelicious 2023-05-22 15:48:12 +10:00 committed by Kent Keirsey
parent f071b03ceb
commit 5de3c41d19
9 changed files with 228 additions and 322 deletions

View File

@ -4,6 +4,7 @@ from logging import Logger
import os import os
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
from invokeai.app.services.images import ImageService from invokeai.app.services.images import ImageService
from invokeai.app.services.metadata import CoreMetadataService
from invokeai.app.services.urls import LocalUrlService from invokeai.app.services.urls import LocalUrlService
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
@ -18,7 +19,6 @@ from ..services.invocation_services import InvocationServices
from ..services.invoker import Invoker from ..services.invoker import Invoker
from ..services.processor import DefaultInvocationProcessor from ..services.processor import DefaultInvocationProcessor
from ..services.sqlite import SqliteItemStorage from ..services.sqlite import SqliteItemStorage
from ..services.metadata import PngMetadataService
from .events import FastAPIEventService from .events import FastAPIEventService
@ -59,7 +59,7 @@ class ApiDependencies:
DiskLatentsStorage(f"{output_folder}/latents") DiskLatentsStorage(f"{output_folder}/latents")
) )
metadata = PngMetadataService() metadata = CoreMetadataService()
urls = LocalUrlService() urls = LocalUrlService()
@ -80,6 +80,7 @@ class ApiDependencies:
metadata=metadata, metadata=metadata,
url=urls, url=urls,
logger=logger, logger=logger,
graph_execution_manager=graph_execution_manager,
) )
services = InvocationServices( services = InvocationServices(

View File

@ -2,7 +2,6 @@ from typing import Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from invokeai.app.models.image import ImageType from invokeai.app.models.image import ImageType
from invokeai.app.services.metadata import InvokeAIMetadata
class ImageResponseMetadata(BaseModel): class ImageResponseMetadata(BaseModel):
@ -11,9 +10,9 @@ class ImageResponseMetadata(BaseModel):
created: int = Field(description="The creation timestamp of the image") created: int = Field(description="The creation timestamp of the image")
width: int = Field(description="The width of the image in pixels") width: int = Field(description="The width of the image in pixels")
height: int = Field(description="The height of the image in pixels") height: int = Field(description="The height of the image in pixels")
invokeai: Optional[InvokeAIMetadata] = Field( # invokeai: Optional[InvokeAIMetadata] = Field(
description="The image's InvokeAI-specific metadata" # description="The image's InvokeAI-specific metadata"
) # )
class ImageResponse(BaseModel): class ImageResponse(BaseModel):

View File

@ -7,6 +7,7 @@ from pydantic import BaseModel, Field
import torch import torch
from invokeai.app.invocations.util.choose_model import choose_model from invokeai.app.invocations.util.choose_model import choose_model
from invokeai.app.models.image import ImageCategory
from invokeai.app.util.misc import SEED_MAX, get_random_seed from invokeai.app.util.misc import SEED_MAX, get_random_seed
from invokeai.app.util.step_callback import stable_diffusion_step_callback from invokeai.app.util.step_callback import stable_diffusion_step_callback
@ -356,20 +357,30 @@ class LatentsToImageInvocation(BaseInvocation):
np_image = model.decode_latents(latents) np_image = model.decode_latents(latents)
image = model.numpy_to_pil(np_image)[0] image = model.numpy_to_pil(np_image)[0]
image_type = ImageType.RESULT # image_type = ImageType.RESULT
image_name = context.services.images.create_name( # image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id # context.graph_execution_state_id, self.id
# )
# metadata = context.services.metadata.build_metadata(
# session_id=context.graph_execution_state_id, node=self
# )
# torch.cuda.empty_cache()
# context.services.images.save(image_type, image_name, image, metadata)
image_dto = context.services.images_new.create(
image=image,
image_type=ImageType.RESULT,
image_category=ImageCategory.IMAGE,
session_id=context.graph_execution_state_id,
node_id=self.id,
) )
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
torch.cuda.empty_cache()
context.services.images.save(image_type, image_name, image, metadata)
return build_image_output( return build_image_output(
image_type=image_type, image_name=image_name, image=image image_type=image_dto.image_type,
image_name=image_dto.image_name,
image=image,
) )

View File

@ -1,5 +1,5 @@
from typing import Optional from typing import Optional
from pydantic import BaseModel, Field, StrictFloat, StrictInt, StrictStr from pydantic import BaseModel, Extra, Field, StrictFloat, StrictInt, StrictStr
class ImageMetadata(BaseModel): class ImageMetadata(BaseModel):
@ -8,11 +8,24 @@ class ImageMetadata(BaseModel):
Also includes any metadata from the image's PNG tEXt chunks. Also includes any metadata from the image's PNG tEXt chunks.
Generated by traversing the execution graph, collecting the parameters of the nearest ancestors of a given node. Generated by traversing the execution graph, collecting the parameters of the nearest ancestors
of a given node.
Full metadata may be accessed by querying for the session in the `graph_executions` table. Full metadata may be accessed by querying for the session in the `graph_executions` table.
""" """
class Config:
extra = Extra.allow
"""
This lets the ImageMetadata class accept arbitrary additional fields. The CoreMetadataService
won't add any fields that are not already defined, but other a different metadata service
implementation might.
"""
type: Optional[StrictStr] = Field(
default=None,
description="The type of the ancestor node of the image output node.",
)
positive_conditioning: Optional[StrictStr] = Field( positive_conditioning: Optional[StrictStr] = Field(
default=None, description="The positive conditioning." default=None, description="The positive conditioning."
) )
@ -20,10 +33,10 @@ class ImageMetadata(BaseModel):
default=None, description="The negative conditioning." default=None, description="The negative conditioning."
) )
width: Optional[StrictInt] = Field( width: Optional[StrictInt] = Field(
default=None, description="Width of the image/tensor in pixels." default=None, description="Width of the image/latents in pixels."
) )
height: Optional[StrictInt] = Field( height: Optional[StrictInt] = Field(
default=None, description="Height of the image/tensor in pixels." default=None, description="Height of the image/latents in pixels."
) )
seed: Optional[StrictInt] = Field( seed: Optional[StrictInt] = Field(
default=None, description="The seed used for noise generation." default=None, description="The seed used for noise generation."
@ -42,18 +55,21 @@ class ImageMetadata(BaseModel):
) )
strength: Optional[StrictFloat] = Field( strength: Optional[StrictFloat] = Field(
default=None, default=None,
description="The strength used for image-to-image/tensor-to-tensor.", description="The strength used for image-to-image/latents-to-latents.",
) )
image: Optional[StrictStr] = Field( latents: Optional[StrictStr] = Field(
default=None, description="The ID of the initial image." default=None, description="The ID of the initial latents."
) )
tensor: Optional[StrictStr] = Field( vae: Optional[StrictStr] = Field(
default=None, description="The ID of the initial tensor." default=None, description="The VAE used for decoding."
)
unet: Optional[StrictStr] = Field(
default=None, description="The UNet used dor inference."
)
clip: Optional[StrictStr] = Field(
default=None, description="The CLIP Encoder used for conditioning."
) )
# Pending model refactor:
# vae: Optional[str] = Field(default=None,description="The VAE used for decoding.")
# unet: Optional[str] = Field(default=None,description="The UNet used dor inference.")
# clip: Optional[str] = Field(default=None,description="The CLIP Encoder used for conditioning.")
extra: Optional[StrictStr] = Field( extra: Optional[StrictStr] = Field(
default=None, description="Extra metadata, extracted from the PNG tEXt chunk." default=None,
description="Uploaded image metadata, extracted from the PNG tEXt chunk.",
) )

View File

@ -713,6 +713,13 @@ class Graph(BaseModel):
g.add_edges_from(set([(e.source.node_id, e.destination.node_id) for e in self.edges])) g.add_edges_from(set([(e.source.node_id, e.destination.node_id) for e in self.edges]))
return g return g
def nx_graph_with_data(self) -> nx.DiGraph:
"""Returns a NetworkX DiGraph representing the data and layout of this graph"""
g = nx.DiGraph()
g.add_nodes_from([n for n in self.nodes.items()])
g.add_edges_from(set([(e.source.node_id, e.destination.node_id) for e in self.edges]))
return g
def nx_graph_flat( def nx_graph_flat(
self, nx_graph: Optional[nx.DiGraph] = None, prefix: Optional[str] = None self, nx_graph: Optional[nx.DiGraph] = None, prefix: Optional[str] = None
) -> nx.DiGraph: ) -> nx.DiGraph:

View File

@ -6,11 +6,11 @@ from queue import Queue
from typing import Dict, Optional from typing import Dict, Optional
from PIL.Image import Image as PILImageType from PIL.Image import Image as PILImageType
from PIL import Image from PIL import Image, PngImagePlugin
from PIL.PngImagePlugin import PngInfo
from send2trash import send2trash from send2trash import send2trash
from invokeai.app.models.image import ImageType from invokeai.app.models.image import ImageType
from invokeai.app.models.metadata import ImageMetadata
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
@ -54,7 +54,7 @@ class ImageFileStorageBase(ABC):
image: PILImageType, image: PILImageType,
image_type: ImageType, image_type: ImageType,
image_name: str, image_name: str,
pnginfo: Optional[PngInfo] = None, metadata: Optional[ImageMetadata] = None,
thumbnail_size: int = 256, thumbnail_size: int = 256,
) -> None: ) -> None:
"""Saves an image and a 256x256 WEBP thumbnail. Returns a tuple of the image name, thumbnail name, and created timestamp.""" """Saves an image and a 256x256 WEBP thumbnail. Returns a tuple of the image name, thumbnail name, and created timestamp."""
@ -109,12 +109,18 @@ class DiskImageFileStorage(ImageFileStorageBase):
image: PILImageType, image: PILImageType,
image_type: ImageType, image_type: ImageType,
image_name: str, image_name: str,
pnginfo: Optional[PngInfo] = None, metadata: Optional[ImageMetadata] = None,
thumbnail_size: int = 256, thumbnail_size: int = 256,
) -> None: ) -> None:
try: try:
image_path = self.get_path(image_type, image_name) image_path = self.get_path(image_type, image_name)
image.save(image_path, "PNG", pnginfo=pnginfo)
if metadata is not None:
pnginfo = PngImagePlugin.PngInfo()
pnginfo.add_text("invokeai", metadata.json())
image.save(image_path, "PNG", pnginfo=pnginfo)
else:
image.save(image_path, "PNG")
thumbnail_name = get_thumbnail_name(image_name) thumbnail_name = get_thumbnail_name(image_name)
thumbnail_path = self.get_path(image_type, thumbnail_name, thumbnail=True) thumbnail_path = self.get_path(image_type, thumbnail_name, thumbnail=True)

View File

@ -1,10 +1,9 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import json import json
from logging import Logger from logging import Logger
from typing import Optional, Union from typing import Optional, TYPE_CHECKING, Union
import uuid import uuid
from PIL.Image import Image as PILImageType from PIL.Image import Image as PILImageType
from PIL import PngImagePlugin
from invokeai.app.models.image import ImageCategory, ImageType from invokeai.app.models.image import ImageCategory, ImageType
from invokeai.app.models.metadata import ImageMetadata from invokeai.app.models.metadata import ImageMetadata
@ -17,12 +16,16 @@ from invokeai.app.services.models.image_record import (
image_record_to_dto, image_record_to_dto,
) )
from invokeai.app.services.image_file_storage import ImageFileStorageBase from invokeai.app.services.image_file_storage import ImageFileStorageBase
from invokeai.app.services.item_storage import PaginatedResults from invokeai.app.services.item_storage import ItemStorageABC, PaginatedResults
from invokeai.app.services.metadata import MetadataServiceBase from invokeai.app.services.metadata import MetadataServiceBase
from invokeai.app.services.urls import UrlServiceBase from invokeai.app.services.urls import UrlServiceBase
from invokeai.app.util.misc import get_iso_timestamp from invokeai.app.util.misc import get_iso_timestamp
if TYPE_CHECKING:
from invokeai.app.services.graph import GraphExecutionState
class ImageServiceABC(ABC): class ImageServiceABC(ABC):
""" """
High-level service for image management. High-level service for image management.
@ -59,7 +62,9 @@ class ImageServiceABC(ABC):
pass pass
@abstractmethod @abstractmethod
def get_url(self, image_type: ImageType, image_name: str, thumbnail: bool = False) -> str: def get_url(
self, image_type: ImageType, image_name: str, thumbnail: bool = False
) -> str:
"""Gets an image's or thumbnail's URL""" """Gets an image's or thumbnail's URL"""
pass pass
@ -113,6 +118,7 @@ class ImageServiceDependencies:
metadata: MetadataServiceBase metadata: MetadataServiceBase
urls: UrlServiceBase urls: UrlServiceBase
logger: Logger logger: Logger
graph_execution_manager: ItemStorageABC["GraphExecutionState"]
def __init__( def __init__(
self, self,
@ -121,12 +127,14 @@ class ImageServiceDependencies:
metadata: MetadataServiceBase, metadata: MetadataServiceBase,
url: UrlServiceBase, url: UrlServiceBase,
logger: Logger, logger: Logger,
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
): ):
self.records = image_record_storage self.records = image_record_storage
self.files = image_file_storage self.files = image_file_storage
self.metadata = metadata self.metadata = metadata
self.urls = url self.urls = url
self.logger = logger self.logger = logger
self.graph_execution_manager = graph_execution_manager
class ImageService(ImageServiceABC): class ImageService(ImageServiceABC):
@ -139,6 +147,7 @@ class ImageService(ImageServiceABC):
metadata: MetadataServiceBase, metadata: MetadataServiceBase,
url: UrlServiceBase, url: UrlServiceBase,
logger: Logger, logger: Logger,
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
): ):
self._services = ImageServiceDependencies( self._services = ImageServiceDependencies(
image_record_storage=image_record_storage, image_record_storage=image_record_storage,
@ -146,6 +155,7 @@ class ImageService(ImageServiceABC):
metadata=metadata, metadata=metadata,
url=url, url=url,
logger=logger, logger=logger,
graph_execution_manager=graph_execution_manager,
) )
def create( def create(
@ -155,7 +165,6 @@ class ImageService(ImageServiceABC):
image_category: ImageCategory, image_category: ImageCategory,
node_id: Optional[str] = None, node_id: Optional[str] = None,
session_id: Optional[str] = None, session_id: Optional[str] = None,
metadata: Optional[ImageMetadata] = None,
) -> ImageDTO: ) -> ImageDTO:
image_name = self._create_image_name( image_name = self._create_image_name(
image_type=image_type, image_type=image_type,
@ -165,12 +174,7 @@ class ImageService(ImageServiceABC):
) )
timestamp = get_iso_timestamp() timestamp = get_iso_timestamp()
metadata = self._get_metadata(session_id, node_id)
if metadata is not None:
pnginfo = PngImagePlugin.PngInfo()
pnginfo.add_text("invokeai", json.dumps(metadata))
else:
pnginfo = None
try: try:
# TODO: Consider using a transaction here to ensure consistency between storage and database # TODO: Consider using a transaction here to ensure consistency between storage and database
@ -178,7 +182,7 @@ class ImageService(ImageServiceABC):
image_type=image_type, image_type=image_type,
image_name=image_name, image_name=image_name,
image=image, image=image,
pnginfo=pnginfo, metadata=metadata,
) )
self._services.records.save( self._services.records.save(
@ -237,24 +241,6 @@ class ImageService(ImageServiceABC):
self._services.logger.error("Problem getting image record") self._services.logger.error("Problem getting image record")
raise e raise e
def get_path(
self, image_type: ImageType, image_name: str, thumbnail: bool = False
) -> str:
try:
return self._services.files.get_path(image_type, image_name, thumbnail)
except Exception as e:
self._services.logger.error("Problem getting image path")
raise e
def get_url(
self, image_type: ImageType, image_name: str, thumbnail: bool = False
) -> str:
try:
return self._services.urls.get_image_url(image_type, image_name, thumbnail)
except Exception as e:
self._services.logger.error("Problem getting image path")
raise e
def get_dto(self, image_type: ImageType, image_name: str) -> ImageDTO: def get_dto(self, image_type: ImageType, image_name: str) -> ImageDTO:
try: try:
image_record = self._services.records.get(image_type, image_name) image_record = self._services.records.get(image_type, image_name)
@ -273,6 +259,24 @@ class ImageService(ImageServiceABC):
self._services.logger.error("Problem getting image DTO") self._services.logger.error("Problem getting image DTO")
raise e raise e
def get_path(
self, image_type: ImageType, image_name: str, thumbnail: bool = False
) -> str:
try:
return self._services.files.get_path(image_type, image_name, thumbnail)
except Exception as e:
self._services.logger.error("Problem getting image path")
raise e
def get_url(
self, image_type: ImageType, image_name: str, thumbnail: bool = False
) -> str:
try:
return self._services.urls.get_image_url(image_type, image_name, thumbnail)
except Exception as e:
self._services.logger.error("Problem getting image path")
raise e
def get_many( def get_many(
self, self,
image_type: ImageType, image_type: ImageType,
@ -353,3 +357,15 @@ class ImageService(ImageServiceABC):
return f"{image_type.value}_{image_category.value}_{session_id}_{node_id}_{uuid_str}.png" return f"{image_type.value}_{image_category.value}_{session_id}_{node_id}_{uuid_str}.png"
return f"{image_type.value}_{image_category.value}_{uuid_str}.png" return f"{image_type.value}_{image_category.value}_{uuid_str}.png"
def _get_metadata(
self, session_id: Optional[str] = None, node_id: Optional[str] = None
) -> Union[ImageMetadata, None]:
"""Get the metadata for a node."""
metadata = None
if node_id is not None and session_id is not None:
session = self._services.graph_execution_manager.get(session_id)
metadata = self._services.metadata.create_image_metadata(session, node_id)
return metadata

View File

@ -1,295 +1,142 @@
import json import json
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict, Optional, TypedDict from typing import Any, Union
from PIL import Image, PngImagePlugin import networkx as nx
from pydantic import BaseModel
from invokeai.app.models.image import ImageType, is_image_type from invokeai.app.models.metadata import ImageMetadata
from invokeai.app.services.graph import Edge, Graph, GraphExecutionState
class MetadataImageField(TypedDict):
"""Pydantic-less ImageField, used for metadata parsing."""
image_type: ImageType
image_name: str
class MetadataLatentsField(TypedDict):
"""Pydantic-less LatentsField, used for metadata parsing."""
latents_name: str
class MetadataColorField(TypedDict):
"""Pydantic-less ColorField, used for metadata parsing"""
r: int
g: int
b: int
a: int
# TODO: This is a placeholder for `InvocationsUnion` pending resolution of circular imports
NodeMetadata = Dict[
str,
None
| str
| int
| float
| bool
| MetadataImageField
| MetadataLatentsField
| MetadataColorField,
]
class InvokeAIMetadata(TypedDict, total=False):
"""InvokeAI-specific metadata format."""
session_id: Optional[str]
node: Optional[NodeMetadata]
def build_invokeai_metadata_pnginfo(
metadata: InvokeAIMetadata | None,
) -> PngImagePlugin.PngInfo:
"""Builds a PngInfo object with key `"invokeai"` and value `metadata`"""
pnginfo = PngImagePlugin.PngInfo()
if metadata is not None:
pnginfo.add_text("invokeai", json.dumps(metadata))
return pnginfo
class MetadataServiceBase(ABC): class MetadataServiceBase(ABC):
@abstractmethod """Handles building metadata for nodes, images, and outputs."""
def get_metadata(self, image: Image.Image) -> InvokeAIMetadata | None:
"""Gets the InvokeAI metadata from a PIL Image, skipping invalid values"""
pass
@abstractmethod @abstractmethod
def build_metadata( def create_image_metadata(
self, session_id: str, node: BaseModel self, session: GraphExecutionState, node_id: str
) -> InvokeAIMetadata | None: ) -> ImageMetadata:
"""Builds an InvokeAIMetadata object""" """Builds an ImageMetadata object for a node."""
pass pass
# @abstractmethod
# def create_metadata(self, session_id: str, node_id: str) -> dict:
# """Creates metadata for a result"""
# pass
class PngMetadataService(MetadataServiceBase): class CoreMetadataService(MetadataServiceBase):
"""Handles loading and building metadata for images.""" _ANCESTOR_TYPES = ["t2l", "l2l"]
"""The ancestor types that contain the core metadata"""
# TODO: Use `InvocationsUnion` to **validate** metadata as representing a fully-functioning node _ANCESTOR_PARAMS = ["type", "steps", "model", "cfg_scale", "scheduler", "strength"]
def _load_metadata(self, image: Image.Image) -> dict | None: """The core metadata parameters in the ancestor types"""
"""Loads a specific info entry from a PIL Image."""
try: _NOISE_FIELDS = ["seed", "width", "height"]
info = image.info.get("invokeai") """The core metadata parameters in the noise node"""
if type(info) is not str: def create_image_metadata(
return None self, session: GraphExecutionState, node_id: str
) -> ImageMetadata:
loaded_metadata = json.loads(info) metadata = self._build_metadata_from_graph(session, node_id)
if type(loaded_metadata) is not dict:
return None
if len(loaded_metadata.items()) == 0:
return None
return loaded_metadata
except:
return None
def get_metadata(self, image: Image.Image) -> dict | None:
"""Retrieves an image's metadata as a dict"""
loaded_metadata = self._load_metadata(image)
return loaded_metadata
def build_metadata(self, session_id: str, node: BaseModel) -> InvokeAIMetadata:
metadata = InvokeAIMetadata(session_id=session_id, node=node.dict())
return metadata return metadata
def _find_nearest_ancestor(self, G: nx.DiGraph, node_id: str) -> Union[str, None]:
"""
Finds the id of the nearest ancestor (of a valid type) of a given node.
from enum import Enum Parameters:
G (nx.DiGraph): The execution graph, converted in to a networkx DiGraph. Its nodes must
have the same data as the execution graph.
node_id (str): The ID of the node.
from abc import ABC, abstractmethod Returns:
import json str | None: The ID of the nearest ancestor, or None if there are no valid ancestors.
import sqlite3 """
from threading import Lock
from typing import Any, Union
import networkx as nx # Retrieve the node from the graph
node = G.nodes[node_id]
from pydantic import BaseModel, Field, parse_obj_as, parse_raw_as # If the node type is one of the core metadata node types, return its id
from invokeai.app.invocations.image import ImageOutput if node.get("type") in self._ANCESTOR_TYPES:
from invokeai.app.services.graph import Edge, GraphExecutionState return node.get("id")
from invokeai.app.invocations.latent import LatentsOutput
from invokeai.app.services.item_storage import PaginatedResults
from invokeai.app.util.misc import get_timestamp
# Else, look for the ancestor in the predecessor nodes
for predecessor in G.predecessors(node_id):
result = self._find_nearest_ancestor(G, predecessor)
if result:
return result
class ResultType(str, Enum): # If there are no valid ancestors, return None
image_output = "image_output" return None
latents_output = "latents_output"
def _get_additional_metadata(
self, graph: Graph, node_id: str
) -> Union[dict[str, Any], None]:
"""
Returns additional metadata for a given node.
class Result(BaseModel): Parameters:
"""A session result""" graph (Graph): The execution graph.
node_id (str): The ID of the node.
id: str = Field(description="Result ID") Returns:
session_id: str = Field(description="Session ID") dict[str, Any] | None: A dictionary of additional metadata.
node_id: str = Field(description="Node ID") """
data: Union[LatentsOutput, ImageOutput] = Field(description="The result data")
metadata = {}
class ResultWithSession(BaseModel): # Iterate over all edges in the graph
"""A result with its session""" for edge in graph.edges:
dest_node_id = edge.destination.node_id
dest_field = edge.destination.field
source_node_dict = graph.nodes[edge.source.node_id].dict()
result: Result = Field(description="The result") # If the destination node ID matches the given node ID, gather necessary metadata
session: GraphExecutionState = Field(description="The session") if dest_node_id == node_id:
# If the destination field is 'positive_conditioning', add the 'prompt' from the source node
if dest_field == "positive_conditioning":
metadata["positive_conditioning"] = source_node_dict.get("prompt")
# If the destination field is 'negative_conditioning', add the 'prompt' from the source node
if dest_field == "negative_conditioning":
metadata["negative_conditioning"] = source_node_dict.get("prompt")
# If the destination field is 'noise', add the core noise fields from the source node
if dest_field == "noise":
for field in self._NOISE_FIELDS:
metadata[field] = source_node_dict.get(field)
return metadata
def _build_metadata_from_graph(
self, session: GraphExecutionState, node_id: str
) -> ImageMetadata:
"""
Builds an ImageMetadata object for a node.
# # Create a directed graph Parameters:
# from typing import Any, TypedDict, Union session (GraphExecutionState): The session.
# from networkx import DiGraph node_id (str): The ID of the node.
# import networkx as nx
# import json
Returns:
ImageMetadata: The metadata for the node.
"""
# # We need to use a loose class for nodes to allow for graceful parsing - we cannot use the stricter graph = session.execution_graph
# # model used by the system, because we may be a graph in an old format. We can, however, use the
# # Edge model, because the edge format does not change.
# class LooseGraph(BaseModel):
# id: str
# nodes: dict[str, dict[str, Any]]
# edges: list[Edge]
# Find the nearest ancestor of the given node
ancestor_id = self._find_nearest_ancestor(graph.nx_graph_with_data(), node_id)
# # An intermediate type used during parsing # If no ancestor was found, return an empty ImageMetadata object
# class NearestAncestor(TypedDict): if ancestor_id is None:
# node_id: str return ImageMetadata()
# metadata: dict[str, Any]
ancestor_node = graph.get_node(ancestor_id)
# # The ancestor types that contain the core metadata ancestor_metadata = {
# ANCESTOR_TYPES = ['t2l', 'l2l'] param: val
for param, val in ancestor_node.dict().items()
if param in self._ANCESTOR_PARAMS
}
# # The core metadata parameters in the ancestor types # Get additional metadata related to the ancestor
# ANCESTOR_PARAMS = ['steps', 'model', 'cfg_scale', 'scheduler', 'strength'] addl_metadata = self._get_additional_metadata(graph, ancestor_id)
# # The core metadata parameters in the noise node # If additional metadata was found, add it to the main metadata
# NOISE_FIELDS = ['seed', 'width', 'height'] if addl_metadata is not None:
ancestor_metadata.update(addl_metadata)
# # Find nearest t2l or l2l ancestor from a given l2i node return ImageMetadata(**ancestor_metadata)
# def find_nearest_ancestor(G: DiGraph, node_id: str) -> Union[NearestAncestor, None]:
# """Returns metadata for the nearest ancestor of a given node.
# Parameters:
# G (DiGraph): A directed graph.
# node_id (str): The ID of the starting node.
# Returns:
# NearestAncestor | None: An object with the ID and metadata of the nearest ancestor.
# """
# # Retrieve the node from the graph
# node = G.nodes[node_id]
# # If the node type is one of the core metadata node types, gather necessary metadata and return
# if node.get('type') in ANCESTOR_TYPES:
# parsed_metadata = {param: val for param, val in node.items() if param in ANCESTOR_PARAMS}
# return NearestAncestor(node_id=node_id, metadata=parsed_metadata)
# # Else, look for the ancestor in the predecessor nodes
# for predecessor in G.predecessors(node_id):
# result = find_nearest_ancestor(G, predecessor)
# if result:
# return result
# # If there are no valid ancestors, return None
# return None
# def get_additional_metadata(graph: LooseGraph, node_id: str) -> Union[dict[str, Any], None]:
# """Collects additional metadata from nodes connected to a given node.
# Parameters:
# graph (LooseGraph): The graph.
# node_id (str): The ID of the node.
# Returns:
# dict | None: A dictionary containing additional metadata.
# """
# metadata = {}
# # Iterate over all edges in the graph
# for edge in graph.edges:
# dest_node_id = edge.destination.node_id
# dest_field = edge.destination.field
# source_node = graph.nodes[edge.source.node_id]
# # If the destination node ID matches the given node ID, gather necessary metadata
# if dest_node_id == node_id:
# # If the destination field is 'positive_conditioning', add the 'prompt' from the source node
# if dest_field == 'positive_conditioning':
# metadata['positive_conditioning'] = source_node.get('prompt')
# # If the destination field is 'negative_conditioning', add the 'prompt' from the source node
# if dest_field == 'negative_conditioning':
# metadata['negative_conditioning'] = source_node.get('prompt')
# # If the destination field is 'noise', add the core noise fields from the source node
# if dest_field == 'noise':
# for field in NOISE_FIELDS:
# metadata[field] = source_node.get(field)
# return metadata
# def build_core_metadata(graph_raw: str, node_id: str) -> Union[dict, None]:
# """Builds the core metadata for a given node.
# Parameters:
# graph_raw (str): The graph structure as a raw string.
# node_id (str): The ID of the node.
# Returns:
# dict | None: A dictionary containing core metadata.
# """
# # Create a directed graph to facilitate traversal
# G = nx.DiGraph()
# # Convert the raw graph string into a JSON object
# graph = parse_obj_as(LooseGraph, graph_raw)
# # Add nodes and edges to the graph
# for node_id, node_data in graph.nodes.items():
# G.add_node(node_id, **node_data)
# for edge in graph.edges:
# G.add_edge(edge.source.node_id, edge.destination.node_id)
# # Find the nearest ancestor of the given node
# ancestor = find_nearest_ancestor(G, node_id)
# # If no ancestor was found, return None
# if ancestor is None:
# return None
# metadata = ancestor['metadata']
# ancestor_id = ancestor['node_id']
# # Get additional metadata related to the ancestor
# addl_metadata = get_additional_metadata(graph, ancestor_id)
# # If additional metadata was found, add it to the main metadata
# if addl_metadata is not None:
# metadata.update(addl_metadata)
# return metadata

View File

@ -62,9 +62,12 @@ def deserialize_image_record(image_row: sqlite3.Row) -> ImageRecord:
image_type = ImageType(image_dict.get("image_type", ImageType.RESULT.value)) image_type = ImageType(image_dict.get("image_type", ImageType.RESULT.value))
raw_metadata = image_dict.get("metadata", "{}") raw_metadata = image_dict.get("metadata")
metadata = ImageMetadata.parse_raw(raw_metadata) if raw_metadata is not None:
metadata = ImageMetadata.parse_raw(raw_metadata)
else:
metadata = None
return ImageRecord( return ImageRecord(
image_name=image_dict.get("id", "unknown"), image_name=image_dict.get("id", "unknown"),