start support for 1.5 inpainting model, not complete

This commit is contained in:
Lincoln Stein 2022-10-25 00:30:48 -04:00
parent 3081b6b7dd
commit 83a3cc9eb4
8 changed files with 145 additions and 14 deletions

View File

@ -404,7 +404,10 @@ class Generate:
) )
# TODO: Hacky selection of operation to perform. Needs to be refactored. # TODO: Hacky selection of operation to perform. Needs to be refactored.
if (init_image is not None) and (mask_image is not None): if self.sampler.conditioning_key() in ('hybrid','concat'):
print(f'** Inpainting model detected. Will try it! **')
generator = self._make_omnibus()
elif (init_image is not None) and (mask_image is not None):
generator = self._make_inpaint() generator = self._make_inpaint()
elif (embiggen != None or embiggen_tiles != None): elif (embiggen != None or embiggen_tiles != None):
generator = self._make_embiggen() generator = self._make_embiggen()
@ -690,6 +693,12 @@ class Generate:
self.generators['inpaint'] = Inpaint(self.model, self.precision) self.generators['inpaint'] = Inpaint(self.model, self.precision)
return self.generators['inpaint'] return self.generators['inpaint']
def _make_omnibus(self):
if not self.generators.get('omnibus'):
from ldm.invoke.generator.omnibus import Omnibus
self.generators['omnibus'] = Omnibus(self.model, self.precision)
return self.generators['omnibus']
def load_model(self): def load_model(self):
''' '''
preload model identified in self.model_name preload model identified in self.model_name

View File

@ -40,12 +40,13 @@ class Generator():
self.variation_amount = variation_amount self.variation_amount = variation_amount
self.with_variations = with_variations self.with_variations = with_variations
def generate(self,prompt,init_image,width,height,iterations=1,seed=None, def generate(self,prompt,init_image,width,height,sampler, iterations=1,seed=None,
image_callback=None, step_callback=None, threshold=0.0, perlin=0.0, image_callback=None, step_callback=None, threshold=0.0, perlin=0.0,
**kwargs): **kwargs):
scope = choose_autocast(self.precision) scope = choose_autocast(self.precision)
make_image = self.get_make_image( make_image = self.get_make_image(
prompt, prompt,
sampler = sampler,
init_image = init_image, init_image = init_image,
width = width, width = width,
height = height, height = height,
@ -54,13 +55,16 @@ class Generator():
perlin = perlin, perlin = perlin,
**kwargs **kwargs
) )
results = [] results = []
seed = seed if seed is not None else self.new_seed() seed = seed if seed is not None else self.new_seed()
first_seed = seed first_seed = seed
seed, initial_noise = self.generate_initial_noise(seed, width, height) seed, initial_noise = self.generate_initial_noise(seed, width, height)
with scope(self.model.device.type), self.model.ema_scope():
scope = (scope(self.model.device.type), self.model.ema_scope()) if sampler.conditioning_key() not in ('hybrid','concat') else scope(self.model.device.type)
with scope:
for n in trange(iterations, desc='Generating'): for n in trange(iterations, desc='Generating'):
print('DEBUG: in iterations loop() called')
x_T = None x_T = None
if self.variation_amount > 0: if self.variation_amount > 0:
seed_everything(seed) seed_everything(seed)
@ -75,7 +79,6 @@ class Generator():
x_T = self.get_noise(width,height) x_T = self.get_noise(width,height)
except: except:
pass pass
image = make_image(x_T) image = make_image(x_T)
results.append([image, seed]) results.append([image, seed])
if image_callback is not None: if image_callback is not None:
@ -83,10 +86,10 @@ class Generator():
seed = self.new_seed() seed = self.new_seed()
return results return results
def sample_to_image(self,samples): def sample_to_image(self,samples)->Image.Image:
""" """
Returns a function returning an image derived from the prompt and the initial image Given samples returned from a sampler, converts
Return value depends on the seed at the time you call it it into a PIL Image
""" """
x_samples = self.model.decode_first_stage(samples) x_samples = self.model.decode_first_stage(samples)
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)

View File

