mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Core implementation of ControlNet and MultiControlNet.
This commit is contained in:
parent
5569f205ee
commit
5ff98a4179
@ -9,16 +9,20 @@ from typing import Any, Callable, Generic, List, Optional, Type, TypeVar, Union
|
|||||||
|
|
||||||
import einops
|
import einops
|
||||||
import PIL.Image
|
import PIL.Image
|
||||||
|
import numpy as np
|
||||||
from accelerate.utils import set_seed
|
from accelerate.utils import set_seed
|
||||||
import psutil
|
import psutil
|
||||||
import torch
|
import torch
|
||||||
import torchvision.transforms as T
|
import torchvision.transforms as T
|
||||||
from compel import EmbeddingsProvider
|
from compel import EmbeddingsProvider
|
||||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||||
|
from diffusers.models.controlnet import ControlNetModel, ControlNetOutput
|
||||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
|
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
|
||||||
StableDiffusionPipeline,
|
StableDiffusionPipeline,
|
||||||
)
|
)
|
||||||
|
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel
|
||||||
|
|
||||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import (
|
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import (
|
||||||
StableDiffusionImg2ImgPipeline,
|
StableDiffusionImg2ImgPipeline,
|
||||||
)
|
)
|
||||||
@ -27,6 +31,7 @@ from diffusers.pipelines.stable_diffusion.safety_checker import (
|
|||||||
)
|
)
|
||||||
from diffusers.schedulers import KarrasDiffusionSchedulers
|
from diffusers.schedulers import KarrasDiffusionSchedulers
|
||||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
|
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
|
||||||
|
from diffusers.utils import PIL_INTERPOLATION
|
||||||
from diffusers.utils.import_utils import is_xformers_available
|
from diffusers.utils.import_utils import is_xformers_available
|
||||||
from diffusers.utils.outputs import BaseOutput
|
from diffusers.utils.outputs import BaseOutput
|
||||||
from torchvision.transforms.functional import resize as tv_resize
|
from torchvision.transforms.functional import resize as tv_resize
|
||||||
@ -302,6 +307,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
feature_extractor: Optional[CLIPFeatureExtractor],
|
feature_extractor: Optional[CLIPFeatureExtractor],
|
||||||
requires_safety_checker: bool = False,
|
requires_safety_checker: bool = False,
|
||||||
precision: str = "float32",
|
precision: str = "float32",
|
||||||
|
control_model: ControlNetModel = None,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
vae,
|
vae,
|
||||||
@ -322,6 +328,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
scheduler=scheduler,
|
scheduler=scheduler,
|
||||||
safety_checker=safety_checker,
|
safety_checker=safety_checker,
|
||||||
feature_extractor=feature_extractor,
|
feature_extractor=feature_extractor,
|
||||||
|
# FIXME: can't currently register control module
|
||||||
|
# control_model=control_model,
|
||||||
)
|
)
|
||||||
self.invokeai_diffuser = InvokeAIDiffuserComponent(
|
self.invokeai_diffuser = InvokeAIDiffuserComponent(
|
||||||
self.unet, self._unet_forward, is_running_diffusers=True
|
self.unet, self._unet_forward, is_running_diffusers=True
|
||||||
@ -341,6 +349,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
|
|
||||||
self._model_group = FullyLoadedModelGroup(self.unet.device)
|
self._model_group = FullyLoadedModelGroup(self.unet.device)
|
||||||
self._model_group.install(*self._submodels)
|
self._model_group.install(*self._submodels)
|
||||||
|
self.control_model = control_model
|
||||||
|
|
||||||
def _adjust_memory_efficient_attention(self, latents: torch.Tensor):
|
def _adjust_memory_efficient_attention(self, latents: torch.Tensor):
|
||||||
"""
|
"""
|
||||||
@ -463,6 +472,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
noise: torch.Tensor,
|
noise: torch.Tensor,
|
||||||
callback: Callable[[PipelineIntermediateState], None] = None,
|
callback: Callable[[PipelineIntermediateState], None] = None,
|
||||||
run_id=None,
|
run_id=None,
|
||||||
|
**kwargs,
|
||||||
) -> InvokeAIStableDiffusionPipelineOutput:
|
) -> InvokeAIStableDiffusionPipelineOutput:
|
||||||
r"""
|
r"""
|
||||||
Function invoked when calling the pipeline for generation.
|
Function invoked when calling the pipeline for generation.
|
||||||
@ -483,6 +493,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
noise=noise,
|
noise=noise,
|
||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
callback=callback,
|
callback=callback,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
@ -507,6 +518,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
additional_guidance: List[Callable] = None,
|
additional_guidance: List[Callable] = None,
|
||||||
run_id=None,
|
run_id=None,
|
||||||
callback: Callable[[PipelineIntermediateState], None] = None,
|
callback: Callable[[PipelineIntermediateState], None] = None,
|
||||||
|
**kwargs,
|
||||||
) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]:
|
) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]:
|
||||||
if self.scheduler.config.get("cpu_only", False):
|
if self.scheduler.config.get("cpu_only", False):
|
||||||
scheduler_device = torch.device('cpu')
|
scheduler_device = torch.device('cpu')
|
||||||
@ -527,6 +539,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
additional_guidance=additional_guidance,
|
additional_guidance=additional_guidance,
|
||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
callback=callback,
|
callback=callback,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
return result.latents, result.attention_map_saver
|
return result.latents, result.attention_map_saver
|
||||||
|
|
||||||
@ -539,6 +552,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
noise: torch.Tensor,
|
noise: torch.Tensor,
|
||||||
run_id: str = None,
|
run_id: str = None,
|
||||||
additional_guidance: List[Callable] = None,
|
additional_guidance: List[Callable] = None,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
self._adjust_memory_efficient_attention(latents)
|
self._adjust_memory_efficient_attention(latents)
|
||||||
if run_id is None:
|
if run_id is None:
|
||||||
@ -578,6 +592,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
step_index=i,
|
step_index=i,
|
||||||
total_step_count=len(timesteps),
|
total_step_count=len(timesteps),
|
||||||
additional_guidance=additional_guidance,
|
additional_guidance=additional_guidance,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
latents = step_output.prev_sample
|
latents = step_output.prev_sample
|
||||||
|
|
||||||
@ -618,6 +633,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
step_index: int,
|
step_index: int,
|
||||||
total_step_count: int,
|
total_step_count: int,
|
||||||
additional_guidance: List[Callable] = None,
|
additional_guidance: List[Callable] = None,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
|
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
|
||||||
timestep = t[0]
|
timestep = t[0]
|
||||||
@ -629,6 +645,33 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
# i.e. before or after passing it to InvokeAIDiffuserComponent
|
# i.e. before or after passing it to InvokeAIDiffuserComponent
|
||||||
latent_model_input = self.scheduler.scale_model_input(latents, timestep)
|
latent_model_input = self.scheduler.scale_model_input(latents, timestep)
|
||||||
|
|
||||||
|
if (self.control_model is not None) and (kwargs.get("control_image") is not None):
|
||||||
|
control_image = kwargs.get("control_image") # should be a processed tensor derived from the control image(s)
|
||||||
|
control_scale = kwargs.get("control_scale", 1.0) # control_scale default is 1.0
|
||||||
|
# handling case where using multiple control models but only specifying single control_scale
|
||||||
|
# so reshape control_scale to match number of control models
|
||||||
|
if isinstance(self.control_model, MultiControlNetModel) and isinstance(control_scale, float):
|
||||||
|
control_scale = [control_scale] * len(self.control_model.nets)
|
||||||
|
if conditioning_data.guidance_scale > 1.0:
|
||||||
|
# expand the latents input to control model if doing classifier free guidance
|
||||||
|
# (which I think for now is always true, there is conditional elsewhere that stops execution if
|
||||||
|
# classifier_free_guidance is <= 1.0 ?)
|
||||||
|
latent_control_input = torch.cat([latent_model_input] * 2)
|
||||||
|
else:
|
||||||
|
latent_control_input = latent_model_input
|
||||||
|
# controlnet inference
|
||||||
|
down_block_res_samples, mid_block_res_sample = self.control_model(
|
||||||
|
latent_control_input,
|
||||||
|
timestep,
|
||||||
|
encoder_hidden_states=torch.cat([conditioning_data.unconditioned_embeddings,
|
||||||
|
conditioning_data.text_embeddings]),
|
||||||
|
controlnet_cond=control_image,
|
||||||
|
conditioning_scale=control_scale,
|
||||||
|
return_dict=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
down_block_res_samples, mid_block_res_sample = None, None
|
||||||
|
|
||||||
# predict the noise residual
|
# predict the noise residual
|
||||||
noise_pred = self.invokeai_diffuser.do_diffusion_step(
|
noise_pred = self.invokeai_diffuser.do_diffusion_step(
|
||||||
latent_model_input,
|
latent_model_input,
|
||||||
@ -638,6 +681,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
conditioning_data.guidance_scale,
|
conditioning_data.guidance_scale,
|
||||||
step_index=step_index,
|
step_index=step_index,
|
||||||
total_step_count=total_step_count,
|
total_step_count=total_step_count,
|
||||||
|
down_block_additional_residuals=down_block_res_samples,
|
||||||
|
mid_block_additional_residual=mid_block_res_sample,
|
||||||
)
|
)
|
||||||
|
|
||||||
# compute the previous noisy sample x_t -> x_t-1
|
# compute the previous noisy sample x_t -> x_t-1
|
||||||
@ -659,6 +704,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
t,
|
t,
|
||||||
text_embeddings,
|
text_embeddings,
|
||||||
cross_attention_kwargs: Optional[dict[str, Any]] = None,
|
cross_attention_kwargs: Optional[dict[str, Any]] = None,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""predict the noise residual"""
|
"""predict the noise residual"""
|
||||||
if is_inpainting_model(self.unet) and latents.size(1) == 4:
|
if is_inpainting_model(self.unet) and latents.size(1) == 4:
|
||||||
@ -678,7 +724,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
|
|
||||||
# First three args should be positional, not keywords, so torch hooks can see them.
|
# First three args should be positional, not keywords, so torch hooks can see them.
|
||||||
return self.unet(
|
return self.unet(
|
||||||
latents, t, text_embeddings, cross_attention_kwargs=cross_attention_kwargs
|
latents, t, text_embeddings, cross_attention_kwargs=cross_attention_kwargs,
|
||||||
|
**kwargs,
|
||||||
).sample
|
).sample
|
||||||
|
|
||||||
def img2img_from_embeddings(
|
def img2img_from_embeddings(
|
||||||
@ -940,3 +987,48 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
debug_image(
|
debug_image(
|
||||||
img, f"latents {msg} {i+1}/{len(decoded)}", debug_status=True
|
img, f"latents {msg} {i+1}/{len(decoded)}", debug_status=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Copied from diffusers pipeline_stable_diffusion_controlnet.py
|
||||||
|
# Returns torch.Tensor of shape (batch_size, 3, height, width)
|
||||||
|
def prepare_control_image(
|
||||||
|
self,
|
||||||
|
image,
|
||||||
|
width=512,
|
||||||
|
height=512,
|
||||||
|
batch_size=1,
|
||||||
|
num_images_per_prompt=1,
|
||||||
|
device="cuda",
|
||||||
|
dtype=torch.float16,
|
||||||
|
do_classifier_free_guidance=True,
|
||||||
|
):
|
||||||
|
if not isinstance(image, torch.Tensor):
|
||||||
|
if isinstance(image, PIL.Image.Image):
|
||||||
|
image = [image]
|
||||||
|
|
||||||
|
if isinstance(image[0], PIL.Image.Image):
|
||||||
|
images = []
|
||||||
|
for image_ in image:
|
||||||
|
image_ = image_.convert("RGB")
|
||||||
|
image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
|
||||||
|
image_ = np.array(image_)
|
||||||
|
image_ = image_[None, :]
|
||||||
|
images.append(image_)
|
||||||
|
image = images
|
||||||
|
image = np.concatenate(image, axis=0)
|
||||||
|
image = np.array(image).astype(np.float32) / 255.0
|
||||||
|
image = image.transpose(0, 3, 1, 2)
|
||||||
|
image = torch.from_numpy(image)
|
||||||
|
elif isinstance(image[0], torch.Tensor):
|
||||||
|
image = torch.cat(image, dim=0)
|
||||||
|
|
||||||
|
image_batch_size = image.shape[0]
|
||||||
|
if image_batch_size == 1:
|
||||||
|
repeat_by = batch_size
|
||||||
|
else:
|
||||||
|
# image batch size is the same as prompt batch size
|
||||||
|
repeat_by = num_images_per_prompt
|
||||||
|
image = image.repeat_interleave(repeat_by, dim=0)
|
||||||
|
image = image.to(device=device, dtype=dtype)
|
||||||
|
if do_classifier_free_guidance:
|
||||||
|
image = torch.cat([image] * 2)
|
||||||
|
return image
|
||||||
|
@ -181,6 +181,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
unconditional_guidance_scale: float,
|
unconditional_guidance_scale: float,
|
||||||
step_index: Optional[int] = None,
|
step_index: Optional[int] = None,
|
||||||
total_step_count: Optional[int] = None,
|
total_step_count: Optional[int] = None,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
:param x: current latents
|
:param x: current latents
|
||||||
@ -209,7 +210,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
|
|
||||||
if wants_hybrid_conditioning:
|
if wants_hybrid_conditioning:
|
||||||
unconditioned_next_x, conditioned_next_x = self._apply_hybrid_conditioning(
|
unconditioned_next_x, conditioned_next_x = self._apply_hybrid_conditioning(
|
||||||
x, sigma, unconditioning, conditioning
|
x, sigma, unconditioning, conditioning, **kwargs,
|
||||||
)
|
)
|
||||||
elif wants_cross_attention_control:
|
elif wants_cross_attention_control:
|
||||||
(
|
(
|
||||||
@ -221,13 +222,14 @@ class InvokeAIDiffuserComponent:
|
|||||||
unconditioning,
|
unconditioning,
|
||||||
conditioning,
|
conditioning,
|
||||||
cross_attention_control_types_to_do,
|
cross_attention_control_types_to_do,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
elif self.sequential_guidance:
|
elif self.sequential_guidance:
|
||||||
(
|
(
|
||||||
unconditioned_next_x,
|
unconditioned_next_x,
|
||||||
conditioned_next_x,
|
conditioned_next_x,
|
||||||
) = self._apply_standard_conditioning_sequentially(
|
) = self._apply_standard_conditioning_sequentially(
|
||||||
x, sigma, unconditioning, conditioning
|
x, sigma, unconditioning, conditioning, **kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@ -235,7 +237,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
unconditioned_next_x,
|
unconditioned_next_x,
|
||||||
conditioned_next_x,
|
conditioned_next_x,
|
||||||
) = self._apply_standard_conditioning(
|
) = self._apply_standard_conditioning(
|
||||||
x, sigma, unconditioning, conditioning
|
x, sigma, unconditioning, conditioning, **kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
combined_next_x = self._combine(
|
combined_next_x = self._combine(
|
||||||
@ -282,13 +284,13 @@ class InvokeAIDiffuserComponent:
|
|||||||
|
|
||||||
# methods below are called from do_diffusion_step and should be considered private to this class.
|
# methods below are called from do_diffusion_step and should be considered private to this class.
|
||||||
|
|
||||||
def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning):
|
def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs):
|
||||||
# fast batched path
|
# fast batched path
|
||||||
x_twice = torch.cat([x] * 2)
|
x_twice = torch.cat([x] * 2)
|
||||||
sigma_twice = torch.cat([sigma] * 2)
|
sigma_twice = torch.cat([sigma] * 2)
|
||||||
both_conditionings = torch.cat([unconditioning, conditioning])
|
both_conditionings = torch.cat([unconditioning, conditioning])
|
||||||
both_results = self.model_forward_callback(
|
both_results = self.model_forward_callback(
|
||||||
x_twice, sigma_twice, both_conditionings
|
x_twice, sigma_twice, both_conditionings, **kwargs,
|
||||||
)
|
)
|
||||||
unconditioned_next_x, conditioned_next_x = both_results.chunk(2)
|
unconditioned_next_x, conditioned_next_x = both_results.chunk(2)
|
||||||
if conditioned_next_x.device.type == "mps":
|
if conditioned_next_x.device.type == "mps":
|
||||||
@ -302,16 +304,17 @@ class InvokeAIDiffuserComponent:
|
|||||||
sigma,
|
sigma,
|
||||||
unconditioning: torch.Tensor,
|
unconditioning: torch.Tensor,
|
||||||
conditioning: torch.Tensor,
|
conditioning: torch.Tensor,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
# low-memory sequential path
|
# low-memory sequential path
|
||||||
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning)
|
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning, **kwargs)
|
||||||
conditioned_next_x = self.model_forward_callback(x, sigma, conditioning)
|
conditioned_next_x = self.model_forward_callback(x, sigma, conditioning, **kwargs)
|
||||||
if conditioned_next_x.device.type == "mps":
|
if conditioned_next_x.device.type == "mps":
|
||||||
# prevent a result filled with zeros. seems to be a torch bug.
|
# prevent a result filled with zeros. seems to be a torch bug.
|
||||||
conditioned_next_x = conditioned_next_x.clone()
|
conditioned_next_x = conditioned_next_x.clone()
|
||||||
return unconditioned_next_x, conditioned_next_x
|
return unconditioned_next_x, conditioned_next_x
|
||||||
|
|
||||||
def _apply_hybrid_conditioning(self, x, sigma, unconditioning, conditioning):
|
def _apply_hybrid_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs):
|
||||||
assert isinstance(conditioning, dict)
|
assert isinstance(conditioning, dict)
|
||||||
assert isinstance(unconditioning, dict)
|
assert isinstance(unconditioning, dict)
|
||||||
x_twice = torch.cat([x] * 2)
|
x_twice = torch.cat([x] * 2)
|
||||||
@ -326,7 +329,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
else:
|
else:
|
||||||
both_conditionings[k] = torch.cat([unconditioning[k], conditioning[k]])
|
both_conditionings[k] = torch.cat([unconditioning[k], conditioning[k]])
|
||||||
unconditioned_next_x, conditioned_next_x = self.model_forward_callback(
|
unconditioned_next_x, conditioned_next_x = self.model_forward_callback(
|
||||||
x_twice, sigma_twice, both_conditionings
|
x_twice, sigma_twice, both_conditionings, **kwargs,
|
||||||
).chunk(2)
|
).chunk(2)
|
||||||
return unconditioned_next_x, conditioned_next_x
|
return unconditioned_next_x, conditioned_next_x
|
||||||
|
|
||||||
@ -337,6 +340,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
unconditioning,
|
unconditioning,
|
||||||
conditioning,
|
conditioning,
|
||||||
cross_attention_control_types_to_do,
|
cross_attention_control_types_to_do,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
if self.is_running_diffusers:
|
if self.is_running_diffusers:
|
||||||
return self._apply_cross_attention_controlled_conditioning__diffusers(
|
return self._apply_cross_attention_controlled_conditioning__diffusers(
|
||||||
@ -345,6 +349,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
unconditioning,
|
unconditioning,
|
||||||
conditioning,
|
conditioning,
|
||||||
cross_attention_control_types_to_do,
|
cross_attention_control_types_to_do,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return self._apply_cross_attention_controlled_conditioning__compvis(
|
return self._apply_cross_attention_controlled_conditioning__compvis(
|
||||||
@ -353,6 +358,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
unconditioning,
|
unconditioning,
|
||||||
conditioning,
|
conditioning,
|
||||||
cross_attention_control_types_to_do,
|
cross_attention_control_types_to_do,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _apply_cross_attention_controlled_conditioning__diffusers(
|
def _apply_cross_attention_controlled_conditioning__diffusers(
|
||||||
@ -362,6 +368,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
unconditioning,
|
unconditioning,
|
||||||
conditioning,
|
conditioning,
|
||||||
cross_attention_control_types_to_do,
|
cross_attention_control_types_to_do,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
context: Context = self.cross_attention_control_context
|
context: Context = self.cross_attention_control_context
|
||||||
|
|
||||||
@ -377,6 +384,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
sigma,
|
sigma,
|
||||||
unconditioning,
|
unconditioning,
|
||||||
{"swap_cross_attn_context": cross_attn_processor_context},
|
{"swap_cross_attn_context": cross_attn_processor_context},
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# do requested cross attention types for conditioning (positive prompt)
|
# do requested cross attention types for conditioning (positive prompt)
|
||||||
@ -388,6 +396,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
sigma,
|
sigma,
|
||||||
conditioning,
|
conditioning,
|
||||||
{"swap_cross_attn_context": cross_attn_processor_context},
|
{"swap_cross_attn_context": cross_attn_processor_context},
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
return unconditioned_next_x, conditioned_next_x
|
return unconditioned_next_x, conditioned_next_x
|
||||||
|
|
||||||
@ -398,6 +407,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
unconditioning,
|
unconditioning,
|
||||||
conditioning,
|
conditioning,
|
||||||
cross_attention_control_types_to_do,
|
cross_attention_control_types_to_do,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
# print('pct', percent_through, ': doing cross attention control on', cross_attention_control_types_to_do)
|
# print('pct', percent_through, ': doing cross attention control on', cross_attention_control_types_to_do)
|
||||||
# slower non-batched path (20% slower on mac MPS)
|
# slower non-batched path (20% slower on mac MPS)
|
||||||
@ -411,13 +421,13 @@ class InvokeAIDiffuserComponent:
|
|||||||
context: Context = self.cross_attention_control_context
|
context: Context = self.cross_attention_control_context
|
||||||
|
|
||||||
try:
|
try:
|
||||||
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning)
|
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning, **kwargs)
|
||||||
|
|
||||||
# process x using the original prompt, saving the attention maps
|
# process x using the original prompt, saving the attention maps
|
||||||
# print("saving attention maps for", cross_attention_control_types_to_do)
|
# print("saving attention maps for", cross_attention_control_types_to_do)
|
||||||
for ca_type in cross_attention_control_types_to_do:
|
for ca_type in cross_attention_control_types_to_do:
|
||||||
context.request_save_attention_maps(ca_type)
|
context.request_save_attention_maps(ca_type)
|
||||||
_ = self.model_forward_callback(x, sigma, conditioning)
|
_ = self.model_forward_callback(x, sigma, conditioning, **kwargs,)
|
||||||
context.clear_requests(cleanup=False)
|
context.clear_requests(cleanup=False)
|
||||||
|
|
||||||
# process x again, using the saved attention maps to control where self.edited_conditioning will be applied
|
# process x again, using the saved attention maps to control where self.edited_conditioning will be applied
|
||||||
@ -428,7 +438,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
self.conditioning.cross_attention_control_args.edited_conditioning
|
self.conditioning.cross_attention_control_args.edited_conditioning
|
||||||
)
|
)
|
||||||
conditioned_next_x = self.model_forward_callback(
|
conditioned_next_x = self.model_forward_callback(
|
||||||
x, sigma, edited_conditioning
|
x, sigma, edited_conditioning, **kwargs,
|
||||||
)
|
)
|
||||||
context.clear_requests(cleanup=True)
|
context.clear_requests(cleanup=True)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user