mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix broken image generation on plms and ddim samplers
This commit is contained in:
parent
c1230da3ab
commit
333219be35
@ -39,6 +39,7 @@ class Sampler(object):
|
||||
ddim_eta=0.0,
|
||||
verbose=False,
|
||||
):
|
||||
self.total_steps = ddim_num_steps
|
||||
self.ddim_timesteps = make_ddim_timesteps(
|
||||
ddim_discr_method=ddim_discretize,
|
||||
num_ddim_timesteps=ddim_num_steps,
|
||||
@ -211,6 +212,7 @@ class Sampler(object):
|
||||
if ddim_use_original_steps
|
||||
else np.flip(timesteps)
|
||||
)
|
||||
|
||||
total_steps=steps
|
||||
|
||||
iterator = tqdm(
|
||||
@ -305,7 +307,7 @@ class Sampler(object):
|
||||
|
||||
time_range = np.flip(timesteps)
|
||||
total_steps = timesteps.shape[0]
|
||||
print(f'>> Running {self.__class__.__name__} Sampling with {total_steps} timesteps')
|
||||
print(f'>> Running {self.__class__.__name__} sampling starting at step {self.total_steps - t_start} of {self.total_steps} ({total_steps} new sampling steps)')
|
||||
|
||||
iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
|
||||
x_dec = x_latent
|
||||
@ -351,11 +353,10 @@ class Sampler(object):
|
||||
return x_dec
|
||||
|
||||
def get_initial_image(self,x_T,shape,timesteps=None):
|
||||
x = torch.randn(shape, device=self.device)
|
||||
if x_T is None:
|
||||
return x
|
||||
return torch.randn(shape, device=self.device)
|
||||
else:
|
||||
return x_T + x
|
||||
return x_T
|
||||
|
||||
def p_sample(
|
||||
self,
|
||||
|
Loading…
Reference in New Issue
Block a user