Added support for using multiple control nets. Unfortunately this breaks direct usage of Control node output port ==> TextToLatent control input port -- passing through a Collect node is now required. Working on fixing this...

This commit is contained in:
user1 2023-05-08 19:19:24 -07:00 committed by Kent Keirsey
parent a2a2cfa765
commit bb96543d66

View File

@ -2,14 +2,14 @@
import random import random
import einops import einops
from pydantic import BaseModel, Field, validator
import torch
from typing import Literal, Optional, Union, List from typing import Literal, Optional, Union, List
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel
from pydantic import BaseModel, Field
import torch
from invokeai.app.invocations.util.choose_model import choose_model from invokeai.app.invocations.util.choose_model import choose_model
from invokeai.app.models.image import ImageCategory
from invokeai.app.util.misc import SEED_MAX, get_random_seed from invokeai.app.util.misc import SEED_MAX, get_random_seed
from invokeai.app.util.step_callback import stable_diffusion_step_callback from invokeai.app.util.step_callback import stable_diffusion_step_callback
@ -20,16 +20,13 @@ from ...backend.util.devices import choose_torch_device, torch_dtype
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
from ...backend.image_util.seamless import configure_model_padding from ...backend.image_util.seamless import configure_model_padding
from ...backend.prompting.conditioning import get_uc_and_c_and_ec from ...backend.prompting.conditioning import get_uc_and_c_and_ec
from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline, image_resized_to_grid_as_tensor from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline, image_resized_to_grid_as_tensor
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
from ...backend.stable_diffusion.diffusers_pipeline import ControlNetData
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
import numpy as np import numpy as np
from ..services.image_file_storage import ImageType from ..services.image_storage import ImageType
from .baseinvocation import BaseInvocation, InvocationContext from .baseinvocation import BaseInvocation, InvocationContext
from .image import ImageField, ImageOutput from .image import ImageField, ImageOutput, build_image_output
from .compel import ConditioningField from .compel import ConditioningField
from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.stable_diffusion import PipelineIntermediateState
from diffusers.schedulers import SchedulerMixin as Scheduler from diffusers.schedulers import SchedulerMixin as Scheduler
@ -90,13 +87,13 @@ SAMPLER_NAME_VALUES = Literal[
def get_scheduler(scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler: def get_scheduler(scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler:
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP['ddim']) scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP['ddim'])
scheduler_config = model.scheduler.config scheduler_config = model.scheduler.config
if "_backup" in scheduler_config: if "_backup" in scheduler_config:
scheduler_config = scheduler_config["_backup"] scheduler_config = scheduler_config["_backup"]
scheduler_config = {**scheduler_config, **scheduler_extra_config, "_backup": scheduler_config} scheduler_config = {**scheduler_config, **scheduler_extra_config, "_backup": scheduler_config}
scheduler = scheduler_class.from_config(scheduler_config) scheduler = scheduler_class.from_config(scheduler_config)
# hack copied over from generate.py # hack copied over from generate.py
if not hasattr(scheduler, 'uses_inpainting_model'): if not hasattr(scheduler, 'uses_inpainting_model'):
scheduler.uses_inpainting_model = lambda: False scheduler.uses_inpainting_model = lambda: False
@ -146,17 +143,12 @@ class NoiseInvocation(BaseInvocation):
}, },
} }
@validator("seed", pre=True)
def modulo_seed(cls, v):
"""Returns the seed modulo SEED_MAX to ensure it is within the valid range."""
return v % SEED_MAX
def invoke(self, context: InvocationContext) -> NoiseOutput: def invoke(self, context: InvocationContext) -> NoiseOutput:
device = torch.device(choose_torch_device()) device = torch.device(choose_torch_device())
noise = get_noise(self.width, self.height, device, self.seed) noise = get_noise(self.width, self.height, device, self.seed)
name = f'{context.graph_execution_state_id}__{self.id}' name = f'{context.graph_execution_state_id}__{self.id}'
context.services.latents.save(name, noise) context.services.latents.set(name, noise)
return build_noise_output(latents_name=name, latents=noise) return build_noise_output(latents_name=name, latents=noise)
@ -173,22 +165,29 @@ class TextToLatentsInvocation(BaseInvocation):
noise: Optional[LatentsField] = Field(description="The noise to use") noise: Optional[LatentsField] = Field(description="The noise to use")
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image") steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", ) cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" ) scheduler: SAMPLER_NAME_VALUES = Field(default="lms", description="The scheduler to use" )
model: str = Field(default="", description="The model to use (currently ignored)") model: str = Field(default="", description="The model to use (currently ignored)")
# seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", ) seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
# seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'") seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", ) progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", )
control: Union[ControlField, List[ControlField]] = Field(default=None, description="The controlnet(s) to use") control: list[ControlField] = Field(default=None, description="The controlnet(s) to use")
# fmt: on # control: Union[list[ControlField] | None] = Field(default=None, description="The controlnet(s) to use")
# control: ControlField = Field(default=None, description="The controlnet(s) to use")
# control: Union[ControlField | list[ControlField] | None] = Field(default=None, description="The controlnet(s) to use")
# control: Any = Field(default=None, description="The controlnet(s) to use")
# control: Optional[ControlField] = Field(default=None, description="The control to use")
# control: List[ControlField] = Field(description="The controlnet(s) to use")
# control: Optional[list[ControlField]] = Field(default=None, description="The controlnet(s) to use")
# control: Optional[list[ControlField]] = Field(description="The controlnet(s) to use")
# fmt: on
# Schema customisation # Schema customisation
class Config(InvocationConfig): class Config(InvocationConfig):
schema_extra = { schema_extra = {
"ui": { "ui": {
"tags": ["latents"], "tags": ["latents", "image"],
"type_hints": { "type_hints": {
"model": "model", "model": "model"
"control": "control",
} }
}, },
} }
@ -214,17 +213,17 @@ class TextToLatentsInvocation(BaseInvocation):
scheduler_name=self.scheduler scheduler_name=self.scheduler
) )
# if isinstance(model, DiffusionPipeline): if isinstance(model, DiffusionPipeline):
# for component in [model.unet, model.vae]: for component in [model.unet, model.vae]:
# configure_model_padding(component, configure_model_padding(component,
# self.seamless, self.seamless,
# self.seamless_axes self.seamless_axes
# ) )
# else: else:
# configure_model_padding(model, configure_model_padding(model,
# self.seamless, self.seamless,
# self.seamless_axes self.seamless_axes
# ) )
return model return model
@ -247,71 +246,13 @@ class TextToLatentsInvocation(BaseInvocation):
).add_scheduler_args_if_applicable(model.scheduler, eta=0.0)#ddim_eta) ).add_scheduler_args_if_applicable(model.scheduler, eta=0.0)#ddim_eta)
return conditioning_data return conditioning_data
def prep_control_data(self,
context: InvocationContext,
model: StableDiffusionGeneratorPipeline, # really only need model for dtype and device
control_input: List[ControlField],
latents_shape: List[int],
do_classifier_free_guidance: bool = True,
) -> List[ControlNetData]:
# assuming fixed dimensional scaling of 8:1 for image:latents
control_height_resize = latents_shape[2] * 8
control_width_resize = latents_shape[3] * 8
if control_input is None:
# print("control input is None")
control_list = None
elif isinstance(control_input, list) and len(control_input) == 0:
# print("control input is empty list")
control_list = None
elif isinstance(control_input, ControlField):
# print("control input is ControlField")
control_list = [control_input]
elif isinstance(control_input, list) and len(control_input) > 0 and isinstance(control_input[0], ControlField):
# print("control input is list[ControlField]")
control_list = control_input
else:
# print("input control is unrecognized:", type(self.control))
control_list = None
if (control_list is None):
control_data = None
# from above handling, any control that is not None should now be of type list[ControlField]
else:
# FIXME: add checks to skip entry if model or image is None
# and if weight is None, populate with default 1.0?
control_data = []
control_models = []
for control_info in control_list:
# handle control models
control_model = ControlNetModel.from_pretrained(control_info.control_model,
torch_dtype=model.unet.dtype).to(model.device)
control_models.append(control_model)
control_image_field = control_info.image
input_image = context.services.images.get(control_image_field.image_type, control_image_field.image_name)
# FIXME: still need to test with different widths, heights, devices, dtypes
# and add in batch_size, num_images_per_prompt?
# and do real check for classifier_free_guidance?
# prepare_control_image should return torch.Tensor of shape(batch_size, 3, height, width)
control_image = model.prepare_control_image(
image=input_image,
do_classifier_free_guidance=do_classifier_free_guidance,
width=control_width_resize,
height=control_height_resize,
# batch_size=batch_size * num_images_per_prompt,
# num_images_per_prompt=num_images_per_prompt,
device=control_model.device,
dtype=control_model.dtype,
)
control_item = ControlNetData(model=control_model,
image_tensor=control_image,
weight=control_info.control_weight,
begin_step_percent=control_info.begin_step_percent,
end_step_percent=control_info.end_step_percent)
control_data.append(control_item)
# MultiControlNetModel has been refactored out, just need list[ControlNetData]
return control_data
def invoke(self, context: InvocationContext) -> LatentsOutput: def invoke(self, context: InvocationContext) -> LatentsOutput:
noise = context.services.latents.get(self.noise.latents_name) noise = context.services.latents.get(self.noise.latents_name)
latents_shape = noise.shape
# assuming fixed dimensional scaling of 8:1 for image:latents
control_height_resize = latents_shape[2] * 8
control_width_resize = latents_shape[3] * 8
# Get the source node id (we are invoking the prepared node) # Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id) graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
@ -324,9 +265,77 @@ class TextToLatentsInvocation(BaseInvocation):
conditioning_data = self.get_conditioning_data(context, model) conditioning_data = self.get_conditioning_data(context, model)
print("type of control input: ", type(self.control)) print("type of control input: ", type(self.control))
control_data = self.prep_control_data(model=model, context=context, control_input=self.control,
latents_shape=noise.shape, if self.control is None:
do_classifier_free_guidance=(self.cfg_scale >= 1.0)) print("control input is None")
control_list = None
elif isinstance(self.control, list) and len(self.control) == 0:
print("control input is empty list")
control_list = None
elif isinstance(self.control, ControlField):
print("control input is ControlField")
# control = [self.control]
control_list = [self.control]
# elif isinstance(self.control, list) and len(self.control)>0 and isinstance(self.control[0], ControlField):
elif isinstance(self.control, list) and len(self.control) > 0 and isinstance(self.control[0], ControlField):
print("control input is list[ControlField]")
# print("using first controlnet in list")
control_list = self.control
# control = self.control
else:
print("input control is unrecognized:", type(self.control))
control_list = None
#if (self.control is None or (isinstance(self.control, list) and len(self.control)==0)):
if (control_list is None):
control_models = None
control_weights = None
control_images = None
# from above handling, any control that is not None should now be of type list[ControlField]
else:
# FIXME: add checks to skip entry if model or image is None
# and if weight is None, populate with default 1.0?
control_models = []
control_images = []
control_weights = []
for control_info in control_list:
# handle control weights
control_weights.append(control_info.control_weight)
# handle control models
# FIXME: change this to dropdown menu?
# FIXME: generalize so don't have to hardcode torch_dtype and device
control_model = ControlNetModel.from_pretrained(control_info.control_model,
#torch_dtype=model.unet.dtype).to(model.device)
#torch.dtype=model.unet.dtype).to("cuda")
# torch.dtype = model.unet.dtype).to("cuda")
torch_dtype=torch.float16).to("cuda")
# torch_dtype = torch.float16).to(model.device)
# model.dtype).to(model.device)
control_models.append(control_model)
# handle control images
# loading controlnet image (currently requires pre-processed image)
# control_image = prep_control_image(control_info.image)
control_image_field = control_info.image
input_image = context.services.images.get(control_image_field.image_type, control_image_field.image_name)
# FIXME: still need to test with different widths, heights, devices, dtypes
# and add in batch_size, num_images_per_prompt?
# and do real check for classifier_free_guidance?
control_image = model.prepare_control_image(
image=input_image,
# do_classifier_free_guidance=do_classifier_free_guidance,
do_classifier_free_guidance=True,
width=control_width_resize,
height=control_height_resize,
# batch_size=batch_size * num_images_per_prompt,
# num_images_per_prompt=num_images_per_prompt,
device=control_model.device,
dtype=control_model.dtype,
)
control_images.append(control_image)
multi_control = MultiControlNetModel(control_models)
model.control_model = multi_control
# TODO: Verify the noise is the right size # TODO: Verify the noise is the right size
result_latents, result_attention_map_saver = model.latents_from_embeddings( result_latents, result_attention_map_saver = model.latents_from_embeddings(
@ -334,15 +343,15 @@ class TextToLatentsInvocation(BaseInvocation):
noise=noise, noise=noise,
num_inference_steps=self.steps, num_inference_steps=self.steps,
conditioning_data=conditioning_data, conditioning_data=conditioning_data,
control_data=control_data, # list[ControlNetData]
callback=step_callback, callback=step_callback,
control_image=control_images,
) )
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache() torch.cuda.empty_cache()
name = f'{context.graph_execution_state_id}__{self.id}' name = f'{context.graph_execution_state_id}__{self.id}'
context.services.latents.save(name, result_latents) context.services.latents.set(name, result_latents)
return build_latents_output(latents_name=name, latents=result_latents) return build_latents_output(latents_name=name, latents=result_latents)
@ -355,6 +364,17 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
latents: Optional[LatentsField] = Field(description="The latents to use as a base image") latents: Optional[LatentsField] = Field(description="The latents to use as a base image")
strength: float = Field(default=0.5, description="The strength of the latents to use") strength: float = Field(default=0.5, description="The strength of the latents to use")
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["latents"],
"type_hints": {
"model": "model"
}
},
}
def invoke(self, context: InvocationContext) -> LatentsOutput: def invoke(self, context: InvocationContext) -> LatentsOutput:
noise = context.services.latents.get(self.noise.latents_name) noise = context.services.latents.get(self.noise.latents_name)
latent = context.services.latents.get(self.latents.latents_name) latent = context.services.latents.get(self.latents.latents_name)
@ -369,11 +389,6 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
model = self.get_model(context.services.model_manager) model = self.get_model(context.services.model_manager)
conditioning_data = self.get_conditioning_data(context, model) conditioning_data = self.get_conditioning_data(context, model)
print("type of control input: ", type(self.control))
control_data = self.prep_control_data(model=model, context=context, control_input=self.control,
latents_shape=noise.shape,
do_classifier_free_guidance=(self.cfg_scale >= 1.0))
# TODO: Verify the noise is the right size # TODO: Verify the noise is the right size
initial_latents = latent if self.strength < 1.0 else torch.zeros_like( initial_latents = latent if self.strength < 1.0 else torch.zeros_like(
@ -388,7 +403,6 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
noise=noise, noise=noise,
num_inference_steps=self.steps, num_inference_steps=self.steps,
conditioning_data=conditioning_data, conditioning_data=conditioning_data,
control_data=control_data, # list[ControlNetData]
callback=step_callback callback=step_callback
) )
@ -396,7 +410,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
torch.cuda.empty_cache() torch.cuda.empty_cache()
name = f'{context.graph_execution_state_id}__{self.id}' name = f'{context.graph_execution_state_id}__{self.id}'
context.services.latents.save(name, result_latents) context.services.latents.set(name, result_latents)
return build_latents_output(latents_name=name, latents=result_latents) return build_latents_output(latents_name=name, latents=result_latents)
@ -433,24 +447,20 @@ class LatentsToImageInvocation(BaseInvocation):
np_image = model.decode_latents(latents) np_image = model.decode_latents(latents)
image = model.numpy_to_pil(np_image)[0] image = model.numpy_to_pil(np_image)[0]
torch.cuda.empty_cache() image_type = ImageType.RESULT
image_name = context.services.images.create_name(
image_dto = context.services.images.create( context.graph_execution_state_id, self.id
image=image,
image_type=ImageType.RESULT,
image_category=ImageCategory.GENERAL,
session_id=context.graph_execution_state_id,
node_id=self.id,
is_intermediate=self.is_intermediate
) )
return ImageOutput( metadata = context.services.metadata.build_metadata(
image=ImageField( session_id=context.graph_execution_state_id, node=self
image_name=image_dto.image_name, )
image_type=image_dto.image_type,
), torch.cuda.empty_cache()
width=image_dto.width,
height=image_dto.height, context.services.images.save(image_type, image_name, image, metadata)
return build_image_output(
image_type=image_type, image_name=image_name, image=image
) )
@ -485,7 +495,7 @@ class ResizeLatentsInvocation(BaseInvocation):
torch.cuda.empty_cache() torch.cuda.empty_cache()
name = f"{context.graph_execution_state_id}__{self.id}" name = f"{context.graph_execution_state_id}__{self.id}"
context.services.latents.save(name, resized_latents) context.services.latents.set(name, resized_latents)
return build_latents_output(latents_name=name, latents=resized_latents) return build_latents_output(latents_name=name, latents=resized_latents)
@ -515,7 +525,7 @@ class ScaleLatentsInvocation(BaseInvocation):
torch.cuda.empty_cache() torch.cuda.empty_cache()
name = f"{context.graph_execution_state_id}__{self.id}" name = f"{context.graph_execution_state_id}__{self.id}"
context.services.latents.save(name, resized_latents) context.services.latents.set(name, resized_latents)
return build_latents_output(latents_name=name, latents=resized_latents) return build_latents_output(latents_name=name, latents=resized_latents)
@ -539,7 +549,7 @@ class ImageToLatentsInvocation(BaseInvocation):
@torch.no_grad() @torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput: def invoke(self, context: InvocationContext) -> LatentsOutput:
image = context.services.images.get_pil_image( image = context.services.images.get(
self.image.image_type, self.image.image_name self.image.image_type, self.image.image_name
) )
@ -559,6 +569,5 @@ class ImageToLatentsInvocation(BaseInvocation):
) )
name = f"{context.graph_execution_state_id}__{self.id}" name = f"{context.graph_execution_state_id}__{self.id}"
context.services.latents.save(name, latents) context.services.latents.set(name, latents)
return build_latents_output(latents_name=name, latents=latents) return build_latents_output(latents_name=name, latents=latents)