mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Added support for using multiple control nets. Unfortunately this breaks direct usage of Control node output port ==> TextToLatent control input port -- passing through a Collect node is now required. Working on fixing this...
This commit is contained in:
parent
78b0b37ba6
commit
c0863fa20f
@ -175,8 +175,16 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
# seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||
# seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
||||
progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", )
|
||||
control: Optional[ControlField] = Field(default=None, description="The control to use")
|
||||
# fmt: on
|
||||
control: list[ControlField] = Field(default=None, description="The controlnet(s) to use")
|
||||
# control: Union[list[ControlField] | None] = Field(default=None, description="The controlnet(s) to use")
|
||||
# control: ControlField = Field(default=None, description="The controlnet(s) to use")
|
||||
# control: Union[ControlField | list[ControlField] | None] = Field(default=None, description="The controlnet(s) to use")
|
||||
# control: Any = Field(default=None, description="The controlnet(s) to use")
|
||||
# control: Optional[ControlField] = Field(default=None, description="The control to use")
|
||||
# control: List[ControlField] = Field(description="The controlnet(s) to use")
|
||||
# control: Optional[list[ControlField]] = Field(default=None, description="The controlnet(s) to use")
|
||||
# control: Optional[list[ControlField]] = Field(description="The controlnet(s) to use")
|
||||
# fmt: on
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
@ -246,6 +254,10 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
noise = context.services.latents.get(self.noise.latents_name)
|
||||
latents_shape = noise.shape
|
||||
# assuming fixed dimensional scaling of 8:1 for image:latents
|
||||
control_height_resize = latents_shape[2] * 8
|
||||
control_width_resize = latents_shape[3] * 8
|
||||
|
||||
# Get the source node id (we are invoking the prepared node)
|
||||
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
||||
@ -259,58 +271,64 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
|
||||
print("type of control input: ", type(self.control))
|
||||
|
||||
if (self.control is None):
|
||||
control_model = None
|
||||
control_image_field = None
|
||||
control_weight = None
|
||||
if self.control is None:
|
||||
print("control input is None")
|
||||
control_list = None
|
||||
elif isinstance(self.control, list) and len(self.control) == 0:
|
||||
print("control input is empty list")
|
||||
control_list = None
|
||||
elif isinstance(self.control, ControlField):
|
||||
print("control input is ControlField")
|
||||
# control = [self.control]
|
||||
control_list = [self.control]
|
||||
# elif isinstance(self.control, list) and len(self.control)>0 and isinstance(self.control[0], ControlField):
|
||||
elif isinstance(self.control, list) and len(self.control) > 0 and isinstance(self.control[0], ControlField):
|
||||
print("control input is list[ControlField]")
|
||||
# print("using first controlnet in list")
|
||||
control_list = self.control
|
||||
# control = self.control
|
||||
else:
|
||||
control_model_name = self.control.control_model
|
||||
control_image_field = self.control.image
|
||||
control_weight = self.control.control_weight
|
||||
print("input control is unrecognized:", type(self.control))
|
||||
control_list = None
|
||||
|
||||
# # loading controlnet model
|
||||
# if (self.control_model is None or self.control_model==''):
|
||||
# control_model = None
|
||||
# else:
|
||||
# FIXME: change this to dropdown menu?
|
||||
# FIXME: generalize so don't have to hardcode torch_dtype and device
|
||||
control_model = ControlNetModel.from_pretrained(control_model_name,
|
||||
#if (self.control is None or (isinstance(self.control, list) and len(self.control)==0)):
|
||||
if (control_list is None):
|
||||
control_models = None
|
||||
control_weights = None
|
||||
control_images = None
|
||||
# from above handling, any control that is not None should now be of type list[ControlField]
|
||||
else:
|
||||
# FIXME: add checks to skip entry if model or image is None
|
||||
# and if weight is None, populate with default 1.0?
|
||||
control_models = []
|
||||
control_images = []
|
||||
control_weights = []
|
||||
for control_info in control_list:
|
||||
# handle control weights
|
||||
control_weights.append(control_info.control_weight)
|
||||
|
||||
# handle control models
|
||||
# FIXME: change this to dropdown menu?
|
||||
# FIXME: generalize so don't have to hardcode torch_dtype and device
|
||||
control_model = ControlNetModel.from_pretrained(control_info.control_model,
|
||||
#torch_dtype=model.unet.dtype).to(model.device)
|
||||
#torch.dtype=model.unet.dtype).to("cuda")
|
||||
# torch.dtype = model.unet.dtype).to("cuda")
|
||||
torch_dtype=torch.float16).to("cuda")
|
||||
model.control_model = control_model
|
||||
# torch_dtype = torch.float16).to(model.device)
|
||||
# model.dtype).to(model.device)
|
||||
control_models.append(control_model)
|
||||
|
||||
# loading controlnet image (currently requires pre-processed image)
|
||||
control_image = (
|
||||
None if control_image_field is None
|
||||
else context.services.images.get(
|
||||
control_image_field.image_type, control_image_field.image_name
|
||||
)
|
||||
)
|
||||
|
||||
latents_shape = noise.shape
|
||||
control_height_resize = latents_shape[2] * 8
|
||||
control_width_resize = latents_shape[3] * 8
|
||||
|
||||
# copied from old backend/txt2img.py
|
||||
# FIXME: still need to test with different widths, heights, devices, dtypes
|
||||
# and add in batch_size, num_images_per_prompt?
|
||||
if control_image is not None:
|
||||
if isinstance(control_model, ControlNetModel):
|
||||
# handle control images
|
||||
# loading controlnet image (currently requires pre-processed image)
|
||||
# control_image = prep_control_image(control_info.image)
|
||||
control_image_field = control_info.image
|
||||
input_image = context.services.images.get(control_image_field.image_type, control_image_field.image_name)
|
||||
# FIXME: still need to test with different widths, heights, devices, dtypes
|
||||
# and add in batch_size, num_images_per_prompt?
|
||||
# and do real check for classifier_free_guidance?
|
||||
control_image = model.prepare_control_image(
|
||||
image=control_image,
|
||||
# do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
do_classifier_free_guidance=True,
|
||||
width=control_width_resize,
|
||||
height=control_height_resize,
|
||||
# batch_size=batch_size * num_images_per_prompt,
|
||||
# num_images_per_prompt=num_images_per_prompt,
|
||||
device=control_model.device,
|
||||
dtype=control_model.dtype,
|
||||
)
|
||||
elif isinstance(control_model, MultiControlNetModel):
|
||||
images = []
|
||||
for image_ in control_image:
|
||||
image_ = model.prepare_control_image(
|
||||
image=image_,
|
||||
image=input_image,
|
||||
# do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
do_classifier_free_guidance=True,
|
||||
width=control_width_resize,
|
||||
@ -319,9 +337,10 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
# num_images_per_prompt=num_images_per_prompt,
|
||||
device=control_model.device,
|
||||
dtype=control_model.dtype,
|
||||
)
|
||||
images.append(image_)
|
||||
control_image = images
|
||||
)
|
||||
control_images.append(control_image)
|
||||
multi_control = MultiControlNetModel(control_models)
|
||||
model.control_model = multi_control
|
||||
|
||||
# TODO: Verify the noise is the right size
|
||||
result_latents, result_attention_map_saver = model.latents_from_embeddings(
|
||||
@ -330,7 +349,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
num_inference_steps=self.steps,
|
||||
conditioning_data=conditioning_data,
|
||||
callback=step_callback,
|
||||
control_image=control_image,
|
||||
control_image=control_images,
|
||||
)
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
|
Loading…
Reference in New Issue
Block a user