Re-factor IPAdapter to patch UNet in a context manager.

This commit is contained in:
Ryan Dick 2023-09-08 15:39:22 -04:00
parent d669f0855d
commit 91596d9527
4 changed files with 76 additions and 188 deletions

View File

@ -1,7 +1,10 @@
# copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0)
# and modified as needed
from contextlib import contextmanager
import torch
from diffusers.models import UNet2DConditionModel
# FIXME: Getting errors when trying to use PyTorch 2.0 versions of IPAttnProcessor and AttnProcessor
# so for now falling back to the default versions
@ -38,18 +41,14 @@ class ImageProjModel(torch.nn.Module):
class IPAdapter:
def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens=4):
def __init__(self, unet: UNet2DConditionModel, image_encoder_path, ip_ckpt, device, num_tokens=4):
self._unet = unet
self.device = device
self.image_encoder_path = image_encoder_path
self.ip_ckpt = ip_ckpt
self.num_tokens = num_tokens
# FIXME:
# InvokeAI StableDiffusionPipeline has a to() method that isn't meant to be used
# so for now assuming that pipeline is already on the correct device
# self.pipe = sd_pipe.to(self.device)
self.pipe = sd_pipe
self.set_ip_adapter()
self._attn_processors = self._prepare_attention_processors()
# load image encoder
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(
@ -63,44 +62,55 @@ class IPAdapter:
def init_proj(self):
image_proj_model = ImageProjModel(
cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
cross_attention_dim=self._unet.config.cross_attention_dim,
clip_embeddings_dim=self.image_encoder.config.projection_dim,
clip_extra_context_tokens=self.num_tokens,
).to(self.device, dtype=torch.float16)
return image_proj_model
def set_ip_adapter(self):
unet = self.pipe.unet
def _prepare_attention_processors(self):
"""Creates a dict of attention processors that can later be injected into `self.unet`, and loads the IP-Adapter
attention weights into them.
"""
attn_procs = {}
print("Original UNet Attn Processors count:", len(unet.attn_processors))
print(unet.attn_processors.keys())
for name in unet.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
for name in self._unet.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else self._unet.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = unet.config.block_out_channels[-1]
hidden_size = self._unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
hidden_size = list(reversed(self._unet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
hidden_size = self._unet.config.block_out_channels[block_id]
if cross_attention_dim is None:
attn_procs[name] = AttnProcessor()
else:
print("swapping in IPAttnProcessor for", name)
attn_procs[name] = IPAttnProcessor(
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
scale=1.0,
).to(self.device, dtype=torch.float16)
unet.set_attn_processor(attn_procs)
print("Modified UNet Attn Processors count:", len(unet.attn_processors))
print(unet.attn_processors.keys())
return attn_procs
@contextmanager
def apply_ip_adapter_attention(self):
"""A context manager that patches `self._unet` with this IP-Adapter's attention processors while it is active.
Yields:
None
"""
orig_attn_processors = self._unet.attn_processors
try:
self._unet.set_attn_processor(self._attn_processors)
yield None
finally:
self._unet.set_attn_processor(orig_attn_processors)
def load_ip_adapter(self):
state_dict = torch.load(self.ip_ckpt, map_location="cpu")
self.image_proj_model.load_state_dict(state_dict["image_proj"])
ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
ip_layers = torch.nn.ModuleList(self._attn_processors.values())
ip_layers.load_state_dict(state_dict["ip_adapter"])
@torch.inference_mode()
@ -114,139 +124,15 @@ class IPAdapter:
return image_prompt_embeds, uncond_image_prompt_embeds
def set_scale(self, scale):
for attn_processor in self.pipe.unet.attn_processors.values():
for attn_processor in self._attn_processors.values():
if isinstance(attn_processor, IPAttnProcessor):
attn_processor.scale = scale
# IPAdapter.generate() method is not used for InvokeAI. Left here for reference:
# def generate(
# self,
# pil_image,
# prompt=None,
# negative_prompt=None,
# scale=1.0,
# num_samples=4,
# seed=-1,
# guidance_scale=7.5,
# num_inference_steps=30,
# **kwargs,
# ):
# self.set_scale(scale)
# if isinstance(pil_image, Image.Image):
# num_prompts = 1
# else:
# num_prompts = len(pil_image)
# if prompt is None:
# prompt = "best quality, high quality"
# if negative_prompt is None:
# negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
# if not isinstance(prompt, List):
# prompt = [prompt] * num_prompts
# if not isinstance(negative_prompt, List):
# negative_prompt = [negative_prompt] * num_prompts
# image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image)
# bs_embed, seq_len, _ = image_prompt_embeds.shape
# image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
# image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
# uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
# uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
# with torch.inference_mode():
# prompt_embeds = self.pipe._encode_prompt(
# prompt,
# device=self.device,
# num_images_per_prompt=num_samples,
# do_classifier_free_guidance=True,
# negative_prompt=negative_prompt,
# )
# negative_prompt_embeds_, prompt_embeds_ = prompt_embeds.chunk(2)
# prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
# negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
# generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
# images = self.pipe(
# prompt_embeds=prompt_embeds,
# negative_prompt_embeds=negative_prompt_embeds,
# guidance_scale=guidance_scale,
# num_inference_steps=num_inference_steps,
# generator=generator,
# **kwargs,
# ).images
# return images
class IPAdapterXL(IPAdapter):
"""SDXL"""
pass
# IPAdapterXL.generate() method is not used for InvokeAI. Left here for reference:
# def generate(
# self,
# pil_image,
# prompt=None,
# negative_prompt=None,
# scale=1.0,
# num_samples=4,
# seed=-1,
# num_inference_steps=30,
# **kwargs,
# ):
# self.set_scale(scale)
# if isinstance(pil_image, Image.Image):
# num_prompts = 1
# else:
# num_prompts = len(pil_image)
# if prompt is None:
# prompt = "best quality, high quality"
# if negative_prompt is None:
# negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
# if not isinstance(prompt, List):
# prompt = [prompt] * num_prompts
# if not isinstance(negative_prompt, List):
# negative_prompt = [negative_prompt] * num_prompts
# image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image)
# bs_embed, seq_len, _ = image_prompt_embeds.shape
# image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
# image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
# uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
# uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
# with torch.inference_mode():
# (
# prompt_embeds,
# negative_prompt_embeds,
# pooled_prompt_embeds,
# negative_pooled_prompt_embeds,
# ) = self.pipe.encode_prompt(
# prompt,
# num_images_per_prompt=num_samples,
# do_classifier_free_guidance=True,
# negative_prompt=negative_prompt,
# )
# prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
# negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
# generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
# images = self.pipe(
# prompt_embeds=prompt_embeds,
# negative_prompt_embeds=negative_prompt_embeds,
# pooled_prompt_embeds=pooled_prompt_embeds,
# negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
# num_inference_steps=num_inference_steps,
# generator=generator,
# **kwargs,
# ).images
# return images
class IPAdapterPlus(IPAdapter):
@ -254,13 +140,13 @@ class IPAdapterPlus(IPAdapter):
def init_proj(self):
image_proj_model = Resampler(
dim=self.pipe.unet.config.cross_attention_dim,
dim=self._unet.config.cross_attention_dim,
depth=4,
dim_head=64,
heads=12,
num_queries=self.num_tokens,
embedding_dim=self.image_encoder.config.hidden_size,
output_dim=self.pipe.unet.config.cross_attention_dim,
output_dim=self._unet.config.cross_attention_dim,
ff_mult=4,
).to(self.device, dtype=torch.float16)
return image_proj_model

View File

@ -2,6 +2,7 @@ from __future__ import annotations
import dataclasses
import inspect
from contextlib import nullcontext
from dataclasses import dataclass, field
from typing import Any, Callable, List, Optional, Union
@ -419,21 +420,14 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
if ip_adapter_data is not None:
# Initialize IPAdapter
# FIXME:
# WARNING!
# IPAdapter constructor modifies UNet model in-place
# Adds additional cross-attention layers to UNet model for image embedding
# need to figure out how to only do this if UNet hasn't already been modified by prior IPAdapter
# and how to undo if ip_adapter_image is removed
# Should reimplement to use existing model management context etc.
#
# TODO(ryand): Refactor to use model management for the IP-Adapter.
if "sdxl" in ip_adapter_data.ip_adapter_model:
ip_adapter = IPAdapterXL(
self, ip_adapter_data.image_encoder_model, ip_adapter_data.ip_adapter_model, self.unet.device
self.unet, ip_adapter_data.image_encoder_model, ip_adapter_data.ip_adapter_model, self.unet.device
)
elif "plus" in ip_adapter_data.ip_adapter_model:
ip_adapter = IPAdapterPlus(
self, # IPAdapterPlus first arg is StableDiffusionPipeline
self.unet,
ip_adapter_data.image_encoder_model,
ip_adapter_data.ip_adapter_model,
self.unet.device,
@ -441,7 +435,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
)
else:
ip_adapter = IPAdapter(
self, # IPAdapter first arg is StableDiffusionPipeline
self.unet,
ip_adapter_data.image_encoder_model,
ip_adapter_data.ip_adapter_model,
self.unet.device,
@ -454,13 +448,20 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
image_prompt_embeds, uncond_image_prompt_embeds
)
# TODO(ryand): Apply IP-Adapter or custom attention control
extra_conditioning_info = conditioning_data.extra
with self.invokeai_diffuser.custom_attention_context(
self.invokeai_diffuser.model,
extra_conditioning_info=extra_conditioning_info,
step_count=len(self.scheduler.timesteps),
):
if conditioning_data.extra is not None and conditioning_data.extra.wants_cross_attention_control:
attn_ctx = self.invokeai_diffuser.custom_attention_context(
self.invokeai_diffuser.model,
extra_conditioning_info=conditioning_data.extra,
step_count=len(self.scheduler.timesteps),
)
elif ip_adapter_data is not None:
# TODO(ryand): Should we raise an exception if both custom attention and IP-Adapter attention are active?
# As it is now, the IP-Adapter will silently be skipped.
attn_ctx = ip_adapter.apply_ip_adapter_attention()
else:
attn_ctx = nullcontext()
with attn_ctx:
if callback is not None:
callback(
PipelineIntermediateState(

View File

@ -11,16 +11,17 @@ import diffusers
import psutil
import torch
from compel.cross_attention_control import Arguments
from diffusers.models.unet_2d_condition import UNet2DConditionModel
from diffusers.models.attention_processor import AttentionProcessor
from diffusers.models.attention_processor import (
Attention,
AttentionProcessor,
AttnProcessor,
SlicedAttnProcessor,
)
from diffusers.models.unet_2d_condition import UNet2DConditionModel
from torch import nn
import invokeai.backend.util.logging as logger
from ...util import torch_dtype
@ -380,10 +381,13 @@ def get_cross_attention_modules(model, which: CrossAttentionType) -> list[tuple[
# non-fatal error but .swap() won't work.
logger.error(
f"Error! CrossAttentionControl found an unexpected number of {cross_attention_class} modules in the model "
+ f"(expected {expected_count}, found {cross_attention_modules_in_model_count}). Either monkey-patching failed "
+ f"(expected {expected_count}, found {cross_attention_modules_in_model_count}). Either monkey-patching"
" failed "
+ "or some assumption has changed about the structure of the model itself. Please fix the monkey-patching, "
+ f"and/or update the {expected_count} above to an appropriate number, and/or find and inform someone who knows "
+ "what it means. This error is non-fatal, but it is likely that .swap() and attention map display will not "
+ f"and/or update the {expected_count} above to an appropriate number, and/or find and inform someone who"
" knows "
+ "what it means. This error is non-fatal, but it is likely that .swap() and attention map display will"
" not "
+ "work properly until it is fixed."
)
return attention_module_tuples
@ -581,6 +585,7 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor):
attention_mask=None,
# kwargs
swap_cross_attn_context: SwapCrossAttnContext = None,
**kwargs,
):
attention_type = CrossAttentionType.SELF if encoder_hidden_states is None else CrossAttentionType.TOKENS

View File

@ -67,30 +67,26 @@ class InvokeAIDiffuserComponent:
@contextmanager
def custom_attention_context(
self,
unet: UNet2DConditionModel, # note: also may futz with the text encoder depending on requested LoRAs
unet: UNet2DConditionModel,
extra_conditioning_info: Optional[ExtraConditioningInfo],
step_count: int,
):
old_attn_processors = None
if extra_conditioning_info and (extra_conditioning_info.wants_cross_attention_control):
old_attn_processors = unet.attn_processors
# Load lora conditions into the model
if extra_conditioning_info.wants_cross_attention_control:
self.cross_attention_control_context = Context(
arguments=extra_conditioning_info.cross_attention_control_args,
step_count=step_count,
)
setup_cross_attention_control_attention_processors(
unet,
self.cross_attention_control_context,
)
old_attn_processors = unet.attn_processors
try:
self.cross_attention_control_context = Context(
arguments=extra_conditioning_info.cross_attention_control_args,
step_count=step_count,
)
setup_cross_attention_control_attention_processors(
unet,
self.cross_attention_control_context,
)
yield None
finally:
self.cross_attention_control_context = None
if old_attn_processors is not None:
unet.set_attn_processor(old_attn_processors)
unet.set_attn_processor(old_attn_processors)
# TODO resuscitate attention map saving
# self.remove_attention_map_saving()