correct fail-to-resume error

- applied https://github.com/huggingface/diffusers/pull/2072 to fix
  error in epoch calculation that caused script not to resume from
  latest checkpoint when asked to.
This commit is contained in:
Lincoln Stein 2023-01-23 21:04:07 -05:00
parent 292b0d70d8
commit 98e9721101

View File

@ -4,7 +4,6 @@
# and modified slightly by Lincoln Stein (@lstein) to work with InvokeAI
import argparse
from argparse import Namespace
import logging
import math
import os
@ -207,6 +206,12 @@ def parse_args():
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
parser.add_argument(
"--hub_model_id",
type=str,
default=None,
help="The name of the repository to keep in sync with the local `output_dir`.",
)
parser.add_argument(
"--logging_dir",
type=Path,
@ -455,7 +460,8 @@ def do_textual_inversion_training(
checkpointing_steps:int=500,
resume_from_checkpoint:Path=None,
enable_xformers_memory_efficient_attention:bool=False,
root_dir:Path=None
root_dir:Path=None,
hub_model_id:str=None,
):
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
if env_local_rank != -1 and env_local_rank != local_rank:
@ -631,7 +637,7 @@ def do_textual_inversion_training(
text_encoder, optimizer, train_dataloader, lr_scheduler
)
# For mixed precision training we cast the text_encoder and vae weights to half-precision
# For mixed precision training we cast the unet and vae weights to half-precision
# as these models are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
@ -674,25 +680,29 @@ def do_textual_inversion_training(
# Potentially load in the weights and states from a previous save
if resume_from_checkpoint:
try:
if resume_from_checkpoint != "latest":
path = os.path.basename(resume_from_checkpoint)
else:
# Get the most recent checkpoint
dirs = os.listdir(output_dir)
dirs = [d for d in dirs if d.startswith("checkpoint")]
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
path = dirs[-1]
if resume_from_checkpoint != "latest":
path = os.path.basename(resume_from_checkpoint)
else:
# Get the most recent checkpoint
dirs = os.listdir(output_dir)
dirs = [d for d in dirs if d.startswith("checkpoint")]
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
path = dirs[-1] if len(dirs) > 0 else None
if path is None:
accelerator.print(
f"Checkpoint '{resume_from_checkpoint}' does not exist. Starting a new training run."
)
resume_from_checkpoint = None
else:
accelerator.print(f"Resuming from checkpoint {path}")
accelerator.load_state(os.path.join(output_dir, path))
global_step = int(path.split("-")[1])
resume_global_step = global_step * gradient_accumulation_steps
first_epoch = resume_global_step // num_update_steps_per_epoch
resume_step = resume_global_step % num_update_steps_per_epoch
except:
logger.warn("No checkpoint available to resume from")
first_epoch = global_step // num_update_steps_per_epoch
resume_step = resume_global_step % (num_update_steps_per_epoch * gradient_accumulation_steps)
# Only show the progress bar once on each machine.
progress_bar = tqdm(range(global_step, max_train_steps), disable=not accelerator.is_local_main_process)
progress_bar.set_description("Steps")