mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix(nodes): controlnet input accepts list or single controlnet
This commit is contained in:
parent
7467fa5e57
commit
020f3ccf07
@ -110,7 +110,7 @@ class ControlOutput(BaseInvocationOutput):
|
||||
"""node output for ControlNet info"""
|
||||
# fmt: off
|
||||
type: Literal["control_output"] = "control_output"
|
||||
control: Optional[ControlField] = Field(default=None, description="The control info dict")
|
||||
control: ControlField = Field(default=None, description="The control info dict")
|
||||
# fmt: on
|
||||
|
||||
|
||||
|
@ -178,8 +178,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
# seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||
# seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
||||
progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", )
|
||||
control: list[ControlField] = Field(default=None, description="The controlnet(s) to use")
|
||||
# control: ControlField | List[ControlField] = Field(default=None, description="The controlnet(s) to use")
|
||||
control: Union[ControlField, List[ControlField]] = Field(default=None, description="The controlnet(s) to use")
|
||||
# fmt: on
|
||||
|
||||
# Schema customisation
|
||||
|
@ -60,6 +60,35 @@ def get_input_field(node: BaseInvocation, field: str) -> Any:
|
||||
node_input_field = node_inputs.get(field) or None
|
||||
return node_input_field
|
||||
|
||||
from typing import Optional, Union, List, get_args
|
||||
|
||||
def is_union_subtype(t1, t2):
|
||||
t1_args = get_args(t1)
|
||||
t2_args = get_args(t2)
|
||||
|
||||
if not t1_args:
|
||||
# t1 is a single type
|
||||
return t1 in t2_args
|
||||
else:
|
||||
# t1 is a Union, check that all of its types are in t2_args
|
||||
return all(arg in t2_args for arg in t1_args)
|
||||
|
||||
def is_list_or_contains_list(t):
|
||||
t_args = get_args(t)
|
||||
|
||||
# If the type is a List
|
||||
if get_origin(t) is list:
|
||||
return True
|
||||
|
||||
# If the type is a Union
|
||||
elif t_args:
|
||||
# Check if any of the types in the Union is a List
|
||||
for arg in t_args:
|
||||
if get_origin(arg) is list:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def are_connection_types_compatible(from_type: Any, to_type: Any) -> bool:
|
||||
if not from_type:
|
||||
@ -85,7 +114,8 @@ def are_connection_types_compatible(from_type: Any, to_type: Any) -> bool:
|
||||
if to_type in get_args(from_type):
|
||||
return True
|
||||
|
||||
if not issubclass(from_type, to_type):
|
||||
# if not issubclass(from_type, to_type):
|
||||
if not is_union_subtype(from_type, to_type):
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
@ -694,7 +724,11 @@ class Graph(BaseModel):
|
||||
input_root_type = next(t[0] for t in type_degrees if t[1] == 0) # type: ignore
|
||||
|
||||
# Verify that all outputs are lists
|
||||
if not all((get_origin(f) == list for f in output_fields)):
|
||||
# if not all((get_origin(f) == list for f in output_fields)):
|
||||
# return False
|
||||
|
||||
# Verify that all outputs are lists
|
||||
if not all(is_list_or_contains_list(f) for f in output_fields):
|
||||
return False
|
||||
|
||||
# Verify that all outputs match the input type (are a base class or the same class)
|
||||
|
Loading…
Reference in New Issue
Block a user