Add tqdm progress bar to FLUX denoising.

This commit is contained in:
Ryan Dick 2024-08-20 14:52:05 +00:00 committed by Brandon
parent 0c5e11f521
commit e49105ece5

View File

@ -4,6 +4,7 @@ from typing import Callable
import torch import torch
from einops import rearrange, repeat from einops import rearrange, repeat
from torch import Tensor from torch import Tensor
from tqdm import tqdm
from .model import Flux from .model import Flux
from .modules.conditioner import HFEncoder from .modules.conditioner import HFEncoder
@ -115,7 +116,7 @@ def denoise(
# this is ignored for schnell # this is ignored for schnell
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:], strict=True): for t_curr, t_prev in tqdm(list(zip(timesteps[:-1], timesteps[1:], strict=True))):
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
pred = model( pred = model(
img=img, img=img,