Port the command-line tools to use model_manager2 (#5546)

* Port the command-line tools to use model_manager2

1.Reimplement the following:

  - invokeai-model-install
  - invokeai-merge
  - invokeai-ti

  To avoid breaking the original modeal manager, the udpated tools
  have been renamed invokeai-model-install2 and invokeai-merge2. The
  textual inversion training script should continue to work with
  existing installations. The "starter" models now live in
  `invokeai/configs/INITIAL_MODELS2.yaml`.

  When the full model manager 2 is in place and working, I'll rename
  these files and commands.

2. Add the `merge` route to the web API. This will merge two or three models,
   resulting a new one.

   - Note that because the model installer selectively installs the `fp16` variant
     of models (rather than both 16- and 32-bit versions as previous),
     the diffusers merge script will choke on any huggingface diffuserse models
     that were downloaded with the new installer. Previously-downloaded models
     should continue to merge correctly. I have a PR
     upstream https://github.com/huggingface/diffusers/pull/6670 to fix
     this.

3. (more important!)
  During implementation of the CLI tools, found and fixed a number of small
  runtime bugs in the model_manager2 implementation:

  - During model database migration, if a registered models file was
    not found on disk, the migration would be aborted. Now the
    offending model is skipped with a log warning.

  - Caught and fixed a condition in which the installer would download the
    entire diffusers repo when the user provided a single `.safetensors`
    file URL.

  - Caught and fixed a condition in which the installer would raise an
    exception and stop the app when a request for an unknown model's metadata
    was passed to Civitai. Now an error is logged and the installer continues.

  - Replaced the LoWRA starter LoRA with FlatColor. The former has been removed
    from Civitai.

* fix ruff issue

---------

Co-authored-by: Lincoln Stein <lstein@gmail.com>
This commit is contained in:
Lincoln Stein 2024-02-02 12:18:47 -05:00 committed by GitHub
parent d3320dc4ee
commit f2777f5096
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 2297 additions and 78 deletions

View File

@ -1,7 +1,7 @@
# Copyright (c) 2023 Lincoln D. Stein
"""FastAPI route for model configuration records."""
import pathlib
from hashlib import sha1
from random import randbytes
from typing import Any, Dict, List, Optional, Set
@ -27,6 +27,7 @@ from invokeai.backend.model_manager.config import (
ModelFormat,
ModelType,
)
from invokeai.backend.model_manager.merge import MergeInterpolationMethod, ModelMerger
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
from ..dependencies import ApiDependencies
@ -415,3 +416,57 @@ async def sync_models_to_config() -> Response:
"""
ApiDependencies.invoker.services.model_install.sync_to_config()
return Response(status_code=204)
@model_records_router.put(
"/merge",
operation_id="merge",
)
async def merge(
keys: List[str] = Body(description="Keys for two to three models to merge", min_length=2, max_length=3),
merged_model_name: Optional[str] = Body(description="Name of destination model", default=None),
alpha: float = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5),
force: bool = Body(
description="Force merging of models created with different versions of diffusers",
default=False,
),
interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method", default=None),
merge_dest_directory: Optional[str] = Body(
description="Save the merged model to the designated directory (with 'merged_model_name' appended)",
default=None,
),
) -> AnyModelConfig:
"""
Merge diffusers models.
keys: List of 2-3 model keys to merge together. All models must use the same base type.
merged_model_name: Name for the merged model [Concat model names]
alpha: Alpha value (0.0-1.0). Higher values give more weight to the second model [0.5]
force: If true, force the merge even if the models were generated by different versions of the diffusers library [False]
interp: Interpolation method. One of "weighted_sum", "sigmoid", "inv_sigmoid" or "add_difference" [weighted_sum]
merge_dest_directory: Specify a directory to store the merged model in [models directory]
"""
print(f"here i am, keys={keys}")
logger = ApiDependencies.invoker.services.logger
try:
logger.info(f"Merging models: {keys} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None
installer = ApiDependencies.invoker.services.model_install
merger = ModelMerger(installer)
model_names = [installer.record_store.get_model(x).name for x in keys]
response = merger.merge_diffusion_models_and_save(
model_keys=keys,
merged_model_name=merged_model_name or "+".join(model_names),
alpha=alpha,
interp=interp,
force=force,
merge_dest_directory=dest,
)
except UnknownModelException:
raise HTTPException(
status_code=404,
detail=f"One or more of the models '{keys}' not found",
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
return response

View File

@ -208,7 +208,6 @@ class DownloadQueueService(DownloadQueueServiceBase):
job = self._queue.get(timeout=1)
except Empty:
continue
try:
job.job_started = get_iso_timestamp()
self._do_download(job)

View File

@ -165,8 +165,8 @@ class ModelInstallJob(BaseModel):
)
source: ModelSource = Field(description="Source (URL, repo_id, or local path) of model")
local_path: Path = Field(description="Path to locally-downloaded model; may be the same as the source")
bytes: Optional[int] = Field(
default=None, description="For a remote model, the number of bytes downloaded so far (may not be available)"
bytes: int = Field(
default=0, description="For a remote model, the number of bytes downloaded so far (may not be available)"
)
total_bytes: int = Field(default=0, description="Total size of the model to be installed")
source_metadata: Optional[AnyModelRepoMetadata] = Field(

View File

@ -535,19 +535,19 @@ class ModelInstallService(ModelInstallServiceBase):
def _import_from_url(self, source: URLModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
# URLs from Civitai or HuggingFace will be handled specially
url_patterns = {
r"https?://civitai.com/": CivitaiMetadataFetch,
r"https?://huggingface.co/": HuggingFaceMetadataFetch,
r"^https?://civitai.com/": CivitaiMetadataFetch,
r"^https?://huggingface.co/[^/]+/[^/]+$": HuggingFaceMetadataFetch,
}
metadata = None
for pattern, fetcher in url_patterns.items():
if re.match(pattern, str(source.url), re.IGNORECASE):
metadata = fetcher(self._session).from_url(source.url)
break
self._logger.debug(f"metadata={metadata}")
if metadata and isinstance(metadata, ModelMetadataWithFiles):
remote_files = metadata.download_urls(session=self._session)
else:
remote_files = [RemoteModelFile(url=source.url, path=Path("."), size=0)]
return self._import_remote_model(
source=source,
config=config,
@ -586,6 +586,7 @@ class ModelInstallService(ModelInstallServiceBase):
assert install_job.total_bytes is not None # to avoid type checking complaints in the loop below
self._logger.info(f"Queuing {source} for downloading")
self._logger.debug(f"remote_files={remote_files}")
for model_file in remote_files:
url = model_file.url
path = model_file.path

View File

@ -72,7 +72,12 @@ class MigrateModelYamlToDb1:
continue
base_type, model_type, model_name = str(model_key).split("/")
hash = FastModelHash.hash(self.config.models_path / stanza.path)
try:
hash = FastModelHash.hash(self.config.models_path / stanza.path)
except OSError:
self.logger.warning(f"The model at {stanza.path} is not a valid file or directory. Skipping migration.")
continue
assert isinstance(model_key, str)
new_key = sha1(model_key.encode("utf-8")).hexdigest()

View File

@ -0,0 +1,281 @@
"""Utility (backend) functions used by model_install.py"""
import re
from logging import Logger
from pathlib import Path
from typing import Any, Dict, List, Optional
import omegaconf
from huggingface_hub import HfFolder
from pydantic import BaseModel, Field
from pydantic.dataclasses import dataclass
from pydantic.networks import AnyHttpUrl
from requests import HTTPError
from tqdm import tqdm
import invokeai.configs as configs
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.download import DownloadQueueService
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.image_files.image_files_disk import DiskImageFileStorage
from invokeai.app.services.model_install import (
HFModelSource,
LocalModelSource,
ModelInstallService,
ModelInstallServiceBase,
ModelSource,
URLModelSource,
)
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
from invokeai.backend.model_manager import (
BaseModelType,
InvalidModelConfigException,
ModelType,
)
from invokeai.backend.model_manager.metadata import UnknownMetadataException
from invokeai.backend.util.logging import InvokeAILogger
# name of the starter models file
INITIAL_MODELS = "INITIAL_MODELS2.yaml"
def initialize_record_store(app_config: InvokeAIAppConfig) -> ModelRecordServiceBase:
"""Return an initialized ModelConfigRecordServiceBase object."""
logger = InvokeAILogger.get_logger(config=app_config)
image_files = DiskImageFileStorage(f"{app_config.output_path}/images")
db = init_db(config=app_config, logger=logger, image_files=image_files)
obj: ModelRecordServiceBase = ModelRecordServiceSQL(db)
return obj
def initialize_installer(
app_config: InvokeAIAppConfig, event_bus: Optional[EventServiceBase] = None
) -> ModelInstallServiceBase:
"""Return an initialized ModelInstallService object."""
record_store = initialize_record_store(app_config)
metadata_store = record_store.metadata_store
download_queue = DownloadQueueService()
installer = ModelInstallService(
app_config=app_config,
record_store=record_store,
metadata_store=metadata_store,
download_queue=download_queue,
event_bus=event_bus,
)
download_queue.start()
installer.start()
return installer
class UnifiedModelInfo(BaseModel):
"""Catchall class for information in INITIAL_MODELS2.yaml."""
name: Optional[str] = None
base: Optional[BaseModelType] = None
type: Optional[ModelType] = None
source: Optional[str] = None
subfolder: Optional[str] = None
description: Optional[str] = None
recommended: bool = False
installed: bool = False
default: bool = False
requires: List[str] = Field(default_factory=list)
@dataclass
class InstallSelections:
"""Lists of models to install and remove."""
install_models: List[UnifiedModelInfo] = Field(default_factory=list)
remove_models: List[str] = Field(default_factory=list)
class TqdmEventService(EventServiceBase):
"""An event service to track downloads."""
def __init__(self) -> None:
"""Create a new TqdmEventService object."""
super().__init__()
self._bars: Dict[str, tqdm] = {}
self._last: Dict[str, int] = {}
def dispatch(self, event_name: str, payload: Any) -> None:
"""Dispatch an event by appending it to self.events."""
if payload["event"] == "model_install_downloading":
data = payload["data"]
dest = data["local_path"]
total_bytes = data["total_bytes"]
bytes = data["bytes"]
if dest not in self._bars:
self._bars[dest] = tqdm(desc=Path(dest).name, initial=0, total=total_bytes, unit="iB", unit_scale=True)
self._last[dest] = 0
self._bars[dest].update(bytes - self._last[dest])
self._last[dest] = bytes
class InstallHelper(object):
"""Capture information stored jointly in INITIAL_MODELS.yaml and the installed models db."""
def __init__(self, app_config: InvokeAIAppConfig, logger: Logger):
"""Create new InstallHelper object."""
self._app_config = app_config
self.all_models: Dict[str, UnifiedModelInfo] = {}
omega = omegaconf.OmegaConf.load(Path(configs.__path__[0]) / INITIAL_MODELS)
assert isinstance(omega, omegaconf.dictconfig.DictConfig)
self._installer = initialize_installer(app_config, TqdmEventService())
self._initial_models = omega
self._installed_models: List[str] = []
self._starter_models: List[str] = []
self._default_model: Optional[str] = None
self._logger = logger
self._initialize_model_lists()
@property
def installer(self) -> ModelInstallServiceBase:
"""Return the installer object used internally."""
return self._installer
def _initialize_model_lists(self) -> None:
"""
Initialize our model slots.
Set up the following:
installed_models -- list of installed model keys
starter_models -- list of starter model keys from INITIAL_MODELS
all_models -- dict of key => UnifiedModelInfo
default_model -- key to default model
"""
# previously-installed models
for model in self._installer.record_store.all_models():
info = UnifiedModelInfo.parse_obj(model.dict())
info.installed = True
model_key = f"{model.base.value}/{model.type.value}/{model.name}"
self.all_models[model_key] = info
self._installed_models.append(model_key)
for key in self._initial_models.keys():
assert isinstance(key, str)
if key in self.all_models:
# we want to preserve the description
description = self.all_models[key].description or self._initial_models[key].get("description")
self.all_models[key].description = description
else:
base_model, model_type, model_name = key.split("/")
info = UnifiedModelInfo(
name=model_name,
type=ModelType(model_type),
base=BaseModelType(base_model),
source=self._initial_models[key].source,
description=self._initial_models[key].get("description"),
recommended=self._initial_models[key].get("recommended", False),
default=self._initial_models[key].get("default", False),
subfolder=self._initial_models[key].get("subfolder"),
requires=list(self._initial_models[key].get("requires", [])),
)
self.all_models[key] = info
if not self.default_model():
self._default_model = key
elif self._initial_models[key].get("default", False):
self._default_model = key
self._starter_models.append(key)
# previously-installed models
for model in self._installer.record_store.all_models():
info = UnifiedModelInfo.parse_obj(model.dict())
info.installed = True
model_key = f"{model.base.value}/{model.type.value}/{model.name}"
self.all_models[model_key] = info
self._installed_models.append(model_key)
def recommended_models(self) -> List[UnifiedModelInfo]:
"""List of the models recommended in INITIAL_MODELS.yaml."""
return [self._to_model(x) for x in self._starter_models if self._to_model(x).recommended]
def installed_models(self) -> List[UnifiedModelInfo]:
"""List of models already installed."""
return [self._to_model(x) for x in self._installed_models]
def starter_models(self) -> List[UnifiedModelInfo]:
"""List of starter models."""
return [self._to_model(x) for x in self._starter_models]
def default_model(self) -> Optional[UnifiedModelInfo]:
"""Return the default model."""
return self._to_model(self._default_model) if self._default_model else None
def _to_model(self, key: str) -> UnifiedModelInfo:
return self.all_models[key]
def _add_required_models(self, model_list: List[UnifiedModelInfo]) -> None:
installed = {x.source for x in self.installed_models()}
reverse_source = {x.source: x for x in self.all_models.values()}
additional_models: List[UnifiedModelInfo] = []
for model_info in model_list:
for requirement in model_info.requires:
if requirement not in installed and reverse_source.get(requirement):
additional_models.append(reverse_source[requirement])
model_list.extend(additional_models)
def _make_install_source(self, model_info: UnifiedModelInfo) -> ModelSource:
assert model_info.source
model_path_id_or_url = model_info.source.strip("\"' ")
model_path = Path(model_path_id_or_url)
if model_path.exists(): # local file on disk
return LocalModelSource(path=model_path.absolute(), inplace=True)
if re.match(r"^[^/]+/[^/]+$", model_path_id_or_url): # hugging face repo_id
return HFModelSource(
repo_id=model_path_id_or_url,
access_token=HfFolder.get_token(),
subfolder=model_info.subfolder,
)
if re.match(r"^(http|https):", model_path_id_or_url):
return URLModelSource(url=AnyHttpUrl(model_path_id_or_url))
raise ValueError(f"Unsupported model source: {model_path_id_or_url}")
def add_or_delete(self, selections: InstallSelections) -> None:
"""Add or delete selected models."""
installer = self._installer
self._add_required_models(selections.install_models)
for model in selections.install_models:
source = self._make_install_source(model)
config = (
{
"description": model.description,
"name": model.name,
}
if model.name
else None
)
try:
installer.import_model(
source=source,
config=config,
)
except (UnknownMetadataException, InvalidModelConfigException, HTTPError, OSError) as e:
self._logger.warning(f"{source}: {e}")
for model_to_remove in selections.remove_models:
parts = model_to_remove.split("/")
if len(parts) == 1:
base_model, model_type, model_name = (None, None, model_to_remove)
else:
base_model, model_type, model_name = parts
matches = installer.record_store.search_by_attr(
base_model=BaseModelType(base_model) if base_model else None,
model_type=ModelType(model_type) if model_type else None,
model_name=model_name,
)
if len(matches) > 1:
print(f"{model} is ambiguous. Please use model_type:model_name (e.g. main:my_model) to disambiguate.")
elif not matches:
print(f"{model}: unknown model")
else:
for m in matches:
print(f"Deleting {m.type}:{m.name}")
installer.delete(m.key)
installer.wait_for_installs()

View File

@ -849,7 +849,7 @@ def migrate_if_needed(opt: Namespace, root: Path) -> bool:
# -------------------------------------
def main():
def main() -> None:
parser = argparse.ArgumentParser(description="InvokeAI model downloader")
parser.add_argument(
"--skip-sd-weights",

View File

@ -0,0 +1,177 @@
"""
invokeai.backend.model_manager.merge exports:
merge_diffusion_models() -- combine multiple models by location and return a pipeline object
merge_diffusion_models_and_commit() -- combine multiple models by ModelManager ID and write to models.yaml
Copyright (c) 2023 Lincoln Stein and the InvokeAI Development Team
"""
import warnings
from enum import Enum
from pathlib import Path
from typing import Any, List, Optional, Set
import torch
from diffusers import AutoPipelineForText2Image
from diffusers import logging as dlogging
from invokeai.app.services.model_install import ModelInstallServiceBase
from invokeai.backend.util.devices import choose_torch_device, torch_dtype
from . import (
AnyModelConfig,
BaseModelType,
ModelType,
ModelVariantType,
)
from .config import MainDiffusersConfig
class MergeInterpolationMethod(str, Enum):
WeightedSum = "weighted_sum"
Sigmoid = "sigmoid"
InvSigmoid = "inv_sigmoid"
AddDifference = "add_difference"
class ModelMerger(object):
"""Wrapper class for model merge function."""
def __init__(self, installer: ModelInstallServiceBase):
"""
Initialize a ModelMerger object.
:param store: Underlying storage manager for the running process.
:param config: InvokeAIAppConfig object (if not provided, default will be selected).
"""
self._installer = installer
def merge_diffusion_models(
self,
model_paths: List[Path],
alpha: float = 0.5,
interp: Optional[MergeInterpolationMethod] = None,
force: bool = False,
variant: Optional[str] = None,
**kwargs: Any,
) -> Any: # pipe.merge is an untyped function.
"""
:param model_paths: up to three models, designated by their local paths or HuggingFace repo_ids
:param alpha: The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
:param interp: The interpolation method to use for the merging. Supports "sigmoid", "inv_sigmoid", "add_difference" and None.
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported.
:param force: Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
"""
with warnings.catch_warnings():
warnings.simplefilter("ignore")
verbosity = dlogging.get_verbosity()
dlogging.set_verbosity_error()
dtype = torch.float16 if variant == "fp16" else torch_dtype(choose_torch_device())
# Note that checkpoint_merger will not work with downloaded HuggingFace fp16 models
# until upstream https://github.com/huggingface/diffusers/pull/6670 is merged and released.
pipe = AutoPipelineForText2Image.from_pretrained(
model_paths[0],
custom_pipeline="checkpoint_merger",
torch_dtype=dtype,
variant=variant,
)
merged_pipe = pipe.merge(
pretrained_model_name_or_path_list=model_paths,
alpha=alpha,
interp=interp.value if interp else None, # diffusers API treats None as "weighted sum"
force=force,
torch_dtype=dtype,
variant=variant,
**kwargs,
)
dlogging.set_verbosity(verbosity)
return merged_pipe
def merge_diffusion_models_and_save(
self,
model_keys: List[str],
merged_model_name: str,
alpha: float = 0.5,
force: bool = False,
interp: Optional[MergeInterpolationMethod] = None,
merge_dest_directory: Optional[Path] = None,
variant: Optional[str] = None,
**kwargs: Any,
) -> AnyModelConfig:
"""
:param models: up to three models, designated by their InvokeAI models.yaml model name
:param merged_model_name: name for new model
:param alpha: The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
:param interp: The interpolation method to use for the merging. Supports "weighted_average", "sigmoid", "inv_sigmoid", "add_difference" and None.
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported. Add_difference is A+(B-C).
:param force: Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
"""
model_paths: List[Path] = []
model_names: List[str] = []
config = self._installer.app_config
store = self._installer.record_store
base_models: Set[BaseModelType] = set()
vae = None
variant = None if self._installer.app_config.full_precision else "fp16"
assert (
len(model_keys) <= 2 or interp == MergeInterpolationMethod.AddDifference
), "When merging three models, only the 'add_difference' merge method is supported"
for key in model_keys:
info = store.get_model(key)
model_names.append(info.name)
assert isinstance(
info, MainDiffusersConfig
), f"{info.name} ({info.key}) is not a diffusers model. It must be optimized before merging"
assert info.variant == ModelVariantType(
"normal"
), f"{info.name} ({info.key}) is a {info.variant} model, which cannot currently be merged"
# pick up the first model's vae
if key == model_keys[0]:
vae = info.vae
# tally base models used
base_models.add(info.base)
model_paths.extend([config.models_path / info.path])
assert len(base_models) == 1, f"All models to merge must have same base model, but found bases {base_models}"
base_model = base_models.pop()
merge_method = None if interp == "weighted_sum" else MergeInterpolationMethod(interp)
merged_pipe = self.merge_diffusion_models(model_paths, alpha, merge_method, force, variant=variant, **kwargs)
dump_path = (
Path(merge_dest_directory)
if merge_dest_directory
else config.models_path / base_model.value / ModelType.Main.value
)
dump_path.mkdir(parents=True, exist_ok=True)
dump_path = dump_path / merged_model_name
dtype = torch.float16 if variant == "fp16" else torch_dtype(choose_torch_device())
merged_pipe.save_pretrained(dump_path.as_posix(), safe_serialization=True, torch_dtype=dtype, variant=variant)
# register model and get its unique key
key = self._installer.register_path(dump_path)
# update model's config
model_config = self._installer.record_store.get_model(key)
model_config.update(
{
"name": merged_model_name,
"description": f"Merge of models {', '.join(model_names)}",
"vae": vae,
}
)
self._installer.record_store.update_model(key, model_config)
return model_config

View File

@ -170,6 +170,8 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
if model_id is None:
version_url = CIVITAI_VERSION_ENDPOINT + str(version_id)
version = self._requests.get(version_url).json()
if error := version.get("error"):
raise UnknownMetadataException(error)
model_id = version["modelId"]
model_url = CIVITAI_MODEL_ENDPOINT + str(model_id)

View File

@ -11,6 +11,7 @@ import logging
import math
import os
import random
from argparse import Namespace
from pathlib import Path
from typing import Optional
@ -30,8 +31,6 @@ from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version
from diffusers.utils.import_utils import is_xformers_available
from huggingface_hub import HfFolder, Repository, whoami
# TODO: remove and import from diffusers.utils when the new version of diffusers is released
from packaging import version
from PIL import Image
from torch.utils.data import Dataset
@ -41,8 +40,8 @@ from transformers import CLIPTextModel, CLIPTokenizer
# invokeai stuff
from invokeai.app.services.config import InvokeAIAppConfig, PagingArgumentParser
from invokeai.app.services.model_manager import ModelManagerService
from invokeai.backend.model_management.models import SubModelType
from invokeai.backend.install.install_helper import initialize_record_store
from invokeai.backend.model_manager import BaseModelType, ModelType
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
PIL_INTERPOLATION = {
@ -77,7 +76,7 @@ def save_progress(text_encoder, placeholder_token_id, accelerator, placeholder_t
torch.save(learned_embeds_dict, save_path)
def parse_args():
def parse_args() -> Namespace:
config = InvokeAIAppConfig.get_config()
parser = PagingArgumentParser(description="Textual inversion training")
general_group = parser.add_argument_group("General")
@ -444,7 +443,7 @@ class TextualInversionDataset(Dataset):
self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small
self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)
def __len__(self):
def __len__(self) -> int:
return self._length
def __getitem__(self, i):
@ -509,11 +508,10 @@ def do_textual_inversion_training(
initializer_token: str,
save_steps: int = 500,
only_save_embeds: bool = False,
revision: str = None,
tokenizer_name: str = None,
tokenizer_name: Optional[str] = None,
learnable_property: str = "object",
repeats: int = 100,
seed: int = None,
seed: Optional[int] = None,
resolution: int = 512,
center_crop: bool = False,
train_batch_size: int = 16,
@ -530,18 +528,18 @@ def do_textual_inversion_training(
adam_weight_decay: float = 1e-02,
adam_epsilon: float = 1e-08,
push_to_hub: bool = False,
hub_token: str = None,
hub_token: Optional[str] = None,
logging_dir: Path = Path("logs"),
mixed_precision: str = "fp16",
allow_tf32: bool = False,
report_to: str = "tensorboard",
local_rank: int = -1,
checkpointing_steps: int = 500,
resume_from_checkpoint: Path = None,
resume_from_checkpoint: Optional[Path] = None,
enable_xformers_memory_efficient_attention: bool = False,
hub_model_id: str = None,
hub_model_id: Optional[str] = None,
**kwargs,
):
) -> None:
assert model, "Please specify a base model with --model"
assert train_data_dir, "Please specify a directory containing the training images using --train_data_dir"
assert placeholder_token, "Please specify a trigger term using --placeholder_token"
@ -564,8 +562,6 @@ def do_textual_inversion_training(
project_config=accelerator_config,
)
model_manager = ModelManagerService(config, logger)
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
@ -603,44 +599,37 @@ def do_textual_inversion_training(
elif output_dir is not None:
os.makedirs(output_dir, exist_ok=True)
known_models = model_manager.model_names()
model_name = model.split("/")[-1]
model_meta = next((mm for mm in known_models if mm[0].endswith(model_name)), None)
assert model_meta is not None, f"Unknown model: {model}"
model_info = model_manager.model_info(*model_meta)
assert model_info["model_format"] == "diffusers", "This script only works with models of type 'diffusers'"
tokenizer_info = model_manager.get_model(*model_meta, submodel=SubModelType.Tokenizer)
noise_scheduler_info = model_manager.get_model(*model_meta, submodel=SubModelType.Scheduler)
text_encoder_info = model_manager.get_model(*model_meta, submodel=SubModelType.TextEncoder)
vae_info = model_manager.get_model(*model_meta, submodel=SubModelType.Vae)
unet_info = model_manager.get_model(*model_meta, submodel=SubModelType.UNet)
model_records = initialize_record_store(config)
base, type, name = model.split("/") # note frontend still returns old-style keys
try:
model_config = model_records.search_by_attr(
model_name=name, model_type=ModelType(type), base_model=BaseModelType(base)
)[0]
except IndexError:
raise Exception(f"Unknown model {model}")
model_path = config.models_path / model_config.path
pipeline_args = {"local_files_only": True}
if tokenizer_name:
tokenizer = CLIPTokenizer.from_pretrained(tokenizer_name, **pipeline_args)
else:
tokenizer = CLIPTokenizer.from_pretrained(tokenizer_info.location, subfolder="tokenizer", **pipeline_args)
tokenizer = CLIPTokenizer.from_pretrained(model_path, subfolder="tokenizer", **pipeline_args)
# Load scheduler and models
noise_scheduler = DDPMScheduler.from_pretrained(
noise_scheduler_info.location, subfolder="scheduler", **pipeline_args
)
noise_scheduler = DDPMScheduler.from_pretrained(model_path, subfolder="scheduler", **pipeline_args)
text_encoder = CLIPTextModel.from_pretrained(
text_encoder_info.location,
model_path,
subfolder="text_encoder",
revision=revision,
**pipeline_args,
)
vae = AutoencoderKL.from_pretrained(
vae_info.location,
model_path,
subfolder="vae",
revision=revision,
**pipeline_args,
)
unet = UNet2DConditionModel.from_pretrained(
unet_info.location,
model_path,
subfolder="unet",
revision=revision,
**pipeline_args,
)
@ -728,7 +717,7 @@ def do_textual_inversion_training(
max_train_steps = num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
lr_scheduler = get_scheduler(
scheduler = get_scheduler(
lr_scheduler,
optimizer=optimizer,
num_warmup_steps=lr_warmup_steps * gradient_accumulation_steps,
@ -737,7 +726,7 @@ def do_textual_inversion_training(
# Prepare everything with our `accelerator`.
text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
text_encoder, optimizer, train_dataloader, lr_scheduler
text_encoder, optimizer, train_dataloader, scheduler
)
# For mixed precision training we cast the unet and vae weights to half-precision
@ -863,7 +852,7 @@ def do_textual_inversion_training(
accelerator.backward(loss)
optimizer.step()
lr_scheduler.step()
scheduler.step()
optimizer.zero_grad()
# Let's make sure we don't update any embedding weights besides the newly added token
@ -893,7 +882,7 @@ def do_textual_inversion_training(
accelerator.save_state(save_path)
logger.info(f"Saved state to {save_path}")
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
logs = {"loss": loss.detach().item(), "lr": scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
accelerator.log(logs, step=global_step)
@ -910,7 +899,7 @@ def do_textual_inversion_training(
save_full_model = not only_save_embeds
if save_full_model:
pipeline = StableDiffusionPipeline.from_pretrained(
unet_info.location,
model_path,
text_encoder=accelerator.unwrap_model(text_encoder),
vae=vae,
unet=unet,

View File

@ -0,0 +1,157 @@
# This file predefines a few models that the user may want to install.
sd-1/main/stable-diffusion-v1-5:
description: Stable Diffusion version 1.5 diffusers model (4.27 GB)
source: runwayml/stable-diffusion-v1-5
recommended: True
default: True
sd-1/main/stable-diffusion-v1-5-inpainting:
description: RunwayML SD 1.5 model optimized for inpainting, diffusers version (4.27 GB)
source: runwayml/stable-diffusion-inpainting
recommended: True
sd-2/main/stable-diffusion-2-1:
description: Stable Diffusion version 2.1 diffusers model, trained on 768 pixel images (5.21 GB)
source: stabilityai/stable-diffusion-2-1
recommended: False
sd-2/main/stable-diffusion-2-inpainting:
description: Stable Diffusion version 2.0 inpainting model (5.21 GB)
source: stabilityai/stable-diffusion-2-inpainting
recommended: False
sdxl/main/stable-diffusion-xl-base-1-0:
description: Stable Diffusion XL base model (12 GB)
source: stabilityai/stable-diffusion-xl-base-1.0
recommended: True
sdxl-refiner/main/stable-diffusion-xl-refiner-1-0:
description: Stable Diffusion XL refiner model (12 GB)
source: stabilityai/stable-diffusion-xl-refiner-1.0
recommended: False
sdxl/vae/sdxl-vae-fp16-fix:
description: Version of the SDXL-1.0 VAE that works in half precision mode
source: madebyollin/sdxl-vae-fp16-fix
recommended: True
sd-1/main/Analog-Diffusion:
description: An SD-1.5 model trained on diverse analog photographs (2.13 GB)
source: wavymulder/Analog-Diffusion
recommended: False
sd-1/main/Deliberate:
description: Versatile model that produces detailed images up to 768px (4.27 GB)
source: XpucT/Deliberate
recommended: False
sd-1/main/Dungeons-and-Diffusion:
description: Dungeons & Dragons characters (2.13 GB)
source: 0xJustin/Dungeons-and-Diffusion
recommended: False
sd-1/main/dreamlike-photoreal-2:
description: A photorealistic model trained on 768 pixel images based on SD 1.5 (2.13 GB)
source: dreamlike-art/dreamlike-photoreal-2.0
recommended: False
sd-1/main/Inkpunk-Diffusion:
description: Stylized illustrations inspired by Gorillaz, FLCL and Shinkawa; prompt with "nvinkpunk" (4.27 GB)
source: Envvi/Inkpunk-Diffusion
recommended: False
sd-1/main/openjourney:
description: An SD 1.5 model fine tuned on Midjourney; prompt with "mdjrny-v4 style" (2.13 GB)
source: prompthero/openjourney
recommended: False
sd-1/main/seek.art_MEGA:
source: coreco/seek.art_MEGA
description: A general use SD-1.5 "anything" model that supports multiple styles (2.1 GB)
recommended: False
sd-1/main/trinart_stable_diffusion_v2:
description: An SD-1.5 model finetuned with ~40K assorted high resolution manga/anime-style images (2.13 GB)
source: naclbit/trinart_stable_diffusion_v2
recommended: False
sd-1/controlnet/qrcode_monster:
source: monster-labs/control_v1p_sd15_qrcode_monster
subfolder: v2
sd-1/controlnet/canny:
source: lllyasviel/control_v11p_sd15_canny
recommended: True
sd-1/controlnet/inpaint:
source: lllyasviel/control_v11p_sd15_inpaint
sd-1/controlnet/mlsd:
source: lllyasviel/control_v11p_sd15_mlsd
sd-1/controlnet/depth:
source: lllyasviel/control_v11f1p_sd15_depth
recommended: True
sd-1/controlnet/normal_bae:
source: lllyasviel/control_v11p_sd15_normalbae
sd-1/controlnet/seg:
source: lllyasviel/control_v11p_sd15_seg
sd-1/controlnet/lineart:
source: lllyasviel/control_v11p_sd15_lineart
recommended: True
sd-1/controlnet/lineart_anime:
source: lllyasviel/control_v11p_sd15s2_lineart_anime
sd-1/controlnet/openpose:
source: lllyasviel/control_v11p_sd15_openpose
recommended: True
sd-1/controlnet/scribble:
source: lllyasviel/control_v11p_sd15_scribble
recommended: False
sd-1/controlnet/softedge:
source: lllyasviel/control_v11p_sd15_softedge
sd-1/controlnet/shuffle:
source: lllyasviel/control_v11e_sd15_shuffle
sd-1/controlnet/tile:
source: lllyasviel/control_v11f1e_sd15_tile
sd-1/controlnet/ip2p:
source: lllyasviel/control_v11e_sd15_ip2p
sd-1/t2i_adapter/canny-sd15:
source: TencentARC/t2iadapter_canny_sd15v2
sd-1/t2i_adapter/sketch-sd15:
source: TencentARC/t2iadapter_sketch_sd15v2
sd-1/t2i_adapter/depth-sd15:
source: TencentARC/t2iadapter_depth_sd15v2
sd-1/t2i_adapter/zoedepth-sd15:
source: TencentARC/t2iadapter_zoedepth_sd15v1
sdxl/t2i_adapter/canny-sdxl:
source: TencentARC/t2i-adapter-canny-sdxl-1.0
sdxl/t2i_adapter/zoedepth-sdxl:
source: TencentARC/t2i-adapter-depth-zoe-sdxl-1.0
sdxl/t2i_adapter/lineart-sdxl:
source: TencentARC/t2i-adapter-lineart-sdxl-1.0
sdxl/t2i_adapter/sketch-sdxl:
source: TencentARC/t2i-adapter-sketch-sdxl-1.0
sd-1/embedding/EasyNegative:
source: https://huggingface.co/embed/EasyNegative/resolve/main/EasyNegative.safetensors
recommended: True
description: A textual inversion to use in the negative prompt to reduce bad anatomy
sd-1/lora/FlatColor:
source: https://civitai.com/models/6433/loraflatcolor
recommended: True
description: A LoRA that generates scenery using solid blocks of color
sd-1/lora/Ink scenery:
source: https://civitai.com/api/download/models/83390
description: Generate india ink-like landscapes
sd-1/ip_adapter/ip_adapter_sd15:
source: InvokeAI/ip_adapter_sd15
recommended: True
requires:
- InvokeAI/ip_adapter_sd_image_encoder
description: IP-Adapter for SD 1.5 models
sd-1/ip_adapter/ip_adapter_plus_sd15:
source: InvokeAI/ip_adapter_plus_sd15
recommended: False
requires:
- InvokeAI/ip_adapter_sd_image_encoder
description: Refined IP-Adapter for SD 1.5 models
sd-1/ip_adapter/ip_adapter_plus_face_sd15:
source: InvokeAI/ip_adapter_plus_face_sd15
recommended: False
requires:
- InvokeAI/ip_adapter_sd_image_encoder
description: Refined IP-Adapter for SD 1.5 models, adapted for faces
sdxl/ip_adapter/ip_adapter_sdxl:
source: InvokeAI/ip_adapter_sdxl
recommended: False
requires:
- InvokeAI/ip_adapter_sdxl_image_encoder
description: IP-Adapter for SDXL models
any/clip_vision/ip_adapter_sd_image_encoder:
source: InvokeAI/ip_adapter_sd_image_encoder
recommended: False
description: Required model for using IP-Adapters with SD-1/2 models
any/clip_vision/ip_adapter_sdxl_image_encoder:
source: InvokeAI/ip_adapter_sdxl_image_encoder
recommended: False
description: Required model for using IP-Adapters with SDXL models

View File

@ -2,3 +2,5 @@
Wrapper for invokeai.backend.configure.invokeai_configure
"""
from ...backend.install.invokeai_configure import main as invokeai_configure # noqa: F401
__all__ = ["invokeai_configure"]

View File

@ -0,0 +1,645 @@
#!/usr/bin/env python
# Copyright (c) 2022 Lincoln D. Stein (https://github.com/lstein)
# Before running stable-diffusion on an internet-isolated machine,
# run this script from one with internet connectivity. The
# two machines must share a common .cache directory.
"""
This is the npyscreen frontend to the model installation application.
It is currently named model_install2.py, but will ultimately replace model_install.py.
"""
import argparse
import curses
import sys
import traceback
import warnings
from argparse import Namespace
from shutil import get_terminal_size
from typing import Any, Dict, List, Optional, Set
import npyscreen
import torch
from npyscreen import widget
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.model_install import ModelInstallService
from invokeai.backend.install.install_helper import InstallHelper, InstallSelections, UnifiedModelInfo
from invokeai.backend.model_manager import ModelType
from invokeai.backend.util import choose_precision, choose_torch_device
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.frontend.install.widgets import (
MIN_COLS,
MIN_LINES,
CenteredTitleText,
CyclingForm,
MultiSelectColumns,
SingleSelectColumns,
TextBox,
WindowTooSmallException,
set_min_terminal_size,
)
warnings.filterwarnings("ignore", category=UserWarning) # noqa: E402
config = InvokeAIAppConfig.get_config()
logger = InvokeAILogger.get_logger("ModelInstallService")
logger.setLevel("WARNING")
# logger.setLevel('DEBUG')
# build a table mapping all non-printable characters to None
# for stripping control characters
# from https://stackoverflow.com/questions/92438/stripping-non-printable-characters-from-a-string-in-python
NOPRINT_TRANS_TABLE = {i: None for i in range(0, sys.maxunicode + 1) if not chr(i).isprintable()}
# maximum number of installed models we can display before overflowing vertically
MAX_OTHER_MODELS = 72
def make_printable(s: str) -> str:
"""Replace non-printable characters in a string."""
return s.translate(NOPRINT_TRANS_TABLE)
class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
"""Main form for interactive TUI."""
# for responsive resizing set to False, but this seems to cause a crash!
FIX_MINIMUM_SIZE_WHEN_CREATED = True
# for persistence
current_tab = 0
def __init__(self, parentApp: npyscreen.NPSAppManaged, name: str, multipage: bool = False, **keywords: Any):
self.multipage = multipage
self.subprocess = None
super().__init__(parentApp=parentApp, name=name, **keywords)
def create(self) -> None:
self.installer = self.parentApp.install_helper.installer
self.model_labels = self._get_model_labels()
self.keypress_timeout = 10
self.counter = 0
self.subprocess_connection = None
window_width, window_height = get_terminal_size()
# npyscreen has no typing hints
self.nextrely -= 1 # type: ignore
self.add_widget_intelligent(
npyscreen.FixedText,
value="Use ctrl-N and ctrl-P to move to the <N>ext and <P>revious fields. Cursor keys navigate, and <space> selects.",
editable=False,
color="CAUTION",
)
self.nextrely += 1 # type: ignore
self.tabs = self.add_widget_intelligent(
SingleSelectColumns,
values=[
"STARTERS",
"MAINS",
"CONTROLNETS",
"T2I-ADAPTERS",
"IP-ADAPTERS",
"LORAS",
"TI EMBEDDINGS",
],
value=[self.current_tab],
columns=7,
max_height=2,
relx=8,
scroll_exit=True,
)
self.tabs.on_changed = self._toggle_tables
top_of_table = self.nextrely # type: ignore
self.starter_pipelines = self.add_starter_pipelines()
bottom_of_table = self.nextrely # type: ignore
self.nextrely = top_of_table
self.pipeline_models = self.add_pipeline_widgets(
model_type=ModelType.Main, window_width=window_width, exclude=self.starter_models
)
# self.pipeline_models['autoload_pending'] = True
bottom_of_table = max(bottom_of_table, self.nextrely)
self.nextrely = top_of_table
self.controlnet_models = self.add_model_widgets(
model_type=ModelType.ControlNet,
window_width=window_width,
)
bottom_of_table = max(bottom_of_table, self.nextrely)
self.nextrely = top_of_table
self.t2i_models = self.add_model_widgets(
model_type=ModelType.T2IAdapter,
window_width=window_width,
)
bottom_of_table = max(bottom_of_table, self.nextrely)
self.nextrely = top_of_table
self.ipadapter_models = self.add_model_widgets(
model_type=ModelType.IPAdapter,
window_width=window_width,
)
bottom_of_table = max(bottom_of_table, self.nextrely)
self.nextrely = top_of_table
self.lora_models = self.add_model_widgets(
model_type=ModelType.Lora,
window_width=window_width,
)
bottom_of_table = max(bottom_of_table, self.nextrely)
self.nextrely = top_of_table
self.ti_models = self.add_model_widgets(
model_type=ModelType.TextualInversion,
window_width=window_width,
)
bottom_of_table = max(bottom_of_table, self.nextrely)
self.nextrely = bottom_of_table + 1
self.nextrely += 1
back_label = "BACK"
cancel_label = "CANCEL"
current_position = self.nextrely
if self.multipage:
self.back_button = self.add_widget_intelligent(
npyscreen.ButtonPress,
name=back_label,
when_pressed_function=self.on_back,
)
else:
self.nextrely = current_position
self.cancel_button = self.add_widget_intelligent(
npyscreen.ButtonPress, name=cancel_label, when_pressed_function=self.on_cancel
)
self.nextrely = current_position
label = "APPLY CHANGES"
self.nextrely = current_position
self.done = self.add_widget_intelligent(
npyscreen.ButtonPress,
name=label,
relx=window_width - len(label) - 15,
when_pressed_function=self.on_done,
)
# This restores the selected page on return from an installation
for _i in range(1, self.current_tab + 1):
self.tabs.h_cursor_line_down(1)
self._toggle_tables([self.current_tab])
############# diffusers tab ##########
def add_starter_pipelines(self) -> dict[str, npyscreen.widget]:
"""Add widgets responsible for selecting diffusers models"""
widgets: Dict[str, npyscreen.widget] = {}
all_models = self.all_models # master dict of all models, indexed by key
model_list = [x for x in self.starter_models if all_models[x].type in ["main", "vae"]]
model_labels = [self.model_labels[x] for x in model_list]
widgets.update(
label1=self.add_widget_intelligent(
CenteredTitleText,
name="Select from a starter set of Stable Diffusion models from HuggingFace and Civitae.",
editable=False,
labelColor="CAUTION",
)
)
self.nextrely -= 1
# if user has already installed some initial models, then don't patronize them
# by showing more recommendations
show_recommended = len(self.installed_models) == 0
checked = [
model_list.index(x)
for x in model_list
if (show_recommended and all_models[x].recommended) or all_models[x].installed
]
widgets.update(
models_selected=self.add_widget_intelligent(
MultiSelectColumns,
columns=1,
name="Install Starter Models",
values=model_labels,
value=checked,
max_height=len(model_list) + 1,
relx=4,
scroll_exit=True,
),
models=model_list,
)
self.nextrely += 1
return widgets
############# Add a set of model install widgets ########
def add_model_widgets(
self,
model_type: ModelType,
window_width: int = 120,
install_prompt: Optional[str] = None,
exclude: Optional[Set[str]] = None,
) -> dict[str, npyscreen.widget]:
"""Generic code to create model selection widgets"""
if exclude is None:
exclude = set()
widgets: Dict[str, npyscreen.widget] = {}
all_models = self.all_models
model_list = sorted(
[x for x in all_models if all_models[x].type == model_type and x not in exclude],
key=lambda x: all_models[x].name or "",
)
model_labels = [self.model_labels[x] for x in model_list]
show_recommended = len(self.installed_models) == 0
truncated = False
if len(model_list) > 0:
max_width = max([len(x) for x in model_labels])
columns = window_width // (max_width + 8) # 8 characters for "[x] " and padding
columns = min(len(model_list), columns) or 1
prompt = (
install_prompt
or f"Select the desired {model_type.value.title()} models to install. Unchecked models will be purged from disk."
)
widgets.update(
label1=self.add_widget_intelligent(
CenteredTitleText,
name=prompt,
editable=False,
labelColor="CAUTION",
)
)
if len(model_labels) > MAX_OTHER_MODELS:
model_labels = model_labels[0:MAX_OTHER_MODELS]
truncated = True
widgets.update(
models_selected=self.add_widget_intelligent(
MultiSelectColumns,
columns=columns,
name=f"Install {model_type} Models",
values=model_labels,
value=[
model_list.index(x)
for x in model_list
if (show_recommended and all_models[x].recommended) or all_models[x].installed
],
max_height=len(model_list) // columns + 1,
relx=4,
scroll_exit=True,
),
models=model_list,
)
if truncated:
widgets.update(
warning_message=self.add_widget_intelligent(
npyscreen.FixedText,
value=f"Too many models to display (max={MAX_OTHER_MODELS}). Some are not displayed.",
editable=False,
color="CAUTION",
)
)
self.nextrely += 1
widgets.update(
download_ids=self.add_widget_intelligent(
TextBox,
name="Additional URLs, or HuggingFace repo_ids to install (Space separated. Use shift-control-V to paste):",
max_height=6,
scroll_exit=True,
editable=True,
)
)
return widgets
### Tab for arbitrary diffusers widgets ###
def add_pipeline_widgets(
self,
model_type: ModelType = ModelType.Main,
window_width: int = 120,
**kwargs,
) -> dict[str, npyscreen.widget]:
"""Similar to add_model_widgets() but adds some additional widgets at the bottom
to support the autoload directory"""
widgets = self.add_model_widgets(
model_type=model_type,
window_width=window_width,
install_prompt=f"Installed {model_type.value.title()} models. Unchecked models in the InvokeAI root directory will be deleted. Enter URLs, paths or repo_ids to import.",
**kwargs,
)
return widgets
def resize(self) -> None:
super().resize()
if s := self.starter_pipelines.get("models_selected"):
if model_list := self.starter_pipelines.get("models"):
s.values = [self.model_labels[x] for x in model_list]
def _toggle_tables(self, value: List[int]) -> None:
selected_tab = value[0]
widgets = [
self.starter_pipelines,
self.pipeline_models,
self.controlnet_models,
self.t2i_models,
self.ipadapter_models,
self.lora_models,
self.ti_models,
]
for group in widgets:
for _k, v in group.items():
try:
v.hidden = True
v.editable = False
except Exception:
pass
for _k, v in widgets[selected_tab].items():
try:
v.hidden = False
if not isinstance(v, (npyscreen.FixedText, npyscreen.TitleFixedText, CenteredTitleText)):
v.editable = True
except Exception:
pass
self.__class__.current_tab = selected_tab # for persistence
self.display()
def _get_model_labels(self) -> dict[str, str]:
"""Return a list of trimmed labels for all models."""
window_width, window_height = get_terminal_size()
checkbox_width = 4
spacing_width = 2
result = {}
models = self.all_models
label_width = max([len(models[x].name or "") for x in self.starter_models])
description_width = window_width - label_width - checkbox_width - spacing_width
for key in self.all_models:
description = models[key].description
description = (
description[0 : description_width - 3] + "..."
if description and len(description) > description_width
else description
if description
else ""
)
result[key] = f"%-{label_width}s %s" % (models[key].name, description)
return result
def _get_columns(self) -> int:
window_width, window_height = get_terminal_size()
cols = 4 if window_width > 240 else 3 if window_width > 160 else 2 if window_width > 80 else 1
return min(cols, len(self.installed_models))
def confirm_deletions(self, selections: InstallSelections) -> bool:
remove_models = selections.remove_models
if remove_models:
model_names = [self.all_models[x].name or "" for x in remove_models]
mods = "\n".join(model_names)
is_ok = npyscreen.notify_ok_cancel(
f"These unchecked models will be deleted from disk. Continue?\n---------\n{mods}"
)
assert isinstance(is_ok, bool) # npyscreen doesn't have return type annotations
return is_ok
else:
return True
@property
def all_models(self) -> Dict[str, UnifiedModelInfo]:
# npyscreen doesn't having typing hints
return self.parentApp.install_helper.all_models # type: ignore
@property
def starter_models(self) -> List[str]:
return self.parentApp.install_helper._starter_models # type: ignore
@property
def installed_models(self) -> List[str]:
return self.parentApp.install_helper._installed_models # type: ignore
def on_back(self) -> None:
self.parentApp.switchFormPrevious()
self.editing = False
def on_cancel(self) -> None:
self.parentApp.setNextForm(None)
self.parentApp.user_cancelled = True
self.editing = False
def on_done(self) -> None:
self.marshall_arguments()
if not self.confirm_deletions(self.parentApp.install_selections):
return
self.parentApp.setNextForm(None)
self.parentApp.user_cancelled = False
self.editing = False
def marshall_arguments(self) -> None:
"""
Assemble arguments and store as attributes of the application:
.starter_models: dict of model names to install from INITIAL_CONFIGURE.yaml
True => Install
False => Remove
.scan_directory: Path to a directory of models to scan and import
.autoscan_on_startup: True if invokeai should scan and import at startup time
.import_model_paths: list of URLs, repo_ids and file paths to import
"""
selections = self.parentApp.install_selections
all_models = self.all_models
# Defined models (in INITIAL_CONFIG.yaml or models.yaml) to add/remove
ui_sections = [
self.starter_pipelines,
self.pipeline_models,
self.controlnet_models,
self.t2i_models,
self.ipadapter_models,
self.lora_models,
self.ti_models,
]
for section in ui_sections:
if "models_selected" not in section:
continue
selected = {section["models"][x] for x in section["models_selected"].value}
models_to_install = [x for x in selected if not self.all_models[x].installed]
models_to_remove = [x for x in section["models"] if x not in selected and self.all_models[x].installed]
selections.remove_models.extend(models_to_remove)
selections.install_models.extend([all_models[x] for x in models_to_install])
# models located in the 'download_ids" section
for section in ui_sections:
if downloads := section.get("download_ids"):
models = [UnifiedModelInfo(source=x) for x in downloads.value.split()]
selections.install_models.extend(models)
class AddModelApplication(npyscreen.NPSAppManaged): # type: ignore
def __init__(self, opt: Namespace, install_helper: InstallHelper):
super().__init__()
self.program_opts = opt
self.user_cancelled = False
self.install_selections = InstallSelections()
self.install_helper = install_helper
def onStart(self) -> None:
npyscreen.setTheme(npyscreen.Themes.DefaultTheme)
self.main_form = self.addForm(
"MAIN",
addModelsForm,
name="Install Stable Diffusion Models",
cycle_widgets=False,
)
def list_models(installer: ModelInstallService, model_type: ModelType):
"""Print out all models of type model_type."""
models = installer.record_store.search_by_attr(model_type=model_type)
print(f"Installed models of type `{model_type}`:")
for model in models:
path = (config.models_path / model.path).resolve()
print(f"{model.name:40}{model.base.value:14}{path}")
# --------------------------------------------------------
def select_and_download_models(opt: Namespace) -> None:
"""Prompt user for install/delete selections and execute."""
precision = "float32" if opt.full_precision else choose_precision(torch.device(choose_torch_device()))
# unsure how to avoid a typing complaint in the next line: config.precision is an enumerated Literal
config.precision = precision # type: ignore
install_helper = InstallHelper(config, logger)
installer = install_helper.installer
if opt.list_models:
list_models(installer, opt.list_models)
elif opt.add or opt.delete:
selections = InstallSelections(
install_models=[UnifiedModelInfo(source=x) for x in (opt.add or [])], remove_models=opt.delete or []
)
install_helper.add_or_delete(selections)
elif opt.default_only:
selections = InstallSelections(install_models=[install_helper.default_model()])
install_helper.add_or_delete(selections)
elif opt.yes_to_all:
selections = InstallSelections(install_models=install_helper.recommended_models())
install_helper.add_or_delete(selections)
# this is where the TUI is called
else:
if not set_min_terminal_size(MIN_COLS, MIN_LINES):
raise WindowTooSmallException(
"Could not increase terminal size. Try running again with a larger window or smaller font size."
)
installApp = AddModelApplication(opt, install_helper)
try:
installApp.run()
except KeyboardInterrupt:
print("Aborted...")
sys.exit(-1)
install_helper.add_or_delete(installApp.install_selections)
# -------------------------------------
def main() -> None:
parser = argparse.ArgumentParser(description="InvokeAI model downloader")
parser.add_argument(
"--add",
nargs="*",
help="List of URLs, local paths or repo_ids of models to install",
)
parser.add_argument(
"--delete",
nargs="*",
help="List of names of models to delete. Use type:name to disambiguate, as in `controlnet:my_model`",
)
parser.add_argument(
"--full-precision",
dest="full_precision",
action=argparse.BooleanOptionalAction,
type=bool,
default=False,
help="use 32-bit weights instead of faster 16-bit weights",
)
parser.add_argument(
"--yes",
"-y",
dest="yes_to_all",
action="store_true",
help='answer "yes" to all prompts',
)
parser.add_argument(
"--default_only",
action="store_true",
help="Only install the default model",
)
parser.add_argument(
"--list-models",
choices=[x.value for x in ModelType],
help="list installed models",
)
parser.add_argument(
"--root_dir",
dest="root",
type=str,
default=None,
help="path to root of install directory",
)
opt = parser.parse_args()
invoke_args = []
if opt.root:
invoke_args.extend(["--root", opt.root])
if opt.full_precision:
invoke_args.extend(["--precision", "float32"])
config.parse_args(invoke_args)
logger = InvokeAILogger().get_logger(config=config)
if not config.model_conf_path.exists():
logger.info("Your InvokeAI root directory is not set up. Calling invokeai-configure.")
from invokeai.frontend.install.invokeai_configure import invokeai_configure
invokeai_configure()
sys.exit(0)
try:
select_and_download_models(opt)
except AssertionError as e:
logger.error(e)
sys.exit(-1)
except KeyboardInterrupt:
curses.nocbreak()
curses.echo()
curses.endwin()
logger.info("Goodbye! Come back soon.")
except WindowTooSmallException as e:
logger.error(str(e))
except widget.NotEnoughSpaceForWidget as e:
if str(e).startswith("Height of 1 allocated"):
logger.error("Insufficient vertical space for the interface. Please make your window taller and try again")
input("Press any key to continue...")
except Exception as e:
if str(e).startswith("addwstr"):
logger.error(
"Insufficient horizontal space for the interface. Please make your window wider and try again."
)
else:
print(f"An exception has occurred: {str(e)} Details:")
print(traceback.format_exc(), file=sys.stderr)
input("Press any key to continue...")
# -------------------------------------
if __name__ == "__main__":
main()

View File

@ -0,0 +1,438 @@
"""
invokeai.frontend.merge exports a single function called merge_diffusion_models().
It merges 2-3 models together and create a new InvokeAI-registered diffusion model.
Copyright (c) 2023-24 Lincoln Stein and the InvokeAI Development Team
"""
import argparse
import curses
import re
import sys
from argparse import Namespace
from pathlib import Path
from typing import List, Optional, Tuple
import npyscreen
from npyscreen import widget
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.model_install import ModelInstallServiceBase
from invokeai.app.services.model_records import ModelRecordServiceBase
from invokeai.backend.install.install_helper import initialize_installer
from invokeai.backend.model_manager import (
BaseModelType,
ModelFormat,
ModelType,
ModelVariantType,
)
from invokeai.backend.model_manager.merge import ModelMerger
from invokeai.frontend.install.widgets import FloatTitleSlider, SingleSelectColumns, TextBox
config = InvokeAIAppConfig.get_config()
BASE_TYPES = [
(BaseModelType.StableDiffusion1, "Models Built on SD-1.x"),
(BaseModelType.StableDiffusion2, "Models Built on SD-2.x"),
(BaseModelType.StableDiffusionXL, "Models Built on SDXL"),
]
def _parse_args() -> Namespace:
parser = argparse.ArgumentParser(description="InvokeAI model merging")
parser.add_argument(
"--root_dir",
type=Path,
default=config.root,
help="Path to the invokeai runtime directory",
)
parser.add_argument(
"--front_end",
"--gui",
dest="front_end",
action="store_true",
default=False,
help="Activate the text-based graphical front end for collecting parameters. Aside from --root_dir, other parameters will be ignored.",
)
parser.add_argument(
"--models",
dest="model_names",
type=str,
nargs="+",
help="Two to three model names to be merged",
)
parser.add_argument(
"--base_model",
type=str,
choices=[x[0].value for x in BASE_TYPES],
help="The base model shared by the models to be merged",
)
parser.add_argument(
"--merged_model_name",
"--destination",
dest="merged_model_name",
type=str,
help="Name of the output model. If not specified, will be the concatenation of the input model names.",
)
parser.add_argument(
"--alpha",
type=float,
default=0.5,
help="The interpolation parameter, ranging from 0 to 1. It affects the ratio in which the checkpoints are merged. Higher values give more weight to the 2d and 3d models",
)
parser.add_argument(
"--interpolation",
dest="interp",
type=str,
choices=["weighted_sum", "sigmoid", "inv_sigmoid", "add_difference"],
default="weighted_sum",
help='Interpolation method to use. If three models are present, only "add_difference" will work.',
)
parser.add_argument(
"--force",
action="store_true",
help="Try to merge models even if they are incompatible with each other",
)
parser.add_argument(
"--clobber",
"--overwrite",
dest="clobber",
action="store_true",
help="Overwrite the merged model if --merged_model_name already exists",
)
return parser.parse_args()
# ------------------------- GUI HERE -------------------------
class mergeModelsForm(npyscreen.FormMultiPageAction):
interpolations = ["weighted_sum", "sigmoid", "inv_sigmoid"]
def __init__(self, parentApp, name):
self.parentApp = parentApp
self.ALLOW_RESIZE = True
self.FIX_MINIMUM_SIZE_WHEN_CREATED = False
super().__init__(parentApp, name)
@property
def model_record_store(self) -> ModelRecordServiceBase:
installer: ModelInstallServiceBase = self.parentApp.installer
return installer.record_store
def afterEditing(self) -> None:
self.parentApp.setNextForm(None)
def create(self) -> None:
window_height, window_width = curses.initscr().getmaxyx()
self.current_base = 0
self.models = self.get_models(BASE_TYPES[self.current_base][0])
self.model_names = [x[1] for x in self.models]
max_width = max([len(x) for x in self.model_names])
max_width += 6
horizontal_layout = max_width * 3 < window_width
self.add_widget_intelligent(
npyscreen.FixedText,
color="CONTROL",
value="Select two models to merge and optionally a third.",
editable=False,
)
self.add_widget_intelligent(
npyscreen.FixedText,
color="CONTROL",
value="Use up and down arrows to move, <space> to select an item, <tab> and <shift-tab> to move from one field to the next.",
editable=False,
)
self.nextrely += 1
self.base_select = self.add_widget_intelligent(
SingleSelectColumns,
values=[x[1] for x in BASE_TYPES],
value=[self.current_base],
columns=4,
max_height=2,
relx=8,
scroll_exit=True,
)
self.base_select.on_changed = self._populate_models
self.add_widget_intelligent(
npyscreen.FixedText,
value="MODEL 1",
color="GOOD",
editable=False,
rely=6 if horizontal_layout else None,
)
self.model1 = self.add_widget_intelligent(
npyscreen.SelectOne,
values=self.model_names,
value=0,
max_height=len(self.model_names),
max_width=max_width,
scroll_exit=True,
rely=7,
)
self.add_widget_intelligent(
npyscreen.FixedText,
value="MODEL 2",
color="GOOD",
editable=False,
relx=max_width + 3 if horizontal_layout else None,
rely=6 if horizontal_layout else None,
)
self.model2 = self.add_widget_intelligent(
npyscreen.SelectOne,
name="(2)",
values=self.model_names,
value=1,
max_height=len(self.model_names),
max_width=max_width,
relx=max_width + 3 if horizontal_layout else None,
rely=7 if horizontal_layout else None,
scroll_exit=True,
)
self.add_widget_intelligent(
npyscreen.FixedText,
value="MODEL 3",
color="GOOD",
editable=False,
relx=max_width * 2 + 3 if horizontal_layout else None,
rely=6 if horizontal_layout else None,
)
models_plus_none = self.model_names.copy()
models_plus_none.insert(0, "None")
self.model3 = self.add_widget_intelligent(
npyscreen.SelectOne,
name="(3)",
values=models_plus_none,
value=0,
max_height=len(self.model_names) + 1,
max_width=max_width,
scroll_exit=True,
relx=max_width * 2 + 3 if horizontal_layout else None,
rely=7 if horizontal_layout else None,
)
for m in [self.model1, self.model2, self.model3]:
m.when_value_edited = self.models_changed
self.merged_model_name = self.add_widget_intelligent(
TextBox,
name="Name for merged model:",
labelColor="CONTROL",
max_height=3,
value="",
scroll_exit=True,
)
self.force = self.add_widget_intelligent(
npyscreen.Checkbox,
name="Force merge of models created by different diffusers library versions",
labelColor="CONTROL",
value=True,
scroll_exit=True,
)
self.nextrely += 1
self.merge_method = self.add_widget_intelligent(
npyscreen.TitleSelectOne,
name="Merge Method:",
values=self.interpolations,
value=0,
labelColor="CONTROL",
max_height=len(self.interpolations) + 1,
scroll_exit=True,
)
self.alpha = self.add_widget_intelligent(
FloatTitleSlider,
name="Weight (alpha) to assign to second and third models:",
out_of=1.0,
step=0.01,
lowest=0,
value=0.5,
labelColor="CONTROL",
scroll_exit=True,
)
self.model1.editing = True
def models_changed(self) -> None:
models = self.model1.values
selected_model1 = self.model1.value[0]
selected_model2 = self.model2.value[0]
selected_model3 = self.model3.value[0]
merged_model_name = f"{models[selected_model1]}+{models[selected_model2]}"
self.merged_model_name.value = merged_model_name
if selected_model3 > 0:
self.merge_method.values = ["add_difference ( A+(B-C) )"]
self.merged_model_name.value += f"+{models[selected_model3 -1]}" # In model3 there is one more element in the list (None). So we have to subtract one.
else:
self.merge_method.values = self.interpolations
self.merge_method.value = 0
def on_ok(self) -> None:
if self.validate_field_values() and self.check_for_overwrite():
self.parentApp.setNextForm(None)
self.editing = False
self.parentApp.merge_arguments = self.marshall_arguments()
npyscreen.notify("Starting the merge...")
else:
self.editing = True
def on_cancel(self) -> None:
sys.exit(0)
def marshall_arguments(self) -> dict:
model_keys = [x[0] for x in self.models]
models = [
model_keys[self.model1.value[0]],
model_keys[self.model2.value[0]],
]
if self.model3.value[0] > 0:
models.append(model_keys[self.model3.value[0] - 1])
interp = "add_difference"
else:
interp = self.interpolations[self.merge_method.value[0]]
args = {
"model_keys": models,
"alpha": self.alpha.value,
"interp": interp,
"force": self.force.value,
"merged_model_name": self.merged_model_name.value,
}
return args
def check_for_overwrite(self) -> bool:
model_out = self.merged_model_name.value
if model_out not in self.model_names:
return True
else:
result: bool = npyscreen.notify_yes_no(
f"The chosen merged model destination, {model_out}, is already in use. Overwrite?"
)
return result
def validate_field_values(self) -> bool:
bad_fields = []
model_names = self.model_names
selected_models = {model_names[self.model1.value[0]], model_names[self.model2.value[0]]}
if self.model3.value[0] > 0:
selected_models.add(model_names[self.model3.value[0] - 1])
if len(selected_models) < 2:
bad_fields.append(f"Please select two or three DIFFERENT models to compare. You selected {selected_models}")
if len(bad_fields) > 0:
message = "The following problems were detected and must be corrected:"
for problem in bad_fields:
message += f"\n* {problem}"
npyscreen.notify_confirm(message)
return False
else:
return True
def get_models(self, base_model: Optional[BaseModelType] = None) -> List[Tuple[str, str]]: # key to name
models = [
(x.key, x.name)
for x in self.model_record_store.search_by_attr(model_type=ModelType.Main, base_model=base_model)
if x.format == ModelFormat("diffusers")
and hasattr(x, "variant")
and x.variant == ModelVariantType("normal")
]
return sorted(models, key=lambda x: x[1])
def _populate_models(self, value: List[int]) -> None:
base_model = BASE_TYPES[value[0]][0]
self.models = self.get_models(base_model)
self.model_names = [x[1] for x in self.models]
models_plus_none = self.model_names.copy()
models_plus_none.insert(0, "None")
self.model1.values = self.model_names
self.model2.values = self.model_names
self.model3.values = models_plus_none
self.display()
# npyscreen is untyped and causes mypy to get naggy
class Mergeapp(npyscreen.NPSAppManaged): # type: ignore
def __init__(self, installer: ModelInstallServiceBase):
"""Initialize the npyscreen application."""
super().__init__()
self.installer = installer
def onStart(self) -> None:
npyscreen.setTheme(npyscreen.Themes.ElegantTheme)
self.main = self.addForm("MAIN", mergeModelsForm, name="Merge Models Settings")
def run_gui(args: Namespace) -> None:
installer = initialize_installer(config)
mergeapp = Mergeapp(installer)
mergeapp.run()
merge_args = mergeapp.merge_arguments
merger = ModelMerger(installer)
merger.merge_diffusion_models_and_save(**merge_args)
logger.info(f'Models merged into new model: "{merge_args.merged_model_name}".')
def run_cli(args: Namespace) -> None:
assert args.alpha >= 0 and args.alpha <= 1.0, "alpha must be between 0 and 1"
assert (
args.model_names and len(args.model_names) >= 1 and len(args.model_names) <= 3
), "Please provide the --models argument to list 2 to 3 models to merge. Use --help for full usage."
if not args.merged_model_name:
args.merged_model_name = "+".join(args.model_names)
logger.info(f'No --merged_model_name provided. Defaulting to "{args.merged_model_name}"')
installer = initialize_installer(config)
store = installer.record_store
assert (
len(store.search_by_attr(args.merged_model_name, args.base_model, ModelType.Main)) == 0 or args.clobber
), f'A model named "{args.merged_model_name}" already exists. Use --clobber to overwrite.'
merger = ModelMerger(installer)
model_keys = []
for name in args.model_names:
if len(name) == 32 and re.match(r"^[0-9a-f]$", name):
model_keys.append(name)
else:
models = store.search_by_attr(
model_name=name, model_type=ModelType.Main, base_model=BaseModelType(args.base_model)
)
assert len(models) > 0, f"{name}: Unknown model"
assert len(models) < 2, f"{name}: More than one model by this name. Please specify the model key instead."
model_keys.append(models[0].key)
merger.merge_diffusion_models_and_save(
alpha=args.alpha,
model_keys=model_keys,
merged_model_name=args.merged_model_name,
interp=args.interp,
force=args.force,
)
logger.info(f'Models merged into new model: "{args.merged_model_name}".')
def main() -> None:
args = _parse_args()
if args.root_dir:
config.parse_args(["--root", str(args.root_dir)])
else:
config.parse_args([])
try:
if args.front_end:
run_gui(args)
else:
run_cli(args)
except widget.NotEnoughSpaceForWidget as e:
if str(e).startswith("Height of 1 allocated"):
logger.error("You need to have at least two diffusers models defined in models.yaml in order to merge")
else:
logger.error("Not enough room for the user interface. Try making this window larger.")
sys.exit(-1)
except Exception as e:
logger.error(str(e))
sys.exit(-1)
except KeyboardInterrupt:
sys.exit(-1)
if __name__ == "__main__":
main()

View File

@ -3,7 +3,7 @@
"""
This is the frontend to "textual_inversion_training.py".
Copyright (c) 2023 Lincoln Stein and the InvokeAI Development Team
Copyright (c) 2023-24 Lincoln Stein and the InvokeAI Development Team
"""
@ -14,7 +14,7 @@ import sys
import traceback
from argparse import Namespace
from pathlib import Path
from typing import List, Tuple
from typing import Dict, List, Optional, Tuple
import npyscreen
from npyscreen import widget
@ -22,8 +22,9 @@ from omegaconf import OmegaConf
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig
from ...backend.training import do_textual_inversion_training, parse_args
from invokeai.backend.install.install_helper import initialize_installer
from invokeai.backend.model_manager import ModelType
from invokeai.backend.training import do_textual_inversion_training, parse_args
TRAINING_DATA = "text-inversion-training-data"
TRAINING_DIR = "text-inversion-output"
@ -44,19 +45,21 @@ class textualInversionForm(npyscreen.FormMultiPageAction):
precisions = ["no", "fp16", "bf16"]
learnable_properties = ["object", "style"]
def __init__(self, parentApp, name, saved_args=None):
def __init__(self, parentApp: npyscreen.NPSAppManaged, name: str, saved_args: Optional[Dict[str, str]] = None):
self.saved_args = saved_args or {}
super().__init__(parentApp, name)
def afterEditing(self):
def afterEditing(self) -> None:
self.parentApp.setNextForm(None)
def create(self):
def create(self) -> None:
self.model_names, default = self.get_model_names()
default_initializer_token = ""
default_placeholder_token = ""
saved_args = self.saved_args
assert config is not None
try:
default = self.model_names.index(saved_args["model"])
except Exception:
@ -71,7 +74,7 @@ class textualInversionForm(npyscreen.FormMultiPageAction):
self.model = self.add_widget_intelligent(
npyscreen.TitleSelectOne,
name="Model Name:",
values=self.model_names,
values=sorted(self.model_names),
value=default,
max_height=len(self.model_names) + 1,
scroll_exit=True,
@ -236,7 +239,7 @@ class textualInversionForm(npyscreen.FormMultiPageAction):
)
self.model.editing = True
def initializer_changed(self):
def initializer_changed(self) -> None:
placeholder = self.placeholder_token.value
self.prompt_token.value = f"(Trigger by using <{placeholder}> in your prompts)"
self.train_data_dir.value = str(config.root_dir / TRAINING_DATA / placeholder)
@ -275,10 +278,13 @@ class textualInversionForm(npyscreen.FormMultiPageAction):
return True
def get_model_names(self) -> Tuple[List[str], int]:
conf = OmegaConf.load(config.root_dir / "configs/models.yaml")
model_names = [idx for idx in sorted(conf.keys()) if conf[idx].get("format", None) == "diffusers"]
defaults = [idx for idx in range(len(model_names)) if "default" in conf[model_names[idx]]]
default = defaults[0] if len(defaults) > 0 else 0
global config
assert config is not None
installer = initialize_installer(config)
store = installer.record_store
main_models = store.search_by_attr(model_type=ModelType.Main)
model_names = [f"{x.base.value}/{x.type.value}/{x.name}" for x in main_models if x.format == "diffusers"]
default = 0
return (model_names, default)
def marshall_arguments(self) -> dict:
@ -326,7 +332,7 @@ class textualInversionForm(npyscreen.FormMultiPageAction):
class MyApplication(npyscreen.NPSAppManaged):
def __init__(self, saved_args=None):
def __init__(self, saved_args: Optional[Dict[str, str]] = None):
super().__init__()
self.ti_arguments = None
self.saved_args = saved_args
@ -341,11 +347,12 @@ class MyApplication(npyscreen.NPSAppManaged):
)
def copy_to_embeddings_folder(args: dict):
def copy_to_embeddings_folder(args: Dict[str, str]) -> None:
"""
Copy learned_embeds.bin into the embeddings folder, and offer to
delete the full model and checkpoints.
"""
assert config is not None
source = Path(args["output_dir"], "learned_embeds.bin")
dest_dir_name = args["placeholder_token"].strip("<>")
destination = config.root_dir / "embeddings" / dest_dir_name
@ -358,10 +365,11 @@ def copy_to_embeddings_folder(args: dict):
logger.info(f'Keeping {args["output_dir"]}')
def save_args(args: dict):
def save_args(args: dict) -> None:
"""
Save the current argument values to an omegaconf file
"""
assert config is not None
dest_dir = config.root_dir / TRAINING_DIR
os.makedirs(dest_dir, exist_ok=True)
conf_file = dest_dir / CONF_FILE
@ -373,6 +381,7 @@ def previous_args() -> dict:
"""
Get the previous arguments used.
"""
assert config is not None
conf_file = config.root_dir / TRAINING_DIR / CONF_FILE
try:
conf = OmegaConf.load(conf_file)
@ -383,24 +392,26 @@ def previous_args() -> dict:
return conf
def do_front_end(args: Namespace):
def do_front_end() -> None:
global config
saved_args = previous_args()
myapplication = MyApplication(saved_args=saved_args)
myapplication.run()
if args := myapplication.ti_arguments:
os.makedirs(args["output_dir"], exist_ok=True)
if my_args := myapplication.ti_arguments:
os.makedirs(my_args["output_dir"], exist_ok=True)
# Automatically add angle brackets around the trigger
if not re.match("^<.+>$", args["placeholder_token"]):
args["placeholder_token"] = f"<{args['placeholder_token']}>"
if not re.match("^<.+>$", my_args["placeholder_token"]):
my_args["placeholder_token"] = f"<{my_args['placeholder_token']}>"
args["only_save_embeds"] = True
save_args(args)
my_args["only_save_embeds"] = True
save_args(my_args)
try:
do_textual_inversion_training(InvokeAIAppConfig.get_config(), **args)
copy_to_embeddings_folder(args)
print(my_args)
do_textual_inversion_training(config, **my_args)
copy_to_embeddings_folder(my_args)
except Exception as e:
logger.error("An exception occurred during training. The exception was:")
logger.error(str(e))
@ -408,11 +419,12 @@ def do_front_end(args: Namespace):
logger.error(traceback.format_exc())
def main():
def main() -> None:
global config
args = parse_args()
args: Namespace = parse_args()
config = InvokeAIAppConfig.get_config()
config.parse_args([])
# change root if needed
if args.root_dir:
@ -420,7 +432,7 @@ def main():
try:
if args.front_end:
do_front_end(args)
do_front_end()
else:
do_textual_inversion_training(config, **vars(args))
except AssertionError as e:

