resolve conflicts between PR #1108 and #1243

This commit is contained in:
Lincoln Stein
2022-10-26 15:37:24 -04:00
17 changed files with 444 additions and 38 deletions

View File

@ -13,6 +13,13 @@ stable-diffusion-1.4:
width: 512 width: 512
height: 512 height: 512
default: true default: true
inpainting-1.5:
description: runwayML tuned inpainting model v1.5
weights: models/ldm/stable-diffusion-v1/sd-v1-5-inpainting.ckpt
config: configs/stable-diffusion/v1-inpainting-inference.yaml
# vae: models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt
width: 512
height: 512
stable-diffusion-1.5: stable-diffusion-1.5:
config: configs/stable-diffusion/v1-inference.yaml config: configs/stable-diffusion/v1-inference.yaml
weights: models/ldm/stable-diffusion-v1/v1-5-pruned-emaonly.ckpt weights: models/ldm/stable-diffusion-v1/v1-5-pruned-emaonly.ckpt

View File

@ -0,0 +1,79 @@
model:
base_learning_rate: 7.5e-05
target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
image_size: 64
channels: 4
cond_stage_trainable: false # Note: different from the one we trained before
conditioning_key: hybrid # important
monitor: val/loss_simple_ema
scale_factor: 0.18215
finetune_keys: null
scheduler_config: # 10000 warmup steps
target: ldm.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [ 2500 ] # NOTE for resuming. use 10000 if starting from scratch
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
f_start: [ 1.e-6 ]
f_max: [ 1. ]
f_min: [ 1. ]
personalization_config:
target: ldm.modules.embedding_manager.EmbeddingManager
params:
placeholder_strings: ["*"]
initializer_words: ['face', 'man', 'photo', 'africanmale']
per_image_tokens: false
num_vectors_per_token: 1
progressive_words: False
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 32 # unused
in_channels: 9 # 4 data + 4 downscaled image + 1 mask
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: True
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder

View File

@ -421,7 +421,10 @@ class Generate:
) )
# TODO: Hacky selection of operation to perform. Needs to be refactored. # TODO: Hacky selection of operation to perform. Needs to be refactored.
if (init_image is not None) and (mask_image is not None): if self.sampler.conditioning_key() in ('hybrid','concat'):
print(f'** Inpainting model detected. Will try it! **')
generator = self._make_omnibus()
elif (init_image is not None) and (mask_image is not None):
generator = self._make_inpaint() generator = self._make_inpaint()
elif (embiggen != None or embiggen_tiles != None): elif (embiggen != None or embiggen_tiles != None):
generator = self._make_embiggen() generator = self._make_embiggen()
@ -677,6 +680,7 @@ class Generate:
return init_image,init_mask return init_image,init_mask
# lots o' repeated code here! Turn into a make_func()
def _make_base(self): def _make_base(self):
if not self.generators.get('base'): if not self.generators.get('base'):
from ldm.invoke.generator import Generator from ldm.invoke.generator import Generator
@ -687,6 +691,7 @@ class Generate:
if not self.generators.get('img2img'): if not self.generators.get('img2img'):
from ldm.invoke.generator.img2img import Img2Img from ldm.invoke.generator.img2img import Img2Img
self.generators['img2img'] = Img2Img(self.model, self.precision) self.generators['img2img'] = Img2Img(self.model, self.precision)
self.generators['img2img'].free_gpu_mem = self.free_gpu_mem
return self.generators['img2img'] return self.generators['img2img']
def _make_embiggen(self): def _make_embiggen(self):
@ -715,6 +720,15 @@ class Generate:
self.generators['inpaint'] = Inpaint(self.model, self.precision) self.generators['inpaint'] = Inpaint(self.model, self.precision)
return self.generators['inpaint'] return self.generators['inpaint']
# "omnibus" supports the runwayML custom inpainting model, which does
# txt2img, img2img and inpainting using slight variations on the same code
def _make_omnibus(self):
if not self.generators.get('omnibus'):
from ldm.invoke.generator.omnibus import Omnibus
self.generators['omnibus'] = Omnibus(self.model, self.precision)
self.generators['omnibus'].free_gpu_mem = self.free_gpu_mem
return self.generators['omnibus']
def load_model(self): def load_model(self):
''' '''
preload model identified in self.model_name preload model identified in self.model_name

