mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into lstein/default-model-install
This commit is contained in:
commit
7fa394912d
@ -24,7 +24,7 @@ import torch.utils.checkpoint
|
|||||||
import transformers
|
import transformers
|
||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
from accelerate.logging import get_logger
|
from accelerate.logging import get_logger
|
||||||
from accelerate.utils import set_seed
|
from accelerate.utils import set_seed, ProjectConfiguration
|
||||||
from diffusers import (
|
from diffusers import (
|
||||||
AutoencoderKL,
|
AutoencoderKL,
|
||||||
DDPMScheduler,
|
DDPMScheduler,
|
||||||
@ -35,7 +35,6 @@ from diffusers.optimization import get_scheduler
|
|||||||
from diffusers.utils import check_min_version
|
from diffusers.utils import check_min_version
|
||||||
from diffusers.utils.import_utils import is_xformers_available
|
from diffusers.utils.import_utils import is_xformers_available
|
||||||
from huggingface_hub import HfFolder, Repository, whoami
|
from huggingface_hub import HfFolder, Repository, whoami
|
||||||
from omegaconf import OmegaConf
|
|
||||||
|
|
||||||
# TODO: remove and import from diffusers.utils when the new version of diffusers is released
|
# TODO: remove and import from diffusers.utils when the new version of diffusers is released
|
||||||
from packaging import version
|
from packaging import version
|
||||||
@ -47,6 +46,8 @@ from transformers import CLIPTextModel, CLIPTokenizer
|
|||||||
|
|
||||||
# invokeai stuff
|
# invokeai stuff
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig,PagingArgumentParser
|
from invokeai.app.services.config import InvokeAIAppConfig,PagingArgumentParser
|
||||||
|
from invokeai.app.services.model_manager_service import ModelManagerService
|
||||||
|
from invokeai.backend.model_management.models import SubModelType
|
||||||
|
|
||||||
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
|
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
|
||||||
PIL_INTERPOLATION = {
|
PIL_INTERPOLATION = {
|
||||||
@ -132,7 +133,7 @@ def parse_args():
|
|||||||
model_group.add_argument(
|
model_group.add_argument(
|
||||||
"--model",
|
"--model",
|
||||||
type=str,
|
type=str,
|
||||||
default="stable-diffusion-1.5",
|
default="sd-1/main/stable-diffusion-v1-5",
|
||||||
help="Name of the diffusers model to train against, as defined in configs/models.yaml.",
|
help="Name of the diffusers model to train against, as defined in configs/models.yaml.",
|
||||||
)
|
)
|
||||||
model_group.add_argument(
|
model_group.add_argument(
|
||||||
@ -565,7 +566,6 @@ def do_textual_inversion_training(
|
|||||||
checkpointing_steps: int = 500,
|
checkpointing_steps: int = 500,
|
||||||
resume_from_checkpoint: Path = None,
|
resume_from_checkpoint: Path = None,
|
||||||
enable_xformers_memory_efficient_attention: bool = False,
|
enable_xformers_memory_efficient_attention: bool = False,
|
||||||
root_dir: Path = None,
|
|
||||||
hub_model_id: str = None,
|
hub_model_id: str = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
@ -584,13 +584,17 @@ def do_textual_inversion_training(
|
|||||||
|
|
||||||
logging_dir = output_dir / logging_dir
|
logging_dir = output_dir / logging_dir
|
||||||
|
|
||||||
|
accelerator_config = ProjectConfiguration()
|
||||||
|
accelerator_config.logging_dir = logging_dir
|
||||||
accelerator = Accelerator(
|
accelerator = Accelerator(
|
||||||
gradient_accumulation_steps=gradient_accumulation_steps,
|
gradient_accumulation_steps=gradient_accumulation_steps,
|
||||||
mixed_precision=mixed_precision,
|
mixed_precision=mixed_precision,
|
||||||
log_with=report_to,
|
log_with=report_to,
|
||||||
logging_dir=logging_dir,
|
project_config=accelerator_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
model_manager = ModelManagerService(config,logger)
|
||||||
|
|
||||||
# Make one log on every process with the configuration for debugging.
|
# Make one log on every process with the configuration for debugging.
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||||
@ -628,46 +632,46 @@ def do_textual_inversion_training(
|
|||||||
elif output_dir is not None:
|
elif output_dir is not None:
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
models_conf = OmegaConf.load(config.model_conf_path)
|
known_models = model_manager.model_names()
|
||||||
model_conf = models_conf.get(model, None)
|
model_name = model.split('/')[-1]
|
||||||
assert model_conf is not None, f"Unknown model: {model}"
|
model_meta = next((mm for mm in known_models if mm[0].endswith(model_name)), None)
|
||||||
|
assert model_meta is not None, f"Unknown model: {model}"
|
||||||
|
model_info = model_manager.model_info(*model_meta)
|
||||||
assert (
|
assert (
|
||||||
model_conf.get("format", "diffusers") == "diffusers"
|
model_info['model_format'] == "diffusers"
|
||||||
), "This script only works with models of type 'diffusers'"
|
), "This script only works with models of type 'diffusers'"
|
||||||
pretrained_model_name_or_path = model_conf.get("repo_id", None) or Path(
|
tokenizer_info = model_manager.get_model(*model_meta, submodel=SubModelType.Tokenizer)
|
||||||
model_conf.get("path")
|
noise_scheduler_info = model_manager.get_model(*model_meta, submodel=SubModelType.Scheduler)
|
||||||
)
|
text_encoder_info = model_manager.get_model(*model_meta, submodel=SubModelType.TextEncoder)
|
||||||
assert (
|
vae_info = model_manager.get_model(*model_meta, submodel=SubModelType.Vae)
|
||||||
pretrained_model_name_or_path
|
unet_info = model_manager.get_model(*model_meta, submodel=SubModelType.UNet)
|
||||||
), f"models.yaml error: neither 'repo_id' nor 'path' is defined for {model}"
|
|
||||||
pipeline_args = dict(cache_dir=config.cache_dir)
|
|
||||||
|
|
||||||
# Load tokenizer
|
pipeline_args = dict(local_files_only=True)
|
||||||
if tokenizer_name:
|
if tokenizer_name:
|
||||||
tokenizer = CLIPTokenizer.from_pretrained(tokenizer_name, **pipeline_args)
|
tokenizer = CLIPTokenizer.from_pretrained(tokenizer_name, **pipeline_args)
|
||||||
else:
|
else:
|
||||||
tokenizer = CLIPTokenizer.from_pretrained(
|
tokenizer = CLIPTokenizer.from_pretrained(
|
||||||
pretrained_model_name_or_path, subfolder="tokenizer", **pipeline_args
|
tokenizer_info.location, subfolder='tokenizer', **pipeline_args
|
||||||
)
|
)
|
||||||
|
|
||||||
# Load scheduler and models
|
# Load scheduler and models
|
||||||
noise_scheduler = DDPMScheduler.from_pretrained(
|
noise_scheduler = DDPMScheduler.from_pretrained(
|
||||||
pretrained_model_name_or_path, subfolder="scheduler", **pipeline_args
|
noise_scheduler_info.location, subfolder="scheduler", **pipeline_args
|
||||||
)
|
)
|
||||||
text_encoder = CLIPTextModel.from_pretrained(
|
text_encoder = CLIPTextModel.from_pretrained(
|
||||||
pretrained_model_name_or_path,
|
text_encoder_info.location,
|
||||||
subfolder="text_encoder",
|
subfolder="text_encoder",
|
||||||
revision=revision,
|
revision=revision,
|
||||||
**pipeline_args,
|
**pipeline_args,
|
||||||
)
|
)
|
||||||
vae = AutoencoderKL.from_pretrained(
|
vae = AutoencoderKL.from_pretrained(
|
||||||
pretrained_model_name_or_path,
|
vae_info.location,
|
||||||
subfolder="vae",
|
subfolder="vae",
|
||||||
revision=revision,
|
revision=revision,
|
||||||
**pipeline_args,
|
**pipeline_args,
|
||||||
)
|
)
|
||||||
unet = UNet2DConditionModel.from_pretrained(
|
unet = UNet2DConditionModel.from_pretrained(
|
||||||
pretrained_model_name_or_path,
|
unet_info.location,
|
||||||
subfolder="unet",
|
subfolder="unet",
|
||||||
revision=revision,
|
revision=revision,
|
||||||
**pipeline_args,
|
**pipeline_args,
|
||||||
@ -989,7 +993,7 @@ def do_textual_inversion_training(
|
|||||||
save_full_model = not only_save_embeds
|
save_full_model = not only_save_embeds
|
||||||
if save_full_model:
|
if save_full_model:
|
||||||
pipeline = StableDiffusionPipeline.from_pretrained(
|
pipeline = StableDiffusionPipeline.from_pretrained(
|
||||||
pretrained_model_name_or_path,
|
unet_info.location,
|
||||||
text_encoder=accelerator.unwrap_model(text_encoder),
|
text_encoder=accelerator.unwrap_model(text_encoder),
|
||||||
vae=vae,
|
vae=vae,
|
||||||
unet=unet,
|
unet=unet,
|
||||||
|
Loading…
Reference in New Issue
Block a user