mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
several small model install enhancements
- Support extended HF repoid syntax in TUI. This allows installation of subfolders and safetensors files, as in `XpucT/Deliberate::Deliberate_v5.safetensors` - Add `error` and `error_traceback` properties to the install job objects. - Rename the `heuristic_import` route to `heuristic_install`. - Fix the example `config` input in the `heuristic_install` route.
This commit is contained in:
parent
1cec0bb179
commit
cc41e8912c
@ -382,8 +382,8 @@ async def add_model_record(
|
|||||||
|
|
||||||
|
|
||||||
@model_manager_router.post(
|
@model_manager_router.post(
|
||||||
"/heuristic_import",
|
"/heuristic_install",
|
||||||
operation_id="heuristic_import_model",
|
operation_id="heuristic_install_model",
|
||||||
responses={
|
responses={
|
||||||
201: {"description": "The model imported successfully"},
|
201: {"description": "The model imported successfully"},
|
||||||
415: {"description": "Unrecognized file/folder format"},
|
415: {"description": "Unrecognized file/folder format"},
|
||||||
@ -392,12 +392,12 @@ async def add_model_record(
|
|||||||
},
|
},
|
||||||
status_code=201,
|
status_code=201,
|
||||||
)
|
)
|
||||||
async def heuristic_import(
|
async def heuristic_install(
|
||||||
source: str,
|
source: str,
|
||||||
config: Optional[Dict[str, Any]] = Body(
|
config: Optional[Dict[str, Any]] = Body(
|
||||||
description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
|
description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
|
||||||
default=None,
|
default=None,
|
||||||
example={"name": "modelT", "description": "antique cars"},
|
example={"name": "string", "description": "string"},
|
||||||
),
|
),
|
||||||
access_token: Optional[str] = None,
|
access_token: Optional[str] = None,
|
||||||
) -> ModelInstallJob:
|
) -> ModelInstallJob:
|
||||||
|
@ -177,6 +177,12 @@ class ModelInstallJob(BaseModel):
|
|||||||
download_parts: Set[DownloadJob] = Field(
|
download_parts: Set[DownloadJob] = Field(
|
||||||
default_factory=set, description="Download jobs contributing to this install"
|
default_factory=set, description="Download jobs contributing to this install"
|
||||||
)
|
)
|
||||||
|
error: Optional[str] = Field(
|
||||||
|
default=None, description="On an error condition, this field will contain the text of the exception"
|
||||||
|
)
|
||||||
|
error_traceback: Optional[str] = Field(
|
||||||
|
default=None, description="On an error condition, this field will contain the exception traceback"
|
||||||
|
)
|
||||||
# internal flags and transitory settings
|
# internal flags and transitory settings
|
||||||
_install_tmpdir: Optional[Path] = PrivateAttr(default=None)
|
_install_tmpdir: Optional[Path] = PrivateAttr(default=None)
|
||||||
_exception: Optional[Exception] = PrivateAttr(default=None)
|
_exception: Optional[Exception] = PrivateAttr(default=None)
|
||||||
@ -184,6 +190,8 @@ class ModelInstallJob(BaseModel):
|
|||||||
def set_error(self, e: Exception) -> None:
|
def set_error(self, e: Exception) -> None:
|
||||||
"""Record the error and traceback from an exception."""
|
"""Record the error and traceback from an exception."""
|
||||||
self._exception = e
|
self._exception = e
|
||||||
|
self.error = str(e)
|
||||||
|
self.error_traceback = self._format_error(e)
|
||||||
self.status = InstallStatus.ERROR
|
self.status = InstallStatus.ERROR
|
||||||
|
|
||||||
def cancel(self) -> None:
|
def cancel(self) -> None:
|
||||||
@ -195,10 +203,9 @@ class ModelInstallJob(BaseModel):
|
|||||||
"""Class name of the exception that led to status==ERROR."""
|
"""Class name of the exception that led to status==ERROR."""
|
||||||
return self._exception.__class__.__name__ if self._exception else None
|
return self._exception.__class__.__name__ if self._exception else None
|
||||||
|
|
||||||
@property
|
def _format_error(self, exception: Exception) -> str:
|
||||||
def error(self) -> Optional[str]:
|
|
||||||
"""Error traceback."""
|
"""Error traceback."""
|
||||||
return "".join(traceback.format_exception(self._exception)) if self._exception else None
|
return "".join(traceback.format_exception(exception))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def cancelled(self) -> bool:
|
def cancelled(self) -> bool:
|
||||||
|
@ -1,14 +1,11 @@
|
|||||||
"""Utility (backend) functions used by model_install.py"""
|
"""Utility (backend) functions used by model_install.py"""
|
||||||
import re
|
|
||||||
from logging import Logger
|
from logging import Logger
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import omegaconf
|
import omegaconf
|
||||||
from huggingface_hub import HfFolder
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from pydantic.dataclasses import dataclass
|
from pydantic.dataclasses import dataclass
|
||||||
from pydantic.networks import AnyHttpUrl
|
|
||||||
from requests import HTTPError
|
from requests import HTTPError
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
@ -18,12 +15,8 @@ from invokeai.app.services.download import DownloadQueueService
|
|||||||
from invokeai.app.services.events.events_base import EventServiceBase
|
from invokeai.app.services.events.events_base import EventServiceBase
|
||||||
from invokeai.app.services.image_files.image_files_disk import DiskImageFileStorage
|
from invokeai.app.services.image_files.image_files_disk import DiskImageFileStorage
|
||||||
from invokeai.app.services.model_install import (
|
from invokeai.app.services.model_install import (
|
||||||
HFModelSource,
|
|
||||||
LocalModelSource,
|
|
||||||
ModelInstallService,
|
ModelInstallService,
|
||||||
ModelInstallServiceBase,
|
ModelInstallServiceBase,
|
||||||
ModelSource,
|
|
||||||
URLModelSource,
|
|
||||||
)
|
)
|
||||||
from invokeai.app.services.model_metadata import ModelMetadataStoreSQL
|
from invokeai.app.services.model_metadata import ModelMetadataStoreSQL
|
||||||
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL
|
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL
|
||||||
@ -31,7 +24,6 @@ from invokeai.app.services.shared.sqlite.sqlite_util import init_db
|
|||||||
from invokeai.backend.model_manager import (
|
from invokeai.backend.model_manager import (
|
||||||
BaseModelType,
|
BaseModelType,
|
||||||
InvalidModelConfigException,
|
InvalidModelConfigException,
|
||||||
ModelRepoVariant,
|
|
||||||
ModelType,
|
ModelType,
|
||||||
)
|
)
|
||||||
from invokeai.backend.model_manager.metadata import UnknownMetadataException
|
from invokeai.backend.model_manager.metadata import UnknownMetadataException
|
||||||
@ -226,37 +218,13 @@ class InstallHelper(object):
|
|||||||
additional_models.append(reverse_source[requirement])
|
additional_models.append(reverse_source[requirement])
|
||||||
model_list.extend(additional_models)
|
model_list.extend(additional_models)
|
||||||
|
|
||||||
def _make_install_source(self, model_info: UnifiedModelInfo) -> ModelSource:
|
|
||||||
assert model_info.source
|
|
||||||
model_path_id_or_url = model_info.source.strip("\"' ")
|
|
||||||
model_path = Path(model_path_id_or_url)
|
|
||||||
|
|
||||||
if model_path.exists(): # local file on disk
|
|
||||||
return LocalModelSource(path=model_path.absolute(), inplace=True)
|
|
||||||
|
|
||||||
# parsing huggingface repo ids
|
|
||||||
# we're going to do a little trick that allows for extended repo_ids of form "foo/bar:fp16"
|
|
||||||
variants = "|".join([x.lower() for x in ModelRepoVariant.__members__])
|
|
||||||
if match := re.match(f"^([^/]+/[^/]+?)(?::({variants}))?$", model_path_id_or_url):
|
|
||||||
repo_id = match.group(1)
|
|
||||||
repo_variant = ModelRepoVariant(match.group(2)) if match.group(2) else None
|
|
||||||
subfolder = Path(model_info.subfolder) if model_info.subfolder else None
|
|
||||||
return HFModelSource(
|
|
||||||
repo_id=repo_id,
|
|
||||||
access_token=HfFolder.get_token(),
|
|
||||||
subfolder=subfolder,
|
|
||||||
variant=repo_variant,
|
|
||||||
)
|
|
||||||
if re.match(r"^(http|https):", model_path_id_or_url):
|
|
||||||
return URLModelSource(url=AnyHttpUrl(model_path_id_or_url))
|
|
||||||
raise ValueError(f"Unsupported model source: {model_path_id_or_url}")
|
|
||||||
|
|
||||||
def add_or_delete(self, selections: InstallSelections) -> None:
|
def add_or_delete(self, selections: InstallSelections) -> None:
|
||||||
"""Add or delete selected models."""
|
"""Add or delete selected models."""
|
||||||
installer = self._installer
|
installer = self._installer
|
||||||
self._add_required_models(selections.install_models)
|
self._add_required_models(selections.install_models)
|
||||||
for model in selections.install_models:
|
for model in selections.install_models:
|
||||||
source = self._make_install_source(model)
|
assert model.source
|
||||||
|
model_path_id_or_url = model.source.strip("\"' ")
|
||||||
config = (
|
config = (
|
||||||
{
|
{
|
||||||
"description": model.description,
|
"description": model.description,
|
||||||
@ -267,12 +235,12 @@ class InstallHelper(object):
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
installer.import_model(
|
installer.heuristic_import(
|
||||||
source=source,
|
source=model_path_id_or_url,
|
||||||
config=config,
|
config=config,
|
||||||
)
|
)
|
||||||
except (UnknownMetadataException, InvalidModelConfigException, HTTPError, OSError) as e:
|
except (UnknownMetadataException, InvalidModelConfigException, HTTPError, OSError) as e:
|
||||||
self._logger.warning(f"{source}: {e}")
|
self._logger.warning(f"{model.source}: {e}")
|
||||||
|
|
||||||
for model_to_remove in selections.remove_models:
|
for model_to_remove in selections.remove_models:
|
||||||
parts = model_to_remove.split("/")
|
parts = model_to_remove.split("/")
|
||||||
|
@ -256,4 +256,4 @@ def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: In
|
|||||||
assert job.error_type == "HTTPError"
|
assert job.error_type == "HTTPError"
|
||||||
assert job.error
|
assert job.error
|
||||||
assert "NOT FOUND" in job.error
|
assert "NOT FOUND" in job.error
|
||||||
assert "Traceback" in job.error
|
assert job.error_traceback.startswith("Traceback")
|
||||||
|
Loading…
Reference in New Issue
Block a user