diff --git a/invokeai/app/services/events/events_base.py b/invokeai/app/services/events/events_base.py index 75a4bcc259..6373ec1f78 100644 --- a/invokeai/app/services/events/events_base.py +++ b/invokeai/app/services/events/events_base.py @@ -386,6 +386,17 @@ class EventServiceBase: }, ) + def emit_model_install_downloads_done(self, source: str) -> None: + """ + Emit once when all parts are downloaded, but before the probing and registration start. + + :param source: Source of the model; local path, repo_id or url + """ + self.__emit_model_event( + event_name="model_install_downloads_done", + payload={"source": source}, + ) + def emit_model_install_running(self, source: str) -> None: """ Emit once when an install job becomes active. diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index 6330cc0969..992f5a11fe 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -124,15 +124,28 @@ class ModelInstallService(ModelInstallServiceBase): if not self._running: raise Exception("Attempt to stop the install service before it was started") self._stop_event.set() - with self._install_queue.mutex: - self._install_queue.queue.clear() # get rid of pending jobs - active_jobs = [x for x in self.list_jobs() if x.running] - if active_jobs: - self._logger.warning("Waiting for active install job to complete") - self.wait_for_installs() + self._clear_pending_jobs() self._download_cache.clear() self._running = False + def _clear_pending_jobs(self) -> None: + for job in self.list_jobs(): + if not job.in_terminal_state: + self._logger.warning("Cancelling job {job.id}") + self.cancel_job(job) + while True: + try: + job = self._install_queue.get(block=False) + self._install_queue.task_done() + except Empty: + break + + def _put_in_queue(self, job: ModelInstallJob) -> None: + if self._stop_event.is_set(): + self.cancel_job(job) + else: + self._install_queue.put(job) + def register_path( self, model_path: Union[Path, str], @@ -218,7 +231,7 @@ class ModelInstallService(ModelInstallServiceBase): if isinstance(source, LocalModelSource): install_job = self._import_local_model(source, config) - self._install_queue.put(install_job) # synchronously install + self._put_in_queue(install_job) # synchronously install elif isinstance(source, HFModelSource): install_job = self._import_from_hf(source, config) elif isinstance(source, URLModelSource): @@ -253,7 +266,6 @@ class ModelInstallService(ModelInstallServiceBase): raise TimeoutError("Timeout exceeded") return job - # TODO: Better name? Maybe wait_for_jobs()? Maybe too easily confused with above def wait_for_installs(self, timeout: int = 0) -> List[ModelInstallJob]: # noqa D102 """Block until all installation jobs are done.""" start = time.time() @@ -412,7 +424,6 @@ class ModelInstallService(ModelInstallServiceBase): job = self._install_queue.get(timeout=1) except Empty: continue - assert job.local_path is not None try: if job.cancelled: @@ -462,8 +473,6 @@ class ModelInstallService(ModelInstallServiceBase): self._install_completed_event.set() self._install_queue.task_done() - self._logger.info("Install thread exiting") - # -------------------------------------------------------------------------------------------- # Internal functions that manage the models directory # -------------------------------------------------------------------------------------------- @@ -779,14 +788,14 @@ class ModelInstallService(ModelInstallServiceBase): self._logger.info(f"{download_job.source}: model download complete") with self._lock: install_job = self._download_cache[download_job.source] - self._download_cache.pop(download_job.source, None) # are there any more active jobs left in this task? if install_job.downloading and all(x.complete for x in install_job.download_parts): - install_job.status = InstallStatus.DOWNLOADS_DONE - self._install_queue.put(install_job) + self._signal_job_downloads_done(install_job) + self._put_in_queue(install_job) # Let other threads know that the number of downloads has changed + self._download_cache.pop(download_job.source, None) self._downloads_changed_event.set() def _download_error_callback(self, download_job: DownloadJob, excp: Optional[Exception] = None) -> None: @@ -826,7 +835,7 @@ class ModelInstallService(ModelInstallServiceBase): if all(x.in_terminal_state for x in install_job.download_parts): # When all parts have reached their terminal state, we finalize the job to clean up the temporary directory and other resources - self._install_queue.put(install_job) + self._put_in_queue(install_job) # ------------------------------------------------------------------------------------------------ # Internal methods that put events on the event bus @@ -859,6 +868,12 @@ class ModelInstallService(ModelInstallServiceBase): id=job.id, ) + def _signal_job_downloads_done(self, job: ModelInstallJob) -> None: + job.status = InstallStatus.DOWNLOADS_DONE + self._logger.info(f"{job.source}: all parts of this model are downloaded") + if self._event_bus: + self._event_bus.emit_model_install_downloads_done(str(job.source)) + def _signal_job_completed(self, job: ModelInstallJob) -> None: job.status = InstallStatus.COMPLETED assert job.config_out diff --git a/tests/app/services/model_install/test_model_install.py b/tests/app/services/model_install/test_model_install.py index c0360c3e7d..79895ed380 100644 --- a/tests/app/services/model_install/test_model_install.py +++ b/tests/app/services/model_install/test_model_install.py @@ -221,9 +221,14 @@ def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: model_record = store.get_model(key) assert Path(model_record.path).exists() - assert len(bus.events) == 3 + assert len(bus.events) == 4 event_names = [x.event_name for x in bus.events] - assert event_names == ["model_install_downloading", "model_install_running", "model_install_completed"] + assert event_names == [ + "model_install_downloading", + "model_install_downloads_done", + "model_install_running", + "model_install_completed", + ] @pytest.mark.timeout(timeout=20, method="thread") @@ -250,7 +255,12 @@ def test_huggingface_download(mm2_installer: ModelInstallServiceBase, mm2_app_co assert hasattr(bus, "events") # the dummyeventservice has this assert len(bus.events) >= 3 event_names = {x.event_name for x in bus.events} - assert event_names == {"model_install_downloading", "model_install_running", "model_install_completed"} + assert event_names == { + "model_install_downloading", + "model_install_downloads_done", + "model_install_running", + "model_install_completed", + } def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None: