Compare commits

...

18 Commits

Author SHA1 Message Date
7d46e8430b Run tests on this branch 2024-03-19 01:05:27 -04:00
b3f8f22998 more debug statements 2024-03-19 00:20:32 -04:00
6a75c5ba08 Merge branch 'allow-model-type-passthrough-on-probe' of github.com:invoke-ai/InvokeAI into allow-model-type-passthrough-on-probe 2024-03-19 00:01:21 -04:00
dcbb1ff894 add debugging statements to catch hang 2024-03-19 00:01:09 -04:00
5751455618 Remove redundant embedding file read 2024-03-18 23:58:33 -04:00
a2232c2e09 Disallow adding new installs to the queue during stop events 2024-03-18 15:37:21 -04:00
30da11998b Use wait_for_job instead of wait_for_installs 2024-03-18 12:23:46 -04:00
1f000306f3 Wrap in try except for InvalidModelConfigException 2024-03-18 12:10:06 -04:00
c778d74a42 Skip hashing in test_heuristic_import_with_type 2024-03-18 11:53:49 -04:00
97bcd54408 Increase timeout for test_heuristic_import_with_type, fix Url import 2024-03-18 11:50:32 -04:00
62d7e38030 Run ruff 2024-03-18 10:11:56 -04:00
9b6f3bded9 Fix test to run on windows vms 2024-03-18 10:10:45 -04:00
aec567179d Pull format type setting out of model_type if statement 2024-03-17 22:44:26 -04:00
28bfc1c935 Simplify logic for determining model type in probe 2024-03-17 22:44:26 -04:00
39f62ac63c Run Ruff 2024-03-17 22:44:26 -04:00
9d0952c2ef Add unit test 2024-03-17 22:44:26 -04:00
902e26507d Allow type field to be a string 2024-03-17 22:44:26 -04:00
b83427d7ce Allow users to specify model type and skip detection step of probe 2024-03-17 22:44:26 -04:00
6 changed files with 99 additions and 17 deletions

View File

@ -9,6 +9,7 @@ on:
push:
branches:
- 'main'
- 'bug-install-job-running-multiple-times'
pull_request:
types:
- 'ready_for_review'

View File

