Fixes issue with cuda/current mismatch

This commit is contained in:
Sean McLellan 2022-08-24 13:14:08 -04:00
parent c6b5e930dc
commit b5cdbd3b0b
5 changed files with 21 additions and 12 deletions

View File

@ -10,13 +10,17 @@ from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, mak
class DDIMSampler(object): class DDIMSampler(object):
def __init__(self, model, schedule="linear", **kwargs): def __init__(self, model, schedule="linear", device="cuda", **kwargs):
super().__init__() super().__init__()
self.model = model self.model = model
self.ddpm_num_timesteps = model.num_timesteps self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule self.schedule = schedule
self.device = device
def register_buffer(self, name, attr): def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device(self.device):
attr = attr.to(torch.device(self.device))
setattr(self, name, attr) setattr(self, name, attr)
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):

View File

@ -9,13 +9,18 @@ from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, mak
class PLMSSampler(object): class PLMSSampler(object):
def __init__(self, model, schedule="linear", **kwargs): def __init__(self, model, schedule="linear", device="cuda", **kwargs):
super().__init__() super().__init__()
self.model = model self.model = model
self.ddpm_num_timesteps = model.num_timesteps self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule self.schedule = schedule
self.device = device
def register_buffer(self, name, attr): def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device(self.device):
attr = attr.to(torch.device(self.device))
setattr(self, name, attr) setattr(self, name, attr)
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):

View File

@ -58,7 +58,6 @@ import sys
import os import os
from omegaconf import OmegaConf from omegaconf import OmegaConf
from PIL import Image from PIL import Image
import PIL
from tqdm import tqdm, trange from tqdm import tqdm, trange
from itertools import islice from itertools import islice
from einops import rearrange, repeat from einops import rearrange, repeat
@ -286,7 +285,8 @@ The vast majority of these arguments default to reasonable values.
@torch.no_grad() @torch.no_grad()
def img2img(self,prompt,outdir=None,init_img=None,batch_size=None,iterations=None, def img2img(self,prompt,outdir=None,init_img=None,batch_size=None,iterations=None,
steps=None,seed=None,grid=None,individual=None,width=None,height=None, steps=None,seed=None,grid=None,individual=None,width=None,height=None,
cfg_scale=None,ddim_eta=None,strength=None,skip_normalize=False): cfg_scale=None,ddim_eta=None,strength=None,embedding_path=None,
skip_normalize=False,variants=None):
""" """
Generate an image from the prompt and the initial image, writing iteration images into the outdir Generate an image from the prompt and the initial image, writing iteration images into the outdir
The output is a list of lists in the format: [[filename1,seed1], [filename2,seed2],...] The output is a list of lists in the format: [[filename1,seed1], [filename2,seed2],...]
@ -324,7 +324,7 @@ The vast majority of these arguments default to reasonable values.
# PLMS sampler not supported yet, so ignore previous sampler # PLMS sampler not supported yet, so ignore previous sampler
if self.sampler_name!='ddim': if self.sampler_name!='ddim':
print(f"sampler '{self.sampler_name}' is not yet supported. Using DDM sampler") print(f"sampler '{self.sampler_name}' is not yet supported. Using DDM sampler")
sampler = DDIMSampler(model) sampler = DDIMSampler(model, device=self.device)
else: else:
sampler = self.sampler sampler = self.sampler
@ -461,9 +461,9 @@ The vast majority of these arguments default to reasonable values.
msg = f'setting sampler to {self.sampler_name}' msg = f'setting sampler to {self.sampler_name}'
if self.sampler_name=='plms': if self.sampler_name=='plms':
self.sampler = PLMSSampler(self.model) self.sampler = PLMSSampler(self.model, device=self.device)
elif self.sampler_name == 'ddim': elif self.sampler_name == 'ddim':
self.sampler = DDIMSampler(self.model) self.sampler = DDIMSampler(self.model, device=self.device)
elif self.sampler_name == 'k_dpm_2_a': elif self.sampler_name == 'k_dpm_2_a':
self.sampler = KSampler(self.model,'dpm_2_ancestral') self.sampler = KSampler(self.model,'dpm_2_ancestral')
elif self.sampler_name == 'k_dpm_2': elif self.sampler_name == 'k_dpm_2':
@ -478,7 +478,7 @@ The vast majority of these arguments default to reasonable values.
self.sampler = KSampler(self.model,'lms') self.sampler = KSampler(self.model,'lms')
else: else:
msg = f'unsupported sampler {self.sampler_name}, defaulting to plms' msg = f'unsupported sampler {self.sampler_name}, defaulting to plms'
self.sampler = PLMSSampler(self.model) self.sampler = PLMSSampler(self.model, device=self.device)
print(msg) print(msg)
@ -505,7 +505,7 @@ The vast majority of these arguments default to reasonable values.
w, h = image.size w, h = image.size
print(f"loaded input image of size ({w}, {h}) from {path}") print(f"loaded input image of size ({w}, {h}) from {path}")
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=PIL.Image.LANCZOS) image = image.resize((w, h), resample=Image.Resampling.LANCZOS)
image = np.array(image).astype(np.float32) / 255.0 image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2) image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image) image = torch.from_numpy(image)

View File

@ -191,16 +191,16 @@ def main_loop(t2i,parser,log,infile):
print(f"{newopt.init_img}") print(f"{newopt.init_img}")
try: try:
variantResults = t2i.img2img(**vars(newopt)) variantResults = t2i.img2img(**vars(newopt))
allVariantResults.append([newopt,variantResults])
except AssertionError as e: except AssertionError as e:
print(e) print(e)
continue continue
allVariantResults.append([newopt,variantResults])
print(f"{opt.variants} Variants generated!") print(f"{opt.variants} Variants generated!")
print("Outputs:") print("Outputs:")
write_log_message(t2i,opt,results,log) write_log_message(t2i,opt,results,log)
if len(allVariantResults)>0: if allVariantResults:
print("Variant outputs:") print("Variant outputs:")
for vr in allVariantResults: for vr in allVariantResults:
write_log_message(t2i,vr[0],vr[1],log) write_log_message(t2i,vr[0],vr[1],log)

@ -1 +1 @@
Subproject commit db5799068749bf3a6d5845120ed32df16b7d883b Subproject commit ef1bf07627c9a10ba9137e68a0206b844544a7d9