preserve description in metadata when installing a starter model

This commit is contained in:
Lincoln Stein
2023-09-20 20:30:35 -04:00
parent ed91f48a92
commit 3402cf6542
2 changed files with 20 additions and 14 deletions

View File

@ -504,10 +504,11 @@ class DownloadQueue(DownloadQueueBase):
try:
repo_id = job.source
variant = job.variant
urls_to_download, metadata = self._get_repo_info(repo_id, variant)
if not job.metadata:
job.metadata = ModelSourceMetadata()
urls_to_download = self._get_repo_info(repo_id, variant=variant, metadata=job.metadata)
if job.destination.name != Path(repo_id).name:
job.destination = job.destination / Path(repo_id).name
job.metadata = metadata
bytes_downloaded = dict()
job.total_bytes = 0
@ -535,10 +536,12 @@ class DownloadQueue(DownloadQueueBase):
def _get_repo_info(
self,
repo_id: str,
metadata: ModelSourceMetadata,
variant: Optional[str] = None,
) -> Tuple[List[Tuple[AnyHttpUrl, Path, Path]], ModelSourceMetadata]:
"""
Given a repo_id and an optional variant, return list of URLs to download to get the model.
The metadata field will be updated with model metadata from HuggingFace.
Known variants currently are:
1. onnx
@ -561,12 +564,10 @@ class DownloadQueue(DownloadQueueBase):
(hf_hub_url(repo_id, filename=x.as_posix()), x.parent or Path("."), x.name, sizes[x.as_posix()])
for x in self._select_variants(paths, variant)
]
return (
urls,
ModelSourceMetadata(
license=model_info.cardData.get("license"), tags=model_info.tags, author=model_info.author
),
)
metadata.license = metadata.license or model_info.cardData.get("license")
metadata.tags = metadata.tags or model_info.tags
metadata.author = metadata.author or model_info.author
return urls
def _select_variants(self, paths: List[str], variant: Optional[str] = None) -> Set[Path]:
"""Select the proper variant files from a list of HuggingFace repo_id paths."""

View File

@ -33,7 +33,7 @@ import invokeai.configs as configs
from invokeai.app.services.config import InvokeAIAppConfig
# from invokeai.backend.install.model_install_backend import InstallSelections, ModelInstall, SchedulerPredictionType
from invokeai.backend.model_manager import BaseModelType, ModelFormat, ModelInstall, ModelInstallJob, ModelType
from invokeai.backend.model_manager import BaseModelType, ModelInstall, ModelInstallJob, ModelType
from invokeai.backend.model_manager.install import ModelSourceMetadata
from invokeai.backend.util import choose_precision, choose_torch_device
from invokeai.backend.util.logging import InvokeAILogger
@ -47,7 +47,6 @@ from invokeai.frontend.install.widgets import (
SingleSelectColumns,
TextBox,
WindowTooSmallException,
select_stable_diffusion_config_file,
set_min_terminal_size,
)
@ -87,11 +86,12 @@ class InstallSelections:
def make_printable(s: str) -> str:
"""Replace non-printable characters in a string"""
"""Replace non-printable characters in a string."""
return s.translate(NOPRINT_TRANS_TABLE)
class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
"""Main form for interactive TUI."""
# for responsive resizing set to False, but this seems to cause a crash!
FIX_MINIMUM_SIZE_WHEN_CREATED = True
@ -402,12 +402,16 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
for model in self.installer.store.all_models():
info = UnifiedModelInfo.parse_obj(model.dict())
info.installed = True
key = f"{model.base_model}/{model.model_type}/{model.name}"
key = f"{model.base_model.value}/{model.model_type.value}/{model.name}"
all_models[key] = info
installed_models.append(key)
for key in INITIAL_MODELS_CONFIG.keys():
if key not in all_models:
if key in all_models:
# we want to preserve the description
description = all_models[key].description or INITIAL_MODELS_CONFIG[key].get("description")
all_models[key].description = description
else:
base_model, model_type, model_name = key.split("/")
info = UnifiedModelInfo(
name=model_name,
@ -436,7 +440,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
label_width = max([len(models[x].name) for x in self.starter_models])
description_width = window_width - label_width - checkbox_width - spacing_width
for key in self.starter_models:
for key in self.all_models:
description = models[key].description
description = (
description[0 : description_width - 3] + "..."
@ -587,6 +591,7 @@ def add_or_delete(installer: ModelInstall, selections: InstallSelections):
# --------------------------------------------------------
def select_and_download_models(opt: Namespace):
"""Prompt user for install/delete selections and execute."""
precision = "float32" if opt.full_precision else choose_precision(torch.device(choose_torch_device()))
config.precision = precision
installer = ModelInstall(config=config, event_handlers=[tqdm_progress])