refactor big _install_next_item() loop

This commit is contained in:
Lincoln Stein 2024-03-19 23:45:02 -04:00 committed by psychedelicious
parent 30283a4767
commit 9bd7dabed3

View File

@ -438,37 +438,14 @@ class ModelInstallService(ModelInstallServiceBase):
elif ( elif (
job.waiting or job.downloads_done job.waiting or job.downloads_done
): # local jobs will be in waiting state, remote jobs will be downloading state ):
job.total_bytes = self._stat_size(job.local_path) self._register_or_install(job)
job.bytes = job.total_bytes
self._signal_job_running(job)
job.config_in["source"] = str(job.source)
job.config_in["source_type"] = MODEL_SOURCE_TO_TYPE_MAP[job.source.__class__]
# enter the metadata, if there is any
if isinstance(job.source_metadata, (HuggingFaceMetadata)):
job.config_in["source_api_response"] = job.source_metadata.api_response
if job.inplace:
key = self.register_path(job.local_path, job.config_in)
else:
key = self.install_path(job.local_path, job.config_in)
job.config_out = self.record_store.get_model(key)
self._signal_job_completed(job)
except InvalidModelConfigException as excp: except InvalidModelConfigException as excp:
if any(x.content_type is not None and "text/html" in x.content_type for x in job.download_parts): self._set_error(job, excp)
job.set_error(
InvalidModelConfigException(
f"At least one file in {job.local_path} is an HTML page, not a model. This can happen when an access token is required to download."
)
)
else:
job.set_error(excp)
self._signal_job_errored(job)
except (OSError, DuplicateModelException) as excp: except (OSError, DuplicateModelException) as excp:
job.set_error(excp) self._set_error(job, excp)
self._signal_job_errored(job)
finally: finally:
# if this is an install of a remote file, then clean up the temporary directory # if this is an install of a remote file, then clean up the temporary directory
@ -477,6 +454,35 @@ class ModelInstallService(ModelInstallServiceBase):
self._install_completed_event.set() self._install_completed_event.set()
self._install_queue.task_done() self._install_queue.task_done()
def _register_or_install(self, job: ModelInstallJob) -> None:
# local jobs will be in waiting state, remote jobs will be downloading state
job.total_bytes = self._stat_size(job.local_path)
job.bytes = job.total_bytes
self._signal_job_running(job)
job.config_in["source"] = str(job.source)
job.config_in["source_type"] = MODEL_SOURCE_TO_TYPE_MAP[job.source.__class__]
# enter the metadata, if there is any
if isinstance(job.source_metadata, (HuggingFaceMetadata)):
job.config_in["source_api_response"] = job.source_metadata.api_response
if job.inplace:
key = self.register_path(job.local_path, job.config_in)
else:
key = self.install_path(job.local_path, job.config_in)
job.config_out = self.record_store.get_model(key)
self._signal_job_completed(job)
def _set_error(self, job: ModelInstallJob, excp: Exception) -> None:
if any(x.content_type is not None and "text/html" in x.content_type for x in job.download_parts):
job.set_error(
InvalidModelConfigException(
f"At least one file in {job.local_path} is an HTML page, not a model. This can happen when an access token is required to download."
)
)
else:
job.set_error(excp)
self._signal_job_errored(job)
# -------------------------------------------------------------------------------------------- # --------------------------------------------------------------------------------------------
# Internal functions that manage the models directory # Internal functions that manage the models directory
# -------------------------------------------------------------------------------------------- # --------------------------------------------------------------------------------------------