mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
TUI installer more or less working
This commit is contained in:
@ -183,6 +183,7 @@ INIT_FILE = Path("invokeai.yaml")
|
||||
DB_FILE = Path("invokeai.db")
|
||||
LEGACY_INIT_FILE = Path("invokeai.init")
|
||||
DEFAULT_MAX_VRAM = 0.5
|
||||
DEFAULT_MAX_DISK_CACHE = 15 # gigs, enough for two sdxl models, or 5 sd-1 models
|
||||
|
||||
|
||||
class InvokeAIAppConfig(InvokeAISettings):
|
||||
@ -242,6 +243,7 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
# CACHE
|
||||
ram : Union[float, Literal["auto"]] = Field(default=6.0, gt=0, description="Maximum memory amount used by model cache for rapid switching (floating point number or 'auto')", category="Model Cache", )
|
||||
vram : Union[float, Literal["auto"]] = Field(default=0.25, ge=0, description="Amount of VRAM reserved for model storage (floating point number or 'auto')", category="Model Cache", )
|
||||
disk : float = Field(default=DEFAULT_MAX_DISK_CACHE, ge=0, description="Maximum size (in GB) for the disk-based diffusers model conversion cache", category="Model Cache", )
|
||||
lazy_offload : bool = Field(default=True, description="Keep models in VRAM until their space is needed", category="Model Cache", )
|
||||
|
||||
# DEVICE
|
||||
@ -408,6 +410,10 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
def ram_cache_size(self) -> float:
|
||||
return self.max_cache_size or self.ram
|
||||
|
||||
@property
|
||||
def conversion_cache_size(self) -> float:
|
||||
return self.disk
|
||||
|
||||
@property
|
||||
def vram_cache_size(self) -> float:
|
||||
return self.max_vram_cache_size or self.vram
|
||||
|
@ -138,31 +138,28 @@ class ModelCache(object):
|
||||
self._cached_models = dict()
|
||||
self._cache_stack = list()
|
||||
|
||||
# Note that the combination of model_path and submodel_type
|
||||
# are sufficient to generate a unique cache key. This key
|
||||
# is not the same as the unique hash used to identify models
|
||||
# in invokeai.backend.model_manager.storage
|
||||
def get_key(
|
||||
self,
|
||||
model_path: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
model_path: Path,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
):
|
||||
key = f"{model_path}:{base_model}:{model_type}"
|
||||
key = model_path.as_posix()
|
||||
if submodel_type:
|
||||
key += f":{submodel_type}"
|
||||
return key
|
||||
|
||||
def _get_model_info(
|
||||
self,
|
||||
model_path: str,
|
||||
model_path: Path,
|
||||
model_class: Type[ModelBase],
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
):
|
||||
model_info_key = self.get_key(
|
||||
model_path=model_path,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel_type=None,
|
||||
)
|
||||
model_info_key = self.get_key(model_path=model_path)
|
||||
|
||||
if model_info_key not in self.model_infos:
|
||||
self.model_infos[model_info_key] = model_class(
|
||||
@ -195,12 +192,8 @@ class ModelCache(object):
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
)
|
||||
key = self.get_key(
|
||||
model_path=model_path,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel_type=submodel,
|
||||
)
|
||||
key = self.get_key(model_path, submodel)
|
||||
|
||||
# TODO: lock for no copies on simultaneous calls?
|
||||
cache_entry = self._cached_models.get(key, None)
|
||||
if cache_entry is None:
|
||||
@ -305,18 +298,6 @@ class ModelCache(object):
|
||||
self._cache_stack.remove(cache_id)
|
||||
self._cached_models.pop(cache_id, None)
|
||||
|
||||
def model_hash(
|
||||
self,
|
||||
model_path: Union[str, Path],
|
||||
) -> str:
|
||||
"""
|
||||
Given the HF repo id or path to a model on disk, returns a unique
|
||||
hash. Works for legacy checkpoint files, HF models on disk, and HF repo IDs
|
||||
|
||||
:param model_path: Path to model file/directory on disk.
|
||||
"""
|
||||
return self._local_model_hash(model_path)
|
||||
|
||||
def cache_size(self) -> float:
|
||||
"""Return the current size of the cache, in GB."""
|
||||
return self._cache_size() / GIG
|
||||
@ -366,8 +347,8 @@ class ModelCache(object):
|
||||
|
||||
refs = sys.getrefcount(cache_entry.model)
|
||||
|
||||
# manualy clear local variable references of just finished function calls
|
||||
# for some reason python don't want to collect it even by gc.collect() immidiately
|
||||
# Manually clear local variable references of just finished function calls.
|
||||
# For some reason python doesn't want to garbage collect it even when gc.collect() is called
|
||||
if refs > 2:
|
||||
while True:
|
||||
cleared = False
|
||||
@ -435,26 +416,6 @@ class ModelCache(object):
|
||||
if choose_torch_device() == torch.device("mps"):
|
||||
mps.empty_cache()
|
||||
|
||||
def _local_model_hash(self, model_path: Union[str, Path]) -> str:
|
||||
sha = hashlib.sha256()
|
||||
path = Path(model_path)
|
||||
|
||||
hashpath = path / "checksum.sha256"
|
||||
if hashpath.exists() and path.stat().st_mtime <= hashpath.stat().st_mtime:
|
||||
with open(hashpath) as f:
|
||||
hash = f.read()
|
||||
return hash
|
||||
|
||||
self.logger.debug(f"computing hash of model {path.name}")
|
||||
for file in list(path.rglob("*.ckpt")) + list(path.rglob("*.safetensors")) + list(path.rglob("*.pth")):
|
||||
with open(file, "rb") as f:
|
||||
while chunk := f.read(self.sha_chunksize):
|
||||
sha.update(chunk)
|
||||
hash = sha.hexdigest()
|
||||
with open(hashpath, "w") as f:
|
||||
f.write(hash)
|
||||
return hash
|
||||
|
||||
|
||||
class VRAMUsage(object):
|
||||
def __init__(self):
|
||||
|
@ -39,8 +39,9 @@ Typical usage:
|
||||
# scan directory recursively and install all new models found
|
||||
ids: List[str] = installer.scan_directory('/path/to/directory')
|
||||
|
||||
# unregister any model whose path is no longer valid
|
||||
ids: List[str] = installer.garbage_collect()
|
||||
# Synchronize with the models directory, adding missing models and
|
||||
# removing orphans
|
||||
installer.scan_models_directory()
|
||||
|
||||
hash: str = installer.hash('/path/to/model') # should be same as id above
|
||||
|
||||
@ -59,7 +60,7 @@ from pydantic import Field
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.backend.util import Chdir, InvokeAILogger
|
||||
|
||||
from .config import (
|
||||
BaseModelType,
|
||||
@ -183,7 +184,9 @@ class ModelInstallBase(ABC):
|
||||
source: Union[str, Path, AnyHttpUrl],
|
||||
inplace: bool = True,
|
||||
variant: Optional[str] = None,
|
||||
info: Optional[ModelProbeInfo] = None,
|
||||
probe_override: Optional[Dict[str, Any]] = None,
|
||||
metadata: Optional[ModelSourceMetadata] = None,
|
||||
access_token: Optional[str] = None,
|
||||
) -> DownloadJobBase:
|
||||
"""
|
||||
Download and install the indicated model.
|
||||
@ -202,7 +205,11 @@ class ModelInstallBase(ABC):
|
||||
the models directory, but registered in place (the default).
|
||||
:param variant: For HuggingFace models, this optional parameter
|
||||
specifies which variant to download (e.g. 'fp16')
|
||||
:param info: Optional ModelProbeInfo object. If not provided, model will be probed.
|
||||
:param probe_override: Optional dict. Any fields in this dict
|
||||
will override corresponding probe fields. Use it to override
|
||||
`base_type`, `model_type`, `format`, `prediction_type` and `image_size`.
|
||||
:param metadata: Use this to override the fields 'description`,
|
||||
`author`, `tags`, `source` and `license`.
|
||||
:returns DownloadQueueBase object.
|
||||
|
||||
The `inplace` flag does not affect the behavior of downloaded
|
||||
@ -258,6 +265,11 @@ class ModelInstallBase(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def conditionally_delete(self, key: str): # noqa D102
|
||||
"""Unregister the model. Delete its files only if they are within our models directory."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]:
|
||||
"""
|
||||
@ -269,18 +281,6 @@ class ModelInstallBase(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def garbage_collect(self) -> List[str]:
|
||||
"""
|
||||
Unregister any models whose paths are no longer valid.
|
||||
|
||||
This checks each registered model's path. Models with paths that are
|
||||
no longer found on disk will be unregistered.
|
||||
|
||||
:return List[str]: Return the list of model IDs that were unregistered.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def hash(self, model_path: Union[Path, str]) -> str:
|
||||
"""
|
||||
@ -305,11 +305,21 @@ class ModelInstallBase(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def scan_models_directory(self):
|
||||
"""
|
||||
Scan the models directory for new and missing models.
|
||||
|
||||
New models will be added to the storage backend. Missing models
|
||||
will be deleted.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class ModelInstall(ModelInstallBase):
|
||||
"""Model installer class handles installation from a local path."""
|
||||
|
||||
_config: InvokeAIAppConfig
|
||||
_app_config: InvokeAIAppConfig
|
||||
_logger: InvokeAILogger
|
||||
_store: ModelConfigStore
|
||||
_download_queue: DownloadQueueBase
|
||||
@ -348,14 +358,17 @@ class ModelInstall(ModelInstallBase):
|
||||
download: Optional[DownloadQueueBase] = None,
|
||||
event_handlers: Optional[List[DownloadEventHandler]] = None,
|
||||
): # noqa D107 - use base class docstrings
|
||||
self._config = config or InvokeAIAppConfig.get_config()
|
||||
self._logger = logger or InvokeAILogger.getLogger(config=self._config)
|
||||
self._store = store or get_config_store(self._config.model_conf_path)
|
||||
self._download_queue = download or DownloadQueue(config=self._config, event_handlers=event_handlers)
|
||||
self._app_config = config or InvokeAIAppConfig.get_config()
|
||||
self._logger = logger or InvokeAILogger.getLogger(config=self._app_config)
|
||||
self._store = store or get_config_store(self._app_config.model_conf_path)
|
||||
self._download_queue = download or DownloadQueue(config=self._app_config, event_handlers=event_handlers)
|
||||
self._async_installs = dict()
|
||||
self._installed = set()
|
||||
self._tmpdir = None
|
||||
|
||||
# this step synchronizes the `models` directory with the models db
|
||||
self.scan_models_directory()
|
||||
|
||||
@property
|
||||
def queue(self) -> DownloadQueueBase:
|
||||
"""Return the queue."""
|
||||
@ -397,7 +410,7 @@ class ModelInstall(ModelInstallBase):
|
||||
)
|
||||
config_file = config_file[SchedulerPredictionType.VPrediction]
|
||||
registration_data.update(
|
||||
config=Path(self._config.legacy_conf_dir, config_file).as_posix(),
|
||||
config=Path(self._app_config.legacy_conf_dir, config_file).as_posix(),
|
||||
)
|
||||
except KeyError as exc:
|
||||
raise InvalidModelException(
|
||||
@ -414,7 +427,7 @@ class ModelInstall(ModelInstallBase):
|
||||
model_path = Path(model_path)
|
||||
info: ModelProbeInfo = self._probe_model(model_path, overrides)
|
||||
|
||||
dest_path = self._config.models_path / info.base_type.value / info.model_type.value / model_path.name
|
||||
dest_path = self._app_config.models_path / info.base_type.value / info.model_type.value / model_path.name
|
||||
return self._register(
|
||||
self._move_model(model_path, dest_path),
|
||||
info,
|
||||
@ -437,7 +450,10 @@ class ModelInstall(ModelInstallBase):
|
||||
info: ModelProbeInfo = ModelProbe.probe(model_path)
|
||||
if overrides: # used to override probe fields
|
||||
for key, value in overrides.items():
|
||||
setattr(info, key, value) # may generate a pydantic validation error
|
||||
try:
|
||||
setattr(info, key, value) # skip validation errors
|
||||
except:
|
||||
pass
|
||||
return info
|
||||
|
||||
def unregister(self, key: str): # noqa D102
|
||||
@ -448,12 +464,23 @@ class ModelInstall(ModelInstallBase):
|
||||
rmtree(model.path)
|
||||
self.unregister(key)
|
||||
|
||||
def conditionally_delete(self, key: str): # noqa D102
|
||||
"""Unregister the model. Delete its files only if they are within our models directory."""
|
||||
model = self._store.get_model(key)
|
||||
models_dir = self._app_config.models_path
|
||||
model_path = models_dir / model.path
|
||||
if model_path.is_relative_to(models_dir):
|
||||
self.delete(key)
|
||||
else:
|
||||
self.unregister(key)
|
||||
|
||||
def install(
|
||||
self,
|
||||
source: Union[str, Path, AnyHttpUrl],
|
||||
inplace: bool = True,
|
||||
variant: Optional[str] = None,
|
||||
probe_override: Optional[Dict[str, Any]] = None,
|
||||
metadata: Optional[ModelSourceMetadata] = None,
|
||||
access_token: Optional[str] = None,
|
||||
) -> DownloadJobBase: # noqa D102
|
||||
queue = self._download_queue
|
||||
@ -465,6 +492,7 @@ class ModelInstall(ModelInstallBase):
|
||||
else self._complete_installation_handler
|
||||
)
|
||||
job.probe_override = probe_override
|
||||
job.metadata = metadata
|
||||
job.add_event_handler(handler)
|
||||
|
||||
self._async_installs[source] = None
|
||||
@ -523,7 +551,7 @@ class ModelInstall(ModelInstallBase):
|
||||
"""
|
||||
model = self._store.get_model(key)
|
||||
old_path = Path(model.path)
|
||||
models_dir = self._config.models_path
|
||||
models_dir = self._app_config.models_path
|
||||
|
||||
if not old_path.is_relative_to(models_dir):
|
||||
return old_path
|
||||
@ -542,6 +570,10 @@ class ModelInstall(ModelInstallBase):
|
||||
variant: Optional[str] = None,
|
||||
access_token: Optional[str] = None,
|
||||
) -> DownloadJobBase:
|
||||
# Clean up a common source of error. Doesn't work with Paths.
|
||||
if isinstance(source, str):
|
||||
source = source.strip()
|
||||
|
||||
# In the event that we are being asked to install a path that is already on disk,
|
||||
# we simply probe and register/install it. The job does not actually do anything, but we
|
||||
# create one anyway in order to have similar behavior for local files, URLs and repo_ids.
|
||||
@ -551,7 +583,7 @@ class ModelInstall(ModelInstallBase):
|
||||
return ModelInstallPathJob(source=source, destination=Path(destdir))
|
||||
|
||||
# choose a temporary directory inside the models directory
|
||||
models_dir = self._config.models_path
|
||||
models_dir = self._app_config.models_path
|
||||
self._tmpdir = self._tmpdir or tempfile.TemporaryDirectory(dir=models_dir)
|
||||
|
||||
if re.match(REPO_ID_RE, str(source)):
|
||||
@ -577,15 +609,6 @@ class ModelInstall(ModelInstallBase):
|
||||
search.search(scan_dir)
|
||||
return list(self._installed)
|
||||
|
||||
def garbage_collect(self) -> List[str]: # noqa D102
|
||||
unregistered = list()
|
||||
for model in self._store.all_models():
|
||||
path = Path(model.path)
|
||||
if not path.exists():
|
||||
self._store.del_model(model.key)
|
||||
unregistered.append(model.key)
|
||||
return unregistered
|
||||
|
||||
def hash(self, model_path: Union[Path, str]) -> str: # noqa D102
|
||||
return FastModelHash.hash(model_path)
|
||||
|
||||
@ -618,7 +641,7 @@ class ModelInstall(ModelInstallBase):
|
||||
# We are taking advantage of a side effect of get_model() that converts check points
|
||||
# into cached diffusers directories stored at `path`. It doesn't matter
|
||||
# what submodel type we request here, so we get the smallest.
|
||||
loader = ModelLoad(self._config)
|
||||
loader = ModelLoad(self._app_config)
|
||||
submodel = {"submodel_type": SubModelType.Scheduler} if info.model_type == ModelType.Main else {}
|
||||
converted_model: ModelInfo = loader.get_model(key, **submodel)
|
||||
|
||||
@ -646,7 +669,7 @@ class ModelInstall(ModelInstallBase):
|
||||
rmtree(new_diffusers_path)
|
||||
raise excp
|
||||
|
||||
if checkpoint_path.exists() and checkpoint_path.is_relative_to(self._config.models_path):
|
||||
if checkpoint_path.exists() and checkpoint_path.is_relative_to(self._app_config.models_path):
|
||||
checkpoint_path.unlink()
|
||||
|
||||
return result
|
||||
@ -670,3 +693,30 @@ class ModelInstall(ModelInstallBase):
|
||||
except DuplicateModelException:
|
||||
pass
|
||||
return True
|
||||
|
||||
def scan_models_directory(self):
|
||||
"""
|
||||
Scan the models directory for new and missing models.
|
||||
|
||||
New models will be added to the storage backend. Missing models
|
||||
will be deleted.
|
||||
"""
|
||||
defunct_models = set()
|
||||
installed = set()
|
||||
|
||||
with Chdir(self._app_config.models_path):
|
||||
self._logger.info("Checking for models that have been moved or deleted from disk.")
|
||||
for model_config in self._store.all_models():
|
||||
path = Path(model_config.path)
|
||||
if not path.exists():
|
||||
self._logger.info(f"{model_config.name}: path {path.as_posix()} no longer exists. Unregistering.")
|
||||
defunct_models.add(model_config.key)
|
||||
for key in defunct_models:
|
||||
self.unregister(key)
|
||||
|
||||
self._logger.info(f"Scanning {self._app_config.models_path} for new models")
|
||||
for cur_base_model in BaseModelType:
|
||||
for cur_model_type in ModelType:
|
||||
models_dir = Path(cur_base_model.value, cur_model_type.value)
|
||||
installed.update(self.scan_directory(models_dir))
|
||||
self._logger.info(f"{len(installed)} new models registered; {len(defunct_models)} unregistered")
|
||||
|
@ -2,6 +2,7 @@
|
||||
"""Model loader for InvokeAI."""
|
||||
|
||||
import hashlib
|
||||
import shutil
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
@ -10,9 +11,9 @@ from typing import List, Optional, Union
|
||||
import torch
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.util import Chdir, InvokeAILogger, choose_precision, choose_torch_device
|
||||
from invokeai.backend.util import InvokeAILogger, choose_precision, choose_torch_device, directory_size
|
||||
|
||||
from .cache import CacheStats, ModelCache, ModelLocker
|
||||
from .cache import GIG, CacheStats, ModelCache, ModelLocker
|
||||
from .config import BaseModelType, ModelConfigBase, ModelType, SubModelType
|
||||
from .download import DownloadEventHandler
|
||||
from .install import ModelInstall, ModelInstallBase
|
||||
@ -174,8 +175,6 @@ class ModelLoad(ModelLoadBase):
|
||||
logger=self._logger,
|
||||
)
|
||||
|
||||
self._scan_models_directory()
|
||||
|
||||
@property
|
||||
def store(self) -> ModelConfigStore:
|
||||
"""Return the ModelConfigStore instance used by this class."""
|
||||
@ -232,11 +231,12 @@ class ModelLoad(ModelLoadBase):
|
||||
if not model_path.exists():
|
||||
raise InvalidModelException(f"Files for model '{key}' not found at {model_path}")
|
||||
|
||||
dst_convert_path = self._get_model_cache_path(model_path)
|
||||
dst_convert_path = self._get_model_convert_cache_path(model_path)
|
||||
model_path = model_class.convert_if_required(
|
||||
model_config=model_config,
|
||||
output_path=dst_convert_path,
|
||||
)
|
||||
self._trim_model_convert_cache() # keeps cache size under control
|
||||
|
||||
model_context = self._cache.get_model(
|
||||
model_path=model_path,
|
||||
@ -273,9 +273,37 @@ class ModelLoad(ModelLoadBase):
|
||||
model_class = MODEL_CLASSES[base_model][model_type]
|
||||
return model_class
|
||||
|
||||
def _get_model_cache_path(self, model_path):
|
||||
def _get_model_convert_cache_path(self, model_path):
|
||||
return self.resolve_model_path(Path(".cache") / hashlib.md5(str(model_path).encode()).hexdigest())
|
||||
|
||||
def _trim_model_convert_cache(self):
|
||||
max_cache_size = self._app_config.conversion_cache_size * GIG
|
||||
cache_path = self.resolve_model_path(Path(".cache"))
|
||||
current_size = directory_size(cache_path)
|
||||
|
||||
if current_size <= max_cache_size:
|
||||
return
|
||||
|
||||
self.logger.debug("Convert cache has gotten too large. Trimming.")
|
||||
|
||||
# For this to work, we make the assumption that the directory contains
|
||||
# either a 'unet/config.json' file, or a 'config.json' file at top level
|
||||
def by_atime(path: Path) -> float:
|
||||
for config in ["unet/config.json", "config.json"]:
|
||||
sentinel = path / sentinel
|
||||
if sentinel.exists():
|
||||
return sentinel.stat().m_atime
|
||||
return 0.0
|
||||
|
||||
# sort by last access time - least accessed files will be at the end
|
||||
lru_models = sorted(cache_dir.iterdir(), key=by_atime, reverse=True)
|
||||
while current_size > max_cache_size:
|
||||
next_victim = lru_models.pop()
|
||||
victim_size = directory_size(next_victim)
|
||||
self.logger.debug(f"Removing cached converted model {next_victim} to free {victim_size / GIG} GB")
|
||||
shutil.rmtree(next_victim)
|
||||
current_size -= victim_size
|
||||
|
||||
def _get_model_path(
|
||||
self, model_config: ModelConfigBase, submodel_type: Optional[SubModelType] = None
|
||||
) -> (Path, bool):
|
||||
@ -298,25 +326,4 @@ class ModelLoad(ModelLoadBase):
|
||||
|
||||
def sync_to_config(self):
|
||||
self._store = get_config_store(self._models_file)
|
||||
self._scan_models_directory()
|
||||
|
||||
def _scan_models_directory(self):
|
||||
defunct_models = set()
|
||||
installed = set()
|
||||
|
||||
with Chdir(self._app_config.models_path):
|
||||
self._logger.info("Checking for models that have been moved or deleted from disk.")
|
||||
for model_config in self._store.all_models():
|
||||
path = self.resolve_model_path(model_config.path)
|
||||
if not path.exists():
|
||||
self._logger.info(f"{model_config.name}: path {path.as_posix()} no longer exists. Unregistering.")
|
||||
defunct_models.add(model_config.key)
|
||||
for key in defunct_models:
|
||||
self._installer.unregister(key)
|
||||
|
||||
self._logger.info(f"Scanning {self._app_config.models_path} for new models")
|
||||
for cur_base_model in BaseModelType:
|
||||
for cur_model_type in ModelType:
|
||||
models_dir = self.resolve_model_path(Path(cur_base_model.value, cur_model_type.value))
|
||||
installed.update(self._installer.scan_directory(models_dir))
|
||||
self._logger.info(f"{len(installed)} new models registered; {len(defunct_models)} unregistered")
|
||||
self.installer.scan_models_directory()
|
||||
|
@ -40,8 +40,6 @@ class StableDiffusionModelBase(DiffusersModel):
|
||||
output_path: str,
|
||||
) -> str:
|
||||
if isinstance(model_config, MainCheckpointConfig):
|
||||
from invokeai.backend.model_manager.models.stable_diffusion import _convert_ckpt_and_cache
|
||||
|
||||
return _convert_ckpt_and_cache(
|
||||
model_config=model_config,
|
||||
output_path=output_path,
|
||||
@ -225,7 +223,6 @@ class StableDiffusion2Model(StableDiffusionModelBase):
|
||||
|
||||
|
||||
# TODO: rework
|
||||
# pass precision - currently defaulting to fp16
|
||||
def _convert_ckpt_and_cache(
|
||||
model_config: ModelConfigBase,
|
||||
output_path: str,
|
||||
|
@ -12,4 +12,11 @@ from .devices import ( # noqa: F401
|
||||
torch_dtype,
|
||||
)
|
||||
from .logging import InvokeAILogger # noqa: F401
|
||||
from .util import Chdir, ask_user, download_with_resume, instantiate_from_config, url_attachment_name # noqa: F401
|
||||
from .util import ( # noqa: F401
|
||||
Chdir,
|
||||
ask_user,
|
||||
directory_size,
|
||||
download_with_resume,
|
||||
instantiate_from_config,
|
||||
url_attachment_name,
|
||||
)
|
||||
|
@ -101,6 +101,7 @@ def get_obj_from_str(string, reload=False):
|
||||
return getattr(importlib.import_module(module, package=None), cls)
|
||||
|
||||
|
||||
# DEAD CODE?
|
||||
def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False):
|
||||
# create dummy dataset instance
|
||||
|
||||
@ -113,6 +114,7 @@ def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False):
|
||||
Q.put("Done")
|
||||
|
||||
|
||||
# DEAD CODE?
|
||||
def parallel_data_prefetch(
|
||||
func: callable,
|
||||
data,
|
||||
@ -363,6 +365,19 @@ def image_to_dataURL(image: Image.Image, image_format: str = "PNG") -> str:
|
||||
return image_base64
|
||||
|
||||
|
||||
def directory_size(directory: Path) -> int:
|
||||
"""
|
||||
Returns the aggregate size of all files in a directory (bytes).
|
||||
"""
|
||||
sum = 0
|
||||
for root, dirs, files in os.walk(directory):
|
||||
for f in files:
|
||||
sum += Path(root, f).stat().st_size
|
||||
for d in dirs:
|
||||
sum += Path(root, d).stat().st_size
|
||||
return sum
|
||||
|
||||
|
||||
class Chdir(object):
|
||||
"""Context manager to chdir to desired directory and change back after context exits:
|
||||
Args:
|
||||
|
@ -1,105 +1,105 @@
|
||||
# This file predefines a few models that the user may want to install.
|
||||
sd-1/main/stable-diffusion-v1-5:
|
||||
description: Stable Diffusion version 1.5 diffusers model (4.27 GB)
|
||||
repo_id: runwayml/stable-diffusion-v1-5
|
||||
source: runwayml/stable-diffusion-v1-5
|
||||
recommended: True
|
||||
default: True
|
||||
sd-1/main/stable-diffusion-v1-5-inpainting:
|
||||
description: RunwayML SD 1.5 model optimized for inpainting, diffusers version (4.27 GB)
|
||||
repo_id: runwayml/stable-diffusion-inpainting
|
||||
source: runwayml/stable-diffusion-inpainting
|
||||
recommended: True
|
||||
sd-2/main/stable-diffusion-2-1:
|
||||
description: Stable Diffusion version 2.1 diffusers model, trained on 768 pixel images (5.21 GB)
|
||||
repo_id: stabilityai/stable-diffusion-2-1
|
||||
source: stabilityai/stable-diffusion-2-1
|
||||
recommended: False
|
||||
sd-2/main/stable-diffusion-2-inpainting:
|
||||
description: Stable Diffusion version 2.0 inpainting model (5.21 GB)
|
||||
repo_id: stabilityai/stable-diffusion-2-inpainting
|
||||
source: stabilityai/stable-diffusion-2-inpainting
|
||||
recommended: False
|
||||
sdxl/main/stable-diffusion-xl-base-1-0:
|
||||
description: Stable Diffusion XL base model (12 GB)
|
||||
repo_id: stabilityai/stable-diffusion-xl-base-1.0
|
||||
source: stabilityai/stable-diffusion-xl-base-1.0
|
||||
recommended: True
|
||||
sdxl-refiner/main/stable-diffusion-xl-refiner-1-0:
|
||||
description: Stable Diffusion XL refiner model (12 GB)
|
||||
repo_id: stabilityai/stable-diffusion-xl-refiner-1.0
|
||||
source: stabilityai/stable-diffusion-xl-refiner-1.0
|
||||
recommended: False
|
||||
sdxl/vae/sdxl-1-0-vae-fix:
|
||||
description: Fine tuned version of the SDXL-1.0 VAE
|
||||
repo_id: madebyollin/sdxl-vae-fp16-fix
|
||||
sdxl/vae/sdxl-vae-fp16-fix:
|
||||
description: Version of the SDXL-1.0 VAE that works in half precision mode
|
||||
source: madebyollin/sdxl-vae-fp16-fix
|
||||
recommended: True
|
||||
sd-1/main/Analog-Diffusion:
|
||||
description: An SD-1.5 model trained on diverse analog photographs (2.13 GB)
|
||||
repo_id: wavymulder/Analog-Diffusion
|
||||
source: wavymulder/Analog-Diffusion
|
||||
recommended: False
|
||||
sd-1/main/Deliberate:
|
||||
description: Versatile model that produces detailed images up to 768px (4.27 GB)
|
||||
repo_id: XpucT/Deliberate
|
||||
source: XpucT/Deliberate
|
||||
recommended: False
|
||||
sd-1/main/Dungeons-and-Diffusion:
|
||||
description: Dungeons & Dragons characters (2.13 GB)
|
||||
repo_id: 0xJustin/Dungeons-and-Diffusion
|
||||
source: 0xJustin/Dungeons-and-Diffusion
|
||||
recommended: False
|
||||
sd-1/main/dreamlike-photoreal-2:
|
||||
description: A photorealistic model trained on 768 pixel images based on SD 1.5 (2.13 GB)
|
||||
repo_id: dreamlike-art/dreamlike-photoreal-2.0
|
||||
source: dreamlike-art/dreamlike-photoreal-2.0
|
||||
recommended: False
|
||||
sd-1/main/Inkpunk-Diffusion:
|
||||
description: Stylized illustrations inspired by Gorillaz, FLCL and Shinkawa; prompt with "nvinkpunk" (4.27 GB)
|
||||
repo_id: Envvi/Inkpunk-Diffusion
|
||||
source: Envvi/Inkpunk-Diffusion
|
||||
recommended: False
|
||||
sd-1/main/openjourney:
|
||||
description: An SD 1.5 model fine tuned on Midjourney; prompt with "mdjrny-v4 style" (2.13 GB)
|
||||
repo_id: prompthero/openjourney
|
||||
source: prompthero/openjourney
|
||||
recommended: False
|
||||
sd-1/main/seek.art_MEGA:
|
||||
repo_id: coreco/seek.art_MEGA
|
||||
source: coreco/seek.art_MEGA
|
||||
description: A general use SD-1.5 "anything" model that supports multiple styles (2.1 GB)
|
||||
recommended: False
|
||||
sd-1/main/trinart_stable_diffusion_v2:
|
||||
description: An SD-1.5 model finetuned with ~40K assorted high resolution manga/anime-style images (2.13 GB)
|
||||
repo_id: naclbit/trinart_stable_diffusion_v2
|
||||
source: naclbit/trinart_stable_diffusion_v2
|
||||
recommended: False
|
||||
sd-1/controlnet/canny:
|
||||
repo_id: lllyasviel/control_v11p_sd15_canny
|
||||
source: lllyasviel/control_v11p_sd15_canny
|
||||
recommended: True
|
||||
sd-1/controlnet/inpaint:
|
||||
repo_id: lllyasviel/control_v11p_sd15_inpaint
|
||||
source: lllyasviel/control_v11p_sd15_inpaint
|
||||
sd-1/controlnet/mlsd:
|
||||
repo_id: lllyasviel/control_v11p_sd15_mlsd
|
||||
source: lllyasviel/control_v11p_sd15_mlsd
|
||||
sd-1/controlnet/depth:
|
||||
repo_id: lllyasviel/control_v11f1p_sd15_depth
|
||||
source: lllyasviel/control_v11f1p_sd15_depth
|
||||
recommended: True
|
||||
sd-1/controlnet/normal_bae:
|
||||
repo_id: lllyasviel/control_v11p_sd15_normalbae
|
||||
source: lllyasviel/control_v11p_sd15_normalbae
|
||||
sd-1/controlnet/seg:
|
||||
repo_id: lllyasviel/control_v11p_sd15_seg
|
||||
source: lllyasviel/control_v11p_sd15_seg
|
||||
sd-1/controlnet/lineart:
|
||||
repo_id: lllyasviel/control_v11p_sd15_lineart
|
||||
source: lllyasviel/control_v11p_sd15_lineart
|
||||
recommended: True
|
||||
sd-1/controlnet/lineart_anime:
|
||||
repo_id: lllyasviel/control_v11p_sd15s2_lineart_anime
|
||||
source: lllyasviel/control_v11p_sd15s2_lineart_anime
|
||||
sd-1/controlnet/openpose:
|
||||
repo_id: lllyasviel/control_v11p_sd15_openpose
|
||||
source: lllyasviel/control_v11p_sd15_openpose
|
||||
recommended: True
|
||||
sd-1/controlnet/scribble:
|
||||
repo_id: lllyasviel/control_v11p_sd15_scribble
|
||||
source: lllyasviel/control_v11p_sd15_scribble
|
||||
recommended: False
|
||||
sd-1/controlnet/softedge:
|
||||
repo_id: lllyasviel/control_v11p_sd15_softedge
|
||||
source: lllyasviel/control_v11p_sd15_softedge
|
||||
sd-1/controlnet/shuffle:
|
||||
repo_id: lllyasviel/control_v11e_sd15_shuffle
|
||||
source: lllyasviel/control_v11e_sd15_shuffle
|
||||
sd-1/controlnet/tile:
|
||||
repo_id: lllyasviel/control_v11f1e_sd15_tile
|
||||
source: lllyasviel/control_v11f1e_sd15_tile
|
||||
sd-1/controlnet/ip2p:
|
||||
repo_id: lllyasviel/control_v11e_sd15_ip2p
|
||||
source: lllyasviel/control_v11e_sd15_ip2p
|
||||
sd-1/embedding/EasyNegative:
|
||||
path: https://huggingface.co/embed/EasyNegative/resolve/main/EasyNegative.safetensors
|
||||
source: https://huggingface.co/embed/EasyNegative/resolve/main/EasyNegative.safetensors
|
||||
recommended: True
|
||||
sd-1/embedding/ahx-beta-453407d:
|
||||
repo_id: sd-concepts-library/ahx-beta-453407d
|
||||
source: sd-concepts-library/ahx-beta-453407d
|
||||
sd-1/lora/LowRA:
|
||||
path: https://civitai.com/api/download/models/63006
|
||||
source: https://civitai.com/api/download/models/63006
|
||||
recommended: True
|
||||
sd-1/lora/Ink scenery:
|
||||
path: https://civitai.com/api/download/models/83390
|
||||
source: https://civitai.com/api/download/models/83390
|
||||
|
@ -6,7 +6,6 @@
|
||||
|
||||
"""
|
||||
This is the npyscreen frontend to the model installation application.
|
||||
The work is actually done in backend code in model_install_backend.py.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
@ -16,18 +15,26 @@ import sys
|
||||
import textwrap
|
||||
import traceback
|
||||
from argparse import Namespace
|
||||
from dataclasses import dataclass, field
|
||||
from multiprocessing import Process
|
||||
from multiprocessing.connection import Connection, Pipe
|
||||
from pathlib import Path
|
||||
from shutil import get_terminal_size
|
||||
from typing import List, Optional
|
||||
|
||||
import npyscreen
|
||||
import omegaconf
|
||||
import torch
|
||||
from huggingface_hub import HfFolder
|
||||
from npyscreen import widget
|
||||
from pydantic import BaseModel
|
||||
|
||||
import invokeai.configs as configs
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.install.model_install_backend import InstallSelections, ModelInstall, SchedulerPredictionType
|
||||
from invokeai.backend.model_management import ModelManager, ModelType
|
||||
|
||||
# from invokeai.backend.install.model_install_backend import InstallSelections, ModelInstall, SchedulerPredictionType
|
||||
from invokeai.backend.model_manager import BaseModelType, ModelFormat, ModelInstall, ModelInstallJob, ModelType
|
||||
from invokeai.backend.model_manager.install import ModelSourceMetadata
|
||||
from invokeai.backend.util import choose_precision, choose_torch_device
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.frontend.install.widgets import (
|
||||
@ -55,6 +62,29 @@ NOPRINT_TRANS_TABLE = {i: None for i in range(0, sys.maxunicode + 1) if not chr(
|
||||
# maximum number of installed models we can display before overflowing vertically
|
||||
MAX_OTHER_MODELS = 72
|
||||
|
||||
# name of the starter models file
|
||||
INITIAL_MODELS = "INITIAL_MODELS.yaml"
|
||||
INITIAL_MODELS_CONFIG = omegaconf.OmegaConf.load(Path(configs.__path__[0]) / INITIAL_MODELS)
|
||||
|
||||
ACCESS_TOKEN = HfFolder.get_token()
|
||||
|
||||
|
||||
class UnifiedModelInfo(BaseModel):
|
||||
name: str
|
||||
base_model: BaseModelType
|
||||
model_type: ModelType
|
||||
source: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
recommended: bool = False
|
||||
installed: bool = False
|
||||
default: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class InstallSelections:
|
||||
install_models: List[UnifiedModelInfo] = field(default_factory=list)
|
||||
remove_models: List[UnifiedModelInfo] = field(default_factory=list)
|
||||
|
||||
|
||||
def make_printable(s: str) -> str:
|
||||
"""Replace non-printable characters in a string"""
|
||||
@ -74,17 +104,13 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
super().__init__(parentApp=parentApp, name=name, *args, **keywords)
|
||||
|
||||
def create(self):
|
||||
self.installer = self.parentApp.installer
|
||||
self.initialize_model_lists()
|
||||
self.model_labels = self._get_model_labels()
|
||||
self.keypress_timeout = 10
|
||||
self.counter = 0
|
||||
self.subprocess_connection = None
|
||||
|
||||
if not config.model_conf_path.exists():
|
||||
with open(config.model_conf_path, "w") as file:
|
||||
print("# InvokeAI model configuration file", file=file)
|
||||
self.installer = ModelInstall(config)
|
||||
self.all_models = self.installer.all_models()
|
||||
self.starter_models = self.installer.starter_models()
|
||||
self.model_labels = self._get_model_labels()
|
||||
window_width, window_height = get_terminal_size()
|
||||
|
||||
self.nextrely -= 1
|
||||
@ -154,7 +180,6 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
)
|
||||
|
||||
self.nextrely += 1
|
||||
done_label = "APPLY CHANGES"
|
||||
back_label = "BACK"
|
||||
cancel_label = "CANCEL"
|
||||
current_position = self.nextrely
|
||||
@ -170,14 +195,8 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
npyscreen.ButtonPress, name=cancel_label, when_pressed_function=self.on_cancel
|
||||
)
|
||||
self.nextrely = current_position
|
||||
self.ok_button = self.add_widget_intelligent(
|
||||
npyscreen.ButtonPress,
|
||||
name=done_label,
|
||||
relx=(window_width - len(done_label)) // 2,
|
||||
when_pressed_function=self.on_execute,
|
||||
)
|
||||
|
||||
label = "APPLY CHANGES & EXIT"
|
||||
label = "APPLY CHANGES"
|
||||
self.nextrely = current_position
|
||||
self.done = self.add_widget_intelligent(
|
||||
npyscreen.ButtonPress,
|
||||
@ -195,16 +214,15 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
def add_starter_pipelines(self) -> dict[str, npyscreen.widget]:
|
||||
"""Add widgets responsible for selecting diffusers models"""
|
||||
widgets = dict()
|
||||
models = self.all_models
|
||||
starters = self.starter_models
|
||||
starter_model_labels = self.model_labels
|
||||
|
||||
self.installed_models = sorted([x for x in starters if models[x].installed])
|
||||
all_models = self.all_models # master dict of all models, indexed by key
|
||||
model_list = [x for x in self.starter_models if all_models[x].model_type in ["main", "vae"]]
|
||||
model_labels = [self.model_labels[x] for x in model_list]
|
||||
|
||||
widgets.update(
|
||||
label1=self.add_widget_intelligent(
|
||||
CenteredTitleText,
|
||||
name="Select from a starter set of Stable Diffusion models from HuggingFace.",
|
||||
name="Select from a starter set of Stable Diffusion models from HuggingFace and Civitae.",
|
||||
editable=False,
|
||||
labelColor="CAUTION",
|
||||
)
|
||||
@ -214,23 +232,24 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
# if user has already installed some initial models, then don't patronize them
|
||||
# by showing more recommendations
|
||||
show_recommended = len(self.installed_models) == 0
|
||||
keys = [x for x in models.keys() if x in starters]
|
||||
|
||||
checked = [
|
||||
model_list.index(x)
|
||||
for x in model_list
|
||||
if (show_recommended and all_models[x].recommended) or all_models[x].installed
|
||||
]
|
||||
widgets.update(
|
||||
models_selected=self.add_widget_intelligent(
|
||||
MultiSelectColumns,
|
||||
columns=1,
|
||||
name="Install Starter Models",
|
||||
values=[starter_model_labels[x] for x in keys],
|
||||
value=[
|
||||
keys.index(x)
|
||||
for x in keys
|
||||
if (show_recommended and models[x].recommended) or (x in self.installed_models)
|
||||
],
|
||||
max_height=len(starters) + 1,
|
||||
values=model_labels,
|
||||
value=checked,
|
||||
max_height=len(model_list) + 1,
|
||||
relx=4,
|
||||
scroll_exit=True,
|
||||
),
|
||||
models=keys,
|
||||
models=model_list,
|
||||
)
|
||||
|
||||
self.nextrely += 1
|
||||
@ -246,7 +265,8 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
) -> dict[str, npyscreen.widget]:
|
||||
"""Generic code to create model selection widgets"""
|
||||
widgets = dict()
|
||||
model_list = [x for x in self.all_models if self.all_models[x].model_type == model_type and x not in exclude]
|
||||
all_models = self.all_models
|
||||
model_list = [x for x in all_models if all_models[x].model_type == model_type and x not in exclude]
|
||||
model_labels = [self.model_labels[x] for x in model_list]
|
||||
|
||||
show_recommended = len(self.installed_models) == 0
|
||||
@ -282,7 +302,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
value=[
|
||||
model_list.index(x)
|
||||
for x in model_list
|
||||
if (show_recommended and self.all_models[x].recommended) or self.all_models[x].installed
|
||||
if (show_recommended and all_models[x].recommended) or all_models[x].installed
|
||||
],
|
||||
max_height=len(model_list) // columns + 1,
|
||||
relx=4,
|
||||
@ -334,8 +354,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
def resize(self):
|
||||
super().resize()
|
||||
if s := self.starter_pipelines.get("models_selected"):
|
||||
keys = [x for x in self.all_models.keys() if x in self.starter_models]
|
||||
s.values = [self.model_labels[x] for x in keys]
|
||||
s.values = [self.model_labels[x] for x in self.starter_pipelines.get("models")]
|
||||
|
||||
def _toggle_tables(self, value=None):
|
||||
selected_tab = value[0]
|
||||
@ -364,18 +383,61 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
self.__class__.current_tab = selected_tab # for persistence
|
||||
self.display()
|
||||
|
||||
def initialize_model_lists(self):
|
||||
"""
|
||||
Initialize our model slots.
|
||||
|
||||
Set up the following:
|
||||
self.installed_models -- list of installed model keys
|
||||
self.starter_models -- list of starter model keys from INITIAL_MODELS
|
||||
self.all_models -- dict of key => UnifiedModelInfo
|
||||
|
||||
Each of these is a dict of key=>ModelConfigBase.
|
||||
"""
|
||||
installed_models = list()
|
||||
starter_models = list()
|
||||
all_models = dict()
|
||||
|
||||
# previously-installed models
|
||||
for model in self.installer.store.all_models():
|
||||
info = UnifiedModelInfo.parse_obj(model.dict())
|
||||
info.installed = True
|
||||
key = f"{model.base_model}/{model.model_type}/{model.name}"
|
||||
all_models[key] = info
|
||||
installed_models.append(key)
|
||||
|
||||
for key in INITIAL_MODELS_CONFIG.keys():
|
||||
if key not in all_models:
|
||||
base_model, model_type, model_name = key.split("/")
|
||||
info = UnifiedModelInfo(
|
||||
name=model_name,
|
||||
model_type=model_type,
|
||||
base_model=base_model,
|
||||
source=INITIAL_MODELS_CONFIG[key].source,
|
||||
description=INITIAL_MODELS_CONFIG[key].get("description"),
|
||||
recommended=INITIAL_MODELS_CONFIG[key].get("recommended", False),
|
||||
default=INITIAL_MODELS_CONFIG[key].get("default", False),
|
||||
)
|
||||
all_models[key] = info
|
||||
starter_models.append(key)
|
||||
|
||||
self.installed_models = installed_models
|
||||
self.starter_models = starter_models
|
||||
self.all_models = all_models
|
||||
|
||||
def _get_model_labels(self) -> dict[str, str]:
|
||||
"""Return a list of trimmed labels for all models."""
|
||||
window_width, window_height = get_terminal_size()
|
||||
checkbox_width = 4
|
||||
spacing_width = 2
|
||||
result = dict()
|
||||
|
||||
models = self.all_models
|
||||
label_width = max([len(models[x].name) for x in models])
|
||||
label_width = max([len(models[x].name) for x in self.starter_models])
|
||||
description_width = window_width - label_width - checkbox_width - spacing_width
|
||||
|
||||
result = dict()
|
||||
for x in models.keys():
|
||||
description = models[x].description
|
||||
for key in self.starter_models:
|
||||
description = models[key].description
|
||||
description = (
|
||||
description[0 : description_width - 3] + "..."
|
||||
if description and len(description) > description_width
|
||||
@ -383,7 +445,8 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
if description
|
||||
else ""
|
||||
)
|
||||
result[x] = f"%-{label_width}s %s" % (models[x].name, description)
|
||||
result[key] = f"%-{label_width}s %s" % (models[key].name, description)
|
||||
|
||||
return result
|
||||
|
||||
def _get_columns(self) -> int:
|
||||
@ -394,39 +457,13 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
def confirm_deletions(self, selections: InstallSelections) -> bool:
|
||||
remove_models = selections.remove_models
|
||||
if len(remove_models) > 0:
|
||||
mods = "\n".join([ModelManager.parse_key(x)[0] for x in remove_models])
|
||||
mods = "\n".join([self.all_models[x].name for x in remove_models])
|
||||
return npyscreen.notify_ok_cancel(
|
||||
f"These unchecked models will be deleted from disk. Continue?\n---------\n{mods}"
|
||||
)
|
||||
else:
|
||||
return True
|
||||
|
||||
def on_execute(self):
|
||||
self.marshall_arguments()
|
||||
app = self.parentApp
|
||||
if not self.confirm_deletions(app.install_selections):
|
||||
return
|
||||
|
||||
self.monitor.entry_widget.buffer(["Processing..."], scroll_end=True)
|
||||
self.ok_button.hidden = True
|
||||
self.display()
|
||||
|
||||
# TO DO: Spawn a worker thread, not a subprocess
|
||||
parent_conn, child_conn = Pipe()
|
||||
p = Process(
|
||||
target=process_and_execute,
|
||||
kwargs=dict(
|
||||
opt=app.program_opts,
|
||||
selections=app.install_selections,
|
||||
conn_out=child_conn,
|
||||
),
|
||||
)
|
||||
p.start()
|
||||
child_conn.close()
|
||||
self.subprocess_connection = parent_conn
|
||||
self.subprocess = p
|
||||
app.install_selections = InstallSelections()
|
||||
|
||||
def on_back(self):
|
||||
self.parentApp.switchFormPrevious()
|
||||
self.editing = False
|
||||
@ -444,76 +481,6 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
self.parentApp.user_cancelled = False
|
||||
self.editing = False
|
||||
|
||||
########## This routine monitors the child process that is performing model installation and removal #####
|
||||
def while_waiting(self):
|
||||
"""Called during idle periods. Main task is to update the Log Messages box with messages
|
||||
from the child process that does the actual installation/removal"""
|
||||
c = self.subprocess_connection
|
||||
if not c:
|
||||
return
|
||||
|
||||
monitor_widget = self.monitor.entry_widget
|
||||
while c.poll():
|
||||
try:
|
||||
data = c.recv_bytes().decode("utf-8")
|
||||
data.strip("\n")
|
||||
|
||||
# processing child is requesting user input to select the
|
||||
# right configuration file
|
||||
if data.startswith("*need v2 config"):
|
||||
_, model_path, *_ = data.split(":", 2)
|
||||
self._return_v2_config(model_path)
|
||||
|
||||
# processing child is done
|
||||
elif data == "*done*":
|
||||
self._close_subprocess_and_regenerate_form()
|
||||
break
|
||||
|
||||
# update the log message box
|
||||
else:
|
||||
data = make_printable(data)
|
||||
data = data.replace("[A", "")
|
||||
monitor_widget.buffer(
|
||||
textwrap.wrap(
|
||||
data,
|
||||
width=monitor_widget.width,
|
||||
subsequent_indent=" ",
|
||||
),
|
||||
scroll_end=True,
|
||||
)
|
||||
self.display()
|
||||
except (EOFError, OSError):
|
||||
self.subprocess_connection = None
|
||||
|
||||
def _return_v2_config(self, model_path: str):
|
||||
c = self.subprocess_connection
|
||||
model_name = Path(model_path).name
|
||||
message = select_stable_diffusion_config_file(model_name=model_name)
|
||||
c.send_bytes(message.encode("utf-8"))
|
||||
|
||||
def _close_subprocess_and_regenerate_form(self):
|
||||
app = self.parentApp
|
||||
self.subprocess_connection.close()
|
||||
self.subprocess_connection = None
|
||||
self.monitor.entry_widget.buffer(["** Action Complete **"])
|
||||
self.display()
|
||||
|
||||
# rebuild the form, saving and restoring some of the fields that need to be preserved.
|
||||
saved_messages = self.monitor.entry_widget.values
|
||||
|
||||
app.main_form = app.addForm(
|
||||
"MAIN",
|
||||
addModelsForm,
|
||||
name="Install Stable Diffusion Models",
|
||||
multipage=self.multipage,
|
||||
)
|
||||
app.switchForm("MAIN")
|
||||
|
||||
app.main_form.monitor.entry_widget.values = saved_messages
|
||||
app.main_form.monitor.entry_widget.buffer([""], scroll_end=True)
|
||||
# app.main_form.pipeline_models['autoload_directory'].value = autoload_dir
|
||||
# app.main_form.pipeline_models['autoscan_on_startup'].value = autoscan
|
||||
|
||||
def marshall_arguments(self):
|
||||
"""
|
||||
Assemble arguments and store as attributes of the application:
|
||||
@ -542,11 +509,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
models_to_install = [x for x in selected if not self.all_models[x].installed]
|
||||
models_to_remove = [x for x in section["models"] if x not in selected and self.all_models[x].installed]
|
||||
selections.remove_models.extend(models_to_remove)
|
||||
selections.install_models.extend(
|
||||
all_models[x].path or all_models[x].repo_id
|
||||
for x in models_to_install
|
||||
if all_models[x].path or all_models[x].repo_id
|
||||
)
|
||||
selections.install_models.extend([all_models[x] for x in models_to_install])
|
||||
|
||||
# models located in the 'download_ids" section
|
||||
for section in ui_sections:
|
||||
@ -555,12 +518,12 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
|
||||
|
||||
class AddModelApplication(npyscreen.NPSAppManaged):
|
||||
def __init__(self, opt):
|
||||
def __init__(self, opt: Namespace, installer: ModelInstall):
|
||||
super().__init__()
|
||||
self.program_opts = opt
|
||||
self.user_cancelled = False
|
||||
# self.autoload_pending = True
|
||||
self.install_selections = InstallSelections()
|
||||
self.installer = installer
|
||||
|
||||
def onStart(self):
|
||||
npyscreen.setTheme(npyscreen.Themes.DefaultTheme)
|
||||
@ -583,103 +546,64 @@ class StderrToMessage:
|
||||
pass
|
||||
|
||||
|
||||
# --------------------------------------------------------
|
||||
def ask_user_for_prediction_type(model_path: Path, tui_conn: Connection = None) -> SchedulerPredictionType:
|
||||
if tui_conn:
|
||||
logger.debug("Waiting for user response...")
|
||||
return _ask_user_for_pt_tui(model_path, tui_conn)
|
||||
else:
|
||||
return _ask_user_for_pt_cmdline(model_path)
|
||||
def list_models(installer: ModelInstall, model_type: ModelType):
|
||||
"""Print out all models of type model_type."""
|
||||
models = installer.store.search_by_name(model_type=model_type)
|
||||
print(f"Installed models of type `{model_type}`:")
|
||||
for model in models:
|
||||
path = (config.models_path / model.path).resolve()
|
||||
print(f"{model.name:40}{model.base_model:10}{path}")
|
||||
|
||||
|
||||
def _ask_user_for_pt_cmdline(model_path: Path) -> SchedulerPredictionType:
|
||||
choices = [SchedulerPredictionType.Epsilon, SchedulerPredictionType.VPrediction, None]
|
||||
print(
|
||||
f"""
|
||||
Please select the type of the V2 checkpoint named {model_path.name}:
|
||||
[1] A model based on Stable Diffusion v2 trained on 512 pixel images (SD-2-base)
|
||||
[2] A model based on Stable Diffusion v2 trained on 768 pixel images (SD-2-768)
|
||||
[3] Skip this model and come back later.
|
||||
"""
|
||||
)
|
||||
choice = None
|
||||
ok = False
|
||||
while not ok:
|
||||
try:
|
||||
choice = input("select> ").strip()
|
||||
choice = choices[int(choice) - 1]
|
||||
ok = True
|
||||
except (ValueError, IndexError):
|
||||
print(f"{choice} is not a valid choice")
|
||||
except EOFError:
|
||||
return
|
||||
return choice
|
||||
def tqdm_progress(job: ModelInstallJob):
|
||||
pass
|
||||
|
||||
|
||||
def _ask_user_for_pt_tui(model_path: Path, tui_conn: Connection) -> SchedulerPredictionType:
|
||||
try:
|
||||
tui_conn.send_bytes(f"*need v2 config for:{model_path}".encode("utf-8"))
|
||||
# note that we don't do any status checking here
|
||||
response = tui_conn.recv_bytes().decode("utf-8")
|
||||
if response is None:
|
||||
return None
|
||||
elif response == "epsilon":
|
||||
return SchedulerPredictionType.epsilon
|
||||
elif response == "v":
|
||||
return SchedulerPredictionType.VPrediction
|
||||
elif response == "abort":
|
||||
logger.info("Conversion aborted")
|
||||
return None
|
||||
def add_or_delete(installer: ModelInstall, selections: InstallSelections):
|
||||
for model in selections.install_models:
|
||||
print(f"Installing {model.name}")
|
||||
metadata = ModelSourceMetadata(description=model.description)
|
||||
installer.install(
|
||||
model.source,
|
||||
variant="fp16" if config.precision == "float16" else None,
|
||||
access_token=ACCESS_TOKEN, # this is a global,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
for model in selections.remove_models:
|
||||
base_model, model_type, model_name = model.split("/")
|
||||
matches = installer.store.search_by_name(base_model=base_model, model_type=model_type, model_name=model_name)
|
||||
if len(matches) > 1:
|
||||
print(f"{model} is ambiguous. Please use model_type:model_name (e.g. main:my_model) to disambiguate.")
|
||||
elif not matches:
|
||||
print(f"{model}: unknown model")
|
||||
else:
|
||||
return response
|
||||
except Exception:
|
||||
return None
|
||||
for m in matches:
|
||||
print(f"Deleting {m.model_type}:{m.name}")
|
||||
installer.conditionally_delete(m.key)
|
||||
|
||||
|
||||
# --------------------------------------------------------
|
||||
def process_and_execute(
|
||||
opt: Namespace,
|
||||
selections: InstallSelections,
|
||||
conn_out: Connection = None,
|
||||
):
|
||||
# need to reinitialize config in subprocess
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
args = ["--root", opt.root] if opt.root else []
|
||||
config.parse_args(args)
|
||||
|
||||
# set up so that stderr is sent to conn_out
|
||||
if conn_out:
|
||||
translator = StderrToMessage(conn_out)
|
||||
sys.stderr = translator
|
||||
sys.stdout = translator
|
||||
logger = InvokeAILogger.getLogger()
|
||||
logger.handlers.clear()
|
||||
logger.addHandler(logging.StreamHandler(translator))
|
||||
|
||||
installer = ModelInstall(config, prediction_type_helper=lambda x: ask_user_for_prediction_type(x, conn_out))
|
||||
installer.install(selections)
|
||||
|
||||
if conn_out:
|
||||
conn_out.send_bytes("*done*".encode("utf-8"))
|
||||
conn_out.close()
|
||||
installer.wait_for_installs()
|
||||
|
||||
|
||||
# --------------------------------------------------------
|
||||
def select_and_download_models(opt: Namespace):
|
||||
precision = "float32" if opt.full_precision else choose_precision(torch.device(choose_torch_device()))
|
||||
config.precision = precision
|
||||
installer = ModelInstall(config, prediction_type_helper=ask_user_for_prediction_type)
|
||||
installer = ModelInstall(config=config, event_handlers=[tqdm_progress])
|
||||
|
||||
if opt.list_models:
|
||||
installer.list_models(opt.list_models)
|
||||
list_models(installer, opt.list_models)
|
||||
|
||||
elif opt.add or opt.delete:
|
||||
selections = InstallSelections(install_models=opt.add or [], remove_models=opt.delete or [])
|
||||
installer.install(selections)
|
||||
selections = InstallSelections(install_models=opt.add, remove_models=opt.delete)
|
||||
add_or_delete(installer, selections)
|
||||
|
||||
elif opt.default_only:
|
||||
selections = InstallSelections(install_models=installer.default_model())
|
||||
installer.install(selections)
|
||||
add_or_delete(installer, selections)
|
||||
elif opt.yes_to_all:
|
||||
selections = InstallSelections(install_models=installer.recommended_models())
|
||||
installer.install(selections)
|
||||
add_or_delete(installer, selections)
|
||||
|
||||
# this is where the TUI is called
|
||||
else:
|
||||
@ -691,17 +615,14 @@ def select_and_download_models(opt: Namespace):
|
||||
"Could not increase terminal size. Try running again with a larger window or smaller font size."
|
||||
)
|
||||
|
||||
installApp = AddModelApplication(opt)
|
||||
installApp = AddModelApplication(opt, installer)
|
||||
try:
|
||||
installApp.run()
|
||||
except KeyboardInterrupt as e:
|
||||
if hasattr(installApp, "main_form"):
|
||||
if installApp.main_form.subprocess and installApp.main_form.subprocess.is_alive():
|
||||
logger.info("Terminating subprocesses")
|
||||
installApp.main_form.subprocess.terminate()
|
||||
installApp.main_form.subprocess = None
|
||||
raise e
|
||||
process_and_execute(opt, installApp.install_selections)
|
||||
print("Aborted...")
|
||||
sys.exit(-1)
|
||||
|
||||
add_or_delete(installer, installApp.install_selections)
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
@ -715,7 +636,7 @@ def main():
|
||||
parser.add_argument(
|
||||
"--delete",
|
||||
nargs="*",
|
||||
help="List of names of models to idelete",
|
||||
help="List of names of models to delete. Use type:name to disambiguate, as in `controlnet:my_model`",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--full-precision",
|
||||
|
Reference in New Issue
Block a user