View File

@ -181,7 +181,9 @@ class Args(object):
switches_started = False switches_started = False
for element in elements: for element in elements:
if element[0] == '-' and not switches_started: if len(element) == 0: # empty prompt
pass
elif element[0] == '-' and not switches_started:
switches_started = True switches_started = True
if switches_started: if switches_started:
switches.append(element) switches.append(element)

View File

@ -123,8 +123,8 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n
else: else:
conditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt, log_tokens=log_tokens) conditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt, log_tokens=log_tokens)
unconditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, parsed_negative_prompt, log_tokens=log_tokens) unconditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, parsed_negative_prompt, log_tokens=log_tokens)
conditioning = flatten_hybrid_conditioning(unconditioning, conditioning)
return ( return (
unconditioning, conditioning, InvokeAIDiffuserComponent.ExtraConditioningInfo( unconditioning, conditioning, InvokeAIDiffuserComponent.ExtraConditioningInfo(
cross_attention_control_args=cac_args cross_attention_control_args=cac_args
@ -166,4 +166,25 @@ def get_tokens_length(model, fragments: list[Fragment]):
tokens = model.cond_stage_model.get_tokens(fragment_texts, include_start_and_end_markers=False) tokens = model.cond_stage_model.get_tokens(fragment_texts, include_start_and_end_markers=False)
return sum([len(x) for x in tokens]) return sum([len(x) for x in tokens])
def flatten_hybrid_conditioning(uncond, cond):
'''
This handles the choice between a conditional conditioning
that is a tensor (used by cross attention) vs one that has additional
dimensions as well, as used by 'hybrid'
'''
if isinstance(cond, dict):
assert isinstance(uncond, dict)
cond_in = dict()
for k in cond:
if isinstance(cond[k], list):
cond_in[k] = [
torch.cat([uncond[k][i], cond[k][i]])
for i in range(len(cond[k]))
]
else:
cond_in[k] = torch.cat([uncond[k], cond[k]])
return cond_in
else:
return cond

View File

@ -6,6 +6,7 @@ import torch
import numpy as np import numpy as np
import random import random
import os import os
import traceback
from tqdm import tqdm, trange from tqdm import tqdm, trange
from PIL import Image, ImageFilter from PIL import Image, ImageFilter
from einops import rearrange, repeat from einops import rearrange, repeat
@ -43,7 +44,7 @@ class Generator():
self.variation_amount = variation_amount self.variation_amount = variation_amount
self.with_variations = with_variations self.with_variations = with_variations
def generate(self,prompt,init_image,width,height,iterations=1,seed=None, def generate(self,prompt,init_image,width,height,sampler, iterations=1,seed=None,
image_callback=None, step_callback=None, threshold=0.0, perlin=0.0, image_callback=None, step_callback=None, threshold=0.0, perlin=0.0,
safety_checker:dict=None, safety_checker:dict=None,
**kwargs): **kwargs):
@ -51,6 +52,7 @@ class Generator():
self.safety_checker = safety_checker self.safety_checker = safety_checker
make_image = self.get_make_image( make_image = self.get_make_image(
prompt, prompt,
sampler = sampler,
init_image = init_image, init_image = init_image,
width = width, width = width,
height = height, height = height,
@ -59,12 +61,14 @@ class Generator():
perlin = perlin, perlin = perlin,
**kwargs **kwargs
) )
results = [] results = []
seed = seed if seed is not None else self.new_seed() seed = seed if seed is not None else self.new_seed()
first_seed = seed first_seed = seed
seed, initial_noise = self.generate_initial_noise(seed, width, height) seed, initial_noise = self.generate_initial_noise(seed, width, height)
with scope(self.model.device.type), self.model.ema_scope():
# There used to be an additional self.model.ema_scope() here, but it breaks
# the inpaint-1.5 model. Not sure what it did.... ?
with scope(self.model.device.type):
for n in trange(iterations, desc='Generating'): for n in trange(iterations, desc='Generating'):
x_T = None x_T = None
if self.variation_amount > 0: if self.variation_amount > 0:
@ -79,7 +83,8 @@ class Generator():
try: try:
x_T = self.get_noise(width,height) x_T = self.get_noise(width,height)
except: except:
pass print('** An error occurred while getting initial noise **')
print(traceback.format_exc())
image = make_image(x_T) image = make_image(x_T)
@ -95,10 +100,10 @@ class Generator():
return results return results
def sample_to_image(self,samples): def sample_to_image(self,samples)->Image.Image:
""" """
Returns a function returning an image derived from the prompt and the initial image Given samples returned from a sampler, converts
Return value depends on the seed at the time you call it it into a PIL Image
""" """
x_samples = self.model.decode_first_stage(samples) x_samples = self.model.decode_first_stage(samples)
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)

View File

@ -15,7 +15,7 @@ from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserCompo
class Img2Img(Generator): class Img2Img(Generator):
def __init__(self, model, precision): def __init__(self, model, precision):
super().__init__(model, precision) super().__init__(model, precision)
self.init_latent = None # by get_noise() self.init_latent = None # by get_noise()
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta, def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
conditioning,init_image,strength,step_callback=None,threshold=0.0,perlin=0.0,**kwargs): conditioning,init_image,strength,step_callback=None,threshold=0.0,perlin=0.0,**kwargs):
@ -80,7 +80,10 @@ class Img2Img(Generator):
def _image_to_tensor(self, image:Image, normalize:bool=True)->Tensor: def _image_to_tensor(self, image:Image, normalize:bool=True)->Tensor:
image = np.array(image).astype(np.float32) / 255.0 image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2) if len(image.shape) == 2: # 'L' image, as in a mask
image = image[None,None]
else: # 'RGB' image
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image) image = torch.from_numpy(image)
if normalize: if normalize:
image = 2.0 * image - 1.0 image = 2.0 * image - 1.0

