Add unit test

This commit is contained in:
Brandon Rising 2024-03-15 14:23:30 -04:00 committed by Brandon
parent 902e26507d
commit 9d0952c2ef
4 changed files with 62 additions and 4 deletions

View File

@ -132,7 +132,8 @@ class ModelProbe(object):
format_type = ModelFormat.Diffusers if model_path.is_dir() else ModelFormat.Checkpoint format_type = ModelFormat.Diffusers if model_path.is_dir() else ModelFormat.Checkpoint
model_info = None model_info = None
model_type = ModelType(fields['type']) if 'type' in fields else None model_type = fields['type'] if 'type' in fields else None
model_type = ModelType(model_type) if isinstance(model_type, str) else model_type
if not model_type: if not model_type:
if format_type is ModelFormat.Diffusers: if format_type is ModelFormat.Diffusers:
model_type = cls.get_model_type_from_folder(model_path) model_type = cls.get_model_type_from_folder(model_path)
@ -157,7 +158,7 @@ class ModelProbe(object):
fields["image_encoder_model_id"] = fields.get("image_encoder_model_id") or probe.get_image_encoder_model_id() fields["image_encoder_model_id"] = fields.get("image_encoder_model_id") or probe.get_image_encoder_model_id()
fields["name"] = fields.get("name") or cls.get_model_name(model_path) fields["name"] = fields.get("name") or cls.get_model_name(model_path)
fields["description"] = ( fields["description"] = (
fields.get("description") or f"{fields['base'].value} {fields['type'].value} model {fields['name']}" fields.get("description") or f"{fields['base'].value} {model_type.value} model {fields['name']}"
) )
fields["format"] = fields.get("format") or probe.get_format() fields["format"] = fields.get("format") or probe.get_format()
fields["hash"] = fields.get("hash") or ModelHash(algorithm=hash_algo).hash(model_path) fields["hash"] = fields.get("hash") or ModelHash(algorithm=hash_algo).hash(model_path)

View File

@ -9,6 +9,7 @@ from pathlib import Path
import pytest import pytest
from pydantic import ValidationError from pydantic import ValidationError
from pydantic.networks import Url from pydantic.networks import Url
from time import sleep
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.events.events_base import EventServiceBase
@ -20,7 +21,7 @@ from invokeai.app.services.model_install import (
URLModelSource, URLModelSource,
) )
from invokeai.app.services.model_records import UnknownModelException from invokeai.app.services.model_records import UnknownModelException
from invokeai.backend.model_manager.config import BaseModelType, ModelFormat, ModelType from invokeai.backend.model_manager.config import BaseModelType, ModelFormat, ModelType, InvalidModelConfigException
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403 from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
OS = platform.uname().system OS = platform.uname().system
@ -53,7 +54,7 @@ def test_registration_meta(mm2_installer: ModelInstallServiceBase, embedding_fil
def test_registration_meta_override_fail(mm2_installer: ModelInstallServiceBase, embedding_file: Path) -> None: def test_registration_meta_override_fail(mm2_installer: ModelInstallServiceBase, embedding_file: Path) -> None:
key = None key = None
with pytest.raises(ValidationError): with pytest.raises((ValidationError, InvalidModelConfigException)):
key = mm2_installer.register_path(embedding_file, {"name": "banana_sushi", "type": ModelType("lora")}) key = mm2_installer.register_path(embedding_file, {"name": "banana_sushi", "type": ModelType("lora")})
assert key is None assert key is None
@ -263,3 +264,36 @@ def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: In
assert job.error assert job.error
assert "NOT FOUND" in job.error assert "NOT FOUND" in job.error
assert job.error_traceback.startswith("Traceback") assert job.error_traceback.startswith("Traceback")
@pytest.mark.parametrize(
"model_params",
[
# SDXL, Lora
{
"repo_id": "InvokeAI-test/textual_inversion_tests::learned_embeds-steps-1000.safetensors",
"name": "test_lora",
"type": 'embedding',
},
# SDXL, Lora - incorrect type
{
"repo_id": "InvokeAI-test/textual_inversion_tests::learned_embeds-steps-1000.safetensors",
"name": "test_lora",
"type": 'lora',
},
],
)
def test_heuristic_import_with_type(mm2_installer: ModelInstallServiceBase, model_params: Dict[str, str]):
"""Test whether or not type is respected on configs when passed to heuristic import."""
config: Dict[str, Any] = {
"type": model_params["type"],
}
try:
assert("repo_id" in model_params)
install_job = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config)
while not install_job.in_terminal_state:
sleep(.01)
assert(install_job.config_out if model_params["type"] == "embedding" else not install_job.config_out)
except InvalidModelConfigException:
assert model_params["type"] != "embedding"

View File

@ -38,6 +38,7 @@ from tests.backend.model_manager.model_metadata.metadata_examples import (
RepoHFMetadata1, RepoHFMetadata1,
RepoHFMetadata1_nofp16, RepoHFMetadata1_nofp16,
RepoHFModelJson1, RepoHFModelJson1,
HFTestLoraMetadata,
) )
from tests.fixtures.sqlite_database import create_mock_sqlite_database from tests.fixtures.sqlite_database import create_mock_sqlite_database
@ -300,6 +301,23 @@ def mm2_session(embedding_file: Path, diffusers_dir: Path) -> Session:
headers={"Content-Type": "application/json; charset=utf-8", "Content-Length": len(RepoHFMetadata1)}, headers={"Content-Type": "application/json; charset=utf-8", "Content-Length": len(RepoHFMetadata1)},
), ),
) )
with open(embedding_file, "rb") as f:
data = f.read() # file is small - just 15K
sess.mount(
"https://huggingface.co/api/models/InvokeAI-test/textual_inversion_tests?blobs=True",
TestAdapter(
HFTestLoraMetadata,
headers={"Content-Type": "application/json; charset=utf-8", "Content-Length": len(HFTestLoraMetadata)},
),
)
sess.mount(
"https://huggingface.co/InvokeAI-test/textual_inversion_tests/resolve/main/learned_embeds-steps-1000.safetensors",
TestAdapter(
data,
headers={"Content-Type": "application/json; charset=utf-8", "Content-Length": len(data)},
),
)
for root, _, files in os.walk(diffusers_dir): for root, _, files in os.walk(diffusers_dir):
for name in files: for name in files:
path = Path(root, name) path = Path(root, name)

File diff suppressed because one or more lines are too long