mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
several small model install enhancements
- Support extended HF repoid syntax in TUI. This allows installation of subfolders and safetensors files, as in `XpucT/Deliberate::Deliberate_v5.safetensors` - Add `error` and `error_traceback` properties to the install job objects. - Rename the `heuristic_import` route to `heuristic_install`. - Fix the example `config` input in the `heuristic_install` route.
This commit is contained in:
parent
613f11a3ac
commit
cc12f57a5a
@ -382,8 +382,8 @@ async def add_model_record(
|
||||
|
||||
|
||||
@model_manager_router.post(
|
||||
"/heuristic_import",
|
||||
operation_id="heuristic_import_model",
|
||||
"/heuristic_install",
|
||||
operation_id="heuristic_install_model",
|
||||
responses={
|
||||
201: {"description": "The model imported successfully"},
|
||||
415: {"description": "Unrecognized file/folder format"},
|
||||
@ -392,12 +392,12 @@ async def add_model_record(
|
||||
},
|
||||
status_code=201,
|
||||
)
|
||||
async def heuristic_import(
|
||||
async def heuristic_install(
|
||||
source: str,
|
||||
config: Optional[Dict[str, Any]] = Body(
|
||||
description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
|
||||
default=None,
|
||||
example={"name": "modelT", "description": "antique cars"},
|
||||
example={"name": "string", "description": "string"},
|
||||
),
|
||||
access_token: Optional[str] = None,
|
||||
) -> ModelInstallJob:
|
||||
|
@ -177,6 +177,12 @@ class ModelInstallJob(BaseModel):
|
||||
download_parts: Set[DownloadJob] = Field(
|
||||
default_factory=set, description="Download jobs contributing to this install"
|
||||
)
|
||||
error: Optional[str] = Field(
|
||||
default=None, description="On an error condition, this field will contain the text of the exception"
|
||||
)
|
||||
error_traceback: Optional[str] = Field(
|
||||
default=None, description="On an error condition, this field will contain the exception traceback"
|
||||
)
|
||||
# internal flags and transitory settings
|
||||
_install_tmpdir: Optional[Path] = PrivateAttr(default=None)
|
||||
_exception: Optional[Exception] = PrivateAttr(default=None)
|
||||
@ -184,6 +190,8 @@ class ModelInstallJob(BaseModel):
|
||||
def set_error(self, e: Exception) -> None:
|
||||
"""Record the error and traceback from an exception."""
|
||||
self._exception = e
|
||||
self.error = str(e)
|
||||
self.error_traceback = self._format_error(e)
|
||||
self.status = InstallStatus.ERROR
|
||||
|
||||
def cancel(self) -> None:
|
||||
@ -195,10 +203,9 @@ class ModelInstallJob(BaseModel):
|
||||
"""Class name of the exception that led to status==ERROR."""
|
||||
return self._exception.__class__.__name__ if self._exception else None
|
||||
|
||||
@property
|
||||
def error(self) -> Optional[str]:
|
||||
def _format_error(self, exception: Exception) -> str:
|
||||
"""Error traceback."""
|
||||
return "".join(traceback.format_exception(self._exception)) if self._exception else None
|
||||
return "".join(traceback.format_exception(exception))
|
||||
|
||||
@property
|
||||
def cancelled(self) -> bool:
|
||||
|
@ -1,14 +1,11 @@
|
||||
"""Utility (backend) functions used by model_install.py"""
|
||||
import re
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import omegaconf
|
||||
from huggingface_hub import HfFolder
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic.dataclasses import dataclass
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from requests import HTTPError
|
||||
from tqdm import tqdm
|
||||
|
||||
@ -18,12 +15,8 @@ from invokeai.app.services.download import DownloadQueueService
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.image_files.image_files_disk import DiskImageFileStorage
|
||||
from invokeai.app.services.model_install import (
|
||||
HFModelSource,
|
||||
LocalModelSource,
|
||||
ModelInstallService,
|
||||
ModelInstallServiceBase,
|
||||
ModelSource,
|
||||
URLModelSource,
|
||||
)
|
||||
from invokeai.app.services.model_metadata import ModelMetadataStoreSQL
|
||||
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL
|
||||
@ -31,7 +24,6 @@ from invokeai.app.services.shared.sqlite.sqlite_util import init_db
|
||||
from invokeai.backend.model_manager import (
|
||||
BaseModelType,
|
||||
InvalidModelConfigException,
|
||||
ModelRepoVariant,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.metadata import UnknownMetadataException
|
||||
@ -226,37 +218,13 @@ class InstallHelper(object):
|
||||
additional_models.append(reverse_source[requirement])
|
||||
model_list.extend(additional_models)
|
||||
|
||||
def _make_install_source(self, model_info: UnifiedModelInfo) -> ModelSource:
|
||||
assert model_info.source
|
||||
model_path_id_or_url = model_info.source.strip("\"' ")
|
||||
model_path = Path(model_path_id_or_url)
|
||||
|
||||
if model_path.exists(): # local file on disk
|
||||
return LocalModelSource(path=model_path.absolute(), inplace=True)
|
||||
|
||||
# parsing huggingface repo ids
|
||||
# we're going to do a little trick that allows for extended repo_ids of form "foo/bar:fp16"
|
||||
variants = "|".join([x.lower() for x in ModelRepoVariant.__members__])
|
||||
if match := re.match(f"^([^/]+/[^/]+?)(?::({variants}))?$", model_path_id_or_url):
|
||||
repo_id = match.group(1)
|
||||
repo_variant = ModelRepoVariant(match.group(2)) if match.group(2) else None
|
||||
subfolder = Path(model_info.subfolder) if model_info.subfolder else None
|
||||
return HFModelSource(
|
||||
repo_id=repo_id,
|
||||
access_token=HfFolder.get_token(),
|
||||
subfolder=subfolder,
|
||||
variant=repo_variant,
|
||||
)
|
||||
if re.match(r"^(http|https):", model_path_id_or_url):
|
||||
return URLModelSource(url=AnyHttpUrl(model_path_id_or_url))
|
||||
raise ValueError(f"Unsupported model source: {model_path_id_or_url}")
|
||||
|
||||
def add_or_delete(self, selections: InstallSelections) -> None:
|
||||
"""Add or delete selected models."""
|
||||
installer = self._installer
|
||||
self._add_required_models(selections.install_models)
|
||||
for model in selections.install_models:
|
||||
source = self._make_install_source(model)
|
||||
assert model.source
|
||||
model_path_id_or_url = model.source.strip("\"' ")
|
||||
config = (
|
||||
{
|
||||
"description": model.description,
|
||||
@ -267,12 +235,12 @@ class InstallHelper(object):
|
||||
)
|
||||
|
||||
try:
|
||||
installer.import_model(
|
||||
source=source,
|
||||
installer.heuristic_import(
|
||||
source=model_path_id_or_url,
|
||||
config=config,
|
||||
)
|
||||
except (UnknownMetadataException, InvalidModelConfigException, HTTPError, OSError) as e:
|
||||
self._logger.warning(f"{source}: {e}")
|
||||
self._logger.warning(f"{model.source}: {e}")
|
||||
|
||||
for model_to_remove in selections.remove_models:
|
||||
parts = model_to_remove.split("/")
|
||||
|
@ -256,4 +256,4 @@ def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: In
|
||||
assert job.error_type == "HTTPError"
|
||||
assert job.error
|
||||
assert "NOT FOUND" in job.error
|
||||
assert "Traceback" in job.error
|
||||
assert job.error_traceback.startswith("Traceback")
|
||||
|
Loading…
Reference in New Issue
Block a user