mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Create a UNetAttentionPatcher for patching UNet models with CustomAttnProcessor2_0 modules.
This commit is contained in:
@ -1,8 +1,8 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from invokeai.backend.ip_adapter.unet_patcher import UNetPatcher
|
||||
from invokeai.backend.model_manager import BaseModelType, ModelType, SubModelType
|
||||
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher
|
||||
from invokeai.backend.util.test_utils import install_and_load_model
|
||||
|
||||
|
||||
@ -77,7 +77,7 @@ def test_ip_adapter_unet_patch(model_params, model_installer, torch_device):
|
||||
ip_embeds = torch.randn((1, 3, 4, 768)).to(torch_device)
|
||||
|
||||
cross_attention_kwargs = {"ip_adapter_image_prompt_embeds": [ip_embeds]}
|
||||
ip_adapter_unet_patcher = UNetPatcher([ip_adapter])
|
||||
ip_adapter_unet_patcher = UNetAttentionPatcher([ip_adapter])
|
||||
with ip_adapter_unet_patcher.apply_ip_adapter_attention(unet):
|
||||
output = unet(**dummy_unet_input, cross_attention_kwargs=cross_attention_kwargs).sample
|
||||
|
||||
|
Reference in New Issue
Block a user