Added support for using multiple control nets. Unfortunately this breaks direct usage of Control node output port ==> TextToLatent control input port -- passing through a Collect node is now required. Working on fixing this...

This commit is contained in:
user1 2023-05-08 19:19:24 -07:00 committed by Kent Keirsey
parent 78b0b37ba6
commit c0863fa20f

View File

@ -175,8 +175,16 @@ class TextToLatentsInvocation(BaseInvocation):
# seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", ) # seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
# seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'") # seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", ) progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", )
control: Optional[ControlField] = Field(default=None, description="The control to use") control: list[ControlField] = Field(default=None, description="The controlnet(s) to use")
# fmt: on # control: Union[list[ControlField] | None] = Field(default=None, description="The controlnet(s) to use")
# control: ControlField = Field(default=None, description="The controlnet(s) to use")
# control: Union[ControlField | list[ControlField] | None] = Field(default=None, description="The controlnet(s) to use")
# control: Any = Field(default=None, description="The controlnet(s) to use")
# control: Optional[ControlField] = Field(default=None, description="The control to use")
# control: List[ControlField] = Field(description="The controlnet(s) to use")
# control: Optional[list[ControlField]] = Field(default=None, description="The controlnet(s) to use")
# control: Optional[list[ControlField]] = Field(description="The controlnet(s) to use")
# fmt: on
# Schema customisation # Schema customisation
class Config(InvocationConfig): class Config(InvocationConfig):
@ -246,6 +254,10 @@ class TextToLatentsInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> LatentsOutput: def invoke(self, context: InvocationContext) -> LatentsOutput:
noise = context.services.latents.get(self.noise.latents_name) noise = context.services.latents.get(self.noise.latents_name)
latents_shape = noise.shape
# assuming fixed dimensional scaling of 8:1 for image:latents
control_height_resize = latents_shape[2] * 8
control_width_resize = latents_shape[3] * 8
# Get the source node id (we are invoking the prepared node) # Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id) graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
@ -259,58 +271,64 @@ class TextToLatentsInvocation(BaseInvocation):
print("type of control input: ", type(self.control)) print("type of control input: ", type(self.control))
if (self.control is None): if self.control is None:
control_model = None print("control input is None")
control_image_field = None control_list = None
control_weight = None elif isinstance(self.control, list) and len(self.control) == 0:
print("control input is empty list")
control_list = None
elif isinstance(self.control, ControlField):
print("control input is ControlField")
# control = [self.control]
control_list = [self.control]
# elif isinstance(self.control, list) and len(self.control)>0 and isinstance(self.control[0], ControlField):
elif isinstance(self.control, list) and len(self.control) > 0 and isinstance(self.control[0], ControlField):
print("control input is list[ControlField]")
# print("using first controlnet in list")
control_list = self.control
# control = self.control
else: else:
control_model_name = self.control.control_model print("input control is unrecognized:", type(self.control))
control_image_field = self.control.image control_list = None
control_weight = self.control.control_weight
# # loading controlnet model #if (self.control is None or (isinstance(self.control, list) and len(self.control)==0)):
# if (self.control_model is None or self.control_model==''): if (control_list is None):
# control_model = None control_models = None
# else: control_weights = None
# FIXME: change this to dropdown menu? control_images = None
# FIXME: generalize so don't have to hardcode torch_dtype and device # from above handling, any control that is not None should now be of type list[ControlField]
control_model = ControlNetModel.from_pretrained(control_model_name, else:
# FIXME: add checks to skip entry if model or image is None
# and if weight is None, populate with default 1.0?
control_models = []
control_images = []
control_weights = []
for control_info in control_list:
# handle control weights
control_weights.append(control_info.control_weight)
# handle control models
# FIXME: change this to dropdown menu?
# FIXME: generalize so don't have to hardcode torch_dtype and device
control_model = ControlNetModel.from_pretrained(control_info.control_model,
#torch_dtype=model.unet.dtype).to(model.device)
#torch.dtype=model.unet.dtype).to("cuda")
# torch.dtype = model.unet.dtype).to("cuda")
torch_dtype=torch.float16).to("cuda") torch_dtype=torch.float16).to("cuda")
model.control_model = control_model # torch_dtype = torch.float16).to(model.device)
# model.dtype).to(model.device)
control_models.append(control_model)
# loading controlnet image (currently requires pre-processed image) # handle control images
control_image = ( # loading controlnet image (currently requires pre-processed image)
None if control_image_field is None # control_image = prep_control_image(control_info.image)
else context.services.images.get( control_image_field = control_info.image
control_image_field.image_type, control_image_field.image_name input_image = context.services.images.get(control_image_field.image_type, control_image_field.image_name)
) # FIXME: still need to test with different widths, heights, devices, dtypes
) # and add in batch_size, num_images_per_prompt?
# and do real check for classifier_free_guidance?
latents_shape = noise.shape
control_height_resize = latents_shape[2] * 8
control_width_resize = latents_shape[3] * 8
# copied from old backend/txt2img.py
# FIXME: still need to test with different widths, heights, devices, dtypes
# and add in batch_size, num_images_per_prompt?
if control_image is not None:
if isinstance(control_model, ControlNetModel):
control_image = model.prepare_control_image( control_image = model.prepare_control_image(
image=control_image, image=input_image,
# do_classifier_free_guidance=do_classifier_free_guidance,
do_classifier_free_guidance=True,
width=control_width_resize,
height=control_height_resize,
# batch_size=batch_size * num_images_per_prompt,
# num_images_per_prompt=num_images_per_prompt,
device=control_model.device,
dtype=control_model.dtype,
)
elif isinstance(control_model, MultiControlNetModel):
images = []
for image_ in control_image:
image_ = model.prepare_control_image(
image=image_,
# do_classifier_free_guidance=do_classifier_free_guidance, # do_classifier_free_guidance=do_classifier_free_guidance,
do_classifier_free_guidance=True, do_classifier_free_guidance=True,
width=control_width_resize, width=control_width_resize,
@ -319,9 +337,10 @@ class TextToLatentsInvocation(BaseInvocation):
# num_images_per_prompt=num_images_per_prompt, # num_images_per_prompt=num_images_per_prompt,
device=control_model.device, device=control_model.device,
dtype=control_model.dtype, dtype=control_model.dtype,
) )
images.append(image_) control_images.append(control_image)
control_image = images multi_control = MultiControlNetModel(control_models)
model.control_model = multi_control
# TODO: Verify the noise is the right size # TODO: Verify the noise is the right size
result_latents, result_attention_map_saver = model.latents_from_embeddings( result_latents, result_attention_map_saver = model.latents_from_embeddings(
@ -330,7 +349,7 @@ class TextToLatentsInvocation(BaseInvocation):
num_inference_steps=self.steps, num_inference_steps=self.steps,
conditioning_data=conditioning_data, conditioning_data=conditioning_data,
callback=step_callback, callback=step_callback,
control_image=control_image, control_image=control_images,
) )
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699