mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
make download manager optional in InvokeAIServices during development
This commit is contained in:
@ -35,7 +35,7 @@ class InvocationServices:
|
||||
model_manager: "ModelManagerServiceBase"
|
||||
processor: "InvocationProcessorABC"
|
||||
performance_statistics: "InvocationStatsServiceBase"
|
||||
download_manager: "DownloadQueueServiceBase"
|
||||
download_manager: Optional["DownloadQueueServiceBase"]
|
||||
queue: "InvocationQueueABC"
|
||||
|
||||
def __init__(
|
||||
@ -52,8 +52,8 @@ class InvocationServices:
|
||||
model_manager: "ModelManagerServiceBase",
|
||||
processor: "InvocationProcessorABC",
|
||||
performance_statistics: "InvocationStatsServiceBase",
|
||||
download_manager: "DownloadQueueServiceBase",
|
||||
queue: "InvocationQueueABC",
|
||||
download_manager: Optional["DownloadQueueServiceBase"] = None, # optional for now pending design decisions
|
||||
):
|
||||
self.board_images = board_images
|
||||
self.boards = boards
|
||||
@ -66,6 +66,7 @@ class InvocationServices:
|
||||
self.latents = latents
|
||||
self.logger = logger
|
||||
self.model_manager = model_manager
|
||||
self.download_manager = download_manager
|
||||
self.processor = processor
|
||||
self.performance_statistics = performance_statistics
|
||||
self.queue = queue
|
||||
|
@ -419,10 +419,11 @@ class DownloadQueue(DownloadQueueBase):
|
||||
self._update_job_status(job, DownloadJobStatus.COMPLETED)
|
||||
subqueue.release() # get rid of the subqueue
|
||||
|
||||
def _get_repo_info(self,
|
||||
repo_id: str,
|
||||
variant: Optional[str] = None,
|
||||
) -> Tuple[List[Tuple[AnyHttpUrl, Path, Path]], Dict[str, str]]:
|
||||
def _get_repo_info(
|
||||
self,
|
||||
repo_id: str,
|
||||
variant: Optional[str] = None,
|
||||
) -> Tuple[List[Tuple[AnyHttpUrl, Path, Path]], Dict[str, str]]:
|
||||
"""Given a repo_id and an optional variant, return list of URLs to download to get the model."""
|
||||
model_info = HfApi().model_info(repo_id=repo_id, files_metadata=True)
|
||||
sibs = model_info.siblings
|
||||
@ -439,7 +440,7 @@ class DownloadQueue(DownloadQueueBase):
|
||||
(hf_hub_url(repo_id, filename=x.as_posix()), x.parent or Path("."), x.name, sizes[x.as_posix()])
|
||||
for x in self._select_variants(paths, variant)
|
||||
]
|
||||
return (urls, {'cardData': model_info.cardData, 'tags': model_info.tags, 'author': model_info.author})
|
||||
return (urls, {"cardData": model_info.cardData, "tags": model_info.tags, "author": model_info.author})
|
||||
|
||||
def _select_variants(self, paths: List[str], variant: Optional[str] = None) -> Set[Path]:
|
||||
"""Select the proper variant files from a list of HuggingFace repo_id paths."""
|
||||
|
@ -320,11 +320,11 @@ class ModelInstall(ModelInstallBase):
|
||||
info = self._store.get_model(id)
|
||||
info.description = f"Downloaded model {info.name}"
|
||||
info.source_url = str(job.source)
|
||||
if card_data := job.metadata.get('cardData'):
|
||||
info.license = card_data.get('license')
|
||||
if author := job.metadata.get('author'):
|
||||
if card_data := job.metadata.get("cardData"):
|
||||
info.license = card_data.get("license")
|
||||
if author := job.metadata.get("author"):
|
||||
info.author = author
|
||||
if tags := job.metadata.get('tags'):
|
||||
if tags := job.metadata.get("tags"):
|
||||
info.tags = tags
|
||||
self._store.update_model(id, info)
|
||||
self._async_installs[job.source] = id
|
||||
@ -337,9 +337,7 @@ class ModelInstall(ModelInstallBase):
|
||||
# Better to do the cleanup in the callback
|
||||
self._tmpdir = self._tmpdir or tempfile.TemporaryDirectory(dir=models_dir)
|
||||
return queue.create_download_job(
|
||||
source=source,
|
||||
destdir=self._tmpdir.name,
|
||||
event_handlers=[complete_installation]
|
||||
source=source, destdir=self._tmpdir.name, event_handlers=[complete_installation]
|
||||
)
|
||||
|
||||
def wait_for_downloads(self) -> Dict[str, str]: # noqa D102
|
||||
|
Reference in New Issue
Block a user