From b5cdbd3b0b7ee6d7cadb8ba4fbc5cdce2820c46d Mon Sep 17 00:00:00 2001 From: Sean McLellan Date: Wed, 24 Aug 2022 13:14:08 -0400 Subject: [PATCH] Fixes issue with cuda/current mismatch --- ldm/models/diffusion/ddim.py | 6 +++++- ldm/models/diffusion/plms.py | 7 ++++++- ldm/simplet2i.py | 14 +++++++------- scripts/dream.py | 4 ++-- src/k-diffusion | 2 +- 5 files changed, 21 insertions(+), 12 deletions(-) diff --git a/ldm/models/diffusion/ddim.py b/ldm/models/diffusion/ddim.py index 2ebaeabd22..ddf786b5a8 100644 --- a/ldm/models/diffusion/ddim.py +++ b/ldm/models/diffusion/ddim.py @@ -10,13 +10,17 @@ from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, mak class DDIMSampler(object): - def __init__(self, model, schedule="linear", **kwargs): + def __init__(self, model, schedule="linear", device="cuda", **kwargs): super().__init__() self.model = model self.ddpm_num_timesteps = model.num_timesteps self.schedule = schedule + self.device = device def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device(self.device): + attr = attr.to(torch.device(self.device)) setattr(self, name, attr) def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): diff --git a/ldm/models/diffusion/plms.py b/ldm/models/diffusion/plms.py index 5d09f023f3..5eafe1d7ce 100644 --- a/ldm/models/diffusion/plms.py +++ b/ldm/models/diffusion/plms.py @@ -9,13 +9,18 @@ from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, mak class PLMSSampler(object): - def __init__(self, model, schedule="linear", **kwargs): + def __init__(self, model, schedule="linear", device="cuda", **kwargs): super().__init__() self.model = model self.ddpm_num_timesteps = model.num_timesteps self.schedule = schedule + self.device = device def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device(self.device): + attr = attr.to(torch.device(self.device)) + setattr(self, name, attr) def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): diff --git a/ldm/simplet2i.py b/ldm/simplet2i.py index 183e2448eb..09141c8ebb 100644 --- a/ldm/simplet2i.py +++ b/ldm/simplet2i.py @@ -58,7 +58,6 @@ import sys import os from omegaconf import OmegaConf from PIL import Image -import PIL from tqdm import tqdm, trange from itertools import islice from einops import rearrange, repeat @@ -286,7 +285,8 @@ The vast majority of these arguments default to reasonable values. @torch.no_grad() def img2img(self,prompt,outdir=None,init_img=None,batch_size=None,iterations=None, steps=None,seed=None,grid=None,individual=None,width=None,height=None, - cfg_scale=None,ddim_eta=None,strength=None,skip_normalize=False): + cfg_scale=None,ddim_eta=None,strength=None,embedding_path=None, + skip_normalize=False,variants=None): """ Generate an image from the prompt and the initial image, writing iteration images into the outdir The output is a list of lists in the format: [[filename1,seed1], [filename2,seed2],...] @@ -324,7 +324,7 @@ The vast majority of these arguments default to reasonable values. # PLMS sampler not supported yet, so ignore previous sampler if self.sampler_name!='ddim': print(f"sampler '{self.sampler_name}' is not yet supported. Using DDM sampler") - sampler = DDIMSampler(model) + sampler = DDIMSampler(model, device=self.device) else: sampler = self.sampler @@ -461,9 +461,9 @@ The vast majority of these arguments default to reasonable values. msg = f'setting sampler to {self.sampler_name}' if self.sampler_name=='plms': - self.sampler = PLMSSampler(self.model) + self.sampler = PLMSSampler(self.model, device=self.device) elif self.sampler_name == 'ddim': - self.sampler = DDIMSampler(self.model) + self.sampler = DDIMSampler(self.model, device=self.device) elif self.sampler_name == 'k_dpm_2_a': self.sampler = KSampler(self.model,'dpm_2_ancestral') elif self.sampler_name == 'k_dpm_2': @@ -478,7 +478,7 @@ The vast majority of these arguments default to reasonable values. self.sampler = KSampler(self.model,'lms') else: msg = f'unsupported sampler {self.sampler_name}, defaulting to plms' - self.sampler = PLMSSampler(self.model) + self.sampler = PLMSSampler(self.model, device=self.device) print(msg) @@ -505,7 +505,7 @@ The vast majority of these arguments default to reasonable values. w, h = image.size print(f"loaded input image of size ({w}, {h}) from {path}") w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 - image = image.resize((w, h), resample=PIL.Image.LANCZOS) + image = image.resize((w, h), resample=Image.Resampling.LANCZOS) image = np.array(image).astype(np.float32) / 255.0 image = image[None].transpose(0, 3, 1, 2) image = torch.from_numpy(image) diff --git a/scripts/dream.py b/scripts/dream.py index 94d5e9a1a6..886d8bd682 100755 --- a/scripts/dream.py +++ b/scripts/dream.py @@ -191,16 +191,16 @@ def main_loop(t2i,parser,log,infile): print(f"{newopt.init_img}") try: variantResults = t2i.img2img(**vars(newopt)) + allVariantResults.append([newopt,variantResults]) except AssertionError as e: print(e) continue - allVariantResults.append([newopt,variantResults]) print(f"{opt.variants} Variants generated!") print("Outputs:") write_log_message(t2i,opt,results,log) - if len(allVariantResults)>0: + if allVariantResults: print("Variant outputs:") for vr in allVariantResults: write_log_message(t2i,vr[0],vr[1],log) diff --git a/src/k-diffusion b/src/k-diffusion index db57990687..ef1bf07627 160000 --- a/src/k-diffusion +++ b/src/k-diffusion @@ -1 +1 @@ -Subproject commit db5799068749bf3a6d5845120ed32df16b7d883b +Subproject commit ef1bf07627c9a10ba9137e68a0206b844544a7d9