POC of a test that depends on models.

This commit is contained in:
Ryan Dick 2023-09-22 18:44:10 -04:00
parent 78377469db
commit 1c8b1fbc53
5 changed files with 152 additions and 1 deletions

View File

@ -0,0 +1,77 @@
import contextlib
from pathlib import Path
from typing import Optional, Union
import pytest
import torch
from invokeai.app.services.config.invokeai_config import InvokeAIAppConfig
from invokeai.backend.install.model_install_backend import ModelInstall
from invokeai.backend.model_management.model_manager import ModelInfo
from invokeai.backend.model_management.models.base import BaseModelType, ModelNotFoundException, ModelType, SubModelType
def slow(test_case):
"""Decorator for slow tests.
Tests should be marked as slow if they download a model, run model inference, or do anything else slow. To judge
whether a test is 'slow', consider how it would perform in a CPU-only environment with a low-bandwidth internet
connection.
"""
return pytest.mark.slow(test_case)
@pytest.fixture(scope="session")
def torch_device():
return "cuda" if torch.cuda.is_available() else "cpu"
@pytest.fixture(scope="module")
def model_installer():
"""A global ModelInstall pytest fixture to be used by many tests."""
# HACK(ryand): InvokeAIAppConfig.get_config() returns a singleton config object. This can lead to weird interactions
# between tests that need to alter the config. For example, some tests change the 'root' directory in the config,
# which can cause `install_and_load_model(...)` to re-download the model unnecessarily. As a temporary workaround,
# we pass a kwarg to get_config, which causes the config to be re-loaded. To fix this properly, we should stop using
# a singleton.
return ModelInstall(InvokeAIAppConfig.get_config(log_level="info"))
def install_and_load_model(
model_installer: ModelInstall,
model_path_id_or_url: Union[str, Path],
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
submodel_type: Optional[SubModelType] = None,
) -> ModelInfo:
"""Install a model if it is not already installed, then get the ModelInfo for that model.
This is intended as a utility function for tests.
Args:
model_installer (ModelInstall): The model installer.
model_path_id_or_url (Union[str, Path]): The path, HF ID, URL, etc. where the model can be installed from if it
is not already installed.
model_name (str): The model name, forwarded to ModelManager.get_model(...).
base_model (BaseModelType): The base model, forwarded to ModelManager.get_model(...).
model_type (ModelType): The model type, forwarded to ModelManager.get_model(...).
submodel_type (Optional[SubModelType]): The submodel type, forwarded to ModelManager.get_model(...).
Returns:
ModelInfo
"""
# If the requested model is already installed, return its ModelInfo.
with contextlib.suppress(ModelNotFoundException):
return model_installer.mgr.get_model(model_name, base_model, model_type, submodel_type)
# Install the requested model.
model_installer.heuristic_import(model_path_id_or_url)
try:
return model_installer.mgr.get_model(model_name, base_model, model_type, submodel_type)
except ModelNotFoundException as e:
raise Exception(
"Failed to get model info after installing it. There could be a mismatch between the requested model and"
f" the installation id ('{model_path_id_or_url}'). Error: {e}"
)

View File

@ -178,7 +178,10 @@ version = { attr = "invokeai.version.__version__" }
#=== Begin: PyTest and Coverage
[tool.pytest.ini_options]
addopts = "--cov-report term --cov-report html --cov-report xml"
addopts = "--cov-report term --cov-report html --cov-report xml --strict-markers -m \"not slow\""
markers = [
"slow: Marks tests as slow. Disabled by default. To run all tests, use -m \"\". To run only slow tests, use -m \"slow\"."
]
[tool.coverage.run]
branch = true
source = ["invokeai"]

View File

View File

View File

@ -0,0 +1,71 @@
import pytest
import torch
from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType
from invokeai.backend.util.test_utils import install_and_load_model, model_installer, slow, torch_device
def build_dummy_sd15_unet_input(torch_device):
batch_size = 1
num_channels = 4
sizes = (32, 32)
noise = torch.randn((batch_size, num_channels) + sizes).to(torch_device)
time_step = torch.tensor([10]).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, 77, 768)).to(torch_device)
return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states}
@pytest.mark.parametrize(
"model_params",
[
# SD1.5, IPAdapter
{
"ip_adapter_model_id": "InvokeAI/ip_adapter_sd15",
"ip_adapter_model_name": "ip_adapter_sd15",
"base_model": BaseModelType.StableDiffusion1,
"unet_model_id": "runwayml/stable-diffusion-v1-5",
"unet_model_name": "stable-diffusion-v1-5",
},
# SD1.5, IPAdapterPlus
{
"ip_adapter_model_id": "InvokeAI/ip_adapter_plus_sd15",
"ip_adapter_model_name": "ip_adapter_plus_sd15",
"base_model": BaseModelType.StableDiffusion1,
"unet_model_id": "runwayml/stable-diffusion-v1-5",
"unet_model_name": "stable-diffusion-v1-5",
},
],
)
@slow
def test_ip_adapter_unet_patch(model_params, model_installer, torch_device):
"""Smoke test that IP-Adapter weights can be loaded and used to patch a UNet."""
ip_adapter_info = install_and_load_model(
model_installer=model_installer,
model_path_id_or_url=model_params["ip_adapter_model_id"],
model_name=model_params["ip_adapter_model_name"],
base_model=model_params["base_model"],
model_type=ModelType.IPAdapter,
)
unet_info = install_and_load_model(
model_installer=model_installer,
model_path_id_or_url=model_params["unet_model_id"],
model_name=model_params["unet_model_name"],
base_model=model_params["base_model"],
model_type=ModelType.Main,
submodel_type=SubModelType.UNet,
)
dummy_unet_input = build_dummy_sd15_unet_input(torch_device)
with torch.no_grad(), ip_adapter_info as ip_adapter, unet_info as unet:
ip_adapter.to(torch_device, dtype=torch.float32)
unet.to(torch_device, dtype=torch.float32)
cross_attention_kwargs = {"ip_adapter_image_prompt_embeds": torch.randn((1, 4, 768)).to(torch_device)}
with ip_adapter.apply_ip_adapter_attention(unet, 1.0):
output = unet(**dummy_unet_input, cross_attention_kwargs=cross_attention_kwargs).sample
assert output.shape == dummy_unet_input["sample"].shape