make download manager optional in InvokeAIServices during development

This commit is contained in:
Lincoln Stein
2023-09-09 14:06:36 -04:00
parent 64424c6db0
commit 3582cfa267
3 changed files with 14 additions and 14 deletions

View File

@ -35,7 +35,7 @@ class InvocationServices:
model_manager: "ModelManagerServiceBase"
processor: "InvocationProcessorABC"
performance_statistics: "InvocationStatsServiceBase"
download_manager: "DownloadQueueServiceBase"
download_manager: Optional["DownloadQueueServiceBase"]
queue: "InvocationQueueABC"
def __init__(
@ -52,8 +52,8 @@ class InvocationServices:
model_manager: "ModelManagerServiceBase",
processor: "InvocationProcessorABC",
performance_statistics: "InvocationStatsServiceBase",
download_manager: "DownloadQueueServiceBase",
queue: "InvocationQueueABC",
download_manager: Optional["DownloadQueueServiceBase"] = None, # optional for now pending design decisions
):
self.board_images = board_images
self.boards = boards
@ -66,6 +66,7 @@ class InvocationServices:
self.latents = latents
self.logger = logger
self.model_manager = model_manager
self.download_manager = download_manager
self.processor = processor
self.performance_statistics = performance_statistics
self.queue = queue

View File

@ -419,10 +419,11 @@ class DownloadQueue(DownloadQueueBase):
self._update_job_status(job, DownloadJobStatus.COMPLETED)
subqueue.release() # get rid of the subqueue
def _get_repo_info(self,
repo_id: str,
variant: Optional[str] = None,
) -> Tuple[List[Tuple[AnyHttpUrl, Path, Path]], Dict[str, str]]:
def _get_repo_info(
self,
repo_id: str,
variant: Optional[str] = None,
) -> Tuple[List[Tuple[AnyHttpUrl, Path, Path]], Dict[str, str]]:
"""Given a repo_id and an optional variant, return list of URLs to download to get the model."""
model_info = HfApi().model_info(repo_id=repo_id, files_metadata=True)
sibs = model_info.siblings
@ -439,7 +440,7 @@ class DownloadQueue(DownloadQueueBase):
(hf_hub_url(repo_id, filename=x.as_posix()), x.parent or Path("."), x.name, sizes[x.as_posix()])
for x in self._select_variants(paths, variant)
]
return (urls, {'cardData': model_info.cardData, 'tags': model_info.tags, 'author': model_info.author})
return (urls, {"cardData": model_info.cardData, "tags": model_info.tags, "author": model_info.author})
def _select_variants(self, paths: List[str], variant: Optional[str] = None) -> Set[Path]:
"""Select the proper variant files from a list of HuggingFace repo_id paths."""

View File

@ -320,11 +320,11 @@ class ModelInstall(ModelInstallBase):
info = self._store.get_model(id)
info.description = f"Downloaded model {info.name}"
info.source_url = str(job.source)
if card_data := job.metadata.get('cardData'):
info.license = card_data.get('license')
if author := job.metadata.get('author'):
if card_data := job.metadata.get("cardData"):
info.license = card_data.get("license")
if author := job.metadata.get("author"):
info.author = author
if tags := job.metadata.get('tags'):
if tags := job.metadata.get("tags"):
info.tags = tags
self._store.update_model(id, info)
self._async_installs[job.source] = id
@ -337,9 +337,7 @@ class ModelInstall(ModelInstallBase):
# Better to do the cleanup in the callback
self._tmpdir = self._tmpdir or tempfile.TemporaryDirectory(dir=models_dir)
return queue.create_download_job(
source=source,
destdir=self._tmpdir.name,
event_handlers=[complete_installation]
source=source, destdir=self._tmpdir.name, event_handlers=[complete_installation]
)
def wait_for_downloads(self) -> Dict[str, str]: # noqa D102