Specify the image_embedding_len in the IPAttnProcessor rather than the text embedding length. This enables the IPAttnProcessor to handle text embeddings of varying lengths.

This commit is contained in:
Ryan Dick
2023-09-07 18:20:21 -04:00
parent 7703bf2ca1
commit c2d43f007b
2 changed files with 19 additions and 14 deletions

View File

@ -91,7 +91,10 @@ class IPAdapter:
else:
print("swapping in IPAttnProcessor for", name)
attn_procs[name] = IPAttnProcessor(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0
hidden_size=hidden_size,
image_embedding_len=self.num_tokens,
cross_attention_dim=cross_attention_dim,
scale=1.0,
).to(self.device, dtype=torch.float16)
unet.set_attn_processor(attn_procs)
print("Modified UNet Attn Processors count:", len(unet.attn_processors))