Modifying code from https://github.com/tencent-ailab/IP-Adapter. Also adding license notice at top.

This commit is contained in:
user1 2023-08-29 06:29:05 -07:00
parent 1ad98ce999
commit 8c1390166f
4 changed files with 94 additions and 68 deletions

View File

@ -1,3 +1,7 @@
# copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0)
# and modified as needed
# tencent-ailab comment:
# modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py # modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
import torch import torch
import torch.nn as nn import torch.nn as nn

View File

@ -1,3 +1,6 @@
# copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0)
# and modified as needed
import os import os
from typing import List from typing import List
@ -6,11 +9,14 @@ from diffusers import StableDiffusionPipeline
from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
from PIL import Image from PIL import Image
from .utils import is_torch2_available # FIXME: Getting errors when trying to use PyTorch 2.0 versions of IPAttnProcessor and AttnProcessor
if is_torch2_available: # so for now falling back to the default versions
from .attention_processor import IPAttnProcessor2_0 as IPAttnProcessor, AttnProcessor2_0 as AttnProcessor # from .utils import is_torch2_available
else: # if is_torch2_available:
from .attention_processor import IPAttnProcessor, AttnProcessor # from .attention_processor import IPAttnProcessor2_0 as IPAttnProcessor, AttnProcessor2_0 as AttnProcessor
# else:
# from .attention_processor import IPAttnProcessor, AttnProcessor
from .attention_processor import IPAttnProcessor, AttnProcessor
from .resampler import Resampler from .resampler import Resampler
@ -40,7 +46,11 @@ class IPAdapter:
self.ip_ckpt = ip_ckpt self.ip_ckpt = ip_ckpt
self.num_tokens = num_tokens self.num_tokens = num_tokens
self.pipe = sd_pipe.to(self.device) # FIXME:
# InvokeAI StableDiffusionPipeline has a to() method that isn't meant to be used
# so for now assuming that pipeline is already on the correct device
# self.pipe = sd_pipe.to(self.device)
self.pipe = sd_pipe
self.set_ip_adapter() self.set_ip_adapter()
# load image encoder # load image encoder
@ -62,6 +72,8 @@ class IPAdapter:
def set_ip_adapter(self): def set_ip_adapter(self):
unet = self.pipe.unet unet = self.pipe.unet
attn_procs = {} attn_procs = {}
print("Original UNet Attn Processors count:", len(unet.attn_processors))
print(unet.attn_processors.keys())
for name in unet.attn_processors.keys(): for name in unet.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
if name.startswith("mid_block"): if name.startswith("mid_block"):
@ -75,9 +87,12 @@ class IPAdapter:
if cross_attention_dim is None: if cross_attention_dim is None:
attn_procs[name] = AttnProcessor() attn_procs[name] = AttnProcessor()
else: else:
print("swapping in IPAttnProcessor for", name)
attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim,
scale=1.0).to(self.device, dtype=torch.float16) scale=1.0).to(self.device, dtype=torch.float16)
unet.set_attn_processor(attn_procs) unet.set_attn_processor(attn_procs)
print("Modified UNet Attn Processors count:", len(unet.attn_processors))
print(unet.attn_processors.keys())
def load_ip_adapter(self): def load_ip_adapter(self):
state_dict = torch.load(self.ip_ckpt, map_location="cpu") state_dict = torch.load(self.ip_ckpt, map_location="cpu")
@ -100,6 +115,8 @@ class IPAdapter:
if isinstance(attn_processor, IPAttnProcessor): if isinstance(attn_processor, IPAttnProcessor):
attn_processor.scale = scale attn_processor.scale = scale
# IPAdapter.generate() method is not used for InvokeAI
# left here for reference
def generate( def generate(
self, self,
pil_image, pil_image,

View File

@ -1,4 +1,6 @@
# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py # copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0)
# tencent ailab comment: modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
import math import math
import torch import torch

View File

@ -1,3 +1,6 @@
# copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0)
# and modified as needed
import inspect import inspect
import warnings import warnings
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Tuple, Union