Update IP-Adapter unit test for multi-image.

This commit is contained in:
Ryan Dick 2023-10-14 13:00:52 -04:00
parent 8464450a53
commit 49279bbe74

View File

@ -65,7 +65,10 @@ def test_ip_adapter_unet_patch(model_params, model_installer, torch_device):
ip_adapter.to(torch_device, dtype=torch.float32)
unet.to(torch_device, dtype=torch.float32)
cross_attention_kwargs = {"ip_adapter_image_prompt_embeds": [torch.randn((1, 4, 768)).to(torch_device)]}
# ip_embeds shape: (batch_size, num_ip_images, seq_len, ip_image_embedding_len)
ip_embeds = torch.randn((1, 3, 4, 768)).to(torch_device)
cross_attention_kwargs = {"ip_adapter_image_prompt_embeds": [ip_embeds]}
ip_adapter_unet_patcher = UNetPatcher([ip_adapter])
with ip_adapter_unet_patcher.apply_ip_adapter_attention(unet):
output = unet(**dummy_unet_input, cross_attention_kwargs=cross_attention_kwargs).sample