Compare commits

...

18 Commits

Author SHA1 Message Date
7d46e8430b Run tests on this branch 2024-03-19 01:05:27 -04:00
b3f8f22998 more debug statements 2024-03-19 00:20:32 -04:00
6a75c5ba08 Merge branch 'allow-model-type-passthrough-on-probe' of github.com:invoke-ai/InvokeAI into allow-model-type-passthrough-on-probe 2024-03-19 00:01:21 -04:00
dcbb1ff894 add debugging statements to catch hang 2024-03-19 00:01:09 -04:00
5751455618 Remove redundant embedding file read 2024-03-18 23:58:33 -04:00
a2232c2e09 Disallow adding new installs to the queue during stop events 2024-03-18 15:37:21 -04:00
30da11998b Use wait_for_job instead of wait_for_installs 2024-03-18 12:23:46 -04:00
1f000306f3 Wrap in try except for InvalidModelConfigException 2024-03-18 12:10:06 -04:00
c778d74a42 Skip hashing in test_heuristic_import_with_type 2024-03-18 11:53:49 -04:00
97bcd54408 Increase timeout for test_heuristic_import_with_type, fix Url import 2024-03-18 11:50:32 -04:00
62d7e38030 Run ruff 2024-03-18 10:11:56 -04:00
9b6f3bded9 Fix test to run on windows vms 2024-03-18 10:10:45 -04:00
aec567179d Pull format type setting out of model_type if statement 2024-03-17 22:44:26 -04:00
28bfc1c935 Simplify logic for determining model type in probe 2024-03-17 22:44:26 -04:00
39f62ac63c Run Ruff 2024-03-17 22:44:26 -04:00
9d0952c2ef Add unit test 2024-03-17 22:44:26 -04:00
902e26507d Allow type field to be a string 2024-03-17 22:44:26 -04:00
b83427d7ce Allow users to specify model type and skip detection step of probe 2024-03-17 22:44:26 -04:00
6 changed files with 99 additions and 17 deletions

View File

@ -9,6 +9,7 @@ on:
push: push:
branches: branches:
- 'main' - 'main'
- 'bug-install-job-running-multiple-times'
pull_request: pull_request:
types: types:
- 'ready_for_review' - 'ready_for_review'

View File

