cleanup: merge conflicts

This commit is contained in:
blessedcoolant 2023-09-05 11:37:12 +12:00
parent 6bb378a101
commit 07381e5a26
8 changed files with 86 additions and 69 deletions

View File

@ -29,6 +29,7 @@ CONTROLNET_RESIZE_VALUES = Literal[
"just_resize_simple", "just_resize_simple",
] ]
class ControlNetModelField(BaseModel): class ControlNetModelField(BaseModel):
"""ControlNet model field""" """ControlNet model field"""
@ -68,6 +69,7 @@ class ControlField(BaseModel):
raise ValueError("Control weights must be within -1 to 2 range") raise ValueError("Control weights must be within -1 to 2 range")
return v return v
@invocation_output("control_output") @invocation_output("control_output")
class ControlOutput(BaseInvocationOutput): class ControlOutput(BaseInvocationOutput):
"""node output for ControlNet info""" """node output for ControlNet info"""
@ -78,7 +80,6 @@ class ControlOutput(BaseInvocationOutput):
control: ControlField = OutputField(description=FieldDescriptions.control) control: ControlField = OutputField(description=FieldDescriptions.control)
@invocation("controlnet", title="ControlNet", tags=["controlnet"], category="controlnet") @invocation("controlnet", title="ControlNet", tags=["controlnet"], category="controlnet")
class ControlNetInvocation(BaseInvocation): class ControlNetInvocation(BaseInvocation):
"""Collects ControlNet info to pass to other nodes""" """Collects ControlNet info to pass to other nodes"""
@ -119,19 +120,21 @@ class ControlNetInvocation(BaseInvocation):
), ),
) )
IP_ADAPTER_MODELS = Literal[ IP_ADAPTER_MODELS = Literal[
"models_ip_adapter/models/ip-adapter_sd15.bin", "models_ip_adapter/models/ip-adapter_sd15.bin",
"models_ip_adapter/models/ip-adapter-plus_sd15.bin", "models_ip_adapter/models/ip-adapter-plus_sd15.bin",
"models_ip_adapter/models/ip-adapter-plus-face_sd15.bin", "models_ip_adapter/models/ip-adapter-plus-face_sd15.bin",
"models_ip_adapter/sdxl_models/ip-adapter_sdxl.bin" "models_ip_adapter/sdxl_models/ip-adapter_sdxl.bin",
] ]
IP_ADAPTER_IMAGE_ENCODER_MODELS = Literal[ IP_ADAPTER_IMAGE_ENCODER_MODELS = Literal[
"models_ip_adapter/models/image_encoder/", "models_ip_adapter/models/image_encoder/",
"./models_ip_adapter/models/image_encoder/", "./models_ip_adapter/models/image_encoder/",
"models_ip_adapter/sdxl_models/image_encoder/" "models_ip_adapter/sdxl_models/image_encoder/",
] ]
@invocation("ipadapter", title="IP-Adapter", tags=["ipadapter"], category="ipadapter") @invocation("ipadapter", title="IP-Adapter", tags=["ipadapter"], category="ipadapter")
class IPAdapterInvocation(BaseInvocation): class IPAdapterInvocation(BaseInvocation):
"""Collects IP-Adapter info to pass to other nodes""" """Collects IP-Adapter info to pass to other nodes"""
@ -140,14 +143,15 @@ class IPAdapterInvocation(BaseInvocation):
# Inputs # Inputs
image: ImageField = InputField(description="The control image") image: ImageField = InputField(description="The control image")
#control_model: ControlNetModelField = InputField( # control_model: ControlNetModelField = InputField(
# default="lllyasviel/sd-controlnet-canny", description=FieldDescriptions.controlnet_model, input=Input.Direct # default="lllyasviel/sd-controlnet-canny", description=FieldDescriptions.controlnet_model, input=Input.Direct
#) # )
ip_adapter_model: IP_ADAPTER_MODELS = InputField(default="./models_ip_adapter/models/ip-adapter_sd15.bin", ip_adapter_model: IP_ADAPTER_MODELS = InputField(
description="The IP-Adapter model") default="./models_ip_adapter/models/ip-adapter_sd15.bin", description="The IP-Adapter model"
)
image_encoder_model: IP_ADAPTER_IMAGE_ENCODER_MODELS = InputField( image_encoder_model: IP_ADAPTER_IMAGE_ENCODER_MODELS = InputField(
default="./models_ip_adapter/models/image_encoder/", default="./models_ip_adapter/models/image_encoder/", description="The image encoder model"
description="The image encoder model") )
control_weight: Union[float, List[float]] = InputField( control_weight: Union[float, List[float]] = InputField(
default=1.0, description="The weight given to the ControlNet", ui_type=UIType.Float default=1.0, description="The weight given to the ControlNet", ui_type=UIType.Float
) )
@ -172,9 +176,9 @@ class IPAdapterInvocation(BaseInvocation):
image_encoder_model=self.image_encoder_model, image_encoder_model=self.image_encoder_model,
control_weight=self.control_weight, control_weight=self.control_weight,
# rest are currently ignored # rest are currently ignored
#begin_step_percent=self.begin_step_percent, # begin_step_percent=self.begin_step_percent,
#end_step_percent=self.end_step_percent, # end_step_percent=self.end_step_percent,
#control_mode=self.control_mode, # control_mode=self.control_mode,
#resize_mode=self.resize_mode, # resize_mode=self.resize_mode,
), ),
) )

