mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add klms sampling
This commit is contained in:
parent
1ca3dc553c
commit
41f0afbcb6
@ -24,6 +24,8 @@ dependencies:
|
|||||||
- transformers==4.19.2
|
- transformers==4.19.2
|
||||||
- torchmetrics==0.6.0
|
- torchmetrics==0.6.0
|
||||||
- kornia==0.6
|
- kornia==0.6
|
||||||
|
- accelerate==0.12.0
|
||||||
|
- git+https://github.com/crowsonkb/k-diffusion.git@master
|
||||||
- -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
|
- -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
|
||||||
- -e git+https://github.com/openai/CLIP.git@main#egg=clip
|
- -e git+https://github.com/openai/CLIP.git@main#egg=clip
|
||||||
- -e .
|
- -e .
|
||||||
|
@ -169,7 +169,7 @@ def create_argv_parser():
|
|||||||
default=1,
|
default=1,
|
||||||
help="number of images to produce per iteration (currently not working properly - producing too many images)")
|
help="number of images to produce per iteration (currently not working properly - producing too many images)")
|
||||||
parser.add_argument('--sampler',
|
parser.add_argument('--sampler',
|
||||||
choices=['plms','ddim'],
|
choices=['plms','ddim', 'klms'],
|
||||||
default='plms',
|
default='plms',
|
||||||
help="which sampler to use")
|
help="which sampler to use")
|
||||||
parser.add_argument('-o',
|
parser.add_argument('-o',
|
||||||
|
@ -12,6 +12,10 @@ from pytorch_lightning import seed_everything
|
|||||||
from torch import autocast
|
from torch import autocast
|
||||||
from contextlib import contextmanager, nullcontext
|
from contextlib import contextmanager, nullcontext
|
||||||
|
|
||||||
|
import accelerate
|
||||||
|
import k_diffusion as K
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
from ldm.util import instantiate_from_config
|
from ldm.util import instantiate_from_config
|
||||||
from ldm.models.diffusion.ddim import DDIMSampler
|
from ldm.models.diffusion.ddim import DDIMSampler
|
||||||
from ldm.models.diffusion.plms import PLMSSampler
|
from ldm.models.diffusion.plms import PLMSSampler
|
||||||
@ -80,6 +84,11 @@ def main():
|
|||||||
action='store_true',
|
action='store_true',
|
||||||
help="use plms sampling",
|
help="use plms sampling",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--klms",
|
||||||
|
action='store_true',
|
||||||
|
help="use klms sampling",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--laion400m",
|
"--laion400m",
|
||||||
action='store_true',
|
action='store_true',
|
||||||
@ -190,6 +199,22 @@ def main():
|
|||||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
|
|
||||||
|
#for klms
|
||||||
|
model_wrap = K.external.CompVisDenoiser(model)
|
||||||
|
accelerator = accelerate.Accelerator()
|
||||||
|
device = accelerator.device
|
||||||
|
class CFGDenoiser(nn.Module):
|
||||||
|
def __init__(self, model):
|
||||||
|
super().__init__()
|
||||||
|
self.inner_model = model
|
||||||
|
|
||||||
|
def forward(self, x, sigma, uncond, cond, cond_scale):
|
||||||
|
x_in = torch.cat([x] * 2)
|
||||||
|
sigma_in = torch.cat([sigma] * 2)
|
||||||
|
cond_in = torch.cat([uncond, cond])
|
||||||
|
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
|
||||||
|
return uncond + (cond - uncond) * cond_scale
|
||||||
|
|
||||||
if opt.plms:
|
if opt.plms:
|
||||||
sampler = PLMSSampler(model)
|
sampler = PLMSSampler(model)
|
||||||
else:
|
else:
|
||||||
@ -226,8 +251,8 @@ def main():
|
|||||||
with model.ema_scope():
|
with model.ema_scope():
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
all_samples = list()
|
all_samples = list()
|
||||||
for n in trange(opt.n_iter, desc="Sampling"):
|
for n in trange(opt.n_iter, desc="Sampling", disable =not accelerator.is_main_process):
|
||||||
for prompts in tqdm(data, desc="data"):
|
for prompts in tqdm(data, desc="data", disable =not accelerator.is_main_process):
|
||||||
uc = None
|
uc = None
|
||||||
if opt.scale != 1.0:
|
if opt.scale != 1.0:
|
||||||
uc = model.get_learned_conditioning(batch_size * [""])
|
uc = model.get_learned_conditioning(batch_size * [""])
|
||||||
@ -235,18 +260,32 @@ def main():
|
|||||||
prompts = list(prompts)
|
prompts = list(prompts)
|
||||||
c = model.get_learned_conditioning(prompts)
|
c = model.get_learned_conditioning(prompts)
|
||||||
shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
|
shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
|
||||||
samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
|
|
||||||
conditioning=c,
|
if not opt.klms:
|
||||||
batch_size=opt.n_samples,
|
samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
|
||||||
shape=shape,
|
conditioning=c,
|
||||||
verbose=False,
|
batch_size=opt.n_samples,
|
||||||
unconditional_guidance_scale=opt.scale,
|
shape=shape,
|
||||||
unconditional_conditioning=uc,
|
verbose=False,
|
||||||
eta=opt.ddim_eta,
|
unconditional_guidance_scale=opt.scale,
|
||||||
x_T=start_code)
|
unconditional_conditioning=uc,
|
||||||
|
eta=opt.ddim_eta,
|
||||||
|
x_T=start_code)
|
||||||
|
else:
|
||||||
|
sigmas = model_wrap.get_sigmas(opt.ddim_steps)
|
||||||
|
if start_code:
|
||||||
|
x = start_code
|
||||||
|
else:
|
||||||
|
x = torch.randn([opt.n_samples, *shape], device=device) * sigmas[0] # for GPU draw
|
||||||
|
model_wrap_cfg = CFGDenoiser(model_wrap)
|
||||||
|
extra_args = {'cond': c, 'uncond': uc, 'cond_scale': opt.scale}
|
||||||
|
samples_ddim = K.sampling.sample_lms(model_wrap_cfg, x, sigmas, extra_args=extra_args, disable=not accelerator.is_main_process)
|
||||||
|
|
||||||
x_samples_ddim = model.decode_first_stage(samples_ddim)
|
x_samples_ddim = model.decode_first_stage(samples_ddim)
|
||||||
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
|
|
||||||
|
if opt.klms:
|
||||||
|
x_sample = accelerator.gather(x_samples_ddim)
|
||||||
|
|
||||||
if not opt.skip_save:
|
if not opt.skip_save:
|
||||||
for x_sample in x_samples_ddim:
|
for x_sample in x_samples_ddim:
|
||||||
|
Loading…
Reference in New Issue
Block a user