prompt weighting not working

This commit is contained in:
Lincoln Stein 2022-08-23 01:23:14 -04:00
commit 11c0df07b7
2 changed files with 90 additions and 14 deletions

View File

@ -143,7 +143,7 @@ The vast majority of these arguments default to reasonable values.
def txt2img(self,prompt,outdir=None,batch_size=None,iterations=None, def txt2img(self,prompt,outdir=None,batch_size=None,iterations=None,
steps=None,seed=None,grid=None,individual=None,width=None,height=None, steps=None,seed=None,grid=None,individual=None,width=None,height=None,
cfg_scale=None,ddim_eta=None,strength=None,init_img=None): cfg_scale=None,ddim_eta=None,strength=None,init_img=None,skip_normalize=False):
""" """
Generate an image from the prompt, writing iteration images into the outdir Generate an image from the prompt, writing iteration images into the outdir
The output is a list of lists in the format: [[filename1,seed1], [filename2,seed2],...] The output is a list of lists in the format: [[filename1,seed1], [filename2,seed2],...]
@ -189,6 +189,7 @@ The vast majority of these arguments default to reasonable values.
image_count = 0 image_count = 0
tic = time.time() tic = time.time()
# Gawd. Too many levels of indent here. Need to refactor into smaller routines!
try: try:
with torch.no_grad(): with torch.no_grad():
with precision_scope("cuda"): with precision_scope("cuda"):
@ -202,7 +203,23 @@ The vast majority of these arguments default to reasonable values.
uc = model.get_learned_conditioning(batch_size * [""]) uc = model.get_learned_conditioning(batch_size * [""])
if isinstance(prompts, tuple): if isinstance(prompts, tuple):
prompts = list(prompts) prompts = list(prompts)
c = model.get_learned_conditioning(prompts)
# weighted sub-prompts
subprompts,weights = T2I._split_weighted_subprompts(prompts[0])
if len(subprompts) > 1:
# i dont know if this is correct.. but it works
c = torch.zeros_like(uc)
# get total weight for normalizing
totalWeight = sum(weights)
# normalize each "sub prompt" and add it
for i in range(0,len(subprompts)):
weight = weights[i]
if not skip_normalize:
weight = weight / totalWeight
c = torch.add(c,model.get_learned_conditioning(subprompts[i]), alpha=weight)
else: # just standard 1 prompt
c = model.get_learned_conditioning(prompts)
shape = [self.latent_channels, height // self.downsampling_factor, width // self.downsampling_factor] shape = [self.latent_channels, height // self.downsampling_factor, width // self.downsampling_factor]
samples_ddim, _ = sampler.sample(S=steps, samples_ddim, _ = sampler.sample(S=steps,
conditioning=c, conditioning=c,
@ -220,24 +237,22 @@ The vast majority of these arguments default to reasonable values.
if not grid: if not grid:
for x_sample in x_samples_ddim: for x_sample in x_samples_ddim:
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
filename = self._unique_filename(outdir,previousname=filename, filename = os.path.join(outdir, f"{base_count:05}.png")
seed=seed,isbatch=(batch_size>1))
assert not os.path.exists(filename)
Image.fromarray(x_sample.astype(np.uint8)).save(filename) Image.fromarray(x_sample.astype(np.uint8)).save(filename)
images.append([filename,seed]) images.append([filename,seed])
base_count += 1
else: else:
all_samples.append(x_samples_ddim) all_samples.append(x_samples_ddim)
seeds.append(seed) seeds.append(seed)
image_count += 1 image_count += 1
seed = self._new_seed() seed = self._new_seed()
if grid: if grid:
images = self._make_grid(samples=all_samples, images = self._make_grid(samples=all_samples,
seeds=seeds, seeds=seeds,
batch_size=batch_size, batch_size=batch_size,
iterations=iterations, iterations=iterations,
outdir=outdir) outdir=outdir)
except KeyboardInterrupt: except KeyboardInterrupt:
print('*interrupted*') print('*interrupted*')
print('Partial results will be returned; if --grid was requested, nothing will be returned.') print('Partial results will be returned; if --grid was requested, nothing will be returned.')
@ -252,7 +267,7 @@ The vast majority of these arguments default to reasonable values.
# There is lots of shared code between this and txt2img and should be refactored. # There is lots of shared code between this and txt2img and should be refactored.
def img2img(self,prompt,outdir=None,init_img=None,batch_size=None,iterations=None, def img2img(self,prompt,outdir=None,init_img=None,batch_size=None,iterations=None,
steps=None,seed=None,grid=None,individual=None,width=None,height=None, steps=None,seed=None,grid=None,individual=None,width=None,height=None,
cfg_scale=None,ddim_eta=None,strength=None): cfg_scale=None,ddim_eta=None,strength=None,skip_normalize=False):
""" """
Generate an image from the prompt and the initial image, writing iteration images into the outdir Generate an image from the prompt and the initial image, writing iteration images into the outdir
The output is a list of lists in the format: [[filename1,seed1], [filename2,seed2],...] The output is a list of lists in the format: [[filename1,seed1], [filename2,seed2],...]
@ -314,9 +329,8 @@ The vast majority of these arguments default to reasonable values.
seeds = list() seeds = list()
filename = None filename = None
image_count = 0 # actual number of iterations performed image_count = 0 # actual number of iterations performed
tic = time.time()
# Gawd. Too many levels of indent here. Need to refactor into smaller routines!
try: try:
with torch.no_grad(): with torch.no_grad():
with precision_scope("cuda"): with precision_scope("cuda"):
@ -330,7 +344,22 @@ The vast majority of these arguments default to reasonable values.
uc = model.get_learned_conditioning(batch_size * [""]) uc = model.get_learned_conditioning(batch_size * [""])
if isinstance(prompts, tuple): if isinstance(prompts, tuple):
prompts = list(prompts) prompts = list(prompts)
c = model.get_learned_conditioning(prompts)
# weighted sub-prompts
subprompts,weights = T2I._split_weighted_subprompts(prompts[0])
if len(subprompts) > 1:
# i dont know if this is correct.. but it works
c = torch.zeros_like(uc)
# get total weight for normalizing
totalWeight = sum(weights)
# normalize each "sub prompt" and add it
for i in range(0,len(subprompts)):
weight = weights[i]
if not skip_normalize:
weight = weight / totalWeight
c = torch.add(c,model.get_learned_conditioning(subprompts[i]), alpha=weight)
else: # just standard 1 prompt
c = model.get_learned_conditioning(prompts)
# encode (scaled latent) # encode (scaled latent)
z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(self.device)) z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(self.device))
@ -344,14 +373,14 @@ The vast majority of these arguments default to reasonable values.
if not grid: if not grid:
for x_sample in x_samples: for x_sample in x_samples:
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
filename = self._unique_filename(outdir,filename,seed=seed,isbatch=(batch_size>1)) filename = os.path.join(outdir, f"{base_count:05}.png")
assert not os.path.exists(filename)
Image.fromarray(x_sample.astype(np.uint8)).save(filename) Image.fromarray(x_sample.astype(np.uint8)).save(filename)
images.append([filename,seed]) images.append([filename,seed])
base_count += 1
else: else:
all_samples.append(x_samples) all_samples.append(x_samples)
seeds.append(seed) seeds.append(seed)
image_count += 1 image_count +=1
seed = self._new_seed() seed = self._new_seed()
if grid: if grid:
images = self._make_grid(samples=all_samples, images = self._make_grid(samples=all_samples,
@ -361,6 +390,7 @@ The vast majority of these arguments default to reasonable values.
outdir=outdir) outdir=outdir)
except KeyboardInterrupt: except KeyboardInterrupt:
print('*interrupted*')
print('Partial results will be returned; if --grid was requested, nothing will be returned.') print('Partial results will be returned; if --grid was requested, nothing will be returned.')
except RuntimeError as e: except RuntimeError as e:
print(str(e)) print(str(e))
@ -481,3 +511,48 @@ The vast majority of these arguments default to reasonable values.
filename = f'{basecount:06}.{seed}.{series:02}.png' filename = f'{basecount:06}.{seed}.{series:02}.png'
finished = not os.path.exists(os.path.join(outdir,filename)) finished = not os.path.exists(os.path.join(outdir,filename))
return os.path.join(outdir,filename) return os.path.join(outdir,filename)
def _split_weighted_subprompts(text):
"""
grabs all text up to the first occurrence of ':'
uses the grabbed text as a sub-prompt, and takes the value following ':' as weight
if ':' has no value defined, defaults to 1.0
repeats until no text remaining
"""
remaining = len(text)
prompts = []
weights = []
while remaining > 0:
if ":" in text:
idx = text.index(":") # first occurrence from start
# grab up to index as sub-prompt
prompt = text[:idx]
remaining -= idx
# remove from main text
text = text[idx+1:]
# find value for weight
if " " in text:
idx = text.index(" ") # first occurence
else: # no space, read to end
idx = len(text)
if idx != 0:
try:
weight = float(text[:idx])
except: # couldn't treat as float
print(f"Warning: '{text[:idx]}' is not a value, are you missing a space?")
weight = 1.0
else: # no value found
weight = 1.0
# remove from main text
remaining -= idx
text = text[idx+1:]
# append the sub-prompt and its weight
prompts.append(prompt)
weights.append(weight)
else: # no : found
if len(text) > 0: # there is still text though
# take remainder as weight 1
prompts.append(text)
weights.append(1.0)
remaining = 0
return prompts, weights

View File

@ -285,6 +285,7 @@ def create_cmd_parser():
parser.add_argument('-i','--individual',action='store_true',help="generate individual files (default)") parser.add_argument('-i','--individual',action='store_true',help="generate individual files (default)")
parser.add_argument('-I','--init_img',type=str,help="path to input image (supersedes width and height)") parser.add_argument('-I','--init_img',type=str,help="path to input image (supersedes width and height)")
parser.add_argument('-f','--strength',default=0.75,type=float,help="strength for noising/unnoising. 0.0 preserves image exactly, 1.0 replaces it completely") parser.add_argument('-f','--strength',default=0.75,type=float,help="strength for noising/unnoising. 0.0 preserves image exactly, 1.0 replaces it completely")
parser.add_argument('-x','--skip_normalize',action='store_true',help="skip subprompt weight normalization")
return parser return parser
if readline_available: if readline_available: