mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
test(model_management): add a couple tests for _get_model_path
This commit is contained in:
parent
65ed224bfc
commit
44bf308192
@ -533,7 +533,7 @@ class ModelManager(object):
|
|||||||
model_path = self.resolve_model_path(model_path)
|
model_path = self.resolve_model_path(model_path)
|
||||||
return model_path, is_submodel_override
|
return model_path, is_submodel_override
|
||||||
|
|
||||||
def _get_model_config(self, base_model, model_name, model_type) -> ModelConfigBase:
|
def _get_model_config(self, base_model: BaseModelType, model_name: str, model_type: ModelType) -> ModelConfigBase:
|
||||||
"""Get a model's config object."""
|
"""Get a model's config object."""
|
||||||
model_key = self.create_key(model_name, base_model, model_type)
|
model_key = self.create_key(model_name, base_model, model_type)
|
||||||
try:
|
try:
|
||||||
|
@ -100,7 +100,7 @@ dependencies = [
|
|||||||
"dev" = [
|
"dev" = [
|
||||||
"pudb",
|
"pudb",
|
||||||
]
|
]
|
||||||
"test" = ["pytest>6.0.0", "pytest-cov", "black"]
|
"test" = ["pytest>6.0.0", "pytest-cov", "pytest-datadir", "black"]
|
||||||
"xformers" = [
|
"xformers" = [
|
||||||
"xformers~=0.0.19; sys_platform!='darwin'",
|
"xformers~=0.0.19; sys_platform!='darwin'",
|
||||||
"triton; sys_platform=='linux'",
|
"triton; sys_platform=='linux'",
|
||||||
|
36
tests/test_model_manager.py
Normal file
36
tests/test_model_manager.py
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
|
from invokeai.backend import ModelManager, BaseModelType, ModelType, SubModelType
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def model_manager(datadir) -> ModelManager:
|
||||||
|
InvokeAIAppConfig.get_config(root=datadir)
|
||||||
|
return ModelManager(datadir / "configs" / "relative_sub.models.yaml")
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_model_names(model_manager: ModelManager):
|
||||||
|
names = model_manager.model_names()
|
||||||
|
assert names[:2] == [
|
||||||
|
("SDXL base", BaseModelType.StableDiffusionXL, ModelType.Main),
|
||||||
|
("SDXL with VAE", BaseModelType.StableDiffusionXL, ModelType.Main),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_model_path_for_diffusers(model_manager: ModelManager, datadir: Path):
|
||||||
|
model_config = model_manager._get_model_config(BaseModelType.StableDiffusionXL, "SDXL base", ModelType.Main)
|
||||||
|
top_model_path, is_override = model_manager._get_model_path(model_config)
|
||||||
|
expected_model_path = datadir / "models" / "sdxl" / "main" / "SDXL base 1_0"
|
||||||
|
assert top_model_path == expected_model_path
|
||||||
|
assert not is_override
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_model_path_for_overridden_vae(model_manager: ModelManager, datadir: Path):
|
||||||
|
model_config = model_manager._get_model_config(BaseModelType.StableDiffusionXL, "SDXL with VAE", ModelType.Main)
|
||||||
|
vae_model_path, is_override = model_manager._get_model_path(model_config, SubModelType.Vae)
|
||||||
|
expected_vae_path = datadir / "models" / "sdxl" / "vae" / "sdxl-vae-fp16-fix"
|
||||||
|
assert vae_model_path == expected_vae_path
|
||||||
|
assert is_override
|
15
tests/test_model_manager/configs/relative_sub.models.yaml
Normal file
15
tests/test_model_manager/configs/relative_sub.models.yaml
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
__metadata__:
|
||||||
|
version: 3.0.0
|
||||||
|
|
||||||
|
sdxl/main/SDXL base:
|
||||||
|
path: sdxl/main/SDXL base 1_0
|
||||||
|
description: SDXL base v1.0
|
||||||
|
variant: normal
|
||||||
|
format: diffusers
|
||||||
|
|
||||||
|
sdxl/main/SDXL with VAE:
|
||||||
|
path: sdxl/main/SDXL base 1_0
|
||||||
|
description: SDXL base v1.0
|
||||||
|
vae: sdxl/vae/sdxl-vae-fp16-fix/
|
||||||
|
variant: normal
|
||||||
|
format: diffusers
|
Loading…
Reference in New Issue
Block a user