mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Fixes issue with cuda/current mismatch
This commit is contained in:
parent
c6b5e930dc
commit
b5cdbd3b0b
@ -10,13 +10,17 @@ from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, mak
|
|||||||
|
|
||||||
|
|
||||||
class DDIMSampler(object):
|
class DDIMSampler(object):
|
||||||
def __init__(self, model, schedule="linear", **kwargs):
|
def __init__(self, model, schedule="linear", device="cuda", **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.model = model
|
self.model = model
|
||||||
self.ddpm_num_timesteps = model.num_timesteps
|
self.ddpm_num_timesteps = model.num_timesteps
|
||||||
self.schedule = schedule
|
self.schedule = schedule
|
||||||
|
self.device = device
|
||||||
|
|
||||||
def register_buffer(self, name, attr):
|
def register_buffer(self, name, attr):
|
||||||
|
if type(attr) == torch.Tensor:
|
||||||
|
if attr.device != torch.device(self.device):
|
||||||
|
attr = attr.to(torch.device(self.device))
|
||||||
setattr(self, name, attr)
|
setattr(self, name, attr)
|
||||||
|
|
||||||
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
|
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
|
||||||
|
@ -9,13 +9,18 @@ from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, mak
|
|||||||
|
|
||||||
|
|
||||||
class PLMSSampler(object):
|
class PLMSSampler(object):
|
||||||
def __init__(self, model, schedule="linear", **kwargs):
|
def __init__(self, model, schedule="linear", device="cuda", **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.model = model
|
self.model = model
|
||||||
self.ddpm_num_timesteps = model.num_timesteps
|
self.ddpm_num_timesteps = model.num_timesteps
|
||||||
self.schedule = schedule
|
self.schedule = schedule
|
||||||
|
self.device = device
|
||||||
|
|
||||||
def register_buffer(self, name, attr):
|
def register_buffer(self, name, attr):
|
||||||
|
if type(attr) == torch.Tensor:
|
||||||
|
if attr.device != torch.device(self.device):
|
||||||
|
attr = attr.to(torch.device(self.device))
|
||||||
|
|
||||||
setattr(self, name, attr)
|
setattr(self, name, attr)
|
||||||
|
|
||||||
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
|
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
|
||||||
|
@ -58,7 +58,6 @@ import sys
|
|||||||
import os
|
import os
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import PIL
|
|
||||||
from tqdm import tqdm, trange
|
from tqdm import tqdm, trange
|
||||||
from itertools import islice
|
from itertools import islice
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
@ -286,7 +285,8 @@ The vast majority of these arguments default to reasonable values.
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def img2img(self,prompt,outdir=None,init_img=None,batch_size=None,iterations=None,
|
def img2img(self,prompt,outdir=None,init_img=None,batch_size=None,iterations=None,
|
||||||
steps=None,seed=None,grid=None,individual=None,width=None,height=None,
|
steps=None,seed=None,grid=None,individual=None,width=None,height=None,
|
||||||
cfg_scale=None,ddim_eta=None,strength=None,skip_normalize=False):
|
cfg_scale=None,ddim_eta=None,strength=None,embedding_path=None,
|
||||||
|
skip_normalize=False,variants=None):
|
||||||
"""
|
"""
|
||||||
Generate an image from the prompt and the initial image, writing iteration images into the outdir
|
Generate an image from the prompt and the initial image, writing iteration images into the outdir
|
||||||
The output is a list of lists in the format: [[filename1,seed1], [filename2,seed2],...]
|
The output is a list of lists in the format: [[filename1,seed1], [filename2,seed2],...]
|
||||||
@ -324,7 +324,7 @@ The vast majority of these arguments default to reasonable values.
|
|||||||
# PLMS sampler not supported yet, so ignore previous sampler
|
# PLMS sampler not supported yet, so ignore previous sampler
|
||||||
if self.sampler_name!='ddim':
|
if self.sampler_name!='ddim':
|
||||||
print(f"sampler '{self.sampler_name}' is not yet supported. Using DDM sampler")
|
print(f"sampler '{self.sampler_name}' is not yet supported. Using DDM sampler")
|
||||||
sampler = DDIMSampler(model)
|
sampler = DDIMSampler(model, device=self.device)
|
||||||
else:
|
else:
|
||||||
sampler = self.sampler
|
sampler = self.sampler
|
||||||
|
|
||||||
@ -461,9 +461,9 @@ The vast majority of these arguments default to reasonable values.
|
|||||||
|
|
||||||
msg = f'setting sampler to {self.sampler_name}'
|
msg = f'setting sampler to {self.sampler_name}'
|
||||||
if self.sampler_name=='plms':
|
if self.sampler_name=='plms':
|
||||||
self.sampler = PLMSSampler(self.model)
|
self.sampler = PLMSSampler(self.model, device=self.device)
|
||||||
elif self.sampler_name == 'ddim':
|
elif self.sampler_name == 'ddim':
|
||||||
self.sampler = DDIMSampler(self.model)
|
self.sampler = DDIMSampler(self.model, device=self.device)
|
||||||
elif self.sampler_name == 'k_dpm_2_a':
|
elif self.sampler_name == 'k_dpm_2_a':
|
||||||
self.sampler = KSampler(self.model,'dpm_2_ancestral')
|
self.sampler = KSampler(self.model,'dpm_2_ancestral')
|
||||||
elif self.sampler_name == 'k_dpm_2':
|
elif self.sampler_name == 'k_dpm_2':
|
||||||
@ -478,7 +478,7 @@ The vast majority of these arguments default to reasonable values.
|
|||||||
self.sampler = KSampler(self.model,'lms')
|
self.sampler = KSampler(self.model,'lms')
|
||||||
else:
|
else:
|
||||||
msg = f'unsupported sampler {self.sampler_name}, defaulting to plms'
|
msg = f'unsupported sampler {self.sampler_name}, defaulting to plms'
|
||||||
self.sampler = PLMSSampler(self.model)
|
self.sampler = PLMSSampler(self.model, device=self.device)
|
||||||
|
|
||||||
print(msg)
|
print(msg)
|
||||||
|
|
||||||
@ -505,7 +505,7 @@ The vast majority of these arguments default to reasonable values.
|
|||||||
w, h = image.size
|
w, h = image.size
|
||||||
print(f"loaded input image of size ({w}, {h}) from {path}")
|
print(f"loaded input image of size ({w}, {h}) from {path}")
|
||||||
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
||||||
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
|
image = image.resize((w, h), resample=Image.Resampling.LANCZOS)
|
||||||
image = np.array(image).astype(np.float32) / 255.0
|
image = np.array(image).astype(np.float32) / 255.0
|
||||||
image = image[None].transpose(0, 3, 1, 2)
|
image = image[None].transpose(0, 3, 1, 2)
|
||||||
image = torch.from_numpy(image)
|
image = torch.from_numpy(image)
|
||||||
|
@ -191,16 +191,16 @@ def main_loop(t2i,parser,log,infile):
|
|||||||
print(f"{newopt.init_img}")
|
print(f"{newopt.init_img}")
|
||||||
try:
|
try:
|
||||||
variantResults = t2i.img2img(**vars(newopt))
|
variantResults = t2i.img2img(**vars(newopt))
|
||||||
|
allVariantResults.append([newopt,variantResults])
|
||||||
except AssertionError as e:
|
except AssertionError as e:
|
||||||
print(e)
|
print(e)
|
||||||
continue
|
continue
|
||||||
allVariantResults.append([newopt,variantResults])
|
|
||||||
print(f"{opt.variants} Variants generated!")
|
print(f"{opt.variants} Variants generated!")
|
||||||
|
|
||||||
print("Outputs:")
|
print("Outputs:")
|
||||||
write_log_message(t2i,opt,results,log)
|
write_log_message(t2i,opt,results,log)
|
||||||
|
|
||||||
if len(allVariantResults)>0:
|
if allVariantResults:
|
||||||
print("Variant outputs:")
|
print("Variant outputs:")
|
||||||
for vr in allVariantResults:
|
for vr in allVariantResults:
|
||||||
write_log_message(t2i,vr[0],vr[1],log)
|
write_log_message(t2i,vr[0],vr[1],log)
|
||||||
|
@ -1 +1 @@
|
|||||||
Subproject commit db5799068749bf3a6d5845120ed32df16b7d883b
|
Subproject commit ef1bf07627c9a10ba9137e68a0206b844544a7d9
|
Loading…
Reference in New Issue
Block a user