Remove accelerate library

This library is not required to use k-diffusion
Make k-diffusion wrapper closer to the other samplers
This commit is contained in:
BlueAmulet
2022-08-25 11:04:57 -06:00
parent 1c8ecacddf
commit 39b55ae016
5 changed files with 13 additions and 26 deletions

View File

@ -12,7 +12,6 @@ from pytorch_lightning import seed_everything
from torch import autocast
from contextlib import contextmanager, nullcontext
import accelerate
import k_diffusion as K
import torch.nn as nn
@ -201,8 +200,6 @@ def main():
#for klms
model_wrap = K.external.CompVisDenoiser(model)
accelerator = accelerate.Accelerator()
device = accelerator.device
class CFGDenoiser(nn.Module):
def __init__(self, model):
super().__init__()
@ -251,8 +248,8 @@ def main():
with model.ema_scope():
tic = time.time()
all_samples = list()
for n in trange(opt.n_iter, desc="Sampling", disable =not accelerator.is_main_process):
for prompts in tqdm(data, desc="data", disable =not accelerator.is_main_process):
for n in trange(opt.n_iter, desc="Sampling"):
for prompts in tqdm(data, desc="data"):
uc = None
if opt.scale != 1.0:
uc = model.get_learned_conditioning(batch_size * [""])
@ -279,13 +276,10 @@ def main():
x = torch.randn([opt.n_samples, *shape], device=device) * sigmas[0] # for GPU draw
model_wrap_cfg = CFGDenoiser(model_wrap)
extra_args = {'cond': c, 'uncond': uc, 'cond_scale': opt.scale}
samples_ddim = K.sampling.sample_lms(model_wrap_cfg, x, sigmas, extra_args=extra_args, disable=not accelerator.is_main_process)
samples_ddim = K.sampling.sample_lms(model_wrap_cfg, x, sigmas, extra_args=extra_args)
x_samples_ddim = model.decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
if opt.klms:
x_sample = accelerator.gather(x_samples_ddim)
if not opt.skip_save:
for x_sample in x_samples_ddim: