several small model install enhancements

- Support extended HF repoid syntax in TUI. This allows
  installation of subfolders and safetensors files, as in
  `XpucT/Deliberate::Deliberate_v5.safetensors`

- Add `error` and `error_traceback` properties to the install
  job objects.

- Rename the `heuristic_import` route to `heuristic_install`.

- Fix the example `config` input in the `heuristic_install` route.
This commit is contained in:
Lincoln Stein 2024-02-21 17:15:54 -05:00 committed by psychedelicious
parent 613f11a3ac
commit cc12f57a5a
4 changed files with 20 additions and 45 deletions

View File

@ -382,8 +382,8 @@ async def add_model_record(
@model_manager_router.post(
"/heuristic_import",
operation_id="heuristic_import_model",
"/heuristic_install",
operation_id="heuristic_install_model",
responses={
201: {"description": "The model imported successfully"},
415: {"description": "Unrecognized file/folder format"},
@ -392,12 +392,12 @@ async def add_model_record(
},
status_code=201,
)
async def heuristic_import(
async def heuristic_install(
source: str,
config: Optional[Dict[str, Any]] = Body(
description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
default=None,
example={"name": "modelT", "description": "antique cars"},
example={"name": "string", "description": "string"},
),
access_token: Optional[str] = None,
) -> ModelInstallJob:

View File

@ -177,6 +177,12 @@ class ModelInstallJob(BaseModel):
download_parts: Set[DownloadJob] = Field(
default_factory=set, description="Download jobs contributing to this install"
)
error: Optional[str] = Field(
default=None, description="On an error condition, this field will contain the text of the exception"
)
error_traceback: Optional[str] = Field(
default=None, description="On an error condition, this field will contain the exception traceback"
)
# internal flags and transitory settings
_install_tmpdir: Optional[Path] = PrivateAttr(default=None)
_exception: Optional[Exception] = PrivateAttr(default=None)
@ -184,6 +190,8 @@ class ModelInstallJob(BaseModel):
def set_error(self, e: Exception) -> None:
"""Record the error and traceback from an exception."""
self._exception = e
self.error = str(e)
self.error_traceback = self._format_error(e)
self.status = InstallStatus.ERROR
def cancel(self) -> None:
@ -195,10 +203,9 @@ class ModelInstallJob(BaseModel):
"""Class name of the exception that led to status==ERROR."""
return self._exception.__class__.__name__ if self._exception else None
@property
def error(self) -> Optional[str]:
def _format_error(self, exception: Exception) -> str:
"""Error traceback."""
return "".join(traceback.format_exception(self._exception)) if self._exception else None
return "".join(traceback.format_exception(exception))
@property
def cancelled(self) -> bool:

View File

@ -1,14 +1,11 @@
"""Utility (backend) functions used by model_install.py"""
import re
from logging import Logger
from pathlib import Path
from typing import Any, Dict, List, Optional
import omegaconf
from huggingface_hub import HfFolder
from pydantic import BaseModel, Field
from pydantic.dataclasses import dataclass
from pydantic.networks import AnyHttpUrl
from requests import HTTPError
from tqdm import tqdm
@ -18,12 +15,8 @@ from invokeai.app.services.download import DownloadQueueService
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.image_files.image_files_disk import DiskImageFileStorage
from invokeai.app.services.model_install import (
HFModelSource,
LocalModelSource,
ModelInstallService,
ModelInstallServiceBase,
ModelSource,
URLModelSource,
)
from invokeai.app.services.model_metadata import ModelMetadataStoreSQL
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL
@ -31,7 +24,6 @@ from invokeai.app.services.shared.sqlite.sqlite_util import init_db
from invokeai.backend.model_manager import (
BaseModelType,
InvalidModelConfigException,
ModelRepoVariant,
ModelType,
)
from invokeai.backend.model_manager.metadata import UnknownMetadataException
@ -226,37 +218,13 @@ class InstallHelper(object):
additional_models.append(reverse_source[requirement])
model_list.extend(additional_models)
def _make_install_source(self, model_info: UnifiedModelInfo) -> ModelSource:
assert model_info.source
model_path_id_or_url = model_info.source.strip("\"' ")
model_path = Path(model_path_id_or_url)
if model_path.exists(): # local file on disk
return LocalModelSource(path=model_path.absolute(), inplace=True)
# parsing huggingface repo ids
# we're going to do a little trick that allows for extended repo_ids of form "foo/bar:fp16"
variants = "|".join([x.lower() for x in ModelRepoVariant.__members__])
if match := re.match(f"^([^/]+/[^/]+?)(?::({variants}))?$", model_path_id_or_url):
repo_id = match.group(1)
repo_variant = ModelRepoVariant(match.group(2)) if match.group(2) else None
subfolder = Path(model_info.subfolder) if model_info.subfolder else None
return HFModelSource(
repo_id=repo_id,
access_token=HfFolder.get_token(),
subfolder=subfolder,
variant=repo_variant,
)
if re.match(r"^(http|https):", model_path_id_or_url):
return URLModelSource(url=AnyHttpUrl(model_path_id_or_url))
raise ValueError(f"Unsupported model source: {model_path_id_or_url}")
def add_or_delete(self, selections: InstallSelections) -> None:
"""Add or delete selected models."""
installer = self._installer
self._add_required_models(selections.install_models)
for model in selections.install_models:
source = self._make_install_source(model)
assert model.source
model_path_id_or_url = model.source.strip("\"' ")
config = (
{
"description": model.description,
@ -267,12 +235,12 @@ class InstallHelper(object):
)
try:
installer.import_model(
source=source,
installer.heuristic_import(
source=model_path_id_or_url,
config=config,
)
except (UnknownMetadataException, InvalidModelConfigException, HTTPError, OSError) as e:
self._logger.warning(f"{source}: {e}")
self._logger.warning(f"{model.source}: {e}")
for model_to_remove in selections.remove_models:
parts = model_to_remove.split("/")

View File

@ -256,4 +256,4 @@ def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: In
assert job.error_type == "HTTPError"
assert job.error
assert "NOT FOUND" in job.error
assert "Traceback" in job.error
assert job.error_traceback.startswith("Traceback")