Fix total_steps in generation event, order field added

This commit is contained in:
Sergey Borisov 2023-08-09 03:34:25 +03:00
parent b4a74f6523
commit e98f7eda2e
3 changed files with 11 additions and 1 deletions

View File

@ -35,6 +35,7 @@ class EventServiceBase:
source_node_id: str, source_node_id: str,
progress_image: Optional[ProgressImage], progress_image: Optional[ProgressImage],
step: int, step: int,
order: int,
total_steps: int, total_steps: int,
) -> None: ) -> None:
"""Emitted when there is generation progress""" """Emitted when there is generation progress"""
@ -46,6 +47,7 @@ class EventServiceBase:
source_node_id=source_node_id, source_node_id=source_node_id,
progress_image=progress_image.dict() if progress_image is not None else None, progress_image=progress_image.dict() if progress_image is not None else None,
step=step, step=step,
order=order,
total_steps=total_steps, total_steps=total_steps,
), ),
) )

View File

@ -115,5 +115,6 @@ def stable_diffusion_step_callback(
source_node_id=source_node_id, source_node_id=source_node_id,
progress_image=ProgressImage(width=width, height=height, dataURL=dataURL), progress_image=ProgressImage(width=width, height=height, dataURL=dataURL),
step=intermediate_state.step, step=intermediate_state.step,
total_steps=node["steps"], order=intermediate_state.order,
total_steps=intermediate_state.total_steps,
) )

View File

@ -6,6 +6,7 @@ import inspect
from typing import Any, Callable, Generic, List, Optional, Type, TypeVar, Union from typing import Any, Callable, Generic, List, Optional, Type, TypeVar, Union
from pydantic import Field from pydantic import Field
import math
import einops import einops
import PIL.Image import PIL.Image
import numpy as np import numpy as np
@ -42,6 +43,8 @@ from .diffusion import (
@dataclass @dataclass
class PipelineIntermediateState: class PipelineIntermediateState:
step: int step: int
order: int
total_steps: int
timestep: int timestep: int
latents: torch.Tensor latents: torch.Tensor
predicted_original: Optional[torch.Tensor] = None predicted_original: Optional[torch.Tensor] = None
@ -484,6 +487,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
yield PipelineIntermediateState( yield PipelineIntermediateState(
step=-1, step=-1,
order=self.scheduler.order,
total_steps=len(timesteps),
timestep=self.scheduler.config.num_train_timesteps, timestep=self.scheduler.config.num_train_timesteps,
latents=latents, latents=latents,
) )
@ -522,6 +527,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
yield PipelineIntermediateState( yield PipelineIntermediateState(
step=i, step=i,
order=self.scheduler.order,
total_steps=len(timesteps),
timestep=int(t), timestep=int(t),
latents=latents, latents=latents,
predicted_original=predicted_original, predicted_original=predicted_original,