diff --git a/ldm/simplet2i.py b/ldm/simplet2i.py index daae2293d9..95e64f24fb 100644 --- a/ldm/simplet2i.py +++ b/ldm/simplet2i.py @@ -143,7 +143,7 @@ The vast majority of these arguments default to reasonable values. def txt2img(self,prompt,outdir=None,batch_size=None,iterations=None, steps=None,seed=None,grid=None,individual=None,width=None,height=None, - cfg_scale=None,ddim_eta=None,strength=None,init_img=None): + cfg_scale=None,ddim_eta=None,strength=None,init_img=None,skip_normalize=False): """ Generate an image from the prompt, writing iteration images into the outdir The output is a list of lists in the format: [[filename1,seed1], [filename2,seed2],...] @@ -189,6 +189,7 @@ The vast majority of these arguments default to reasonable values. image_count = 0 tic = time.time() + # Gawd. Too many levels of indent here. Need to refactor into smaller routines! try: with torch.no_grad(): with precision_scope("cuda"): @@ -202,7 +203,23 @@ The vast majority of these arguments default to reasonable values. uc = model.get_learned_conditioning(batch_size * [""]) if isinstance(prompts, tuple): prompts = list(prompts) - c = model.get_learned_conditioning(prompts) + + # weighted sub-prompts + subprompts,weights = T2I._split_weighted_subprompts(prompts[0]) + if len(subprompts) > 1: + # i dont know if this is correct.. but it works + c = torch.zeros_like(uc) + # get total weight for normalizing + totalWeight = sum(weights) + # normalize each "sub prompt" and add it + for i in range(0,len(subprompts)): + weight = weights[i] + if not skip_normalize: + weight = weight / totalWeight + c = torch.add(c,model.get_learned_conditioning(subprompts[i]), alpha=weight) + else: # just standard 1 prompt + c = model.get_learned_conditioning(prompts) + shape = [self.latent_channels, height // self.downsampling_factor, width // self.downsampling_factor] samples_ddim, _ = sampler.sample(S=steps, conditioning=c, @@ -220,24 +237,22 @@ The vast majority of these arguments default to reasonable values. if not grid: for x_sample in x_samples_ddim: x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') - filename = self._unique_filename(outdir,previousname=filename, - seed=seed,isbatch=(batch_size>1)) - assert not os.path.exists(filename) + filename = os.path.join(outdir, f"{base_count:05}.png") Image.fromarray(x_sample.astype(np.uint8)).save(filename) images.append([filename,seed]) + base_count += 1 else: all_samples.append(x_samples_ddim) seeds.append(seed) + image_count += 1 seed = self._new_seed() - if grid: images = self._make_grid(samples=all_samples, seeds=seeds, batch_size=batch_size, iterations=iterations, outdir=outdir) - except KeyboardInterrupt: print('*interrupted*') print('Partial results will be returned; if --grid was requested, nothing will be returned.') @@ -252,7 +267,7 @@ The vast majority of these arguments default to reasonable values. # There is lots of shared code between this and txt2img and should be refactored. def img2img(self,prompt,outdir=None,init_img=None,batch_size=None,iterations=None, steps=None,seed=None,grid=None,individual=None,width=None,height=None, - cfg_scale=None,ddim_eta=None,strength=None): + cfg_scale=None,ddim_eta=None,strength=None,skip_normalize=False): """ Generate an image from the prompt and the initial image, writing iteration images into the outdir The output is a list of lists in the format: [[filename1,seed1], [filename2,seed2],...] @@ -314,9 +329,8 @@ The vast majority of these arguments default to reasonable values. seeds = list() filename = None image_count = 0 # actual number of iterations performed - - tic = time.time() + # Gawd. Too many levels of indent here. Need to refactor into smaller routines! try: with torch.no_grad(): with precision_scope("cuda"): @@ -330,7 +344,22 @@ The vast majority of these arguments default to reasonable values. uc = model.get_learned_conditioning(batch_size * [""]) if isinstance(prompts, tuple): prompts = list(prompts) - c = model.get_learned_conditioning(prompts) + + # weighted sub-prompts + subprompts,weights = T2I._split_weighted_subprompts(prompts[0]) + if len(subprompts) > 1: + # i dont know if this is correct.. but it works + c = torch.zeros_like(uc) + # get total weight for normalizing + totalWeight = sum(weights) + # normalize each "sub prompt" and add it + for i in range(0,len(subprompts)): + weight = weights[i] + if not skip_normalize: + weight = weight / totalWeight + c = torch.add(c,model.get_learned_conditioning(subprompts[i]), alpha=weight) + else: # just standard 1 prompt + c = model.get_learned_conditioning(prompts) # encode (scaled latent) z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(self.device)) @@ -344,14 +373,14 @@ The vast majority of these arguments default to reasonable values. if not grid: for x_sample in x_samples: x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') - filename = self._unique_filename(outdir,filename,seed=seed,isbatch=(batch_size>1)) - assert not os.path.exists(filename) + filename = os.path.join(outdir, f"{base_count:05}.png") Image.fromarray(x_sample.astype(np.uint8)).save(filename) images.append([filename,seed]) + base_count += 1 else: all_samples.append(x_samples) seeds.append(seed) - image_count += 1 + image_count +=1 seed = self._new_seed() if grid: images = self._make_grid(samples=all_samples, @@ -361,6 +390,7 @@ The vast majority of these arguments default to reasonable values. outdir=outdir) except KeyboardInterrupt: + print('*interrupted*') print('Partial results will be returned; if --grid was requested, nothing will be returned.') except RuntimeError as e: print(str(e)) @@ -481,3 +511,48 @@ The vast majority of these arguments default to reasonable values. filename = f'{basecount:06}.{seed}.{series:02}.png' finished = not os.path.exists(os.path.join(outdir,filename)) return os.path.join(outdir,filename) + + def _split_weighted_subprompts(text): + """ + grabs all text up to the first occurrence of ':' + uses the grabbed text as a sub-prompt, and takes the value following ':' as weight + if ':' has no value defined, defaults to 1.0 + repeats until no text remaining + """ + remaining = len(text) + prompts = [] + weights = [] + while remaining > 0: + if ":" in text: + idx = text.index(":") # first occurrence from start + # grab up to index as sub-prompt + prompt = text[:idx] + remaining -= idx + # remove from main text + text = text[idx+1:] + # find value for weight + if " " in text: + idx = text.index(" ") # first occurence + else: # no space, read to end + idx = len(text) + if idx != 0: + try: + weight = float(text[:idx]) + except: # couldn't treat as float + print(f"Warning: '{text[:idx]}' is not a value, are you missing a space?") + weight = 1.0 + else: # no value found + weight = 1.0 + # remove from main text + remaining -= idx + text = text[idx+1:] + # append the sub-prompt and its weight + prompts.append(prompt) + weights.append(weight) + else: # no : found + if len(text) > 0: # there is still text though + # take remainder as weight 1 + prompts.append(text) + weights.append(1.0) + remaining = 0 + return prompts, weights diff --git a/scripts/dream.py b/scripts/dream.py index 0e511f7789..3ce0e3a88e 100755 --- a/scripts/dream.py +++ b/scripts/dream.py @@ -285,6 +285,7 @@ def create_cmd_parser(): parser.add_argument('-i','--individual',action='store_true',help="generate individual files (default)") parser.add_argument('-I','--init_img',type=str,help="path to input image (supersedes width and height)") parser.add_argument('-f','--strength',default=0.75,type=float,help="strength for noising/unnoising. 0.0 preserves image exactly, 1.0 replaces it completely") + parser.add_argument('-x','--skip_normalize',action='store_true',help="skip subprompt weight normalization") return parser if readline_available: