feat(nodes): add metadata handling

This commit is contained in:
psychedelicious 2023-05-22 15:48:12 +10:00 committed by Kent Keirsey
parent f071b03ceb
commit 5de3c41d19
9 changed files with 228 additions and 322 deletions

View File

@ -4,6 +4,7 @@ from logging import Logger
import os
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
from invokeai.app.services.images import ImageService
from invokeai.app.services.metadata import CoreMetadataService
from invokeai.app.services.urls import LocalUrlService
from invokeai.backend.util.logging import InvokeAILogger
@ -18,7 +19,6 @@ from ..services.invocation_services import InvocationServices
from ..services.invoker import Invoker
from ..services.processor import DefaultInvocationProcessor
from ..services.sqlite import SqliteItemStorage
from ..services.metadata import PngMetadataService
from .events import FastAPIEventService
@ -59,7 +59,7 @@ class ApiDependencies:
DiskLatentsStorage(f"{output_folder}/latents")
)
metadata = PngMetadataService()
metadata = CoreMetadataService()
urls = LocalUrlService()
@ -80,6 +80,7 @@ class ApiDependencies:
metadata=metadata,
url=urls,
logger=logger,
graph_execution_manager=graph_execution_manager,
)
services = InvocationServices(

View File

@ -2,7 +2,6 @@ from typing import Optional
from pydantic import BaseModel, Field
from invokeai.app.models.image import ImageType
from invokeai.app.services.metadata import InvokeAIMetadata
class ImageResponseMetadata(BaseModel):
@ -11,9 +10,9 @@ class ImageResponseMetadata(BaseModel):
created: int = Field(description="The creation timestamp of the image")
width: int = Field(description="The width of the image in pixels")
height: int = Field(description="The height of the image in pixels")
invokeai: Optional[InvokeAIMetadata] = Field(
description="The image's InvokeAI-specific metadata"
)
# invokeai: Optional[InvokeAIMetadata] = Field(
# description="The image's InvokeAI-specific metadata"
# )
class ImageResponse(BaseModel):

View File

@ -7,6 +7,7 @@ from pydantic import BaseModel, Field
import torch
from invokeai.app.invocations.util.choose_model import choose_model
from invokeai.app.models.image import ImageCategory
from invokeai.app.util.misc import SEED_MAX, get_random_seed
from invokeai.app.util.step_callback import stable_diffusion_step_callback
@ -356,20 +357,30 @@ class LatentsToImageInvocation(BaseInvocation):
np_image = model.decode_latents(latents)
image = model.numpy_to_pil(np_image)[0]
image_type = ImageType.RESULT
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
# image_type = ImageType.RESULT
# image_name = context.services.images.create_name(
# context.graph_execution_state_id, self.id
# )
# metadata = context.services.metadata.build_metadata(
# session_id=context.graph_execution_state_id, node=self
# )
# torch.cuda.empty_cache()
# context.services.images.save(image_type, image_name, image, metadata)
image_dto = context.services.images_new.create(
image=image,
image_type=ImageType.RESULT,
image_category=ImageCategory.IMAGE,
session_id=context.graph_execution_state_id,
node_id=self.id,
)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
torch.cuda.empty_cache()
context.services.images.save(image_type, image_name, image, metadata)
return build_image_output(
image_type=image_type, image_name=image_name, image=image
image_type=image_dto.image_type,
image_name=image_dto.image_name,
image=image,
)

View File

@ -1,5 +1,5 @@
from typing import Optional
from pydantic import BaseModel, Field, StrictFloat, StrictInt, StrictStr
from pydantic import BaseModel, Extra, Field, StrictFloat, StrictInt, StrictStr
class ImageMetadata(BaseModel):
@ -8,11 +8,24 @@ class ImageMetadata(BaseModel):
Also includes any metadata from the image's PNG tEXt chunks.
Generated by traversing the execution graph, collecting the parameters of the nearest ancestors of a given node.
Generated by traversing the execution graph, collecting the parameters of the nearest ancestors
of a given node.
Full metadata may be accessed by querying for the session in the `graph_executions` table.
"""
class Config:
extra = Extra.allow
"""
This lets the ImageMetadata class accept arbitrary additional fields. The CoreMetadataService
won't add any fields that are not already defined, but other a different metadata service
implementation might.
"""
type: Optional[StrictStr] = Field(
default=None,
description="The type of the ancestor node of the image output node.",
)
positive_conditioning: Optional[StrictStr] = Field(
default=None, description="The positive conditioning."
)
@ -20,10 +33,10 @@ class ImageMetadata(BaseModel):
default=None, description="The negative conditioning."
)
width: Optional[StrictInt] = Field(
default=None, description="Width of the image/tensor in pixels."
default=None, description="Width of the image/latents in pixels."
)
height: Optional[StrictInt] = Field(
default=None, description="Height of the image/tensor in pixels."
default=None, description="Height of the image/latents in pixels."
)
seed: Optional[StrictInt] = Field(
default=None, description="The seed used for noise generation."
@ -42,18 +55,21 @@ class ImageMetadata(BaseModel):
)
strength: Optional[StrictFloat] = Field(
default=None,
description="The strength used for image-to-image/tensor-to-tensor.",
description="The strength used for image-to-image/latents-to-latents.",
)
image: Optional[StrictStr] = Field(
default=None, description="The ID of the initial image."
latents: Optional[StrictStr] = Field(
default=None, description="The ID of the initial latents."
)
tensor: Optional[StrictStr] = Field(
default=None, description="The ID of the initial tensor."
vae: Optional[StrictStr] = Field(
default=None, description="The VAE used for decoding."
)
unet: Optional[StrictStr] = Field(
default=None, description="The UNet used dor inference."
)
clip: Optional[StrictStr] = Field(
default=None, description="The CLIP Encoder used for conditioning."
)
# Pending model refactor:
# vae: Optional[str] = Field(default=None,description="The VAE used for decoding.")
# unet: Optional[str] = Field(default=None,description="The UNet used dor inference.")
# clip: Optional[str] = Field(default=None,description="The CLIP Encoder used for conditioning.")
extra: Optional[StrictStr] = Field(
default=None, description="Extra metadata, extracted from the PNG tEXt chunk."
default=None,
description="Uploaded image metadata, extracted from the PNG tEXt chunk.",
)

View File

@ -713,6 +713,13 @@ class Graph(BaseModel):
g.add_edges_from(set([(e.source.node_id, e.destination.node_id) for e in self.edges]))
return g
def nx_graph_with_data(self) -> nx.DiGraph:
"""Returns a NetworkX DiGraph representing the data and layout of this graph"""
g = nx.DiGraph()
g.add_nodes_from([n for n in self.nodes.items()])
g.add_edges_from(set([(e.source.node_id, e.destination.node_id) for e in self.edges]))
return g
def nx_graph_flat(
self, nx_graph: Optional[nx.DiGraph] = None, prefix: Optional[str] = None
) -> nx.DiGraph:

View File

@ -6,11 +6,11 @@ from queue import Queue
from typing import Dict, Optional
from PIL.Image import Image as PILImageType
from PIL import Image
from PIL.PngImagePlugin import PngInfo
from PIL import Image, PngImagePlugin
from send2trash import send2trash
from invokeai.app.models.image import ImageType
from invokeai.app.models.metadata import ImageMetadata
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
@ -54,7 +54,7 @@ class ImageFileStorageBase(ABC):
image: PILImageType,
image_type: ImageType,
image_name: str,
pnginfo: Optional[PngInfo] = None,
metadata: Optional[ImageMetadata] = None,
thumbnail_size: int = 256,
) -> None:
"""Saves an image and a 256x256 WEBP thumbnail. Returns a tuple of the image name, thumbnail name, and created timestamp."""
@ -109,12 +109,18 @@ class DiskImageFileStorage(ImageFileStorageBase):
image: PILImageType,
image_type: ImageType,
image_name: str,
pnginfo: Optional[PngInfo] = None,
metadata: Optional[ImageMetadata] = None,
thumbnail_size: int = 256,
) -> None:
try:
image_path = self.get_path(image_type, image_name)
image.save(image_path, "PNG", pnginfo=pnginfo)
if metadata is not None:
pnginfo = PngImagePlugin.PngInfo()
pnginfo.add_text("invokeai", metadata.json())
image.save(image_path, "PNG", pnginfo=pnginfo)
else:
image.save(image_path, "PNG")
thumbnail_name = get_thumbnail_name(image_name)
thumbnail_path = self.get_path(image_type, thumbnail_name, thumbnail=True)

View File

@ -1,10 +1,9 @@
from abc import ABC, abstractmethod
import json
from logging import Logger
from typing import Optional, Union
from typing import Optional, TYPE_CHECKING, Union
import uuid
from PIL.Image import Image as PILImageType
from PIL import PngImagePlugin
from invokeai.app.models.image import ImageCategory, ImageType
from invokeai.app.models.metadata import ImageMetadata
@ -17,12 +16,16 @@ from invokeai.app.services.models.image_record import (
image_record_to_dto,
)
from invokeai.app.services.image_file_storage import ImageFileStorageBase
from invokeai.app.services.item_storage import PaginatedResults
from invokeai.app.services.item_storage import ItemStorageABC, PaginatedResults
from invokeai.app.services.metadata import MetadataServiceBase
from invokeai.app.services.urls import UrlServiceBase
from invokeai.app.util.misc import get_iso_timestamp
if TYPE_CHECKING:
from invokeai.app.services.graph import GraphExecutionState
class ImageServiceABC(ABC):
"""
High-level service for image management.
@ -59,7 +62,9 @@ class ImageServiceABC(ABC):
pass
@abstractmethod
def get_url(self, image_type: ImageType, image_name: str, thumbnail: bool = False) -> str:
def get_url(
self, image_type: ImageType, image_name: str, thumbnail: bool = False
) -> str:
"""Gets an image's or thumbnail's URL"""
pass
@ -113,6 +118,7 @@ class ImageServiceDependencies:
metadata: MetadataServiceBase
urls: UrlServiceBase
logger: Logger
graph_execution_manager: ItemStorageABC["GraphExecutionState"]
def __init__(
self,
@ -121,12 +127,14 @@ class ImageServiceDependencies:
metadata: MetadataServiceBase,
url: UrlServiceBase,
logger: Logger,
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
):
self.records = image_record_storage
self.files = image_file_storage
self.metadata = metadata
self.urls = url
self.logger = logger
self.graph_execution_manager = graph_execution_manager
class ImageService(ImageServiceABC):
@ -139,6 +147,7 @@ class ImageService(ImageServiceABC):
metadata: MetadataServiceBase,
url: UrlServiceBase,
logger: Logger,
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
):
self._services = ImageServiceDependencies(
image_record_storage=image_record_storage,
@ -146,6 +155,7 @@ class ImageService(ImageServiceABC):
metadata=metadata,
url=url,
logger=logger,
graph_execution_manager=graph_execution_manager,
)
def create(
@ -155,7 +165,6 @@ class ImageService(ImageServiceABC):
image_category: ImageCategory,
node_id: Optional[str] = None,
session_id: Optional[str] = None,
metadata: Optional[ImageMetadata] = None,
) -> ImageDTO:
image_name = self._create_image_name(
image_type=image_type,
@ -165,12 +174,7 @@ class ImageService(ImageServiceABC):
)
timestamp = get_iso_timestamp()
if metadata is not None:
pnginfo = PngImagePlugin.PngInfo()
pnginfo.add_text("invokeai", json.dumps(metadata))
else:
pnginfo = None
metadata = self._get_metadata(session_id, node_id)
try:
# TODO: Consider using a transaction here to ensure consistency between storage and database
@ -178,7 +182,7 @@ class ImageService(ImageServiceABC):
image_type=image_type,
image_name=image_name,
image=image,
pnginfo=pnginfo,
metadata=metadata,
)
self._services.records.save(
@ -237,24 +241,6 @@ class ImageService(ImageServiceABC):
self._services.logger.error("Problem getting image record")
raise e
def get_path(
self, image_type: ImageType, image_name: str, thumbnail: bool = False
) -> str:
try:
return self._services.files.get_path(image_type, image_name, thumbnail)
except Exception as e:
self._services.logger.error("Problem getting image path")
raise e
def get_url(
self, image_type: ImageType, image_name: str, thumbnail: bool = False
) -> str:
try:
return self._services.urls.get_image_url(image_type, image_name, thumbnail)
except Exception as e:
self._services.logger.error("Problem getting image path")
raise e
def get_dto(self, image_type: ImageType, image_name: str) -> ImageDTO:
try:
image_record = self._services.records.get(image_type, image_name)
@ -273,6 +259,24 @@ class ImageService(ImageServiceABC):
self._services.logger.error("Problem getting image DTO")
raise e
def get_path(
self, image_type: ImageType, image_name: str, thumbnail: bool = False
) -> str:
try:
return self._services.files.get_path(image_type, image_name, thumbnail)
except Exception as e:
self._services.logger.error("Problem getting image path")
raise e
def get_url(
self, image_type: ImageType, image_name: str, thumbnail: bool = False
) -> str:
try:
return self._services.urls.get_image_url(image_type, image_name, thumbnail)
except Exception as e:
self._services.logger.error("Problem getting image path")
raise e
def get_many(
self,
image_type: ImageType,
@ -353,3 +357,15 @@ class ImageService(ImageServiceABC):
return f"{image_type.value}_{image_category.value}_{session_id}_{node_id}_{uuid_str}.png"
return f"{image_type.value}_{image_category.value}_{uuid_str}.png"
def _get_metadata(
self, session_id: Optional[str] = None, node_id: Optional[str] = None
) -> Union[ImageMetadata, None]:
"""Get the metadata for a node."""
metadata = None
if node_id is not None and session_id is not None:
session = self._services.graph_execution_manager.get(session_id)
metadata = self._services.metadata.create_image_metadata(session, node_id)
return metadata

View File

@ -1,295 +1,142 @@
import json
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional, TypedDict
from PIL import Image, PngImagePlugin
from pydantic import BaseModel
from typing import Any, Union
import networkx as nx
from invokeai.app.models.image import ImageType, is_image_type
class MetadataImageField(TypedDict):
"""Pydantic-less ImageField, used for metadata parsing."""
image_type: ImageType
image_name: str
class MetadataLatentsField(TypedDict):
"""Pydantic-less LatentsField, used for metadata parsing."""
latents_name: str
class MetadataColorField(TypedDict):
"""Pydantic-less ColorField, used for metadata parsing"""
r: int
g: int
b: int
a: int
# TODO: This is a placeholder for `InvocationsUnion` pending resolution of circular imports
NodeMetadata = Dict[
str,
None
| str
| int
| float
| bool
| MetadataImageField
| MetadataLatentsField
| MetadataColorField,
]
class InvokeAIMetadata(TypedDict, total=False):
"""InvokeAI-specific metadata format."""
session_id: Optional[str]
node: Optional[NodeMetadata]
def build_invokeai_metadata_pnginfo(
metadata: InvokeAIMetadata | None,
) -> PngImagePlugin.PngInfo:
"""Builds a PngInfo object with key `"invokeai"` and value `metadata`"""
pnginfo = PngImagePlugin.PngInfo()
if metadata is not None:
pnginfo.add_text("invokeai", json.dumps(metadata))
return pnginfo
from invokeai.app.models.metadata import ImageMetadata
from invokeai.app.services.graph import Edge, Graph, GraphExecutionState
class MetadataServiceBase(ABC):
@abstractmethod
def get_metadata(self, image: Image.Image) -> InvokeAIMetadata | None:
"""Gets the InvokeAI metadata from a PIL Image, skipping invalid values"""
pass
"""Handles building metadata for nodes, images, and outputs."""
@abstractmethod
def build_metadata(
self, session_id: str, node: BaseModel
) -> InvokeAIMetadata | None:
"""Builds an InvokeAIMetadata object"""
def create_image_metadata(
self, session: GraphExecutionState, node_id: str
) -> ImageMetadata:
"""Builds an ImageMetadata object for a node."""
pass
# @abstractmethod
# def create_metadata(self, session_id: str, node_id: str) -> dict:
# """Creates metadata for a result"""
# pass
class PngMetadataService(MetadataServiceBase):
"""Handles loading and building metadata for images."""
class CoreMetadataService(MetadataServiceBase):
_ANCESTOR_TYPES = ["t2l", "l2l"]
"""The ancestor types that contain the core metadata"""
# TODO: Use `InvocationsUnion` to **validate** metadata as representing a fully-functioning node
def _load_metadata(self, image: Image.Image) -> dict | None:
"""Loads a specific info entry from a PIL Image."""
_ANCESTOR_PARAMS = ["type", "steps", "model", "cfg_scale", "scheduler", "strength"]
"""The core metadata parameters in the ancestor types"""
try:
info = image.info.get("invokeai")
_NOISE_FIELDS = ["seed", "width", "height"]
"""The core metadata parameters in the noise node"""
if type(info) is not str:
return None
loaded_metadata = json.loads(info)
if type(loaded_metadata) is not dict:
return None
if len(loaded_metadata.items()) == 0:
return None
return loaded_metadata
except:
return None
def get_metadata(self, image: Image.Image) -> dict | None:
"""Retrieves an image's metadata as a dict"""
loaded_metadata = self._load_metadata(image)
return loaded_metadata
def build_metadata(self, session_id: str, node: BaseModel) -> InvokeAIMetadata:
metadata = InvokeAIMetadata(session_id=session_id, node=node.dict())
def create_image_metadata(
self, session: GraphExecutionState, node_id: str
) -> ImageMetadata:
metadata = self._build_metadata_from_graph(session, node_id)
return metadata
def _find_nearest_ancestor(self, G: nx.DiGraph, node_id: str) -> Union[str, None]:
"""
Finds the id of the nearest ancestor (of a valid type) of a given node.
from enum import Enum
Parameters:
G (nx.DiGraph): The execution graph, converted in to a networkx DiGraph. Its nodes must
have the same data as the execution graph.
node_id (str): The ID of the node.
from abc import ABC, abstractmethod
import json
import sqlite3
from threading import Lock
from typing import Any, Union
Returns:
str | None: The ID of the nearest ancestor, or None if there are no valid ancestors.
"""
import networkx as nx
# Retrieve the node from the graph
node = G.nodes[node_id]
from pydantic import BaseModel, Field, parse_obj_as, parse_raw_as
from invokeai.app.invocations.image import ImageOutput
from invokeai.app.services.graph import Edge, GraphExecutionState
from invokeai.app.invocations.latent import LatentsOutput
from invokeai.app.services.item_storage import PaginatedResults
from invokeai.app.util.misc import get_timestamp
# If the node type is one of the core metadata node types, return its id
if node.get("type") in self._ANCESTOR_TYPES:
return node.get("id")
# Else, look for the ancestor in the predecessor nodes
for predecessor in G.predecessors(node_id):
result = self._find_nearest_ancestor(G, predecessor)
if result:
return result
class ResultType(str, Enum):
image_output = "image_output"
latents_output = "latents_output"
# If there are no valid ancestors, return None
return None
def _get_additional_metadata(
self, graph: Graph, node_id: str
) -> Union[dict[str, Any], None]:
"""
Returns additional metadata for a given node.
class Result(BaseModel):
"""A session result"""
Parameters:
graph (Graph): The execution graph.
node_id (str): The ID of the node.
id: str = Field(description="Result ID")
session_id: str = Field(description="Session ID")
node_id: str = Field(description="Node ID")
data: Union[LatentsOutput, ImageOutput] = Field(description="The result data")
Returns:
dict[str, Any] | None: A dictionary of additional metadata.
"""
metadata = {}
class ResultWithSession(BaseModel):
"""A result with its session"""
# Iterate over all edges in the graph
for edge in graph.edges:
dest_node_id = edge.destination.node_id
dest_field = edge.destination.field
source_node_dict = graph.nodes[edge.source.node_id].dict()
result: Result = Field(description="The result")
session: GraphExecutionState = Field(description="The session")
# If the destination node ID matches the given node ID, gather necessary metadata
if dest_node_id == node_id:
# If the destination field is 'positive_conditioning', add the 'prompt' from the source node
if dest_field == "positive_conditioning":
metadata["positive_conditioning"] = source_node_dict.get("prompt")
# If the destination field is 'negative_conditioning', add the 'prompt' from the source node
if dest_field == "negative_conditioning":
metadata["negative_conditioning"] = source_node_dict.get("prompt")
# If the destination field is 'noise', add the core noise fields from the source node
if dest_field == "noise":
for field in self._NOISE_FIELDS:
metadata[field] = source_node_dict.get(field)
return metadata
def _build_metadata_from_graph(
self, session: GraphExecutionState, node_id: str
) -> ImageMetadata:
"""
Builds an ImageMetadata object for a node.
# # Create a directed graph
# from typing import Any, TypedDict, Union
# from networkx import DiGraph
# import networkx as nx
# import json
Parameters:
session (GraphExecutionState): The session.
node_id (str): The ID of the node.
Returns:
ImageMetadata: The metadata for the node.
"""
# # We need to use a loose class for nodes to allow for graceful parsing - we cannot use the stricter
# # model used by the system, because we may be a graph in an old format. We can, however, use the
# # Edge model, because the edge format does not change.
# class LooseGraph(BaseModel):
# id: str
# nodes: dict[str, dict[str, Any]]
# edges: list[Edge]
graph = session.execution_graph
# Find the nearest ancestor of the given node
ancestor_id = self._find_nearest_ancestor(graph.nx_graph_with_data(), node_id)
# # An intermediate type used during parsing
# class NearestAncestor(TypedDict):
# node_id: str
# metadata: dict[str, Any]
# If no ancestor was found, return an empty ImageMetadata object
if ancestor_id is None:
return ImageMetadata()
ancestor_node = graph.get_node(ancestor_id)
# # The ancestor types that contain the core metadata
# ANCESTOR_TYPES = ['t2l', 'l2l']
ancestor_metadata = {
param: val
for param, val in ancestor_node.dict().items()
if param in self._ANCESTOR_PARAMS
}
# # The core metadata parameters in the ancestor types
# ANCESTOR_PARAMS = ['steps', 'model', 'cfg_scale', 'scheduler', 'strength']
# Get additional metadata related to the ancestor
addl_metadata = self._get_additional_metadata(graph, ancestor_id)
# # The core metadata parameters in the noise node
# NOISE_FIELDS = ['seed', 'width', 'height']
# If additional metadata was found, add it to the main metadata
if addl_metadata is not None:
ancestor_metadata.update(addl_metadata)
# # Find nearest t2l or l2l ancestor from a given l2i node
# def find_nearest_ancestor(G: DiGraph, node_id: str) -> Union[NearestAncestor, None]:
# """Returns metadata for the nearest ancestor of a given node.
# Parameters:
# G (DiGraph): A directed graph.
# node_id (str): The ID of the starting node.
# Returns:
# NearestAncestor | None: An object with the ID and metadata of the nearest ancestor.
# """
# # Retrieve the node from the graph
# node = G.nodes[node_id]
# # If the node type is one of the core metadata node types, gather necessary metadata and return
# if node.get('type') in ANCESTOR_TYPES:
# parsed_metadata = {param: val for param, val in node.items() if param in ANCESTOR_PARAMS}
# return NearestAncestor(node_id=node_id, metadata=parsed_metadata)
# # Else, look for the ancestor in the predecessor nodes
# for predecessor in G.predecessors(node_id):
# result = find_nearest_ancestor(G, predecessor)
# if result:
# return result
# # If there are no valid ancestors, return None
# return None
# def get_additional_metadata(graph: LooseGraph, node_id: str) -> Union[dict[str, Any], None]:
# """Collects additional metadata from nodes connected to a given node.
# Parameters:
# graph (LooseGraph): The graph.
# node_id (str): The ID of the node.
# Returns:
# dict | None: A dictionary containing additional metadata.
# """
# metadata = {}
# # Iterate over all edges in the graph
# for edge in graph.edges:
# dest_node_id = edge.destination.node_id
# dest_field = edge.destination.field
# source_node = graph.nodes[edge.source.node_id]
# # If the destination node ID matches the given node ID, gather necessary metadata
# if dest_node_id == node_id:
# # If the destination field is 'positive_conditioning', add the 'prompt' from the source node
# if dest_field == 'positive_conditioning':
# metadata['positive_conditioning'] = source_node.get('prompt')
# # If the destination field is 'negative_conditioning', add the 'prompt' from the source node
# if dest_field == 'negative_conditioning':
# metadata['negative_conditioning'] = source_node.get('prompt')
# # If the destination field is 'noise', add the core noise fields from the source node
# if dest_field == 'noise':
# for field in NOISE_FIELDS:
# metadata[field] = source_node.get(field)
# return metadata
# def build_core_metadata(graph_raw: str, node_id: str) -> Union[dict, None]:
# """Builds the core metadata for a given node.
# Parameters:
# graph_raw (str): The graph structure as a raw string.
# node_id (str): The ID of the node.
# Returns:
# dict | None: A dictionary containing core metadata.
# """
# # Create a directed graph to facilitate traversal
# G = nx.DiGraph()
# # Convert the raw graph string into a JSON object
# graph = parse_obj_as(LooseGraph, graph_raw)
# # Add nodes and edges to the graph
# for node_id, node_data in graph.nodes.items():
# G.add_node(node_id, **node_data)
# for edge in graph.edges:
# G.add_edge(edge.source.node_id, edge.destination.node_id)
# # Find the nearest ancestor of the given node
# ancestor = find_nearest_ancestor(G, node_id)
# # If no ancestor was found, return None
# if ancestor is None:
# return None
# metadata = ancestor['metadata']
# ancestor_id = ancestor['node_id']
# # Get additional metadata related to the ancestor
# addl_metadata = get_additional_metadata(graph, ancestor_id)
# # If additional metadata was found, add it to the main metadata
# if addl_metadata is not None:
# metadata.update(addl_metadata)
# return metadata
return ImageMetadata(**ancestor_metadata)

View File

@ -62,9 +62,12 @@ def deserialize_image_record(image_row: sqlite3.Row) -> ImageRecord:
image_type = ImageType(image_dict.get("image_type", ImageType.RESULT.value))
raw_metadata = image_dict.get("metadata", "{}")
raw_metadata = image_dict.get("metadata")
metadata = ImageMetadata.parse_raw(raw_metadata)
if raw_metadata is not None:
metadata = ImageMetadata.parse_raw(raw_metadata)
else:
metadata = None
return ImageRecord(
image_name=image_dict.get("id", "unknown"),