diff --git a/ldm/generate.py b/ldm/generate.py index d268b909db..292329d24c 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -34,23 +34,7 @@ from ldm.dream.image_util import InitImageResizer from ldm.dream.devices import choose_torch_device, choose_precision from ldm.dream.conditioning import get_uc_and_c -def fix_func(orig): - if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): - def new_func(*args, **kw): - device = kw.get("device", "mps") - kw["device"]="cpu" - return orig(*args, **kw).to(device) - return new_func - return orig -torch.rand = fix_func(torch.rand) -torch.rand_like = fix_func(torch.rand_like) -torch.randn = fix_func(torch.randn) -torch.randn_like = fix_func(torch.randn_like) -torch.randint = fix_func(torch.randint) -torch.randint_like = fix_func(torch.randint_like) -torch.bernoulli = fix_func(torch.bernoulli) -torch.multinomial = fix_func(torch.multinomial) def fix_func(orig): if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): @@ -70,23 +54,7 @@ torch.randint_like = fix_func(torch.randint_like) torch.bernoulli = fix_func(torch.bernoulli) torch.multinomial = fix_func(torch.multinomial) -def fix_func(orig): - if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): - def new_func(*args, **kw): - device = kw.get("device", "mps") - kw["device"]="cpu" - return orig(*args, **kw).to(device) - return new_func - return orig -torch.rand = fix_func(torch.rand) -torch.rand_like = fix_func(torch.rand_like) -torch.randn = fix_func(torch.randn) -torch.randn_like = fix_func(torch.randn_like) -torch.randint = fix_func(torch.randint) -torch.randint_like = fix_func(torch.randint_like) -torch.bernoulli = fix_func(torch.bernoulli) -torch.multinomial = fix_func(torch.multinomial) """Simplified text to image API for stable diffusion/latent diffusion diff --git a/ldm/models/diffusion/sampler.py b/ldm/models/diffusion/sampler.py index a31ec96d6f..3056cbb6b8 100644 --- a/ldm/models/diffusion/sampler.py +++ b/ldm/models/diffusion/sampler.py @@ -39,6 +39,7 @@ class Sampler(object): ddim_eta=0.0, verbose=False, ): + self.total_steps = ddim_num_steps self.ddim_timesteps = make_ddim_timesteps( ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, @@ -211,6 +212,7 @@ class Sampler(object): if ddim_use_original_steps else np.flip(timesteps) ) + total_steps=steps iterator = tqdm( @@ -305,7 +307,7 @@ class Sampler(object): time_range = np.flip(timesteps) total_steps = timesteps.shape[0] - print(f'>> Running {self.__class__.__name__} Sampling with {total_steps} timesteps') + print(f'>> Running {self.__class__.__name__} sampling starting at step {self.total_steps - t_start} of {self.total_steps} ({total_steps} new sampling steps)') iterator = tqdm(time_range, desc='Decoding image', total=total_steps) x_dec = x_latent @@ -351,11 +353,10 @@ class Sampler(object): return x_dec def get_initial_image(self,x_T,shape,timesteps=None): - x = torch.randn(shape, device=self.device) if x_T is None: - return x + return torch.randn(shape, device=self.device) else: - return x_T + x + return x_T def p_sample( self,