View File

@ -0,0 +1,151 @@
"""omnibus module to be used with the runwayml 9-channel custom inpainting model"""
import torch
import numpy as np
from einops import repeat
from PIL import Image, ImageOps
from ldm.invoke.devices import choose_autocast
from ldm.invoke.generator.base import downsampling
from ldm.invoke.generator.img2img import Img2Img
from ldm.invoke.generator.txt2img import Txt2Img
class Omnibus(Img2Img,Txt2Img):
def __init__(self, model, precision):
super().__init__(model, precision)
def get_make_image(
self,
prompt,
sampler,
steps,
cfg_scale,
ddim_eta,
conditioning,
width,
height,
init_image = None,
mask_image = None,
strength = None,
step_callback=None,
threshold=0.0,
perlin=0.0,
**kwargs):
"""
Returns a function returning an image derived from the prompt and the initial image
Return value depends on the seed at the time you call it.
"""
self.perlin = perlin
num_samples = 1
sampler.make_schedule(
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
)
if isinstance(init_image, Image.Image):
init_image = self._image_to_tensor(init_image)
if isinstance(mask_image, Image.Image):
mask_image = self._image_to_tensor(ImageOps.invert(mask_image).convert('L'),normalize=False)
t_enc = steps
if init_image is not None and mask_image is not None: # inpainting
masked_image = init_image * (1 - mask_image) # masked image is the image masked by mask - masked regions zero
elif init_image is not None: # img2img
scope = choose_autocast(self.precision)
with scope(self.model.device.type):
self.init_latent = self.model.get_first_stage_encoding(
self.model.encode_first_stage(init_image)
) # move to latent space
# create a completely black mask (1s)
mask_image = torch.ones(1, 1, init_image.shape[2], init_image.shape[3], device=self.model.device)
# and the masked image is just a copy of the original
masked_image = init_image
else: # txt2img
init_image = torch.zeros(1, 3, height, width, device=self.model.device)
mask_image = torch.ones(1, 1, height, width, device=self.model.device)
masked_image = init_image
self.init_latent = init_image
height = init_image.shape[2]
width = init_image.shape[3]
model = self.model
def make_image(x_T):
with torch.no_grad():
scope = choose_autocast(self.precision)
with scope(self.model.device.type):
batch = self.make_batch_sd(
init_image,
mask_image,
masked_image,
prompt=prompt,
device=model.device,
num_samples=num_samples,
)
c = model.cond_stage_model.encode(batch["txt"])
c_cat = list()
for ck in model.concat_keys:
cc = batch[ck].float()
if ck != model.masked_image_key:
bchw = [num_samples, 4, height//8, width//8]
cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
else:
cc = model.get_first_stage_encoding(model.encode_first_stage(cc))
c_cat.append(cc)
c_cat = torch.cat(c_cat, dim=1)
# cond
cond={"c_concat": [c_cat], "c_crossattn": [c]}
# uncond cond
uc_cross = model.get_unconditional_conditioning(num_samples, "")
uc_full = {"c_concat": [c_cat], "c_crossattn": [uc_cross]}
shape = [model.channels, height//8, width//8]
samples, _ = sampler.sample(
batch_size = 1,
S = steps,
x_T = x_T,
conditioning = cond,
shape = shape,
verbose = False,
unconditional_guidance_scale = cfg_scale,
unconditional_conditioning = uc_full,
eta = 1.0,
img_callback = step_callback,
threshold = threshold,
)
if self.free_gpu_mem:
self.model.model.to("cpu")
return self.sample_to_image(samples)
return make_image
def make_batch_sd(
self,
image,
mask,
masked_image,
prompt,
device,
num_samples=1):
batch = {
"image": repeat(image.to(device=device), "1 ... -> n ...", n=num_samples),
"txt": num_samples * [prompt],
"mask": repeat(mask.to(device=device), "1 ... -> n ...", n=num_samples),
"masked_image": repeat(masked_image.to(device=device), "1 ... -> n ...", n=num_samples),
}
return batch
def get_noise(self, width:int, height:int):
if self.init_latent is not None:
height = self.init_latent.shape[2]
width = self.init_latent.shape[3]
return Txt2Img.get_noise(self,width,height)

View File

@ -13,6 +13,7 @@ import gc
import hashlib import hashlib
import psutil import psutil
import transformers import transformers
import traceback
import os import os
from sys import getrefcount from sys import getrefcount
from omegaconf import OmegaConf from omegaconf import OmegaConf
@ -73,6 +74,7 @@ class ModelCache(object):
self.models[model_name]['hash'] = hash self.models[model_name]['hash'] = hash
except Exception as e: except Exception as e:
print(f'** model {model_name} could not be loaded: {str(e)}') print(f'** model {model_name} could not be loaded: {str(e)}')
print(traceback.format_exc())
print(f'** restoring {self.current_model}') print(f'** restoring {self.current_model}')
self.get_model(self.current_model) self.get_model(self.current_model)
return None return None

View File

@ -89,6 +89,9 @@ class Outcrop(object):
def _extend(self,image:Image,pixels:int)-> Image: def _extend(self,image:Image,pixels:int)-> Image:
extended_img = Image.new('RGBA',(image.width,image.height+pixels)) extended_img = Image.new('RGBA',(image.width,image.height+pixels))
mask_height = pixels if self.generate.model.model.conditioning_key in ('hybrid','concat') \
else pixels *2
# first paste places old image at top of extended image, stretch # first paste places old image at top of extended image, stretch
# it, and applies a gaussian blur to it # it, and applies a gaussian blur to it
# take the top half region, stretch and paste it # take the top half region, stretch and paste it
@ -105,7 +108,9 @@ class Outcrop(object):
# now make the top part transparent to use as a mask # now make the top part transparent to use as a mask
alpha = extended_img.getchannel('A') alpha = extended_img.getchannel('A')
alpha.paste(0,(0,0,extended_img.width,pixels*2)) alpha.paste(0,(0,0,extended_img.width,mask_height))
extended_img.putalpha(alpha) extended_img.putalpha(alpha)
extended_img.save('outputs/curly_extended.png')
return extended_img return extended_img

View File

@ -66,7 +66,7 @@ class VQModel(pl.LightningModule):
self.use_ema = use_ema self.use_ema = use_ema
if self.use_ema: if self.use_ema:
self.model_ema = LitEma(self) self.model_ema = LitEma(self)
print(f'Keeping EMAs of {len(list(self.model_ema.buffers()))}.') print(f'>> Keeping EMAs of {len(list(self.model_ema.buffers()))}.')
if ckpt_path is not None: if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)

View File

@ -53,12 +53,14 @@ class DDIMSampler(Sampler):
# damian0815 would like to know when/if this code path is used # damian0815 would like to know when/if this code path is used
e_t = self.model.apply_model(x, t, c) e_t = self.model.apply_model(x, t, c)
else: else:
# step_index counts in the opposite direction to index
step_index = step_count-(index+1) step_index = step_count-(index+1)
e_t = self.invokeai_diffuser.do_diffusion_step(x, t, e_t = self.invokeai_diffuser.do_diffusion_step(
unconditional_conditioning, c, x, t,
unconditional_guidance_scale, unconditional_conditioning, c,
step_index=step_index) unconditional_guidance_scale,
step_index=step_index
)
if score_corrector is not None: if score_corrector is not None:
assert self.model.parameterization == 'eps' assert self.model.parameterization == 'eps'
e_t = score_corrector.modify_score( e_t = score_corrector.modify_score(

View File

@ -19,6 +19,7 @@ from functools import partial
from tqdm import tqdm from tqdm import tqdm
from torchvision.utils import make_grid from torchvision.utils import make_grid
from pytorch_lightning.utilities.distributed import rank_zero_only from pytorch_lightning.utilities.distributed import rank_zero_only
from omegaconf import ListConfig
import urllib import urllib
from ldm.util import ( from ldm.util import (
@ -120,7 +121,7 @@ class DDPM(pl.LightningModule):
self.use_ema = use_ema self.use_ema = use_ema
if self.use_ema: if self.use_ema:
self.model_ema = LitEma(self.model) self.model_ema = LitEma(self.model)
print(f'Keeping EMAs of {len(list(self.model_ema.buffers()))}.') print(f' | Keeping EMAs of {len(list(self.model_ema.buffers()))}.')
self.use_scheduler = scheduler_config is not None self.use_scheduler = scheduler_config is not None
if self.use_scheduler: if self.use_scheduler:
@ -1883,6 +1884,24 @@ class LatentDiffusion(DDPM):
return samples, intermediates return samples, intermediates
@torch.no_grad()
def get_unconditional_conditioning(self, batch_size, null_label=None):
if null_label is not None:
xc = null_label
if isinstance(xc, ListConfig):
xc = list(xc)
if isinstance(xc, dict) or isinstance(xc, list):
c = self.get_learned_conditioning(xc)
else:
if hasattr(xc, "to"):
xc = xc.to(self.device)
c = self.get_learned_conditioning(xc)
else:
# todo: get null label from cond_stage_model
raise NotImplementedError()
c = repeat(c, "1 ... -> b ...", b=batch_size).to(self.device)
return c
@torch.no_grad() @torch.no_grad()
def log_images( def log_images(
self, self,
@ -2147,8 +2166,8 @@ class DiffusionWrapper(pl.LightningModule):
cc = torch.cat(c_crossattn, 1) cc = torch.cat(c_crossattn, 1)
out = self.diffusion_model(x, t, context=cc) out = self.diffusion_model(x, t, context=cc)
elif self.conditioning_key == 'hybrid': elif self.conditioning_key == 'hybrid':
xc = torch.cat([x] + c_concat, dim=1)
cc = torch.cat(c_crossattn, 1) cc = torch.cat(c_crossattn, 1)
xc = torch.cat([x] + c_concat, dim=1)
out = self.diffusion_model(xc, t, context=cc) out = self.diffusion_model(xc, t, context=cc)
elif self.conditioning_key == 'adm': elif self.conditioning_key == 'adm':
cc = c_crossattn[0] cc = c_crossattn[0]
@ -2187,3 +2206,58 @@ class Layout2ImgDiffusion(LatentDiffusion):
cond_img = torch.stack(bbox_imgs, dim=0) cond_img = torch.stack(bbox_imgs, dim=0)
logs['bbox_image'] = cond_img logs['bbox_image'] = cond_img
return logs return logs
class LatentInpaintDiffusion(LatentDiffusion):
def __init__(
self,
concat_keys=("mask", "masked_image"),
masked_image_key="masked_image",
finetune_keys=None,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.masked_image_key = masked_image_key
assert self.masked_image_key in concat_keys
self.concat_keys = concat_keys
@torch.no_grad()
def get_input(
self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False
):
# note: restricted to non-trainable encoders currently
assert (
not self.cond_stage_trainable
), "trainable cond stages not yet supported for inpainting"
z, c, x, xrec, xc = super().get_input(
batch,
self.first_stage_key,
return_first_stage_outputs=True,
force_c_encode=True,
return_original_cond=True,
bs=bs,
)
assert exists(self.concat_keys)
c_cat = list()
for ck in self.concat_keys:
cc = (
rearrange(batch[ck], "b h w c -> b c h w")
.to(memory_format=torch.contiguous_format)
.float()
)
if bs is not None:
cc = cc[:bs]
cc = cc.to(self.device)
bchw = z.shape
if ck != self.masked_image_key:
cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
else:
cc = self.get_first_stage_encoding(self.encode_first_stage(cc))
c_cat.append(cc)
c_cat = torch.cat(c_cat, dim=1)
all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
if return_first_stage_outputs:
return z, all_conds, x, xrec, xc
return z, all_conds

View File

@ -23,9 +23,10 @@ def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7):
class CFGDenoiser(nn.Module): class CFGDenoiser(nn.Module):
def __init__(self, model, threshold = 0, warmup = 0): def __init__(self, sampler, threshold = 0, warmup = 0):
super().__init__() super().__init__()
self.inner_model = model self.inner_model = sampler.model
self.sampler = sampler
self.threshold = threshold self.threshold = threshold
self.warmup_max = warmup self.warmup_max = warmup
self.warmup = max(warmup / 10, 1) self.warmup = max(warmup / 10, 1)
@ -43,10 +44,14 @@ class CFGDenoiser(nn.Module):
def forward(self, x, sigma, uncond, cond, cond_scale): def forward(self, x, sigma, uncond, cond, cond_scale):
if isinstance(cond,dict): # hybrid model
next_x = self.invokeai_diffuser.do_diffusion_step(x, sigma, uncond, cond, cond_scale) x_in = torch.cat([x] * 2)
sigma_in = torch.cat([sigma] * 2)
# apply threshold cond_in = self.sampler.make_cond_in(uncond,cond)
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
next_x = uncond + (cond - uncond) * cond_scale
else: # cross attention model
next_x = self.invokeai_diffuser.do_diffusion_step(x, sigma, uncond, cond, cond_scale)
if self.warmup < self.warmup_max: if self.warmup < self.warmup_max:
thresh = max(1, 1 + (self.threshold - 1) * (self.warmup / self.warmup_max)) thresh = max(1, 1 + (self.threshold - 1) * (self.warmup / self.warmup_max))
self.warmup += 1 self.warmup += 1
@ -56,8 +61,6 @@ class CFGDenoiser(nn.Module):
thresh = self.threshold thresh = self.threshold
return cfg_apply_threshold(next_x, thresh) return cfg_apply_threshold(next_x, thresh)
class KSampler(Sampler): class KSampler(Sampler):
def __init__(self, model, schedule='lms', device=None, **kwargs): def __init__(self, model, schedule='lms', device=None, **kwargs):
denoiser = K.external.CompVisDenoiser(model) denoiser = K.external.CompVisDenoiser(model)
@ -286,3 +289,6 @@ class KSampler(Sampler):
''' '''
return self.model.inner_model.q_sample(x0,ts) return self.model.inner_model.q_sample(x0,ts)
def conditioning_key(self)->str:
return self.model.inner_model.model.conditioning_key

View File

@ -14,9 +14,6 @@ class PLMSSampler(Sampler):
def __init__(self, model, schedule='linear', device=None, **kwargs): def __init__(self, model, schedule='linear', device=None, **kwargs):
super().__init__(model,schedule,model.num_timesteps, device) super().__init__(model,schedule,model.num_timesteps, device)
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.model,
model_forward_callback = lambda x, sigma, cond: self.model.apply_model(x, sigma, cond))
def prepare_to_sample(self, t_enc, **kwargs): def prepare_to_sample(self, t_enc, **kwargs):
super().prepare_to_sample(t_enc, **kwargs) super().prepare_to_sample(t_enc, **kwargs)
@ -67,7 +64,6 @@ class PLMSSampler(Sampler):
unconditional_conditioning, c, unconditional_conditioning, c,
unconditional_guidance_scale, unconditional_guidance_scale,
step_index=step_index) step_index=step_index)
if score_corrector is not None: if score_corrector is not None:
assert self.model.parameterization == 'eps' assert self.model.parameterization == 'eps'
e_t = score_corrector.modify_score( e_t = score_corrector.modify_score(

View File

@ -11,6 +11,7 @@ import numpy as np
from tqdm import tqdm from tqdm import tqdm
from functools import partial from functools import partial
from ldm.invoke.devices import choose_torch_device from ldm.invoke.devices import choose_torch_device
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
from ldm.modules.diffusionmodules.util import ( from ldm.modules.diffusionmodules.util import (
make_ddim_sampling_parameters, make_ddim_sampling_parameters,
@ -26,6 +27,8 @@ class Sampler(object):
self.ddpm_num_timesteps = steps self.ddpm_num_timesteps = steps
self.schedule = schedule self.schedule = schedule
self.device = device or choose_torch_device() self.device = device or choose_torch_device()
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.model,
model_forward_callback = lambda x, sigma, cond: self.model.apply_model(x, sigma, cond))
def register_buffer(self, name, attr): def register_buffer(self, name, attr):
if type(attr) == torch.Tensor: if type(attr) == torch.Tensor:
@ -160,6 +163,18 @@ class Sampler(object):
**kwargs, **kwargs,
): ):
if conditioning is not None:
if isinstance(conditioning, dict):
ctmp = conditioning[list(conditioning.keys())[0]]
while isinstance(ctmp, list):
ctmp = ctmp[0]
cbs = ctmp.shape[0]
if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
else:
if conditioning.shape[0] != batch_size:
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
# check to see if make_schedule() has run, and if not, run it # check to see if make_schedule() has run, and if not, run it
if self.ddim_timesteps is None: if self.ddim_timesteps is None:
self.make_schedule( self.make_schedule(
@ -196,7 +211,7 @@ class Sampler(object):
) )
return samples, intermediates return samples, intermediates
#torch.no_grad() @torch.no_grad()
def do_sampling( def do_sampling(
self, self,
cond, cond,
@ -257,6 +272,7 @@ class Sampler(object):
) )
if mask is not None: if mask is not None:
print('DEBUG: in masking routine')
assert x0 is not None assert x0 is not None
img_orig = self.model.q_sample( img_orig = self.model.q_sample(
x0, ts x0, ts
@ -313,7 +329,6 @@ class Sampler(object):
all_timesteps_count = None, all_timesteps_count = None,
**kwargs **kwargs
): ):
timesteps = ( timesteps = (
np.arange(self.ddpm_num_timesteps) np.arange(self.ddpm_num_timesteps)
if use_original_steps if use_original_steps
@ -420,3 +435,27 @@ class Sampler(object):
''' '''
return self.model.q_sample(x0,ts) return self.model.q_sample(x0,ts)
def conditioning_key(self)->str:
return self.model.model.conditioning_key
# def make_cond_in(self, uncond, cond):
# '''
# This handles the choice between a conditional conditioning
# that is a tensor (used by cross attention) vs one that is a dict
# used by 'hybrid'
# '''
# if isinstance(cond, dict):
# assert isinstance(uncond, dict)
# cond_in = dict()
# for k in cond:
# if isinstance(cond[k], list):
# cond_in[k] = [
# torch.cat([uncond[k][i], cond[k][i]])
# for i in range(len(cond[k]))
# ]
# else:
# cond_in[k] = torch.cat([uncond[k], cond[k]])
# else:
# cond_in = torch.cat([uncond, cond])
# return cond_in

View File

@ -171,9 +171,9 @@ def main_loop(gen, opt):
except (OSError, AttributeError, KeyError): except (OSError, AttributeError, KeyError):
pass pass
if len(opt.prompt) == 0: # if len(opt.prompt) == 0:
print('\nTry again with a prompt!') # print('\nTry again with a prompt!')
continue # continue
# width and height are set by model if not specified # width and height are set by model if not specified
if not opt.width: if not opt.width: