Compare commits

...

5 Commits

5 changed files with 102 additions and 96 deletions

View File

@ -87,6 +87,8 @@ class DownloadQueueService(DownloadQueueServiceBase):
self._queue.queue.clear() self._queue.queue.clear()
self.join() # wait for all active jobs to finish self.join() # wait for all active jobs to finish
self._stop_event.set() self._stop_event.set()
for thread in self._worker_pool:
thread.join()
self._worker_pool.clear() self._worker_pool.clear()
def submit_download_job( def submit_download_job(

View File

@ -34,6 +34,7 @@ from invokeai.backend.model_manager.config import (
from invokeai.backend.model_manager.metadata import ( from invokeai.backend.model_manager.metadata import (
AnyModelRepoMetadata, AnyModelRepoMetadata,
HuggingFaceMetadataFetch, HuggingFaceMetadataFetch,
ModelMetadataFetchBase,
ModelMetadataWithFiles, ModelMetadataWithFiles,
RemoteModelFile, RemoteModelFile,
) )
@ -92,6 +93,7 @@ class ModelInstallService(ModelInstallServiceBase):
self._download_cache: Dict[AnyHttpUrl, ModelInstallJob] = {} self._download_cache: Dict[AnyHttpUrl, ModelInstallJob] = {}
self._running = False self._running = False
self._session = session self._session = session
self._install_thread: Optional[threading.Thread] = None
self._next_job_id = 0 self._next_job_id = 0
@property @property
@ -126,6 +128,8 @@ class ModelInstallService(ModelInstallServiceBase):
self._stop_event.set() self._stop_event.set()
self._clear_pending_jobs() self._clear_pending_jobs()
self._download_cache.clear() self._download_cache.clear()
assert self._install_thread is not None
self._install_thread.join()
self._running = False self._running = False
def _clear_pending_jobs(self) -> None: def _clear_pending_jobs(self) -> None:
@ -275,6 +279,7 @@ class ModelInstallService(ModelInstallServiceBase):
if timeout > 0 and time.time() - start > timeout: if timeout > 0 and time.time() - start > timeout:
raise TimeoutError("Timeout exceeded") raise TimeoutError("Timeout exceeded")
self._install_queue.join() self._install_queue.join()
return self._install_jobs return self._install_jobs
def cancel_job(self, job: ModelInstallJob) -> None: def cancel_job(self, job: ModelInstallJob) -> None:
@ -415,15 +420,16 @@ class ModelInstallService(ModelInstallServiceBase):
# Internal functions that manage the installer threads # Internal functions that manage the installer threads
# -------------------------------------------------------------------------------------------- # --------------------------------------------------------------------------------------------
def _start_installer_thread(self) -> None: def _start_installer_thread(self) -> None:
threading.Thread(target=self._install_next_item, daemon=True).start() self._install_thread = threading.Thread(target=self._install_next_item, daemon=True)
self._install_thread.start()
self._running = True self._running = True
def _install_next_item(self) -> None: def _install_next_item(self) -> None:
done = False self._logger.info(f"Installer thread {threading.get_ident()} starting")
while not done: while True:
if self._stop_event.is_set(): if self._stop_event.is_set():
done = True break
continue self._logger.info(f"Installer thread {threading.get_ident()} running")
try: try:
job = self._install_queue.get(timeout=1) job = self._install_queue.get(timeout=1)
except Empty: except Empty:
@ -436,9 +442,25 @@ class ModelInstallService(ModelInstallServiceBase):
elif job.errored: elif job.errored:
self._signal_job_errored(job) self._signal_job_errored(job)
elif ( elif job.waiting or job.downloads_done:
job.waiting or job.downloads_done self._register_or_install(job)
): # local jobs will be in waiting state, remote jobs will be downloading state
except InvalidModelConfigException as excp:
self._set_error(job, excp)
except (OSError, DuplicateModelException) as excp:
self._set_error(job, excp)
finally:
# if this is an install of a remote file, then clean up the temporary directory
if job._install_tmpdir is not None:
rmtree(job._install_tmpdir)
self._install_completed_event.set()
self._install_queue.task_done()
self._logger.info(f"Installer thread {threading.get_ident()} exiting")
def _register_or_install(self, job: ModelInstallJob) -> None:
# local jobs will be in waiting state, remote jobs will be downloading state
job.total_bytes = self._stat_size(job.local_path) job.total_bytes = self._stat_size(job.local_path)
job.bytes = job.total_bytes job.bytes = job.total_bytes
self._signal_job_running(job) self._signal_job_running(job)
@ -455,7 +477,7 @@ class ModelInstallService(ModelInstallServiceBase):
job.config_out = self.record_store.get_model(key) job.config_out = self.record_store.get_model(key)
self._signal_job_completed(job) self._signal_job_completed(job)
except InvalidModelConfigException as excp: def _set_error(self, job: ModelInstallJob, excp: Exception) -> None:
if any(x.content_type is not None and "text/html" in x.content_type for x in job.download_parts): if any(x.content_type is not None and "text/html" in x.content_type for x in job.download_parts):
job.set_error( job.set_error(
InvalidModelConfigException( InvalidModelConfigException(
@ -466,17 +488,6 @@ class ModelInstallService(ModelInstallServiceBase):
job.set_error(excp) job.set_error(excp)
self._signal_job_errored(job) self._signal_job_errored(job)
except (OSError, DuplicateModelException) as excp:
job.set_error(excp)
self._signal_job_errored(job)
finally:
# if this is an install of a remote file, then clean up the temporary directory
if job._install_tmpdir is not None:
rmtree(job._install_tmpdir)
self._install_completed_event.set()
self._install_queue.task_done()
# -------------------------------------------------------------------------------------------- # --------------------------------------------------------------------------------------------
# Internal functions that manage the models directory # Internal functions that manage the models directory
# -------------------------------------------------------------------------------------------- # --------------------------------------------------------------------------------------------
@ -905,7 +916,7 @@ class ModelInstallService(ModelInstallServiceBase):
self._event_bus.emit_model_install_cancelled(str(job.source), id=job.id) self._event_bus.emit_model_install_cancelled(str(job.source), id=job.id)
@staticmethod @staticmethod
def get_fetcher_from_url(url: str): def get_fetcher_from_url(url: str) -> ModelMetadataFetchBase:
if re.match(r"^https?://huggingface.co/[^/]+/[^/]+$", url.lower()): if re.match(r"^https?://huggingface.co/[^/]+/[^/]+$", url.lower()):
return HuggingFaceMetadataFetch return HuggingFaceMetadataFetch
raise ValueError(f"Unsupported model source: '{url}'") raise ValueError(f"Unsupported model source: '{url}'")

View File

@ -51,6 +51,7 @@ def session() -> Session:
return sess return sess
@pytest.mark.timeout(timeout=20, method="thread")
def test_basic_queue_download(tmp_path: Path, session: Session) -> None: def test_basic_queue_download(tmp_path: Path, session: Session) -> None:
events = set() events = set()
@ -80,6 +81,7 @@ def test_basic_queue_download(tmp_path: Path, session: Session) -> None:
queue.stop() queue.stop()
@pytest.mark.timeout(timeout=20, method="thread")
def test_errors(tmp_path: Path, session: Session) -> None: def test_errors(tmp_path: Path, session: Session) -> None:
queue = DownloadQueueService( queue = DownloadQueueService(
requests_session=session, requests_session=session,
@ -101,6 +103,7 @@ def test_errors(tmp_path: Path, session: Session) -> None:
queue.stop() queue.stop()
@pytest.mark.timeout(timeout=20, method="thread")
def test_event_bus(tmp_path: Path, session: Session) -> None: def test_event_bus(tmp_path: Path, session: Session) -> None:
event_bus = TestEventService() event_bus = TestEventService()
@ -136,6 +139,7 @@ def test_event_bus(tmp_path: Path, session: Session) -> None:
queue.stop() queue.stop()
@pytest.mark.timeout(timeout=20, method="thread")
def test_broken_callbacks(tmp_path: Path, session: Session, capsys) -> None: def test_broken_callbacks(tmp_path: Path, session: Session, capsys) -> None:
queue = DownloadQueueService( queue = DownloadQueueService(
requests_session=session, requests_session=session,

View File

@ -5,6 +5,7 @@ Test the model installer
import platform import platform
import uuid import uuid
from pathlib import Path from pathlib import Path
from typing import Any, Dict
import pytest import pytest
from pydantic import ValidationError from pydantic import ValidationError
@ -276,48 +277,48 @@ def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: In
# TODO: Fix bug in model install causing jobs to get installed multiple times then uncomment this test # TODO: Fix bug in model install causing jobs to get installed multiple times then uncomment this test
# @pytest.mark.parametrize( @pytest.mark.parametrize(
# "model_params", "model_params",
# [ [
# # SDXL, Lora # SDXL, Lora
# { {
# "repo_id": "InvokeAI-test/textual_inversion_tests::learned_embeds-steps-1000.safetensors", "repo_id": "InvokeAI-test/textual_inversion_tests::learned_embeds-steps-1000.safetensors",
# "name": "test_lora", "name": "test_lora",
# "type": "embedding", "type": "embedding",
# }, },
# # SDXL, Lora - incorrect type # SDXL, Lora - incorrect type
# { {
# "repo_id": "InvokeAI-test/textual_inversion_tests::learned_embeds-steps-1000.safetensors", "repo_id": "InvokeAI-test/textual_inversion_tests::learned_embeds-steps-1000.safetensors",
# "name": "test_lora", "name": "test_lora",
# "type": "lora", "type": "lora",
# }, },
# ], ],
# ) )
# @pytest.mark.timeout(timeout=40, method="thread") @pytest.mark.timeout(timeout=40, method="thread")
# def test_heuristic_import_with_type(mm2_installer: ModelInstallServiceBase, model_params: Dict[str, str]): def test_heuristic_import_with_type(mm2_installer: ModelInstallServiceBase, model_params: Dict[str, str]):
# """Test whether or not type is respected on configs when passed to heuristic import.""" """Test whether or not type is respected on configs when passed to heuristic import."""
# assert "name" in model_params and "type" in model_params assert "name" in model_params and "type" in model_params
# config1: Dict[str, Any] = { config1: Dict[str, Any] = {
# "name": f"{model_params['name']}_1", "name": f"{model_params['name']}_1",
# "type": model_params["type"], "type": model_params["type"],
# "hash": "placeholder1", "hash": "placeholder1",
# } }
# config2: Dict[str, Any] = { config2: Dict[str, Any] = {
# "name": f"{model_params['name']}_2", "name": f"{model_params['name']}_2",
# "type": ModelType(model_params["type"]), "type": ModelType(model_params["type"]),
# "hash": "placeholder2", "hash": "placeholder2",
# } }
# assert "repo_id" in model_params assert "repo_id" in model_params
# install_job1 = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config1) install_job1 = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config1)
# mm2_installer.wait_for_job(install_job1, timeout=20) mm2_installer.wait_for_job(install_job1, timeout=20)
# if model_params["type"] != "embedding": if model_params["type"] != "embedding":
# assert install_job1.errored assert install_job1.errored
# assert install_job1.error_type == "InvalidModelConfigException" assert install_job1.error_type == "InvalidModelConfigException"
# return return
# assert install_job1.complete assert install_job1.complete
# assert install_job1.config_out if model_params["type"] == "embedding" else not install_job1.config_out assert install_job1.config_out if model_params["type"] == "embedding" else not install_job1.config_out
# install_job2 = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config2) install_job2 = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config2)
# mm2_installer.wait_for_job(install_job2, timeout=20) mm2_installer.wait_for_job(install_job2, timeout=20)
# assert install_job2.complete assert install_job2.complete
# assert install_job2.config_out if model_params["type"] == "embedding" else not install_job2.config_out assert install_job2.config_out if model_params["type"] == "embedding" else not install_job2.config_out

View File

@ -2,13 +2,11 @@
import os import os
import shutil import shutil
import time
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List from typing import Any, Dict, List
import pytest import pytest
from pydantic import BaseModel from pydantic import BaseModel
from pytest import FixtureRequest
from requests.sessions import Session from requests.sessions import Session
from requests_testadapter import TestAdapter, TestSession from requests_testadapter import TestAdapter, TestSession
@ -99,16 +97,12 @@ def mm2_app_config(mm2_root_dir: Path) -> InvokeAIAppConfig:
@pytest.fixture @pytest.fixture
def mm2_download_queue(mm2_session: Session, request: FixtureRequest) -> DownloadQueueServiceBase: def mm2_download_queue(mm2_session: Session) -> DownloadQueueServiceBase:
download_queue = DownloadQueueService(requests_session=mm2_session) download_queue = DownloadQueueService(requests_session=mm2_session)
download_queue.start() download_queue.start()
yield download_queue
def stop_queue() -> None:
download_queue.stop() download_queue.stop()
request.addfinalizer(stop_queue)
return download_queue
@pytest.fixture @pytest.fixture
def mm2_loader(mm2_app_config: InvokeAIAppConfig, mm2_record_store: ModelRecordServiceBase) -> ModelLoadServiceBase: def mm2_loader(mm2_app_config: InvokeAIAppConfig, mm2_record_store: ModelRecordServiceBase) -> ModelLoadServiceBase:
@ -130,7 +124,6 @@ def mm2_installer(
mm2_app_config: InvokeAIAppConfig, mm2_app_config: InvokeAIAppConfig,
mm2_download_queue: DownloadQueueServiceBase, mm2_download_queue: DownloadQueueServiceBase,
mm2_session: Session, mm2_session: Session,
request: FixtureRequest,
) -> ModelInstallServiceBase: ) -> ModelInstallServiceBase:
logger = InvokeAILogger.get_logger() logger = InvokeAILogger.get_logger()
db = create_mock_sqlite_database(mm2_app_config, logger) db = create_mock_sqlite_database(mm2_app_config, logger)
@ -145,13 +138,8 @@ def mm2_installer(
session=mm2_session, session=mm2_session,
) )
installer.start() installer.start()
yield installer
def stop_installer() -> None:
installer.stop() installer.stop()
time.sleep(0.1) # avoid error message from the logger when it is closed before thread prints final message
request.addfinalizer(stop_installer)
return installer
@pytest.fixture @pytest.fixture