mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
added debugging statements
This commit is contained in:
parent
b15d05f8a8
commit
e452c6171b
@ -34,6 +34,7 @@ from invokeai.backend.model_manager.config import (
|
|||||||
from invokeai.backend.model_manager.metadata import (
|
from invokeai.backend.model_manager.metadata import (
|
||||||
AnyModelRepoMetadata,
|
AnyModelRepoMetadata,
|
||||||
HuggingFaceMetadataFetch,
|
HuggingFaceMetadataFetch,
|
||||||
|
ModelMetadataFetchBase,
|
||||||
ModelMetadataWithFiles,
|
ModelMetadataWithFiles,
|
||||||
RemoteModelFile,
|
RemoteModelFile,
|
||||||
)
|
)
|
||||||
@ -268,13 +269,19 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
|
|
||||||
def wait_for_installs(self, timeout: int = 0) -> List[ModelInstallJob]: # noqa D102
|
def wait_for_installs(self, timeout: int = 0) -> List[ModelInstallJob]: # noqa D102
|
||||||
"""Block until all installation jobs are done."""
|
"""Block until all installation jobs are done."""
|
||||||
|
self.printf("wait_for_installs(): ENTERING")
|
||||||
start = time.time()
|
start = time.time()
|
||||||
while len(self._download_cache) > 0:
|
while len(self._download_cache) > 0:
|
||||||
if self._downloads_changed_event.wait(timeout=0.25): # in case we miss an event
|
if self._downloads_changed_event.wait(timeout=0.25): # in case we miss an event
|
||||||
self._downloads_changed_event.clear()
|
self._downloads_changed_event.clear()
|
||||||
if timeout > 0 and time.time() - start > timeout:
|
if timeout > 0 and time.time() - start > timeout:
|
||||||
raise TimeoutError("Timeout exceeded")
|
raise TimeoutError("Timeout exceeded")
|
||||||
|
self.printf(
|
||||||
|
f"wait_for_installs(): install_queue size={self._install_queue.qsize()}, download_cache={self._download_cache}"
|
||||||
|
)
|
||||||
self._install_queue.join()
|
self._install_queue.join()
|
||||||
|
|
||||||
|
self.printf("wait_for_installs(): EXITING")
|
||||||
return self._install_jobs
|
return self._install_jobs
|
||||||
|
|
||||||
def cancel_job(self, job: ModelInstallJob) -> None:
|
def cancel_job(self, job: ModelInstallJob) -> None:
|
||||||
@ -425,6 +432,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
except Empty:
|
except Empty:
|
||||||
continue
|
continue
|
||||||
assert job.local_path is not None
|
assert job.local_path is not None
|
||||||
|
self.printf(f"_install_next_item(source={job.source}, id={job.id}")
|
||||||
try:
|
try:
|
||||||
if job.cancelled:
|
if job.cancelled:
|
||||||
self._signal_job_cancelled(job)
|
self._signal_job_cancelled(job)
|
||||||
@ -432,9 +440,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
elif job.errored:
|
elif job.errored:
|
||||||
self._signal_job_errored(job)
|
self._signal_job_errored(job)
|
||||||
|
|
||||||
elif (
|
elif job.waiting or job.downloads_done:
|
||||||
job.waiting or job.downloads_done
|
|
||||||
):
|
|
||||||
self._register_or_install(job)
|
self._register_or_install(job)
|
||||||
|
|
||||||
except InvalidModelConfigException as excp:
|
except InvalidModelConfigException as excp:
|
||||||
@ -448,6 +454,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
if job._install_tmpdir is not None:
|
if job._install_tmpdir is not None:
|
||||||
rmtree(job._install_tmpdir)
|
rmtree(job._install_tmpdir)
|
||||||
self._install_completed_event.set()
|
self._install_completed_event.set()
|
||||||
|
self.printf("Signaling task done")
|
||||||
self._install_queue.task_done()
|
self._install_queue.task_done()
|
||||||
|
|
||||||
def _register_or_install(self, job: ModelInstallJob) -> None:
|
def _register_or_install(self, job: ModelInstallJob) -> None:
|
||||||
@ -793,16 +800,25 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
def _download_complete_callback(self, download_job: DownloadJob) -> None:
|
def _download_complete_callback(self, download_job: DownloadJob) -> None:
|
||||||
self._logger.info(f"{download_job.source}: model download complete")
|
self._logger.info(f"{download_job.source}: model download complete")
|
||||||
with self._lock:
|
with self._lock:
|
||||||
|
self.printf("_LOCK")
|
||||||
install_job = self._download_cache[download_job.source]
|
install_job = self._download_cache[download_job.source]
|
||||||
|
self.printf(
|
||||||
|
f"_download_complete_callback(source={download_job.source}, job={install_job.source}, install_job.id={install_job.id})"
|
||||||
|
)
|
||||||
|
|
||||||
# are there any more active jobs left in this task?
|
# are there any more active jobs left in this task?
|
||||||
if install_job.downloading and all(x.complete for x in install_job.download_parts):
|
if install_job.downloading and all(x.complete for x in install_job.download_parts):
|
||||||
|
self.printf(f"_enqueuing job {install_job.id}")
|
||||||
self._signal_job_downloads_done(install_job)
|
self._signal_job_downloads_done(install_job)
|
||||||
self._put_in_queue(install_job)
|
self._put_in_queue(install_job)
|
||||||
|
self.printf(f"_enqueued job {install_job.id}")
|
||||||
|
|
||||||
# Let other threads know that the number of downloads has changed
|
# Let other threads know that the number of downloads has changed
|
||||||
|
self.printf(f"popping {download_job.source}")
|
||||||
self._download_cache.pop(download_job.source, None)
|
self._download_cache.pop(download_job.source, None)
|
||||||
self._downloads_changed_event.set()
|
self._downloads_changed_event.set()
|
||||||
|
self.printf("downloads_changed_event is set")
|
||||||
|
self.printf("_UNLOCK")
|
||||||
|
|
||||||
def _download_error_callback(self, download_job: DownloadJob, excp: Optional[Exception] = None) -> None:
|
def _download_error_callback(self, download_job: DownloadJob, excp: Optional[Exception] = None) -> None:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
@ -907,7 +923,11 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
self._event_bus.emit_model_install_cancelled(str(job.source), id=job.id)
|
self._event_bus.emit_model_install_cancelled(str(job.source), id=job.id)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_fetcher_from_url(url: str):
|
def get_fetcher_from_url(url: str) -> ModelMetadataFetchBase:
|
||||||
if re.match(r"^https?://huggingface.co/[^/]+/[^/]+$", url.lower()):
|
if re.match(r"^https?://huggingface.co/[^/]+/[^/]+$", url.lower()):
|
||||||
return HuggingFaceMetadataFetch
|
return HuggingFaceMetadataFetch
|
||||||
raise ValueError(f"Unsupported model source: '{url}'")
|
raise ValueError(f"Unsupported model source: '{url}'")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def printf(message: str) -> None:
|
||||||
|
print(f"[{time.time():18}] [{threading.get_ident():16}] {message}", flush=True)
|
||||||
|
Loading…
Reference in New Issue
Block a user