Modifying code from https://github.com/tencent-ailab/IP-Adapter. Also adding license notice at top.

This commit is contained in:
user1 2023-08-29 06:29:05 -07:00
parent 1ad98ce999
commit 8c1390166f
4 changed files with 94 additions and 68 deletions

View File

@ -1,3 +1,7 @@
# copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0)
# and modified as needed
# tencent-ailab comment:
# modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py # modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -74,8 +78,8 @@ class AttnProcessor(nn.Module):
hidden_states = hidden_states / attn.rescale_output_factor hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states return hidden_states
class IPAttnProcessor(nn.Module): class IPAttnProcessor(nn.Module):
r""" r"""
Attention processor for IP-Adapater. Attention processor for IP-Adapater.
@ -134,7 +138,7 @@ class IPAttnProcessor(nn.Module):
encoder_hidden_states = hidden_states encoder_hidden_states = hidden_states
elif attn.norm_cross: elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
# split hidden states # split hidden states
encoder_hidden_states, ip_hidden_states = encoder_hidden_states[:, :self.text_context_len, :], encoder_hidden_states[:, self.text_context_len:, :] encoder_hidden_states, ip_hidden_states = encoder_hidden_states[:, :self.text_context_len, :], encoder_hidden_states[:, self.text_context_len:, :]
@ -148,18 +152,18 @@ class IPAttnProcessor(nn.Module):
attention_probs = attn.get_attention_scores(query, key, attention_mask) attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value) hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states) hidden_states = attn.batch_to_head_dim(hidden_states)
# for ip-adapter # for ip-adapter
ip_key = self.to_k_ip(ip_hidden_states) ip_key = self.to_k_ip(ip_hidden_states)
ip_value = self.to_v_ip(ip_hidden_states) ip_value = self.to_v_ip(ip_hidden_states)
ip_key = attn.head_to_batch_dim(ip_key) ip_key = attn.head_to_batch_dim(ip_key)
ip_value = attn.head_to_batch_dim(ip_value) ip_value = attn.head_to_batch_dim(ip_value)
ip_attention_probs = attn.get_attention_scores(query, ip_key, None) ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states) ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
hidden_states = hidden_states + self.scale * ip_hidden_states hidden_states = hidden_states + self.scale * ip_hidden_states
# linear proj # linear proj
@ -176,8 +180,8 @@ class IPAttnProcessor(nn.Module):
hidden_states = hidden_states / attn.rescale_output_factor hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states return hidden_states
class AttnProcessor2_0(torch.nn.Module): class AttnProcessor2_0(torch.nn.Module):
r""" r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
@ -264,8 +268,8 @@ class AttnProcessor2_0(torch.nn.Module):
hidden_states = hidden_states / attn.rescale_output_factor hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states return hidden_states
class IPAttnProcessor2_0(torch.nn.Module): class IPAttnProcessor2_0(torch.nn.Module):
r""" r"""
Attention processor for IP-Adapater for PyTorch 2.0. Attention processor for IP-Adapater for PyTorch 2.0.
@ -355,11 +359,11 @@ class IPAttnProcessor2_0(torch.nn.Module):
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype) hidden_states = hidden_states.to(query.dtype)
# for ip-adapter # for ip-adapter
ip_key = self.to_k_ip(ip_hidden_states) ip_key = self.to_k_ip(ip_hidden_states)
ip_value = self.to_v_ip(ip_hidden_states) ip_value = self.to_v_ip(ip_hidden_states)
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
@ -368,10 +372,10 @@ class IPAttnProcessor2_0(torch.nn.Module):
ip_hidden_states = F.scaled_dot_product_attention( ip_hidden_states = F.scaled_dot_product_attention(
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
) )
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
ip_hidden_states = ip_hidden_states.to(query.dtype) ip_hidden_states = ip_hidden_states.to(query.dtype)
hidden_states = hidden_states + self.scale * ip_hidden_states hidden_states = hidden_states + self.scale * ip_hidden_states
# linear proj # linear proj

View File

@ -1,3 +1,6 @@
# copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0)
# and modified as needed
import os import os
from typing import List from typing import List
@ -6,11 +9,14 @@ from diffusers import StableDiffusionPipeline
from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
from PIL import Image from PIL import Image
from .utils import is_torch2_available # FIXME: Getting errors when trying to use PyTorch 2.0 versions of IPAttnProcessor and AttnProcessor
if is_torch2_available: # so for now falling back to the default versions
from .attention_processor import IPAttnProcessor2_0 as IPAttnProcessor, AttnProcessor2_0 as AttnProcessor # from .utils import is_torch2_available
else: # if is_torch2_available:
from .attention_processor import IPAttnProcessor, AttnProcessor # from .attention_processor import IPAttnProcessor2_0 as IPAttnProcessor, AttnProcessor2_0 as AttnProcessor
# else:
# from .attention_processor import IPAttnProcessor, AttnProcessor
from .attention_processor import IPAttnProcessor, AttnProcessor
from .resampler import Resampler from .resampler import Resampler
@ -18,12 +24,12 @@ class ImageProjModel(torch.nn.Module):
"""Projection Model""" """Projection Model"""
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4): def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
super().__init__() super().__init__()
self.cross_attention_dim = cross_attention_dim self.cross_attention_dim = cross_attention_dim
self.clip_extra_context_tokens = clip_extra_context_tokens self.clip_extra_context_tokens = clip_extra_context_tokens
self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim) self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
self.norm = torch.nn.LayerNorm(cross_attention_dim) self.norm = torch.nn.LayerNorm(cross_attention_dim)
def forward(self, image_embeds): def forward(self, image_embeds):
embeds = image_embeds embeds = image_embeds
clip_extra_context_tokens = self.proj(embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim) clip_extra_context_tokens = self.proj(embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim)
@ -32,25 +38,29 @@ class ImageProjModel(torch.nn.Module):
class IPAdapter: class IPAdapter:
def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens=4): def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens=4):
self.device = device self.device = device
self.image_encoder_path = image_encoder_path self.image_encoder_path = image_encoder_path
self.ip_ckpt = ip_ckpt self.ip_ckpt = ip_ckpt
self.num_tokens = num_tokens self.num_tokens = num_tokens
self.pipe = sd_pipe.to(self.device) # FIXME:
# InvokeAI StableDiffusionPipeline has a to() method that isn't meant to be used
# so for now assuming that pipeline is already on the correct device
# self.pipe = sd_pipe.to(self.device)
self.pipe = sd_pipe
self.set_ip_adapter() self.set_ip_adapter()
# load image encoder # load image encoder
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(self.device, dtype=torch.float16) self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(self.device, dtype=torch.float16)
self.clip_image_processor = CLIPImageProcessor() self.clip_image_processor = CLIPImageProcessor()
# image proj model # image proj model
self.image_proj_model = self.init_proj() self.image_proj_model = self.init_proj()
self.load_ip_adapter() self.load_ip_adapter()
def init_proj(self): def init_proj(self):
image_proj_model = ImageProjModel( image_proj_model = ImageProjModel(
cross_attention_dim=self.pipe.unet.config.cross_attention_dim, cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
@ -58,10 +68,12 @@ class IPAdapter:
clip_extra_context_tokens=self.num_tokens, clip_extra_context_tokens=self.num_tokens,
).to(self.device, dtype=torch.float16) ).to(self.device, dtype=torch.float16)
return image_proj_model return image_proj_model
def set_ip_adapter(self): def set_ip_adapter(self):
unet = self.pipe.unet unet = self.pipe.unet
attn_procs = {} attn_procs = {}
print("Original UNet Attn Processors count:", len(unet.attn_processors))
print(unet.attn_processors.keys())
for name in unet.attn_processors.keys(): for name in unet.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
if name.startswith("mid_block"): if name.startswith("mid_block"):
@ -75,16 +87,19 @@ class IPAdapter:
if cross_attention_dim is None: if cross_attention_dim is None:
attn_procs[name] = AttnProcessor() attn_procs[name] = AttnProcessor()
else: else:
print("swapping in IPAttnProcessor for", name)
attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim,
scale=1.0).to(self.device, dtype=torch.float16) scale=1.0).to(self.device, dtype=torch.float16)
unet.set_attn_processor(attn_procs) unet.set_attn_processor(attn_procs)
print("Modified UNet Attn Processors count:", len(unet.attn_processors))
print(unet.attn_processors.keys())
def load_ip_adapter(self): def load_ip_adapter(self):
state_dict = torch.load(self.ip_ckpt, map_location="cpu") state_dict = torch.load(self.ip_ckpt, map_location="cpu")
self.image_proj_model.load_state_dict(state_dict["image_proj"]) self.image_proj_model.load_state_dict(state_dict["image_proj"])
ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values()) ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
ip_layers.load_state_dict(state_dict["ip_adapter"]) ip_layers.load_state_dict(state_dict["ip_adapter"])
@torch.inference_mode() @torch.inference_mode()
def get_image_embeds(self, pil_image): def get_image_embeds(self, pil_image):
if isinstance(pil_image, Image.Image): if isinstance(pil_image, Image.Image):
@ -94,12 +109,14 @@ class IPAdapter:
image_prompt_embeds = self.image_proj_model(clip_image_embeds) image_prompt_embeds = self.image_proj_model(clip_image_embeds)
uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds)) uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds))
return image_prompt_embeds, uncond_image_prompt_embeds return image_prompt_embeds, uncond_image_prompt_embeds
def set_scale(self, scale): def set_scale(self, scale):
for attn_processor in self.pipe.unet.attn_processors.values(): for attn_processor in self.pipe.unet.attn_processors.values():
if isinstance(attn_processor, IPAttnProcessor): if isinstance(attn_processor, IPAttnProcessor):
attn_processor.scale = scale attn_processor.scale = scale
# IPAdapter.generate() method is not used for InvokeAI
# left here for reference
def generate( def generate(
self, self,
pil_image, pil_image,
@ -113,22 +130,22 @@ class IPAdapter:
**kwargs, **kwargs,
): ):
self.set_scale(scale) self.set_scale(scale)
if isinstance(pil_image, Image.Image): if isinstance(pil_image, Image.Image):
num_prompts = 1 num_prompts = 1
else: else:
num_prompts = len(pil_image) num_prompts = len(pil_image)
if prompt is None: if prompt is None:
prompt = "best quality, high quality" prompt = "best quality, high quality"
if negative_prompt is None: if negative_prompt is None:
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
if not isinstance(prompt, List): if not isinstance(prompt, List):
prompt = [prompt] * num_prompts prompt = [prompt] * num_prompts
if not isinstance(negative_prompt, List): if not isinstance(negative_prompt, List):
negative_prompt = [negative_prompt] * num_prompts negative_prompt = [negative_prompt] * num_prompts
image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image) image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image)
bs_embed, seq_len, _ = image_prompt_embeds.shape bs_embed, seq_len, _ = image_prompt_embeds.shape
image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
@ -142,7 +159,7 @@ class IPAdapter:
negative_prompt_embeds_, prompt_embeds_ = prompt_embeds.chunk(2) negative_prompt_embeds_, prompt_embeds_ = prompt_embeds.chunk(2)
prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1) prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1) negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
images = self.pipe( images = self.pipe(
prompt_embeds=prompt_embeds, prompt_embeds=prompt_embeds,
@ -152,13 +169,13 @@ class IPAdapter:
generator=generator, generator=generator,
**kwargs, **kwargs,
).images ).images
return images return images
class IPAdapterXL(IPAdapter): class IPAdapterXL(IPAdapter):
"""SDXL""" """SDXL"""
def generate( def generate(
self, self,
pil_image, pil_image,
@ -171,22 +188,22 @@ class IPAdapterXL(IPAdapter):
**kwargs, **kwargs,
): ):
self.set_scale(scale) self.set_scale(scale)
if isinstance(pil_image, Image.Image): if isinstance(pil_image, Image.Image):
num_prompts = 1 num_prompts = 1
else: else:
num_prompts = len(pil_image) num_prompts = len(pil_image)
if prompt is None: if prompt is None:
prompt = "best quality, high quality" prompt = "best quality, high quality"
if negative_prompt is None: if negative_prompt is None:
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
if not isinstance(prompt, List): if not isinstance(prompt, List):
prompt = [prompt] * num_prompts prompt = [prompt] * num_prompts
if not isinstance(negative_prompt, List): if not isinstance(negative_prompt, List):
negative_prompt = [negative_prompt] * num_prompts negative_prompt = [negative_prompt] * num_prompts
image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image) image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image)
bs_embed, seq_len, _ = image_prompt_embeds.shape bs_embed, seq_len, _ = image_prompt_embeds.shape
image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
@ -199,7 +216,7 @@ class IPAdapterXL(IPAdapter):
prompt, num_images_per_prompt=num_samples, do_classifier_free_guidance=True, negative_prompt=negative_prompt) prompt, num_images_per_prompt=num_samples, do_classifier_free_guidance=True, negative_prompt=negative_prompt)
prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1) prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1) negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
images = self.pipe( images = self.pipe(
prompt_embeds=prompt_embeds, prompt_embeds=prompt_embeds,
@ -210,10 +227,10 @@ class IPAdapterXL(IPAdapter):
generator=generator, generator=generator,
**kwargs, **kwargs,
).images ).images
return images return images
class IPAdapterPlus(IPAdapter): class IPAdapterPlus(IPAdapter):
"""IP-Adapter with fine-grained features""" """IP-Adapter with fine-grained features"""
@ -229,7 +246,7 @@ class IPAdapterPlus(IPAdapter):
ff_mult=4 ff_mult=4
).to(self.device, dtype=torch.float16) ).to(self.device, dtype=torch.float16)
return image_proj_model return image_proj_model
@torch.inference_mode() @torch.inference_mode()
def get_image_embeds(self, pil_image): def get_image_embeds(self, pil_image):
if isinstance(pil_image, Image.Image): if isinstance(pil_image, Image.Image):
@ -240,4 +257,4 @@ class IPAdapterPlus(IPAdapter):
image_prompt_embeds = self.image_proj_model(clip_image_embeds) image_prompt_embeds = self.image_proj_model(clip_image_embeds)
uncond_clip_image_embeds = self.image_encoder(torch.zeros_like(clip_image), output_hidden_states=True).hidden_states[-2] uncond_clip_image_embeds = self.image_encoder(torch.zeros_like(clip_image), output_hidden_states=True).hidden_states[-2]
uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds) uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
return image_prompt_embeds, uncond_image_prompt_embeds return image_prompt_embeds, uncond_image_prompt_embeds

View File

@ -1,4 +1,6 @@
# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py # copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0)
# tencent ailab comment: modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
import math import math
import torch import torch
@ -14,8 +16,8 @@ def FeedForward(dim, mult=4):
nn.GELU(), nn.GELU(),
nn.Linear(inner_dim, dim, bias=False), nn.Linear(inner_dim, dim, bias=False),
) )
def reshape_tensor(x, heads): def reshape_tensor(x, heads):
bs, length, width = x.shape bs, length, width = x.shape
#(bs, length, width) --> (bs, length, n_heads, dim_per_head) #(bs, length, width) --> (bs, length, n_heads, dim_per_head)
@ -53,13 +55,13 @@ class PerceiverAttention(nn.Module):
""" """
x = self.norm1(x) x = self.norm1(x)
latents = self.norm2(latents) latents = self.norm2(latents)
b, l, _ = latents.shape b, l, _ = latents.shape
q = self.to_q(latents) q = self.to_q(latents)
kv_input = torch.cat((x, latents), dim=-2) kv_input = torch.cat((x, latents), dim=-2)
k, v = self.to_kv(kv_input).chunk(2, dim=-1) k, v = self.to_kv(kv_input).chunk(2, dim=-1)
q = reshape_tensor(q, self.heads) q = reshape_tensor(q, self.heads)
k = reshape_tensor(k, self.heads) k = reshape_tensor(k, self.heads)
v = reshape_tensor(v, self.heads) v = reshape_tensor(v, self.heads)
@ -69,7 +71,7 @@ class PerceiverAttention(nn.Module):
weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
out = weight @ v out = weight @ v
out = out.permute(0, 2, 1, 3).reshape(b, l, -1) out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
return self.to_out(out) return self.to_out(out)
@ -88,14 +90,14 @@ class Resampler(nn.Module):
ff_mult=4, ff_mult=4,
): ):
super().__init__() super().__init__()
self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5) self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
self.proj_in = nn.Linear(embedding_dim, dim) self.proj_in = nn.Linear(embedding_dim, dim)
self.proj_out = nn.Linear(dim, output_dim) self.proj_out = nn.Linear(dim, output_dim)
self.norm_out = nn.LayerNorm(output_dim) self.norm_out = nn.LayerNorm(output_dim)
self.layers = nn.ModuleList([]) self.layers = nn.ModuleList([])
for _ in range(depth): for _ in range(depth):
self.layers.append( self.layers.append(
@ -108,14 +110,14 @@ class Resampler(nn.Module):
) )
def forward(self, x): def forward(self, x):
latents = self.latents.repeat(x.size(0), 1, 1) latents = self.latents.repeat(x.size(0), 1, 1)
x = self.proj_in(x) x = self.proj_in(x)
for attn, ff in self.layers: for attn, ff in self.layers:
latents = attn(x, latents) + latents latents = attn(x, latents) + latents
latents = ff(latents) + latents latents = ff(latents) + latents
latents = self.proj_out(latents) latents = self.proj_out(latents)
return self.norm_out(latents) return self.norm_out(latents)

View File

@ -1,3 +1,6 @@
# copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0)
# and modified as needed
import inspect import inspect
import warnings import warnings
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Tuple, Union
@ -365,4 +368,4 @@ def generate(
if not return_dict: if not return_dict:
return (image, has_nsfw_concept) return (image, has_nsfw_concept)
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)