View File

@ -1,7 +1,7 @@
# Invocations for ControlNet image preprocessors # Invocations for ControlNet image preprocessors
# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux # heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux
from builtins import bool, float from builtins import bool, float
from typing import Dict, List, Literal, Optional, Union from typing import Dict, List, Optional
import cv2 import cv2
import numpy as np import numpy as np
@ -27,17 +27,7 @@ from PIL import Image
from invokeai.app.invocations.primitives import ImageField, ImageOutput from invokeai.app.invocations.primitives import ImageField, ImageOutput
from ..models.image import ImageCategory, ResourceOrigin from ..models.image import ImageCategory, ResourceOrigin
from .baseinvocation import ( from .baseinvocation import BaseInvocation, FieldDescriptions, InputField, InvocationContext, invocation
BaseInvocation,
BaseInvocationOutput,
FieldDescriptions,
Input,
InputField,
InvocationContext,
OutputField,
UIType,
invocation,
)
@invocation( @invocation(

View File

@ -65,7 +65,6 @@ from .control_adapter import ControlField
from .model import ModelInfo, UNetField, VaeField from .model import ModelInfo, UNetField, VaeField
DEFAULT_PRECISION = choose_precision(choose_torch_device()) DEFAULT_PRECISION = choose_precision(choose_torch_device())
SAMPLER_NAME_VALUES = Literal[tuple(list(SCHEDULER_MAP.keys()))] SAMPLER_NAME_VALUES = Literal[tuple(list(SCHEDULER_MAP.keys()))]
@ -387,7 +386,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
resize_mode=control_info.resize_mode, resize_mode=control_info.resize_mode,
) )
control_item = ControlNetData( control_item = ControlNetData(
model=control_model, # model object model=control_model, # model object
image_tensor=control_image, image_tensor=control_image,
weight=control_info.control_weight, weight=control_info.control_weight,
begin_step_percent=control_info.begin_step_percent, begin_step_percent=control_info.begin_step_percent,
@ -404,7 +403,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
input_image = context.services.images.get_pil_image(control_image_field.image_name) input_image = context.services.images.get_pil_image(control_image_field.image_name)
control_item = IPAdapterData( control_item = IPAdapterData(
ip_adapter_model=control_info.ip_adapter_model, # name of model (NOT model object) ip_adapter_model=control_info.ip_adapter_model, # name of model (NOT model object)
image_encoder_model=control_info.image_encoder_model, # name of model (NOT model obj) image_encoder_model=control_info.image_encoder_model, # name of model (NOT model obj)
image=input_image, image=input_image,
weight=control_info.control_weight, weight=control_info.control_weight,
) )
@ -564,8 +563,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
conditioning_data=conditioning_data, conditioning_data=conditioning_data,
control_data=controlnet_data, # list[ControlNetData], control_data=controlnet_data, # list[ControlNetData],
ip_adapter_data=ip_adapter_data, # list[IPAdapterData], ip_adapter_data=ip_adapter_data, # list[IPAdapterData],
# ip_adapter_image=unwrapped_ip_adapter_image, # ip_adapter_image=unwrapped_ip_adapter_image,
# ip_adapter_strength=self.ip_adapter_strength, # ip_adapter_strength=self.ip_adapter_strength,
callback=step_callback, callback=step_callback,
) )

