Port the command-line tools to use model_manager2 (#5546)

* Port the command-line tools to use model_manager2

1.Reimplement the following:

  - invokeai-model-install
  - invokeai-merge
  - invokeai-ti

  To avoid breaking the original modeal manager, the udpated tools
  have been renamed invokeai-model-install2 and invokeai-merge2. The
  textual inversion training script should continue to work with
  existing installations. The "starter" models now live in
  `invokeai/configs/INITIAL_MODELS2.yaml`.

  When the full model manager 2 is in place and working, I'll rename
  these files and commands.

2. Add the `merge` route to the web API. This will merge two or three models,
   resulting a new one.

   - Note that because the model installer selectively installs the `fp16` variant
     of models (rather than both 16- and 32-bit versions as previous),
     the diffusers merge script will choke on any huggingface diffuserse models
     that were downloaded with the new installer. Previously-downloaded models
     should continue to merge correctly. I have a PR
     upstream https://github.com/huggingface/diffusers/pull/6670 to fix
     this.

3. (more important!)
  During implementation of the CLI tools, found and fixed a number of small
  runtime bugs in the model_manager2 implementation:

  - During model database migration, if a registered models file was
    not found on disk, the migration would be aborted. Now the
    offending model is skipped with a log warning.

  - Caught and fixed a condition in which the installer would download the
    entire diffusers repo when the user provided a single `.safetensors`
    file URL.

  - Caught and fixed a condition in which the installer would raise an
    exception and stop the app when a request for an unknown model's metadata
    was passed to Civitai. Now an error is logged and the installer continues.

  - Replaced the LoWRA starter LoRA with FlatColor. The former has been removed
    from Civitai.

* fix ruff issue

---------

Co-authored-by: Lincoln Stein <lstein@gmail.com>
This commit is contained in:
Lincoln Stein
2024-02-02 12:18:47 -05:00
committed by GitHub
parent d3320dc4ee
commit f2777f5096
17 changed files with 2297 additions and 78 deletions

View File

@ -0,0 +1,281 @@
"""Utility (backend) functions used by model_install.py"""
import re
from logging import Logger
from pathlib import Path
from typing import Any, Dict, List, Optional
import omegaconf
from huggingface_hub import HfFolder
from pydantic import BaseModel, Field
from pydantic.dataclasses import dataclass
from pydantic.networks import AnyHttpUrl
from requests import HTTPError
from tqdm import tqdm
import invokeai.configs as configs
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.download import DownloadQueueService
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.image_files.image_files_disk import DiskImageFileStorage
from invokeai.app.services.model_install import (
HFModelSource,
LocalModelSource,
ModelInstallService,
ModelInstallServiceBase,
ModelSource,
URLModelSource,
)
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
from invokeai.backend.model_manager import (
BaseModelType,
InvalidModelConfigException,
ModelType,
)
from invokeai.backend.model_manager.metadata import UnknownMetadataException
from invokeai.backend.util.logging import InvokeAILogger
# name of the starter models file
INITIAL_MODELS = "INITIAL_MODELS2.yaml"
def initialize_record_store(app_config: InvokeAIAppConfig) -> ModelRecordServiceBase:
"""Return an initialized ModelConfigRecordServiceBase object."""
logger = InvokeAILogger.get_logger(config=app_config)
image_files = DiskImageFileStorage(f"{app_config.output_path}/images")
db = init_db(config=app_config, logger=logger, image_files=image_files)
obj: ModelRecordServiceBase = ModelRecordServiceSQL(db)
return obj
def initialize_installer(
app_config: InvokeAIAppConfig, event_bus: Optional[EventServiceBase] = None
) -> ModelInstallServiceBase:
"""Return an initialized ModelInstallService object."""
record_store = initialize_record_store(app_config)
metadata_store = record_store.metadata_store
download_queue = DownloadQueueService()
installer = ModelInstallService(
app_config=app_config,
record_store=record_store,
metadata_store=metadata_store,
download_queue=download_queue,
event_bus=event_bus,
)
download_queue.start()
installer.start()
return installer
class UnifiedModelInfo(BaseModel):
"""Catchall class for information in INITIAL_MODELS2.yaml."""
name: Optional[str] = None
base: Optional[BaseModelType] = None
type: Optional[ModelType] = None
source: Optional[str] = None
subfolder: Optional[str] = None
description: Optional[str] = None
recommended: bool = False
installed: bool = False
default: bool = False
requires: List[str] = Field(default_factory=list)
@dataclass
class InstallSelections:
"""Lists of models to install and remove."""
install_models: List[UnifiedModelInfo] = Field(default_factory=list)
remove_models: List[str] = Field(default_factory=list)
class TqdmEventService(EventServiceBase):
"""An event service to track downloads."""
def __init__(self) -> None:
"""Create a new TqdmEventService object."""
super().__init__()
self._bars: Dict[str, tqdm] = {}
self._last: Dict[str, int] = {}
def dispatch(self, event_name: str, payload: Any) -> None:
"""Dispatch an event by appending it to self.events."""
if payload["event"] == "model_install_downloading":
data = payload["data"]
dest = data["local_path"]
total_bytes = data["total_bytes"]
bytes = data["bytes"]
if dest not in self._bars:
self._bars[dest] = tqdm(desc=Path(dest).name, initial=0, total=total_bytes, unit="iB", unit_scale=True)
self._last[dest] = 0
self._bars[dest].update(bytes - self._last[dest])
self._last[dest] = bytes
class InstallHelper(object):
"""Capture information stored jointly in INITIAL_MODELS.yaml and the installed models db."""
def __init__(self, app_config: InvokeAIAppConfig, logger: Logger):
"""Create new InstallHelper object."""
self._app_config = app_config
self.all_models: Dict[str, UnifiedModelInfo] = {}
omega = omegaconf.OmegaConf.load(Path(configs.__path__[0]) / INITIAL_MODELS)
assert isinstance(omega, omegaconf.dictconfig.DictConfig)
self._installer = initialize_installer(app_config, TqdmEventService())
self._initial_models = omega
self._installed_models: List[str] = []
self._starter_models: List[str] = []
self._default_model: Optional[str] = None
self._logger = logger
self._initialize_model_lists()
@property
def installer(self) -> ModelInstallServiceBase:
"""Return the installer object used internally."""
return self._installer
def _initialize_model_lists(self) -> None:
"""
Initialize our model slots.
Set up the following:
installed_models -- list of installed model keys
starter_models -- list of starter model keys from INITIAL_MODELS
all_models -- dict of key => UnifiedModelInfo
default_model -- key to default model
"""
# previously-installed models
for model in self._installer.record_store.all_models():
info = UnifiedModelInfo.parse_obj(model.dict())
info.installed = True
model_key = f"{model.base.value}/{model.type.value}/{model.name}"
self.all_models[model_key] = info
self._installed_models.append(model_key)
for key in self._initial_models.keys():
assert isinstance(key, str)
if key in self.all_models:
# we want to preserve the description
description = self.all_models[key].description or self._initial_models[key].get("description")
self.all_models[key].description = description
else:
base_model, model_type, model_name = key.split("/")
info = UnifiedModelInfo(
name=model_name,
type=ModelType(model_type),
base=BaseModelType(base_model),
source=self._initial_models[key].source,
description=self._initial_models[key].get("description"),
recommended=self._initial_models[key].get("recommended", False),
default=self._initial_models[key].get("default", False),
subfolder=self._initial_models[key].get("subfolder"),
requires=list(self._initial_models[key].get("requires", [])),
)
self.all_models[key] = info
if not self.default_model():
self._default_model = key
elif self._initial_models[key].get("default", False):
self._default_model = key
self._starter_models.append(key)
# previously-installed models
for model in self._installer.record_store.all_models():
info = UnifiedModelInfo.parse_obj(model.dict())
info.installed = True
model_key = f"{model.base.value}/{model.type.value}/{model.name}"
self.all_models[model_key] = info
self._installed_models.append(model_key)
def recommended_models(self) -> List[UnifiedModelInfo]:
"""List of the models recommended in INITIAL_MODELS.yaml."""
return [self._to_model(x) for x in self._starter_models if self._to_model(x).recommended]
def installed_models(self) -> List[UnifiedModelInfo]:
"""List of models already installed."""
return [self._to_model(x) for x in self._installed_models]
def starter_models(self) -> List[UnifiedModelInfo]:
"""List of starter models."""
return [self._to_model(x) for x in self._starter_models]
def default_model(self) -> Optional[UnifiedModelInfo]:
"""Return the default model."""
return self._to_model(self._default_model) if self._default_model else None
def _to_model(self, key: str) -> UnifiedModelInfo:
return self.all_models[key]
def _add_required_models(self, model_list: List[UnifiedModelInfo]) -> None:
installed = {x.source for x in self.installed_models()}
reverse_source = {x.source: x for x in self.all_models.values()}
additional_models: List[UnifiedModelInfo] = []
for model_info in model_list:
for requirement in model_info.requires:
if requirement not in installed and reverse_source.get(requirement):
additional_models.append(reverse_source[requirement])
model_list.extend(additional_models)
def _make_install_source(self, model_info: UnifiedModelInfo) -> ModelSource:
assert model_info.source
model_path_id_or_url = model_info.source.strip("\"' ")
model_path = Path(model_path_id_or_url)
if model_path.exists(): # local file on disk
return LocalModelSource(path=model_path.absolute(), inplace=True)
if re.match(r"^[^/]+/[^/]+$", model_path_id_or_url): # hugging face repo_id
return HFModelSource(
repo_id=model_path_id_or_url,
access_token=HfFolder.get_token(),
subfolder=model_info.subfolder,
)
if re.match(r"^(http|https):", model_path_id_or_url):
return URLModelSource(url=AnyHttpUrl(model_path_id_or_url))
raise ValueError(f"Unsupported model source: {model_path_id_or_url}")
def add_or_delete(self, selections: InstallSelections) -> None:
"""Add or delete selected models."""
installer = self._installer
self._add_required_models(selections.install_models)
for model in selections.install_models:
source = self._make_install_source(model)
config = (
{
"description": model.description,
"name": model.name,
}
if model.name
else None
)
try:
installer.import_model(
source=source,
config=config,
)
except (UnknownMetadataException, InvalidModelConfigException, HTTPError, OSError) as e:
self._logger.warning(f"{source}: {e}")
for model_to_remove in selections.remove_models:
parts = model_to_remove.split("/")
if len(parts) == 1:
base_model, model_type, model_name = (None, None, model_to_remove)
else:
base_model, model_type, model_name = parts
matches = installer.record_store.search_by_attr(
base_model=BaseModelType(base_model) if base_model else None,
model_type=ModelType(model_type) if model_type else None,
model_name=model_name,
)
if len(matches) > 1:
print(f"{model} is ambiguous. Please use model_type:model_name (e.g. main:my_model) to disambiguate.")
elif not matches:
print(f"{model}: unknown model")
else:
for m in matches:
print(f"Deleting {m.type}:{m.name}")
installer.delete(m.key)
installer.wait_for_installs()

View File

@ -849,7 +849,7 @@ def migrate_if_needed(opt: Namespace, root: Path) -> bool:
# -------------------------------------
def main():
def main() -> None:
parser = argparse.ArgumentParser(description="InvokeAI model downloader")
parser.add_argument(
"--skip-sd-weights",

View File

@ -0,0 +1,177 @@
"""
invokeai.backend.model_manager.merge exports:
merge_diffusion_models() -- combine multiple models by location and return a pipeline object
merge_diffusion_models_and_commit() -- combine multiple models by ModelManager ID and write to models.yaml
Copyright (c) 2023 Lincoln Stein and the InvokeAI Development Team
"""
import warnings
from enum import Enum
from pathlib import Path
from typing import Any, List, Optional, Set
import torch
from diffusers import AutoPipelineForText2Image
from diffusers import logging as dlogging
from invokeai.app.services.model_install import ModelInstallServiceBase
from invokeai.backend.util.devices import choose_torch_device, torch_dtype
from . import (
AnyModelConfig,
BaseModelType,
ModelType,
ModelVariantType,
)
from .config import MainDiffusersConfig
class MergeInterpolationMethod(str, Enum):
WeightedSum = "weighted_sum"
Sigmoid = "sigmoid"
InvSigmoid = "inv_sigmoid"
AddDifference = "add_difference"
class ModelMerger(object):
"""Wrapper class for model merge function."""
def __init__(self, installer: ModelInstallServiceBase):
"""
Initialize a ModelMerger object.
:param store: Underlying storage manager for the running process.
:param config: InvokeAIAppConfig object (if not provided, default will be selected).
"""
self._installer = installer
def merge_diffusion_models(
self,
model_paths: List[Path],
alpha: float = 0.5,
interp: Optional[MergeInterpolationMethod] = None,
force: bool = False,
variant: Optional[str] = None,
**kwargs: Any,
) -> Any: # pipe.merge is an untyped function.
"""
:param model_paths: up to three models, designated by their local paths or HuggingFace repo_ids
:param alpha: The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
:param interp: The interpolation method to use for the merging. Supports "sigmoid", "inv_sigmoid", "add_difference" and None.
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported.
:param force: Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
"""
with warnings.catch_warnings():
warnings.simplefilter("ignore")
verbosity = dlogging.get_verbosity()
dlogging.set_verbosity_error()
dtype = torch.float16 if variant == "fp16" else torch_dtype(choose_torch_device())
# Note that checkpoint_merger will not work with downloaded HuggingFace fp16 models
# until upstream https://github.com/huggingface/diffusers/pull/6670 is merged and released.
pipe = AutoPipelineForText2Image.from_pretrained(
model_paths[0],
custom_pipeline="checkpoint_merger",
torch_dtype=dtype,
variant=variant,
)
merged_pipe = pipe.merge(
pretrained_model_name_or_path_list=model_paths,
alpha=alpha,
interp=interp.value if interp else None, # diffusers API treats None as "weighted sum"
force=force,
torch_dtype=dtype,
variant=variant,
**kwargs,
)
dlogging.set_verbosity(verbosity)
return merged_pipe
def merge_diffusion_models_and_save(
self,
model_keys: List[str],
merged_model_name: str,
alpha: float = 0.5,
force: bool = False,
interp: Optional[MergeInterpolationMethod] = None,
merge_dest_directory: Optional[Path] = None,
variant: Optional[str] = None,
**kwargs: Any,
) -> AnyModelConfig:
"""
:param models: up to three models, designated by their InvokeAI models.yaml model name
:param merged_model_name: name for new model
:param alpha: The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
:param interp: The interpolation method to use for the merging. Supports "weighted_average", "sigmoid", "inv_sigmoid", "add_difference" and None.
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported. Add_difference is A+(B-C).
:param force: Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
"""
model_paths: List[Path] = []
model_names: List[str] = []
config = self._installer.app_config
store = self._installer.record_store
base_models: Set[BaseModelType] = set()
vae = None
variant = None if self._installer.app_config.full_precision else "fp16"
assert (
len(model_keys) <= 2 or interp == MergeInterpolationMethod.AddDifference
), "When merging three models, only the 'add_difference' merge method is supported"
for key in model_keys:
info = store.get_model(key)
model_names.append(info.name)
assert isinstance(
info, MainDiffusersConfig
), f"{info.name} ({info.key}) is not a diffusers model. It must be optimized before merging"
assert info.variant == ModelVariantType(
"normal"
), f"{info.name} ({info.key}) is a {info.variant} model, which cannot currently be merged"
# pick up the first model's vae
if key == model_keys[0]:
vae = info.vae
# tally base models used
base_models.add(info.base)
model_paths.extend([config.models_path / info.path])
assert len(base_models) == 1, f"All models to merge must have same base model, but found bases {base_models}"
base_model = base_models.pop()
merge_method = None if interp == "weighted_sum" else MergeInterpolationMethod(interp)
merged_pipe = self.merge_diffusion_models(model_paths, alpha, merge_method, force, variant=variant, **kwargs)
dump_path = (
Path(merge_dest_directory)
if merge_dest_directory
else config.models_path / base_model.value / ModelType.Main.value
)
dump_path.mkdir(parents=True, exist_ok=True)
dump_path = dump_path / merged_model_name
dtype = torch.float16 if variant == "fp16" else torch_dtype(choose_torch_device())
merged_pipe.save_pretrained(dump_path.as_posix(), safe_serialization=True, torch_dtype=dtype, variant=variant)
# register model and get its unique key
key = self._installer.register_path(dump_path)
# update model's config
model_config = self._installer.record_store.get_model(key)
model_config.update(
{
"name": merged_model_name,
"description": f"Merge of models {', '.join(model_names)}",
"vae": vae,
}
)
self._installer.record_store.update_model(key, model_config)
return model_config

View File

@ -170,6 +170,8 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
if model_id is None:
version_url = CIVITAI_VERSION_ENDPOINT + str(version_id)
version = self._requests.get(version_url).json()
if error := version.get("error"):
raise UnknownMetadataException(error)
model_id = version["modelId"]
model_url = CIVITAI_MODEL_ENDPOINT + str(model_id)

View File

@ -11,6 +11,7 @@ import logging
import math
import os
import random
from argparse import Namespace
from pathlib import Path
from typing import Optional
@ -30,8 +31,6 @@ from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version
from diffusers.utils.import_utils import is_xformers_available
from huggingface_hub import HfFolder, Repository, whoami
# TODO: remove and import from diffusers.utils when the new version of diffusers is released
from packaging import version
from PIL import Image
from torch.utils.data import Dataset
@ -41,8 +40,8 @@ from transformers import CLIPTextModel, CLIPTokenizer
# invokeai stuff
from invokeai.app.services.config import InvokeAIAppConfig, PagingArgumentParser
from invokeai.app.services.model_manager import ModelManagerService
from invokeai.backend.model_management.models import SubModelType
from invokeai.backend.install.install_helper import initialize_record_store
from invokeai.backend.model_manager import BaseModelType, ModelType
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
PIL_INTERPOLATION = {
@ -77,7 +76,7 @@ def save_progress(text_encoder, placeholder_token_id, accelerator, placeholder_t
torch.save(learned_embeds_dict, save_path)
def parse_args():
def parse_args() -> Namespace:
config = InvokeAIAppConfig.get_config()
parser = PagingArgumentParser(description="Textual inversion training")
general_group = parser.add_argument_group("General")
@ -444,7 +443,7 @@ class TextualInversionDataset(Dataset):
self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small
self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)
def __len__(self):
def __len__(self) -> int:
return self._length
def __getitem__(self, i):
@ -509,11 +508,10 @@ def do_textual_inversion_training(
initializer_token: str,
save_steps: int = 500,
only_save_embeds: bool = False,
revision: str = None,
tokenizer_name: str = None,
tokenizer_name: Optional[str] = None,
learnable_property: str = "object",
repeats: int = 100,
seed: int = None,
seed: Optional[int] = None,
resolution: int = 512,
center_crop: bool = False,
train_batch_size: int = 16,
@ -530,18 +528,18 @@ def do_textual_inversion_training(
adam_weight_decay: float = 1e-02,
adam_epsilon: float = 1e-08,
push_to_hub: bool = False,
hub_token: str = None,
hub_token: Optional[str] = None,
logging_dir: Path = Path("logs"),
mixed_precision: str = "fp16",
allow_tf32: bool = False,
report_to: str = "tensorboard",
local_rank: int = -1,
checkpointing_steps: int = 500,
resume_from_checkpoint: Path = None,
resume_from_checkpoint: Optional[Path] = None,
enable_xformers_memory_efficient_attention: bool = False,
hub_model_id: str = None,
hub_model_id: Optional[str] = None,
**kwargs,
):
) -> None:
assert model, "Please specify a base model with --model"
assert train_data_dir, "Please specify a directory containing the training images using --train_data_dir"
assert placeholder_token, "Please specify a trigger term using --placeholder_token"
@ -564,8 +562,6 @@ def do_textual_inversion_training(
project_config=accelerator_config,
)
model_manager = ModelManagerService(config, logger)
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
@ -603,44 +599,37 @@ def do_textual_inversion_training(
elif output_dir is not None:
os.makedirs(output_dir, exist_ok=True)
known_models = model_manager.model_names()
model_name = model.split("/")[-1]
model_meta = next((mm for mm in known_models if mm[0].endswith(model_name)), None)
assert model_meta is not None, f"Unknown model: {model}"
model_info = model_manager.model_info(*model_meta)
assert model_info["model_format"] == "diffusers", "This script only works with models of type 'diffusers'"
tokenizer_info = model_manager.get_model(*model_meta, submodel=SubModelType.Tokenizer)
noise_scheduler_info = model_manager.get_model(*model_meta, submodel=SubModelType.Scheduler)
text_encoder_info = model_manager.get_model(*model_meta, submodel=SubModelType.TextEncoder)
vae_info = model_manager.get_model(*model_meta, submodel=SubModelType.Vae)
unet_info = model_manager.get_model(*model_meta, submodel=SubModelType.UNet)
model_records = initialize_record_store(config)
base, type, name = model.split("/") # note frontend still returns old-style keys
try:
model_config = model_records.search_by_attr(
model_name=name, model_type=ModelType(type), base_model=BaseModelType(base)
)[0]
except IndexError:
raise Exception(f"Unknown model {model}")
model_path = config.models_path / model_config.path
pipeline_args = {"local_files_only": True}
if tokenizer_name:
tokenizer = CLIPTokenizer.from_pretrained(tokenizer_name, **pipeline_args)
else:
tokenizer = CLIPTokenizer.from_pretrained(tokenizer_info.location, subfolder="tokenizer", **pipeline_args)
tokenizer = CLIPTokenizer.from_pretrained(model_path, subfolder="tokenizer", **pipeline_args)
# Load scheduler and models
noise_scheduler = DDPMScheduler.from_pretrained(
noise_scheduler_info.location, subfolder="scheduler", **pipeline_args
)
noise_scheduler = DDPMScheduler.from_pretrained(model_path, subfolder="scheduler", **pipeline_args)
text_encoder = CLIPTextModel.from_pretrained(
text_encoder_info.location,
model_path,
subfolder="text_encoder",
revision=revision,
**pipeline_args,
)
vae = AutoencoderKL.from_pretrained(
vae_info.location,
model_path,
subfolder="vae",
revision=revision,
**pipeline_args,
)
unet = UNet2DConditionModel.from_pretrained(
unet_info.location,
model_path,
subfolder="unet",
revision=revision,
**pipeline_args,
)
@ -728,7 +717,7 @@ def do_textual_inversion_training(
max_train_steps = num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
lr_scheduler = get_scheduler(
scheduler = get_scheduler(
lr_scheduler,
optimizer=optimizer,
num_warmup_steps=lr_warmup_steps * gradient_accumulation_steps,
@ -737,7 +726,7 @@ def do_textual_inversion_training(
# Prepare everything with our `accelerator`.
text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
text_encoder, optimizer, train_dataloader, lr_scheduler
text_encoder, optimizer, train_dataloader, scheduler
)
# For mixed precision training we cast the unet and vae weights to half-precision
@ -863,7 +852,7 @@ def do_textual_inversion_training(
accelerator.backward(loss)
optimizer.step()
lr_scheduler.step()
scheduler.step()
optimizer.zero_grad()
# Let's make sure we don't update any embedding weights besides the newly added token
@ -893,7 +882,7 @@ def do_textual_inversion_training(
accelerator.save_state(save_path)
logger.info(f"Saved state to {save_path}")
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
logs = {"loss": loss.detach().item(), "lr": scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
accelerator.log(logs, step=global_step)
@ -910,7 +899,7 @@ def do_textual_inversion_training(
save_full_model = not only_save_embeds
if save_full_model:
pipeline = StableDiffusionPipeline.from_pretrained(
unet_info.location,
model_path,
text_encoder=accelerator.unwrap_model(text_encoder),
vae=vae,
unet=unet,