mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Compare commits
5 Commits
lstein/doc
...
lstein/deb
Author | SHA1 | Date | |
---|---|---|---|
12f9bda524 | |||
b65eff1c65 | |||
ce687a2869 | |||
e452c6171b | |||
b15d05f8a8 |
@ -87,6 +87,8 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
|||||||
self._queue.queue.clear()
|
self._queue.queue.clear()
|
||||||
self.join() # wait for all active jobs to finish
|
self.join() # wait for all active jobs to finish
|
||||||
self._stop_event.set()
|
self._stop_event.set()
|
||||||
|
for thread in self._worker_pool:
|
||||||
|
thread.join()
|
||||||
self._worker_pool.clear()
|
self._worker_pool.clear()
|
||||||
|
|
||||||
def submit_download_job(
|
def submit_download_job(
|
||||||
|
@ -34,6 +34,7 @@ from invokeai.backend.model_manager.config import (
|
|||||||
from invokeai.backend.model_manager.metadata import (
|
from invokeai.backend.model_manager.metadata import (
|
||||||
AnyModelRepoMetadata,
|
AnyModelRepoMetadata,
|
||||||
HuggingFaceMetadataFetch,
|
HuggingFaceMetadataFetch,
|
||||||
|
ModelMetadataFetchBase,
|
||||||
ModelMetadataWithFiles,
|
ModelMetadataWithFiles,
|
||||||
RemoteModelFile,
|
RemoteModelFile,
|
||||||
)
|
)
|
||||||
@ -92,6 +93,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
self._download_cache: Dict[AnyHttpUrl, ModelInstallJob] = {}
|
self._download_cache: Dict[AnyHttpUrl, ModelInstallJob] = {}
|
||||||
self._running = False
|
self._running = False
|
||||||
self._session = session
|
self._session = session
|
||||||
|
self._install_thread: Optional[threading.Thread] = None
|
||||||
self._next_job_id = 0
|
self._next_job_id = 0
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -126,6 +128,8 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
self._stop_event.set()
|
self._stop_event.set()
|
||||||
self._clear_pending_jobs()
|
self._clear_pending_jobs()
|
||||||
self._download_cache.clear()
|
self._download_cache.clear()
|
||||||
|
assert self._install_thread is not None
|
||||||
|
self._install_thread.join()
|
||||||
self._running = False
|
self._running = False
|
||||||
|
|
||||||
def _clear_pending_jobs(self) -> None:
|
def _clear_pending_jobs(self) -> None:
|
||||||
@ -275,6 +279,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
if timeout > 0 and time.time() - start > timeout:
|
if timeout > 0 and time.time() - start > timeout:
|
||||||
raise TimeoutError("Timeout exceeded")
|
raise TimeoutError("Timeout exceeded")
|
||||||
self._install_queue.join()
|
self._install_queue.join()
|
||||||
|
|
||||||
return self._install_jobs
|
return self._install_jobs
|
||||||
|
|
||||||
def cancel_job(self, job: ModelInstallJob) -> None:
|
def cancel_job(self, job: ModelInstallJob) -> None:
|
||||||
@ -415,15 +420,16 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
# Internal functions that manage the installer threads
|
# Internal functions that manage the installer threads
|
||||||
# --------------------------------------------------------------------------------------------
|
# --------------------------------------------------------------------------------------------
|
||||||
def _start_installer_thread(self) -> None:
|
def _start_installer_thread(self) -> None:
|
||||||
threading.Thread(target=self._install_next_item, daemon=True).start()
|
self._install_thread = threading.Thread(target=self._install_next_item, daemon=True)
|
||||||
|
self._install_thread.start()
|
||||||
self._running = True
|
self._running = True
|
||||||
|
|
||||||
def _install_next_item(self) -> None:
|
def _install_next_item(self) -> None:
|
||||||
done = False
|
self._logger.info(f"Installer thread {threading.get_ident()} starting")
|
||||||
while not done:
|
while True:
|
||||||
if self._stop_event.is_set():
|
if self._stop_event.is_set():
|
||||||
done = True
|
break
|
||||||
continue
|
self._logger.info(f"Installer thread {threading.get_ident()} running")
|
||||||
try:
|
try:
|
||||||
job = self._install_queue.get(timeout=1)
|
job = self._install_queue.get(timeout=1)
|
||||||
except Empty:
|
except Empty:
|
||||||
@ -436,9 +442,25 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
elif job.errored:
|
elif job.errored:
|
||||||
self._signal_job_errored(job)
|
self._signal_job_errored(job)
|
||||||
|
|
||||||
elif (
|
elif job.waiting or job.downloads_done:
|
||||||
job.waiting or job.downloads_done
|
self._register_or_install(job)
|
||||||
): # local jobs will be in waiting state, remote jobs will be downloading state
|
|
||||||
|
except InvalidModelConfigException as excp:
|
||||||
|
self._set_error(job, excp)
|
||||||
|
|
||||||
|
except (OSError, DuplicateModelException) as excp:
|
||||||
|
self._set_error(job, excp)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# if this is an install of a remote file, then clean up the temporary directory
|
||||||
|
if job._install_tmpdir is not None:
|
||||||
|
rmtree(job._install_tmpdir)
|
||||||
|
self._install_completed_event.set()
|
||||||
|
self._install_queue.task_done()
|
||||||
|
self._logger.info(f"Installer thread {threading.get_ident()} exiting")
|
||||||
|
|
||||||
|
def _register_or_install(self, job: ModelInstallJob) -> None:
|
||||||
|
# local jobs will be in waiting state, remote jobs will be downloading state
|
||||||
job.total_bytes = self._stat_size(job.local_path)
|
job.total_bytes = self._stat_size(job.local_path)
|
||||||
job.bytes = job.total_bytes
|
job.bytes = job.total_bytes
|
||||||
self._signal_job_running(job)
|
self._signal_job_running(job)
|
||||||
@ -455,7 +477,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
job.config_out = self.record_store.get_model(key)
|
job.config_out = self.record_store.get_model(key)
|
||||||
self._signal_job_completed(job)
|
self._signal_job_completed(job)
|
||||||
|
|
||||||
except InvalidModelConfigException as excp:
|
def _set_error(self, job: ModelInstallJob, excp: Exception) -> None:
|
||||||
if any(x.content_type is not None and "text/html" in x.content_type for x in job.download_parts):
|
if any(x.content_type is not None and "text/html" in x.content_type for x in job.download_parts):
|
||||||
job.set_error(
|
job.set_error(
|
||||||
InvalidModelConfigException(
|
InvalidModelConfigException(
|
||||||
@ -466,17 +488,6 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
job.set_error(excp)
|
job.set_error(excp)
|
||||||
self._signal_job_errored(job)
|
self._signal_job_errored(job)
|
||||||
|
|
||||||
except (OSError, DuplicateModelException) as excp:
|
|
||||||
job.set_error(excp)
|
|
||||||
self._signal_job_errored(job)
|
|
||||||
|
|
||||||
finally:
|
|
||||||
# if this is an install of a remote file, then clean up the temporary directory
|
|
||||||
if job._install_tmpdir is not None:
|
|
||||||
rmtree(job._install_tmpdir)
|
|
||||||
self._install_completed_event.set()
|
|
||||||
self._install_queue.task_done()
|
|
||||||
|
|
||||||
# --------------------------------------------------------------------------------------------
|
# --------------------------------------------------------------------------------------------
|
||||||
# Internal functions that manage the models directory
|
# Internal functions that manage the models directory
|
||||||
# --------------------------------------------------------------------------------------------
|
# --------------------------------------------------------------------------------------------
|
||||||
@ -905,7 +916,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
self._event_bus.emit_model_install_cancelled(str(job.source), id=job.id)
|
self._event_bus.emit_model_install_cancelled(str(job.source), id=job.id)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_fetcher_from_url(url: str):
|
def get_fetcher_from_url(url: str) -> ModelMetadataFetchBase:
|
||||||
if re.match(r"^https?://huggingface.co/[^/]+/[^/]+$", url.lower()):
|
if re.match(r"^https?://huggingface.co/[^/]+/[^/]+$", url.lower()):
|
||||||
return HuggingFaceMetadataFetch
|
return HuggingFaceMetadataFetch
|
||||||
raise ValueError(f"Unsupported model source: '{url}'")
|
raise ValueError(f"Unsupported model source: '{url}'")
|
||||||
|
@ -51,6 +51,7 @@ def session() -> Session:
|
|||||||
return sess
|
return sess
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.timeout(timeout=20, method="thread")
|
||||||
def test_basic_queue_download(tmp_path: Path, session: Session) -> None:
|
def test_basic_queue_download(tmp_path: Path, session: Session) -> None:
|
||||||
events = set()
|
events = set()
|
||||||
|
|
||||||
@ -80,6 +81,7 @@ def test_basic_queue_download(tmp_path: Path, session: Session) -> None:
|
|||||||
queue.stop()
|
queue.stop()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.timeout(timeout=20, method="thread")
|
||||||
def test_errors(tmp_path: Path, session: Session) -> None:
|
def test_errors(tmp_path: Path, session: Session) -> None:
|
||||||
queue = DownloadQueueService(
|
queue = DownloadQueueService(
|
||||||
requests_session=session,
|
requests_session=session,
|
||||||
@ -101,6 +103,7 @@ def test_errors(tmp_path: Path, session: Session) -> None:
|
|||||||
queue.stop()
|
queue.stop()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.timeout(timeout=20, method="thread")
|
||||||
def test_event_bus(tmp_path: Path, session: Session) -> None:
|
def test_event_bus(tmp_path: Path, session: Session) -> None:
|
||||||
event_bus = TestEventService()
|
event_bus = TestEventService()
|
||||||
|
|
||||||
@ -136,6 +139,7 @@ def test_event_bus(tmp_path: Path, session: Session) -> None:
|
|||||||
queue.stop()
|
queue.stop()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.timeout(timeout=20, method="thread")
|
||||||
def test_broken_callbacks(tmp_path: Path, session: Session, capsys) -> None:
|
def test_broken_callbacks(tmp_path: Path, session: Session, capsys) -> None:
|
||||||
queue = DownloadQueueService(
|
queue = DownloadQueueService(
|
||||||
requests_session=session,
|
requests_session=session,
|
||||||
|
@ -5,6 +5,7 @@ Test the model installer
|
|||||||
import platform
|
import platform
|
||||||
import uuid
|
import uuid
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
@ -276,48 +277,48 @@ def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: In
|
|||||||
|
|
||||||
|
|
||||||
# TODO: Fix bug in model install causing jobs to get installed multiple times then uncomment this test
|
# TODO: Fix bug in model install causing jobs to get installed multiple times then uncomment this test
|
||||||
# @pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
# "model_params",
|
"model_params",
|
||||||
# [
|
[
|
||||||
# # SDXL, Lora
|
# SDXL, Lora
|
||||||
# {
|
{
|
||||||
# "repo_id": "InvokeAI-test/textual_inversion_tests::learned_embeds-steps-1000.safetensors",
|
"repo_id": "InvokeAI-test/textual_inversion_tests::learned_embeds-steps-1000.safetensors",
|
||||||
# "name": "test_lora",
|
"name": "test_lora",
|
||||||
# "type": "embedding",
|
"type": "embedding",
|
||||||
# },
|
},
|
||||||
# # SDXL, Lora - incorrect type
|
# SDXL, Lora - incorrect type
|
||||||
# {
|
{
|
||||||
# "repo_id": "InvokeAI-test/textual_inversion_tests::learned_embeds-steps-1000.safetensors",
|
"repo_id": "InvokeAI-test/textual_inversion_tests::learned_embeds-steps-1000.safetensors",
|
||||||
# "name": "test_lora",
|
"name": "test_lora",
|
||||||
# "type": "lora",
|
"type": "lora",
|
||||||
# },
|
},
|
||||||
# ],
|
],
|
||||||
# )
|
)
|
||||||
# @pytest.mark.timeout(timeout=40, method="thread")
|
@pytest.mark.timeout(timeout=40, method="thread")
|
||||||
# def test_heuristic_import_with_type(mm2_installer: ModelInstallServiceBase, model_params: Dict[str, str]):
|
def test_heuristic_import_with_type(mm2_installer: ModelInstallServiceBase, model_params: Dict[str, str]):
|
||||||
# """Test whether or not type is respected on configs when passed to heuristic import."""
|
"""Test whether or not type is respected on configs when passed to heuristic import."""
|
||||||
# assert "name" in model_params and "type" in model_params
|
assert "name" in model_params and "type" in model_params
|
||||||
# config1: Dict[str, Any] = {
|
config1: Dict[str, Any] = {
|
||||||
# "name": f"{model_params['name']}_1",
|
"name": f"{model_params['name']}_1",
|
||||||
# "type": model_params["type"],
|
"type": model_params["type"],
|
||||||
# "hash": "placeholder1",
|
"hash": "placeholder1",
|
||||||
# }
|
}
|
||||||
# config2: Dict[str, Any] = {
|
config2: Dict[str, Any] = {
|
||||||
# "name": f"{model_params['name']}_2",
|
"name": f"{model_params['name']}_2",
|
||||||
# "type": ModelType(model_params["type"]),
|
"type": ModelType(model_params["type"]),
|
||||||
# "hash": "placeholder2",
|
"hash": "placeholder2",
|
||||||
# }
|
}
|
||||||
# assert "repo_id" in model_params
|
assert "repo_id" in model_params
|
||||||
# install_job1 = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config1)
|
install_job1 = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config1)
|
||||||
# mm2_installer.wait_for_job(install_job1, timeout=20)
|
mm2_installer.wait_for_job(install_job1, timeout=20)
|
||||||
# if model_params["type"] != "embedding":
|
if model_params["type"] != "embedding":
|
||||||
# assert install_job1.errored
|
assert install_job1.errored
|
||||||
# assert install_job1.error_type == "InvalidModelConfigException"
|
assert install_job1.error_type == "InvalidModelConfigException"
|
||||||
# return
|
return
|
||||||
# assert install_job1.complete
|
assert install_job1.complete
|
||||||
# assert install_job1.config_out if model_params["type"] == "embedding" else not install_job1.config_out
|
assert install_job1.config_out if model_params["type"] == "embedding" else not install_job1.config_out
|
||||||
|
|
||||||
# install_job2 = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config2)
|
install_job2 = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config2)
|
||||||
# mm2_installer.wait_for_job(install_job2, timeout=20)
|
mm2_installer.wait_for_job(install_job2, timeout=20)
|
||||||
# assert install_job2.complete
|
assert install_job2.complete
|
||||||
# assert install_job2.config_out if model_params["type"] == "embedding" else not install_job2.config_out
|
assert install_job2.config_out if model_params["type"] == "embedding" else not install_job2.config_out
|
||||||
|
@ -2,13 +2,11 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import time
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from pytest import FixtureRequest
|
|
||||||
from requests.sessions import Session
|
from requests.sessions import Session
|
||||||
from requests_testadapter import TestAdapter, TestSession
|
from requests_testadapter import TestAdapter, TestSession
|
||||||
|
|
||||||
@ -99,16 +97,12 @@ def mm2_app_config(mm2_root_dir: Path) -> InvokeAIAppConfig:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mm2_download_queue(mm2_session: Session, request: FixtureRequest) -> DownloadQueueServiceBase:
|
def mm2_download_queue(mm2_session: Session) -> DownloadQueueServiceBase:
|
||||||
download_queue = DownloadQueueService(requests_session=mm2_session)
|
download_queue = DownloadQueueService(requests_session=mm2_session)
|
||||||
download_queue.start()
|
download_queue.start()
|
||||||
|
yield download_queue
|
||||||
def stop_queue() -> None:
|
|
||||||
download_queue.stop()
|
download_queue.stop()
|
||||||
|
|
||||||
request.addfinalizer(stop_queue)
|
|
||||||
return download_queue
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mm2_loader(mm2_app_config: InvokeAIAppConfig, mm2_record_store: ModelRecordServiceBase) -> ModelLoadServiceBase:
|
def mm2_loader(mm2_app_config: InvokeAIAppConfig, mm2_record_store: ModelRecordServiceBase) -> ModelLoadServiceBase:
|
||||||
@ -130,7 +124,6 @@ def mm2_installer(
|
|||||||
mm2_app_config: InvokeAIAppConfig,
|
mm2_app_config: InvokeAIAppConfig,
|
||||||
mm2_download_queue: DownloadQueueServiceBase,
|
mm2_download_queue: DownloadQueueServiceBase,
|
||||||
mm2_session: Session,
|
mm2_session: Session,
|
||||||
request: FixtureRequest,
|
|
||||||
) -> ModelInstallServiceBase:
|
) -> ModelInstallServiceBase:
|
||||||
logger = InvokeAILogger.get_logger()
|
logger = InvokeAILogger.get_logger()
|
||||||
db = create_mock_sqlite_database(mm2_app_config, logger)
|
db = create_mock_sqlite_database(mm2_app_config, logger)
|
||||||
@ -145,13 +138,8 @@ def mm2_installer(
|
|||||||
session=mm2_session,
|
session=mm2_session,
|
||||||
)
|
)
|
||||||
installer.start()
|
installer.start()
|
||||||
|
yield installer
|
||||||
def stop_installer() -> None:
|
|
||||||
installer.stop()
|
installer.stop()
|
||||||
time.sleep(0.1) # avoid error message from the logger when it is closed before thread prints final message
|
|
||||||
|
|
||||||
request.addfinalizer(stop_installer)
|
|
||||||
return installer
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
Reference in New Issue
Block a user