Cleaning up after ControlNet refactor in TextToLatentsInvocation

This commit is contained in:
user1 2023-05-17 17:30:08 -07:00 committed by Kent Keirsey
parent a9007c7e0f
commit d855a65e73

View File

@ -282,13 +282,9 @@ class TextToLatentsInvocation(BaseInvocation):
control_models = []
for control_info in control_list:
# handle control models
# FIXME: change this to dropdown menu
control_model = ControlNetModel.from_pretrained(control_info.control_model,
torch_dtype=model.unet.dtype).to(model.device)
control_models.append(control_model)
# handle control images
# loading controlnet image (currently requires pre-processed image)
# control_image = prep_control_image(control_info.image)
control_image_field = control_info.image
input_image = context.services.images.get(control_image_field.image_type, control_image_field.image_name)
# FIXME: still need to test with different widths, heights, devices, dtypes
@ -298,7 +294,6 @@ class TextToLatentsInvocation(BaseInvocation):
control_image = model.prepare_control_image(
image=input_image,
do_classifier_free_guidance=do_classifier_free_guidance,
# do_classifier_free_guidance=True,
width=control_width_resize,
height=control_height_resize,
# batch_size=batch_size * num_images_per_prompt,
@ -312,17 +307,11 @@ class TextToLatentsInvocation(BaseInvocation):
begin_step_percent=control_info.begin_step_percent,
end_step_percent=control_info.end_step_percent)
control_data.append(control_item)
# multi_control = MultiControlNetModel(control_models) # no longer need MultiControlNetModel
# model.control_model = multi_control
# model.control_model = control_models
# MultiControlNetModel has been refactored out, just need list[ControlNetData]
return control_data
def invoke(self, context: InvocationContext) -> LatentsOutput:
noise = context.services.latents.get(self.noise.latents_name)
# latents_shape = noise.shape
# assuming fixed dimensional scaling of 8:1 for image:latents
# control_height_resize = latents_shape[2] * 8
# control_width_resize = latents_shape[3] * 8
# Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
@ -337,8 +326,7 @@ class TextToLatentsInvocation(BaseInvocation):
print("type of control input: ", type(self.control))
control_data = self.prep_control_data(model=model, context=context, control_input=self.control,
latents_shape=noise.shape,
do_classifier_free_guidance=(self.cfg_scale >= 1.0),
)
do_classifier_free_guidance=(self.cfg_scale >= 1.0))
# TODO: Verify the noise is the right size
result_latents, result_attention_map_saver = model.latents_from_embeddings(