From bb96543d6620cf3b56c2bfcd60ed1bcbd75f9402 Mon Sep 17 00:00:00 2001 From: user1 Date: Mon, 8 May 2023 19:19:24 -0700 Subject: [PATCH] Added support for using multiple control nets. Unfortunately this breaks direct usage of Control node output port ==> TextToLatent control input port -- passing through a Collect node is now required. Working on fixing this... --- invokeai/app/invocations/latent.py | 269 +++++++++++++++-------------- 1 file changed, 139 insertions(+), 130 deletions(-) diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 4e5b97919f..efb7c9ab74 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -2,14 +2,14 @@ import random import einops -from pydantic import BaseModel, Field, validator -import torch from typing import Literal, Optional, Union, List from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel +from pydantic import BaseModel, Field +import torch + from invokeai.app.invocations.util.choose_model import choose_model -from invokeai.app.models.image import ImageCategory from invokeai.app.util.misc import SEED_MAX, get_random_seed from invokeai.app.util.step_callback import stable_diffusion_step_callback @@ -20,16 +20,13 @@ from ...backend.util.devices import choose_torch_device, torch_dtype from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings from ...backend.image_util.seamless import configure_model_padding from ...backend.prompting.conditioning import get_uc_and_c_and_ec - from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline, image_resized_to_grid_as_tensor from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP -from ...backend.stable_diffusion.diffusers_pipeline import ControlNetData - from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig import numpy as np -from ..services.image_file_storage import ImageType +from ..services.image_storage import ImageType from .baseinvocation import BaseInvocation, InvocationContext -from .image import ImageField, ImageOutput +from .image import ImageField, ImageOutput, build_image_output from .compel import ConditioningField from ...backend.stable_diffusion import PipelineIntermediateState from diffusers.schedulers import SchedulerMixin as Scheduler @@ -90,13 +87,13 @@ SAMPLER_NAME_VALUES = Literal[ def get_scheduler(scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler: scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP['ddim']) - + scheduler_config = model.scheduler.config if "_backup" in scheduler_config: scheduler_config = scheduler_config["_backup"] scheduler_config = {**scheduler_config, **scheduler_extra_config, "_backup": scheduler_config} scheduler = scheduler_class.from_config(scheduler_config) - + # hack copied over from generate.py if not hasattr(scheduler, 'uses_inpainting_model'): scheduler.uses_inpainting_model = lambda: False @@ -146,17 +143,12 @@ class NoiseInvocation(BaseInvocation): }, } - @validator("seed", pre=True) - def modulo_seed(cls, v): - """Returns the seed modulo SEED_MAX to ensure it is within the valid range.""" - return v % SEED_MAX - def invoke(self, context: InvocationContext) -> NoiseOutput: device = torch.device(choose_torch_device()) noise = get_noise(self.width, self.height, device, self.seed) name = f'{context.graph_execution_state_id}__{self.id}' - context.services.latents.save(name, noise) + context.services.latents.set(name, noise) return build_noise_output(latents_name=name, latents=noise) @@ -173,22 +165,29 @@ class TextToLatentsInvocation(BaseInvocation): noise: Optional[LatentsField] = Field(description="The noise to use") steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image") cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", ) - scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" ) + scheduler: SAMPLER_NAME_VALUES = Field(default="lms", description="The scheduler to use" ) model: str = Field(default="", description="The model to use (currently ignored)") - # seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", ) - # seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'") + seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", ) + seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'") progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", ) - control: Union[ControlField, List[ControlField]] = Field(default=None, description="The controlnet(s) to use") - # fmt: on + control: list[ControlField] = Field(default=None, description="The controlnet(s) to use") + # control: Union[list[ControlField] | None] = Field(default=None, description="The controlnet(s) to use") + # control: ControlField = Field(default=None, description="The controlnet(s) to use") + # control: Union[ControlField | list[ControlField] | None] = Field(default=None, description="The controlnet(s) to use") + # control: Any = Field(default=None, description="The controlnet(s) to use") + # control: Optional[ControlField] = Field(default=None, description="The control to use") + # control: List[ControlField] = Field(description="The controlnet(s) to use") + # control: Optional[list[ControlField]] = Field(default=None, description="The controlnet(s) to use") + # control: Optional[list[ControlField]] = Field(description="The controlnet(s) to use") + # fmt: on # Schema customisation class Config(InvocationConfig): schema_extra = { "ui": { - "tags": ["latents"], + "tags": ["latents", "image"], "type_hints": { - "model": "model", - "control": "control", + "model": "model" } }, } @@ -214,17 +213,17 @@ class TextToLatentsInvocation(BaseInvocation): scheduler_name=self.scheduler ) - # if isinstance(model, DiffusionPipeline): - # for component in [model.unet, model.vae]: - # configure_model_padding(component, - # self.seamless, - # self.seamless_axes - # ) - # else: - # configure_model_padding(model, - # self.seamless, - # self.seamless_axes - # ) + if isinstance(model, DiffusionPipeline): + for component in [model.unet, model.vae]: + configure_model_padding(component, + self.seamless, + self.seamless_axes + ) + else: + configure_model_padding(model, + self.seamless, + self.seamless_axes + ) return model @@ -247,71 +246,13 @@ class TextToLatentsInvocation(BaseInvocation): ).add_scheduler_args_if_applicable(model.scheduler, eta=0.0)#ddim_eta) return conditioning_data - def prep_control_data(self, - context: InvocationContext, - model: StableDiffusionGeneratorPipeline, # really only need model for dtype and device - control_input: List[ControlField], - latents_shape: List[int], - do_classifier_free_guidance: bool = True, - ) -> List[ControlNetData]: - # assuming fixed dimensional scaling of 8:1 for image:latents - control_height_resize = latents_shape[2] * 8 - control_width_resize = latents_shape[3] * 8 - if control_input is None: - # print("control input is None") - control_list = None - elif isinstance(control_input, list) and len(control_input) == 0: - # print("control input is empty list") - control_list = None - elif isinstance(control_input, ControlField): - # print("control input is ControlField") - control_list = [control_input] - elif isinstance(control_input, list) and len(control_input) > 0 and isinstance(control_input[0], ControlField): - # print("control input is list[ControlField]") - control_list = control_input - else: - # print("input control is unrecognized:", type(self.control)) - control_list = None - if (control_list is None): - control_data = None - # from above handling, any control that is not None should now be of type list[ControlField] - else: - # FIXME: add checks to skip entry if model or image is None - # and if weight is None, populate with default 1.0? - control_data = [] - control_models = [] - for control_info in control_list: - # handle control models - control_model = ControlNetModel.from_pretrained(control_info.control_model, - torch_dtype=model.unet.dtype).to(model.device) - control_models.append(control_model) - control_image_field = control_info.image - input_image = context.services.images.get(control_image_field.image_type, control_image_field.image_name) - # FIXME: still need to test with different widths, heights, devices, dtypes - # and add in batch_size, num_images_per_prompt? - # and do real check for classifier_free_guidance? - # prepare_control_image should return torch.Tensor of shape(batch_size, 3, height, width) - control_image = model.prepare_control_image( - image=input_image, - do_classifier_free_guidance=do_classifier_free_guidance, - width=control_width_resize, - height=control_height_resize, - # batch_size=batch_size * num_images_per_prompt, - # num_images_per_prompt=num_images_per_prompt, - device=control_model.device, - dtype=control_model.dtype, - ) - control_item = ControlNetData(model=control_model, - image_tensor=control_image, - weight=control_info.control_weight, - begin_step_percent=control_info.begin_step_percent, - end_step_percent=control_info.end_step_percent) - control_data.append(control_item) - # MultiControlNetModel has been refactored out, just need list[ControlNetData] - return control_data def invoke(self, context: InvocationContext) -> LatentsOutput: noise = context.services.latents.get(self.noise.latents_name) + latents_shape = noise.shape + # assuming fixed dimensional scaling of 8:1 for image:latents + control_height_resize = latents_shape[2] * 8 + control_width_resize = latents_shape[3] * 8 # Get the source node id (we are invoking the prepared node) graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id) @@ -324,9 +265,77 @@ class TextToLatentsInvocation(BaseInvocation): conditioning_data = self.get_conditioning_data(context, model) print("type of control input: ", type(self.control)) - control_data = self.prep_control_data(model=model, context=context, control_input=self.control, - latents_shape=noise.shape, - do_classifier_free_guidance=(self.cfg_scale >= 1.0)) + + if self.control is None: + print("control input is None") + control_list = None + elif isinstance(self.control, list) and len(self.control) == 0: + print("control input is empty list") + control_list = None + elif isinstance(self.control, ControlField): + print("control input is ControlField") + # control = [self.control] + control_list = [self.control] + # elif isinstance(self.control, list) and len(self.control)>0 and isinstance(self.control[0], ControlField): + elif isinstance(self.control, list) and len(self.control) > 0 and isinstance(self.control[0], ControlField): + print("control input is list[ControlField]") + # print("using first controlnet in list") + control_list = self.control + # control = self.control + else: + print("input control is unrecognized:", type(self.control)) + control_list = None + + #if (self.control is None or (isinstance(self.control, list) and len(self.control)==0)): + if (control_list is None): + control_models = None + control_weights = None + control_images = None + # from above handling, any control that is not None should now be of type list[ControlField] + else: + # FIXME: add checks to skip entry if model or image is None + # and if weight is None, populate with default 1.0? + control_models = [] + control_images = [] + control_weights = [] + for control_info in control_list: + # handle control weights + control_weights.append(control_info.control_weight) + + # handle control models + # FIXME: change this to dropdown menu? + # FIXME: generalize so don't have to hardcode torch_dtype and device + control_model = ControlNetModel.from_pretrained(control_info.control_model, + #torch_dtype=model.unet.dtype).to(model.device) + #torch.dtype=model.unet.dtype).to("cuda") + # torch.dtype = model.unet.dtype).to("cuda") + torch_dtype=torch.float16).to("cuda") + # torch_dtype = torch.float16).to(model.device) + # model.dtype).to(model.device) + control_models.append(control_model) + + # handle control images + # loading controlnet image (currently requires pre-processed image) + # control_image = prep_control_image(control_info.image) + control_image_field = control_info.image + input_image = context.services.images.get(control_image_field.image_type, control_image_field.image_name) + # FIXME: still need to test with different widths, heights, devices, dtypes + # and add in batch_size, num_images_per_prompt? + # and do real check for classifier_free_guidance? + control_image = model.prepare_control_image( + image=input_image, + # do_classifier_free_guidance=do_classifier_free_guidance, + do_classifier_free_guidance=True, + width=control_width_resize, + height=control_height_resize, + # batch_size=batch_size * num_images_per_prompt, + # num_images_per_prompt=num_images_per_prompt, + device=control_model.device, + dtype=control_model.dtype, + ) + control_images.append(control_image) + multi_control = MultiControlNetModel(control_models) + model.control_model = multi_control # TODO: Verify the noise is the right size result_latents, result_attention_map_saver = model.latents_from_embeddings( @@ -334,15 +343,15 @@ class TextToLatentsInvocation(BaseInvocation): noise=noise, num_inference_steps=self.steps, conditioning_data=conditioning_data, - control_data=control_data, # list[ControlNetData] callback=step_callback, + control_image=control_images, ) # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 torch.cuda.empty_cache() name = f'{context.graph_execution_state_id}__{self.id}' - context.services.latents.save(name, result_latents) + context.services.latents.set(name, result_latents) return build_latents_output(latents_name=name, latents=result_latents) @@ -355,6 +364,17 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation): latents: Optional[LatentsField] = Field(description="The latents to use as a base image") strength: float = Field(default=0.5, description="The strength of the latents to use") + # Schema customisation + class Config(InvocationConfig): + schema_extra = { + "ui": { + "tags": ["latents"], + "type_hints": { + "model": "model" + } + }, + } + def invoke(self, context: InvocationContext) -> LatentsOutput: noise = context.services.latents.get(self.noise.latents_name) latent = context.services.latents.get(self.latents.latents_name) @@ -369,11 +389,6 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation): model = self.get_model(context.services.model_manager) conditioning_data = self.get_conditioning_data(context, model) - print("type of control input: ", type(self.control)) - control_data = self.prep_control_data(model=model, context=context, control_input=self.control, - latents_shape=noise.shape, - do_classifier_free_guidance=(self.cfg_scale >= 1.0)) - # TODO: Verify the noise is the right size initial_latents = latent if self.strength < 1.0 else torch.zeros_like( @@ -388,7 +403,6 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation): noise=noise, num_inference_steps=self.steps, conditioning_data=conditioning_data, - control_data=control_data, # list[ControlNetData] callback=step_callback ) @@ -396,7 +410,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation): torch.cuda.empty_cache() name = f'{context.graph_execution_state_id}__{self.id}' - context.services.latents.save(name, result_latents) + context.services.latents.set(name, result_latents) return build_latents_output(latents_name=name, latents=result_latents) @@ -433,24 +447,20 @@ class LatentsToImageInvocation(BaseInvocation): np_image = model.decode_latents(latents) image = model.numpy_to_pil(np_image)[0] - torch.cuda.empty_cache() - - image_dto = context.services.images.create( - image=image, - image_type=ImageType.RESULT, - image_category=ImageCategory.GENERAL, - session_id=context.graph_execution_state_id, - node_id=self.id, - is_intermediate=self.is_intermediate + image_type = ImageType.RESULT + image_name = context.services.images.create_name( + context.graph_execution_state_id, self.id ) - return ImageOutput( - image=ImageField( - image_name=image_dto.image_name, - image_type=image_dto.image_type, - ), - width=image_dto.width, - height=image_dto.height, + metadata = context.services.metadata.build_metadata( + session_id=context.graph_execution_state_id, node=self + ) + + torch.cuda.empty_cache() + + context.services.images.save(image_type, image_name, image, metadata) + return build_image_output( + image_type=image_type, image_name=image_name, image=image ) @@ -485,7 +495,7 @@ class ResizeLatentsInvocation(BaseInvocation): torch.cuda.empty_cache() name = f"{context.graph_execution_state_id}__{self.id}" - context.services.latents.save(name, resized_latents) + context.services.latents.set(name, resized_latents) return build_latents_output(latents_name=name, latents=resized_latents) @@ -515,7 +525,7 @@ class ScaleLatentsInvocation(BaseInvocation): torch.cuda.empty_cache() name = f"{context.graph_execution_state_id}__{self.id}" - context.services.latents.save(name, resized_latents) + context.services.latents.set(name, resized_latents) return build_latents_output(latents_name=name, latents=resized_latents) @@ -539,7 +549,7 @@ class ImageToLatentsInvocation(BaseInvocation): @torch.no_grad() def invoke(self, context: InvocationContext) -> LatentsOutput: - image = context.services.images.get_pil_image( + image = context.services.images.get( self.image.image_type, self.image.image_name ) @@ -559,6 +569,5 @@ class ImageToLatentsInvocation(BaseInvocation): ) name = f"{context.graph_execution_state_id}__{self.id}" - context.services.latents.save(name, latents) + context.services.latents.set(name, latents) return build_latents_output(latents_name=name, latents=latents) -