diff --git a/ldm/simplet2i.py b/ldm/simplet2i.py index 0e31c185ad..a28670fc05 100644 --- a/ldm/simplet2i.py +++ b/ldm/simplet2i.py @@ -483,22 +483,19 @@ class T2I: uc = self.model.get_learned_conditioning(['']) - # weighted sub-prompts - subprompts, weights = T2I._split_weighted_subprompts(prompt) - if len(subprompts) > 1: + # get weighted sub-prompts + weighted_subprompts = T2I._split_weighted_subprompts(prompt, skip_normalize) + + if len(weighted_subprompts) > 1: # i dont know if this is correct.. but it works c = torch.zeros_like(uc) - # get total weight for normalizing - totalWeight = sum(weights) # normalize each "sub prompt" and add it - for i in range(0, len(subprompts)): - weight = weights[i] - if not skip_normalize: - weight = weight / totalWeight - self._log_tokenization(subprompts[i]) + for i in range(0, len(weighted_subprompts)): + subprompt, weight = weighted_subprompts[i] + self._log_tokenization(subprompt) c = torch.add( c, - self.model.get_learned_conditioning([subprompts[i]]), + self.model.get_learned_conditioning([subprompt]), alpha=weight, ) else: # just standard 1 prompt @@ -630,55 +627,39 @@ class T2I: image = torch.from_numpy(image) return 2.0 * image - 1.0 - def _split_weighted_subprompts(text): + def _split_weighted_subprompts(text, skip_normalize=False): """ grabs all text up to the first occurrence of ':' uses the grabbed text as a sub-prompt, and takes the value following ':' as weight if ':' has no value defined, defaults to 1.0 repeats until no text remaining """ - remaining = len(text) - prompts = [] - weights = [] - while remaining > 0: - if ':' in text: - idx = text.index(':') # first occurrence from start - # grab up to index as sub-prompt - prompt = text[:idx] - remaining -= idx - # remove from main text - text = text[idx + 1:] - # find value for weight - if ' ' in text: - idx = text.index(' ') # first occurence - else: # no space, read to end - idx = len(text) - if idx != 0: - try: - weight = float(text[:idx]) - except: # couldn't treat as float - print( - f"Warning: '{text[:idx]}' is not a value, are you missing a space?" - ) - weight = 1.0 - else: # no value found - weight = 1.0 - # remove from main text - remaining -= idx - text = text[idx + 1:] - # append the sub-prompt and its weight - prompts.append(prompt) - weights.append(weight) - else: # no : found - if len(text) > 0: # there is still text though - # take remainder as weight 1 - prompts.append(text) - weights.append(1.0) - remaining = 0 - return prompts, weights - - # shows how the prompt is tokenized - # usually tokens have '' to indicate end-of-word, + prompt_parser = re.compile(""" + (?P # capture group for 'prompt' + (?:\\\:|[^:])+ # match one or more non ':' characters or escaped colons '\:' + ) # end 'prompt' + (?: # non-capture group + :+ # match one or more ':' characters + (?P # capture group for 'weight' + -?\d+(?:\.\d+)? # match positive or negative integer or decimal number + )? # end weight capture group, make optional + \s* # strip spaces after weight + | # OR + $ # else, if no ':' then match end of line + ) # end non-capture group + """, re.VERBOSE) + parsed_prompts = [(match.group("prompt").replace("\\:", ":"), float(match.group("weight") or 1)) for match in re.finditer(prompt_parser, text)] + if skip_normalize: + return parsed_prompts + weight_sum = sum(map(lambda x: x[1], parsed_prompts)) + if weight_sum == 0: + print("Warning: Subprompt weights add up to zero. Discarding and using even weights instead.") + equal_weight = 1 / len(parsed_prompts) + return [(x[0], equal_weight) for x in parsed_prompts] + return [(x[0], x[1] / weight_sum) for x in parsed_prompts] + + # shows how the prompt is tokenized + # usually tokens have '' to indicate end-of-word, # but for readability it has been replaced with ' ' def _log_tokenization(self, text): if not self.log_tokenization: diff --git a/scripts/dream.py b/scripts/dream.py index 50be6dfa7c..2911e8847a 100755 --- a/scripts/dream.py +++ b/scripts/dream.py @@ -29,7 +29,10 @@ def main(): width = 512 height = 512 config = 'configs/stable-diffusion/v1-inference.yaml' - weights = 'models/ldm/stable-diffusion-v1/model.ckpt' + if '.ckpt' in opt.weights: + weights = opt.weights + else: + weights = f'models/ldm/stable-diffusion-v1/{opt.weights}.ckpt' print('* Initializing, be patient...\n') sys.path.append('.') @@ -418,6 +421,11 @@ def create_argv_parser(): action='store_true', help='Start in web server mode.', ) + parser.add_argument( + '--weights', + default='model', + help='Indicates the Stable Diffusion model to use.', + ) return parser diff --git a/static/dream_web/index.js b/static/dream_web/index.js index 76a76a53a3..cbd66366f4 100644 --- a/static/dream_web/index.js +++ b/static/dream_web/index.js @@ -98,7 +98,6 @@ async function generateSubmit(form) { appendOutput(data.url, data.seed, data.config); progressEle.setAttribute('value', 0); progressEle.setAttribute('max', totalSteps); - progressImageEle.src = BLANK_IMAGE_URL; } else if (data.event === 'upscaling-started') { document.getElementById("processing_cnt").textContent=data.processed_file_cnt; document.getElementById("scaling-inprocess-message").style.display = "block";