several small model install enhancements

- Support extended HF repoid syntax in TUI. This allows
  installation of subfolders and safetensors files, as in
  `XpucT/Deliberate::Deliberate_v5.safetensors`

- Add `error` and `error_traceback` properties to the install
  job objects.

- Rename the `heuristic_import` route to `heuristic_install`.

- Fix the example `config` input in the `heuristic_install` route.
This commit is contained in:
Lincoln Stein 2024-02-21 17:15:54 -05:00 committed by psychedelicious
parent 613f11a3ac
commit cc12f57a5a
4 changed files with 20 additions and 45 deletions

View File

@ -382,8 +382,8 @@ async def add_model_record(
@model_manager_router.post( @model_manager_router.post(
"/heuristic_import", "/heuristic_install",
operation_id="heuristic_import_model", operation_id="heuristic_install_model",
responses={ responses={
201: {"description": "The model imported successfully"}, 201: {"description": "The model imported successfully"},
415: {"description": "Unrecognized file/folder format"}, 415: {"description": "Unrecognized file/folder format"},
@ -392,12 +392,12 @@ async def add_model_record(
}, },
status_code=201, status_code=201,
) )
async def heuristic_import( async def heuristic_install(
source: str, source: str,
config: Optional[Dict[str, Any]] = Body( config: Optional[Dict[str, Any]] = Body(
description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ", description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
default=None, default=None,
example={"name": "modelT", "description": "antique cars"}, example={"name": "string", "description": "string"},
), ),
access_token: Optional[str] = None, access_token: Optional[str] = None,
) -> ModelInstallJob: ) -> ModelInstallJob:

View File

@ -177,6 +177,12 @@ class ModelInstallJob(BaseModel):
download_parts: Set[DownloadJob] = Field( download_parts: Set[DownloadJob] = Field(
default_factory=set, description="Download jobs contributing to this install" default_factory=set, description="Download jobs contributing to this install"
) )
error: Optional[str] = Field(
default=None, description="On an error condition, this field will contain the text of the exception"
)
error_traceback: Optional[str] = Field(
default=None, description="On an error condition, this field will contain the exception traceback"
)
# internal flags and transitory settings # internal flags and transitory settings
_install_tmpdir: Optional[Path] = PrivateAttr(default=None) _install_tmpdir: Optional[Path] = PrivateAttr(default=None)
_exception: Optional[Exception] = PrivateAttr(default=None) _exception: Optional[Exception] = PrivateAttr(default=None)
@ -184,6 +190,8 @@ class ModelInstallJob(BaseModel):
def set_error(self, e: Exception) -> None: def set_error(self, e: Exception) -> None:
"""Record the error and traceback from an exception.""" """Record the error and traceback from an exception."""
self._exception = e self._exception = e
self.error = str(e)
self.error_traceback = self._format_error(e)
self.status = InstallStatus.ERROR self.status = InstallStatus.ERROR
def cancel(self) -> None: def cancel(self) -> None:
@ -195,10 +203,9 @@ class ModelInstallJob(BaseModel):
"""Class name of the exception that led to status==ERROR.""" """Class name of the exception that led to status==ERROR."""
return self._exception.__class__.__name__ if self._exception else None return self._exception.__class__.__name__ if self._exception else None
@property def _format_error(self, exception: Exception) -> str:
def error(self) -> Optional[str]:
"""Error traceback.""" """Error traceback."""
return "".join(traceback.format_exception(self._exception)) if self._exception else None return "".join(traceback.format_exception(exception))
@property @property
def cancelled(self) -> bool: def cancelled(self) -> bool:

View File

@ -1,14 +1,11 @@
"""Utility (backend) functions used by model_install.py""" """Utility (backend) functions used by model_install.py"""
import re
from logging import Logger from logging import Logger
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import omegaconf import omegaconf
from huggingface_hub import HfFolder
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from pydantic.dataclasses import dataclass from pydantic.dataclasses import dataclass
from pydantic.networks import AnyHttpUrl
from requests import HTTPError from requests import HTTPError
from tqdm import tqdm from tqdm import tqdm
@ -18,12 +15,8 @@ from invokeai.app.services.download import DownloadQueueService
from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.image_files.image_files_disk import DiskImageFileStorage from invokeai.app.services.image_files.image_files_disk import DiskImageFileStorage
from invokeai.app.services.model_install import ( from invokeai.app.services.model_install import (
HFModelSource,
LocalModelSource,
ModelInstallService, ModelInstallService,
ModelInstallServiceBase, ModelInstallServiceBase,
ModelSource,
URLModelSource,
) )
from invokeai.app.services.model_metadata import ModelMetadataStoreSQL from invokeai.app.services.model_metadata import ModelMetadataStoreSQL
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL
@ -31,7 +24,6 @@ from invokeai.app.services.shared.sqlite.sqlite_util import init_db
from invokeai.backend.model_manager import ( from invokeai.backend.model_manager import (
BaseModelType, BaseModelType,
InvalidModelConfigException, InvalidModelConfigException,
ModelRepoVariant,
ModelType, ModelType,
) )
from invokeai.backend.model_manager.metadata import UnknownMetadataException from invokeai.backend.model_manager.metadata import UnknownMetadataException
@ -226,37 +218,13 @@ class InstallHelper(object):
additional_models.append(reverse_source[requirement]) additional_models.append(reverse_source[requirement])
model_list.extend(additional_models) model_list.extend(additional_models)
def _make_install_source(self, model_info: UnifiedModelInfo) -> ModelSource:
assert model_info.source
model_path_id_or_url = model_info.source.strip("\"' ")
model_path = Path(model_path_id_or_url)
if model_path.exists(): # local file on disk
return LocalModelSource(path=model_path.absolute(), inplace=True)
# parsing huggingface repo ids
# we're going to do a little trick that allows for extended repo_ids of form "foo/bar:fp16"
variants = "|".join([x.lower() for x in ModelRepoVariant.__members__])
if match := re.match(f"^([^/]+/[^/]+?)(?::({variants}))?$", model_path_id_or_url):
repo_id = match.group(1)
repo_variant = ModelRepoVariant(match.group(2)) if match.group(2) else None
subfolder = Path(model_info.subfolder) if model_info.subfolder else None
return HFModelSource(
repo_id=repo_id,
access_token=HfFolder.get_token(),
subfolder=subfolder,
variant=repo_variant,
)
if re.match(r"^(http|https):", model_path_id_or_url):
return URLModelSource(url=AnyHttpUrl(model_path_id_or_url))
raise ValueError(f"Unsupported model source: {model_path_id_or_url}")
def add_or_delete(self, selections: InstallSelections) -> None: def add_or_delete(self, selections: InstallSelections) -> None:
"""Add or delete selected models.""" """Add or delete selected models."""
installer = self._installer installer = self._installer
self._add_required_models(selections.install_models) self._add_required_models(selections.install_models)
for model in selections.install_models: for model in selections.install_models:
source = self._make_install_source(model) assert model.source
model_path_id_or_url = model.source.strip("\"' ")
config = ( config = (
{ {
"description": model.description, "description": model.description,
@ -267,12 +235,12 @@ class InstallHelper(object):
) )
try: try:
installer.import_model( installer.heuristic_import(
source=source, source=model_path_id_or_url,
config=config, config=config,
) )
except (UnknownMetadataException, InvalidModelConfigException, HTTPError, OSError) as e: except (UnknownMetadataException, InvalidModelConfigException, HTTPError, OSError) as e:
self._logger.warning(f"{source}: {e}") self._logger.warning(f"{model.source}: {e}")
for model_to_remove in selections.remove_models: for model_to_remove in selections.remove_models:
parts = model_to_remove.split("/") parts = model_to_remove.split("/")

View File

@ -256,4 +256,4 @@ def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: In
assert job.error_type == "HTTPError" assert job.error_type == "HTTPError"
assert job.error assert job.error
assert "NOT FOUND" in job.error assert "NOT FOUND" in job.error
assert "Traceback" in job.error assert job.error_traceback.startswith("Traceback")