@ -133,6 +133,14 @@ class ModelInstallService(ModelInstallServiceBase):
self._download_cache.clear()
self._running = False
def _put_in_queue(self, job: ModelInstallJob) -> None:
print(f'DEBUG: in _put_in_queue(job={job.id})')
if self._stop_event.is_set():
self.cancel_job(job)
else:
print(f'DEBUG: putting {job.id} into the install queue')
self._install_queue.put(job)
def register_path(
self,
model_path: Union[Path, str],
@ -218,7 +226,7 @@ class ModelInstallService(ModelInstallServiceBase):
if isinstance(source, LocalModelSource):
install_job = self._import_local_model(source, config)
self._install_queue.put(install_job) # synchronously install
self._put_in_queue(install_job) # synchronously install
elif isinstance(source, HFModelSource):
install_job = self._import_from_hf(source, config)
elif isinstance(source, URLModelSource):
@ -403,10 +411,11 @@ class ModelInstallService(ModelInstallServiceBase):
done = True
continue
try:
print(f'DEBUG: _install_next_item() checking for a job to install')
job = self._install_queue.get(timeout=1)
except Empty:
continue
print(f'DEBUG: _install_next_item() got job {job.id}, status={job.status}')
assert job.local_path is not None
try:
if job.cancelled:
@ -432,6 +441,7 @@ class ModelInstallService(ModelInstallServiceBase):
else:
key = self.install_path(job.local_path, job.config_in)
job.config_out = self.record_store.get_model(key)
print(f'DEBUG: _install_next_item() signaling completion for job={job.id}, status={job.status}')
self._signal_job_completed(job)
except InvalidModelConfigException as excp:
@ -781,14 +791,16 @@ class ModelInstallService(ModelInstallServiceBase):
def _download_complete_callback(self, download_job: DownloadJob) -> None:
self._logger.info(f"{download_job.source}: model download complete")
print(f'DEBUG: _download_complete_callback(download_job={download_job.source}')
with self._lock:
install_job = self._download_cache[download_job.source]
self._download_cache.pop(download_job.source, None)
install_job = self._download_cache.pop(download_job.source, None)
print(f'DEBUG: download_job={download_job.source} / install_job={install_job}')
# are there any more active jobs left in this task?
if install_job.downloading and all(x.complete for x in install_job.download_parts):
if install_job and install_job.downloading and all(x.complete for x in install_job.download_parts):
print(f'DEBUG: setting job {install_job.id} to DOWNLOADS_DONE')
install_job.status = InstallStatus.DOWNLOADS_DONE
self._install_queue.put(install_job)
print(f'DEBUG: putting {install_job.id} into the install queue')
self._put_in_queue(install_job)
# Let other threads know that the number of downloads has changed
self._downloads_changed_event.set()
@ -830,7 +842,7 @@ class ModelInstallService(ModelInstallServiceBase):
if all(x.in_terminal_state for x in install_job.download_parts):
# When all parts have reached their terminal state, we finalize the job to clean up the temporary directory and other resources
self._install_queue.put(install_job)
self._put_in_queue(install_job)
# ------------------------------------------------------------------------------------------------
# Internal methods that put events on the event bus

View File

@ -132,7 +132,8 @@ class ModelProbe(object):
format_type = ModelFormat.Diffusers if model_path.is_dir() else ModelFormat.Checkpoint
model_info = None
model_type = None
model_type = ModelType(fields["type"]) if "type" in fields and fields["type"] else None
if not model_type:
if format_type is ModelFormat.Diffusers:
model_type = cls.get_model_type_from_folder(model_path)
else:
@ -156,7 +157,7 @@ class ModelProbe(object):
fields["image_encoder_model_id"] = fields.get("image_encoder_model_id") or probe.get_image_encoder_model_id()
fields["name"] = fields.get("name") or cls.get_model_name(model_path)
fields["description"] = (
fields.get("description") or f"{fields['base'].value} {fields['type'].value} model {fields['name']}"
fields.get("description") or f"{fields['base'].value} {model_type.value} model {fields['name']}"
)
fields["format"] = fields.get("format") or probe.get_format()
fields["hash"] = fields.get("hash") or ModelHash(algorithm=hash_algo).hash(model_path)

View File

@ -5,10 +5,11 @@ Test the model installer
import platform
import uuid
from pathlib import Path
from typing import Any, Dict
import pytest
from pydantic import ValidationError
from pydantic.networks import Url
from pydantic_core import Url
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.events.events_base import EventServiceBase
@ -20,7 +21,7 @@ from invokeai.app.services.model_install import (
URLModelSource,
)
from invokeai.app.services.model_records import UnknownModelException
from invokeai.backend.model_manager.config import BaseModelType, ModelFormat, ModelType
from invokeai.backend.model_manager.config import BaseModelType, InvalidModelConfigException, ModelFormat, ModelType
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
OS = platform.uname().system
@ -53,7 +54,7 @@ def test_registration_meta(mm2_installer: ModelInstallServiceBase, embedding_fil
def test_registration_meta_override_fail(mm2_installer: ModelInstallServiceBase, embedding_file: Path) -> None:
key = None
with pytest.raises(ValidationError):
with pytest.raises((ValidationError, InvalidModelConfigException)):
key = mm2_installer.register_path(embedding_file, {"name": "banana_sushi", "type": ModelType("lora")})
assert key is None
@ -263,3 +264,50 @@ def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: In
assert job.error
assert "NOT FOUND" in job.error
assert job.error_traceback.startswith("Traceback")
@pytest.mark.parametrize(
"model_params",
[
# SDXL, Lora
{
"repo_id": "InvokeAI-test/textual_inversion_tests::learned_embeds-steps-1000.safetensors",
"name": "test_lora",
"type": "embedding",
},
# SDXL, Lora - incorrect type
{
"repo_id": "InvokeAI-test/textual_inversion_tests::learned_embeds-steps-1000.safetensors",
"name": "test_lora",
"type": "lora",
},
],
)
@pytest.mark.timeout(timeout=40, method="thread")
def test_heuristic_import_with_type(mm2_installer: ModelInstallServiceBase, model_params: Dict[str, str]):
"""Test whether or not type is respected on configs when passed to heuristic import."""
assert "name" in model_params and "type" in model_params
config1: Dict[str, Any] = {
"name": f"{model_params['name']}_1",
"type": model_params["type"],
"hash": "placeholder1",
}
config2: Dict[str, Any] = {
"name": f"{model_params['name']}_2",
"type": ModelType(model_params["type"]),
"hash": "placeholder2",
}
assert "repo_id" in model_params
install_job1 = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config1)
mm2_installer.wait_for_job(install_job1, timeout=20)
if model_params["type"] != "embedding":
assert install_job1.errored
assert install_job1.error_type == "InvalidModelConfigException"
return
assert install_job1.complete
assert install_job1.config_out if model_params["type"] == "embedding" else not install_job1.config_out
install_job2 = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config2)
mm2_installer.wait_for_job(install_job2, timeout=20)
assert install_job2.complete
assert install_job2.config_out if model_params["type"] == "embedding" else not install_job2.config_out

View File

@ -33,6 +33,7 @@ from invokeai.backend.model_manager.config import (
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache
from invokeai.backend.util.logging import InvokeAILogger
from tests.backend.model_manager.model_metadata.metadata_examples import (
HFTestLoraMetadata,
RepoCivitaiModelMetadata1,
RepoCivitaiVersionMetadata1,
RepoHFMetadata1,
@ -300,6 +301,20 @@ def mm2_session(embedding_file: Path, diffusers_dir: Path) -> Session:
headers={"Content-Type": "application/json; charset=utf-8", "Content-Length": len(RepoHFMetadata1)},
),
)
sess.mount(
"https://huggingface.co/api/models/InvokeAI-test/textual_inversion_tests?blobs=True",
TestAdapter(
HFTestLoraMetadata,
headers={"Content-Type": "application/json; charset=utf-8", "Content-Length": len(HFTestLoraMetadata)},
),
)
sess.mount(
"https://huggingface.co/InvokeAI-test/textual_inversion_tests/resolve/main/learned_embeds-steps-1000.safetensors",
TestAdapter(
data,
headers={"Content-Type": "application/json; charset=utf-8", "Content-Length": len(data)},
),
)
for root, _, files in os.walk(diffusers_dir):
for name in files:
path = Path(root, name)

File diff suppressed because one or more lines are too long