mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Fix total_steps in generation event, order field added
This commit is contained in:
parent
b4a74f6523
commit
e98f7eda2e
@ -35,6 +35,7 @@ class EventServiceBase:
|
|||||||
source_node_id: str,
|
source_node_id: str,
|
||||||
progress_image: Optional[ProgressImage],
|
progress_image: Optional[ProgressImage],
|
||||||
step: int,
|
step: int,
|
||||||
|
order: int,
|
||||||
total_steps: int,
|
total_steps: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Emitted when there is generation progress"""
|
"""Emitted when there is generation progress"""
|
||||||
@ -46,6 +47,7 @@ class EventServiceBase:
|
|||||||
source_node_id=source_node_id,
|
source_node_id=source_node_id,
|
||||||
progress_image=progress_image.dict() if progress_image is not None else None,
|
progress_image=progress_image.dict() if progress_image is not None else None,
|
||||||
step=step,
|
step=step,
|
||||||
|
order=order,
|
||||||
total_steps=total_steps,
|
total_steps=total_steps,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
@ -115,5 +115,6 @@ def stable_diffusion_step_callback(
|
|||||||
source_node_id=source_node_id,
|
source_node_id=source_node_id,
|
||||||
progress_image=ProgressImage(width=width, height=height, dataURL=dataURL),
|
progress_image=ProgressImage(width=width, height=height, dataURL=dataURL),
|
||||||
step=intermediate_state.step,
|
step=intermediate_state.step,
|
||||||
total_steps=node["steps"],
|
order=intermediate_state.order,
|
||||||
|
total_steps=intermediate_state.total_steps,
|
||||||
)
|
)
|
||||||
|
@ -6,6 +6,7 @@ import inspect
|
|||||||
from typing import Any, Callable, Generic, List, Optional, Type, TypeVar, Union
|
from typing import Any, Callable, Generic, List, Optional, Type, TypeVar, Union
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
|
import math
|
||||||
import einops
|
import einops
|
||||||
import PIL.Image
|
import PIL.Image
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -42,6 +43,8 @@ from .diffusion import (
|
|||||||
@dataclass
|
@dataclass
|
||||||
class PipelineIntermediateState:
|
class PipelineIntermediateState:
|
||||||
step: int
|
step: int
|
||||||
|
order: int
|
||||||
|
total_steps: int
|
||||||
timestep: int
|
timestep: int
|
||||||
latents: torch.Tensor
|
latents: torch.Tensor
|
||||||
predicted_original: Optional[torch.Tensor] = None
|
predicted_original: Optional[torch.Tensor] = None
|
||||||
@ -484,6 +487,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
|
|
||||||
yield PipelineIntermediateState(
|
yield PipelineIntermediateState(
|
||||||
step=-1,
|
step=-1,
|
||||||
|
order=self.scheduler.order,
|
||||||
|
total_steps=len(timesteps),
|
||||||
timestep=self.scheduler.config.num_train_timesteps,
|
timestep=self.scheduler.config.num_train_timesteps,
|
||||||
latents=latents,
|
latents=latents,
|
||||||
)
|
)
|
||||||
@ -522,6 +527,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
|
|
||||||
yield PipelineIntermediateState(
|
yield PipelineIntermediateState(
|
||||||
step=i,
|
step=i,
|
||||||
|
order=self.scheduler.order,
|
||||||
|
total_steps=len(timesteps),
|
||||||
timestep=int(t),
|
timestep=int(t),
|
||||||
latents=latents,
|
latents=latents,
|
||||||
predicted_original=predicted_original,
|
predicted_original=predicted_original,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user