Fix total_steps in generation event, order field added

This commit is contained in:
Sergey Borisov
2023-08-09 03:34:25 +03:00
parent b4a74f6523
commit e98f7eda2e
3 changed files with 11 additions and 1 deletions

View File

@ -6,6 +6,7 @@ import inspect
from typing import Any, Callable, Generic, List, Optional, Type, TypeVar, Union
from pydantic import Field
import math
import einops
import PIL.Image
import numpy as np
@ -42,6 +43,8 @@ from .diffusion import (
@dataclass
class PipelineIntermediateState:
step: int
order: int
total_steps: int
timestep: int
latents: torch.Tensor
predicted_original: Optional[torch.Tensor] = None
@ -484,6 +487,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
yield PipelineIntermediateState(
step=-1,
order=self.scheduler.order,
total_steps=len(timesteps),
timestep=self.scheduler.config.num_train_timesteps,
latents=latents,
)
@ -522,6 +527,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
yield PipelineIntermediateState(
step=i,
order=self.scheduler.order,
total_steps=len(timesteps),
timestep=int(t),
latents=latents,
predicted_original=predicted_original,