diff --git a/README.md b/README.md
index 5215530a9e..131289bab1 100644
--- a/README.md
+++ b/README.md
@@ -73,6 +73,14 @@ The --init_img (-I) option gives the path to the seed picture. --strength (-f) c
 the original will be modified, ranging from 0.0 (keep the original intact), to 1.0 (ignore the original
 completely). The default is 0.75, and ranges from 0.25-0.75 give interesting results.
 
+## Changes
+
+- v1.01 (21 August 2022)
+* added k_lms sampling **Please run "conda update -f environment.yaml" to load the k_lms dependencies**
+* use half precision arithmetic by default, resulting in faster execution and lower memory requirements
+Pass argument --full_precision to dream.py to get slower but more accurate image generation
+
+
 ## Installation
 
 ### Linux/Mac
diff --git a/environment.yaml b/environment.yaml
index 7f25da800a..0de05e815a 100644
--- a/environment.yaml
+++ b/environment.yaml
@@ -24,6 +24,8 @@ dependencies:
     - transformers==4.19.2
     - torchmetrics==0.6.0
     - kornia==0.6
-    - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
+    - accelerate==0.12.0
     - -e git+https://github.com/openai/CLIP.git@main#egg=clip
+    - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
+    - -e git+https://github.com/lstein/k-diffusion.git@master#egg=k-diffusion
     - -e .
