From 9bd7dabed34e04a3c5301fa39cb7a366cae69c24 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Tue, 19 Mar 2024 23:45:02 -0400 Subject: [PATCH] refactor big _install_next_item() loop --- .../model_install/model_install_default.py | 60 ++++++++++--------- 1 file changed, 33 insertions(+), 27 deletions(-) diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index bd41022037..5f5c12abb7 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -438,37 +438,14 @@ class ModelInstallService(ModelInstallServiceBase): elif ( job.waiting or job.downloads_done - ): # local jobs will be in waiting state, remote jobs will be downloading state - job.total_bytes = self._stat_size(job.local_path) - job.bytes = job.total_bytes - self._signal_job_running(job) - job.config_in["source"] = str(job.source) - job.config_in["source_type"] = MODEL_SOURCE_TO_TYPE_MAP[job.source.__class__] - # enter the metadata, if there is any - if isinstance(job.source_metadata, (HuggingFaceMetadata)): - job.config_in["source_api_response"] = job.source_metadata.api_response - - if job.inplace: - key = self.register_path(job.local_path, job.config_in) - else: - key = self.install_path(job.local_path, job.config_in) - job.config_out = self.record_store.get_model(key) - self._signal_job_completed(job) + ): + self._register_or_install(job) except InvalidModelConfigException as excp: - if any(x.content_type is not None and "text/html" in x.content_type for x in job.download_parts): - job.set_error( - InvalidModelConfigException( - f"At least one file in {job.local_path} is an HTML page, not a model. This can happen when an access token is required to download." - ) - ) - else: - job.set_error(excp) - self._signal_job_errored(job) + self._set_error(job, excp) except (OSError, DuplicateModelException) as excp: - job.set_error(excp) - self._signal_job_errored(job) + self._set_error(job, excp) finally: # if this is an install of a remote file, then clean up the temporary directory @@ -477,6 +454,35 @@ class ModelInstallService(ModelInstallServiceBase): self._install_completed_event.set() self._install_queue.task_done() + def _register_or_install(self, job: ModelInstallJob) -> None: + # local jobs will be in waiting state, remote jobs will be downloading state + job.total_bytes = self._stat_size(job.local_path) + job.bytes = job.total_bytes + self._signal_job_running(job) + job.config_in["source"] = str(job.source) + job.config_in["source_type"] = MODEL_SOURCE_TO_TYPE_MAP[job.source.__class__] + # enter the metadata, if there is any + if isinstance(job.source_metadata, (HuggingFaceMetadata)): + job.config_in["source_api_response"] = job.source_metadata.api_response + + if job.inplace: + key = self.register_path(job.local_path, job.config_in) + else: + key = self.install_path(job.local_path, job.config_in) + job.config_out = self.record_store.get_model(key) + self._signal_job_completed(job) + + def _set_error(self, job: ModelInstallJob, excp: Exception) -> None: + if any(x.content_type is not None and "text/html" in x.content_type for x in job.download_parts): + job.set_error( + InvalidModelConfigException( + f"At least one file in {job.local_path} is an HTML page, not a model. This can happen when an access token is required to download." + ) + ) + else: + job.set_error(excp) + self._signal_job_errored(job) + # -------------------------------------------------------------------------------------------- # Internal functions that manage the models directory # --------------------------------------------------------------------------------------------