From a10baf58082defe05311c2d90614c387651301bd Mon Sep 17 00:00:00 2001 From: Kevin Gibbons Date: Thu, 25 Aug 2022 15:13:07 -0700 Subject: [PATCH 1/4] factor out exception handler --- ldm/simplet2i.py | 248 +++++++++++++++++++++++------------------------ 1 file changed, 124 insertions(+), 124 deletions(-) diff --git a/ldm/simplet2i.py b/ldm/simplet2i.py index 0ec3d60d98..157f55fbcb 100644 --- a/ldm/simplet2i.py +++ b/ldm/simplet2i.py @@ -52,7 +52,7 @@ t2i = T2I(model = // models/ldm/stable-diffusion-v1/model.ck # do the slow model initialization t2i.load_model() -# Do the fast inference & image generation. Any options passed here +# Do the fast inference & image generation. Any options passed here # override the default values assigned during class initialization # Will call load_model() if the model was not previously loaded and so # may be slow at first. @@ -70,7 +70,7 @@ results = t2i.prompt2png(prompt = "an astronaut riding a horse", outdir = "./outputs/, iterations = 3, init_img = "./sketches/horse+rider.png") - + for row in results: print(f'filename={row[0]}') print(f'seed ={row[1]}') @@ -181,7 +181,7 @@ The vast majority of these arguments default to reasonable values. outdir = kwargs.get('outdir','outputs/img-samples') assert 'init_img' in kwargs,'call to img2img() must include the init_img argument' return self.prompt2png(prompt,outdir,**kwargs) - + def prompt2image(self, # these are common prompt, @@ -219,7 +219,7 @@ The vast majority of these arguments default to reasonable values. callback // a function or method that will be called each time an image is generated To use the callback, define a function of method that receives two arguments, an Image object - and the seed. You can then do whatever you like with the image, including converting it to + and the seed. You can then do whatever you like with the image, including converting it to different formats and manipulating it. For example: def process_image(image,seed): @@ -252,28 +252,41 @@ The vast majority of these arguments default to reasonable values. data = [batch_size * [prompt]] scope = autocast if self.precision=="autocast" else nullcontext - tic = time.time() - if init_img: - assert os.path.exists(init_img),f'{init_img}: File not found' - results = self._img2img(prompt, - data=data,precision_scope=scope, - batch_size=batch_size,iterations=iterations, - steps=steps,seed=seed,cfg_scale=cfg_scale,ddim_eta=ddim_eta, - skip_normalize=skip_normalize, - init_img=init_img,strength=strength,variants=variants, - callback=image_callback) - else: - results = self._txt2img(prompt, - data=data,precision_scope=scope, - batch_size=batch_size,iterations=iterations, - steps=steps,seed=seed,cfg_scale=cfg_scale,ddim_eta=ddim_eta, - skip_normalize=skip_normalize, - width=width,height=height, - callback=image_callback) + tic = time.time() + results = list() + def prompt_callback(image, seed): + results.append([image, seed]) + if image_callback is not None: + image_callback(image, seed) + + try: + if init_img: + assert os.path.exists(init_img),f'{init_img}: File not found' + self._img2img(prompt, + data=data,precision_scope=scope, + batch_size=batch_size,iterations=iterations, + steps=steps,seed=seed,cfg_scale=cfg_scale,ddim_eta=ddim_eta, + skip_normalize=skip_normalize, + init_img=init_img,strength=strength,variants=variants, + callback=prompt_callback) + else: + self._txt2img(prompt, + data=data,precision_scope=scope, + batch_size=batch_size,iterations=iterations, + steps=steps,seed=seed,cfg_scale=cfg_scale,ddim_eta=ddim_eta, + skip_normalize=skip_normalize, + width=width,height=height, + callback=prompt_callback) + except KeyboardInterrupt: + print('*interrupted*') + print('Partial results will be returned; if --grid was requested, nothing will be returned.') + except RuntimeError as e: + print(str(e)) + toc = time.time() print(f'{len(results)} images generated in',"%4.2fs"% (toc-tic)) return results - + @torch.no_grad() def _txt2img(self,prompt, data,precision_scope, @@ -292,62 +305,56 @@ The vast majority of these arguments default to reasonable values. image_count = 0 # Gawd. Too many levels of indent here. Need to refactor into smaller routines! - try: - with precision_scope(self.device.type), self.model.ema_scope(): - all_samples = list() - for n in trange(iterations, desc="Sampling"): - seed_everything(seed) - for prompts in tqdm(data, desc="data", dynamic_ncols=True): - uc = None - if cfg_scale != 1.0: - uc = self.model.get_learned_conditioning(batch_size * [""]) - if isinstance(prompts, tuple): - prompts = list(prompts) + with precision_scope(self.device.type), self.model.ema_scope(): + all_samples = list() + for n in trange(iterations, desc="Sampling"): + seed_everything(seed) + for prompts in tqdm(data, desc="data", dynamic_ncols=True): + uc = None + if cfg_scale != 1.0: + uc = self.model.get_learned_conditioning(batch_size * [""]) + if isinstance(prompts, tuple): + prompts = list(prompts) - # weighted sub-prompts - subprompts,weights = T2I._split_weighted_subprompts(prompts[0]) - if len(subprompts) > 1: - # i dont know if this is correct.. but it works - c = torch.zeros_like(uc) - # get total weight for normalizing - totalWeight = sum(weights) - # normalize each "sub prompt" and add it - for i in range(0,len(subprompts)): - weight = weights[i] - if not skip_normalize: - weight = weight / totalWeight - c = torch.add(c,self.model.get_learned_conditioning(subprompts[i]), alpha=weight) - else: # just standard 1 prompt - c = self.model.get_learned_conditioning(prompts) + # weighted sub-prompts + subprompts,weights = T2I._split_weighted_subprompts(prompts[0]) + if len(subprompts) > 1: + # i dont know if this is correct.. but it works + c = torch.zeros_like(uc) + # get total weight for normalizing + totalWeight = sum(weights) + # normalize each "sub prompt" and add it + for i in range(0,len(subprompts)): + weight = weights[i] + if not skip_normalize: + weight = weight / totalWeight + c = torch.add(c,self.model.get_learned_conditioning(subprompts[i]), alpha=weight) + else: # just standard 1 prompt + c = self.model.get_learned_conditioning(prompts) - shape = [self.latent_channels, height // self.downsampling_factor, width // self.downsampling_factor] - samples_ddim, _ = sampler.sample(S=steps, - conditioning=c, - batch_size=batch_size, - shape=shape, - verbose=False, - unconditional_guidance_scale=cfg_scale, - unconditional_conditioning=uc, - eta=ddim_eta) + shape = [self.latent_channels, height // self.downsampling_factor, width // self.downsampling_factor] + samples_ddim, _ = sampler.sample(S=steps, + conditioning=c, + batch_size=batch_size, + shape=shape, + verbose=False, + unconditional_guidance_scale=cfg_scale, + unconditional_conditioning=uc, + eta=ddim_eta) - x_samples_ddim = self.model.decode_first_stage(samples_ddim) - x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) - for x_sample in x_samples_ddim: - x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') - image = Image.fromarray(x_sample.astype(np.uint8)) - images.append([image,seed]) - if callback is not None: - callback(image,seed) - - seed = self._new_seed() - except KeyboardInterrupt: - print('*interrupted*') - print('Partial results will be returned; if --grid was requested, nothing will be returned.') - except RuntimeError as e: - print(str(e)) + x_samples_ddim = self.model.decode_first_stage(samples_ddim) + x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) + for x_sample in x_samples_ddim: + x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') + image = Image.fromarray(x_sample.astype(np.uint8)) + images.append([image,seed]) + if callback is not None: + callback(image,seed) + + seed = self._new_seed() return images - + @torch.no_grad() def _img2img(self,prompt, data,precision_scope, @@ -374,62 +381,55 @@ The vast majority of these arguments default to reasonable values. init_latent = self.model.get_first_stage_encoding(self.model.encode_first_stage(init_image)) # move to latent space sampler.make_schedule(ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False) - + t_enc = int(strength * steps) # print(f"target t_enc is {t_enc} steps") images = list() - try: - with precision_scope(self.device.type), self.model.ema_scope(): - all_samples = list() - for n in trange(iterations, desc="Sampling"): - seed_everything(seed) - for prompts in tqdm(data, desc="data", dynamic_ncols=True): - uc = None - if cfg_scale != 1.0: - uc = self.model.get_learned_conditioning(batch_size * [""]) - if isinstance(prompts, tuple): - prompts = list(prompts) + with precision_scope(self.device.type), self.model.ema_scope(): + all_samples = list() + for n in trange(iterations, desc="Sampling"): + seed_everything(seed) + for prompts in tqdm(data, desc="data", dynamic_ncols=True): + uc = None + if cfg_scale != 1.0: + uc = self.model.get_learned_conditioning(batch_size * [""]) + if isinstance(prompts, tuple): + prompts = list(prompts) - # weighted sub-prompts - subprompts,weights = T2I._split_weighted_subprompts(prompts[0]) - if len(subprompts) > 1: - # i dont know if this is correct.. but it works - c = torch.zeros_like(uc) - # get total weight for normalizing - totalWeight = sum(weights) - # normalize each "sub prompt" and add it - for i in range(0,len(subprompts)): - weight = weights[i] - if not skip_normalize: - weight = weight / totalWeight - c = torch.add(c,self.model.get_learned_conditioning(subprompts[i]), alpha=weight) - else: # just standard 1 prompt - c = self.model.get_learned_conditioning(prompts) + # weighted sub-prompts + subprompts,weights = T2I._split_weighted_subprompts(prompts[0]) + if len(subprompts) > 1: + # i dont know if this is correct.. but it works + c = torch.zeros_like(uc) + # get total weight for normalizing + totalWeight = sum(weights) + # normalize each "sub prompt" and add it + for i in range(0,len(subprompts)): + weight = weights[i] + if not skip_normalize: + weight = weight / totalWeight + c = torch.add(c,self.model.get_learned_conditioning(subprompts[i]), alpha=weight) + else: # just standard 1 prompt + c = self.model.get_learned_conditioning(prompts) - # encode (scaled latent) - z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(self.device)) - # decode it - samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=cfg_scale, - unconditional_conditioning=uc,) + # encode (scaled latent) + z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(self.device)) + # decode it + samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=cfg_scale, + unconditional_conditioning=uc,) - x_samples = self.model.decode_first_stage(samples) - x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) + x_samples = self.model.decode_first_stage(samples) + x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) - for x_sample in x_samples: - x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') - image = Image.fromarray(x_sample.astype(np.uint8)) - images.append([image,seed]) - if callback is not None: - callback(image,seed) - seed = self._new_seed() + for x_sample in x_samples: + x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') + image = Image.fromarray(x_sample.astype(np.uint8)) + images.append([image,seed]) + if callback is not None: + callback(image,seed) + seed = self._new_seed() - except KeyboardInterrupt: - print('*interrupted*') - print('Partial results will be returned; if --grid was requested, nothing will be returned.') - except RuntimeError as e: - print("Oops! A runtime error has occurred. If this is unexpected, please copy-and-paste this stack trace and post it as an Issue to http://github.com/lstein/stable-diffusion") - traceback.print_exc() return images def _new_seed(self): @@ -476,7 +476,7 @@ The vast majority of these arguments default to reasonable values. print(msg) return self.model - + def _load_model_from_config(self, config, ckpt): print(f"Loading model from {ckpt}") pl_sd = torch.load(ckpt, map_location="cpu") @@ -507,7 +507,7 @@ The vast majority of these arguments default to reasonable values. def _split_weighted_subprompts(text): """ - grabs all text up to the first occurrence of ':' + grabs all text up to the first occurrence of ':' uses the grabbed text as a sub-prompt, and takes the value following ':' as weight if ':' has no value defined, defaults to 1.0 repeats until no text remaining @@ -523,7 +523,7 @@ The vast majority of these arguments default to reasonable values. remaining -= idx # remove from main text text = text[idx+1:] - # find value for weight + # find value for weight if " " in text: idx = text.index(" ") # first occurence else: # no space, read to end From 078859207df4a6149cb8cdf7d4d9b4bb1fef1ae6 Mon Sep 17 00:00:00 2001 From: Kevin Gibbons Date: Thu, 25 Aug 2022 15:19:44 -0700 Subject: [PATCH 2/4] factor out loop --- ldm/simplet2i.py | 226 ++++++++++++++++++++--------------------------- 1 file changed, 94 insertions(+), 132 deletions(-) diff --git a/ldm/simplet2i.py b/ldm/simplet2i.py index 157f55fbcb..31721906d7 100644 --- a/ldm/simplet2i.py +++ b/ldm/simplet2i.py @@ -216,7 +216,7 @@ The vast majority of these arguments default to reasonable values. strength // strength for noising/unnoising init_img. 0.0 preserves image exactly, 1.0 replaces it completely ddim_eta // image randomness (eta=0.0 means the same seed always produces the same image) variants // if >0, the 1st generated image will be passed back to img2img to generate the requested number of variants - callback // a function or method that will be called each time an image is generated + image_callback // a function or method that will be called each time an image is generated To use the callback, define a function of method that receives two arguments, an Image object and the seed. You can then do whatever you like with the image, including converting it to @@ -249,34 +249,40 @@ The vast majority of these arguments default to reasonable values. height = h width = w - data = [batch_size * [prompt]] scope = autocast if self.precision=="autocast" else nullcontext tic = time.time() results = list() - def prompt_callback(image, seed): - results.append([image, seed]) - if image_callback is not None: - image_callback(image, seed) try: if init_img: assert os.path.exists(init_img),f'{init_img}: File not found' - self._img2img(prompt, - data=data,precision_scope=scope, - batch_size=batch_size,iterations=iterations, - steps=steps,seed=seed,cfg_scale=cfg_scale,ddim_eta=ddim_eta, + get_images = self._img2img( + precision_scope=scope, + batch_size=batch_size, + steps=steps,cfg_scale=cfg_scale,ddim_eta=ddim_eta, skip_normalize=skip_normalize, - init_img=init_img,strength=strength,variants=variants, - callback=prompt_callback) + init_img=init_img,strength=strength) else: - self._txt2img(prompt, - data=data,precision_scope=scope, - batch_size=batch_size,iterations=iterations, - steps=steps,seed=seed,cfg_scale=cfg_scale,ddim_eta=ddim_eta, + get_images = self._txt2img( + precision_scope=scope, + batch_size=batch_size, + steps=steps,cfg_scale=cfg_scale,ddim_eta=ddim_eta, skip_normalize=skip_normalize, - width=width,height=height, - callback=prompt_callback) + width=width,height=height) + + data = [batch_size * [prompt]] + with scope(self.device.type), self.model.ema_scope(): + for n in trange(iterations, desc="Sampling"): + seed_everything(seed) + for prompts in tqdm(data, desc="data", dynamic_ncols=True): + iter_images = get_images(prompts) + for image in iter_images: + results.append([image, seed]) + if image_callback is not None: + image_callback(image,seed) + seed = self._new_seed() + except KeyboardInterrupt: print('*interrupted*') print('Partial results will be returned; if --grid was requested, nothing will be returned.') @@ -288,84 +294,41 @@ The vast majority of these arguments default to reasonable values. return results @torch.no_grad() - def _txt2img(self,prompt, - data,precision_scope, - batch_size,iterations, - steps,seed,cfg_scale,ddim_eta, + def _txt2img(self, + precision_scope, + batch_size, + steps,cfg_scale,ddim_eta, skip_normalize, - width,height, - callback): # the callback is called each time a new Image is generated + width,height): """ - Generate an image from the prompt, writing iteration images into the outdir - The output is a list of lists in the format: [[image1,seed1], [image2,seed2],...] + Generate an image from the prompt """ - sampler = self.sampler - images = list() - image_count = 0 + sampler = self.sampler - # Gawd. Too many levels of indent here. Need to refactor into smaller routines! - with precision_scope(self.device.type), self.model.ema_scope(): - all_samples = list() - for n in trange(iterations, desc="Sampling"): - seed_everything(seed) - for prompts in tqdm(data, desc="data", dynamic_ncols=True): - uc = None - if cfg_scale != 1.0: - uc = self.model.get_learned_conditioning(batch_size * [""]) - if isinstance(prompts, tuple): - prompts = list(prompts) - - # weighted sub-prompts - subprompts,weights = T2I._split_weighted_subprompts(prompts[0]) - if len(subprompts) > 1: - # i dont know if this is correct.. but it works - c = torch.zeros_like(uc) - # get total weight for normalizing - totalWeight = sum(weights) - # normalize each "sub prompt" and add it - for i in range(0,len(subprompts)): - weight = weights[i] - if not skip_normalize: - weight = weight / totalWeight - c = torch.add(c,self.model.get_learned_conditioning(subprompts[i]), alpha=weight) - else: # just standard 1 prompt - c = self.model.get_learned_conditioning(prompts) - - shape = [self.latent_channels, height // self.downsampling_factor, width // self.downsampling_factor] - samples_ddim, _ = sampler.sample(S=steps, - conditioning=c, - batch_size=batch_size, - shape=shape, - verbose=False, - unconditional_guidance_scale=cfg_scale, - unconditional_conditioning=uc, - eta=ddim_eta) - - x_samples_ddim = self.model.decode_first_stage(samples_ddim) - x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) - for x_sample in x_samples_ddim: - x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') - image = Image.fromarray(x_sample.astype(np.uint8)) - images.append([image,seed]) - if callback is not None: - callback(image,seed) - - seed = self._new_seed() - - return images + def get_images(prompts): + uc, c = self._get_uc_and_c(prompts, batch_size, skip_normalize) + shape = [self.latent_channels, height // self.downsampling_factor, width // self.downsampling_factor] + samples, _ = sampler.sample(S=steps, + conditioning=c, + batch_size=batch_size, + shape=shape, + verbose=False, + unconditional_guidance_scale=cfg_scale, + unconditional_conditioning=uc, + eta=ddim_eta) + return self._samples_to_images(samples) + return get_images @torch.no_grad() - def _img2img(self,prompt, - data,precision_scope, - batch_size,iterations, - steps,seed,cfg_scale,ddim_eta, + def _img2img(self, + precision_scope, + batch_size, + steps,cfg_scale,ddim_eta, skip_normalize, - init_img,strength,variants, - callback): + init_img,strength): """ - Generate an image from the prompt and the initial image, writing iteration images into the outdir - The output is a list of lists in the format: [[image,seed1], [image,seed2],...] + Generate an image from the prompt and the initial image """ # PLMS sampler not supported yet, so ignore previous sampler @@ -384,54 +347,53 @@ The vast majority of these arguments default to reasonable values. t_enc = int(strength * steps) # print(f"target t_enc is {t_enc} steps") + + def get_images(prompts): + uc, c = self._get_uc_and_c(prompts, batch_size, skip_normalize) + + # encode (scaled latent) + z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(self.device)) + # decode it + samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=cfg_scale, + unconditional_conditioning=uc,) + return self._samples_to_images(samples) + return get_images + + # TODO: does this actually need to run every loop? does anything in it vary by random seed? + def _get_uc_and_c(self, prompts, batch_size, skip_normalize): + if isinstance(prompts, tuple): + prompts = list(prompts) + + uc = self.model.get_learned_conditioning(batch_size * [""]) + + # weighted sub-prompts + subprompts,weights = T2I._split_weighted_subprompts(prompts[0]) + if len(subprompts) > 1: + # i dont know if this is correct.. but it works + c = torch.zeros_like(uc) + # get total weight for normalizing + totalWeight = sum(weights) + # normalize each "sub prompt" and add it + for i in range(0,len(subprompts)): + weight = weights[i] + if not skip_normalize: + weight = weight / totalWeight + c = torch.add(c,self.model.get_learned_conditioning(subprompts[i]), alpha=weight) + else: # just standard 1 prompt + c = self.model.get_learned_conditioning(prompts) + return (uc, c) + + def _samples_to_images(self, samples): + x_samples = self.model.decode_first_stage(samples) + x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) images = list() - - with precision_scope(self.device.type), self.model.ema_scope(): - all_samples = list() - for n in trange(iterations, desc="Sampling"): - seed_everything(seed) - for prompts in tqdm(data, desc="data", dynamic_ncols=True): - uc = None - if cfg_scale != 1.0: - uc = self.model.get_learned_conditioning(batch_size * [""]) - if isinstance(prompts, tuple): - prompts = list(prompts) - - # weighted sub-prompts - subprompts,weights = T2I._split_weighted_subprompts(prompts[0]) - if len(subprompts) > 1: - # i dont know if this is correct.. but it works - c = torch.zeros_like(uc) - # get total weight for normalizing - totalWeight = sum(weights) - # normalize each "sub prompt" and add it - for i in range(0,len(subprompts)): - weight = weights[i] - if not skip_normalize: - weight = weight / totalWeight - c = torch.add(c,self.model.get_learned_conditioning(subprompts[i]), alpha=weight) - else: # just standard 1 prompt - c = self.model.get_learned_conditioning(prompts) - - # encode (scaled latent) - z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(self.device)) - # decode it - samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=cfg_scale, - unconditional_conditioning=uc,) - - x_samples = self.model.decode_first_stage(samples) - x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) - - for x_sample in x_samples: - x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') - image = Image.fromarray(x_sample.astype(np.uint8)) - images.append([image,seed]) - if callback is not None: - callback(image,seed) - seed = self._new_seed() - + for x_sample in x_samples: + x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') + image = Image.fromarray(x_sample.astype(np.uint8)) + images.append(image) return images + def _new_seed(self): self.seed = random.randrange(0,np.iinfo(np.uint32).max) return self.seed From 31b22e057d12daddd9a9b79c8d156288dfad3b95 Mon Sep 17 00:00:00 2001 From: Kevin Gibbons Date: Thu, 25 Aug 2022 17:01:17 -0700 Subject: [PATCH 3/4] switch to generators --- ldm/simplet2i.py | 30 ++++++++++++++---------------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/ldm/simplet2i.py b/ldm/simplet2i.py index 31721906d7..d63502831d 100644 --- a/ldm/simplet2i.py +++ b/ldm/simplet2i.py @@ -257,26 +257,25 @@ The vast majority of these arguments default to reasonable values. try: if init_img: assert os.path.exists(init_img),f'{init_img}: File not found' - get_images = self._img2img( + images_iterator = self._img2img(prompt, precision_scope=scope, batch_size=batch_size, steps=steps,cfg_scale=cfg_scale,ddim_eta=ddim_eta, skip_normalize=skip_normalize, init_img=init_img,strength=strength) else: - get_images = self._txt2img( + images_iterator = self._txt2img(prompt, precision_scope=scope, batch_size=batch_size, steps=steps,cfg_scale=cfg_scale,ddim_eta=ddim_eta, skip_normalize=skip_normalize, width=width,height=height) - data = [batch_size * [prompt]] with scope(self.device.type), self.model.ema_scope(): for n in trange(iterations, desc="Sampling"): seed_everything(seed) - for prompts in tqdm(data, desc="data", dynamic_ncols=True): - iter_images = get_images(prompts) + for batch_item in tqdm(range(batch_size), desc="data", dynamic_ncols=True): + iter_images = next(images_iterator) for image in iter_images: results.append([image, seed]) if image_callback is not None: @@ -295,19 +294,20 @@ The vast majority of these arguments default to reasonable values. @torch.no_grad() def _txt2img(self, + prompt, precision_scope, batch_size, steps,cfg_scale,ddim_eta, skip_normalize, width,height): """ - Generate an image from the prompt + An infinite iterator of images from the prompt. """ sampler = self.sampler - def get_images(prompts): - uc, c = self._get_uc_and_c(prompts, batch_size, skip_normalize) + while True: + uc, c = self._get_uc_and_c(prompt, batch_size, skip_normalize) shape = [self.latent_channels, height // self.downsampling_factor, width // self.downsampling_factor] samples, _ = sampler.sample(S=steps, conditioning=c, @@ -317,18 +317,18 @@ The vast majority of these arguments default to reasonable values. unconditional_guidance_scale=cfg_scale, unconditional_conditioning=uc, eta=ddim_eta) - return self._samples_to_images(samples) - return get_images + yield self._samples_to_images(samples) @torch.no_grad() def _img2img(self, + prompt, precision_scope, batch_size, steps,cfg_scale,ddim_eta, skip_normalize, init_img,strength): """ - Generate an image from the prompt and the initial image + An infinite iterator of images from the prompt and the initial image """ # PLMS sampler not supported yet, so ignore previous sampler @@ -348,16 +348,15 @@ The vast majority of these arguments default to reasonable values. t_enc = int(strength * steps) # print(f"target t_enc is {t_enc} steps") - def get_images(prompts): - uc, c = self._get_uc_and_c(prompts, batch_size, skip_normalize) + while True: + uc, c = self._get_uc_and_c(prompt, batch_size, skip_normalize) # encode (scaled latent) z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(self.device)) # decode it samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=cfg_scale, unconditional_conditioning=uc,) - return self._samples_to_images(samples) - return get_images + yield self._samples_to_images(samples) # TODO: does this actually need to run every loop? does anything in it vary by random seed? def _get_uc_and_c(self, prompts, batch_size, skip_normalize): @@ -393,7 +392,6 @@ The vast majority of these arguments default to reasonable values. images.append(image) return images - def _new_seed(self): self.seed = random.randrange(0,np.iinfo(np.uint32).max) return self.seed From 797de3257cf3507dba8672f4387995821a42adcc Mon Sep 17 00:00:00 2001 From: Kevin Gibbons Date: Thu, 25 Aug 2022 17:16:07 -0700 Subject: [PATCH 4/4] fix batch_size --- ldm/simplet2i.py | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/ldm/simplet2i.py b/ldm/simplet2i.py index d63502831d..3187bee090 100644 --- a/ldm/simplet2i.py +++ b/ldm/simplet2i.py @@ -274,12 +274,11 @@ The vast majority of these arguments default to reasonable values. with scope(self.device.type), self.model.ema_scope(): for n in trange(iterations, desc="Sampling"): seed_everything(seed) - for batch_item in tqdm(range(batch_size), desc="data", dynamic_ncols=True): - iter_images = next(images_iterator) - for image in iter_images: - results.append([image, seed]) - if image_callback is not None: - image_callback(image,seed) + iter_images = next(images_iterator) + for image in iter_images: + results.append([image, seed]) + if image_callback is not None: + image_callback(image,seed) seed = self._new_seed() except KeyboardInterrupt: @@ -359,14 +358,12 @@ The vast majority of these arguments default to reasonable values. yield self._samples_to_images(samples) # TODO: does this actually need to run every loop? does anything in it vary by random seed? - def _get_uc_and_c(self, prompts, batch_size, skip_normalize): - if isinstance(prompts, tuple): - prompts = list(prompts) + def _get_uc_and_c(self, prompt, batch_size, skip_normalize): uc = self.model.get_learned_conditioning(batch_size * [""]) # weighted sub-prompts - subprompts,weights = T2I._split_weighted_subprompts(prompts[0]) + subprompts,weights = T2I._split_weighted_subprompts(prompt) if len(subprompts) > 1: # i dont know if this is correct.. but it works c = torch.zeros_like(uc) @@ -377,9 +374,9 @@ The vast majority of these arguments default to reasonable values. weight = weights[i] if not skip_normalize: weight = weight / totalWeight - c = torch.add(c,self.model.get_learned_conditioning(subprompts[i]), alpha=weight) + c = torch.add(c, self.model.get_learned_conditioning(batch_size * [subprompts[i]]), alpha=weight) else: # just standard 1 prompt - c = self.model.get_learned_conditioning(prompts) + c = self.model.get_learned_conditioning(batch_size * [prompt]) return (uc, c) def _samples_to_images(self, samples):