Switch to using torch 2.0 attention for IP-Adapter (more memory-efficient).

This commit is contained in:
Ryan Dick
2023-09-18 16:30:53 -04:00
parent 382e2139bd
commit b05b8ef677
2 changed files with 8 additions and 143 deletions

View File

@ -6,18 +6,10 @@ from typing import Optional, Union
import torch
from diffusers.models import UNet2DConditionModel
# FIXME: Getting errors when trying to use PyTorch 2.0 versions of IPAttnProcessor and AttnProcessor
# so for now falling back to the default versions
# from .utils import is_torch2_available
# if is_torch2_available:
# from .attention_processor import IPAttnProcessor2_0 as IPAttnProcessor, AttnProcessor2_0 as AttnProcessor
# else:
# from .attention_processor import IPAttnProcessor, AttnProcessor
from PIL import Image
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from .attention_processor import AttnProcessor, IPAttnProcessor
from .attention_processor import AttnProcessor2_0, IPAttnProcessor2_0
from .resampler import Resampler
@ -118,9 +110,9 @@ class IPAdapter:
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
if cross_attention_dim is None:
attn_procs[name] = AttnProcessor()
attn_procs[name] = AttnProcessor2_0()
else:
attn_procs[name] = IPAttnProcessor(
attn_procs[name] = IPAttnProcessor2_0(
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
scale=1.0,
@ -138,7 +130,7 @@ class IPAdapter:
def set_scale(self, scale):
if self._attn_processors is not None:
for attn_processor in self._attn_processors.values():
if isinstance(attn_processor, IPAttnProcessor):
if isinstance(attn_processor, IPAttnProcessor2_0):
attn_processor.scale = scale
@contextmanager
@ -156,7 +148,7 @@ class IPAdapter:
# Set scale
self.set_scale(scale)
# for attn_processor in self._attn_processors.values():
# if isinstance(attn_processor, IPAttnProcessor):
# if isinstance(attn_processor, IPAttnProcessor2_0):
# attn_processor.scale = scale
orig_attn_processors = unet.attn_processors