@ -13,6 +13,7 @@ import gc
import hashlib import hashlib
import psutil import psutil
import transformers import transformers
import traceback
import os import os
from sys import getrefcount from sys import getrefcount
from omegaconf import OmegaConf from omegaconf import OmegaConf
@ -73,6 +74,7 @@ class ModelCache(object):
self.models[model_name]['hash'] = hash self.models[model_name]['hash'] = hash
except Exception as e: except Exception as e:
print(f'** model {model_name} could not be loaded: {str(e)}') print(f'** model {model_name} could not be loaded: {str(e)}')
print(traceback.format_exc())
print(f'** restoring {self.current_model}') print(f'** restoring {self.current_model}')
self.get_model(self.current_model) self.get_model(self.current_model)
return None return None

View File

@ -66,7 +66,7 @@ class VQModel(pl.LightningModule):
self.use_ema = use_ema self.use_ema = use_ema
if self.use_ema: if self.use_ema:
self.model_ema = LitEma(self) self.model_ema = LitEma(self)
print(f'Keeping EMAs of {len(list(self.model_ema.buffers()))}.') print(f'>> Keeping EMAs of {len(list(self.model_ema.buffers()))}.')
if ckpt_path is not None: if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)

View File

@ -41,7 +41,19 @@ class DDIMSampler(Sampler):
else: else:
x_in = torch.cat([x] * 2) x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2) t_in = torch.cat([t] * 2)
c_in = torch.cat([unconditional_conditioning, c]) if isinstance(c, dict):
assert isinstance(unconditional_conditioning, dict)
c_in = dict()
for k in c:
if isinstance(c[k], list):
c_in[k] = [
torch.cat([unconditional_conditioning[k][i], c[k][i]])
for i in range(len(c[k]))
]
else:
c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])
else:
c_in = torch.cat([unconditional_conditioning, c])
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
e_t = e_t_uncond + unconditional_guidance_scale * ( e_t = e_t_uncond + unconditional_guidance_scale * (
e_t - e_t_uncond e_t - e_t_uncond

View File

@ -19,6 +19,7 @@ from functools import partial
from tqdm import tqdm from tqdm import tqdm
from torchvision.utils import make_grid from torchvision.utils import make_grid
from pytorch_lightning.utilities.distributed import rank_zero_only from pytorch_lightning.utilities.distributed import rank_zero_only
from omegaconf import ListConfig
import urllib import urllib
from ldm.util import ( from ldm.util import (
@ -120,7 +121,7 @@ class DDPM(pl.LightningModule):
self.use_ema = use_ema self.use_ema = use_ema
if self.use_ema: if self.use_ema:
self.model_ema = LitEma(self.model) self.model_ema = LitEma(self.model)
print(f'Keeping EMAs of {len(list(self.model_ema.buffers()))}.') print(f' | Keeping EMAs of {len(list(self.model_ema.buffers()))}.')
self.use_scheduler = scheduler_config is not None self.use_scheduler = scheduler_config is not None
if self.use_scheduler: if self.use_scheduler:
@ -1883,6 +1884,24 @@ class LatentDiffusion(DDPM):
return samples, intermediates return samples, intermediates
@torch.no_grad()
def get_unconditional_conditioning(self, batch_size, null_label=None):
if null_label is not None:
xc = null_label
if isinstance(xc, ListConfig):
xc = list(xc)
if isinstance(xc, dict) or isinstance(xc, list):
c = self.get_learned_conditioning(xc)
else:
if hasattr(xc, "to"):
xc = xc.to(self.device)
c = self.get_learned_conditioning(xc)
else:
# todo: get null label from cond_stage_model
raise NotImplementedError()
c = repeat(c, "1 ... -> b ...", b=batch_size).to(self.device)
return c
@torch.no_grad() @torch.no_grad()
def log_images( def log_images(
self, self,
@ -2138,6 +2157,7 @@ class DiffusionWrapper(pl.LightningModule):
] ]
def forward(self, x, t, c_concat: list = None, c_crossattn: list = None): def forward(self, x, t, c_concat: list = None, c_crossattn: list = None):
print(f'DEBUG (ddpm) c_concat = {c_concat}')
if self.conditioning_key is None: if self.conditioning_key is None:
out = self.diffusion_model(x, t) out = self.diffusion_model(x, t)
elif self.conditioning_key == 'concat': elif self.conditioning_key == 'concat':
@ -2147,8 +2167,8 @@ class DiffusionWrapper(pl.LightningModule):
cc = torch.cat(c_crossattn, 1) cc = torch.cat(c_crossattn, 1)
out = self.diffusion_model(x, t, context=cc) out = self.diffusion_model(x, t, context=cc)
elif self.conditioning_key == 'hybrid': elif self.conditioning_key == 'hybrid':
xc = torch.cat([x] + c_concat, dim=1)
cc = torch.cat(c_crossattn, 1) cc = torch.cat(c_crossattn, 1)
xc = torch.cat([x] + c_concat, dim=1)
out = self.diffusion_model(xc, t, context=cc) out = self.diffusion_model(xc, t, context=cc)
elif self.conditioning_key == 'adm': elif self.conditioning_key == 'adm':
cc = c_crossattn[0] cc = c_crossattn[0]
@ -2187,3 +2207,58 @@ class Layout2ImgDiffusion(LatentDiffusion):
cond_img = torch.stack(bbox_imgs, dim=0) cond_img = torch.stack(bbox_imgs, dim=0)
logs['bbox_image'] = cond_img logs['bbox_image'] = cond_img
return logs return logs
class LatentInpaintDiffusion(LatentDiffusion):
def __init__(
self,
concat_keys=("mask", "masked_image"),
masked_image_key="masked_image",
finetune_keys=None,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.masked_image_key = masked_image_key
assert self.masked_image_key in concat_keys
self.concat_keys = concat_keys
@torch.no_grad()
def get_input(
self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False
):
# note: restricted to non-trainable encoders currently
assert (
not self.cond_stage_trainable
), "trainable cond stages not yet supported for inpainting"
z, c, x, xrec, xc = super().get_input(
batch,
self.first_stage_key,
return_first_stage_outputs=True,
force_c_encode=True,
return_original_cond=True,
bs=bs,
)
assert exists(self.concat_keys)
c_cat = list()
for ck in self.concat_keys:
cc = (
rearrange(batch[ck], "b h w c -> b c h w")
.to(memory_format=torch.contiguous_format)
.float()
)
if bs is not None:
cc = cc[:bs]
cc = cc.to(self.device)
bchw = z.shape
if ck != self.masked_image_key:
cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
else:
cc = self.get_first_stage_encoding(self.encode_first_stage(cc))
c_cat.append(cc)
c_cat = torch.cat(c_cat, dim=1)
all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
if return_first_stage_outputs:
return z, all_conds, x, xrec, xc
return z, all_conds

View File

@ -281,3 +281,5 @@ class KSampler(Sampler):
''' '''
return self.model.inner_model.q_sample(x0,ts) return self.model.inner_model.q_sample(x0,ts)
def conditioning_key(self)->str:
return self.model.inner_model.model.conditioning_key

View File

@ -158,6 +158,18 @@ class Sampler(object):
**kwargs, **kwargs,
): ):
if conditioning is not None:
if isinstance(conditioning, dict):
ctmp = conditioning[list(conditioning.keys())[0]]
while isinstance(ctmp, list):
ctmp = ctmp[0]
cbs = ctmp.shape[0]
if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
else:
if conditioning.shape[0] != batch_size:
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
# check to see if make_schedule() has run, and if not, run it # check to see if make_schedule() has run, and if not, run it
if self.ddim_timesteps is None: if self.ddim_timesteps is None:
self.make_schedule( self.make_schedule(
@ -193,7 +205,7 @@ class Sampler(object):
) )
return samples, intermediates return samples, intermediates
#torch.no_grad() @torch.no_grad()
def do_sampling( def do_sampling(
self, self,
cond, cond,
@ -307,6 +319,19 @@ class Sampler(object):
mask = None, mask = None,
): ):
print(f'DEBUG(sampler): cond = {cond}')
if cond is not None:
if isinstance(cond, dict):
ctmp = cond[list(cond.keys())[0]]
while isinstance(ctmp, list):
ctmp = ctmp[0]
cbs = ctmp.shape[0]
if cbs != batch_size:
print(f"Warning: Got {cbs} conds but batch-size is {batch_size}")
else:
if cond.shape[0] != batch_size:
print(f"Warning: Got {cond.shape[0]} conditionings but batch-size is {batch_size}")
timesteps = ( timesteps = (
np.arange(self.ddpm_num_timesteps) np.arange(self.ddpm_num_timesteps)
if use_original_steps if use_original_steps
@ -411,3 +436,6 @@ class Sampler(object):
return self.model.inner_model.q_sample(x0,ts) return self.model.inner_model.q_sample(x0,ts)
''' '''
return self.model.q_sample(x0,ts) return self.model.q_sample(x0,ts)
def conditioning_key(self)->str:
return self.model.model.conditioning_key