mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Remove accelerate library
This library is not required to use k-diffusion Make k-diffusion wrapper closer to the other samplers
This commit is contained in:
parent
1c8ecacddf
commit
39b55ae016
@ -10,7 +10,6 @@ dependencies:
|
|||||||
- torchvision=0.12.0
|
- torchvision=0.12.0
|
||||||
- numpy=1.19.2
|
- numpy=1.19.2
|
||||||
- pip:
|
- pip:
|
||||||
- accelerate==0.12.0
|
|
||||||
- albumentations==0.4.3
|
- albumentations==0.4.3
|
||||||
- opencv-python==4.1.2.30
|
- opencv-python==4.1.2.30
|
||||||
- pudb==2019.2
|
- pudb==2019.2
|
||||||
|
@ -2,7 +2,6 @@
|
|||||||
import k_diffusion as K
|
import k_diffusion as K
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import accelerate
|
|
||||||
|
|
||||||
class CFGDenoiser(nn.Module):
|
class CFGDenoiser(nn.Module):
|
||||||
def __init__(self, model):
|
def __init__(self, model):
|
||||||
@ -17,12 +16,11 @@ class CFGDenoiser(nn.Module):
|
|||||||
return uncond + (cond - uncond) * cond_scale
|
return uncond + (cond - uncond) * cond_scale
|
||||||
|
|
||||||
class KSampler(object):
|
class KSampler(object):
|
||||||
def __init__(self,model,schedule="lms", **kwargs):
|
def __init__(self, model, schedule="lms", device="cuda", **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.model = K.external.CompVisDenoiser(model)
|
self.model = K.external.CompVisDenoiser(model)
|
||||||
self.accelerator = accelerate.Accelerator()
|
|
||||||
self.device = self.accelerator.device
|
|
||||||
self.schedule = schedule
|
self.schedule = schedule
|
||||||
|
self.device = device
|
||||||
|
|
||||||
def forward(self, x, sigma, uncond, cond, cond_scale):
|
def forward(self, x, sigma, uncond, cond, cond_scale):
|
||||||
x_in = torch.cat([x] * 2)
|
x_in = torch.cat([x] * 2)
|
||||||
@ -67,8 +65,5 @@ class KSampler(object):
|
|||||||
x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0] # for GPU draw
|
x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0] # for GPU draw
|
||||||
model_wrap_cfg = CFGDenoiser(self.model)
|
model_wrap_cfg = CFGDenoiser(self.model)
|
||||||
extra_args = {'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': unconditional_guidance_scale}
|
extra_args = {'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': unconditional_guidance_scale}
|
||||||
return (K.sampling.__dict__[f'sample_{self.schedule}'](model_wrap_cfg, x, sigmas, extra_args=extra_args, disable=not self.accelerator.is_main_process),
|
return (K.sampling.__dict__[f'sample_{self.schedule}'](model_wrap_cfg, x, sigmas, extra_args=extra_args),
|
||||||
None)
|
None)
|
||||||
|
|
||||||
def gather(samples_ddim):
|
|
||||||
return self.accelerator.gather(samples_ddim)
|
|
||||||
|
@ -467,17 +467,17 @@ The vast majority of these arguments default to reasonable values.
|
|||||||
elif self.sampler_name == 'ddim':
|
elif self.sampler_name == 'ddim':
|
||||||
self.sampler = DDIMSampler(self.model, device=self.device)
|
self.sampler = DDIMSampler(self.model, device=self.device)
|
||||||
elif self.sampler_name == 'k_dpm_2_a':
|
elif self.sampler_name == 'k_dpm_2_a':
|
||||||
self.sampler = KSampler(self.model,'dpm_2_ancestral')
|
self.sampler = KSampler(self.model, 'dpm_2_ancestral', device=self.device)
|
||||||
elif self.sampler_name == 'k_dpm_2':
|
elif self.sampler_name == 'k_dpm_2':
|
||||||
self.sampler = KSampler(self.model,'dpm_2')
|
self.sampler = KSampler(self.model, 'dpm_2', device=self.device)
|
||||||
elif self.sampler_name == 'k_euler_a':
|
elif self.sampler_name == 'k_euler_a':
|
||||||
self.sampler = KSampler(self.model,'euler_ancestral')
|
self.sampler = KSampler(self.model, 'euler_ancestral', device=self.device)
|
||||||
elif self.sampler_name == 'k_euler':
|
elif self.sampler_name == 'k_euler':
|
||||||
self.sampler = KSampler(self.model,'euler')
|
self.sampler = KSampler(self.model, 'euler', device=self.device)
|
||||||
elif self.sampler_name == 'k_heun':
|
elif self.sampler_name == 'k_heun':
|
||||||
self.sampler = KSampler(self.model,'heun')
|
self.sampler = KSampler(self.model, 'heun', device=self.device)
|
||||||
elif self.sampler_name == 'k_lms':
|
elif self.sampler_name == 'k_lms':
|
||||||
self.sampler = KSampler(self.model,'lms')
|
self.sampler = KSampler(self.model, 'lms', device=self.device)
|
||||||
else:
|
else:
|
||||||
msg = f'unsupported sampler {self.sampler_name}, defaulting to plms'
|
msg = f'unsupported sampler {self.sampler_name}, defaulting to plms'
|
||||||
self.sampler = PLMSSampler(self.model, device=self.device)
|
self.sampler = PLMSSampler(self.model, device=self.device)
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
accelerate==0.12.0
|
|
||||||
albumentations==0.4.3
|
albumentations==0.4.3
|
||||||
einops==0.3.0
|
einops==0.3.0
|
||||||
huggingface-hub==0.8.1
|
huggingface-hub==0.8.1
|
||||||
|
@ -12,7 +12,6 @@ from pytorch_lightning import seed_everything
|
|||||||
from torch import autocast
|
from torch import autocast
|
||||||
from contextlib import contextmanager, nullcontext
|
from contextlib import contextmanager, nullcontext
|
||||||
|
|
||||||
import accelerate
|
|
||||||
import k_diffusion as K
|
import k_diffusion as K
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
@ -201,8 +200,6 @@ def main():
|
|||||||
|
|
||||||
#for klms
|
#for klms
|
||||||
model_wrap = K.external.CompVisDenoiser(model)
|
model_wrap = K.external.CompVisDenoiser(model)
|
||||||
accelerator = accelerate.Accelerator()
|
|
||||||
device = accelerator.device
|
|
||||||
class CFGDenoiser(nn.Module):
|
class CFGDenoiser(nn.Module):
|
||||||
def __init__(self, model):
|
def __init__(self, model):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -251,8 +248,8 @@ def main():
|
|||||||
with model.ema_scope():
|
with model.ema_scope():
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
all_samples = list()
|
all_samples = list()
|
||||||
for n in trange(opt.n_iter, desc="Sampling", disable =not accelerator.is_main_process):
|
for n in trange(opt.n_iter, desc="Sampling"):
|
||||||
for prompts in tqdm(data, desc="data", disable =not accelerator.is_main_process):
|
for prompts in tqdm(data, desc="data"):
|
||||||
uc = None
|
uc = None
|
||||||
if opt.scale != 1.0:
|
if opt.scale != 1.0:
|
||||||
uc = model.get_learned_conditioning(batch_size * [""])
|
uc = model.get_learned_conditioning(batch_size * [""])
|
||||||
@ -279,14 +276,11 @@ def main():
|
|||||||
x = torch.randn([opt.n_samples, *shape], device=device) * sigmas[0] # for GPU draw
|
x = torch.randn([opt.n_samples, *shape], device=device) * sigmas[0] # for GPU draw
|
||||||
model_wrap_cfg = CFGDenoiser(model_wrap)
|
model_wrap_cfg = CFGDenoiser(model_wrap)
|
||||||
extra_args = {'cond': c, 'uncond': uc, 'cond_scale': opt.scale}
|
extra_args = {'cond': c, 'uncond': uc, 'cond_scale': opt.scale}
|
||||||
samples_ddim = K.sampling.sample_lms(model_wrap_cfg, x, sigmas, extra_args=extra_args, disable=not accelerator.is_main_process)
|
samples_ddim = K.sampling.sample_lms(model_wrap_cfg, x, sigmas, extra_args=extra_args)
|
||||||
|
|
||||||
x_samples_ddim = model.decode_first_stage(samples_ddim)
|
x_samples_ddim = model.decode_first_stage(samples_ddim)
|
||||||
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
|
|
||||||
if opt.klms:
|
|
||||||
x_sample = accelerator.gather(x_samples_ddim)
|
|
||||||
|
|
||||||
if not opt.skip_save:
|
if not opt.skip_save:
|
||||||
for x_sample in x_samples_ddim:
|
for x_sample in x_samples_ddim:
|
||||||
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
|
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
|
||||||
|
Loading…
Reference in New Issue
Block a user