mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'sdxl-support' of github.com:invoke-ai/InvokeAI into sdxl-support
This commit is contained in:
commit
ab840742b0
@ -24,7 +24,7 @@ import torch.utils.checkpoint
|
||||
import transformers
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import set_seed
|
||||
from accelerate.utils import set_seed, ProjectConfiguration
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDPMScheduler,
|
||||
@ -35,7 +35,6 @@ from diffusers.optimization import get_scheduler
|
||||
from diffusers.utils import check_min_version
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from huggingface_hub import HfFolder, Repository, whoami
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
# TODO: remove and import from diffusers.utils when the new version of diffusers is released
|
||||
from packaging import version
|
||||
@ -47,6 +46,8 @@ from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
# invokeai stuff
|
||||
from invokeai.app.services.config import InvokeAIAppConfig,PagingArgumentParser
|
||||
from invokeai.app.services.model_manager_service import ModelManagerService
|
||||
from invokeai.backend.model_management.models import SubModelType
|
||||
|
||||
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
|
||||
PIL_INTERPOLATION = {
|
||||
@ -132,7 +133,7 @@ def parse_args():
|
||||
model_group.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default="stable-diffusion-1.5",
|
||||
default="sd-1/main/stable-diffusion-v1-5",
|
||||
help="Name of the diffusers model to train against, as defined in configs/models.yaml.",
|
||||
)
|
||||
model_group.add_argument(
|
||||
@ -565,7 +566,6 @@ def do_textual_inversion_training(
|
||||
checkpointing_steps: int = 500,
|
||||
resume_from_checkpoint: Path = None,
|
||||
enable_xformers_memory_efficient_attention: bool = False,
|
||||
root_dir: Path = None,
|
||||
hub_model_id: str = None,
|
||||
**kwargs,
|
||||
):
|
||||
@ -584,13 +584,17 @@ def do_textual_inversion_training(
|
||||
|
||||
logging_dir = output_dir / logging_dir
|
||||
|
||||
accelerator_config = ProjectConfiguration()
|
||||
accelerator_config.logging_dir = logging_dir
|
||||
accelerator = Accelerator(
|
||||
gradient_accumulation_steps=gradient_accumulation_steps,
|
||||
mixed_precision=mixed_precision,
|
||||
log_with=report_to,
|
||||
logging_dir=logging_dir,
|
||||
project_config=accelerator_config,
|
||||
)
|
||||
|
||||
model_manager = ModelManagerService(config,logger)
|
||||
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
@ -628,46 +632,46 @@ def do_textual_inversion_training(
|
||||
elif output_dir is not None:
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
models_conf = OmegaConf.load(config.model_conf_path)
|
||||
model_conf = models_conf.get(model, None)
|
||||
assert model_conf is not None, f"Unknown model: {model}"
|
||||
known_models = model_manager.model_names()
|
||||
model_name = model.split('/')[-1]
|
||||
model_meta = next((mm for mm in known_models if mm[0].endswith(model_name)), None)
|
||||
assert model_meta is not None, f"Unknown model: {model}"
|
||||
model_info = model_manager.model_info(*model_meta)
|
||||
assert (
|
||||
model_conf.get("format", "diffusers") == "diffusers"
|
||||
model_info['model_format'] == "diffusers"
|
||||
), "This script only works with models of type 'diffusers'"
|
||||
pretrained_model_name_or_path = model_conf.get("repo_id", None) or Path(
|
||||
model_conf.get("path")
|
||||
)
|
||||
assert (
|
||||
pretrained_model_name_or_path
|
||||
), f"models.yaml error: neither 'repo_id' nor 'path' is defined for {model}"
|
||||
pipeline_args = dict(cache_dir=config.cache_dir)
|
||||
tokenizer_info = model_manager.get_model(*model_meta, submodel=SubModelType.Tokenizer)
|
||||
noise_scheduler_info = model_manager.get_model(*model_meta, submodel=SubModelType.Scheduler)
|
||||
text_encoder_info = model_manager.get_model(*model_meta, submodel=SubModelType.TextEncoder)
|
||||
vae_info = model_manager.get_model(*model_meta, submodel=SubModelType.Vae)
|
||||
unet_info = model_manager.get_model(*model_meta, submodel=SubModelType.UNet)
|
||||
|
||||
# Load tokenizer
|
||||
pipeline_args = dict(local_files_only=True)
|
||||
if tokenizer_name:
|
||||
tokenizer = CLIPTokenizer.from_pretrained(tokenizer_name, **pipeline_args)
|
||||
else:
|
||||
tokenizer = CLIPTokenizer.from_pretrained(
|
||||
pretrained_model_name_or_path, subfolder="tokenizer", **pipeline_args
|
||||
tokenizer_info.location, subfolder='tokenizer', **pipeline_args
|
||||
)
|
||||
|
||||
# Load scheduler and models
|
||||
noise_scheduler = DDPMScheduler.from_pretrained(
|
||||
pretrained_model_name_or_path, subfolder="scheduler", **pipeline_args
|
||||
noise_scheduler_info.location, subfolder="scheduler", **pipeline_args
|
||||
)
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
pretrained_model_name_or_path,
|
||||
text_encoder_info.location,
|
||||
subfolder="text_encoder",
|
||||
revision=revision,
|
||||
**pipeline_args,
|
||||
)
|
||||
vae = AutoencoderKL.from_pretrained(
|
||||
pretrained_model_name_or_path,
|
||||
vae_info.location,
|
||||
subfolder="vae",
|
||||
revision=revision,
|
||||
**pipeline_args,
|
||||
)
|
||||
unet = UNet2DConditionModel.from_pretrained(
|
||||
pretrained_model_name_or_path,
|
||||
unet_info.location,
|
||||
subfolder="unet",
|
||||
revision=revision,
|
||||
**pipeline_args,
|
||||
@ -989,7 +993,7 @@ def do_textual_inversion_training(
|
||||
save_full_model = not only_save_embeds
|
||||
if save_full_model:
|
||||
pipeline = StableDiffusionPipeline.from_pretrained(
|
||||
pretrained_model_name_or_path,
|
||||
unet_info.location,
|
||||
text_encoder=accelerator.unwrap_model(text_encoder),
|
||||
vae=vae,
|
||||
unet=unet,
|
||||
|
Loading…
Reference in New Issue
Block a user