diff --git a/invokeai/app/invocations/flux_text_to_image.py b/invokeai/app/invocations/flux_text_to_image.py index 19af5baae6..33a09da9bf 100644 --- a/invokeai/app/invocations/flux_text_to_image.py +++ b/invokeai/app/invocations/flux_text_to_image.py @@ -112,7 +112,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard): with transformer_info as transformer: assert isinstance(transformer, Flux) - def step_callback(img: torch.Tensor, state: PipelineIntermediateState) -> None: + def step_callback() -> None: if context.util.is_canceled(): raise CanceledException diff --git a/invokeai/backend/flux/sampling.py b/invokeai/backend/flux/sampling.py index ab9d41797b..5001959e50 100644 --- a/invokeai/backend/flux/sampling.py +++ b/invokeai/backend/flux/sampling.py @@ -109,7 +109,7 @@ def denoise( vec: Tensor, # sampling parameters timesteps: list[float], - step_callback: Callable[[Tensor, PipelineIntermediateState], None], + step_callback: Callable[[], None], guidance: float = 4.0, ): dtype = model.txt_in.bias.dtype @@ -123,7 +123,6 @@ def denoise( # this is ignored for schnell guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) - step_count = 0 for t_curr, t_prev in tqdm(list(zip(timesteps[:-1], timesteps[1:], strict=True))): t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) pred = model( @@ -137,17 +136,7 @@ def denoise( ) img = img + (t_prev - t_curr) * pred - step_callback( - img, - PipelineIntermediateState( - step=step_count, - order=0, - total_steps=len(timesteps), - timestep=math.floor(t_curr), - latents=img, - ), - ) - step_count += 1 + step_callback() return img