feat(nodes): improve metadata service comments

This commit is contained in:
psychedelicious 2023-05-23 18:43:06 +10:00 committed by Kent Keirsey
parent 7a1de3887e
commit 021e5a2aa3

View File

@ -1,10 +1,9 @@
import json
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Union from typing import Any, Union
import networkx as nx import networkx as nx
from invokeai.app.models.metadata import ImageMetadata from invokeai.app.models.metadata import ImageMetadata
from invokeai.app.services.graph import Edge, Graph, GraphExecutionState from invokeai.app.services.graph import Graph, GraphExecutionState
class MetadataServiceBase(ABC): class MetadataServiceBase(ABC):
@ -18,7 +17,6 @@ class MetadataServiceBase(ABC):
pass pass
class CoreMetadataService(MetadataServiceBase): class CoreMetadataService(MetadataServiceBase):
_ANCESTOR_TYPES = ["t2l", "l2l"] _ANCESTOR_TYPES = ["t2l", "l2l"]
"""The ancestor types that contain the core metadata""" """The ancestor types that contain the core metadata"""
@ -89,13 +87,13 @@ class CoreMetadataService(MetadataServiceBase):
# If the destination node ID matches the given node ID, gather necessary metadata # If the destination node ID matches the given node ID, gather necessary metadata
if dest_node_id == node_id: if dest_node_id == node_id:
# If the destination field is 'positive_conditioning', add the 'prompt' from the source node # Prompt
if dest_field == "positive_conditioning": if dest_field == "positive_conditioning":
metadata["positive_conditioning"] = source_node_dict.get("prompt") metadata["positive_conditioning"] = source_node_dict.get("prompt")
# If the destination field is 'negative_conditioning', add the 'prompt' from the source node # Negative prompt
if dest_field == "negative_conditioning": if dest_field == "negative_conditioning":
metadata["negative_conditioning"] = source_node_dict.get("prompt") metadata["negative_conditioning"] = source_node_dict.get("prompt")
# If the destination field is 'noise', add the core noise fields from the source node # Seed, width and height
if dest_field == "noise": if dest_field == "noise":
for field in self._NOISE_FIELDS: for field in self._NOISE_FIELDS:
metadata[field] = source_node_dict.get(field) metadata[field] = source_node_dict.get(field)
@ -115,9 +113,10 @@ class CoreMetadataService(MetadataServiceBase):
ImageMetadata: The metadata for the node. ImageMetadata: The metadata for the node.
""" """
# We need to do all the traversal on the execution graph
graph = session.execution_graph graph = session.execution_graph
# Find the nearest ancestor of the given node # Find the nearest `t2l`/`l2l` ancestor of the given node
ancestor_id = self._find_nearest_ancestor(graph.nx_graph_with_data(), node_id) ancestor_id = self._find_nearest_ancestor(graph.nx_graph_with_data(), node_id)
# If no ancestor was found, return an empty ImageMetadata object # If no ancestor was found, return an empty ImageMetadata object
@ -126,13 +125,14 @@ class CoreMetadataService(MetadataServiceBase):
ancestor_node = graph.get_node(ancestor_id) ancestor_node = graph.get_node(ancestor_id)
# Grab all the core metadata from the ancestor node
ancestor_metadata = { ancestor_metadata = {
param: val param: val
for param, val in ancestor_node.dict().items() for param, val in ancestor_node.dict().items()
if param in self._ANCESTOR_PARAMS if param in self._ANCESTOR_PARAMS
} }
# Get additional metadata related to the ancestor # Get this image's prompts and noise parameters
addl_metadata = self._get_additional_metadata(graph, ancestor_id) addl_metadata = self._get_additional_metadata(graph, ancestor_id)
# If additional metadata was found, add it to the main metadata # If additional metadata was found, add it to the main metadata