From cc41e8912cf3e80f2d7adc438fc81f83635dfa59 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Wed, 21 Feb 2024 17:15:54 -0500 Subject: [PATCH] several small model install enhancements - Support extended HF repoid syntax in TUI. This allows installation of subfolders and safetensors files, as in `XpucT/Deliberate::Deliberate_v5.safetensors` - Add `error` and `error_traceback` properties to the install job objects. - Rename the `heuristic_import` route to `heuristic_install`. - Fix the example `config` input in the `heuristic_install` route. --- invokeai/app/api/routers/model_manager.py | 8 ++-- .../model_install/model_install_base.py | 13 ++++-- invokeai/backend/install/install_helper.py | 42 +++---------------- .../model_install/test_model_install.py | 2 +- 4 files changed, 20 insertions(+), 45 deletions(-) diff --git a/invokeai/app/api/routers/model_manager.py b/invokeai/app/api/routers/model_manager.py index aee457406a..f57f5f97b6 100644 --- a/invokeai/app/api/routers/model_manager.py +++ b/invokeai/app/api/routers/model_manager.py @@ -382,8 +382,8 @@ async def add_model_record( @model_manager_router.post( - "/heuristic_import", - operation_id="heuristic_import_model", + "/heuristic_install", + operation_id="heuristic_install_model", responses={ 201: {"description": "The model imported successfully"}, 415: {"description": "Unrecognized file/folder format"}, @@ -392,12 +392,12 @@ async def add_model_record( }, status_code=201, ) -async def heuristic_import( +async def heuristic_install( source: str, config: Optional[Dict[str, Any]] = Body( description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ", default=None, - example={"name": "modelT", "description": "antique cars"}, + example={"name": "string", "description": "string"}, ), access_token: Optional[str] = None, ) -> ModelInstallJob: diff --git a/invokeai/app/services/model_install/model_install_base.py b/invokeai/app/services/model_install/model_install_base.py index 080219af75..d1e8e4f8e5 100644 --- a/invokeai/app/services/model_install/model_install_base.py +++ b/invokeai/app/services/model_install/model_install_base.py @@ -177,6 +177,12 @@ class ModelInstallJob(BaseModel): download_parts: Set[DownloadJob] = Field( default_factory=set, description="Download jobs contributing to this install" ) + error: Optional[str] = Field( + default=None, description="On an error condition, this field will contain the text of the exception" + ) + error_traceback: Optional[str] = Field( + default=None, description="On an error condition, this field will contain the exception traceback" + ) # internal flags and transitory settings _install_tmpdir: Optional[Path] = PrivateAttr(default=None) _exception: Optional[Exception] = PrivateAttr(default=None) @@ -184,6 +190,8 @@ class ModelInstallJob(BaseModel): def set_error(self, e: Exception) -> None: """Record the error and traceback from an exception.""" self._exception = e + self.error = str(e) + self.error_traceback = self._format_error(e) self.status = InstallStatus.ERROR def cancel(self) -> None: @@ -195,10 +203,9 @@ class ModelInstallJob(BaseModel): """Class name of the exception that led to status==ERROR.""" return self._exception.__class__.__name__ if self._exception else None - @property - def error(self) -> Optional[str]: + def _format_error(self, exception: Exception) -> str: """Error traceback.""" - return "".join(traceback.format_exception(self._exception)) if self._exception else None + return "".join(traceback.format_exception(exception)) @property def cancelled(self) -> bool: diff --git a/invokeai/backend/install/install_helper.py b/invokeai/backend/install/install_helper.py index 3623b623a9..999dcdd100 100644 --- a/invokeai/backend/install/install_helper.py +++ b/invokeai/backend/install/install_helper.py @@ -1,14 +1,11 @@ """Utility (backend) functions used by model_install.py""" -import re from logging import Logger from pathlib import Path from typing import Any, Dict, List, Optional import omegaconf -from huggingface_hub import HfFolder from pydantic import BaseModel, Field from pydantic.dataclasses import dataclass -from pydantic.networks import AnyHttpUrl from requests import HTTPError from tqdm import tqdm @@ -18,12 +15,8 @@ from invokeai.app.services.download import DownloadQueueService from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.image_files.image_files_disk import DiskImageFileStorage from invokeai.app.services.model_install import ( - HFModelSource, - LocalModelSource, ModelInstallService, ModelInstallServiceBase, - ModelSource, - URLModelSource, ) from invokeai.app.services.model_metadata import ModelMetadataStoreSQL from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL @@ -31,7 +24,6 @@ from invokeai.app.services.shared.sqlite.sqlite_util import init_db from invokeai.backend.model_manager import ( BaseModelType, InvalidModelConfigException, - ModelRepoVariant, ModelType, ) from invokeai.backend.model_manager.metadata import UnknownMetadataException @@ -226,37 +218,13 @@ class InstallHelper(object): additional_models.append(reverse_source[requirement]) model_list.extend(additional_models) - def _make_install_source(self, model_info: UnifiedModelInfo) -> ModelSource: - assert model_info.source - model_path_id_or_url = model_info.source.strip("\"' ") - model_path = Path(model_path_id_or_url) - - if model_path.exists(): # local file on disk - return LocalModelSource(path=model_path.absolute(), inplace=True) - - # parsing huggingface repo ids - # we're going to do a little trick that allows for extended repo_ids of form "foo/bar:fp16" - variants = "|".join([x.lower() for x in ModelRepoVariant.__members__]) - if match := re.match(f"^([^/]+/[^/]+?)(?::({variants}))?$", model_path_id_or_url): - repo_id = match.group(1) - repo_variant = ModelRepoVariant(match.group(2)) if match.group(2) else None - subfolder = Path(model_info.subfolder) if model_info.subfolder else None - return HFModelSource( - repo_id=repo_id, - access_token=HfFolder.get_token(), - subfolder=subfolder, - variant=repo_variant, - ) - if re.match(r"^(http|https):", model_path_id_or_url): - return URLModelSource(url=AnyHttpUrl(model_path_id_or_url)) - raise ValueError(f"Unsupported model source: {model_path_id_or_url}") - def add_or_delete(self, selections: InstallSelections) -> None: """Add or delete selected models.""" installer = self._installer self._add_required_models(selections.install_models) for model in selections.install_models: - source = self._make_install_source(model) + assert model.source + model_path_id_or_url = model.source.strip("\"' ") config = ( { "description": model.description, @@ -267,12 +235,12 @@ class InstallHelper(object): ) try: - installer.import_model( - source=source, + installer.heuristic_import( + source=model_path_id_or_url, config=config, ) except (UnknownMetadataException, InvalidModelConfigException, HTTPError, OSError) as e: - self._logger.warning(f"{source}: {e}") + self._logger.warning(f"{model.source}: {e}") for model_to_remove in selections.remove_models: parts = model_to_remove.split("/") diff --git a/tests/app/services/model_install/test_model_install.py b/tests/app/services/model_install/test_model_install.py index 55f7e86541..80b106c5cb 100644 --- a/tests/app/services/model_install/test_model_install.py +++ b/tests/app/services/model_install/test_model_install.py @@ -256,4 +256,4 @@ def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: In assert job.error_type == "HTTPError" assert job.error assert "NOT FOUND" in job.error - assert "Traceback" in job.error + assert job.error_traceback.startswith("Traceback")