diff --git a/ldm/generate.py b/ldm/generate.py index 8cb3058694..aa04b24df3 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -11,6 +11,7 @@ import re import sys import time import traceback +from typing import List import cv2 import diffusers @@ -18,19 +19,19 @@ import numpy as np import skimage import torch import transformers -from PIL import Image, ImageOps from diffusers.pipeline_utils import DiffusionPipeline from diffusers.utils.import_utils import is_xformers_available from omegaconf import OmegaConf -from pytorch_lightning import seed_everything, logging +from PIL import Image, ImageOps +from pytorch_lightning import logging, seed_everything import ldm.invoke.conditioning from ldm.invoke.args import metadata_from_png from ldm.invoke.concepts_lib import HuggingFaceConceptsLibrary from ldm.invoke.conditioning import get_uc_and_c_and_ec -from ldm.invoke.devices import choose_torch_device, choose_precision +from ldm.invoke.devices import choose_precision, choose_torch_device from ldm.invoke.generator.inpaint import infill_methods -from ldm.invoke.globals import global_cache_dir, Globals +from ldm.invoke.globals import Globals, global_cache_dir from ldm.invoke.image_util import InitImageResizer from ldm.invoke.model_manager import ModelManager from ldm.invoke.pngwriter import PngWriter @@ -42,14 +43,17 @@ from ldm.models.diffusion.plms import PLMSSampler def fix_func(orig): - if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): + if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + def new_func(*args, **kw): device = kw.get("device", "mps") - kw["device"]="cpu" + kw["device"] = "cpu" return orig(*args, **kw).to(device) + return new_func return orig + torch.rand = fix_func(torch.rand) torch.rand_like = fix_func(torch.rand_like) torch.randn = fix_func(torch.randn) @@ -60,7 +64,7 @@ torch.bernoulli = fix_func(torch.bernoulli) torch.multinomial = fix_func(torch.multinomial) # this is fallback model in case no default is defined -FALLBACK_MODEL_NAME='stable-diffusion-1.5' +FALLBACK_MODEL_NAME = "stable-diffusion-1.5" """Simplified text to image API for stable diffusion/latent diffusion @@ -129,59 +133,60 @@ gr = Generate( """ + class Generate: """Generate class Stores default values for multiple configuration items """ def __init__( - self, - model = None, - conf = 'configs/models.yaml', - embedding_path = None, - sampler_name = 'k_lms', - ddim_eta = 0.0, # deterministic - full_precision = False, - precision = 'auto', - outdir = 'outputs/img-samples', - gfpgan=None, - codeformer=None, - esrgan=None, - free_gpu_mem: bool=False, - safety_checker:bool=False, - max_loaded_models:int=2, - # these are deprecated; if present they override values in the conf file - weights = None, - config = None, + self, + model=None, + conf="configs/models.yaml", + embedding_path=None, + sampler_name="k_lms", + ddim_eta=0.0, # deterministic + full_precision=False, + precision="auto", + outdir="outputs/img-samples", + gfpgan=None, + codeformer=None, + esrgan=None, + free_gpu_mem: bool = False, + safety_checker: bool = False, + max_loaded_models: int = 2, + # these are deprecated; if present they override values in the conf file + weights=None, + config=None, ): - mconfig = OmegaConf.load(conf) - self.height = None - self.width = None - self.model_manager = None - self.iterations = 1 - self.steps = 50 - self.cfg_scale = 7.5 - self.sampler_name = sampler_name - self.ddim_eta = ddim_eta # same seed always produces same image - self.precision = precision - self.strength = 0.75 - self.seamless = False - self.seamless_axes = {'x','y'} - self.hires_fix = False + mconfig = OmegaConf.load(conf) + self.height = None + self.width = None + self.model_manager = None + self.iterations = 1 + self.steps = 50 + self.cfg_scale = 7.5 + self.sampler_name = sampler_name + self.ddim_eta = ddim_eta # same seed always produces same image + self.precision = precision + self.strength = 0.75 + self.seamless = False + self.seamless_axes = {"x", "y"} + self.hires_fix = False self.embedding_path = embedding_path - self.model = None # empty for now - self.model_hash = None - self.sampler = None - self.device = None + self.model = None # empty for now + self.model_hash = None + self.sampler = None + self.device = None self.session_peakmem = None self.base_generator = None - self.seed = None + self.seed = None self.outdir = outdir self.gfpgan = gfpgan self.codeformer = codeformer self.esrgan = esrgan self.free_gpu_mem = free_gpu_mem - self.max_loaded_models = max_loaded_models, + self.max_loaded_models = (max_loaded_models,) self.size_matters = True # used to warn once about large image sizes and VRAM self.txt2mask = None self.safety_checker = None @@ -192,62 +197,77 @@ class Generate: # device to Generate(). However the device was then ignored, so # it wasn't actually doing anything. This logic could be reinstated. device_type = choose_torch_device() - print(f'>> Using device_type {device_type}') + print(f">> Using device_type {device_type}") self.device = torch.device(device_type) if full_precision: - if self.precision != 'auto': - raise ValueError('Remove --full_precision / -F if using --precision') - print('Please remove deprecated --full_precision / -F') - print('If auto config does not work you can use --precision=float32') - self.precision = 'float32' - if self.precision == 'auto': + if self.precision != "auto": + raise ValueError("Remove --full_precision / -F if using --precision") + print("Please remove deprecated --full_precision / -F") + print("If auto config does not work you can use --precision=float32") + self.precision = "float32" + if self.precision == "auto": self.precision = choose_precision(self.device) - Globals.full_precision = self.precision=='float32' + Globals.full_precision = self.precision == "float32" if is_xformers_available(): if not Globals.disable_xformers: - print('>> xformers memory-efficient attention is available and enabled') + print(">> xformers memory-efficient attention is available and enabled") else: - print('>> xformers memory-efficient attention is available but disabled') + print( + ">> xformers memory-efficient attention is available but disabled" + ) else: - print('>> xformers not installed') + print(">> xformers not installed") # model caching system for fast switching - self.model_manager = ModelManager(mconfig,self.device,self.precision,max_loaded_models=max_loaded_models) + self.model_manager = ModelManager( + mconfig, self.device, self.precision, max_loaded_models=max_loaded_models + ) # don't accept invalid models fallback = self.model_manager.default_model() or FALLBACK_MODEL_NAME model = model or fallback if not self.model_manager.valid_model(model): - print(f'** "{model}" is not a known model name; falling back to {fallback}.') + print( + f'** "{model}" is not a known model name; falling back to {fallback}.' + ) model = None - self.model_name = model or fallback + self.model_name = model or fallback # for VRAM usage statistics - self.session_peakmem = torch.cuda.max_memory_allocated(self.device) if self._has_cuda else None + self.session_peakmem = ( + torch.cuda.max_memory_allocated(self.device) if self._has_cuda else None + ) transformers.logging.set_verbosity_error() # gets rid of annoying messages about random seed - logging.getLogger('pytorch_lightning').setLevel(logging.ERROR) + logging.getLogger("pytorch_lightning").setLevel(logging.ERROR) # load safety checker if requested if safety_checker: try: - print('>> Initializing safety checker') - from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker + print(">> Initializing safety checker") + from diffusers.pipelines.stable_diffusion.safety_checker import ( + StableDiffusionSafetyChecker, + ) from transformers import AutoFeatureExtractor + safety_model_id = "CompVis/stable-diffusion-safety-checker" safety_model_path = global_cache_dir("hub") - self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id, - local_files_only=True, - cache_dir=safety_model_path, + self.safety_checker = StableDiffusionSafetyChecker.from_pretrained( + safety_model_id, + local_files_only=True, + cache_dir=safety_model_path, ) - self.safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id, - local_files_only=True, - cache_dir=safety_model_path, + self.safety_feature_extractor = AutoFeatureExtractor.from_pretrained( + safety_model_id, + local_files_only=True, + cache_dir=safety_model_path, ) self.safety_checker.to(self.device) except Exception: - print('** An error was encountered while installing the safety checker:') + print( + "** An error was encountered while installing the safety checker:" + ) print(traceback.format_exc()) def prompt2png(self, prompt, outdir, **kwargs): @@ -256,95 +276,95 @@ class Generate: of PNG files, and returns an array of [[filename,seed],[filename,seed]...] Optional named arguments are the same as those passed to Generate and prompt2image() """ - results = self.prompt2image(prompt, **kwargs) + results = self.prompt2image(prompt, **kwargs) pngwriter = PngWriter(outdir) - prefix = pngwriter.unique_prefix() - outputs = [] + prefix = pngwriter.unique_prefix() + outputs = [] for image, seed in results: - name = f'{prefix}.{seed}.png' + name = f"{prefix}.{seed}.png" path = pngwriter.save_image_and_prompt_to_png( - image, dream_prompt=f'{prompt} -S{seed}', name=name) + image, dream_prompt=f"{prompt} -S{seed}", name=name + ) outputs.append([path, seed]) return outputs def txt2img(self, prompt, **kwargs): - outdir = kwargs.pop('outdir', self.outdir) + outdir = kwargs.pop("outdir", self.outdir) return self.prompt2png(prompt, outdir, **kwargs) def img2img(self, prompt, **kwargs): - outdir = kwargs.pop('outdir', self.outdir) + outdir = kwargs.pop("outdir", self.outdir) assert ( - 'init_img' in kwargs - ), 'call to img2img() must include the init_img argument' + "init_img" in kwargs + ), "call to img2img() must include the init_img argument" return self.prompt2png(prompt, outdir, **kwargs) def prompt2image( - self, - # these are common - prompt, - iterations = None, - steps = None, - seed = None, - cfg_scale = None, - ddim_eta = None, - skip_normalize = False, - image_callback = None, - step_callback = None, - width = None, - height = None, - sampler_name = None, - seamless = False, - seamless_axes = {'x','y'}, - log_tokenization = False, - with_variations = None, - variation_amount = 0.0, - threshold = 0.0, - perlin = 0.0, - karras_max = None, - outdir = None, - # these are specific to img2img and inpaint - init_img = None, - init_mask = None, - text_mask = None, - invert_mask = False, - fit = False, - strength = None, - init_color = None, - # these are specific to embiggen (which also relies on img2img args) - embiggen = None, - embiggen_tiles = None, - embiggen_strength = None, - # these are specific to GFPGAN/ESRGAN - gfpgan_strength= 0, - facetool = None, - facetool_strength = 0, - codeformer_fidelity = None, - save_original = False, - upscale = None, - upscale_denoise_str = 0.75, - # this is specific to inpainting and causes more extreme inpainting - inpaint_replace = 0.0, - # This controls the size at which inpaint occurs (scaled up for inpaint, then back down for the result) - inpaint_width = None, - inpaint_height = None, - # This will help match inpainted areas to the original image more smoothly - mask_blur_radius: int = 8, - # Set this True to handle KeyboardInterrupt internally - catch_interrupts = False, - hires_fix = False, - use_mps_noise = False, - # Seam settings for outpainting - seam_size: int = 0, - seam_blur: int = 0, - seam_strength: float = 0.7, - seam_steps: int = 10, - tile_size: int = 32, - infill_method = None, - force_outpaint: bool = False, - enable_image_debugging = False, - - **args, - ): # eat up additional cruft + self, + # these are common + prompt, + iterations=None, + steps=None, + seed=None, + cfg_scale=None, + ddim_eta=None, + skip_normalize=False, + image_callback=None, + step_callback=None, + width=None, + height=None, + sampler_name=None, + seamless=False, + seamless_axes={"x", "y"}, + log_tokenization=False, + with_variations=None, + variation_amount=0.0, + threshold=0.0, + perlin=0.0, + karras_max=None, + outdir=None, + # these are specific to img2img and inpaint + init_img=None, + init_mask=None, + text_mask=None, + invert_mask=False, + fit=False, + strength=None, + init_color=None, + # these are specific to embiggen (which also relies on img2img args) + embiggen=None, + embiggen_tiles=None, + embiggen_strength=None, + # these are specific to GFPGAN/ESRGAN + gfpgan_strength=0, + facetool=None, + facetool_strength=0, + codeformer_fidelity=None, + save_original=False, + upscale=None, + upscale_denoise_str=0.75, + # this is specific to inpainting and causes more extreme inpainting + inpaint_replace=0.0, + # This controls the size at which inpaint occurs (scaled up for inpaint, then back down for the result) + inpaint_width=None, + inpaint_height=None, + # This will help match inpainted areas to the original image more smoothly + mask_blur_radius: int = 8, + # Set this True to handle KeyboardInterrupt internally + catch_interrupts=False, + hires_fix=False, + use_mps_noise=False, + # Seam settings for outpainting + seam_size: int = 0, + seam_blur: int = 0, + seam_strength: float = 0.7, + seam_steps: int = 10, + tile_size: int = 32, + infill_method=None, + force_outpaint: bool = False, + enable_image_debugging=False, + **args, + ): # eat up additional cruft self.clear_cuda_stats() """ ldm.generate.prompt2image() is the common entry point for txt2img() and img2img() @@ -401,12 +421,14 @@ class Generate: ddim_eta = ddim_eta or self.ddim_eta iterations = iterations or self.iterations strength = strength or self.strength - outdir = outdir or self.outdir + outdir = outdir or self.outdir self.seed = seed self.log_tokenization = log_tokenization self.step_callback = step_callback self.karras_max = karras_max - self.infill_method = infill_method or infill_methods()[0], # The infill method to use + self.infill_method = ( + infill_method or infill_methods()[0], + ) # The infill method to use with_variations = [] if with_variations is None else with_variations # will instantiate the model or return it from cache @@ -423,33 +445,33 @@ class Generate: else: configure_model_padding(model, seamless, seamless_axes) - assert cfg_scale > 1.0, 'CFG_Scale (-C) must be >1.0' - assert threshold >= 0.0, '--threshold must be >=0.0' + assert cfg_scale > 1.0, "CFG_Scale (-C) must be >1.0" + assert threshold >= 0.0, "--threshold must be >=0.0" assert ( 0.0 < strength <= 1.0 - ), 'img2img and inpaint strength can only work with 0.0 < strength < 1.0' + ), "img2img and inpaint strength can only work with 0.0 < strength < 1.0" assert ( 0.0 <= variation_amount <= 1.0 - ), '-v --variation_amount must be in [0.0, 1.0]' - assert ( - 0.0 <= perlin <= 1.0 - ), '--perlin must be in [0.0, 1.0]' - assert ( - (embiggen == None and embiggen_tiles == None) or ( - (embiggen != None or embiggen_tiles != None) and init_img != None) - ), 'Embiggen requires an init/input image to be specified' + ), "-v --variation_amount must be in [0.0, 1.0]" + assert 0.0 <= perlin <= 1.0, "--perlin must be in [0.0, 1.0]" + assert (embiggen == None and embiggen_tiles == None) or ( + (embiggen != None or embiggen_tiles != None) and init_img != None + ), "Embiggen requires an init/input image to be specified" if len(with_variations) > 0 or variation_amount > 1.0: - assert seed is not None,\ - 'seed must be specified when using with_variations' + assert seed is not None, "seed must be specified when using with_variations" if variation_amount == 0.0: - assert iterations == 1,\ - 'when using --with_variations, multiple iterations are only possible when using --variation_amount' - assert all(0 <= weight <= 1 for _, weight in with_variations),\ - f'variation weights must be in [0.0, 1.0]: got {[weight for _, weight in with_variations]}' + assert ( + iterations == 1 + ), "when using --with_variations, multiple iterations are only possible when using --variation_amount" + assert all( + 0 <= weight <= 1 for _, weight in with_variations + ), f"variation weights must be in [0.0, 1.0]: got {[weight for _, weight in with_variations]}" width, height, _ = self._resolution_check(width, height, log=True) - assert inpaint_replace >=0.0 and inpaint_replace <= 1.0,'inpaint_replace must be between 0.0 and 1.0' + assert ( + inpaint_replace >= 0.0 and inpaint_replace <= 1.0 + ), "inpaint_replace must be between 0.0 and 1.0" if sampler_name and (sampler_name != self.sampler_name): self.sampler_name = sampler_name @@ -459,12 +481,12 @@ class Generate: prompt = self.huggingface_concepts_library.replace_concepts_with_triggers( prompt, lambda concepts: self.load_huggingface_concepts(concepts), - self.model.textual_inversion_manager.get_all_trigger_strings() + self.model.textual_inversion_manager.get_all_trigger_strings(), ) # bit of a hack to change the cached sampler's karras threshold to # whatever the user asked for - if karras_max is not None and isinstance(self.sampler,KSampler): + if karras_max is not None and isinstance(self.sampler, KSampler): self.sampler.adjust_settings(karras_max=karras_max) tic = time.time() @@ -476,18 +498,24 @@ class Generate: mask_image = None try: - if self.free_gpu_mem and self.model.cond_stage_model.device != self.model.device: + if ( + self.free_gpu_mem + and self.model.cond_stage_model.device != self.model.device + ): self.model.cond_stage_model.device = self.model.device self.model.cond_stage_model.to(self.model.device) except AttributeError: - print(">> Warning: '--free_gpu_mem' is not yet supported when generating image using model based on HuggingFace Diffuser.") + print( + ">> Warning: '--free_gpu_mem' is not yet supported when generating image using model based on HuggingFace Diffuser." + ) pass try: uc, c, extra_conditioning_info = get_uc_and_c_and_ec( - prompt, model =self.model, + prompt, + model=self.model, skip_normalize_legacy_blend=skip_normalize, - log_tokens =self.log_tokenization + log_tokens=self.log_tokenization, ) init_image, mask_image = self._make_images( @@ -502,17 +530,21 @@ class Generate: ) # TODO: Hacky selection of operation to perform. Needs to be refactored. - generator = self.select_generator(init_image, mask_image, embiggen, hires_fix, force_outpaint) - - generator.set_variation( - self.seed, variation_amount, with_variations + generator = self.select_generator( + init_image, mask_image, embiggen, hires_fix, force_outpaint ) + + generator.set_variation(self.seed, variation_amount, with_variations) generator.use_mps_noise = use_mps_noise - checker = { - 'checker':self.safety_checker, - 'extractor':self.safety_feature_extractor - } if self.safety_checker else None + checker = ( + { + "checker": self.safety_checker, + "extractor": self.safety_feature_extractor, + } + if self.safety_checker + else None + ) results = generator.generate( prompt, @@ -524,11 +556,11 @@ class Generate: conditioning=(uc, c, extra_conditioning_info), ddim_eta=ddim_eta, image_callback=image_callback, # called after the final image is generated - step_callback=step_callback, # called after each intermediate image is generated + step_callback=step_callback, # called after each intermediate image is generated width=width, height=height, - init_img=init_img, # embiggen needs to manipulate from the unmodified init_img - init_image=init_image, # notice that init_image is different from init_img + init_img=init_img, # embiggen needs to manipulate from the unmodified init_img + init_image=init_image, # notice that init_image is different from init_img mask_image=mask_image, strength=strength, threshold=threshold, @@ -539,41 +571,45 @@ class Generate: inpaint_replace=inpaint_replace, mask_blur_radius=mask_blur_radius, safety_checker=checker, - seam_size = seam_size, - seam_blur = seam_blur, - seam_strength = seam_strength, - seam_steps = seam_steps, - tile_size = tile_size, - infill_method = infill_method, - force_outpaint = force_outpaint, - inpaint_height = inpaint_height, - inpaint_width = inpaint_width, - enable_image_debugging = enable_image_debugging, + seam_size=seam_size, + seam_blur=seam_blur, + seam_strength=seam_strength, + seam_steps=seam_steps, + tile_size=tile_size, + infill_method=infill_method, + force_outpaint=force_outpaint, + inpaint_height=inpaint_height, + inpaint_width=inpaint_width, + enable_image_debugging=enable_image_debugging, free_gpu_mem=self.free_gpu_mem, - clear_cuda_cache=self.clear_cuda_cache + clear_cuda_cache=self.clear_cuda_cache, ) if init_color: - self.correct_colors(image_list = results, - reference_image_path = init_color, - image_callback = image_callback) + self.correct_colors( + image_list=results, + reference_image_path=init_color, + image_callback=image_callback, + ) if upscale is not None or facetool_strength > 0: - self.upscale_and_reconstruct(results, - upscale = upscale, - upscale_denoise_str = upscale_denoise_str, - facetool = facetool, - strength = facetool_strength, - codeformer_fidelity = codeformer_fidelity, - save_original = save_original, - image_callback = image_callback) + self.upscale_and_reconstruct( + results, + upscale=upscale, + upscale_denoise_str=upscale_denoise_str, + facetool=facetool, + strength=facetool_strength, + codeformer_fidelity=codeformer_fidelity, + save_original=save_original, + image_callback=image_callback, + ) except KeyboardInterrupt: # Clear the CUDA cache on an exception self.clear_cuda_cache() if catch_interrupts: - print('**Interrupted** Partial results will be returned.') + print("**Interrupted** Partial results will be returned.") else: raise KeyboardInterrupt except RuntimeError: @@ -581,30 +617,24 @@ class Generate: self.clear_cuda_cache() print(traceback.format_exc(), file=sys.stderr) - print('>> Could not generate image.') + print(">> Could not generate image.") toc = time.time() - print('\n>> Usage stats:') - print( - f'>> {len(results)} image(s) generated in', '%4.2fs' % ( - toc - tic) - ) + print("\n>> Usage stats:") + print(f">> {len(results)} image(s) generated in", "%4.2fs" % (toc - tic)) self.print_cuda_stats() return results def gather_cuda_stats(self): if self._has_cuda(): self.max_memory_allocated = max( - self.max_memory_allocated, - torch.cuda.max_memory_allocated(self.device) + self.max_memory_allocated, torch.cuda.max_memory_allocated(self.device) ) self.memory_allocated = max( - self.memory_allocated, - torch.cuda.memory_allocated(self.device) + self.memory_allocated, torch.cuda.memory_allocated(self.device) ) self.session_peakmem = max( - self.session_peakmem, - torch.cuda.max_memory_allocated(self.device) + self.session_peakmem, torch.cuda.max_memory_allocated(self.device) ) def clear_cuda_cache(self): @@ -620,35 +650,35 @@ class Generate: if self._has_cuda(): self.gather_cuda_stats() print( - '>> Max VRAM used for this generation:', - '%4.2fG.' % (self.max_memory_allocated / 1e9), - 'Current VRAM utilization:', - '%4.2fG' % (self.memory_allocated / 1e9), + ">> Max VRAM used for this generation:", + "%4.2fG." % (self.max_memory_allocated / 1e9), + "Current VRAM utilization:", + "%4.2fG" % (self.memory_allocated / 1e9), ) print( - '>> Max VRAM used since script start: ', - '%4.2fG' % (self.session_peakmem / 1e9), + ">> Max VRAM used since script start: ", + "%4.2fG" % (self.session_peakmem / 1e9), ) # this needs to be generalized to all sorts of postprocessors, which should be wrapped # in a nice harmonized call signature. For now we have a bunch of if/elses! def apply_postprocessor( - self, - image_path, - tool = 'gfpgan', # one of 'upscale', 'gfpgan', 'codeformer', 'outpaint', or 'embiggen' - facetool_strength = 0.0, - codeformer_fidelity = 0.75, - upscale = None, - upscale_denoise_str = 0.75, - out_direction = None, - outcrop = [], - save_original = True, # to get new name - callback = None, - opt = None, - ): + self, + image_path, + tool="gfpgan", # one of 'upscale', 'gfpgan', 'codeformer', 'outpaint', or 'embiggen' + facetool_strength=0.0, + codeformer_fidelity=0.75, + upscale=None, + upscale_denoise_str=0.75, + out_direction=None, + outcrop=[], + save_original=True, # to get new name + callback=None, + opt=None, + ): # retrieve the seed from the image; - seed = None + seed = None prompt = None args = metadata_from_png(image_path) @@ -656,13 +686,13 @@ class Generate: if seed is None or seed < 0: seed = random.randrange(0, np.iinfo(np.uint32).max) - prompt = opt.prompt or args.prompt or '' + prompt = opt.prompt or args.prompt or "" print(f'>> using seed {seed} and prompt "{prompt}" for {image_path}') # try to reuse the same filename prefix as the original file. # we take everything up to the first period prefix = None - m = re.match(r'^([^.]+)\.',os.path.basename(image_path)) + m = re.match(r"^([^.]+)\.", os.path.basename(image_path)) if m: prefix = m.groups()[0] @@ -672,99 +702,106 @@ class Generate: # used by multiple postfixers # todo: cross-attention control uc, c, extra_conditioning_info = get_uc_and_c_and_ec( - prompt, model=self.model, + prompt, + model=self.model, skip_normalize_legacy_blend=opt.skip_normalize, - log_tokens=ldm.invoke.conditioning.log_tokenization + log_tokens=ldm.invoke.conditioning.log_tokenization, ) - if tool in ('gfpgan','codeformer','upscale'): - if tool == 'gfpgan': - facetool = 'gfpgan' - elif tool == 'codeformer': - facetool = 'codeformer' - elif tool == 'upscale': - facetool = 'gfpgan' # but won't be run + if tool in ("gfpgan", "codeformer", "upscale"): + if tool == "gfpgan": + facetool = "gfpgan" + elif tool == "codeformer": + facetool = "codeformer" + elif tool == "upscale": + facetool = "gfpgan" # but won't be run facetool_strength = 0 return self.upscale_and_reconstruct( - [[image,seed]], - facetool = facetool, - strength = facetool_strength, - codeformer_fidelity = codeformer_fidelity, - save_original = save_original, - upscale = upscale, - upscale_denoise_str = upscale_denoise_str, - image_callback = callback, - prefix = prefix, + [[image, seed]], + facetool=facetool, + strength=facetool_strength, + codeformer_fidelity=codeformer_fidelity, + save_original=save_original, + upscale=upscale, + upscale_denoise_str=upscale_denoise_str, + image_callback=callback, + prefix=prefix, ) - elif tool == 'outcrop': + elif tool == "outcrop": from ldm.invoke.restoration.outcrop import Outcrop + extend_instructions = {} - for direction,pixels in _pairwise(opt.outcrop): + for direction, pixels in _pairwise(opt.outcrop): try: - extend_instructions[direction]=int(pixels) + extend_instructions[direction] = int(pixels) except ValueError: - print('** invalid extension instruction. Use ..., as in "top 64 left 128 right 64 bottom 64"') + print( + '** invalid extension instruction. Use ..., as in "top 64 left 128 right 64 bottom 64"' + ) opt.seed = seed opt.prompt = prompt if len(extend_instructions) > 0: - restorer = Outcrop(image,self,) - return restorer.process ( + restorer = Outcrop( + image, + self, + ) + return restorer.process( extend_instructions, - opt = opt, - orig_opt = args, - image_callback = callback, - prefix = prefix, + opt=opt, + orig_opt=args, + image_callback=callback, + prefix=prefix, ) - elif tool == 'embiggen': + elif tool == "embiggen": # fetch the metadata from the image generator = self.select_generator(embiggen=True) opt.strength = opt.embiggen_strength or 0.40 - print(f'>> Setting img2img strength to {opt.strength} for happy embiggening') + print( + f">> Setting img2img strength to {opt.strength} for happy embiggening" + ) generator.generate( prompt, - sampler = self.sampler, - steps = opt.steps, - cfg_scale = opt.cfg_scale, - ddim_eta = self.ddim_eta, - conditioning= (uc, c, extra_conditioning_info), - init_img = image_path, # not the Image! (sigh) - init_image = image, # embiggen wants both! (sigh) - strength = opt.strength, - width = opt.width, - height = opt.height, - embiggen = opt.embiggen, - embiggen_tiles = opt.embiggen_tiles, - embiggen_strength = opt.embiggen_strength, - image_callback = callback, + sampler=self.sampler, + steps=opt.steps, + cfg_scale=opt.cfg_scale, + ddim_eta=self.ddim_eta, + conditioning=(uc, c, extra_conditioning_info), + init_img=image_path, # not the Image! (sigh) + init_image=image, # embiggen wants both! (sigh) + strength=opt.strength, + width=opt.width, + height=opt.height, + embiggen=opt.embiggen, + embiggen_tiles=opt.embiggen_tiles, + embiggen_strength=opt.embiggen_strength, + image_callback=callback, ) - elif tool == 'outpaint': + elif tool == "outpaint": from ldm.invoke.restoration.outpaint import Outpaint - restorer = Outpaint(image,self) - return restorer.process( - opt, - args, - image_callback = callback, - prefix = prefix - ) + + restorer = Outpaint(image, self) + return restorer.process(opt, args, image_callback=callback, prefix=prefix) elif tool is None: - print('* please provide at least one postprocessing option, such as -G or -U') + print( + "* please provide at least one postprocessing option, such as -G or -U" + ) return None else: - print(f'* postprocessing tool {tool} is not yet supported') + print(f"* postprocessing tool {tool} is not yet supported") return None def select_generator( - self, - init_image:Image.Image=None, - mask_image:Image.Image=None, - embiggen:bool=False, - hires_fix:bool=False, - force_outpaint:bool=False, + self, + init_image: Image.Image = None, + mask_image: Image.Image = None, + embiggen: bool = False, + hires_fix: bool = False, + force_outpaint: bool = False, ): inpainting_model_in_use = self.sampler.uses_inpainting_model() @@ -786,40 +823,46 @@ class Generate: return self._make_txt2img() def _make_images( - self, - img, - mask, - width, - height, - fit=False, - text_mask=None, - invert_mask=False, - force_outpaint=False, + self, + img, + mask, + width, + height, + fit=False, + text_mask=None, + invert_mask=False, + force_outpaint=False, ): - init_image = None - init_mask = None + init_image = None + init_mask = None if not img: return None, None image = self._load_img(img) if image.width < self.width and image.height < self.height: - print(f'>> WARNING: img2img and inpainting may produce unexpected results with initial images smaller than {self.width}x{self.height} in both dimensions') + print( + f">> WARNING: img2img and inpainting may produce unexpected results with initial images smaller than {self.width}x{self.height} in both dimensions" + ) # if image has a transparent area and no mask was provided, then try to generate mask if self._has_transparency(image): self._transparency_check_and_warning(image, mask, force_outpaint) init_mask = self._create_init_mask(image, width, height, fit=fit) - if (image.width * image.height) > (self.width * self.height) and self.size_matters: - print(">> This input is larger than your defaults. If you run out of memory, please use a smaller image.") + if (image.width * image.height) > ( + self.width * self.height + ) and self.size_matters: + print( + ">> This input is larger than your defaults. If you run out of memory, please use a smaller image." + ) self.size_matters = False - init_image = self._create_init_image(image,width,height,fit=fit) + init_image = self._create_init_image(image, width, height, fit=fit) if mask: mask_image = self._load_img(mask) - init_mask = self._create_init_mask(mask_image,width,height,fit=fit) + init_mask = self._create_init_mask(mask_image, width, height, fit=fit) elif text_mask: init_mask = self._txt2mask(image, text_mask, width, height, fit=fit) @@ -827,47 +870,47 @@ class Generate: if init_mask and invert_mask: init_mask = ImageOps.invert(init_mask) - return init_image,init_mask + return init_image, init_mask def _make_base(self): - return self._load_generator('','Generator') + return self._load_generator("", "Generator") def _make_txt2img(self): - return self._load_generator('.txt2img','Txt2Img') + return self._load_generator(".txt2img", "Txt2Img") def _make_img2img(self): - return self._load_generator('.img2img','Img2Img') + return self._load_generator(".img2img", "Img2Img") def _make_embiggen(self): - return self._load_generator('.embiggen','Embiggen') + return self._load_generator(".embiggen", "Embiggen") def _make_txt2img2img(self): - return self._load_generator('.txt2img2img','Txt2Img2Img') + return self._load_generator(".txt2img2img", "Txt2Img2Img") def _make_inpaint(self): - return self._load_generator('.inpaint','Inpaint') + return self._load_generator(".inpaint", "Inpaint") def _make_omnibus(self): - return self._load_generator('.omnibus','Omnibus') + return self._load_generator(".omnibus", "Omnibus") def _load_generator(self, module, class_name): if self.is_legacy_model(self.model_name): - mn = f'ldm.invoke.ckpt_generator{module}' - cn = f'Ckpt{class_name}' + mn = f"ldm.invoke.ckpt_generator{module}" + cn = f"Ckpt{class_name}" else: - mn = f'ldm.invoke.generator{module}' + mn = f"ldm.invoke.generator{module}" cn = class_name module = importlib.import_module(mn) - constructor = getattr(module,cn) + constructor = getattr(module, cn) return constructor(self.model, self.precision) def load_model(self): - ''' + """ preload model identified in self.model_name - ''' + """ return self.set_model(self.model_name) - def set_model(self,model_name): + def set_model(self, model_name): """ Given the name of a model defined in models.yaml, will load and initialize it and return the model object. Previously-used models will be cached. @@ -884,7 +927,9 @@ class Generate: # the model cache does the loading and offloading cache = self.model_manager if not cache.valid_model(model_name): - raise KeyError(f'** "{model_name}" is not a known model name. Cannot change.') + raise KeyError( + f'** "{model_name}" is not a known model name. Cannot change.' + ) cache.print_vram_usage() @@ -897,20 +942,20 @@ class Generate: try: model_data = cache.get_model(model_name) except Exception as e: - print(f'** model {model_name} could not be loaded: {str(e)}') + print(f"** model {model_name} could not be loaded: {str(e)}") print(traceback.format_exc(), file=sys.stderr) if previous_model_name is None: raise e - print(f'** trying to reload previous model') - model_data = cache.get_model(previous_model_name) # load previous + print("** trying to reload previous model") + model_data = cache.get_model(previous_model_name) # load previous if model_data is None: raise e model_name = previous_model_name - self.model = model_data['model'] - self.width = model_data['width'] - self.height= model_data['height'] - self.model_hash = model_data['hash'] + self.model = model_data["model"] + self.width = model_data["width"] + self.height = model_data["height"] + self.model_hash = model_data["hash"] # uncache generators so they pick up new models self.generators = {} @@ -920,35 +965,37 @@ class Generate: for root, _, files in os.walk(self.embedding_path): for name in files: ti_path = os.path.join(root, name) - self.model.textual_inversion_manager.load_textual_inversion(ti_path, - defer_injecting_tokens=True) - print(f'>> Textual inversions available: {", ".join(self.model.textual_inversion_manager.get_all_trigger_strings())}') + self.model.textual_inversion_manager.load_textual_inversion( + ti_path, defer_injecting_tokens=True + ) + print( + f'>> Textual inversions available: {", ".join(self.model.textual_inversion_manager.get_all_trigger_strings())}' + ) self.model_name = model_name self._set_sampler() # requires self.model_name to be set first return self.model - def load_huggingface_concepts(self, concepts:list[str]): + def load_huggingface_concepts(self, concepts: list[str]): self.model.textual_inversion_manager.load_huggingface_concepts(concepts) @property def huggingface_concepts_library(self) -> HuggingFaceConceptsLibrary: return self.model.textual_inversion_manager.hf_concepts_library - def correct_colors(self, - image_list, - reference_image_path, - image_callback = None): + @property + def embedding_trigger_strings(self) -> List[str]: + return self.model.textual_inversion_manager.get_all_trigger_strings() + + def correct_colors(self, image_list, reference_image_path, image_callback=None): reference_image = Image.open(reference_image_path) - correction_target = cv2.cvtColor(np.asarray(reference_image), - cv2.COLOR_RGB2LAB) + correction_target = cv2.cvtColor(np.asarray(reference_image), cv2.COLOR_RGB2LAB) for r in image_list: image, seed = r - image = cv2.cvtColor(np.asarray(image), - cv2.COLOR_RGB2LAB) - image = skimage.exposure.match_histograms(image, - correction_target, - channel_axis=2) + image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2LAB) + image = skimage.exposure.match_histograms( + image, correction_target, channel_axis=2 + ) image = Image.fromarray( cv2.cvtColor(image, cv2.COLOR_LAB2RGB).astype("uint8") ) @@ -957,34 +1004,46 @@ class Generate: else: r[0] = image - def upscale_and_reconstruct(self, - image_list, - facetool = 'gfpgan', - upscale = None, - upscale_denoise_str = 0.75, - strength = 0.0, - codeformer_fidelity = 0.75, - save_original = False, - image_callback = None, - prefix = None, + def upscale_and_reconstruct( + self, + image_list, + facetool="gfpgan", + upscale=None, + upscale_denoise_str=0.75, + strength=0.0, + codeformer_fidelity=0.75, + save_original=False, + image_callback=None, + prefix=None, ): - for r in image_list: image, seed = r try: if strength > 0: if self.gfpgan is not None or self.codeformer is not None: - if facetool == 'gfpgan': + if facetool == "gfpgan": if self.gfpgan is None: - print('>> GFPGAN not found. Face restoration is disabled.') + print( + ">> GFPGAN not found. Face restoration is disabled." + ) else: - image = self.gfpgan.process(image, strength, seed) - if facetool == 'codeformer': + image = self.gfpgan.process(image, strength, seed) + if facetool == "codeformer": if self.codeformer is None: - print('>> CodeFormer not found. Face restoration is disabled.') + print( + ">> CodeFormer not found. Face restoration is disabled." + ) else: - cf_device = 'cpu' if str(self.device) == 'mps' else self.device - image = self.codeformer.process(image=image, strength=strength, device=cf_device, seed=seed, fidelity=codeformer_fidelity) + cf_device = ( + "cpu" if str(self.device) == "mps" else self.device + ) + image = self.codeformer.process( + image=image, + strength=strength, + device=cf_device, + seed=seed, + fidelity=codeformer_fidelity, + ) else: print(">> Face Restoration is disabled.") if upscale is not None: @@ -992,12 +1051,17 @@ class Generate: if len(upscale) < 2: upscale.append(0.75) image = self.esrgan.process( - image, upscale[1], seed, int(upscale[0]), denoise_str=upscale_denoise_str) + image, + upscale[1], + seed, + int(upscale[0]), + denoise_str=upscale_denoise_str, + ) else: print(">> ESRGAN is disabled. Image not upscaled.") except Exception as e: print( - f'>> Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}' + f">> Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}" ) if image_callback is not None: @@ -1005,22 +1069,26 @@ class Generate: else: r[0] = image - def apply_textmask(self, image_path:str, prompt:str, callback, threshold:float=0.5): - assert os.path.exists(image_path), f'** "{image_path}" not found. Please enter the name of an existing image file to mask **' - basename,_ = os.path.splitext(os.path.basename(image_path)) + def apply_textmask( + self, image_path: str, prompt: str, callback, threshold: float = 0.5 + ): + assert os.path.exists( + image_path + ), f'** "{image_path}" not found. Please enter the name of an existing image file to mask **' + basename, _ = os.path.splitext(os.path.basename(image_path)) if self.txt2mask is None: - self.txt2mask = Txt2Mask(device = self.device, refined=True) - segmented = self.txt2mask.segment(image_path,prompt) + self.txt2mask = Txt2Mask(device=self.device, refined=True) + segmented = self.txt2mask.segment(image_path, prompt) trans = segmented.to_transparent() inverse = segmented.to_transparent(invert=True) mask = segmented.to_mask(threshold) path_filter = re.compile(r'[<>:"/\\|?*]') - safe_prompt = path_filter.sub('_', prompt)[:50].rstrip(' .') + safe_prompt = path_filter.sub("_", prompt)[:50].rstrip(" .") - callback(trans,f'{safe_prompt}.deselected',use_prefix=basename) - callback(inverse,f'{safe_prompt}.selected',use_prefix=basename) - callback(mask,f'{safe_prompt}.masked',use_prefix=basename) + callback(trans, f"{safe_prompt}.deselected", use_prefix=basename) + callback(inverse, f"{safe_prompt}.selected", use_prefix=basename) + callback(mask, f"{safe_prompt}.masked", use_prefix=basename) # to help WebGUI - front end to generator util function def sample_to_image(self, samples): @@ -1029,7 +1097,7 @@ class Generate: def sample_to_lowres_estimated_image(self, samples): return self._make_base().sample_to_lowres_estimated_image(samples) - def is_legacy_model(self,model_name)->bool: + def is_legacy_model(self, model_name) -> bool: return self.model_manager.is_legacy(model_name) def _set_sampler(self): @@ -1041,29 +1109,31 @@ class Generate: # very repetitive code - can this be simplified? The KSampler names are # consistent, at least def _set_sampler_legacy(self): - msg = f'>> Setting Sampler to {self.sampler_name}' - if self.sampler_name == 'plms': + msg = f">> Setting Sampler to {self.sampler_name}" + if self.sampler_name == "plms": self.sampler = PLMSSampler(self.model, device=self.device) - elif self.sampler_name == 'ddim': + elif self.sampler_name == "ddim": self.sampler = DDIMSampler(self.model, device=self.device) - elif self.sampler_name == 'k_dpm_2_a': - self.sampler = KSampler(self.model, 'dpm_2_ancestral', device=self.device) - elif self.sampler_name == 'k_dpm_2': - self.sampler = KSampler(self.model, 'dpm_2', device=self.device) - elif self.sampler_name == 'k_dpmpp_2_a': - self.sampler = KSampler(self.model, 'dpmpp_2s_ancestral', device=self.device) - elif self.sampler_name == 'k_dpmpp_2': - self.sampler = KSampler(self.model, 'dpmpp_2m', device=self.device) - elif self.sampler_name == 'k_euler_a': - self.sampler = KSampler(self.model, 'euler_ancestral', device=self.device) - elif self.sampler_name == 'k_euler': - self.sampler = KSampler(self.model, 'euler', device=self.device) - elif self.sampler_name == 'k_heun': - self.sampler = KSampler(self.model, 'heun', device=self.device) - elif self.sampler_name == 'k_lms': - self.sampler = KSampler(self.model, 'lms', device=self.device) + elif self.sampler_name == "k_dpm_2_a": + self.sampler = KSampler(self.model, "dpm_2_ancestral", device=self.device) + elif self.sampler_name == "k_dpm_2": + self.sampler = KSampler(self.model, "dpm_2", device=self.device) + elif self.sampler_name == "k_dpmpp_2_a": + self.sampler = KSampler( + self.model, "dpmpp_2s_ancestral", device=self.device + ) + elif self.sampler_name == "k_dpmpp_2": + self.sampler = KSampler(self.model, "dpmpp_2m", device=self.device) + elif self.sampler_name == "k_euler_a": + self.sampler = KSampler(self.model, "euler_ancestral", device=self.device) + elif self.sampler_name == "k_euler": + self.sampler = KSampler(self.model, "euler", device=self.device) + elif self.sampler_name == "k_heun": + self.sampler = KSampler(self.model, "heun", device=self.device) + elif self.sampler_name == "k_lms": + self.sampler = KSampler(self.model, "lms", device=self.device) else: - msg = f'>> Unsupported Sampler: {self.sampler_name}, Defaulting to plms' + msg = f">> Unsupported Sampler: {self.sampler_name}, Defaulting to plms" self.sampler = PLMSSampler(self.model, device=self.device) print(msg) @@ -1090,51 +1160,59 @@ class Generate: if self.sampler_name in scheduler_map: sampler_class = scheduler_map[self.sampler_name] - msg = f'>> Setting Sampler to {self.sampler_name} ({sampler_class.__name__})' + msg = ( + f">> Setting Sampler to {self.sampler_name} ({sampler_class.__name__})" + ) self.sampler = sampler_class.from_config(self.model.scheduler.config) else: - msg = (f'>> Unsupported Sampler: {self.sampler_name} ' - f'Defaulting to {default}') + msg = ( + f">> Unsupported Sampler: {self.sampler_name} " + f"Defaulting to {default}" + ) self.sampler = default print(msg) - if not hasattr(self.sampler, 'uses_inpainting_model'): + if not hasattr(self.sampler, "uses_inpainting_model"): # FIXME: terrible kludge! self.sampler.uses_inpainting_model = lambda: False - def _load_img(self, img)->Image: + def _load_img(self, img) -> Image: if isinstance(img, Image.Image): image = img - print( - f'>> using provided input image of size {image.width}x{image.height}' - ) + print(f">> using provided input image of size {image.width}x{image.height}") elif isinstance(img, str): - assert os.path.exists(img), f'>> {img}: File not found' + assert os.path.exists(img), f">> {img}: File not found" image = Image.open(img) print( - f'>> loaded input image of size {image.width}x{image.height} from {img}' + f">> loaded input image of size {image.width}x{image.height} from {img}" ) else: image = Image.open(img) - print( - f'>> loaded input image of size {image.width}x{image.height}' - ) + print(f">> loaded input image of size {image.width}x{image.height}") image = ImageOps.exif_transpose(image) return image def _create_init_image(self, image: Image.Image, width, height, fit=True): - if image.mode != 'RGBA': - image = image.convert('RGBA') - image = self._fit_image(image, (width, height)) if fit else self._squeeze_image(image) + if image.mode != "RGBA": + image = image.convert("RGBA") + image = ( + self._fit_image(image, (width, height)) + if fit + else self._squeeze_image(image) + ) return image def _create_init_mask(self, image, width, height, fit=True): # convert into a black/white mask image = self._image_to_mask(image) - image = image.convert('RGB') - image = self._fit_image(image, (width, height)) if fit else self._squeeze_image(image) + image = image.convert("RGB") + image = ( + self._fit_image(image, (width, height)) + if fit + else self._squeeze_image(image) + ) return image # The mask is expected to have the region to be inpainted @@ -1142,10 +1220,10 @@ class Generate: # image with the transparent part black. def _image_to_mask(self, mask_image: Image.Image, invert=False) -> Image: # Obtain the mask from the transparency channel - if mask_image.mode == 'L': + if mask_image.mode == "L": mask = mask_image - elif mask_image.mode in ('RGB', 'P'): - mask = mask_image.convert('L') + elif mask_image.mode in ("RGB", "P"): + mask = mask_image.convert("L") else: # Obtain the mask from the transparency channel mask = Image.new(mode="L", size=mask_image.size, color=255) @@ -1154,16 +1232,20 @@ class Generate: mask = ImageOps.invert(mask) return mask - def _txt2mask(self, image:Image, text_mask:list, width, height, fit=True) -> Image: + def _txt2mask( + self, image: Image, text_mask: list, width, height, fit=True + ) -> Image: prompt = text_mask[0] - confidence_level = text_mask[1] if len(text_mask)>1 else 0.5 + confidence_level = text_mask[1] if len(text_mask) > 1 else 0.5 if self.txt2mask is None: - self.txt2mask = Txt2Mask(device = self.device) + self.txt2mask = Txt2Mask(device=self.device) segmented = self.txt2mask.segment(image, prompt) mask = segmented.to_mask(float(confidence_level)) - mask = mask.convert('RGB') - mask = self._fit_image(mask, (width, height)) if fit else self._squeeze_image(mask) + mask = mask.convert("RGB") + mask = ( + self._fit_image(mask, (width, height)) if fit else self._squeeze_image(mask) + ) return mask def _has_transparency(self, image): @@ -1180,8 +1262,8 @@ class Generate: return True return False - def _check_for_erasure(self, image:Image.Image)->bool: - if image.mode not in ('RGBA','RGB'): + def _check_for_erasure(self, image: Image.Image) -> bool: + if image.mode not in ("RGBA", "RGB"): return False width, height = image.size pixdata = image.load() @@ -1190,20 +1272,20 @@ class Generate: for x in range(width): if pixdata[x, y][3] == 0: r, g, b, _ = pixdata[x, y] - if (r, g, b) != (0, 0, 0) and \ - (r, g, b) != (255, 255, 255): + if (r, g, b) != (0, 0, 0) and (r, g, b) != (255, 255, 255): colored += 1 return colored == 0 - def _transparency_check_and_warning(self,image, mask, force_outpaint=False): + def _transparency_check_and_warning(self, image, mask, force_outpaint=False): if not mask: print( - '>> Initial image has transparent areas. Will inpaint in these regions.') + ">> Initial image has transparent areas. Will inpaint in these regions." + ) if (not force_outpaint) and self._check_for_erasure(image): print( - '>> WARNING: Colors underneath the transparent region seem to have been erased.\n', - '>> Inpainting will be suboptimal. Please preserve the colors when making\n', - '>> a transparency mask, or provide mask explicitly using --init_mask (-M).' + ">> WARNING: Colors underneath the transparent region seem to have been erased.\n", + ">> Inpainting will be suboptimal. Please preserve the colors when making\n", + ">> a transparency mask, or provide mask explicitly using --init_mask (-M).", ) def _squeeze_image(self, image): @@ -1214,13 +1296,11 @@ class Generate: def _fit_image(self, image, max_dimensions): w, h = max_dimensions - print( - f'>> image will be resized to fit inside a box {w}x{h} in size.' - ) + print(f">> image will be resized to fit inside a box {w}x{h} in size.") # note that InitImageResizer does the multiple of 64 truncation internally image = InitImageResizer(image).resize(width=w, height=h) print( - f'>> after adjusting image dimensions to be multiples of 64, init image is {image.width}x{image.height}' + f">> after adjusting image dimensions to be multiples of 64, init image is {image.width}x{image.height}" ) return image @@ -1232,30 +1312,32 @@ class Generate: if h != height or w != width: if log: print( - f'>> Provided width and height must be multiples of 64. Auto-resizing to {w}x{h}' + f">> Provided width and height must be multiples of 64. Auto-resizing to {w}x{h}" ) height = h width = w resize_needed = True return width, height, resize_needed - def _has_cuda(self): - return self.device.type == 'cuda' + return self.device.type == "cuda" - def write_intermediate_images(self,modulus,path): + def write_intermediate_images(self, modulus, path): counter = -1 if not os.path.exists(path): os.makedirs(path) + def callback(img): nonlocal counter counter += 1 if counter % modulus != 0: - return; + return image = self.sample_to_image(img) - image.save(os.path.join(path,f'{counter:03}.png'),'PNG') + image.save(os.path.join(path, f"{counter:03}.png"), "PNG") + return callback + def _pairwise(iterable): "s -> (s0, s1), (s2, s3), (s4, s5), ..." a = iter(iterable) diff --git a/ldm/invoke/CLI.py b/ldm/invoke/CLI.py index 32c6d816be..d56984caf3 100644 --- a/ldm/invoke/CLI.py +++ b/ldm/invoke/CLI.py @@ -1,9 +1,8 @@ import os import re -import sys import shlex +import sys import traceback - from argparse import Namespace from pathlib import Path from typing import Optional, Union @@ -11,41 +10,47 @@ from typing import Optional, Union if sys.platform == "darwin": os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" -from ldm.invoke.globals import Globals +import click # type: ignore +import pyparsing # type: ignore + +import ldm.invoke from ldm.generate import Generate -from ldm.invoke.prompt_parser import PromptParser -from ldm.invoke.readline import get_completer, Completer -from ldm.invoke.args import Args, metadata_dumps, metadata_from_png, dream_cmd_from_png -from ldm.invoke.pngwriter import PngWriter, retrieve_metadata, write_metadata +from ldm.invoke.args import (Args, dream_cmd_from_png, metadata_dumps, + metadata_from_png) +from ldm.invoke.globals import Globals from ldm.invoke.image_util import make_grid from ldm.invoke.log import write_log from ldm.invoke.model_manager import ModelManager - -import click # type: ignore -import ldm.invoke -import pyparsing # type: ignore +from ldm.invoke.pngwriter import PngWriter, retrieve_metadata, write_metadata +from ldm.invoke.prompt_parser import PromptParser +from ldm.invoke.readline import Completer, get_completer # global used in multiple functions (fix) infile = None + def main(): """Initialize command-line parsers and the diffusion model""" global infile - opt = Args() + opt = Args() args = opt.parse_args() if not args: sys.exit(-1) if args.laion400m: - print('--laion400m flag has been deprecated. Please use --model laion400m instead.') + print( + "--laion400m flag has been deprecated. Please use --model laion400m instead." + ) sys.exit(-1) if args.weights: - print('--weights argument has been deprecated. Please edit ./configs/models.yaml, and select the weights using --model instead.') + print( + "--weights argument has been deprecated. Please edit ./configs/models.yaml, and select the weights using --model instead." + ) sys.exit(-1) if args.max_loaded_models is not None: if args.max_loaded_models <= 0: - print('--max_loaded_models must be >= 1; using 1') + print("--max_loaded_models must be >= 1; using 1") args.max_loaded_models = 1 # alert - setting a few globals here @@ -55,36 +60,42 @@ def main(): Globals.disable_xformers = not args.xformers Globals.ckpt_convert = args.ckpt_convert - print(f'>> Internet connectivity is {Globals.internet_available}') + print(f">> Internet connectivity is {Globals.internet_available}") if not args.conf: - config_file = os.path.join(Globals.root,'configs','models.yaml') + config_file = os.path.join(Globals.root, "configs", "models.yaml") if not os.path.exists(config_file): - report_model_error(opt, FileNotFoundError(f"The file {config_file} could not be found.")) + report_model_error( + opt, FileNotFoundError(f"The file {config_file} could not be found.") + ) - print(f'>> {ldm.invoke.__app_name__}, version {ldm.invoke.__version__}') + print(f">> {ldm.invoke.__app_name__}, version {ldm.invoke.__version__}") print(f'>> InvokeAI runtime directory is "{Globals.root}"') # loading here to avoid long delays on startup - from ldm.generate import Generate - # these two lines prevent a horrible warning message from appearing # when the frozen CLIP tokenizer is imported import transformers # type: ignore + + from ldm.generate import Generate + transformers.logging.set_verbosity_error() import diffusers + diffusers.logging.set_verbosity_error() # Loading Face Restoration and ESRGAN Modules - gfpgan,codeformer,esrgan = load_face_restoration(opt) + gfpgan, codeformer, esrgan = load_face_restoration(opt) # normalize the config directory relative to root if not os.path.isabs(opt.conf): - opt.conf = os.path.normpath(os.path.join(Globals.root,opt.conf)) + opt.conf = os.path.normpath(os.path.join(Globals.root, opt.conf)) if opt.embeddings: if not os.path.isabs(opt.embedding_path): - embedding_path = os.path.normpath(os.path.join(Globals.root,opt.embedding_path)) + embedding_path = os.path.normpath( + os.path.join(Globals.root, opt.embedding_path) + ) else: embedding_path = opt.embedding_path else: @@ -97,35 +108,35 @@ def main(): if opt.infile: try: if os.path.isfile(opt.infile): - infile = open(opt.infile, 'r', encoding='utf-8') - elif opt.infile == '-': # stdin + infile = open(opt.infile, "r", encoding="utf-8") + elif opt.infile == "-": # stdin infile = sys.stdin else: - raise FileNotFoundError(f'{opt.infile} not found.') + raise FileNotFoundError(f"{opt.infile} not found.") except (FileNotFoundError, IOError) as e: - print(f'{e}. Aborting.') + print(f"{e}. Aborting.") sys.exit(-1) # creating a Generate object: try: gen = Generate( - conf = opt.conf, - model = opt.model, - sampler_name = opt.sampler_name, - embedding_path = embedding_path, - full_precision = opt.full_precision, - precision = opt.precision, + conf=opt.conf, + model=opt.model, + sampler_name=opt.sampler_name, + embedding_path=embedding_path, + full_precision=opt.full_precision, + precision=opt.precision, gfpgan=gfpgan, codeformer=codeformer, esrgan=esrgan, free_gpu_mem=opt.free_gpu_mem, safety_checker=opt.safety_checker, max_loaded_models=opt.max_loaded_models, - ) + ) except (FileNotFoundError, TypeError, AssertionError) as e: - report_model_error(opt,e) + report_model_error(opt, e) except (IOError, KeyError) as e: - print(f'{e}. Aborting.') + print(f"{e}. Aborting.") sys.exit(-1) if opt.seamless: @@ -160,11 +171,14 @@ def main(): try: main_loop(gen, opt) except KeyboardInterrupt: - print(f'\nGoodbye!\nYou can start InvokeAI again by running the "invoke.bat" (or "invoke.sh") script from {Globals.root}') + print( + f'\nGoodbye!\nYou can start InvokeAI again by running the "invoke.bat" (or "invoke.sh") script from {Globals.root}' + ) except Exception: print(">> An error occurred:") traceback.print_exc() + # TODO: main_loop() has gotten busy. Needs to be refactored. def main_loop(gen, opt): """prompt/read/execute loop""" @@ -177,23 +191,22 @@ def main_loop(gen, opt): # The readline completer reads history from the .dream_history file located in the # output directory specified at the time of script launch. We do not currently support # changing the history file midstream when the output directory is changed. - completer = get_completer(opt, models=gen.model_manager.list_models()) + completer = get_completer(opt, models=gen.model_manager.list_models()) set_default_output_dir(opt, completer) if gen.model: add_embedding_terms(gen, completer) - output_cntr = completer.get_current_history_length()+1 + output_cntr = completer.get_current_history_length() + 1 # os.pathconf is not available on Windows - if hasattr(os, 'pathconf'): - path_max = os.pathconf(opt.outdir, 'PC_PATH_MAX') - name_max = os.pathconf(opt.outdir, 'PC_NAME_MAX') + if hasattr(os, "pathconf"): + path_max = os.pathconf(opt.outdir, "PC_PATH_MAX") + name_max = os.pathconf(opt.outdir, "PC_NAME_MAX") else: path_max = 260 name_max = 255 while not done: - - operation = 'generate' + operation = "generate" try: command = get_next_command(infile, gen.model_name) @@ -206,17 +219,17 @@ def main_loop(gen, opt): if not command.strip(): continue - if command.startswith(('#', '//')): + if command.startswith(("#", "//")): continue - if len(command.strip()) == 1 and command.startswith('q'): + if len(command.strip()) == 1 and command.startswith("q"): done = True break - if not command.startswith('!history'): + if not command.startswith("!history"): completer.add_history(command) - if command.startswith('!'): + if command.startswith("!"): command, operation = do_command(command, gen, opt, completer) if operation is None: @@ -228,14 +241,14 @@ def main_loop(gen, opt): if opt.init_img: try: if not opt.prompt: - oldargs = metadata_from_png(opt.init_img) + oldargs = metadata_from_png(opt.init_img) opt.prompt = oldargs.prompt print(f'>> Retrieved old prompt "{opt.prompt}" from {opt.init_img}') except (OSError, AttributeError, KeyError): pass if len(opt.prompt) == 0: - opt.prompt = '' + opt.prompt = "" # width and height are set by model if not specified if not opt.width: @@ -244,36 +257,35 @@ def main_loop(gen, opt): opt.height = gen.height # retrieve previous value of init image if requested - if opt.init_img is not None and re.match('^-\\d+$', opt.init_img): + if opt.init_img is not None and re.match("^-\\d+$", opt.init_img): try: opt.init_img = last_results[int(opt.init_img)][0] - print(f'>> Reusing previous image {opt.init_img}') + print(f">> Reusing previous image {opt.init_img}") except IndexError: - print( - f'>> No previous initial image at position {opt.init_img} found') + print(f">> No previous initial image at position {opt.init_img} found") opt.init_img = None continue # the outdir can change with each command, so we adjust it here - set_default_output_dir(opt,completer) + set_default_output_dir(opt, completer) # try to relativize pathnames - for attr in ('init_img','init_mask','init_color'): - if getattr(opt,attr) and not os.path.exists(getattr(opt,attr)): - basename = getattr(opt,attr) - path = os.path.join(opt.outdir,basename) - setattr(opt,attr,path) + for attr in ("init_img", "init_mask", "init_color"): + if getattr(opt, attr) and not os.path.exists(getattr(opt, attr)): + basename = getattr(opt, attr) + path = os.path.join(opt.outdir, basename) + setattr(opt, attr, path) # retrieve previous value of seed if requested # Exception: for postprocess operations negative seed values # mean "discard the original seed and generate a new one" # (this is a non-obvious hack and needs to be reworked) - if opt.seed is not None and opt.seed < 0 and operation != 'postprocess': + if opt.seed is not None and opt.seed < 0 and operation != "postprocess": try: opt.seed = last_results[opt.seed][1] - print(f'>> Reusing previous seed {opt.seed}') + print(f">> Reusing previous seed {opt.seed}") except IndexError: - print(f'>> No previous seed at position {opt.seed} found') + print(f">> No previous seed at position {opt.seed} found") opt.seed = None continue @@ -283,13 +295,13 @@ def main_loop(gen, opt): if opt.with_variations is not None: opt.with_variations = split_variations(opt.with_variations) - if opt.prompt_as_dir and operation == 'generate': + if opt.prompt_as_dir and operation == "generate": # sanitize the prompt to a valid folder name - subdir = path_filter.sub('_', opt.prompt)[:name_max].rstrip(' .') + subdir = path_filter.sub("_", opt.prompt)[:name_max].rstrip(" .") # truncate path to maximum allowed length # 39 is the length of '######.##########.##########-##.png', plus two separators and a NUL - subdir = subdir[:(path_max - 39 - len(os.path.abspath(opt.outdir)))] + subdir = subdir[: (path_max - 39 - len(os.path.abspath(opt.outdir)))] current_outdir = os.path.join(opt.outdir, subdir) print('Writing files to directory: "' + current_outdir + '"') @@ -305,14 +317,26 @@ def main_loop(gen, opt): # Here is where the images are actually generated! last_results = [] try: - file_writer = PngWriter(current_outdir) - results = [] # list of filename, prompt pairs - grid_images = dict() # seed -> Image, only used if `opt.grid` + file_writer = PngWriter(current_outdir) + results = [] # list of filename, prompt pairs + grid_images = dict() # seed -> Image, only used if `opt.grid` prior_variations = opt.with_variations or [] prefix = file_writer.unique_prefix() - step_callback = make_step_callback(gen, opt, prefix) if opt.save_intermediates > 0 else None + step_callback = ( + make_step_callback(gen, opt, prefix) + if opt.save_intermediates > 0 + else None + ) - def image_writer(image, seed, upscaled=False, first_seed=None, use_prefix=None, prompt_in=None, attention_maps_image=None): + def image_writer( + image, + seed, + upscaled=False, + first_seed=None, + use_prefix=None, + prompt_in=None, + attention_maps_image=None, + ): # note the seed is the seed of the current image # the first_seed is the original seed that noise is added to # when the -v switch is used to generate variations @@ -323,25 +347,31 @@ def main_loop(gen, opt): if opt.grid: grid_images[seed] = image - elif operation == 'mask': - filename = f'{prefix}.{use_prefix}.{seed}.png' + elif operation == "mask": + filename = f"{prefix}.{use_prefix}.{seed}.png" tm = opt.text_mask[0] - th = opt.text_mask[1] if len(opt.text_mask)>1 else 0.5 - formatted_dream_prompt = f'!mask {opt.input_file_path} -tm {tm} {th}' + th = opt.text_mask[1] if len(opt.text_mask) > 1 else 0.5 + formatted_dream_prompt = ( + f"!mask {opt.input_file_path} -tm {tm} {th}" + ) path = file_writer.save_image_and_prompt_to_png( - image = image, - dream_prompt = formatted_dream_prompt, - metadata = {}, - name = filename, - compress_level = opt.png_compression, + image=image, + dream_prompt=formatted_dream_prompt, + metadata={}, + name=filename, + compress_level=opt.png_compression, ) results.append([path, formatted_dream_prompt]) else: if use_prefix is not None: prefix = use_prefix - postprocessed = upscaled if upscaled else operation=='postprocess' - opt.prompt = gen.huggingface_concepts_library.replace_triggers_with_concepts(opt.prompt or prompt_in) # to avoid the problem of non-unique concept triggers + postprocessed = upscaled if upscaled else operation == "postprocess" + opt.prompt = ( + gen.huggingface_concepts_library.replace_triggers_with_concepts( + opt.prompt or prompt_in + ) + ) # to avoid the problem of non-unique concept triggers filename, formatted_dream_prompt = prepare_image_metadata( opt, prefix, @@ -349,23 +379,30 @@ def main_loop(gen, opt): operation, prior_variations, postprocessed, - first_seed + first_seed, ) path = file_writer.save_image_and_prompt_to_png( - image = image, - dream_prompt = formatted_dream_prompt, - metadata = metadata_dumps( + image=image, + dream_prompt=formatted_dream_prompt, + metadata=metadata_dumps( opt, - seeds = [seed if opt.variation_amount==0 and len(prior_variations)==0 else first_seed], - model_hash = gen.model_hash, + seeds=[ + seed + if opt.variation_amount == 0 + and len(prior_variations) == 0 + else first_seed + ], + model_hash=gen.model_hash, ), - name = filename, - compress_level = opt.png_compression, + name=filename, + compress_level=opt.png_compression, ) # update rfc metadata - if operation == 'postprocess': - tool = re.match('postprocess:(\w+)',opt.last_operation).groups()[0] + if operation == "postprocess": + tool = re.match( + "postprocess:(\w+)", opt.last_operation + ).groups()[0] add_postprocessing_to_metadata( opt, opt.input_file_path, @@ -379,49 +416,51 @@ def main_loop(gen, opt): results.append([path, formatted_dream_prompt]) # so that the seed autocompletes (on linux|mac when -S or --seed specified - if completer and operation == 'generate': + if completer and operation == "generate": completer.add_seed(seed) completer.add_seed(first_seed) last_results.append([path, seed]) - if operation == 'generate': - catch_ctrl_c = infile is None # if running interactively, we catch keyboard interrupts - opt.last_operation='generate' + if operation == "generate": + catch_ctrl_c = ( + infile is None + ) # if running interactively, we catch keyboard interrupts + opt.last_operation = "generate" try: gen.prompt2image( image_callback=image_writer, step_callback=step_callback, catch_interrupts=catch_ctrl_c, - **vars(opt) + **vars(opt), ) except (PromptParser.ParsingException, pyparsing.ParseException) as e: - print('** An error occurred while processing your prompt **') - print(f'** {str(e)} **') - elif operation == 'postprocess': - print(f'>> fixing {opt.prompt}') - opt.last_operation = do_postprocess(gen,opt,image_writer) + print("** An error occurred while processing your prompt **") + print(f"** {str(e)} **") + elif operation == "postprocess": + print(f">> fixing {opt.prompt}") + opt.last_operation = do_postprocess(gen, opt, image_writer) - elif operation == 'mask': - print(f'>> generating masks from {opt.prompt}') + elif operation == "mask": + print(f">> generating masks from {opt.prompt}") do_textmask(gen, opt, image_writer) if opt.grid and len(grid_images) > 0: - grid_img = make_grid(list(grid_images.values())) + grid_img = make_grid(list(grid_images.values())) grid_seeds = list(grid_images.keys()) first_seed = last_results[0][1] - filename = f'{prefix}.{first_seed}.png' - formatted_dream_prompt = opt.dream_prompt_str(seed=first_seed,grid=True,iterations=len(grid_images)) - formatted_dream_prompt += f' # {grid_seeds}' + filename = f"{prefix}.{first_seed}.png" + formatted_dream_prompt = opt.dream_prompt_str( + seed=first_seed, grid=True, iterations=len(grid_images) + ) + formatted_dream_prompt += f" # {grid_seeds}" metadata = metadata_dumps( - opt, - seeds = grid_seeds, - model_hash = gen.model_hash - ) + opt, seeds=grid_seeds, model_hash=gen.model_hash + ) path = file_writer.save_image_and_prompt_to_png( - image = grid_img, - dream_prompt = formatted_dream_prompt, - metadata = metadata, - name = filename + image=grid_img, + dream_prompt=formatted_dream_prompt, + metadata=metadata, + name=filename, ) results = [[path, formatted_dream_prompt]] @@ -433,286 +472,321 @@ def main_loop(gen, opt): print(e) continue - print('Outputs:') - log_path = os.path.join(current_outdir, 'invoke_log') - output_cntr = write_log(results, log_path ,('txt', 'md'), output_cntr) + print("Outputs:") + log_path = os.path.join(current_outdir, "invoke_log") + output_cntr = write_log(results, log_path, ("txt", "md"), output_cntr) print() + print( + f'\nGoodbye!\nYou can start InvokeAI again by running the "invoke.bat" (or "invoke.sh") script from {Globals.root}' + ) - print(f'\nGoodbye!\nYou can start InvokeAI again by running the "invoke.bat" (or "invoke.sh") script from {Globals.root}') # TO DO: remove repetitive code and the awkward command.replace() trope # Just do a simple parse of the command! -def do_command(command:str, gen, opt:Args, completer) -> tuple: +def do_command(command: str, gen, opt: Args, completer) -> tuple: global infile - operation = 'generate' # default operation, alternative is 'postprocess' + operation = "generate" # default operation, alternative is 'postprocess' - if command.startswith('!dream'): # in case a stored prompt still contains the !dream command - command = command.replace('!dream ','',1) + if command.startswith( + "!dream" + ): # in case a stored prompt still contains the !dream command + command = command.replace("!dream ", "", 1) - elif command.startswith('!fix'): - command = command.replace('!fix ','',1) - operation = 'postprocess' + elif command.startswith("!fix"): + command = command.replace("!fix ", "", 1) + operation = "postprocess" - elif command.startswith('!mask'): - command = command.replace('!mask ','',1) - operation = 'mask' + elif command.startswith("!mask"): + command = command.replace("!mask ", "", 1) + operation = "mask" - elif command.startswith('!switch'): - model_name = command.replace('!switch ','',1) + elif command.startswith("!switch"): + model_name = command.replace("!switch ", "", 1) try: gen.set_model(model_name) add_embedding_terms(gen, completer) except KeyError as e: print(str(e)) except Exception as e: - report_model_error(opt,e) + report_model_error(opt, e) completer.add_history(command) operation = None - elif command.startswith('!models'): + elif command.startswith("!models"): gen.model_manager.print_models() completer.add_history(command) operation = None - elif command.startswith('!import'): + elif command.startswith("!import"): path = shlex.split(command) if len(path) < 2: - print('** please provide (1) a URL to a .ckpt file to import; (2) a local path to a .ckpt file; or (3) a diffusers repository id in the form stabilityai/stable-diffusion-2-1') + print( + "** please provide (1) a URL to a .ckpt file to import; (2) a local path to a .ckpt file; or (3) a diffusers repository id in the form stabilityai/stable-diffusion-2-1" + ) else: import_model(path[1], gen, opt, completer) completer.add_history(command) operation = None - elif command.startswith('!convert'): + elif command.startswith("!convert"): path = shlex.split(command) if len(path) < 2: - print('** please provide the path to a .ckpt or .safetensors model') + print("** please provide the path to a .ckpt or .safetensors model") elif not os.path.exists(path[1]): - print(f'** {path[1]}: model not found') + print(f"** {path[1]}: model not found") else: optimize_model(path[1], gen, opt, completer) completer.add_history(command) operation = None - - elif command.startswith('!optimize'): + elif command.startswith("!optimize"): path = shlex.split(command) if len(path) < 2: - print('** please provide an installed model name') + print("** please provide an installed model name") elif not path[1] in gen.model_manager.list_models(): - print(f'** {path[1]}: model not found') + print(f"** {path[1]}: model not found") else: optimize_model(path[1], gen, opt, completer) completer.add_history(command) operation = None - elif command.startswith('!edit'): + elif command.startswith("!edit"): path = shlex.split(command) if len(path) < 2: - print('** please provide the name of a model') + print("** please provide the name of a model") else: edit_model(path[1], gen, opt, completer) completer.add_history(command) operation = None - elif command.startswith('!del'): + elif command.startswith("!del"): path = shlex.split(command) if len(path) < 2: - print('** please provide the name of a model') + print("** please provide the name of a model") else: del_config(path[1], gen, opt, completer) completer.add_history(command) operation = None - elif command.startswith('!fetch'): - file_path = command.replace('!fetch','',1).strip() - retrieve_dream_command(opt,file_path,completer) + elif command.startswith("!fetch"): + file_path = command.replace("!fetch", "", 1).strip() + retrieve_dream_command(opt, file_path, completer) completer.add_history(command) operation = None - elif command.startswith('!replay'): - file_path = command.replace('!replay','',1).strip() + elif command.startswith("!replay"): + file_path = command.replace("!replay", "", 1).strip() if infile is None and os.path.isfile(file_path): - infile = open(file_path, 'r', encoding='utf-8') + infile = open(file_path, "r", encoding="utf-8") completer.add_history(command) operation = None - elif command.startswith('!history'): + elif command.startswith("!trigger"): + print("Embedding trigger strings: ", ", ".join(gen.embedding_trigger_strings)) + operation = None + + elif command.startswith("!history"): completer.show_history() operation = None - elif command.startswith('!search'): - search_str = command.replace('!search','',1).strip() + elif command.startswith("!search"): + search_str = command.replace("!search", "", 1).strip() completer.show_history(search_str) operation = None - elif command.startswith('!clear'): + elif command.startswith("!clear"): completer.clear_history() operation = None - elif re.match('^!(\d+)',command): - command_no = re.match('^!(\d+)',command).groups()[0] - command = completer.get_line(int(command_no)) + elif re.match("^!(\d+)", command): + command_no = re.match("^!(\d+)", command).groups()[0] + command = completer.get_line(int(command_no)) completer.set_line(command) operation = None else: # not a recognized command, so give the --help text - command = '-h' + command = "-h" return command, operation -def set_default_output_dir(opt:Args, completer:Completer): - ''' + +def set_default_output_dir(opt: Args, completer: Completer): + """ If opt.outdir is relative, we add the root directory to it normalize the outdir relative to root and make sure it exists. - ''' + """ if not os.path.isabs(opt.outdir): - opt.outdir=os.path.normpath(os.path.join(Globals.root,opt.outdir)) + opt.outdir = os.path.normpath(os.path.join(Globals.root, opt.outdir)) if not os.path.exists(opt.outdir): os.makedirs(opt.outdir) completer.set_default_dir(opt.outdir) def import_model(model_path: str, gen, opt, completer): - ''' + """ model_path can be (1) a URL to a .ckpt file; (2) a local .ckpt file path; or (3) a huggingface repository id - ''' + """ model_name = None - if model_path.startswith(('http:','https:','ftp:')): + if model_path.startswith(("http:", "https:", "ftp:")): model_name = import_ckpt_model(model_path, gen, opt, completer) - elif os.path.exists(model_path) and model_path.endswith(('.ckpt','.safetensors')) and os.path.isfile(model_path): + elif ( + os.path.exists(model_path) + and model_path.endswith((".ckpt", ".safetensors")) + and os.path.isfile(model_path) + ): model_name = import_ckpt_model(model_path, gen, opt, completer) elif os.path.isdir(model_path): - # Allow for a directory containing multiple models. - models = list(Path(model_path).rglob('*.ckpt')) + list(Path(model_path).rglob('*.safetensors')) + models = list(Path(model_path).rglob("*.ckpt")) + list( + Path(model_path).rglob("*.safetensors") + ) if models: # Only the last model name will be used below. for model in sorted(models): - - if click.confirm(f'Import {model.stem} ?', default=True): + if click.confirm(f"Import {model.stem} ?", default=True): model_name = import_ckpt_model(model, gen, opt, completer) print() else: model_name = import_diffuser_model(Path(model_path), gen, opt, completer) - elif re.match(r'^[\w.+-]+/[\w.+-]+$', model_path): + elif re.match(r"^[\w.+-]+/[\w.+-]+$", model_path): model_name = import_diffuser_model(model_path, gen, opt, completer) else: - print(f'** {model_path} is neither the path to a .ckpt file nor a diffusers repository id. Can\'t import.') + print( + f"** {model_path} is neither the path to a .ckpt file nor a diffusers repository id. Can't import." + ) if not model_name: return if not _verify_load(model_name, gen): - print('** model failed to load. Discarding configuration entry') + print("** model failed to load. Discarding configuration entry") gen.model_manager.del_model(model_name) return - if input('Make this the default model? [n] ').strip() in ('y','Y'): + if input("Make this the default model? [n] ").strip() in ("y", "Y"): gen.model_manager.set_default_model(model_name) gen.model_manager.commit(opt.conf) completer.update_models(gen.model_manager.list_models()) - print(f'>> {model_name} successfully installed') + print(f">> {model_name} successfully installed") -def import_diffuser_model(path_or_repo: Union[Path, str], gen, _, completer) -> Optional[str]: + +def import_diffuser_model( + path_or_repo: Union[Path, str], gen, _, completer +) -> Optional[str]: manager = gen.model_manager default_name = Path(path_or_repo).stem - default_description = f'Imported model {default_name}' + default_description = f"Imported model {default_name}" model_name, model_description = _get_model_name_and_desc( manager, completer, model_name=default_name, - model_description=default_description + model_description=default_description, ) vae = None - if input('Replace this model\'s VAE with "stabilityai/sd-vae-ft-mse"? [n] ').strip() in ('y','Y'): - vae = dict(repo_id='stabilityai/sd-vae-ft-mse') + if input( + 'Replace this model\'s VAE with "stabilityai/sd-vae-ft-mse"? [n] ' + ).strip() in ("y", "Y"): + vae = dict(repo_id="stabilityai/sd-vae-ft-mse") if not manager.import_diffuser_model( - path_or_repo, - model_name = model_name, - vae = vae, - description = model_description): - print('** model failed to import') + path_or_repo, model_name=model_name, vae=vae, description=model_description + ): + print("** model failed to import") return None return model_name -def import_ckpt_model(path_or_url: Union[Path, str], gen, opt, completer) -> Optional[str]: + +def import_ckpt_model( + path_or_url: Union[Path, str], gen, opt, completer +) -> Optional[str]: manager = gen.model_manager default_name = Path(path_or_url).stem - default_description = f'Imported model {default_name}' + default_description = f"Imported model {default_name}" model_name, model_description = _get_model_name_and_desc( manager, completer, model_name=default_name, - model_description=default_description + model_description=default_description, ) config_file = None - default = Path(Globals.root,'configs/stable-diffusion/v1-inpainting-inference.yaml') \ - if re.search('inpaint',default_name, flags=re.IGNORECASE) \ - else Path(Globals.root,'configs/stable-diffusion/v1-inference.yaml') + default = ( + Path(Globals.root, "configs/stable-diffusion/v1-inpainting-inference.yaml") + if re.search("inpaint", default_name, flags=re.IGNORECASE) + else Path(Globals.root, "configs/stable-diffusion/v1-inference.yaml") + ) - completer.complete_extensions(('.yaml','.yml')) + completer.complete_extensions((".yaml", ".yml")) completer.set_line(str(default)) done = False while not done: - config_file = input('Configuration file for this model: ').strip() + config_file = input("Configuration file for this model: ").strip() done = os.path.exists(config_file) - completer.complete_extensions(('.ckpt','.safetensors')) + completer.complete_extensions((".ckpt", ".safetensors")) vae = None - default = Path(Globals.root,'models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt') + default = Path( + Globals.root, "models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt" + ) completer.set_line(str(default)) done = False while not done: - vae = input('VAE file for this model (leave blank for none): ').strip() or None + vae = input("VAE file for this model (leave blank for none): ").strip() or None done = (not vae) or os.path.exists(vae) completer.complete_extensions(None) if not manager.import_ckpt_model( - path_or_url, - config = config_file, - vae = vae, - model_name = model_name, - model_description = model_description, - commit_to_conf = opt.conf, + path_or_url, + config=config_file, + vae=vae, + model_name=model_name, + model_description=model_description, + commit_to_conf=opt.conf, ): - print('** model failed to import') + print("** model failed to import") return None return model_name -def _verify_load(model_name:str, gen)->bool: - print('>> Verifying that new model loads...') + +def _verify_load(model_name: str, gen) -> bool: + print(">> Verifying that new model loads...") current_model = gen.model_name if not gen.model_manager.get_model(model_name): return False - do_switch = input('Keep model loaded? [y] ') - if len(do_switch)==0 or do_switch[0] in ('y','Y'): + do_switch = input("Keep model loaded? [y] ") + if len(do_switch) == 0 or do_switch[0] in ("y", "Y"): gen.set_model(model_name) else: - print('>> Restoring previous model') + print(">> Restoring previous model") gen.set_model(current_model) return True -def _get_model_name_and_desc(model_manager,completer,model_name:str='',model_description:str=''): - model_name = _get_model_name(model_manager.list_models(),completer,model_name) + +def _get_model_name_and_desc( + model_manager, completer, model_name: str = "", model_description: str = "" +): + model_name = _get_model_name(model_manager.list_models(), completer, model_name) completer.set_line(model_description) - model_description = input(f'Description for this model [{model_description}]: ').strip() or model_description + model_description = ( + input(f"Description for this model [{model_description}]: ").strip() + or model_description + ) return model_name, model_description -def _is_inpainting(model_name_or_path: str)->bool: - if re.search('inpaint',model_name_or_path, flags=re.IGNORECASE): - return not input('Is this an inpainting model? [y] ').startswith(('n','N')) + +def _is_inpainting(model_name_or_path: str) -> bool: + if re.search("inpaint", model_name_or_path, flags=re.IGNORECASE): + return not input("Is this an inpainting model? [y] ").startswith(("n", "N")) else: - return not input('Is this an inpainting model? [n] ').startswith(('y','Y')) + return not input("Is this an inpainting model? [n] ").startswith(("y", "Y")) + def optimize_model(model_name_or_path: str, gen, opt, completer): manager = gen.model_manager @@ -722,70 +796,76 @@ def optimize_model(model_name_or_path: str, gen, opt, completer): if model_name_or_path == gen.model_name: print("** Can't convert the active model. !switch to another model first. **") return - elif (model_info := manager.model_info(model_name_or_path)): - if 'weights' in model_info: - ckpt_path = Path(model_info['weights']) - original_config_file = Path(model_info['config']) + elif model_info := manager.model_info(model_name_or_path): + if "weights" in model_info: + ckpt_path = Path(model_info["weights"]) + original_config_file = Path(model_info["config"]) model_name = model_name_or_path - model_description = model_info['description'] + model_description = model_info["description"] else: - print(f'** {model_name_or_path} is not a legacy .ckpt weights file') + print(f"** {model_name_or_path} is not a legacy .ckpt weights file") return elif os.path.exists(model_name_or_path): ckpt_path = Path(model_name_or_path) model_name, model_description = _get_model_name_and_desc( - manager, - completer, - ckpt_path.stem, - f'Converted model {ckpt_path.stem}' + manager, completer, ckpt_path.stem, f"Converted model {ckpt_path.stem}" ) is_inpainting = _is_inpainting(model_name_or_path) original_config_file = Path( - 'configs', - 'stable-diffusion', - 'v1-inpainting-inference.yaml' if is_inpainting else 'v1-inference.yaml' + "configs", + "stable-diffusion", + "v1-inpainting-inference.yaml" if is_inpainting else "v1-inference.yaml", ) else: - print(f'** {model_name_or_path} is neither an existing model nor the path to a .ckpt file') + print( + f"** {model_name_or_path} is neither an existing model nor the path to a .ckpt file" + ) return if not ckpt_path.is_absolute(): - ckpt_path = Path(Globals.root,ckpt_path) + ckpt_path = Path(Globals.root, ckpt_path) if original_config_file and not original_config_file.is_absolute(): - original_config_file = Path(Globals.root,original_config_file) + original_config_file = Path(Globals.root, original_config_file) - diffuser_path = Path(Globals.root, 'models',Globals.converted_ckpts_dir,model_name) + diffuser_path = Path( + Globals.root, "models", Globals.converted_ckpts_dir, model_name + ) if diffuser_path.exists(): - print(f'** {model_name_or_path} is already optimized. Will not overwrite. If this is an error, please remove the directory {diffuser_path} and try again.') + print( + f"** {model_name_or_path} is already optimized. Will not overwrite. If this is an error, please remove the directory {diffuser_path} and try again." + ) return vae = None - if input('Replace this model\'s VAE with "stabilityai/sd-vae-ft-mse"? [n] ').strip() in ('y','Y'): - vae = dict(repo_id='stabilityai/sd-vae-ft-mse') + if input( + 'Replace this model\'s VAE with "stabilityai/sd-vae-ft-mse"? [n] ' + ).strip() in ("y", "Y"): + vae = dict(repo_id="stabilityai/sd-vae-ft-mse") new_config = gen.model_manager.convert_and_import( ckpt_path, diffuser_path, model_name=model_name, model_description=model_description, - vae = vae, - original_config_file = original_config_file, + vae=vae, + original_config_file=original_config_file, commit_to_conf=opt.conf, ) if not new_config: return completer.update_models(gen.model_manager.list_models()) - if input(f'Load optimized model {model_name}? [y] ').strip() not in ('n','N'): + if input(f"Load optimized model {model_name}? [y] ").strip() not in ("n", "N"): gen.set_model(model_name) - response = input(f'Delete the original .ckpt file at ({ckpt_path} ? [n] ') - if response.startswith(('y','Y')): + response = input(f"Delete the original .ckpt file at ({ckpt_path} ? [n] ") + if response.startswith(("y", "Y")): ckpt_path.unlink(missing_ok=True) - print(f'{ckpt_path} deleted') + print(f"{ckpt_path} deleted") -def del_config(model_name:str, gen, opt, completer): + +def del_config(model_name: str, gen, opt, completer): current_model = gen.model_name if model_name == current_model: print("** Can't delete active model. !switch to another model first. **") @@ -794,31 +874,38 @@ def del_config(model_name:str, gen, opt, completer): print(f"** Unknown model {model_name}") return - if input(f'Remove {model_name} from the list of models known to InvokeAI? [y] ').strip().startswith(('n','N')): + if ( + input(f"Remove {model_name} from the list of models known to InvokeAI? [y] ") + .strip() + .startswith(("n", "N")) + ): return - delete_completely = input('Completely remove the model file or directory from disk? [n] ').startswith(('y','Y')) - gen.model_manager.del_model(model_name,delete_files=delete_completely) + delete_completely = input( + "Completely remove the model file or directory from disk? [n] " + ).startswith(("y", "Y")) + gen.model_manager.del_model(model_name, delete_files=delete_completely) gen.model_manager.commit(opt.conf) - print(f'** {model_name} deleted') + print(f"** {model_name} deleted") completer.update_models(gen.model_manager.list_models()) -def edit_model(model_name:str, gen, opt, completer): + +def edit_model(model_name: str, gen, opt, completer): manager = gen.model_manager if not (info := manager.model_info(model_name)): - print(f'** Unknown model {model_name}') + print(f"** Unknown model {model_name}") return - print(f'\n>> Editing model {model_name} from configuration file {opt.conf}') - new_name = _get_model_name(manager.list_models(),completer,model_name) + print(f"\n>> Editing model {model_name} from configuration file {opt.conf}") + new_name = _get_model_name(manager.list_models(), completer, model_name) for attribute in info.keys(): if type(info[attribute]) != str: continue - if attribute == 'format': + if attribute == "format": continue completer.set_line(info[attribute]) - info[attribute] = input(f'{attribute}: ') or info[attribute] + info[attribute] = input(f"{attribute}: ") or info[attribute] if new_name != model_name: manager.del_model(model_name) @@ -826,23 +913,26 @@ def edit_model(model_name:str, gen, opt, completer): # this does the update manager.add_model(new_name, info, True) - if input('Make this the default model? [n] ').startswith(('y','Y')): + if input("Make this the default model? [n] ").startswith(("y", "Y")): manager.set_default_model(new_name) manager.commit(opt.conf) completer.update_models(manager.list_models()) - print('>> Model successfully updated') + print(">> Model successfully updated") -def _get_model_name(existing_names,completer,default_name:str='')->str: + +def _get_model_name(existing_names, completer, default_name: str = "") -> str: done = False completer.set_line(default_name) while not done: - model_name = input(f'Short name for this model [{default_name}]: ').strip() - if len(model_name)==0: + model_name = input(f"Short name for this model [{default_name}]: ").strip() + if len(model_name) == 0: model_name = default_name - if not re.match('^[\w._+:/-]+$',model_name): - print('** model name must contain only words, digits and the characters "._+:/-" **') + if not re.match("^[\w._+:/-]+$", model_name): + print( + '** model name must contain only words, digits and the characters "._+:/-" **' + ) elif model_name != default_name and model_name in existing_names: - print(f'** the name {model_name} is already in use. Pick another.') + print(f"** the name {model_name} is already in use. Pick another.") else: done = True return model_name @@ -851,197 +941,223 @@ def _get_model_name(existing_names,completer,default_name:str='')->str: def do_textmask(gen, opt, callback): image_path = opt.prompt if not os.path.exists(image_path): - image_path = os.path.join(opt.outdir,image_path) - assert os.path.exists(image_path), '** "{opt.prompt}" not found. Please enter the name of an existing image file to mask **' - assert opt.text_mask is not None and len(opt.text_mask) >= 1, '** Please provide a text mask with -tm **' + image_path = os.path.join(opt.outdir, image_path) + assert os.path.exists( + image_path + ), '** "{opt.prompt}" not found. Please enter the name of an existing image file to mask **' + assert ( + opt.text_mask is not None and len(opt.text_mask) >= 1 + ), "** Please provide a text mask with -tm **" opt.input_file_path = image_path tm = opt.text_mask[0] - threshold = float(opt.text_mask[1]) if len(opt.text_mask) > 1 else 0.5 + threshold = float(opt.text_mask[1]) if len(opt.text_mask) > 1 else 0.5 gen.apply_textmask( - image_path = image_path, - prompt = tm, - threshold = threshold, - callback = callback, + image_path=image_path, + prompt=tm, + threshold=threshold, + callback=callback, ) -def do_postprocess (gen, opt, callback): - file_path = opt.prompt # treat the prompt as the file pathname + +def do_postprocess(gen, opt, callback): + file_path = opt.prompt # treat the prompt as the file pathname if opt.new_prompt is not None: opt.prompt = opt.new_prompt else: opt.prompt = None - if os.path.dirname(file_path) == '': #basename given - file_path = os.path.join(opt.outdir,file_path) + if os.path.dirname(file_path) == "": # basename given + file_path = os.path.join(opt.outdir, file_path) opt.input_file_path = file_path - tool=None + tool = None if opt.facetool_strength > 0: tool = opt.facetool elif opt.embiggen: - tool = 'embiggen' + tool = "embiggen" elif opt.upscale: - tool = 'upscale' + tool = "upscale" elif opt.out_direction: - tool = 'outpaint' + tool = "outpaint" elif opt.outcrop: - tool = 'outcrop' - opt.save_original = True # do not overwrite old image! - opt.last_operation = f'postprocess:{tool}' + tool = "outcrop" + opt.save_original = True # do not overwrite old image! + opt.last_operation = f"postprocess:{tool}" try: gen.apply_postprocessor( - image_path = file_path, - tool = tool, - facetool_strength = opt.facetool_strength, - codeformer_fidelity = opt.codeformer_fidelity, - save_original = opt.save_original, - upscale = opt.upscale, - upscale_denoise_str = opt.esrgan_denoise_str, - out_direction = opt.out_direction, - outcrop = opt.outcrop, - callback = callback, - opt = opt, + image_path=file_path, + tool=tool, + facetool_strength=opt.facetool_strength, + codeformer_fidelity=opt.codeformer_fidelity, + save_original=opt.save_original, + upscale=opt.upscale, + upscale_denoise_str=opt.esrgan_denoise_str, + out_direction=opt.out_direction, + outcrop=opt.outcrop, + callback=callback, + opt=opt, ) except OSError: print(traceback.format_exc(), file=sys.stderr) - print(f'** {file_path}: file could not be read') + print(f"** {file_path}: file could not be read") return except (KeyError, AttributeError): print(traceback.format_exc(), file=sys.stderr) return return opt.last_operation -def add_postprocessing_to_metadata(opt,original_file,new_file,tool,command): - original_file = original_file if os.path.exists(original_file) else os.path.join(opt.outdir,original_file) - new_file = new_file if os.path.exists(new_file) else os.path.join(opt.outdir,new_file) + +def add_postprocessing_to_metadata(opt, original_file, new_file, tool, command): + original_file = ( + original_file + if os.path.exists(original_file) + else os.path.join(opt.outdir, original_file) + ) + new_file = ( + new_file if os.path.exists(new_file) else os.path.join(opt.outdir, new_file) + ) try: - meta = retrieve_metadata(original_file)['sd-metadata'] + meta = retrieve_metadata(original_file)["sd-metadata"] except AttributeError: try: - meta = retrieve_metadata(new_file)['sd-metadata'] + meta = retrieve_metadata(new_file)["sd-metadata"] except AttributeError: meta = {} - if 'image' not in meta: - meta = metadata_dumps(opt,seeds=[opt.seed])['image'] - meta['image'] = {} - img_data = meta.get('image') - pp = img_data.get('postprocessing',[]) or [] + if "image" not in meta: + meta = metadata_dumps(opt, seeds=[opt.seed])["image"] + meta["image"] = {} + img_data = meta.get("image") + pp = img_data.get("postprocessing", []) or [] pp.append( { - 'tool':tool, - 'dream_command':command, + "tool": tool, + "dream_command": command, } ) - meta['image']['postprocessing'] = pp - write_metadata(new_file,meta) + meta["image"]["postprocessing"] = pp + write_metadata(new_file, meta) + def prepare_image_metadata( - opt, - prefix, - seed, - operation='generate', - prior_variations=[], - postprocessed=False, - first_seed=None + opt, + prefix, + seed, + operation="generate", + prior_variations=[], + postprocessed=False, + first_seed=None, ): - if postprocessed and opt.save_original: - filename = choose_postprocess_name(opt,prefix,seed) + filename = choose_postprocess_name(opt, prefix, seed) else: wildcards = dict(opt.__dict__) - wildcards['prefix'] = prefix - wildcards['seed'] = seed + wildcards["prefix"] = prefix + wildcards["seed"] = seed try: filename = opt.fnformat.format(**wildcards) except KeyError as e: - print(f'** The filename format contains an unknown key \'{e.args[0]}\'. Will use {{prefix}}.{{seed}}.png\' instead') - filename = f'{prefix}.{seed}.png' + print( + f"** The filename format contains an unknown key '{e.args[0]}'. Will use {{prefix}}.{{seed}}.png' instead" + ) + filename = f"{prefix}.{seed}.png" except IndexError: - print("** The filename format is broken or complete. Will use '{prefix}.{seed}.png' instead") - filename = f'{prefix}.{seed}.png' + print( + "** The filename format is broken or complete. Will use '{prefix}.{seed}.png' instead" + ) + filename = f"{prefix}.{seed}.png" if opt.variation_amount > 0: - first_seed = first_seed or seed - this_variation = [[seed, opt.variation_amount]] - opt.with_variations = prior_variations + this_variation + first_seed = first_seed or seed + this_variation = [[seed, opt.variation_amount]] + opt.with_variations = prior_variations + this_variation formatted_dream_prompt = opt.dream_prompt_str(seed=first_seed) elif len(prior_variations) > 0: formatted_dream_prompt = opt.dream_prompt_str(seed=first_seed) - elif operation == 'postprocess': - formatted_dream_prompt = '!fix '+opt.dream_prompt_str(seed=seed,prompt=opt.input_file_path) + elif operation == "postprocess": + formatted_dream_prompt = "!fix " + opt.dream_prompt_str( + seed=seed, prompt=opt.input_file_path + ) else: formatted_dream_prompt = opt.dream_prompt_str(seed=seed) - return filename,formatted_dream_prompt + return filename, formatted_dream_prompt -def choose_postprocess_name(opt,prefix,seed) -> str: - match = re.search('postprocess:(\w+)',opt.last_operation) + +def choose_postprocess_name(opt, prefix, seed) -> str: + match = re.search("postprocess:(\w+)", opt.last_operation) if match: - modifier = match.group(1) # will look like "gfpgan", "upscale", "outpaint" or "embiggen" + modifier = match.group( + 1 + ) # will look like "gfpgan", "upscale", "outpaint" or "embiggen" else: - modifier = 'postprocessed' + modifier = "postprocessed" - counter = 0 - filename = None + counter = 0 + filename = None available = False while not available: if counter == 0: - filename = f'{prefix}.{seed}.{modifier}.png' + filename = f"{prefix}.{seed}.{modifier}.png" else: - filename = f'{prefix}.{seed}.{modifier}-{counter:02d}.png' - available = not os.path.exists(os.path.join(opt.outdir,filename)) + filename = f"{prefix}.{seed}.{modifier}-{counter:02d}.png" + available = not os.path.exists(os.path.join(opt.outdir, filename)) counter += 1 return filename -def get_next_command(infile=None, model_name='no model') -> str: # command string + +def get_next_command(infile=None, model_name="no model") -> str: # command string if infile is None: - command = input(f'({model_name}) invoke> ').strip() + command = input(f"({model_name}) invoke> ").strip() else: command = infile.readline() if not command: raise EOFError else: command = command.strip() - if len(command)>0: - print(f'#{command}') + if len(command) > 0: + print(f"#{command}") return command -def invoke_ai_web_server_loop(gen: Generate, gfpgan, codeformer, esrgan): - print('\n* --web was specified, starting web server...') - from invokeai.backend import InvokeAIWebServer - # Change working directory to the stable-diffusion directory - os.chdir( - os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) - ) - invoke_ai_web_server = InvokeAIWebServer(generate=gen, gfpgan=gfpgan, codeformer=codeformer, esrgan=esrgan) +def invoke_ai_web_server_loop(gen: Generate, gfpgan, codeformer, esrgan): + print("\n* --web was specified, starting web server...") + from invokeai.backend import InvokeAIWebServer + + # Change working directory to the stable-diffusion directory + os.chdir(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) + + invoke_ai_web_server = InvokeAIWebServer( + generate=gen, gfpgan=gfpgan, codeformer=codeformer, esrgan=esrgan + ) try: invoke_ai_web_server.run() except KeyboardInterrupt: pass -def add_embedding_terms(gen,completer): - ''' + +def add_embedding_terms(gen, completer): + """ Called after setting the model, updates the autocompleter with any terms loaded by the embedding manager. - ''' + """ trigger_strings = gen.model.textual_inversion_manager.get_all_trigger_strings() completer.add_embedding_terms(trigger_strings) + def split_variations(variations_string) -> list: # shotgun parsing, woo parts = [] broken = False # python doesn't have labeled loops... - for part in variations_string.split(','): - seed_and_weight = part.split(':') + for part in variations_string.split(","): + seed_and_weight = part.split(":") if len(seed_and_weight) != 2: print(f'** Could not parse with_variation part "{part}"') broken = True break try: - seed = int(seed_and_weight[0]) + seed = int(seed_and_weight[0]) weight = float(seed_and_weight[1]) except ValueError: print(f'** Could not parse with_variation part "{part}"') @@ -1055,40 +1171,48 @@ def split_variations(variations_string) -> list: else: return parts + def load_face_restoration(opt): try: gfpgan, codeformer, esrgan = None, None, None if opt.restore or opt.esrgan: from ldm.invoke.restoration import Restoration + restoration = Restoration() if opt.restore: - gfpgan, codeformer = restoration.load_face_restore_models(opt.gfpgan_model_path) + gfpgan, codeformer = restoration.load_face_restore_models( + opt.gfpgan_model_path + ) else: - print('>> Face restoration disabled') + print(">> Face restoration disabled") if opt.esrgan: esrgan = restoration.load_esrgan(opt.esrgan_bg_tile) else: - print('>> Upscaling disabled') + print(">> Upscaling disabled") else: - print('>> Face restoration and upscaling disabled') + print(">> Face restoration and upscaling disabled") except (ModuleNotFoundError, ImportError): print(traceback.format_exc(), file=sys.stderr) - print('>> You may need to install the ESRGAN and/or GFPGAN modules') - return gfpgan,codeformer,esrgan + print(">> You may need to install the ESRGAN and/or GFPGAN modules") + return gfpgan, codeformer, esrgan + def make_step_callback(gen, opt, prefix): - destination = os.path.join(opt.outdir,'intermediates',prefix) - os.makedirs(destination,exist_ok=True) - print(f'>> Intermediate images will be written into {destination}') + destination = os.path.join(opt.outdir, "intermediates", prefix) + os.makedirs(destination, exist_ok=True) + print(f">> Intermediate images will be written into {destination}") + def callback(img, step): - if step % opt.save_intermediates == 0 or step == opt.steps-1: - filename = os.path.join(destination,f'{step:04}.png') + if step % opt.save_intermediates == 0 or step == opt.steps - 1: + filename = os.path.join(destination, f"{step:04}.png") image = gen.sample_to_image(img) - image.save(filename,'PNG') + image.save(filename, "PNG") + return callback -def retrieve_dream_command(opt,command,completer): - ''' + +def retrieve_dream_command(opt, command, completer): + """ Given a full or partial path to a previously-generated image file, will retrieve and format the dream command used to generate the image, and pop it into the readline buffer (linux, Mac), or print out a comment @@ -1097,34 +1221,35 @@ def retrieve_dream_command(opt,command,completer): Given a wildcard path to a folder with image png files, will retrieve and format the dream command used to generate the images, and save them to a file commands.txt for further processing - ''' + """ if len(command) == 0: return tokens = command.split() - dir,basename = os.path.split(tokens[0]) + dir, basename = os.path.split(tokens[0]) if len(dir) == 0: - path = os.path.join(opt.outdir,basename) + path = os.path.join(opt.outdir, basename) else: path = tokens[0] if len(tokens) > 1: return write_commands(opt, path, tokens[1]) - cmd = '' + cmd = "" try: cmd = dream_cmd_from_png(path) except OSError: - print(f'## {tokens[0]}: file could not be read') + print(f"## {tokens[0]}: file could not be read") except (KeyError, AttributeError, IndexError): - print(f'## {tokens[0]}: file has no metadata') + print(f"## {tokens[0]}: file has no metadata") except: - print(f'## {tokens[0]}: file could not be processed') - if len(cmd)>0: + print(f"## {tokens[0]}: file could not be processed") + if len(cmd) > 0: completer.set_line(cmd) -def write_commands(opt, file_path:str, outfilepath:str): - dir,basename = os.path.split(file_path) + +def write_commands(opt, file_path: str, outfilepath: str): + dir, basename = os.path.split(file_path) try: paths = sorted(list(Path(dir).glob(basename))) except ValueError: @@ -1137,39 +1262,46 @@ def write_commands(opt, file_path:str, outfilepath:str): try: cmd = dream_cmd_from_png(path) except (KeyError, AttributeError, IndexError): - print(f'## {path}: file has no metadata') + print(f"## {path}: file has no metadata") except: - print(f'## {path}: file could not be processed') + print(f"## {path}: file could not be processed") if cmd: - commands.append(f'# {path}') + commands.append(f"# {path}") commands.append(cmd) - if len(commands)>0: - dir,basename = os.path.split(outfilepath) - if len(dir)==0: - outfilepath = os.path.join(opt.outdir,basename) - with open(outfilepath, 'w', encoding='utf-8') as f: - f.write('\n'.join(commands)) - print(f'>> File {outfilepath} with commands created') + if len(commands) > 0: + dir, basename = os.path.split(outfilepath) + if len(dir) == 0: + outfilepath = os.path.join(opt.outdir, basename) + with open(outfilepath, "w", encoding="utf-8") as f: + f.write("\n".join(commands)) + print(f">> File {outfilepath} with commands created") -def report_model_error(opt:Namespace, e:Exception): + +def report_model_error(opt: Namespace, e: Exception): print(f'** An error occurred while attempting to initialize the model: "{str(e)}"') - print('** This can be caused by a missing or corrupted models file, and can sometimes be fixed by (re)installing the models.') - yes_to_all = os.environ.get('INVOKE_MODEL_RECONFIGURE') + print( + "** This can be caused by a missing or corrupted models file, and can sometimes be fixed by (re)installing the models." + ) + yes_to_all = os.environ.get("INVOKE_MODEL_RECONFIGURE") if yes_to_all: - print('** Reconfiguration is being forced by environment variable INVOKE_MODEL_RECONFIGURE') + print( + "** Reconfiguration is being forced by environment variable INVOKE_MODEL_RECONFIGURE" + ) else: - response = input('Do you want to run invokeai-configure script to select and/or reinstall models? [y] ') - if response.startswith(('n', 'N')): + response = input( + "Do you want to run invokeai-configure script to select and/or reinstall models? [y] " + ) + if response.startswith(("n", "N")): return - print('invokeai-configure is launching....\n') + print("invokeai-configure is launching....\n") # Match arguments that were set on the CLI # only the arguments accepted by the configuration script are parsed root_dir = ["--root", opt.root_dir] if opt.root_dir is not None else [] config = ["--config", opt.conf] if opt.conf is not None else [] previous_args = sys.argv - sys.argv = [ 'invokeai-configure' ] + sys.argv = ["invokeai-configure"] sys.argv.extend(root_dir) sys.argv.extend(config) if yes_to_all is not None: @@ -1177,21 +1309,24 @@ def report_model_error(opt:Namespace, e:Exception): sys.argv.append(arg) from ldm.invoke.config import invokeai_configure + invokeai_configure.main() - print('** InvokeAI will now restart') + print("** InvokeAI will now restart") sys.argv = previous_args - main() # would rather do a os.exec(), but doesn't exist? + main() # would rather do a os.exec(), but doesn't exist? sys.exit(0) -def check_internet()->bool: - ''' + +def check_internet() -> bool: + """ Return true if the internet is reachable. It does this by pinging huggingface.co. - ''' + """ import urllib.request - host = 'http://huggingface.co' + + host = "http://huggingface.co" try: - urllib.request.urlopen(host,timeout=1) + urllib.request.urlopen(host, timeout=1) return True except: return False diff --git a/ldm/invoke/args.py b/ldm/invoke/args.py index d81de4f1ca..1bd1aa46ab 100644 --- a/ldm/invoke/args.py +++ b/ldm/invoke/args.py @@ -751,6 +751,9 @@ class Args(object): !fix applies upscaling/facefixing to a previously-generated image. invoke> !fix 0000045.4829112.png -G1 -U4 -ft codeformer + *embeddings* + invoke> !triggers -- return all trigger phrases contained in loaded embedding files + *History manipulation* !fetch retrieves the command used to generate an earlier image. Provide a directory wildcard and the name of a file to write and all the commands diff --git a/ldm/invoke/readline.py b/ldm/invoke/readline.py index f14af0714f..1e9b31ea8d 100644 --- a/ldm/invoke/readline.py +++ b/ldm/invoke/readline.py @@ -60,7 +60,7 @@ COMMANDS = ( '--text_mask','-tm', '!fix','!fetch','!replay','!history','!search','!clear', '!models','!switch','!import_model','!optimize_model','!convert_model','!edit_model','!del_model', - '!mask', + '!mask','!triggers', ) MODEL_COMMANDS = ( '!switch', diff --git a/ldm/modules/textual_inversion_manager.py b/ldm/modules/textual_inversion_manager.py index 2e61be6b12..8ca1a0bf5e 100644 --- a/ldm/modules/textual_inversion_manager.py +++ b/ldm/modules/textual_inversion_manager.py @@ -1,11 +1,12 @@ import os import traceback +from dataclasses import dataclass +from pathlib import Path from typing import Optional import torch -from dataclasses import dataclass from picklescan.scanner import scan_file_path -from transformers import CLIPTokenizer, CLIPTextModel +from transformers import CLIPTextModel, CLIPTokenizer from ldm.invoke.concepts_lib import HuggingFaceConceptsLibrary @@ -21,11 +22,14 @@ class TextualInversion: def embedding_vector_length(self) -> int: return self.embedding.shape[0] -class TextualInversionManager(): - def __init__(self, - tokenizer: CLIPTokenizer, - text_encoder: CLIPTextModel, - full_precision: bool=True): + +class TextualInversionManager: + def __init__( + self, + tokenizer: CLIPTokenizer, + text_encoder: CLIPTextModel, + full_precision: bool = True, + ): self.tokenizer = tokenizer self.text_encoder = text_encoder self.full_precision = full_precision @@ -38,47 +42,60 @@ class TextualInversionManager(): if concept_name in self.hf_concepts_library.concepts_loaded: continue trigger = self.hf_concepts_library.concept_to_trigger(concept_name) - if self.has_textual_inversion_for_trigger_string(trigger) \ - or self.has_textual_inversion_for_trigger_string(concept_name) \ - or self.has_textual_inversion_for_trigger_string(f'<{concept_name}>'): # in case a token with literal angle brackets encountered - print(f'>> Loaded local embedding for trigger {concept_name}') + if ( + self.has_textual_inversion_for_trigger_string(trigger) + or self.has_textual_inversion_for_trigger_string(concept_name) + or self.has_textual_inversion_for_trigger_string(f"<{concept_name}>") + ): # in case a token with literal angle brackets encountered + print(f">> Loaded local embedding for trigger {concept_name}") continue bin_file = self.hf_concepts_library.get_concept_model_path(concept_name) if not bin_file: continue - print(f'>> Loaded remote embedding for trigger {concept_name}') + print(f">> Loaded remote embedding for trigger {concept_name}") self.load_textual_inversion(bin_file) - self.hf_concepts_library.concepts_loaded[concept_name]=True + self.hf_concepts_library.concepts_loaded[concept_name] = True def get_all_trigger_strings(self) -> list[str]: return [ti.trigger_string for ti in self.textual_inversions] - def load_textual_inversion(self, ckpt_path, defer_injecting_tokens: bool=False): - if str(ckpt_path).endswith('.DS_Store'): + def load_textual_inversion(self, ckpt_path, defer_injecting_tokens: bool = False): + if str(ckpt_path).endswith(".DS_Store"): return try: scan_result = scan_file_path(ckpt_path) if scan_result.infected_files == 1: - print(f'\n### Security Issues Found in Model: {scan_result.issues_count}') - print('### For your safety, InvokeAI will not load this embed.') + print( + f"\n### Security Issues Found in Model: {scan_result.issues_count}" + ) + print("### For your safety, InvokeAI will not load this embed.") return except Exception: - print(f"### WARNING::: Invalid or corrupt embeddings found. Ignoring: {ckpt_path}") + ckpt_path = Path(ckpt_path) + print( + f"** Notice: {ckpt_path.parents[0].stem}/{ckpt_path.stem} is incompatible with this model" + ) return embedding_info = self._parse_embedding(ckpt_path) if embedding_info: try: - self._add_textual_inversion(embedding_info['name'], - embedding_info['embedding'], - defer_injecting_tokens=defer_injecting_tokens) + self._add_textual_inversion( + embedding_info["name"], + embedding_info["embedding"], + defer_injecting_tokens=defer_injecting_tokens, + ) except ValueError as e: print(f' | Ignoring incompatible embedding {embedding_info["name"]}') - print(f' | The error was {str(e)}') + print(f" | The error was {str(e)}") else: - print(f'>> Failed to load embedding located at {ckpt_path}. Unsupported file.') + print( + f">> Failed to load embedding located at {ckpt_path}. Unsupported file." + ) - def _add_textual_inversion(self, trigger_str, embedding, defer_injecting_tokens=False) -> TextualInversion: + def _add_textual_inversion( + self, trigger_str, embedding, defer_injecting_tokens=False + ) -> TextualInversion: """ Add a textual inversion to be recognised. :param trigger_str: The trigger text in the prompt that activates this textual inversion. If unknown to the embedder's tokenizer, will be added. @@ -86,46 +103,59 @@ class TextualInversionManager(): :return: The token id for the added embedding, either existing or newly-added. """ if trigger_str in [ti.trigger_string for ti in self.textual_inversions]: - print(f">> TextualInversionManager refusing to overwrite already-loaded token '{trigger_str}'") + print( + f">> TextualInversionManager refusing to overwrite already-loaded token '{trigger_str}'" + ) return if not self.full_precision: embedding = embedding.half() if len(embedding.shape) == 1: embedding = embedding.unsqueeze(0) elif len(embedding.shape) > 2: - raise ValueError(f"TextualInversionManager cannot add {trigger_str} because the embedding shape {embedding.shape} is incorrect. The embedding must have shape [token_dim] or [V, token_dim] where V is vector length and token_dim is 768 for SD1 or 1280 for SD2.") + raise ValueError( + f"TextualInversionManager cannot add {trigger_str} because the embedding shape {embedding.shape} is incorrect. The embedding must have shape [token_dim] or [V, token_dim] where V is vector length and token_dim is 768 for SD1 or 1280 for SD2." + ) try: - ti = TextualInversion( - trigger_string=trigger_str, - embedding=embedding - ) + ti = TextualInversion(trigger_string=trigger_str, embedding=embedding) if not defer_injecting_tokens: self._inject_tokens_and_assign_embeddings(ti) self.textual_inversions.append(ti) return ti except ValueError as e: - if str(e).startswith('Warning'): + if str(e).startswith("Warning"): print(f">> {str(e)}") else: traceback.print_exc() - print(f">> TextualInversionManager was unable to add a textual inversion with trigger string {trigger_str}.") + print( + f">> TextualInversionManager was unable to add a textual inversion with trigger string {trigger_str}." + ) raise def _inject_tokens_and_assign_embeddings(self, ti: TextualInversion) -> int: - if ti.trigger_token_id is not None: - raise ValueError(f"Tokens already injected for textual inversion with trigger '{ti.trigger_string}'") + raise ValueError( + f"Tokens already injected for textual inversion with trigger '{ti.trigger_string}'" + ) - trigger_token_id = self._get_or_create_token_id_and_assign_embedding(ti.trigger_string, ti.embedding[0]) + trigger_token_id = self._get_or_create_token_id_and_assign_embedding( + ti.trigger_string, ti.embedding[0] + ) if ti.embedding_vector_length > 1: # for embeddings with vector length > 1 - pad_token_strings = [ti.trigger_string + "-!pad-" + str(pad_index) for pad_index in range(1, ti.embedding_vector_length)] + pad_token_strings = [ + ti.trigger_string + "-!pad-" + str(pad_index) + for pad_index in range(1, ti.embedding_vector_length) + ] # todo: batched UI for faster loading when vector length >2 - pad_token_ids = [self._get_or_create_token_id_and_assign_embedding(pad_token_str, ti.embedding[1 + i]) \ - for (i, pad_token_str) in enumerate(pad_token_strings)] + pad_token_ids = [ + self._get_or_create_token_id_and_assign_embedding( + pad_token_str, ti.embedding[1 + i] + ) + for (i, pad_token_str) in enumerate(pad_token_strings) + ] else: pad_token_ids = [] @@ -133,7 +163,6 @@ class TextualInversionManager(): ti.pad_token_ids = pad_token_ids return ti.trigger_token_id - def has_textual_inversion_for_trigger_string(self, trigger_string: str) -> bool: try: ti = self.get_textual_inversion_for_trigger_string(trigger_string) @@ -141,32 +170,43 @@ class TextualInversionManager(): except StopIteration: return False - - def get_textual_inversion_for_trigger_string(self, trigger_string: str) -> TextualInversion: - return next(ti for ti in self.textual_inversions if ti.trigger_string == trigger_string) - + def get_textual_inversion_for_trigger_string( + self, trigger_string: str + ) -> TextualInversion: + return next( + ti for ti in self.textual_inversions if ti.trigger_string == trigger_string + ) def get_textual_inversion_for_token_id(self, token_id: int) -> TextualInversion: - return next(ti for ti in self.textual_inversions if ti.trigger_token_id == token_id) + return next( + ti for ti in self.textual_inversions if ti.trigger_token_id == token_id + ) - def create_deferred_token_ids_for_any_trigger_terms(self, prompt_string: str) -> list[int]: + def create_deferred_token_ids_for_any_trigger_terms( + self, prompt_string: str + ) -> list[int]: injected_token_ids = [] for ti in self.textual_inversions: if ti.trigger_token_id is None and ti.trigger_string in prompt_string: if ti.embedding_vector_length > 1: - print(f">> Preparing tokens for textual inversion {ti.trigger_string}...") + print( + f">> Preparing tokens for textual inversion {ti.trigger_string}..." + ) try: self._inject_tokens_and_assign_embeddings(ti) except ValueError as e: - print(f' | Ignoring incompatible embedding trigger {ti.trigger_string}') - print(f' | The error was {str(e)}') + print( + f" | Ignoring incompatible embedding trigger {ti.trigger_string}" + ) + print(f" | The error was {str(e)}") continue injected_token_ids.append(ti.trigger_token_id) injected_token_ids.extend(ti.pad_token_ids) return injected_token_ids - - def expand_textual_inversion_token_ids_if_necessary(self, prompt_token_ids: list[int]) -> list[int]: + def expand_textual_inversion_token_ids_if_necessary( + self, prompt_token_ids: list[int] + ) -> list[int]: """ Insert padding tokens as necessary into the passed-in list of token ids to match any textual inversions it includes. @@ -181,20 +221,31 @@ class TextualInversionManager(): raise ValueError("prompt_token_ids must not start with bos_token_id") if prompt_token_ids[-1] == self.tokenizer.eos_token_id: raise ValueError("prompt_token_ids must not end with eos_token_id") - textual_inversion_trigger_token_ids = [ti.trigger_token_id for ti in self.textual_inversions] + textual_inversion_trigger_token_ids = [ + ti.trigger_token_id for ti in self.textual_inversions + ] prompt_token_ids = prompt_token_ids.copy() for i, token_id in reversed(list(enumerate(prompt_token_ids))): if token_id in textual_inversion_trigger_token_ids: - textual_inversion = next(ti for ti in self.textual_inversions if ti.trigger_token_id == token_id) - for pad_idx in range(0, textual_inversion.embedding_vector_length-1): - prompt_token_ids.insert(i+pad_idx+1, textual_inversion.pad_token_ids[pad_idx]) + textual_inversion = next( + ti + for ti in self.textual_inversions + if ti.trigger_token_id == token_id + ) + for pad_idx in range(0, textual_inversion.embedding_vector_length - 1): + prompt_token_ids.insert( + i + pad_idx + 1, textual_inversion.pad_token_ids[pad_idx] + ) return prompt_token_ids - - def _get_or_create_token_id_and_assign_embedding(self, token_str: str, embedding: torch.Tensor) -> int: + def _get_or_create_token_id_and_assign_embedding( + self, token_str: str, embedding: torch.Tensor + ) -> int: if len(embedding.shape) != 1: - raise ValueError("Embedding has incorrect shape - must be [token_dim] where token_dim is 768 for SD1 or 1280 for SD2") + raise ValueError( + "Embedding has incorrect shape - must be [token_dim] where token_dim is 768 for SD1 or 1280 for SD2" + ) existing_token_id = self.tokenizer.convert_tokens_to_ids(token_str) if existing_token_id == self.tokenizer.unk_token_id: num_tokens_added = self.tokenizer.add_tokens(token_str) @@ -207,66 +258,78 @@ class TextualInversionManager(): token_id = self.tokenizer.convert_tokens_to_ids(token_str) if token_id == self.tokenizer.unk_token_id: raise RuntimeError(f"Unable to find token id for token '{token_str}'") - if self.text_encoder.get_input_embeddings().weight.data[token_id].shape != embedding.shape: - raise ValueError(f"Warning. Cannot load embedding for {token_str}. It was trained on a model with token dimension {embedding.shape[0]}, but the current model has token dimension {self.text_encoder.get_input_embeddings().weight.data[token_id].shape[0]}.") + if ( + self.text_encoder.get_input_embeddings().weight.data[token_id].shape + != embedding.shape + ): + raise ValueError( + f"Warning. Cannot load embedding for {token_str}. It was trained on a model with token dimension {embedding.shape[0]}, but the current model has token dimension {self.text_encoder.get_input_embeddings().weight.data[token_id].shape[0]}." + ) self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding return token_id def _parse_embedding(self, embedding_file: str): - file_type = embedding_file.split('.')[-1] - if file_type == 'pt': + file_type = embedding_file.split(".")[-1] + if file_type == "pt": return self._parse_embedding_pt(embedding_file) - elif file_type == 'bin': + elif file_type == "bin": return self._parse_embedding_bin(embedding_file) else: - print(f'>> Not a recognized embedding file: {embedding_file}') + print(f">> Not a recognized embedding file: {embedding_file}") def _parse_embedding_pt(self, embedding_file): - embedding_ckpt = torch.load(embedding_file, map_location='cpu') + embedding_ckpt = torch.load(embedding_file, map_location="cpu") embedding_info = {} # Check if valid embedding file - if 'string_to_token' and 'string_to_param' in embedding_ckpt: - + if "string_to_token" and "string_to_param" in embedding_ckpt: # Catch variants that do not have the expected keys or values. try: - embedding_info['name'] = embedding_ckpt['name'] or os.path.basename(os.path.splitext(embedding_file)[0]) + embedding_info["name"] = embedding_ckpt["name"] or os.path.basename( + os.path.splitext(embedding_file)[0] + ) # Check num of embeddings and warn user only the first will be used - embedding_info['num_of_embeddings'] = len(embedding_ckpt["string_to_token"]) - if embedding_info['num_of_embeddings'] > 1: - print('>> More than 1 embedding found. Will use the first one') + embedding_info["num_of_embeddings"] = len( + embedding_ckpt["string_to_token"] + ) + if embedding_info["num_of_embeddings"] > 1: + print(">> More than 1 embedding found. Will use the first one") - embedding = list(embedding_ckpt['string_to_param'].values())[0] - except (AttributeError,KeyError): + embedding = list(embedding_ckpt["string_to_param"].values())[0] + except (AttributeError, KeyError): return self._handle_broken_pt_variants(embedding_ckpt, embedding_file) - embedding_info['embedding'] = embedding - embedding_info['num_vectors_per_token'] = embedding.size()[0] - embedding_info['token_dim'] = embedding.size()[1] + embedding_info["embedding"] = embedding + embedding_info["num_vectors_per_token"] = embedding.size()[0] + embedding_info["token_dim"] = embedding.size()[1] try: - embedding_info['trained_steps'] = embedding_ckpt['step'] - embedding_info['trained_model_name'] = embedding_ckpt['sd_checkpoint_name'] - embedding_info['trained_model_checksum'] = embedding_ckpt['sd_checkpoint'] + embedding_info["trained_steps"] = embedding_ckpt["step"] + embedding_info["trained_model_name"] = embedding_ckpt[ + "sd_checkpoint_name" + ] + embedding_info["trained_model_checksum"] = embedding_ckpt[ + "sd_checkpoint" + ] except AttributeError: print(">> No Training Details Found. Passing ...") # .pt files found at https://cyberes.github.io/stable-diffusion-textual-inversion-models/ # They are actually .bin files - elif len(embedding_ckpt.keys())==1: - print('>> Detected .bin file masquerading as .pt file') + elif len(embedding_ckpt.keys()) == 1: + print(">> Detected .bin file masquerading as .pt file") embedding_info = self._parse_embedding_bin(embedding_file) else: - print('>> Invalid embedding format') + print(">> Invalid embedding format") embedding_info = None return embedding_info def _parse_embedding_bin(self, embedding_file): - embedding_ckpt = torch.load(embedding_file, map_location='cpu') + embedding_ckpt = torch.load(embedding_file, map_location="cpu") embedding_info = {} if list(embedding_ckpt.keys()) == 0: @@ -274,27 +337,45 @@ class TextualInversionManager(): embedding_info = None else: for token in list(embedding_ckpt.keys()): - embedding_info['name'] = token or os.path.basename(os.path.splitext(embedding_file)[0]) - embedding_info['embedding'] = embedding_ckpt[token] - embedding_info['num_vectors_per_token'] = 1 # All Concepts seem to default to 1 - embedding_info['token_dim'] = embedding_info['embedding'].size()[0] + embedding_info["name"] = token or os.path.basename( + os.path.splitext(embedding_file)[0] + ) + embedding_info["embedding"] = embedding_ckpt[token] + embedding_info[ + "num_vectors_per_token" + ] = 1 # All Concepts seem to default to 1 + embedding_info["token_dim"] = embedding_info["embedding"].size()[0] return embedding_info - def _handle_broken_pt_variants(self, embedding_ckpt:dict, embedding_file:str)->dict: - ''' + def _handle_broken_pt_variants( + self, embedding_ckpt: dict, embedding_file: str + ) -> dict: + """ This handles the broken .pt file variants. We only know of one at present. - ''' + """ embedding_info = {} - if isinstance(list(embedding_ckpt['string_to_token'].values())[0],torch.Tensor): - print('>> Detected .pt file variant 1') # example at https://github.com/invoke-ai/InvokeAI/issues/1829 - for token in list(embedding_ckpt['string_to_token'].keys()): - embedding_info['name'] = token if token != '*' else os.path.basename(os.path.splitext(embedding_file)[0]) - embedding_info['embedding'] = embedding_ckpt['string_to_param'].state_dict()[token] - embedding_info['num_vectors_per_token'] = embedding_info['embedding'].shape[0] - embedding_info['token_dim'] = embedding_info['embedding'].size()[0] + if isinstance( + list(embedding_ckpt["string_to_token"].values())[0], torch.Tensor + ): + print( + ">> Detected .pt file variant 1" + ) # example at https://github.com/invoke-ai/InvokeAI/issues/1829 + for token in list(embedding_ckpt["string_to_token"].keys()): + embedding_info["name"] = ( + token + if token != "*" + else os.path.basename(os.path.splitext(embedding_file)[0]) + ) + embedding_info["embedding"] = embedding_ckpt[ + "string_to_param" + ].state_dict()[token] + embedding_info["num_vectors_per_token"] = embedding_info[ + "embedding" + ].shape[0] + embedding_info["token_dim"] = embedding_info["embedding"].size()[0] else: - print('>> Invalid embedding format') + print(">> Invalid embedding format") embedding_info = None return embedding_info