InvokeAI/tests/backend/ip_adapter/test_ip_adapter.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

77 lines
3.0 KiB
Python
Raw Normal View History

2023-09-22 22:44:10 +00:00
import pytest
import torch
from invokeai.backend.ip_adapter.unet_patcher import UNetPatcher
2023-09-22 22:44:10 +00:00
from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType
from invokeai.backend.util.test_utils import install_and_load_model
2023-09-22 22:44:10 +00:00
def build_dummy_sd15_unet_input(torch_device):
batch_size = 1
num_channels = 4
sizes = (32, 32)
noise = torch.randn((batch_size, num_channels) + sizes).to(torch_device)
time_step = torch.tensor([10]).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, 77, 768)).to(torch_device)
return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states}
@pytest.mark.parametrize(
"model_params",
[
# SD1.5, IPAdapter
{
"ip_adapter_model_id": "InvokeAI/ip_adapter_sd15",
"ip_adapter_model_name": "ip_adapter_sd15",
"base_model": BaseModelType.StableDiffusion1,
"unet_model_id": "runwayml/stable-diffusion-v1-5",
"unet_model_name": "stable-diffusion-v1-5",
},
# SD1.5, IPAdapterPlus
{
"ip_adapter_model_id": "InvokeAI/ip_adapter_plus_sd15",
"ip_adapter_model_name": "ip_adapter_plus_sd15",
"base_model": BaseModelType.StableDiffusion1,
"unet_model_id": "runwayml/stable-diffusion-v1-5",
"unet_model_name": "stable-diffusion-v1-5",
},
],
)
@pytest.mark.slow
2023-09-22 22:44:10 +00:00
def test_ip_adapter_unet_patch(model_params, model_installer, torch_device):
"""Smoke test that IP-Adapter weights can be loaded and used to patch a UNet."""
ip_adapter_info = install_and_load_model(
model_installer=model_installer,
model_path_id_or_url=model_params["ip_adapter_model_id"],
model_name=model_params["ip_adapter_model_name"],
base_model=model_params["base_model"],
model_type=ModelType.IPAdapter,
)
unet_info = install_and_load_model(
model_installer=model_installer,
model_path_id_or_url=model_params["unet_model_id"],
model_name=model_params["unet_model_name"],
base_model=model_params["base_model"],
model_type=ModelType.Main,
submodel_type=SubModelType.UNet,
)
dummy_unet_input = build_dummy_sd15_unet_input(torch_device)
with torch.no_grad(), ip_adapter_info as ip_adapter, unet_info as unet:
ip_adapter.to(torch_device, dtype=torch.float32)
unet.to(torch_device, dtype=torch.float32)
# ip_embeds shape: (batch_size, num_ip_images, seq_len, ip_image_embedding_len)
ip_embeds = torch.randn((1, 3, 4, 768)).to(torch_device)
cross_attention_kwargs = {"ip_adapter_image_prompt_embeds": [ip_embeds]}
ip_adapter_unet_patcher = UNetPatcher([ip_adapter])
with ip_adapter_unet_patcher.apply_ip_adapter_attention(unet):
2023-09-22 22:44:10 +00:00
output = unet(**dummy_unet_input, cross_attention_kwargs=cross_attention_kwargs).sample
assert output.shape == dummy_unet_input["sample"].shape