mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
prompt weighting not working
This commit is contained in:
commit
11c0df07b7
103
ldm/simplet2i.py
103
ldm/simplet2i.py
@ -143,7 +143,7 @@ The vast majority of these arguments default to reasonable values.
|
|||||||
|
|
||||||
def txt2img(self,prompt,outdir=None,batch_size=None,iterations=None,
|
def txt2img(self,prompt,outdir=None,batch_size=None,iterations=None,
|
||||||
steps=None,seed=None,grid=None,individual=None,width=None,height=None,
|
steps=None,seed=None,grid=None,individual=None,width=None,height=None,
|
||||||
cfg_scale=None,ddim_eta=None,strength=None,init_img=None):
|
cfg_scale=None,ddim_eta=None,strength=None,init_img=None,skip_normalize=False):
|
||||||
"""
|
"""
|
||||||
Generate an image from the prompt, writing iteration images into the outdir
|
Generate an image from the prompt, writing iteration images into the outdir
|
||||||
The output is a list of lists in the format: [[filename1,seed1], [filename2,seed2],...]
|
The output is a list of lists in the format: [[filename1,seed1], [filename2,seed2],...]
|
||||||
@ -189,6 +189,7 @@ The vast majority of these arguments default to reasonable values.
|
|||||||
image_count = 0
|
image_count = 0
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
|
|
||||||
|
# Gawd. Too many levels of indent here. Need to refactor into smaller routines!
|
||||||
try:
|
try:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
with precision_scope("cuda"):
|
with precision_scope("cuda"):
|
||||||
@ -202,7 +203,23 @@ The vast majority of these arguments default to reasonable values.
|
|||||||
uc = model.get_learned_conditioning(batch_size * [""])
|
uc = model.get_learned_conditioning(batch_size * [""])
|
||||||
if isinstance(prompts, tuple):
|
if isinstance(prompts, tuple):
|
||||||
prompts = list(prompts)
|
prompts = list(prompts)
|
||||||
c = model.get_learned_conditioning(prompts)
|
|
||||||
|
# weighted sub-prompts
|
||||||
|
subprompts,weights = T2I._split_weighted_subprompts(prompts[0])
|
||||||
|
if len(subprompts) > 1:
|
||||||
|
# i dont know if this is correct.. but it works
|
||||||
|
c = torch.zeros_like(uc)
|
||||||
|
# get total weight for normalizing
|
||||||
|
totalWeight = sum(weights)
|
||||||
|
# normalize each "sub prompt" and add it
|
||||||
|
for i in range(0,len(subprompts)):
|
||||||
|
weight = weights[i]
|
||||||
|
if not skip_normalize:
|
||||||
|
weight = weight / totalWeight
|
||||||
|
c = torch.add(c,model.get_learned_conditioning(subprompts[i]), alpha=weight)
|
||||||
|
else: # just standard 1 prompt
|
||||||
|
c = model.get_learned_conditioning(prompts)
|
||||||
|
|
||||||
shape = [self.latent_channels, height // self.downsampling_factor, width // self.downsampling_factor]
|
shape = [self.latent_channels, height // self.downsampling_factor, width // self.downsampling_factor]
|
||||||
samples_ddim, _ = sampler.sample(S=steps,
|
samples_ddim, _ = sampler.sample(S=steps,
|
||||||
conditioning=c,
|
conditioning=c,
|
||||||
@ -220,24 +237,22 @@ The vast majority of these arguments default to reasonable values.
|
|||||||
if not grid:
|
if not grid:
|
||||||
for x_sample in x_samples_ddim:
|
for x_sample in x_samples_ddim:
|
||||||
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
|
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
|
||||||
filename = self._unique_filename(outdir,previousname=filename,
|
filename = os.path.join(outdir, f"{base_count:05}.png")
|
||||||
seed=seed,isbatch=(batch_size>1))
|
|
||||||
assert not os.path.exists(filename)
|
|
||||||
Image.fromarray(x_sample.astype(np.uint8)).save(filename)
|
Image.fromarray(x_sample.astype(np.uint8)).save(filename)
|
||||||
images.append([filename,seed])
|
images.append([filename,seed])
|
||||||
|
base_count += 1
|
||||||
else:
|
else:
|
||||||
all_samples.append(x_samples_ddim)
|
all_samples.append(x_samples_ddim)
|
||||||
seeds.append(seed)
|
seeds.append(seed)
|
||||||
|
|
||||||
image_count += 1
|
image_count += 1
|
||||||
seed = self._new_seed()
|
seed = self._new_seed()
|
||||||
|
|
||||||
if grid:
|
if grid:
|
||||||
images = self._make_grid(samples=all_samples,
|
images = self._make_grid(samples=all_samples,
|
||||||
seeds=seeds,
|
seeds=seeds,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
iterations=iterations,
|
iterations=iterations,
|
||||||
outdir=outdir)
|
outdir=outdir)
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
print('*interrupted*')
|
print('*interrupted*')
|
||||||
print('Partial results will be returned; if --grid was requested, nothing will be returned.')
|
print('Partial results will be returned; if --grid was requested, nothing will be returned.')
|
||||||
@ -252,7 +267,7 @@ The vast majority of these arguments default to reasonable values.
|
|||||||
# There is lots of shared code between this and txt2img and should be refactored.
|
# There is lots of shared code between this and txt2img and should be refactored.
|
||||||
def img2img(self,prompt,outdir=None,init_img=None,batch_size=None,iterations=None,
|
def img2img(self,prompt,outdir=None,init_img=None,batch_size=None,iterations=None,
|
||||||
steps=None,seed=None,grid=None,individual=None,width=None,height=None,
|
steps=None,seed=None,grid=None,individual=None,width=None,height=None,
|
||||||
cfg_scale=None,ddim_eta=None,strength=None):
|
cfg_scale=None,ddim_eta=None,strength=None,skip_normalize=False):
|
||||||
"""
|
"""
|
||||||
Generate an image from the prompt and the initial image, writing iteration images into the outdir
|
Generate an image from the prompt and the initial image, writing iteration images into the outdir
|
||||||
The output is a list of lists in the format: [[filename1,seed1], [filename2,seed2],...]
|
The output is a list of lists in the format: [[filename1,seed1], [filename2,seed2],...]
|
||||||
@ -314,9 +329,8 @@ The vast majority of these arguments default to reasonable values.
|
|||||||
seeds = list()
|
seeds = list()
|
||||||
filename = None
|
filename = None
|
||||||
image_count = 0 # actual number of iterations performed
|
image_count = 0 # actual number of iterations performed
|
||||||
|
|
||||||
tic = time.time()
|
|
||||||
|
|
||||||
|
# Gawd. Too many levels of indent here. Need to refactor into smaller routines!
|
||||||
try:
|
try:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
with precision_scope("cuda"):
|
with precision_scope("cuda"):
|
||||||
@ -330,7 +344,22 @@ The vast majority of these arguments default to reasonable values.
|
|||||||
uc = model.get_learned_conditioning(batch_size * [""])
|
uc = model.get_learned_conditioning(batch_size * [""])
|
||||||
if isinstance(prompts, tuple):
|
if isinstance(prompts, tuple):
|
||||||
prompts = list(prompts)
|
prompts = list(prompts)
|
||||||
c = model.get_learned_conditioning(prompts)
|
|
||||||
|
# weighted sub-prompts
|
||||||
|
subprompts,weights = T2I._split_weighted_subprompts(prompts[0])
|
||||||
|
if len(subprompts) > 1:
|
||||||
|
# i dont know if this is correct.. but it works
|
||||||
|
c = torch.zeros_like(uc)
|
||||||
|
# get total weight for normalizing
|
||||||
|
totalWeight = sum(weights)
|
||||||
|
# normalize each "sub prompt" and add it
|
||||||
|
for i in range(0,len(subprompts)):
|
||||||
|
weight = weights[i]
|
||||||
|
if not skip_normalize:
|
||||||
|
weight = weight / totalWeight
|
||||||
|
c = torch.add(c,model.get_learned_conditioning(subprompts[i]), alpha=weight)
|
||||||
|
else: # just standard 1 prompt
|
||||||
|
c = model.get_learned_conditioning(prompts)
|
||||||
|
|
||||||
# encode (scaled latent)
|
# encode (scaled latent)
|
||||||
z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(self.device))
|
z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(self.device))
|
||||||
@ -344,14 +373,14 @@ The vast majority of these arguments default to reasonable values.
|
|||||||
if not grid:
|
if not grid:
|
||||||
for x_sample in x_samples:
|
for x_sample in x_samples:
|
||||||
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
|
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
|
||||||
filename = self._unique_filename(outdir,filename,seed=seed,isbatch=(batch_size>1))
|
filename = os.path.join(outdir, f"{base_count:05}.png")
|
||||||
assert not os.path.exists(filename)
|
|
||||||
Image.fromarray(x_sample.astype(np.uint8)).save(filename)
|
Image.fromarray(x_sample.astype(np.uint8)).save(filename)
|
||||||
images.append([filename,seed])
|
images.append([filename,seed])
|
||||||
|
base_count += 1
|
||||||
else:
|
else:
|
||||||
all_samples.append(x_samples)
|
all_samples.append(x_samples)
|
||||||
seeds.append(seed)
|
seeds.append(seed)
|
||||||
image_count += 1
|
image_count +=1
|
||||||
seed = self._new_seed()
|
seed = self._new_seed()
|
||||||
if grid:
|
if grid:
|
||||||
images = self._make_grid(samples=all_samples,
|
images = self._make_grid(samples=all_samples,
|
||||||
@ -361,6 +390,7 @@ The vast majority of these arguments default to reasonable values.
|
|||||||
outdir=outdir)
|
outdir=outdir)
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
|
print('*interrupted*')
|
||||||
print('Partial results will be returned; if --grid was requested, nothing will be returned.')
|
print('Partial results will be returned; if --grid was requested, nothing will be returned.')
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
print(str(e))
|
print(str(e))
|
||||||
@ -481,3 +511,48 @@ The vast majority of these arguments default to reasonable values.
|
|||||||
filename = f'{basecount:06}.{seed}.{series:02}.png'
|
filename = f'{basecount:06}.{seed}.{series:02}.png'
|
||||||
finished = not os.path.exists(os.path.join(outdir,filename))
|
finished = not os.path.exists(os.path.join(outdir,filename))
|
||||||
return os.path.join(outdir,filename)
|
return os.path.join(outdir,filename)
|
||||||
|
|
||||||
|
def _split_weighted_subprompts(text):
|
||||||
|
"""
|
||||||
|
grabs all text up to the first occurrence of ':'
|
||||||
|
uses the grabbed text as a sub-prompt, and takes the value following ':' as weight
|
||||||
|
if ':' has no value defined, defaults to 1.0
|
||||||
|
repeats until no text remaining
|
||||||
|
"""
|
||||||
|
remaining = len(text)
|
||||||
|
prompts = []
|
||||||
|
weights = []
|
||||||
|
while remaining > 0:
|
||||||
|
if ":" in text:
|
||||||
|
idx = text.index(":") # first occurrence from start
|
||||||
|
# grab up to index as sub-prompt
|
||||||
|
prompt = text[:idx]
|
||||||
|
remaining -= idx
|
||||||
|
# remove from main text
|
||||||
|
text = text[idx+1:]
|
||||||
|
# find value for weight
|
||||||
|
if " " in text:
|
||||||
|
idx = text.index(" ") # first occurence
|
||||||
|
else: # no space, read to end
|
||||||
|
idx = len(text)
|
||||||
|
if idx != 0:
|
||||||
|
try:
|
||||||
|
weight = float(text[:idx])
|
||||||
|
except: # couldn't treat as float
|
||||||
|
print(f"Warning: '{text[:idx]}' is not a value, are you missing a space?")
|
||||||
|
weight = 1.0
|
||||||
|
else: # no value found
|
||||||
|
weight = 1.0
|
||||||
|
# remove from main text
|
||||||
|
remaining -= idx
|
||||||
|
text = text[idx+1:]
|
||||||
|
# append the sub-prompt and its weight
|
||||||
|
prompts.append(prompt)
|
||||||
|
weights.append(weight)
|
||||||
|
else: # no : found
|
||||||
|
if len(text) > 0: # there is still text though
|
||||||
|
# take remainder as weight 1
|
||||||
|
prompts.append(text)
|
||||||
|
weights.append(1.0)
|
||||||
|
remaining = 0
|
||||||
|
return prompts, weights
|
||||||
|
@ -285,6 +285,7 @@ def create_cmd_parser():
|
|||||||
parser.add_argument('-i','--individual',action='store_true',help="generate individual files (default)")
|
parser.add_argument('-i','--individual',action='store_true',help="generate individual files (default)")
|
||||||
parser.add_argument('-I','--init_img',type=str,help="path to input image (supersedes width and height)")
|
parser.add_argument('-I','--init_img',type=str,help="path to input image (supersedes width and height)")
|
||||||
parser.add_argument('-f','--strength',default=0.75,type=float,help="strength for noising/unnoising. 0.0 preserves image exactly, 1.0 replaces it completely")
|
parser.add_argument('-f','--strength',default=0.75,type=float,help="strength for noising/unnoising. 0.0 preserves image exactly, 1.0 replaces it completely")
|
||||||
|
parser.add_argument('-x','--skip_normalize',action='store_true',help="skip subprompt weight normalization")
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
if readline_available:
|
if readline_available:
|
||||||
|
Loading…
Reference in New Issue
Block a user