Re-factor IPAdapter to patch UNet in a context manager.

This commit is contained in:
Ryan Dick 2023-09-08 15:39:22 -04:00
parent d669f0855d
commit 91596d9527
4 changed files with 76 additions and 188 deletions

View File

@ -1,7 +1,10 @@
# copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0) # copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0)
# and modified as needed # and modified as needed
from contextlib import contextmanager
import torch import torch
from diffusers.models import UNet2DConditionModel
# FIXME: Getting errors when trying to use PyTorch 2.0 versions of IPAttnProcessor and AttnProcessor # FIXME: Getting errors when trying to use PyTorch 2.0 versions of IPAttnProcessor and AttnProcessor
# so for now falling back to the default versions # so for now falling back to the default versions
@ -38,18 +41,14 @@ class ImageProjModel(torch.nn.Module):
class IPAdapter: class IPAdapter:
def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens=4): def __init__(self, unet: UNet2DConditionModel, image_encoder_path, ip_ckpt, device, num_tokens=4):
self._unet = unet
self.device = device self.device = device
self.image_encoder_path = image_encoder_path self.image_encoder_path = image_encoder_path
self.ip_ckpt = ip_ckpt self.ip_ckpt = ip_ckpt
self.num_tokens = num_tokens self.num_tokens = num_tokens
# FIXME: self._attn_processors = self._prepare_attention_processors()
# InvokeAI StableDiffusionPipeline has a to() method that isn't meant to be used
# so for now assuming that pipeline is already on the correct device
# self.pipe = sd_pipe.to(self.device)
self.pipe = sd_pipe
self.set_ip_adapter()
# load image encoder # load image encoder
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to( self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(
@ -63,44 +62,55 @@ class IPAdapter:
def init_proj(self): def init_proj(self):
image_proj_model = ImageProjModel( image_proj_model = ImageProjModel(
cross_attention_dim=self.pipe.unet.config.cross_attention_dim, cross_attention_dim=self._unet.config.cross_attention_dim,
clip_embeddings_dim=self.image_encoder.config.projection_dim, clip_embeddings_dim=self.image_encoder.config.projection_dim,
clip_extra_context_tokens=self.num_tokens, clip_extra_context_tokens=self.num_tokens,
).to(self.device, dtype=torch.float16) ).to(self.device, dtype=torch.float16)
return image_proj_model return image_proj_model
def set_ip_adapter(self): def _prepare_attention_processors(self):
unet = self.pipe.unet """Creates a dict of attention processors that can later be injected into `self.unet`, and loads the IP-Adapter
attention weights into them.
"""
attn_procs = {} attn_procs = {}
print("Original UNet Attn Processors count:", len(unet.attn_processors)) for name in self._unet.attn_processors.keys():
print(unet.attn_processors.keys()) cross_attention_dim = None if name.endswith("attn1.processor") else self._unet.config.cross_attention_dim
for name in unet.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
if name.startswith("mid_block"): if name.startswith("mid_block"):
hidden_size = unet.config.block_out_channels[-1] hidden_size = self._unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"): elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")]) block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(unet.config.block_out_channels))[block_id] hidden_size = list(reversed(self._unet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"): elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")]) block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id] hidden_size = self._unet.config.block_out_channels[block_id]
if cross_attention_dim is None: if cross_attention_dim is None:
attn_procs[name] = AttnProcessor() attn_procs[name] = AttnProcessor()
else: else:
print("swapping in IPAttnProcessor for", name)
attn_procs[name] = IPAttnProcessor( attn_procs[name] = IPAttnProcessor(
hidden_size=hidden_size, hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
scale=1.0, scale=1.0,
).to(self.device, dtype=torch.float16) ).to(self.device, dtype=torch.float16)
unet.set_attn_processor(attn_procs) return attn_procs
print("Modified UNet Attn Processors count:", len(unet.attn_processors))
print(unet.attn_processors.keys()) @contextmanager
def apply_ip_adapter_attention(self):
"""A context manager that patches `self._unet` with this IP-Adapter's attention processors while it is active.
Yields:
None
"""
orig_attn_processors = self._unet.attn_processors
try:
self._unet.set_attn_processor(self._attn_processors)
yield None
finally:
self._unet.set_attn_processor(orig_attn_processors)
def load_ip_adapter(self): def load_ip_adapter(self):
state_dict = torch.load(self.ip_ckpt, map_location="cpu") state_dict = torch.load(self.ip_ckpt, map_location="cpu")
self.image_proj_model.load_state_dict(state_dict["image_proj"]) self.image_proj_model.load_state_dict(state_dict["image_proj"])
ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values()) ip_layers = torch.nn.ModuleList(self._attn_processors.values())
ip_layers.load_state_dict(state_dict["ip_adapter"]) ip_layers.load_state_dict(state_dict["ip_adapter"])
@torch.inference_mode() @torch.inference_mode()
@ -114,139 +124,15 @@ class IPAdapter:
return image_prompt_embeds, uncond_image_prompt_embeds return image_prompt_embeds, uncond_image_prompt_embeds
def set_scale(self, scale): def set_scale(self, scale):
for attn_processor in self.pipe.unet.attn_processors.values(): for attn_processor in self._attn_processors.values():
if isinstance(attn_processor, IPAttnProcessor): if isinstance(attn_processor, IPAttnProcessor):
attn_processor.scale = scale attn_processor.scale = scale
# IPAdapter.generate() method is not used for InvokeAI. Left here for reference:
# def generate(
# self,
# pil_image,
# prompt=None,
# negative_prompt=None,
# scale=1.0,
# num_samples=4,
# seed=-1,
# guidance_scale=7.5,
# num_inference_steps=30,
# **kwargs,
# ):
# self.set_scale(scale)
# if isinstance(pil_image, Image.Image):
# num_prompts = 1
# else:
# num_prompts = len(pil_image)
# if prompt is None:
# prompt = "best quality, high quality"
# if negative_prompt is None:
# negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
# if not isinstance(prompt, List):
# prompt = [prompt] * num_prompts
# if not isinstance(negative_prompt, List):
# negative_prompt = [negative_prompt] * num_prompts
# image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image)
# bs_embed, seq_len, _ = image_prompt_embeds.shape
# image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
# image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
# uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
# uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
# with torch.inference_mode():
# prompt_embeds = self.pipe._encode_prompt(
# prompt,
# device=self.device,
# num_images_per_prompt=num_samples,
# do_classifier_free_guidance=True,
# negative_prompt=negative_prompt,
# )
# negative_prompt_embeds_, prompt_embeds_ = prompt_embeds.chunk(2)
# prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
# negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
# generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
# images = self.pipe(
# prompt_embeds=prompt_embeds,
# negative_prompt_embeds=negative_prompt_embeds,
# guidance_scale=guidance_scale,
# num_inference_steps=num_inference_steps,
# generator=generator,
# **kwargs,
# ).images
# return images
class IPAdapterXL(IPAdapter): class IPAdapterXL(IPAdapter):
"""SDXL""" """SDXL"""
pass pass
# IPAdapterXL.generate() method is not used for InvokeAI. Left here for reference:
# def generate(
# self,
# pil_image,
# prompt=None,
# negative_prompt=None,
# scale=1.0,
# num_samples=4,
# seed=-1,
# num_inference_steps=30,
# **kwargs,
# ):
# self.set_scale(scale)
# if isinstance(pil_image, Image.Image):
# num_prompts = 1
# else:
# num_prompts = len(pil_image)
# if prompt is None:
# prompt = "best quality, high quality"
# if negative_prompt is None:
# negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
# if not isinstance(prompt, List):
# prompt = [prompt] * num_prompts
# if not isinstance(negative_prompt, List):
# negative_prompt = [negative_prompt] * num_prompts
# image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image)
# bs_embed, seq_len, _ = image_prompt_embeds.shape
# image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
# image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
# uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
# uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
# with torch.inference_mode():
# (
# prompt_embeds,
# negative_prompt_embeds,
# pooled_prompt_embeds,
# negative_pooled_prompt_embeds,
# ) = self.pipe.encode_prompt(
# prompt,
# num_images_per_prompt=num_samples,
# do_classifier_free_guidance=True,
# negative_prompt=negative_prompt,
# )
# prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
# negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
# generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
# images = self.pipe(
# prompt_embeds=prompt_embeds,
# negative_prompt_embeds=negative_prompt_embeds,
# pooled_prompt_embeds=pooled_prompt_embeds,
# negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
# num_inference_steps=num_inference_steps,
# generator=generator,
# **kwargs,
# ).images
# return images
class IPAdapterPlus(IPAdapter): class IPAdapterPlus(IPAdapter):
@ -254,13 +140,13 @@ class IPAdapterPlus(IPAdapter):
def init_proj(self): def init_proj(self):
image_proj_model = Resampler( image_proj_model = Resampler(
dim=self.pipe.unet.config.cross_attention_dim, dim=self._unet.config.cross_attention_dim,
depth=4, depth=4,
dim_head=64, dim_head=64,
heads=12, heads=12,
num_queries=self.num_tokens, num_queries=self.num_tokens,
embedding_dim=self.image_encoder.config.hidden_size, embedding_dim=self.image_encoder.config.hidden_size,
output_dim=self.pipe.unet.config.cross_attention_dim, output_dim=self._unet.config.cross_attention_dim,
ff_mult=4, ff_mult=4,
).to(self.device, dtype=torch.float16) ).to(self.device, dtype=torch.float16)
return image_proj_model return image_proj_model