@ -133,6 +133,14 @@ class ModelInstallService(ModelInstallServiceBase):
self._download_cache.clear() self._download_cache.clear()
self._running = False self._running = False
def _put_in_queue(self, job: ModelInstallJob) -> None:
print(f'DEBUG: in _put_in_queue(job={job.id})')
if self._stop_event.is_set():
self.cancel_job(job)
else:
print(f'DEBUG: putting {job.id} into the install queue')
self._install_queue.put(job)
def register_path( def register_path(
self, self,
model_path: Union[Path, str], model_path: Union[Path, str],
@ -218,7 +226,7 @@ class ModelInstallService(ModelInstallServiceBase):
if isinstance(source, LocalModelSource): if isinstance(source, LocalModelSource):
install_job = self._import_local_model(source, config) install_job = self._import_local_model(source, config)
self._install_queue.put(install_job) # synchronously install self._put_in_queue(install_job) # synchronously install
elif isinstance(source, HFModelSource): elif isinstance(source, HFModelSource):
install_job = self._import_from_hf(source, config) install_job = self._import_from_hf(source, config)
elif isinstance(source, URLModelSource): elif isinstance(source, URLModelSource):
@ -403,10 +411,11 @@ class ModelInstallService(ModelInstallServiceBase):
done = True done = True
continue continue
try: try:
print(f'DEBUG: _install_next_item() checking for a job to install')
job = self._install_queue.get(timeout=1) job = self._install_queue.get(timeout=1)
except Empty: except Empty:
continue continue
print(f'DEBUG: _install_next_item() got job {job.id}, status={job.status}')
assert job.local_path is not None assert job.local_path is not None
try: try:
if job.cancelled: if job.cancelled:
@ -432,6 +441,7 @@ class ModelInstallService(ModelInstallServiceBase):
else: else:
key = self.install_path(job.local_path, job.config_in) key = self.install_path(job.local_path, job.config_in)
job.config_out = self.record_store.get_model(key) job.config_out = self.record_store.get_model(key)
print(f'DEBUG: _install_next_item() signaling completion for job={job.id}, status={job.status}')
self._signal_job_completed(job) self._signal_job_completed(job)
except InvalidModelConfigException as excp: except InvalidModelConfigException as excp:
@ -781,14 +791,16 @@ class ModelInstallService(ModelInstallServiceBase):
def _download_complete_callback(self, download_job: DownloadJob) -> None: def _download_complete_callback(self, download_job: DownloadJob) -> None:
self._logger.info(f"{download_job.source}: model download complete") self._logger.info(f"{download_job.source}: model download complete")
print(f'DEBUG: _download_complete_callback(download_job={download_job.source}')
with self._lock: with self._lock:
install_job = self._download_cache[download_job.source] install_job = self._download_cache.pop(download_job.source, None)
self._download_cache.pop(download_job.source, None) print(f'DEBUG: download_job={download_job.source} / install_job={install_job}')
# are there any more active jobs left in this task? # are there any more active jobs left in this task?
if install_job.downloading and all(x.complete for x in install_job.download_parts): if install_job and install_job.downloading and all(x.complete for x in install_job.download_parts):
print(f'DEBUG: setting job {install_job.id} to DOWNLOADS_DONE')
install_job.status = InstallStatus.DOWNLOADS_DONE install_job.status = InstallStatus.DOWNLOADS_DONE
self._install_queue.put(install_job) print(f'DEBUG: putting {install_job.id} into the install queue')
self._put_in_queue(install_job)
# Let other threads know that the number of downloads has changed # Let other threads know that the number of downloads has changed
self._downloads_changed_event.set() self._downloads_changed_event.set()
@ -830,7 +842,7 @@ class ModelInstallService(ModelInstallServiceBase):
if all(x.in_terminal_state for x in install_job.download_parts): if all(x.in_terminal_state for x in install_job.download_parts):
# When all parts have reached their terminal state, we finalize the job to clean up the temporary directory and other resources # When all parts have reached their terminal state, we finalize the job to clean up the temporary directory and other resources
self._install_queue.put(install_job) self._put_in_queue(install_job)
# ------------------------------------------------------------------------------------------------ # ------------------------------------------------------------------------------------------------
# Internal methods that put events on the event bus # Internal methods that put events on the event bus

View File

@ -132,11 +132,12 @@ class ModelProbe(object):
format_type = ModelFormat.Diffusers if model_path.is_dir() else ModelFormat.Checkpoint format_type = ModelFormat.Diffusers if model_path.is_dir() else ModelFormat.Checkpoint
model_info = None model_info = None
model_type = None model_type = ModelType(fields["type"]) if "type" in fields and fields["type"] else None
if format_type is ModelFormat.Diffusers: if not model_type:
model_type = cls.get_model_type_from_folder(model_path) if format_type is ModelFormat.Diffusers:
else: model_type = cls.get_model_type_from_folder(model_path)
model_type = cls.get_model_type_from_checkpoint(model_path) else:
model_type = cls.get_model_type_from_checkpoint(model_path)
format_type = ModelFormat.ONNX if model_type == ModelType.ONNX else format_type format_type = ModelFormat.ONNX if model_type == ModelType.ONNX else format_type
probe_class = cls.PROBES[format_type].get(model_type) probe_class = cls.PROBES[format_type].get(model_type)
@ -156,7 +157,7 @@ class ModelProbe(object):
fields["image_encoder_model_id"] = fields.get("image_encoder_model_id") or probe.get_image_encoder_model_id() fields["image_encoder_model_id"] = fields.get("image_encoder_model_id") or probe.get_image_encoder_model_id()
fields["name"] = fields.get("name") or cls.get_model_name(model_path) fields["name"] = fields.get("name") or cls.get_model_name(model_path)
fields["description"] = ( fields["description"] = (
fields.get("description") or f"{fields['base'].value} {fields['type'].value} model {fields['name']}" fields.get("description") or f"{fields['base'].value} {model_type.value} model {fields['name']}"
) )
fields["format"] = fields.get("format") or probe.get_format() fields["format"] = fields.get("format") or probe.get_format()
fields["hash"] = fields.get("hash") or ModelHash(algorithm=hash_algo).hash(model_path) fields["hash"] = fields.get("hash") or ModelHash(algorithm=hash_algo).hash(model_path)

View File

@ -5,10 +5,11 @@ Test the model installer
import platform import platform
import uuid import uuid
from pathlib import Path from pathlib import Path
from typing import Any, Dict
import pytest import pytest
from pydantic import ValidationError from pydantic import ValidationError
from pydantic.networks import Url from pydantic_core import Url
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.events.events_base import EventServiceBase
@ -20,7 +21,7 @@ from invokeai.app.services.model_install import (
URLModelSource, URLModelSource,
) )
from invokeai.app.services.model_records import UnknownModelException from invokeai.app.services.model_records import UnknownModelException
from invokeai.backend.model_manager.config import BaseModelType, ModelFormat, ModelType from invokeai.backend.model_manager.config import BaseModelType, InvalidModelConfigException, ModelFormat, ModelType
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403 from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
OS = platform.uname().system OS = platform.uname().system
@ -53,7 +54,7 @@ def test_registration_meta(mm2_installer: ModelInstallServiceBase, embedding_fil
def test_registration_meta_override_fail(mm2_installer: ModelInstallServiceBase, embedding_file: Path) -> None: def test_registration_meta_override_fail(mm2_installer: ModelInstallServiceBase, embedding_file: Path) -> None:
key = None key = None
with pytest.raises(ValidationError): with pytest.raises((ValidationError, InvalidModelConfigException)):
key = mm2_installer.register_path(embedding_file, {"name": "banana_sushi", "type": ModelType("lora")}) key = mm2_installer.register_path(embedding_file, {"name": "banana_sushi", "type": ModelType("lora")})
assert key is None assert key is None
@ -263,3 +264,50 @@ def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: In
assert job.error assert job.error
assert "NOT FOUND" in job.error assert "NOT FOUND" in job.error
assert job.error_traceback.startswith("Traceback") assert job.error_traceback.startswith("Traceback")
@pytest.mark.parametrize(
"model_params",
[
# SDXL, Lora
{
"repo_id": "InvokeAI-test/textual_inversion_tests::learned_embeds-steps-1000.safetensors",
"name": "test_lora",
"type": "embedding",
},
# SDXL, Lora - incorrect type
{
"repo_id": "InvokeAI-test/textual_inversion_tests::learned_embeds-steps-1000.safetensors",
"name": "test_lora",
"type": "lora",
},
],
)
@pytest.mark.timeout(timeout=40, method="thread")
def test_heuristic_import_with_type(mm2_installer: ModelInstallServiceBase, model_params: Dict[str, str]):
"""Test whether or not type is respected on configs when passed to heuristic import."""
assert "name" in model_params and "type" in model_params
config1: Dict[str, Any] = {
"name": f"{model_params['name']}_1",
"type": model_params["type"],
"hash": "placeholder1",
}
config2: Dict[str, Any] = {
"name": f"{model_params['name']}_2",
"type": ModelType(model_params["type"]),
"hash": "placeholder2",
}
assert "repo_id" in model_params
install_job1 = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config1)
mm2_installer.wait_for_job(install_job1, timeout=20)
if model_params["type"] != "embedding":
assert install_job1.errored
assert install_job1.error_type == "InvalidModelConfigException"
return
assert install_job1.complete
assert install_job1.config_out if model_params["type"] == "embedding" else not install_job1.config_out
install_job2 = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config2)
mm2_installer.wait_for_job(install_job2, timeout=20)
assert install_job2.complete
assert install_job2.config_out if model_params["type"] == "embedding" else not install_job2.config_out

