mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add unit test
This commit is contained in:
parent
902e26507d
commit
9d0952c2ef
@ -132,7 +132,8 @@ class ModelProbe(object):
|
||||
|
||||
format_type = ModelFormat.Diffusers if model_path.is_dir() else ModelFormat.Checkpoint
|
||||
model_info = None
|
||||
model_type = ModelType(fields['type']) if 'type' in fields else None
|
||||
model_type = fields['type'] if 'type' in fields else None
|
||||
model_type = ModelType(model_type) if isinstance(model_type, str) else model_type
|
||||
if not model_type:
|
||||
if format_type is ModelFormat.Diffusers:
|
||||
model_type = cls.get_model_type_from_folder(model_path)
|
||||
@ -157,7 +158,7 @@ class ModelProbe(object):
|
||||
fields["image_encoder_model_id"] = fields.get("image_encoder_model_id") or probe.get_image_encoder_model_id()
|
||||
fields["name"] = fields.get("name") or cls.get_model_name(model_path)
|
||||
fields["description"] = (
|
||||
fields.get("description") or f"{fields['base'].value} {fields['type'].value} model {fields['name']}"
|
||||
fields.get("description") or f"{fields['base'].value} {model_type.value} model {fields['name']}"
|
||||
)
|
||||
fields["format"] = fields.get("format") or probe.get_format()
|
||||
fields["hash"] = fields.get("hash") or ModelHash(algorithm=hash_algo).hash(model_path)
|
||||
|
@ -9,6 +9,7 @@ from pathlib import Path
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
from pydantic.networks import Url
|
||||
from time import sleep
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
@ -20,7 +21,7 @@ from invokeai.app.services.model_install import (
|
||||
URLModelSource,
|
||||
)
|
||||
from invokeai.app.services.model_records import UnknownModelException
|
||||
from invokeai.backend.model_manager.config import BaseModelType, ModelFormat, ModelType
|
||||
from invokeai.backend.model_manager.config import BaseModelType, ModelFormat, ModelType, InvalidModelConfigException
|
||||
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
|
||||
|
||||
OS = platform.uname().system
|
||||
@ -53,7 +54,7 @@ def test_registration_meta(mm2_installer: ModelInstallServiceBase, embedding_fil
|
||||
|
||||
def test_registration_meta_override_fail(mm2_installer: ModelInstallServiceBase, embedding_file: Path) -> None:
|
||||
key = None
|
||||
with pytest.raises(ValidationError):
|
||||
with pytest.raises((ValidationError, InvalidModelConfigException)):
|
||||
key = mm2_installer.register_path(embedding_file, {"name": "banana_sushi", "type": ModelType("lora")})
|
||||
assert key is None
|
||||
|
||||
@ -263,3 +264,36 @@ def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: In
|
||||
assert job.error
|
||||
assert "NOT FOUND" in job.error
|
||||
assert job.error_traceback.startswith("Traceback")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_params",
|
||||
[
|
||||
# SDXL, Lora
|
||||
{
|
||||
"repo_id": "InvokeAI-test/textual_inversion_tests::learned_embeds-steps-1000.safetensors",
|
||||
"name": "test_lora",
|
||||
"type": 'embedding',
|
||||
},
|
||||
# SDXL, Lora - incorrect type
|
||||
{
|
||||
"repo_id": "InvokeAI-test/textual_inversion_tests::learned_embeds-steps-1000.safetensors",
|
||||
"name": "test_lora",
|
||||
"type": 'lora',
|
||||
},
|
||||
],
|
||||
)
|
||||
def test_heuristic_import_with_type(mm2_installer: ModelInstallServiceBase, model_params: Dict[str, str]):
|
||||
"""Test whether or not type is respected on configs when passed to heuristic import."""
|
||||
config: Dict[str, Any] = {
|
||||
"type": model_params["type"],
|
||||
}
|
||||
try:
|
||||
assert("repo_id" in model_params)
|
||||
install_job = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config)
|
||||
|
||||
while not install_job.in_terminal_state:
|
||||
sleep(.01)
|
||||
assert(install_job.config_out if model_params["type"] == "embedding" else not install_job.config_out)
|
||||
except InvalidModelConfigException:
|
||||
assert model_params["type"] != "embedding"
|
||||
|
@ -38,6 +38,7 @@ from tests.backend.model_manager.model_metadata.metadata_examples import (
|
||||
RepoHFMetadata1,
|
||||
RepoHFMetadata1_nofp16,
|
||||
RepoHFModelJson1,
|
||||
HFTestLoraMetadata,
|
||||
)
|
||||
from tests.fixtures.sqlite_database import create_mock_sqlite_database
|
||||
|
||||
@ -300,6 +301,23 @@ def mm2_session(embedding_file: Path, diffusers_dir: Path) -> Session:
|
||||
headers={"Content-Type": "application/json; charset=utf-8", "Content-Length": len(RepoHFMetadata1)},
|
||||
),
|
||||
)
|
||||
|
||||
with open(embedding_file, "rb") as f:
|
||||
data = f.read() # file is small - just 15K
|
||||
sess.mount(
|
||||
"https://huggingface.co/api/models/InvokeAI-test/textual_inversion_tests?blobs=True",
|
||||
TestAdapter(
|
||||
HFTestLoraMetadata,
|
||||
headers={"Content-Type": "application/json; charset=utf-8", "Content-Length": len(HFTestLoraMetadata)},
|
||||
),
|
||||
)
|
||||
sess.mount(
|
||||
"https://huggingface.co/InvokeAI-test/textual_inversion_tests/resolve/main/learned_embeds-steps-1000.safetensors",
|
||||
TestAdapter(
|
||||
data,
|
||||
headers={"Content-Type": "application/json; charset=utf-8", "Content-Length": len(data)},
|
||||
),
|
||||
)
|
||||
for root, _, files in os.walk(diffusers_dir):
|
||||
for name in files:
|
||||
path = Path(root, name)
|
||||
|
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user