mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Port the command-line tools to use model_manager2 (#5546)
* Port the command-line tools to use model_manager2 1.Reimplement the following: - invokeai-model-install - invokeai-merge - invokeai-ti To avoid breaking the original modeal manager, the udpated tools have been renamed invokeai-model-install2 and invokeai-merge2. The textual inversion training script should continue to work with existing installations. The "starter" models now live in `invokeai/configs/INITIAL_MODELS2.yaml`. When the full model manager 2 is in place and working, I'll rename these files and commands. 2. Add the `merge` route to the web API. This will merge two or three models, resulting a new one. - Note that because the model installer selectively installs the `fp16` variant of models (rather than both 16- and 32-bit versions as previous), the diffusers merge script will choke on any huggingface diffuserse models that were downloaded with the new installer. Previously-downloaded models should continue to merge correctly. I have a PR upstream https://github.com/huggingface/diffusers/pull/6670 to fix this. 3. (more important!) During implementation of the CLI tools, found and fixed a number of small runtime bugs in the model_manager2 implementation: - During model database migration, if a registered models file was not found on disk, the migration would be aborted. Now the offending model is skipped with a log warning. - Caught and fixed a condition in which the installer would download the entire diffusers repo when the user provided a single `.safetensors` file URL. - Caught and fixed a condition in which the installer would raise an exception and stop the app when a request for an unknown model's metadata was passed to Civitai. Now an error is logged and the installer continues. - Replaced the LoWRA starter LoRA with FlatColor. The former has been removed from Civitai. * fix ruff issue --------- Co-authored-by: Lincoln Stein <lstein@gmail.com>
This commit is contained in:
281
invokeai/backend/install/install_helper.py
Normal file
281
invokeai/backend/install/install_helper.py
Normal file
@ -0,0 +1,281 @@
|
||||
"""Utility (backend) functions used by model_install.py"""
|
||||
import re
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import omegaconf
|
||||
from huggingface_hub import HfFolder
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic.dataclasses import dataclass
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from requests import HTTPError
|
||||
from tqdm import tqdm
|
||||
|
||||
import invokeai.configs as configs
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.download import DownloadQueueService
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.image_files.image_files_disk import DiskImageFileStorage
|
||||
from invokeai.app.services.model_install import (
|
||||
HFModelSource,
|
||||
LocalModelSource,
|
||||
ModelInstallService,
|
||||
ModelInstallServiceBase,
|
||||
ModelSource,
|
||||
URLModelSource,
|
||||
)
|
||||
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL
|
||||
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
|
||||
from invokeai.backend.model_manager import (
|
||||
BaseModelType,
|
||||
InvalidModelConfigException,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.metadata import UnknownMetadataException
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
# name of the starter models file
|
||||
INITIAL_MODELS = "INITIAL_MODELS2.yaml"
|
||||
|
||||
|
||||
def initialize_record_store(app_config: InvokeAIAppConfig) -> ModelRecordServiceBase:
|
||||
"""Return an initialized ModelConfigRecordServiceBase object."""
|
||||
logger = InvokeAILogger.get_logger(config=app_config)
|
||||
image_files = DiskImageFileStorage(f"{app_config.output_path}/images")
|
||||
db = init_db(config=app_config, logger=logger, image_files=image_files)
|
||||
obj: ModelRecordServiceBase = ModelRecordServiceSQL(db)
|
||||
return obj
|
||||
|
||||
|
||||
def initialize_installer(
|
||||
app_config: InvokeAIAppConfig, event_bus: Optional[EventServiceBase] = None
|
||||
) -> ModelInstallServiceBase:
|
||||
"""Return an initialized ModelInstallService object."""
|
||||
record_store = initialize_record_store(app_config)
|
||||
metadata_store = record_store.metadata_store
|
||||
download_queue = DownloadQueueService()
|
||||
installer = ModelInstallService(
|
||||
app_config=app_config,
|
||||
record_store=record_store,
|
||||
metadata_store=metadata_store,
|
||||
download_queue=download_queue,
|
||||
event_bus=event_bus,
|
||||
)
|
||||
download_queue.start()
|
||||
installer.start()
|
||||
return installer
|
||||
|
||||
|
||||
class UnifiedModelInfo(BaseModel):
|
||||
"""Catchall class for information in INITIAL_MODELS2.yaml."""
|
||||
|
||||
name: Optional[str] = None
|
||||
base: Optional[BaseModelType] = None
|
||||
type: Optional[ModelType] = None
|
||||
source: Optional[str] = None
|
||||
subfolder: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
recommended: bool = False
|
||||
installed: bool = False
|
||||
default: bool = False
|
||||
requires: List[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class InstallSelections:
|
||||
"""Lists of models to install and remove."""
|
||||
|
||||
install_models: List[UnifiedModelInfo] = Field(default_factory=list)
|
||||
remove_models: List[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class TqdmEventService(EventServiceBase):
|
||||
"""An event service to track downloads."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Create a new TqdmEventService object."""
|
||||
super().__init__()
|
||||
self._bars: Dict[str, tqdm] = {}
|
||||
self._last: Dict[str, int] = {}
|
||||
|
||||
def dispatch(self, event_name: str, payload: Any) -> None:
|
||||
"""Dispatch an event by appending it to self.events."""
|
||||
if payload["event"] == "model_install_downloading":
|
||||
data = payload["data"]
|
||||
dest = data["local_path"]
|
||||
total_bytes = data["total_bytes"]
|
||||
bytes = data["bytes"]
|
||||
if dest not in self._bars:
|
||||
self._bars[dest] = tqdm(desc=Path(dest).name, initial=0, total=total_bytes, unit="iB", unit_scale=True)
|
||||
self._last[dest] = 0
|
||||
self._bars[dest].update(bytes - self._last[dest])
|
||||
self._last[dest] = bytes
|
||||
|
||||
|
||||
class InstallHelper(object):
|
||||
"""Capture information stored jointly in INITIAL_MODELS.yaml and the installed models db."""
|
||||
|
||||
def __init__(self, app_config: InvokeAIAppConfig, logger: Logger):
|
||||
"""Create new InstallHelper object."""
|
||||
self._app_config = app_config
|
||||
self.all_models: Dict[str, UnifiedModelInfo] = {}
|
||||
|
||||
omega = omegaconf.OmegaConf.load(Path(configs.__path__[0]) / INITIAL_MODELS)
|
||||
assert isinstance(omega, omegaconf.dictconfig.DictConfig)
|
||||
|
||||
self._installer = initialize_installer(app_config, TqdmEventService())
|
||||
self._initial_models = omega
|
||||
self._installed_models: List[str] = []
|
||||
self._starter_models: List[str] = []
|
||||
self._default_model: Optional[str] = None
|
||||
self._logger = logger
|
||||
self._initialize_model_lists()
|
||||
|
||||
@property
|
||||
def installer(self) -> ModelInstallServiceBase:
|
||||
"""Return the installer object used internally."""
|
||||
return self._installer
|
||||
|
||||
def _initialize_model_lists(self) -> None:
|
||||
"""
|
||||
Initialize our model slots.
|
||||
|
||||
Set up the following:
|
||||
installed_models -- list of installed model keys
|
||||
starter_models -- list of starter model keys from INITIAL_MODELS
|
||||
all_models -- dict of key => UnifiedModelInfo
|
||||
default_model -- key to default model
|
||||
"""
|
||||
# previously-installed models
|
||||
for model in self._installer.record_store.all_models():
|
||||
info = UnifiedModelInfo.parse_obj(model.dict())
|
||||
info.installed = True
|
||||
model_key = f"{model.base.value}/{model.type.value}/{model.name}"
|
||||
self.all_models[model_key] = info
|
||||
self._installed_models.append(model_key)
|
||||
|
||||
for key in self._initial_models.keys():
|
||||
assert isinstance(key, str)
|
||||
if key in self.all_models:
|
||||
# we want to preserve the description
|
||||
description = self.all_models[key].description or self._initial_models[key].get("description")
|
||||
self.all_models[key].description = description
|
||||
else:
|
||||
base_model, model_type, model_name = key.split("/")
|
||||
info = UnifiedModelInfo(
|
||||
name=model_name,
|
||||
type=ModelType(model_type),
|
||||
base=BaseModelType(base_model),
|
||||
source=self._initial_models[key].source,
|
||||
description=self._initial_models[key].get("description"),
|
||||
recommended=self._initial_models[key].get("recommended", False),
|
||||
default=self._initial_models[key].get("default", False),
|
||||
subfolder=self._initial_models[key].get("subfolder"),
|
||||
requires=list(self._initial_models[key].get("requires", [])),
|
||||
)
|
||||
self.all_models[key] = info
|
||||
if not self.default_model():
|
||||
self._default_model = key
|
||||
elif self._initial_models[key].get("default", False):
|
||||
self._default_model = key
|
||||
self._starter_models.append(key)
|
||||
|
||||
# previously-installed models
|
||||
for model in self._installer.record_store.all_models():
|
||||
info = UnifiedModelInfo.parse_obj(model.dict())
|
||||
info.installed = True
|
||||
model_key = f"{model.base.value}/{model.type.value}/{model.name}"
|
||||
self.all_models[model_key] = info
|
||||
self._installed_models.append(model_key)
|
||||
|
||||
def recommended_models(self) -> List[UnifiedModelInfo]:
|
||||
"""List of the models recommended in INITIAL_MODELS.yaml."""
|
||||
return [self._to_model(x) for x in self._starter_models if self._to_model(x).recommended]
|
||||
|
||||
def installed_models(self) -> List[UnifiedModelInfo]:
|
||||
"""List of models already installed."""
|
||||
return [self._to_model(x) for x in self._installed_models]
|
||||
|
||||
def starter_models(self) -> List[UnifiedModelInfo]:
|
||||
"""List of starter models."""
|
||||
return [self._to_model(x) for x in self._starter_models]
|
||||
|
||||
def default_model(self) -> Optional[UnifiedModelInfo]:
|
||||
"""Return the default model."""
|
||||
return self._to_model(self._default_model) if self._default_model else None
|
||||
|
||||
def _to_model(self, key: str) -> UnifiedModelInfo:
|
||||
return self.all_models[key]
|
||||
|
||||
def _add_required_models(self, model_list: List[UnifiedModelInfo]) -> None:
|
||||
installed = {x.source for x in self.installed_models()}
|
||||
reverse_source = {x.source: x for x in self.all_models.values()}
|
||||
additional_models: List[UnifiedModelInfo] = []
|
||||
for model_info in model_list:
|
||||
for requirement in model_info.requires:
|
||||
if requirement not in installed and reverse_source.get(requirement):
|
||||
additional_models.append(reverse_source[requirement])
|
||||
model_list.extend(additional_models)
|
||||
|
||||
def _make_install_source(self, model_info: UnifiedModelInfo) -> ModelSource:
|
||||
assert model_info.source
|
||||
model_path_id_or_url = model_info.source.strip("\"' ")
|
||||
model_path = Path(model_path_id_or_url)
|
||||
|
||||
if model_path.exists(): # local file on disk
|
||||
return LocalModelSource(path=model_path.absolute(), inplace=True)
|
||||
if re.match(r"^[^/]+/[^/]+$", model_path_id_or_url): # hugging face repo_id
|
||||
return HFModelSource(
|
||||
repo_id=model_path_id_or_url,
|
||||
access_token=HfFolder.get_token(),
|
||||
subfolder=model_info.subfolder,
|
||||
)
|
||||
if re.match(r"^(http|https):", model_path_id_or_url):
|
||||
return URLModelSource(url=AnyHttpUrl(model_path_id_or_url))
|
||||
raise ValueError(f"Unsupported model source: {model_path_id_or_url}")
|
||||
|
||||
def add_or_delete(self, selections: InstallSelections) -> None:
|
||||
"""Add or delete selected models."""
|
||||
installer = self._installer
|
||||
self._add_required_models(selections.install_models)
|
||||
for model in selections.install_models:
|
||||
source = self._make_install_source(model)
|
||||
config = (
|
||||
{
|
||||
"description": model.description,
|
||||
"name": model.name,
|
||||
}
|
||||
if model.name
|
||||
else None
|
||||
)
|
||||
|
||||
try:
|
||||
installer.import_model(
|
||||
source=source,
|
||||
config=config,
|
||||
)
|
||||
except (UnknownMetadataException, InvalidModelConfigException, HTTPError, OSError) as e:
|
||||
self._logger.warning(f"{source}: {e}")
|
||||
|
||||
for model_to_remove in selections.remove_models:
|
||||
parts = model_to_remove.split("/")
|
||||
if len(parts) == 1:
|
||||
base_model, model_type, model_name = (None, None, model_to_remove)
|
||||
else:
|
||||
base_model, model_type, model_name = parts
|
||||
matches = installer.record_store.search_by_attr(
|
||||
base_model=BaseModelType(base_model) if base_model else None,
|
||||
model_type=ModelType(model_type) if model_type else None,
|
||||
model_name=model_name,
|
||||
)
|
||||
if len(matches) > 1:
|
||||
print(f"{model} is ambiguous. Please use model_type:model_name (e.g. main:my_model) to disambiguate.")
|
||||
elif not matches:
|
||||
print(f"{model}: unknown model")
|
||||
else:
|
||||
for m in matches:
|
||||
print(f"Deleting {m.type}:{m.name}")
|
||||
installer.delete(m.key)
|
||||
|
||||
installer.wait_for_installs()
|
@ -849,7 +849,7 @@ def migrate_if_needed(opt: Namespace, root: Path) -> bool:
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def main():
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="InvokeAI model downloader")
|
||||
parser.add_argument(
|
||||
"--skip-sd-weights",
|
||||
|
177
invokeai/backend/model_manager/merge.py
Normal file
177
invokeai/backend/model_manager/merge.py
Normal file
@ -0,0 +1,177 @@
|
||||
"""
|
||||
invokeai.backend.model_manager.merge exports:
|
||||
merge_diffusion_models() -- combine multiple models by location and return a pipeline object
|
||||
merge_diffusion_models_and_commit() -- combine multiple models by ModelManager ID and write to models.yaml
|
||||
|
||||
Copyright (c) 2023 Lincoln Stein and the InvokeAI Development Team
|
||||
"""
|
||||
|
||||
import warnings
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Optional, Set
|
||||
|
||||
import torch
|
||||
from diffusers import AutoPipelineForText2Image
|
||||
from diffusers import logging as dlogging
|
||||
|
||||
from invokeai.app.services.model_install import ModelInstallServiceBase
|
||||
from invokeai.backend.util.devices import choose_torch_device, torch_dtype
|
||||
|
||||
from . import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelType,
|
||||
ModelVariantType,
|
||||
)
|
||||
from .config import MainDiffusersConfig
|
||||
|
||||
|
||||
class MergeInterpolationMethod(str, Enum):
|
||||
WeightedSum = "weighted_sum"
|
||||
Sigmoid = "sigmoid"
|
||||
InvSigmoid = "inv_sigmoid"
|
||||
AddDifference = "add_difference"
|
||||
|
||||
|
||||
class ModelMerger(object):
|
||||
"""Wrapper class for model merge function."""
|
||||
|
||||
def __init__(self, installer: ModelInstallServiceBase):
|
||||
"""
|
||||
Initialize a ModelMerger object.
|
||||
|
||||
:param store: Underlying storage manager for the running process.
|
||||
:param config: InvokeAIAppConfig object (if not provided, default will be selected).
|
||||
"""
|
||||
self._installer = installer
|
||||
|
||||
def merge_diffusion_models(
|
||||
self,
|
||||
model_paths: List[Path],
|
||||
alpha: float = 0.5,
|
||||
interp: Optional[MergeInterpolationMethod] = None,
|
||||
force: bool = False,
|
||||
variant: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any: # pipe.merge is an untyped function.
|
||||
"""
|
||||
:param model_paths: up to three models, designated by their local paths or HuggingFace repo_ids
|
||||
:param alpha: The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
|
||||
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
|
||||
:param interp: The interpolation method to use for the merging. Supports "sigmoid", "inv_sigmoid", "add_difference" and None.
|
||||
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported.
|
||||
:param force: Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
|
||||
|
||||
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
|
||||
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
|
||||
"""
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
verbosity = dlogging.get_verbosity()
|
||||
dlogging.set_verbosity_error()
|
||||
dtype = torch.float16 if variant == "fp16" else torch_dtype(choose_torch_device())
|
||||
|
||||
# Note that checkpoint_merger will not work with downloaded HuggingFace fp16 models
|
||||
# until upstream https://github.com/huggingface/diffusers/pull/6670 is merged and released.
|
||||
pipe = AutoPipelineForText2Image.from_pretrained(
|
||||
model_paths[0],
|
||||
custom_pipeline="checkpoint_merger",
|
||||
torch_dtype=dtype,
|
||||
variant=variant,
|
||||
)
|
||||
merged_pipe = pipe.merge(
|
||||
pretrained_model_name_or_path_list=model_paths,
|
||||
alpha=alpha,
|
||||
interp=interp.value if interp else None, # diffusers API treats None as "weighted sum"
|
||||
force=force,
|
||||
torch_dtype=dtype,
|
||||
variant=variant,
|
||||
**kwargs,
|
||||
)
|
||||
dlogging.set_verbosity(verbosity)
|
||||
return merged_pipe
|
||||
|
||||
def merge_diffusion_models_and_save(
|
||||
self,
|
||||
model_keys: List[str],
|
||||
merged_model_name: str,
|
||||
alpha: float = 0.5,
|
||||
force: bool = False,
|
||||
interp: Optional[MergeInterpolationMethod] = None,
|
||||
merge_dest_directory: Optional[Path] = None,
|
||||
variant: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> AnyModelConfig:
|
||||
"""
|
||||
:param models: up to three models, designated by their InvokeAI models.yaml model name
|
||||
:param merged_model_name: name for new model
|
||||
:param alpha: The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
|
||||
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
|
||||
:param interp: The interpolation method to use for the merging. Supports "weighted_average", "sigmoid", "inv_sigmoid", "add_difference" and None.
|
||||
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported. Add_difference is A+(B-C).
|
||||
:param force: Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
|
||||
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
|
||||
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
|
||||
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
|
||||
"""
|
||||
model_paths: List[Path] = []
|
||||
model_names: List[str] = []
|
||||
config = self._installer.app_config
|
||||
store = self._installer.record_store
|
||||
base_models: Set[BaseModelType] = set()
|
||||
vae = None
|
||||
variant = None if self._installer.app_config.full_precision else "fp16"
|
||||
|
||||
assert (
|
||||
len(model_keys) <= 2 or interp == MergeInterpolationMethod.AddDifference
|
||||
), "When merging three models, only the 'add_difference' merge method is supported"
|
||||
|
||||
for key in model_keys:
|
||||
info = store.get_model(key)
|
||||
model_names.append(info.name)
|
||||
assert isinstance(
|
||||
info, MainDiffusersConfig
|
||||
), f"{info.name} ({info.key}) is not a diffusers model. It must be optimized before merging"
|
||||
assert info.variant == ModelVariantType(
|
||||
"normal"
|
||||
), f"{info.name} ({info.key}) is a {info.variant} model, which cannot currently be merged"
|
||||
|
||||
# pick up the first model's vae
|
||||
if key == model_keys[0]:
|
||||
vae = info.vae
|
||||
|
||||
# tally base models used
|
||||
base_models.add(info.base)
|
||||
model_paths.extend([config.models_path / info.path])
|
||||
|
||||
assert len(base_models) == 1, f"All models to merge must have same base model, but found bases {base_models}"
|
||||
base_model = base_models.pop()
|
||||
|
||||
merge_method = None if interp == "weighted_sum" else MergeInterpolationMethod(interp)
|
||||
merged_pipe = self.merge_diffusion_models(model_paths, alpha, merge_method, force, variant=variant, **kwargs)
|
||||
dump_path = (
|
||||
Path(merge_dest_directory)
|
||||
if merge_dest_directory
|
||||
else config.models_path / base_model.value / ModelType.Main.value
|
||||
)
|
||||
dump_path.mkdir(parents=True, exist_ok=True)
|
||||
dump_path = dump_path / merged_model_name
|
||||
|
||||
dtype = torch.float16 if variant == "fp16" else torch_dtype(choose_torch_device())
|
||||
merged_pipe.save_pretrained(dump_path.as_posix(), safe_serialization=True, torch_dtype=dtype, variant=variant)
|
||||
|
||||
# register model and get its unique key
|
||||
key = self._installer.register_path(dump_path)
|
||||
|
||||
# update model's config
|
||||
model_config = self._installer.record_store.get_model(key)
|
||||
model_config.update(
|
||||
{
|
||||
"name": merged_model_name,
|
||||
"description": f"Merge of models {', '.join(model_names)}",
|
||||
"vae": vae,
|
||||
}
|
||||
)
|
||||
self._installer.record_store.update_model(key, model_config)
|
||||
return model_config
|
@ -170,6 +170,8 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
|
||||
if model_id is None:
|
||||
version_url = CIVITAI_VERSION_ENDPOINT + str(version_id)
|
||||
version = self._requests.get(version_url).json()
|
||||
if error := version.get("error"):
|
||||
raise UnknownMetadataException(error)
|
||||
model_id = version["modelId"]
|
||||
|
||||
model_url = CIVITAI_MODEL_ENDPOINT + str(model_id)
|
||||
|
@ -11,6 +11,7 @@ import logging
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
@ -30,8 +31,6 @@ from diffusers.optimization import get_scheduler
|
||||
from diffusers.utils import check_min_version
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from huggingface_hub import HfFolder, Repository, whoami
|
||||
|
||||
# TODO: remove and import from diffusers.utils when the new version of diffusers is released
|
||||
from packaging import version
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset
|
||||
@ -41,8 +40,8 @@ from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
# invokeai stuff
|
||||
from invokeai.app.services.config import InvokeAIAppConfig, PagingArgumentParser
|
||||
from invokeai.app.services.model_manager import ModelManagerService
|
||||
from invokeai.backend.model_management.models import SubModelType
|
||||
from invokeai.backend.install.install_helper import initialize_record_store
|
||||
from invokeai.backend.model_manager import BaseModelType, ModelType
|
||||
|
||||
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
|
||||
PIL_INTERPOLATION = {
|
||||
@ -77,7 +76,7 @@ def save_progress(text_encoder, placeholder_token_id, accelerator, placeholder_t
|
||||
torch.save(learned_embeds_dict, save_path)
|
||||
|
||||
|
||||
def parse_args():
|
||||
def parse_args() -> Namespace:
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
parser = PagingArgumentParser(description="Textual inversion training")
|
||||
general_group = parser.add_argument_group("General")
|
||||
@ -444,7 +443,7 @@ class TextualInversionDataset(Dataset):
|
||||
self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small
|
||||
self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)
|
||||
|
||||
def __len__(self):
|
||||
def __len__(self) -> int:
|
||||
return self._length
|
||||
|
||||
def __getitem__(self, i):
|
||||
@ -509,11 +508,10 @@ def do_textual_inversion_training(
|
||||
initializer_token: str,
|
||||
save_steps: int = 500,
|
||||
only_save_embeds: bool = False,
|
||||
revision: str = None,
|
||||
tokenizer_name: str = None,
|
||||
tokenizer_name: Optional[str] = None,
|
||||
learnable_property: str = "object",
|
||||
repeats: int = 100,
|
||||
seed: int = None,
|
||||
seed: Optional[int] = None,
|
||||
resolution: int = 512,
|
||||
center_crop: bool = False,
|
||||
train_batch_size: int = 16,
|
||||
@ -530,18 +528,18 @@ def do_textual_inversion_training(
|
||||
adam_weight_decay: float = 1e-02,
|
||||
adam_epsilon: float = 1e-08,
|
||||
push_to_hub: bool = False,
|
||||
hub_token: str = None,
|
||||
hub_token: Optional[str] = None,
|
||||
logging_dir: Path = Path("logs"),
|
||||
mixed_precision: str = "fp16",
|
||||
allow_tf32: bool = False,
|
||||
report_to: str = "tensorboard",
|
||||
local_rank: int = -1,
|
||||
checkpointing_steps: int = 500,
|
||||
resume_from_checkpoint: Path = None,
|
||||
resume_from_checkpoint: Optional[Path] = None,
|
||||
enable_xformers_memory_efficient_attention: bool = False,
|
||||
hub_model_id: str = None,
|
||||
hub_model_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
) -> None:
|
||||
assert model, "Please specify a base model with --model"
|
||||
assert train_data_dir, "Please specify a directory containing the training images using --train_data_dir"
|
||||
assert placeholder_token, "Please specify a trigger term using --placeholder_token"
|
||||
@ -564,8 +562,6 @@ def do_textual_inversion_training(
|
||||
project_config=accelerator_config,
|
||||
)
|
||||
|
||||
model_manager = ModelManagerService(config, logger)
|
||||
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
@ -603,44 +599,37 @@ def do_textual_inversion_training(
|
||||
elif output_dir is not None:
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
known_models = model_manager.model_names()
|
||||
model_name = model.split("/")[-1]
|
||||
model_meta = next((mm for mm in known_models if mm[0].endswith(model_name)), None)
|
||||
assert model_meta is not None, f"Unknown model: {model}"
|
||||
model_info = model_manager.model_info(*model_meta)
|
||||
assert model_info["model_format"] == "diffusers", "This script only works with models of type 'diffusers'"
|
||||
tokenizer_info = model_manager.get_model(*model_meta, submodel=SubModelType.Tokenizer)
|
||||
noise_scheduler_info = model_manager.get_model(*model_meta, submodel=SubModelType.Scheduler)
|
||||
text_encoder_info = model_manager.get_model(*model_meta, submodel=SubModelType.TextEncoder)
|
||||
vae_info = model_manager.get_model(*model_meta, submodel=SubModelType.Vae)
|
||||
unet_info = model_manager.get_model(*model_meta, submodel=SubModelType.UNet)
|
||||
model_records = initialize_record_store(config)
|
||||
base, type, name = model.split("/") # note frontend still returns old-style keys
|
||||
try:
|
||||
model_config = model_records.search_by_attr(
|
||||
model_name=name, model_type=ModelType(type), base_model=BaseModelType(base)
|
||||
)[0]
|
||||
except IndexError:
|
||||
raise Exception(f"Unknown model {model}")
|
||||
model_path = config.models_path / model_config.path
|
||||
|
||||
pipeline_args = {"local_files_only": True}
|
||||
if tokenizer_name:
|
||||
tokenizer = CLIPTokenizer.from_pretrained(tokenizer_name, **pipeline_args)
|
||||
else:
|
||||
tokenizer = CLIPTokenizer.from_pretrained(tokenizer_info.location, subfolder="tokenizer", **pipeline_args)
|
||||
tokenizer = CLIPTokenizer.from_pretrained(model_path, subfolder="tokenizer", **pipeline_args)
|
||||
|
||||
# Load scheduler and models
|
||||
noise_scheduler = DDPMScheduler.from_pretrained(
|
||||
noise_scheduler_info.location, subfolder="scheduler", **pipeline_args
|
||||
)
|
||||
noise_scheduler = DDPMScheduler.from_pretrained(model_path, subfolder="scheduler", **pipeline_args)
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
text_encoder_info.location,
|
||||
model_path,
|
||||
subfolder="text_encoder",
|
||||
revision=revision,
|
||||
**pipeline_args,
|
||||
)
|
||||
vae = AutoencoderKL.from_pretrained(
|
||||
vae_info.location,
|
||||
model_path,
|
||||
subfolder="vae",
|
||||
revision=revision,
|
||||
**pipeline_args,
|
||||
)
|
||||
unet = UNet2DConditionModel.from_pretrained(
|
||||
unet_info.location,
|
||||
model_path,
|
||||
subfolder="unet",
|
||||
revision=revision,
|
||||
**pipeline_args,
|
||||
)
|
||||
|
||||
@ -728,7 +717,7 @@ def do_textual_inversion_training(
|
||||
max_train_steps = num_train_epochs * num_update_steps_per_epoch
|
||||
overrode_max_train_steps = True
|
||||
|
||||
lr_scheduler = get_scheduler(
|
||||
scheduler = get_scheduler(
|
||||
lr_scheduler,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=lr_warmup_steps * gradient_accumulation_steps,
|
||||
@ -737,7 +726,7 @@ def do_textual_inversion_training(
|
||||
|
||||
# Prepare everything with our `accelerator`.
|
||||
text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
text_encoder, optimizer, train_dataloader, lr_scheduler
|
||||
text_encoder, optimizer, train_dataloader, scheduler
|
||||
)
|
||||
|
||||
# For mixed precision training we cast the unet and vae weights to half-precision
|
||||
@ -863,7 +852,7 @@ def do_textual_inversion_training(
|
||||
accelerator.backward(loss)
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Let's make sure we don't update any embedding weights besides the newly added token
|
||||
@ -893,7 +882,7 @@ def do_textual_inversion_training(
|
||||
accelerator.save_state(save_path)
|
||||
logger.info(f"Saved state to {save_path}")
|
||||
|
||||
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
||||
logs = {"loss": loss.detach().item(), "lr": scheduler.get_last_lr()[0]}
|
||||
progress_bar.set_postfix(**logs)
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
@ -910,7 +899,7 @@ def do_textual_inversion_training(
|
||||
save_full_model = not only_save_embeds
|
||||
if save_full_model:
|
||||
pipeline = StableDiffusionPipeline.from_pretrained(
|
||||
unet_info.location,
|
||||
model_path,
|
||||
text_encoder=accelerator.unwrap_model(text_encoder),
|
||||
vae=vae,
|
||||
unet=unet,
|
||||
|
Reference in New Issue
Block a user