From 39b55ae016c52c82de5158170ecc9205affff7ef Mon Sep 17 00:00:00 2001 From: BlueAmulet <43395286+BlueAmulet@users.noreply.github.com> Date: Thu, 25 Aug 2022 11:04:57 -0600 Subject: [PATCH] Remove accelerate library This library is not required to use k-diffusion Make k-diffusion wrapper closer to the other samplers --- environment.yaml | 1 - ldm/models/diffusion/ksampler.py | 13 ++++--------- ldm/simplet2i.py | 12 ++++++------ requirements.txt | 1 - scripts/orig_scripts/txt2img.py | 12 +++--------- 5 files changed, 13 insertions(+), 26 deletions(-) diff --git a/environment.yaml b/environment.yaml index b554cfc035..7d5b4fe9e3 100644 --- a/environment.yaml +++ b/environment.yaml @@ -10,7 +10,6 @@ dependencies: - torchvision=0.12.0 - numpy=1.19.2 - pip: - - accelerate==0.12.0 - albumentations==0.4.3 - opencv-python==4.1.2.30 - pudb==2019.2 diff --git a/ldm/models/diffusion/ksampler.py b/ldm/models/diffusion/ksampler.py index c48e533410..cea77dac1a 100644 --- a/ldm/models/diffusion/ksampler.py +++ b/ldm/models/diffusion/ksampler.py @@ -2,7 +2,6 @@ import k_diffusion as K import torch import torch.nn as nn -import accelerate class CFGDenoiser(nn.Module): def __init__(self, model): @@ -17,12 +16,11 @@ class CFGDenoiser(nn.Module): return uncond + (cond - uncond) * cond_scale class KSampler(object): - def __init__(self,model,schedule="lms", **kwargs): + def __init__(self, model, schedule="lms", device="cuda", **kwargs): super().__init__() - self.model = K.external.CompVisDenoiser(model) - self.accelerator = accelerate.Accelerator() - self.device = self.accelerator.device + self.model = K.external.CompVisDenoiser(model) self.schedule = schedule + self.device = device def forward(self, x, sigma, uncond, cond, cond_scale): x_in = torch.cat([x] * 2) @@ -67,8 +65,5 @@ class KSampler(object): x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0] # for GPU draw model_wrap_cfg = CFGDenoiser(self.model) extra_args = {'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': unconditional_guidance_scale} - return (K.sampling.__dict__[f'sample_{self.schedule}'](model_wrap_cfg, x, sigmas, extra_args=extra_args, disable=not self.accelerator.is_main_process), + return (K.sampling.__dict__[f'sample_{self.schedule}'](model_wrap_cfg, x, sigmas, extra_args=extra_args), None) - - def gather(samples_ddim): - return self.accelerator.gather(samples_ddim) diff --git a/ldm/simplet2i.py b/ldm/simplet2i.py index 4737d90ba7..7e5d5d6e7c 100644 --- a/ldm/simplet2i.py +++ b/ldm/simplet2i.py @@ -467,17 +467,17 @@ The vast majority of these arguments default to reasonable values. elif self.sampler_name == 'ddim': self.sampler = DDIMSampler(self.model, device=self.device) elif self.sampler_name == 'k_dpm_2_a': - self.sampler = KSampler(self.model,'dpm_2_ancestral') + self.sampler = KSampler(self.model, 'dpm_2_ancestral', device=self.device) elif self.sampler_name == 'k_dpm_2': - self.sampler = KSampler(self.model,'dpm_2') + self.sampler = KSampler(self.model, 'dpm_2', device=self.device) elif self.sampler_name == 'k_euler_a': - self.sampler = KSampler(self.model,'euler_ancestral') + self.sampler = KSampler(self.model, 'euler_ancestral', device=self.device) elif self.sampler_name == 'k_euler': - self.sampler = KSampler(self.model,'euler') + self.sampler = KSampler(self.model, 'euler', device=self.device) elif self.sampler_name == 'k_heun': - self.sampler = KSampler(self.model,'heun') + self.sampler = KSampler(self.model, 'heun', device=self.device) elif self.sampler_name == 'k_lms': - self.sampler = KSampler(self.model,'lms') + self.sampler = KSampler(self.model, 'lms', device=self.device) else: msg = f'unsupported sampler {self.sampler_name}, defaulting to plms' self.sampler = PLMSSampler(self.model, device=self.device) diff --git a/requirements.txt b/requirements.txt index 30b2251a1c..a94a4a5382 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,3 @@ -accelerate==0.12.0 albumentations==0.4.3 einops==0.3.0 huggingface-hub==0.8.1 diff --git a/scripts/orig_scripts/txt2img.py b/scripts/orig_scripts/txt2img.py index 42d5e83496..1edc531309 100644 --- a/scripts/orig_scripts/txt2img.py +++ b/scripts/orig_scripts/txt2img.py @@ -12,7 +12,6 @@ from pytorch_lightning import seed_everything from torch import autocast from contextlib import contextmanager, nullcontext -import accelerate import k_diffusion as K import torch.nn as nn @@ -201,8 +200,6 @@ def main(): #for klms model_wrap = K.external.CompVisDenoiser(model) - accelerator = accelerate.Accelerator() - device = accelerator.device class CFGDenoiser(nn.Module): def __init__(self, model): super().__init__() @@ -251,8 +248,8 @@ def main(): with model.ema_scope(): tic = time.time() all_samples = list() - for n in trange(opt.n_iter, desc="Sampling", disable =not accelerator.is_main_process): - for prompts in tqdm(data, desc="data", disable =not accelerator.is_main_process): + for n in trange(opt.n_iter, desc="Sampling"): + for prompts in tqdm(data, desc="data"): uc = None if opt.scale != 1.0: uc = model.get_learned_conditioning(batch_size * [""]) @@ -279,13 +276,10 @@ def main(): x = torch.randn([opt.n_samples, *shape], device=device) * sigmas[0] # for GPU draw model_wrap_cfg = CFGDenoiser(model_wrap) extra_args = {'cond': c, 'uncond': uc, 'cond_scale': opt.scale} - samples_ddim = K.sampling.sample_lms(model_wrap_cfg, x, sigmas, extra_args=extra_args, disable=not accelerator.is_main_process) + samples_ddim = K.sampling.sample_lms(model_wrap_cfg, x, sigmas, extra_args=extra_args) x_samples_ddim = model.decode_first_stage(samples_ddim) x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) - - if opt.klms: - x_sample = accelerator.gather(x_samples_ddim) if not opt.skip_save: for x_sample in x_samples_ddim: