mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Modifying code from https://github.com/tencent-ailab/IP-Adapter. Also adding license notice at top.
This commit is contained in:
parent
1ad98ce999
commit
8c1390166f
@ -1,3 +1,7 @@
|
|||||||
|
# copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0)
|
||||||
|
# and modified as needed
|
||||||
|
|
||||||
|
# tencent-ailab comment:
|
||||||
# modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
|
# modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -74,8 +78,8 @@ class AttnProcessor(nn.Module):
|
|||||||
hidden_states = hidden_states / attn.rescale_output_factor
|
hidden_states = hidden_states / attn.rescale_output_factor
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class IPAttnProcessor(nn.Module):
|
class IPAttnProcessor(nn.Module):
|
||||||
r"""
|
r"""
|
||||||
Attention processor for IP-Adapater.
|
Attention processor for IP-Adapater.
|
||||||
@ -134,7 +138,7 @@ class IPAttnProcessor(nn.Module):
|
|||||||
encoder_hidden_states = hidden_states
|
encoder_hidden_states = hidden_states
|
||||||
elif attn.norm_cross:
|
elif attn.norm_cross:
|
||||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||||
|
|
||||||
# split hidden states
|
# split hidden states
|
||||||
encoder_hidden_states, ip_hidden_states = encoder_hidden_states[:, :self.text_context_len, :], encoder_hidden_states[:, self.text_context_len:, :]
|
encoder_hidden_states, ip_hidden_states = encoder_hidden_states[:, :self.text_context_len, :], encoder_hidden_states[:, self.text_context_len:, :]
|
||||||
|
|
||||||
@ -148,18 +152,18 @@ class IPAttnProcessor(nn.Module):
|
|||||||
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
||||||
hidden_states = torch.bmm(attention_probs, value)
|
hidden_states = torch.bmm(attention_probs, value)
|
||||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||||
|
|
||||||
# for ip-adapter
|
# for ip-adapter
|
||||||
ip_key = self.to_k_ip(ip_hidden_states)
|
ip_key = self.to_k_ip(ip_hidden_states)
|
||||||
ip_value = self.to_v_ip(ip_hidden_states)
|
ip_value = self.to_v_ip(ip_hidden_states)
|
||||||
|
|
||||||
ip_key = attn.head_to_batch_dim(ip_key)
|
ip_key = attn.head_to_batch_dim(ip_key)
|
||||||
ip_value = attn.head_to_batch_dim(ip_value)
|
ip_value = attn.head_to_batch_dim(ip_value)
|
||||||
|
|
||||||
ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
|
ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
|
||||||
ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
|
ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
|
||||||
ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
|
ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
|
||||||
|
|
||||||
hidden_states = hidden_states + self.scale * ip_hidden_states
|
hidden_states = hidden_states + self.scale * ip_hidden_states
|
||||||
|
|
||||||
# linear proj
|
# linear proj
|
||||||
@ -176,8 +180,8 @@ class IPAttnProcessor(nn.Module):
|
|||||||
hidden_states = hidden_states / attn.rescale_output_factor
|
hidden_states = hidden_states / attn.rescale_output_factor
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class AttnProcessor2_0(torch.nn.Module):
|
class AttnProcessor2_0(torch.nn.Module):
|
||||||
r"""
|
r"""
|
||||||
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
||||||
@ -264,8 +268,8 @@ class AttnProcessor2_0(torch.nn.Module):
|
|||||||
hidden_states = hidden_states / attn.rescale_output_factor
|
hidden_states = hidden_states / attn.rescale_output_factor
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class IPAttnProcessor2_0(torch.nn.Module):
|
class IPAttnProcessor2_0(torch.nn.Module):
|
||||||
r"""
|
r"""
|
||||||
Attention processor for IP-Adapater for PyTorch 2.0.
|
Attention processor for IP-Adapater for PyTorch 2.0.
|
||||||
@ -355,11 +359,11 @@ class IPAttnProcessor2_0(torch.nn.Module):
|
|||||||
|
|
||||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||||
hidden_states = hidden_states.to(query.dtype)
|
hidden_states = hidden_states.to(query.dtype)
|
||||||
|
|
||||||
# for ip-adapter
|
# for ip-adapter
|
||||||
ip_key = self.to_k_ip(ip_hidden_states)
|
ip_key = self.to_k_ip(ip_hidden_states)
|
||||||
ip_value = self.to_v_ip(ip_hidden_states)
|
ip_value = self.to_v_ip(ip_hidden_states)
|
||||||
|
|
||||||
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
|
|
||||||
@ -368,10 +372,10 @@ class IPAttnProcessor2_0(torch.nn.Module):
|
|||||||
ip_hidden_states = F.scaled_dot_product_attention(
|
ip_hidden_states = F.scaled_dot_product_attention(
|
||||||
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
||||||
)
|
)
|
||||||
|
|
||||||
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||||
ip_hidden_states = ip_hidden_states.to(query.dtype)
|
ip_hidden_states = ip_hidden_states.to(query.dtype)
|
||||||
|
|
||||||
hidden_states = hidden_states + self.scale * ip_hidden_states
|
hidden_states = hidden_states + self.scale * ip_hidden_states
|
||||||
|
|
||||||
# linear proj
|
# linear proj
|
||||||
|
@ -1,3 +1,6 @@
|
|||||||
|
# copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0)
|
||||||
|
# and modified as needed
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
@ -6,11 +9,14 @@ from diffusers import StableDiffusionPipeline
|
|||||||
from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
|
from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from .utils import is_torch2_available
|
# FIXME: Getting errors when trying to use PyTorch 2.0 versions of IPAttnProcessor and AttnProcessor
|
||||||
if is_torch2_available:
|
# so for now falling back to the default versions
|
||||||
from .attention_processor import IPAttnProcessor2_0 as IPAttnProcessor, AttnProcessor2_0 as AttnProcessor
|
# from .utils import is_torch2_available
|
||||||
else:
|
# if is_torch2_available:
|
||||||
from .attention_processor import IPAttnProcessor, AttnProcessor
|
# from .attention_processor import IPAttnProcessor2_0 as IPAttnProcessor, AttnProcessor2_0 as AttnProcessor
|
||||||
|
# else:
|
||||||
|
# from .attention_processor import IPAttnProcessor, AttnProcessor
|
||||||
|
from .attention_processor import IPAttnProcessor, AttnProcessor
|
||||||
from .resampler import Resampler
|
from .resampler import Resampler
|
||||||
|
|
||||||
|
|
||||||
@ -18,12 +24,12 @@ class ImageProjModel(torch.nn.Module):
|
|||||||
"""Projection Model"""
|
"""Projection Model"""
|
||||||
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
|
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.cross_attention_dim = cross_attention_dim
|
self.cross_attention_dim = cross_attention_dim
|
||||||
self.clip_extra_context_tokens = clip_extra_context_tokens
|
self.clip_extra_context_tokens = clip_extra_context_tokens
|
||||||
self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
|
self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
|
||||||
self.norm = torch.nn.LayerNorm(cross_attention_dim)
|
self.norm = torch.nn.LayerNorm(cross_attention_dim)
|
||||||
|
|
||||||
def forward(self, image_embeds):
|
def forward(self, image_embeds):
|
||||||
embeds = image_embeds
|
embeds = image_embeds
|
||||||
clip_extra_context_tokens = self.proj(embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim)
|
clip_extra_context_tokens = self.proj(embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim)
|
||||||
@ -32,25 +38,29 @@ class ImageProjModel(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class IPAdapter:
|
class IPAdapter:
|
||||||
|
|
||||||
def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens=4):
|
def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens=4):
|
||||||
|
|
||||||
self.device = device
|
self.device = device
|
||||||
self.image_encoder_path = image_encoder_path
|
self.image_encoder_path = image_encoder_path
|
||||||
self.ip_ckpt = ip_ckpt
|
self.ip_ckpt = ip_ckpt
|
||||||
self.num_tokens = num_tokens
|
self.num_tokens = num_tokens
|
||||||
|
|
||||||
self.pipe = sd_pipe.to(self.device)
|
# FIXME:
|
||||||
|
# InvokeAI StableDiffusionPipeline has a to() method that isn't meant to be used
|
||||||
|
# so for now assuming that pipeline is already on the correct device
|
||||||
|
# self.pipe = sd_pipe.to(self.device)
|
||||||
|
self.pipe = sd_pipe
|
||||||
self.set_ip_adapter()
|
self.set_ip_adapter()
|
||||||
|
|
||||||
# load image encoder
|
# load image encoder
|
||||||
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(self.device, dtype=torch.float16)
|
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(self.device, dtype=torch.float16)
|
||||||
self.clip_image_processor = CLIPImageProcessor()
|
self.clip_image_processor = CLIPImageProcessor()
|
||||||
# image proj model
|
# image proj model
|
||||||
self.image_proj_model = self.init_proj()
|
self.image_proj_model = self.init_proj()
|
||||||
|
|
||||||
self.load_ip_adapter()
|
self.load_ip_adapter()
|
||||||
|
|
||||||
def init_proj(self):
|
def init_proj(self):
|
||||||
image_proj_model = ImageProjModel(
|
image_proj_model = ImageProjModel(
|
||||||
cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
|
cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
|
||||||
@ -58,10 +68,12 @@ class IPAdapter:
|
|||||||
clip_extra_context_tokens=self.num_tokens,
|
clip_extra_context_tokens=self.num_tokens,
|
||||||
).to(self.device, dtype=torch.float16)
|
).to(self.device, dtype=torch.float16)
|
||||||
return image_proj_model
|
return image_proj_model
|
||||||
|
|
||||||
def set_ip_adapter(self):
|
def set_ip_adapter(self):
|
||||||
unet = self.pipe.unet
|
unet = self.pipe.unet
|
||||||
attn_procs = {}
|
attn_procs = {}
|
||||||
|
print("Original UNet Attn Processors count:", len(unet.attn_processors))
|
||||||
|
print(unet.attn_processors.keys())
|
||||||
for name in unet.attn_processors.keys():
|
for name in unet.attn_processors.keys():
|
||||||
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
|
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
|
||||||
if name.startswith("mid_block"):
|
if name.startswith("mid_block"):
|
||||||
@ -75,16 +87,19 @@ class IPAdapter:
|
|||||||
if cross_attention_dim is None:
|
if cross_attention_dim is None:
|
||||||
attn_procs[name] = AttnProcessor()
|
attn_procs[name] = AttnProcessor()
|
||||||
else:
|
else:
|
||||||
|
print("swapping in IPAttnProcessor for", name)
|
||||||
attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim,
|
attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim,
|
||||||
scale=1.0).to(self.device, dtype=torch.float16)
|
scale=1.0).to(self.device, dtype=torch.float16)
|
||||||
unet.set_attn_processor(attn_procs)
|
unet.set_attn_processor(attn_procs)
|
||||||
|
print("Modified UNet Attn Processors count:", len(unet.attn_processors))
|
||||||
|
print(unet.attn_processors.keys())
|
||||||
|
|
||||||
def load_ip_adapter(self):
|
def load_ip_adapter(self):
|
||||||
state_dict = torch.load(self.ip_ckpt, map_location="cpu")
|
state_dict = torch.load(self.ip_ckpt, map_location="cpu")
|
||||||
self.image_proj_model.load_state_dict(state_dict["image_proj"])
|
self.image_proj_model.load_state_dict(state_dict["image_proj"])
|
||||||
ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
|
ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
|
||||||
ip_layers.load_state_dict(state_dict["ip_adapter"])
|
ip_layers.load_state_dict(state_dict["ip_adapter"])
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def get_image_embeds(self, pil_image):
|
def get_image_embeds(self, pil_image):
|
||||||
if isinstance(pil_image, Image.Image):
|
if isinstance(pil_image, Image.Image):
|
||||||
@ -94,12 +109,14 @@ class IPAdapter:
|
|||||||
image_prompt_embeds = self.image_proj_model(clip_image_embeds)
|
image_prompt_embeds = self.image_proj_model(clip_image_embeds)
|
||||||
uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds))
|
uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds))
|
||||||
return image_prompt_embeds, uncond_image_prompt_embeds
|
return image_prompt_embeds, uncond_image_prompt_embeds
|
||||||
|
|
||||||
def set_scale(self, scale):
|
def set_scale(self, scale):
|
||||||
for attn_processor in self.pipe.unet.attn_processors.values():
|
for attn_processor in self.pipe.unet.attn_processors.values():
|
||||||
if isinstance(attn_processor, IPAttnProcessor):
|
if isinstance(attn_processor, IPAttnProcessor):
|
||||||
attn_processor.scale = scale
|
attn_processor.scale = scale
|
||||||
|
|
||||||
|
# IPAdapter.generate() method is not used for InvokeAI
|
||||||
|
# left here for reference
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
pil_image,
|
pil_image,
|
||||||
@ -113,22 +130,22 @@ class IPAdapter:
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
self.set_scale(scale)
|
self.set_scale(scale)
|
||||||
|
|
||||||
if isinstance(pil_image, Image.Image):
|
if isinstance(pil_image, Image.Image):
|
||||||
num_prompts = 1
|
num_prompts = 1
|
||||||
else:
|
else:
|
||||||
num_prompts = len(pil_image)
|
num_prompts = len(pil_image)
|
||||||
|
|
||||||
if prompt is None:
|
if prompt is None:
|
||||||
prompt = "best quality, high quality"
|
prompt = "best quality, high quality"
|
||||||
if negative_prompt is None:
|
if negative_prompt is None:
|
||||||
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
|
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
|
||||||
|
|
||||||
if not isinstance(prompt, List):
|
if not isinstance(prompt, List):
|
||||||
prompt = [prompt] * num_prompts
|
prompt = [prompt] * num_prompts
|
||||||
if not isinstance(negative_prompt, List):
|
if not isinstance(negative_prompt, List):
|
||||||
negative_prompt = [negative_prompt] * num_prompts
|
negative_prompt = [negative_prompt] * num_prompts
|
||||||
|
|
||||||
image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image)
|
image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image)
|
||||||
bs_embed, seq_len, _ = image_prompt_embeds.shape
|
bs_embed, seq_len, _ = image_prompt_embeds.shape
|
||||||
image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
|
image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
|
||||||
@ -142,7 +159,7 @@ class IPAdapter:
|
|||||||
negative_prompt_embeds_, prompt_embeds_ = prompt_embeds.chunk(2)
|
negative_prompt_embeds_, prompt_embeds_ = prompt_embeds.chunk(2)
|
||||||
prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
|
prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
|
||||||
negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
|
negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
|
||||||
|
|
||||||
generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
|
generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
|
||||||
images = self.pipe(
|
images = self.pipe(
|
||||||
prompt_embeds=prompt_embeds,
|
prompt_embeds=prompt_embeds,
|
||||||
@ -152,13 +169,13 @@ class IPAdapter:
|
|||||||
generator=generator,
|
generator=generator,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
).images
|
).images
|
||||||
|
|
||||||
return images
|
return images
|
||||||
|
|
||||||
|
|
||||||
class IPAdapterXL(IPAdapter):
|
class IPAdapterXL(IPAdapter):
|
||||||
"""SDXL"""
|
"""SDXL"""
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
pil_image,
|
pil_image,
|
||||||
@ -171,22 +188,22 @@ class IPAdapterXL(IPAdapter):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
self.set_scale(scale)
|
self.set_scale(scale)
|
||||||
|
|
||||||
if isinstance(pil_image, Image.Image):
|
if isinstance(pil_image, Image.Image):
|
||||||
num_prompts = 1
|
num_prompts = 1
|
||||||
else:
|
else:
|
||||||
num_prompts = len(pil_image)
|
num_prompts = len(pil_image)
|
||||||
|
|
||||||
if prompt is None:
|
if prompt is None:
|
||||||
prompt = "best quality, high quality"
|
prompt = "best quality, high quality"
|
||||||
if negative_prompt is None:
|
if negative_prompt is None:
|
||||||
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
|
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
|
||||||
|
|
||||||
if not isinstance(prompt, List):
|
if not isinstance(prompt, List):
|
||||||
prompt = [prompt] * num_prompts
|
prompt = [prompt] * num_prompts
|
||||||
if not isinstance(negative_prompt, List):
|
if not isinstance(negative_prompt, List):
|
||||||
negative_prompt = [negative_prompt] * num_prompts
|
negative_prompt = [negative_prompt] * num_prompts
|
||||||
|
|
||||||
image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image)
|
image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image)
|
||||||
bs_embed, seq_len, _ = image_prompt_embeds.shape
|
bs_embed, seq_len, _ = image_prompt_embeds.shape
|
||||||
image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
|
image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
|
||||||
@ -199,7 +216,7 @@ class IPAdapterXL(IPAdapter):
|
|||||||
prompt, num_images_per_prompt=num_samples, do_classifier_free_guidance=True, negative_prompt=negative_prompt)
|
prompt, num_images_per_prompt=num_samples, do_classifier_free_guidance=True, negative_prompt=negative_prompt)
|
||||||
prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
|
prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
|
||||||
negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
|
negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
|
||||||
|
|
||||||
generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
|
generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
|
||||||
images = self.pipe(
|
images = self.pipe(
|
||||||
prompt_embeds=prompt_embeds,
|
prompt_embeds=prompt_embeds,
|
||||||
@ -210,10 +227,10 @@ class IPAdapterXL(IPAdapter):
|
|||||||
generator=generator,
|
generator=generator,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
).images
|
).images
|
||||||
|
|
||||||
return images
|
return images
|
||||||
|
|
||||||
|
|
||||||
class IPAdapterPlus(IPAdapter):
|
class IPAdapterPlus(IPAdapter):
|
||||||
"""IP-Adapter with fine-grained features"""
|
"""IP-Adapter with fine-grained features"""
|
||||||
|
|
||||||
@ -229,7 +246,7 @@ class IPAdapterPlus(IPAdapter):
|
|||||||
ff_mult=4
|
ff_mult=4
|
||||||
).to(self.device, dtype=torch.float16)
|
).to(self.device, dtype=torch.float16)
|
||||||
return image_proj_model
|
return image_proj_model
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def get_image_embeds(self, pil_image):
|
def get_image_embeds(self, pil_image):
|
||||||
if isinstance(pil_image, Image.Image):
|
if isinstance(pil_image, Image.Image):
|
||||||
@ -240,4 +257,4 @@ class IPAdapterPlus(IPAdapter):
|
|||||||
image_prompt_embeds = self.image_proj_model(clip_image_embeds)
|
image_prompt_embeds = self.image_proj_model(clip_image_embeds)
|
||||||
uncond_clip_image_embeds = self.image_encoder(torch.zeros_like(clip_image), output_hidden_states=True).hidden_states[-2]
|
uncond_clip_image_embeds = self.image_encoder(torch.zeros_like(clip_image), output_hidden_states=True).hidden_states[-2]
|
||||||
uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
|
uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
|
||||||
return image_prompt_embeds, uncond_image_prompt_embeds
|
return image_prompt_embeds, uncond_image_prompt_embeds
|
||||||
|
@ -1,4 +1,6 @@
|
|||||||
# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
|
# copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0)
|
||||||
|
|
||||||
|
# tencent ailab comment: modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
|
||||||
import math
|
import math
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -14,8 +16,8 @@ def FeedForward(dim, mult=4):
|
|||||||
nn.GELU(),
|
nn.GELU(),
|
||||||
nn.Linear(inner_dim, dim, bias=False),
|
nn.Linear(inner_dim, dim, bias=False),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def reshape_tensor(x, heads):
|
def reshape_tensor(x, heads):
|
||||||
bs, length, width = x.shape
|
bs, length, width = x.shape
|
||||||
#(bs, length, width) --> (bs, length, n_heads, dim_per_head)
|
#(bs, length, width) --> (bs, length, n_heads, dim_per_head)
|
||||||
@ -53,13 +55,13 @@ class PerceiverAttention(nn.Module):
|
|||||||
"""
|
"""
|
||||||
x = self.norm1(x)
|
x = self.norm1(x)
|
||||||
latents = self.norm2(latents)
|
latents = self.norm2(latents)
|
||||||
|
|
||||||
b, l, _ = latents.shape
|
b, l, _ = latents.shape
|
||||||
|
|
||||||
q = self.to_q(latents)
|
q = self.to_q(latents)
|
||||||
kv_input = torch.cat((x, latents), dim=-2)
|
kv_input = torch.cat((x, latents), dim=-2)
|
||||||
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
|
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
|
||||||
|
|
||||||
q = reshape_tensor(q, self.heads)
|
q = reshape_tensor(q, self.heads)
|
||||||
k = reshape_tensor(k, self.heads)
|
k = reshape_tensor(k, self.heads)
|
||||||
v = reshape_tensor(v, self.heads)
|
v = reshape_tensor(v, self.heads)
|
||||||
@ -69,7 +71,7 @@ class PerceiverAttention(nn.Module):
|
|||||||
weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
|
weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
|
||||||
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
||||||
out = weight @ v
|
out = weight @ v
|
||||||
|
|
||||||
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
|
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
|
||||||
|
|
||||||
return self.to_out(out)
|
return self.to_out(out)
|
||||||
@ -88,14 +90,14 @@ class Resampler(nn.Module):
|
|||||||
ff_mult=4,
|
ff_mult=4,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
|
self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
|
||||||
|
|
||||||
self.proj_in = nn.Linear(embedding_dim, dim)
|
self.proj_in = nn.Linear(embedding_dim, dim)
|
||||||
|
|
||||||
self.proj_out = nn.Linear(dim, output_dim)
|
self.proj_out = nn.Linear(dim, output_dim)
|
||||||
self.norm_out = nn.LayerNorm(output_dim)
|
self.norm_out = nn.LayerNorm(output_dim)
|
||||||
|
|
||||||
self.layers = nn.ModuleList([])
|
self.layers = nn.ModuleList([])
|
||||||
for _ in range(depth):
|
for _ in range(depth):
|
||||||
self.layers.append(
|
self.layers.append(
|
||||||
@ -108,14 +110,14 @@ class Resampler(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
|
||||||
latents = self.latents.repeat(x.size(0), 1, 1)
|
latents = self.latents.repeat(x.size(0), 1, 1)
|
||||||
|
|
||||||
x = self.proj_in(x)
|
x = self.proj_in(x)
|
||||||
|
|
||||||
for attn, ff in self.layers:
|
for attn, ff in self.layers:
|
||||||
latents = attn(x, latents) + latents
|
latents = attn(x, latents) + latents
|
||||||
latents = ff(latents) + latents
|
latents = ff(latents) + latents
|
||||||
|
|
||||||
latents = self.proj_out(latents)
|
latents = self.proj_out(latents)
|
||||||
return self.norm_out(latents)
|
return self.norm_out(latents)
|
||||||
|
@ -1,3 +1,6 @@
|
|||||||
|
# copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0)
|
||||||
|
# and modified as needed
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||||
@ -365,4 +368,4 @@ def generate(
|
|||||||
if not return_dict:
|
if not return_dict:
|
||||||
return (image, has_nsfw_concept)
|
return (image, has_nsfw_concept)
|
||||||
|
|
||||||
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
||||||
|
Loading…
Reference in New Issue
Block a user