mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'development' into fnformat
This commit is contained in:
commit
173dc34194
@ -34,23 +34,7 @@ from ldm.dream.image_util import InitImageResizer
|
|||||||
from ldm.dream.devices import choose_torch_device, choose_precision
|
from ldm.dream.devices import choose_torch_device, choose_precision
|
||||||
from ldm.dream.conditioning import get_uc_and_c
|
from ldm.dream.conditioning import get_uc_and_c
|
||||||
|
|
||||||
def fix_func(orig):
|
|
||||||
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
|
||||||
def new_func(*args, **kw):
|
|
||||||
device = kw.get("device", "mps")
|
|
||||||
kw["device"]="cpu"
|
|
||||||
return orig(*args, **kw).to(device)
|
|
||||||
return new_func
|
|
||||||
return orig
|
|
||||||
|
|
||||||
torch.rand = fix_func(torch.rand)
|
|
||||||
torch.rand_like = fix_func(torch.rand_like)
|
|
||||||
torch.randn = fix_func(torch.randn)
|
|
||||||
torch.randn_like = fix_func(torch.randn_like)
|
|
||||||
torch.randint = fix_func(torch.randint)
|
|
||||||
torch.randint_like = fix_func(torch.randint_like)
|
|
||||||
torch.bernoulli = fix_func(torch.bernoulli)
|
|
||||||
torch.multinomial = fix_func(torch.multinomial)
|
|
||||||
|
|
||||||
def fix_func(orig):
|
def fix_func(orig):
|
||||||
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
||||||
@ -70,23 +54,7 @@ torch.randint_like = fix_func(torch.randint_like)
|
|||||||
torch.bernoulli = fix_func(torch.bernoulli)
|
torch.bernoulli = fix_func(torch.bernoulli)
|
||||||
torch.multinomial = fix_func(torch.multinomial)
|
torch.multinomial = fix_func(torch.multinomial)
|
||||||
|
|
||||||
def fix_func(orig):
|
|
||||||
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
|
||||||
def new_func(*args, **kw):
|
|
||||||
device = kw.get("device", "mps")
|
|
||||||
kw["device"]="cpu"
|
|
||||||
return orig(*args, **kw).to(device)
|
|
||||||
return new_func
|
|
||||||
return orig
|
|
||||||
|
|
||||||
torch.rand = fix_func(torch.rand)
|
|
||||||
torch.rand_like = fix_func(torch.rand_like)
|
|
||||||
torch.randn = fix_func(torch.randn)
|
|
||||||
torch.randn_like = fix_func(torch.randn_like)
|
|
||||||
torch.randint = fix_func(torch.randint)
|
|
||||||
torch.randint_like = fix_func(torch.randint_like)
|
|
||||||
torch.bernoulli = fix_func(torch.bernoulli)
|
|
||||||
torch.multinomial = fix_func(torch.multinomial)
|
|
||||||
|
|
||||||
"""Simplified text to image API for stable diffusion/latent diffusion
|
"""Simplified text to image API for stable diffusion/latent diffusion
|
||||||
|
|
||||||
|
@ -39,6 +39,7 @@ class Sampler(object):
|
|||||||
ddim_eta=0.0,
|
ddim_eta=0.0,
|
||||||
verbose=False,
|
verbose=False,
|
||||||
):
|
):
|
||||||
|
self.total_steps = ddim_num_steps
|
||||||
self.ddim_timesteps = make_ddim_timesteps(
|
self.ddim_timesteps = make_ddim_timesteps(
|
||||||
ddim_discr_method=ddim_discretize,
|
ddim_discr_method=ddim_discretize,
|
||||||
num_ddim_timesteps=ddim_num_steps,
|
num_ddim_timesteps=ddim_num_steps,
|
||||||
@ -211,6 +212,7 @@ class Sampler(object):
|
|||||||
if ddim_use_original_steps
|
if ddim_use_original_steps
|
||||||
else np.flip(timesteps)
|
else np.flip(timesteps)
|
||||||
)
|
)
|
||||||
|
|
||||||
total_steps=steps
|
total_steps=steps
|
||||||
|
|
||||||
iterator = tqdm(
|
iterator = tqdm(
|
||||||
@ -305,7 +307,7 @@ class Sampler(object):
|
|||||||
|
|
||||||
time_range = np.flip(timesteps)
|
time_range = np.flip(timesteps)
|
||||||
total_steps = timesteps.shape[0]
|
total_steps = timesteps.shape[0]
|
||||||
print(f'>> Running {self.__class__.__name__} Sampling with {total_steps} timesteps')
|
print(f'>> Running {self.__class__.__name__} sampling starting at step {self.total_steps - t_start} of {self.total_steps} ({total_steps} new sampling steps)')
|
||||||
|
|
||||||
iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
|
iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
|
||||||
x_dec = x_latent
|
x_dec = x_latent
|
||||||
@ -351,11 +353,10 @@ class Sampler(object):
|
|||||||
return x_dec
|
return x_dec
|
||||||
|
|
||||||
def get_initial_image(self,x_T,shape,timesteps=None):
|
def get_initial_image(self,x_T,shape,timesteps=None):
|
||||||
x = torch.randn(shape, device=self.device)
|
|
||||||
if x_T is None:
|
if x_T is None:
|
||||||
return x
|
return torch.randn(shape, device=self.device)
|
||||||
else:
|
else:
|
||||||
return x_T + x
|
return x_T
|
||||||
|
|
||||||
def p_sample(
|
def p_sample(
|
||||||
self,
|
self,
|
||||||
|
Loading…
Reference in New Issue
Block a user