Modifying code from https://github.com/tencent-ailab/IP-Adapter. Also adding license notice at top.

This commit is contained in:
user1 2023-08-29 06:29:05 -07:00
parent 1ad98ce999
commit 8c1390166f
4 changed files with 94 additions and 68 deletions

View File

@ -1,3 +1,7 @@
# copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0)
# and modified as needed
# tencent-ailab comment:
# modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
import torch
import torch.nn as nn

View File

@ -1,3 +1,6 @@
# copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0)
# and modified as needed
import os
from typing import List
@ -6,10 +9,13 @@ from diffusers import StableDiffusionPipeline
from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
from PIL import Image
from .utils import is_torch2_available
if is_torch2_available:
from .attention_processor import IPAttnProcessor2_0 as IPAttnProcessor, AttnProcessor2_0 as AttnProcessor
else:
# FIXME: Getting errors when trying to use PyTorch 2.0 versions of IPAttnProcessor and AttnProcessor
# so for now falling back to the default versions
# from .utils import is_torch2_available
# if is_torch2_available:
# from .attention_processor import IPAttnProcessor2_0 as IPAttnProcessor, AttnProcessor2_0 as AttnProcessor
# else:
# from .attention_processor import IPAttnProcessor, AttnProcessor
from .attention_processor import IPAttnProcessor, AttnProcessor
from .resampler import Resampler
@ -40,7 +46,11 @@ class IPAdapter:
self.ip_ckpt = ip_ckpt
self.num_tokens = num_tokens
self.pipe = sd_pipe.to(self.device)
# FIXME:
# InvokeAI StableDiffusionPipeline has a to() method that isn't meant to be used
# so for now assuming that pipeline is already on the correct device
# self.pipe = sd_pipe.to(self.device)
self.pipe = sd_pipe
self.set_ip_adapter()
# load image encoder
@ -62,6 +72,8 @@ class IPAdapter:
def set_ip_adapter(self):
unet = self.pipe.unet
attn_procs = {}
print("Original UNet Attn Processors count:", len(unet.attn_processors))
print(unet.attn_processors.keys())
for name in unet.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
if name.startswith("mid_block"):
@ -75,9 +87,12 @@ class IPAdapter:
if cross_attention_dim is None:
attn_procs[name] = AttnProcessor()
else:
print("swapping in IPAttnProcessor for", name)
attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim,
scale=1.0).to(self.device, dtype=torch.float16)
unet.set_attn_processor(attn_procs)
print("Modified UNet Attn Processors count:", len(unet.attn_processors))
print(unet.attn_processors.keys())
def load_ip_adapter(self):
state_dict = torch.load(self.ip_ckpt, map_location="cpu")
@ -100,6 +115,8 @@ class IPAdapter:
if isinstance(attn_processor, IPAttnProcessor):
attn_processor.scale = scale
# IPAdapter.generate() method is not used for InvokeAI
# left here for reference
def generate(
self,
pil_image,

View File

@ -1,4 +1,6 @@
# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
# copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0)
# tencent ailab comment: modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
import math
import torch

View File

@ -1,3 +1,6 @@
# copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0)
# and modified as needed
import inspect
import warnings
from typing import Any, Callable, Dict, List, Optional, Tuple, Union