Compare commits

...

8 Commits

6 changed files with 193 additions and 23 deletions

View File

@ -4,7 +4,9 @@ from functools import partial
from typing import Literal, Optional, Union from typing import Literal, Optional, Union
import numpy as np import numpy as np
from diffusers import ControlNetModel
from torch import Tensor from torch import Tensor
import torch
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@ -53,6 +55,9 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", ) seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
model: str = Field(default="", description="The model to use (currently ignored)") model: str = Field(default="", description="The model to use (currently ignored)")
progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", ) progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", )
control_model: Optional[str] = Field(default=None, description="The control model to use")
control_image: Optional[ImageField] = Field(default=None, description="The processed control image")
# control_strength: Optional[float] = Field(default=1.0, ge=0, le=1, description="The strength of the controlnet")
# fmt: on # fmt: on
# TODO: pass this an emitter method or something? or a session for dispatching? # TODO: pass this an emitter method or something? or a session for dispatching?
@ -70,20 +75,36 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
) )
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
# Handle invalid model parameter
model = choose_model(context.services.model_manager, self.model) model = choose_model(context.services.model_manager, self.model)
# loading controlnet image (currently requires pre-processed image)
control_image = (
None if self.control_image is None
else context.services.images.get(
self.control_image.image_type, self.control_image.image_name
)
)
# loading controlnet model
if (self.control_model is None or self.control_model==''):
control_model = None
else:
# FIXME: change this to dropdown menu?
control_model = ControlNetModel.from_pretrained(self.control_model,
torch_dtype=torch.float16).to("cuda")
# Get the source node id (we are invoking the prepared node) # Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get( graph_execution_state = context.services.graph_execution_manager.get(
context.graph_execution_state_id context.graph_execution_state_id
) )
source_node_id = graph_execution_state.prepared_source_mapping[self.id] source_node_id = graph_execution_state.prepared_source_mapping[self.id]
outputs = Txt2Img(model).generate( txt2img = Txt2Img(model, control_model=control_model)
outputs = txt2img.generate(
prompt=self.prompt, prompt=self.prompt,
step_callback=partial(self.dispatch_progress, context, source_node_id), step_callback=partial(self.dispatch_progress, context, source_node_id),
control_image=control_image,
**self.dict( **self.dict(
exclude={"prompt"} exclude={"prompt", "control_image" }
), # Shorthand for passing all of the parameters above manually ), # Shorthand for passing all of the parameters above manually
) )
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object # Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object

View File

@ -86,9 +86,11 @@ class InvokeAIGenerator(metaclass=ABCMeta):
def __init__(self, def __init__(self,
model_info: dict, model_info: dict,
params: InvokeAIGeneratorBasicParams=InvokeAIGeneratorBasicParams(), params: InvokeAIGeneratorBasicParams=InvokeAIGeneratorBasicParams(),
**kwargs,
): ):
self.model_info=model_info self.model_info=model_info
self.params=params self.params=params
self.kwargs = kwargs
def generate(self, def generate(self,
prompt: str='', prompt: str='',
@ -129,9 +131,12 @@ class InvokeAIGenerator(metaclass=ABCMeta):
model=model, model=model,
scheduler_name=generator_args.get('scheduler') scheduler_name=generator_args.get('scheduler')
) )
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(prompt,model=model)
# get conditioning from prompt via Compel package
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(prompt, model=model)
gen_class = self._generator_class() gen_class = self._generator_class()
generator = gen_class(model, self.params.precision) generator = gen_class(model, self.params.precision, **self.kwargs)
if self.params.variation_amount > 0: if self.params.variation_amount > 0:
generator.set_variation(generator_args.get('seed'), generator.set_variation(generator_args.get('seed'),
generator_args.get('variation_amount'), generator_args.get('variation_amount'),
@ -281,7 +286,7 @@ class Generator:
precision: str precision: str
model: DiffusionPipeline model: DiffusionPipeline
def __init__(self, model: DiffusionPipeline, precision: str): def __init__(self, model: DiffusionPipeline, precision: str, **kwargs):
self.model = model self.model = model
self.precision = precision self.precision = precision
self.seed = None self.seed = None
@ -354,7 +359,6 @@ class Generator:
seed = seed if seed is not None and seed >= 0 else self.new_seed() seed = seed if seed is not None and seed >= 0 else self.new_seed()
first_seed = seed first_seed = seed
seed, initial_noise = self.generate_initial_noise(seed, width, height) seed, initial_noise = self.generate_initial_noise(seed, width, height)
# There used to be an additional self.model.ema_scope() here, but it breaks # There used to be an additional self.model.ema_scope() here, but it breaks
# the inpaint-1.5 model. Not sure what it did.... ? # the inpaint-1.5 model. Not sure what it did.... ?
with scope(self.model.device.type): with scope(self.model.device.type):

View File

@ -4,6 +4,10 @@ invokeai.backend.generator.txt2img inherits from invokeai.backend.generator
import PIL.Image import PIL.Image
import torch import torch
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from diffusers.models.controlnet import ControlNetModel, ControlNetOutput
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel
from ..stable_diffusion import ( from ..stable_diffusion import (
ConditioningData, ConditioningData,
PostprocessingSettings, PostprocessingSettings,
@ -13,8 +17,13 @@ from .base import Generator
class Txt2Img(Generator): class Txt2Img(Generator):
def __init__(self, model, precision): def __init__(self, model, precision,
super().__init__(model, precision) control_model: Optional[Union[ControlNetModel, List[ControlNetModel]]] = None,
**kwargs):
self.control_model = control_model
if isinstance(self.control_model, list):
self.control_model = MultiControlNetModel(self.control_model)
super().__init__(model, precision, **kwargs)
@torch.no_grad() @torch.no_grad()
def get_make_image( def get_make_image(
@ -42,9 +51,12 @@ class Txt2Img(Generator):
kwargs are 'width' and 'height' kwargs are 'width' and 'height'
""" """
self.perlin = perlin self.perlin = perlin
control_image = kwargs.get("control_image", None)
do_classifier_free_guidance = cfg_scale > 1.0
# noinspection PyTypeChecker # noinspection PyTypeChecker
pipeline: StableDiffusionGeneratorPipeline = self.model pipeline: StableDiffusionGeneratorPipeline = self.model
pipeline.control_model = self.control_model
pipeline.scheduler = sampler pipeline.scheduler = sampler
uc, c, extra_conditioning_info = conditioning uc, c, extra_conditioning_info = conditioning
@ -61,6 +73,37 @@ class Txt2Img(Generator):
), ),
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta) ).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)
# FIXME: still need to test with different widths, heights, devices, dtypes
# and add in batch_size, num_images_per_prompt?
if control_image is not None:
if isinstance(self.control_model, ControlNetModel):
control_image = pipeline.prepare_control_image(
image=control_image,
do_classifier_free_guidance=do_classifier_free_guidance,
width=width,
height=height,
# batch_size=batch_size * num_images_per_prompt,
# num_images_per_prompt=num_images_per_prompt,
device=self.control_model.device,
dtype=self.control_model.dtype,
)
elif isinstance(self.control_model, MultiControlNetModel):
images = []
for image_ in control_image:
image_ = self.model.prepare_control_image(
image=image_,
do_classifier_free_guidance=do_classifier_free_guidance,
width=width,
height=height,
# batch_size=batch_size * num_images_per_prompt,
# num_images_per_prompt=num_images_per_prompt,
device=self.control_model.device,
dtype=self.control_model.dtype,
)
images.append(image_)
control_image = images
kwargs["control_image"] = control_image
def make_image(x_T: torch.Tensor, _: int) -> PIL.Image.Image: def make_image(x_T: torch.Tensor, _: int) -> PIL.Image.Image:
pipeline_output = pipeline.image_from_embeddings( pipeline_output = pipeline.image_from_embeddings(
latents=torch.zeros_like(x_T, dtype=self.torch_dtype()), latents=torch.zeros_like(x_T, dtype=self.torch_dtype()),
@ -68,6 +111,7 @@ class Txt2Img(Generator):
num_inference_steps=steps, num_inference_steps=steps,
conditioning_data=conditioning_data, conditioning_data=conditioning_data,
callback=step_callback, callback=step_callback,
**kwargs,
) )
if ( if (

View File

@ -9,16 +9,20 @@ from typing import Any, Callable, Generic, List, Optional, Type, TypeVar, Union
import einops import einops
import PIL.Image import PIL.Image
import numpy as np
from accelerate.utils import set_seed from accelerate.utils import set_seed
import psutil import psutil
import torch import torch
import torchvision.transforms as T import torchvision.transforms as T
from compel import EmbeddingsProvider from compel import EmbeddingsProvider
from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.models.controlnet import ControlNetModel, ControlNetOutput
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import ( from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
StableDiffusionPipeline, StableDiffusionPipeline,
) )
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import ( from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import (
StableDiffusionImg2ImgPipeline, StableDiffusionImg2ImgPipeline,
) )
@ -27,6 +31,7 @@ from diffusers.pipelines.stable_diffusion.safety_checker import (
) )
from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
from diffusers.utils import PIL_INTERPOLATION
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.outputs import BaseOutput from diffusers.utils.outputs import BaseOutput
from torchvision.transforms.functional import resize as tv_resize from torchvision.transforms.functional import resize as tv_resize
@ -304,6 +309,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
feature_extractor: Optional[CLIPFeatureExtractor], feature_extractor: Optional[CLIPFeatureExtractor],
requires_safety_checker: bool = False, requires_safety_checker: bool = False,
precision: str = "float32", precision: str = "float32",
control_model: ControlNetModel = None,
): ):
super().__init__( super().__init__(
vae, vae,
@ -324,6 +330,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
scheduler=scheduler, scheduler=scheduler,
safety_checker=safety_checker, safety_checker=safety_checker,
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
# FIXME: can't currently register control module
# control_model=control_model,
) )
self.invokeai_diffuser = InvokeAIDiffuserComponent( self.invokeai_diffuser = InvokeAIDiffuserComponent(
self.unet, self._unet_forward, is_running_diffusers=True self.unet, self._unet_forward, is_running_diffusers=True
@ -343,6 +351,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
self._model_group = FullyLoadedModelGroup(self.unet.device) self._model_group = FullyLoadedModelGroup(self.unet.device)
self._model_group.install(*self._submodels) self._model_group.install(*self._submodels)
self.control_model = control_model
def _adjust_memory_efficient_attention(self, latents: torch.Tensor): def _adjust_memory_efficient_attention(self, latents: torch.Tensor):
""" """
@ -464,6 +473,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
noise: torch.Tensor, noise: torch.Tensor,
callback: Callable[[PipelineIntermediateState], None] = None, callback: Callable[[PipelineIntermediateState], None] = None,
run_id=None, run_id=None,
**kwargs,
) -> InvokeAIStableDiffusionPipelineOutput: ) -> InvokeAIStableDiffusionPipelineOutput:
r""" r"""
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
@ -484,6 +494,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
noise=noise, noise=noise,
run_id=run_id, run_id=run_id,
callback=callback, callback=callback,
**kwargs,
) )
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache() torch.cuda.empty_cache()
@ -508,6 +519,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
additional_guidance: List[Callable] = None, additional_guidance: List[Callable] = None,
run_id=None, run_id=None,
callback: Callable[[PipelineIntermediateState], None] = None, callback: Callable[[PipelineIntermediateState], None] = None,
**kwargs,
) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]: ) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]:
if timesteps is None: if timesteps is None:
self.scheduler.set_timesteps( self.scheduler.set_timesteps(
@ -525,6 +537,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
additional_guidance=additional_guidance, additional_guidance=additional_guidance,
run_id=run_id, run_id=run_id,
callback=callback, callback=callback,
**kwargs,
) )
return result.latents, result.attention_map_saver return result.latents, result.attention_map_saver
@ -537,6 +550,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
noise: torch.Tensor, noise: torch.Tensor,
run_id: str = None, run_id: str = None,
additional_guidance: List[Callable] = None, additional_guidance: List[Callable] = None,
**kwargs,
): ):
self._adjust_memory_efficient_attention(latents) self._adjust_memory_efficient_attention(latents)
if run_id is None: if run_id is None:
@ -575,6 +589,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
step_index=i, step_index=i,
total_step_count=len(timesteps), total_step_count=len(timesteps),
additional_guidance=additional_guidance, additional_guidance=additional_guidance,
**kwargs,
) )
latents = step_output.prev_sample latents = step_output.prev_sample
@ -615,6 +630,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
step_index: int, step_index: int,
total_step_count: int, total_step_count: int,
additional_guidance: List[Callable] = None, additional_guidance: List[Callable] = None,
**kwargs,
): ):
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value # invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
timestep = t[0] timestep = t[0]
@ -626,6 +642,33 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# i.e. before or after passing it to InvokeAIDiffuserComponent # i.e. before or after passing it to InvokeAIDiffuserComponent
latent_model_input = self.scheduler.scale_model_input(latents, timestep) latent_model_input = self.scheduler.scale_model_input(latents, timestep)
if (self.control_model is not None) and (kwargs.get("control_image") is not None):
control_image = kwargs.get("control_image") # should be a processed tensor derived from the control image(s)
control_scale = kwargs.get("control_scale", 1.0) # control_scale default is 1.0
# handling case where using multiple control models but only specifying single control_scale
# so reshape control_scale to match number of control models
if isinstance(self.control_model, MultiControlNetModel) and isinstance(control_scale, float):
control_scale = [control_scale] * len(self.control_model.nets)
if conditioning_data.guidance_scale > 1.0:
# expand the latents input to control model if doing classifier free guidance
# (which I think for now is always true, there is conditional elsewhere that stops execution if
# classifier_free_guidance is <= 1.0 ?)
latent_control_input = torch.cat([latent_model_input] * 2)
else:
latent_control_input = latent_model_input
# controlnet inference
down_block_res_samples, mid_block_res_sample = self.control_model(
latent_control_input,
timestep,
encoder_hidden_states=torch.cat([conditioning_data.unconditioned_embeddings,
conditioning_data.text_embeddings]),
controlnet_cond=control_image,
conditioning_scale=control_scale,
return_dict=False,
)
else:
down_block_res_samples, mid_block_res_sample = None, None
# predict the noise residual # predict the noise residual
noise_pred = self.invokeai_diffuser.do_diffusion_step( noise_pred = self.invokeai_diffuser.do_diffusion_step(
latent_model_input, latent_model_input,
@ -635,6 +678,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
conditioning_data.guidance_scale, conditioning_data.guidance_scale,
step_index=step_index, step_index=step_index,
total_step_count=total_step_count, total_step_count=total_step_count,
down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample,
) )
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
@ -656,6 +701,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
t, t,
text_embeddings, text_embeddings,
cross_attention_kwargs: Optional[dict[str, Any]] = None, cross_attention_kwargs: Optional[dict[str, Any]] = None,
**kwargs,
): ):
"""predict the noise residual""" """predict the noise residual"""
if is_inpainting_model(self.unet) and latents.size(1) == 4: if is_inpainting_model(self.unet) and latents.size(1) == 4:
@ -675,7 +721,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# First three args should be positional, not keywords, so torch hooks can see them. # First three args should be positional, not keywords, so torch hooks can see them.
return self.unet( return self.unet(
latents, t, text_embeddings, cross_attention_kwargs=cross_attention_kwargs latents, t, text_embeddings, cross_attention_kwargs=cross_attention_kwargs,
**kwargs,
).sample ).sample
def img2img_from_embeddings( def img2img_from_embeddings(
@ -937,3 +984,48 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
debug_image( debug_image(
img, f"latents {msg} {i+1}/{len(decoded)}", debug_status=True img, f"latents {msg} {i+1}/{len(decoded)}", debug_status=True
) )
# Copied from diffusers pipeline_stable_diffusion_controlnet.py
# Returns torch.Tensor of shape (batch_size, 3, height, width)
def prepare_control_image(
self,
image,
width=512,
height=512,
batch_size=1,
num_images_per_prompt=1,
device="cuda",
dtype=torch.float16,
do_classifier_free_guidance=True,
):
if not isinstance(image, torch.Tensor):
if isinstance(image, PIL.Image.Image):
image = [image]
if isinstance(image[0], PIL.Image.Image):
images = []
for image_ in image:
image_ = image_.convert("RGB")
image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
image_ = np.array(image_)
image_ = image_[None, :]
images.append(image_)
image = images
image = np.concatenate(image, axis=0)
image = np.array(image).astype(np.float32) / 255.0
image = image.transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
elif isinstance(image[0], torch.Tensor):
image = torch.cat(image, dim=0)
image_batch_size = image.shape[0]
if image_batch_size == 1:
repeat_by = batch_size
else:
# image batch size is the same as prompt batch size
repeat_by = num_images_per_prompt
image = image.repeat_interleave(repeat_by, dim=0)
image = image.to(device=device, dtype=dtype)
if do_classifier_free_guidance:
image = torch.cat([image] * 2)
return image

View File

@ -1,7 +1,6 @@
# adapted from bloc97's CrossAttentionControl colab # adapted from bloc97's CrossAttentionControl colab
# https://github.com/bloc97/CrossAttentionControl # https://github.com/bloc97/CrossAttentionControl
import enum import enum
import math import math
from typing import Callable, Optional from typing import Callable, Optional

View File

@ -168,6 +168,7 @@ class InvokeAIDiffuserComponent:
unconditional_guidance_scale: float, unconditional_guidance_scale: float,
step_index: Optional[int] = None, step_index: Optional[int] = None,
total_step_count: Optional[int] = None, total_step_count: Optional[int] = None,
**kwargs,
): ):
""" """
:param x: current latents :param x: current latents
@ -196,7 +197,7 @@ class InvokeAIDiffuserComponent:
if wants_hybrid_conditioning: if wants_hybrid_conditioning:
unconditioned_next_x, conditioned_next_x = self._apply_hybrid_conditioning( unconditioned_next_x, conditioned_next_x = self._apply_hybrid_conditioning(
x, sigma, unconditioning, conditioning x, sigma, unconditioning, conditioning, **kwargs,
) )
elif wants_cross_attention_control: elif wants_cross_attention_control:
( (
@ -208,13 +209,14 @@ class InvokeAIDiffuserComponent:
unconditioning, unconditioning,
conditioning, conditioning,
cross_attention_control_types_to_do, cross_attention_control_types_to_do,
**kwargs,
) )
elif self.sequential_guidance: elif self.sequential_guidance:
( (
unconditioned_next_x, unconditioned_next_x,
conditioned_next_x, conditioned_next_x,
) = self._apply_standard_conditioning_sequentially( ) = self._apply_standard_conditioning_sequentially(
x, sigma, unconditioning, conditioning x, sigma, unconditioning, conditioning, **kwargs,
) )
else: else:
@ -222,7 +224,7 @@ class InvokeAIDiffuserComponent:
unconditioned_next_x, unconditioned_next_x,
conditioned_next_x, conditioned_next_x,
) = self._apply_standard_conditioning( ) = self._apply_standard_conditioning(
x, sigma, unconditioning, conditioning x, sigma, unconditioning, conditioning, **kwargs,
) )
combined_next_x = self._combine( combined_next_x = self._combine(
@ -269,13 +271,13 @@ class InvokeAIDiffuserComponent:
# methods below are called from do_diffusion_step and should be considered private to this class. # methods below are called from do_diffusion_step and should be considered private to this class.
def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning): def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs):
# fast batched path # fast batched path
x_twice = torch.cat([x] * 2) x_twice = torch.cat([x] * 2)
sigma_twice = torch.cat([sigma] * 2) sigma_twice = torch.cat([sigma] * 2)
both_conditionings = torch.cat([unconditioning, conditioning]) both_conditionings = torch.cat([unconditioning, conditioning])
both_results = self.model_forward_callback( both_results = self.model_forward_callback(
x_twice, sigma_twice, both_conditionings x_twice, sigma_twice, both_conditionings, **kwargs,
) )
unconditioned_next_x, conditioned_next_x = both_results.chunk(2) unconditioned_next_x, conditioned_next_x = both_results.chunk(2)
if conditioned_next_x.device.type == "mps": if conditioned_next_x.device.type == "mps":
@ -289,16 +291,17 @@ class InvokeAIDiffuserComponent:
sigma, sigma,
unconditioning: torch.Tensor, unconditioning: torch.Tensor,
conditioning: torch.Tensor, conditioning: torch.Tensor,
**kwargs,
): ):
# low-memory sequential path # low-memory sequential path
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning) unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning, **kwargs)
conditioned_next_x = self.model_forward_callback(x, sigma, conditioning) conditioned_next_x = self.model_forward_callback(x, sigma, conditioning, **kwargs)
if conditioned_next_x.device.type == "mps": if conditioned_next_x.device.type == "mps":
# prevent a result filled with zeros. seems to be a torch bug. # prevent a result filled with zeros. seems to be a torch bug.
conditioned_next_x = conditioned_next_x.clone() conditioned_next_x = conditioned_next_x.clone()
return unconditioned_next_x, conditioned_next_x return unconditioned_next_x, conditioned_next_x
def _apply_hybrid_conditioning(self, x, sigma, unconditioning, conditioning): def _apply_hybrid_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs):
assert isinstance(conditioning, dict) assert isinstance(conditioning, dict)
assert isinstance(unconditioning, dict) assert isinstance(unconditioning, dict)
x_twice = torch.cat([x] * 2) x_twice = torch.cat([x] * 2)
@ -313,7 +316,7 @@ class InvokeAIDiffuserComponent:
else: else:
both_conditionings[k] = torch.cat([unconditioning[k], conditioning[k]]) both_conditionings[k] = torch.cat([unconditioning[k], conditioning[k]])
unconditioned_next_x, conditioned_next_x = self.model_forward_callback( unconditioned_next_x, conditioned_next_x = self.model_forward_callback(
x_twice, sigma_twice, both_conditionings x_twice, sigma_twice, both_conditionings, **kwargs,
).chunk(2) ).chunk(2)
return unconditioned_next_x, conditioned_next_x return unconditioned_next_x, conditioned_next_x
@ -324,6 +327,7 @@ class InvokeAIDiffuserComponent:
unconditioning, unconditioning,
conditioning, conditioning,
cross_attention_control_types_to_do, cross_attention_control_types_to_do,
**kwargs,
): ):
if self.is_running_diffusers: if self.is_running_diffusers:
return self._apply_cross_attention_controlled_conditioning__diffusers( return self._apply_cross_attention_controlled_conditioning__diffusers(
@ -332,6 +336,7 @@ class InvokeAIDiffuserComponent:
unconditioning, unconditioning,
conditioning, conditioning,
cross_attention_control_types_to_do, cross_attention_control_types_to_do,
**kwargs,
) )
else: else:
return self._apply_cross_attention_controlled_conditioning__compvis( return self._apply_cross_attention_controlled_conditioning__compvis(
@ -340,6 +345,7 @@ class InvokeAIDiffuserComponent:
unconditioning, unconditioning,
conditioning, conditioning,
cross_attention_control_types_to_do, cross_attention_control_types_to_do,
**kwargs,
) )
def _apply_cross_attention_controlled_conditioning__diffusers( def _apply_cross_attention_controlled_conditioning__diffusers(
@ -349,6 +355,7 @@ class InvokeAIDiffuserComponent:
unconditioning, unconditioning,
conditioning, conditioning,
cross_attention_control_types_to_do, cross_attention_control_types_to_do,
**kwargs,
): ):
context: Context = self.cross_attention_control_context context: Context = self.cross_attention_control_context
@ -364,6 +371,7 @@ class InvokeAIDiffuserComponent:
sigma, sigma,
unconditioning, unconditioning,
{"swap_cross_attn_context": cross_attn_processor_context}, {"swap_cross_attn_context": cross_attn_processor_context},
**kwargs,
) )
# do requested cross attention types for conditioning (positive prompt) # do requested cross attention types for conditioning (positive prompt)
@ -375,6 +383,7 @@ class InvokeAIDiffuserComponent:
sigma, sigma,
conditioning, conditioning,
{"swap_cross_attn_context": cross_attn_processor_context}, {"swap_cross_attn_context": cross_attn_processor_context},
**kwargs,
) )
return unconditioned_next_x, conditioned_next_x return unconditioned_next_x, conditioned_next_x
@ -385,6 +394,7 @@ class InvokeAIDiffuserComponent:
unconditioning, unconditioning,
conditioning, conditioning,
cross_attention_control_types_to_do, cross_attention_control_types_to_do,
**kwargs,
): ):
# print('pct', percent_through, ': doing cross attention control on', cross_attention_control_types_to_do) # print('pct', percent_through, ': doing cross attention control on', cross_attention_control_types_to_do)
# slower non-batched path (20% slower on mac MPS) # slower non-batched path (20% slower on mac MPS)
@ -398,13 +408,13 @@ class InvokeAIDiffuserComponent:
context: Context = self.cross_attention_control_context context: Context = self.cross_attention_control_context
try: try:
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning) unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning, **kwargs)
# process x using the original prompt, saving the attention maps # process x using the original prompt, saving the attention maps
# print("saving attention maps for", cross_attention_control_types_to_do) # print("saving attention maps for", cross_attention_control_types_to_do)
for ca_type in cross_attention_control_types_to_do: for ca_type in cross_attention_control_types_to_do:
context.request_save_attention_maps(ca_type) context.request_save_attention_maps(ca_type)
_ = self.model_forward_callback(x, sigma, conditioning) _ = self.model_forward_callback(x, sigma, conditioning, **kwargs,)
context.clear_requests(cleanup=False) context.clear_requests(cleanup=False)
# process x again, using the saved attention maps to control where self.edited_conditioning will be applied # process x again, using the saved attention maps to control where self.edited_conditioning will be applied
@ -415,7 +425,7 @@ class InvokeAIDiffuserComponent:
self.conditioning.cross_attention_control_args.edited_conditioning self.conditioning.cross_attention_control_args.edited_conditioning
) )
conditioned_next_x = self.model_forward_callback( conditioned_next_x = self.model_forward_callback(
x, sigma, edited_conditioning x, sigma, edited_conditioning, **kwargs,
) )
context.clear_requests(cleanup=True) context.clear_requests(cleanup=True)