mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
preserve description in metadata when installing a starter model
This commit is contained in:
@ -504,10 +504,11 @@ class DownloadQueue(DownloadQueueBase):
|
||||
try:
|
||||
repo_id = job.source
|
||||
variant = job.variant
|
||||
urls_to_download, metadata = self._get_repo_info(repo_id, variant)
|
||||
if not job.metadata:
|
||||
job.metadata = ModelSourceMetadata()
|
||||
urls_to_download = self._get_repo_info(repo_id, variant=variant, metadata=job.metadata)
|
||||
if job.destination.name != Path(repo_id).name:
|
||||
job.destination = job.destination / Path(repo_id).name
|
||||
job.metadata = metadata
|
||||
bytes_downloaded = dict()
|
||||
job.total_bytes = 0
|
||||
|
||||
@ -535,10 +536,12 @@ class DownloadQueue(DownloadQueueBase):
|
||||
def _get_repo_info(
|
||||
self,
|
||||
repo_id: str,
|
||||
metadata: ModelSourceMetadata,
|
||||
variant: Optional[str] = None,
|
||||
) -> Tuple[List[Tuple[AnyHttpUrl, Path, Path]], ModelSourceMetadata]:
|
||||
"""
|
||||
Given a repo_id and an optional variant, return list of URLs to download to get the model.
|
||||
The metadata field will be updated with model metadata from HuggingFace.
|
||||
|
||||
Known variants currently are:
|
||||
1. onnx
|
||||
@ -561,12 +564,10 @@ class DownloadQueue(DownloadQueueBase):
|
||||
(hf_hub_url(repo_id, filename=x.as_posix()), x.parent or Path("."), x.name, sizes[x.as_posix()])
|
||||
for x in self._select_variants(paths, variant)
|
||||
]
|
||||
return (
|
||||
urls,
|
||||
ModelSourceMetadata(
|
||||
license=model_info.cardData.get("license"), tags=model_info.tags, author=model_info.author
|
||||
),
|
||||
)
|
||||
metadata.license = metadata.license or model_info.cardData.get("license")
|
||||
metadata.tags = metadata.tags or model_info.tags
|
||||
metadata.author = metadata.author or model_info.author
|
||||
return urls
|
||||
|
||||
def _select_variants(self, paths: List[str], variant: Optional[str] = None) -> Set[Path]:
|
||||
"""Select the proper variant files from a list of HuggingFace repo_id paths."""
|
||||
|
@ -33,7 +33,7 @@ import invokeai.configs as configs
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
# from invokeai.backend.install.model_install_backend import InstallSelections, ModelInstall, SchedulerPredictionType
|
||||
from invokeai.backend.model_manager import BaseModelType, ModelFormat, ModelInstall, ModelInstallJob, ModelType
|
||||
from invokeai.backend.model_manager import BaseModelType, ModelInstall, ModelInstallJob, ModelType
|
||||
from invokeai.backend.model_manager.install import ModelSourceMetadata
|
||||
from invokeai.backend.util import choose_precision, choose_torch_device
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
@ -47,7 +47,6 @@ from invokeai.frontend.install.widgets import (
|
||||
SingleSelectColumns,
|
||||
TextBox,
|
||||
WindowTooSmallException,
|
||||
select_stable_diffusion_config_file,
|
||||
set_min_terminal_size,
|
||||
)
|
||||
|
||||
@ -87,11 +86,12 @@ class InstallSelections:
|
||||
|
||||
|
||||
def make_printable(s: str) -> str:
|
||||
"""Replace non-printable characters in a string"""
|
||||
"""Replace non-printable characters in a string."""
|
||||
return s.translate(NOPRINT_TRANS_TABLE)
|
||||
|
||||
|
||||
class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
"""Main form for interactive TUI."""
|
||||
# for responsive resizing set to False, but this seems to cause a crash!
|
||||
FIX_MINIMUM_SIZE_WHEN_CREATED = True
|
||||
|
||||
@ -402,12 +402,16 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
for model in self.installer.store.all_models():
|
||||
info = UnifiedModelInfo.parse_obj(model.dict())
|
||||
info.installed = True
|
||||
key = f"{model.base_model}/{model.model_type}/{model.name}"
|
||||
key = f"{model.base_model.value}/{model.model_type.value}/{model.name}"
|
||||
all_models[key] = info
|
||||
installed_models.append(key)
|
||||
|
||||
for key in INITIAL_MODELS_CONFIG.keys():
|
||||
if key not in all_models:
|
||||
if key in all_models:
|
||||
# we want to preserve the description
|
||||
description = all_models[key].description or INITIAL_MODELS_CONFIG[key].get("description")
|
||||
all_models[key].description = description
|
||||
else:
|
||||
base_model, model_type, model_name = key.split("/")
|
||||
info = UnifiedModelInfo(
|
||||
name=model_name,
|
||||
@ -436,7 +440,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
label_width = max([len(models[x].name) for x in self.starter_models])
|
||||
description_width = window_width - label_width - checkbox_width - spacing_width
|
||||
|
||||
for key in self.starter_models:
|
||||
for key in self.all_models:
|
||||
description = models[key].description
|
||||
description = (
|
||||
description[0 : description_width - 3] + "..."
|
||||
@ -587,6 +591,7 @@ def add_or_delete(installer: ModelInstall, selections: InstallSelections):
|
||||
|
||||
# --------------------------------------------------------
|
||||
def select_and_download_models(opt: Namespace):
|
||||
"""Prompt user for install/delete selections and execute."""
|
||||
precision = "float32" if opt.full_precision else choose_precision(torch.device(choose_torch_device()))
|
||||
config.precision = precision
|
||||
installer = ModelInstall(config=config, event_handlers=[tqdm_progress])
|
||||
|
Reference in New Issue
Block a user