mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Clean up IP-Adapter in diffusers_pipeline.py - WIP
This commit is contained in:
parent
cdbf40c9b2
commit
23fdf0156f
@ -465,12 +465,10 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
# Should reimplement to use existing model management context etc.
|
# Should reimplement to use existing model management context etc.
|
||||||
#
|
#
|
||||||
if "sdxl" in ip_adapter_data.ip_adapter_model:
|
if "sdxl" in ip_adapter_data.ip_adapter_model:
|
||||||
print("using IPAdapterXL")
|
|
||||||
ip_adapter = IPAdapterXL(
|
ip_adapter = IPAdapterXL(
|
||||||
self, ip_adapter_data.image_encoder_model, ip_adapter_data.ip_adapter_model, self.unet.device
|
self, ip_adapter_data.image_encoder_model, ip_adapter_data.ip_adapter_model, self.unet.device
|
||||||
)
|
)
|
||||||
elif "plus" in ip_adapter_data.ip_adapter_model:
|
elif "plus" in ip_adapter_data.ip_adapter_model:
|
||||||
print("using IPAdapterPlus")
|
|
||||||
ip_adapter = IPAdapterPlus(
|
ip_adapter = IPAdapterPlus(
|
||||||
self, # IPAdapterPlus first arg is StableDiffusionPipeline
|
self, # IPAdapterPlus first arg is StableDiffusionPipeline
|
||||||
ip_adapter_data.image_encoder_model,
|
ip_adapter_data.image_encoder_model,
|
||||||
@ -479,33 +477,25 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
num_tokens=16,
|
num_tokens=16,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
print("using IPAdapter")
|
|
||||||
ip_adapter = IPAdapter(
|
ip_adapter = IPAdapter(
|
||||||
self, # IPAdapter first arg is StableDiffusionPipeline
|
self, # IPAdapter first arg is StableDiffusionPipeline
|
||||||
ip_adapter_data.image_encoder_model,
|
ip_adapter_data.image_encoder_model,
|
||||||
ip_adapter_data.ip_adapter_model,
|
ip_adapter_data.ip_adapter_model,
|
||||||
self.unet.device,
|
self.unet.device,
|
||||||
)
|
)
|
||||||
# IP-Adapter ==> add additional cross-attention layers to UNet model here?
|
|
||||||
ip_adapter.set_scale(ip_adapter_data.weight)
|
ip_adapter.set_scale(ip_adapter_data.weight)
|
||||||
print("ip_adapter:", ip_adapter)
|
|
||||||
|
|
||||||
# get image embedding from CLIP and ImageProjModel
|
# Get image embeddings from CLIP and ImageProjModel.
|
||||||
print("getting image embeddings from IP-Adapter...")
|
|
||||||
num_samples = 1 # hardwiring for first pass
|
|
||||||
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter.get_image_embeds(ip_adapter_data.image)
|
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter.get_image_embeds(ip_adapter_data.image)
|
||||||
print("image cond embeds shape:", image_prompt_embeds.shape)
|
|
||||||
print("image uncond embeds shape:", uncond_image_prompt_embeds.shape)
|
|
||||||
bs_embed, seq_len, _ = image_prompt_embeds.shape
|
|
||||||
image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
|
|
||||||
image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
|
|
||||||
uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
|
|
||||||
uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
|
|
||||||
print("image cond embeds shape:", image_prompt_embeds.shape)
|
|
||||||
print("image uncond embeds shape:", uncond_image_prompt_embeds.shape)
|
|
||||||
|
|
||||||
# IP-Adapter: run IP-Adapter model here?
|
# The following commented block is kept for reference on how to repeat/reshape the image embeddings to
|
||||||
# and add output as additional cross-attention layers
|
# generate a batch of multiple images:
|
||||||
|
# bs_embed, seq_len, _ = image_prompt_embeds.shape
|
||||||
|
# image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
|
||||||
|
# image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
|
||||||
|
# uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
|
||||||
|
# uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
|
||||||
|
|
||||||
text_prompt_embeds = conditioning_data.text_embeddings.embeds
|
text_prompt_embeds = conditioning_data.text_embeddings.embeds
|
||||||
uncond_text_prompt_embeds = conditioning_data.unconditioned_embeddings.embeds
|
uncond_text_prompt_embeds = conditioning_data.unconditioned_embeddings.embeds
|
||||||
print("text embeds shape:", text_prompt_embeds.shape)
|
print("text embeds shape:", text_prompt_embeds.shape)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user