diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index fad875cdbb..8f90aed22b 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -175,8 +175,16 @@ class TextToLatentsInvocation(BaseInvocation): # seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", ) # seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'") progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", ) - control: Optional[ControlField] = Field(default=None, description="The control to use") - # fmt: on + control: list[ControlField] = Field(default=None, description="The controlnet(s) to use") + # control: Union[list[ControlField] | None] = Field(default=None, description="The controlnet(s) to use") + # control: ControlField = Field(default=None, description="The controlnet(s) to use") + # control: Union[ControlField | list[ControlField] | None] = Field(default=None, description="The controlnet(s) to use") + # control: Any = Field(default=None, description="The controlnet(s) to use") + # control: Optional[ControlField] = Field(default=None, description="The control to use") + # control: List[ControlField] = Field(description="The controlnet(s) to use") + # control: Optional[list[ControlField]] = Field(default=None, description="The controlnet(s) to use") + # control: Optional[list[ControlField]] = Field(description="The controlnet(s) to use") + # fmt: on # Schema customisation class Config(InvocationConfig): @@ -246,6 +254,10 @@ class TextToLatentsInvocation(BaseInvocation): def invoke(self, context: InvocationContext) -> LatentsOutput: noise = context.services.latents.get(self.noise.latents_name) + latents_shape = noise.shape + # assuming fixed dimensional scaling of 8:1 for image:latents + control_height_resize = latents_shape[2] * 8 + control_width_resize = latents_shape[3] * 8 # Get the source node id (we are invoking the prepared node) graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id) @@ -259,58 +271,64 @@ class TextToLatentsInvocation(BaseInvocation): print("type of control input: ", type(self.control)) - if (self.control is None): - control_model = None - control_image_field = None - control_weight = None + if self.control is None: + print("control input is None") + control_list = None + elif isinstance(self.control, list) and len(self.control) == 0: + print("control input is empty list") + control_list = None + elif isinstance(self.control, ControlField): + print("control input is ControlField") + # control = [self.control] + control_list = [self.control] + # elif isinstance(self.control, list) and len(self.control)>0 and isinstance(self.control[0], ControlField): + elif isinstance(self.control, list) and len(self.control) > 0 and isinstance(self.control[0], ControlField): + print("control input is list[ControlField]") + # print("using first controlnet in list") + control_list = self.control + # control = self.control else: - control_model_name = self.control.control_model - control_image_field = self.control.image - control_weight = self.control.control_weight + print("input control is unrecognized:", type(self.control)) + control_list = None - # # loading controlnet model - # if (self.control_model is None or self.control_model==''): - # control_model = None - # else: - # FIXME: change this to dropdown menu? - # FIXME: generalize so don't have to hardcode torch_dtype and device - control_model = ControlNetModel.from_pretrained(control_model_name, + #if (self.control is None or (isinstance(self.control, list) and len(self.control)==0)): + if (control_list is None): + control_models = None + control_weights = None + control_images = None + # from above handling, any control that is not None should now be of type list[ControlField] + else: + # FIXME: add checks to skip entry if model or image is None + # and if weight is None, populate with default 1.0? + control_models = [] + control_images = [] + control_weights = [] + for control_info in control_list: + # handle control weights + control_weights.append(control_info.control_weight) + + # handle control models + # FIXME: change this to dropdown menu? + # FIXME: generalize so don't have to hardcode torch_dtype and device + control_model = ControlNetModel.from_pretrained(control_info.control_model, + #torch_dtype=model.unet.dtype).to(model.device) + #torch.dtype=model.unet.dtype).to("cuda") + # torch.dtype = model.unet.dtype).to("cuda") torch_dtype=torch.float16).to("cuda") - model.control_model = control_model + # torch_dtype = torch.float16).to(model.device) + # model.dtype).to(model.device) + control_models.append(control_model) - # loading controlnet image (currently requires pre-processed image) - control_image = ( - None if control_image_field is None - else context.services.images.get( - control_image_field.image_type, control_image_field.image_name - ) - ) - - latents_shape = noise.shape - control_height_resize = latents_shape[2] * 8 - control_width_resize = latents_shape[3] * 8 - - # copied from old backend/txt2img.py - # FIXME: still need to test with different widths, heights, devices, dtypes - # and add in batch_size, num_images_per_prompt? - if control_image is not None: - if isinstance(control_model, ControlNetModel): + # handle control images + # loading controlnet image (currently requires pre-processed image) + # control_image = prep_control_image(control_info.image) + control_image_field = control_info.image + input_image = context.services.images.get(control_image_field.image_type, control_image_field.image_name) + # FIXME: still need to test with different widths, heights, devices, dtypes + # and add in batch_size, num_images_per_prompt? + # and do real check for classifier_free_guidance? control_image = model.prepare_control_image( - image=control_image, - # do_classifier_free_guidance=do_classifier_free_guidance, - do_classifier_free_guidance=True, - width=control_width_resize, - height=control_height_resize, - # batch_size=batch_size * num_images_per_prompt, - # num_images_per_prompt=num_images_per_prompt, - device=control_model.device, - dtype=control_model.dtype, - ) - elif isinstance(control_model, MultiControlNetModel): - images = [] - for image_ in control_image: - image_ = model.prepare_control_image( - image=image_, + image=input_image, # do_classifier_free_guidance=do_classifier_free_guidance, do_classifier_free_guidance=True, width=control_width_resize, @@ -319,9 +337,10 @@ class TextToLatentsInvocation(BaseInvocation): # num_images_per_prompt=num_images_per_prompt, device=control_model.device, dtype=control_model.dtype, - ) - images.append(image_) - control_image = images + ) + control_images.append(control_image) + multi_control = MultiControlNetModel(control_models) + model.control_model = multi_control # TODO: Verify the noise is the right size result_latents, result_attention_map_saver = model.latents_from_embeddings( @@ -330,7 +349,7 @@ class TextToLatentsInvocation(BaseInvocation): num_inference_steps=self.steps, conditioning_data=conditioning_data, callback=step_callback, - control_image=control_image, + control_image=control_images, ) # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699