mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add unit test
This commit is contained in:
parent
902e26507d
commit
9d0952c2ef
@ -132,7 +132,8 @@ class ModelProbe(object):
|
|||||||
|
|
||||||
format_type = ModelFormat.Diffusers if model_path.is_dir() else ModelFormat.Checkpoint
|
format_type = ModelFormat.Diffusers if model_path.is_dir() else ModelFormat.Checkpoint
|
||||||
model_info = None
|
model_info = None
|
||||||
model_type = ModelType(fields['type']) if 'type' in fields else None
|
model_type = fields['type'] if 'type' in fields else None
|
||||||
|
model_type = ModelType(model_type) if isinstance(model_type, str) else model_type
|
||||||
if not model_type:
|
if not model_type:
|
||||||
if format_type is ModelFormat.Diffusers:
|
if format_type is ModelFormat.Diffusers:
|
||||||
model_type = cls.get_model_type_from_folder(model_path)
|
model_type = cls.get_model_type_from_folder(model_path)
|
||||||
@ -157,7 +158,7 @@ class ModelProbe(object):
|
|||||||
fields["image_encoder_model_id"] = fields.get("image_encoder_model_id") or probe.get_image_encoder_model_id()
|
fields["image_encoder_model_id"] = fields.get("image_encoder_model_id") or probe.get_image_encoder_model_id()
|
||||||
fields["name"] = fields.get("name") or cls.get_model_name(model_path)
|
fields["name"] = fields.get("name") or cls.get_model_name(model_path)
|
||||||
fields["description"] = (
|
fields["description"] = (
|
||||||
fields.get("description") or f"{fields['base'].value} {fields['type'].value} model {fields['name']}"
|
fields.get("description") or f"{fields['base'].value} {model_type.value} model {fields['name']}"
|
||||||
)
|
)
|
||||||
fields["format"] = fields.get("format") or probe.get_format()
|
fields["format"] = fields.get("format") or probe.get_format()
|
||||||
fields["hash"] = fields.get("hash") or ModelHash(algorithm=hash_algo).hash(model_path)
|
fields["hash"] = fields.get("hash") or ModelHash(algorithm=hash_algo).hash(model_path)
|
||||||
|
@ -9,6 +9,7 @@ from pathlib import Path
|
|||||||
import pytest
|
import pytest
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
from pydantic.networks import Url
|
from pydantic.networks import Url
|
||||||
|
from time import sleep
|
||||||
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.app.services.events.events_base import EventServiceBase
|
from invokeai.app.services.events.events_base import EventServiceBase
|
||||||
@ -20,7 +21,7 @@ from invokeai.app.services.model_install import (
|
|||||||
URLModelSource,
|
URLModelSource,
|
||||||
)
|
)
|
||||||
from invokeai.app.services.model_records import UnknownModelException
|
from invokeai.app.services.model_records import UnknownModelException
|
||||||
from invokeai.backend.model_manager.config import BaseModelType, ModelFormat, ModelType
|
from invokeai.backend.model_manager.config import BaseModelType, ModelFormat, ModelType, InvalidModelConfigException
|
||||||
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
|
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
|
||||||
|
|
||||||
OS = platform.uname().system
|
OS = platform.uname().system
|
||||||
@ -53,7 +54,7 @@ def test_registration_meta(mm2_installer: ModelInstallServiceBase, embedding_fil
|
|||||||
|
|
||||||
def test_registration_meta_override_fail(mm2_installer: ModelInstallServiceBase, embedding_file: Path) -> None:
|
def test_registration_meta_override_fail(mm2_installer: ModelInstallServiceBase, embedding_file: Path) -> None:
|
||||||
key = None
|
key = None
|
||||||
with pytest.raises(ValidationError):
|
with pytest.raises((ValidationError, InvalidModelConfigException)):
|
||||||
key = mm2_installer.register_path(embedding_file, {"name": "banana_sushi", "type": ModelType("lora")})
|
key = mm2_installer.register_path(embedding_file, {"name": "banana_sushi", "type": ModelType("lora")})
|
||||||
assert key is None
|
assert key is None
|
||||||
|
|
||||||
@ -263,3 +264,36 @@ def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: In
|
|||||||
assert job.error
|
assert job.error
|
||||||
assert "NOT FOUND" in job.error
|
assert "NOT FOUND" in job.error
|
||||||
assert job.error_traceback.startswith("Traceback")
|
assert job.error_traceback.startswith("Traceback")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model_params",
|
||||||
|
[
|
||||||
|
# SDXL, Lora
|
||||||
|
{
|
||||||
|
"repo_id": "InvokeAI-test/textual_inversion_tests::learned_embeds-steps-1000.safetensors",
|
||||||
|
"name": "test_lora",
|
||||||
|
"type": 'embedding',
|
||||||
|
},
|
||||||
|
# SDXL, Lora - incorrect type
|
||||||
|
{
|
||||||
|
"repo_id": "InvokeAI-test/textual_inversion_tests::learned_embeds-steps-1000.safetensors",
|
||||||
|
"name": "test_lora",
|
||||||
|
"type": 'lora',
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_heuristic_import_with_type(mm2_installer: ModelInstallServiceBase, model_params: Dict[str, str]):
|
||||||
|
"""Test whether or not type is respected on configs when passed to heuristic import."""
|
||||||
|
config: Dict[str, Any] = {
|
||||||
|
"type": model_params["type"],
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
assert("repo_id" in model_params)
|
||||||
|
install_job = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config)
|
||||||
|
|
||||||
|
while not install_job.in_terminal_state:
|
||||||
|
sleep(.01)
|
||||||
|
assert(install_job.config_out if model_params["type"] == "embedding" else not install_job.config_out)
|
||||||
|
except InvalidModelConfigException:
|
||||||
|
assert model_params["type"] != "embedding"
|
||||||
|
@ -38,6 +38,7 @@ from tests.backend.model_manager.model_metadata.metadata_examples import (
|
|||||||
RepoHFMetadata1,
|
RepoHFMetadata1,
|
||||||
RepoHFMetadata1_nofp16,
|
RepoHFMetadata1_nofp16,
|
||||||
RepoHFModelJson1,
|
RepoHFModelJson1,
|
||||||
|
HFTestLoraMetadata,
|
||||||
)
|
)
|
||||||
from tests.fixtures.sqlite_database import create_mock_sqlite_database
|
from tests.fixtures.sqlite_database import create_mock_sqlite_database
|
||||||
|
|
||||||
@ -300,6 +301,23 @@ def mm2_session(embedding_file: Path, diffusers_dir: Path) -> Session:
|
|||||||
headers={"Content-Type": "application/json; charset=utf-8", "Content-Length": len(RepoHFMetadata1)},
|
headers={"Content-Type": "application/json; charset=utf-8", "Content-Length": len(RepoHFMetadata1)},
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
with open(embedding_file, "rb") as f:
|
||||||
|
data = f.read() # file is small - just 15K
|
||||||
|
sess.mount(
|
||||||
|
"https://huggingface.co/api/models/InvokeAI-test/textual_inversion_tests?blobs=True",
|
||||||
|
TestAdapter(
|
||||||
|
HFTestLoraMetadata,
|
||||||
|
headers={"Content-Type": "application/json; charset=utf-8", "Content-Length": len(HFTestLoraMetadata)},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
sess.mount(
|
||||||
|
"https://huggingface.co/InvokeAI-test/textual_inversion_tests/resolve/main/learned_embeds-steps-1000.safetensors",
|
||||||
|
TestAdapter(
|
||||||
|
data,
|
||||||
|
headers={"Content-Type": "application/json; charset=utf-8", "Content-Length": len(data)},
|
||||||
|
),
|
||||||
|
)
|
||||||
for root, _, files in os.walk(diffusers_dir):
|
for root, _, files in os.walk(diffusers_dir):
|
||||||
for name in files:
|
for name in files:
|
||||||
path = Path(root, name)
|
path = Path(root, name)
|
||||||
|
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user