mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Fixes issue with cuda/current mismatch
This commit is contained in:
parent
c6b5e930dc
commit
b5cdbd3b0b
@ -10,13 +10,17 @@ from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, mak
|
||||
|
||||
|
||||
class DDIMSampler(object):
|
||||
def __init__(self, model, schedule="linear", **kwargs):
|
||||
def __init__(self, model, schedule="linear", device="cuda", **kwargs):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.ddpm_num_timesteps = model.num_timesteps
|
||||
self.schedule = schedule
|
||||
self.device = device
|
||||
|
||||
def register_buffer(self, name, attr):
|
||||
if type(attr) == torch.Tensor:
|
||||
if attr.device != torch.device(self.device):
|
||||
attr = attr.to(torch.device(self.device))
|
||||
setattr(self, name, attr)
|
||||
|
||||
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
|
||||
|
@ -9,13 +9,18 @@ from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, mak
|
||||
|
||||
|
||||
class PLMSSampler(object):
|
||||
def __init__(self, model, schedule="linear", **kwargs):
|
||||
def __init__(self, model, schedule="linear", device="cuda", **kwargs):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.ddpm_num_timesteps = model.num_timesteps
|
||||
self.schedule = schedule
|
||||
self.device = device
|
||||
|
||||
def register_buffer(self, name, attr):
|
||||
if type(attr) == torch.Tensor:
|
||||
if attr.device != torch.device(self.device):
|
||||
attr = attr.to(torch.device(self.device))
|
||||
|
||||
setattr(self, name, attr)
|
||||
|
||||
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
|
||||
|
@ -58,7 +58,6 @@ import sys
|
||||
import os
|
||||
from omegaconf import OmegaConf
|
||||
from PIL import Image
|
||||
import PIL
|
||||
from tqdm import tqdm, trange
|
||||
from itertools import islice
|
||||
from einops import rearrange, repeat
|
||||
@ -286,7 +285,8 @@ The vast majority of these arguments default to reasonable values.
|
||||
@torch.no_grad()
|
||||
def img2img(self,prompt,outdir=None,init_img=None,batch_size=None,iterations=None,
|
||||
steps=None,seed=None,grid=None,individual=None,width=None,height=None,
|
||||
cfg_scale=None,ddim_eta=None,strength=None,skip_normalize=False):
|
||||
cfg_scale=None,ddim_eta=None,strength=None,embedding_path=None,
|
||||
skip_normalize=False,variants=None):
|
||||
"""
|
||||
Generate an image from the prompt and the initial image, writing iteration images into the outdir
|
||||
The output is a list of lists in the format: [[filename1,seed1], [filename2,seed2],...]
|
||||
@ -324,7 +324,7 @@ The vast majority of these arguments default to reasonable values.
|
||||
# PLMS sampler not supported yet, so ignore previous sampler
|
||||
if self.sampler_name!='ddim':
|
||||
print(f"sampler '{self.sampler_name}' is not yet supported. Using DDM sampler")
|
||||
sampler = DDIMSampler(model)
|
||||
sampler = DDIMSampler(model, device=self.device)
|
||||
else:
|
||||
sampler = self.sampler
|
||||
|
||||
@ -461,9 +461,9 @@ The vast majority of these arguments default to reasonable values.
|
||||
|
||||
msg = f'setting sampler to {self.sampler_name}'
|
||||
if self.sampler_name=='plms':
|
||||
self.sampler = PLMSSampler(self.model)
|
||||
self.sampler = PLMSSampler(self.model, device=self.device)
|
||||
elif self.sampler_name == 'ddim':
|
||||
self.sampler = DDIMSampler(self.model)
|
||||
self.sampler = DDIMSampler(self.model, device=self.device)
|
||||
elif self.sampler_name == 'k_dpm_2_a':
|
||||
self.sampler = KSampler(self.model,'dpm_2_ancestral')
|
||||
elif self.sampler_name == 'k_dpm_2':
|
||||
@ -478,7 +478,7 @@ The vast majority of these arguments default to reasonable values.
|
||||
self.sampler = KSampler(self.model,'lms')
|
||||
else:
|
||||
msg = f'unsupported sampler {self.sampler_name}, defaulting to plms'
|
||||
self.sampler = PLMSSampler(self.model)
|
||||
self.sampler = PLMSSampler(self.model, device=self.device)
|
||||
|
||||
print(msg)
|
||||
|
||||
@ -505,7 +505,7 @@ The vast majority of these arguments default to reasonable values.
|
||||
w, h = image.size
|
||||
print(f"loaded input image of size ({w}, {h}) from {path}")
|
||||
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
||||
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
|
||||
image = image.resize((w, h), resample=Image.Resampling.LANCZOS)
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image)
|
||||
|
@ -191,16 +191,16 @@ def main_loop(t2i,parser,log,infile):
|
||||
print(f"{newopt.init_img}")
|
||||
try:
|
||||
variantResults = t2i.img2img(**vars(newopt))
|
||||
allVariantResults.append([newopt,variantResults])
|
||||
except AssertionError as e:
|
||||
print(e)
|
||||
continue
|
||||
allVariantResults.append([newopt,variantResults])
|
||||
print(f"{opt.variants} Variants generated!")
|
||||
|
||||
print("Outputs:")
|
||||
write_log_message(t2i,opt,results,log)
|
||||
|
||||
if len(allVariantResults)>0:
|
||||
if allVariantResults:
|
||||
print("Variant outputs:")
|
||||
for vr in allVariantResults:
|
||||
write_log_message(t2i,vr[0],vr[1],log)
|
||||
|
@ -1 +1 @@
|
||||
Subproject commit db5799068749bf3a6d5845120ed32df16b7d883b
|
||||
Subproject commit ef1bf07627c9a10ba9137e68a0206b844544a7d9
|
Loading…
Reference in New Issue
Block a user