mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix(nodes): fix step_callback + fast latents generation
this depends on the small change in #2957
This commit is contained in:
parent
5347c12fed
commit
67f8f222d9
@ -4,6 +4,8 @@ from datetime import datetime, timezone
|
||||
from typing import Any, Literal, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from torch import Tensor
|
||||
from PIL import Image
|
||||
from pydantic import Field
|
||||
from skimage.exposure.histogram_matching import match_histograms
|
||||
@ -12,7 +14,9 @@ from ..services.image_storage import ImageType
|
||||
from ..services.invocation_services import InvocationServices
|
||||
from .baseinvocation import BaseInvocation, InvocationContext
|
||||
from .image import ImageField, ImageOutput
|
||||
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator
|
||||
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator, Generator
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
from ...backend.util.util import image_to_dataURL
|
||||
|
||||
SAMPLER_NAME_VALUES = Literal[
|
||||
tuple(InvokeAIGenerator.schedulers())
|
||||
@ -41,18 +45,32 @@ class TextToImageInvocation(BaseInvocation):
|
||||
|
||||
# TODO: pass this an emitter method or something? or a session for dispatching?
|
||||
def dispatch_progress(
|
||||
self, context: InvocationContext, sample: Any = None, step: int = 0
|
||||
) -> None:
|
||||
self, context: InvocationContext, sample: Tensor, step: int
|
||||
) -> None:
|
||||
# TODO: only output a preview image when requested
|
||||
image = Generator.sample_to_lowres_estimated_image(sample)
|
||||
|
||||
(width, height) = image.size
|
||||
width *= 8
|
||||
height *= 8
|
||||
|
||||
dataURL = image_to_dataURL(image, image_format="JPEG")
|
||||
|
||||
context.services.events.emit_generator_progress(
|
||||
context.graph_execution_state_id,
|
||||
self.id,
|
||||
{
|
||||
"width": width,
|
||||
"height": height,
|
||||
"dataURL": dataURL
|
||||
},
|
||||
step,
|
||||
float(step) / float(self.steps),
|
||||
self.steps,
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
def step_callback(sample, step=0):
|
||||
self.dispatch_progress(context, sample, step)
|
||||
def step_callback(state: PipelineIntermediateState):
|
||||
self.dispatch_progress(context, state.latents, state.step)
|
||||
|
||||
# Handle invalid model parameter
|
||||
# TODO: figure out if this can be done via a validator that uses the model_cache
|
||||
|
@ -1,7 +1,10 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import Any, Dict
|
||||
from typing import Any, Dict, TypedDict
|
||||
|
||||
ProgressImage = TypedDict(
|
||||
"ProgressImage", {"dataURL": str, "width": int, "height": int}
|
||||
)
|
||||
|
||||
class EventServiceBase:
|
||||
session_event: str = "session_event"
|
||||
@ -23,8 +26,9 @@ class EventServiceBase:
|
||||
self,
|
||||
graph_execution_state_id: str,
|
||||
invocation_id: str,
|
||||
progress_image: ProgressImage | None,
|
||||
step: int,
|
||||
percent: float,
|
||||
total_steps: int,
|
||||
) -> None:
|
||||
"""Emitted when there is generation progress"""
|
||||
self.__emit_session_event(
|
||||
@ -32,8 +36,9 @@ class EventServiceBase:
|
||||
payload=dict(
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
invocation_id=invocation_id,
|
||||
progress_image=progress_image,
|
||||
step=step,
|
||||
percent=percent,
|
||||
total_steps=total_steps,
|
||||
),
|
||||
)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user