View File

@ -2,6 +2,7 @@ from __future__ import annotations
import dataclasses import dataclasses
import inspect import inspect
from contextlib import nullcontext
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Callable, List, Optional, Union from typing import Any, Callable, List, Optional, Union
@ -419,21 +420,14 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
if ip_adapter_data is not None: if ip_adapter_data is not None:
# Initialize IPAdapter # Initialize IPAdapter
# FIXME: # TODO(ryand): Refactor to use model management for the IP-Adapter.
# WARNING!
# IPAdapter constructor modifies UNet model in-place
# Adds additional cross-attention layers to UNet model for image embedding
# need to figure out how to only do this if UNet hasn't already been modified by prior IPAdapter
# and how to undo if ip_adapter_image is removed
# Should reimplement to use existing model management context etc.
#
if "sdxl" in ip_adapter_data.ip_adapter_model: if "sdxl" in ip_adapter_data.ip_adapter_model:
ip_adapter = IPAdapterXL( ip_adapter = IPAdapterXL(
self, ip_adapter_data.image_encoder_model, ip_adapter_data.ip_adapter_model, self.unet.device self.unet, ip_adapter_data.image_encoder_model, ip_adapter_data.ip_adapter_model, self.unet.device
) )
elif "plus" in ip_adapter_data.ip_adapter_model: elif "plus" in ip_adapter_data.ip_adapter_model:
ip_adapter = IPAdapterPlus( ip_adapter = IPAdapterPlus(
self, # IPAdapterPlus first arg is StableDiffusionPipeline self.unet,
ip_adapter_data.image_encoder_model, ip_adapter_data.image_encoder_model,
ip_adapter_data.ip_adapter_model, ip_adapter_data.ip_adapter_model,
self.unet.device, self.unet.device,
@ -441,7 +435,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
) )
else: else:
ip_adapter = IPAdapter( ip_adapter = IPAdapter(
self, # IPAdapter first arg is StableDiffusionPipeline self.unet,
ip_adapter_data.image_encoder_model, ip_adapter_data.image_encoder_model,
ip_adapter_data.ip_adapter_model, ip_adapter_data.ip_adapter_model,
self.unet.device, self.unet.device,
@ -454,13 +448,20 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
image_prompt_embeds, uncond_image_prompt_embeds image_prompt_embeds, uncond_image_prompt_embeds
) )
# TODO(ryand): Apply IP-Adapter or custom attention control if conditioning_data.extra is not None and conditioning_data.extra.wants_cross_attention_control:
extra_conditioning_info = conditioning_data.extra attn_ctx = self.invokeai_diffuser.custom_attention_context(
with self.invokeai_diffuser.custom_attention_context( self.invokeai_diffuser.model,
self.invokeai_diffuser.model, extra_conditioning_info=conditioning_data.extra,
extra_conditioning_info=extra_conditioning_info, step_count=len(self.scheduler.timesteps),
step_count=len(self.scheduler.timesteps), )
): elif ip_adapter_data is not None:
# TODO(ryand): Should we raise an exception if both custom attention and IP-Adapter attention are active?
# As it is now, the IP-Adapter will silently be skipped.
attn_ctx = ip_adapter.apply_ip_adapter_attention()
else:
attn_ctx = nullcontext()
with attn_ctx:
if callback is not None: if callback is not None:
callback( callback(
PipelineIntermediateState( PipelineIntermediateState(

View File

@ -11,16 +11,17 @@ import diffusers
import psutil import psutil
import torch import torch
from compel.cross_attention_control import Arguments from compel.cross_attention_control import Arguments
from diffusers.models.unet_2d_condition import UNet2DConditionModel
from diffusers.models.attention_processor import AttentionProcessor
from diffusers.models.attention_processor import ( from diffusers.models.attention_processor import (
Attention, Attention,
AttentionProcessor,
AttnProcessor, AttnProcessor,
SlicedAttnProcessor, SlicedAttnProcessor,
) )
from diffusers.models.unet_2d_condition import UNet2DConditionModel
from torch import nn from torch import nn
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from ...util import torch_dtype from ...util import torch_dtype
@ -380,10 +381,13 @@ def get_cross_attention_modules(model, which: CrossAttentionType) -> list[tuple[
# non-fatal error but .swap() won't work. # non-fatal error but .swap() won't work.
logger.error( logger.error(
f"Error! CrossAttentionControl found an unexpected number of {cross_attention_class} modules in the model " f"Error! CrossAttentionControl found an unexpected number of {cross_attention_class} modules in the model "
+ f"(expected {expected_count}, found {cross_attention_modules_in_model_count}). Either monkey-patching failed " + f"(expected {expected_count}, found {cross_attention_modules_in_model_count}). Either monkey-patching"
" failed "
+ "or some assumption has changed about the structure of the model itself. Please fix the monkey-patching, " + "or some assumption has changed about the structure of the model itself. Please fix the monkey-patching, "
+ f"and/or update the {expected_count} above to an appropriate number, and/or find and inform someone who knows " + f"and/or update the {expected_count} above to an appropriate number, and/or find and inform someone who"
+ "what it means. This error is non-fatal, but it is likely that .swap() and attention map display will not " " knows "
+ "what it means. This error is non-fatal, but it is likely that .swap() and attention map display will"
" not "
+ "work properly until it is fixed." + "work properly until it is fixed."
) )
return attention_module_tuples return attention_module_tuples
@ -581,6 +585,7 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor):
attention_mask=None, attention_mask=None,
# kwargs # kwargs
swap_cross_attn_context: SwapCrossAttnContext = None, swap_cross_attn_context: SwapCrossAttnContext = None,
**kwargs,
): ):
attention_type = CrossAttentionType.SELF if encoder_hidden_states is None else CrossAttentionType.TOKENS attention_type = CrossAttentionType.SELF if encoder_hidden_states is None else CrossAttentionType.TOKENS

View File

@ -67,30 +67,26 @@ class InvokeAIDiffuserComponent:
@contextmanager @contextmanager
def custom_attention_context( def custom_attention_context(
self, self,
unet: UNet2DConditionModel, # note: also may futz with the text encoder depending on requested LoRAs unet: UNet2DConditionModel,
extra_conditioning_info: Optional[ExtraConditioningInfo], extra_conditioning_info: Optional[ExtraConditioningInfo],
step_count: int, step_count: int,
): ):
old_attn_processors = None old_attn_processors = unet.attn_processors
if extra_conditioning_info and (extra_conditioning_info.wants_cross_attention_control):
old_attn_processors = unet.attn_processors
# Load lora conditions into the model
if extra_conditioning_info.wants_cross_attention_control:
self.cross_attention_control_context = Context(
arguments=extra_conditioning_info.cross_attention_control_args,
step_count=step_count,
)
setup_cross_attention_control_attention_processors(
unet,
self.cross_attention_control_context,
)
try: try:
self.cross_attention_control_context = Context(
arguments=extra_conditioning_info.cross_attention_control_args,
step_count=step_count,
)
setup_cross_attention_control_attention_processors(
unet,
self.cross_attention_control_context,
)
yield None yield None
finally: finally:
self.cross_attention_control_context = None self.cross_attention_control_context = None
if old_attn_processors is not None: unet.set_attn_processor(old_attn_processors)
unet.set_attn_processor(old_attn_processors)
# TODO resuscitate attention map saving # TODO resuscitate attention map saving
# self.remove_attention_map_saving() # self.remove_attention_map_saving()