mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix(nodes): controlnet input accepts list or single controlnet
This commit is contained in:
parent
7467fa5e57
commit
020f3ccf07
@ -110,7 +110,7 @@ class ControlOutput(BaseInvocationOutput):
|
|||||||
"""node output for ControlNet info"""
|
"""node output for ControlNet info"""
|
||||||
# fmt: off
|
# fmt: off
|
||||||
type: Literal["control_output"] = "control_output"
|
type: Literal["control_output"] = "control_output"
|
||||||
control: Optional[ControlField] = Field(default=None, description="The control info dict")
|
control: ControlField = Field(default=None, description="The control info dict")
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
|
@ -178,8 +178,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
# seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
# seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||||
# seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
# seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
||||||
progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", )
|
progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", )
|
||||||
control: list[ControlField] = Field(default=None, description="The controlnet(s) to use")
|
control: Union[ControlField, List[ControlField]] = Field(default=None, description="The controlnet(s) to use")
|
||||||
# control: ControlField | List[ControlField] = Field(default=None, description="The controlnet(s) to use")
|
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
# Schema customisation
|
# Schema customisation
|
||||||
|
@ -60,6 +60,35 @@ def get_input_field(node: BaseInvocation, field: str) -> Any:
|
|||||||
node_input_field = node_inputs.get(field) or None
|
node_input_field = node_inputs.get(field) or None
|
||||||
return node_input_field
|
return node_input_field
|
||||||
|
|
||||||
|
from typing import Optional, Union, List, get_args
|
||||||
|
|
||||||
|
def is_union_subtype(t1, t2):
|
||||||
|
t1_args = get_args(t1)
|
||||||
|
t2_args = get_args(t2)
|
||||||
|
|
||||||
|
if not t1_args:
|
||||||
|
# t1 is a single type
|
||||||
|
return t1 in t2_args
|
||||||
|
else:
|
||||||
|
# t1 is a Union, check that all of its types are in t2_args
|
||||||
|
return all(arg in t2_args for arg in t1_args)
|
||||||
|
|
||||||
|
def is_list_or_contains_list(t):
|
||||||
|
t_args = get_args(t)
|
||||||
|
|
||||||
|
# If the type is a List
|
||||||
|
if get_origin(t) is list:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# If the type is a Union
|
||||||
|
elif t_args:
|
||||||
|
# Check if any of the types in the Union is a List
|
||||||
|
for arg in t_args:
|
||||||
|
if get_origin(arg) is list:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def are_connection_types_compatible(from_type: Any, to_type: Any) -> bool:
|
def are_connection_types_compatible(from_type: Any, to_type: Any) -> bool:
|
||||||
if not from_type:
|
if not from_type:
|
||||||
@ -85,7 +114,8 @@ def are_connection_types_compatible(from_type: Any, to_type: Any) -> bool:
|
|||||||
if to_type in get_args(from_type):
|
if to_type in get_args(from_type):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if not issubclass(from_type, to_type):
|
# if not issubclass(from_type, to_type):
|
||||||
|
if not is_union_subtype(from_type, to_type):
|
||||||
return False
|
return False
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
@ -694,7 +724,11 @@ class Graph(BaseModel):
|
|||||||
input_root_type = next(t[0] for t in type_degrees if t[1] == 0) # type: ignore
|
input_root_type = next(t[0] for t in type_degrees if t[1] == 0) # type: ignore
|
||||||
|
|
||||||
# Verify that all outputs are lists
|
# Verify that all outputs are lists
|
||||||
if not all((get_origin(f) == list for f in output_fields)):
|
# if not all((get_origin(f) == list for f in output_fields)):
|
||||||
|
# return False
|
||||||
|
|
||||||
|
# Verify that all outputs are lists
|
||||||
|
if not all(is_list_or_contains_list(f) for f in output_fields):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Verify that all outputs match the input type (are a base class or the same class)
|
# Verify that all outputs match the input type (are a base class or the same class)
|
||||||
|
Loading…
Reference in New Issue
Block a user