View File

@ -33,6 +33,7 @@ from invokeai.backend.model_manager.config import (
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
from tests.backend.model_manager.model_metadata.metadata_examples import ( from tests.backend.model_manager.model_metadata.metadata_examples import (
HFTestLoraMetadata,
RepoCivitaiModelMetadata1, RepoCivitaiModelMetadata1,
RepoCivitaiVersionMetadata1, RepoCivitaiVersionMetadata1,
RepoHFMetadata1, RepoHFMetadata1,
@ -300,6 +301,20 @@ def mm2_session(embedding_file: Path, diffusers_dir: Path) -> Session:
headers={"Content-Type": "application/json; charset=utf-8", "Content-Length": len(RepoHFMetadata1)}, headers={"Content-Type": "application/json; charset=utf-8", "Content-Length": len(RepoHFMetadata1)},
), ),
) )
sess.mount(
"https://huggingface.co/api/models/InvokeAI-test/textual_inversion_tests?blobs=True",
TestAdapter(
HFTestLoraMetadata,
headers={"Content-Type": "application/json; charset=utf-8", "Content-Length": len(HFTestLoraMetadata)},
),
)
sess.mount(
"https://huggingface.co/InvokeAI-test/textual_inversion_tests/resolve/main/learned_embeds-steps-1000.safetensors",
TestAdapter(
data,
headers={"Content-Type": "application/json; charset=utf-8", "Content-Length": len(data)},
),
)
for root, _, files in os.walk(diffusers_dir): for root, _, files in os.walk(diffusers_dir):
for name in files: for name in files:
path = Path(root, name) path = Path(root, name)

File diff suppressed because one or more lines are too long