Remove no longer used code in the flux denoise function

This commit is contained in:
Brandon Rising 2024-08-26 13:07:31 -04:00 committed by Brandon
parent 3ea6c9666e
commit 849da67cc7
2 changed files with 3 additions and 14 deletions

View File

@ -112,7 +112,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
with transformer_info as transformer:
assert isinstance(transformer, Flux)
def step_callback(img: torch.Tensor, state: PipelineIntermediateState) -> None:
def step_callback() -> None:
if context.util.is_canceled():
raise CanceledException

View File

@ -109,7 +109,7 @@ def denoise(
vec: Tensor,
# sampling parameters
timesteps: list[float],
step_callback: Callable[[Tensor, PipelineIntermediateState], None],
step_callback: Callable[[], None],
guidance: float = 4.0,
):
dtype = model.txt_in.bias.dtype
@ -123,7 +123,6 @@ def denoise(
# this is ignored for schnell
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
step_count = 0
for t_curr, t_prev in tqdm(list(zip(timesteps[:-1], timesteps[1:], strict=True))):
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
pred = model(
@ -137,17 +136,7 @@ def denoise(
)
img = img + (t_prev - t_curr) * pred
step_callback(
img,
PipelineIntermediateState(
step=step_count,
order=0,
total_steps=len(timesteps),
timestep=math.floor(t_curr),
latents=img,
),
)
step_count += 1
step_callback()
return img