preparing for merge into main

This commit is contained in:
Lincoln Stein 2022-08-21 19:57:48 -04:00
parent bb91ca0462
commit 78aba5b770
6 changed files with 47 additions and 15 deletions

View File

@ -73,6 +73,14 @@ The --init_img (-I) option gives the path to the seed picture. --strength (-f) c
the original will be modified, ranging from 0.0 (keep the original intact), to 1.0 (ignore the original the original will be modified, ranging from 0.0 (keep the original intact), to 1.0 (ignore the original
completely). The default is 0.75, and ranges from 0.25-0.75 give interesting results. completely). The default is 0.75, and ranges from 0.25-0.75 give interesting results.
## Changes
*v1.01 (21 August 2022)
-added k_lms sampling **Please run "conda update -f environment.yaml" to load the k_lms dependencies**
-*use half precision arithmetic by default, resulting in faster execution and lower memory requirements
Pass argument --full_precision to dream.py to get slower but more accurate image generation
## Installation ## Installation
### Linux/Mac ### Linux/Mac

View File

@ -25,7 +25,7 @@ dependencies:
- torchmetrics==0.6.0 - torchmetrics==0.6.0
- kornia==0.6 - kornia==0.6
- accelerate==0.12.0 - accelerate==0.12.0
- git+https://github.com/crowsonkb/k-diffusion.git@master
- -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
- -e git+https://github.com/openai/CLIP.git@main#egg=clip - -e git+https://github.com/openai/CLIP.git@main#egg=clip
- -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
- -e git+https://github.com/lstein/k-diffusion.git@master#egg=k-diffusion
- -e . - -e .

View File

@ -1,12 +1,29 @@
'''wrapper around part of Karen Crownson's k-duffsion library, making it call compatible with other Samplers''' '''wrapper around part of Karen Crownson's k-duffsion library, making it call compatible with other Samplers'''
import k_diffusion as K import k_diffusion as K
import torch
import torch.nn as nn import torch.nn as nn
import accelerate
class CFGDenoiser(nn.Module): class CFGDenoiser(nn.Module):
def __init__(self, model): def __init__(self, model):
super().__init__() super().__init__()
self.inner_model = model self.inner_model = model
def forward(self, x, sigma, uncond, cond, cond_scale):
x_in = torch.cat([x] * 2)
sigma_in = torch.cat([sigma] * 2)
cond_in = torch.cat([uncond, cond])
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
return uncond + (cond - uncond) * cond_scale
class KSampler(object):
def __init__(self,model,schedule="lms", **kwargs):
super().__init__()
self.model = K.external.CompVisDenoiser(model)
self.accelerator = accelerate.Accelerator()
self.device = self.accelerator.device
self.schedule = schedule
def forward(self, x, sigma, uncond, cond, cond_scale): def forward(self, x, sigma, uncond, cond, cond_scale):
x_in = torch.cat([x] * 2) x_in = torch.cat([x] * 2)
sigma_in = torch.cat([sigma] * 2) sigma_in = torch.cat([sigma] * 2)
@ -14,13 +31,6 @@ class CFGDenoiser(nn.Module):
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
return uncond + (cond - uncond) * cond_scale return uncond + (cond - uncond) * cond_scale
class KSampler(object):
def __init__(self,model,schedule="lms", **kwargs):
super().__init__()
self.model = K.external.CompVisDenoiser(model)
self.accelerator = accelerate.Accelerator()
self.device = accelerator.device
self.schedule = schedule
# most of these arguments are ignored and are only present for compatibility with # most of these arguments are ignored and are only present for compatibility with
# other samples # other samples
@ -54,10 +64,10 @@ class KSampler(object):
if x_T: if x_T:
x = x_T x = x_T
else: else:
x = torch.randn([batch_size, *shape], device=device) * sigmas[0] # for GPU draw x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0] # for GPU draw
model_wrap_cfg = CFGDenoiser(self.model) model_wrap_cfg = CFGDenoiser(self.model)
extra_args = {'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': unconditional_guidance_scale} extra_args = {'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': unconditional_guidance_scale}
return (K.sampling.sample_lms(model_wrap_cfg, x, sigmas, extra_args=extra_args, disable=not accelerator.is_main_process), return (K.sampling.sample_lms(model_wrap_cfg, x, sigmas, extra_args=extra_args, disable=not self.accelerator.is_main_process),
None) None)
def gather(samples_ddim): def gather(samples_ddim):

View File

@ -108,6 +108,7 @@ class T2I:
ddim_eta=0.0, # deterministic ddim_eta=0.0, # deterministic
fixed_code=False, fixed_code=False,
precision='autocast', precision='autocast',
full_precision=False,
strength=0.75 # default in scripts/img2img.py strength=0.75 # default in scripts/img2img.py
): ):
self.outdir = outdir self.outdir = outdir
@ -126,6 +127,7 @@ class T2I:
self.downsampling_factor = downsampling_factor self.downsampling_factor = downsampling_factor
self.ddim_eta = ddim_eta self.ddim_eta = ddim_eta
self.precision = precision self.precision = precision
self.full_precision = full_precision
self.strength = strength self.strength = strength
self.model = None # empty for now self.model = None # empty for now
self.sampler = None self.sampler = None
@ -407,7 +409,12 @@ class T2I:
m, u = model.load_state_dict(sd, strict=False) m, u = model.load_state_dict(sd, strict=False)
model.cuda() model.cuda()
model.eval() model.eval()
model.half() if self.full_precision:
print('Using slower but more accurate full precision math')
model.full()
else:
print('Using half precision math. Call with --full_precision to use full precision')
model.half()
return model return model
def _load_img(self,path): def _load_img(self,path):

View File

@ -1,3 +1,4 @@
#!/usr/bin/env python3
import argparse import argparse
import shlex import shlex
import atexit import atexit
@ -11,7 +12,7 @@ try:
except: except:
readline_available = False readline_available = False
debugging = True debugging = False
def main(): def main():
''' Initialize command-line parsers and the diffusion model ''' ''' Initialize command-line parsers and the diffusion model '''
@ -49,6 +50,7 @@ def main():
outdir=opt.outdir, outdir=opt.outdir,
sampler=opt.sampler, sampler=opt.sampler,
weights=weights, weights=weights,
full_precision=opt.full_precision,
config=config) config=config)
# make sure the output directory exists # make sure the output directory exists
@ -165,14 +167,18 @@ def create_argv_parser():
type=int, type=int,
default=1, default=1,
help="number of images to generate") help="number of images to generate")
parser.add_argument('-F','--full_precision',
dest='full_precision',
action='store_true',
help="use slower full precision math for calculations")
parser.add_argument('-b','--batch_size', parser.add_argument('-b','--batch_size',
type=int, type=int,
default=1, default=1,
help="number of images to produce per iteration (currently not working properly - producing too many images)") help="number of images to produce per iteration (currently not working properly - producing too many images)")
parser.add_argument('--sampler', parser.add_argument('--sampler',
choices=['plms','ddim', 'klms'], choices=['plms','ddim', 'klms'],
default='plms', default='klms',
help="which sampler to use") help="which sampler to use (klms)")
parser.add_argument('-o', parser.add_argument('-o',
'--outdir', '--outdir',
type=str, type=str,

View File

@ -1,3 +1,4 @@
#!/usr/bin/env python3
# Before running stable-diffusion on an internet-isolated machine, # Before running stable-diffusion on an internet-isolated machine,
# run this script from one with internet connectivity. The # run this script from one with internet connectivity. The
# two machines must share a common .cache directory. # two machines must share a common .cache directory.