From e7573ac90f1bce38f7130bb5748ea183e2fa3eb4 Mon Sep 17 00:00:00 2001 From: wfng92 <43742196+wfng92@users.noreply.github.com> Date: Sat, 22 Oct 2022 09:03:31 +0800 Subject: [PATCH 1/3] Removed duplicate fix_func for MPS --- ldm/generate.py | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/ldm/generate.py b/ldm/generate.py index b21787eb47..d60226cdcd 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -55,24 +55,6 @@ torch.randint_like = fix_func(torch.randint_like) torch.bernoulli = fix_func(torch.bernoulli) torch.multinomial = fix_func(torch.multinomial) -def fix_func(orig): - if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): - def new_func(*args, **kw): - device = kw.get("device", "mps") - kw["device"]="cpu" - return orig(*args, **kw).to(device) - return new_func - return orig - -torch.rand = fix_func(torch.rand) -torch.rand_like = fix_func(torch.rand_like) -torch.randn = fix_func(torch.randn) -torch.randn_like = fix_func(torch.randn_like) -torch.randint = fix_func(torch.randint) -torch.randint_like = fix_func(torch.randint_like) -torch.bernoulli = fix_func(torch.bernoulli) -torch.multinomial = fix_func(torch.multinomial) - """Simplified text to image API for stable diffusion/latent diffusion Example Usage: From 51fdbe22d2b74c0080a4f9eec8ce00d50023cc97 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Sat, 22 Oct 2022 13:29:45 -0400 Subject: [PATCH 2/3] add support for loading VAE autoencoders To add a VAE autoencoder to an existing model: 1. Download the appropriate autoencoder and put it into models/ldm/stable-diffusion Note that you MUST use a VAE that was written for the original CompViz Stable Diffusion codebase. For v1.4, that would be the file named vae-ft-mse-840000-ema-pruned.ckpt that you can download from https://huggingface.co/stabilityai/sd-vae-ft-mse-original 2. Edit config/models.yaml to contain the following stanza, modifying `weights` and `vae` as required to match the weights and vae model file names. There is no requirement to rename the VAE file. ~~~ stable-diffusion-1.4: weights: models/ldm/stable-diffusion-v1/sd-v1-4.ckpt description: Stable Diffusion v1.4 config: configs/stable-diffusion/v1-inference.yaml vae: models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt width: 512 height: 512 ~~~ 3. Alternatively from within the `invoke.py` CLI, you may use the command `!editmodel stable-diffusion-1.4` to bring up a simple editor that will allow you to add the path to the VAE. 4. If you are just installing InvokeAI for the first time, you can also use `!import_model models/ldm/stable-diffusion/sd-v1.4.ckpt` instead to create the configuration from scratch. 5. That's it! --- configs/models.yaml | 1 + ldm/invoke/model_cache.py | 10 ++++++++++ scripts/invoke.py | 14 ++++++++++++-- 3 files changed, 23 insertions(+), 2 deletions(-) diff --git a/configs/models.yaml b/configs/models.yaml index f3fde45d8f..67183bdd1f 100644 --- a/configs/models.yaml +++ b/configs/models.yaml @@ -9,6 +9,7 @@ stable-diffusion-1.4: config: configs/stable-diffusion/v1-inference.yaml weights: models/ldm/stable-diffusion-v1/model.ckpt + vae: models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt description: Stable Diffusion inference model version 1.4 width: 512 height: 512 diff --git a/ldm/invoke/model_cache.py b/ldm/invoke/model_cache.py index 5e9e53cfb7..f580dfba25 100644 --- a/ldm/invoke/model_cache.py +++ b/ldm/invoke/model_cache.py @@ -13,6 +13,7 @@ import gc import hashlib import psutil import transformers +import os from sys import getrefcount from omegaconf import OmegaConf from omegaconf.errors import ConfigAttributeError @@ -193,6 +194,7 @@ class ModelCache(object): mconfig = self.config[model_name] config = mconfig.config weights = mconfig.weights + vae = mconfig.get('vae',None) width = mconfig.width height = mconfig.height @@ -222,9 +224,17 @@ class ModelCache(object): else: print(' | Using more accurate float32 precision') + # look and load a matching vae file. Code borrowed from AUTOMATIC1111 modules/sd_models.py + if vae and os.path.exists(vae): + print(f' | Loading VAE weights from: {vae}') + vae_ckpt = torch.load(vae, map_location="cpu") + vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"} + model.first_stage_model.load_state_dict(vae_dict, strict=False) + model.to(self.device) # model.to doesn't change the cond_stage_model.device used to move the tokenizer output, so set it here model.cond_stage_model.device = self.device + model.eval() for m in model.modules(): diff --git a/scripts/invoke.py b/scripts/invoke.py index b7af4d6469..f4d4f3c4c0 100644 --- a/scripts/invoke.py +++ b/scripts/invoke.py @@ -493,6 +493,16 @@ def add_weights_to_config(model_path:str, gen, opt, completer): new_config['config'] = input('Configuration file for this model: ') done = os.path.exists(new_config['config']) + done = False + completer.complete_extensions(('.vae.pt','.vae','.ckpt')) + while not done: + vae = input('VAE autoencoder file for this model [None]: ') + if os.path.exists(vae): + new_config['vae'] = vae + done = True + else: + done = len(vae)==0 + completer.complete_extensions(None) for field in ('width','height'): @@ -537,8 +547,8 @@ def edit_config(model_name:str, gen, opt, completer): conf = config[model_name] new_config = {} - completer.complete_extensions(('.yaml','.yml','.ckpt','.vae')) - for field in ('description', 'weights', 'config', 'width','height'): + completer.complete_extensions(('.yaml','.yml','.ckpt','.vae.pt')) + for field in ('description', 'weights', 'vae', 'config', 'width','height'): completer.linebuffer = str(conf[field]) if field in conf else '' new_value = input(f'{field}: ') new_config[field] = int(new_value) if field in ('width','height') else new_value From 93cba3fba5b31418f4b6251de5f6ebef07e02ad4 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Sat, 22 Oct 2022 23:09:38 -0400 Subject: [PATCH 3/3] Kyle0654 inpaint improvement - with refactoring from PR #1221 (#1) * Removed duplicate fix_func for MPS * add support for loading VAE autoencoders To add a VAE autoencoder to an existing model: 1. Download the appropriate autoencoder and put it into models/ldm/stable-diffusion Note that you MUST use a VAE that was written for the original CompViz Stable Diffusion codebase. For v1.4, that would be the file named vae-ft-mse-840000-ema-pruned.ckpt that you can download from https://huggingface.co/stabilityai/sd-vae-ft-mse-original 2. Edit config/models.yaml to contain the following stanza, modifying `weights` and `vae` as required to match the weights and vae model file names. There is no requirement to rename the VAE file. ~~~ stable-diffusion-1.4: weights: models/ldm/stable-diffusion-v1/sd-v1-4.ckpt description: Stable Diffusion v1.4 config: configs/stable-diffusion/v1-inference.yaml vae: models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt width: 512 height: 512 ~~~ 3. Alternatively from within the `invoke.py` CLI, you may use the command `!editmodel stable-diffusion-1.4` to bring up a simple editor that will allow you to add the path to the VAE. 4. If you are just installing InvokeAI for the first time, you can also use `!import_model models/ldm/stable-diffusion/sd-v1.4.ckpt` instead to create the configuration from scratch. 5. That's it! * ported code refactor changes from PR #1221 - pass a PIL.Image to img2img and inpaint rather than tensor - To support clipseg, inpaint needs to accept an "L" or "1" format mask. Made the appropriate change. * minor fixes to inpaint code 1. If tensors are passed to inpaint as init_image and/or init_mask, then the post-generation image fixup code will be skipped. 2. Post-generation image fixup will work with either a black and white "L" or "RGB" mask, or an "RGBA" mask. Co-authored-by: wfng92 <43742196+wfng92@users.noreply.github.com> --- configs/models.yaml | 26 ++++---- ldm/generate.py | 72 ++++------------------- ldm/invoke/generator/img2img.py | 20 ++++++- ldm/invoke/generator/inpaint.py | 101 ++++++++++++++++++++------------ ldm/invoke/model_cache.py | 10 ++++ scripts/invoke.py | 14 ++++- 6 files changed, 129 insertions(+), 114 deletions(-) diff --git a/configs/models.yaml b/configs/models.yaml index f3fde45d8f..9d6d80084f 100644 --- a/configs/models.yaml +++ b/configs/models.yaml @@ -1,20 +1,22 @@ # This file describes the alternative machine learning models -# available to the dream script. +# available to the dream script. # # To add a new model, follow the examples below. Each # model requires a model config file, a weights file, # and the width and height of the images it # was trained on. - stable-diffusion-1.4: - config: configs/stable-diffusion/v1-inference.yaml - weights: models/ldm/stable-diffusion-v1/model.ckpt - description: Stable Diffusion inference model version 1.4 - width: 512 - height: 512 + config: configs/stable-diffusion/v1-inference.yaml + weights: models/ldm/stable-diffusion-v1/model.ckpt + vae: models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt + description: Stable Diffusion inference model version 1.4 + width: 512 + height: 512 stable-diffusion-1.5: - config: configs/stable-diffusion/v1-inference.yaml - weights: models/ldm/stable-diffusion-v1/v1-5-pruned-emaonly.ckpt - description: Stable Diffusion inference model version 1.5 - width: 512 - height: 512 + config: configs/stable-diffusion/v1-inference.yaml + weights: models/ldm/stable-diffusion-v1/v1-5-pruned-emaonly.ckpt + description: Stable Diffusion inference model version 1.5 + width: 512 + height: 512 + vae: models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt + default: true diff --git a/ldm/generate.py b/ldm/generate.py index 9b3ef14b27..ce2331806c 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -58,24 +58,6 @@ torch.multinomial = fix_func(torch.multinomial) # this is fallback model in case no default is defined FALLBACK_MODEL_NAME='stable-diffusion-1.4' -def fix_func(orig): - if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): - def new_func(*args, **kw): - device = kw.get("device", "mps") - kw["device"]="cpu" - return orig(*args, **kw).to(device) - return new_func - return orig - -torch.rand = fix_func(torch.rand) -torch.rand_like = fix_func(torch.rand_like) -torch.randn = fix_func(torch.randn) -torch.randn_like = fix_func(torch.randn_like) -torch.randint = fix_func(torch.randint) -torch.randint_like = fix_func(torch.randint_like) -torch.bernoulli = fix_func(torch.bernoulli) -torch.multinomial = fix_func(torch.multinomial) - """Simplified text to image API for stable diffusion/latent diffusion Example Usage: @@ -411,7 +393,7 @@ class Generate: log_tokens =self.log_tokenization ) - init_image,mask_image,pil_image,pil_mask = self._make_images( + init_image, mask_image = self._make_images( init_img, init_mask, width, @@ -451,8 +433,6 @@ class Generate: height=height, init_img=init_img, # embiggen needs to manipulate from the unmodified init_img init_image=init_image, # notice that init_image is different from init_img - pil_image=pil_image, - pil_mask=pil_mask, mask_image=mask_image, strength=strength, threshold=threshold, @@ -644,7 +624,7 @@ class Generate: init_image = None init_mask = None if not img: - return None, None, None, None + return None, None image = self._load_img(img) @@ -654,23 +634,22 @@ class Generate: # if image has a transparent area and no mask was provided, then try to generate mask if self._has_transparency(image): self._transparency_check_and_warning(image, mask) - # this returns a torch tensor init_mask = self._create_init_mask(image, width, height, fit=fit) if (image.width * image.height) > (self.width * self.height) and self.size_matters: print(">> This input is larger than your defaults. If you run out of memory, please use a smaller image.") self.size_matters = False - init_image = self._create_init_image(image,width,height,fit=fit) # this returns a torch tensor + init_image = self._create_init_image(image,width,height,fit=fit) if mask: - mask_image = self._load_img(mask) # this returns an Image + mask_image = self._load_img(mask) init_mask = self._create_init_mask(mask_image,width,height,fit=fit) elif text_mask: init_mask = self._txt2mask(image, text_mask, width, height, fit=fit) - return init_image, init_mask, image, mask_image + return init_image,init_mask def _make_base(self): if not self.generators.get('base'): @@ -887,33 +866,15 @@ class Generate: def _create_init_image(self, image, width, height, fit=True): image = image.convert('RGB') - if fit: - image = self._fit_image(image, (width, height)) - else: - image = self._squeeze_image(image) - image = np.array(image).astype(np.float32) / 255.0 - image = image[None].transpose(0, 3, 1, 2) - image = torch.from_numpy(image) - image = 2.0 * image - 1.0 - return image.to(self.device) + image = self._fit_image(image, (width, height)) if fit else self._squeeze_image(image) + return image def _create_init_mask(self, image, width, height, fit=True): # convert into a black/white mask image = self._image_to_mask(image) image = image.convert('RGB') - - # now we adjust the size - if fit: - image = self._fit_image(image, (width, height)) - else: - image = self._squeeze_image(image) - image = image.resize((image.width//downsampling, image.height // - downsampling), resample=Image.Resampling.NEAREST) - image = np.array(image) - image = image.astype(np.float32) / 255.0 - image = image[None].transpose(0, 3, 1, 2) - image = torch.from_numpy(image) - return image.to(self.device) + image = self._fit_image(image, (width, height)) if fit else self._squeeze_image(image) + return image # The mask is expected to have the region to be inpainted # with alpha transparency. It converts it into a black/white @@ -930,7 +891,6 @@ class Generate: mask = ImageOps.invert(mask) return mask - # TODO: The latter part of this method repeats code from _create_init_mask() def _txt2mask(self, image:Image, text_mask:list, width, height, fit=True) -> Image: prompt = text_mask[0] confidence_level = text_mask[1] if len(text_mask)>1 else 0.5 @@ -940,18 +900,8 @@ class Generate: segmented = self.txt2mask.segment(image, prompt) mask = segmented.to_mask(float(confidence_level)) mask = mask.convert('RGB') - # now we adjust the size - if fit: - mask = self._fit_image(mask, (width, height)) - else: - mask = self._squeeze_image(mask) - mask = mask.resize((mask.width//downsampling, mask.height // - downsampling), resample=Image.Resampling.NEAREST) - mask = np.array(mask) - mask = mask.astype(np.float32) / 255.0 - mask = mask[None].transpose(0, 3, 1, 2) - mask = torch.from_numpy(mask) - return mask.to(self.device) + mask = self._fit_image(mask, (width, height)) if fit else self._squeeze_image(mask) + return mask def _has_transparency(self, image): if image.info.get("transparency", None) is not None: diff --git a/ldm/invoke/generator/img2img.py b/ldm/invoke/generator/img2img.py index 7fde1a94cf..613f1aca31 100644 --- a/ldm/invoke/generator/img2img.py +++ b/ldm/invoke/generator/img2img.py @@ -4,9 +4,12 @@ ldm.invoke.generator.img2img descends from ldm.invoke.generator import torch import numpy as np -from ldm.invoke.devices import choose_autocast -from ldm.invoke.generator.base import Generator -from ldm.models.diffusion.ddim import DDIMSampler +import PIL +from torch import Tensor +from PIL import Image +from ldm.invoke.devices import choose_autocast +from ldm.invoke.generator.base import Generator +from ldm.models.diffusion.ddim import DDIMSampler class Img2Img(Generator): def __init__(self, model, precision): @@ -25,6 +28,9 @@ class Img2Img(Generator): ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False ) + if isinstance(init_image, PIL.Image.Image): + init_image = self._image_to_tensor(init_image) + scope = choose_autocast(self.precision) with scope(self.model.device.type): self.init_latent = self.model.get_first_stage_encoding( @@ -68,3 +74,11 @@ class Img2Img(Generator): shape = init_latent.shape x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(shape[3], shape[2]) return x + + def _image_to_tensor(self, image:Image, normalize:bool=True)->Tensor: + image = np.array(image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + if normalize: + image = 2.0 * image - 1.0 + return image.to(self.model.device) diff --git a/ldm/invoke/generator/inpaint.py b/ldm/invoke/generator/inpaint.py index ee67b90c46..f524f3d236 100644 --- a/ldm/invoke/generator/inpaint.py +++ b/ldm/invoke/generator/inpaint.py @@ -6,6 +6,7 @@ import torch import torchvision.transforms as T import numpy as np import cv2 as cv +import PIL from PIL import Image, ImageFilter from skimage.exposure.histogram_matching import match_histograms from einops import rearrange, repeat @@ -13,16 +14,19 @@ from ldm.invoke.devices import choose_autocast from ldm.invoke.generator.img2img import Img2Img from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.ksampler import KSampler +from ldm.invoke.generator.base import downsampling class Inpaint(Img2Img): def __init__(self, model, precision): self.init_latent = None + self.pil_image = None + self.pil_mask = None + self.mask_blur_radius = 0 super().__init__(model, precision) @torch.no_grad() def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta, conditioning,init_image,mask_image,strength, - pil_image: Image.Image, pil_mask: Image.Image, mask_blur_radius: int = 8, step_callback=None,inpaint_replace=False, **kwargs): """ @@ -31,17 +35,22 @@ class Inpaint(Img2Img): the time you call it. kwargs are 'init_latent' and 'strength' """ - # Get the alpha channel of the mask - pil_init_mask = pil_mask.getchannel('A') - pil_init_image = pil_image.convert('RGBA') # Add an alpha channel if one doesn't exist + if isinstance(init_image, PIL.Image.Image): + self.pil_image = init_image + init_image = self._image_to_tensor(init_image) - # Build an image with only visible pixels from source to use as reference for color-matching. - # Note that this doesn't use the mask, which would exclude some source image pixels from the - # histogram and cause slight color changes. - init_rgb_pixels = np.asarray(pil_image.convert('RGB'), dtype=np.uint8).reshape(pil_image.width * pil_image.height, 3) - init_a_pixels = np.asarray(pil_init_image.getchannel('A'), dtype=np.uint8).reshape(pil_init_mask.width * pil_init_mask.height) - init_rgb_pixels = init_rgb_pixels[init_a_pixels > 0] - init_rgb_pixels = init_rgb_pixels.reshape(1, init_rgb_pixels.shape[0], init_rgb_pixels.shape[1]) # Filter to just pixels that have any alpha, this is now our histogram + if isinstance(mask_image, PIL.Image.Image): + self.pil_mask = mask_image + mask_image = mask_image.resize( + ( + mask_image.width // downsampling, + mask_image.height // downsampling + ), + resample=Image.Resampling.NEAREST + ) + mask_image = self._image_to_tensor(mask_image,normalize=False) + + self.mask_blur_radius = mask_blur_radius # klms samplers not supported yet, so ignore previous sampler if isinstance(sampler,KSampler): @@ -96,30 +105,50 @@ class Inpaint(Img2Img): mask = mask_image, init_latent = self.init_latent ) - - # Get PIL result - gen_result = self.sample_to_image(samples).convert('RGB') - - # Get numpy version - np_gen_result = np.asarray(gen_result, dtype=np.uint8) - - # Color correct - np_matched_result = match_histograms(np_gen_result, init_rgb_pixels, channel_axis=-1) - matched_result = Image.fromarray(np_matched_result, mode='RGB') - - - # Blur the mask out (into init image) by specified amount - if mask_blur_radius > 0: - nm = np.asarray(pil_init_mask, dtype=np.uint8) - nmd = cv.dilate(nm, kernel=np.ones((3,3), dtype=np.uint8), iterations=int(mask_blur_radius / 2)) - pmd = Image.fromarray(nmd, mode='L') - blurred_init_mask = pmd.filter(ImageFilter.BoxBlur(mask_blur_radius)) - else: - blurred_init_mask = pil_init_mask - - # Paste original on color-corrected generation (using blurred mask) - matched_result.paste(pil_image, (0,0), mask = blurred_init_mask) - - return matched_result + return self.sample_to_image(samples) return make_image + + def sample_to_image(self, samples)->Image: + gen_result = super().sample_to_image(samples).convert('RGB') + + if self.pil_image is None or self.pil_mask is None: + return gen_result + + pil_mask = self.pil_mask + pil_image = self.pil_image + mask_blur_radius = self.mask_blur_radius + + # Get the original alpha channel of the mask if there is one. + # Otherwise it is some other black/white image format ('1', 'L' or 'RGB') + pil_init_mask = pil_mask.getchannel('A') if pil_mask.mode == 'RGBA' else pil_mask.convert('L') + pil_init_image = pil_image.convert('RGBA') # Add an alpha channel if one doesn't exist + + # Build an image with only visible pixels from source to use as reference for color-matching. + # Note that this doesn't use the mask, which would exclude some source image pixels from the + # histogram and cause slight color changes. + init_rgb_pixels = np.asarray(pil_image.convert('RGB'), dtype=np.uint8).reshape(pil_image.width * pil_image.height, 3) + init_a_pixels = np.asarray(pil_init_image.getchannel('A'), dtype=np.uint8).reshape(pil_init_mask.width * pil_init_mask.height) + init_rgb_pixels = init_rgb_pixels[init_a_pixels > 0] + init_rgb_pixels = init_rgb_pixels.reshape(1, init_rgb_pixels.shape[0], init_rgb_pixels.shape[1]) # Filter to just pixels that have any alpha, this is now our histogram + + # Get numpy version + np_gen_result = np.asarray(gen_result, dtype=np.uint8) + + # Color correct + np_matched_result = match_histograms(np_gen_result, init_rgb_pixels, channel_axis=-1) + matched_result = Image.fromarray(np_matched_result, mode='RGB') + + # Blur the mask out (into init image) by specified amount + if mask_blur_radius > 0: + nm = np.asarray(pil_init_mask, dtype=np.uint8) + nmd = cv.dilate(nm, kernel=np.ones((3,3), dtype=np.uint8), iterations=int(mask_blur_radius / 2)) + pmd = Image.fromarray(nmd, mode='L') + blurred_init_mask = pmd.filter(ImageFilter.BoxBlur(mask_blur_radius)) + else: + blurred_init_mask = pil_init_mask + + # Paste original on color-corrected generation (using blurred mask) + matched_result.paste(pil_image, (0,0), mask = blurred_init_mask) + return matched_result + diff --git a/ldm/invoke/model_cache.py b/ldm/invoke/model_cache.py index 5e9e53cfb7..f580dfba25 100644 --- a/ldm/invoke/model_cache.py +++ b/ldm/invoke/model_cache.py @@ -13,6 +13,7 @@ import gc import hashlib import psutil import transformers +import os from sys import getrefcount from omegaconf import OmegaConf from omegaconf.errors import ConfigAttributeError @@ -193,6 +194,7 @@ class ModelCache(object): mconfig = self.config[model_name] config = mconfig.config weights = mconfig.weights + vae = mconfig.get('vae',None) width = mconfig.width height = mconfig.height @@ -222,9 +224,17 @@ class ModelCache(object): else: print(' | Using more accurate float32 precision') + # look and load a matching vae file. Code borrowed from AUTOMATIC1111 modules/sd_models.py + if vae and os.path.exists(vae): + print(f' | Loading VAE weights from: {vae}') + vae_ckpt = torch.load(vae, map_location="cpu") + vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"} + model.first_stage_model.load_state_dict(vae_dict, strict=False) + model.to(self.device) # model.to doesn't change the cond_stage_model.device used to move the tokenizer output, so set it here model.cond_stage_model.device = self.device + model.eval() for m in model.modules(): diff --git a/scripts/invoke.py b/scripts/invoke.py index b7af4d6469..f4d4f3c4c0 100644 --- a/scripts/invoke.py +++ b/scripts/invoke.py @@ -493,6 +493,16 @@ def add_weights_to_config(model_path:str, gen, opt, completer): new_config['config'] = input('Configuration file for this model: ') done = os.path.exists(new_config['config']) + done = False + completer.complete_extensions(('.vae.pt','.vae','.ckpt')) + while not done: + vae = input('VAE autoencoder file for this model [None]: ') + if os.path.exists(vae): + new_config['vae'] = vae + done = True + else: + done = len(vae)==0 + completer.complete_extensions(None) for field in ('width','height'): @@ -537,8 +547,8 @@ def edit_config(model_name:str, gen, opt, completer): conf = config[model_name] new_config = {} - completer.complete_extensions(('.yaml','.yml','.ckpt','.vae')) - for field in ('description', 'weights', 'config', 'width','height'): + completer.complete_extensions(('.yaml','.yml','.ckpt','.vae.pt')) + for field in ('description', 'weights', 'vae', 'config', 'width','height'): completer.linebuffer = str(conf[field]) if field in conf else '' new_value = input(f'{field}: ') new_config[field] = int(new_value) if field in ('width','height') else new_value