Clean up IP-Adapter in diffusers_pipeline.py - WIP

This commit is contained in:
Ryan Dick 2023-09-06 20:42:20 -04:00
parent cdbf40c9b2
commit 23fdf0156f

View File

@ -465,12 +465,10 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# Should reimplement to use existing model management context etc.
#
if "sdxl" in ip_adapter_data.ip_adapter_model:
print("using IPAdapterXL")
ip_adapter = IPAdapterXL(
self, ip_adapter_data.image_encoder_model, ip_adapter_data.ip_adapter_model, self.unet.device
)
elif "plus" in ip_adapter_data.ip_adapter_model:
print("using IPAdapterPlus")
ip_adapter = IPAdapterPlus(
self, # IPAdapterPlus first arg is StableDiffusionPipeline
ip_adapter_data.image_encoder_model,
@ -479,33 +477,25 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
num_tokens=16,
)
else:
print("using IPAdapter")
ip_adapter = IPAdapter(
self, # IPAdapter first arg is StableDiffusionPipeline
ip_adapter_data.image_encoder_model,
ip_adapter_data.ip_adapter_model,
self.unet.device,
)
# IP-Adapter ==> add additional cross-attention layers to UNet model here?
ip_adapter.set_scale(ip_adapter_data.weight)
print("ip_adapter:", ip_adapter)
# get image embedding from CLIP and ImageProjModel
print("getting image embeddings from IP-Adapter...")
num_samples = 1 # hardwiring for first pass
# Get image embeddings from CLIP and ImageProjModel.
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter.get_image_embeds(ip_adapter_data.image)
print("image cond embeds shape:", image_prompt_embeds.shape)
print("image uncond embeds shape:", uncond_image_prompt_embeds.shape)
bs_embed, seq_len, _ = image_prompt_embeds.shape
image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
print("image cond embeds shape:", image_prompt_embeds.shape)
print("image uncond embeds shape:", uncond_image_prompt_embeds.shape)
# IP-Adapter: run IP-Adapter model here?
# and add output as additional cross-attention layers
# The following commented block is kept for reference on how to repeat/reshape the image embeddings to
# generate a batch of multiple images:
# bs_embed, seq_len, _ = image_prompt_embeds.shape
# image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
# image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
# uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
# uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
text_prompt_embeds = conditioning_data.text_embeddings.embeds
uncond_text_prompt_embeds = conditioning_data.unconditioned_embeddings.embeds
print("text embeds shape:", text_prompt_embeds.shape)