Create a UNetAttentionPatcher for patching UNet models with CustomAttnProcessor2_0 modules.

This commit is contained in:
Ryan Dick
2024-03-08 14:15:16 -05:00
committed by Kent Keirsey
parent 31c456c1e6
commit 7ca677578e
4 changed files with 25 additions and 205 deletions

View File

@ -1,8 +1,8 @@
import pytest
import torch
from invokeai.backend.ip_adapter.unet_patcher import UNetPatcher
from invokeai.backend.model_manager import BaseModelType, ModelType, SubModelType
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher
from invokeai.backend.util.test_utils import install_and_load_model
@ -77,7 +77,7 @@ def test_ip_adapter_unet_patch(model_params, model_installer, torch_device):
ip_embeds = torch.randn((1, 3, 4, 768)).to(torch_device)
cross_attention_kwargs = {"ip_adapter_image_prompt_embeds": [ip_embeds]}
ip_adapter_unet_patcher = UNetPatcher([ip_adapter])
ip_adapter_unet_patcher = UNetAttentionPatcher([ip_adapter])
with ip_adapter_unet_patcher.apply_ip_adapter_attention(unet):
output = unet(**dummy_unet_input, cross_attention_kwargs=cross_attention_kwargs).sample