Merge pull request #89 from BlueAmulet/remove-accelerate

Remove accelerate library
This commit is contained in:
Lincoln Stein 2022-08-25 13:48:48 -04:00 committed by GitHub
commit 650ae3eb13
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 14 additions and 27 deletions

View File

@ -10,7 +10,6 @@ dependencies:
- torchvision=0.12.0
- numpy=1.19.2
- pip:
- accelerate==0.12.0
- albumentations==0.4.3
- opencv-python==4.1.2.30
- pudb==2019.2

View File

@ -1,8 +1,7 @@
'''wrapper around part of Karen Crownson's k-duffsion library, making it call compatible with other Samplers'''
'''wrapper around part of Katherine Crowson's k-diffusion library, making it call compatible with other Samplers'''
import k_diffusion as K
import torch
import torch.nn as nn
import accelerate
class CFGDenoiser(nn.Module):
def __init__(self, model):
@ -17,12 +16,11 @@ class CFGDenoiser(nn.Module):
return uncond + (cond - uncond) * cond_scale
class KSampler(object):
def __init__(self,model,schedule="lms", **kwargs):
def __init__(self, model, schedule="lms", device="cuda", **kwargs):
super().__init__()
self.model = K.external.CompVisDenoiser(model)
self.accelerator = accelerate.Accelerator()
self.device = self.accelerator.device
self.model = K.external.CompVisDenoiser(model)
self.schedule = schedule
self.device = device
def forward(self, x, sigma, uncond, cond, cond_scale):
x_in = torch.cat([x] * 2)
@ -67,8 +65,5 @@ class KSampler(object):
x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0] # for GPU draw
model_wrap_cfg = CFGDenoiser(self.model)
extra_args = {'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': unconditional_guidance_scale}
return (K.sampling.__dict__[f'sample_{self.schedule}'](model_wrap_cfg, x, sigmas, extra_args=extra_args, disable=not self.accelerator.is_main_process),
return (K.sampling.__dict__[f'sample_{self.schedule}'](model_wrap_cfg, x, sigmas, extra_args=extra_args),
None)
def gather(samples_ddim):
return self.accelerator.gather(samples_ddim)

View File

@ -467,17 +467,17 @@ The vast majority of these arguments default to reasonable values.
elif self.sampler_name == 'ddim':
self.sampler = DDIMSampler(self.model, device=self.device)
elif self.sampler_name == 'k_dpm_2_a':
self.sampler = KSampler(self.model,'dpm_2_ancestral')
self.sampler = KSampler(self.model, 'dpm_2_ancestral', device=self.device)
elif self.sampler_name == 'k_dpm_2':
self.sampler = KSampler(self.model,'dpm_2')
self.sampler = KSampler(self.model, 'dpm_2', device=self.device)
elif self.sampler_name == 'k_euler_a':
self.sampler = KSampler(self.model,'euler_ancestral')
self.sampler = KSampler(self.model, 'euler_ancestral', device=self.device)
elif self.sampler_name == 'k_euler':
self.sampler = KSampler(self.model,'euler')
self.sampler = KSampler(self.model, 'euler', device=self.device)
elif self.sampler_name == 'k_heun':
self.sampler = KSampler(self.model,'heun')
self.sampler = KSampler(self.model, 'heun', device=self.device)
elif self.sampler_name == 'k_lms':
self.sampler = KSampler(self.model,'lms')
self.sampler = KSampler(self.model, 'lms', device=self.device)
else:
msg = f'unsupported sampler {self.sampler_name}, defaulting to plms'
self.sampler = PLMSSampler(self.model, device=self.device)

View File

@ -1,4 +1,3 @@
accelerate==0.12.0
albumentations==0.4.3
einops==0.3.0
huggingface-hub==0.8.1

View File

@ -12,7 +12,6 @@ from pytorch_lightning import seed_everything
from torch import autocast
from contextlib import contextmanager, nullcontext
import accelerate
import k_diffusion as K
import torch.nn as nn
@ -201,8 +200,6 @@ def main():
#for klms
model_wrap = K.external.CompVisDenoiser(model)
accelerator = accelerate.Accelerator()
device = accelerator.device
class CFGDenoiser(nn.Module):
def __init__(self, model):
super().__init__()
@ -251,8 +248,8 @@ def main():
with model.ema_scope():
tic = time.time()
all_samples = list()
for n in trange(opt.n_iter, desc="Sampling", disable =not accelerator.is_main_process):
for prompts in tqdm(data, desc="data", disable =not accelerator.is_main_process):
for n in trange(opt.n_iter, desc="Sampling"):
for prompts in tqdm(data, desc="data"):
uc = None
if opt.scale != 1.0:
uc = model.get_learned_conditioning(batch_size * [""])
@ -279,13 +276,10 @@ def main():
x = torch.randn([opt.n_samples, *shape], device=device) * sigmas[0] # for GPU draw
model_wrap_cfg = CFGDenoiser(model_wrap)
extra_args = {'cond': c, 'uncond': uc, 'cond_scale': opt.scale}
samples_ddim = K.sampling.sample_lms(model_wrap_cfg, x, sigmas, extra_args=extra_args, disable=not accelerator.is_main_process)
samples_ddim = K.sampling.sample_lms(model_wrap_cfg, x, sigmas, extra_args=extra_args)
x_samples_ddim = model.decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
if opt.klms:
x_sample = accelerator.gather(x_samples_ddim)
if not opt.skip_save:
for x_sample in x_samples_ddim: