diff --git a/ldm/generate.py b/ldm/generate.py index 8ffb7110a3..a834844e7e 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -404,7 +404,10 @@ class Generate: ) # TODO: Hacky selection of operation to perform. Needs to be refactored. - if (init_image is not None) and (mask_image is not None): + if self.sampler.conditioning_key() in ('hybrid','concat'): + print(f'** Inpainting model detected. Will try it! **') + generator = self._make_omnibus() + elif (init_image is not None) and (mask_image is not None): generator = self._make_inpaint() elif (embiggen != None or embiggen_tiles != None): generator = self._make_embiggen() @@ -690,6 +693,12 @@ class Generate: self.generators['inpaint'] = Inpaint(self.model, self.precision) return self.generators['inpaint'] + def _make_omnibus(self): + if not self.generators.get('omnibus'): + from ldm.invoke.generator.omnibus import Omnibus + self.generators['omnibus'] = Omnibus(self.model, self.precision) + return self.generators['omnibus'] + def load_model(self): ''' preload model identified in self.model_name diff --git a/ldm/invoke/generator/base.py b/ldm/invoke/generator/base.py index 89476cd216..c70924449b 100644 --- a/ldm/invoke/generator/base.py +++ b/ldm/invoke/generator/base.py @@ -40,12 +40,13 @@ class Generator(): self.variation_amount = variation_amount self.with_variations = with_variations - def generate(self,prompt,init_image,width,height,iterations=1,seed=None, + def generate(self,prompt,init_image,width,height,sampler, iterations=1,seed=None, image_callback=None, step_callback=None, threshold=0.0, perlin=0.0, **kwargs): scope = choose_autocast(self.precision) - make_image = self.get_make_image( + make_image = self.get_make_image( prompt, + sampler = sampler, init_image = init_image, width = width, height = height, @@ -54,13 +55,16 @@ class Generator(): perlin = perlin, **kwargs ) - results = [] seed = seed if seed is not None else self.new_seed() first_seed = seed seed, initial_noise = self.generate_initial_noise(seed, width, height) - with scope(self.model.device.type), self.model.ema_scope(): + + scope = (scope(self.model.device.type), self.model.ema_scope()) if sampler.conditioning_key() not in ('hybrid','concat') else scope(self.model.device.type) + + with scope: for n in trange(iterations, desc='Generating'): + print('DEBUG: in iterations loop() called') x_T = None if self.variation_amount > 0: seed_everything(seed) @@ -75,7 +79,6 @@ class Generator(): x_T = self.get_noise(width,height) except: pass - image = make_image(x_T) results.append([image, seed]) if image_callback is not None: @@ -83,10 +86,10 @@ class Generator(): seed = self.new_seed() return results - def sample_to_image(self,samples): + def sample_to_image(self,samples)->Image.Image: """ - Returns a function returning an image derived from the prompt and the initial image - Return value depends on the seed at the time you call it + Given samples returned from a sampler, converts + it into a PIL Image """ x_samples = self.model.decode_first_stage(samples) x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) diff --git a/ldm/invoke/model_cache.py b/ldm/invoke/model_cache.py index f580dfba25..f972a9eb16 100644 --- a/ldm/invoke/model_cache.py +++ b/ldm/invoke/model_cache.py @@ -13,6 +13,7 @@ import gc import hashlib import psutil import transformers +import traceback import os from sys import getrefcount from omegaconf import OmegaConf @@ -73,6 +74,7 @@ class ModelCache(object): self.models[model_name]['hash'] = hash except Exception as e: print(f'** model {model_name} could not be loaded: {str(e)}') + print(traceback.format_exc()) print(f'** restoring {self.current_model}') self.get_model(self.current_model) return None diff --git a/ldm/models/autoencoder.py b/ldm/models/autoencoder.py index 359f5688d1..3db7b6fd73 100644 --- a/ldm/models/autoencoder.py +++ b/ldm/models/autoencoder.py @@ -66,7 +66,7 @@ class VQModel(pl.LightningModule): self.use_ema = use_ema if self.use_ema: self.model_ema = LitEma(self) - print(f'Keeping EMAs of {len(list(self.model_ema.buffers()))}.') + print(f'>> Keeping EMAs of {len(list(self.model_ema.buffers()))}.') if ckpt_path is not None: self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) diff --git a/ldm/models/diffusion/ddim.py b/ldm/models/diffusion/ddim.py index f5dada8627..bfb78c1397 100644 --- a/ldm/models/diffusion/ddim.py +++ b/ldm/models/diffusion/ddim.py @@ -41,7 +41,19 @@ class DDIMSampler(Sampler): else: x_in = torch.cat([x] * 2) t_in = torch.cat([t] * 2) - c_in = torch.cat([unconditional_conditioning, c]) + if isinstance(c, dict): + assert isinstance(unconditional_conditioning, dict) + c_in = dict() + for k in c: + if isinstance(c[k], list): + c_in[k] = [ + torch.cat([unconditional_conditioning[k][i], c[k][i]]) + for i in range(len(c[k])) + ] + else: + c_in[k] = torch.cat([unconditional_conditioning[k], c[k]]) + else: + c_in = torch.cat([unconditional_conditioning, c]) e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) e_t = e_t_uncond + unconditional_guidance_scale * ( e_t - e_t_uncond diff --git a/ldm/models/diffusion/ddpm.py b/ldm/models/diffusion/ddpm.py index 4b62b5e393..fd3b6688f3 100644 --- a/ldm/models/diffusion/ddpm.py +++ b/ldm/models/diffusion/ddpm.py @@ -19,6 +19,7 @@ from functools import partial from tqdm import tqdm from torchvision.utils import make_grid from pytorch_lightning.utilities.distributed import rank_zero_only +from omegaconf import ListConfig import urllib from ldm.util import ( @@ -120,7 +121,7 @@ class DDPM(pl.LightningModule): self.use_ema = use_ema if self.use_ema: self.model_ema = LitEma(self.model) - print(f'Keeping EMAs of {len(list(self.model_ema.buffers()))}.') + print(f' | Keeping EMAs of {len(list(self.model_ema.buffers()))}.') self.use_scheduler = scheduler_config is not None if self.use_scheduler: @@ -1883,6 +1884,24 @@ class LatentDiffusion(DDPM): return samples, intermediates + @torch.no_grad() + def get_unconditional_conditioning(self, batch_size, null_label=None): + if null_label is not None: + xc = null_label + if isinstance(xc, ListConfig): + xc = list(xc) + if isinstance(xc, dict) or isinstance(xc, list): + c = self.get_learned_conditioning(xc) + else: + if hasattr(xc, "to"): + xc = xc.to(self.device) + c = self.get_learned_conditioning(xc) + else: + # todo: get null label from cond_stage_model + raise NotImplementedError() + c = repeat(c, "1 ... -> b ...", b=batch_size).to(self.device) + return c + @torch.no_grad() def log_images( self, @@ -2138,6 +2157,7 @@ class DiffusionWrapper(pl.LightningModule): ] def forward(self, x, t, c_concat: list = None, c_crossattn: list = None): + print(f'DEBUG (ddpm) c_concat = {c_concat}') if self.conditioning_key is None: out = self.diffusion_model(x, t) elif self.conditioning_key == 'concat': @@ -2147,8 +2167,8 @@ class DiffusionWrapper(pl.LightningModule): cc = torch.cat(c_crossattn, 1) out = self.diffusion_model(x, t, context=cc) elif self.conditioning_key == 'hybrid': - xc = torch.cat([x] + c_concat, dim=1) cc = torch.cat(c_crossattn, 1) + xc = torch.cat([x] + c_concat, dim=1) out = self.diffusion_model(xc, t, context=cc) elif self.conditioning_key == 'adm': cc = c_crossattn[0] @@ -2187,3 +2207,58 @@ class Layout2ImgDiffusion(LatentDiffusion): cond_img = torch.stack(bbox_imgs, dim=0) logs['bbox_image'] = cond_img return logs + +class LatentInpaintDiffusion(LatentDiffusion): + def __init__( + self, + concat_keys=("mask", "masked_image"), + masked_image_key="masked_image", + finetune_keys=None, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.masked_image_key = masked_image_key + assert self.masked_image_key in concat_keys + self.concat_keys = concat_keys + + + @torch.no_grad() + def get_input( + self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False + ): + # note: restricted to non-trainable encoders currently + assert ( + not self.cond_stage_trainable + ), "trainable cond stages not yet supported for inpainting" + z, c, x, xrec, xc = super().get_input( + batch, + self.first_stage_key, + return_first_stage_outputs=True, + force_c_encode=True, + return_original_cond=True, + bs=bs, + ) + + assert exists(self.concat_keys) + c_cat = list() + for ck in self.concat_keys: + cc = ( + rearrange(batch[ck], "b h w c -> b c h w") + .to(memory_format=torch.contiguous_format) + .float() + ) + if bs is not None: + cc = cc[:bs] + cc = cc.to(self.device) + bchw = z.shape + if ck != self.masked_image_key: + cc = torch.nn.functional.interpolate(cc, size=bchw[-2:]) + else: + cc = self.get_first_stage_encoding(self.encode_first_stage(cc)) + c_cat.append(cc) + c_cat = torch.cat(c_cat, dim=1) + all_conds = {"c_concat": [c_cat], "c_crossattn": [c]} + if return_first_stage_outputs: + return z, all_conds, x, xrec, xc + return z, all_conds diff --git a/ldm/models/diffusion/ksampler.py b/ldm/models/diffusion/ksampler.py index ac0615b30c..68c26b5d6c 100644 --- a/ldm/models/diffusion/ksampler.py +++ b/ldm/models/diffusion/ksampler.py @@ -281,3 +281,5 @@ class KSampler(Sampler): ''' return self.model.inner_model.q_sample(x0,ts) + def conditioning_key(self)->str: + return self.model.inner_model.model.conditioning_key diff --git a/ldm/models/diffusion/sampler.py b/ldm/models/diffusion/sampler.py index ff705513f8..9e57bc25d4 100644 --- a/ldm/models/diffusion/sampler.py +++ b/ldm/models/diffusion/sampler.py @@ -158,6 +158,18 @@ class Sampler(object): **kwargs, ): + if conditioning is not None: + if isinstance(conditioning, dict): + ctmp = conditioning[list(conditioning.keys())[0]] + while isinstance(ctmp, list): + ctmp = ctmp[0] + cbs = ctmp.shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + # check to see if make_schedule() has run, and if not, run it if self.ddim_timesteps is None: self.make_schedule( @@ -193,7 +205,7 @@ class Sampler(object): ) return samples, intermediates - #torch.no_grad() + @torch.no_grad() def do_sampling( self, cond, @@ -307,6 +319,19 @@ class Sampler(object): mask = None, ): + print(f'DEBUG(sampler): cond = {cond}') + if cond is not None: + if isinstance(cond, dict): + ctmp = cond[list(cond.keys())[0]] + while isinstance(ctmp, list): + ctmp = ctmp[0] + cbs = ctmp.shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conds but batch-size is {batch_size}") + else: + if cond.shape[0] != batch_size: + print(f"Warning: Got {cond.shape[0]} conditionings but batch-size is {batch_size}") + timesteps = ( np.arange(self.ddpm_num_timesteps) if use_original_steps @@ -411,3 +436,6 @@ class Sampler(object): return self.model.inner_model.q_sample(x0,ts) ''' return self.model.q_sample(x0,ts) + + def conditioning_key(self)->str: + return self.model.model.conditioning_key