optional weighting for creative blending of prompts

example: "an apple: a banana:0 a watermelon:0.5"
        the above example turns into 3 sub-prompts:
        "an apple" 1.0 (default if no value)
        "a banana" 0.0
        "a watermelon" 0.5
        The weights are added and normalized
        The resulting image will be: apple 66%, banana 0%, watermelon 33%
This commit is contained in:
xra 2022-08-22 22:59:06 +09:00
parent 7cb5149a02
commit 2736d7e15e

View File

@ -200,7 +200,21 @@ The vast majority of these arguments default to reasonable values.
uc = model.get_learned_conditioning(batch_size * [""])
if isinstance(prompts, tuple):
prompts = list(prompts)
c = model.get_learned_conditioning(prompts)
# weighted sub-prompts
subprompts,weights = T2I.split_weighted_subprompts(prompts[0])
if len(subprompts) > 1:
# i dont know if this is correct.. but it works
c = torch.zeros_like(uc)
# get total weight for normalizing
totalWeight = sum(weights)
# normalize each "sub prompt" and add it
for i in range(0,len(subprompts)):
weight = weights[i] / totalWeight
c = torch.add(c,model.get_learned_conditioning(subprompts[i]), alpha=weight)
else: # just standard 1 prompt
c = model.get_learned_conditioning(prompts)
shape = [self.latent_channels, height // self.downsampling_factor, width // self.downsampling_factor]
samples_ddim, _ = sampler.sample(S=steps,
conditioning=c,
@ -319,7 +333,20 @@ The vast majority of these arguments default to reasonable values.
uc = model.get_learned_conditioning(batch_size * [""])
if isinstance(prompts, tuple):
prompts = list(prompts)
c = model.get_learned_conditioning(prompts)
# weighted sub-prompts
subprompts,weights = T2I.split_weighted_subprompts(prompts[0])
if len(subprompts) > 1:
# i dont know if this is correct.. but it works
c = torch.zeros_like(uc)
# get total weight for normalizing
totalWeight = sum(weights)
# normalize each "sub prompt" and add it
for i in range(0,len(subprompts)):
weight = weights[i] / totalWeight
c = torch.add(c,model.get_learned_conditioning(subprompts[i]), alpha=weight)
else: # just standard 1 prompt
c = model.get_learned_conditioning(prompts)
# encode (scaled latent)
z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(self.device))
@ -430,3 +457,53 @@ The vast majority of these arguments default to reasonable values.
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return 2.*image - 1.
"""
example: "an apple: a banana:0 a watermelon:0.5"
grabs all text up to the first occurance of ':'
then removes the text, repeating until no characters left.
if ':' has no weight defined, defaults to 1.0
the above example turns into 3 sub-prompts:
"an apple" 1.0
"a banana" 0.0
"a watermelon" 0.5
The weights are added and normalized
The resulting image will be: apple 66% (1.0 / 1.5), banana 0%, watermelon 33% (0.5 / 1.5)
"""
def split_weighted_subprompts(text):
# very simple, uses : to separate sub-prompts
# assumes number following : and space after number
# if no number found, defaults to 1.0
remaining = len(text)
prompts = []
weights = []
while remaining > 0:
# find :
if ":" in text:
idx = text.index(":") # first occurrance from start
# snip sub prompt
prompt = text[:idx]
remaining -= idx
# remove from main text
text = text[idx+1:]
# get number
if " " in text:
idx = text.index(" ") # first occurance
else: # no space, read to end
idx = len(text)
if idx != 0:
weight = float(text[:idx])
else: # no number to grab
weight = 1.0
# remove
remaining -= idx
text = text[idx+1:]
prompts.append(prompt)
weights.append(weight)
else:
if len(text) > 0:
# take what remains as weight 1
prompts.append(text)
weights.append(1.0)
remaining = 0
return prompts, weights