import pytest
import torch

from invokeai.backend.ip_adapter.unet_patcher import UNetPatcher
from invokeai.backend.model_manager import BaseModelType, ModelType, SubModelType
from invokeai.backend.util.test_utils import install_and_load_model


def build_dummy_sd15_unet_input(torch_device):
    batch_size = 1
    num_channels = 4
    sizes = (32, 32)

    noise = torch.randn((batch_size, num_channels) + sizes).to(torch_device)
    time_step = torch.tensor([10]).to(torch_device)
    encoder_hidden_states = torch.randn((batch_size, 77, 768)).to(torch_device)

    return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states}


@pytest.mark.parametrize(
    "model_params",
    [
        # SD1.5, IPAdapter
        {
            "ip_adapter_model_id": "InvokeAI/ip_adapter_sd15",
            "ip_adapter_model_name": "ip_adapter_sd15",
            "base_model": BaseModelType.StableDiffusion1,
            "unet_model_id": "runwayml/stable-diffusion-v1-5",
            "unet_model_name": "stable-diffusion-v1-5",
        },
        # SD1.5, IPAdapterPlus
        {
            "ip_adapter_model_id": "InvokeAI/ip_adapter_plus_sd15",
            "ip_adapter_model_name": "ip_adapter_plus_sd15",
            "base_model": BaseModelType.StableDiffusion1,
            "unet_model_id": "runwayml/stable-diffusion-v1-5",
            "unet_model_name": "stable-diffusion-v1-5",
        },
        # SD1.5, IPAdapterFull
        {
            "ip_adapter_model_id": "InvokeAI/ip-adapter-full-face_sd15",
            "ip_adapter_model_name": "ip-adapter-full-face_sd15",
            "base_model": BaseModelType.StableDiffusion1,
            "unet_model_id": "runwayml/stable-diffusion-v1-5",
            "unet_model_name": "stable-diffusion-v1-5",
        },
    ],
)
@pytest.mark.slow
def test_ip_adapter_unet_patch(model_params, model_installer, torch_device):
    """Smoke test that IP-Adapter weights can be loaded and used to patch a UNet."""
    ip_adapter_info = install_and_load_model(
        model_installer=model_installer,
        model_path_id_or_url=model_params["ip_adapter_model_id"],
        model_name=model_params["ip_adapter_model_name"],
        base_model=model_params["base_model"],
        model_type=ModelType.IPAdapter,
    )

    unet_info = install_and_load_model(
        model_installer=model_installer,
        model_path_id_or_url=model_params["unet_model_id"],
        model_name=model_params["unet_model_name"],
        base_model=model_params["base_model"],
        model_type=ModelType.Main,
        submodel_type=SubModelType.UNet,
    )

    dummy_unet_input = build_dummy_sd15_unet_input(torch_device)

    with torch.no_grad(), ip_adapter_info as ip_adapter, unet_info as unet:
        ip_adapter.to(torch_device, dtype=torch.float32)
        unet.to(torch_device, dtype=torch.float32)

        # ip_embeds shape: (batch_size, num_ip_images, seq_len, ip_image_embedding_len)
        ip_embeds = torch.randn((1, 3, 4, 768)).to(torch_device)

        cross_attention_kwargs = {"ip_adapter_image_prompt_embeds": [ip_embeds]}
        ip_adapter_unet_patcher = UNetPatcher([ip_adapter])
        with ip_adapter_unet_patcher.apply_ip_adapter_attention(unet):
            output = unet(**dummy_unet_input, cross_attention_kwargs=cross_attention_kwargs).sample

    assert output.shape == dummy_unet_input["sample"].shape