View File

@ -0,0 +1,454 @@
#!/usr/bin/env python
"""
This is the frontend to "textual_inversion_training.py".
Copyright (c) 2023-24 Lincoln Stein and the InvokeAI Development Team
"""
import os
import re
import shutil
import sys
import traceback
from argparse import Namespace
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import npyscreen
from npyscreen import widget
from omegaconf import OmegaConf
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.install.install_helper import initialize_installer
from invokeai.backend.model_manager import ModelType
from invokeai.backend.training import do_textual_inversion_training, parse_args
TRAINING_DATA = "text-inversion-training-data"
TRAINING_DIR = "text-inversion-output"
CONF_FILE = "preferences.conf"
config = None
class textualInversionForm(npyscreen.FormMultiPageAction):
resolutions = [512, 768, 1024]
lr_schedulers = [
"linear",
"cosine",
"cosine_with_restarts",
"polynomial",
"constant",
"constant_with_warmup",
]
precisions = ["no", "fp16", "bf16"]
learnable_properties = ["object", "style"]
def __init__(self, parentApp: npyscreen.NPSAppManaged, name: str, saved_args: Optional[Dict[str, str]] = None):
self.saved_args = saved_args or {}
super().__init__(parentApp, name)
def afterEditing(self) -> None:
self.parentApp.setNextForm(None)
def create(self) -> None:
self.model_names, default = self.get_model_names()
default_initializer_token = ""
default_placeholder_token = ""
saved_args = self.saved_args
assert config is not None
try:
default = self.model_names.index(saved_args["model"])
except Exception:
pass
self.add_widget_intelligent(
npyscreen.FixedText,
value="Use ctrl-N and ctrl-P to move to the <N>ext and <P>revious fields, cursor arrows to make a selection, and space to toggle checkboxes.",
editable=False,
)
self.model = self.add_widget_intelligent(
npyscreen.TitleSelectOne,
name="Model Name:",
values=sorted(self.model_names),
value=default,
max_height=len(self.model_names) + 1,
scroll_exit=True,
)
self.placeholder_token = self.add_widget_intelligent(
npyscreen.TitleText,
name="Trigger Term:",
value="", # saved_args.get('placeholder_token',''), # to restore previous term
scroll_exit=True,
)
self.placeholder_token.when_value_edited = self.initializer_changed
self.nextrely -= 1
self.nextrelx += 30
self.prompt_token = self.add_widget_intelligent(
npyscreen.FixedText,
name="Trigger term for use in prompt",
value="",
editable=False,
scroll_exit=True,
)
self.nextrelx -= 30
self.initializer_token = self.add_widget_intelligent(
npyscreen.TitleText,
name="Initializer:",
value=saved_args.get("initializer_token", default_initializer_token),
scroll_exit=True,
)
self.resume_from_checkpoint = self.add_widget_intelligent(
npyscreen.Checkbox,
name="Resume from last saved checkpoint",
value=False,
scroll_exit=True,
)
self.learnable_property = self.add_widget_intelligent(
npyscreen.TitleSelectOne,
name="Learnable property:",
values=self.learnable_properties,
value=self.learnable_properties.index(saved_args.get("learnable_property", "object")),
max_height=4,
scroll_exit=True,
)
self.train_data_dir = self.add_widget_intelligent(
npyscreen.TitleFilename,
name="Data Training Directory:",
select_dir=True,
must_exist=False,
value=str(
saved_args.get(
"train_data_dir",
config.root_dir / TRAINING_DATA / default_placeholder_token,
)
),
scroll_exit=True,
)
self.output_dir = self.add_widget_intelligent(
npyscreen.TitleFilename,
name="Output Destination Directory:",
select_dir=True,
must_exist=False,
value=str(
saved_args.get(
"output_dir",
config.root_dir / TRAINING_DIR / default_placeholder_token,
)
),
scroll_exit=True,
)
self.resolution = self.add_widget_intelligent(
npyscreen.TitleSelectOne,
name="Image resolution (pixels):",
values=self.resolutions,
value=self.resolutions.index(saved_args.get("resolution", 512)),
max_height=4,
scroll_exit=True,
)
self.center_crop = self.add_widget_intelligent(
npyscreen.Checkbox,
name="Center crop images before resizing to resolution",
value=saved_args.get("center_crop", False),
scroll_exit=True,
)
self.mixed_precision = self.add_widget_intelligent(
npyscreen.TitleSelectOne,
name="Mixed Precision:",
values=self.precisions,
value=self.precisions.index(saved_args.get("mixed_precision", "fp16")),
max_height=4,
scroll_exit=True,
)
self.num_train_epochs = self.add_widget_intelligent(
npyscreen.TitleSlider,
name="Number of training epochs:",
out_of=1000,
step=50,
lowest=1,
value=saved_args.get("num_train_epochs", 100),
scroll_exit=True,
)
self.max_train_steps = self.add_widget_intelligent(
npyscreen.TitleSlider,
name="Max Training Steps:",
out_of=10000,
step=500,
lowest=1,
value=saved_args.get("max_train_steps", 3000),
scroll_exit=True,
)
self.train_batch_size = self.add_widget_intelligent(
npyscreen.TitleSlider,
name="Batch Size (reduce if you run out of memory):",
out_of=50,
step=1,
lowest=1,
value=saved_args.get("train_batch_size", 8),
scroll_exit=True,
)
self.gradient_accumulation_steps = self.add_widget_intelligent(
npyscreen.TitleSlider,
name="Gradient Accumulation Steps (may need to decrease this to resume from a checkpoint):",
out_of=10,
step=1,
lowest=1,
value=saved_args.get("gradient_accumulation_steps", 4),
scroll_exit=True,
)
self.lr_warmup_steps = self.add_widget_intelligent(
npyscreen.TitleSlider,
name="Warmup Steps:",
out_of=100,
step=1,
lowest=0,
value=saved_args.get("lr_warmup_steps", 0),
scroll_exit=True,
)
self.learning_rate = self.add_widget_intelligent(
npyscreen.TitleText,
name="Learning Rate:",
value=str(
saved_args.get("learning_rate", "5.0e-04"),
),
scroll_exit=True,
)
self.scale_lr = self.add_widget_intelligent(
npyscreen.Checkbox,
name="Scale learning rate by number GPUs, steps and batch size",
value=saved_args.get("scale_lr", True),
scroll_exit=True,
)
self.enable_xformers_memory_efficient_attention = self.add_widget_intelligent(
npyscreen.Checkbox,
name="Use xformers acceleration",
value=saved_args.get("enable_xformers_memory_efficient_attention", False),
scroll_exit=True,
)
self.lr_scheduler = self.add_widget_intelligent(
npyscreen.TitleSelectOne,
name="Learning rate scheduler:",
values=self.lr_schedulers,
max_height=7,
value=self.lr_schedulers.index(saved_args.get("lr_scheduler", "constant")),
scroll_exit=True,
)
self.model.editing = True
def initializer_changed(self) -> None:
placeholder = self.placeholder_token.value
self.prompt_token.value = f"(Trigger by using <{placeholder}> in your prompts)"
self.train_data_dir.value = str(config.root_dir / TRAINING_DATA / placeholder)
self.output_dir.value = str(config.root_dir / TRAINING_DIR / placeholder)
self.resume_from_checkpoint.value = Path(self.output_dir.value).exists()
def on_ok(self):
if self.validate_field_values():
self.parentApp.setNextForm(None)
self.editing = False
self.parentApp.ti_arguments = self.marshall_arguments()
npyscreen.notify("Launching textual inversion training. This will take a while...")
else:
self.editing = True
def ok_cancel(self):
sys.exit(0)
def validate_field_values(self) -> bool:
bad_fields = []
if self.model.value is None:
bad_fields.append("Model Name must correspond to a known model in models.yaml")
if not re.match("^[a-zA-Z0-9.-]+$", self.placeholder_token.value):
bad_fields.append("Trigger term must only contain alphanumeric characters, the dot and hyphen")
if self.train_data_dir.value is None:
bad_fields.append("Data Training Directory cannot be empty")
if self.output_dir.value is None:
bad_fields.append("The Output Destination Directory cannot be empty")
if len(bad_fields) > 0:
message = "The following problems were detected and must be corrected:"
for problem in bad_fields:
message += f"\n* {problem}"
npyscreen.notify_confirm(message)
return False
else:
return True
def get_model_names(self) -> Tuple[List[str], int]:
global config
assert config is not None
installer = initialize_installer(config)
store = installer.record_store
main_models = store.search_by_attr(model_type=ModelType.Main)
model_names = [f"{x.base.value}/{x.type.value}/{x.name}" for x in main_models if x.format == "diffusers"]
default = 0
return (model_names, default)
def marshall_arguments(self) -> dict:
args = {}
# the choices
args.update(
model=self.model_names[self.model.value[0]],
resolution=self.resolutions[self.resolution.value[0]],
lr_scheduler=self.lr_schedulers[self.lr_scheduler.value[0]],
mixed_precision=self.precisions[self.mixed_precision.value[0]],
learnable_property=self.learnable_properties[self.learnable_property.value[0]],
)
# all the strings and booleans
for attr in (
"initializer_token",
"placeholder_token",
"train_data_dir",
"output_dir",
"scale_lr",
"center_crop",
"enable_xformers_memory_efficient_attention",
):
args[attr] = getattr(self, attr).value
# all the integers
for attr in (
"train_batch_size",
"gradient_accumulation_steps",
"num_train_epochs",
"max_train_steps",
"lr_warmup_steps",
):
args[attr] = int(getattr(self, attr).value)
# the floats (just one)
args.update(learning_rate=float(self.learning_rate.value))
# a special case
if self.resume_from_checkpoint.value and Path(self.output_dir.value).exists():
args["resume_from_checkpoint"] = "latest"
return args
class MyApplication(npyscreen.NPSAppManaged):
def __init__(self, saved_args: Optional[Dict[str, str]] = None):
super().__init__()
self.ti_arguments = None
self.saved_args = saved_args
def onStart(self):
npyscreen.setTheme(npyscreen.Themes.DefaultTheme)
self.main = self.addForm(
"MAIN",
textualInversionForm,
name="Textual Inversion Settings",
saved_args=self.saved_args,
)
def copy_to_embeddings_folder(args: Dict[str, str]) -> None:
"""
Copy learned_embeds.bin into the embeddings folder, and offer to
delete the full model and checkpoints.
"""
assert config is not None
source = Path(args["output_dir"], "learned_embeds.bin")
dest_dir_name = args["placeholder_token"].strip("<>")
destination = config.root_dir / "embeddings" / dest_dir_name
os.makedirs(destination, exist_ok=True)
logger.info(f"Training completed. Copying learned_embeds.bin into {str(destination)}")
shutil.copy(source, destination)
if (input("Delete training logs and intermediate checkpoints? [y] ") or "y").startswith(("y", "Y")):
shutil.rmtree(Path(args["output_dir"]))
else:
logger.info(f'Keeping {args["output_dir"]}')
def save_args(args: dict) -> None:
"""
Save the current argument values to an omegaconf file
"""
assert config is not None
dest_dir = config.root_dir / TRAINING_DIR
os.makedirs(dest_dir, exist_ok=True)
conf_file = dest_dir / CONF_FILE
conf = OmegaConf.create(args)
OmegaConf.save(config=conf, f=conf_file)
def previous_args() -> dict:
"""
Get the previous arguments used.
"""
assert config is not None
conf_file = config.root_dir / TRAINING_DIR / CONF_FILE
try:
conf = OmegaConf.load(conf_file)
conf["placeholder_token"] = conf["placeholder_token"].strip("<>")
except Exception:
conf = None
return conf
def do_front_end() -> None:
global config
saved_args = previous_args()
myapplication = MyApplication(saved_args=saved_args)
myapplication.run()
if my_args := myapplication.ti_arguments:
os.makedirs(my_args["output_dir"], exist_ok=True)
# Automatically add angle brackets around the trigger
if not re.match("^<.+>$", my_args["placeholder_token"]):
my_args["placeholder_token"] = f"<{my_args['placeholder_token']}>"
my_args["only_save_embeds"] = True
save_args(my_args)
try:
print(my_args)
do_textual_inversion_training(config, **my_args)
copy_to_embeddings_folder(my_args)
except Exception as e:
logger.error("An exception occurred during training. The exception was:")
logger.error(str(e))
logger.error("DETAILS:")
logger.error(traceback.format_exc())
def main() -> None:
global config
args: Namespace = parse_args()
config = InvokeAIAppConfig.get_config()
config.parse_args([])
# change root if needed
if args.root_dir:
config.root = args.root_dir
try:
if args.front_end:
do_front_end()
else:
do_textual_inversion_training(config, **vars(args))
except AssertionError as e:
logger.error(e)
sys.exit(-1)
except KeyboardInterrupt:
pass
except (widget.NotEnoughSpaceForWidget, Exception) as e:
if str(e).startswith("Height of 1 allocated"):
logger.error("You need to have at least one diffusers models defined in models.yaml in order to train")
elif str(e).startswith("addwstr"):
logger.error("Not enough window space for the interface. Please make your window larger and try again.")
else:
logger.error(e)
sys.exit(-1)
if __name__ == "__main__":
main()

View File

@ -137,8 +137,10 @@ dependencies = [
# full commands
"invokeai-configure" = "invokeai.frontend.install.invokeai_configure:invokeai_configure"
"invokeai-merge" = "invokeai.frontend.merge:invokeai_merge_diffusers"
"invokeai-merge2" = "invokeai.frontend.merge.merge_diffusers2:main"
"invokeai-ti" = "invokeai.frontend.training:invokeai_textual_inversion"
"invokeai-model-install" = "invokeai.frontend.install.model_install:main"
"invokeai-model-install2" = "invokeai.frontend.install.model_install2:main" # will eventually be renamed to invokeai-model-install
"invokeai-migrate3" = "invokeai.backend.install.migrate_to_3:main"
"invokeai-update" = "invokeai.frontend.install.invokeai_update:main"
"invokeai-metadata" = "invokeai.backend.image_util.invoke_metadata:main"