Fix total_steps in generation event, order field added

This commit is contained in:
Sergey Borisov 2023-08-09 03:34:25 +03:00
parent b4a74f6523
commit e98f7eda2e
3 changed files with 11 additions and 1 deletions

View File

@ -35,6 +35,7 @@ class EventServiceBase:
source_node_id: str,
progress_image: Optional[ProgressImage],
step: int,
order: int,
total_steps: int,
) -> None:
"""Emitted when there is generation progress"""
@ -46,6 +47,7 @@ class EventServiceBase:
source_node_id=source_node_id,
progress_image=progress_image.dict() if progress_image is not None else None,
step=step,
order=order,
total_steps=total_steps,
),
)

View File

@ -115,5 +115,6 @@ def stable_diffusion_step_callback(
source_node_id=source_node_id,
progress_image=ProgressImage(width=width, height=height, dataURL=dataURL),
step=intermediate_state.step,
total_steps=node["steps"],
order=intermediate_state.order,
total_steps=intermediate_state.total_steps,
)

View File

@ -6,6 +6,7 @@ import inspect
from typing import Any, Callable, Generic, List, Optional, Type, TypeVar, Union
from pydantic import Field
import math
import einops
import PIL.Image
import numpy as np
@ -42,6 +43,8 @@ from .diffusion import (
@dataclass
class PipelineIntermediateState:
step: int
order: int
total_steps: int
timestep: int
latents: torch.Tensor
predicted_original: Optional[torch.Tensor] = None
@ -484,6 +487,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
yield PipelineIntermediateState(
step=-1,
order=self.scheduler.order,
total_steps=len(timesteps),
timestep=self.scheduler.config.num_train_timesteps,
latents=latents,
)
@ -522,6 +527,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
yield PipelineIntermediateState(
step=i,
order=self.scheduler.order,
total_steps=len(timesteps),
timestep=int(t),
latents=latents,
predicted_original=predicted_original,