Add unit test

This commit is contained in:
Brandon Rising 2024-03-15 14:23:30 -04:00 committed by Brandon
parent 902e26507d
commit 9d0952c2ef
4 changed files with 62 additions and 4 deletions

View File

@ -132,7 +132,8 @@ class ModelProbe(object):
format_type = ModelFormat.Diffusers if model_path.is_dir() else ModelFormat.Checkpoint
model_info = None
model_type = ModelType(fields['type']) if 'type' in fields else None
model_type = fields['type'] if 'type' in fields else None
model_type = ModelType(model_type) if isinstance(model_type, str) else model_type
if not model_type:
if format_type is ModelFormat.Diffusers:
model_type = cls.get_model_type_from_folder(model_path)
@ -157,7 +158,7 @@ class ModelProbe(object):
fields["image_encoder_model_id"] = fields.get("image_encoder_model_id") or probe.get_image_encoder_model_id()
fields["name"] = fields.get("name") or cls.get_model_name(model_path)
fields["description"] = (
fields.get("description") or f"{fields['base'].value} {fields['type'].value} model {fields['name']}"
fields.get("description") or f"{fields['base'].value} {model_type.value} model {fields['name']}"
)
fields["format"] = fields.get("format") or probe.get_format()
fields["hash"] = fields.get("hash") or ModelHash(algorithm=hash_algo).hash(model_path)

View File

@ -9,6 +9,7 @@ from pathlib import Path
import pytest
from pydantic import ValidationError
from pydantic.networks import Url
from time import sleep
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.events.events_base import EventServiceBase
@ -20,7 +21,7 @@ from invokeai.app.services.model_install import (
URLModelSource,
)
from invokeai.app.services.model_records import UnknownModelException
from invokeai.backend.model_manager.config import BaseModelType, ModelFormat, ModelType
from invokeai.backend.model_manager.config import BaseModelType, ModelFormat, ModelType, InvalidModelConfigException
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
OS = platform.uname().system
@ -53,7 +54,7 @@ def test_registration_meta(mm2_installer: ModelInstallServiceBase, embedding_fil
def test_registration_meta_override_fail(mm2_installer: ModelInstallServiceBase, embedding_file: Path) -> None:
key = None
with pytest.raises(ValidationError):
with pytest.raises((ValidationError, InvalidModelConfigException)):
key = mm2_installer.register_path(embedding_file, {"name": "banana_sushi", "type": ModelType("lora")})
assert key is None
@ -263,3 +264,36 @@ def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: In
assert job.error
assert "NOT FOUND" in job.error
assert job.error_traceback.startswith("Traceback")
@pytest.mark.parametrize(
"model_params",
[
# SDXL, Lora
{
"repo_id": "InvokeAI-test/textual_inversion_tests::learned_embeds-steps-1000.safetensors",
"name": "test_lora",
"type": 'embedding',
},
# SDXL, Lora - incorrect type
{
"repo_id": "InvokeAI-test/textual_inversion_tests::learned_embeds-steps-1000.safetensors",
"name": "test_lora",
"type": 'lora',
},
],
)
def test_heuristic_import_with_type(mm2_installer: ModelInstallServiceBase, model_params: Dict[str, str]):
"""Test whether or not type is respected on configs when passed to heuristic import."""
config: Dict[str, Any] = {
"type": model_params["type"],
}
try:
assert("repo_id" in model_params)
install_job = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config)
while not install_job.in_terminal_state:
sleep(.01)
assert(install_job.config_out if model_params["type"] == "embedding" else not install_job.config_out)
except InvalidModelConfigException:
assert model_params["type"] != "embedding"

View File

@ -38,6 +38,7 @@ from tests.backend.model_manager.model_metadata.metadata_examples import (
RepoHFMetadata1,
RepoHFMetadata1_nofp16,
RepoHFModelJson1,
HFTestLoraMetadata,
)
from tests.fixtures.sqlite_database import create_mock_sqlite_database
@ -300,6 +301,23 @@ def mm2_session(embedding_file: Path, diffusers_dir: Path) -> Session:
headers={"Content-Type": "application/json; charset=utf-8", "Content-Length": len(RepoHFMetadata1)},
),
)
with open(embedding_file, "rb") as f:
data = f.read() # file is small - just 15K
sess.mount(
"https://huggingface.co/api/models/InvokeAI-test/textual_inversion_tests?blobs=True",
TestAdapter(
HFTestLoraMetadata,
headers={"Content-Type": "application/json; charset=utf-8", "Content-Length": len(HFTestLoraMetadata)},
),
)
sess.mount(
"https://huggingface.co/InvokeAI-test/textual_inversion_tests/resolve/main/learned_embeds-steps-1000.safetensors",
TestAdapter(
data,
headers={"Content-Type": "application/json; charset=utf-8", "Content-Length": len(data)},
),
)
for root, _, files in os.walk(diffusers_dir):
for name in files:
path = Path(root, name)

File diff suppressed because one or more lines are too long