mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'inpaint-improvement' of https://github.com/Kyle0654/InvokeAI into inpaint-improvement
This commit is contained in:
commit
0c34554170
@ -1,21 +1,21 @@
|
|||||||
# This file describes the alternative machine learning models
|
# This file describes the alternative machine learning models
|
||||||
# available to the dream script.
|
# available to the dream script.
|
||||||
#
|
#
|
||||||
# To add a new model, follow the examples below. Each
|
# To add a new model, follow the examples below. Each
|
||||||
# model requires a model config file, a weights file,
|
# model requires a model config file, a weights file,
|
||||||
# and the width and height of the images it
|
# and the width and height of the images it
|
||||||
# was trained on.
|
# was trained on.
|
||||||
|
|
||||||
stable-diffusion-1.4:
|
stable-diffusion-1.4:
|
||||||
config: configs/stable-diffusion/v1-inference.yaml
|
config: configs/stable-diffusion/v1-inference.yaml
|
||||||
weights: models/ldm/stable-diffusion-v1/model.ckpt
|
weights: models/ldm/stable-diffusion-v1/model.ckpt
|
||||||
description: Stable Diffusion inference model version 1.4
|
description: Stable Diffusion inference model version 1.4
|
||||||
width: 512
|
width: 512
|
||||||
height: 512
|
height: 512
|
||||||
default: true
|
default: true
|
||||||
stable-diffusion-1.5:
|
stable-diffusion-1.5:
|
||||||
config: configs/stable-diffusion/v1-inference.yaml
|
config: configs/stable-diffusion/v1-inference.yaml
|
||||||
weights: models/ldm/stable-diffusion-v1/v1-5-pruned-emaonly.ckpt
|
weights: models/ldm/stable-diffusion-v1/v1-5-pruned-emaonly.ckpt
|
||||||
description: Stable Diffusion inference model version 1.5
|
description: Stable Diffusion inference model version 1.5
|
||||||
width: 512
|
width: 512
|
||||||
height: 512
|
height: 512
|
||||||
|
vae: models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt
|
||||||
|
@ -58,24 +58,6 @@ torch.multinomial = fix_func(torch.multinomial)
|
|||||||
# this is fallback model in case no default is defined
|
# this is fallback model in case no default is defined
|
||||||
FALLBACK_MODEL_NAME='stable-diffusion-1.4'
|
FALLBACK_MODEL_NAME='stable-diffusion-1.4'
|
||||||
|
|
||||||
def fix_func(orig):
|
|
||||||
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
|
||||||
def new_func(*args, **kw):
|
|
||||||
device = kw.get("device", "mps")
|
|
||||||
kw["device"]="cpu"
|
|
||||||
return orig(*args, **kw).to(device)
|
|
||||||
return new_func
|
|
||||||
return orig
|
|
||||||
|
|
||||||
torch.rand = fix_func(torch.rand)
|
|
||||||
torch.rand_like = fix_func(torch.rand_like)
|
|
||||||
torch.randn = fix_func(torch.randn)
|
|
||||||
torch.randn_like = fix_func(torch.randn_like)
|
|
||||||
torch.randint = fix_func(torch.randint)
|
|
||||||
torch.randint_like = fix_func(torch.randint_like)
|
|
||||||
torch.bernoulli = fix_func(torch.bernoulli)
|
|
||||||
torch.multinomial = fix_func(torch.multinomial)
|
|
||||||
|
|
||||||
"""Simplified text to image API for stable diffusion/latent diffusion
|
"""Simplified text to image API for stable diffusion/latent diffusion
|
||||||
|
|
||||||
Example Usage:
|
Example Usage:
|
||||||
@ -411,7 +393,7 @@ class Generate:
|
|||||||
log_tokens =self.log_tokenization
|
log_tokens =self.log_tokenization
|
||||||
)
|
)
|
||||||
|
|
||||||
init_image,mask_image,pil_image,pil_mask = self._make_images(
|
init_image, mask_image = self._make_images(
|
||||||
init_img,
|
init_img,
|
||||||
init_mask,
|
init_mask,
|
||||||
width,
|
width,
|
||||||
@ -451,8 +433,6 @@ class Generate:
|
|||||||
height=height,
|
height=height,
|
||||||
init_img=init_img, # embiggen needs to manipulate from the unmodified init_img
|
init_img=init_img, # embiggen needs to manipulate from the unmodified init_img
|
||||||
init_image=init_image, # notice that init_image is different from init_img
|
init_image=init_image, # notice that init_image is different from init_img
|
||||||
pil_image=pil_image,
|
|
||||||
pil_mask=pil_mask,
|
|
||||||
mask_image=mask_image,
|
mask_image=mask_image,
|
||||||
strength=strength,
|
strength=strength,
|
||||||
threshold=threshold,
|
threshold=threshold,
|
||||||
@ -644,7 +624,7 @@ class Generate:
|
|||||||
init_image = None
|
init_image = None
|
||||||
init_mask = None
|
init_mask = None
|
||||||
if not img:
|
if not img:
|
||||||
return None, None, None, None
|
return None, None
|
||||||
|
|
||||||
image = self._load_img(img)
|
image = self._load_img(img)
|
||||||
|
|
||||||
@ -654,23 +634,22 @@ class Generate:
|
|||||||
# if image has a transparent area and no mask was provided, then try to generate mask
|
# if image has a transparent area and no mask was provided, then try to generate mask
|
||||||
if self._has_transparency(image):
|
if self._has_transparency(image):
|
||||||
self._transparency_check_and_warning(image, mask)
|
self._transparency_check_and_warning(image, mask)
|
||||||
# this returns a torch tensor
|
|
||||||
init_mask = self._create_init_mask(image, width, height, fit=fit)
|
init_mask = self._create_init_mask(image, width, height, fit=fit)
|
||||||
|
|
||||||
if (image.width * image.height) > (self.width * self.height) and self.size_matters:
|
if (image.width * image.height) > (self.width * self.height) and self.size_matters:
|
||||||
print(">> This input is larger than your defaults. If you run out of memory, please use a smaller image.")
|
print(">> This input is larger than your defaults. If you run out of memory, please use a smaller image.")
|
||||||
self.size_matters = False
|
self.size_matters = False
|
||||||
|
|
||||||
init_image = self._create_init_image(image,width,height,fit=fit) # this returns a torch tensor
|
init_image = self._create_init_image(image,width,height,fit=fit)
|
||||||
|
|
||||||
if mask:
|
if mask:
|
||||||
mask_image = self._load_img(mask) # this returns an Image
|
mask_image = self._load_img(mask)
|
||||||
init_mask = self._create_init_mask(mask_image,width,height,fit=fit)
|
init_mask = self._create_init_mask(mask_image,width,height,fit=fit)
|
||||||
|
|
||||||
elif text_mask:
|
elif text_mask:
|
||||||
init_mask = self._txt2mask(image, text_mask, width, height, fit=fit)
|
init_mask = self._txt2mask(image, text_mask, width, height, fit=fit)
|
||||||
|
|
||||||
return init_image, init_mask, image, mask_image
|
return init_image,init_mask
|
||||||
|
|
||||||
def _make_base(self):
|
def _make_base(self):
|
||||||
if not self.generators.get('base'):
|
if not self.generators.get('base'):
|
||||||
@ -887,33 +866,15 @@ class Generate:
|
|||||||
|
|
||||||
def _create_init_image(self, image, width, height, fit=True):
|
def _create_init_image(self, image, width, height, fit=True):
|
||||||
image = image.convert('RGB')
|
image = image.convert('RGB')
|
||||||
if fit:
|
image = self._fit_image(image, (width, height)) if fit else self._squeeze_image(image)
|
||||||
image = self._fit_image(image, (width, height))
|
return image
|
||||||
else:
|
|
||||||
image = self._squeeze_image(image)
|
|
||||||
image = np.array(image).astype(np.float32) / 255.0
|
|
||||||
image = image[None].transpose(0, 3, 1, 2)
|
|
||||||
image = torch.from_numpy(image)
|
|
||||||
image = 2.0 * image - 1.0
|
|
||||||
return image.to(self.device)
|
|
||||||
|
|
||||||
def _create_init_mask(self, image, width, height, fit=True):
|
def _create_init_mask(self, image, width, height, fit=True):
|
||||||
# convert into a black/white mask
|
# convert into a black/white mask
|
||||||
image = self._image_to_mask(image)
|
image = self._image_to_mask(image)
|
||||||
image = image.convert('RGB')
|
image = image.convert('RGB')
|
||||||
|
image = self._fit_image(image, (width, height)) if fit else self._squeeze_image(image)
|
||||||
# now we adjust the size
|
return image
|
||||||
if fit:
|
|
||||||
image = self._fit_image(image, (width, height))
|
|
||||||
else:
|
|
||||||
image = self._squeeze_image(image)
|
|
||||||
image = image.resize((image.width//downsampling, image.height //
|
|
||||||
downsampling), resample=Image.Resampling.NEAREST)
|
|
||||||
image = np.array(image)
|
|
||||||
image = image.astype(np.float32) / 255.0
|
|
||||||
image = image[None].transpose(0, 3, 1, 2)
|
|
||||||
image = torch.from_numpy(image)
|
|
||||||
return image.to(self.device)
|
|
||||||
|
|
||||||
# The mask is expected to have the region to be inpainted
|
# The mask is expected to have the region to be inpainted
|
||||||
# with alpha transparency. It converts it into a black/white
|
# with alpha transparency. It converts it into a black/white
|
||||||
@ -930,7 +891,6 @@ class Generate:
|
|||||||
mask = ImageOps.invert(mask)
|
mask = ImageOps.invert(mask)
|
||||||
return mask
|
return mask
|
||||||
|
|
||||||
# TODO: The latter part of this method repeats code from _create_init_mask()
|
|
||||||
def _txt2mask(self, image:Image, text_mask:list, width, height, fit=True) -> Image:
|
def _txt2mask(self, image:Image, text_mask:list, width, height, fit=True) -> Image:
|
||||||
prompt = text_mask[0]
|
prompt = text_mask[0]
|
||||||
confidence_level = text_mask[1] if len(text_mask)>1 else 0.5
|
confidence_level = text_mask[1] if len(text_mask)>1 else 0.5
|
||||||
@ -940,18 +900,8 @@ class Generate:
|
|||||||
segmented = self.txt2mask.segment(image, prompt)
|
segmented = self.txt2mask.segment(image, prompt)
|
||||||
mask = segmented.to_mask(float(confidence_level))
|
mask = segmented.to_mask(float(confidence_level))
|
||||||
mask = mask.convert('RGB')
|
mask = mask.convert('RGB')
|
||||||
# now we adjust the size
|
mask = self._fit_image(mask, (width, height)) if fit else self._squeeze_image(mask)
|
||||||
if fit:
|
return mask
|
||||||
mask = self._fit_image(mask, (width, height))
|
|
||||||
else:
|
|
||||||
mask = self._squeeze_image(mask)
|
|
||||||
mask = mask.resize((mask.width//downsampling, mask.height //
|
|
||||||
downsampling), resample=Image.Resampling.NEAREST)
|
|
||||||
mask = np.array(mask)
|
|
||||||
mask = mask.astype(np.float32) / 255.0
|
|
||||||
mask = mask[None].transpose(0, 3, 1, 2)
|
|
||||||
mask = torch.from_numpy(mask)
|
|
||||||
return mask.to(self.device)
|
|
||||||
|
|
||||||
def _has_transparency(self, image):
|
def _has_transparency(self, image):
|
||||||
if image.info.get("transparency", None) is not None:
|
if image.info.get("transparency", None) is not None:
|
||||||
|
@ -4,9 +4,12 @@ ldm.invoke.generator.img2img descends from ldm.invoke.generator
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from ldm.invoke.devices import choose_autocast
|
import PIL
|
||||||
from ldm.invoke.generator.base import Generator
|
from torch import Tensor
|
||||||
from ldm.models.diffusion.ddim import DDIMSampler
|
from PIL import Image
|
||||||
|
from ldm.invoke.devices import choose_autocast
|
||||||
|
from ldm.invoke.generator.base import Generator
|
||||||
|
from ldm.models.diffusion.ddim import DDIMSampler
|
||||||
|
|
||||||
class Img2Img(Generator):
|
class Img2Img(Generator):
|
||||||
def __init__(self, model, precision):
|
def __init__(self, model, precision):
|
||||||
@ -25,6 +28,9 @@ class Img2Img(Generator):
|
|||||||
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
|
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if isinstance(init_image, PIL.Image.Image):
|
||||||
|
init_image = self._image_to_tensor(init_image)
|
||||||
|
|
||||||
scope = choose_autocast(self.precision)
|
scope = choose_autocast(self.precision)
|
||||||
with scope(self.model.device.type):
|
with scope(self.model.device.type):
|
||||||
self.init_latent = self.model.get_first_stage_encoding(
|
self.init_latent = self.model.get_first_stage_encoding(
|
||||||
@ -68,3 +74,11 @@ class Img2Img(Generator):
|
|||||||
shape = init_latent.shape
|
shape = init_latent.shape
|
||||||
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(shape[3], shape[2])
|
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(shape[3], shape[2])
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
def _image_to_tensor(self, image:Image, normalize:bool=True)->Tensor:
|
||||||
|
image = np.array(image).astype(np.float32) / 255.0
|
||||||
|
image = image[None].transpose(0, 3, 1, 2)
|
||||||
|
image = torch.from_numpy(image)
|
||||||
|
if normalize:
|
||||||
|
image = 2.0 * image - 1.0
|
||||||
|
return image.to(self.model.device)
|
||||||
|
@ -6,6 +6,7 @@ import torch
|
|||||||
import torchvision.transforms as T
|
import torchvision.transforms as T
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import cv2 as cv
|
import cv2 as cv
|
||||||
|
import PIL
|
||||||
from PIL import Image, ImageFilter
|
from PIL import Image, ImageFilter
|
||||||
from skimage.exposure.histogram_matching import match_histograms
|
from skimage.exposure.histogram_matching import match_histograms
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
@ -13,16 +14,19 @@ from ldm.invoke.devices import choose_autocast
|
|||||||
from ldm.invoke.generator.img2img import Img2Img
|
from ldm.invoke.generator.img2img import Img2Img
|
||||||
from ldm.models.diffusion.ddim import DDIMSampler
|
from ldm.models.diffusion.ddim import DDIMSampler
|
||||||
from ldm.models.diffusion.ksampler import KSampler
|
from ldm.models.diffusion.ksampler import KSampler
|
||||||
|
from ldm.invoke.generator.base import downsampling
|
||||||
|
|
||||||
class Inpaint(Img2Img):
|
class Inpaint(Img2Img):
|
||||||
def __init__(self, model, precision):
|
def __init__(self, model, precision):
|
||||||
self.init_latent = None
|
self.init_latent = None
|
||||||
|
self.pil_image = None
|
||||||
|
self.pil_mask = None
|
||||||
|
self.mask_blur_radius = 0
|
||||||
super().__init__(model, precision)
|
super().__init__(model, precision)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
|
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
|
||||||
conditioning,init_image,mask_image,strength,
|
conditioning,init_image,mask_image,strength,
|
||||||
pil_image: Image.Image, pil_mask: Image.Image,
|
|
||||||
mask_blur_radius: int = 8,
|
mask_blur_radius: int = 8,
|
||||||
step_callback=None,inpaint_replace=False, **kwargs):
|
step_callback=None,inpaint_replace=False, **kwargs):
|
||||||
"""
|
"""
|
||||||
@ -31,17 +35,22 @@ class Inpaint(Img2Img):
|
|||||||
the time you call it. kwargs are 'init_latent' and 'strength'
|
the time you call it. kwargs are 'init_latent' and 'strength'
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Get the alpha channel of the mask
|
if isinstance(init_image, PIL.Image.Image):
|
||||||
pil_init_mask = pil_mask.getchannel('A')
|
self.pil_image = init_image
|
||||||
pil_init_image = pil_image.convert('RGBA') # Add an alpha channel if one doesn't exist
|
init_image = self._image_to_tensor(init_image)
|
||||||
|
|
||||||
# Build an image with only visible pixels from source to use as reference for color-matching.
|
if isinstance(mask_image, PIL.Image.Image):
|
||||||
# Note that this doesn't use the mask, which would exclude some source image pixels from the
|
self.pil_mask = mask_image
|
||||||
# histogram and cause slight color changes.
|
mask_image = mask_image.resize(
|
||||||
init_rgb_pixels = np.asarray(pil_image.convert('RGB'), dtype=np.uint8).reshape(pil_image.width * pil_image.height, 3)
|
(
|
||||||
init_a_pixels = np.asarray(pil_init_image.getchannel('A'), dtype=np.uint8).reshape(pil_init_mask.width * pil_init_mask.height)
|
mask_image.width // downsampling,
|
||||||
init_rgb_pixels = init_rgb_pixels[init_a_pixels > 0]
|
mask_image.height // downsampling
|
||||||
init_rgb_pixels = init_rgb_pixels.reshape(1, init_rgb_pixels.shape[0], init_rgb_pixels.shape[1]) # Filter to just pixels that have any alpha, this is now our histogram
|
),
|
||||||
|
resample=Image.Resampling.NEAREST
|
||||||
|
)
|
||||||
|
mask_image = self._image_to_tensor(mask_image,normalize=False)
|
||||||
|
|
||||||
|
self.mask_blur_radius = mask_blur_radius
|
||||||
|
|
||||||
# klms samplers not supported yet, so ignore previous sampler
|
# klms samplers not supported yet, so ignore previous sampler
|
||||||
if isinstance(sampler,KSampler):
|
if isinstance(sampler,KSampler):
|
||||||
@ -96,30 +105,50 @@ class Inpaint(Img2Img):
|
|||||||
mask = mask_image,
|
mask = mask_image,
|
||||||
init_latent = self.init_latent
|
init_latent = self.init_latent
|
||||||
)
|
)
|
||||||
|
return self.sample_to_image(samples)
|
||||||
# Get PIL result
|
|
||||||
gen_result = self.sample_to_image(samples).convert('RGB')
|
|
||||||
|
|
||||||
# Get numpy version
|
|
||||||
np_gen_result = np.asarray(gen_result, dtype=np.uint8)
|
|
||||||
|
|
||||||
# Color correct
|
|
||||||
np_matched_result = match_histograms(np_gen_result, init_rgb_pixels, channel_axis=-1)
|
|
||||||
matched_result = Image.fromarray(np_matched_result, mode='RGB')
|
|
||||||
|
|
||||||
|
|
||||||
# Blur the mask out (into init image) by specified amount
|
|
||||||
if mask_blur_radius > 0:
|
|
||||||
nm = np.asarray(pil_init_mask, dtype=np.uint8)
|
|
||||||
nmd = cv.erode(nm, kernel=np.ones((3,3), dtype=np.uint8), iterations=int(mask_blur_radius / 2))
|
|
||||||
pmd = Image.fromarray(nmd, mode='L')
|
|
||||||
blurred_init_mask = pmd.filter(ImageFilter.BoxBlur(mask_blur_radius))
|
|
||||||
else:
|
|
||||||
blurred_init_mask = pil_init_mask
|
|
||||||
|
|
||||||
# Paste original on color-corrected generation (using blurred mask)
|
|
||||||
matched_result.paste(pil_image, (0,0), mask = blurred_init_mask)
|
|
||||||
|
|
||||||
return matched_result
|
|
||||||
|
|
||||||
return make_image
|
return make_image
|
||||||
|
|
||||||
|
def sample_to_image(self, samples)->Image:
|
||||||
|
gen_result = super().sample_to_image(samples).convert('RGB')
|
||||||
|
|
||||||
|
if self.pil_image is None or self.pil_mask is None:
|
||||||
|
return gen_result
|
||||||
|
|
||||||
|
pil_mask = self.pil_mask
|
||||||
|
pil_image = self.pil_image
|
||||||
|
mask_blur_radius = self.mask_blur_radius
|
||||||
|
|
||||||
|
# Get the original alpha channel of the mask if there is one.
|
||||||
|
# Otherwise it is some other black/white image format ('1', 'L' or 'RGB')
|
||||||
|
pil_init_mask = pil_mask.getchannel('A') if pil_mask.mode == 'RGBA' else pil_mask.convert('L')
|
||||||
|
pil_init_image = pil_image.convert('RGBA') # Add an alpha channel if one doesn't exist
|
||||||
|
|
||||||
|
# Build an image with only visible pixels from source to use as reference for color-matching.
|
||||||
|
# Note that this doesn't use the mask, which would exclude some source image pixels from the
|
||||||
|
# histogram and cause slight color changes.
|
||||||
|
init_rgb_pixels = np.asarray(pil_image.convert('RGB'), dtype=np.uint8).reshape(pil_image.width * pil_image.height, 3)
|
||||||
|
init_a_pixels = np.asarray(pil_init_image.getchannel('A'), dtype=np.uint8).reshape(pil_init_mask.width * pil_init_mask.height)
|
||||||
|
init_rgb_pixels = init_rgb_pixels[init_a_pixels > 0]
|
||||||
|
init_rgb_pixels = init_rgb_pixels.reshape(1, init_rgb_pixels.shape[0], init_rgb_pixels.shape[1]) # Filter to just pixels that have any alpha, this is now our histogram
|
||||||
|
|
||||||
|
# Get numpy version
|
||||||
|
np_gen_result = np.asarray(gen_result, dtype=np.uint8)
|
||||||
|
|
||||||
|
# Color correct
|
||||||
|
np_matched_result = match_histograms(np_gen_result, init_rgb_pixels, channel_axis=-1)
|
||||||
|
matched_result = Image.fromarray(np_matched_result, mode='RGB')
|
||||||
|
|
||||||
|
# Blur the mask out (into init image) by specified amount
|
||||||
|
if mask_blur_radius > 0:
|
||||||
|
nm = np.asarray(pil_init_mask, dtype=np.uint8)
|
||||||
|
nmd = cv.erode(nm, kernel=np.ones((3,3), dtype=np.uint8), iterations=int(mask_blur_radius / 2))
|
||||||
|
pmd = Image.fromarray(nmd, mode='L')
|
||||||
|
blurred_init_mask = pmd.filter(ImageFilter.BoxBlur(mask_blur_radius))
|
||||||
|
else:
|
||||||
|
blurred_init_mask = pil_init_mask
|
||||||
|
|
||||||
|
# Paste original on color-corrected generation (using blurred mask)
|
||||||
|
matched_result.paste(pil_image, (0,0), mask = blurred_init_mask)
|
||||||
|
return matched_result
|
||||||
|
|
||||||
|
@ -13,6 +13,7 @@ import gc
|
|||||||
import hashlib
|
import hashlib
|
||||||
import psutil
|
import psutil
|
||||||
import transformers
|
import transformers
|
||||||
|
import os
|
||||||
from sys import getrefcount
|
from sys import getrefcount
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from omegaconf.errors import ConfigAttributeError
|
from omegaconf.errors import ConfigAttributeError
|
||||||
@ -193,6 +194,7 @@ class ModelCache(object):
|
|||||||
mconfig = self.config[model_name]
|
mconfig = self.config[model_name]
|
||||||
config = mconfig.config
|
config = mconfig.config
|
||||||
weights = mconfig.weights
|
weights = mconfig.weights
|
||||||
|
vae = mconfig.get('vae',None)
|
||||||
width = mconfig.width
|
width = mconfig.width
|
||||||
height = mconfig.height
|
height = mconfig.height
|
||||||
|
|
||||||
@ -222,9 +224,17 @@ class ModelCache(object):
|
|||||||
else:
|
else:
|
||||||
print(' | Using more accurate float32 precision')
|
print(' | Using more accurate float32 precision')
|
||||||
|
|
||||||
|
# look and load a matching vae file. Code borrowed from AUTOMATIC1111 modules/sd_models.py
|
||||||
|
if vae and os.path.exists(vae):
|
||||||
|
print(f' | Loading VAE weights from: {vae}')
|
||||||
|
vae_ckpt = torch.load(vae, map_location="cpu")
|
||||||
|
vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"}
|
||||||
|
model.first_stage_model.load_state_dict(vae_dict, strict=False)
|
||||||
|
|
||||||
model.to(self.device)
|
model.to(self.device)
|
||||||
# model.to doesn't change the cond_stage_model.device used to move the tokenizer output, so set it here
|
# model.to doesn't change the cond_stage_model.device used to move the tokenizer output, so set it here
|
||||||
model.cond_stage_model.device = self.device
|
model.cond_stage_model.device = self.device
|
||||||
|
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
for m in model.modules():
|
for m in model.modules():
|
||||||
|
@ -493,6 +493,16 @@ def add_weights_to_config(model_path:str, gen, opt, completer):
|
|||||||
new_config['config'] = input('Configuration file for this model: ')
|
new_config['config'] = input('Configuration file for this model: ')
|
||||||
done = os.path.exists(new_config['config'])
|
done = os.path.exists(new_config['config'])
|
||||||
|
|
||||||
|
done = False
|
||||||
|
completer.complete_extensions(('.vae.pt','.vae','.ckpt'))
|
||||||
|
while not done:
|
||||||
|
vae = input('VAE autoencoder file for this model [None]: ')
|
||||||
|
if os.path.exists(vae):
|
||||||
|
new_config['vae'] = vae
|
||||||
|
done = True
|
||||||
|
else:
|
||||||
|
done = len(vae)==0
|
||||||
|
|
||||||
completer.complete_extensions(None)
|
completer.complete_extensions(None)
|
||||||
|
|
||||||
for field in ('width','height'):
|
for field in ('width','height'):
|
||||||
@ -537,8 +547,8 @@ def edit_config(model_name:str, gen, opt, completer):
|
|||||||
|
|
||||||
conf = config[model_name]
|
conf = config[model_name]
|
||||||
new_config = {}
|
new_config = {}
|
||||||
completer.complete_extensions(('.yaml','.yml','.ckpt','.vae'))
|
completer.complete_extensions(('.yaml','.yml','.ckpt','.vae.pt'))
|
||||||
for field in ('description', 'weights', 'config', 'width','height'):
|
for field in ('description', 'weights', 'vae', 'config', 'width','height'):
|
||||||
completer.linebuffer = str(conf[field]) if field in conf else ''
|
completer.linebuffer = str(conf[field]) if field in conf else ''
|
||||||
new_value = input(f'{field}: ')
|
new_value = input(f'{field}: ')
|
||||||
new_config[field] = int(new_value) if field in ('width','height') else new_value
|
new_config[field] = int(new_value) if field in ('width','height') else new_value
|
||||||
|
Loading…
x
Reference in New Issue
Block a user