mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Added support for using multiple control nets. Unfortunately this breaks direct usage of Control node output port ==> TextToLatent control input port -- passing through a Collect node is now required. Working on fixing this...
This commit is contained in:
parent
78b0b37ba6
commit
c0863fa20f
@ -175,8 +175,16 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
# seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
# seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||||
# seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
# seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
||||||
progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", )
|
progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", )
|
||||||
control: Optional[ControlField] = Field(default=None, description="The control to use")
|
control: list[ControlField] = Field(default=None, description="The controlnet(s) to use")
|
||||||
# fmt: on
|
# control: Union[list[ControlField] | None] = Field(default=None, description="The controlnet(s) to use")
|
||||||
|
# control: ControlField = Field(default=None, description="The controlnet(s) to use")
|
||||||
|
# control: Union[ControlField | list[ControlField] | None] = Field(default=None, description="The controlnet(s) to use")
|
||||||
|
# control: Any = Field(default=None, description="The controlnet(s) to use")
|
||||||
|
# control: Optional[ControlField] = Field(default=None, description="The control to use")
|
||||||
|
# control: List[ControlField] = Field(description="The controlnet(s) to use")
|
||||||
|
# control: Optional[list[ControlField]] = Field(default=None, description="The controlnet(s) to use")
|
||||||
|
# control: Optional[list[ControlField]] = Field(description="The controlnet(s) to use")
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
# Schema customisation
|
# Schema customisation
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
@ -246,6 +254,10 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
noise = context.services.latents.get(self.noise.latents_name)
|
noise = context.services.latents.get(self.noise.latents_name)
|
||||||
|
latents_shape = noise.shape
|
||||||
|
# assuming fixed dimensional scaling of 8:1 for image:latents
|
||||||
|
control_height_resize = latents_shape[2] * 8
|
||||||
|
control_width_resize = latents_shape[3] * 8
|
||||||
|
|
||||||
# Get the source node id (we are invoking the prepared node)
|
# Get the source node id (we are invoking the prepared node)
|
||||||
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
||||||
@ -259,58 +271,64 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
print("type of control input: ", type(self.control))
|
print("type of control input: ", type(self.control))
|
||||||
|
|
||||||
if (self.control is None):
|
if self.control is None:
|
||||||
control_model = None
|
print("control input is None")
|
||||||
control_image_field = None
|
control_list = None
|
||||||
control_weight = None
|
elif isinstance(self.control, list) and len(self.control) == 0:
|
||||||
|
print("control input is empty list")
|
||||||
|
control_list = None
|
||||||
|
elif isinstance(self.control, ControlField):
|
||||||
|
print("control input is ControlField")
|
||||||
|
# control = [self.control]
|
||||||
|
control_list = [self.control]
|
||||||
|
# elif isinstance(self.control, list) and len(self.control)>0 and isinstance(self.control[0], ControlField):
|
||||||
|
elif isinstance(self.control, list) and len(self.control) > 0 and isinstance(self.control[0], ControlField):
|
||||||
|
print("control input is list[ControlField]")
|
||||||
|
# print("using first controlnet in list")
|
||||||
|
control_list = self.control
|
||||||
|
# control = self.control
|
||||||
else:
|
else:
|
||||||
control_model_name = self.control.control_model
|
print("input control is unrecognized:", type(self.control))
|
||||||
control_image_field = self.control.image
|
control_list = None
|
||||||
control_weight = self.control.control_weight
|
|
||||||
|
|
||||||
# # loading controlnet model
|
#if (self.control is None or (isinstance(self.control, list) and len(self.control)==0)):
|
||||||
# if (self.control_model is None or self.control_model==''):
|
if (control_list is None):
|
||||||
# control_model = None
|
control_models = None
|
||||||
# else:
|
control_weights = None
|
||||||
# FIXME: change this to dropdown menu?
|
control_images = None
|
||||||
# FIXME: generalize so don't have to hardcode torch_dtype and device
|
# from above handling, any control that is not None should now be of type list[ControlField]
|
||||||
control_model = ControlNetModel.from_pretrained(control_model_name,
|
else:
|
||||||
|
# FIXME: add checks to skip entry if model or image is None
|
||||||
|
# and if weight is None, populate with default 1.0?
|
||||||
|
control_models = []
|
||||||
|
control_images = []
|
||||||
|
control_weights = []
|
||||||
|
for control_info in control_list:
|
||||||
|
# handle control weights
|
||||||
|
control_weights.append(control_info.control_weight)
|
||||||
|
|
||||||
|
# handle control models
|
||||||
|
# FIXME: change this to dropdown menu?
|
||||||
|
# FIXME: generalize so don't have to hardcode torch_dtype and device
|
||||||
|
control_model = ControlNetModel.from_pretrained(control_info.control_model,
|
||||||
|
#torch_dtype=model.unet.dtype).to(model.device)
|
||||||
|
#torch.dtype=model.unet.dtype).to("cuda")
|
||||||
|
# torch.dtype = model.unet.dtype).to("cuda")
|
||||||
torch_dtype=torch.float16).to("cuda")
|
torch_dtype=torch.float16).to("cuda")
|
||||||
model.control_model = control_model
|
# torch_dtype = torch.float16).to(model.device)
|
||||||
|
# model.dtype).to(model.device)
|
||||||
|
control_models.append(control_model)
|
||||||
|
|
||||||
# loading controlnet image (currently requires pre-processed image)
|
# handle control images
|
||||||
control_image = (
|
# loading controlnet image (currently requires pre-processed image)
|
||||||
None if control_image_field is None
|
# control_image = prep_control_image(control_info.image)
|
||||||
else context.services.images.get(
|
control_image_field = control_info.image
|
||||||
control_image_field.image_type, control_image_field.image_name
|
input_image = context.services.images.get(control_image_field.image_type, control_image_field.image_name)
|
||||||
)
|
# FIXME: still need to test with different widths, heights, devices, dtypes
|
||||||
)
|
# and add in batch_size, num_images_per_prompt?
|
||||||
|
# and do real check for classifier_free_guidance?
|
||||||
latents_shape = noise.shape
|
|
||||||
control_height_resize = latents_shape[2] * 8
|
|
||||||
control_width_resize = latents_shape[3] * 8
|
|
||||||
|
|
||||||
# copied from old backend/txt2img.py
|
|
||||||
# FIXME: still need to test with different widths, heights, devices, dtypes
|
|
||||||
# and add in batch_size, num_images_per_prompt?
|
|
||||||
if control_image is not None:
|
|
||||||
if isinstance(control_model, ControlNetModel):
|
|
||||||
control_image = model.prepare_control_image(
|
control_image = model.prepare_control_image(
|
||||||
image=control_image,
|
image=input_image,
|
||||||
# do_classifier_free_guidance=do_classifier_free_guidance,
|
|
||||||
do_classifier_free_guidance=True,
|
|
||||||
width=control_width_resize,
|
|
||||||
height=control_height_resize,
|
|
||||||
# batch_size=batch_size * num_images_per_prompt,
|
|
||||||
# num_images_per_prompt=num_images_per_prompt,
|
|
||||||
device=control_model.device,
|
|
||||||
dtype=control_model.dtype,
|
|
||||||
)
|
|
||||||
elif isinstance(control_model, MultiControlNetModel):
|
|
||||||
images = []
|
|
||||||
for image_ in control_image:
|
|
||||||
image_ = model.prepare_control_image(
|
|
||||||
image=image_,
|
|
||||||
# do_classifier_free_guidance=do_classifier_free_guidance,
|
# do_classifier_free_guidance=do_classifier_free_guidance,
|
||||||
do_classifier_free_guidance=True,
|
do_classifier_free_guidance=True,
|
||||||
width=control_width_resize,
|
width=control_width_resize,
|
||||||
@ -319,9 +337,10 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
# num_images_per_prompt=num_images_per_prompt,
|
# num_images_per_prompt=num_images_per_prompt,
|
||||||
device=control_model.device,
|
device=control_model.device,
|
||||||
dtype=control_model.dtype,
|
dtype=control_model.dtype,
|
||||||
)
|
)
|
||||||
images.append(image_)
|
control_images.append(control_image)
|
||||||
control_image = images
|
multi_control = MultiControlNetModel(control_models)
|
||||||
|
model.control_model = multi_control
|
||||||
|
|
||||||
# TODO: Verify the noise is the right size
|
# TODO: Verify the noise is the right size
|
||||||
result_latents, result_attention_map_saver = model.latents_from_embeddings(
|
result_latents, result_attention_map_saver = model.latents_from_embeddings(
|
||||||
@ -330,7 +349,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
num_inference_steps=self.steps,
|
num_inference_steps=self.steps,
|
||||||
conditioning_data=conditioning_data,
|
conditioning_data=conditioning_data,
|
||||||
callback=step_callback,
|
callback=step_callback,
|
||||||
control_image=control_image,
|
control_image=control_images,
|
||||||
)
|
)
|
||||||
|
|
||||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||||
|
Loading…
Reference in New Issue
Block a user