View File

@ -12,6 +12,7 @@ class AttnProcessor(nn.Module):
r""" r"""
Default processor for performing attention-related computations. Default processor for performing attention-related computations.
""" """
def __init__( def __init__(
self, self,
hidden_size=None, hidden_size=None,
@ -140,7 +141,10 @@ class IPAttnProcessor(nn.Module):
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
# split hidden states # split hidden states
encoder_hidden_states, ip_hidden_states = encoder_hidden_states[:, :self.text_context_len, :], encoder_hidden_states[:, self.text_context_len:, :] encoder_hidden_states, ip_hidden_states = (
encoder_hidden_states[:, : self.text_context_len, :],
encoder_hidden_states[:, self.text_context_len :, :],
)
key = attn.to_k(encoder_hidden_states) key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states) value = attn.to_v(encoder_hidden_states)
@ -186,6 +190,7 @@ class AttnProcessor2_0(torch.nn.Module):
r""" r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
""" """
def __init__( def __init__(
self, self,
hidden_size=None, hidden_size=None,
@ -338,7 +343,10 @@ class IPAttnProcessor2_0(torch.nn.Module):
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
# split hidden states # split hidden states
encoder_hidden_states, ip_hidden_states = encoder_hidden_states[:, :self.text_context_len, :], encoder_hidden_states[:, self.text_context_len:, :] encoder_hidden_states, ip_hidden_states = (
encoder_hidden_states[:, : self.text_context_len, :],
encoder_hidden_states[:, self.text_context_len :, :],
)
key = attn.to_k(encoder_hidden_states) key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states) value = attn.to_v(encoder_hidden_states)

View File

@ -22,6 +22,7 @@ from .resampler import Resampler
class ImageProjModel(torch.nn.Module): class ImageProjModel(torch.nn.Module):
"""Projection Model""" """Projection Model"""
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4): def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
super().__init__() super().__init__()
@ -32,15 +33,15 @@ class ImageProjModel(torch.nn.Module):
def forward(self, image_embeds): def forward(self, image_embeds):
embeds = image_embeds embeds = image_embeds
clip_extra_context_tokens = self.proj(embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim) clip_extra_context_tokens = self.proj(embeds).reshape(
-1, self.clip_extra_context_tokens, self.cross_attention_dim
)
clip_extra_context_tokens = self.norm(clip_extra_context_tokens) clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
return clip_extra_context_tokens return clip_extra_context_tokens
class IPAdapter: class IPAdapter:
def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens=4): def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens=4):
self.device = device self.device = device
self.image_encoder_path = image_encoder_path self.image_encoder_path = image_encoder_path
self.ip_ckpt = ip_ckpt self.ip_ckpt = ip_ckpt
@ -54,7 +55,9 @@ class IPAdapter:
self.set_ip_adapter() self.set_ip_adapter()
# load image encoder # load image encoder
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(self.device, dtype=torch.float16) self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(
self.device, dtype=torch.float16
)
self.clip_image_processor = CLIPImageProcessor() self.clip_image_processor = CLIPImageProcessor()
# image proj model # image proj model
self.image_proj_model = self.init_proj() self.image_proj_model = self.init_proj()
@ -88,8 +91,9 @@ class IPAdapter:
attn_procs[name] = AttnProcessor() attn_procs[name] = AttnProcessor()
else: else:
print("swapping in IPAttnProcessor for", name) print("swapping in IPAttnProcessor for", name)
attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, attn_procs[name] = IPAttnProcessor(
scale=1.0).to(self.device, dtype=torch.float16) hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0
).to(self.device, dtype=torch.float16)
unet.set_attn_processor(attn_procs) unet.set_attn_processor(attn_procs)
print("Modified UNet Attn Processors count:", len(unet.attn_processors)) print("Modified UNet Attn Processors count:", len(unet.attn_processors))
print(unet.attn_processors.keys()) print(unet.attn_processors.keys())
@ -155,7 +159,12 @@ class IPAdapter:
with torch.inference_mode(): with torch.inference_mode():
prompt_embeds = self.pipe._encode_prompt( prompt_embeds = self.pipe._encode_prompt(
prompt, device=self.device, num_images_per_prompt=num_samples, do_classifier_free_guidance=True, negative_prompt=negative_prompt) prompt,
device=self.device,
num_images_per_prompt=num_samples,
do_classifier_free_guidance=True,
negative_prompt=negative_prompt,
)
negative_prompt_embeds_, prompt_embeds_ = prompt_embeds.chunk(2) negative_prompt_embeds_, prompt_embeds_ = prompt_embeds.chunk(2)
prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1) prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1) negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
@ -212,8 +221,17 @@ class IPAdapterXL(IPAdapter):
uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
with torch.inference_mode(): with torch.inference_mode():
prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = self.pipe.encode_prompt( (
prompt, num_images_per_prompt=num_samples, do_classifier_free_guidance=True, negative_prompt=negative_prompt) prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = self.pipe.encode_prompt(
prompt,
num_images_per_prompt=num_samples,
do_classifier_free_guidance=True,
negative_prompt=negative_prompt,
)
prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1) prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1) negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
@ -243,7 +261,7 @@ class IPAdapterPlus(IPAdapter):
num_queries=self.num_tokens, num_queries=self.num_tokens,
embedding_dim=self.image_encoder.config.hidden_size, embedding_dim=self.image_encoder.config.hidden_size,
output_dim=self.pipe.unet.config.cross_attention_dim, output_dim=self.pipe.unet.config.cross_attention_dim,
ff_mult=4 ff_mult=4,
).to(self.device, dtype=torch.float16) ).to(self.device, dtype=torch.float16)
return image_proj_model return image_proj_model
@ -255,6 +273,8 @@ class IPAdapterPlus(IPAdapter):
clip_image = clip_image.to(self.device, dtype=torch.float16) clip_image = clip_image.to(self.device, dtype=torch.float16)
clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
image_prompt_embeds = self.image_proj_model(clip_image_embeds) image_prompt_embeds = self.image_proj_model(clip_image_embeds)
uncond_clip_image_embeds = self.image_encoder(torch.zeros_like(clip_image), output_hidden_states=True).hidden_states[-2] uncond_clip_image_embeds = self.image_encoder(
torch.zeros_like(clip_image), output_hidden_states=True
).hidden_states[-2]
uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds) uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
return image_prompt_embeds, uncond_image_prompt_embeds return image_prompt_embeds, uncond_image_prompt_embeds

View File

@ -20,7 +20,7 @@ def FeedForward(dim, mult=4):
def reshape_tensor(x, heads): def reshape_tensor(x, heads):
bs, length, width = x.shape bs, length, width = x.shape
#(bs, length, width) --> (bs, length, n_heads, dim_per_head) # (bs, length, width) --> (bs, length, n_heads, dim_per_head)
x = x.view(bs, length, heads, -1) x = x.view(bs, length, heads, -1)
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
x = x.transpose(1, 2) x = x.transpose(1, 2)
@ -44,7 +44,6 @@ class PerceiverAttention(nn.Module):
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
self.to_out = nn.Linear(inner_dim, dim, bias=False) self.to_out = nn.Linear(inner_dim, dim, bias=False)
def forward(self, x, latents): def forward(self, x, latents):
""" """
Args: Args:
@ -68,7 +67,7 @@ class PerceiverAttention(nn.Module):
# attention # attention
scale = 1 / math.sqrt(math.sqrt(self.dim_head)) scale = 1 / math.sqrt(math.sqrt(self.dim_head))
weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
out = weight @ v out = weight @ v
@ -110,7 +109,6 @@ class Resampler(nn.Module):
) )
def forward(self, x): def forward(self, x):
latents = self.latents.repeat(x.size(0), 1, 1) latents = self.latents.repeat(x.size(0), 1, 1)
x = self.proj_in(x) x = self.proj_in(x)

View File

@ -15,7 +15,6 @@ from diffusers.models import ControlNetModel
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
def is_torch2_available(): def is_torch2_available():
return hasattr(F, "scaled_dot_product_attention") return hasattr(F, "scaled_dot_product_attention")
@ -150,9 +149,7 @@ def generate(
control_guidance_end = len(control_guidance_start) * [control_guidance_end] control_guidance_end = len(control_guidance_start) * [control_guidance_end]
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
control_guidance_start, control_guidance_end = mult * [control_guidance_start], mult * [ control_guidance_start, control_guidance_end = mult * [control_guidance_start], mult * [control_guidance_end]
control_guidance_end
]
# 1. Check inputs. Raise error if not correct # 1. Check inputs. Raise error if not correct
self.check_inputs( self.check_inputs(
@ -192,9 +189,7 @@ def generate(
guess_mode = guess_mode or global_pool_conditions guess_mode = guess_mode or global_pool_conditions
# 3. Encode input prompt # 3. Encode input prompt
text_encoder_lora_scale = ( text_encoder_lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
)
prompt_embeds = self._encode_prompt( prompt_embeds = self._encode_prompt(
prompt, prompt,
device, device,

View File

@ -179,6 +179,7 @@ class IPAdapterData:
# weight: Union[float, List[float]] = Field(default=1.0) # weight: Union[float, List[float]] = Field(default=1.0)
weight: float = Field(default=1.0) weight: float = Field(default=1.0)
@dataclass @dataclass
class ConditioningData: class ConditioningData:
unconditioned_embeddings: BasicConditioningInfo unconditioned_embeddings: BasicConditioningInfo
@ -442,7 +443,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
ip_adapter_data: List[IPAdapterData] = None, ip_adapter_data: List[IPAdapterData] = None,
callback: Callable[[PipelineIntermediateState], None] = None, callback: Callable[[PipelineIntermediateState], None] = None,
): ):
self._adjust_memory_efficient_attention(latents) self._adjust_memory_efficient_attention(latents)
if additional_guidance is None: if additional_guidance is None:
additional_guidance = [] additional_guidance = []
@ -469,30 +469,33 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# #
if "sdxl" in ip_adapter_info.ip_adapter_model: if "sdxl" in ip_adapter_info.ip_adapter_model:
print("using IPAdapterXL") print("using IPAdapterXL")
ip_adapter = IPAdapterXL(self, ip_adapter = IPAdapterXL(
ip_adapter_info.image_encoder_model, self, ip_adapter_info.image_encoder_model, ip_adapter_info.ip_adapter_model, self.unet.device
ip_adapter_info.ip_adapter_model, )
self.unet.device)
elif "plus" in ip_adapter_info.ip_adapter_model: elif "plus" in ip_adapter_info.ip_adapter_model:
print("using IPAdapterPlus") print("using IPAdapterPlus")
ip_adapter = IPAdapterPlus(self, # IPAdapterPlus first arg is StableDiffusionPipeline ip_adapter = IPAdapterPlus(
ip_adapter_info.image_encoder_model, self, # IPAdapterPlus first arg is StableDiffusionPipeline
ip_adapter_info.ip_adapter_model, ip_adapter_info.image_encoder_model,
self.unet.device, ip_adapter_info.ip_adapter_model,
num_tokens=16) self.unet.device,
num_tokens=16,
)
else: else:
print("using IPAdapter") print("using IPAdapter")
ip_adapter = IPAdapter(self, # IPAdapter first arg is StableDiffusionPipeline ip_adapter = IPAdapter(
ip_adapter_info.image_encoder_model, self, # IPAdapter first arg is StableDiffusionPipeline
ip_adapter_info.ip_adapter_model, ip_adapter_info.image_encoder_model,
self.unet.device) ip_adapter_info.ip_adapter_model,
self.unet.device,
)
# IP-Adapter ==> add additional cross-attention layers to UNet model here? # IP-Adapter ==> add additional cross-attention layers to UNet model here?
ip_adapter.set_scale(ip_adapter_info.weight) ip_adapter.set_scale(ip_adapter_info.weight)
print("ip_adapter:", ip_adapter) print("ip_adapter:", ip_adapter)
# get image embedding from CLIP and ImageProjModel # get image embedding from CLIP and ImageProjModel
print("getting image embeddings from IP-Adapter...") print("getting image embeddings from IP-Adapter...")
num_samples = 1 # hardwiring for first pass num_samples = 1 # hardwiring for first pass
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter.get_image_embeds(ip_adapter_image) image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter.get_image_embeds(ip_adapter_image)
print("image cond embeds shape:", image_prompt_embeds.shape) print("image cond embeds shape:", image_prompt_embeds.shape)
print("image uncond embeds shape:", uncond_image_prompt_embeds.shape) print("image uncond embeds shape:", uncond_image_prompt_embeds.shape)