textual inversion + init img fix

This commit is contained in:
nicolai256 2022-08-24 05:16:01 +02:00 committed by GitHub
parent 0cdf5e61b0
commit 9588444f0e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -51,6 +51,7 @@ import sys
import os import os
from omegaconf import OmegaConf from omegaconf import OmegaConf
from PIL import Image from PIL import Image
import PIL
from tqdm import tqdm, trange from tqdm import tqdm, trange
from itertools import islice from itertools import islice
from einops import rearrange, repeat from einops import rearrange, repeat
@ -89,6 +90,7 @@ class T2I:
downsampling_factor downsampling_factor
precision precision
strength strength
embedding_path
The vast majority of these arguments default to reasonable values. The vast majority of these arguments default to reasonable values.
""" """
@ -113,6 +115,7 @@ The vast majority of these arguments default to reasonable values.
precision='autocast', precision='autocast',
full_precision=False, full_precision=False,
strength=0.75, # default in scripts/img2img.py strength=0.75, # default in scripts/img2img.py
embedding_path=None,
latent_diffusion_weights=False # just to keep track of this parameter when regenerating prompt latent_diffusion_weights=False # just to keep track of this parameter when regenerating prompt
): ):
self.outdir = outdir self.outdir = outdir
@ -133,6 +136,7 @@ The vast majority of these arguments default to reasonable values.
self.precision = precision self.precision = precision
self.full_precision = full_precision self.full_precision = full_precision
self.strength = strength self.strength = strength
self.embedding_path = embedding_path
self.model = None # empty for now self.model = None # empty for now
self.sampler = None self.sampler = None
self.latent_diffusion_weights=latent_diffusion_weights self.latent_diffusion_weights=latent_diffusion_weights
@ -143,7 +147,7 @@ The vast majority of these arguments default to reasonable values.
def txt2img(self,prompt,outdir=None,batch_size=None,iterations=None, def txt2img(self,prompt,outdir=None,batch_size=None,iterations=None,
steps=None,seed=None,grid=None,individual=None,width=None,height=None, steps=None,seed=None,grid=None,individual=None,width=None,height=None,
cfg_scale=None,ddim_eta=None,strength=None,init_img=None,skip_normalize=False): cfg_scale=None,ddim_eta=None,strength=None,embedding_path=None,init_img=None,skip_normalize=False):
""" """
Generate an image from the prompt, writing iteration images into the outdir Generate an image from the prompt, writing iteration images into the outdir
The output is a list of lists in the format: [[filename1,seed1], [filename2,seed2],...] The output is a list of lists in the format: [[filename1,seed1], [filename2,seed2],...]
@ -158,6 +162,7 @@ The vast majority of these arguments default to reasonable values.
batch_size = batch_size or self.batch_size batch_size = batch_size or self.batch_size
iterations = iterations or self.iterations iterations = iterations or self.iterations
strength = strength or self.strength # not actually used here, but preserved for code refactoring strength = strength or self.strength # not actually used here, but preserved for code refactoring
embedding_path = embedding_path or self.embedding_path
model = self.load_model() # will instantiate the model or return it from cache model = self.load_model() # will instantiate the model or return it from cache
@ -268,7 +273,7 @@ The vast majority of these arguments default to reasonable values.
# There is lots of shared code between this and txt2img and should be refactored. # There is lots of shared code between this and txt2img and should be refactored.
def img2img(self,prompt,outdir=None,init_img=None,batch_size=None,iterations=None, def img2img(self,prompt,outdir=None,init_img=None,batch_size=None,iterations=None,
steps=None,seed=None,grid=None,individual=None,width=None,height=None, steps=None,seed=None,grid=None,individual=None,width=None,height=None,
cfg_scale=None,ddim_eta=None,strength=None,skip_normalize=False): cfg_scale=None,ddim_eta=None,strength=None,embedding_path=None,skip_normalize=False):
""" """
Generate an image from the prompt and the initial image, writing iteration images into the outdir Generate an image from the prompt and the initial image, writing iteration images into the outdir
The output is a list of lists in the format: [[filename1,seed1], [filename2,seed2],...] The output is a list of lists in the format: [[filename1,seed1], [filename2,seed2],...]
@ -281,6 +286,7 @@ The vast majority of these arguments default to reasonable values.
batch_size = batch_size or self.batch_size batch_size = batch_size or self.batch_size
iterations = iterations or self.iterations iterations = iterations or self.iterations
strength = strength or self.strength strength = strength or self.strength
embedding_path = embedding_path or self.embedding_path
if init_img is None: if init_img is None:
print("no init_img provided!") print("no init_img provided!")
@ -431,6 +437,7 @@ The vast majority of these arguments default to reasonable values.
config = OmegaConf.load(self.config) config = OmegaConf.load(self.config)
self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = self._load_model_from_config(config,self.weights) model = self._load_model_from_config(config,self.weights)
model.embedding_manager.load(self.embedding_path)
self.model = model.to(self.device) self.model = model.to(self.device)
except AttributeError: except AttributeError:
raise SystemExit raise SystemExit
@ -472,7 +479,7 @@ The vast majority of these arguments default to reasonable values.
w, h = image.size w, h = image.size
print(f"loaded input image of size ({w}, {h}) from {path}") print(f"loaded input image of size ({w}, {h}) from {path}")
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=Image.Resampling.LANCZOS) image = image.resize((w, h), resample=PIL.Image.LANCZOS)
image = np.array(image).astype(np.float32) / 255.0 image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2) image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image) image = torch.from_numpy(image)