define install abstract base class

This commit is contained in:
Lincoln Stein 2023-11-20 21:57:10 -05:00
parent 6f719b2c7a
commit 6c56233edc
5 changed files with 362 additions and 0 deletions

View File

@ -0,0 +1 @@
from .events_base import EventServiceBase # noqa F401

View File

@ -1,5 +1,7 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import traceback
from typing import Any, Optional
from invokeai.app.services.invocation_processor.invocation_processor_common import ProgressImage
@ -313,3 +315,55 @@ class EventServiceBase:
event_name="queue_cleared",
payload={"queue_id": queue_id},
)
def emit_model_install_started(self, source: str) -> None:
"""
Emitted when an install job is started.
:param source: Source of the model; local path, repo_id or url
"""
self.__emit_queue_event(
event_name="model_install_started",
payload={"source": source},
)
def emit_model_install_completed(self, source: str, dest: str) -> None:
"""
Emitted when an install job is completed successfully.
:param source: Source of the model; local path, repo_id or url
:param dest: Destination of the model files; always a local path
"""
self.__emit_queue_event(
event_name="model_install_completed",
payload={
"source": source,
"dest": dest,
},
)
def emit_model_install_error(self,
source:str,
exception: Exception,
) -> None:
"""
Emitted when an install job encounters an exception.
:param source: Source of the model
:param exception: The exception that raised the error
"""
# Revisit:
# it makes more sense to receive an exception and break it out
# into error_type and error here, rather than at the caller's side
error_type = exception.__class__.__name__,
error = traceback.format_exc(),
self.__emit_queue_event(
event_name="model_install_error",
payload={
"source": source,
"error_type": error_type,
"error": error,
},
)

View File

@ -0,0 +1 @@
from .model_install_default import ModelInstallService # noqa F401

View File

@ -0,0 +1,228 @@
from abc import ABC, abstractmethod
from enum import Enum
from pathlib import Path
from typing import Dict, Optional, Union
from pydantic import BaseModel, Field
from pydantic.networks import AnyHttpUrl
from invokeai.app.services.model_records import ModelRecordServiceBase
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.events import EventServiceBase
class InstallStatus(str, Enum):
"""State of an install job running in the background."""
WAITING = "waiting" # waiting to be dequeued
COMPLETED = "completed" # finished running
ERROR = "error" # terminated with an error message
class ModelInstallStatus(BaseModel):
status: InstallStatus = Field(default=InstallStatus.WAITING, description="Current status of install process") # noqa #501
error: Optional[Exception] = Field(default=None, description="Exception that led to status==ERROR") # noqa #501
class ModelInstallServiceBase(ABC):
"""Abstract base class for InvokeAI model installation."""
@abstractmethod
def __init__(
self,
config: InvokeAIAppConfig,
store: ModelRecordServiceBase,
event_bus: Optional["EventServiceBase"] = None,
):
"""
Create ModelInstallService object.
:param config: Systemwide InvokeAIAppConfig.
:param store: Systemwide ModelConfigStore
:param event_bus: InvokeAI event bus for reporting events to.
"""
pass
@property
@abstractmethod
def config(self) -> InvokeAIAppConfig:
"""Return the app_config used by the installer."""
pass
@property
@abstractmethod
def store(self) -> ModelRecordServiceBase:
"""Return the storage backend used by the installer."""
pass
@abstractmethod
def register_path(
self,
model_path: Union[Path, str],
name: Optional[str] = None,
description: Optional[str] = None,
metadata: Optional[Dict[str, str]] = None,
) -> str:
"""
Probe and register the model at model_path.
This keeps the model in its current location.
:param model_path: Filesystem Path to the model.
:param name: Name for the model (optional)
:param description: Description for the model (optional)
:param metadata: Dict of attributes that will override probed values.
:returns id: The string ID of the registered model.
"""
pass
@abstractmethod
def install_path(
self,
model_path: Union[Path, str],
name: Optional[str] = None,
description: Optional[str] = None,
metadata: Optional[Dict[str, str]] = None,
)-> str:
"""
Probe, register and install the model in the models directory.
This moves the model from its current location into
the models directory handled by InvokeAI.
:param model_path: Filesystem Path to the model.
:param name: Name for the model (optional)
:param description: Description for the model (optional)
:param metadata: Dict of attributes that will override probed values.
:returns id: The string ID of the registered model.
"""
pass
@abstractmethod
def install_model(
self,
source: Union[str, Path, AnyHttpUrl],
inplace: bool = True,
name: Optional[str] = None,
description: Optional[str] = None,
variant: Optional[str] = None,
subfolder: Optional[str] = None,
metadata: Optional[Dict[str, str]] = None,
access_token: Optional[str] = None,
) -> ModelInstallStatus:
"""Install the indicated model.
:param source: Either a URL or a HuggingFace repo_id.
:param inplace: If True, local paths will not be moved into
the models directory, but registered in place (the default).
:param variant: For HuggingFace models, this optional parameter
specifies which variant to download (e.g. 'fp16')
:param subfolder: When downloading HF repo_ids this can be used to
specify a subfolder of the HF repository to download from.
:param metadata: Optional dict. Any fields in this dict
will override corresponding probe fields. Use it to override
`base_type`, `model_type`, `format`, `prediction_type`, `image_size`,
and `ztsnr_training`.
:param access_token: Access token for use in downloading remote
models.
This will download the model located at `source`,
probe it, and install it into the models directory.
This call is executed asynchronously in a separate
thread and will issue the following events on the event bus:
- model_install_started
- model_install_error
- model_install_completed
The `inplace` flag does not affect the behavior of downloaded
models, which are always moved into the `models` directory.
The call returns a ModelInstallStatus object which can be
polled to learn the current status and/or error message.
Variants recognized by HuggingFace currently are:
1. onnx
2. openvino
3. fp16
4. None (usually returns fp32 model)
"""
pass
@abstractmethod
def wait_for_installs(self) -> Dict[Union[str, Path, AnyHttpUrl], Optional[str]]:
"""
Wait for all pending installs to complete.
This will block until all pending downloads have
completed, been cancelled, or errored out. It will
block indefinitely if one or more jobs are in the
paused state.
It will return a dict that maps the source model
path, URL or repo_id to the ID of the installed model.
"""
pass
@abstractmethod
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]:
"""
Recursively scan directory for new models and register or install them.
:param scan_dir: Path to the directory to scan.
:param install: Install if True, otherwise register in place.
:returns list of IDs: Returns list of IDs of models registered/installed
"""
pass
@abstractmethod
def sync_to_config(self):
"""Synchronize models on disk to those in memory."""
pass
@abstractmethod
def hash(self, model_path: Union[Path, str]) -> str:
"""
Compute and return the fast hash of the model.
:param model_path: Path to the model on disk.
:return str: FastHash of the model for use as an ID.
"""
pass
# The following are internal methods
@abstractmethod
def _create_name(self, model_path: Union[Path, str]) -> str:
"""
Creates a default name for the model.
:param model_path: Path to the model on disk.
:return str: Model name
"""
pass
@abstractmethod
def _create_description(self, model_path: Union[Path, str]) -> str:
"""
Creates a default description for the model.
:param model_path: Path to the model on disk.
:return str: Model description
"""
pass
@abstractmethod
def _create_id(self, model_path: Union[Path, str], name: Optional[str] = None) -> str:
"""
Creates a unique ID for the model for use with the model records module. # noqa E501
:param model_path: Path to the model on disk.
:param name: (optional) non-default name for the model
:return str: Unique ID
"""
pass

