mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
289 lines
12 KiB
Python
289 lines
12 KiB
Python
"""Utility (backend) functions used by model_install.py"""
|
|
import re
|
|
from logging import Logger
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
import omegaconf
|
|
from huggingface_hub import HfFolder
|
|
from pydantic import BaseModel, Field
|
|
from pydantic.dataclasses import dataclass
|
|
from pydantic.networks import AnyHttpUrl
|
|
from requests import HTTPError
|
|
from tqdm import tqdm
|
|
|
|
import invokeai.configs as configs
|
|
from invokeai.app.services.config import InvokeAIAppConfig
|
|
from invokeai.app.services.download import DownloadQueueService
|
|
from invokeai.app.services.events.events_base import EventServiceBase
|
|
from invokeai.app.services.image_files.image_files_disk import DiskImageFileStorage
|
|
from invokeai.app.services.model_install import (
|
|
HFModelSource,
|
|
LocalModelSource,
|
|
ModelInstallService,
|
|
ModelInstallServiceBase,
|
|
ModelSource,
|
|
URLModelSource,
|
|
)
|
|
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL
|
|
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
|
|
from invokeai.backend.model_manager import (
|
|
BaseModelType,
|
|
InvalidModelConfigException,
|
|
ModelType,
|
|
)
|
|
from invokeai.backend.model_manager.metadata import UnknownMetadataException
|
|
from invokeai.backend.util.logging import InvokeAILogger
|
|
|
|
# name of the starter models file
|
|
INITIAL_MODELS = "INITIAL_MODELS2.yaml"
|
|
|
|
|
|
def initialize_record_store(app_config: InvokeAIAppConfig) -> ModelRecordServiceBase:
|
|
"""Return an initialized ModelConfigRecordServiceBase object."""
|
|
logger = InvokeAILogger.get_logger(config=app_config)
|
|
image_files = DiskImageFileStorage(f"{app_config.output_path}/images")
|
|
db = init_db(config=app_config, logger=logger, image_files=image_files)
|
|
obj: ModelRecordServiceBase = ModelRecordServiceSQL(db)
|
|
return obj
|
|
|
|
|
|
def initialize_installer(
|
|
app_config: InvokeAIAppConfig, event_bus: Optional[EventServiceBase] = None
|
|
) -> ModelInstallServiceBase:
|
|
"""Return an initialized ModelInstallService object."""
|
|
record_store = initialize_record_store(app_config)
|
|
metadata_store = record_store.metadata_store
|
|
download_queue = DownloadQueueService()
|
|
installer = ModelInstallService(
|
|
app_config=app_config,
|
|
record_store=record_store,
|
|
metadata_store=metadata_store,
|
|
download_queue=download_queue,
|
|
event_bus=event_bus,
|
|
)
|
|
download_queue.start()
|
|
installer.start()
|
|
return installer
|
|
|
|
|
|
class UnifiedModelInfo(BaseModel):
|
|
"""Catchall class for information in INITIAL_MODELS2.yaml."""
|
|
|
|
name: Optional[str] = None
|
|
base: Optional[BaseModelType] = None
|
|
type: Optional[ModelType] = None
|
|
source: Optional[str] = None
|
|
subfolder: Optional[str] = None
|
|
description: Optional[str] = None
|
|
recommended: bool = False
|
|
installed: bool = False
|
|
default: bool = False
|
|
requires: List[str] = Field(default_factory=list)
|
|
|
|
|
|
@dataclass
|
|
class InstallSelections:
|
|
"""Lists of models to install and remove."""
|
|
|
|
install_models: List[UnifiedModelInfo] = Field(default_factory=list)
|
|
remove_models: List[str] = Field(default_factory=list)
|
|
|
|
|
|
class TqdmEventService(EventServiceBase):
|
|
"""An event service to track downloads."""
|
|
|
|
def __init__(self) -> None:
|
|
"""Create a new TqdmEventService object."""
|
|
super().__init__()
|
|
self._bars: Dict[str, tqdm] = {}
|
|
self._last: Dict[str, int] = {}
|
|
self._logger = InvokeAILogger.get_logger(__name__)
|
|
|
|
def dispatch(self, event_name: str, payload: Any) -> None:
|
|
"""Dispatch an event by appending it to self.events."""
|
|
data = payload["data"]
|
|
source = data["source"]
|
|
if payload["event"] == "model_install_downloading":
|
|
dest = data["local_path"]
|
|
total_bytes = data["total_bytes"]
|
|
bytes = data["bytes"]
|
|
if dest not in self._bars:
|
|
self._bars[dest] = tqdm(desc=Path(dest).name, initial=0, total=total_bytes, unit="iB", unit_scale=True)
|
|
self._last[dest] = 0
|
|
self._bars[dest].update(bytes - self._last[dest])
|
|
self._last[dest] = bytes
|
|
elif payload["event"] == "model_install_completed":
|
|
self._logger.info(f"{source}: installed successfully.")
|
|
elif payload["event"] == "model_install_error":
|
|
self._logger.warning(f"{source}: installation failed with error {data['error']}")
|
|
elif payload["event"] == "model_install_cancelled":
|
|
self._logger.warning(f"{source}: installation cancelled")
|
|
|
|
class InstallHelper(object):
|
|
"""Capture information stored jointly in INITIAL_MODELS.yaml and the installed models db."""
|
|
|
|
def __init__(self, app_config: InvokeAIAppConfig, logger: Logger):
|
|
"""Create new InstallHelper object."""
|
|
self._app_config = app_config
|
|
self.all_models: Dict[str, UnifiedModelInfo] = {}
|
|
|
|
omega = omegaconf.OmegaConf.load(Path(configs.__path__[0]) / INITIAL_MODELS)
|
|
assert isinstance(omega, omegaconf.dictconfig.DictConfig)
|
|
|
|
self._installer = initialize_installer(app_config, TqdmEventService())
|
|
self._initial_models = omega
|
|
self._installed_models: List[str] = []
|
|
self._starter_models: List[str] = []
|
|
self._default_model: Optional[str] = None
|
|
self._logger = logger
|
|
self._initialize_model_lists()
|
|
|
|
@property
|
|
def installer(self) -> ModelInstallServiceBase:
|
|
"""Return the installer object used internally."""
|
|
return self._installer
|
|
|
|
def _initialize_model_lists(self) -> None:
|
|
"""
|
|
Initialize our model slots.
|
|
|
|
Set up the following:
|
|
installed_models -- list of installed model keys
|
|
starter_models -- list of starter model keys from INITIAL_MODELS
|
|
all_models -- dict of key => UnifiedModelInfo
|
|
default_model -- key to default model
|
|
"""
|
|
# previously-installed models
|
|
for model in self._installer.record_store.all_models():
|
|
info = UnifiedModelInfo.parse_obj(model.dict())
|
|
info.installed = True
|
|
model_key = f"{model.base.value}/{model.type.value}/{model.name}"
|
|
self.all_models[model_key] = info
|
|
self._installed_models.append(model_key)
|
|
|
|
for key in self._initial_models.keys():
|
|
assert isinstance(key, str)
|
|
if key in self.all_models:
|
|
# we want to preserve the description
|
|
description = self.all_models[key].description or self._initial_models[key].get("description")
|
|
self.all_models[key].description = description
|
|
else:
|
|
base_model, model_type, model_name = key.split("/")
|
|
info = UnifiedModelInfo(
|
|
name=model_name,
|
|
type=ModelType(model_type),
|
|
base=BaseModelType(base_model),
|
|
source=self._initial_models[key].source,
|
|
description=self._initial_models[key].get("description"),
|
|
recommended=self._initial_models[key].get("recommended", False),
|
|
default=self._initial_models[key].get("default", False),
|
|
subfolder=self._initial_models[key].get("subfolder"),
|
|
requires=list(self._initial_models[key].get("requires", [])),
|
|
)
|
|
self.all_models[key] = info
|
|
if not self.default_model():
|
|
self._default_model = key
|
|
elif self._initial_models[key].get("default", False):
|
|
self._default_model = key
|
|
self._starter_models.append(key)
|
|
|
|
# previously-installed models
|
|
for model in self._installer.record_store.all_models():
|
|
info = UnifiedModelInfo.parse_obj(model.dict())
|
|
info.installed = True
|
|
model_key = f"{model.base.value}/{model.type.value}/{model.name}"
|
|
self.all_models[model_key] = info
|
|
self._installed_models.append(model_key)
|
|
|
|
def recommended_models(self) -> List[UnifiedModelInfo]:
|
|
"""List of the models recommended in INITIAL_MODELS.yaml."""
|
|
return [self._to_model(x) for x in self._starter_models if self._to_model(x).recommended]
|
|
|
|
def installed_models(self) -> List[UnifiedModelInfo]:
|
|
"""List of models already installed."""
|
|
return [self._to_model(x) for x in self._installed_models]
|
|
|
|
def starter_models(self) -> List[UnifiedModelInfo]:
|
|
"""List of starter models."""
|
|
return [self._to_model(x) for x in self._starter_models]
|
|
|
|
def default_model(self) -> Optional[UnifiedModelInfo]:
|
|
"""Return the default model."""
|
|
return self._to_model(self._default_model) if self._default_model else None
|
|
|
|
def _to_model(self, key: str) -> UnifiedModelInfo:
|
|
return self.all_models[key]
|
|
|
|
def _add_required_models(self, model_list: List[UnifiedModelInfo]) -> None:
|
|
installed = {x.source for x in self.installed_models()}
|
|
reverse_source = {x.source: x for x in self.all_models.values()}
|
|
additional_models: List[UnifiedModelInfo] = []
|
|
for model_info in model_list:
|
|
for requirement in model_info.requires:
|
|
if requirement not in installed and reverse_source.get(requirement):
|
|
additional_models.append(reverse_source[requirement])
|
|
model_list.extend(additional_models)
|
|
|
|
def _make_install_source(self, model_info: UnifiedModelInfo) -> ModelSource:
|
|
assert model_info.source
|
|
model_path_id_or_url = model_info.source.strip("\"' ")
|
|
model_path = Path(model_path_id_or_url)
|
|
|
|
if model_path.exists(): # local file on disk
|
|
return LocalModelSource(path=model_path.absolute(), inplace=True)
|
|
if re.match(r"^[^/]+/[^/]+$", model_path_id_or_url): # hugging face repo_id
|
|
return HFModelSource(
|
|
repo_id=model_path_id_or_url,
|
|
access_token=HfFolder.get_token(),
|
|
subfolder=model_info.subfolder,
|
|
)
|
|
if re.match(r"^(http|https):", model_path_id_or_url):
|
|
return URLModelSource(url=AnyHttpUrl(model_path_id_or_url))
|
|
raise ValueError(f"Unsupported model source: {model_path_id_or_url}")
|
|
|
|
def add_or_delete(self, selections: InstallSelections) -> None:
|
|
"""Add or delete selected models."""
|
|
installer = self._installer
|
|
self._add_required_models(selections.install_models)
|
|
for model in selections.install_models:
|
|
source = self._make_install_source(model)
|
|
config = (
|
|
{
|
|
"description": model.description,
|
|
"name": model.name,
|
|
}
|
|
if model.name
|
|
else None
|
|
)
|
|
|
|
try:
|
|
installer.import_model(
|
|
source=source,
|
|
config=config,
|
|
)
|
|
except (UnknownMetadataException, InvalidModelConfigException, HTTPError, OSError) as e:
|
|
self._logger.warning(f"{source}: {e}")
|
|
|
|
for model_to_remove in selections.remove_models:
|
|
parts = model_to_remove.split("/")
|
|
if len(parts) == 1:
|
|
base_model, model_type, model_name = (None, None, model_to_remove)
|
|
else:
|
|
base_model, model_type, model_name = parts
|
|
matches = installer.record_store.search_by_attr(
|
|
base_model=BaseModelType(base_model) if base_model else None,
|
|
model_type=ModelType(model_type) if model_type else None,
|
|
model_name=model_name,
|
|
)
|
|
if len(matches) > 1:
|
|
print(f"{model} is ambiguous. Please use model_type:model_name (e.g. main:my_model) to disambiguate.")
|
|
elif not matches:
|
|
print(f"{model}: unknown model")
|
|
else:
|
|
for m in matches:
|
|
print(f"Deleting {m.type}:{m.name}")
|
|
installer.delete(m.key)
|
|
|
|
installer.wait_for_installs()
|