mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix race condition between downloading last file and starting install
This commit is contained in:
parent
e18533e3b5
commit
813a086cfe
@ -221,29 +221,21 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
except Empty:
|
||||
continue
|
||||
try:
|
||||
print(f'DEBUG: job [{job.id}] started', flush=True)
|
||||
job.job_started = get_iso_timestamp()
|
||||
self._do_download(job)
|
||||
print(f'DEBUG: job [{job.id}] download completed', flush=True)
|
||||
self._signal_job_complete(job)
|
||||
print(f'DEBUG: job [{job.id}] signaled completion', flush=True)
|
||||
except (OSError, HTTPError) as excp:
|
||||
job.error_type = excp.__class__.__name__ + f"({str(excp)})"
|
||||
job.error = traceback.format_exc()
|
||||
self._signal_job_error(job, excp)
|
||||
print(f'DEBUG: job [{job.id}] signaled error', flush=True)
|
||||
except DownloadJobCancelledException:
|
||||
self._signal_job_cancelled(job)
|
||||
self._cleanup_cancelled_job(job)
|
||||
print(f'DEBUG: job [{job.id}] signaled cancelled', flush=True)
|
||||
|
||||
finally:
|
||||
print(f'DEBUG: job [{job.id}] signalling completion', flush=True)
|
||||
job.job_ended = get_iso_timestamp()
|
||||
self._job_completed_event.set() # signal a change to terminal state
|
||||
print(f'DEBUG: job [{job.id}] set job completion event', flush=True)
|
||||
self._queue.task_done()
|
||||
print(f'DEBUG: job [{job.id}] done', flush=True)
|
||||
self._logger.debug(f"Download queue worker thread {threading.current_thread().name} exiting.")
|
||||
|
||||
def _do_download(self, job: DownloadJob) -> None:
|
||||
@ -251,7 +243,6 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
url = job.source
|
||||
header = {"Authorization": f"Bearer {job.access_token}"} if job.access_token else {}
|
||||
open_mode = "wb"
|
||||
print(f'DEBUG: In _do_download [0]', flush=True)
|
||||
|
||||
# Make a streaming request. This will retrieve headers including
|
||||
# content-length and content-disposition, but not fetch any content itself
|
||||
@ -263,8 +254,6 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
content_length = int(resp.headers.get("content-length", 0))
|
||||
job.total_bytes = content_length
|
||||
|
||||
print(f'DEBUG: In _do_download [1]')
|
||||
|
||||
if job.dest.is_dir():
|
||||
file_name = os.path.basename(str(url.path)) # default is to use the last bit of the URL
|
||||
|
||||
@ -292,7 +281,6 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
# signal caller that the download is starting. At this point, key fields such as
|
||||
# download_path and total_bytes will be populated. We call it here because the might
|
||||
# discover that the local file is already complete and generate a COMPLETED status.
|
||||
print(f'DEBUG: In _do_download [2]', flush=True)
|
||||
self._signal_job_started(job)
|
||||
|
||||
# "range not satisfiable" - local file is at least as large as the remote file
|
||||
@ -308,15 +296,12 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
elif resp.status_code != 200:
|
||||
raise HTTPError(resp.reason)
|
||||
|
||||
print(f'DEBUG: In _do_download [3]', flush=True)
|
||||
|
||||
self._logger.debug(f"{job.source}: Downloading {job.download_path}")
|
||||
report_delta = job.total_bytes / 100 # report every 1% change
|
||||
last_report_bytes = 0
|
||||
|
||||
# DOWNLOAD LOOP
|
||||
with open(in_progress_path, open_mode) as file:
|
||||
print(f'DEBUG: In _do_download loop [4]', flush=True)
|
||||
for data in resp.iter_content(chunk_size=DOWNLOAD_CHUNK_SIZE):
|
||||
if job.cancelled:
|
||||
raise DownloadJobCancelledException("Job was cancelled at caller's request")
|
||||
@ -329,8 +314,6 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
# if we get here we are done and can rename the file to the original dest
|
||||
self._logger.debug(f"{job.source}: saved to {job.download_path} (bytes={job.bytes})")
|
||||
in_progress_path.rename(job.download_path)
|
||||
print(f'DEBUG: In _do_download [5]', flush=True)
|
||||
|
||||
|
||||
def _validate_filename(self, directory: str, filename: str) -> bool:
|
||||
pc_name_max = os.pathconf(directory, "PC_NAME_MAX") if hasattr(os, "pathconf") else 260 # hardcoded for windows
|
||||
|
@ -28,6 +28,7 @@ class InstallStatus(str, Enum):
|
||||
|
||||
WAITING = "waiting" # waiting to be dequeued
|
||||
DOWNLOADING = "downloading" # downloading of model files in process
|
||||
DOWNLOADS_DONE = "downloads_done" # downloading done, waiting to run
|
||||
RUNNING = "running" # being processed
|
||||
COMPLETED = "completed" # finished running
|
||||
ERROR = "error" # terminated with an error message
|
||||
@ -229,6 +230,11 @@ class ModelInstallJob(BaseModel):
|
||||
"""Return true if job is downloading."""
|
||||
return self.status == InstallStatus.DOWNLOADING
|
||||
|
||||
@property
|
||||
def downloads_done(self) -> bool:
|
||||
"""Return true if job's downloads ae done."""
|
||||
return self.status == InstallStatus.DOWNLOADS_DONE
|
||||
|
||||
@property
|
||||
def running(self) -> bool:
|
||||
"""Return true if job is running."""
|
||||
|
@ -28,7 +28,6 @@ from invokeai.backend.model_manager.config import (
|
||||
ModelRepoVariant,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.hash import FastModelHash
|
||||
from invokeai.backend.model_manager.metadata import (
|
||||
AnyModelRepoMetadata,
|
||||
CivitaiMetadataFetch,
|
||||
@ -153,7 +152,6 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
config["source"] = model_path.resolve().as_posix()
|
||||
|
||||
info: AnyModelConfig = self._probe_model(Path(model_path), config)
|
||||
old_hash = info.current_hash
|
||||
|
||||
if preferred_name := config.get("name"):
|
||||
preferred_name = Path(preferred_name).with_suffix(model_path.suffix)
|
||||
@ -167,8 +165,6 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
raise DuplicateModelException(
|
||||
f"A model named {model_path.name} is already installed at {dest_path.as_posix()}"
|
||||
) from excp
|
||||
new_hash = FastModelHash.hash(new_path)
|
||||
assert new_hash == old_hash, f"{model_path}: Model hash changed during installation, possibly corrupted."
|
||||
|
||||
return self._register(
|
||||
new_path,
|
||||
@ -370,7 +366,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self._signal_job_errored(job)
|
||||
|
||||
elif (
|
||||
job.waiting or job.downloading
|
||||
job.waiting or job.downloads_done
|
||||
): # local jobs will be in waiting state, remote jobs will be downloading state
|
||||
job.total_bytes = self._stat_size(job.local_path)
|
||||
job.bytes = job.total_bytes
|
||||
@ -448,7 +444,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
installed.update(self.scan_directory(models_dir))
|
||||
self._logger.info(f"{len(installed)} new models registered; {len(defunct_models)} unregistered")
|
||||
|
||||
def _sync_model_path(self, key: str, ignore_hash_change: bool = False) -> AnyModelConfig:
|
||||
def _sync_model_path(self, key: str) -> AnyModelConfig:
|
||||
"""
|
||||
Move model into the location indicated by its basetype, type and name.
|
||||
|
||||
@ -469,14 +465,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
new_path = models_dir / model.base.value / model.type.value / model.name
|
||||
self._logger.info(f"Moving {model.name} to {new_path}.")
|
||||
new_path = self._move_model(old_path, new_path)
|
||||
new_hash = FastModelHash.hash(new_path)
|
||||
model.path = new_path.relative_to(models_dir).as_posix()
|
||||
if model.current_hash != new_hash:
|
||||
assert (
|
||||
ignore_hash_change
|
||||
), f"{model.name}: Model hash changed during installation, model is possibly corrupted"
|
||||
model.current_hash = new_hash
|
||||
self._logger.info(f"Model has new hash {model.current_hash}, but will continue to be identified by {key}")
|
||||
self.record_store.update_model(key, model)
|
||||
return model
|
||||
|
||||
@ -749,8 +738,8 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self._download_cache.pop(download_job.source, None)
|
||||
|
||||
# are there any more active jobs left in this task?
|
||||
if all(x.complete for x in install_job.download_parts):
|
||||
# now enqueue job for actual installation into the models directory
|
||||
if install_job.downloading and all(x.complete for x in install_job.download_parts):
|
||||
install_job.status = InstallStatus.DOWNLOADS_DONE
|
||||
self._install_queue.put(install_job)
|
||||
|
||||
# Let other threads know that the number of downloads has changed
|
||||
|
@ -187,10 +187,10 @@ version = { attr = "invokeai.version.__version__" }
|
||||
|
||||
#=== Begin: PyTest and Coverage
|
||||
[tool.pytest.ini_options]
|
||||
addopts = "--cov-report term --cov-report html --cov-report xml --strict-markers -m \"not slow\""
|
||||
addopts = "--cov-report term --cov-report html --cov-report xml --strict-markers --timeout 60 -m \"not slow\""
|
||||
markers = [
|
||||
"slow: Marks tests as slow. Disabled by default. To run all tests, use -m \"\". To run only slow tests, use -m \"slow\".",
|
||||
"timeout: 60"
|
||||
"timeout: Marks the timeout override."
|
||||
]
|
||||
[tool.coverage.run]
|
||||
branch = true
|
||||
|
@ -166,6 +166,7 @@ def test_broken_callbacks(tmp_path: Path, session: Session, capsys) -> None:
|
||||
# assert re.search("division by zero", captured.err)
|
||||
queue.stop()
|
||||
|
||||
|
||||
@pytest.mark.timeout(timeout=15, method="thread")
|
||||
def test_cancel(tmp_path: Path, session: Session) -> None:
|
||||
event_bus = TestEventService()
|
||||
|
@ -220,6 +220,7 @@ def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config:
|
||||
event_names = [x.event_name for x in bus.events]
|
||||
assert event_names == ["model_install_downloading", "model_install_running", "model_install_completed"]
|
||||
|
||||
|
||||
@pytest.mark.timeout(timeout=20, method="thread")
|
||||
def test_huggingface_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
|
||||
source = URLModelSource(url=Url("https://huggingface.co/stabilityai/sdxl-turbo"))
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
|
||||
@ -149,6 +150,7 @@ def mm2_installer(
|
||||
|
||||
def stop_installer() -> None:
|
||||
installer.stop()
|
||||
time.sleep(0.1) # avoid error message from the logger when it is closed before thread prints final message
|
||||
|
||||
request.addfinalizer(stop_installer)
|
||||
return installer
|
||||
|
Loading…
Reference in New Issue
Block a user