Merge pull request #58 from nicolai256/main

init img didn't work in textual inversion, now it does :)
This commit is contained in:
Lincoln Stein 2022-08-24 11:24:30 -04:00 committed by GitHub
commit 73901a2777
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -58,6 +58,7 @@ import sys
import os
from omegaconf import OmegaConf
from PIL import Image
import PIL
from tqdm import tqdm, trange
from itertools import islice
from einops import rearrange, repeat
@ -96,6 +97,7 @@ class T2I:
downsampling_factor
precision
strength
embedding_path
The vast majority of these arguments default to reasonable values.
"""
@ -120,6 +122,7 @@ The vast majority of these arguments default to reasonable values.
precision='autocast',
full_precision=False,
strength=0.75, # default in scripts/img2img.py
embedding_path=None,
latent_diffusion_weights=False # just to keep track of this parameter when regenerating prompt
):
self.outdir = outdir
@ -140,6 +143,7 @@ The vast majority of these arguments default to reasonable values.
self.precision = precision
self.full_precision = full_precision
self.strength = strength
self.embedding_path = embedding_path
self.model = None # empty for now
self.sampler = None
self.latent_diffusion_weights=latent_diffusion_weights
@ -150,7 +154,7 @@ The vast majority of these arguments default to reasonable values.
def txt2img(self,prompt,outdir=None,batch_size=None,iterations=None,
steps=None,seed=None,grid=None,individual=None,width=None,height=None,
cfg_scale=None,ddim_eta=None,strength=None,init_img=None,skip_normalize=False):
cfg_scale=None,ddim_eta=None,strength=None,embedding_path=None,init_img=None,skip_normalize=False):
"""
Generate an image from the prompt, writing iteration images into the outdir
The output is a list of lists in the format: [[filename1,seed1], [filename2,seed2],...]
@ -165,6 +169,7 @@ The vast majority of these arguments default to reasonable values.
batch_size = batch_size or self.batch_size
iterations = iterations or self.iterations
strength = strength or self.strength # not actually used here, but preserved for code refactoring
embedding_path = embedding_path or self.embedding_path
model = self.load_model() # will instantiate the model or return it from cache
@ -278,7 +283,7 @@ The vast majority of these arguments default to reasonable values.
# There is lots of shared code between this and txt2img and should be refactored.
def img2img(self,prompt,outdir=None,init_img=None,batch_size=None,iterations=None,
steps=None,seed=None,grid=None,individual=None,width=None,height=None,
cfg_scale=None,ddim_eta=None,strength=None,skip_normalize=False):
cfg_scale=None,ddim_eta=None,strength=None,embedding_path=None,skip_normalize=False):
"""
Generate an image from the prompt and the initial image, writing iteration images into the outdir
The output is a list of lists in the format: [[filename1,seed1], [filename2,seed2],...]
@ -291,6 +296,7 @@ The vast majority of these arguments default to reasonable values.
batch_size = batch_size or self.batch_size
iterations = iterations or self.iterations
strength = strength or self.strength
embedding_path = embedding_path or self.embedding_path
assert strength<1.0 and strength>=0.0, "strength (-f) must be >=0.0 and <1.0"
assert cfg_scale>1.0, "CFG_Scale (-C) must be >1.0"
@ -444,6 +450,7 @@ The vast majority of these arguments default to reasonable values.
config = OmegaConf.load(self.config)
self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = self._load_model_from_config(config,self.weights)
model.embedding_manager.load(self.embedding_path)
self.model = model.to(self.device)
except AttributeError:
raise SystemExit
@ -495,7 +502,7 @@ The vast majority of these arguments default to reasonable values.
w, h = image.size
print(f"loaded input image of size ({w}, {h}) from {path}")
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=Image.Resampling.LANCZOS)
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)