mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
optional weighting for creative blending of prompts
example: "an apple: a banana:0 a watermelon:0.5" the above example turns into 3 sub-prompts: "an apple" 1.0 (default if no value) "a banana" 0.0 "a watermelon" 0.5 The weights are added and normalized The resulting image will be: apple 66%, banana 0%, watermelon 33%
This commit is contained in:
parent
7cb5149a02
commit
2736d7e15e
@ -200,7 +200,21 @@ The vast majority of these arguments default to reasonable values.
|
|||||||
uc = model.get_learned_conditioning(batch_size * [""])
|
uc = model.get_learned_conditioning(batch_size * [""])
|
||||||
if isinstance(prompts, tuple):
|
if isinstance(prompts, tuple):
|
||||||
prompts = list(prompts)
|
prompts = list(prompts)
|
||||||
|
|
||||||
|
# weighted sub-prompts
|
||||||
|
subprompts,weights = T2I.split_weighted_subprompts(prompts[0])
|
||||||
|
if len(subprompts) > 1:
|
||||||
|
# i dont know if this is correct.. but it works
|
||||||
|
c = torch.zeros_like(uc)
|
||||||
|
# get total weight for normalizing
|
||||||
|
totalWeight = sum(weights)
|
||||||
|
# normalize each "sub prompt" and add it
|
||||||
|
for i in range(0,len(subprompts)):
|
||||||
|
weight = weights[i] / totalWeight
|
||||||
|
c = torch.add(c,model.get_learned_conditioning(subprompts[i]), alpha=weight)
|
||||||
|
else: # just standard 1 prompt
|
||||||
c = model.get_learned_conditioning(prompts)
|
c = model.get_learned_conditioning(prompts)
|
||||||
|
|
||||||
shape = [self.latent_channels, height // self.downsampling_factor, width // self.downsampling_factor]
|
shape = [self.latent_channels, height // self.downsampling_factor, width // self.downsampling_factor]
|
||||||
samples_ddim, _ = sampler.sample(S=steps,
|
samples_ddim, _ = sampler.sample(S=steps,
|
||||||
conditioning=c,
|
conditioning=c,
|
||||||
@ -319,6 +333,19 @@ The vast majority of these arguments default to reasonable values.
|
|||||||
uc = model.get_learned_conditioning(batch_size * [""])
|
uc = model.get_learned_conditioning(batch_size * [""])
|
||||||
if isinstance(prompts, tuple):
|
if isinstance(prompts, tuple):
|
||||||
prompts = list(prompts)
|
prompts = list(prompts)
|
||||||
|
|
||||||
|
# weighted sub-prompts
|
||||||
|
subprompts,weights = T2I.split_weighted_subprompts(prompts[0])
|
||||||
|
if len(subprompts) > 1:
|
||||||
|
# i dont know if this is correct.. but it works
|
||||||
|
c = torch.zeros_like(uc)
|
||||||
|
# get total weight for normalizing
|
||||||
|
totalWeight = sum(weights)
|
||||||
|
# normalize each "sub prompt" and add it
|
||||||
|
for i in range(0,len(subprompts)):
|
||||||
|
weight = weights[i] / totalWeight
|
||||||
|
c = torch.add(c,model.get_learned_conditioning(subprompts[i]), alpha=weight)
|
||||||
|
else: # just standard 1 prompt
|
||||||
c = model.get_learned_conditioning(prompts)
|
c = model.get_learned_conditioning(prompts)
|
||||||
|
|
||||||
# encode (scaled latent)
|
# encode (scaled latent)
|
||||||
@ -430,3 +457,53 @@ The vast majority of these arguments default to reasonable values.
|
|||||||
image = image[None].transpose(0, 3, 1, 2)
|
image = image[None].transpose(0, 3, 1, 2)
|
||||||
image = torch.from_numpy(image)
|
image = torch.from_numpy(image)
|
||||||
return 2.*image - 1.
|
return 2.*image - 1.
|
||||||
|
"""
|
||||||
|
example: "an apple: a banana:0 a watermelon:0.5"
|
||||||
|
grabs all text up to the first occurance of ':'
|
||||||
|
then removes the text, repeating until no characters left.
|
||||||
|
if ':' has no weight defined, defaults to 1.0
|
||||||
|
|
||||||
|
the above example turns into 3 sub-prompts:
|
||||||
|
"an apple" 1.0
|
||||||
|
"a banana" 0.0
|
||||||
|
"a watermelon" 0.5
|
||||||
|
The weights are added and normalized
|
||||||
|
The resulting image will be: apple 66% (1.0 / 1.5), banana 0%, watermelon 33% (0.5 / 1.5)
|
||||||
|
"""
|
||||||
|
def split_weighted_subprompts(text):
|
||||||
|
# very simple, uses : to separate sub-prompts
|
||||||
|
# assumes number following : and space after number
|
||||||
|
# if no number found, defaults to 1.0
|
||||||
|
remaining = len(text)
|
||||||
|
prompts = []
|
||||||
|
weights = []
|
||||||
|
while remaining > 0:
|
||||||
|
# find :
|
||||||
|
if ":" in text:
|
||||||
|
idx = text.index(":") # first occurrance from start
|
||||||
|
# snip sub prompt
|
||||||
|
prompt = text[:idx]
|
||||||
|
remaining -= idx
|
||||||
|
# remove from main text
|
||||||
|
text = text[idx+1:]
|
||||||
|
# get number
|
||||||
|
if " " in text:
|
||||||
|
idx = text.index(" ") # first occurance
|
||||||
|
else: # no space, read to end
|
||||||
|
idx = len(text)
|
||||||
|
if idx != 0:
|
||||||
|
weight = float(text[:idx])
|
||||||
|
else: # no number to grab
|
||||||
|
weight = 1.0
|
||||||
|
# remove
|
||||||
|
remaining -= idx
|
||||||
|
text = text[idx+1:]
|
||||||
|
prompts.append(prompt)
|
||||||
|
weights.append(weight)
|
||||||
|
else:
|
||||||
|
if len(text) > 0:
|
||||||
|
# take what remains as weight 1
|
||||||
|
prompts.append(text)
|
||||||
|
weights.append(1.0)
|
||||||
|
remaining = 0
|
||||||
|
return prompts, weights
|
Loading…
x
Reference in New Issue
Block a user