mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(nodes): improve metadata service comments
This commit is contained in:
parent
7a1de3887e
commit
021e5a2aa3
@ -1,10 +1,9 @@
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Union
|
||||
import networkx as nx
|
||||
|
||||
from invokeai.app.models.metadata import ImageMetadata
|
||||
from invokeai.app.services.graph import Edge, Graph, GraphExecutionState
|
||||
from invokeai.app.services.graph import Graph, GraphExecutionState
|
||||
|
||||
|
||||
class MetadataServiceBase(ABC):
|
||||
@ -18,7 +17,6 @@ class MetadataServiceBase(ABC):
|
||||
pass
|
||||
|
||||
|
||||
|
||||
class CoreMetadataService(MetadataServiceBase):
|
||||
_ANCESTOR_TYPES = ["t2l", "l2l"]
|
||||
"""The ancestor types that contain the core metadata"""
|
||||
@ -89,13 +87,13 @@ class CoreMetadataService(MetadataServiceBase):
|
||||
|
||||
# If the destination node ID matches the given node ID, gather necessary metadata
|
||||
if dest_node_id == node_id:
|
||||
# If the destination field is 'positive_conditioning', add the 'prompt' from the source node
|
||||
# Prompt
|
||||
if dest_field == "positive_conditioning":
|
||||
metadata["positive_conditioning"] = source_node_dict.get("prompt")
|
||||
# If the destination field is 'negative_conditioning', add the 'prompt' from the source node
|
||||
# Negative prompt
|
||||
if dest_field == "negative_conditioning":
|
||||
metadata["negative_conditioning"] = source_node_dict.get("prompt")
|
||||
# If the destination field is 'noise', add the core noise fields from the source node
|
||||
# Seed, width and height
|
||||
if dest_field == "noise":
|
||||
for field in self._NOISE_FIELDS:
|
||||
metadata[field] = source_node_dict.get(field)
|
||||
@ -115,9 +113,10 @@ class CoreMetadataService(MetadataServiceBase):
|
||||
ImageMetadata: The metadata for the node.
|
||||
"""
|
||||
|
||||
# We need to do all the traversal on the execution graph
|
||||
graph = session.execution_graph
|
||||
|
||||
# Find the nearest ancestor of the given node
|
||||
# Find the nearest `t2l`/`l2l` ancestor of the given node
|
||||
ancestor_id = self._find_nearest_ancestor(graph.nx_graph_with_data(), node_id)
|
||||
|
||||
# If no ancestor was found, return an empty ImageMetadata object
|
||||
@ -126,13 +125,14 @@ class CoreMetadataService(MetadataServiceBase):
|
||||
|
||||
ancestor_node = graph.get_node(ancestor_id)
|
||||
|
||||
# Grab all the core metadata from the ancestor node
|
||||
ancestor_metadata = {
|
||||
param: val
|
||||
for param, val in ancestor_node.dict().items()
|
||||
if param in self._ANCESTOR_PARAMS
|
||||
}
|
||||
|
||||
# Get additional metadata related to the ancestor
|
||||
# Get this image's prompts and noise parameters
|
||||
addl_metadata = self._get_additional_metadata(graph, ancestor_id)
|
||||
|
||||
# If additional metadata was found, add it to the main metadata
|
||||
|
Loading…
Reference in New Issue
Block a user