diff --git a/ldm/simplet2i.py b/ldm/simplet2i.py index 0ec3d60d98..21f973988b 100644 --- a/ldm/simplet2i.py +++ b/ldm/simplet2i.py @@ -52,7 +52,7 @@ t2i = T2I(model = // models/ldm/stable-diffusion-v1/model.ck # do the slow model initialization t2i.load_model() -# Do the fast inference & image generation. Any options passed here +# Do the fast inference & image generation. Any options passed here # override the default values assigned during class initialization # Will call load_model() if the model was not previously loaded and so # may be slow at first. @@ -70,7 +70,7 @@ results = t2i.prompt2png(prompt = "an astronaut riding a horse", outdir = "./outputs/, iterations = 3, init_img = "./sketches/horse+rider.png") - + for row in results: print(f'filename={row[0]}') print(f'seed ={row[1]}') @@ -181,7 +181,7 @@ The vast majority of these arguments default to reasonable values. outdir = kwargs.get('outdir','outputs/img-samples') assert 'init_img' in kwargs,'call to img2img() must include the init_img argument' return self.prompt2png(prompt,outdir,**kwargs) - + def prompt2image(self, # these are common prompt, @@ -216,10 +216,10 @@ The vast majority of these arguments default to reasonable values. strength // strength for noising/unnoising init_img. 0.0 preserves image exactly, 1.0 replaces it completely ddim_eta // image randomness (eta=0.0 means the same seed always produces the same image) variants // if >0, the 1st generated image will be passed back to img2img to generate the requested number of variants - callback // a function or method that will be called each time an image is generated + image_callback // a function or method that will be called each time an image is generated To use the callback, define a function of method that receives two arguments, an Image object - and the seed. You can then do whatever you like with the image, including converting it to + and the seed. You can then do whatever you like with the image, including converting it to different formats and manipulating it. For example: def process_image(image,seed): @@ -249,116 +249,86 @@ The vast majority of these arguments default to reasonable values. height = h width = w - data = [batch_size * [prompt]] scope = autocast if self.precision=="autocast" else nullcontext - tic = time.time() - if init_img: - assert os.path.exists(init_img),f'{init_img}: File not found' - results = self._img2img(prompt, - data=data,precision_scope=scope, - batch_size=batch_size,iterations=iterations, - steps=steps,seed=seed,cfg_scale=cfg_scale,ddim_eta=ddim_eta, - skip_normalize=skip_normalize, - init_img=init_img,strength=strength,variants=variants, - callback=image_callback) - else: - results = self._txt2img(prompt, - data=data,precision_scope=scope, - batch_size=batch_size,iterations=iterations, - steps=steps,seed=seed,cfg_scale=cfg_scale,ddim_eta=ddim_eta, - skip_normalize=skip_normalize, - width=width,height=height, - callback=image_callback) - toc = time.time() - print(f'{len(results)} images generated in',"%4.2fs"% (toc-tic)) - return results - - @torch.no_grad() - def _txt2img(self,prompt, - data,precision_scope, - batch_size,iterations, - steps,seed,cfg_scale,ddim_eta, - skip_normalize, - width,height, - callback): # the callback is called each time a new Image is generated - """ - Generate an image from the prompt, writing iteration images into the outdir - The output is a list of lists in the format: [[image1,seed1], [image2,seed2],...] - """ + tic = time.time() + results = list() - sampler = self.sampler - images = list() - image_count = 0 - - # Gawd. Too many levels of indent here. Need to refactor into smaller routines! try: - with precision_scope(self.device.type), self.model.ema_scope(): - all_samples = list() + if init_img: + assert os.path.exists(init_img),f'{init_img}: File not found' + images_iterator = self._img2img(prompt, + precision_scope=scope, + batch_size=batch_size, + steps=steps,cfg_scale=cfg_scale,ddim_eta=ddim_eta, + skip_normalize=skip_normalize, + init_img=init_img,strength=strength) + else: + images_iterator = self._txt2img(prompt, + precision_scope=scope, + batch_size=batch_size, + steps=steps,cfg_scale=cfg_scale,ddim_eta=ddim_eta, + skip_normalize=skip_normalize, + width=width,height=height) + + with scope(self.device.type), self.model.ema_scope(): for n in trange(iterations, desc="Sampling"): seed_everything(seed) - for prompts in tqdm(data, desc="data", dynamic_ncols=True): - uc = None - if cfg_scale != 1.0: - uc = self.model.get_learned_conditioning(batch_size * [""]) - if isinstance(prompts, tuple): - prompts = list(prompts) - - # weighted sub-prompts - subprompts,weights = T2I._split_weighted_subprompts(prompts[0]) - if len(subprompts) > 1: - # i dont know if this is correct.. but it works - c = torch.zeros_like(uc) - # get total weight for normalizing - totalWeight = sum(weights) - # normalize each "sub prompt" and add it - for i in range(0,len(subprompts)): - weight = weights[i] - if not skip_normalize: - weight = weight / totalWeight - c = torch.add(c,self.model.get_learned_conditioning(subprompts[i]), alpha=weight) - else: # just standard 1 prompt - c = self.model.get_learned_conditioning(prompts) - - shape = [self.latent_channels, height // self.downsampling_factor, width // self.downsampling_factor] - samples_ddim, _ = sampler.sample(S=steps, - conditioning=c, - batch_size=batch_size, - shape=shape, - verbose=False, - unconditional_guidance_scale=cfg_scale, - unconditional_conditioning=uc, - eta=ddim_eta) - - x_samples_ddim = self.model.decode_first_stage(samples_ddim) - x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) - for x_sample in x_samples_ddim: - x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') - image = Image.fromarray(x_sample.astype(np.uint8)) - images.append([image,seed]) - if callback is not None: - callback(image,seed) - + iter_images = next(images_iterator) + for image in iter_images: + results.append([image, seed]) + if image_callback is not None: + image_callback(image,seed) seed = self._new_seed() + except KeyboardInterrupt: print('*interrupted*') print('Partial results will be returned; if --grid was requested, nothing will be returned.') except RuntimeError as e: print(str(e)) + print('Are you sure your system has an adequate NVIDIA GPU?') + + toc = time.time() + print(f'{len(results)} images generated in',"%4.2fs"% (toc-tic)) + return results - return images - @torch.no_grad() - def _img2img(self,prompt, - data,precision_scope, - batch_size,iterations, - steps,seed,cfg_scale,ddim_eta, + def _txt2img(self, + prompt, + precision_scope, + batch_size, + steps,cfg_scale,ddim_eta, skip_normalize, - init_img,strength,variants, - callback): + width,height): """ - Generate an image from the prompt and the initial image, writing iteration images into the outdir - The output is a list of lists in the format: [[image,seed1], [image,seed2],...] + An infinite iterator of images from the prompt. + """ + + sampler = self.sampler + + while True: + uc, c = self._get_uc_and_c(prompt, batch_size, skip_normalize) + shape = [self.latent_channels, height // self.downsampling_factor, width // self.downsampling_factor] + samples, _ = sampler.sample(S=steps, + conditioning=c, + batch_size=batch_size, + shape=shape, + verbose=False, + unconditional_guidance_scale=cfg_scale, + unconditional_conditioning=uc, + eta=ddim_eta) + yield self._samples_to_images(samples) + + @torch.no_grad() + def _img2img(self, + prompt, + precision_scope, + batch_size, + steps,cfg_scale,ddim_eta, + skip_normalize, + init_img,strength): + """ + An infinite iterator of images from the prompt and the initial image """ # PLMS sampler not supported yet, so ignore previous sampler @@ -374,62 +344,50 @@ The vast majority of these arguments default to reasonable values. init_latent = self.model.get_first_stage_encoding(self.model.encode_first_stage(init_image)) # move to latent space sampler.make_schedule(ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False) - + t_enc = int(strength * steps) # print(f"target t_enc is {t_enc} steps") + + while True: + uc, c = self._get_uc_and_c(prompt, batch_size, skip_normalize) + + # encode (scaled latent) + z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(self.device)) + # decode it + samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=cfg_scale, + unconditional_conditioning=uc,) + yield self._samples_to_images(samples) + + # TODO: does this actually need to run every loop? does anything in it vary by random seed? + def _get_uc_and_c(self, prompt, batch_size, skip_normalize): + + uc = self.model.get_learned_conditioning(batch_size * [""]) + + # weighted sub-prompts + subprompts,weights = T2I._split_weighted_subprompts(prompt) + if len(subprompts) > 1: + # i dont know if this is correct.. but it works + c = torch.zeros_like(uc) + # get total weight for normalizing + totalWeight = sum(weights) + # normalize each "sub prompt" and add it + for i in range(0,len(subprompts)): + weight = weights[i] + if not skip_normalize: + weight = weight / totalWeight + c = torch.add(c, self.model.get_learned_conditioning(batch_size * [subprompts[i]]), alpha=weight) + else: # just standard 1 prompt + c = self.model.get_learned_conditioning(batch_size * [prompt]) + return (uc, c) + + def _samples_to_images(self, samples): + x_samples = self.model.decode_first_stage(samples) + x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) images = list() - - try: - with precision_scope(self.device.type), self.model.ema_scope(): - all_samples = list() - for n in trange(iterations, desc="Sampling"): - seed_everything(seed) - for prompts in tqdm(data, desc="data", dynamic_ncols=True): - uc = None - if cfg_scale != 1.0: - uc = self.model.get_learned_conditioning(batch_size * [""]) - if isinstance(prompts, tuple): - prompts = list(prompts) - - # weighted sub-prompts - subprompts,weights = T2I._split_weighted_subprompts(prompts[0]) - if len(subprompts) > 1: - # i dont know if this is correct.. but it works - c = torch.zeros_like(uc) - # get total weight for normalizing - totalWeight = sum(weights) - # normalize each "sub prompt" and add it - for i in range(0,len(subprompts)): - weight = weights[i] - if not skip_normalize: - weight = weight / totalWeight - c = torch.add(c,self.model.get_learned_conditioning(subprompts[i]), alpha=weight) - else: # just standard 1 prompt - c = self.model.get_learned_conditioning(prompts) - - # encode (scaled latent) - z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(self.device)) - # decode it - samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=cfg_scale, - unconditional_conditioning=uc,) - - x_samples = self.model.decode_first_stage(samples) - x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) - - for x_sample in x_samples: - x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') - image = Image.fromarray(x_sample.astype(np.uint8)) - images.append([image,seed]) - if callback is not None: - callback(image,seed) - seed = self._new_seed() - - except KeyboardInterrupt: - print('*interrupted*') - print('Partial results will be returned; if --grid was requested, nothing will be returned.') - except RuntimeError as e: - print("Oops! A runtime error has occurred. If this is unexpected, please copy-and-paste this stack trace and post it as an Issue to http://github.com/lstein/stable-diffusion") - traceback.print_exc() + for x_sample in x_samples: + x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') + image = Image.fromarray(x_sample.astype(np.uint8)) + images.append(image) return images def _new_seed(self): @@ -476,7 +434,7 @@ The vast majority of these arguments default to reasonable values. print(msg) return self.model - + def _load_model_from_config(self, config, ckpt): print(f"Loading model from {ckpt}") pl_sd = torch.load(ckpt, map_location="cpu") @@ -507,7 +465,7 @@ The vast majority of these arguments default to reasonable values. def _split_weighted_subprompts(text): """ - grabs all text up to the first occurrence of ':' + grabs all text up to the first occurrence of ':' uses the grabbed text as a sub-prompt, and takes the value following ':' as weight if ':' has no value defined, defaults to 1.0 repeats until no text remaining @@ -523,7 +481,7 @@ The vast majority of these arguments default to reasonable values. remaining -= idx # remove from main text text = text[idx+1:] - # find value for weight + # find value for weight if " " in text: idx = text.index(" ") # first occurence else: # no space, read to end diff --git a/scripts/dream.py b/scripts/dream.py index 24dac5b927..b0a31b63e0 100755 --- a/scripts/dream.py +++ b/scripts/dream.py @@ -252,7 +252,7 @@ def create_argv_parser(): '-o', type=str, default="outputs/img-samples", - help="directory in which to place generated images and a log of prompts and seeds") + help="directory in which to place generated images and a log of prompts and seeds (outputs/img-samples") parser.add_argument('--embedding_path', type=str, help="Path to a pre-trained embedding manager checkpoint - can only be set on command line")