mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Re-factor IPAdapter to patch UNet in a context manager.
This commit is contained in:
parent
d669f0855d
commit
91596d9527
@ -1,7 +1,10 @@
|
|||||||
# copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0)
|
# copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0)
|
||||||
# and modified as needed
|
# and modified as needed
|
||||||
|
|
||||||
|
from contextlib import contextmanager
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from diffusers.models import UNet2DConditionModel
|
||||||
|
|
||||||
# FIXME: Getting errors when trying to use PyTorch 2.0 versions of IPAttnProcessor and AttnProcessor
|
# FIXME: Getting errors when trying to use PyTorch 2.0 versions of IPAttnProcessor and AttnProcessor
|
||||||
# so for now falling back to the default versions
|
# so for now falling back to the default versions
|
||||||
@ -38,18 +41,14 @@ class ImageProjModel(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class IPAdapter:
|
class IPAdapter:
|
||||||
def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens=4):
|
def __init__(self, unet: UNet2DConditionModel, image_encoder_path, ip_ckpt, device, num_tokens=4):
|
||||||
|
self._unet = unet
|
||||||
self.device = device
|
self.device = device
|
||||||
self.image_encoder_path = image_encoder_path
|
self.image_encoder_path = image_encoder_path
|
||||||
self.ip_ckpt = ip_ckpt
|
self.ip_ckpt = ip_ckpt
|
||||||
self.num_tokens = num_tokens
|
self.num_tokens = num_tokens
|
||||||
|
|
||||||
# FIXME:
|
self._attn_processors = self._prepare_attention_processors()
|
||||||
# InvokeAI StableDiffusionPipeline has a to() method that isn't meant to be used
|
|
||||||
# so for now assuming that pipeline is already on the correct device
|
|
||||||
# self.pipe = sd_pipe.to(self.device)
|
|
||||||
self.pipe = sd_pipe
|
|
||||||
self.set_ip_adapter()
|
|
||||||
|
|
||||||
# load image encoder
|
# load image encoder
|
||||||
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(
|
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(
|
||||||
@ -63,44 +62,55 @@ class IPAdapter:
|
|||||||
|
|
||||||
def init_proj(self):
|
def init_proj(self):
|
||||||
image_proj_model = ImageProjModel(
|
image_proj_model = ImageProjModel(
|
||||||
cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
|
cross_attention_dim=self._unet.config.cross_attention_dim,
|
||||||
clip_embeddings_dim=self.image_encoder.config.projection_dim,
|
clip_embeddings_dim=self.image_encoder.config.projection_dim,
|
||||||
clip_extra_context_tokens=self.num_tokens,
|
clip_extra_context_tokens=self.num_tokens,
|
||||||
).to(self.device, dtype=torch.float16)
|
).to(self.device, dtype=torch.float16)
|
||||||
return image_proj_model
|
return image_proj_model
|
||||||
|
|
||||||
def set_ip_adapter(self):
|
def _prepare_attention_processors(self):
|
||||||
unet = self.pipe.unet
|
"""Creates a dict of attention processors that can later be injected into `self.unet`, and loads the IP-Adapter
|
||||||
|
attention weights into them.
|
||||||
|
"""
|
||||||
attn_procs = {}
|
attn_procs = {}
|
||||||
print("Original UNet Attn Processors count:", len(unet.attn_processors))
|
for name in self._unet.attn_processors.keys():
|
||||||
print(unet.attn_processors.keys())
|
cross_attention_dim = None if name.endswith("attn1.processor") else self._unet.config.cross_attention_dim
|
||||||
for name in unet.attn_processors.keys():
|
|
||||||
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
|
|
||||||
if name.startswith("mid_block"):
|
if name.startswith("mid_block"):
|
||||||
hidden_size = unet.config.block_out_channels[-1]
|
hidden_size = self._unet.config.block_out_channels[-1]
|
||||||
elif name.startswith("up_blocks"):
|
elif name.startswith("up_blocks"):
|
||||||
block_id = int(name[len("up_blocks.")])
|
block_id = int(name[len("up_blocks.")])
|
||||||
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
|
hidden_size = list(reversed(self._unet.config.block_out_channels))[block_id]
|
||||||
elif name.startswith("down_blocks"):
|
elif name.startswith("down_blocks"):
|
||||||
block_id = int(name[len("down_blocks.")])
|
block_id = int(name[len("down_blocks.")])
|
||||||
hidden_size = unet.config.block_out_channels[block_id]
|
hidden_size = self._unet.config.block_out_channels[block_id]
|
||||||
if cross_attention_dim is None:
|
if cross_attention_dim is None:
|
||||||
attn_procs[name] = AttnProcessor()
|
attn_procs[name] = AttnProcessor()
|
||||||
else:
|
else:
|
||||||
print("swapping in IPAttnProcessor for", name)
|
|
||||||
attn_procs[name] = IPAttnProcessor(
|
attn_procs[name] = IPAttnProcessor(
|
||||||
hidden_size=hidden_size,
|
hidden_size=hidden_size,
|
||||||
cross_attention_dim=cross_attention_dim,
|
cross_attention_dim=cross_attention_dim,
|
||||||
scale=1.0,
|
scale=1.0,
|
||||||
).to(self.device, dtype=torch.float16)
|
).to(self.device, dtype=torch.float16)
|
||||||
unet.set_attn_processor(attn_procs)
|
return attn_procs
|
||||||
print("Modified UNet Attn Processors count:", len(unet.attn_processors))
|
|
||||||
print(unet.attn_processors.keys())
|
@contextmanager
|
||||||
|
def apply_ip_adapter_attention(self):
|
||||||
|
"""A context manager that patches `self._unet` with this IP-Adapter's attention processors while it is active.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
orig_attn_processors = self._unet.attn_processors
|
||||||
|
try:
|
||||||
|
self._unet.set_attn_processor(self._attn_processors)
|
||||||
|
yield None
|
||||||
|
finally:
|
||||||
|
self._unet.set_attn_processor(orig_attn_processors)
|
||||||
|
|
||||||
def load_ip_adapter(self):
|
def load_ip_adapter(self):
|
||||||
state_dict = torch.load(self.ip_ckpt, map_location="cpu")
|
state_dict = torch.load(self.ip_ckpt, map_location="cpu")
|
||||||
self.image_proj_model.load_state_dict(state_dict["image_proj"])
|
self.image_proj_model.load_state_dict(state_dict["image_proj"])
|
||||||
ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
|
ip_layers = torch.nn.ModuleList(self._attn_processors.values())
|
||||||
ip_layers.load_state_dict(state_dict["ip_adapter"])
|
ip_layers.load_state_dict(state_dict["ip_adapter"])
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
@ -114,139 +124,15 @@ class IPAdapter:
|
|||||||
return image_prompt_embeds, uncond_image_prompt_embeds
|
return image_prompt_embeds, uncond_image_prompt_embeds
|
||||||
|
|
||||||
def set_scale(self, scale):
|
def set_scale(self, scale):
|
||||||
for attn_processor in self.pipe.unet.attn_processors.values():
|
for attn_processor in self._attn_processors.values():
|
||||||
if isinstance(attn_processor, IPAttnProcessor):
|
if isinstance(attn_processor, IPAttnProcessor):
|
||||||
attn_processor.scale = scale
|
attn_processor.scale = scale
|
||||||
|
|
||||||
# IPAdapter.generate() method is not used for InvokeAI. Left here for reference:
|
|
||||||
# def generate(
|
|
||||||
# self,
|
|
||||||
# pil_image,
|
|
||||||
# prompt=None,
|
|
||||||
# negative_prompt=None,
|
|
||||||
# scale=1.0,
|
|
||||||
# num_samples=4,
|
|
||||||
# seed=-1,
|
|
||||||
# guidance_scale=7.5,
|
|
||||||
# num_inference_steps=30,
|
|
||||||
# **kwargs,
|
|
||||||
# ):
|
|
||||||
# self.set_scale(scale)
|
|
||||||
|
|
||||||
# if isinstance(pil_image, Image.Image):
|
|
||||||
# num_prompts = 1
|
|
||||||
# else:
|
|
||||||
# num_prompts = len(pil_image)
|
|
||||||
|
|
||||||
# if prompt is None:
|
|
||||||
# prompt = "best quality, high quality"
|
|
||||||
# if negative_prompt is None:
|
|
||||||
# negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
|
|
||||||
|
|
||||||
# if not isinstance(prompt, List):
|
|
||||||
# prompt = [prompt] * num_prompts
|
|
||||||
# if not isinstance(negative_prompt, List):
|
|
||||||
# negative_prompt = [negative_prompt] * num_prompts
|
|
||||||
|
|
||||||
# image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image)
|
|
||||||
# bs_embed, seq_len, _ = image_prompt_embeds.shape
|
|
||||||
# image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
|
|
||||||
# image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
|
|
||||||
# uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
|
|
||||||
# uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
|
|
||||||
|
|
||||||
# with torch.inference_mode():
|
|
||||||
# prompt_embeds = self.pipe._encode_prompt(
|
|
||||||
# prompt,
|
|
||||||
# device=self.device,
|
|
||||||
# num_images_per_prompt=num_samples,
|
|
||||||
# do_classifier_free_guidance=True,
|
|
||||||
# negative_prompt=negative_prompt,
|
|
||||||
# )
|
|
||||||
# negative_prompt_embeds_, prompt_embeds_ = prompt_embeds.chunk(2)
|
|
||||||
# prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
|
|
||||||
# negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
|
|
||||||
|
|
||||||
# generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
|
|
||||||
# images = self.pipe(
|
|
||||||
# prompt_embeds=prompt_embeds,
|
|
||||||
# negative_prompt_embeds=negative_prompt_embeds,
|
|
||||||
# guidance_scale=guidance_scale,
|
|
||||||
# num_inference_steps=num_inference_steps,
|
|
||||||
# generator=generator,
|
|
||||||
# **kwargs,
|
|
||||||
# ).images
|
|
||||||
|
|
||||||
# return images
|
|
||||||
|
|
||||||
|
|
||||||
class IPAdapterXL(IPAdapter):
|
class IPAdapterXL(IPAdapter):
|
||||||
"""SDXL"""
|
"""SDXL"""
|
||||||
|
|
||||||
pass
|
pass
|
||||||
# IPAdapterXL.generate() method is not used for InvokeAI. Left here for reference:
|
|
||||||
# def generate(
|
|
||||||
# self,
|
|
||||||
# pil_image,
|
|
||||||
# prompt=None,
|
|
||||||
# negative_prompt=None,
|
|
||||||
# scale=1.0,
|
|
||||||
# num_samples=4,
|
|
||||||
# seed=-1,
|
|
||||||
# num_inference_steps=30,
|
|
||||||
# **kwargs,
|
|
||||||
# ):
|
|
||||||
# self.set_scale(scale)
|
|
||||||
|
|
||||||
# if isinstance(pil_image, Image.Image):
|
|
||||||
# num_prompts = 1
|
|
||||||
# else:
|
|
||||||
# num_prompts = len(pil_image)
|
|
||||||
|
|
||||||
# if prompt is None:
|
|
||||||
# prompt = "best quality, high quality"
|
|
||||||
# if negative_prompt is None:
|
|
||||||
# negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
|
|
||||||
|
|
||||||
# if not isinstance(prompt, List):
|
|
||||||
# prompt = [prompt] * num_prompts
|
|
||||||
# if not isinstance(negative_prompt, List):
|
|
||||||
# negative_prompt = [negative_prompt] * num_prompts
|
|
||||||
|
|
||||||
# image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image)
|
|
||||||
# bs_embed, seq_len, _ = image_prompt_embeds.shape
|
|
||||||
# image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
|
|
||||||
# image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
|
|
||||||
# uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
|
|
||||||
# uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
|
|
||||||
|
|
||||||
# with torch.inference_mode():
|
|
||||||
# (
|
|
||||||
# prompt_embeds,
|
|
||||||
# negative_prompt_embeds,
|
|
||||||
# pooled_prompt_embeds,
|
|
||||||
# negative_pooled_prompt_embeds,
|
|
||||||
# ) = self.pipe.encode_prompt(
|
|
||||||
# prompt,
|
|
||||||
# num_images_per_prompt=num_samples,
|
|
||||||
# do_classifier_free_guidance=True,
|
|
||||||
# negative_prompt=negative_prompt,
|
|
||||||
# )
|
|
||||||
# prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
|
|
||||||
# negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
|
|
||||||
|
|
||||||
# generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
|
|
||||||
# images = self.pipe(
|
|
||||||
# prompt_embeds=prompt_embeds,
|
|
||||||
# negative_prompt_embeds=negative_prompt_embeds,
|
|
||||||
# pooled_prompt_embeds=pooled_prompt_embeds,
|
|
||||||
# negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
|
||||||
# num_inference_steps=num_inference_steps,
|
|
||||||
# generator=generator,
|
|
||||||
# **kwargs,
|
|
||||||
# ).images
|
|
||||||
|
|
||||||
# return images
|
|
||||||
|
|
||||||
|
|
||||||
class IPAdapterPlus(IPAdapter):
|
class IPAdapterPlus(IPAdapter):
|
||||||
@ -254,13 +140,13 @@ class IPAdapterPlus(IPAdapter):
|
|||||||
|
|
||||||
def init_proj(self):
|
def init_proj(self):
|
||||||
image_proj_model = Resampler(
|
image_proj_model = Resampler(
|
||||||
dim=self.pipe.unet.config.cross_attention_dim,
|
dim=self._unet.config.cross_attention_dim,
|
||||||
depth=4,
|
depth=4,
|
||||||
dim_head=64,
|
dim_head=64,
|
||||||
heads=12,
|
heads=12,
|
||||||
num_queries=self.num_tokens,
|
num_queries=self.num_tokens,
|
||||||
embedding_dim=self.image_encoder.config.hidden_size,
|
embedding_dim=self.image_encoder.config.hidden_size,
|
||||||
output_dim=self.pipe.unet.config.cross_attention_dim,
|
output_dim=self._unet.config.cross_attention_dim,
|
||||||
ff_mult=4,
|
ff_mult=4,
|
||||||
).to(self.device, dtype=torch.float16)
|
).to(self.device, dtype=torch.float16)
|
||||||
return image_proj_model
|
return image_proj_model
|
||||||
|
@ -2,6 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import inspect
|
import inspect
|
||||||
|
from contextlib import nullcontext
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Callable, List, Optional, Union
|
from typing import Any, Callable, List, Optional, Union
|
||||||
|
|
||||||
@ -419,21 +420,14 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
|
|
||||||
if ip_adapter_data is not None:
|
if ip_adapter_data is not None:
|
||||||
# Initialize IPAdapter
|
# Initialize IPAdapter
|
||||||
# FIXME:
|
# TODO(ryand): Refactor to use model management for the IP-Adapter.
|
||||||
# WARNING!
|
|
||||||
# IPAdapter constructor modifies UNet model in-place
|
|
||||||
# Adds additional cross-attention layers to UNet model for image embedding
|
|
||||||
# need to figure out how to only do this if UNet hasn't already been modified by prior IPAdapter
|
|
||||||
# and how to undo if ip_adapter_image is removed
|
|
||||||
# Should reimplement to use existing model management context etc.
|
|
||||||
#
|
|
||||||
if "sdxl" in ip_adapter_data.ip_adapter_model:
|
if "sdxl" in ip_adapter_data.ip_adapter_model:
|
||||||
ip_adapter = IPAdapterXL(
|
ip_adapter = IPAdapterXL(
|
||||||
self, ip_adapter_data.image_encoder_model, ip_adapter_data.ip_adapter_model, self.unet.device
|
self.unet, ip_adapter_data.image_encoder_model, ip_adapter_data.ip_adapter_model, self.unet.device
|
||||||
)
|
)
|
||||||
elif "plus" in ip_adapter_data.ip_adapter_model:
|
elif "plus" in ip_adapter_data.ip_adapter_model:
|
||||||
ip_adapter = IPAdapterPlus(
|
ip_adapter = IPAdapterPlus(
|
||||||
self, # IPAdapterPlus first arg is StableDiffusionPipeline
|
self.unet,
|
||||||
ip_adapter_data.image_encoder_model,
|
ip_adapter_data.image_encoder_model,
|
||||||
ip_adapter_data.ip_adapter_model,
|
ip_adapter_data.ip_adapter_model,
|
||||||
self.unet.device,
|
self.unet.device,
|
||||||
@ -441,7 +435,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
ip_adapter = IPAdapter(
|
ip_adapter = IPAdapter(
|
||||||
self, # IPAdapter first arg is StableDiffusionPipeline
|
self.unet,
|
||||||
ip_adapter_data.image_encoder_model,
|
ip_adapter_data.image_encoder_model,
|
||||||
ip_adapter_data.ip_adapter_model,
|
ip_adapter_data.ip_adapter_model,
|
||||||
self.unet.device,
|
self.unet.device,
|
||||||
@ -454,13 +448,20 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
image_prompt_embeds, uncond_image_prompt_embeds
|
image_prompt_embeds, uncond_image_prompt_embeds
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO(ryand): Apply IP-Adapter or custom attention control
|
if conditioning_data.extra is not None and conditioning_data.extra.wants_cross_attention_control:
|
||||||
extra_conditioning_info = conditioning_data.extra
|
attn_ctx = self.invokeai_diffuser.custom_attention_context(
|
||||||
with self.invokeai_diffuser.custom_attention_context(
|
self.invokeai_diffuser.model,
|
||||||
self.invokeai_diffuser.model,
|
extra_conditioning_info=conditioning_data.extra,
|
||||||
extra_conditioning_info=extra_conditioning_info,
|
step_count=len(self.scheduler.timesteps),
|
||||||
step_count=len(self.scheduler.timesteps),
|
)
|
||||||
):
|
elif ip_adapter_data is not None:
|
||||||
|
# TODO(ryand): Should we raise an exception if both custom attention and IP-Adapter attention are active?
|
||||||
|
# As it is now, the IP-Adapter will silently be skipped.
|
||||||
|
attn_ctx = ip_adapter.apply_ip_adapter_attention()
|
||||||
|
else:
|
||||||
|
attn_ctx = nullcontext()
|
||||||
|
|
||||||
|
with attn_ctx:
|
||||||
if callback is not None:
|
if callback is not None:
|
||||||
callback(
|
callback(
|
||||||
PipelineIntermediateState(
|
PipelineIntermediateState(
|
||||||
|
@ -11,16 +11,17 @@ import diffusers
|
|||||||
import psutil
|
import psutil
|
||||||
import torch
|
import torch
|
||||||
from compel.cross_attention_control import Arguments
|
from compel.cross_attention_control import Arguments
|
||||||
from diffusers.models.unet_2d_condition import UNet2DConditionModel
|
|
||||||
from diffusers.models.attention_processor import AttentionProcessor
|
|
||||||
from diffusers.models.attention_processor import (
|
from diffusers.models.attention_processor import (
|
||||||
Attention,
|
Attention,
|
||||||
|
AttentionProcessor,
|
||||||
AttnProcessor,
|
AttnProcessor,
|
||||||
SlicedAttnProcessor,
|
SlicedAttnProcessor,
|
||||||
)
|
)
|
||||||
|
from diffusers.models.unet_2d_condition import UNet2DConditionModel
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
|
|
||||||
from ...util import torch_dtype
|
from ...util import torch_dtype
|
||||||
|
|
||||||
|
|
||||||
@ -380,10 +381,13 @@ def get_cross_attention_modules(model, which: CrossAttentionType) -> list[tuple[
|
|||||||
# non-fatal error but .swap() won't work.
|
# non-fatal error but .swap() won't work.
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Error! CrossAttentionControl found an unexpected number of {cross_attention_class} modules in the model "
|
f"Error! CrossAttentionControl found an unexpected number of {cross_attention_class} modules in the model "
|
||||||
+ f"(expected {expected_count}, found {cross_attention_modules_in_model_count}). Either monkey-patching failed "
|
+ f"(expected {expected_count}, found {cross_attention_modules_in_model_count}). Either monkey-patching"
|
||||||
|
" failed "
|
||||||
+ "or some assumption has changed about the structure of the model itself. Please fix the monkey-patching, "
|
+ "or some assumption has changed about the structure of the model itself. Please fix the monkey-patching, "
|
||||||
+ f"and/or update the {expected_count} above to an appropriate number, and/or find and inform someone who knows "
|
+ f"and/or update the {expected_count} above to an appropriate number, and/or find and inform someone who"
|
||||||
+ "what it means. This error is non-fatal, but it is likely that .swap() and attention map display will not "
|
" knows "
|
||||||
|
+ "what it means. This error is non-fatal, but it is likely that .swap() and attention map display will"
|
||||||
|
" not "
|
||||||
+ "work properly until it is fixed."
|
+ "work properly until it is fixed."
|
||||||
)
|
)
|
||||||
return attention_module_tuples
|
return attention_module_tuples
|
||||||
@ -581,6 +585,7 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor):
|
|||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
# kwargs
|
# kwargs
|
||||||
swap_cross_attn_context: SwapCrossAttnContext = None,
|
swap_cross_attn_context: SwapCrossAttnContext = None,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
attention_type = CrossAttentionType.SELF if encoder_hidden_states is None else CrossAttentionType.TOKENS
|
attention_type = CrossAttentionType.SELF if encoder_hidden_states is None else CrossAttentionType.TOKENS
|
||||||
|
|
||||||
|
@ -67,30 +67,26 @@ class InvokeAIDiffuserComponent:
|
|||||||
@contextmanager
|
@contextmanager
|
||||||
def custom_attention_context(
|
def custom_attention_context(
|
||||||
self,
|
self,
|
||||||
unet: UNet2DConditionModel, # note: also may futz with the text encoder depending on requested LoRAs
|
unet: UNet2DConditionModel,
|
||||||
extra_conditioning_info: Optional[ExtraConditioningInfo],
|
extra_conditioning_info: Optional[ExtraConditioningInfo],
|
||||||
step_count: int,
|
step_count: int,
|
||||||
):
|
):
|
||||||
old_attn_processors = None
|
old_attn_processors = unet.attn_processors
|
||||||
if extra_conditioning_info and (extra_conditioning_info.wants_cross_attention_control):
|
|
||||||
old_attn_processors = unet.attn_processors
|
|
||||||
# Load lora conditions into the model
|
|
||||||
if extra_conditioning_info.wants_cross_attention_control:
|
|
||||||
self.cross_attention_control_context = Context(
|
|
||||||
arguments=extra_conditioning_info.cross_attention_control_args,
|
|
||||||
step_count=step_count,
|
|
||||||
)
|
|
||||||
setup_cross_attention_control_attention_processors(
|
|
||||||
unet,
|
|
||||||
self.cross_attention_control_context,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
self.cross_attention_control_context = Context(
|
||||||
|
arguments=extra_conditioning_info.cross_attention_control_args,
|
||||||
|
step_count=step_count,
|
||||||
|
)
|
||||||
|
setup_cross_attention_control_attention_processors(
|
||||||
|
unet,
|
||||||
|
self.cross_attention_control_context,
|
||||||
|
)
|
||||||
|
|
||||||
yield None
|
yield None
|
||||||
finally:
|
finally:
|
||||||
self.cross_attention_control_context = None
|
self.cross_attention_control_context = None
|
||||||
if old_attn_processors is not None:
|
unet.set_attn_processor(old_attn_processors)
|
||||||
unet.set_attn_processor(old_attn_processors)
|
|
||||||
# TODO resuscitate attention map saving
|
# TODO resuscitate attention map saving
|
||||||
# self.remove_attention_map_saving()
|
# self.remove_attention_map_saving()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user