From 3402cf6542c86d8b0a5493e94e49cf23e446b472 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Wed, 20 Sep 2023 20:30:35 -0400 Subject: [PATCH] preserve description in metadata when installing a starter model --- .../backend/model_manager/download/queue.py | 17 +++++++++-------- invokeai/frontend/install/model_install.py | 17 +++++++++++------ 2 files changed, 20 insertions(+), 14 deletions(-) diff --git a/invokeai/backend/model_manager/download/queue.py b/invokeai/backend/model_manager/download/queue.py index 0faad6fc61..d71d80aca8 100644 --- a/invokeai/backend/model_manager/download/queue.py +++ b/invokeai/backend/model_manager/download/queue.py @@ -504,10 +504,11 @@ class DownloadQueue(DownloadQueueBase): try: repo_id = job.source variant = job.variant - urls_to_download, metadata = self._get_repo_info(repo_id, variant) + if not job.metadata: + job.metadata = ModelSourceMetadata() + urls_to_download = self._get_repo_info(repo_id, variant=variant, metadata=job.metadata) if job.destination.name != Path(repo_id).name: job.destination = job.destination / Path(repo_id).name - job.metadata = metadata bytes_downloaded = dict() job.total_bytes = 0 @@ -535,10 +536,12 @@ class DownloadQueue(DownloadQueueBase): def _get_repo_info( self, repo_id: str, + metadata: ModelSourceMetadata, variant: Optional[str] = None, ) -> Tuple[List[Tuple[AnyHttpUrl, Path, Path]], ModelSourceMetadata]: """ Given a repo_id and an optional variant, return list of URLs to download to get the model. + The metadata field will be updated with model metadata from HuggingFace. Known variants currently are: 1. onnx @@ -561,12 +564,10 @@ class DownloadQueue(DownloadQueueBase): (hf_hub_url(repo_id, filename=x.as_posix()), x.parent or Path("."), x.name, sizes[x.as_posix()]) for x in self._select_variants(paths, variant) ] - return ( - urls, - ModelSourceMetadata( - license=model_info.cardData.get("license"), tags=model_info.tags, author=model_info.author - ), - ) + metadata.license = metadata.license or model_info.cardData.get("license") + metadata.tags = metadata.tags or model_info.tags + metadata.author = metadata.author or model_info.author + return urls def _select_variants(self, paths: List[str], variant: Optional[str] = None) -> Set[Path]: """Select the proper variant files from a list of HuggingFace repo_id paths.""" diff --git a/invokeai/frontend/install/model_install.py b/invokeai/frontend/install/model_install.py index a0f8d1268e..91ae55e55e 100644 --- a/invokeai/frontend/install/model_install.py +++ b/invokeai/frontend/install/model_install.py @@ -33,7 +33,7 @@ import invokeai.configs as configs from invokeai.app.services.config import InvokeAIAppConfig # from invokeai.backend.install.model_install_backend import InstallSelections, ModelInstall, SchedulerPredictionType -from invokeai.backend.model_manager import BaseModelType, ModelFormat, ModelInstall, ModelInstallJob, ModelType +from invokeai.backend.model_manager import BaseModelType, ModelInstall, ModelInstallJob, ModelType from invokeai.backend.model_manager.install import ModelSourceMetadata from invokeai.backend.util import choose_precision, choose_torch_device from invokeai.backend.util.logging import InvokeAILogger @@ -47,7 +47,6 @@ from invokeai.frontend.install.widgets import ( SingleSelectColumns, TextBox, WindowTooSmallException, - select_stable_diffusion_config_file, set_min_terminal_size, ) @@ -87,11 +86,12 @@ class InstallSelections: def make_printable(s: str) -> str: - """Replace non-printable characters in a string""" + """Replace non-printable characters in a string.""" return s.translate(NOPRINT_TRANS_TABLE) class addModelsForm(CyclingForm, npyscreen.FormMultiPage): + """Main form for interactive TUI.""" # for responsive resizing set to False, but this seems to cause a crash! FIX_MINIMUM_SIZE_WHEN_CREATED = True @@ -402,12 +402,16 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage): for model in self.installer.store.all_models(): info = UnifiedModelInfo.parse_obj(model.dict()) info.installed = True - key = f"{model.base_model}/{model.model_type}/{model.name}" + key = f"{model.base_model.value}/{model.model_type.value}/{model.name}" all_models[key] = info installed_models.append(key) for key in INITIAL_MODELS_CONFIG.keys(): - if key not in all_models: + if key in all_models: + # we want to preserve the description + description = all_models[key].description or INITIAL_MODELS_CONFIG[key].get("description") + all_models[key].description = description + else: base_model, model_type, model_name = key.split("/") info = UnifiedModelInfo( name=model_name, @@ -436,7 +440,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage): label_width = max([len(models[x].name) for x in self.starter_models]) description_width = window_width - label_width - checkbox_width - spacing_width - for key in self.starter_models: + for key in self.all_models: description = models[key].description description = ( description[0 : description_width - 3] + "..." @@ -587,6 +591,7 @@ def add_or_delete(installer: ModelInstall, selections: InstallSelections): # -------------------------------------------------------- def select_and_download_models(opt: Namespace): + """Prompt user for install/delete selections and execute.""" precision = "float32" if opt.full_precision else choose_precision(torch.device(choose_torch_device())) config.precision = precision installer = ModelInstall(config=config, event_handlers=[tqdm_progress])