Merge branch 'main' into lstein/default-model-install

This commit is contained in:
Lincoln Stein 2023-07-15 18:26:35 -04:00 committed by GitHub
commit 7fa394912d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -24,7 +24,7 @@ import torch.utils.checkpoint
import transformers import transformers
from accelerate import Accelerator from accelerate import Accelerator
from accelerate.logging import get_logger from accelerate.logging import get_logger
from accelerate.utils import set_seed from accelerate.utils import set_seed, ProjectConfiguration
from diffusers import ( from diffusers import (
AutoencoderKL, AutoencoderKL,
DDPMScheduler, DDPMScheduler,
@ -35,7 +35,6 @@ from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version from diffusers.utils import check_min_version
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
from huggingface_hub import HfFolder, Repository, whoami from huggingface_hub import HfFolder, Repository, whoami
from omegaconf import OmegaConf
# TODO: remove and import from diffusers.utils when the new version of diffusers is released # TODO: remove and import from diffusers.utils when the new version of diffusers is released
from packaging import version from packaging import version
@ -47,6 +46,8 @@ from transformers import CLIPTextModel, CLIPTokenizer
# invokeai stuff # invokeai stuff
from invokeai.app.services.config import InvokeAIAppConfig,PagingArgumentParser from invokeai.app.services.config import InvokeAIAppConfig,PagingArgumentParser
from invokeai.app.services.model_manager_service import ModelManagerService
from invokeai.backend.model_management.models import SubModelType
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"): if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
PIL_INTERPOLATION = { PIL_INTERPOLATION = {
@ -132,7 +133,7 @@ def parse_args():
model_group.add_argument( model_group.add_argument(
"--model", "--model",
type=str, type=str,
default="stable-diffusion-1.5", default="sd-1/main/stable-diffusion-v1-5",
help="Name of the diffusers model to train against, as defined in configs/models.yaml.", help="Name of the diffusers model to train against, as defined in configs/models.yaml.",
) )
model_group.add_argument( model_group.add_argument(
@ -565,7 +566,6 @@ def do_textual_inversion_training(
checkpointing_steps: int = 500, checkpointing_steps: int = 500,
resume_from_checkpoint: Path = None, resume_from_checkpoint: Path = None,
enable_xformers_memory_efficient_attention: bool = False, enable_xformers_memory_efficient_attention: bool = False,
root_dir: Path = None,
hub_model_id: str = None, hub_model_id: str = None,
**kwargs, **kwargs,
): ):
@ -584,13 +584,17 @@ def do_textual_inversion_training(
logging_dir = output_dir / logging_dir logging_dir = output_dir / logging_dir
accelerator_config = ProjectConfiguration()
accelerator_config.logging_dir = logging_dir
accelerator = Accelerator( accelerator = Accelerator(
gradient_accumulation_steps=gradient_accumulation_steps, gradient_accumulation_steps=gradient_accumulation_steps,
mixed_precision=mixed_precision, mixed_precision=mixed_precision,
log_with=report_to, log_with=report_to,
logging_dir=logging_dir, project_config=accelerator_config,
) )
model_manager = ModelManagerService(config,logger)
# Make one log on every process with the configuration for debugging. # Make one log on every process with the configuration for debugging.
logging.basicConfig( logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
@ -628,46 +632,46 @@ def do_textual_inversion_training(
elif output_dir is not None: elif output_dir is not None:
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
models_conf = OmegaConf.load(config.model_conf_path) known_models = model_manager.model_names()
model_conf = models_conf.get(model, None) model_name = model.split('/')[-1]
assert model_conf is not None, f"Unknown model: {model}" model_meta = next((mm for mm in known_models if mm[0].endswith(model_name)), None)
assert model_meta is not None, f"Unknown model: {model}"
model_info = model_manager.model_info(*model_meta)
assert ( assert (
model_conf.get("format", "diffusers") == "diffusers" model_info['model_format'] == "diffusers"
), "This script only works with models of type 'diffusers'" ), "This script only works with models of type 'diffusers'"
pretrained_model_name_or_path = model_conf.get("repo_id", None) or Path( tokenizer_info = model_manager.get_model(*model_meta, submodel=SubModelType.Tokenizer)
model_conf.get("path") noise_scheduler_info = model_manager.get_model(*model_meta, submodel=SubModelType.Scheduler)
) text_encoder_info = model_manager.get_model(*model_meta, submodel=SubModelType.TextEncoder)
assert ( vae_info = model_manager.get_model(*model_meta, submodel=SubModelType.Vae)
pretrained_model_name_or_path unet_info = model_manager.get_model(*model_meta, submodel=SubModelType.UNet)
), f"models.yaml error: neither 'repo_id' nor 'path' is defined for {model}"
pipeline_args = dict(cache_dir=config.cache_dir)
# Load tokenizer pipeline_args = dict(local_files_only=True)
if tokenizer_name: if tokenizer_name:
tokenizer = CLIPTokenizer.from_pretrained(tokenizer_name, **pipeline_args) tokenizer = CLIPTokenizer.from_pretrained(tokenizer_name, **pipeline_args)
else: else:
tokenizer = CLIPTokenizer.from_pretrained( tokenizer = CLIPTokenizer.from_pretrained(
pretrained_model_name_or_path, subfolder="tokenizer", **pipeline_args tokenizer_info.location, subfolder='tokenizer', **pipeline_args
) )
# Load scheduler and models # Load scheduler and models
noise_scheduler = DDPMScheduler.from_pretrained( noise_scheduler = DDPMScheduler.from_pretrained(
pretrained_model_name_or_path, subfolder="scheduler", **pipeline_args noise_scheduler_info.location, subfolder="scheduler", **pipeline_args
) )
text_encoder = CLIPTextModel.from_pretrained( text_encoder = CLIPTextModel.from_pretrained(
pretrained_model_name_or_path, text_encoder_info.location,
subfolder="text_encoder", subfolder="text_encoder",
revision=revision, revision=revision,
**pipeline_args, **pipeline_args,
) )
vae = AutoencoderKL.from_pretrained( vae = AutoencoderKL.from_pretrained(
pretrained_model_name_or_path, vae_info.location,
subfolder="vae", subfolder="vae",
revision=revision, revision=revision,
**pipeline_args, **pipeline_args,
) )
unet = UNet2DConditionModel.from_pretrained( unet = UNet2DConditionModel.from_pretrained(
pretrained_model_name_or_path, unet_info.location,
subfolder="unet", subfolder="unet",
revision=revision, revision=revision,
**pipeline_args, **pipeline_args,
@ -989,7 +993,7 @@ def do_textual_inversion_training(
save_full_model = not only_save_embeds save_full_model = not only_save_embeds
if save_full_model: if save_full_model:
pipeline = StableDiffusionPipeline.from_pretrained( pipeline = StableDiffusionPipeline.from_pretrained(
pretrained_model_name_or_path, unet_info.location,
text_encoder=accelerator.unwrap_model(text_encoder), text_encoder=accelerator.unwrap_model(text_encoder),
vae=vae, vae=vae,
unet=unet, unet=unet,