Compare commits

...

5 Commits

5 changed files with 102 additions and 96 deletions

View File

@ -87,6 +87,8 @@ class DownloadQueueService(DownloadQueueServiceBase):
self._queue.queue.clear()
self.join() # wait for all active jobs to finish
self._stop_event.set()
for thread in self._worker_pool:
thread.join()
self._worker_pool.clear()
def submit_download_job(

View File

@ -34,6 +34,7 @@ from invokeai.backend.model_manager.config import (
from invokeai.backend.model_manager.metadata import (
AnyModelRepoMetadata,
HuggingFaceMetadataFetch,
ModelMetadataFetchBase,
ModelMetadataWithFiles,
RemoteModelFile,
)
@ -92,6 +93,7 @@ class ModelInstallService(ModelInstallServiceBase):
self._download_cache: Dict[AnyHttpUrl, ModelInstallJob] = {}
self._running = False
self._session = session
self._install_thread: Optional[threading.Thread] = None
self._next_job_id = 0
@property
@ -126,6 +128,8 @@ class ModelInstallService(ModelInstallServiceBase):
self._stop_event.set()
self._clear_pending_jobs()
self._download_cache.clear()
assert self._install_thread is not None
self._install_thread.join()
self._running = False
def _clear_pending_jobs(self) -> None:
@ -275,6 +279,7 @@ class ModelInstallService(ModelInstallServiceBase):
if timeout > 0 and time.time() - start > timeout:
raise TimeoutError("Timeout exceeded")
self._install_queue.join()
return self._install_jobs
def cancel_job(self, job: ModelInstallJob) -> None:
@ -415,15 +420,16 @@ class ModelInstallService(ModelInstallServiceBase):
# Internal functions that manage the installer threads
# --------------------------------------------------------------------------------------------
def _start_installer_thread(self) -> None:
threading.Thread(target=self._install_next_item, daemon=True).start()
self._install_thread = threading.Thread(target=self._install_next_item, daemon=True)
self._install_thread.start()
self._running = True
def _install_next_item(self) -> None:
done = False
while not done:
self._logger.info(f"Installer thread {threading.get_ident()} starting")
while True:
if self._stop_event.is_set():
done = True
continue
break
self._logger.info(f"Installer thread {threading.get_ident()} running")
try:
job = self._install_queue.get(timeout=1)
except Empty:
@ -436,39 +442,14 @@ class ModelInstallService(ModelInstallServiceBase):
elif job.errored:
self._signal_job_errored(job)
elif (
job.waiting or job.downloads_done
): # local jobs will be in waiting state, remote jobs will be downloading state
job.total_bytes = self._stat_size(job.local_path)
job.bytes = job.total_bytes
self._signal_job_running(job)
job.config_in["source"] = str(job.source)
job.config_in["source_type"] = MODEL_SOURCE_TO_TYPE_MAP[job.source.__class__]
# enter the metadata, if there is any
if isinstance(job.source_metadata, (HuggingFaceMetadata)):
job.config_in["source_api_response"] = job.source_metadata.api_response
if job.inplace:
key = self.register_path(job.local_path, job.config_in)
else:
key = self.install_path(job.local_path, job.config_in)
job.config_out = self.record_store.get_model(key)
self._signal_job_completed(job)
elif job.waiting or job.downloads_done:
self._register_or_install(job)
except InvalidModelConfigException as excp:
if any(x.content_type is not None and "text/html" in x.content_type for x in job.download_parts):
job.set_error(
InvalidModelConfigException(
f"At least one file in {job.local_path} is an HTML page, not a model. This can happen when an access token is required to download."
)
)
else:
job.set_error(excp)
self._signal_job_errored(job)
self._set_error(job, excp)
except (OSError, DuplicateModelException) as excp:
job.set_error(excp)
self._signal_job_errored(job)
self._set_error(job, excp)
finally:
# if this is an install of a remote file, then clean up the temporary directory
@ -476,6 +457,36 @@ class ModelInstallService(ModelInstallServiceBase):
rmtree(job._install_tmpdir)
self._install_completed_event.set()
self._install_queue.task_done()
self._logger.info(f"Installer thread {threading.get_ident()} exiting")
def _register_or_install(self, job: ModelInstallJob) -> None:
# local jobs will be in waiting state, remote jobs will be downloading state
job.total_bytes = self._stat_size(job.local_path)
job.bytes = job.total_bytes
self._signal_job_running(job)
job.config_in["source"] = str(job.source)
job.config_in["source_type"] = MODEL_SOURCE_TO_TYPE_MAP[job.source.__class__]
# enter the metadata, if there is any
if isinstance(job.source_metadata, (HuggingFaceMetadata)):
job.config_in["source_api_response"] = job.source_metadata.api_response
if job.inplace:
key = self.register_path(job.local_path, job.config_in)
else:
key = self.install_path(job.local_path, job.config_in)
job.config_out = self.record_store.get_model(key)
self._signal_job_completed(job)
def _set_error(self, job: ModelInstallJob, excp: Exception) -> None:
if any(x.content_type is not None and "text/html" in x.content_type for x in job.download_parts):
job.set_error(
InvalidModelConfigException(
f"At least one file in {job.local_path} is an HTML page, not a model. This can happen when an access token is required to download."
)
)
else:
job.set_error(excp)
self._signal_job_errored(job)
# --------------------------------------------------------------------------------------------
# Internal functions that manage the models directory
@ -905,7 +916,7 @@ class ModelInstallService(ModelInstallServiceBase):
self._event_bus.emit_model_install_cancelled(str(job.source), id=job.id)
@staticmethod
def get_fetcher_from_url(url: str):
def get_fetcher_from_url(url: str) -> ModelMetadataFetchBase:
if re.match(r"^https?://huggingface.co/[^/]+/[^/]+$", url.lower()):
return HuggingFaceMetadataFetch
raise ValueError(f"Unsupported model source: '{url}'")

View File

@ -51,6 +51,7 @@ def session() -> Session:
return sess
@pytest.mark.timeout(timeout=20, method="thread")
def test_basic_queue_download(tmp_path: Path, session: Session) -> None:
events = set()
@ -80,6 +81,7 @@ def test_basic_queue_download(tmp_path: Path, session: Session) -> None:
queue.stop()
@pytest.mark.timeout(timeout=20, method="thread")
def test_errors(tmp_path: Path, session: Session) -> None:
queue = DownloadQueueService(
requests_session=session,
@ -101,6 +103,7 @@ def test_errors(tmp_path: Path, session: Session) -> None:
queue.stop()
@pytest.mark.timeout(timeout=20, method="thread")
def test_event_bus(tmp_path: Path, session: Session) -> None:
event_bus = TestEventService()
@ -136,6 +139,7 @@ def test_event_bus(tmp_path: Path, session: Session) -> None:
queue.stop()
@pytest.mark.timeout(timeout=20, method="thread")
def test_broken_callbacks(tmp_path: Path, session: Session, capsys) -> None:
queue = DownloadQueueService(
requests_session=session,

View File

@ -5,6 +5,7 @@ Test the model installer
import platform
import uuid
from pathlib import Path
from typing import Any, Dict
import pytest
from pydantic import ValidationError
@ -276,48 +277,48 @@ def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: In
# TODO: Fix bug in model install causing jobs to get installed multiple times then uncomment this test
# @pytest.mark.parametrize(
# "model_params",
# [
# # SDXL, Lora
# {
# "repo_id": "InvokeAI-test/textual_inversion_tests::learned_embeds-steps-1000.safetensors",
# "name": "test_lora",
# "type": "embedding",
# },
# # SDXL, Lora - incorrect type
# {
# "repo_id": "InvokeAI-test/textual_inversion_tests::learned_embeds-steps-1000.safetensors",
# "name": "test_lora",
# "type": "lora",
# },
# ],
# )
# @pytest.mark.timeout(timeout=40, method="thread")
# def test_heuristic_import_with_type(mm2_installer: ModelInstallServiceBase, model_params: Dict[str, str]):
# """Test whether or not type is respected on configs when passed to heuristic import."""
# assert "name" in model_params and "type" in model_params
# config1: Dict[str, Any] = {
# "name": f"{model_params['name']}_1",
# "type": model_params["type"],
# "hash": "placeholder1",
# }
# config2: Dict[str, Any] = {
# "name": f"{model_params['name']}_2",
# "type": ModelType(model_params["type"]),
# "hash": "placeholder2",
# }
# assert "repo_id" in model_params
# install_job1 = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config1)
# mm2_installer.wait_for_job(install_job1, timeout=20)
# if model_params["type"] != "embedding":
# assert install_job1.errored
# assert install_job1.error_type == "InvalidModelConfigException"
# return
# assert install_job1.complete
# assert install_job1.config_out if model_params["type"] == "embedding" else not install_job1.config_out
@pytest.mark.parametrize(
"model_params",
[
# SDXL, Lora
{
"repo_id": "InvokeAI-test/textual_inversion_tests::learned_embeds-steps-1000.safetensors",
"name": "test_lora",
"type": "embedding",
},
# SDXL, Lora - incorrect type
{
"repo_id": "InvokeAI-test/textual_inversion_tests::learned_embeds-steps-1000.safetensors",
"name": "test_lora",
"type": "lora",
},
],
)
@pytest.mark.timeout(timeout=40, method="thread")
def test_heuristic_import_with_type(mm2_installer: ModelInstallServiceBase, model_params: Dict[str, str]):
"""Test whether or not type is respected on configs when passed to heuristic import."""
assert "name" in model_params and "type" in model_params
config1: Dict[str, Any] = {
"name": f"{model_params['name']}_1",
"type": model_params["type"],
"hash": "placeholder1",
}
config2: Dict[str, Any] = {
"name": f"{model_params['name']}_2",
"type": ModelType(model_params["type"]),
"hash": "placeholder2",
}
assert "repo_id" in model_params
install_job1 = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config1)
mm2_installer.wait_for_job(install_job1, timeout=20)
if model_params["type"] != "embedding":
assert install_job1.errored
assert install_job1.error_type == "InvalidModelConfigException"
return
assert install_job1.complete
assert install_job1.config_out if model_params["type"] == "embedding" else not install_job1.config_out
# install_job2 = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config2)
# mm2_installer.wait_for_job(install_job2, timeout=20)
# assert install_job2.complete
# assert install_job2.config_out if model_params["type"] == "embedding" else not install_job2.config_out
install_job2 = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config2)
mm2_installer.wait_for_job(install_job2, timeout=20)
assert install_job2.complete
assert install_job2.config_out if model_params["type"] == "embedding" else not install_job2.config_out

View File

@ -2,13 +2,11 @@
import os
import shutil
import time
from pathlib import Path
from typing import Any, Dict, List
import pytest
from pydantic import BaseModel
from pytest import FixtureRequest
from requests.sessions import Session
from requests_testadapter import TestAdapter, TestSession
@ -99,15 +97,11 @@ def mm2_app_config(mm2_root_dir: Path) -> InvokeAIAppConfig:
@pytest.fixture
def mm2_download_queue(mm2_session: Session, request: FixtureRequest) -> DownloadQueueServiceBase:
def mm2_download_queue(mm2_session: Session) -> DownloadQueueServiceBase:
download_queue = DownloadQueueService(requests_session=mm2_session)
download_queue.start()
def stop_queue() -> None:
download_queue.stop()
request.addfinalizer(stop_queue)
return download_queue
yield download_queue
download_queue.stop()
@pytest.fixture
@ -130,7 +124,6 @@ def mm2_installer(
mm2_app_config: InvokeAIAppConfig,
mm2_download_queue: DownloadQueueServiceBase,
mm2_session: Session,
request: FixtureRequest,
) -> ModelInstallServiceBase:
logger = InvokeAILogger.get_logger()
db = create_mock_sqlite_database(mm2_app_config, logger)
@ -145,13 +138,8 @@ def mm2_installer(
session=mm2_session,
)
installer.start()
def stop_installer() -> None:
installer.stop()
time.sleep(0.1) # avoid error message from the logger when it is closed before thread prints final message
request.addfinalizer(stop_installer)
return installer
yield installer
installer.stop()
@pytest.fixture