From 020f3ccf079008dce4363c909061fb2c78938bcb Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 19 May 2023 16:10:39 +1000 Subject: [PATCH] fix(nodes): controlnet input accepts list or single controlnet --- .../controlnet_image_processors.py | 2 +- invokeai/app/invocations/latent.py | 3 +- invokeai/app/services/graph.py | 38 ++++++++++++++++++- 3 files changed, 38 insertions(+), 5 deletions(-) diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py index 1cb531c77d..1987381a7e 100644 --- a/invokeai/app/invocations/controlnet_image_processors.py +++ b/invokeai/app/invocations/controlnet_image_processors.py @@ -110,7 +110,7 @@ class ControlOutput(BaseInvocationOutput): """node output for ControlNet info""" # fmt: off type: Literal["control_output"] = "control_output" - control: Optional[ControlField] = Field(default=None, description="The control info dict") + control: ControlField = Field(default=None, description="The control info dict") # fmt: on diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 1befe483f0..4e5b97919f 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -178,8 +178,7 @@ class TextToLatentsInvocation(BaseInvocation): # seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", ) # seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'") progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", ) - control: list[ControlField] = Field(default=None, description="The controlnet(s) to use") - # control: ControlField | List[ControlField] = Field(default=None, description="The controlnet(s) to use") + control: Union[ControlField, List[ControlField]] = Field(default=None, description="The controlnet(s) to use") # fmt: on # Schema customisation diff --git a/invokeai/app/services/graph.py b/invokeai/app/services/graph.py index 44688ada0a..60e196faa1 100644 --- a/invokeai/app/services/graph.py +++ b/invokeai/app/services/graph.py @@ -60,6 +60,35 @@ def get_input_field(node: BaseInvocation, field: str) -> Any: node_input_field = node_inputs.get(field) or None return node_input_field +from typing import Optional, Union, List, get_args + +def is_union_subtype(t1, t2): + t1_args = get_args(t1) + t2_args = get_args(t2) + + if not t1_args: + # t1 is a single type + return t1 in t2_args + else: + # t1 is a Union, check that all of its types are in t2_args + return all(arg in t2_args for arg in t1_args) + +def is_list_or_contains_list(t): + t_args = get_args(t) + + # If the type is a List + if get_origin(t) is list: + return True + + # If the type is a Union + elif t_args: + # Check if any of the types in the Union is a List + for arg in t_args: + if get_origin(arg) is list: + return True + + return False + def are_connection_types_compatible(from_type: Any, to_type: Any) -> bool: if not from_type: @@ -85,7 +114,8 @@ def are_connection_types_compatible(from_type: Any, to_type: Any) -> bool: if to_type in get_args(from_type): return True - if not issubclass(from_type, to_type): + # if not issubclass(from_type, to_type): + if not is_union_subtype(from_type, to_type): return False else: return False @@ -694,7 +724,11 @@ class Graph(BaseModel): input_root_type = next(t[0] for t in type_degrees if t[1] == 0) # type: ignore # Verify that all outputs are lists - if not all((get_origin(f) == list for f in output_fields)): + # if not all((get_origin(f) == list for f in output_fields)): + # return False + + # Verify that all outputs are lists + if not all(is_list_or_contains_list(f) for f in output_fields): return False # Verify that all outputs match the input type (are a base class or the same class)