mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
POC of a test that depends on models.
This commit is contained in:
parent
78377469db
commit
1c8b1fbc53
77
invokeai/backend/util/test_utils.py
Normal file
77
invokeai/backend/util/test_utils.py
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
import contextlib
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from invokeai.app.services.config.invokeai_config import InvokeAIAppConfig
|
||||||
|
from invokeai.backend.install.model_install_backend import ModelInstall
|
||||||
|
from invokeai.backend.model_management.model_manager import ModelInfo
|
||||||
|
from invokeai.backend.model_management.models.base import BaseModelType, ModelNotFoundException, ModelType, SubModelType
|
||||||
|
|
||||||
|
|
||||||
|
def slow(test_case):
|
||||||
|
"""Decorator for slow tests.
|
||||||
|
|
||||||
|
Tests should be marked as slow if they download a model, run model inference, or do anything else slow. To judge
|
||||||
|
whether a test is 'slow', consider how it would perform in a CPU-only environment with a low-bandwidth internet
|
||||||
|
connection.
|
||||||
|
"""
|
||||||
|
return pytest.mark.slow(test_case)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def torch_device():
|
||||||
|
return "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def model_installer():
|
||||||
|
"""A global ModelInstall pytest fixture to be used by many tests."""
|
||||||
|
# HACK(ryand): InvokeAIAppConfig.get_config() returns a singleton config object. This can lead to weird interactions
|
||||||
|
# between tests that need to alter the config. For example, some tests change the 'root' directory in the config,
|
||||||
|
# which can cause `install_and_load_model(...)` to re-download the model unnecessarily. As a temporary workaround,
|
||||||
|
# we pass a kwarg to get_config, which causes the config to be re-loaded. To fix this properly, we should stop using
|
||||||
|
# a singleton.
|
||||||
|
return ModelInstall(InvokeAIAppConfig.get_config(log_level="info"))
|
||||||
|
|
||||||
|
|
||||||
|
def install_and_load_model(
|
||||||
|
model_installer: ModelInstall,
|
||||||
|
model_path_id_or_url: Union[str, Path],
|
||||||
|
model_name: str,
|
||||||
|
base_model: BaseModelType,
|
||||||
|
model_type: ModelType,
|
||||||
|
submodel_type: Optional[SubModelType] = None,
|
||||||
|
) -> ModelInfo:
|
||||||
|
"""Install a model if it is not already installed, then get the ModelInfo for that model.
|
||||||
|
|
||||||
|
This is intended as a utility function for tests.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_installer (ModelInstall): The model installer.
|
||||||
|
model_path_id_or_url (Union[str, Path]): The path, HF ID, URL, etc. where the model can be installed from if it
|
||||||
|
is not already installed.
|
||||||
|
model_name (str): The model name, forwarded to ModelManager.get_model(...).
|
||||||
|
base_model (BaseModelType): The base model, forwarded to ModelManager.get_model(...).
|
||||||
|
model_type (ModelType): The model type, forwarded to ModelManager.get_model(...).
|
||||||
|
submodel_type (Optional[SubModelType]): The submodel type, forwarded to ModelManager.get_model(...).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ModelInfo
|
||||||
|
"""
|
||||||
|
# If the requested model is already installed, return its ModelInfo.
|
||||||
|
with contextlib.suppress(ModelNotFoundException):
|
||||||
|
return model_installer.mgr.get_model(model_name, base_model, model_type, submodel_type)
|
||||||
|
|
||||||
|
# Install the requested model.
|
||||||
|
model_installer.heuristic_import(model_path_id_or_url)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return model_installer.mgr.get_model(model_name, base_model, model_type, submodel_type)
|
||||||
|
except ModelNotFoundException as e:
|
||||||
|
raise Exception(
|
||||||
|
"Failed to get model info after installing it. There could be a mismatch between the requested model and"
|
||||||
|
f" the installation id ('{model_path_id_or_url}'). Error: {e}"
|
||||||
|
)
|
@ -178,7 +178,10 @@ version = { attr = "invokeai.version.__version__" }
|
|||||||
|
|
||||||
#=== Begin: PyTest and Coverage
|
#=== Begin: PyTest and Coverage
|
||||||
[tool.pytest.ini_options]
|
[tool.pytest.ini_options]
|
||||||
addopts = "--cov-report term --cov-report html --cov-report xml"
|
addopts = "--cov-report term --cov-report html --cov-report xml --strict-markers -m \"not slow\""
|
||||||
|
markers = [
|
||||||
|
"slow: Marks tests as slow. Disabled by default. To run all tests, use -m \"\". To run only slow tests, use -m \"slow\"."
|
||||||
|
]
|
||||||
[tool.coverage.run]
|
[tool.coverage.run]
|
||||||
branch = true
|
branch = true
|
||||||
source = ["invokeai"]
|
source = ["invokeai"]
|
||||||
|
0
tests/backend/__init__.py
Normal file
0
tests/backend/__init__.py
Normal file
0
tests/backend/ip_adapter/__init__.py
Normal file
0
tests/backend/ip_adapter/__init__.py
Normal file
71
tests/backend/ip_adapter/test_ip_adapter.py
Normal file
71
tests/backend/ip_adapter/test_ip_adapter.py
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType
|
||||||
|
from invokeai.backend.util.test_utils import install_and_load_model, model_installer, slow, torch_device
|
||||||
|
|
||||||
|
|
||||||
|
def build_dummy_sd15_unet_input(torch_device):
|
||||||
|
batch_size = 1
|
||||||
|
num_channels = 4
|
||||||
|
sizes = (32, 32)
|
||||||
|
|
||||||
|
noise = torch.randn((batch_size, num_channels) + sizes).to(torch_device)
|
||||||
|
time_step = torch.tensor([10]).to(torch_device)
|
||||||
|
encoder_hidden_states = torch.randn((batch_size, 77, 768)).to(torch_device)
|
||||||
|
|
||||||
|
return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model_params",
|
||||||
|
[
|
||||||
|
# SD1.5, IPAdapter
|
||||||
|
{
|
||||||
|
"ip_adapter_model_id": "InvokeAI/ip_adapter_sd15",
|
||||||
|
"ip_adapter_model_name": "ip_adapter_sd15",
|
||||||
|
"base_model": BaseModelType.StableDiffusion1,
|
||||||
|
"unet_model_id": "runwayml/stable-diffusion-v1-5",
|
||||||
|
"unet_model_name": "stable-diffusion-v1-5",
|
||||||
|
},
|
||||||
|
# SD1.5, IPAdapterPlus
|
||||||
|
{
|
||||||
|
"ip_adapter_model_id": "InvokeAI/ip_adapter_plus_sd15",
|
||||||
|
"ip_adapter_model_name": "ip_adapter_plus_sd15",
|
||||||
|
"base_model": BaseModelType.StableDiffusion1,
|
||||||
|
"unet_model_id": "runwayml/stable-diffusion-v1-5",
|
||||||
|
"unet_model_name": "stable-diffusion-v1-5",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@slow
|
||||||
|
def test_ip_adapter_unet_patch(model_params, model_installer, torch_device):
|
||||||
|
"""Smoke test that IP-Adapter weights can be loaded and used to patch a UNet."""
|
||||||
|
ip_adapter_info = install_and_load_model(
|
||||||
|
model_installer=model_installer,
|
||||||
|
model_path_id_or_url=model_params["ip_adapter_model_id"],
|
||||||
|
model_name=model_params["ip_adapter_model_name"],
|
||||||
|
base_model=model_params["base_model"],
|
||||||
|
model_type=ModelType.IPAdapter,
|
||||||
|
)
|
||||||
|
|
||||||
|
unet_info = install_and_load_model(
|
||||||
|
model_installer=model_installer,
|
||||||
|
model_path_id_or_url=model_params["unet_model_id"],
|
||||||
|
model_name=model_params["unet_model_name"],
|
||||||
|
base_model=model_params["base_model"],
|
||||||
|
model_type=ModelType.Main,
|
||||||
|
submodel_type=SubModelType.UNet,
|
||||||
|
)
|
||||||
|
|
||||||
|
dummy_unet_input = build_dummy_sd15_unet_input(torch_device)
|
||||||
|
|
||||||
|
with torch.no_grad(), ip_adapter_info as ip_adapter, unet_info as unet:
|
||||||
|
ip_adapter.to(torch_device, dtype=torch.float32)
|
||||||
|
unet.to(torch_device, dtype=torch.float32)
|
||||||
|
|
||||||
|
cross_attention_kwargs = {"ip_adapter_image_prompt_embeds": torch.randn((1, 4, 768)).to(torch_device)}
|
||||||
|
with ip_adapter.apply_ip_adapter_attention(unet, 1.0):
|
||||||
|
output = unet(**dummy_unet_input, cross_attention_kwargs=cross_attention_kwargs).sample
|
||||||
|
|
||||||
|
assert output.shape == dummy_unet_input["sample"].shape
|
Loading…
Reference in New Issue
Block a user