mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Compare commits
18 Commits
v4.0.1
...
bug-instal
Author | SHA1 | Date | |
---|---|---|---|
7d46e8430b | |||
b3f8f22998 | |||
6a75c5ba08 | |||
dcbb1ff894 | |||
5751455618 | |||
a2232c2e09 | |||
30da11998b | |||
1f000306f3 | |||
c778d74a42 | |||
97bcd54408 | |||
62d7e38030 | |||
9b6f3bded9 | |||
aec567179d | |||
28bfc1c935 | |||
39f62ac63c | |||
9d0952c2ef | |||
902e26507d | |||
b83427d7ce |
1
.github/workflows/python-tests.yml
vendored
1
.github/workflows/python-tests.yml
vendored
@ -9,6 +9,7 @@ on:
|
|||||||
push:
|
push:
|
||||||
branches:
|
branches:
|
||||||
- 'main'
|
- 'main'
|
||||||
|
- 'bug-install-job-running-multiple-times'
|
||||||
pull_request:
|
pull_request:
|
||||||
types:
|
types:
|
||||||
- 'ready_for_review'
|
- 'ready_for_review'
|
||||||
|
@ -133,6 +133,14 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
self._download_cache.clear()
|
self._download_cache.clear()
|
||||||
self._running = False
|
self._running = False
|
||||||
|
|
||||||
|
def _put_in_queue(self, job: ModelInstallJob) -> None:
|
||||||
|
print(f'DEBUG: in _put_in_queue(job={job.id})')
|
||||||
|
if self._stop_event.is_set():
|
||||||
|
self.cancel_job(job)
|
||||||
|
else:
|
||||||
|
print(f'DEBUG: putting {job.id} into the install queue')
|
||||||
|
self._install_queue.put(job)
|
||||||
|
|
||||||
def register_path(
|
def register_path(
|
||||||
self,
|
self,
|
||||||
model_path: Union[Path, str],
|
model_path: Union[Path, str],
|
||||||
@ -218,7 +226,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
|
|
||||||
if isinstance(source, LocalModelSource):
|
if isinstance(source, LocalModelSource):
|
||||||
install_job = self._import_local_model(source, config)
|
install_job = self._import_local_model(source, config)
|
||||||
self._install_queue.put(install_job) # synchronously install
|
self._put_in_queue(install_job) # synchronously install
|
||||||
elif isinstance(source, HFModelSource):
|
elif isinstance(source, HFModelSource):
|
||||||
install_job = self._import_from_hf(source, config)
|
install_job = self._import_from_hf(source, config)
|
||||||
elif isinstance(source, URLModelSource):
|
elif isinstance(source, URLModelSource):
|
||||||
@ -403,10 +411,11 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
done = True
|
done = True
|
||||||
continue
|
continue
|
||||||
try:
|
try:
|
||||||
|
print(f'DEBUG: _install_next_item() checking for a job to install')
|
||||||
job = self._install_queue.get(timeout=1)
|
job = self._install_queue.get(timeout=1)
|
||||||
except Empty:
|
except Empty:
|
||||||
continue
|
continue
|
||||||
|
print(f'DEBUG: _install_next_item() got job {job.id}, status={job.status}')
|
||||||
assert job.local_path is not None
|
assert job.local_path is not None
|
||||||
try:
|
try:
|
||||||
if job.cancelled:
|
if job.cancelled:
|
||||||
@ -432,6 +441,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
else:
|
else:
|
||||||
key = self.install_path(job.local_path, job.config_in)
|
key = self.install_path(job.local_path, job.config_in)
|
||||||
job.config_out = self.record_store.get_model(key)
|
job.config_out = self.record_store.get_model(key)
|
||||||
|
print(f'DEBUG: _install_next_item() signaling completion for job={job.id}, status={job.status}')
|
||||||
self._signal_job_completed(job)
|
self._signal_job_completed(job)
|
||||||
|
|
||||||
except InvalidModelConfigException as excp:
|
except InvalidModelConfigException as excp:
|
||||||
@ -781,14 +791,16 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
|
|
||||||
def _download_complete_callback(self, download_job: DownloadJob) -> None:
|
def _download_complete_callback(self, download_job: DownloadJob) -> None:
|
||||||
self._logger.info(f"{download_job.source}: model download complete")
|
self._logger.info(f"{download_job.source}: model download complete")
|
||||||
|
print(f'DEBUG: _download_complete_callback(download_job={download_job.source}')
|
||||||
with self._lock:
|
with self._lock:
|
||||||
install_job = self._download_cache[download_job.source]
|
install_job = self._download_cache.pop(download_job.source, None)
|
||||||
self._download_cache.pop(download_job.source, None)
|
print(f'DEBUG: download_job={download_job.source} / install_job={install_job}')
|
||||||
|
|
||||||
# are there any more active jobs left in this task?
|
# are there any more active jobs left in this task?
|
||||||
if install_job.downloading and all(x.complete for x in install_job.download_parts):
|
if install_job and install_job.downloading and all(x.complete for x in install_job.download_parts):
|
||||||
|
print(f'DEBUG: setting job {install_job.id} to DOWNLOADS_DONE')
|
||||||
install_job.status = InstallStatus.DOWNLOADS_DONE
|
install_job.status = InstallStatus.DOWNLOADS_DONE
|
||||||
self._install_queue.put(install_job)
|
print(f'DEBUG: putting {install_job.id} into the install queue')
|
||||||
|
self._put_in_queue(install_job)
|
||||||
|
|
||||||
# Let other threads know that the number of downloads has changed
|
# Let other threads know that the number of downloads has changed
|
||||||
self._downloads_changed_event.set()
|
self._downloads_changed_event.set()
|
||||||
@ -830,7 +842,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
|
|
||||||
if all(x.in_terminal_state for x in install_job.download_parts):
|
if all(x.in_terminal_state for x in install_job.download_parts):
|
||||||
# When all parts have reached their terminal state, we finalize the job to clean up the temporary directory and other resources
|
# When all parts have reached their terminal state, we finalize the job to clean up the temporary directory and other resources
|
||||||
self._install_queue.put(install_job)
|
self._put_in_queue(install_job)
|
||||||
|
|
||||||
# ------------------------------------------------------------------------------------------------
|
# ------------------------------------------------------------------------------------------------
|
||||||
# Internal methods that put events on the event bus
|
# Internal methods that put events on the event bus
|
||||||
|
@ -132,11 +132,12 @@ class ModelProbe(object):
|
|||||||
|
|
||||||
format_type = ModelFormat.Diffusers if model_path.is_dir() else ModelFormat.Checkpoint
|
format_type = ModelFormat.Diffusers if model_path.is_dir() else ModelFormat.Checkpoint
|
||||||
model_info = None
|
model_info = None
|
||||||
model_type = None
|
model_type = ModelType(fields["type"]) if "type" in fields and fields["type"] else None
|
||||||
if format_type is ModelFormat.Diffusers:
|
if not model_type:
|
||||||
model_type = cls.get_model_type_from_folder(model_path)
|
if format_type is ModelFormat.Diffusers:
|
||||||
else:
|
model_type = cls.get_model_type_from_folder(model_path)
|
||||||
model_type = cls.get_model_type_from_checkpoint(model_path)
|
else:
|
||||||
|
model_type = cls.get_model_type_from_checkpoint(model_path)
|
||||||
format_type = ModelFormat.ONNX if model_type == ModelType.ONNX else format_type
|
format_type = ModelFormat.ONNX if model_type == ModelType.ONNX else format_type
|
||||||
|
|
||||||
probe_class = cls.PROBES[format_type].get(model_type)
|
probe_class = cls.PROBES[format_type].get(model_type)
|
||||||
@ -156,7 +157,7 @@ class ModelProbe(object):
|
|||||||
fields["image_encoder_model_id"] = fields.get("image_encoder_model_id") or probe.get_image_encoder_model_id()
|
fields["image_encoder_model_id"] = fields.get("image_encoder_model_id") or probe.get_image_encoder_model_id()
|
||||||
fields["name"] = fields.get("name") or cls.get_model_name(model_path)
|
fields["name"] = fields.get("name") or cls.get_model_name(model_path)
|
||||||
fields["description"] = (
|
fields["description"] = (
|
||||||
fields.get("description") or f"{fields['base'].value} {fields['type'].value} model {fields['name']}"
|
fields.get("description") or f"{fields['base'].value} {model_type.value} model {fields['name']}"
|
||||||
)
|
)
|
||||||
fields["format"] = fields.get("format") or probe.get_format()
|
fields["format"] = fields.get("format") or probe.get_format()
|
||||||
fields["hash"] = fields.get("hash") or ModelHash(algorithm=hash_algo).hash(model_path)
|
fields["hash"] = fields.get("hash") or ModelHash(algorithm=hash_algo).hash(model_path)
|
||||||
|
@ -5,10 +5,11 @@ Test the model installer
|
|||||||
import platform
|
import platform
|
||||||
import uuid
|
import uuid
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
from pydantic.networks import Url
|
from pydantic_core import Url
|
||||||
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.app.services.events.events_base import EventServiceBase
|
from invokeai.app.services.events.events_base import EventServiceBase
|
||||||
@ -20,7 +21,7 @@ from invokeai.app.services.model_install import (
|
|||||||
URLModelSource,
|
URLModelSource,
|
||||||
)
|
)
|
||||||
from invokeai.app.services.model_records import UnknownModelException
|
from invokeai.app.services.model_records import UnknownModelException
|
||||||
from invokeai.backend.model_manager.config import BaseModelType, ModelFormat, ModelType
|
from invokeai.backend.model_manager.config import BaseModelType, InvalidModelConfigException, ModelFormat, ModelType
|
||||||
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
|
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
|
||||||
|
|
||||||
OS = platform.uname().system
|
OS = platform.uname().system
|
||||||
@ -53,7 +54,7 @@ def test_registration_meta(mm2_installer: ModelInstallServiceBase, embedding_fil
|
|||||||
|
|
||||||
def test_registration_meta_override_fail(mm2_installer: ModelInstallServiceBase, embedding_file: Path) -> None:
|
def test_registration_meta_override_fail(mm2_installer: ModelInstallServiceBase, embedding_file: Path) -> None:
|
||||||
key = None
|
key = None
|
||||||
with pytest.raises(ValidationError):
|
with pytest.raises((ValidationError, InvalidModelConfigException)):
|
||||||
key = mm2_installer.register_path(embedding_file, {"name": "banana_sushi", "type": ModelType("lora")})
|
key = mm2_installer.register_path(embedding_file, {"name": "banana_sushi", "type": ModelType("lora")})
|
||||||
assert key is None
|
assert key is None
|
||||||
|
|
||||||
@ -263,3 +264,50 @@ def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: In
|
|||||||
assert job.error
|
assert job.error
|
||||||
assert "NOT FOUND" in job.error
|
assert "NOT FOUND" in job.error
|
||||||
assert job.error_traceback.startswith("Traceback")
|
assert job.error_traceback.startswith("Traceback")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model_params",
|
||||||
|
[
|
||||||
|
# SDXL, Lora
|
||||||
|
{
|
||||||
|
"repo_id": "InvokeAI-test/textual_inversion_tests::learned_embeds-steps-1000.safetensors",
|
||||||
|
"name": "test_lora",
|
||||||
|
"type": "embedding",
|
||||||
|
},
|
||||||
|
# SDXL, Lora - incorrect type
|
||||||
|
{
|
||||||
|
"repo_id": "InvokeAI-test/textual_inversion_tests::learned_embeds-steps-1000.safetensors",
|
||||||
|
"name": "test_lora",
|
||||||
|
"type": "lora",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.timeout(timeout=40, method="thread")
|
||||||
|
def test_heuristic_import_with_type(mm2_installer: ModelInstallServiceBase, model_params: Dict[str, str]):
|
||||||
|
"""Test whether or not type is respected on configs when passed to heuristic import."""
|
||||||
|
assert "name" in model_params and "type" in model_params
|
||||||
|
config1: Dict[str, Any] = {
|
||||||
|
"name": f"{model_params['name']}_1",
|
||||||
|
"type": model_params["type"],
|
||||||
|
"hash": "placeholder1",
|
||||||
|
}
|
||||||
|
config2: Dict[str, Any] = {
|
||||||
|
"name": f"{model_params['name']}_2",
|
||||||
|
"type": ModelType(model_params["type"]),
|
||||||
|
"hash": "placeholder2",
|
||||||
|
}
|
||||||
|
assert "repo_id" in model_params
|
||||||
|
install_job1 = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config1)
|
||||||
|
mm2_installer.wait_for_job(install_job1, timeout=20)
|
||||||
|
if model_params["type"] != "embedding":
|
||||||
|
assert install_job1.errored
|
||||||
|
assert install_job1.error_type == "InvalidModelConfigException"
|
||||||
|
return
|
||||||
|
assert install_job1.complete
|
||||||
|
assert install_job1.config_out if model_params["type"] == "embedding" else not install_job1.config_out
|
||||||
|
|
||||||
|
install_job2 = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config2)
|
||||||
|
mm2_installer.wait_for_job(install_job2, timeout=20)
|
||||||
|
assert install_job2.complete
|
||||||
|
assert install_job2.config_out if model_params["type"] == "embedding" else not install_job2.config_out
|
||||||
|
@ -33,6 +33,7 @@ from invokeai.backend.model_manager.config import (
|
|||||||
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache
|
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
from tests.backend.model_manager.model_metadata.metadata_examples import (
|
from tests.backend.model_manager.model_metadata.metadata_examples import (
|
||||||
|
HFTestLoraMetadata,
|
||||||
RepoCivitaiModelMetadata1,
|
RepoCivitaiModelMetadata1,
|
||||||
RepoCivitaiVersionMetadata1,
|
RepoCivitaiVersionMetadata1,
|
||||||
RepoHFMetadata1,
|
RepoHFMetadata1,
|
||||||
@ -300,6 +301,20 @@ def mm2_session(embedding_file: Path, diffusers_dir: Path) -> Session:
|
|||||||
headers={"Content-Type": "application/json; charset=utf-8", "Content-Length": len(RepoHFMetadata1)},
|
headers={"Content-Type": "application/json; charset=utf-8", "Content-Length": len(RepoHFMetadata1)},
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
sess.mount(
|
||||||
|
"https://huggingface.co/api/models/InvokeAI-test/textual_inversion_tests?blobs=True",
|
||||||
|
TestAdapter(
|
||||||
|
HFTestLoraMetadata,
|
||||||
|
headers={"Content-Type": "application/json; charset=utf-8", "Content-Length": len(HFTestLoraMetadata)},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
sess.mount(
|
||||||
|
"https://huggingface.co/InvokeAI-test/textual_inversion_tests/resolve/main/learned_embeds-steps-1000.safetensors",
|
||||||
|
TestAdapter(
|
||||||
|
data,
|
||||||
|
headers={"Content-Type": "application/json; charset=utf-8", "Content-Length": len(data)},
|
||||||
|
),
|
||||||
|
)
|
||||||
for root, _, files in os.walk(diffusers_dir):
|
for root, _, files in os.walk(diffusers_dir):
|
||||||
for name in files:
|
for name in files:
|
||||||
path = Path(root, name)
|
path = Path(root, name)
|
||||||
|
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user