mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
keep the effect of _start and _end arguments consistent across k* and other samplers
This commit is contained in:
parent
ee4273d760
commit
cc2042bd4c
@ -51,7 +51,8 @@ class Img2Img(Generator):
|
||||
unconditional_guidance_scale=cfg_scale,
|
||||
unconditional_conditioning=uc,
|
||||
init_latent = self.init_latent, # changes how noising is performed in ksampler
|
||||
extra_conditioning_info = extra_conditioning_info
|
||||
extra_conditioning_info = extra_conditioning_info,
|
||||
all_timesteps_count = steps
|
||||
)
|
||||
|
||||
return self.sample_to_image(samples)
|
||||
|
@ -29,6 +29,10 @@ class CrossAttentionControl:
|
||||
|
||||
class Context:
|
||||
def __init__(self, arguments: 'CrossAttentionControl.Arguments', step_count: int):
|
||||
"""
|
||||
:param arguments: Arguments for the cross-attention control process
|
||||
:param step_count: The absolute total number of steps of diffusion (for img2img this is likely larger than the number of steps that will actually run)
|
||||
"""
|
||||
self.arguments = arguments
|
||||
self.step_count = step_count
|
||||
|
||||
|
@ -16,9 +16,10 @@ class DDIMSampler(Sampler):
|
||||
super().prepare_to_sample(t_enc, **kwargs)
|
||||
|
||||
extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
|
||||
all_timesteps_count = kwargs.get('all_timesteps_count', t_enc)
|
||||
|
||||
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
|
||||
self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info, step_count = t_enc)
|
||||
self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info, step_count = all_timesteps_count)
|
||||
else:
|
||||
self.invokeai_diffuser.remove_cross_attention_control()
|
||||
|
||||
|
@ -21,9 +21,10 @@ class PLMSSampler(Sampler):
|
||||
super().prepare_to_sample(t_enc, **kwargs)
|
||||
|
||||
extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
|
||||
all_timesteps_count = kwargs.get('all_timesteps_count', t_enc)
|
||||
|
||||
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
|
||||
self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info, step_count = t_enc)
|
||||
self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info, step_count = all_timesteps_count)
|
||||
else:
|
||||
self.invokeai_diffuser.remove_cross_attention_control()
|
||||
|
||||
|
@ -235,7 +235,7 @@ class Sampler(object):
|
||||
dynamic_ncols=True,
|
||||
)
|
||||
old_eps = []
|
||||
self.prepare_to_sample(t_enc=total_steps,**kwargs)
|
||||
self.prepare_to_sample(t_enc=total_steps,all_timesteps_count=steps,**kwargs)
|
||||
img = self.get_initial_image(x_T,shape,total_steps)
|
||||
|
||||
# probably don't need this at all
|
||||
@ -310,6 +310,7 @@ class Sampler(object):
|
||||
use_original_steps=False,
|
||||
init_latent = None,
|
||||
mask = None,
|
||||
all_timesteps_count = None,
|
||||
**kwargs
|
||||
):
|
||||
|
||||
@ -327,7 +328,7 @@ class Sampler(object):
|
||||
iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
|
||||
x_dec = x_latent
|
||||
x0 = init_latent
|
||||
self.prepare_to_sample(t_enc=total_steps,**kwargs)
|
||||
self.prepare_to_sample(t_enc=total_steps, all_timesteps_count=all_timesteps_count, **kwargs)
|
||||
|
||||
for i, step in enumerate(iterator):
|
||||
index = total_steps - i - 1
|
||||
@ -359,7 +360,7 @@ class Sampler(object):
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
t_next = ts_next,
|
||||
step_count=total_steps
|
||||
step_count=len(self.ddim_timesteps)
|
||||
)
|
||||
|
||||
x_dec, pred_x0, e_t = outs
|
||||
|
Loading…
Reference in New Issue
Block a user