after stopping install and download services, wait for thread exit

This commit is contained in:
Lincoln Stein
2024-03-20 22:11:52 -04:00
committed by psychedelicious
parent 0cab1d1e04
commit 689cb9d31d
4 changed files with 62 additions and 86 deletions

View File

@ -87,6 +87,8 @@ class DownloadQueueService(DownloadQueueServiceBase):
self._queue.queue.clear()
self.join() # wait for all active jobs to finish
self._stop_event.set()
for thread in self._worker_pool:
thread.join()
self._worker_pool.clear()
def submit_download_job(

View File

@ -93,6 +93,7 @@ class ModelInstallService(ModelInstallServiceBase):
self._download_cache: Dict[AnyHttpUrl, ModelInstallJob] = {}
self._running = False
self._session = session
self._install_thread: Optional[threading.Thread] = None
self._next_job_id = 0
@property
@ -127,6 +128,8 @@ class ModelInstallService(ModelInstallServiceBase):
self._stop_event.set()
self._clear_pending_jobs()
self._download_cache.clear()
assert self._install_thread is not None
self._install_thread.join()
self._running = False
def _clear_pending_jobs(self) -> None:
@ -269,19 +272,14 @@ class ModelInstallService(ModelInstallServiceBase):
def wait_for_installs(self, timeout: int = 0) -> List[ModelInstallJob]: # noqa D102
"""Block until all installation jobs are done."""
self.printf("wait_for_installs(): ENTERING")
start = time.time()
while len(self._download_cache) > 0:
if self._downloads_changed_event.wait(timeout=0.25): # in case we miss an event
self._downloads_changed_event.clear()
if timeout > 0 and time.time() - start > timeout:
raise TimeoutError("Timeout exceeded")
self.printf(
f"wait_for_installs(): install_queue size={self._install_queue.qsize()}, download_cache={self._download_cache}"
)
self._install_queue.join()
self.printf("wait_for_installs(): EXITING")
return self._install_jobs
def cancel_job(self, job: ModelInstallJob) -> None:
@ -422,21 +420,21 @@ class ModelInstallService(ModelInstallServiceBase):
# Internal functions that manage the installer threads
# --------------------------------------------------------------------------------------------
def _start_installer_thread(self) -> None:
threading.Thread(target=self._install_next_item, daemon=True).start()
self._install_thread = threading.Thread(target=self._install_next_item, daemon=True)
self._install_thread.start()
self._running = True
def _install_next_item(self) -> None:
done = False
while not done:
self._logger.info(f"Installer thread {threading.get_ident()} starting")
while True:
if self._stop_event.is_set():
done = True
continue
break
self._logger.info(f"Installer thread {threading.get_ident()} running")
try:
job = self._install_queue.get(timeout=1)
except Empty:
continue
assert job.local_path is not None
self.printf(f"_install_next_item(source={job.source}, id={job.id}")
try:
if job.cancelled:
self._signal_job_cancelled(job)
@ -458,8 +456,8 @@ class ModelInstallService(ModelInstallServiceBase):
if job._install_tmpdir is not None:
rmtree(job._install_tmpdir)
self._install_completed_event.set()
self.printf("Signaling task done")
self._install_queue.task_done()
self._logger.info(f"Installer thread {threading.get_ident()} exiting")
def _register_or_install(self, job: ModelInstallJob) -> None:
# local jobs will be in waiting state, remote jobs will be downloading state
@ -804,25 +802,16 @@ class ModelInstallService(ModelInstallServiceBase):
def _download_complete_callback(self, download_job: DownloadJob) -> None:
self._logger.info(f"{download_job.source}: model download complete")
with self._lock:
self.printf("_LOCK")
install_job = self._download_cache[download_job.source]
self.printf(
f"_download_complete_callback(source={download_job.source}, job={install_job.source}, install_job.id={install_job.id})"
)
# are there any more active jobs left in this task?
if install_job.downloading and all(x.complete for x in install_job.download_parts):
self.printf(f"_enqueuing job {install_job.id}")
self._signal_job_downloads_done(install_job)
self._put_in_queue(install_job)
self.printf(f"_enqueued job {install_job.id}")
# Let other threads know that the number of downloads has changed
self.printf(f"popping {download_job.source}")
self._download_cache.pop(download_job.source, None)
self._downloads_changed_event.set()
self.printf("downloads_changed_event is set")
self.printf("_UNLOCK")
def _download_error_callback(self, download_job: DownloadJob, excp: Optional[Exception] = None) -> None:
with self._lock:
@ -931,7 +920,3 @@ class ModelInstallService(ModelInstallServiceBase):
if re.match(r"^https?://huggingface.co/[^/]+/[^/]+$", url.lower()):
return HuggingFaceMetadataFetch
raise ValueError(f"Unsupported model source: '{url}'")
@staticmethod
def printf(message: str) -> None:
print(f"[{time.time():18}] [{threading.get_ident():16}] {message}", flush=True)