mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat: simplify and enhance prompt weight splitting (#258)
* feat: simplify and enhance prompt weight splitting * fix: don't shadow the prompt variable * feat: enable backslash-escaped colons in prompts
This commit is contained in:
parent
d022d0dd11
commit
c52ba1b022
@ -487,22 +487,19 @@ class T2I:
|
|||||||
|
|
||||||
uc = self.model.get_learned_conditioning([''])
|
uc = self.model.get_learned_conditioning([''])
|
||||||
|
|
||||||
# weighted sub-prompts
|
# get weighted sub-prompts
|
||||||
subprompts, weights = T2I._split_weighted_subprompts(prompt)
|
weighted_subprompts = T2I._split_weighted_subprompts(prompt, skip_normalize)
|
||||||
if len(subprompts) > 1:
|
|
||||||
|
if len(weighted_subprompts) > 1:
|
||||||
# i dont know if this is correct.. but it works
|
# i dont know if this is correct.. but it works
|
||||||
c = torch.zeros_like(uc)
|
c = torch.zeros_like(uc)
|
||||||
# get total weight for normalizing
|
|
||||||
totalWeight = sum(weights)
|
|
||||||
# normalize each "sub prompt" and add it
|
# normalize each "sub prompt" and add it
|
||||||
for i in range(0, len(subprompts)):
|
for i in range(0, len(weighted_subprompts)):
|
||||||
weight = weights[i]
|
subprompt, weight = weighted_subprompts[i]
|
||||||
if not skip_normalize:
|
self._log_tokenization(subprompt)
|
||||||
weight = weight / totalWeight
|
|
||||||
self._log_tokenization(subprompts[i])
|
|
||||||
c = torch.add(
|
c = torch.add(
|
||||||
c,
|
c,
|
||||||
self.model.get_learned_conditioning([subprompts[i]]),
|
self.model.get_learned_conditioning([subprompt]),
|
||||||
alpha=weight,
|
alpha=weight,
|
||||||
)
|
)
|
||||||
else: # just standard 1 prompt
|
else: # just standard 1 prompt
|
||||||
@ -616,52 +613,36 @@ class T2I:
|
|||||||
image = torch.from_numpy(image)
|
image = torch.from_numpy(image)
|
||||||
return 2.0 * image - 1.0
|
return 2.0 * image - 1.0
|
||||||
|
|
||||||
def _split_weighted_subprompts(text):
|
def _split_weighted_subprompts(text, skip_normalize=False):
|
||||||
"""
|
"""
|
||||||
grabs all text up to the first occurrence of ':'
|
grabs all text up to the first occurrence of ':'
|
||||||
uses the grabbed text as a sub-prompt, and takes the value following ':' as weight
|
uses the grabbed text as a sub-prompt, and takes the value following ':' as weight
|
||||||
if ':' has no value defined, defaults to 1.0
|
if ':' has no value defined, defaults to 1.0
|
||||||
repeats until no text remaining
|
repeats until no text remaining
|
||||||
"""
|
"""
|
||||||
remaining = len(text)
|
prompt_parser = re.compile("""
|
||||||
prompts = []
|
(?P<prompt> # capture group for 'prompt'
|
||||||
weights = []
|
(?:\\\:|[^:])+ # match one or more non ':' characters or escaped colons '\:'
|
||||||
while remaining > 0:
|
) # end 'prompt'
|
||||||
if ':' in text:
|
(?: # non-capture group
|
||||||
idx = text.index(':') # first occurrence from start
|
:+ # match one or more ':' characters
|
||||||
# grab up to index as sub-prompt
|
(?P<weight> # capture group for 'weight'
|
||||||
prompt = text[:idx]
|
-?\d+(?:\.\d+)? # match positive or negative integer or decimal number
|
||||||
remaining -= idx
|
)? # end weight capture group, make optional
|
||||||
# remove from main text
|
\s* # strip spaces after weight
|
||||||
text = text[idx + 1 :]
|
| # OR
|
||||||
# find value for weight
|
$ # else, if no ':' then match end of line
|
||||||
if ' ' in text:
|
) # end non-capture group
|
||||||
idx = text.index(' ') # first occurence
|
""", re.VERBOSE)
|
||||||
else: # no space, read to end
|
parsed_prompts = [(match.group("prompt").replace("\\:", ":"), float(match.group("weight") or 1)) for match in re.finditer(prompt_parser, text)]
|
||||||
idx = len(text)
|
if skip_normalize:
|
||||||
if idx != 0:
|
return parsed_prompts
|
||||||
try:
|
weight_sum = sum(map(lambda x: x[1], parsed_prompts))
|
||||||
weight = float(text[:idx])
|
if weight_sum == 0:
|
||||||
except: # couldn't treat as float
|
print("Warning: Subprompt weights add up to zero. Discarding and using even weights instead.")
|
||||||
print(
|
equal_weight = 1 / len(parsed_prompts)
|
||||||
f"Warning: '{text[:idx]}' is not a value, are you missing a space?"
|
return [(x[0], equal_weight) for x in parsed_prompts]
|
||||||
)
|
return [(x[0], x[1] / weight_sum) for x in parsed_prompts]
|
||||||
weight = 1.0
|
|
||||||
else: # no value found
|
|
||||||
weight = 1.0
|
|
||||||
# remove from main text
|
|
||||||
remaining -= idx
|
|
||||||
text = text[idx + 1 :]
|
|
||||||
# append the sub-prompt and its weight
|
|
||||||
prompts.append(prompt)
|
|
||||||
weights.append(weight)
|
|
||||||
else: # no : found
|
|
||||||
if len(text) > 0: # there is still text though
|
|
||||||
# take remainder as weight 1
|
|
||||||
prompts.append(text)
|
|
||||||
weights.append(1.0)
|
|
||||||
remaining = 0
|
|
||||||
return prompts, weights
|
|
||||||
|
|
||||||
# shows how the prompt is tokenized
|
# shows how the prompt is tokenized
|
||||||
# usually tokens have '</w>' to indicate end-of-word,
|
# usually tokens have '</w>' to indicate end-of-word,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user