Remove accelerate library

This library is not required to use k-diffusion
Make k-diffusion wrapper closer to the other samplers
This commit is contained in:
BlueAmulet 2022-08-25 11:04:57 -06:00
parent 1c8ecacddf
commit 39b55ae016
5 changed files with 13 additions and 26 deletions

View File

@ -10,7 +10,6 @@ dependencies:
- torchvision=0.12.0 - torchvision=0.12.0
- numpy=1.19.2 - numpy=1.19.2
- pip: - pip:
- accelerate==0.12.0
- albumentations==0.4.3 - albumentations==0.4.3
- opencv-python==4.1.2.30 - opencv-python==4.1.2.30
- pudb==2019.2 - pudb==2019.2

View File

@ -2,7 +2,6 @@
import k_diffusion as K import k_diffusion as K
import torch import torch
import torch.nn as nn import torch.nn as nn
import accelerate
class CFGDenoiser(nn.Module): class CFGDenoiser(nn.Module):
def __init__(self, model): def __init__(self, model):
@ -17,12 +16,11 @@ class CFGDenoiser(nn.Module):
return uncond + (cond - uncond) * cond_scale return uncond + (cond - uncond) * cond_scale
class KSampler(object): class KSampler(object):
def __init__(self,model,schedule="lms", **kwargs): def __init__(self, model, schedule="lms", device="cuda", **kwargs):
super().__init__() super().__init__()
self.model = K.external.CompVisDenoiser(model) self.model = K.external.CompVisDenoiser(model)
self.accelerator = accelerate.Accelerator()
self.device = self.accelerator.device
self.schedule = schedule self.schedule = schedule
self.device = device
def forward(self, x, sigma, uncond, cond, cond_scale): def forward(self, x, sigma, uncond, cond, cond_scale):
x_in = torch.cat([x] * 2) x_in = torch.cat([x] * 2)
@ -67,8 +65,5 @@ class KSampler(object):
x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0] # for GPU draw x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0] # for GPU draw
model_wrap_cfg = CFGDenoiser(self.model) model_wrap_cfg = CFGDenoiser(self.model)
extra_args = {'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': unconditional_guidance_scale} extra_args = {'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': unconditional_guidance_scale}
return (K.sampling.__dict__[f'sample_{self.schedule}'](model_wrap_cfg, x, sigmas, extra_args=extra_args, disable=not self.accelerator.is_main_process), return (K.sampling.__dict__[f'sample_{self.schedule}'](model_wrap_cfg, x, sigmas, extra_args=extra_args),
None) None)
def gather(samples_ddim):
return self.accelerator.gather(samples_ddim)

View File

@ -467,17 +467,17 @@ The vast majority of these arguments default to reasonable values.
elif self.sampler_name == 'ddim': elif self.sampler_name == 'ddim':
self.sampler = DDIMSampler(self.model, device=self.device) self.sampler = DDIMSampler(self.model, device=self.device)
elif self.sampler_name == 'k_dpm_2_a': elif self.sampler_name == 'k_dpm_2_a':
self.sampler = KSampler(self.model,'dpm_2_ancestral') self.sampler = KSampler(self.model, 'dpm_2_ancestral', device=self.device)
elif self.sampler_name == 'k_dpm_2': elif self.sampler_name == 'k_dpm_2':
self.sampler = KSampler(self.model,'dpm_2') self.sampler = KSampler(self.model, 'dpm_2', device=self.device)
elif self.sampler_name == 'k_euler_a': elif self.sampler_name == 'k_euler_a':
self.sampler = KSampler(self.model,'euler_ancestral') self.sampler = KSampler(self.model, 'euler_ancestral', device=self.device)
elif self.sampler_name == 'k_euler': elif self.sampler_name == 'k_euler':
self.sampler = KSampler(self.model,'euler') self.sampler = KSampler(self.model, 'euler', device=self.device)
elif self.sampler_name == 'k_heun': elif self.sampler_name == 'k_heun':
self.sampler = KSampler(self.model,'heun') self.sampler = KSampler(self.model, 'heun', device=self.device)
elif self.sampler_name == 'k_lms': elif self.sampler_name == 'k_lms':
self.sampler = KSampler(self.model,'lms') self.sampler = KSampler(self.model, 'lms', device=self.device)
else: else:
msg = f'unsupported sampler {self.sampler_name}, defaulting to plms' msg = f'unsupported sampler {self.sampler_name}, defaulting to plms'
self.sampler = PLMSSampler(self.model, device=self.device) self.sampler = PLMSSampler(self.model, device=self.device)

View File

@ -1,4 +1,3 @@
accelerate==0.12.0
albumentations==0.4.3 albumentations==0.4.3
einops==0.3.0 einops==0.3.0
huggingface-hub==0.8.1 huggingface-hub==0.8.1

View File

@ -12,7 +12,6 @@ from pytorch_lightning import seed_everything
from torch import autocast from torch import autocast
from contextlib import contextmanager, nullcontext from contextlib import contextmanager, nullcontext
import accelerate
import k_diffusion as K import k_diffusion as K
import torch.nn as nn import torch.nn as nn
@ -201,8 +200,6 @@ def main():
#for klms #for klms
model_wrap = K.external.CompVisDenoiser(model) model_wrap = K.external.CompVisDenoiser(model)
accelerator = accelerate.Accelerator()
device = accelerator.device
class CFGDenoiser(nn.Module): class CFGDenoiser(nn.Module):
def __init__(self, model): def __init__(self, model):
super().__init__() super().__init__()
@ -251,8 +248,8 @@ def main():
with model.ema_scope(): with model.ema_scope():
tic = time.time() tic = time.time()
all_samples = list() all_samples = list()
for n in trange(opt.n_iter, desc="Sampling", disable =not accelerator.is_main_process): for n in trange(opt.n_iter, desc="Sampling"):
for prompts in tqdm(data, desc="data", disable =not accelerator.is_main_process): for prompts in tqdm(data, desc="data"):
uc = None uc = None
if opt.scale != 1.0: if opt.scale != 1.0:
uc = model.get_learned_conditioning(batch_size * [""]) uc = model.get_learned_conditioning(batch_size * [""])
@ -279,13 +276,10 @@ def main():
x = torch.randn([opt.n_samples, *shape], device=device) * sigmas[0] # for GPU draw x = torch.randn([opt.n_samples, *shape], device=device) * sigmas[0] # for GPU draw
model_wrap_cfg = CFGDenoiser(model_wrap) model_wrap_cfg = CFGDenoiser(model_wrap)
extra_args = {'cond': c, 'uncond': uc, 'cond_scale': opt.scale} extra_args = {'cond': c, 'uncond': uc, 'cond_scale': opt.scale}
samples_ddim = K.sampling.sample_lms(model_wrap_cfg, x, sigmas, extra_args=extra_args, disable=not accelerator.is_main_process) samples_ddim = K.sampling.sample_lms(model_wrap_cfg, x, sigmas, extra_args=extra_args)
x_samples_ddim = model.decode_first_stage(samples_ddim) x_samples_ddim = model.decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
if opt.klms:
x_sample = accelerator.gather(x_samples_ddim)
if not opt.skip_save: if not opt.skip_save:
for x_sample in x_samples_ddim: for x_sample in x_samples_ddim: