mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Core ip_adapter files from https://github.com/tencent-ailab/IP-Adapter
Copied into InvokeAI since IP-Adapter repo is not a package. Is there a better way to do this for non-packaged Python code while still keeping InvokeAI install easy?
This commit is contained in:
parent
2bd3cf28ea
commit
1ad98ce999
1
invokeai/backend/ip_adapter/__init__.py
Normal file
1
invokeai/backend/ip_adapter/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from .ip_adapter import IPAdapter, IPAdapterXL, IPAdapterPlus
|
390
invokeai/backend/ip_adapter/attention_processor.py
Normal file
390
invokeai/backend/ip_adapter/attention_processor.py
Normal file
@ -0,0 +1,390 @@
|
||||
# modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class AttnProcessor(nn.Module):
|
||||
r"""
|
||||
Default processor for performing attention-related computations.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size=None,
|
||||
cross_attention_dim=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn,
|
||||
hidden_states,
|
||||
encoder_hidden_states=None,
|
||||
attention_mask=None,
|
||||
temb=None,
|
||||
):
|
||||
residual = hidden_states
|
||||
|
||||
if attn.spatial_norm is not None:
|
||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
query = attn.head_to_batch_dim(query)
|
||||
key = attn.head_to_batch_dim(key)
|
||||
value = attn.head_to_batch_dim(value)
|
||||
|
||||
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
||||
hidden_states = torch.bmm(attention_probs, value)
|
||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
if attn.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class IPAttnProcessor(nn.Module):
|
||||
r"""
|
||||
Attention processor for IP-Adapater.
|
||||
Args:
|
||||
hidden_size (`int`):
|
||||
The hidden size of the attention layer.
|
||||
cross_attention_dim (`int`):
|
||||
The number of channels in the `encoder_hidden_states`.
|
||||
text_context_len (`int`, defaults to 77):
|
||||
The context length of the text features.
|
||||
scale (`float`, defaults to 1.0):
|
||||
the weight scale of image prompt.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, cross_attention_dim=None, text_context_len=77, scale=1.0):
|
||||
super().__init__()
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.cross_attention_dim = cross_attention_dim
|
||||
self.text_context_len = text_context_len
|
||||
self.scale = scale
|
||||
|
||||
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
||||
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn,
|
||||
hidden_states,
|
||||
encoder_hidden_states=None,
|
||||
attention_mask=None,
|
||||
temb=None,
|
||||
):
|
||||
residual = hidden_states
|
||||
|
||||
if attn.spatial_norm is not None:
|
||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
# split hidden states
|
||||
encoder_hidden_states, ip_hidden_states = encoder_hidden_states[:, :self.text_context_len, :], encoder_hidden_states[:, self.text_context_len:, :]
|
||||
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
query = attn.head_to_batch_dim(query)
|
||||
key = attn.head_to_batch_dim(key)
|
||||
value = attn.head_to_batch_dim(value)
|
||||
|
||||
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
||||
hidden_states = torch.bmm(attention_probs, value)
|
||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||
|
||||
# for ip-adapter
|
||||
ip_key = self.to_k_ip(ip_hidden_states)
|
||||
ip_value = self.to_v_ip(ip_hidden_states)
|
||||
|
||||
ip_key = attn.head_to_batch_dim(ip_key)
|
||||
ip_value = attn.head_to_batch_dim(ip_value)
|
||||
|
||||
ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
|
||||
ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
|
||||
ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
|
||||
|
||||
hidden_states = hidden_states + self.scale * ip_hidden_states
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
if attn.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class AttnProcessor2_0(torch.nn.Module):
|
||||
r"""
|
||||
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size=None,
|
||||
cross_attention_dim=None,
|
||||
):
|
||||
super().__init__()
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn,
|
||||
hidden_states,
|
||||
encoder_hidden_states=None,
|
||||
attention_mask=None,
|
||||
temb=None,
|
||||
):
|
||||
residual = hidden_states
|
||||
|
||||
if attn.spatial_norm is not None:
|
||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
# scaled_dot_product_attention expects attention_mask shape to be
|
||||
# (batch, heads, source_length, target_length)
|
||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
if attn.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class IPAttnProcessor2_0(torch.nn.Module):
|
||||
r"""
|
||||
Attention processor for IP-Adapater for PyTorch 2.0.
|
||||
Args:
|
||||
hidden_size (`int`):
|
||||
The hidden size of the attention layer.
|
||||
cross_attention_dim (`int`):
|
||||
The number of channels in the `encoder_hidden_states`.
|
||||
text_context_len (`int`, defaults to 77):
|
||||
The context length of the text features.
|
||||
scale (`float`, defaults to 1.0):
|
||||
the weight scale of image prompt.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, cross_attention_dim=None, text_context_len=77, scale=1.0):
|
||||
super().__init__()
|
||||
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.cross_attention_dim = cross_attention_dim
|
||||
self.text_context_len = text_context_len
|
||||
self.scale = scale
|
||||
|
||||
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
||||
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn,
|
||||
hidden_states,
|
||||
encoder_hidden_states=None,
|
||||
attention_mask=None,
|
||||
temb=None,
|
||||
):
|
||||
residual = hidden_states
|
||||
|
||||
if attn.spatial_norm is not None:
|
||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
# scaled_dot_product_attention expects attention_mask shape to be
|
||||
# (batch, heads, source_length, target_length)
|
||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
# split hidden states
|
||||
encoder_hidden_states, ip_hidden_states = encoder_hidden_states[:, :self.text_context_len, :], encoder_hidden_states[:, self.text_context_len:, :]
|
||||
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
# for ip-adapter
|
||||
ip_key = self.to_k_ip(ip_hidden_states)
|
||||
ip_value = self.to_v_ip(ip_hidden_states)
|
||||
|
||||
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
ip_hidden_states = F.scaled_dot_product_attention(
|
||||
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
ip_hidden_states = ip_hidden_states.to(query.dtype)
|
||||
|
||||
hidden_states = hidden_states + self.scale * ip_hidden_states
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
if attn.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
|
||||
return hidden_states
|
243
invokeai/backend/ip_adapter/ip_adapter.py
Normal file
243
invokeai/backend/ip_adapter/ip_adapter.py
Normal file
@ -0,0 +1,243 @@
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from diffusers import StableDiffusionPipeline
|
||||
from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
|
||||
from PIL import Image
|
||||
|
||||
from .utils import is_torch2_available
|
||||
if is_torch2_available:
|
||||
from .attention_processor import IPAttnProcessor2_0 as IPAttnProcessor, AttnProcessor2_0 as AttnProcessor
|
||||
else:
|
||||
from .attention_processor import IPAttnProcessor, AttnProcessor
|
||||
from .resampler import Resampler
|
||||
|
||||
|
||||
class ImageProjModel(torch.nn.Module):
|
||||
"""Projection Model"""
|
||||
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
|
||||
super().__init__()
|
||||
|
||||
self.cross_attention_dim = cross_attention_dim
|
||||
self.clip_extra_context_tokens = clip_extra_context_tokens
|
||||
self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
|
||||
self.norm = torch.nn.LayerNorm(cross_attention_dim)
|
||||
|
||||
def forward(self, image_embeds):
|
||||
embeds = image_embeds
|
||||
clip_extra_context_tokens = self.proj(embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim)
|
||||
clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
|
||||
return clip_extra_context_tokens
|
||||
|
||||
|
||||
class IPAdapter:
|
||||
|
||||
def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens=4):
|
||||
|
||||
self.device = device
|
||||
self.image_encoder_path = image_encoder_path
|
||||
self.ip_ckpt = ip_ckpt
|
||||
self.num_tokens = num_tokens
|
||||
|
||||
self.pipe = sd_pipe.to(self.device)
|
||||
self.set_ip_adapter()
|
||||
|
||||
# load image encoder
|
||||
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(self.device, dtype=torch.float16)
|
||||
self.clip_image_processor = CLIPImageProcessor()
|
||||
# image proj model
|
||||
self.image_proj_model = self.init_proj()
|
||||
|
||||
self.load_ip_adapter()
|
||||
|
||||
def init_proj(self):
|
||||
image_proj_model = ImageProjModel(
|
||||
cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
|
||||
clip_embeddings_dim=self.image_encoder.config.projection_dim,
|
||||
clip_extra_context_tokens=self.num_tokens,
|
||||
).to(self.device, dtype=torch.float16)
|
||||
return image_proj_model
|
||||
|
||||
def set_ip_adapter(self):
|
||||
unet = self.pipe.unet
|
||||
attn_procs = {}
|
||||
for name in unet.attn_processors.keys():
|
||||
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
|
||||
if name.startswith("mid_block"):
|
||||
hidden_size = unet.config.block_out_channels[-1]
|
||||
elif name.startswith("up_blocks"):
|
||||
block_id = int(name[len("up_blocks.")])
|
||||
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
|
||||
elif name.startswith("down_blocks"):
|
||||
block_id = int(name[len("down_blocks.")])
|
||||
hidden_size = unet.config.block_out_channels[block_id]
|
||||
if cross_attention_dim is None:
|
||||
attn_procs[name] = AttnProcessor()
|
||||
else:
|
||||
attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim,
|
||||
scale=1.0).to(self.device, dtype=torch.float16)
|
||||
unet.set_attn_processor(attn_procs)
|
||||
|
||||
def load_ip_adapter(self):
|
||||
state_dict = torch.load(self.ip_ckpt, map_location="cpu")
|
||||
self.image_proj_model.load_state_dict(state_dict["image_proj"])
|
||||
ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
|
||||
ip_layers.load_state_dict(state_dict["ip_adapter"])
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_image_embeds(self, pil_image):
|
||||
if isinstance(pil_image, Image.Image):
|
||||
pil_image = [pil_image]
|
||||
clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
|
||||
clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds
|
||||
image_prompt_embeds = self.image_proj_model(clip_image_embeds)
|
||||
uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds))
|
||||
return image_prompt_embeds, uncond_image_prompt_embeds
|
||||
|
||||
def set_scale(self, scale):
|
||||
for attn_processor in self.pipe.unet.attn_processors.values():
|
||||
if isinstance(attn_processor, IPAttnProcessor):
|
||||
attn_processor.scale = scale
|
||||
|
||||
def generate(
|
||||
self,
|
||||
pil_image,
|
||||
prompt=None,
|
||||
negative_prompt=None,
|
||||
scale=1.0,
|
||||
num_samples=4,
|
||||
seed=-1,
|
||||
guidance_scale=7.5,
|
||||
num_inference_steps=30,
|
||||
**kwargs,
|
||||
):
|
||||
self.set_scale(scale)
|
||||
|
||||
if isinstance(pil_image, Image.Image):
|
||||
num_prompts = 1
|
||||
else:
|
||||
num_prompts = len(pil_image)
|
||||
|
||||
if prompt is None:
|
||||
prompt = "best quality, high quality"
|
||||
if negative_prompt is None:
|
||||
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
|
||||
|
||||
if not isinstance(prompt, List):
|
||||
prompt = [prompt] * num_prompts
|
||||
if not isinstance(negative_prompt, List):
|
||||
negative_prompt = [negative_prompt] * num_prompts
|
||||
|
||||
image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image)
|
||||
bs_embed, seq_len, _ = image_prompt_embeds.shape
|
||||
image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
|
||||
image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
|
||||
uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
|
||||
uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
|
||||
|
||||
with torch.inference_mode():
|
||||
prompt_embeds = self.pipe._encode_prompt(
|
||||
prompt, device=self.device, num_images_per_prompt=num_samples, do_classifier_free_guidance=True, negative_prompt=negative_prompt)
|
||||
negative_prompt_embeds_, prompt_embeds_ = prompt_embeds.chunk(2)
|
||||
prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
|
||||
negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
|
||||
|
||||
generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
|
||||
images = self.pipe(
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
guidance_scale=guidance_scale,
|
||||
num_inference_steps=num_inference_steps,
|
||||
generator=generator,
|
||||
**kwargs,
|
||||
).images
|
||||
|
||||
return images
|
||||
|
||||
|
||||
class IPAdapterXL(IPAdapter):
|
||||
"""SDXL"""
|
||||
|
||||
def generate(
|
||||
self,
|
||||
pil_image,
|
||||
prompt=None,
|
||||
negative_prompt=None,
|
||||
scale=1.0,
|
||||
num_samples=4,
|
||||
seed=-1,
|
||||
num_inference_steps=30,
|
||||
**kwargs,
|
||||
):
|
||||
self.set_scale(scale)
|
||||
|
||||
if isinstance(pil_image, Image.Image):
|
||||
num_prompts = 1
|
||||
else:
|
||||
num_prompts = len(pil_image)
|
||||
|
||||
if prompt is None:
|
||||
prompt = "best quality, high quality"
|
||||
if negative_prompt is None:
|
||||
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
|
||||
|
||||
if not isinstance(prompt, List):
|
||||
prompt = [prompt] * num_prompts
|
||||
if not isinstance(negative_prompt, List):
|
||||
negative_prompt = [negative_prompt] * num_prompts
|
||||
|
||||
image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image)
|
||||
bs_embed, seq_len, _ = image_prompt_embeds.shape
|
||||
image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
|
||||
image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
|
||||
uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
|
||||
uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
|
||||
|
||||
with torch.inference_mode():
|
||||
prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = self.pipe.encode_prompt(
|
||||
prompt, num_images_per_prompt=num_samples, do_classifier_free_guidance=True, negative_prompt=negative_prompt)
|
||||
prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
|
||||
negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
|
||||
|
||||
generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
|
||||
images = self.pipe(
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
||||
num_inference_steps=num_inference_steps,
|
||||
generator=generator,
|
||||
**kwargs,
|
||||
).images
|
||||
|
||||
return images
|
||||
|
||||
|
||||
class IPAdapterPlus(IPAdapter):
|
||||
"""IP-Adapter with fine-grained features"""
|
||||
|
||||
def init_proj(self):
|
||||
image_proj_model = Resampler(
|
||||
dim=self.pipe.unet.config.cross_attention_dim,
|
||||
depth=4,
|
||||
dim_head=64,
|
||||
heads=12,
|
||||
num_queries=self.num_tokens,
|
||||
embedding_dim=self.image_encoder.config.hidden_size,
|
||||
output_dim=self.pipe.unet.config.cross_attention_dim,
|
||||
ff_mult=4
|
||||
).to(self.device, dtype=torch.float16)
|
||||
return image_proj_model
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_image_embeds(self, pil_image):
|
||||
if isinstance(pil_image, Image.Image):
|
||||
pil_image = [pil_image]
|
||||
clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
|
||||
clip_image = clip_image.to(self.device, dtype=torch.float16)
|
||||
clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
|
||||
image_prompt_embeds = self.image_proj_model(clip_image_embeds)
|
||||
uncond_clip_image_embeds = self.image_encoder(torch.zeros_like(clip_image), output_hidden_states=True).hidden_states[-2]
|
||||
uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
|
||||
return image_prompt_embeds, uncond_image_prompt_embeds
|
121
invokeai/backend/ip_adapter/resampler.py
Normal file
121
invokeai/backend/ip_adapter/resampler.py
Normal file
@ -0,0 +1,121 @@
|
||||
# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
# FFN
|
||||
def FeedForward(dim, mult=4):
|
||||
inner_dim = int(dim * mult)
|
||||
return nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, inner_dim, bias=False),
|
||||
nn.GELU(),
|
||||
nn.Linear(inner_dim, dim, bias=False),
|
||||
)
|
||||
|
||||
|
||||
def reshape_tensor(x, heads):
|
||||
bs, length, width = x.shape
|
||||
#(bs, length, width) --> (bs, length, n_heads, dim_per_head)
|
||||
x = x.view(bs, length, heads, -1)
|
||||
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
|
||||
x = x.transpose(1, 2)
|
||||
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
|
||||
x = x.reshape(bs, heads, length, -1)
|
||||
return x
|
||||
|
||||
|
||||
class PerceiverAttention(nn.Module):
|
||||
def __init__(self, *, dim, dim_head=64, heads=8):
|
||||
super().__init__()
|
||||
self.scale = dim_head**-0.5
|
||||
self.dim_head = dim_head
|
||||
self.heads = heads
|
||||
inner_dim = dim_head * heads
|
||||
|
||||
self.norm1 = nn.LayerNorm(dim)
|
||||
self.norm2 = nn.LayerNorm(dim)
|
||||
|
||||
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
||||
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
|
||||
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
||||
|
||||
|
||||
def forward(self, x, latents):
|
||||
"""
|
||||
Args:
|
||||
x (torch.Tensor): image features
|
||||
shape (b, n1, D)
|
||||
latent (torch.Tensor): latent features
|
||||
shape (b, n2, D)
|
||||
"""
|
||||
x = self.norm1(x)
|
||||
latents = self.norm2(latents)
|
||||
|
||||
b, l, _ = latents.shape
|
||||
|
||||
q = self.to_q(latents)
|
||||
kv_input = torch.cat((x, latents), dim=-2)
|
||||
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
|
||||
|
||||
q = reshape_tensor(q, self.heads)
|
||||
k = reshape_tensor(k, self.heads)
|
||||
v = reshape_tensor(v, self.heads)
|
||||
|
||||
# attention
|
||||
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
|
||||
weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
|
||||
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
||||
out = weight @ v
|
||||
|
||||
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
|
||||
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class Resampler(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim=1024,
|
||||
depth=8,
|
||||
dim_head=64,
|
||||
heads=16,
|
||||
num_queries=8,
|
||||
embedding_dim=768,
|
||||
output_dim=1024,
|
||||
ff_mult=4,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
|
||||
|
||||
self.proj_in = nn.Linear(embedding_dim, dim)
|
||||
|
||||
self.proj_out = nn.Linear(dim, output_dim)
|
||||
self.norm_out = nn.LayerNorm(output_dim)
|
||||
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(
|
||||
nn.ModuleList(
|
||||
[
|
||||
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
|
||||
FeedForward(dim=dim, mult=ff_mult),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
latents = self.latents.repeat(x.size(0), 1, 1)
|
||||
|
||||
x = self.proj_in(x)
|
||||
|
||||
for attn, ff in self.layers:
|
||||
latents = attn(x, latents) + latents
|
||||
latents = ff(latents) + latents
|
||||
|
||||
latents = self.proj_out(latents)
|
||||
return self.norm_out(latents)
|
368
invokeai/backend/ip_adapter/utils.py
Normal file
368
invokeai/backend/ip_adapter/utils.py
Normal file
@ -0,0 +1,368 @@
|
||||
import inspect
|
||||
import warnings
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from diffusers.utils import is_compiled_module
|
||||
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
|
||||
from diffusers.models import ControlNetModel
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
|
||||
|
||||
|
||||
def is_torch2_available():
|
||||
return hasattr(F, "scaled_dot_product_attention")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
image: Union[
|
||||
torch.FloatTensor,
|
||||
PIL.Image.Image,
|
||||
np.ndarray,
|
||||
List[torch.FloatTensor],
|
||||
List[PIL.Image.Image],
|
||||
List[np.ndarray],
|
||||
] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 7.5,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: int = 1,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
|
||||
guess_mode: bool = False,
|
||||
control_guidance_start: Union[float, List[float]] = 0.0,
|
||||
control_guidance_end: Union[float, List[float]] = 1.0,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
|
||||
`List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
|
||||
The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
|
||||
the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can
|
||||
also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If
|
||||
height and/or width are passed, `image` is resized according to them. If multiple ControlNets are
|
||||
specified in init, images must be passed as a list such that each element of the list can be correctly
|
||||
batched for input to a single controlnet.
|
||||
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The width in pixels of the generated image.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, *optional*, defaults to 7.5):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
eta (`float`, *optional*, defaults to 0.0):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
||||
plain tuple.
|
||||
callback (`Callable`, *optional*):
|
||||
A function that will be called every `callback_steps` steps during inference. The function will be
|
||||
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||
callback_steps (`int`, *optional*, defaults to 1):
|
||||
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
||||
called at every step.
|
||||
cross_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
|
||||
The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
|
||||
to the residual in the original unet. If multiple ControlNets are specified in init, you can set the
|
||||
corresponding scale as a list.
|
||||
guess_mode (`bool`, *optional*, defaults to `False`):
|
||||
In this mode, the ControlNet encoder will try best to recognize the content of the input image even if
|
||||
you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended.
|
||||
control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
|
||||
The percentage of total steps at which the controlnet starts applying.
|
||||
control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
|
||||
The percentage of total steps at which the controlnet stops applying.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
||||
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
||||
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
||||
(nsfw) content, according to the `safety_checker`.
|
||||
"""
|
||||
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
|
||||
|
||||
# align format for control guidance
|
||||
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
|
||||
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
|
||||
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
|
||||
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
|
||||
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
|
||||
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
|
||||
control_guidance_start, control_guidance_end = mult * [control_guidance_start], mult * [
|
||||
control_guidance_end
|
||||
]
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
image,
|
||||
callback_steps,
|
||||
negative_prompt,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
controlnet_conditioning_scale,
|
||||
control_guidance_start,
|
||||
control_guidance_end,
|
||||
)
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
|
||||
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
|
||||
|
||||
global_pool_conditions = (
|
||||
controlnet.config.global_pool_conditions
|
||||
if isinstance(controlnet, ControlNetModel)
|
||||
else controlnet.nets[0].config.global_pool_conditions
|
||||
)
|
||||
guess_mode = guess_mode or global_pool_conditions
|
||||
|
||||
# 3. Encode input prompt
|
||||
text_encoder_lora_scale = (
|
||||
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
|
||||
)
|
||||
prompt_embeds = self._encode_prompt(
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
lora_scale=text_encoder_lora_scale,
|
||||
)
|
||||
|
||||
# 4. Prepare image
|
||||
if isinstance(controlnet, ControlNetModel):
|
||||
image = self.prepare_image(
|
||||
image=image,
|
||||
width=width,
|
||||
height=height,
|
||||
batch_size=batch_size * num_images_per_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
dtype=controlnet.dtype,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
guess_mode=guess_mode,
|
||||
)
|
||||
height, width = image.shape[-2:]
|
||||
elif isinstance(controlnet, MultiControlNetModel):
|
||||
images = []
|
||||
|
||||
for image_ in image:
|
||||
image_ = self.prepare_image(
|
||||
image=image_,
|
||||
width=width,
|
||||
height=height,
|
||||
batch_size=batch_size * num_images_per_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
dtype=controlnet.dtype,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
guess_mode=guess_mode,
|
||||
)
|
||||
|
||||
images.append(image_)
|
||||
|
||||
image = images
|
||||
height, width = image[0].shape[-2:]
|
||||
else:
|
||||
assert False
|
||||
|
||||
# 5. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 6. Prepare latent variables
|
||||
num_channels_latents = self.unet.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 7.1 Create tensor stating which controlnets to keep
|
||||
controlnet_keep = []
|
||||
for i in range(len(timesteps)):
|
||||
keeps = [
|
||||
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
|
||||
for s, e in zip(control_guidance_start, control_guidance_end)
|
||||
]
|
||||
controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
|
||||
|
||||
# 8. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# controlnet(s) inference
|
||||
if guess_mode and do_classifier_free_guidance:
|
||||
# Infer ControlNet only for the conditional batch.
|
||||
control_model_input = latents
|
||||
control_model_input = self.scheduler.scale_model_input(control_model_input, t)
|
||||
controlnet_prompt_embeds = prompt_embeds[:, :77, :].chunk(2)[1]
|
||||
else:
|
||||
control_model_input = latent_model_input
|
||||
controlnet_prompt_embeds = prompt_embeds[:, :77, :]
|
||||
|
||||
if isinstance(controlnet_keep[i], list):
|
||||
cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
|
||||
else:
|
||||
controlnet_cond_scale = controlnet_conditioning_scale
|
||||
if isinstance(controlnet_cond_scale, list):
|
||||
controlnet_cond_scale = controlnet_cond_scale[0]
|
||||
cond_scale = controlnet_cond_scale * controlnet_keep[i]
|
||||
|
||||
down_block_res_samples, mid_block_res_sample = self.controlnet(
|
||||
control_model_input,
|
||||
t,
|
||||
encoder_hidden_states=controlnet_prompt_embeds,
|
||||
controlnet_cond=image,
|
||||
conditioning_scale=cond_scale,
|
||||
guess_mode=guess_mode,
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
if guess_mode and do_classifier_free_guidance:
|
||||
# Infered ControlNet only for the conditional batch.
|
||||
# To apply the output of ControlNet to both the unconditional and conditional batches,
|
||||
# add 0 to the unconditional batch to keep it unchanged.
|
||||
down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
|
||||
mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
down_block_additional_residuals=down_block_res_samples,
|
||||
mid_block_additional_residual=mid_block_res_sample,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
# If we do sequential model offloading, let's offload unet and controlnet
|
||||
# manually for max memory savings
|
||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||
self.unet.to("cpu")
|
||||
self.controlnet.to("cpu")
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if not output_type == "latent":
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
||||
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
||||
else:
|
||||
image = latents
|
||||
has_nsfw_concept = None
|
||||
|
||||
if has_nsfw_concept is None:
|
||||
do_denormalize = [True] * image.shape[0]
|
||||
else:
|
||||
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
|
||||
|
||||
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
|
||||
|
||||
# Offload last model to CPU
|
||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||
self.final_offload_hook.offload()
|
||||
|
||||
if not return_dict:
|
||||
return (image, has_nsfw_concept)
|
||||
|
||||
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
Loading…
Reference in New Issue
Block a user