View File

@ -0,0 +1,78 @@
from pathlib import Path
from typing import Dict, Optional, Union
from pydantic import BaseModel, Field
from queue import Queue
from pydantic.networks import AnyHttpUrl
from .model_install_base import InstallStatus, ModelInstallStatus, ModelInstallServiceBase
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.model_records import ModelRecordServiceBase
from invokeai.app.services.events import EventServiceBase
class ModelInstallService(ModelInstallBase, BaseModel):
"""class for InvokeAI model installation."""
config: InvokeAIAppConfig
store: ModelRecordServiceBase
event_bus: Optional[EventServiceBase] = Field(default=None)
install_queue: Queue = Field(default_factory=Queue)
def register_path(
self,
model_path: Union[Path, str],
name: Optional[str] = None,
description: Optional[str] = None,
metadata: Optional[Dict[str, str]] = None,
) -> str:
raise NotImplementedError
def install_path(
self,
model_path: Union[Path, str],
name: Optional[str] = None,
description: Optional[str] = None,
metadata: Optional[Dict[str, str]] = None,
) -> str:
raise NotImplementedError
def install_model(
self,
source: Union[str, Path, AnyHttpUrl],
inplace: bool = True,
name: Optional[str] = None,
description: Optional[str] = None,
variant: Optional[str] = None,
subfolder: Optional[str] = None,
metadata: Optional[Dict[str, str]] = None,
access_token: Optional[str] = None,
) -> ModelInstallStatus:
raise NotImplementedError
def wait_for_installs(self) -> Dict[Union[str, Path, AnyHttpUrl], Optional[str]]:
raise NotImplementedError
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]:
raise NotImplementedError
def sync_to_config(self):
raise NotImplementedError
def hash(self, model_path: Union[Path, str]) -> str:
raise NotImplementedError
# The following are internal methods
def _create_name(self, model_path: Union[Path, str]) -> str:
model_path = Path(model_path)
if model_path.suffix in {'.safetensors', '.bin', '.pt', '.ckpt'}:
return model_path.stem
else:
return model_path.name
def _create_description(self, model_path: Union[Path, str]) -> str:
info: ModelProbeInfo = ModelProbe.probe(Path(model_path))
name: str = name or self._create_name(model_path)
return f"a {info.model_type} model based on {info.base_type}"
def _create_id(self, model_path: Union[Path, str], name: Optional[str] = None) -> str:
name: str = name or self._create_name(model_path)
raise NotImplementedError