mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(nodes): add metadata handling
This commit is contained in:
parent
f071b03ceb
commit
5de3c41d19
@ -4,6 +4,7 @@ from logging import Logger
|
||||
import os
|
||||
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
|
||||
from invokeai.app.services.images import ImageService
|
||||
from invokeai.app.services.metadata import CoreMetadataService
|
||||
from invokeai.app.services.urls import LocalUrlService
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
@ -18,7 +19,6 @@ from ..services.invocation_services import InvocationServices
|
||||
from ..services.invoker import Invoker
|
||||
from ..services.processor import DefaultInvocationProcessor
|
||||
from ..services.sqlite import SqliteItemStorage
|
||||
from ..services.metadata import PngMetadataService
|
||||
from .events import FastAPIEventService
|
||||
|
||||
|
||||
@ -59,7 +59,7 @@ class ApiDependencies:
|
||||
DiskLatentsStorage(f"{output_folder}/latents")
|
||||
)
|
||||
|
||||
metadata = PngMetadataService()
|
||||
metadata = CoreMetadataService()
|
||||
|
||||
urls = LocalUrlService()
|
||||
|
||||
@ -80,6 +80,7 @@ class ApiDependencies:
|
||||
metadata=metadata,
|
||||
url=urls,
|
||||
logger=logger,
|
||||
graph_execution_manager=graph_execution_manager,
|
||||
)
|
||||
|
||||
services = InvocationServices(
|
||||
|
@ -2,7 +2,6 @@ from typing import Optional
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.models.image import ImageType
|
||||
from invokeai.app.services.metadata import InvokeAIMetadata
|
||||
|
||||
|
||||
class ImageResponseMetadata(BaseModel):
|
||||
@ -11,9 +10,9 @@ class ImageResponseMetadata(BaseModel):
|
||||
created: int = Field(description="The creation timestamp of the image")
|
||||
width: int = Field(description="The width of the image in pixels")
|
||||
height: int = Field(description="The height of the image in pixels")
|
||||
invokeai: Optional[InvokeAIMetadata] = Field(
|
||||
description="The image's InvokeAI-specific metadata"
|
||||
)
|
||||
# invokeai: Optional[InvokeAIMetadata] = Field(
|
||||
# description="The image's InvokeAI-specific metadata"
|
||||
# )
|
||||
|
||||
|
||||
class ImageResponse(BaseModel):
|
||||
|
@ -7,6 +7,7 @@ from pydantic import BaseModel, Field
|
||||
import torch
|
||||
|
||||
from invokeai.app.invocations.util.choose_model import choose_model
|
||||
from invokeai.app.models.image import ImageCategory
|
||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||
|
||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||
@ -356,20 +357,30 @@ class LatentsToImageInvocation(BaseInvocation):
|
||||
np_image = model.decode_latents(latents)
|
||||
image = model.numpy_to_pil(np_image)[0]
|
||||
|
||||
image_type = ImageType.RESULT
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
# image_type = ImageType.RESULT
|
||||
# image_name = context.services.images.create_name(
|
||||
# context.graph_execution_state_id, self.id
|
||||
# )
|
||||
|
||||
# metadata = context.services.metadata.build_metadata(
|
||||
# session_id=context.graph_execution_state_id, node=self
|
||||
# )
|
||||
|
||||
# torch.cuda.empty_cache()
|
||||
|
||||
# context.services.images.save(image_type, image_name, image, metadata)
|
||||
image_dto = context.services.images_new.create(
|
||||
image=image,
|
||||
image_type=ImageType.RESULT,
|
||||
image_category=ImageCategory.IMAGE,
|
||||
session_id=context.graph_execution_state_id,
|
||||
node_id=self.id,
|
||||
)
|
||||
|
||||
metadata = context.services.metadata.build_metadata(
|
||||
session_id=context.graph_execution_state_id, node=self
|
||||
)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
context.services.images.save(image_type, image_name, image, metadata)
|
||||
return build_image_output(
|
||||
image_type=image_type, image_name=image_name, image=image
|
||||
image_type=image_dto.image_type,
|
||||
image_name=image_dto.image_name,
|
||||
image=image,
|
||||
)
|
||||
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel, Field, StrictFloat, StrictInt, StrictStr
|
||||
from pydantic import BaseModel, Extra, Field, StrictFloat, StrictInt, StrictStr
|
||||
|
||||
|
||||
class ImageMetadata(BaseModel):
|
||||
@ -8,11 +8,24 @@ class ImageMetadata(BaseModel):
|
||||
|
||||
Also includes any metadata from the image's PNG tEXt chunks.
|
||||
|
||||
Generated by traversing the execution graph, collecting the parameters of the nearest ancestors of a given node.
|
||||
Generated by traversing the execution graph, collecting the parameters of the nearest ancestors
|
||||
of a given node.
|
||||
|
||||
Full metadata may be accessed by querying for the session in the `graph_executions` table.
|
||||
"""
|
||||
|
||||
class Config:
|
||||
extra = Extra.allow
|
||||
"""
|
||||
This lets the ImageMetadata class accept arbitrary additional fields. The CoreMetadataService
|
||||
won't add any fields that are not already defined, but other a different metadata service
|
||||
implementation might.
|
||||
"""
|
||||
|
||||
type: Optional[StrictStr] = Field(
|
||||
default=None,
|
||||
description="The type of the ancestor node of the image output node.",
|
||||
)
|
||||
positive_conditioning: Optional[StrictStr] = Field(
|
||||
default=None, description="The positive conditioning."
|
||||
)
|
||||
@ -20,10 +33,10 @@ class ImageMetadata(BaseModel):
|
||||
default=None, description="The negative conditioning."
|
||||
)
|
||||
width: Optional[StrictInt] = Field(
|
||||
default=None, description="Width of the image/tensor in pixels."
|
||||
default=None, description="Width of the image/latents in pixels."
|
||||
)
|
||||
height: Optional[StrictInt] = Field(
|
||||
default=None, description="Height of the image/tensor in pixels."
|
||||
default=None, description="Height of the image/latents in pixels."
|
||||
)
|
||||
seed: Optional[StrictInt] = Field(
|
||||
default=None, description="The seed used for noise generation."
|
||||
@ -42,18 +55,21 @@ class ImageMetadata(BaseModel):
|
||||
)
|
||||
strength: Optional[StrictFloat] = Field(
|
||||
default=None,
|
||||
description="The strength used for image-to-image/tensor-to-tensor.",
|
||||
description="The strength used for image-to-image/latents-to-latents.",
|
||||
)
|
||||
image: Optional[StrictStr] = Field(
|
||||
default=None, description="The ID of the initial image."
|
||||
latents: Optional[StrictStr] = Field(
|
||||
default=None, description="The ID of the initial latents."
|
||||
)
|
||||
tensor: Optional[StrictStr] = Field(
|
||||
default=None, description="The ID of the initial tensor."
|
||||
vae: Optional[StrictStr] = Field(
|
||||
default=None, description="The VAE used for decoding."
|
||||
)
|
||||
unet: Optional[StrictStr] = Field(
|
||||
default=None, description="The UNet used dor inference."
|
||||
)
|
||||
clip: Optional[StrictStr] = Field(
|
||||
default=None, description="The CLIP Encoder used for conditioning."
|
||||
)
|
||||
# Pending model refactor:
|
||||
# vae: Optional[str] = Field(default=None,description="The VAE used for decoding.")
|
||||
# unet: Optional[str] = Field(default=None,description="The UNet used dor inference.")
|
||||
# clip: Optional[str] = Field(default=None,description="The CLIP Encoder used for conditioning.")
|
||||
extra: Optional[StrictStr] = Field(
|
||||
default=None, description="Extra metadata, extracted from the PNG tEXt chunk."
|
||||
default=None,
|
||||
description="Uploaded image metadata, extracted from the PNG tEXt chunk.",
|
||||
)
|
||||
|
@ -713,6 +713,13 @@ class Graph(BaseModel):
|
||||
g.add_edges_from(set([(e.source.node_id, e.destination.node_id) for e in self.edges]))
|
||||
return g
|
||||
|
||||
def nx_graph_with_data(self) -> nx.DiGraph:
|
||||
"""Returns a NetworkX DiGraph representing the data and layout of this graph"""
|
||||
g = nx.DiGraph()
|
||||
g.add_nodes_from([n for n in self.nodes.items()])
|
||||
g.add_edges_from(set([(e.source.node_id, e.destination.node_id) for e in self.edges]))
|
||||
return g
|
||||
|
||||
def nx_graph_flat(
|
||||
self, nx_graph: Optional[nx.DiGraph] = None, prefix: Optional[str] = None
|
||||
) -> nx.DiGraph:
|
||||
|
@ -6,11 +6,11 @@ from queue import Queue
|
||||
from typing import Dict, Optional
|
||||
|
||||
from PIL.Image import Image as PILImageType
|
||||
from PIL import Image
|
||||
from PIL.PngImagePlugin import PngInfo
|
||||
from PIL import Image, PngImagePlugin
|
||||
from send2trash import send2trash
|
||||
|
||||
from invokeai.app.models.image import ImageType
|
||||
from invokeai.app.models.metadata import ImageMetadata
|
||||
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
|
||||
|
||||
|
||||
@ -54,7 +54,7 @@ class ImageFileStorageBase(ABC):
|
||||
image: PILImageType,
|
||||
image_type: ImageType,
|
||||
image_name: str,
|
||||
pnginfo: Optional[PngInfo] = None,
|
||||
metadata: Optional[ImageMetadata] = None,
|
||||
thumbnail_size: int = 256,
|
||||
) -> None:
|
||||
"""Saves an image and a 256x256 WEBP thumbnail. Returns a tuple of the image name, thumbnail name, and created timestamp."""
|
||||
@ -109,12 +109,18 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
||||
image: PILImageType,
|
||||
image_type: ImageType,
|
||||
image_name: str,
|
||||
pnginfo: Optional[PngInfo] = None,
|
||||
metadata: Optional[ImageMetadata] = None,
|
||||
thumbnail_size: int = 256,
|
||||
) -> None:
|
||||
try:
|
||||
image_path = self.get_path(image_type, image_name)
|
||||
image.save(image_path, "PNG", pnginfo=pnginfo)
|
||||
|
||||
if metadata is not None:
|
||||
pnginfo = PngImagePlugin.PngInfo()
|
||||
pnginfo.add_text("invokeai", metadata.json())
|
||||
image.save(image_path, "PNG", pnginfo=pnginfo)
|
||||
else:
|
||||
image.save(image_path, "PNG")
|
||||
|
||||
thumbnail_name = get_thumbnail_name(image_name)
|
||||
thumbnail_path = self.get_path(image_type, thumbnail_name, thumbnail=True)
|
||||
|
@ -1,10 +1,9 @@
|
||||
from abc import ABC, abstractmethod
|
||||
import json
|
||||
from logging import Logger
|
||||
from typing import Optional, Union
|
||||
from typing import Optional, TYPE_CHECKING, Union
|
||||
import uuid
|
||||
from PIL.Image import Image as PILImageType
|
||||
from PIL import PngImagePlugin
|
||||
|
||||
from invokeai.app.models.image import ImageCategory, ImageType
|
||||
from invokeai.app.models.metadata import ImageMetadata
|
||||
@ -17,12 +16,16 @@ from invokeai.app.services.models.image_record import (
|
||||
image_record_to_dto,
|
||||
)
|
||||
from invokeai.app.services.image_file_storage import ImageFileStorageBase
|
||||
from invokeai.app.services.item_storage import PaginatedResults
|
||||
from invokeai.app.services.item_storage import ItemStorageABC, PaginatedResults
|
||||
from invokeai.app.services.metadata import MetadataServiceBase
|
||||
from invokeai.app.services.urls import UrlServiceBase
|
||||
from invokeai.app.util.misc import get_iso_timestamp
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.services.graph import GraphExecutionState
|
||||
|
||||
|
||||
class ImageServiceABC(ABC):
|
||||
"""
|
||||
High-level service for image management.
|
||||
@ -59,7 +62,9 @@ class ImageServiceABC(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_url(self, image_type: ImageType, image_name: str, thumbnail: bool = False) -> str:
|
||||
def get_url(
|
||||
self, image_type: ImageType, image_name: str, thumbnail: bool = False
|
||||
) -> str:
|
||||
"""Gets an image's or thumbnail's URL"""
|
||||
pass
|
||||
|
||||
@ -113,6 +118,7 @@ class ImageServiceDependencies:
|
||||
metadata: MetadataServiceBase
|
||||
urls: UrlServiceBase
|
||||
logger: Logger
|
||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -121,12 +127,14 @@ class ImageServiceDependencies:
|
||||
metadata: MetadataServiceBase,
|
||||
url: UrlServiceBase,
|
||||
logger: Logger,
|
||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
|
||||
):
|
||||
self.records = image_record_storage
|
||||
self.files = image_file_storage
|
||||
self.metadata = metadata
|
||||
self.urls = url
|
||||
self.logger = logger
|
||||
self.graph_execution_manager = graph_execution_manager
|
||||
|
||||
|
||||
class ImageService(ImageServiceABC):
|
||||
@ -139,6 +147,7 @@ class ImageService(ImageServiceABC):
|
||||
metadata: MetadataServiceBase,
|
||||
url: UrlServiceBase,
|
||||
logger: Logger,
|
||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
|
||||
):
|
||||
self._services = ImageServiceDependencies(
|
||||
image_record_storage=image_record_storage,
|
||||
@ -146,6 +155,7 @@ class ImageService(ImageServiceABC):
|
||||
metadata=metadata,
|
||||
url=url,
|
||||
logger=logger,
|
||||
graph_execution_manager=graph_execution_manager,
|
||||
)
|
||||
|
||||
def create(
|
||||
@ -155,7 +165,6 @@ class ImageService(ImageServiceABC):
|
||||
image_category: ImageCategory,
|
||||
node_id: Optional[str] = None,
|
||||
session_id: Optional[str] = None,
|
||||
metadata: Optional[ImageMetadata] = None,
|
||||
) -> ImageDTO:
|
||||
image_name = self._create_image_name(
|
||||
image_type=image_type,
|
||||
@ -165,12 +174,7 @@ class ImageService(ImageServiceABC):
|
||||
)
|
||||
|
||||
timestamp = get_iso_timestamp()
|
||||
|
||||
if metadata is not None:
|
||||
pnginfo = PngImagePlugin.PngInfo()
|
||||
pnginfo.add_text("invokeai", json.dumps(metadata))
|
||||
else:
|
||||
pnginfo = None
|
||||
metadata = self._get_metadata(session_id, node_id)
|
||||
|
||||
try:
|
||||
# TODO: Consider using a transaction here to ensure consistency between storage and database
|
||||
@ -178,7 +182,7 @@ class ImageService(ImageServiceABC):
|
||||
image_type=image_type,
|
||||
image_name=image_name,
|
||||
image=image,
|
||||
pnginfo=pnginfo,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
self._services.records.save(
|
||||
@ -237,24 +241,6 @@ class ImageService(ImageServiceABC):
|
||||
self._services.logger.error("Problem getting image record")
|
||||
raise e
|
||||
|
||||
def get_path(
|
||||
self, image_type: ImageType, image_name: str, thumbnail: bool = False
|
||||
) -> str:
|
||||
try:
|
||||
return self._services.files.get_path(image_type, image_name, thumbnail)
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem getting image path")
|
||||
raise e
|
||||
|
||||
def get_url(
|
||||
self, image_type: ImageType, image_name: str, thumbnail: bool = False
|
||||
) -> str:
|
||||
try:
|
||||
return self._services.urls.get_image_url(image_type, image_name, thumbnail)
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem getting image path")
|
||||
raise e
|
||||
|
||||
def get_dto(self, image_type: ImageType, image_name: str) -> ImageDTO:
|
||||
try:
|
||||
image_record = self._services.records.get(image_type, image_name)
|
||||
@ -273,6 +259,24 @@ class ImageService(ImageServiceABC):
|
||||
self._services.logger.error("Problem getting image DTO")
|
||||
raise e
|
||||
|
||||
def get_path(
|
||||
self, image_type: ImageType, image_name: str, thumbnail: bool = False
|
||||
) -> str:
|
||||
try:
|
||||
return self._services.files.get_path(image_type, image_name, thumbnail)
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem getting image path")
|
||||
raise e
|
||||
|
||||
def get_url(
|
||||
self, image_type: ImageType, image_name: str, thumbnail: bool = False
|
||||
) -> str:
|
||||
try:
|
||||
return self._services.urls.get_image_url(image_type, image_name, thumbnail)
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem getting image path")
|
||||
raise e
|
||||
|
||||
def get_many(
|
||||
self,
|
||||
image_type: ImageType,
|
||||
@ -353,3 +357,15 @@ class ImageService(ImageServiceABC):
|
||||
return f"{image_type.value}_{image_category.value}_{session_id}_{node_id}_{uuid_str}.png"
|
||||
|
||||
return f"{image_type.value}_{image_category.value}_{uuid_str}.png"
|
||||
|
||||
def _get_metadata(
|
||||
self, session_id: Optional[str] = None, node_id: Optional[str] = None
|
||||
) -> Union[ImageMetadata, None]:
|
||||
"""Get the metadata for a node."""
|
||||
metadata = None
|
||||
|
||||
if node_id is not None and session_id is not None:
|
||||
session = self._services.graph_execution_manager.get(session_id)
|
||||
metadata = self._services.metadata.create_image_metadata(session, node_id)
|
||||
|
||||
return metadata
|
||||
|
@ -1,295 +1,142 @@
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Optional, TypedDict
|
||||
from PIL import Image, PngImagePlugin
|
||||
from pydantic import BaseModel
|
||||
from typing import Any, Union
|
||||
import networkx as nx
|
||||
|
||||
from invokeai.app.models.image import ImageType, is_image_type
|
||||
|
||||
|
||||
class MetadataImageField(TypedDict):
|
||||
"""Pydantic-less ImageField, used for metadata parsing."""
|
||||
|
||||
image_type: ImageType
|
||||
image_name: str
|
||||
|
||||
|
||||
class MetadataLatentsField(TypedDict):
|
||||
"""Pydantic-less LatentsField, used for metadata parsing."""
|
||||
|
||||
latents_name: str
|
||||
|
||||
|
||||
class MetadataColorField(TypedDict):
|
||||
"""Pydantic-less ColorField, used for metadata parsing"""
|
||||
|
||||
r: int
|
||||
g: int
|
||||
b: int
|
||||
a: int
|
||||
|
||||
|
||||
# TODO: This is a placeholder for `InvocationsUnion` pending resolution of circular imports
|
||||
NodeMetadata = Dict[
|
||||
str,
|
||||
None
|
||||
| str
|
||||
| int
|
||||
| float
|
||||
| bool
|
||||
| MetadataImageField
|
||||
| MetadataLatentsField
|
||||
| MetadataColorField,
|
||||
]
|
||||
|
||||
|
||||
class InvokeAIMetadata(TypedDict, total=False):
|
||||
"""InvokeAI-specific metadata format."""
|
||||
|
||||
session_id: Optional[str]
|
||||
node: Optional[NodeMetadata]
|
||||
|
||||
|
||||
def build_invokeai_metadata_pnginfo(
|
||||
metadata: InvokeAIMetadata | None,
|
||||
) -> PngImagePlugin.PngInfo:
|
||||
"""Builds a PngInfo object with key `"invokeai"` and value `metadata`"""
|
||||
pnginfo = PngImagePlugin.PngInfo()
|
||||
|
||||
if metadata is not None:
|
||||
pnginfo.add_text("invokeai", json.dumps(metadata))
|
||||
|
||||
return pnginfo
|
||||
from invokeai.app.models.metadata import ImageMetadata
|
||||
from invokeai.app.services.graph import Edge, Graph, GraphExecutionState
|
||||
|
||||
|
||||
class MetadataServiceBase(ABC):
|
||||
@abstractmethod
|
||||
def get_metadata(self, image: Image.Image) -> InvokeAIMetadata | None:
|
||||
"""Gets the InvokeAI metadata from a PIL Image, skipping invalid values"""
|
||||
pass
|
||||
"""Handles building metadata for nodes, images, and outputs."""
|
||||
|
||||
@abstractmethod
|
||||
def build_metadata(
|
||||
self, session_id: str, node: BaseModel
|
||||
) -> InvokeAIMetadata | None:
|
||||
"""Builds an InvokeAIMetadata object"""
|
||||
def create_image_metadata(
|
||||
self, session: GraphExecutionState, node_id: str
|
||||
) -> ImageMetadata:
|
||||
"""Builds an ImageMetadata object for a node."""
|
||||
pass
|
||||
|
||||
# @abstractmethod
|
||||
# def create_metadata(self, session_id: str, node_id: str) -> dict:
|
||||
# """Creates metadata for a result"""
|
||||
# pass
|
||||
|
||||
|
||||
class PngMetadataService(MetadataServiceBase):
|
||||
"""Handles loading and building metadata for images."""
|
||||
class CoreMetadataService(MetadataServiceBase):
|
||||
_ANCESTOR_TYPES = ["t2l", "l2l"]
|
||||
"""The ancestor types that contain the core metadata"""
|
||||
|
||||
# TODO: Use `InvocationsUnion` to **validate** metadata as representing a fully-functioning node
|
||||
def _load_metadata(self, image: Image.Image) -> dict | None:
|
||||
"""Loads a specific info entry from a PIL Image."""
|
||||
_ANCESTOR_PARAMS = ["type", "steps", "model", "cfg_scale", "scheduler", "strength"]
|
||||
"""The core metadata parameters in the ancestor types"""
|
||||
|
||||
try:
|
||||
info = image.info.get("invokeai")
|
||||
_NOISE_FIELDS = ["seed", "width", "height"]
|
||||
"""The core metadata parameters in the noise node"""
|
||||
|
||||
if type(info) is not str:
|
||||
return None
|
||||
|
||||
loaded_metadata = json.loads(info)
|
||||
|
||||
if type(loaded_metadata) is not dict:
|
||||
return None
|
||||
|
||||
if len(loaded_metadata.items()) == 0:
|
||||
return None
|
||||
|
||||
return loaded_metadata
|
||||
except:
|
||||
return None
|
||||
|
||||
def get_metadata(self, image: Image.Image) -> dict | None:
|
||||
"""Retrieves an image's metadata as a dict"""
|
||||
loaded_metadata = self._load_metadata(image)
|
||||
|
||||
return loaded_metadata
|
||||
|
||||
def build_metadata(self, session_id: str, node: BaseModel) -> InvokeAIMetadata:
|
||||
metadata = InvokeAIMetadata(session_id=session_id, node=node.dict())
|
||||
def create_image_metadata(
|
||||
self, session: GraphExecutionState, node_id: str
|
||||
) -> ImageMetadata:
|
||||
metadata = self._build_metadata_from_graph(session, node_id)
|
||||
|
||||
return metadata
|
||||
|
||||
def _find_nearest_ancestor(self, G: nx.DiGraph, node_id: str) -> Union[str, None]:
|
||||
"""
|
||||
Finds the id of the nearest ancestor (of a valid type) of a given node.
|
||||
|
||||
from enum import Enum
|
||||
Parameters:
|
||||
G (nx.DiGraph): The execution graph, converted in to a networkx DiGraph. Its nodes must
|
||||
have the same data as the execution graph.
|
||||
node_id (str): The ID of the node.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
import json
|
||||
import sqlite3
|
||||
from threading import Lock
|
||||
from typing import Any, Union
|
||||
Returns:
|
||||
str | None: The ID of the nearest ancestor, or None if there are no valid ancestors.
|
||||
"""
|
||||
|
||||
import networkx as nx
|
||||
# Retrieve the node from the graph
|
||||
node = G.nodes[node_id]
|
||||
|
||||
from pydantic import BaseModel, Field, parse_obj_as, parse_raw_as
|
||||
from invokeai.app.invocations.image import ImageOutput
|
||||
from invokeai.app.services.graph import Edge, GraphExecutionState
|
||||
from invokeai.app.invocations.latent import LatentsOutput
|
||||
from invokeai.app.services.item_storage import PaginatedResults
|
||||
from invokeai.app.util.misc import get_timestamp
|
||||
# If the node type is one of the core metadata node types, return its id
|
||||
if node.get("type") in self._ANCESTOR_TYPES:
|
||||
return node.get("id")
|
||||
|
||||
# Else, look for the ancestor in the predecessor nodes
|
||||
for predecessor in G.predecessors(node_id):
|
||||
result = self._find_nearest_ancestor(G, predecessor)
|
||||
if result:
|
||||
return result
|
||||
|
||||
class ResultType(str, Enum):
|
||||
image_output = "image_output"
|
||||
latents_output = "latents_output"
|
||||
# If there are no valid ancestors, return None
|
||||
return None
|
||||
|
||||
def _get_additional_metadata(
|
||||
self, graph: Graph, node_id: str
|
||||
) -> Union[dict[str, Any], None]:
|
||||
"""
|
||||
Returns additional metadata for a given node.
|
||||
|
||||
class Result(BaseModel):
|
||||
"""A session result"""
|
||||
Parameters:
|
||||
graph (Graph): The execution graph.
|
||||
node_id (str): The ID of the node.
|
||||
|
||||
id: str = Field(description="Result ID")
|
||||
session_id: str = Field(description="Session ID")
|
||||
node_id: str = Field(description="Node ID")
|
||||
data: Union[LatentsOutput, ImageOutput] = Field(description="The result data")
|
||||
Returns:
|
||||
dict[str, Any] | None: A dictionary of additional metadata.
|
||||
"""
|
||||
|
||||
metadata = {}
|
||||
|
||||
class ResultWithSession(BaseModel):
|
||||
"""A result with its session"""
|
||||
# Iterate over all edges in the graph
|
||||
for edge in graph.edges:
|
||||
dest_node_id = edge.destination.node_id
|
||||
dest_field = edge.destination.field
|
||||
source_node_dict = graph.nodes[edge.source.node_id].dict()
|
||||
|
||||
result: Result = Field(description="The result")
|
||||
session: GraphExecutionState = Field(description="The session")
|
||||
# If the destination node ID matches the given node ID, gather necessary metadata
|
||||
if dest_node_id == node_id:
|
||||
# If the destination field is 'positive_conditioning', add the 'prompt' from the source node
|
||||
if dest_field == "positive_conditioning":
|
||||
metadata["positive_conditioning"] = source_node_dict.get("prompt")
|
||||
# If the destination field is 'negative_conditioning', add the 'prompt' from the source node
|
||||
if dest_field == "negative_conditioning":
|
||||
metadata["negative_conditioning"] = source_node_dict.get("prompt")
|
||||
# If the destination field is 'noise', add the core noise fields from the source node
|
||||
if dest_field == "noise":
|
||||
for field in self._NOISE_FIELDS:
|
||||
metadata[field] = source_node_dict.get(field)
|
||||
return metadata
|
||||
|
||||
def _build_metadata_from_graph(
|
||||
self, session: GraphExecutionState, node_id: str
|
||||
) -> ImageMetadata:
|
||||
"""
|
||||
Builds an ImageMetadata object for a node.
|
||||
|
||||
# # Create a directed graph
|
||||
# from typing import Any, TypedDict, Union
|
||||
# from networkx import DiGraph
|
||||
# import networkx as nx
|
||||
# import json
|
||||
Parameters:
|
||||
session (GraphExecutionState): The session.
|
||||
node_id (str): The ID of the node.
|
||||
|
||||
Returns:
|
||||
ImageMetadata: The metadata for the node.
|
||||
"""
|
||||
|
||||
# # We need to use a loose class for nodes to allow for graceful parsing - we cannot use the stricter
|
||||
# # model used by the system, because we may be a graph in an old format. We can, however, use the
|
||||
# # Edge model, because the edge format does not change.
|
||||
# class LooseGraph(BaseModel):
|
||||
# id: str
|
||||
# nodes: dict[str, dict[str, Any]]
|
||||
# edges: list[Edge]
|
||||
graph = session.execution_graph
|
||||
|
||||
# Find the nearest ancestor of the given node
|
||||
ancestor_id = self._find_nearest_ancestor(graph.nx_graph_with_data(), node_id)
|
||||
|
||||
# # An intermediate type used during parsing
|
||||
# class NearestAncestor(TypedDict):
|
||||
# node_id: str
|
||||
# metadata: dict[str, Any]
|
||||
# If no ancestor was found, return an empty ImageMetadata object
|
||||
if ancestor_id is None:
|
||||
return ImageMetadata()
|
||||
|
||||
ancestor_node = graph.get_node(ancestor_id)
|
||||
|
||||
# # The ancestor types that contain the core metadata
|
||||
# ANCESTOR_TYPES = ['t2l', 'l2l']
|
||||
ancestor_metadata = {
|
||||
param: val
|
||||
for param, val in ancestor_node.dict().items()
|
||||
if param in self._ANCESTOR_PARAMS
|
||||
}
|
||||
|
||||
# # The core metadata parameters in the ancestor types
|
||||
# ANCESTOR_PARAMS = ['steps', 'model', 'cfg_scale', 'scheduler', 'strength']
|
||||
# Get additional metadata related to the ancestor
|
||||
addl_metadata = self._get_additional_metadata(graph, ancestor_id)
|
||||
|
||||
# # The core metadata parameters in the noise node
|
||||
# NOISE_FIELDS = ['seed', 'width', 'height']
|
||||
# If additional metadata was found, add it to the main metadata
|
||||
if addl_metadata is not None:
|
||||
ancestor_metadata.update(addl_metadata)
|
||||
|
||||
# # Find nearest t2l or l2l ancestor from a given l2i node
|
||||
# def find_nearest_ancestor(G: DiGraph, node_id: str) -> Union[NearestAncestor, None]:
|
||||
# """Returns metadata for the nearest ancestor of a given node.
|
||||
|
||||
# Parameters:
|
||||
# G (DiGraph): A directed graph.
|
||||
# node_id (str): The ID of the starting node.
|
||||
|
||||
# Returns:
|
||||
# NearestAncestor | None: An object with the ID and metadata of the nearest ancestor.
|
||||
# """
|
||||
|
||||
# # Retrieve the node from the graph
|
||||
# node = G.nodes[node_id]
|
||||
|
||||
# # If the node type is one of the core metadata node types, gather necessary metadata and return
|
||||
# if node.get('type') in ANCESTOR_TYPES:
|
||||
# parsed_metadata = {param: val for param, val in node.items() if param in ANCESTOR_PARAMS}
|
||||
# return NearestAncestor(node_id=node_id, metadata=parsed_metadata)
|
||||
|
||||
|
||||
# # Else, look for the ancestor in the predecessor nodes
|
||||
# for predecessor in G.predecessors(node_id):
|
||||
# result = find_nearest_ancestor(G, predecessor)
|
||||
# if result:
|
||||
# return result
|
||||
|
||||
# # If there are no valid ancestors, return None
|
||||
# return None
|
||||
|
||||
|
||||
# def get_additional_metadata(graph: LooseGraph, node_id: str) -> Union[dict[str, Any], None]:
|
||||
# """Collects additional metadata from nodes connected to a given node.
|
||||
|
||||
# Parameters:
|
||||
# graph (LooseGraph): The graph.
|
||||
# node_id (str): The ID of the node.
|
||||
|
||||
# Returns:
|
||||
# dict | None: A dictionary containing additional metadata.
|
||||
# """
|
||||
|
||||
# metadata = {}
|
||||
|
||||
# # Iterate over all edges in the graph
|
||||
# for edge in graph.edges:
|
||||
# dest_node_id = edge.destination.node_id
|
||||
# dest_field = edge.destination.field
|
||||
# source_node = graph.nodes[edge.source.node_id]
|
||||
|
||||
# # If the destination node ID matches the given node ID, gather necessary metadata
|
||||
# if dest_node_id == node_id:
|
||||
# # If the destination field is 'positive_conditioning', add the 'prompt' from the source node
|
||||
# if dest_field == 'positive_conditioning':
|
||||
# metadata['positive_conditioning'] = source_node.get('prompt')
|
||||
# # If the destination field is 'negative_conditioning', add the 'prompt' from the source node
|
||||
# if dest_field == 'negative_conditioning':
|
||||
# metadata['negative_conditioning'] = source_node.get('prompt')
|
||||
# # If the destination field is 'noise', add the core noise fields from the source node
|
||||
# if dest_field == 'noise':
|
||||
# for field in NOISE_FIELDS:
|
||||
# metadata[field] = source_node.get(field)
|
||||
# return metadata
|
||||
|
||||
# def build_core_metadata(graph_raw: str, node_id: str) -> Union[dict, None]:
|
||||
# """Builds the core metadata for a given node.
|
||||
|
||||
# Parameters:
|
||||
# graph_raw (str): The graph structure as a raw string.
|
||||
# node_id (str): The ID of the node.
|
||||
|
||||
# Returns:
|
||||
# dict | None: A dictionary containing core metadata.
|
||||
# """
|
||||
|
||||
# # Create a directed graph to facilitate traversal
|
||||
# G = nx.DiGraph()
|
||||
|
||||
# # Convert the raw graph string into a JSON object
|
||||
# graph = parse_obj_as(LooseGraph, graph_raw)
|
||||
|
||||
# # Add nodes and edges to the graph
|
||||
# for node_id, node_data in graph.nodes.items():
|
||||
# G.add_node(node_id, **node_data)
|
||||
# for edge in graph.edges:
|
||||
# G.add_edge(edge.source.node_id, edge.destination.node_id)
|
||||
|
||||
# # Find the nearest ancestor of the given node
|
||||
# ancestor = find_nearest_ancestor(G, node_id)
|
||||
|
||||
# # If no ancestor was found, return None
|
||||
# if ancestor is None:
|
||||
# return None
|
||||
|
||||
# metadata = ancestor['metadata']
|
||||
# ancestor_id = ancestor['node_id']
|
||||
|
||||
# # Get additional metadata related to the ancestor
|
||||
# addl_metadata = get_additional_metadata(graph, ancestor_id)
|
||||
|
||||
# # If additional metadata was found, add it to the main metadata
|
||||
# if addl_metadata is not None:
|
||||
# metadata.update(addl_metadata)
|
||||
|
||||
# return metadata
|
||||
return ImageMetadata(**ancestor_metadata)
|
||||
|
@ -62,9 +62,12 @@ def deserialize_image_record(image_row: sqlite3.Row) -> ImageRecord:
|
||||
|
||||
image_type = ImageType(image_dict.get("image_type", ImageType.RESULT.value))
|
||||
|
||||
raw_metadata = image_dict.get("metadata", "{}")
|
||||
raw_metadata = image_dict.get("metadata")
|
||||
|
||||
metadata = ImageMetadata.parse_raw(raw_metadata)
|
||||
if raw_metadata is not None:
|
||||
metadata = ImageMetadata.parse_raw(raw_metadata)
|
||||
else:
|
||||
metadata = None
|
||||
|
||||
return ImageRecord(
|
||||
image_name=image_dict.get("id", "unknown"),
|
||||
|
Loading…
Reference in New Issue
Block a user