mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Port the command-line tools to use model_manager2 (#5546)
* Port the command-line tools to use model_manager2 1.Reimplement the following: - invokeai-model-install - invokeai-merge - invokeai-ti To avoid breaking the original modeal manager, the udpated tools have been renamed invokeai-model-install2 and invokeai-merge2. The textual inversion training script should continue to work with existing installations. The "starter" models now live in `invokeai/configs/INITIAL_MODELS2.yaml`. When the full model manager 2 is in place and working, I'll rename these files and commands. 2. Add the `merge` route to the web API. This will merge two or three models, resulting a new one. - Note that because the model installer selectively installs the `fp16` variant of models (rather than both 16- and 32-bit versions as previous), the diffusers merge script will choke on any huggingface diffuserse models that were downloaded with the new installer. Previously-downloaded models should continue to merge correctly. I have a PR upstream https://github.com/huggingface/diffusers/pull/6670 to fix this. 3. (more important!) During implementation of the CLI tools, found and fixed a number of small runtime bugs in the model_manager2 implementation: - During model database migration, if a registered models file was not found on disk, the migration would be aborted. Now the offending model is skipped with a log warning. - Caught and fixed a condition in which the installer would download the entire diffusers repo when the user provided a single `.safetensors` file URL. - Caught and fixed a condition in which the installer would raise an exception and stop the app when a request for an unknown model's metadata was passed to Civitai. Now an error is logged and the installer continues. - Replaced the LoWRA starter LoRA with FlatColor. The former has been removed from Civitai. * fix ruff issue --------- Co-authored-by: Lincoln Stein <lstein@gmail.com>
This commit is contained in:
parent
d3320dc4ee
commit
f2777f5096
@ -1,7 +1,7 @@
|
|||||||
# Copyright (c) 2023 Lincoln D. Stein
|
# Copyright (c) 2023 Lincoln D. Stein
|
||||||
"""FastAPI route for model configuration records."""
|
"""FastAPI route for model configuration records."""
|
||||||
|
|
||||||
|
import pathlib
|
||||||
from hashlib import sha1
|
from hashlib import sha1
|
||||||
from random import randbytes
|
from random import randbytes
|
||||||
from typing import Any, Dict, List, Optional, Set
|
from typing import Any, Dict, List, Optional, Set
|
||||||
@ -27,6 +27,7 @@ from invokeai.backend.model_manager.config import (
|
|||||||
ModelFormat,
|
ModelFormat,
|
||||||
ModelType,
|
ModelType,
|
||||||
)
|
)
|
||||||
|
from invokeai.backend.model_manager.merge import MergeInterpolationMethod, ModelMerger
|
||||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||||
|
|
||||||
from ..dependencies import ApiDependencies
|
from ..dependencies import ApiDependencies
|
||||||
@ -415,3 +416,57 @@ async def sync_models_to_config() -> Response:
|
|||||||
"""
|
"""
|
||||||
ApiDependencies.invoker.services.model_install.sync_to_config()
|
ApiDependencies.invoker.services.model_install.sync_to_config()
|
||||||
return Response(status_code=204)
|
return Response(status_code=204)
|
||||||
|
|
||||||
|
|
||||||
|
@model_records_router.put(
|
||||||
|
"/merge",
|
||||||
|
operation_id="merge",
|
||||||
|
)
|
||||||
|
async def merge(
|
||||||
|
keys: List[str] = Body(description="Keys for two to three models to merge", min_length=2, max_length=3),
|
||||||
|
merged_model_name: Optional[str] = Body(description="Name of destination model", default=None),
|
||||||
|
alpha: float = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5),
|
||||||
|
force: bool = Body(
|
||||||
|
description="Force merging of models created with different versions of diffusers",
|
||||||
|
default=False,
|
||||||
|
),
|
||||||
|
interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method", default=None),
|
||||||
|
merge_dest_directory: Optional[str] = Body(
|
||||||
|
description="Save the merged model to the designated directory (with 'merged_model_name' appended)",
|
||||||
|
default=None,
|
||||||
|
),
|
||||||
|
) -> AnyModelConfig:
|
||||||
|
"""
|
||||||
|
Merge diffusers models.
|
||||||
|
|
||||||
|
keys: List of 2-3 model keys to merge together. All models must use the same base type.
|
||||||
|
merged_model_name: Name for the merged model [Concat model names]
|
||||||
|
alpha: Alpha value (0.0-1.0). Higher values give more weight to the second model [0.5]
|
||||||
|
force: If true, force the merge even if the models were generated by different versions of the diffusers library [False]
|
||||||
|
interp: Interpolation method. One of "weighted_sum", "sigmoid", "inv_sigmoid" or "add_difference" [weighted_sum]
|
||||||
|
merge_dest_directory: Specify a directory to store the merged model in [models directory]
|
||||||
|
"""
|
||||||
|
print(f"here i am, keys={keys}")
|
||||||
|
logger = ApiDependencies.invoker.services.logger
|
||||||
|
try:
|
||||||
|
logger.info(f"Merging models: {keys} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
|
||||||
|
dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None
|
||||||
|
installer = ApiDependencies.invoker.services.model_install
|
||||||
|
merger = ModelMerger(installer)
|
||||||
|
model_names = [installer.record_store.get_model(x).name for x in keys]
|
||||||
|
response = merger.merge_diffusion_models_and_save(
|
||||||
|
model_keys=keys,
|
||||||
|
merged_model_name=merged_model_name or "+".join(model_names),
|
||||||
|
alpha=alpha,
|
||||||
|
interp=interp,
|
||||||
|
force=force,
|
||||||
|
merge_dest_directory=dest,
|
||||||
|
)
|
||||||
|
except UnknownModelException:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail=f"One or more of the models '{keys}' not found",
|
||||||
|
)
|
||||||
|
except ValueError as e:
|
||||||
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
|
return response
|
||||||
|
@ -208,7 +208,6 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
|||||||
job = self._queue.get(timeout=1)
|
job = self._queue.get(timeout=1)
|
||||||
except Empty:
|
except Empty:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
job.job_started = get_iso_timestamp()
|
job.job_started = get_iso_timestamp()
|
||||||
self._do_download(job)
|
self._do_download(job)
|
||||||
|
@ -165,8 +165,8 @@ class ModelInstallJob(BaseModel):
|
|||||||
)
|
)
|
||||||
source: ModelSource = Field(description="Source (URL, repo_id, or local path) of model")
|
source: ModelSource = Field(description="Source (URL, repo_id, or local path) of model")
|
||||||
local_path: Path = Field(description="Path to locally-downloaded model; may be the same as the source")
|
local_path: Path = Field(description="Path to locally-downloaded model; may be the same as the source")
|
||||||
bytes: Optional[int] = Field(
|
bytes: int = Field(
|
||||||
default=None, description="For a remote model, the number of bytes downloaded so far (may not be available)"
|
default=0, description="For a remote model, the number of bytes downloaded so far (may not be available)"
|
||||||
)
|
)
|
||||||
total_bytes: int = Field(default=0, description="Total size of the model to be installed")
|
total_bytes: int = Field(default=0, description="Total size of the model to be installed")
|
||||||
source_metadata: Optional[AnyModelRepoMetadata] = Field(
|
source_metadata: Optional[AnyModelRepoMetadata] = Field(
|
||||||
|
@ -535,19 +535,19 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
def _import_from_url(self, source: URLModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
|
def _import_from_url(self, source: URLModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
|
||||||
# URLs from Civitai or HuggingFace will be handled specially
|
# URLs from Civitai or HuggingFace will be handled specially
|
||||||
url_patterns = {
|
url_patterns = {
|
||||||
r"https?://civitai.com/": CivitaiMetadataFetch,
|
r"^https?://civitai.com/": CivitaiMetadataFetch,
|
||||||
r"https?://huggingface.co/": HuggingFaceMetadataFetch,
|
r"^https?://huggingface.co/[^/]+/[^/]+$": HuggingFaceMetadataFetch,
|
||||||
}
|
}
|
||||||
metadata = None
|
metadata = None
|
||||||
for pattern, fetcher in url_patterns.items():
|
for pattern, fetcher in url_patterns.items():
|
||||||
if re.match(pattern, str(source.url), re.IGNORECASE):
|
if re.match(pattern, str(source.url), re.IGNORECASE):
|
||||||
metadata = fetcher(self._session).from_url(source.url)
|
metadata = fetcher(self._session).from_url(source.url)
|
||||||
break
|
break
|
||||||
|
self._logger.debug(f"metadata={metadata}")
|
||||||
if metadata and isinstance(metadata, ModelMetadataWithFiles):
|
if metadata and isinstance(metadata, ModelMetadataWithFiles):
|
||||||
remote_files = metadata.download_urls(session=self._session)
|
remote_files = metadata.download_urls(session=self._session)
|
||||||
else:
|
else:
|
||||||
remote_files = [RemoteModelFile(url=source.url, path=Path("."), size=0)]
|
remote_files = [RemoteModelFile(url=source.url, path=Path("."), size=0)]
|
||||||
|
|
||||||
return self._import_remote_model(
|
return self._import_remote_model(
|
||||||
source=source,
|
source=source,
|
||||||
config=config,
|
config=config,
|
||||||
@ -586,6 +586,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
assert install_job.total_bytes is not None # to avoid type checking complaints in the loop below
|
assert install_job.total_bytes is not None # to avoid type checking complaints in the loop below
|
||||||
|
|
||||||
self._logger.info(f"Queuing {source} for downloading")
|
self._logger.info(f"Queuing {source} for downloading")
|
||||||
|
self._logger.debug(f"remote_files={remote_files}")
|
||||||
for model_file in remote_files:
|
for model_file in remote_files:
|
||||||
url = model_file.url
|
url = model_file.url
|
||||||
path = model_file.path
|
path = model_file.path
|
||||||
|
@ -72,7 +72,12 @@ class MigrateModelYamlToDb1:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
base_type, model_type, model_name = str(model_key).split("/")
|
base_type, model_type, model_name = str(model_key).split("/")
|
||||||
hash = FastModelHash.hash(self.config.models_path / stanza.path)
|
try:
|
||||||
|
hash = FastModelHash.hash(self.config.models_path / stanza.path)
|
||||||
|
except OSError:
|
||||||
|
self.logger.warning(f"The model at {stanza.path} is not a valid file or directory. Skipping migration.")
|
||||||
|
continue
|
||||||
|
|
||||||
assert isinstance(model_key, str)
|
assert isinstance(model_key, str)
|
||||||
new_key = sha1(model_key.encode("utf-8")).hexdigest()
|
new_key = sha1(model_key.encode("utf-8")).hexdigest()
|
||||||
|
|
||||||
|
281
invokeai/backend/install/install_helper.py
Normal file
281
invokeai/backend/install/install_helper.py
Normal file
@ -0,0 +1,281 @@
|
|||||||
|
"""Utility (backend) functions used by model_install.py"""
|
||||||
|
import re
|
||||||
|
from logging import Logger
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import omegaconf
|
||||||
|
from huggingface_hub import HfFolder
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from pydantic.dataclasses import dataclass
|
||||||
|
from pydantic.networks import AnyHttpUrl
|
||||||
|
from requests import HTTPError
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
import invokeai.configs as configs
|
||||||
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
|
from invokeai.app.services.download import DownloadQueueService
|
||||||
|
from invokeai.app.services.events.events_base import EventServiceBase
|
||||||
|
from invokeai.app.services.image_files.image_files_disk import DiskImageFileStorage
|
||||||
|
from invokeai.app.services.model_install import (
|
||||||
|
HFModelSource,
|
||||||
|
LocalModelSource,
|
||||||
|
ModelInstallService,
|
||||||
|
ModelInstallServiceBase,
|
||||||
|
ModelSource,
|
||||||
|
URLModelSource,
|
||||||
|
)
|
||||||
|
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL
|
||||||
|
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
|
||||||
|
from invokeai.backend.model_manager import (
|
||||||
|
BaseModelType,
|
||||||
|
InvalidModelConfigException,
|
||||||
|
ModelType,
|
||||||
|
)
|
||||||
|
from invokeai.backend.model_manager.metadata import UnknownMetadataException
|
||||||
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
|
|
||||||
|
# name of the starter models file
|
||||||
|
INITIAL_MODELS = "INITIAL_MODELS2.yaml"
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_record_store(app_config: InvokeAIAppConfig) -> ModelRecordServiceBase:
|
||||||
|
"""Return an initialized ModelConfigRecordServiceBase object."""
|
||||||
|
logger = InvokeAILogger.get_logger(config=app_config)
|
||||||
|
image_files = DiskImageFileStorage(f"{app_config.output_path}/images")
|
||||||
|
db = init_db(config=app_config, logger=logger, image_files=image_files)
|
||||||
|
obj: ModelRecordServiceBase = ModelRecordServiceSQL(db)
|
||||||
|
return obj
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_installer(
|
||||||
|
app_config: InvokeAIAppConfig, event_bus: Optional[EventServiceBase] = None
|
||||||
|
) -> ModelInstallServiceBase:
|
||||||
|
"""Return an initialized ModelInstallService object."""
|
||||||
|
record_store = initialize_record_store(app_config)
|
||||||
|
metadata_store = record_store.metadata_store
|
||||||
|
download_queue = DownloadQueueService()
|
||||||
|
installer = ModelInstallService(
|
||||||
|
app_config=app_config,
|
||||||
|
record_store=record_store,
|
||||||
|
metadata_store=metadata_store,
|
||||||
|
download_queue=download_queue,
|
||||||
|
event_bus=event_bus,
|
||||||
|
)
|
||||||
|
download_queue.start()
|
||||||
|
installer.start()
|
||||||
|
return installer
|
||||||
|
|
||||||
|
|
||||||
|
class UnifiedModelInfo(BaseModel):
|
||||||
|
"""Catchall class for information in INITIAL_MODELS2.yaml."""
|
||||||
|
|
||||||
|
name: Optional[str] = None
|
||||||
|
base: Optional[BaseModelType] = None
|
||||||
|
type: Optional[ModelType] = None
|
||||||
|
source: Optional[str] = None
|
||||||
|
subfolder: Optional[str] = None
|
||||||
|
description: Optional[str] = None
|
||||||
|
recommended: bool = False
|
||||||
|
installed: bool = False
|
||||||
|
default: bool = False
|
||||||
|
requires: List[str] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class InstallSelections:
|
||||||
|
"""Lists of models to install and remove."""
|
||||||
|
|
||||||
|
install_models: List[UnifiedModelInfo] = Field(default_factory=list)
|
||||||
|
remove_models: List[str] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class TqdmEventService(EventServiceBase):
|
||||||
|
"""An event service to track downloads."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
"""Create a new TqdmEventService object."""
|
||||||
|
super().__init__()
|
||||||
|
self._bars: Dict[str, tqdm] = {}
|
||||||
|
self._last: Dict[str, int] = {}
|
||||||
|
|
||||||
|
def dispatch(self, event_name: str, payload: Any) -> None:
|
||||||
|
"""Dispatch an event by appending it to self.events."""
|
||||||
|
if payload["event"] == "model_install_downloading":
|
||||||
|
data = payload["data"]
|
||||||
|
dest = data["local_path"]
|
||||||
|
total_bytes = data["total_bytes"]
|
||||||
|
bytes = data["bytes"]
|
||||||
|
if dest not in self._bars:
|
||||||
|
self._bars[dest] = tqdm(desc=Path(dest).name, initial=0, total=total_bytes, unit="iB", unit_scale=True)
|
||||||
|
self._last[dest] = 0
|
||||||
|
self._bars[dest].update(bytes - self._last[dest])
|
||||||
|
self._last[dest] = bytes
|
||||||
|
|
||||||
|
|
||||||
|
class InstallHelper(object):
|
||||||
|
"""Capture information stored jointly in INITIAL_MODELS.yaml and the installed models db."""
|
||||||
|
|
||||||
|
def __init__(self, app_config: InvokeAIAppConfig, logger: Logger):
|
||||||
|
"""Create new InstallHelper object."""
|
||||||
|
self._app_config = app_config
|
||||||
|
self.all_models: Dict[str, UnifiedModelInfo] = {}
|
||||||
|
|
||||||
|
omega = omegaconf.OmegaConf.load(Path(configs.__path__[0]) / INITIAL_MODELS)
|
||||||
|
assert isinstance(omega, omegaconf.dictconfig.DictConfig)
|
||||||
|
|
||||||
|
self._installer = initialize_installer(app_config, TqdmEventService())
|
||||||
|
self._initial_models = omega
|
||||||
|
self._installed_models: List[str] = []
|
||||||
|
self._starter_models: List[str] = []
|
||||||
|
self._default_model: Optional[str] = None
|
||||||
|
self._logger = logger
|
||||||
|
self._initialize_model_lists()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def installer(self) -> ModelInstallServiceBase:
|
||||||
|
"""Return the installer object used internally."""
|
||||||
|
return self._installer
|
||||||
|
|
||||||
|
def _initialize_model_lists(self) -> None:
|
||||||
|
"""
|
||||||
|
Initialize our model slots.
|
||||||
|
|
||||||
|
Set up the following:
|
||||||
|
installed_models -- list of installed model keys
|
||||||
|
starter_models -- list of starter model keys from INITIAL_MODELS
|
||||||
|
all_models -- dict of key => UnifiedModelInfo
|
||||||
|
default_model -- key to default model
|
||||||
|
"""
|
||||||
|
# previously-installed models
|
||||||
|
for model in self._installer.record_store.all_models():
|
||||||
|
info = UnifiedModelInfo.parse_obj(model.dict())
|
||||||
|
info.installed = True
|
||||||
|
model_key = f"{model.base.value}/{model.type.value}/{model.name}"
|
||||||
|
self.all_models[model_key] = info
|
||||||
|
self._installed_models.append(model_key)
|
||||||
|
|
||||||
|
for key in self._initial_models.keys():
|
||||||
|
assert isinstance(key, str)
|
||||||
|
if key in self.all_models:
|
||||||
|
# we want to preserve the description
|
||||||
|
description = self.all_models[key].description or self._initial_models[key].get("description")
|
||||||
|
self.all_models[key].description = description
|
||||||
|
else:
|
||||||
|
base_model, model_type, model_name = key.split("/")
|
||||||
|
info = UnifiedModelInfo(
|
||||||
|
name=model_name,
|
||||||
|
type=ModelType(model_type),
|
||||||
|
base=BaseModelType(base_model),
|
||||||
|
source=self._initial_models[key].source,
|
||||||
|
description=self._initial_models[key].get("description"),
|
||||||
|
recommended=self._initial_models[key].get("recommended", False),
|
||||||
|
default=self._initial_models[key].get("default", False),
|
||||||
|
subfolder=self._initial_models[key].get("subfolder"),
|
||||||
|
requires=list(self._initial_models[key].get("requires", [])),
|
||||||
|
)
|
||||||
|
self.all_models[key] = info
|
||||||
|
if not self.default_model():
|
||||||
|
self._default_model = key
|
||||||
|
elif self._initial_models[key].get("default", False):
|
||||||
|
self._default_model = key
|
||||||
|
self._starter_models.append(key)
|
||||||
|
|
||||||
|
# previously-installed models
|
||||||
|
for model in self._installer.record_store.all_models():
|
||||||
|
info = UnifiedModelInfo.parse_obj(model.dict())
|
||||||
|
info.installed = True
|
||||||
|
model_key = f"{model.base.value}/{model.type.value}/{model.name}"
|
||||||
|
self.all_models[model_key] = info
|
||||||
|
self._installed_models.append(model_key)
|
||||||
|
|
||||||
|
def recommended_models(self) -> List[UnifiedModelInfo]:
|
||||||
|
"""List of the models recommended in INITIAL_MODELS.yaml."""
|
||||||
|
return [self._to_model(x) for x in self._starter_models if self._to_model(x).recommended]
|
||||||
|
|
||||||
|
def installed_models(self) -> List[UnifiedModelInfo]:
|
||||||
|
"""List of models already installed."""
|
||||||
|
return [self._to_model(x) for x in self._installed_models]
|
||||||
|
|
||||||
|
def starter_models(self) -> List[UnifiedModelInfo]:
|
||||||
|
"""List of starter models."""
|
||||||
|
return [self._to_model(x) for x in self._starter_models]
|
||||||
|
|
||||||
|
def default_model(self) -> Optional[UnifiedModelInfo]:
|
||||||
|
"""Return the default model."""
|
||||||
|
return self._to_model(self._default_model) if self._default_model else None
|
||||||
|
|
||||||
|
def _to_model(self, key: str) -> UnifiedModelInfo:
|
||||||
|
return self.all_models[key]
|
||||||
|
|
||||||
|
def _add_required_models(self, model_list: List[UnifiedModelInfo]) -> None:
|
||||||
|
installed = {x.source for x in self.installed_models()}
|
||||||
|
reverse_source = {x.source: x for x in self.all_models.values()}
|
||||||
|
additional_models: List[UnifiedModelInfo] = []
|
||||||
|
for model_info in model_list:
|
||||||
|
for requirement in model_info.requires:
|
||||||
|
if requirement not in installed and reverse_source.get(requirement):
|
||||||
|
additional_models.append(reverse_source[requirement])
|
||||||
|
model_list.extend(additional_models)
|
||||||
|
|
||||||
|
def _make_install_source(self, model_info: UnifiedModelInfo) -> ModelSource:
|
||||||
|
assert model_info.source
|
||||||
|
model_path_id_or_url = model_info.source.strip("\"' ")
|
||||||
|
model_path = Path(model_path_id_or_url)
|
||||||
|
|
||||||
|
if model_path.exists(): # local file on disk
|
||||||
|
return LocalModelSource(path=model_path.absolute(), inplace=True)
|
||||||
|
if re.match(r"^[^/]+/[^/]+$", model_path_id_or_url): # hugging face repo_id
|
||||||
|
return HFModelSource(
|
||||||
|
repo_id=model_path_id_or_url,
|
||||||
|
access_token=HfFolder.get_token(),
|
||||||
|
subfolder=model_info.subfolder,
|
||||||
|
)
|
||||||
|
if re.match(r"^(http|https):", model_path_id_or_url):
|
||||||
|
return URLModelSource(url=AnyHttpUrl(model_path_id_or_url))
|
||||||
|
raise ValueError(f"Unsupported model source: {model_path_id_or_url}")
|
||||||
|
|
||||||
|
def add_or_delete(self, selections: InstallSelections) -> None:
|
||||||
|
"""Add or delete selected models."""
|
||||||
|
installer = self._installer
|
||||||
|
self._add_required_models(selections.install_models)
|
||||||
|
for model in selections.install_models:
|
||||||
|
source = self._make_install_source(model)
|
||||||
|
config = (
|
||||||
|
{
|
||||||
|
"description": model.description,
|
||||||
|
"name": model.name,
|
||||||
|
}
|
||||||
|
if model.name
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
installer.import_model(
|
||||||
|
source=source,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
except (UnknownMetadataException, InvalidModelConfigException, HTTPError, OSError) as e:
|
||||||
|
self._logger.warning(f"{source}: {e}")
|
||||||
|
|
||||||
|
for model_to_remove in selections.remove_models:
|
||||||
|
parts = model_to_remove.split("/")
|
||||||
|
if len(parts) == 1:
|
||||||
|
base_model, model_type, model_name = (None, None, model_to_remove)
|
||||||
|
else:
|
||||||
|
base_model, model_type, model_name = parts
|
||||||
|
matches = installer.record_store.search_by_attr(
|
||||||
|
base_model=BaseModelType(base_model) if base_model else None,
|
||||||
|
model_type=ModelType(model_type) if model_type else None,
|
||||||
|
model_name=model_name,
|
||||||
|
)
|
||||||
|
if len(matches) > 1:
|
||||||
|
print(f"{model} is ambiguous. Please use model_type:model_name (e.g. main:my_model) to disambiguate.")
|
||||||
|
elif not matches:
|
||||||
|
print(f"{model}: unknown model")
|
||||||
|
else:
|
||||||
|
for m in matches:
|
||||||
|
print(f"Deleting {m.type}:{m.name}")
|
||||||
|
installer.delete(m.key)
|
||||||
|
|
||||||
|
installer.wait_for_installs()
|
@ -849,7 +849,7 @@ def migrate_if_needed(opt: Namespace, root: Path) -> bool:
|
|||||||
|
|
||||||
|
|
||||||
# -------------------------------------
|
# -------------------------------------
|
||||||
def main():
|
def main() -> None:
|
||||||
parser = argparse.ArgumentParser(description="InvokeAI model downloader")
|
parser = argparse.ArgumentParser(description="InvokeAI model downloader")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--skip-sd-weights",
|
"--skip-sd-weights",
|
||||||
|
177
invokeai/backend/model_manager/merge.py
Normal file
177
invokeai/backend/model_manager/merge.py
Normal file
@ -0,0 +1,177 @@
|
|||||||
|
"""
|
||||||
|
invokeai.backend.model_manager.merge exports:
|
||||||
|
merge_diffusion_models() -- combine multiple models by location and return a pipeline object
|
||||||
|
merge_diffusion_models_and_commit() -- combine multiple models by ModelManager ID and write to models.yaml
|
||||||
|
|
||||||
|
Copyright (c) 2023 Lincoln Stein and the InvokeAI Development Team
|
||||||
|
"""
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
from enum import Enum
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, List, Optional, Set
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from diffusers import AutoPipelineForText2Image
|
||||||
|
from diffusers import logging as dlogging
|
||||||
|
|
||||||
|
from invokeai.app.services.model_install import ModelInstallServiceBase
|
||||||
|
from invokeai.backend.util.devices import choose_torch_device, torch_dtype
|
||||||
|
|
||||||
|
from . import (
|
||||||
|
AnyModelConfig,
|
||||||
|
BaseModelType,
|
||||||
|
ModelType,
|
||||||
|
ModelVariantType,
|
||||||
|
)
|
||||||
|
from .config import MainDiffusersConfig
|
||||||
|
|
||||||
|
|
||||||
|
class MergeInterpolationMethod(str, Enum):
|
||||||
|
WeightedSum = "weighted_sum"
|
||||||
|
Sigmoid = "sigmoid"
|
||||||
|
InvSigmoid = "inv_sigmoid"
|
||||||
|
AddDifference = "add_difference"
|
||||||
|
|
||||||
|
|
||||||
|
class ModelMerger(object):
|
||||||
|
"""Wrapper class for model merge function."""
|
||||||
|
|
||||||
|
def __init__(self, installer: ModelInstallServiceBase):
|
||||||
|
"""
|
||||||
|
Initialize a ModelMerger object.
|
||||||
|
|
||||||
|
:param store: Underlying storage manager for the running process.
|
||||||
|
:param config: InvokeAIAppConfig object (if not provided, default will be selected).
|
||||||
|
"""
|
||||||
|
self._installer = installer
|
||||||
|
|
||||||
|
def merge_diffusion_models(
|
||||||
|
self,
|
||||||
|
model_paths: List[Path],
|
||||||
|
alpha: float = 0.5,
|
||||||
|
interp: Optional[MergeInterpolationMethod] = None,
|
||||||
|
force: bool = False,
|
||||||
|
variant: Optional[str] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Any: # pipe.merge is an untyped function.
|
||||||
|
"""
|
||||||
|
:param model_paths: up to three models, designated by their local paths or HuggingFace repo_ids
|
||||||
|
:param alpha: The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
|
||||||
|
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
|
||||||
|
:param interp: The interpolation method to use for the merging. Supports "sigmoid", "inv_sigmoid", "add_difference" and None.
|
||||||
|
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported.
|
||||||
|
:param force: Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
|
||||||
|
|
||||||
|
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
|
||||||
|
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
|
||||||
|
"""
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.simplefilter("ignore")
|
||||||
|
verbosity = dlogging.get_verbosity()
|
||||||
|
dlogging.set_verbosity_error()
|
||||||
|
dtype = torch.float16 if variant == "fp16" else torch_dtype(choose_torch_device())
|
||||||
|
|
||||||
|
# Note that checkpoint_merger will not work with downloaded HuggingFace fp16 models
|
||||||
|
# until upstream https://github.com/huggingface/diffusers/pull/6670 is merged and released.
|
||||||
|
pipe = AutoPipelineForText2Image.from_pretrained(
|
||||||
|
model_paths[0],
|
||||||
|
custom_pipeline="checkpoint_merger",
|
||||||
|
torch_dtype=dtype,
|
||||||
|
variant=variant,
|
||||||
|
)
|
||||||
|
merged_pipe = pipe.merge(
|
||||||
|
pretrained_model_name_or_path_list=model_paths,
|
||||||
|
alpha=alpha,
|
||||||
|
interp=interp.value if interp else None, # diffusers API treats None as "weighted sum"
|
||||||
|
force=force,
|
||||||
|
torch_dtype=dtype,
|
||||||
|
variant=variant,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
dlogging.set_verbosity(verbosity)
|
||||||
|
return merged_pipe
|
||||||
|
|
||||||
|
def merge_diffusion_models_and_save(
|
||||||
|
self,
|
||||||
|
model_keys: List[str],
|
||||||
|
merged_model_name: str,
|
||||||
|
alpha: float = 0.5,
|
||||||
|
force: bool = False,
|
||||||
|
interp: Optional[MergeInterpolationMethod] = None,
|
||||||
|
merge_dest_directory: Optional[Path] = None,
|
||||||
|
variant: Optional[str] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> AnyModelConfig:
|
||||||
|
"""
|
||||||
|
:param models: up to three models, designated by their InvokeAI models.yaml model name
|
||||||
|
:param merged_model_name: name for new model
|
||||||
|
:param alpha: The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
|
||||||
|
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
|
||||||
|
:param interp: The interpolation method to use for the merging. Supports "weighted_average", "sigmoid", "inv_sigmoid", "add_difference" and None.
|
||||||
|
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported. Add_difference is A+(B-C).
|
||||||
|
:param force: Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
|
||||||
|
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
|
||||||
|
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
|
||||||
|
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
|
||||||
|
"""
|
||||||
|
model_paths: List[Path] = []
|
||||||
|
model_names: List[str] = []
|
||||||
|
config = self._installer.app_config
|
||||||
|
store = self._installer.record_store
|
||||||
|
base_models: Set[BaseModelType] = set()
|
||||||
|
vae = None
|
||||||
|
variant = None if self._installer.app_config.full_precision else "fp16"
|
||||||
|
|
||||||
|
assert (
|
||||||
|
len(model_keys) <= 2 or interp == MergeInterpolationMethod.AddDifference
|
||||||
|
), "When merging three models, only the 'add_difference' merge method is supported"
|
||||||
|
|
||||||
|
for key in model_keys:
|
||||||
|
info = store.get_model(key)
|
||||||
|
model_names.append(info.name)
|
||||||
|
assert isinstance(
|
||||||
|
info, MainDiffusersConfig
|
||||||
|
), f"{info.name} ({info.key}) is not a diffusers model. It must be optimized before merging"
|
||||||
|
assert info.variant == ModelVariantType(
|
||||||
|
"normal"
|
||||||
|
), f"{info.name} ({info.key}) is a {info.variant} model, which cannot currently be merged"
|
||||||
|
|
||||||
|
# pick up the first model's vae
|
||||||
|
if key == model_keys[0]:
|
||||||
|
vae = info.vae
|
||||||
|
|
||||||
|
# tally base models used
|
||||||
|
base_models.add(info.base)
|
||||||
|
model_paths.extend([config.models_path / info.path])
|
||||||
|
|
||||||
|
assert len(base_models) == 1, f"All models to merge must have same base model, but found bases {base_models}"
|
||||||
|
base_model = base_models.pop()
|
||||||
|
|
||||||
|
merge_method = None if interp == "weighted_sum" else MergeInterpolationMethod(interp)
|
||||||
|
merged_pipe = self.merge_diffusion_models(model_paths, alpha, merge_method, force, variant=variant, **kwargs)
|
||||||
|
dump_path = (
|
||||||
|
Path(merge_dest_directory)
|
||||||
|
if merge_dest_directory
|
||||||
|
else config.models_path / base_model.value / ModelType.Main.value
|
||||||
|
)
|
||||||
|
dump_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
dump_path = dump_path / merged_model_name
|
||||||
|
|
||||||
|
dtype = torch.float16 if variant == "fp16" else torch_dtype(choose_torch_device())
|
||||||
|
merged_pipe.save_pretrained(dump_path.as_posix(), safe_serialization=True, torch_dtype=dtype, variant=variant)
|
||||||
|
|
||||||
|
# register model and get its unique key
|
||||||
|
key = self._installer.register_path(dump_path)
|
||||||
|
|
||||||
|
# update model's config
|
||||||
|
model_config = self._installer.record_store.get_model(key)
|
||||||
|
model_config.update(
|
||||||
|
{
|
||||||
|
"name": merged_model_name,
|
||||||
|
"description": f"Merge of models {', '.join(model_names)}",
|
||||||
|
"vae": vae,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
self._installer.record_store.update_model(key, model_config)
|
||||||
|
return model_config
|
@ -170,6 +170,8 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
|
|||||||
if model_id is None:
|
if model_id is None:
|
||||||
version_url = CIVITAI_VERSION_ENDPOINT + str(version_id)
|
version_url = CIVITAI_VERSION_ENDPOINT + str(version_id)
|
||||||
version = self._requests.get(version_url).json()
|
version = self._requests.get(version_url).json()
|
||||||
|
if error := version.get("error"):
|
||||||
|
raise UnknownMetadataException(error)
|
||||||
model_id = version["modelId"]
|
model_id = version["modelId"]
|
||||||
|
|
||||||
model_url = CIVITAI_MODEL_ENDPOINT + str(model_id)
|
model_url = CIVITAI_MODEL_ENDPOINT + str(model_id)
|
||||||
|
@ -11,6 +11,7 @@ import logging
|
|||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
|
from argparse import Namespace
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
@ -30,8 +31,6 @@ from diffusers.optimization import get_scheduler
|
|||||||
from diffusers.utils import check_min_version
|
from diffusers.utils import check_min_version
|
||||||
from diffusers.utils.import_utils import is_xformers_available
|
from diffusers.utils.import_utils import is_xformers_available
|
||||||
from huggingface_hub import HfFolder, Repository, whoami
|
from huggingface_hub import HfFolder, Repository, whoami
|
||||||
|
|
||||||
# TODO: remove and import from diffusers.utils when the new version of diffusers is released
|
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
@ -41,8 +40,8 @@ from transformers import CLIPTextModel, CLIPTokenizer
|
|||||||
|
|
||||||
# invokeai stuff
|
# invokeai stuff
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig, PagingArgumentParser
|
from invokeai.app.services.config import InvokeAIAppConfig, PagingArgumentParser
|
||||||
from invokeai.app.services.model_manager import ModelManagerService
|
from invokeai.backend.install.install_helper import initialize_record_store
|
||||||
from invokeai.backend.model_management.models import SubModelType
|
from invokeai.backend.model_manager import BaseModelType, ModelType
|
||||||
|
|
||||||
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
|
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
|
||||||
PIL_INTERPOLATION = {
|
PIL_INTERPOLATION = {
|
||||||
@ -77,7 +76,7 @@ def save_progress(text_encoder, placeholder_token_id, accelerator, placeholder_t
|
|||||||
torch.save(learned_embeds_dict, save_path)
|
torch.save(learned_embeds_dict, save_path)
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args() -> Namespace:
|
||||||
config = InvokeAIAppConfig.get_config()
|
config = InvokeAIAppConfig.get_config()
|
||||||
parser = PagingArgumentParser(description="Textual inversion training")
|
parser = PagingArgumentParser(description="Textual inversion training")
|
||||||
general_group = parser.add_argument_group("General")
|
general_group = parser.add_argument_group("General")
|
||||||
@ -444,7 +443,7 @@ class TextualInversionDataset(Dataset):
|
|||||||
self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small
|
self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small
|
||||||
self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)
|
self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self) -> int:
|
||||||
return self._length
|
return self._length
|
||||||
|
|
||||||
def __getitem__(self, i):
|
def __getitem__(self, i):
|
||||||
@ -509,11 +508,10 @@ def do_textual_inversion_training(
|
|||||||
initializer_token: str,
|
initializer_token: str,
|
||||||
save_steps: int = 500,
|
save_steps: int = 500,
|
||||||
only_save_embeds: bool = False,
|
only_save_embeds: bool = False,
|
||||||
revision: str = None,
|
tokenizer_name: Optional[str] = None,
|
||||||
tokenizer_name: str = None,
|
|
||||||
learnable_property: str = "object",
|
learnable_property: str = "object",
|
||||||
repeats: int = 100,
|
repeats: int = 100,
|
||||||
seed: int = None,
|
seed: Optional[int] = None,
|
||||||
resolution: int = 512,
|
resolution: int = 512,
|
||||||
center_crop: bool = False,
|
center_crop: bool = False,
|
||||||
train_batch_size: int = 16,
|
train_batch_size: int = 16,
|
||||||
@ -530,18 +528,18 @@ def do_textual_inversion_training(
|
|||||||
adam_weight_decay: float = 1e-02,
|
adam_weight_decay: float = 1e-02,
|
||||||
adam_epsilon: float = 1e-08,
|
adam_epsilon: float = 1e-08,
|
||||||
push_to_hub: bool = False,
|
push_to_hub: bool = False,
|
||||||
hub_token: str = None,
|
hub_token: Optional[str] = None,
|
||||||
logging_dir: Path = Path("logs"),
|
logging_dir: Path = Path("logs"),
|
||||||
mixed_precision: str = "fp16",
|
mixed_precision: str = "fp16",
|
||||||
allow_tf32: bool = False,
|
allow_tf32: bool = False,
|
||||||
report_to: str = "tensorboard",
|
report_to: str = "tensorboard",
|
||||||
local_rank: int = -1,
|
local_rank: int = -1,
|
||||||
checkpointing_steps: int = 500,
|
checkpointing_steps: int = 500,
|
||||||
resume_from_checkpoint: Path = None,
|
resume_from_checkpoint: Optional[Path] = None,
|
||||||
enable_xformers_memory_efficient_attention: bool = False,
|
enable_xformers_memory_efficient_attention: bool = False,
|
||||||
hub_model_id: str = None,
|
hub_model_id: Optional[str] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
) -> None:
|
||||||
assert model, "Please specify a base model with --model"
|
assert model, "Please specify a base model with --model"
|
||||||
assert train_data_dir, "Please specify a directory containing the training images using --train_data_dir"
|
assert train_data_dir, "Please specify a directory containing the training images using --train_data_dir"
|
||||||
assert placeholder_token, "Please specify a trigger term using --placeholder_token"
|
assert placeholder_token, "Please specify a trigger term using --placeholder_token"
|
||||||
@ -564,8 +562,6 @@ def do_textual_inversion_training(
|
|||||||
project_config=accelerator_config,
|
project_config=accelerator_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
model_manager = ModelManagerService(config, logger)
|
|
||||||
|
|
||||||
# Make one log on every process with the configuration for debugging.
|
# Make one log on every process with the configuration for debugging.
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||||
@ -603,44 +599,37 @@ def do_textual_inversion_training(
|
|||||||
elif output_dir is not None:
|
elif output_dir is not None:
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
known_models = model_manager.model_names()
|
model_records = initialize_record_store(config)
|
||||||
model_name = model.split("/")[-1]
|
base, type, name = model.split("/") # note frontend still returns old-style keys
|
||||||
model_meta = next((mm for mm in known_models if mm[0].endswith(model_name)), None)
|
try:
|
||||||
assert model_meta is not None, f"Unknown model: {model}"
|
model_config = model_records.search_by_attr(
|
||||||
model_info = model_manager.model_info(*model_meta)
|
model_name=name, model_type=ModelType(type), base_model=BaseModelType(base)
|
||||||
assert model_info["model_format"] == "diffusers", "This script only works with models of type 'diffusers'"
|
)[0]
|
||||||
tokenizer_info = model_manager.get_model(*model_meta, submodel=SubModelType.Tokenizer)
|
except IndexError:
|
||||||
noise_scheduler_info = model_manager.get_model(*model_meta, submodel=SubModelType.Scheduler)
|
raise Exception(f"Unknown model {model}")
|
||||||
text_encoder_info = model_manager.get_model(*model_meta, submodel=SubModelType.TextEncoder)
|
model_path = config.models_path / model_config.path
|
||||||
vae_info = model_manager.get_model(*model_meta, submodel=SubModelType.Vae)
|
|
||||||
unet_info = model_manager.get_model(*model_meta, submodel=SubModelType.UNet)
|
|
||||||
|
|
||||||
pipeline_args = {"local_files_only": True}
|
pipeline_args = {"local_files_only": True}
|
||||||
if tokenizer_name:
|
if tokenizer_name:
|
||||||
tokenizer = CLIPTokenizer.from_pretrained(tokenizer_name, **pipeline_args)
|
tokenizer = CLIPTokenizer.from_pretrained(tokenizer_name, **pipeline_args)
|
||||||
else:
|
else:
|
||||||
tokenizer = CLIPTokenizer.from_pretrained(tokenizer_info.location, subfolder="tokenizer", **pipeline_args)
|
tokenizer = CLIPTokenizer.from_pretrained(model_path, subfolder="tokenizer", **pipeline_args)
|
||||||
|
|
||||||
# Load scheduler and models
|
# Load scheduler and models
|
||||||
noise_scheduler = DDPMScheduler.from_pretrained(
|
noise_scheduler = DDPMScheduler.from_pretrained(model_path, subfolder="scheduler", **pipeline_args)
|
||||||
noise_scheduler_info.location, subfolder="scheduler", **pipeline_args
|
|
||||||
)
|
|
||||||
text_encoder = CLIPTextModel.from_pretrained(
|
text_encoder = CLIPTextModel.from_pretrained(
|
||||||
text_encoder_info.location,
|
model_path,
|
||||||
subfolder="text_encoder",
|
subfolder="text_encoder",
|
||||||
revision=revision,
|
|
||||||
**pipeline_args,
|
**pipeline_args,
|
||||||
)
|
)
|
||||||
vae = AutoencoderKL.from_pretrained(
|
vae = AutoencoderKL.from_pretrained(
|
||||||
vae_info.location,
|
model_path,
|
||||||
subfolder="vae",
|
subfolder="vae",
|
||||||
revision=revision,
|
|
||||||
**pipeline_args,
|
**pipeline_args,
|
||||||
)
|
)
|
||||||
unet = UNet2DConditionModel.from_pretrained(
|
unet = UNet2DConditionModel.from_pretrained(
|
||||||
unet_info.location,
|
model_path,
|
||||||
subfolder="unet",
|
subfolder="unet",
|
||||||
revision=revision,
|
|
||||||
**pipeline_args,
|
**pipeline_args,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -728,7 +717,7 @@ def do_textual_inversion_training(
|
|||||||
max_train_steps = num_train_epochs * num_update_steps_per_epoch
|
max_train_steps = num_train_epochs * num_update_steps_per_epoch
|
||||||
overrode_max_train_steps = True
|
overrode_max_train_steps = True
|
||||||
|
|
||||||
lr_scheduler = get_scheduler(
|
scheduler = get_scheduler(
|
||||||
lr_scheduler,
|
lr_scheduler,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
num_warmup_steps=lr_warmup_steps * gradient_accumulation_steps,
|
num_warmup_steps=lr_warmup_steps * gradient_accumulation_steps,
|
||||||
@ -737,7 +726,7 @@ def do_textual_inversion_training(
|
|||||||
|
|
||||||
# Prepare everything with our `accelerator`.
|
# Prepare everything with our `accelerator`.
|
||||||
text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||||
text_encoder, optimizer, train_dataloader, lr_scheduler
|
text_encoder, optimizer, train_dataloader, scheduler
|
||||||
)
|
)
|
||||||
|
|
||||||
# For mixed precision training we cast the unet and vae weights to half-precision
|
# For mixed precision training we cast the unet and vae weights to half-precision
|
||||||
@ -863,7 +852,7 @@ def do_textual_inversion_training(
|
|||||||
accelerator.backward(loss)
|
accelerator.backward(loss)
|
||||||
|
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
lr_scheduler.step()
|
scheduler.step()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
# Let's make sure we don't update any embedding weights besides the newly added token
|
# Let's make sure we don't update any embedding weights besides the newly added token
|
||||||
@ -893,7 +882,7 @@ def do_textual_inversion_training(
|
|||||||
accelerator.save_state(save_path)
|
accelerator.save_state(save_path)
|
||||||
logger.info(f"Saved state to {save_path}")
|
logger.info(f"Saved state to {save_path}")
|
||||||
|
|
||||||
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
logs = {"loss": loss.detach().item(), "lr": scheduler.get_last_lr()[0]}
|
||||||
progress_bar.set_postfix(**logs)
|
progress_bar.set_postfix(**logs)
|
||||||
accelerator.log(logs, step=global_step)
|
accelerator.log(logs, step=global_step)
|
||||||
|
|
||||||
@ -910,7 +899,7 @@ def do_textual_inversion_training(
|
|||||||
save_full_model = not only_save_embeds
|
save_full_model = not only_save_embeds
|
||||||
if save_full_model:
|
if save_full_model:
|
||||||
pipeline = StableDiffusionPipeline.from_pretrained(
|
pipeline = StableDiffusionPipeline.from_pretrained(
|
||||||
unet_info.location,
|
model_path,
|
||||||
text_encoder=accelerator.unwrap_model(text_encoder),
|
text_encoder=accelerator.unwrap_model(text_encoder),
|
||||||
vae=vae,
|
vae=vae,
|
||||||
unet=unet,
|
unet=unet,
|
||||||
|
157
invokeai/configs/INITIAL_MODELS2.yaml
Normal file
157
invokeai/configs/INITIAL_MODELS2.yaml
Normal file
@ -0,0 +1,157 @@
|
|||||||
|
# This file predefines a few models that the user may want to install.
|
||||||
|
sd-1/main/stable-diffusion-v1-5:
|
||||||
|
description: Stable Diffusion version 1.5 diffusers model (4.27 GB)
|
||||||
|
source: runwayml/stable-diffusion-v1-5
|
||||||
|
recommended: True
|
||||||
|
default: True
|
||||||
|
sd-1/main/stable-diffusion-v1-5-inpainting:
|
||||||
|
description: RunwayML SD 1.5 model optimized for inpainting, diffusers version (4.27 GB)
|
||||||
|
source: runwayml/stable-diffusion-inpainting
|
||||||
|
recommended: True
|
||||||
|
sd-2/main/stable-diffusion-2-1:
|
||||||
|
description: Stable Diffusion version 2.1 diffusers model, trained on 768 pixel images (5.21 GB)
|
||||||
|
source: stabilityai/stable-diffusion-2-1
|
||||||
|
recommended: False
|
||||||
|
sd-2/main/stable-diffusion-2-inpainting:
|
||||||
|
description: Stable Diffusion version 2.0 inpainting model (5.21 GB)
|
||||||
|
source: stabilityai/stable-diffusion-2-inpainting
|
||||||
|
recommended: False
|
||||||
|
sdxl/main/stable-diffusion-xl-base-1-0:
|
||||||
|
description: Stable Diffusion XL base model (12 GB)
|
||||||
|
source: stabilityai/stable-diffusion-xl-base-1.0
|
||||||
|
recommended: True
|
||||||
|
sdxl-refiner/main/stable-diffusion-xl-refiner-1-0:
|
||||||
|
description: Stable Diffusion XL refiner model (12 GB)
|
||||||
|
source: stabilityai/stable-diffusion-xl-refiner-1.0
|
||||||
|
recommended: False
|
||||||
|
sdxl/vae/sdxl-vae-fp16-fix:
|
||||||
|
description: Version of the SDXL-1.0 VAE that works in half precision mode
|
||||||
|
source: madebyollin/sdxl-vae-fp16-fix
|
||||||
|
recommended: True
|
||||||
|
sd-1/main/Analog-Diffusion:
|
||||||
|
description: An SD-1.5 model trained on diverse analog photographs (2.13 GB)
|
||||||
|
source: wavymulder/Analog-Diffusion
|
||||||
|
recommended: False
|
||||||
|
sd-1/main/Deliberate:
|
||||||
|
description: Versatile model that produces detailed images up to 768px (4.27 GB)
|
||||||
|
source: XpucT/Deliberate
|
||||||
|
recommended: False
|
||||||
|
sd-1/main/Dungeons-and-Diffusion:
|
||||||
|
description: Dungeons & Dragons characters (2.13 GB)
|
||||||
|
source: 0xJustin/Dungeons-and-Diffusion
|
||||||
|
recommended: False
|
||||||
|
sd-1/main/dreamlike-photoreal-2:
|
||||||
|
description: A photorealistic model trained on 768 pixel images based on SD 1.5 (2.13 GB)
|
||||||
|
source: dreamlike-art/dreamlike-photoreal-2.0
|
||||||
|
recommended: False
|
||||||
|
sd-1/main/Inkpunk-Diffusion:
|
||||||
|
description: Stylized illustrations inspired by Gorillaz, FLCL and Shinkawa; prompt with "nvinkpunk" (4.27 GB)
|
||||||
|
source: Envvi/Inkpunk-Diffusion
|
||||||
|
recommended: False
|
||||||
|
sd-1/main/openjourney:
|
||||||
|
description: An SD 1.5 model fine tuned on Midjourney; prompt with "mdjrny-v4 style" (2.13 GB)
|
||||||
|
source: prompthero/openjourney
|
||||||
|
recommended: False
|
||||||
|
sd-1/main/seek.art_MEGA:
|
||||||
|
source: coreco/seek.art_MEGA
|
||||||
|
description: A general use SD-1.5 "anything" model that supports multiple styles (2.1 GB)
|
||||||
|
recommended: False
|
||||||
|
sd-1/main/trinart_stable_diffusion_v2:
|
||||||
|
description: An SD-1.5 model finetuned with ~40K assorted high resolution manga/anime-style images (2.13 GB)
|
||||||
|
source: naclbit/trinart_stable_diffusion_v2
|
||||||
|
recommended: False
|
||||||
|
sd-1/controlnet/qrcode_monster:
|
||||||
|
source: monster-labs/control_v1p_sd15_qrcode_monster
|
||||||
|
subfolder: v2
|
||||||
|
sd-1/controlnet/canny:
|
||||||
|
source: lllyasviel/control_v11p_sd15_canny
|
||||||
|
recommended: True
|
||||||
|
sd-1/controlnet/inpaint:
|
||||||
|
source: lllyasviel/control_v11p_sd15_inpaint
|
||||||
|
sd-1/controlnet/mlsd:
|
||||||
|
source: lllyasviel/control_v11p_sd15_mlsd
|
||||||
|
sd-1/controlnet/depth:
|
||||||
|
source: lllyasviel/control_v11f1p_sd15_depth
|
||||||
|
recommended: True
|
||||||
|
sd-1/controlnet/normal_bae:
|
||||||
|
source: lllyasviel/control_v11p_sd15_normalbae
|
||||||
|
sd-1/controlnet/seg:
|
||||||
|
source: lllyasviel/control_v11p_sd15_seg
|
||||||
|
sd-1/controlnet/lineart:
|
||||||
|
source: lllyasviel/control_v11p_sd15_lineart
|
||||||
|
recommended: True
|
||||||
|
sd-1/controlnet/lineart_anime:
|
||||||
|
source: lllyasviel/control_v11p_sd15s2_lineart_anime
|
||||||
|
sd-1/controlnet/openpose:
|
||||||
|
source: lllyasviel/control_v11p_sd15_openpose
|
||||||
|
recommended: True
|
||||||
|
sd-1/controlnet/scribble:
|
||||||
|
source: lllyasviel/control_v11p_sd15_scribble
|
||||||
|
recommended: False
|
||||||
|
sd-1/controlnet/softedge:
|
||||||
|
source: lllyasviel/control_v11p_sd15_softedge
|
||||||
|
sd-1/controlnet/shuffle:
|
||||||
|
source: lllyasviel/control_v11e_sd15_shuffle
|
||||||
|
sd-1/controlnet/tile:
|
||||||
|
source: lllyasviel/control_v11f1e_sd15_tile
|
||||||
|
sd-1/controlnet/ip2p:
|
||||||
|
source: lllyasviel/control_v11e_sd15_ip2p
|
||||||
|
sd-1/t2i_adapter/canny-sd15:
|
||||||
|
source: TencentARC/t2iadapter_canny_sd15v2
|
||||||
|
sd-1/t2i_adapter/sketch-sd15:
|
||||||
|
source: TencentARC/t2iadapter_sketch_sd15v2
|
||||||
|
sd-1/t2i_adapter/depth-sd15:
|
||||||
|
source: TencentARC/t2iadapter_depth_sd15v2
|
||||||
|
sd-1/t2i_adapter/zoedepth-sd15:
|
||||||
|
source: TencentARC/t2iadapter_zoedepth_sd15v1
|
||||||
|
sdxl/t2i_adapter/canny-sdxl:
|
||||||
|
source: TencentARC/t2i-adapter-canny-sdxl-1.0
|
||||||
|
sdxl/t2i_adapter/zoedepth-sdxl:
|
||||||
|
source: TencentARC/t2i-adapter-depth-zoe-sdxl-1.0
|
||||||
|
sdxl/t2i_adapter/lineart-sdxl:
|
||||||
|
source: TencentARC/t2i-adapter-lineart-sdxl-1.0
|
||||||
|
sdxl/t2i_adapter/sketch-sdxl:
|
||||||
|
source: TencentARC/t2i-adapter-sketch-sdxl-1.0
|
||||||
|
sd-1/embedding/EasyNegative:
|
||||||
|
source: https://huggingface.co/embed/EasyNegative/resolve/main/EasyNegative.safetensors
|
||||||
|
recommended: True
|
||||||
|
description: A textual inversion to use in the negative prompt to reduce bad anatomy
|
||||||
|
sd-1/lora/FlatColor:
|
||||||
|
source: https://civitai.com/models/6433/loraflatcolor
|
||||||
|
recommended: True
|
||||||
|
description: A LoRA that generates scenery using solid blocks of color
|
||||||
|
sd-1/lora/Ink scenery:
|
||||||
|
source: https://civitai.com/api/download/models/83390
|
||||||
|
description: Generate india ink-like landscapes
|
||||||
|
sd-1/ip_adapter/ip_adapter_sd15:
|
||||||
|
source: InvokeAI/ip_adapter_sd15
|
||||||
|
recommended: True
|
||||||
|
requires:
|
||||||
|
- InvokeAI/ip_adapter_sd_image_encoder
|
||||||
|
description: IP-Adapter for SD 1.5 models
|
||||||
|
sd-1/ip_adapter/ip_adapter_plus_sd15:
|
||||||
|
source: InvokeAI/ip_adapter_plus_sd15
|
||||||
|
recommended: False
|
||||||
|
requires:
|
||||||
|
- InvokeAI/ip_adapter_sd_image_encoder
|
||||||
|
description: Refined IP-Adapter for SD 1.5 models
|
||||||
|
sd-1/ip_adapter/ip_adapter_plus_face_sd15:
|
||||||
|
source: InvokeAI/ip_adapter_plus_face_sd15
|
||||||
|
recommended: False
|
||||||
|
requires:
|
||||||
|
- InvokeAI/ip_adapter_sd_image_encoder
|
||||||
|
description: Refined IP-Adapter for SD 1.5 models, adapted for faces
|
||||||
|
sdxl/ip_adapter/ip_adapter_sdxl:
|
||||||
|
source: InvokeAI/ip_adapter_sdxl
|
||||||
|
recommended: False
|
||||||
|
requires:
|
||||||
|
- InvokeAI/ip_adapter_sdxl_image_encoder
|
||||||
|
description: IP-Adapter for SDXL models
|
||||||
|
any/clip_vision/ip_adapter_sd_image_encoder:
|
||||||
|
source: InvokeAI/ip_adapter_sd_image_encoder
|
||||||
|
recommended: False
|
||||||
|
description: Required model for using IP-Adapters with SD-1/2 models
|
||||||
|
any/clip_vision/ip_adapter_sdxl_image_encoder:
|
||||||
|
source: InvokeAI/ip_adapter_sdxl_image_encoder
|
||||||
|
recommended: False
|
||||||
|
description: Required model for using IP-Adapters with SDXL models
|
@ -2,3 +2,5 @@
|
|||||||
Wrapper for invokeai.backend.configure.invokeai_configure
|
Wrapper for invokeai.backend.configure.invokeai_configure
|
||||||
"""
|
"""
|
||||||
from ...backend.install.invokeai_configure import main as invokeai_configure # noqa: F401
|
from ...backend.install.invokeai_configure import main as invokeai_configure # noqa: F401
|
||||||
|
|
||||||
|
__all__ = ["invokeai_configure"]
|
||||||
|
645
invokeai/frontend/install/model_install2.py
Normal file
645
invokeai/frontend/install/model_install2.py
Normal file
@ -0,0 +1,645 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# Copyright (c) 2022 Lincoln D. Stein (https://github.com/lstein)
|
||||||
|
# Before running stable-diffusion on an internet-isolated machine,
|
||||||
|
# run this script from one with internet connectivity. The
|
||||||
|
# two machines must share a common .cache directory.
|
||||||
|
|
||||||
|
"""
|
||||||
|
This is the npyscreen frontend to the model installation application.
|
||||||
|
It is currently named model_install2.py, but will ultimately replace model_install.py.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import curses
|
||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
import warnings
|
||||||
|
from argparse import Namespace
|
||||||
|
from shutil import get_terminal_size
|
||||||
|
from typing import Any, Dict, List, Optional, Set
|
||||||
|
|
||||||
|
import npyscreen
|
||||||
|
import torch
|
||||||
|
from npyscreen import widget
|
||||||
|
|
||||||
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
|
from invokeai.app.services.model_install import ModelInstallService
|
||||||
|
from invokeai.backend.install.install_helper import InstallHelper, InstallSelections, UnifiedModelInfo
|
||||||
|
from invokeai.backend.model_manager import ModelType
|
||||||
|
from invokeai.backend.util import choose_precision, choose_torch_device
|
||||||
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
|
from invokeai.frontend.install.widgets import (
|
||||||
|
MIN_COLS,
|
||||||
|
MIN_LINES,
|
||||||
|
CenteredTitleText,
|
||||||
|
CyclingForm,
|
||||||
|
MultiSelectColumns,
|
||||||
|
SingleSelectColumns,
|
||||||
|
TextBox,
|
||||||
|
WindowTooSmallException,
|
||||||
|
set_min_terminal_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
warnings.filterwarnings("ignore", category=UserWarning) # noqa: E402
|
||||||
|
config = InvokeAIAppConfig.get_config()
|
||||||
|
logger = InvokeAILogger.get_logger("ModelInstallService")
|
||||||
|
logger.setLevel("WARNING")
|
||||||
|
# logger.setLevel('DEBUG')
|
||||||
|
|
||||||
|
# build a table mapping all non-printable characters to None
|
||||||
|
# for stripping control characters
|
||||||
|
# from https://stackoverflow.com/questions/92438/stripping-non-printable-characters-from-a-string-in-python
|
||||||
|
NOPRINT_TRANS_TABLE = {i: None for i in range(0, sys.maxunicode + 1) if not chr(i).isprintable()}
|
||||||
|
|
||||||
|
# maximum number of installed models we can display before overflowing vertically
|
||||||
|
MAX_OTHER_MODELS = 72
|
||||||
|
|
||||||
|
|
||||||
|
def make_printable(s: str) -> str:
|
||||||
|
"""Replace non-printable characters in a string."""
|
||||||
|
return s.translate(NOPRINT_TRANS_TABLE)
|
||||||
|
|
||||||
|
|
||||||
|
class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||||
|
"""Main form for interactive TUI."""
|
||||||
|
|
||||||
|
# for responsive resizing set to False, but this seems to cause a crash!
|
||||||
|
FIX_MINIMUM_SIZE_WHEN_CREATED = True
|
||||||
|
|
||||||
|
# for persistence
|
||||||
|
current_tab = 0
|
||||||
|
|
||||||
|
def __init__(self, parentApp: npyscreen.NPSAppManaged, name: str, multipage: bool = False, **keywords: Any):
|
||||||
|
self.multipage = multipage
|
||||||
|
self.subprocess = None
|
||||||
|
super().__init__(parentApp=parentApp, name=name, **keywords)
|
||||||
|
|
||||||
|
def create(self) -> None:
|
||||||
|
self.installer = self.parentApp.install_helper.installer
|
||||||
|
self.model_labels = self._get_model_labels()
|
||||||
|
self.keypress_timeout = 10
|
||||||
|
self.counter = 0
|
||||||
|
self.subprocess_connection = None
|
||||||
|
|
||||||
|
window_width, window_height = get_terminal_size()
|
||||||
|
|
||||||
|
# npyscreen has no typing hints
|
||||||
|
self.nextrely -= 1 # type: ignore
|
||||||
|
self.add_widget_intelligent(
|
||||||
|
npyscreen.FixedText,
|
||||||
|
value="Use ctrl-N and ctrl-P to move to the <N>ext and <P>revious fields. Cursor keys navigate, and <space> selects.",
|
||||||
|
editable=False,
|
||||||
|
color="CAUTION",
|
||||||
|
)
|
||||||
|
self.nextrely += 1 # type: ignore
|
||||||
|
self.tabs = self.add_widget_intelligent(
|
||||||
|
SingleSelectColumns,
|
||||||
|
values=[
|
||||||
|
"STARTERS",
|
||||||
|
"MAINS",
|
||||||
|
"CONTROLNETS",
|
||||||
|
"T2I-ADAPTERS",
|
||||||
|
"IP-ADAPTERS",
|
||||||
|
"LORAS",
|
||||||
|
"TI EMBEDDINGS",
|
||||||
|
],
|
||||||
|
value=[self.current_tab],
|
||||||
|
columns=7,
|
||||||
|
max_height=2,
|
||||||
|
relx=8,
|
||||||
|
scroll_exit=True,
|
||||||
|
)
|
||||||
|
self.tabs.on_changed = self._toggle_tables
|
||||||
|
|
||||||
|
top_of_table = self.nextrely # type: ignore
|
||||||
|
self.starter_pipelines = self.add_starter_pipelines()
|
||||||
|
bottom_of_table = self.nextrely # type: ignore
|
||||||
|
|
||||||
|
self.nextrely = top_of_table
|
||||||
|
self.pipeline_models = self.add_pipeline_widgets(
|
||||||
|
model_type=ModelType.Main, window_width=window_width, exclude=self.starter_models
|
||||||
|
)
|
||||||
|
# self.pipeline_models['autoload_pending'] = True
|
||||||
|
bottom_of_table = max(bottom_of_table, self.nextrely)
|
||||||
|
|
||||||
|
self.nextrely = top_of_table
|
||||||
|
self.controlnet_models = self.add_model_widgets(
|
||||||
|
model_type=ModelType.ControlNet,
|
||||||
|
window_width=window_width,
|
||||||
|
)
|
||||||
|
bottom_of_table = max(bottom_of_table, self.nextrely)
|
||||||
|
|
||||||
|
self.nextrely = top_of_table
|
||||||
|
self.t2i_models = self.add_model_widgets(
|
||||||
|
model_type=ModelType.T2IAdapter,
|
||||||
|
window_width=window_width,
|
||||||
|
)
|
||||||
|
bottom_of_table = max(bottom_of_table, self.nextrely)
|
||||||
|
self.nextrely = top_of_table
|
||||||
|
self.ipadapter_models = self.add_model_widgets(
|
||||||
|
model_type=ModelType.IPAdapter,
|
||||||
|
window_width=window_width,
|
||||||
|
)
|
||||||
|
bottom_of_table = max(bottom_of_table, self.nextrely)
|
||||||
|
|
||||||
|
self.nextrely = top_of_table
|
||||||
|
self.lora_models = self.add_model_widgets(
|
||||||
|
model_type=ModelType.Lora,
|
||||||
|
window_width=window_width,
|
||||||
|
)
|
||||||
|
bottom_of_table = max(bottom_of_table, self.nextrely)
|
||||||
|
|
||||||
|
self.nextrely = top_of_table
|
||||||
|
self.ti_models = self.add_model_widgets(
|
||||||
|
model_type=ModelType.TextualInversion,
|
||||||
|
window_width=window_width,
|
||||||
|
)
|
||||||
|
bottom_of_table = max(bottom_of_table, self.nextrely)
|
||||||
|
|
||||||
|
self.nextrely = bottom_of_table + 1
|
||||||
|
|
||||||
|
self.nextrely += 1
|
||||||
|
back_label = "BACK"
|
||||||
|
cancel_label = "CANCEL"
|
||||||
|
current_position = self.nextrely
|
||||||
|
if self.multipage:
|
||||||
|
self.back_button = self.add_widget_intelligent(
|
||||||
|
npyscreen.ButtonPress,
|
||||||
|
name=back_label,
|
||||||
|
when_pressed_function=self.on_back,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.nextrely = current_position
|
||||||
|
self.cancel_button = self.add_widget_intelligent(
|
||||||
|
npyscreen.ButtonPress, name=cancel_label, when_pressed_function=self.on_cancel
|
||||||
|
)
|
||||||
|
self.nextrely = current_position
|
||||||
|
|
||||||
|
label = "APPLY CHANGES"
|
||||||
|
self.nextrely = current_position
|
||||||
|
self.done = self.add_widget_intelligent(
|
||||||
|
npyscreen.ButtonPress,
|
||||||
|
name=label,
|
||||||
|
relx=window_width - len(label) - 15,
|
||||||
|
when_pressed_function=self.on_done,
|
||||||
|
)
|
||||||
|
|
||||||
|
# This restores the selected page on return from an installation
|
||||||
|
for _i in range(1, self.current_tab + 1):
|
||||||
|
self.tabs.h_cursor_line_down(1)
|
||||||
|
self._toggle_tables([self.current_tab])
|
||||||
|
|
||||||
|
############# diffusers tab ##########
|
||||||
|
def add_starter_pipelines(self) -> dict[str, npyscreen.widget]:
|
||||||
|
"""Add widgets responsible for selecting diffusers models"""
|
||||||
|
widgets: Dict[str, npyscreen.widget] = {}
|
||||||
|
|
||||||
|
all_models = self.all_models # master dict of all models, indexed by key
|
||||||
|
model_list = [x for x in self.starter_models if all_models[x].type in ["main", "vae"]]
|
||||||
|
model_labels = [self.model_labels[x] for x in model_list]
|
||||||
|
|
||||||
|
widgets.update(
|
||||||
|
label1=self.add_widget_intelligent(
|
||||||
|
CenteredTitleText,
|
||||||
|
name="Select from a starter set of Stable Diffusion models from HuggingFace and Civitae.",
|
||||||
|
editable=False,
|
||||||
|
labelColor="CAUTION",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.nextrely -= 1
|
||||||
|
# if user has already installed some initial models, then don't patronize them
|
||||||
|
# by showing more recommendations
|
||||||
|
show_recommended = len(self.installed_models) == 0
|
||||||
|
|
||||||
|
checked = [
|
||||||
|
model_list.index(x)
|
||||||
|
for x in model_list
|
||||||
|
if (show_recommended and all_models[x].recommended) or all_models[x].installed
|
||||||
|
]
|
||||||
|
widgets.update(
|
||||||
|
models_selected=self.add_widget_intelligent(
|
||||||
|
MultiSelectColumns,
|
||||||
|
columns=1,
|
||||||
|
name="Install Starter Models",
|
||||||
|
values=model_labels,
|
||||||
|
value=checked,
|
||||||
|
max_height=len(model_list) + 1,
|
||||||
|
relx=4,
|
||||||
|
scroll_exit=True,
|
||||||
|
),
|
||||||
|
models=model_list,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.nextrely += 1
|
||||||
|
return widgets
|
||||||
|
|
||||||
|
############# Add a set of model install widgets ########
|
||||||
|
def add_model_widgets(
|
||||||
|
self,
|
||||||
|
model_type: ModelType,
|
||||||
|
window_width: int = 120,
|
||||||
|
install_prompt: Optional[str] = None,
|
||||||
|
exclude: Optional[Set[str]] = None,
|
||||||
|
) -> dict[str, npyscreen.widget]:
|
||||||
|
"""Generic code to create model selection widgets"""
|
||||||
|
if exclude is None:
|
||||||
|
exclude = set()
|
||||||
|
widgets: Dict[str, npyscreen.widget] = {}
|
||||||
|
all_models = self.all_models
|
||||||
|
model_list = sorted(
|
||||||
|
[x for x in all_models if all_models[x].type == model_type and x not in exclude],
|
||||||
|
key=lambda x: all_models[x].name or "",
|
||||||
|
)
|
||||||
|
model_labels = [self.model_labels[x] for x in model_list]
|
||||||
|
|
||||||
|
show_recommended = len(self.installed_models) == 0
|
||||||
|
truncated = False
|
||||||
|
if len(model_list) > 0:
|
||||||
|
max_width = max([len(x) for x in model_labels])
|
||||||
|
columns = window_width // (max_width + 8) # 8 characters for "[x] " and padding
|
||||||
|
columns = min(len(model_list), columns) or 1
|
||||||
|
prompt = (
|
||||||
|
install_prompt
|
||||||
|
or f"Select the desired {model_type.value.title()} models to install. Unchecked models will be purged from disk."
|
||||||
|
)
|
||||||
|
|
||||||
|
widgets.update(
|
||||||
|
label1=self.add_widget_intelligent(
|
||||||
|
CenteredTitleText,
|
||||||
|
name=prompt,
|
||||||
|
editable=False,
|
||||||
|
labelColor="CAUTION",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(model_labels) > MAX_OTHER_MODELS:
|
||||||
|
model_labels = model_labels[0:MAX_OTHER_MODELS]
|
||||||
|
truncated = True
|
||||||
|
|
||||||
|
widgets.update(
|
||||||
|
models_selected=self.add_widget_intelligent(
|
||||||
|
MultiSelectColumns,
|
||||||
|
columns=columns,
|
||||||
|
name=f"Install {model_type} Models",
|
||||||
|
values=model_labels,
|
||||||
|
value=[
|
||||||
|
model_list.index(x)
|
||||||
|
for x in model_list
|
||||||
|
if (show_recommended and all_models[x].recommended) or all_models[x].installed
|
||||||
|
],
|
||||||
|
max_height=len(model_list) // columns + 1,
|
||||||
|
relx=4,
|
||||||
|
scroll_exit=True,
|
||||||
|
),
|
||||||
|
models=model_list,
|
||||||
|
)
|
||||||
|
|
||||||
|
if truncated:
|
||||||
|
widgets.update(
|
||||||
|
warning_message=self.add_widget_intelligent(
|
||||||
|
npyscreen.FixedText,
|
||||||
|
value=f"Too many models to display (max={MAX_OTHER_MODELS}). Some are not displayed.",
|
||||||
|
editable=False,
|
||||||
|
color="CAUTION",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.nextrely += 1
|
||||||
|
widgets.update(
|
||||||
|
download_ids=self.add_widget_intelligent(
|
||||||
|
TextBox,
|
||||||
|
name="Additional URLs, or HuggingFace repo_ids to install (Space separated. Use shift-control-V to paste):",
|
||||||
|
max_height=6,
|
||||||
|
scroll_exit=True,
|
||||||
|
editable=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return widgets
|
||||||
|
|
||||||
|
### Tab for arbitrary diffusers widgets ###
|
||||||
|
def add_pipeline_widgets(
|
||||||
|
self,
|
||||||
|
model_type: ModelType = ModelType.Main,
|
||||||
|
window_width: int = 120,
|
||||||
|
**kwargs,
|
||||||
|
) -> dict[str, npyscreen.widget]:
|
||||||
|
"""Similar to add_model_widgets() but adds some additional widgets at the bottom
|
||||||
|
to support the autoload directory"""
|
||||||
|
widgets = self.add_model_widgets(
|
||||||
|
model_type=model_type,
|
||||||
|
window_width=window_width,
|
||||||
|
install_prompt=f"Installed {model_type.value.title()} models. Unchecked models in the InvokeAI root directory will be deleted. Enter URLs, paths or repo_ids to import.",
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
return widgets
|
||||||
|
|
||||||
|
def resize(self) -> None:
|
||||||
|
super().resize()
|
||||||
|
if s := self.starter_pipelines.get("models_selected"):
|
||||||
|
if model_list := self.starter_pipelines.get("models"):
|
||||||
|
s.values = [self.model_labels[x] for x in model_list]
|
||||||
|
|
||||||
|
def _toggle_tables(self, value: List[int]) -> None:
|
||||||
|
selected_tab = value[0]
|
||||||
|
widgets = [
|
||||||
|
self.starter_pipelines,
|
||||||
|
self.pipeline_models,
|
||||||
|
self.controlnet_models,
|
||||||
|
self.t2i_models,
|
||||||
|
self.ipadapter_models,
|
||||||
|
self.lora_models,
|
||||||
|
self.ti_models,
|
||||||
|
]
|
||||||
|
|
||||||
|
for group in widgets:
|
||||||
|
for _k, v in group.items():
|
||||||
|
try:
|
||||||
|
v.hidden = True
|
||||||
|
v.editable = False
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
for _k, v in widgets[selected_tab].items():
|
||||||
|
try:
|
||||||
|
v.hidden = False
|
||||||
|
if not isinstance(v, (npyscreen.FixedText, npyscreen.TitleFixedText, CenteredTitleText)):
|
||||||
|
v.editable = True
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
self.__class__.current_tab = selected_tab # for persistence
|
||||||
|
self.display()
|
||||||
|
|
||||||
|
def _get_model_labels(self) -> dict[str, str]:
|
||||||
|
"""Return a list of trimmed labels for all models."""
|
||||||
|
window_width, window_height = get_terminal_size()
|
||||||
|
checkbox_width = 4
|
||||||
|
spacing_width = 2
|
||||||
|
result = {}
|
||||||
|
|
||||||
|
models = self.all_models
|
||||||
|
label_width = max([len(models[x].name or "") for x in self.starter_models])
|
||||||
|
description_width = window_width - label_width - checkbox_width - spacing_width
|
||||||
|
|
||||||
|
for key in self.all_models:
|
||||||
|
description = models[key].description
|
||||||
|
description = (
|
||||||
|
description[0 : description_width - 3] + "..."
|
||||||
|
if description and len(description) > description_width
|
||||||
|
else description
|
||||||
|
if description
|
||||||
|
else ""
|
||||||
|
)
|
||||||
|
result[key] = f"%-{label_width}s %s" % (models[key].name, description)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _get_columns(self) -> int:
|
||||||
|
window_width, window_height = get_terminal_size()
|
||||||
|
cols = 4 if window_width > 240 else 3 if window_width > 160 else 2 if window_width > 80 else 1
|
||||||
|
return min(cols, len(self.installed_models))
|
||||||
|
|
||||||
|
def confirm_deletions(self, selections: InstallSelections) -> bool:
|
||||||
|
remove_models = selections.remove_models
|
||||||
|
if remove_models:
|
||||||
|
model_names = [self.all_models[x].name or "" for x in remove_models]
|
||||||
|
mods = "\n".join(model_names)
|
||||||
|
is_ok = npyscreen.notify_ok_cancel(
|
||||||
|
f"These unchecked models will be deleted from disk. Continue?\n---------\n{mods}"
|
||||||
|
)
|
||||||
|
assert isinstance(is_ok, bool) # npyscreen doesn't have return type annotations
|
||||||
|
return is_ok
|
||||||
|
else:
|
||||||
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def all_models(self) -> Dict[str, UnifiedModelInfo]:
|
||||||
|
# npyscreen doesn't having typing hints
|
||||||
|
return self.parentApp.install_helper.all_models # type: ignore
|
||||||
|
|
||||||
|
@property
|
||||||
|
def starter_models(self) -> List[str]:
|
||||||
|
return self.parentApp.install_helper._starter_models # type: ignore
|
||||||
|
|
||||||
|
@property
|
||||||
|
def installed_models(self) -> List[str]:
|
||||||
|
return self.parentApp.install_helper._installed_models # type: ignore
|
||||||
|
|
||||||
|
def on_back(self) -> None:
|
||||||
|
self.parentApp.switchFormPrevious()
|
||||||
|
self.editing = False
|
||||||
|
|
||||||
|
def on_cancel(self) -> None:
|
||||||
|
self.parentApp.setNextForm(None)
|
||||||
|
self.parentApp.user_cancelled = True
|
||||||
|
self.editing = False
|
||||||
|
|
||||||
|
def on_done(self) -> None:
|
||||||
|
self.marshall_arguments()
|
||||||
|
if not self.confirm_deletions(self.parentApp.install_selections):
|
||||||
|
return
|
||||||
|
self.parentApp.setNextForm(None)
|
||||||
|
self.parentApp.user_cancelled = False
|
||||||
|
self.editing = False
|
||||||
|
|
||||||
|
def marshall_arguments(self) -> None:
|
||||||
|
"""
|
||||||
|
Assemble arguments and store as attributes of the application:
|
||||||
|
.starter_models: dict of model names to install from INITIAL_CONFIGURE.yaml
|
||||||
|
True => Install
|
||||||
|
False => Remove
|
||||||
|
.scan_directory: Path to a directory of models to scan and import
|
||||||
|
.autoscan_on_startup: True if invokeai should scan and import at startup time
|
||||||
|
.import_model_paths: list of URLs, repo_ids and file paths to import
|
||||||
|
"""
|
||||||
|
selections = self.parentApp.install_selections
|
||||||
|
all_models = self.all_models
|
||||||
|
|
||||||
|
# Defined models (in INITIAL_CONFIG.yaml or models.yaml) to add/remove
|
||||||
|
ui_sections = [
|
||||||
|
self.starter_pipelines,
|
||||||
|
self.pipeline_models,
|
||||||
|
self.controlnet_models,
|
||||||
|
self.t2i_models,
|
||||||
|
self.ipadapter_models,
|
||||||
|
self.lora_models,
|
||||||
|
self.ti_models,
|
||||||
|
]
|
||||||
|
for section in ui_sections:
|
||||||
|
if "models_selected" not in section:
|
||||||
|
continue
|
||||||
|
selected = {section["models"][x] for x in section["models_selected"].value}
|
||||||
|
models_to_install = [x for x in selected if not self.all_models[x].installed]
|
||||||
|
models_to_remove = [x for x in section["models"] if x not in selected and self.all_models[x].installed]
|
||||||
|
selections.remove_models.extend(models_to_remove)
|
||||||
|
selections.install_models.extend([all_models[x] for x in models_to_install])
|
||||||
|
|
||||||
|
# models located in the 'download_ids" section
|
||||||
|
for section in ui_sections:
|
||||||
|
if downloads := section.get("download_ids"):
|
||||||
|
models = [UnifiedModelInfo(source=x) for x in downloads.value.split()]
|
||||||
|
selections.install_models.extend(models)
|
||||||
|
|
||||||
|
|
||||||
|
class AddModelApplication(npyscreen.NPSAppManaged): # type: ignore
|
||||||
|
def __init__(self, opt: Namespace, install_helper: InstallHelper):
|
||||||
|
super().__init__()
|
||||||
|
self.program_opts = opt
|
||||||
|
self.user_cancelled = False
|
||||||
|
self.install_selections = InstallSelections()
|
||||||
|
self.install_helper = install_helper
|
||||||
|
|
||||||
|
def onStart(self) -> None:
|
||||||
|
npyscreen.setTheme(npyscreen.Themes.DefaultTheme)
|
||||||
|
self.main_form = self.addForm(
|
||||||
|
"MAIN",
|
||||||
|
addModelsForm,
|
||||||
|
name="Install Stable Diffusion Models",
|
||||||
|
cycle_widgets=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def list_models(installer: ModelInstallService, model_type: ModelType):
|
||||||
|
"""Print out all models of type model_type."""
|
||||||
|
models = installer.record_store.search_by_attr(model_type=model_type)
|
||||||
|
print(f"Installed models of type `{model_type}`:")
|
||||||
|
for model in models:
|
||||||
|
path = (config.models_path / model.path).resolve()
|
||||||
|
print(f"{model.name:40}{model.base.value:14}{path}")
|
||||||
|
|
||||||
|
|
||||||
|
# --------------------------------------------------------
|
||||||
|
def select_and_download_models(opt: Namespace) -> None:
|
||||||
|
"""Prompt user for install/delete selections and execute."""
|
||||||
|
precision = "float32" if opt.full_precision else choose_precision(torch.device(choose_torch_device()))
|
||||||
|
# unsure how to avoid a typing complaint in the next line: config.precision is an enumerated Literal
|
||||||
|
config.precision = precision # type: ignore
|
||||||
|
install_helper = InstallHelper(config, logger)
|
||||||
|
installer = install_helper.installer
|
||||||
|
|
||||||
|
if opt.list_models:
|
||||||
|
list_models(installer, opt.list_models)
|
||||||
|
|
||||||
|
elif opt.add or opt.delete:
|
||||||
|
selections = InstallSelections(
|
||||||
|
install_models=[UnifiedModelInfo(source=x) for x in (opt.add or [])], remove_models=opt.delete or []
|
||||||
|
)
|
||||||
|
install_helper.add_or_delete(selections)
|
||||||
|
|
||||||
|
elif opt.default_only:
|
||||||
|
selections = InstallSelections(install_models=[install_helper.default_model()])
|
||||||
|
install_helper.add_or_delete(selections)
|
||||||
|
|
||||||
|
elif opt.yes_to_all:
|
||||||
|
selections = InstallSelections(install_models=install_helper.recommended_models())
|
||||||
|
install_helper.add_or_delete(selections)
|
||||||
|
|
||||||
|
# this is where the TUI is called
|
||||||
|
else:
|
||||||
|
if not set_min_terminal_size(MIN_COLS, MIN_LINES):
|
||||||
|
raise WindowTooSmallException(
|
||||||
|
"Could not increase terminal size. Try running again with a larger window or smaller font size."
|
||||||
|
)
|
||||||
|
|
||||||
|
installApp = AddModelApplication(opt, install_helper)
|
||||||
|
try:
|
||||||
|
installApp.run()
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("Aborted...")
|
||||||
|
sys.exit(-1)
|
||||||
|
|
||||||
|
install_helper.add_or_delete(installApp.install_selections)
|
||||||
|
|
||||||
|
|
||||||
|
# -------------------------------------
|
||||||
|
def main() -> None:
|
||||||
|
parser = argparse.ArgumentParser(description="InvokeAI model downloader")
|
||||||
|
parser.add_argument(
|
||||||
|
"--add",
|
||||||
|
nargs="*",
|
||||||
|
help="List of URLs, local paths or repo_ids of models to install",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--delete",
|
||||||
|
nargs="*",
|
||||||
|
help="List of names of models to delete. Use type:name to disambiguate, as in `controlnet:my_model`",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--full-precision",
|
||||||
|
dest="full_precision",
|
||||||
|
action=argparse.BooleanOptionalAction,
|
||||||
|
type=bool,
|
||||||
|
default=False,
|
||||||
|
help="use 32-bit weights instead of faster 16-bit weights",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--yes",
|
||||||
|
"-y",
|
||||||
|
dest="yes_to_all",
|
||||||
|
action="store_true",
|
||||||
|
help='answer "yes" to all prompts',
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--default_only",
|
||||||
|
action="store_true",
|
||||||
|
help="Only install the default model",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--list-models",
|
||||||
|
choices=[x.value for x in ModelType],
|
||||||
|
help="list installed models",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--root_dir",
|
||||||
|
dest="root",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="path to root of install directory",
|
||||||
|
)
|
||||||
|
opt = parser.parse_args()
|
||||||
|
|
||||||
|
invoke_args = []
|
||||||
|
if opt.root:
|
||||||
|
invoke_args.extend(["--root", opt.root])
|
||||||
|
if opt.full_precision:
|
||||||
|
invoke_args.extend(["--precision", "float32"])
|
||||||
|
config.parse_args(invoke_args)
|
||||||
|
logger = InvokeAILogger().get_logger(config=config)
|
||||||
|
|
||||||
|
if not config.model_conf_path.exists():
|
||||||
|
logger.info("Your InvokeAI root directory is not set up. Calling invokeai-configure.")
|
||||||
|
from invokeai.frontend.install.invokeai_configure import invokeai_configure
|
||||||
|
|
||||||
|
invokeai_configure()
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
try:
|
||||||
|
select_and_download_models(opt)
|
||||||
|
except AssertionError as e:
|
||||||
|
logger.error(e)
|
||||||
|
sys.exit(-1)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
curses.nocbreak()
|
||||||
|
curses.echo()
|
||||||
|
curses.endwin()
|
||||||
|
logger.info("Goodbye! Come back soon.")
|
||||||
|
except WindowTooSmallException as e:
|
||||||
|
logger.error(str(e))
|
||||||
|
except widget.NotEnoughSpaceForWidget as e:
|
||||||
|
if str(e).startswith("Height of 1 allocated"):
|
||||||
|
logger.error("Insufficient vertical space for the interface. Please make your window taller and try again")
|
||||||
|
input("Press any key to continue...")
|
||||||
|
except Exception as e:
|
||||||
|
if str(e).startswith("addwstr"):
|
||||||
|
logger.error(
|
||||||
|
"Insufficient horizontal space for the interface. Please make your window wider and try again."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print(f"An exception has occurred: {str(e)} Details:")
|
||||||
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
input("Press any key to continue...")
|
||||||
|
|
||||||
|
|
||||||
|
# -------------------------------------
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
438
invokeai/frontend/merge/merge_diffusers2.py
Normal file
438
invokeai/frontend/merge/merge_diffusers2.py
Normal file
@ -0,0 +1,438 @@
|
|||||||
|
"""
|
||||||
|
invokeai.frontend.merge exports a single function called merge_diffusion_models().
|
||||||
|
|
||||||
|
It merges 2-3 models together and create a new InvokeAI-registered diffusion model.
|
||||||
|
|
||||||
|
Copyright (c) 2023-24 Lincoln Stein and the InvokeAI Development Team
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
import curses
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
from argparse import Namespace
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import npyscreen
|
||||||
|
from npyscreen import widget
|
||||||
|
|
||||||
|
import invokeai.backend.util.logging as logger
|
||||||
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
|
from invokeai.app.services.model_install import ModelInstallServiceBase
|
||||||
|
from invokeai.app.services.model_records import ModelRecordServiceBase
|
||||||
|
from invokeai.backend.install.install_helper import initialize_installer
|
||||||
|
from invokeai.backend.model_manager import (
|
||||||
|
BaseModelType,
|
||||||
|
ModelFormat,
|
||||||
|
ModelType,
|
||||||
|
ModelVariantType,
|
||||||
|
)
|
||||||
|
from invokeai.backend.model_manager.merge import ModelMerger
|
||||||
|
from invokeai.frontend.install.widgets import FloatTitleSlider, SingleSelectColumns, TextBox
|
||||||
|
|
||||||
|
config = InvokeAIAppConfig.get_config()
|
||||||
|
|
||||||
|
BASE_TYPES = [
|
||||||
|
(BaseModelType.StableDiffusion1, "Models Built on SD-1.x"),
|
||||||
|
(BaseModelType.StableDiffusion2, "Models Built on SD-2.x"),
|
||||||
|
(BaseModelType.StableDiffusionXL, "Models Built on SDXL"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_args() -> Namespace:
|
||||||
|
parser = argparse.ArgumentParser(description="InvokeAI model merging")
|
||||||
|
parser.add_argument(
|
||||||
|
"--root_dir",
|
||||||
|
type=Path,
|
||||||
|
default=config.root,
|
||||||
|
help="Path to the invokeai runtime directory",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--front_end",
|
||||||
|
"--gui",
|
||||||
|
dest="front_end",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="Activate the text-based graphical front end for collecting parameters. Aside from --root_dir, other parameters will be ignored.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--models",
|
||||||
|
dest="model_names",
|
||||||
|
type=str,
|
||||||
|
nargs="+",
|
||||||
|
help="Two to three model names to be merged",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--base_model",
|
||||||
|
type=str,
|
||||||
|
choices=[x[0].value for x in BASE_TYPES],
|
||||||
|
help="The base model shared by the models to be merged",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--merged_model_name",
|
||||||
|
"--destination",
|
||||||
|
dest="merged_model_name",
|
||||||
|
type=str,
|
||||||
|
help="Name of the output model. If not specified, will be the concatenation of the input model names.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--alpha",
|
||||||
|
type=float,
|
||||||
|
default=0.5,
|
||||||
|
help="The interpolation parameter, ranging from 0 to 1. It affects the ratio in which the checkpoints are merged. Higher values give more weight to the 2d and 3d models",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--interpolation",
|
||||||
|
dest="interp",
|
||||||
|
type=str,
|
||||||
|
choices=["weighted_sum", "sigmoid", "inv_sigmoid", "add_difference"],
|
||||||
|
default="weighted_sum",
|
||||||
|
help='Interpolation method to use. If three models are present, only "add_difference" will work.',
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--force",
|
||||||
|
action="store_true",
|
||||||
|
help="Try to merge models even if they are incompatible with each other",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--clobber",
|
||||||
|
"--overwrite",
|
||||||
|
dest="clobber",
|
||||||
|
action="store_true",
|
||||||
|
help="Overwrite the merged model if --merged_model_name already exists",
|
||||||
|
)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------- GUI HERE -------------------------
|
||||||
|
class mergeModelsForm(npyscreen.FormMultiPageAction):
|
||||||
|
interpolations = ["weighted_sum", "sigmoid", "inv_sigmoid"]
|
||||||
|
|
||||||
|
def __init__(self, parentApp, name):
|
||||||
|
self.parentApp = parentApp
|
||||||
|
self.ALLOW_RESIZE = True
|
||||||
|
self.FIX_MINIMUM_SIZE_WHEN_CREATED = False
|
||||||
|
super().__init__(parentApp, name)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_record_store(self) -> ModelRecordServiceBase:
|
||||||
|
installer: ModelInstallServiceBase = self.parentApp.installer
|
||||||
|
return installer.record_store
|
||||||
|
|
||||||
|
def afterEditing(self) -> None:
|
||||||
|
self.parentApp.setNextForm(None)
|
||||||
|
|
||||||
|
def create(self) -> None:
|
||||||
|
window_height, window_width = curses.initscr().getmaxyx()
|
||||||
|
self.current_base = 0
|
||||||
|
self.models = self.get_models(BASE_TYPES[self.current_base][0])
|
||||||
|
self.model_names = [x[1] for x in self.models]
|
||||||
|
max_width = max([len(x) for x in self.model_names])
|
||||||
|
max_width += 6
|
||||||
|
horizontal_layout = max_width * 3 < window_width
|
||||||
|
|
||||||
|
self.add_widget_intelligent(
|
||||||
|
npyscreen.FixedText,
|
||||||
|
color="CONTROL",
|
||||||
|
value="Select two models to merge and optionally a third.",
|
||||||
|
editable=False,
|
||||||
|
)
|
||||||
|
self.add_widget_intelligent(
|
||||||
|
npyscreen.FixedText,
|
||||||
|
color="CONTROL",
|
||||||
|
value="Use up and down arrows to move, <space> to select an item, <tab> and <shift-tab> to move from one field to the next.",
|
||||||
|
editable=False,
|
||||||
|
)
|
||||||
|
self.nextrely += 1
|
||||||
|
self.base_select = self.add_widget_intelligent(
|
||||||
|
SingleSelectColumns,
|
||||||
|
values=[x[1] for x in BASE_TYPES],
|
||||||
|
value=[self.current_base],
|
||||||
|
columns=4,
|
||||||
|
max_height=2,
|
||||||
|
relx=8,
|
||||||
|
scroll_exit=True,
|
||||||
|
)
|
||||||
|
self.base_select.on_changed = self._populate_models
|
||||||
|
self.add_widget_intelligent(
|
||||||
|
npyscreen.FixedText,
|
||||||
|
value="MODEL 1",
|
||||||
|
color="GOOD",
|
||||||
|
editable=False,
|
||||||
|
rely=6 if horizontal_layout else None,
|
||||||
|
)
|
||||||
|
self.model1 = self.add_widget_intelligent(
|
||||||
|
npyscreen.SelectOne,
|
||||||
|
values=self.model_names,
|
||||||
|
value=0,
|
||||||
|
max_height=len(self.model_names),
|
||||||
|
max_width=max_width,
|
||||||
|
scroll_exit=True,
|
||||||
|
rely=7,
|
||||||
|
)
|
||||||
|
self.add_widget_intelligent(
|
||||||
|
npyscreen.FixedText,
|
||||||
|
value="MODEL 2",
|
||||||
|
color="GOOD",
|
||||||
|
editable=False,
|
||||||
|
relx=max_width + 3 if horizontal_layout else None,
|
||||||
|
rely=6 if horizontal_layout else None,
|
||||||
|
)
|
||||||
|
self.model2 = self.add_widget_intelligent(
|
||||||
|
npyscreen.SelectOne,
|
||||||
|
name="(2)",
|
||||||
|
values=self.model_names,
|
||||||
|
value=1,
|
||||||
|
max_height=len(self.model_names),
|
||||||
|
max_width=max_width,
|
||||||
|
relx=max_width + 3 if horizontal_layout else None,
|
||||||
|
rely=7 if horizontal_layout else None,
|
||||||
|
scroll_exit=True,
|
||||||
|
)
|
||||||
|
self.add_widget_intelligent(
|
||||||
|
npyscreen.FixedText,
|
||||||
|
value="MODEL 3",
|
||||||
|
color="GOOD",
|
||||||
|
editable=False,
|
||||||
|
relx=max_width * 2 + 3 if horizontal_layout else None,
|
||||||
|
rely=6 if horizontal_layout else None,
|
||||||
|
)
|
||||||
|
models_plus_none = self.model_names.copy()
|
||||||
|
models_plus_none.insert(0, "None")
|
||||||
|
self.model3 = self.add_widget_intelligent(
|
||||||
|
npyscreen.SelectOne,
|
||||||
|
name="(3)",
|
||||||
|
values=models_plus_none,
|
||||||
|
value=0,
|
||||||
|
max_height=len(self.model_names) + 1,
|
||||||
|
max_width=max_width,
|
||||||
|
scroll_exit=True,
|
||||||
|
relx=max_width * 2 + 3 if horizontal_layout else None,
|
||||||
|
rely=7 if horizontal_layout else None,
|
||||||
|
)
|
||||||
|
for m in [self.model1, self.model2, self.model3]:
|
||||||
|
m.when_value_edited = self.models_changed
|
||||||
|
self.merged_model_name = self.add_widget_intelligent(
|
||||||
|
TextBox,
|
||||||
|
name="Name for merged model:",
|
||||||
|
labelColor="CONTROL",
|
||||||
|
max_height=3,
|
||||||
|
value="",
|
||||||
|
scroll_exit=True,
|
||||||
|
)
|
||||||
|
self.force = self.add_widget_intelligent(
|
||||||
|
npyscreen.Checkbox,
|
||||||
|
name="Force merge of models created by different diffusers library versions",
|
||||||
|
labelColor="CONTROL",
|
||||||
|
value=True,
|
||||||
|
scroll_exit=True,
|
||||||
|
)
|
||||||
|
self.nextrely += 1
|
||||||
|
self.merge_method = self.add_widget_intelligent(
|
||||||
|
npyscreen.TitleSelectOne,
|
||||||
|
name="Merge Method:",
|
||||||
|
values=self.interpolations,
|
||||||
|
value=0,
|
||||||
|
labelColor="CONTROL",
|
||||||
|
max_height=len(self.interpolations) + 1,
|
||||||
|
scroll_exit=True,
|
||||||
|
)
|
||||||
|
self.alpha = self.add_widget_intelligent(
|
||||||
|
FloatTitleSlider,
|
||||||
|
name="Weight (alpha) to assign to second and third models:",
|
||||||
|
out_of=1.0,
|
||||||
|
step=0.01,
|
||||||
|
lowest=0,
|
||||||
|
value=0.5,
|
||||||
|
labelColor="CONTROL",
|
||||||
|
scroll_exit=True,
|
||||||
|
)
|
||||||
|
self.model1.editing = True
|
||||||
|
|
||||||
|
def models_changed(self) -> None:
|
||||||
|
models = self.model1.values
|
||||||
|
selected_model1 = self.model1.value[0]
|
||||||
|
selected_model2 = self.model2.value[0]
|
||||||
|
selected_model3 = self.model3.value[0]
|
||||||
|
merged_model_name = f"{models[selected_model1]}+{models[selected_model2]}"
|
||||||
|
self.merged_model_name.value = merged_model_name
|
||||||
|
|
||||||
|
if selected_model3 > 0:
|
||||||
|
self.merge_method.values = ["add_difference ( A+(B-C) )"]
|
||||||
|
self.merged_model_name.value += f"+{models[selected_model3 -1]}" # In model3 there is one more element in the list (None). So we have to subtract one.
|
||||||
|
else:
|
||||||
|
self.merge_method.values = self.interpolations
|
||||||
|
self.merge_method.value = 0
|
||||||
|
|
||||||
|
def on_ok(self) -> None:
|
||||||
|
if self.validate_field_values() and self.check_for_overwrite():
|
||||||
|
self.parentApp.setNextForm(None)
|
||||||
|
self.editing = False
|
||||||
|
self.parentApp.merge_arguments = self.marshall_arguments()
|
||||||
|
npyscreen.notify("Starting the merge...")
|
||||||
|
else:
|
||||||
|
self.editing = True
|
||||||
|
|
||||||
|
def on_cancel(self) -> None:
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
def marshall_arguments(self) -> dict:
|
||||||
|
model_keys = [x[0] for x in self.models]
|
||||||
|
models = [
|
||||||
|
model_keys[self.model1.value[0]],
|
||||||
|
model_keys[self.model2.value[0]],
|
||||||
|
]
|
||||||
|
if self.model3.value[0] > 0:
|
||||||
|
models.append(model_keys[self.model3.value[0] - 1])
|
||||||
|
interp = "add_difference"
|
||||||
|
else:
|
||||||
|
interp = self.interpolations[self.merge_method.value[0]]
|
||||||
|
|
||||||
|
args = {
|
||||||
|
"model_keys": models,
|
||||||
|
"alpha": self.alpha.value,
|
||||||
|
"interp": interp,
|
||||||
|
"force": self.force.value,
|
||||||
|
"merged_model_name": self.merged_model_name.value,
|
||||||
|
}
|
||||||
|
return args
|
||||||
|
|
||||||
|
def check_for_overwrite(self) -> bool:
|
||||||
|
model_out = self.merged_model_name.value
|
||||||
|
if model_out not in self.model_names:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
result: bool = npyscreen.notify_yes_no(
|
||||||
|
f"The chosen merged model destination, {model_out}, is already in use. Overwrite?"
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def validate_field_values(self) -> bool:
|
||||||
|
bad_fields = []
|
||||||
|
model_names = self.model_names
|
||||||
|
selected_models = {model_names[self.model1.value[0]], model_names[self.model2.value[0]]}
|
||||||
|
if self.model3.value[0] > 0:
|
||||||
|
selected_models.add(model_names[self.model3.value[0] - 1])
|
||||||
|
if len(selected_models) < 2:
|
||||||
|
bad_fields.append(f"Please select two or three DIFFERENT models to compare. You selected {selected_models}")
|
||||||
|
if len(bad_fields) > 0:
|
||||||
|
message = "The following problems were detected and must be corrected:"
|
||||||
|
for problem in bad_fields:
|
||||||
|
message += f"\n* {problem}"
|
||||||
|
npyscreen.notify_confirm(message)
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_models(self, base_model: Optional[BaseModelType] = None) -> List[Tuple[str, str]]: # key to name
|
||||||
|
models = [
|
||||||
|
(x.key, x.name)
|
||||||
|
for x in self.model_record_store.search_by_attr(model_type=ModelType.Main, base_model=base_model)
|
||||||
|
if x.format == ModelFormat("diffusers")
|
||||||
|
and hasattr(x, "variant")
|
||||||
|
and x.variant == ModelVariantType("normal")
|
||||||
|
]
|
||||||
|
return sorted(models, key=lambda x: x[1])
|
||||||
|
|
||||||
|
def _populate_models(self, value: List[int]) -> None:
|
||||||
|
base_model = BASE_TYPES[value[0]][0]
|
||||||
|
self.models = self.get_models(base_model)
|
||||||
|
self.model_names = [x[1] for x in self.models]
|
||||||
|
|
||||||
|
models_plus_none = self.model_names.copy()
|
||||||
|
models_plus_none.insert(0, "None")
|
||||||
|
self.model1.values = self.model_names
|
||||||
|
self.model2.values = self.model_names
|
||||||
|
self.model3.values = models_plus_none
|
||||||
|
|
||||||
|
self.display()
|
||||||
|
|
||||||
|
|
||||||
|
# npyscreen is untyped and causes mypy to get naggy
|
||||||
|
class Mergeapp(npyscreen.NPSAppManaged): # type: ignore
|
||||||
|
def __init__(self, installer: ModelInstallServiceBase):
|
||||||
|
"""Initialize the npyscreen application."""
|
||||||
|
super().__init__()
|
||||||
|
self.installer = installer
|
||||||
|
|
||||||
|
def onStart(self) -> None:
|
||||||
|
npyscreen.setTheme(npyscreen.Themes.ElegantTheme)
|
||||||
|
self.main = self.addForm("MAIN", mergeModelsForm, name="Merge Models Settings")
|
||||||
|
|
||||||
|
|
||||||
|
def run_gui(args: Namespace) -> None:
|
||||||
|
installer = initialize_installer(config)
|
||||||
|
mergeapp = Mergeapp(installer)
|
||||||
|
mergeapp.run()
|
||||||
|
merge_args = mergeapp.merge_arguments
|
||||||
|
merger = ModelMerger(installer)
|
||||||
|
merger.merge_diffusion_models_and_save(**merge_args)
|
||||||
|
logger.info(f'Models merged into new model: "{merge_args.merged_model_name}".')
|
||||||
|
|
||||||
|
|
||||||
|
def run_cli(args: Namespace) -> None:
|
||||||
|
assert args.alpha >= 0 and args.alpha <= 1.0, "alpha must be between 0 and 1"
|
||||||
|
assert (
|
||||||
|
args.model_names and len(args.model_names) >= 1 and len(args.model_names) <= 3
|
||||||
|
), "Please provide the --models argument to list 2 to 3 models to merge. Use --help for full usage."
|
||||||
|
|
||||||
|
if not args.merged_model_name:
|
||||||
|
args.merged_model_name = "+".join(args.model_names)
|
||||||
|
logger.info(f'No --merged_model_name provided. Defaulting to "{args.merged_model_name}"')
|
||||||
|
|
||||||
|
installer = initialize_installer(config)
|
||||||
|
store = installer.record_store
|
||||||
|
assert (
|
||||||
|
len(store.search_by_attr(args.merged_model_name, args.base_model, ModelType.Main)) == 0 or args.clobber
|
||||||
|
), f'A model named "{args.merged_model_name}" already exists. Use --clobber to overwrite.'
|
||||||
|
|
||||||
|
merger = ModelMerger(installer)
|
||||||
|
model_keys = []
|
||||||
|
for name in args.model_names:
|
||||||
|
if len(name) == 32 and re.match(r"^[0-9a-f]$", name):
|
||||||
|
model_keys.append(name)
|
||||||
|
else:
|
||||||
|
models = store.search_by_attr(
|
||||||
|
model_name=name, model_type=ModelType.Main, base_model=BaseModelType(args.base_model)
|
||||||
|
)
|
||||||
|
assert len(models) > 0, f"{name}: Unknown model"
|
||||||
|
assert len(models) < 2, f"{name}: More than one model by this name. Please specify the model key instead."
|
||||||
|
model_keys.append(models[0].key)
|
||||||
|
|
||||||
|
merger.merge_diffusion_models_and_save(
|
||||||
|
alpha=args.alpha,
|
||||||
|
model_keys=model_keys,
|
||||||
|
merged_model_name=args.merged_model_name,
|
||||||
|
interp=args.interp,
|
||||||
|
force=args.force,
|
||||||
|
)
|
||||||
|
logger.info(f'Models merged into new model: "{args.merged_model_name}".')
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
args = _parse_args()
|
||||||
|
if args.root_dir:
|
||||||
|
config.parse_args(["--root", str(args.root_dir)])
|
||||||
|
else:
|
||||||
|
config.parse_args([])
|
||||||
|
|
||||||
|
try:
|
||||||
|
if args.front_end:
|
||||||
|
run_gui(args)
|
||||||
|
else:
|
||||||
|
run_cli(args)
|
||||||
|
except widget.NotEnoughSpaceForWidget as e:
|
||||||
|
if str(e).startswith("Height of 1 allocated"):
|
||||||
|
logger.error("You need to have at least two diffusers models defined in models.yaml in order to merge")
|
||||||
|
else:
|
||||||
|
logger.error("Not enough room for the user interface. Try making this window larger.")
|
||||||
|
sys.exit(-1)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(str(e))
|
||||||
|
sys.exit(-1)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
sys.exit(-1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -3,7 +3,7 @@
|
|||||||
"""
|
"""
|
||||||
This is the frontend to "textual_inversion_training.py".
|
This is the frontend to "textual_inversion_training.py".
|
||||||
|
|
||||||
Copyright (c) 2023 Lincoln Stein and the InvokeAI Development Team
|
Copyright (c) 2023-24 Lincoln Stein and the InvokeAI Development Team
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -14,7 +14,7 @@ import sys
|
|||||||
import traceback
|
import traceback
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import npyscreen
|
import npyscreen
|
||||||
from npyscreen import widget
|
from npyscreen import widget
|
||||||
@ -22,8 +22,9 @@ from omegaconf import OmegaConf
|
|||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
|
from invokeai.backend.install.install_helper import initialize_installer
|
||||||
from ...backend.training import do_textual_inversion_training, parse_args
|
from invokeai.backend.model_manager import ModelType
|
||||||
|
from invokeai.backend.training import do_textual_inversion_training, parse_args
|
||||||
|
|
||||||
TRAINING_DATA = "text-inversion-training-data"
|
TRAINING_DATA = "text-inversion-training-data"
|
||||||
TRAINING_DIR = "text-inversion-output"
|
TRAINING_DIR = "text-inversion-output"
|
||||||
@ -44,19 +45,21 @@ class textualInversionForm(npyscreen.FormMultiPageAction):
|
|||||||
precisions = ["no", "fp16", "bf16"]
|
precisions = ["no", "fp16", "bf16"]
|
||||||
learnable_properties = ["object", "style"]
|
learnable_properties = ["object", "style"]
|
||||||
|
|
||||||
def __init__(self, parentApp, name, saved_args=None):
|
def __init__(self, parentApp: npyscreen.NPSAppManaged, name: str, saved_args: Optional[Dict[str, str]] = None):
|
||||||
self.saved_args = saved_args or {}
|
self.saved_args = saved_args or {}
|
||||||
super().__init__(parentApp, name)
|
super().__init__(parentApp, name)
|
||||||
|
|
||||||
def afterEditing(self):
|
def afterEditing(self) -> None:
|
||||||
self.parentApp.setNextForm(None)
|
self.parentApp.setNextForm(None)
|
||||||
|
|
||||||
def create(self):
|
def create(self) -> None:
|
||||||
self.model_names, default = self.get_model_names()
|
self.model_names, default = self.get_model_names()
|
||||||
default_initializer_token = "★"
|
default_initializer_token = "★"
|
||||||
default_placeholder_token = ""
|
default_placeholder_token = ""
|
||||||
saved_args = self.saved_args
|
saved_args = self.saved_args
|
||||||
|
|
||||||
|
assert config is not None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
default = self.model_names.index(saved_args["model"])
|
default = self.model_names.index(saved_args["model"])
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -71,7 +74,7 @@ class textualInversionForm(npyscreen.FormMultiPageAction):
|
|||||||
self.model = self.add_widget_intelligent(
|
self.model = self.add_widget_intelligent(
|
||||||
npyscreen.TitleSelectOne,
|
npyscreen.TitleSelectOne,
|
||||||
name="Model Name:",
|
name="Model Name:",
|
||||||
values=self.model_names,
|
values=sorted(self.model_names),
|
||||||
value=default,
|
value=default,
|
||||||
max_height=len(self.model_names) + 1,
|
max_height=len(self.model_names) + 1,
|
||||||
scroll_exit=True,
|
scroll_exit=True,
|
||||||
@ -236,7 +239,7 @@ class textualInversionForm(npyscreen.FormMultiPageAction):
|
|||||||
)
|
)
|
||||||
self.model.editing = True
|
self.model.editing = True
|
||||||
|
|
||||||
def initializer_changed(self):
|
def initializer_changed(self) -> None:
|
||||||
placeholder = self.placeholder_token.value
|
placeholder = self.placeholder_token.value
|
||||||
self.prompt_token.value = f"(Trigger by using <{placeholder}> in your prompts)"
|
self.prompt_token.value = f"(Trigger by using <{placeholder}> in your prompts)"
|
||||||
self.train_data_dir.value = str(config.root_dir / TRAINING_DATA / placeholder)
|
self.train_data_dir.value = str(config.root_dir / TRAINING_DATA / placeholder)
|
||||||
@ -275,10 +278,13 @@ class textualInversionForm(npyscreen.FormMultiPageAction):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
def get_model_names(self) -> Tuple[List[str], int]:
|
def get_model_names(self) -> Tuple[List[str], int]:
|
||||||
conf = OmegaConf.load(config.root_dir / "configs/models.yaml")
|
global config
|
||||||
model_names = [idx for idx in sorted(conf.keys()) if conf[idx].get("format", None) == "diffusers"]
|
assert config is not None
|
||||||
defaults = [idx for idx in range(len(model_names)) if "default" in conf[model_names[idx]]]
|
installer = initialize_installer(config)
|
||||||
default = defaults[0] if len(defaults) > 0 else 0
|
store = installer.record_store
|
||||||
|
main_models = store.search_by_attr(model_type=ModelType.Main)
|
||||||
|
model_names = [f"{x.base.value}/{x.type.value}/{x.name}" for x in main_models if x.format == "diffusers"]
|
||||||
|
default = 0
|
||||||
return (model_names, default)
|
return (model_names, default)
|
||||||
|
|
||||||
def marshall_arguments(self) -> dict:
|
def marshall_arguments(self) -> dict:
|
||||||
@ -326,7 +332,7 @@ class textualInversionForm(npyscreen.FormMultiPageAction):
|
|||||||
|
|
||||||
|
|
||||||
class MyApplication(npyscreen.NPSAppManaged):
|
class MyApplication(npyscreen.NPSAppManaged):
|
||||||
def __init__(self, saved_args=None):
|
def __init__(self, saved_args: Optional[Dict[str, str]] = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.ti_arguments = None
|
self.ti_arguments = None
|
||||||
self.saved_args = saved_args
|
self.saved_args = saved_args
|
||||||
@ -341,11 +347,12 @@ class MyApplication(npyscreen.NPSAppManaged):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def copy_to_embeddings_folder(args: dict):
|
def copy_to_embeddings_folder(args: Dict[str, str]) -> None:
|
||||||
"""
|
"""
|
||||||
Copy learned_embeds.bin into the embeddings folder, and offer to
|
Copy learned_embeds.bin into the embeddings folder, and offer to
|
||||||
delete the full model and checkpoints.
|
delete the full model and checkpoints.
|
||||||
"""
|
"""
|
||||||
|
assert config is not None
|
||||||
source = Path(args["output_dir"], "learned_embeds.bin")
|
source = Path(args["output_dir"], "learned_embeds.bin")
|
||||||
dest_dir_name = args["placeholder_token"].strip("<>")
|
dest_dir_name = args["placeholder_token"].strip("<>")
|
||||||
destination = config.root_dir / "embeddings" / dest_dir_name
|
destination = config.root_dir / "embeddings" / dest_dir_name
|
||||||
@ -358,10 +365,11 @@ def copy_to_embeddings_folder(args: dict):
|
|||||||
logger.info(f'Keeping {args["output_dir"]}')
|
logger.info(f'Keeping {args["output_dir"]}')
|
||||||
|
|
||||||
|
|
||||||
def save_args(args: dict):
|
def save_args(args: dict) -> None:
|
||||||
"""
|
"""
|
||||||
Save the current argument values to an omegaconf file
|
Save the current argument values to an omegaconf file
|
||||||
"""
|
"""
|
||||||
|
assert config is not None
|
||||||
dest_dir = config.root_dir / TRAINING_DIR
|
dest_dir = config.root_dir / TRAINING_DIR
|
||||||
os.makedirs(dest_dir, exist_ok=True)
|
os.makedirs(dest_dir, exist_ok=True)
|
||||||
conf_file = dest_dir / CONF_FILE
|
conf_file = dest_dir / CONF_FILE
|
||||||
@ -373,6 +381,7 @@ def previous_args() -> dict:
|
|||||||
"""
|
"""
|
||||||
Get the previous arguments used.
|
Get the previous arguments used.
|
||||||
"""
|
"""
|
||||||
|
assert config is not None
|
||||||
conf_file = config.root_dir / TRAINING_DIR / CONF_FILE
|
conf_file = config.root_dir / TRAINING_DIR / CONF_FILE
|
||||||
try:
|
try:
|
||||||
conf = OmegaConf.load(conf_file)
|
conf = OmegaConf.load(conf_file)
|
||||||
@ -383,24 +392,26 @@ def previous_args() -> dict:
|
|||||||
return conf
|
return conf
|
||||||
|
|
||||||
|
|
||||||
def do_front_end(args: Namespace):
|
def do_front_end() -> None:
|
||||||
|
global config
|
||||||
saved_args = previous_args()
|
saved_args = previous_args()
|
||||||
myapplication = MyApplication(saved_args=saved_args)
|
myapplication = MyApplication(saved_args=saved_args)
|
||||||
myapplication.run()
|
myapplication.run()
|
||||||
|
|
||||||
if args := myapplication.ti_arguments:
|
if my_args := myapplication.ti_arguments:
|
||||||
os.makedirs(args["output_dir"], exist_ok=True)
|
os.makedirs(my_args["output_dir"], exist_ok=True)
|
||||||
|
|
||||||
# Automatically add angle brackets around the trigger
|
# Automatically add angle brackets around the trigger
|
||||||
if not re.match("^<.+>$", args["placeholder_token"]):
|
if not re.match("^<.+>$", my_args["placeholder_token"]):
|
||||||
args["placeholder_token"] = f"<{args['placeholder_token']}>"
|
my_args["placeholder_token"] = f"<{my_args['placeholder_token']}>"
|
||||||
|
|
||||||
args["only_save_embeds"] = True
|
my_args["only_save_embeds"] = True
|
||||||
save_args(args)
|
save_args(my_args)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
do_textual_inversion_training(InvokeAIAppConfig.get_config(), **args)
|
print(my_args)
|
||||||
copy_to_embeddings_folder(args)
|
do_textual_inversion_training(config, **my_args)
|
||||||
|
copy_to_embeddings_folder(my_args)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("An exception occurred during training. The exception was:")
|
logger.error("An exception occurred during training. The exception was:")
|
||||||
logger.error(str(e))
|
logger.error(str(e))
|
||||||
@ -408,11 +419,12 @@ def do_front_end(args: Namespace):
|
|||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main() -> None:
|
||||||
global config
|
global config
|
||||||
|
|
||||||
args = parse_args()
|
args: Namespace = parse_args()
|
||||||
config = InvokeAIAppConfig.get_config()
|
config = InvokeAIAppConfig.get_config()
|
||||||
|
config.parse_args([])
|
||||||
|
|
||||||
# change root if needed
|
# change root if needed
|
||||||
if args.root_dir:
|
if args.root_dir:
|
||||||
@ -420,7 +432,7 @@ def main():
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
if args.front_end:
|
if args.front_end:
|
||||||
do_front_end(args)
|
do_front_end()
|
||||||
else:
|
else:
|
||||||
do_textual_inversion_training(config, **vars(args))
|
do_textual_inversion_training(config, **vars(args))
|
||||||
except AssertionError as e:
|
except AssertionError as e:
|
||||||
|
454
invokeai/frontend/training/textual_inversion2.py
Normal file
454
invokeai/frontend/training/textual_inversion2.py
Normal file
@ -0,0 +1,454 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
"""
|
||||||
|
This is the frontend to "textual_inversion_training.py".
|
||||||
|
|
||||||
|
Copyright (c) 2023-24 Lincoln Stein and the InvokeAI Development Team
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import shutil
|
||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
from argparse import Namespace
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import npyscreen
|
||||||
|
from npyscreen import widget
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
|
import invokeai.backend.util.logging as logger
|
||||||
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
|
from invokeai.backend.install.install_helper import initialize_installer
|
||||||
|
from invokeai.backend.model_manager import ModelType
|
||||||
|
from invokeai.backend.training import do_textual_inversion_training, parse_args
|
||||||
|
|
||||||
|
TRAINING_DATA = "text-inversion-training-data"
|
||||||
|
TRAINING_DIR = "text-inversion-output"
|
||||||
|
CONF_FILE = "preferences.conf"
|
||||||
|
config = None
|
||||||
|
|
||||||
|
|
||||||
|
class textualInversionForm(npyscreen.FormMultiPageAction):
|
||||||
|
resolutions = [512, 768, 1024]
|
||||||
|
lr_schedulers = [
|
||||||
|
"linear",
|
||||||
|
"cosine",
|
||||||
|
"cosine_with_restarts",
|
||||||
|
"polynomial",
|
||||||
|
"constant",
|
||||||
|
"constant_with_warmup",
|
||||||
|
]
|
||||||
|
precisions = ["no", "fp16", "bf16"]
|
||||||
|
learnable_properties = ["object", "style"]
|
||||||
|
|
||||||
|
def __init__(self, parentApp: npyscreen.NPSAppManaged, name: str, saved_args: Optional[Dict[str, str]] = None):
|
||||||
|
self.saved_args = saved_args or {}
|
||||||
|
super().__init__(parentApp, name)
|
||||||
|
|
||||||
|
def afterEditing(self) -> None:
|
||||||
|
self.parentApp.setNextForm(None)
|
||||||
|
|
||||||
|
def create(self) -> None:
|
||||||
|
self.model_names, default = self.get_model_names()
|
||||||
|
default_initializer_token = "★"
|
||||||
|
default_placeholder_token = ""
|
||||||
|
saved_args = self.saved_args
|
||||||
|
|
||||||
|
assert config is not None
|
||||||
|
|
||||||
|
try:
|
||||||
|
default = self.model_names.index(saved_args["model"])
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
self.add_widget_intelligent(
|
||||||
|
npyscreen.FixedText,
|
||||||
|
value="Use ctrl-N and ctrl-P to move to the <N>ext and <P>revious fields, cursor arrows to make a selection, and space to toggle checkboxes.",
|
||||||
|
editable=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.model = self.add_widget_intelligent(
|
||||||
|
npyscreen.TitleSelectOne,
|
||||||
|
name="Model Name:",
|
||||||
|
values=sorted(self.model_names),
|
||||||
|
value=default,
|
||||||
|
max_height=len(self.model_names) + 1,
|
||||||
|
scroll_exit=True,
|
||||||
|
)
|
||||||
|
self.placeholder_token = self.add_widget_intelligent(
|
||||||
|
npyscreen.TitleText,
|
||||||
|
name="Trigger Term:",
|
||||||
|
value="", # saved_args.get('placeholder_token',''), # to restore previous term
|
||||||
|
scroll_exit=True,
|
||||||
|
)
|
||||||
|
self.placeholder_token.when_value_edited = self.initializer_changed
|
||||||
|
self.nextrely -= 1
|
||||||
|
self.nextrelx += 30
|
||||||
|
self.prompt_token = self.add_widget_intelligent(
|
||||||
|
npyscreen.FixedText,
|
||||||
|
name="Trigger term for use in prompt",
|
||||||
|
value="",
|
||||||
|
editable=False,
|
||||||
|
scroll_exit=True,
|
||||||
|
)
|
||||||
|
self.nextrelx -= 30
|
||||||
|
self.initializer_token = self.add_widget_intelligent(
|
||||||
|
npyscreen.TitleText,
|
||||||
|
name="Initializer:",
|
||||||
|
value=saved_args.get("initializer_token", default_initializer_token),
|
||||||
|
scroll_exit=True,
|
||||||
|
)
|
||||||
|
self.resume_from_checkpoint = self.add_widget_intelligent(
|
||||||
|
npyscreen.Checkbox,
|
||||||
|
name="Resume from last saved checkpoint",
|
||||||
|
value=False,
|
||||||
|
scroll_exit=True,
|
||||||
|
)
|
||||||
|
self.learnable_property = self.add_widget_intelligent(
|
||||||
|
npyscreen.TitleSelectOne,
|
||||||
|
name="Learnable property:",
|
||||||
|
values=self.learnable_properties,
|
||||||
|
value=self.learnable_properties.index(saved_args.get("learnable_property", "object")),
|
||||||
|
max_height=4,
|
||||||
|
scroll_exit=True,
|
||||||
|
)
|
||||||
|
self.train_data_dir = self.add_widget_intelligent(
|
||||||
|
npyscreen.TitleFilename,
|
||||||
|
name="Data Training Directory:",
|
||||||
|
select_dir=True,
|
||||||
|
must_exist=False,
|
||||||
|
value=str(
|
||||||
|
saved_args.get(
|
||||||
|
"train_data_dir",
|
||||||
|
config.root_dir / TRAINING_DATA / default_placeholder_token,
|
||||||
|
)
|
||||||
|
),
|
||||||
|
scroll_exit=True,
|
||||||
|
)
|
||||||
|
self.output_dir = self.add_widget_intelligent(
|
||||||
|
npyscreen.TitleFilename,
|
||||||
|
name="Output Destination Directory:",
|
||||||
|
select_dir=True,
|
||||||
|
must_exist=False,
|
||||||
|
value=str(
|
||||||
|
saved_args.get(
|
||||||
|
"output_dir",
|
||||||
|
config.root_dir / TRAINING_DIR / default_placeholder_token,
|
||||||
|
)
|
||||||
|
),
|
||||||
|
scroll_exit=True,
|
||||||
|
)
|
||||||
|
self.resolution = self.add_widget_intelligent(
|
||||||
|
npyscreen.TitleSelectOne,
|
||||||
|
name="Image resolution (pixels):",
|
||||||
|
values=self.resolutions,
|
||||||
|
value=self.resolutions.index(saved_args.get("resolution", 512)),
|
||||||
|
max_height=4,
|
||||||
|
scroll_exit=True,
|
||||||
|
)
|
||||||
|
self.center_crop = self.add_widget_intelligent(
|
||||||
|
npyscreen.Checkbox,
|
||||||
|
name="Center crop images before resizing to resolution",
|
||||||
|
value=saved_args.get("center_crop", False),
|
||||||
|
scroll_exit=True,
|
||||||
|
)
|
||||||
|
self.mixed_precision = self.add_widget_intelligent(
|
||||||
|
npyscreen.TitleSelectOne,
|
||||||
|
name="Mixed Precision:",
|
||||||
|
values=self.precisions,
|
||||||
|
value=self.precisions.index(saved_args.get("mixed_precision", "fp16")),
|
||||||
|
max_height=4,
|
||||||
|
scroll_exit=True,
|
||||||
|
)
|
||||||
|
self.num_train_epochs = self.add_widget_intelligent(
|
||||||
|
npyscreen.TitleSlider,
|
||||||
|
name="Number of training epochs:",
|
||||||
|
out_of=1000,
|
||||||
|
step=50,
|
||||||
|
lowest=1,
|
||||||
|
value=saved_args.get("num_train_epochs", 100),
|
||||||
|
scroll_exit=True,
|
||||||
|
)
|
||||||
|
self.max_train_steps = self.add_widget_intelligent(
|
||||||
|
npyscreen.TitleSlider,
|
||||||
|
name="Max Training Steps:",
|
||||||
|
out_of=10000,
|
||||||
|
step=500,
|
||||||
|
lowest=1,
|
||||||
|
value=saved_args.get("max_train_steps", 3000),
|
||||||
|
scroll_exit=True,
|
||||||
|
)
|
||||||
|
self.train_batch_size = self.add_widget_intelligent(
|
||||||
|
npyscreen.TitleSlider,
|
||||||
|
name="Batch Size (reduce if you run out of memory):",
|
||||||
|
out_of=50,
|
||||||
|
step=1,
|
||||||
|
lowest=1,
|
||||||
|
value=saved_args.get("train_batch_size", 8),
|
||||||
|
scroll_exit=True,
|
||||||
|
)
|
||||||
|
self.gradient_accumulation_steps = self.add_widget_intelligent(
|
||||||
|
npyscreen.TitleSlider,
|
||||||
|
name="Gradient Accumulation Steps (may need to decrease this to resume from a checkpoint):",
|
||||||
|
out_of=10,
|
||||||
|
step=1,
|
||||||
|
lowest=1,
|
||||||
|
value=saved_args.get("gradient_accumulation_steps", 4),
|
||||||
|
scroll_exit=True,
|
||||||
|
)
|
||||||
|
self.lr_warmup_steps = self.add_widget_intelligent(
|
||||||
|
npyscreen.TitleSlider,
|
||||||
|
name="Warmup Steps:",
|
||||||
|
out_of=100,
|
||||||
|
step=1,
|
||||||
|
lowest=0,
|
||||||
|
value=saved_args.get("lr_warmup_steps", 0),
|
||||||
|
scroll_exit=True,
|
||||||
|
)
|
||||||
|
self.learning_rate = self.add_widget_intelligent(
|
||||||
|
npyscreen.TitleText,
|
||||||
|
name="Learning Rate:",
|
||||||
|
value=str(
|
||||||
|
saved_args.get("learning_rate", "5.0e-04"),
|
||||||
|
),
|
||||||
|
scroll_exit=True,
|
||||||
|
)
|
||||||
|
self.scale_lr = self.add_widget_intelligent(
|
||||||
|
npyscreen.Checkbox,
|
||||||
|
name="Scale learning rate by number GPUs, steps and batch size",
|
||||||
|
value=saved_args.get("scale_lr", True),
|
||||||
|
scroll_exit=True,
|
||||||
|
)
|
||||||
|
self.enable_xformers_memory_efficient_attention = self.add_widget_intelligent(
|
||||||
|
npyscreen.Checkbox,
|
||||||
|
name="Use xformers acceleration",
|
||||||
|
value=saved_args.get("enable_xformers_memory_efficient_attention", False),
|
||||||
|
scroll_exit=True,
|
||||||
|
)
|
||||||
|
self.lr_scheduler = self.add_widget_intelligent(
|
||||||
|
npyscreen.TitleSelectOne,
|
||||||
|
name="Learning rate scheduler:",
|
||||||
|
values=self.lr_schedulers,
|
||||||
|
max_height=7,
|
||||||
|
value=self.lr_schedulers.index(saved_args.get("lr_scheduler", "constant")),
|
||||||
|
scroll_exit=True,
|
||||||
|
)
|
||||||
|
self.model.editing = True
|
||||||
|
|
||||||
|
def initializer_changed(self) -> None:
|
||||||
|
placeholder = self.placeholder_token.value
|
||||||
|
self.prompt_token.value = f"(Trigger by using <{placeholder}> in your prompts)"
|
||||||
|
self.train_data_dir.value = str(config.root_dir / TRAINING_DATA / placeholder)
|
||||||
|
self.output_dir.value = str(config.root_dir / TRAINING_DIR / placeholder)
|
||||||
|
self.resume_from_checkpoint.value = Path(self.output_dir.value).exists()
|
||||||
|
|
||||||
|
def on_ok(self):
|
||||||
|
if self.validate_field_values():
|
||||||
|
self.parentApp.setNextForm(None)
|
||||||
|
self.editing = False
|
||||||
|
self.parentApp.ti_arguments = self.marshall_arguments()
|
||||||
|
npyscreen.notify("Launching textual inversion training. This will take a while...")
|
||||||
|
else:
|
||||||
|
self.editing = True
|
||||||
|
|
||||||
|
def ok_cancel(self):
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
def validate_field_values(self) -> bool:
|
||||||
|
bad_fields = []
|
||||||
|
if self.model.value is None:
|
||||||
|
bad_fields.append("Model Name must correspond to a known model in models.yaml")
|
||||||
|
if not re.match("^[a-zA-Z0-9.-]+$", self.placeholder_token.value):
|
||||||
|
bad_fields.append("Trigger term must only contain alphanumeric characters, the dot and hyphen")
|
||||||
|
if self.train_data_dir.value is None:
|
||||||
|
bad_fields.append("Data Training Directory cannot be empty")
|
||||||
|
if self.output_dir.value is None:
|
||||||
|
bad_fields.append("The Output Destination Directory cannot be empty")
|
||||||
|
if len(bad_fields) > 0:
|
||||||
|
message = "The following problems were detected and must be corrected:"
|
||||||
|
for problem in bad_fields:
|
||||||
|
message += f"\n* {problem}"
|
||||||
|
npyscreen.notify_confirm(message)
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_model_names(self) -> Tuple[List[str], int]:
|
||||||
|
global config
|
||||||
|
assert config is not None
|
||||||
|
installer = initialize_installer(config)
|
||||||
|
store = installer.record_store
|
||||||
|
main_models = store.search_by_attr(model_type=ModelType.Main)
|
||||||
|
model_names = [f"{x.base.value}/{x.type.value}/{x.name}" for x in main_models if x.format == "diffusers"]
|
||||||
|
default = 0
|
||||||
|
return (model_names, default)
|
||||||
|
|
||||||
|
def marshall_arguments(self) -> dict:
|
||||||
|
args = {}
|
||||||
|
|
||||||
|
# the choices
|
||||||
|
args.update(
|
||||||
|
model=self.model_names[self.model.value[0]],
|
||||||
|
resolution=self.resolutions[self.resolution.value[0]],
|
||||||
|
lr_scheduler=self.lr_schedulers[self.lr_scheduler.value[0]],
|
||||||
|
mixed_precision=self.precisions[self.mixed_precision.value[0]],
|
||||||
|
learnable_property=self.learnable_properties[self.learnable_property.value[0]],
|
||||||
|
)
|
||||||
|
|
||||||
|
# all the strings and booleans
|
||||||
|
for attr in (
|
||||||
|
"initializer_token",
|
||||||
|
"placeholder_token",
|
||||||
|
"train_data_dir",
|
||||||
|
"output_dir",
|
||||||
|
"scale_lr",
|
||||||
|
"center_crop",
|
||||||
|
"enable_xformers_memory_efficient_attention",
|
||||||
|
):
|
||||||
|
args[attr] = getattr(self, attr).value
|
||||||
|
|
||||||
|
# all the integers
|
||||||
|
for attr in (
|
||||||
|
"train_batch_size",
|
||||||
|
"gradient_accumulation_steps",
|
||||||
|
"num_train_epochs",
|
||||||
|
"max_train_steps",
|
||||||
|
"lr_warmup_steps",
|
||||||
|
):
|
||||||
|
args[attr] = int(getattr(self, attr).value)
|
||||||
|
|
||||||
|
# the floats (just one)
|
||||||
|
args.update(learning_rate=float(self.learning_rate.value))
|
||||||
|
|
||||||
|
# a special case
|
||||||
|
if self.resume_from_checkpoint.value and Path(self.output_dir.value).exists():
|
||||||
|
args["resume_from_checkpoint"] = "latest"
|
||||||
|
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
class MyApplication(npyscreen.NPSAppManaged):
|
||||||
|
def __init__(self, saved_args: Optional[Dict[str, str]] = None):
|
||||||
|
super().__init__()
|
||||||
|
self.ti_arguments = None
|
||||||
|
self.saved_args = saved_args
|
||||||
|
|
||||||
|
def onStart(self):
|
||||||
|
npyscreen.setTheme(npyscreen.Themes.DefaultTheme)
|
||||||
|
self.main = self.addForm(
|
||||||
|
"MAIN",
|
||||||
|
textualInversionForm,
|
||||||
|
name="Textual Inversion Settings",
|
||||||
|
saved_args=self.saved_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def copy_to_embeddings_folder(args: Dict[str, str]) -> None:
|
||||||
|
"""
|
||||||
|
Copy learned_embeds.bin into the embeddings folder, and offer to
|
||||||
|
delete the full model and checkpoints.
|
||||||
|
"""
|
||||||
|
assert config is not None
|
||||||
|
source = Path(args["output_dir"], "learned_embeds.bin")
|
||||||
|
dest_dir_name = args["placeholder_token"].strip("<>")
|
||||||
|
destination = config.root_dir / "embeddings" / dest_dir_name
|
||||||
|
os.makedirs(destination, exist_ok=True)
|
||||||
|
logger.info(f"Training completed. Copying learned_embeds.bin into {str(destination)}")
|
||||||
|
shutil.copy(source, destination)
|
||||||
|
if (input("Delete training logs and intermediate checkpoints? [y] ") or "y").startswith(("y", "Y")):
|
||||||
|
shutil.rmtree(Path(args["output_dir"]))
|
||||||
|
else:
|
||||||
|
logger.info(f'Keeping {args["output_dir"]}')
|
||||||
|
|
||||||
|
|
||||||
|
def save_args(args: dict) -> None:
|
||||||
|
"""
|
||||||
|
Save the current argument values to an omegaconf file
|
||||||
|
"""
|
||||||
|
assert config is not None
|
||||||
|
dest_dir = config.root_dir / TRAINING_DIR
|
||||||
|
os.makedirs(dest_dir, exist_ok=True)
|
||||||
|
conf_file = dest_dir / CONF_FILE
|
||||||
|
conf = OmegaConf.create(args)
|
||||||
|
OmegaConf.save(config=conf, f=conf_file)
|
||||||
|
|
||||||
|
|
||||||
|
def previous_args() -> dict:
|
||||||
|
"""
|
||||||
|
Get the previous arguments used.
|
||||||
|
"""
|
||||||
|
assert config is not None
|
||||||
|
conf_file = config.root_dir / TRAINING_DIR / CONF_FILE
|
||||||
|
try:
|
||||||
|
conf = OmegaConf.load(conf_file)
|
||||||
|
conf["placeholder_token"] = conf["placeholder_token"].strip("<>")
|
||||||
|
except Exception:
|
||||||
|
conf = None
|
||||||
|
|
||||||
|
return conf
|
||||||
|
|
||||||
|
|
||||||
|
def do_front_end() -> None:
|
||||||
|
global config
|
||||||
|
saved_args = previous_args()
|
||||||
|
myapplication = MyApplication(saved_args=saved_args)
|
||||||
|
myapplication.run()
|
||||||
|
|
||||||
|
if my_args := myapplication.ti_arguments:
|
||||||
|
os.makedirs(my_args["output_dir"], exist_ok=True)
|
||||||
|
|
||||||
|
# Automatically add angle brackets around the trigger
|
||||||
|
if not re.match("^<.+>$", my_args["placeholder_token"]):
|
||||||
|
my_args["placeholder_token"] = f"<{my_args['placeholder_token']}>"
|
||||||
|
|
||||||
|
my_args["only_save_embeds"] = True
|
||||||
|
save_args(my_args)
|
||||||
|
|
||||||
|
try:
|
||||||
|
print(my_args)
|
||||||
|
do_textual_inversion_training(config, **my_args)
|
||||||
|
copy_to_embeddings_folder(my_args)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("An exception occurred during training. The exception was:")
|
||||||
|
logger.error(str(e))
|
||||||
|
logger.error("DETAILS:")
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
global config
|
||||||
|
|
||||||
|
args: Namespace = parse_args()
|
||||||
|
config = InvokeAIAppConfig.get_config()
|
||||||
|
config.parse_args([])
|
||||||
|
|
||||||
|
# change root if needed
|
||||||
|
if args.root_dir:
|
||||||
|
config.root = args.root_dir
|
||||||
|
|
||||||
|
try:
|
||||||
|
if args.front_end:
|
||||||
|
do_front_end()
|
||||||
|
else:
|
||||||
|
do_textual_inversion_training(config, **vars(args))
|
||||||
|
except AssertionError as e:
|
||||||
|
logger.error(e)
|
||||||
|
sys.exit(-1)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
pass
|
||||||
|
except (widget.NotEnoughSpaceForWidget, Exception) as e:
|
||||||
|
if str(e).startswith("Height of 1 allocated"):
|
||||||
|
logger.error("You need to have at least one diffusers models defined in models.yaml in order to train")
|
||||||
|
elif str(e).startswith("addwstr"):
|
||||||
|
logger.error("Not enough window space for the interface. Please make your window larger and try again.")
|
||||||
|
else:
|
||||||
|
logger.error(e)
|
||||||
|
sys.exit(-1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -137,8 +137,10 @@ dependencies = [
|
|||||||
# full commands
|
# full commands
|
||||||
"invokeai-configure" = "invokeai.frontend.install.invokeai_configure:invokeai_configure"
|
"invokeai-configure" = "invokeai.frontend.install.invokeai_configure:invokeai_configure"
|
||||||
"invokeai-merge" = "invokeai.frontend.merge:invokeai_merge_diffusers"
|
"invokeai-merge" = "invokeai.frontend.merge:invokeai_merge_diffusers"
|
||||||
|
"invokeai-merge2" = "invokeai.frontend.merge.merge_diffusers2:main"
|
||||||
"invokeai-ti" = "invokeai.frontend.training:invokeai_textual_inversion"
|
"invokeai-ti" = "invokeai.frontend.training:invokeai_textual_inversion"
|
||||||
"invokeai-model-install" = "invokeai.frontend.install.model_install:main"
|
"invokeai-model-install" = "invokeai.frontend.install.model_install:main"
|
||||||
|
"invokeai-model-install2" = "invokeai.frontend.install.model_install2:main" # will eventually be renamed to invokeai-model-install
|
||||||
"invokeai-migrate3" = "invokeai.backend.install.migrate_to_3:main"
|
"invokeai-migrate3" = "invokeai.backend.install.migrate_to_3:main"
|
||||||
"invokeai-update" = "invokeai.frontend.install.invokeai_update:main"
|
"invokeai-update" = "invokeai.frontend.install.invokeai_update:main"
|
||||||
"invokeai-metadata" = "invokeai.backend.image_util.invoke_metadata:main"
|
"invokeai-metadata" = "invokeai.backend.image_util.invoke_metadata:main"
|
||||||
|
Loading…
Reference in New Issue
Block a user