mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Compare commits
18 Commits
lstein/tes
...
bug-instal
Author | SHA1 | Date | |
---|---|---|---|
7d46e8430b | |||
b3f8f22998 | |||
6a75c5ba08 | |||
dcbb1ff894 | |||
5751455618 | |||
a2232c2e09 | |||
30da11998b | |||
1f000306f3 | |||
c778d74a42 | |||
97bcd54408 | |||
62d7e38030 | |||
9b6f3bded9 | |||
aec567179d | |||
28bfc1c935 | |||
39f62ac63c | |||
9d0952c2ef | |||
902e26507d | |||
b83427d7ce |
1
.github/workflows/python-tests.yml
vendored
1
.github/workflows/python-tests.yml
vendored
@ -9,6 +9,7 @@ on:
|
||||
push:
|
||||
branches:
|
||||
- 'main'
|
||||
- 'bug-install-job-running-multiple-times'
|
||||
pull_request:
|
||||
types:
|
||||
- 'ready_for_review'
|
||||
|
@ -133,6 +133,14 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self._download_cache.clear()
|
||||
self._running = False
|
||||
|
||||
def _put_in_queue(self, job: ModelInstallJob) -> None:
|
||||
print(f'DEBUG: in _put_in_queue(job={job.id})')
|
||||
if self._stop_event.is_set():
|
||||
self.cancel_job(job)
|
||||
else:
|
||||
print(f'DEBUG: putting {job.id} into the install queue')
|
||||
self._install_queue.put(job)
|
||||
|
||||
def register_path(
|
||||
self,
|
||||
model_path: Union[Path, str],
|
||||
@ -218,7 +226,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
|
||||
if isinstance(source, LocalModelSource):
|
||||
install_job = self._import_local_model(source, config)
|
||||
self._install_queue.put(install_job) # synchronously install
|
||||
self._put_in_queue(install_job) # synchronously install
|
||||
elif isinstance(source, HFModelSource):
|
||||
install_job = self._import_from_hf(source, config)
|
||||
elif isinstance(source, URLModelSource):
|
||||
@ -403,10 +411,11 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
done = True
|
||||
continue
|
||||
try:
|
||||
print(f'DEBUG: _install_next_item() checking for a job to install')
|
||||
job = self._install_queue.get(timeout=1)
|
||||
except Empty:
|
||||
continue
|
||||
|
||||
print(f'DEBUG: _install_next_item() got job {job.id}, status={job.status}')
|
||||
assert job.local_path is not None
|
||||
try:
|
||||
if job.cancelled:
|
||||
@ -432,6 +441,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
else:
|
||||
key = self.install_path(job.local_path, job.config_in)
|
||||
job.config_out = self.record_store.get_model(key)
|
||||
print(f'DEBUG: _install_next_item() signaling completion for job={job.id}, status={job.status}')
|
||||
self._signal_job_completed(job)
|
||||
|
||||
except InvalidModelConfigException as excp:
|
||||
@ -781,14 +791,16 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
|
||||
def _download_complete_callback(self, download_job: DownloadJob) -> None:
|
||||
self._logger.info(f"{download_job.source}: model download complete")
|
||||
print(f'DEBUG: _download_complete_callback(download_job={download_job.source}')
|
||||
with self._lock:
|
||||
install_job = self._download_cache[download_job.source]
|
||||
self._download_cache.pop(download_job.source, None)
|
||||
|
||||
install_job = self._download_cache.pop(download_job.source, None)
|
||||
print(f'DEBUG: download_job={download_job.source} / install_job={install_job}')
|
||||
# are there any more active jobs left in this task?
|
||||
if install_job.downloading and all(x.complete for x in install_job.download_parts):
|
||||
if install_job and install_job.downloading and all(x.complete for x in install_job.download_parts):
|
||||
print(f'DEBUG: setting job {install_job.id} to DOWNLOADS_DONE')
|
||||
install_job.status = InstallStatus.DOWNLOADS_DONE
|
||||
self._install_queue.put(install_job)
|
||||
print(f'DEBUG: putting {install_job.id} into the install queue')
|
||||
self._put_in_queue(install_job)
|
||||
|
||||
# Let other threads know that the number of downloads has changed
|
||||
self._downloads_changed_event.set()
|
||||
@ -830,7 +842,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
|
||||
if all(x.in_terminal_state for x in install_job.download_parts):
|
||||
# When all parts have reached their terminal state, we finalize the job to clean up the temporary directory and other resources
|
||||
self._install_queue.put(install_job)
|
||||
self._put_in_queue(install_job)
|
||||
|
||||
# ------------------------------------------------------------------------------------------------
|
||||
# Internal methods that put events on the event bus
|
||||
|
@ -132,11 +132,12 @@ class ModelProbe(object):
|
||||
|
||||
format_type = ModelFormat.Diffusers if model_path.is_dir() else ModelFormat.Checkpoint
|
||||
model_info = None
|
||||
model_type = None
|
||||
if format_type is ModelFormat.Diffusers:
|
||||
model_type = cls.get_model_type_from_folder(model_path)
|
||||
else:
|
||||
model_type = cls.get_model_type_from_checkpoint(model_path)
|
||||
model_type = ModelType(fields["type"]) if "type" in fields and fields["type"] else None
|
||||
if not model_type:
|
||||
if format_type is ModelFormat.Diffusers:
|
||||
model_type = cls.get_model_type_from_folder(model_path)
|
||||
else:
|
||||
model_type = cls.get_model_type_from_checkpoint(model_path)
|
||||
format_type = ModelFormat.ONNX if model_type == ModelType.ONNX else format_type
|
||||
|
||||
probe_class = cls.PROBES[format_type].get(model_type)
|
||||
@ -156,7 +157,7 @@ class ModelProbe(object):
|
||||
fields["image_encoder_model_id"] = fields.get("image_encoder_model_id") or probe.get_image_encoder_model_id()
|
||||
fields["name"] = fields.get("name") or cls.get_model_name(model_path)
|
||||
fields["description"] = (
|
||||
fields.get("description") or f"{fields['base'].value} {fields['type'].value} model {fields['name']}"
|
||||
fields.get("description") or f"{fields['base'].value} {model_type.value} model {fields['name']}"
|
||||
)
|
||||
fields["format"] = fields.get("format") or probe.get_format()
|
||||
fields["hash"] = fields.get("hash") or ModelHash(algorithm=hash_algo).hash(model_path)
|
||||
|
@ -5,10 +5,11 @@ Test the model installer
|
||||
import platform
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
from pydantic.networks import Url
|
||||
from pydantic_core import Url
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
@ -20,7 +21,7 @@ from invokeai.app.services.model_install import (
|
||||
URLModelSource,
|
||||
)
|
||||
from invokeai.app.services.model_records import UnknownModelException
|
||||
from invokeai.backend.model_manager.config import BaseModelType, ModelFormat, ModelType
|
||||
from invokeai.backend.model_manager.config import BaseModelType, InvalidModelConfigException, ModelFormat, ModelType
|
||||
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
|
||||
|
||||
OS = platform.uname().system
|
||||
@ -53,7 +54,7 @@ def test_registration_meta(mm2_installer: ModelInstallServiceBase, embedding_fil
|
||||
|
||||
def test_registration_meta_override_fail(mm2_installer: ModelInstallServiceBase, embedding_file: Path) -> None:
|
||||
key = None
|
||||
with pytest.raises(ValidationError):
|
||||
with pytest.raises((ValidationError, InvalidModelConfigException)):
|
||||
key = mm2_installer.register_path(embedding_file, {"name": "banana_sushi", "type": ModelType("lora")})
|
||||
assert key is None
|
||||
|
||||
@ -263,3 +264,50 @@ def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: In
|
||||
assert job.error
|
||||
assert "NOT FOUND" in job.error
|
||||
assert job.error_traceback.startswith("Traceback")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_params",
|
||||
[
|
||||
# SDXL, Lora
|
||||
{
|
||||
"repo_id": "InvokeAI-test/textual_inversion_tests::learned_embeds-steps-1000.safetensors",
|
||||
"name": "test_lora",
|
||||
"type": "embedding",
|
||||
},
|
||||
# SDXL, Lora - incorrect type
|
||||
{
|
||||
"repo_id": "InvokeAI-test/textual_inversion_tests::learned_embeds-steps-1000.safetensors",
|
||||
"name": "test_lora",
|
||||
"type": "lora",
|
||||
},
|
||||
],
|
||||
)
|
||||
@pytest.mark.timeout(timeout=40, method="thread")
|
||||
def test_heuristic_import_with_type(mm2_installer: ModelInstallServiceBase, model_params: Dict[str, str]):
|
||||
"""Test whether or not type is respected on configs when passed to heuristic import."""
|
||||
assert "name" in model_params and "type" in model_params
|
||||
config1: Dict[str, Any] = {
|
||||
"name": f"{model_params['name']}_1",
|
||||
"type": model_params["type"],
|
||||
"hash": "placeholder1",
|
||||
}
|
||||
config2: Dict[str, Any] = {
|
||||
"name": f"{model_params['name']}_2",
|
||||
"type": ModelType(model_params["type"]),
|
||||
"hash": "placeholder2",
|
||||
}
|
||||
assert "repo_id" in model_params
|
||||
install_job1 = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config1)
|
||||
mm2_installer.wait_for_job(install_job1, timeout=20)
|
||||
if model_params["type"] != "embedding":
|
||||
assert install_job1.errored
|
||||
assert install_job1.error_type == "InvalidModelConfigException"
|
||||
return
|
||||
assert install_job1.complete
|
||||
assert install_job1.config_out if model_params["type"] == "embedding" else not install_job1.config_out
|
||||
|
||||
install_job2 = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config2)
|
||||
mm2_installer.wait_for_job(install_job2, timeout=20)
|
||||
assert install_job2.complete
|
||||
assert install_job2.config_out if model_params["type"] == "embedding" else not install_job2.config_out
|
||||
|
@ -33,6 +33,7 @@ from invokeai.backend.model_manager.config import (
|
||||
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from tests.backend.model_manager.model_metadata.metadata_examples import (
|
||||
HFTestLoraMetadata,
|
||||
RepoCivitaiModelMetadata1,
|
||||
RepoCivitaiVersionMetadata1,
|
||||
RepoHFMetadata1,
|
||||
@ -300,6 +301,20 @@ def mm2_session(embedding_file: Path, diffusers_dir: Path) -> Session:
|
||||
headers={"Content-Type": "application/json; charset=utf-8", "Content-Length": len(RepoHFMetadata1)},
|
||||
),
|
||||
)
|
||||
sess.mount(
|
||||
"https://huggingface.co/api/models/InvokeAI-test/textual_inversion_tests?blobs=True",
|
||||
TestAdapter(
|
||||
HFTestLoraMetadata,
|
||||
headers={"Content-Type": "application/json; charset=utf-8", "Content-Length": len(HFTestLoraMetadata)},
|
||||
),
|
||||
)
|
||||
sess.mount(
|
||||
"https://huggingface.co/InvokeAI-test/textual_inversion_tests/resolve/main/learned_embeds-steps-1000.safetensors",
|
||||
TestAdapter(
|
||||
data,
|
||||
headers={"Content-Type": "application/json; charset=utf-8", "Content-Length": len(data)},
|
||||
),
|
||||
)
|
||||
for root, _, files in os.walk(diffusers_dir):
|
||||
for name in files:
|
||||
path = Path(root, name)
|
||||
|
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user