mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
@ -13,6 +13,13 @@ stable-diffusion-1.4:
|
|||||||
width: 512
|
width: 512
|
||||||
height: 512
|
height: 512
|
||||||
default: true
|
default: true
|
||||||
|
inpainting-1.5:
|
||||||
|
description: runwayML tuned inpainting model v1.5
|
||||||
|
weights: models/ldm/stable-diffusion-v1/sd-v1-5-inpainting.ckpt
|
||||||
|
config: configs/stable-diffusion/v1-inpainting-inference.yaml
|
||||||
|
# vae: models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt
|
||||||
|
width: 512
|
||||||
|
height: 512
|
||||||
stable-diffusion-1.5:
|
stable-diffusion-1.5:
|
||||||
config: configs/stable-diffusion/v1-inference.yaml
|
config: configs/stable-diffusion/v1-inference.yaml
|
||||||
weights: models/ldm/stable-diffusion-v1/v1-5-pruned-emaonly.ckpt
|
weights: models/ldm/stable-diffusion-v1/v1-5-pruned-emaonly.ckpt
|
||||||
|
79
configs/stable-diffusion/v1-inpainting-inference.yaml
Normal file
79
configs/stable-diffusion/v1-inpainting-inference.yaml
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
model:
|
||||||
|
base_learning_rate: 7.5e-05
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.0120
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: "jpg"
|
||||||
|
cond_stage_key: "txt"
|
||||||
|
image_size: 64
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: false # Note: different from the one we trained before
|
||||||
|
conditioning_key: hybrid # important
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
scale_factor: 0.18215
|
||||||
|
finetune_keys: null
|
||||||
|
|
||||||
|
scheduler_config: # 10000 warmup steps
|
||||||
|
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||||
|
params:
|
||||||
|
warm_up_steps: [ 2500 ] # NOTE for resuming. use 10000 if starting from scratch
|
||||||
|
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||||
|
f_start: [ 1.e-6 ]
|
||||||
|
f_max: [ 1. ]
|
||||||
|
f_min: [ 1. ]
|
||||||
|
|
||||||
|
personalization_config:
|
||||||
|
target: ldm.modules.embedding_manager.EmbeddingManager
|
||||||
|
params:
|
||||||
|
placeholder_strings: ["*"]
|
||||||
|
initializer_words: ['face', 'man', 'photo', 'africanmale']
|
||||||
|
per_image_tokens: false
|
||||||
|
num_vectors_per_token: 1
|
||||||
|
progressive_words: False
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 32 # unused
|
||||||
|
in_channels: 9 # 4 data + 4 downscaled image + 1 mask
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions: [ 4, 2, 1 ]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
|
num_heads: 8
|
||||||
|
use_spatial_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 768
|
||||||
|
use_checkpoint: True
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
@ -421,7 +421,10 @@ class Generate:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# TODO: Hacky selection of operation to perform. Needs to be refactored.
|
# TODO: Hacky selection of operation to perform. Needs to be refactored.
|
||||||
if (init_image is not None) and (mask_image is not None):
|
if self.sampler.conditioning_key() in ('hybrid','concat'):
|
||||||
|
print(f'** Inpainting model detected. Will try it! **')
|
||||||
|
generator = self._make_omnibus()
|
||||||
|
elif (init_image is not None) and (mask_image is not None):
|
||||||
generator = self._make_inpaint()
|
generator = self._make_inpaint()
|
||||||
elif (embiggen != None or embiggen_tiles != None):
|
elif (embiggen != None or embiggen_tiles != None):
|
||||||
generator = self._make_embiggen()
|
generator = self._make_embiggen()
|
||||||
@ -677,6 +680,7 @@ class Generate:
|
|||||||
|
|
||||||
return init_image,init_mask
|
return init_image,init_mask
|
||||||
|
|
||||||
|
# lots o' repeated code here! Turn into a make_func()
|
||||||
def _make_base(self):
|
def _make_base(self):
|
||||||
if not self.generators.get('base'):
|
if not self.generators.get('base'):
|
||||||
from ldm.invoke.generator import Generator
|
from ldm.invoke.generator import Generator
|
||||||
@ -687,6 +691,7 @@ class Generate:
|
|||||||
if not self.generators.get('img2img'):
|
if not self.generators.get('img2img'):
|
||||||
from ldm.invoke.generator.img2img import Img2Img
|
from ldm.invoke.generator.img2img import Img2Img
|
||||||
self.generators['img2img'] = Img2Img(self.model, self.precision)
|
self.generators['img2img'] = Img2Img(self.model, self.precision)
|
||||||
|
self.generators['img2img'].free_gpu_mem = self.free_gpu_mem
|
||||||
return self.generators['img2img']
|
return self.generators['img2img']
|
||||||
|
|
||||||
def _make_embiggen(self):
|
def _make_embiggen(self):
|
||||||
@ -715,6 +720,15 @@ class Generate:
|
|||||||
self.generators['inpaint'] = Inpaint(self.model, self.precision)
|
self.generators['inpaint'] = Inpaint(self.model, self.precision)
|
||||||
return self.generators['inpaint']
|
return self.generators['inpaint']
|
||||||
|
|
||||||
|
# "omnibus" supports the runwayML custom inpainting model, which does
|
||||||
|
# txt2img, img2img and inpainting using slight variations on the same code
|
||||||
|
def _make_omnibus(self):
|
||||||
|
if not self.generators.get('omnibus'):
|
||||||
|
from ldm.invoke.generator.omnibus import Omnibus
|
||||||
|
self.generators['omnibus'] = Omnibus(self.model, self.precision)
|
||||||
|
self.generators['omnibus'].free_gpu_mem = self.free_gpu_mem
|
||||||
|
return self.generators['omnibus']
|
||||||
|
|
||||||
def load_model(self):
|
def load_model(self):
|
||||||
'''
|
'''
|
||||||
preload model identified in self.model_name
|
preload model identified in self.model_name
|
||||||
|
@ -181,7 +181,9 @@ class Args(object):
|
|||||||
switches_started = False
|
switches_started = False
|
||||||
|
|
||||||
for element in elements:
|
for element in elements:
|
||||||
if element[0] == '-' and not switches_started:
|
if len(element) == 0: # empty prompt
|
||||||
|
pass
|
||||||
|
elif element[0] == '-' and not switches_started:
|
||||||
switches_started = True
|
switches_started = True
|
||||||
if switches_started:
|
if switches_started:
|
||||||
switches.append(element)
|
switches.append(element)
|
||||||
|
@ -123,8 +123,8 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n
|
|||||||
else:
|
else:
|
||||||
conditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt, log_tokens=log_tokens)
|
conditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt, log_tokens=log_tokens)
|
||||||
|
|
||||||
|
|
||||||
unconditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, parsed_negative_prompt, log_tokens=log_tokens)
|
unconditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, parsed_negative_prompt, log_tokens=log_tokens)
|
||||||
|
conditioning = flatten_hybrid_conditioning(unconditioning, conditioning)
|
||||||
return (
|
return (
|
||||||
unconditioning, conditioning, InvokeAIDiffuserComponent.ExtraConditioningInfo(
|
unconditioning, conditioning, InvokeAIDiffuserComponent.ExtraConditioningInfo(
|
||||||
cross_attention_control_args=cac_args
|
cross_attention_control_args=cac_args
|
||||||
@ -166,4 +166,25 @@ def get_tokens_length(model, fragments: list[Fragment]):
|
|||||||
tokens = model.cond_stage_model.get_tokens(fragment_texts, include_start_and_end_markers=False)
|
tokens = model.cond_stage_model.get_tokens(fragment_texts, include_start_and_end_markers=False)
|
||||||
return sum([len(x) for x in tokens])
|
return sum([len(x) for x in tokens])
|
||||||
|
|
||||||
|
def flatten_hybrid_conditioning(uncond, cond):
|
||||||
|
'''
|
||||||
|
This handles the choice between a conditional conditioning
|
||||||
|
that is a tensor (used by cross attention) vs one that has additional
|
||||||
|
dimensions as well, as used by 'hybrid'
|
||||||
|
'''
|
||||||
|
if isinstance(cond, dict):
|
||||||
|
assert isinstance(uncond, dict)
|
||||||
|
cond_in = dict()
|
||||||
|
for k in cond:
|
||||||
|
if isinstance(cond[k], list):
|
||||||
|
cond_in[k] = [
|
||||||
|
torch.cat([uncond[k][i], cond[k][i]])
|
||||||
|
for i in range(len(cond[k]))
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
cond_in[k] = torch.cat([uncond[k], cond[k]])
|
||||||
|
return cond_in
|
||||||
|
else:
|
||||||
|
return cond
|
||||||
|
|
||||||
|
|
||||||
|
@ -6,6 +6,7 @@ import torch
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import random
|
import random
|
||||||
import os
|
import os
|
||||||
|
import traceback
|
||||||
from tqdm import tqdm, trange
|
from tqdm import tqdm, trange
|
||||||
from PIL import Image, ImageFilter
|
from PIL import Image, ImageFilter
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
@ -43,7 +44,7 @@ class Generator():
|
|||||||
self.variation_amount = variation_amount
|
self.variation_amount = variation_amount
|
||||||
self.with_variations = with_variations
|
self.with_variations = with_variations
|
||||||
|
|
||||||
def generate(self,prompt,init_image,width,height,iterations=1,seed=None,
|
def generate(self,prompt,init_image,width,height,sampler, iterations=1,seed=None,
|
||||||
image_callback=None, step_callback=None, threshold=0.0, perlin=0.0,
|
image_callback=None, step_callback=None, threshold=0.0, perlin=0.0,
|
||||||
safety_checker:dict=None,
|
safety_checker:dict=None,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
@ -51,6 +52,7 @@ class Generator():
|
|||||||
self.safety_checker = safety_checker
|
self.safety_checker = safety_checker
|
||||||
make_image = self.get_make_image(
|
make_image = self.get_make_image(
|
||||||
prompt,
|
prompt,
|
||||||
|
sampler = sampler,
|
||||||
init_image = init_image,
|
init_image = init_image,
|
||||||
width = width,
|
width = width,
|
||||||
height = height,
|
height = height,
|
||||||
@ -59,12 +61,14 @@ class Generator():
|
|||||||
perlin = perlin,
|
perlin = perlin,
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
seed = seed if seed is not None else self.new_seed()
|
seed = seed if seed is not None else self.new_seed()
|
||||||
first_seed = seed
|
first_seed = seed
|
||||||
seed, initial_noise = self.generate_initial_noise(seed, width, height)
|
seed, initial_noise = self.generate_initial_noise(seed, width, height)
|
||||||
with scope(self.model.device.type), self.model.ema_scope():
|
|
||||||
|
# There used to be an additional self.model.ema_scope() here, but it breaks
|
||||||
|
# the inpaint-1.5 model. Not sure what it did.... ?
|
||||||
|
with scope(self.model.device.type):
|
||||||
for n in trange(iterations, desc='Generating'):
|
for n in trange(iterations, desc='Generating'):
|
||||||
x_T = None
|
x_T = None
|
||||||
if self.variation_amount > 0:
|
if self.variation_amount > 0:
|
||||||
@ -79,7 +83,8 @@ class Generator():
|
|||||||
try:
|
try:
|
||||||
x_T = self.get_noise(width,height)
|
x_T = self.get_noise(width,height)
|
||||||
except:
|
except:
|
||||||
pass
|
print('** An error occurred while getting initial noise **')
|
||||||
|
print(traceback.format_exc())
|
||||||
|
|
||||||
image = make_image(x_T)
|
image = make_image(x_T)
|
||||||
|
|
||||||
@ -95,10 +100,10 @@ class Generator():
|
|||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def sample_to_image(self,samples):
|
def sample_to_image(self,samples)->Image.Image:
|
||||||
"""
|
"""
|
||||||
Returns a function returning an image derived from the prompt and the initial image
|
Given samples returned from a sampler, converts
|
||||||
Return value depends on the seed at the time you call it
|
it into a PIL Image
|
||||||
"""
|
"""
|
||||||
x_samples = self.model.decode_first_stage(samples)
|
x_samples = self.model.decode_first_stage(samples)
|
||||||
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
|
@ -15,7 +15,7 @@ from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserCompo
|
|||||||
class Img2Img(Generator):
|
class Img2Img(Generator):
|
||||||
def __init__(self, model, precision):
|
def __init__(self, model, precision):
|
||||||
super().__init__(model, precision)
|
super().__init__(model, precision)
|
||||||
self.init_latent = None # by get_noise()
|
self.init_latent = None # by get_noise()
|
||||||
|
|
||||||
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
|
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
|
||||||
conditioning,init_image,strength,step_callback=None,threshold=0.0,perlin=0.0,**kwargs):
|
conditioning,init_image,strength,step_callback=None,threshold=0.0,perlin=0.0,**kwargs):
|
||||||
@ -80,7 +80,10 @@ class Img2Img(Generator):
|
|||||||
|
|
||||||
def _image_to_tensor(self, image:Image, normalize:bool=True)->Tensor:
|
def _image_to_tensor(self, image:Image, normalize:bool=True)->Tensor:
|
||||||
image = np.array(image).astype(np.float32) / 255.0
|
image = np.array(image).astype(np.float32) / 255.0
|
||||||
image = image[None].transpose(0, 3, 1, 2)
|
if len(image.shape) == 2: # 'L' image, as in a mask
|
||||||
|
image = image[None,None]
|
||||||
|
else: # 'RGB' image
|
||||||
|
image = image[None].transpose(0, 3, 1, 2)
|
||||||
image = torch.from_numpy(image)
|
image = torch.from_numpy(image)
|
||||||
if normalize:
|
if normalize:
|
||||||
image = 2.0 * image - 1.0
|
image = 2.0 * image - 1.0
|
||||||
|
151
ldm/invoke/generator/omnibus.py
Normal file
151
ldm/invoke/generator/omnibus.py
Normal file
@ -0,0 +1,151 @@
|
|||||||
|
"""omnibus module to be used with the runwayml 9-channel custom inpainting model"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from einops import repeat
|
||||||
|
from PIL import Image, ImageOps
|
||||||
|
from ldm.invoke.devices import choose_autocast
|
||||||
|
from ldm.invoke.generator.base import downsampling
|
||||||
|
from ldm.invoke.generator.img2img import Img2Img
|
||||||
|
from ldm.invoke.generator.txt2img import Txt2Img
|
||||||
|
|
||||||
|
class Omnibus(Img2Img,Txt2Img):
|
||||||
|
def __init__(self, model, precision):
|
||||||
|
super().__init__(model, precision)
|
||||||
|
|
||||||
|
def get_make_image(
|
||||||
|
self,
|
||||||
|
prompt,
|
||||||
|
sampler,
|
||||||
|
steps,
|
||||||
|
cfg_scale,
|
||||||
|
ddim_eta,
|
||||||
|
conditioning,
|
||||||
|
width,
|
||||||
|
height,
|
||||||
|
init_image = None,
|
||||||
|
mask_image = None,
|
||||||
|
strength = None,
|
||||||
|
step_callback=None,
|
||||||
|
threshold=0.0,
|
||||||
|
perlin=0.0,
|
||||||
|
**kwargs):
|
||||||
|
"""
|
||||||
|
Returns a function returning an image derived from the prompt and the initial image
|
||||||
|
Return value depends on the seed at the time you call it.
|
||||||
|
"""
|
||||||
|
self.perlin = perlin
|
||||||
|
num_samples = 1
|
||||||
|
|
||||||
|
sampler.make_schedule(
|
||||||
|
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(init_image, Image.Image):
|
||||||
|
init_image = self._image_to_tensor(init_image)
|
||||||
|
|
||||||
|
if isinstance(mask_image, Image.Image):
|
||||||
|
mask_image = self._image_to_tensor(ImageOps.invert(mask_image).convert('L'),normalize=False)
|
||||||
|
|
||||||
|
t_enc = steps
|
||||||
|
|
||||||
|
if init_image is not None and mask_image is not None: # inpainting
|
||||||
|
masked_image = init_image * (1 - mask_image) # masked image is the image masked by mask - masked regions zero
|
||||||
|
|
||||||
|
elif init_image is not None: # img2img
|
||||||
|
scope = choose_autocast(self.precision)
|
||||||
|
|
||||||
|
with scope(self.model.device.type):
|
||||||
|
self.init_latent = self.model.get_first_stage_encoding(
|
||||||
|
self.model.encode_first_stage(init_image)
|
||||||
|
) # move to latent space
|
||||||
|
|
||||||
|
# create a completely black mask (1s)
|
||||||
|
mask_image = torch.ones(1, 1, init_image.shape[2], init_image.shape[3], device=self.model.device)
|
||||||
|
# and the masked image is just a copy of the original
|
||||||
|
masked_image = init_image
|
||||||
|
|
||||||
|
else: # txt2img
|
||||||
|
init_image = torch.zeros(1, 3, height, width, device=self.model.device)
|
||||||
|
mask_image = torch.ones(1, 1, height, width, device=self.model.device)
|
||||||
|
masked_image = init_image
|
||||||
|
|
||||||
|
self.init_latent = init_image
|
||||||
|
height = init_image.shape[2]
|
||||||
|
width = init_image.shape[3]
|
||||||
|
model = self.model
|
||||||
|
|
||||||
|
def make_image(x_T):
|
||||||
|
with torch.no_grad():
|
||||||
|
scope = choose_autocast(self.precision)
|
||||||
|
with scope(self.model.device.type):
|
||||||
|
|
||||||
|
batch = self.make_batch_sd(
|
||||||
|
init_image,
|
||||||
|
mask_image,
|
||||||
|
masked_image,
|
||||||
|
prompt=prompt,
|
||||||
|
device=model.device,
|
||||||
|
num_samples=num_samples,
|
||||||
|
)
|
||||||
|
|
||||||
|
c = model.cond_stage_model.encode(batch["txt"])
|
||||||
|
c_cat = list()
|
||||||
|
for ck in model.concat_keys:
|
||||||
|
cc = batch[ck].float()
|
||||||
|
if ck != model.masked_image_key:
|
||||||
|
bchw = [num_samples, 4, height//8, width//8]
|
||||||
|
cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
|
||||||
|
else:
|
||||||
|
cc = model.get_first_stage_encoding(model.encode_first_stage(cc))
|
||||||
|
c_cat.append(cc)
|
||||||
|
c_cat = torch.cat(c_cat, dim=1)
|
||||||
|
|
||||||
|
# cond
|
||||||
|
cond={"c_concat": [c_cat], "c_crossattn": [c]}
|
||||||
|
|
||||||
|
# uncond cond
|
||||||
|
uc_cross = model.get_unconditional_conditioning(num_samples, "")
|
||||||
|
uc_full = {"c_concat": [c_cat], "c_crossattn": [uc_cross]}
|
||||||
|
shape = [model.channels, height//8, width//8]
|
||||||
|
|
||||||
|
samples, _ = sampler.sample(
|
||||||
|
batch_size = 1,
|
||||||
|
S = steps,
|
||||||
|
x_T = x_T,
|
||||||
|
conditioning = cond,
|
||||||
|
shape = shape,
|
||||||
|
verbose = False,
|
||||||
|
unconditional_guidance_scale = cfg_scale,
|
||||||
|
unconditional_conditioning = uc_full,
|
||||||
|
eta = 1.0,
|
||||||
|
img_callback = step_callback,
|
||||||
|
threshold = threshold,
|
||||||
|
)
|
||||||
|
if self.free_gpu_mem:
|
||||||
|
self.model.model.to("cpu")
|
||||||
|
return self.sample_to_image(samples)
|
||||||
|
|
||||||
|
return make_image
|
||||||
|
|
||||||
|
def make_batch_sd(
|
||||||
|
self,
|
||||||
|
image,
|
||||||
|
mask,
|
||||||
|
masked_image,
|
||||||
|
prompt,
|
||||||
|
device,
|
||||||
|
num_samples=1):
|
||||||
|
batch = {
|
||||||
|
"image": repeat(image.to(device=device), "1 ... -> n ...", n=num_samples),
|
||||||
|
"txt": num_samples * [prompt],
|
||||||
|
"mask": repeat(mask.to(device=device), "1 ... -> n ...", n=num_samples),
|
||||||
|
"masked_image": repeat(masked_image.to(device=device), "1 ... -> n ...", n=num_samples),
|
||||||
|
}
|
||||||
|
return batch
|
||||||
|
|
||||||
|
def get_noise(self, width:int, height:int):
|
||||||
|
if self.init_latent is not None:
|
||||||
|
height = self.init_latent.shape[2]
|
||||||
|
width = self.init_latent.shape[3]
|
||||||
|
return Txt2Img.get_noise(self,width,height)
|
@ -13,6 +13,7 @@ import gc
|
|||||||
import hashlib
|
import hashlib
|
||||||
import psutil
|
import psutil
|
||||||
import transformers
|
import transformers
|
||||||
|
import traceback
|
||||||
import os
|
import os
|
||||||
from sys import getrefcount
|
from sys import getrefcount
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
@ -73,6 +74,7 @@ class ModelCache(object):
|
|||||||
self.models[model_name]['hash'] = hash
|
self.models[model_name]['hash'] = hash
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f'** model {model_name} could not be loaded: {str(e)}')
|
print(f'** model {model_name} could not be loaded: {str(e)}')
|
||||||
|
print(traceback.format_exc())
|
||||||
print(f'** restoring {self.current_model}')
|
print(f'** restoring {self.current_model}')
|
||||||
self.get_model(self.current_model)
|
self.get_model(self.current_model)
|
||||||
return None
|
return None
|
||||||
|
@ -89,6 +89,9 @@ class Outcrop(object):
|
|||||||
def _extend(self,image:Image,pixels:int)-> Image:
|
def _extend(self,image:Image,pixels:int)-> Image:
|
||||||
extended_img = Image.new('RGBA',(image.width,image.height+pixels))
|
extended_img = Image.new('RGBA',(image.width,image.height+pixels))
|
||||||
|
|
||||||
|
mask_height = pixels if self.generate.model.model.conditioning_key in ('hybrid','concat') \
|
||||||
|
else pixels *2
|
||||||
|
|
||||||
# first paste places old image at top of extended image, stretch
|
# first paste places old image at top of extended image, stretch
|
||||||
# it, and applies a gaussian blur to it
|
# it, and applies a gaussian blur to it
|
||||||
# take the top half region, stretch and paste it
|
# take the top half region, stretch and paste it
|
||||||
@ -105,7 +108,9 @@ class Outcrop(object):
|
|||||||
|
|
||||||
# now make the top part transparent to use as a mask
|
# now make the top part transparent to use as a mask
|
||||||
alpha = extended_img.getchannel('A')
|
alpha = extended_img.getchannel('A')
|
||||||
alpha.paste(0,(0,0,extended_img.width,pixels*2))
|
alpha.paste(0,(0,0,extended_img.width,mask_height))
|
||||||
extended_img.putalpha(alpha)
|
extended_img.putalpha(alpha)
|
||||||
|
|
||||||
|
extended_img.save('outputs/curly_extended.png')
|
||||||
|
|
||||||
return extended_img
|
return extended_img
|
||||||
|
@ -66,7 +66,7 @@ class VQModel(pl.LightningModule):
|
|||||||
self.use_ema = use_ema
|
self.use_ema = use_ema
|
||||||
if self.use_ema:
|
if self.use_ema:
|
||||||
self.model_ema = LitEma(self)
|
self.model_ema = LitEma(self)
|
||||||
print(f'Keeping EMAs of {len(list(self.model_ema.buffers()))}.')
|
print(f'>> Keeping EMAs of {len(list(self.model_ema.buffers()))}.')
|
||||||
|
|
||||||
if ckpt_path is not None:
|
if ckpt_path is not None:
|
||||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
||||||
|
@ -53,12 +53,14 @@ class DDIMSampler(Sampler):
|
|||||||
# damian0815 would like to know when/if this code path is used
|
# damian0815 would like to know when/if this code path is used
|
||||||
e_t = self.model.apply_model(x, t, c)
|
e_t = self.model.apply_model(x, t, c)
|
||||||
else:
|
else:
|
||||||
|
# step_index counts in the opposite direction to index
|
||||||
step_index = step_count-(index+1)
|
step_index = step_count-(index+1)
|
||||||
e_t = self.invokeai_diffuser.do_diffusion_step(x, t,
|
e_t = self.invokeai_diffuser.do_diffusion_step(
|
||||||
unconditional_conditioning, c,
|
x, t,
|
||||||
unconditional_guidance_scale,
|
unconditional_conditioning, c,
|
||||||
step_index=step_index)
|
unconditional_guidance_scale,
|
||||||
|
step_index=step_index
|
||||||
|
)
|
||||||
if score_corrector is not None:
|
if score_corrector is not None:
|
||||||
assert self.model.parameterization == 'eps'
|
assert self.model.parameterization == 'eps'
|
||||||
e_t = score_corrector.modify_score(
|
e_t = score_corrector.modify_score(
|
||||||
|
@ -19,6 +19,7 @@ from functools import partial
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from torchvision.utils import make_grid
|
from torchvision.utils import make_grid
|
||||||
from pytorch_lightning.utilities.distributed import rank_zero_only
|
from pytorch_lightning.utilities.distributed import rank_zero_only
|
||||||
|
from omegaconf import ListConfig
|
||||||
import urllib
|
import urllib
|
||||||
|
|
||||||
from ldm.util import (
|
from ldm.util import (
|
||||||
@ -120,7 +121,7 @@ class DDPM(pl.LightningModule):
|
|||||||
self.use_ema = use_ema
|
self.use_ema = use_ema
|
||||||
if self.use_ema:
|
if self.use_ema:
|
||||||
self.model_ema = LitEma(self.model)
|
self.model_ema = LitEma(self.model)
|
||||||
print(f'Keeping EMAs of {len(list(self.model_ema.buffers()))}.')
|
print(f' | Keeping EMAs of {len(list(self.model_ema.buffers()))}.')
|
||||||
|
|
||||||
self.use_scheduler = scheduler_config is not None
|
self.use_scheduler = scheduler_config is not None
|
||||||
if self.use_scheduler:
|
if self.use_scheduler:
|
||||||
@ -1883,6 +1884,24 @@ class LatentDiffusion(DDPM):
|
|||||||
|
|
||||||
return samples, intermediates
|
return samples, intermediates
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def get_unconditional_conditioning(self, batch_size, null_label=None):
|
||||||
|
if null_label is not None:
|
||||||
|
xc = null_label
|
||||||
|
if isinstance(xc, ListConfig):
|
||||||
|
xc = list(xc)
|
||||||
|
if isinstance(xc, dict) or isinstance(xc, list):
|
||||||
|
c = self.get_learned_conditioning(xc)
|
||||||
|
else:
|
||||||
|
if hasattr(xc, "to"):
|
||||||
|
xc = xc.to(self.device)
|
||||||
|
c = self.get_learned_conditioning(xc)
|
||||||
|
else:
|
||||||
|
# todo: get null label from cond_stage_model
|
||||||
|
raise NotImplementedError()
|
||||||
|
c = repeat(c, "1 ... -> b ...", b=batch_size).to(self.device)
|
||||||
|
return c
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def log_images(
|
def log_images(
|
||||||
self,
|
self,
|
||||||
@ -2147,8 +2166,8 @@ class DiffusionWrapper(pl.LightningModule):
|
|||||||
cc = torch.cat(c_crossattn, 1)
|
cc = torch.cat(c_crossattn, 1)
|
||||||
out = self.diffusion_model(x, t, context=cc)
|
out = self.diffusion_model(x, t, context=cc)
|
||||||
elif self.conditioning_key == 'hybrid':
|
elif self.conditioning_key == 'hybrid':
|
||||||
xc = torch.cat([x] + c_concat, dim=1)
|
|
||||||
cc = torch.cat(c_crossattn, 1)
|
cc = torch.cat(c_crossattn, 1)
|
||||||
|
xc = torch.cat([x] + c_concat, dim=1)
|
||||||
out = self.diffusion_model(xc, t, context=cc)
|
out = self.diffusion_model(xc, t, context=cc)
|
||||||
elif self.conditioning_key == 'adm':
|
elif self.conditioning_key == 'adm':
|
||||||
cc = c_crossattn[0]
|
cc = c_crossattn[0]
|
||||||
@ -2187,3 +2206,58 @@ class Layout2ImgDiffusion(LatentDiffusion):
|
|||||||
cond_img = torch.stack(bbox_imgs, dim=0)
|
cond_img = torch.stack(bbox_imgs, dim=0)
|
||||||
logs['bbox_image'] = cond_img
|
logs['bbox_image'] = cond_img
|
||||||
return logs
|
return logs
|
||||||
|
|
||||||
|
class LatentInpaintDiffusion(LatentDiffusion):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
concat_keys=("mask", "masked_image"),
|
||||||
|
masked_image_key="masked_image",
|
||||||
|
finetune_keys=None,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.masked_image_key = masked_image_key
|
||||||
|
assert self.masked_image_key in concat_keys
|
||||||
|
self.concat_keys = concat_keys
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def get_input(
|
||||||
|
self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False
|
||||||
|
):
|
||||||
|
# note: restricted to non-trainable encoders currently
|
||||||
|
assert (
|
||||||
|
not self.cond_stage_trainable
|
||||||
|
), "trainable cond stages not yet supported for inpainting"
|
||||||
|
z, c, x, xrec, xc = super().get_input(
|
||||||
|
batch,
|
||||||
|
self.first_stage_key,
|
||||||
|
return_first_stage_outputs=True,
|
||||||
|
force_c_encode=True,
|
||||||
|
return_original_cond=True,
|
||||||
|
bs=bs,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert exists(self.concat_keys)
|
||||||
|
c_cat = list()
|
||||||
|
for ck in self.concat_keys:
|
||||||
|
cc = (
|
||||||
|
rearrange(batch[ck], "b h w c -> b c h w")
|
||||||
|
.to(memory_format=torch.contiguous_format)
|
||||||
|
.float()
|
||||||
|
)
|
||||||
|
if bs is not None:
|
||||||
|
cc = cc[:bs]
|
||||||
|
cc = cc.to(self.device)
|
||||||
|
bchw = z.shape
|
||||||
|
if ck != self.masked_image_key:
|
||||||
|
cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
|
||||||
|
else:
|
||||||
|
cc = self.get_first_stage_encoding(self.encode_first_stage(cc))
|
||||||
|
c_cat.append(cc)
|
||||||
|
c_cat = torch.cat(c_cat, dim=1)
|
||||||
|
all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
|
||||||
|
if return_first_stage_outputs:
|
||||||
|
return z, all_conds, x, xrec, xc
|
||||||
|
return z, all_conds
|
||||||
|
@ -23,9 +23,10 @@ def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7):
|
|||||||
|
|
||||||
|
|
||||||
class CFGDenoiser(nn.Module):
|
class CFGDenoiser(nn.Module):
|
||||||
def __init__(self, model, threshold = 0, warmup = 0):
|
def __init__(self, sampler, threshold = 0, warmup = 0):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.inner_model = model
|
self.inner_model = sampler.model
|
||||||
|
self.sampler = sampler
|
||||||
self.threshold = threshold
|
self.threshold = threshold
|
||||||
self.warmup_max = warmup
|
self.warmup_max = warmup
|
||||||
self.warmup = max(warmup / 10, 1)
|
self.warmup = max(warmup / 10, 1)
|
||||||
@ -43,10 +44,14 @@ class CFGDenoiser(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
def forward(self, x, sigma, uncond, cond, cond_scale):
|
def forward(self, x, sigma, uncond, cond, cond_scale):
|
||||||
|
if isinstance(cond,dict): # hybrid model
|
||||||
next_x = self.invokeai_diffuser.do_diffusion_step(x, sigma, uncond, cond, cond_scale)
|
x_in = torch.cat([x] * 2)
|
||||||
|
sigma_in = torch.cat([sigma] * 2)
|
||||||
# apply threshold
|
cond_in = self.sampler.make_cond_in(uncond,cond)
|
||||||
|
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
|
||||||
|
next_x = uncond + (cond - uncond) * cond_scale
|
||||||
|
else: # cross attention model
|
||||||
|
next_x = self.invokeai_diffuser.do_diffusion_step(x, sigma, uncond, cond, cond_scale)
|
||||||
if self.warmup < self.warmup_max:
|
if self.warmup < self.warmup_max:
|
||||||
thresh = max(1, 1 + (self.threshold - 1) * (self.warmup / self.warmup_max))
|
thresh = max(1, 1 + (self.threshold - 1) * (self.warmup / self.warmup_max))
|
||||||
self.warmup += 1
|
self.warmup += 1
|
||||||
@ -56,8 +61,6 @@ class CFGDenoiser(nn.Module):
|
|||||||
thresh = self.threshold
|
thresh = self.threshold
|
||||||
return cfg_apply_threshold(next_x, thresh)
|
return cfg_apply_threshold(next_x, thresh)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class KSampler(Sampler):
|
class KSampler(Sampler):
|
||||||
def __init__(self, model, schedule='lms', device=None, **kwargs):
|
def __init__(self, model, schedule='lms', device=None, **kwargs):
|
||||||
denoiser = K.external.CompVisDenoiser(model)
|
denoiser = K.external.CompVisDenoiser(model)
|
||||||
@ -286,3 +289,6 @@ class KSampler(Sampler):
|
|||||||
'''
|
'''
|
||||||
return self.model.inner_model.q_sample(x0,ts)
|
return self.model.inner_model.q_sample(x0,ts)
|
||||||
|
|
||||||
|
def conditioning_key(self)->str:
|
||||||
|
return self.model.inner_model.model.conditioning_key
|
||||||
|
|
||||||
|
@ -14,9 +14,6 @@ class PLMSSampler(Sampler):
|
|||||||
def __init__(self, model, schedule='linear', device=None, **kwargs):
|
def __init__(self, model, schedule='linear', device=None, **kwargs):
|
||||||
super().__init__(model,schedule,model.num_timesteps, device)
|
super().__init__(model,schedule,model.num_timesteps, device)
|
||||||
|
|
||||||
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.model,
|
|
||||||
model_forward_callback = lambda x, sigma, cond: self.model.apply_model(x, sigma, cond))
|
|
||||||
|
|
||||||
def prepare_to_sample(self, t_enc, **kwargs):
|
def prepare_to_sample(self, t_enc, **kwargs):
|
||||||
super().prepare_to_sample(t_enc, **kwargs)
|
super().prepare_to_sample(t_enc, **kwargs)
|
||||||
|
|
||||||
@ -67,7 +64,6 @@ class PLMSSampler(Sampler):
|
|||||||
unconditional_conditioning, c,
|
unconditional_conditioning, c,
|
||||||
unconditional_guidance_scale,
|
unconditional_guidance_scale,
|
||||||
step_index=step_index)
|
step_index=step_index)
|
||||||
|
|
||||||
if score_corrector is not None:
|
if score_corrector is not None:
|
||||||
assert self.model.parameterization == 'eps'
|
assert self.model.parameterization == 'eps'
|
||||||
e_t = score_corrector.modify_score(
|
e_t = score_corrector.modify_score(
|
||||||
|
@ -11,6 +11,7 @@ import numpy as np
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from ldm.invoke.devices import choose_torch_device
|
from ldm.invoke.devices import choose_torch_device
|
||||||
|
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||||
|
|
||||||
from ldm.modules.diffusionmodules.util import (
|
from ldm.modules.diffusionmodules.util import (
|
||||||
make_ddim_sampling_parameters,
|
make_ddim_sampling_parameters,
|
||||||
@ -26,6 +27,8 @@ class Sampler(object):
|
|||||||
self.ddpm_num_timesteps = steps
|
self.ddpm_num_timesteps = steps
|
||||||
self.schedule = schedule
|
self.schedule = schedule
|
||||||
self.device = device or choose_torch_device()
|
self.device = device or choose_torch_device()
|
||||||
|
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.model,
|
||||||
|
model_forward_callback = lambda x, sigma, cond: self.model.apply_model(x, sigma, cond))
|
||||||
|
|
||||||
def register_buffer(self, name, attr):
|
def register_buffer(self, name, attr):
|
||||||
if type(attr) == torch.Tensor:
|
if type(attr) == torch.Tensor:
|
||||||
@ -160,6 +163,18 @@ class Sampler(object):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
|
||||||
|
if conditioning is not None:
|
||||||
|
if isinstance(conditioning, dict):
|
||||||
|
ctmp = conditioning[list(conditioning.keys())[0]]
|
||||||
|
while isinstance(ctmp, list):
|
||||||
|
ctmp = ctmp[0]
|
||||||
|
cbs = ctmp.shape[0]
|
||||||
|
if cbs != batch_size:
|
||||||
|
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
||||||
|
else:
|
||||||
|
if conditioning.shape[0] != batch_size:
|
||||||
|
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
||||||
|
|
||||||
# check to see if make_schedule() has run, and if not, run it
|
# check to see if make_schedule() has run, and if not, run it
|
||||||
if self.ddim_timesteps is None:
|
if self.ddim_timesteps is None:
|
||||||
self.make_schedule(
|
self.make_schedule(
|
||||||
@ -196,7 +211,7 @@ class Sampler(object):
|
|||||||
)
|
)
|
||||||
return samples, intermediates
|
return samples, intermediates
|
||||||
|
|
||||||
#torch.no_grad()
|
@torch.no_grad()
|
||||||
def do_sampling(
|
def do_sampling(
|
||||||
self,
|
self,
|
||||||
cond,
|
cond,
|
||||||
@ -257,6 +272,7 @@ class Sampler(object):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
|
print('DEBUG: in masking routine')
|
||||||
assert x0 is not None
|
assert x0 is not None
|
||||||
img_orig = self.model.q_sample(
|
img_orig = self.model.q_sample(
|
||||||
x0, ts
|
x0, ts
|
||||||
@ -313,7 +329,6 @@ class Sampler(object):
|
|||||||
all_timesteps_count = None,
|
all_timesteps_count = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
|
|
||||||
timesteps = (
|
timesteps = (
|
||||||
np.arange(self.ddpm_num_timesteps)
|
np.arange(self.ddpm_num_timesteps)
|
||||||
if use_original_steps
|
if use_original_steps
|
||||||
@ -420,3 +435,27 @@ class Sampler(object):
|
|||||||
'''
|
'''
|
||||||
return self.model.q_sample(x0,ts)
|
return self.model.q_sample(x0,ts)
|
||||||
|
|
||||||
|
def conditioning_key(self)->str:
|
||||||
|
return self.model.model.conditioning_key
|
||||||
|
|
||||||
|
# def make_cond_in(self, uncond, cond):
|
||||||
|
# '''
|
||||||
|
# This handles the choice between a conditional conditioning
|
||||||
|
# that is a tensor (used by cross attention) vs one that is a dict
|
||||||
|
# used by 'hybrid'
|
||||||
|
# '''
|
||||||
|
# if isinstance(cond, dict):
|
||||||
|
# assert isinstance(uncond, dict)
|
||||||
|
# cond_in = dict()
|
||||||
|
# for k in cond:
|
||||||
|
# if isinstance(cond[k], list):
|
||||||
|
# cond_in[k] = [
|
||||||
|
# torch.cat([uncond[k][i], cond[k][i]])
|
||||||
|
# for i in range(len(cond[k]))
|
||||||
|
# ]
|
||||||
|
# else:
|
||||||
|
# cond_in[k] = torch.cat([uncond[k], cond[k]])
|
||||||
|
# else:
|
||||||
|
# cond_in = torch.cat([uncond, cond])
|
||||||
|
# return cond_in
|
||||||
|
|
||||||
|
@ -171,9 +171,9 @@ def main_loop(gen, opt):
|
|||||||
except (OSError, AttributeError, KeyError):
|
except (OSError, AttributeError, KeyError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if len(opt.prompt) == 0:
|
# if len(opt.prompt) == 0:
|
||||||
print('\nTry again with a prompt!')
|
# print('\nTry again with a prompt!')
|
||||||
continue
|
# continue
|
||||||
|
|
||||||
# width and height are set by model if not specified
|
# width and height are set by model if not specified
|
||||||
if not opt.width:
|
if not opt.width:
|
||||||
|
Reference in New Issue
Block a user