diff --git a/ldm/models/diffusion/ksampler.py b/ldm/models/diffusion/ksampler.py
new file mode 100644
index 0000000000..cc4677f47e
--- /dev/null
+++ b/ldm/models/diffusion/ksampler.py
@@ -0,0 +1,74 @@
+'''wrapper around part of Karen Crownson's k-duffsion library, making it call compatible with other Samplers'''
+import k_diffusion as K
+import torch
+import torch.nn as nn
+import accelerate
+
+class CFGDenoiser(nn.Module):
+    def __init__(self, model):
+        super().__init__()
+        self.inner_model = model
+
+    def forward(self, x, sigma, uncond, cond, cond_scale):
+        x_in = torch.cat([x] * 2)
+        sigma_in = torch.cat([sigma] * 2)
+        cond_in = torch.cat([uncond, cond])
+        uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
+        return uncond + (cond - uncond) * cond_scale
+
+class KSampler(object):
+    def __init__(self,model,schedule="lms", **kwargs):
+        super().__init__()
+        self.model        = K.external.CompVisDenoiser(model)
+        self.accelerator  = accelerate.Accelerator()
+        self.device       = self.accelerator.device
+        self.schedule = schedule
+
+        def forward(self, x, sigma, uncond, cond, cond_scale):
+            x_in = torch.cat([x] * 2)
+            sigma_in = torch.cat([sigma] * 2)
+            cond_in = torch.cat([uncond, cond])
+            uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
+            return uncond + (cond - uncond) * cond_scale
+
+
+    # most of these arguments are ignored and are only present for compatibility with
+    # other samples
+    @torch.no_grad()
+    def sample(self,
+               S,
+               batch_size,
+               shape,
+               conditioning=None,
+               callback=None,
+               normals_sequence=None,
+               img_callback=None,
+               quantize_x0=False,
+               eta=0.,
+               mask=None,
+               x0=None,
+               temperature=1.,
+               noise_dropout=0.,
+               score_corrector=None,
+               corrector_kwargs=None,
+               verbose=True,
+               x_T=None,
+               log_every_t=100,
+               unconditional_guidance_scale=1.,
+               unconditional_conditioning=None,
+               # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
+               **kwargs
+               ):
+
+        sigmas = self.model.get_sigmas(S)
+        if x_T:
+            x = x_T
+        else:
+            x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0] # for GPU draw
+        model_wrap_cfg = CFGDenoiser(self.model)
+        extra_args = {'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': unconditional_guidance_scale}
+        return (K.sampling.sample_lms(model_wrap_cfg, x, sigmas, extra_args=extra_args, disable=not self.accelerator.is_main_process),
+                None)
+
+    def gather(samples_ddim):
+        return self.accelerator.gather(samples_ddim)
diff --git a/ldm/simplet2i.py b/ldm/simplet2i.py
index 796a99396b..e99660a8ab 100644
--- a/ldm/simplet2i.py
+++ b/ldm/simplet2i.py
@@ -11,7 +11,7 @@ t2i = T2I(outdir      = <path>        // outputs/txt2img-samples
           batch_size       = <integer>     // how many images to generate per sampling (1)
           steps       = <integer>     // 50
           seed        = <integer>     // current system time
-          sampler     = ['ddim','plms']  // ddim
+          sampler     = ['ddim','plms','klms']  // klms
           grid        = <boolean>     // false
           width       = <integer>     // image width, multiple of 64 (512)
           height      = <integer>     // image height, multiple of 64 (512)
@@ -62,8 +62,9 @@ import time
 import math
 
 from ldm.util import instantiate_from_config
-from ldm.models.diffusion.ddim import DDIMSampler
-from ldm.models.diffusion.plms import PLMSSampler
+from ldm.models.diffusion.ddim     import DDIMSampler
+from ldm.models.diffusion.plms     import PLMSSampler
+from ldm.models.diffusion.ksampler import KSampler
 
 class T2I:
     """T2I class
@@ -101,12 +102,13 @@ class T2I:
                  cfg_scale=7.5,
                  weights="models/ldm/stable-diffusion-v1/model.ckpt",
                  config = "configs/latent-diffusion/txt2img-1p4B-eval.yaml",
-                 sampler="plms",
+                 sampler="klms",
                  latent_channels=4,
                  downsampling_factor=8,
                  ddim_eta=0.0,  # deterministic
                  fixed_code=False,
                  precision='autocast',
+                 full_precision=False,
                  strength=0.75 # default in scripts/img2img.py
     ):
         self.outdir     = outdir
@@ -125,6 +127,7 @@ class T2I:
         self.downsampling_factor = downsampling_factor
         self.ddim_eta            = ddim_eta
         self.precision           = precision
+        self.full_precision      = full_precision
         self.strength            = strength
         self.model      = None     # empty for now
         self.sampler    = None
@@ -387,6 +390,9 @@ class T2I:
             elif self.sampler_name == 'ddim':
                 print("setting sampler to ddim")
                 self.sampler = DDIMSampler(self.model)
+            elif self.sampler_name == 'klms':
+                print("setting sampler to klms")
+                self.sampler = KSampler(self.model,'lms')
             else:
                 print(f"unsupported sampler {self.sampler_name}, defaulting to plms")
                 self.sampler = PLMSSampler(self.model)
@@ -403,7 +409,11 @@ class T2I:
         m, u = model.load_state_dict(sd, strict=False)
         model.cuda()
         model.eval()
-        model.half()
+        if self.full_precision:
+            print('Using slower but more accurate full-precision math (--full_precision)')
+        else:
+            print('Using half precision math. Call with --full_precision to use full precision')
+            model.half()
         return model
 
     def _load_img(self,path):
diff --git a/scripts/dream.py b/scripts/dream.py
index 44b7d9978a..b8abb780fd 100755
--- a/scripts/dream.py
+++ b/scripts/dream.py
@@ -1,3 +1,4 @@
+#!/usr/bin/env python3
 import argparse
 import shlex
 import atexit
@@ -49,6 +50,7 @@ def main():
               outdir=opt.outdir,
               sampler=opt.sampler,
               weights=weights,
+              full_precision=opt.full_precision,
               config=config)
 
     # make sure the output directory exists
@@ -165,14 +167,18 @@ def create_argv_parser():
                         type=int,
                         default=1,
                         help="number of images to generate")
+    parser.add_argument('-F','--full_precision',
+                        dest='full_precision',
+                        action='store_true',
+                        help="use slower full precision math for calculations")
     parser.add_argument('-b','--batch_size',
                         type=int,
                         default=1,
                         help="number of images to produce per iteration (currently not working properly - producing too many images)")
     parser.add_argument('--sampler',
-                        choices=['plms','ddim'],
-                        default='plms',
-                        help="which sampler to use")
+                        choices=['plms','ddim', 'klms'],
+                        default='klms',
+                        help="which sampler to use (klms)")
     parser.add_argument('-o',
                         '--outdir',
                         type=str,
diff --git a/scripts/preload_models.py b/scripts/preload_models.py
index 7db461bec2..ad1a1eecc5 100644
--- a/scripts/preload_models.py
+++ b/scripts/preload_models.py
@@ -1,3 +1,4 @@
+#!/usr/bin/env python3
 # Before running stable-diffusion on an internet-isolated machine,
 # run this script from one with internet connectivity. The
 # two machines must share a common .cache directory.
diff --git a/scripts/txt2img.py b/scripts/txt2img.py
index da77e1a03e..42d5e83496 100644
--- a/scripts/txt2img.py
+++ b/scripts/txt2img.py
@@ -12,6 +12,10 @@ from pytorch_lightning import seed_everything
 from torch import autocast
 from contextlib import contextmanager, nullcontext
 
+import accelerate
+import k_diffusion as K
+import torch.nn as nn
+
 from ldm.util import instantiate_from_config
 from ldm.models.diffusion.ddim import DDIMSampler
 from ldm.models.diffusion.plms import PLMSSampler
@@ -80,6 +84,11 @@ def main():
         action='store_true',
         help="use plms sampling",
     )
+    parser.add_argument(
+        "--klms",
+        action='store_true',
+        help="use klms sampling",
+    )
     parser.add_argument(
         "--laion400m",
         action='store_true',
@@ -190,6 +199,22 @@ def main():
     device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
     model = model.to(device)
 
+    #for klms
+    model_wrap = K.external.CompVisDenoiser(model)
+    accelerator = accelerate.Accelerator()
+    device = accelerator.device
+    class CFGDenoiser(nn.Module):
+        def __init__(self, model):
+            super().__init__()
+            self.inner_model = model
+
+        def forward(self, x, sigma, uncond, cond, cond_scale):
+            x_in = torch.cat([x] * 2)
+            sigma_in = torch.cat([sigma] * 2)
+            cond_in = torch.cat([uncond, cond])
+            uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
+            return uncond + (cond - uncond) * cond_scale
+
     if opt.plms:
         sampler = PLMSSampler(model)
     else:
@@ -226,8 +251,8 @@ def main():
             with model.ema_scope():
                 tic = time.time()
                 all_samples = list()
-                for n in trange(opt.n_iter, desc="Sampling"):
-                    for prompts in tqdm(data, desc="data"):
+                for n in trange(opt.n_iter, desc="Sampling", disable =not accelerator.is_main_process):
+                    for prompts in tqdm(data, desc="data", disable =not accelerator.is_main_process):
                         uc = None
                         if opt.scale != 1.0:
                             uc = model.get_learned_conditioning(batch_size * [""])
@@ -235,18 +260,32 @@ def main():
                             prompts = list(prompts)
                         c = model.get_learned_conditioning(prompts)
                         shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
-                        samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
-                                                         conditioning=c,
-                                                         batch_size=opt.n_samples,
-                                                         shape=shape,
-                                                         verbose=False,
-                                                         unconditional_guidance_scale=opt.scale,
-                                                         unconditional_conditioning=uc,
-                                                         eta=opt.ddim_eta,
-                                                         x_T=start_code)
-
+                        
+                        if not opt.klms:
+                            samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
+                                                            conditioning=c,
+                                                            batch_size=opt.n_samples,
+                                                            shape=shape,
+                                                            verbose=False,
+                                                            unconditional_guidance_scale=opt.scale,
+                                                            unconditional_conditioning=uc,
+                                                            eta=opt.ddim_eta,
+                                                            x_T=start_code)
+                        else:
+                            sigmas = model_wrap.get_sigmas(opt.ddim_steps)
+                            if start_code:
+                                x = start_code
+                            else:
+                                x = torch.randn([opt.n_samples, *shape], device=device) * sigmas[0] # for GPU draw
+                            model_wrap_cfg = CFGDenoiser(model_wrap)
+                            extra_args = {'cond': c, 'uncond': uc, 'cond_scale': opt.scale}
+                            samples_ddim = K.sampling.sample_lms(model_wrap_cfg, x, sigmas, extra_args=extra_args, disable=not accelerator.is_main_process)
+                        
                         x_samples_ddim = model.decode_first_stage(samples_ddim)
                         x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
+                        
+                        if opt.klms:
+                            x_sample = accelerator.gather(x_samples_ddim)
 
                         if not opt.skip_save:
                             for x_sample in x_samples_ddim: