Fixes issue with cuda/current mismatch

This commit is contained in:
Sean McLellan 2022-08-24 13:14:08 -04:00
parent c6b5e930dc
commit b5cdbd3b0b
5 changed files with 21 additions and 12 deletions

View File

@ -10,13 +10,17 @@ from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, mak
class DDIMSampler(object):
def __init__(self, model, schedule="linear", **kwargs):
def __init__(self, model, schedule="linear", device="cuda", **kwargs):
super().__init__()
self.model = model
self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule
self.device = device
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device(self.device):
attr = attr.to(torch.device(self.device))
setattr(self, name, attr)
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):

View File

@ -9,13 +9,18 @@ from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, mak
class PLMSSampler(object):
def __init__(self, model, schedule="linear", **kwargs):
def __init__(self, model, schedule="linear", device="cuda", **kwargs):
super().__init__()
self.model = model
self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule
self.device = device
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device(self.device):
attr = attr.to(torch.device(self.device))
setattr(self, name, attr)
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):

View File

@ -58,7 +58,6 @@ import sys
import os
from omegaconf import OmegaConf
from PIL import Image
import PIL
from tqdm import tqdm, trange
from itertools import islice
from einops import rearrange, repeat
@ -286,7 +285,8 @@ The vast majority of these arguments default to reasonable values.
@torch.no_grad()
def img2img(self,prompt,outdir=None,init_img=None,batch_size=None,iterations=None,
steps=None,seed=None,grid=None,individual=None,width=None,height=None,
cfg_scale=None,ddim_eta=None,strength=None,skip_normalize=False):
cfg_scale=None,ddim_eta=None,strength=None,embedding_path=None,
skip_normalize=False,variants=None):
"""
Generate an image from the prompt and the initial image, writing iteration images into the outdir
The output is a list of lists in the format: [[filename1,seed1], [filename2,seed2],...]
@ -324,7 +324,7 @@ The vast majority of these arguments default to reasonable values.
# PLMS sampler not supported yet, so ignore previous sampler
if self.sampler_name!='ddim':
print(f"sampler '{self.sampler_name}' is not yet supported. Using DDM sampler")
sampler = DDIMSampler(model)
sampler = DDIMSampler(model, device=self.device)
else:
sampler = self.sampler
@ -461,9 +461,9 @@ The vast majority of these arguments default to reasonable values.
msg = f'setting sampler to {self.sampler_name}'
if self.sampler_name=='plms':
self.sampler = PLMSSampler(self.model)
self.sampler = PLMSSampler(self.model, device=self.device)
elif self.sampler_name == 'ddim':
self.sampler = DDIMSampler(self.model)
self.sampler = DDIMSampler(self.model, device=self.device)
elif self.sampler_name == 'k_dpm_2_a':
self.sampler = KSampler(self.model,'dpm_2_ancestral')
elif self.sampler_name == 'k_dpm_2':
@ -478,7 +478,7 @@ The vast majority of these arguments default to reasonable values.
self.sampler = KSampler(self.model,'lms')
else:
msg = f'unsupported sampler {self.sampler_name}, defaulting to plms'
self.sampler = PLMSSampler(self.model)
self.sampler = PLMSSampler(self.model, device=self.device)
print(msg)
@ -505,7 +505,7 @@ The vast majority of these arguments default to reasonable values.
w, h = image.size
print(f"loaded input image of size ({w}, {h}) from {path}")
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
image = image.resize((w, h), resample=Image.Resampling.LANCZOS)
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)

View File

@ -191,16 +191,16 @@ def main_loop(t2i,parser,log,infile):
print(f"{newopt.init_img}")
try:
variantResults = t2i.img2img(**vars(newopt))
allVariantResults.append([newopt,variantResults])
except AssertionError as e:
print(e)
continue
allVariantResults.append([newopt,variantResults])
print(f"{opt.variants} Variants generated!")
print("Outputs:")
write_log_message(t2i,opt,results,log)
if len(allVariantResults)>0:
if allVariantResults:
print("Variant outputs:")
for vr in allVariantResults:
write_log_message(t2i,vr[0],vr[1],log)

@ -1 +1 @@
Subproject commit db5799068749bf3a6d5845120ed32df16b7d883b
Subproject commit ef1bf07627c9a10ba9137e68a0206b844544a7d9