mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
start support for 1.5 inpainting model, not complete
This commit is contained in:
parent
3081b6b7dd
commit
83a3cc9eb4
@ -404,7 +404,10 @@ class Generate:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# TODO: Hacky selection of operation to perform. Needs to be refactored.
|
# TODO: Hacky selection of operation to perform. Needs to be refactored.
|
||||||
if (init_image is not None) and (mask_image is not None):
|
if self.sampler.conditioning_key() in ('hybrid','concat'):
|
||||||
|
print(f'** Inpainting model detected. Will try it! **')
|
||||||
|
generator = self._make_omnibus()
|
||||||
|
elif (init_image is not None) and (mask_image is not None):
|
||||||
generator = self._make_inpaint()
|
generator = self._make_inpaint()
|
||||||
elif (embiggen != None or embiggen_tiles != None):
|
elif (embiggen != None or embiggen_tiles != None):
|
||||||
generator = self._make_embiggen()
|
generator = self._make_embiggen()
|
||||||
@ -690,6 +693,12 @@ class Generate:
|
|||||||
self.generators['inpaint'] = Inpaint(self.model, self.precision)
|
self.generators['inpaint'] = Inpaint(self.model, self.precision)
|
||||||
return self.generators['inpaint']
|
return self.generators['inpaint']
|
||||||
|
|
||||||
|
def _make_omnibus(self):
|
||||||
|
if not self.generators.get('omnibus'):
|
||||||
|
from ldm.invoke.generator.omnibus import Omnibus
|
||||||
|
self.generators['omnibus'] = Omnibus(self.model, self.precision)
|
||||||
|
return self.generators['omnibus']
|
||||||
|
|
||||||
def load_model(self):
|
def load_model(self):
|
||||||
'''
|
'''
|
||||||
preload model identified in self.model_name
|
preload model identified in self.model_name
|
||||||
|
@ -40,12 +40,13 @@ class Generator():
|
|||||||
self.variation_amount = variation_amount
|
self.variation_amount = variation_amount
|
||||||
self.with_variations = with_variations
|
self.with_variations = with_variations
|
||||||
|
|
||||||
def generate(self,prompt,init_image,width,height,iterations=1,seed=None,
|
def generate(self,prompt,init_image,width,height,sampler, iterations=1,seed=None,
|
||||||
image_callback=None, step_callback=None, threshold=0.0, perlin=0.0,
|
image_callback=None, step_callback=None, threshold=0.0, perlin=0.0,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
scope = choose_autocast(self.precision)
|
scope = choose_autocast(self.precision)
|
||||||
make_image = self.get_make_image(
|
make_image = self.get_make_image(
|
||||||
prompt,
|
prompt,
|
||||||
|
sampler = sampler,
|
||||||
init_image = init_image,
|
init_image = init_image,
|
||||||
width = width,
|
width = width,
|
||||||
height = height,
|
height = height,
|
||||||
@ -54,13 +55,16 @@ class Generator():
|
|||||||
perlin = perlin,
|
perlin = perlin,
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
seed = seed if seed is not None else self.new_seed()
|
seed = seed if seed is not None else self.new_seed()
|
||||||
first_seed = seed
|
first_seed = seed
|
||||||
seed, initial_noise = self.generate_initial_noise(seed, width, height)
|
seed, initial_noise = self.generate_initial_noise(seed, width, height)
|
||||||
with scope(self.model.device.type), self.model.ema_scope():
|
|
||||||
|
scope = (scope(self.model.device.type), self.model.ema_scope()) if sampler.conditioning_key() not in ('hybrid','concat') else scope(self.model.device.type)
|
||||||
|
|
||||||
|
with scope:
|
||||||
for n in trange(iterations, desc='Generating'):
|
for n in trange(iterations, desc='Generating'):
|
||||||
|
print('DEBUG: in iterations loop() called')
|
||||||
x_T = None
|
x_T = None
|
||||||
if self.variation_amount > 0:
|
if self.variation_amount > 0:
|
||||||
seed_everything(seed)
|
seed_everything(seed)
|
||||||
@ -75,7 +79,6 @@ class Generator():
|
|||||||
x_T = self.get_noise(width,height)
|
x_T = self.get_noise(width,height)
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
image = make_image(x_T)
|
image = make_image(x_T)
|
||||||
results.append([image, seed])
|
results.append([image, seed])
|
||||||
if image_callback is not None:
|
if image_callback is not None:
|
||||||
@ -83,10 +86,10 @@ class Generator():
|
|||||||
seed = self.new_seed()
|
seed = self.new_seed()
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def sample_to_image(self,samples):
|
def sample_to_image(self,samples)->Image.Image:
|
||||||
"""
|
"""
|
||||||
Returns a function returning an image derived from the prompt and the initial image
|
Given samples returned from a sampler, converts
|
||||||
Return value depends on the seed at the time you call it
|
it into a PIL Image
|
||||||
"""
|
"""
|
||||||
x_samples = self.model.decode_first_stage(samples)
|
x_samples = self.model.decode_first_stage(samples)
|
||||||
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
|
@ -13,6 +13,7 @@ import gc
|
|||||||
import hashlib
|
import hashlib
|
||||||
import psutil
|
import psutil
|
||||||
import transformers
|
import transformers
|
||||||
|
import traceback
|
||||||
import os
|
import os
|
||||||
from sys import getrefcount
|
from sys import getrefcount
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
@ -73,6 +74,7 @@ class ModelCache(object):
|
|||||||
self.models[model_name]['hash'] = hash
|
self.models[model_name]['hash'] = hash
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f'** model {model_name} could not be loaded: {str(e)}')
|
print(f'** model {model_name} could not be loaded: {str(e)}')
|
||||||
|
print(traceback.format_exc())
|
||||||
print(f'** restoring {self.current_model}')
|
print(f'** restoring {self.current_model}')
|
||||||
self.get_model(self.current_model)
|
self.get_model(self.current_model)
|
||||||
return None
|
return None
|
||||||
|
@ -66,7 +66,7 @@ class VQModel(pl.LightningModule):
|
|||||||
self.use_ema = use_ema
|
self.use_ema = use_ema
|
||||||
if self.use_ema:
|
if self.use_ema:
|
||||||
self.model_ema = LitEma(self)
|
self.model_ema = LitEma(self)
|
||||||
print(f'Keeping EMAs of {len(list(self.model_ema.buffers()))}.')
|
print(f'>> Keeping EMAs of {len(list(self.model_ema.buffers()))}.')
|
||||||
|
|
||||||
if ckpt_path is not None:
|
if ckpt_path is not None:
|
||||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
||||||
|
@ -41,7 +41,19 @@ class DDIMSampler(Sampler):
|
|||||||
else:
|
else:
|
||||||
x_in = torch.cat([x] * 2)
|
x_in = torch.cat([x] * 2)
|
||||||
t_in = torch.cat([t] * 2)
|
t_in = torch.cat([t] * 2)
|
||||||
c_in = torch.cat([unconditional_conditioning, c])
|
if isinstance(c, dict):
|
||||||
|
assert isinstance(unconditional_conditioning, dict)
|
||||||
|
c_in = dict()
|
||||||
|
for k in c:
|
||||||
|
if isinstance(c[k], list):
|
||||||
|
c_in[k] = [
|
||||||
|
torch.cat([unconditional_conditioning[k][i], c[k][i]])
|
||||||
|
for i in range(len(c[k]))
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])
|
||||||
|
else:
|
||||||
|
c_in = torch.cat([unconditional_conditioning, c])
|
||||||
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
||||||
e_t = e_t_uncond + unconditional_guidance_scale * (
|
e_t = e_t_uncond + unconditional_guidance_scale * (
|
||||||
e_t - e_t_uncond
|
e_t - e_t_uncond
|
||||||
|
@ -19,6 +19,7 @@ from functools import partial
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from torchvision.utils import make_grid
|
from torchvision.utils import make_grid
|
||||||
from pytorch_lightning.utilities.distributed import rank_zero_only
|
from pytorch_lightning.utilities.distributed import rank_zero_only
|
||||||
|
from omegaconf import ListConfig
|
||||||
import urllib
|
import urllib
|
||||||
|
|
||||||
from ldm.util import (
|
from ldm.util import (
|
||||||
@ -120,7 +121,7 @@ class DDPM(pl.LightningModule):
|
|||||||
self.use_ema = use_ema
|
self.use_ema = use_ema
|
||||||
if self.use_ema:
|
if self.use_ema:
|
||||||
self.model_ema = LitEma(self.model)
|
self.model_ema = LitEma(self.model)
|
||||||
print(f'Keeping EMAs of {len(list(self.model_ema.buffers()))}.')
|
print(f' | Keeping EMAs of {len(list(self.model_ema.buffers()))}.')
|
||||||
|
|
||||||
self.use_scheduler = scheduler_config is not None
|
self.use_scheduler = scheduler_config is not None
|
||||||
if self.use_scheduler:
|
if self.use_scheduler:
|
||||||
@ -1883,6 +1884,24 @@ class LatentDiffusion(DDPM):
|
|||||||
|
|
||||||
return samples, intermediates
|
return samples, intermediates
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def get_unconditional_conditioning(self, batch_size, null_label=None):
|
||||||
|
if null_label is not None:
|
||||||
|
xc = null_label
|
||||||
|
if isinstance(xc, ListConfig):
|
||||||
|
xc = list(xc)
|
||||||
|
if isinstance(xc, dict) or isinstance(xc, list):
|
||||||
|
c = self.get_learned_conditioning(xc)
|
||||||
|
else:
|
||||||
|
if hasattr(xc, "to"):
|
||||||
|
xc = xc.to(self.device)
|
||||||
|
c = self.get_learned_conditioning(xc)
|
||||||
|
else:
|
||||||
|
# todo: get null label from cond_stage_model
|
||||||
|
raise NotImplementedError()
|
||||||
|
c = repeat(c, "1 ... -> b ...", b=batch_size).to(self.device)
|
||||||
|
return c
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def log_images(
|
def log_images(
|
||||||
self,
|
self,
|
||||||
@ -2138,6 +2157,7 @@ class DiffusionWrapper(pl.LightningModule):
|
|||||||
]
|
]
|
||||||
|
|
||||||
def forward(self, x, t, c_concat: list = None, c_crossattn: list = None):
|
def forward(self, x, t, c_concat: list = None, c_crossattn: list = None):
|
||||||
|
print(f'DEBUG (ddpm) c_concat = {c_concat}')
|
||||||
if self.conditioning_key is None:
|
if self.conditioning_key is None:
|
||||||
out = self.diffusion_model(x, t)
|
out = self.diffusion_model(x, t)
|
||||||
elif self.conditioning_key == 'concat':
|
elif self.conditioning_key == 'concat':
|
||||||
@ -2147,8 +2167,8 @@ class DiffusionWrapper(pl.LightningModule):
|
|||||||
cc = torch.cat(c_crossattn, 1)
|
cc = torch.cat(c_crossattn, 1)
|
||||||
out = self.diffusion_model(x, t, context=cc)
|
out = self.diffusion_model(x, t, context=cc)
|
||||||
elif self.conditioning_key == 'hybrid':
|
elif self.conditioning_key == 'hybrid':
|
||||||
xc = torch.cat([x] + c_concat, dim=1)
|
|
||||||
cc = torch.cat(c_crossattn, 1)
|
cc = torch.cat(c_crossattn, 1)
|
||||||
|
xc = torch.cat([x] + c_concat, dim=1)
|
||||||
out = self.diffusion_model(xc, t, context=cc)
|
out = self.diffusion_model(xc, t, context=cc)
|
||||||
elif self.conditioning_key == 'adm':
|
elif self.conditioning_key == 'adm':
|
||||||
cc = c_crossattn[0]
|
cc = c_crossattn[0]
|
||||||
@ -2187,3 +2207,58 @@ class Layout2ImgDiffusion(LatentDiffusion):
|
|||||||
cond_img = torch.stack(bbox_imgs, dim=0)
|
cond_img = torch.stack(bbox_imgs, dim=0)
|
||||||
logs['bbox_image'] = cond_img
|
logs['bbox_image'] = cond_img
|
||||||
return logs
|
return logs
|
||||||
|
|
||||||
|
class LatentInpaintDiffusion(LatentDiffusion):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
concat_keys=("mask", "masked_image"),
|
||||||
|
masked_image_key="masked_image",
|
||||||
|
finetune_keys=None,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.masked_image_key = masked_image_key
|
||||||
|
assert self.masked_image_key in concat_keys
|
||||||
|
self.concat_keys = concat_keys
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def get_input(
|
||||||
|
self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False
|
||||||
|
):
|
||||||
|
# note: restricted to non-trainable encoders currently
|
||||||
|
assert (
|
||||||
|
not self.cond_stage_trainable
|
||||||
|
), "trainable cond stages not yet supported for inpainting"
|
||||||
|
z, c, x, xrec, xc = super().get_input(
|
||||||
|
batch,
|
||||||
|
self.first_stage_key,
|
||||||
|
return_first_stage_outputs=True,
|
||||||
|
force_c_encode=True,
|
||||||
|
return_original_cond=True,
|
||||||
|
bs=bs,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert exists(self.concat_keys)
|
||||||
|
c_cat = list()
|
||||||
|
for ck in self.concat_keys:
|
||||||
|
cc = (
|
||||||
|
rearrange(batch[ck], "b h w c -> b c h w")
|
||||||
|
.to(memory_format=torch.contiguous_format)
|
||||||
|
.float()
|
||||||
|
)
|
||||||
|
if bs is not None:
|
||||||
|
cc = cc[:bs]
|
||||||
|
cc = cc.to(self.device)
|
||||||
|
bchw = z.shape
|
||||||
|
if ck != self.masked_image_key:
|
||||||
|
cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
|
||||||
|
else:
|
||||||
|
cc = self.get_first_stage_encoding(self.encode_first_stage(cc))
|
||||||
|
c_cat.append(cc)
|
||||||
|
c_cat = torch.cat(c_cat, dim=1)
|
||||||
|
all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
|
||||||
|
if return_first_stage_outputs:
|
||||||
|
return z, all_conds, x, xrec, xc
|
||||||
|
return z, all_conds
|
||||||
|
@ -281,3 +281,5 @@ class KSampler(Sampler):
|
|||||||
'''
|
'''
|
||||||
return self.model.inner_model.q_sample(x0,ts)
|
return self.model.inner_model.q_sample(x0,ts)
|
||||||
|
|
||||||
|
def conditioning_key(self)->str:
|
||||||
|
return self.model.inner_model.model.conditioning_key
|
||||||
|
@ -158,6 +158,18 @@ class Sampler(object):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
|
||||||
|
if conditioning is not None:
|
||||||
|
if isinstance(conditioning, dict):
|
||||||
|
ctmp = conditioning[list(conditioning.keys())[0]]
|
||||||
|
while isinstance(ctmp, list):
|
||||||
|
ctmp = ctmp[0]
|
||||||
|
cbs = ctmp.shape[0]
|
||||||
|
if cbs != batch_size:
|
||||||
|
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
||||||
|
else:
|
||||||
|
if conditioning.shape[0] != batch_size:
|
||||||
|
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
||||||
|
|
||||||
# check to see if make_schedule() has run, and if not, run it
|
# check to see if make_schedule() has run, and if not, run it
|
||||||
if self.ddim_timesteps is None:
|
if self.ddim_timesteps is None:
|
||||||
self.make_schedule(
|
self.make_schedule(
|
||||||
@ -193,7 +205,7 @@ class Sampler(object):
|
|||||||
)
|
)
|
||||||
return samples, intermediates
|
return samples, intermediates
|
||||||
|
|
||||||
#torch.no_grad()
|
@torch.no_grad()
|
||||||
def do_sampling(
|
def do_sampling(
|
||||||
self,
|
self,
|
||||||
cond,
|
cond,
|
||||||
@ -307,6 +319,19 @@ class Sampler(object):
|
|||||||
mask = None,
|
mask = None,
|
||||||
):
|
):
|
||||||
|
|
||||||
|
print(f'DEBUG(sampler): cond = {cond}')
|
||||||
|
if cond is not None:
|
||||||
|
if isinstance(cond, dict):
|
||||||
|
ctmp = cond[list(cond.keys())[0]]
|
||||||
|
while isinstance(ctmp, list):
|
||||||
|
ctmp = ctmp[0]
|
||||||
|
cbs = ctmp.shape[0]
|
||||||
|
if cbs != batch_size:
|
||||||
|
print(f"Warning: Got {cbs} conds but batch-size is {batch_size}")
|
||||||
|
else:
|
||||||
|
if cond.shape[0] != batch_size:
|
||||||
|
print(f"Warning: Got {cond.shape[0]} conditionings but batch-size is {batch_size}")
|
||||||
|
|
||||||
timesteps = (
|
timesteps = (
|
||||||
np.arange(self.ddpm_num_timesteps)
|
np.arange(self.ddpm_num_timesteps)
|
||||||
if use_original_steps
|
if use_original_steps
|
||||||
@ -411,3 +436,6 @@ class Sampler(object):
|
|||||||
return self.model.inner_model.q_sample(x0,ts)
|
return self.model.inner_model.q_sample(x0,ts)
|
||||||
'''
|
'''
|
||||||
return self.model.q_sample(x0,ts)
|
return self.model.q_sample(x0,ts)
|
||||||
|
|
||||||
|
def conditioning_key(self)->str:
|
||||||
|
return self.model.model.conditioning_key
|
||||||
|
Loading…
x
Reference in New Issue
Block a user