mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
POC of a test that depends on models.
This commit is contained in:
parent
78377469db
commit
1c8b1fbc53
77
invokeai/backend/util/test_utils.py
Normal file
77
invokeai/backend/util/test_utils.py
Normal file
@ -0,0 +1,77 @@
|
||||
import contextlib
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from invokeai.app.services.config.invokeai_config import InvokeAIAppConfig
|
||||
from invokeai.backend.install.model_install_backend import ModelInstall
|
||||
from invokeai.backend.model_management.model_manager import ModelInfo
|
||||
from invokeai.backend.model_management.models.base import BaseModelType, ModelNotFoundException, ModelType, SubModelType
|
||||
|
||||
|
||||
def slow(test_case):
|
||||
"""Decorator for slow tests.
|
||||
|
||||
Tests should be marked as slow if they download a model, run model inference, or do anything else slow. To judge
|
||||
whether a test is 'slow', consider how it would perform in a CPU-only environment with a low-bandwidth internet
|
||||
connection.
|
||||
"""
|
||||
return pytest.mark.slow(test_case)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def torch_device():
|
||||
return "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def model_installer():
|
||||
"""A global ModelInstall pytest fixture to be used by many tests."""
|
||||
# HACK(ryand): InvokeAIAppConfig.get_config() returns a singleton config object. This can lead to weird interactions
|
||||
# between tests that need to alter the config. For example, some tests change the 'root' directory in the config,
|
||||
# which can cause `install_and_load_model(...)` to re-download the model unnecessarily. As a temporary workaround,
|
||||
# we pass a kwarg to get_config, which causes the config to be re-loaded. To fix this properly, we should stop using
|
||||
# a singleton.
|
||||
return ModelInstall(InvokeAIAppConfig.get_config(log_level="info"))
|
||||
|
||||
|
||||
def install_and_load_model(
|
||||
model_installer: ModelInstall,
|
||||
model_path_id_or_url: Union[str, Path],
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> ModelInfo:
|
||||
"""Install a model if it is not already installed, then get the ModelInfo for that model.
|
||||
|
||||
This is intended as a utility function for tests.
|
||||
|
||||
Args:
|
||||
model_installer (ModelInstall): The model installer.
|
||||
model_path_id_or_url (Union[str, Path]): The path, HF ID, URL, etc. where the model can be installed from if it
|
||||
is not already installed.
|
||||
model_name (str): The model name, forwarded to ModelManager.get_model(...).
|
||||
base_model (BaseModelType): The base model, forwarded to ModelManager.get_model(...).
|
||||
model_type (ModelType): The model type, forwarded to ModelManager.get_model(...).
|
||||
submodel_type (Optional[SubModelType]): The submodel type, forwarded to ModelManager.get_model(...).
|
||||
|
||||
Returns:
|
||||
ModelInfo
|
||||
"""
|
||||
# If the requested model is already installed, return its ModelInfo.
|
||||
with contextlib.suppress(ModelNotFoundException):
|
||||
return model_installer.mgr.get_model(model_name, base_model, model_type, submodel_type)
|
||||
|
||||
# Install the requested model.
|
||||
model_installer.heuristic_import(model_path_id_or_url)
|
||||
|
||||
try:
|
||||
return model_installer.mgr.get_model(model_name, base_model, model_type, submodel_type)
|
||||
except ModelNotFoundException as e:
|
||||
raise Exception(
|
||||
"Failed to get model info after installing it. There could be a mismatch between the requested model and"
|
||||
f" the installation id ('{model_path_id_or_url}'). Error: {e}"
|
||||
)
|
@ -178,7 +178,10 @@ version = { attr = "invokeai.version.__version__" }
|
||||
|
||||
#=== Begin: PyTest and Coverage
|
||||
[tool.pytest.ini_options]
|
||||
addopts = "--cov-report term --cov-report html --cov-report xml"
|
||||
addopts = "--cov-report term --cov-report html --cov-report xml --strict-markers -m \"not slow\""
|
||||
markers = [
|
||||
"slow: Marks tests as slow. Disabled by default. To run all tests, use -m \"\". To run only slow tests, use -m \"slow\"."
|
||||
]
|
||||
[tool.coverage.run]
|
||||
branch = true
|
||||
source = ["invokeai"]
|
||||
|
0
tests/backend/__init__.py
Normal file
0
tests/backend/__init__.py
Normal file
0
tests/backend/ip_adapter/__init__.py
Normal file
0
tests/backend/ip_adapter/__init__.py
Normal file
71
tests/backend/ip_adapter/test_ip_adapter.py
Normal file
71
tests/backend/ip_adapter/test_ip_adapter.py
Normal file
@ -0,0 +1,71 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType
|
||||
from invokeai.backend.util.test_utils import install_and_load_model, model_installer, slow, torch_device
|
||||
|
||||
|
||||
def build_dummy_sd15_unet_input(torch_device):
|
||||
batch_size = 1
|
||||
num_channels = 4
|
||||
sizes = (32, 32)
|
||||
|
||||
noise = torch.randn((batch_size, num_channels) + sizes).to(torch_device)
|
||||
time_step = torch.tensor([10]).to(torch_device)
|
||||
encoder_hidden_states = torch.randn((batch_size, 77, 768)).to(torch_device)
|
||||
|
||||
return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_params",
|
||||
[
|
||||
# SD1.5, IPAdapter
|
||||
{
|
||||
"ip_adapter_model_id": "InvokeAI/ip_adapter_sd15",
|
||||
"ip_adapter_model_name": "ip_adapter_sd15",
|
||||
"base_model": BaseModelType.StableDiffusion1,
|
||||
"unet_model_id": "runwayml/stable-diffusion-v1-5",
|
||||
"unet_model_name": "stable-diffusion-v1-5",
|
||||
},
|
||||
# SD1.5, IPAdapterPlus
|
||||
{
|
||||
"ip_adapter_model_id": "InvokeAI/ip_adapter_plus_sd15",
|
||||
"ip_adapter_model_name": "ip_adapter_plus_sd15",
|
||||
"base_model": BaseModelType.StableDiffusion1,
|
||||
"unet_model_id": "runwayml/stable-diffusion-v1-5",
|
||||
"unet_model_name": "stable-diffusion-v1-5",
|
||||
},
|
||||
],
|
||||
)
|
||||
@slow
|
||||
def test_ip_adapter_unet_patch(model_params, model_installer, torch_device):
|
||||
"""Smoke test that IP-Adapter weights can be loaded and used to patch a UNet."""
|
||||
ip_adapter_info = install_and_load_model(
|
||||
model_installer=model_installer,
|
||||
model_path_id_or_url=model_params["ip_adapter_model_id"],
|
||||
model_name=model_params["ip_adapter_model_name"],
|
||||
base_model=model_params["base_model"],
|
||||
model_type=ModelType.IPAdapter,
|
||||
)
|
||||
|
||||
unet_info = install_and_load_model(
|
||||
model_installer=model_installer,
|
||||
model_path_id_or_url=model_params["unet_model_id"],
|
||||
model_name=model_params["unet_model_name"],
|
||||
base_model=model_params["base_model"],
|
||||
model_type=ModelType.Main,
|
||||
submodel_type=SubModelType.UNet,
|
||||
)
|
||||
|
||||
dummy_unet_input = build_dummy_sd15_unet_input(torch_device)
|
||||
|
||||
with torch.no_grad(), ip_adapter_info as ip_adapter, unet_info as unet:
|
||||
ip_adapter.to(torch_device, dtype=torch.float32)
|
||||
unet.to(torch_device, dtype=torch.float32)
|
||||
|
||||
cross_attention_kwargs = {"ip_adapter_image_prompt_embeds": torch.randn((1, 4, 768)).to(torch_device)}
|
||||
with ip_adapter.apply_ip_adapter_attention(unet, 1.0):
|
||||
output = unet(**dummy_unet_input, cross_attention_kwargs=cross_attention_kwargs).sample
|
||||
|
||||
assert output.shape == dummy_unet_input["sample"].shape
|
Loading…
Reference in New Issue
Block a user