Merge branch 'main' of github.com:lstein/stable-diffusion into main

This commit is contained in:
Lincoln Stein 2022-08-31 14:44:18 -04:00
commit 9ad79207c2
3 changed files with 44 additions and 56 deletions

View File

@ -483,22 +483,19 @@ class T2I:
uc = self.model.get_learned_conditioning(['']) uc = self.model.get_learned_conditioning([''])
# weighted sub-prompts # get weighted sub-prompts
subprompts, weights = T2I._split_weighted_subprompts(prompt) weighted_subprompts = T2I._split_weighted_subprompts(prompt, skip_normalize)
if len(subprompts) > 1:
if len(weighted_subprompts) > 1:
# i dont know if this is correct.. but it works # i dont know if this is correct.. but it works
c = torch.zeros_like(uc) c = torch.zeros_like(uc)
# get total weight for normalizing
totalWeight = sum(weights)
# normalize each "sub prompt" and add it # normalize each "sub prompt" and add it
for i in range(0, len(subprompts)): for i in range(0, len(weighted_subprompts)):
weight = weights[i] subprompt, weight = weighted_subprompts[i]
if not skip_normalize: self._log_tokenization(subprompt)
weight = weight / totalWeight
self._log_tokenization(subprompts[i])
c = torch.add( c = torch.add(
c, c,
self.model.get_learned_conditioning([subprompts[i]]), self.model.get_learned_conditioning([subprompt]),
alpha=weight, alpha=weight,
) )
else: # just standard 1 prompt else: # just standard 1 prompt
@ -630,52 +627,36 @@ class T2I:
image = torch.from_numpy(image) image = torch.from_numpy(image)
return 2.0 * image - 1.0 return 2.0 * image - 1.0
def _split_weighted_subprompts(text): def _split_weighted_subprompts(text, skip_normalize=False):
""" """
grabs all text up to the first occurrence of ':' grabs all text up to the first occurrence of ':'
uses the grabbed text as a sub-prompt, and takes the value following ':' as weight uses the grabbed text as a sub-prompt, and takes the value following ':' as weight
if ':' has no value defined, defaults to 1.0 if ':' has no value defined, defaults to 1.0
repeats until no text remaining repeats until no text remaining
""" """
remaining = len(text) prompt_parser = re.compile("""
prompts = [] (?P<prompt> # capture group for 'prompt'
weights = [] (?:\\\:|[^:])+ # match one or more non ':' characters or escaped colons '\:'
while remaining > 0: ) # end 'prompt'
if ':' in text: (?: # non-capture group
idx = text.index(':') # first occurrence from start :+ # match one or more ':' characters
# grab up to index as sub-prompt (?P<weight> # capture group for 'weight'
prompt = text[:idx] -?\d+(?:\.\d+)? # match positive or negative integer or decimal number
remaining -= idx )? # end weight capture group, make optional
# remove from main text \s* # strip spaces after weight
text = text[idx + 1:] | # OR
# find value for weight $ # else, if no ':' then match end of line
if ' ' in text: ) # end non-capture group
idx = text.index(' ') # first occurence """, re.VERBOSE)
else: # no space, read to end parsed_prompts = [(match.group("prompt").replace("\\:", ":"), float(match.group("weight") or 1)) for match in re.finditer(prompt_parser, text)]
idx = len(text) if skip_normalize:
if idx != 0: return parsed_prompts
try: weight_sum = sum(map(lambda x: x[1], parsed_prompts))
weight = float(text[:idx]) if weight_sum == 0:
except: # couldn't treat as float print("Warning: Subprompt weights add up to zero. Discarding and using even weights instead.")
print( equal_weight = 1 / len(parsed_prompts)
f"Warning: '{text[:idx]}' is not a value, are you missing a space?" return [(x[0], equal_weight) for x in parsed_prompts]
) return [(x[0], x[1] / weight_sum) for x in parsed_prompts]
weight = 1.0
else: # no value found
weight = 1.0
# remove from main text
remaining -= idx
text = text[idx + 1:]
# append the sub-prompt and its weight
prompts.append(prompt)
weights.append(weight)
else: # no : found
if len(text) > 0: # there is still text though
# take remainder as weight 1
prompts.append(text)
weights.append(1.0)
remaining = 0
return prompts, weights
# shows how the prompt is tokenized # shows how the prompt is tokenized
# usually tokens have '</w>' to indicate end-of-word, # usually tokens have '</w>' to indicate end-of-word,

View File

@ -29,7 +29,10 @@ def main():
width = 512 width = 512
height = 512 height = 512
config = 'configs/stable-diffusion/v1-inference.yaml' config = 'configs/stable-diffusion/v1-inference.yaml'
weights = 'models/ldm/stable-diffusion-v1/model.ckpt' if '.ckpt' in opt.weights:
weights = opt.weights
else:
weights = f'models/ldm/stable-diffusion-v1/{opt.weights}.ckpt'
print('* Initializing, be patient...\n') print('* Initializing, be patient...\n')
sys.path.append('.') sys.path.append('.')
@ -418,6 +421,11 @@ def create_argv_parser():
action='store_true', action='store_true',
help='Start in web server mode.', help='Start in web server mode.',
) )
parser.add_argument(
'--weights',
default='model',
help='Indicates the Stable Diffusion model to use.',
)
return parser return parser

View File

@ -98,7 +98,6 @@ async function generateSubmit(form) {
appendOutput(data.url, data.seed, data.config); appendOutput(data.url, data.seed, data.config);
progressEle.setAttribute('value', 0); progressEle.setAttribute('value', 0);
progressEle.setAttribute('max', totalSteps); progressEle.setAttribute('max', totalSteps);
progressImageEle.src = BLANK_IMAGE_URL;
} else if (data.event === 'upscaling-started') { } else if (data.event === 'upscaling-started') {
document.getElementById("processing_cnt").textContent=data.processed_file_cnt; document.getElementById("processing_cnt").textContent=data.processed_file_cnt;
document.getElementById("scaling-inprocess-message").style.display = "block"; document.getElementById("scaling-inprocess-message").style.display = "block";