Merge branch 'main' into main-text-fixup-PR

This commit is contained in:
Lincoln Stein 2023-03-18 09:54:41 -07:00 committed by GitHub
commit 7d7a28beb3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 69 additions and 11 deletions

View File

@ -4,6 +4,8 @@ from datetime import datetime, timezone
from typing import Any, Literal, Optional, Union from typing import Any, Literal, Optional, Union
import numpy as np import numpy as np
from torch import Tensor
from PIL import Image from PIL import Image
from pydantic import Field from pydantic import Field
from skimage.exposure.histogram_matching import match_histograms from skimage.exposure.histogram_matching import match_histograms
@ -12,7 +14,9 @@ from ..services.image_storage import ImageType
from ..services.invocation_services import InvocationServices from ..services.invocation_services import InvocationServices
from .baseinvocation import BaseInvocation, InvocationContext from .baseinvocation import BaseInvocation, InvocationContext
from .image import ImageField, ImageOutput from .image import ImageField, ImageOutput
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator, Generator
from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.util.util import image_to_dataURL
SAMPLER_NAME_VALUES = Literal[ SAMPLER_NAME_VALUES = Literal[
tuple(InvokeAIGenerator.schedulers()) tuple(InvokeAIGenerator.schedulers())
@ -41,18 +45,32 @@ class TextToImageInvocation(BaseInvocation):
# TODO: pass this an emitter method or something? or a session for dispatching? # TODO: pass this an emitter method or something? or a session for dispatching?
def dispatch_progress( def dispatch_progress(
self, context: InvocationContext, sample: Any = None, step: int = 0 self, context: InvocationContext, sample: Tensor, step: int
) -> None: ) -> None:
# TODO: only output a preview image when requested
image = Generator.sample_to_lowres_estimated_image(sample)
(width, height) = image.size
width *= 8
height *= 8
dataURL = image_to_dataURL(image, image_format="JPEG")
context.services.events.emit_generator_progress( context.services.events.emit_generator_progress(
context.graph_execution_state_id, context.graph_execution_state_id,
self.id, self.id,
{
"width": width,
"height": height,
"dataURL": dataURL
},
step, step,
float(step) / float(self.steps), self.steps,
) )
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
def step_callback(sample, step=0): def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, sample, step) self.dispatch_progress(context, state.latents, state.step)
# Handle invalid model parameter # Handle invalid model parameter
# TODO: figure out if this can be done via a validator that uses the model_cache # TODO: figure out if this can be done via a validator that uses the model_cache

View File

@ -1,7 +1,10 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from typing import Any, Dict from typing import Any, Dict, TypedDict
ProgressImage = TypedDict(
"ProgressImage", {"dataURL": str, "width": int, "height": int}
)
class EventServiceBase: class EventServiceBase:
session_event: str = "session_event" session_event: str = "session_event"
@ -23,8 +26,9 @@ class EventServiceBase:
self, self,
graph_execution_state_id: str, graph_execution_state_id: str,
invocation_id: str, invocation_id: str,
progress_image: ProgressImage | None,
step: int, step: int,
percent: float, total_steps: int,
) -> None: ) -> None:
"""Emitted when there is generation progress""" """Emitted when there is generation progress"""
self.__emit_session_event( self.__emit_session_event(
@ -32,8 +36,9 @@ class EventServiceBase:
payload=dict( payload=dict(
graph_execution_state_id=graph_execution_state_id, graph_execution_state_id=graph_execution_state_id,
invocation_id=invocation_id, invocation_id=invocation_id,
progress_image=progress_image,
step=step, step=step,
percent=percent, total_steps=total_steps,
), ),
) )

View File

@ -773,6 +773,24 @@ class GraphExecutionState(BaseModel):
default_factory=dict, default_factory=dict,
) )
# Declare all fields as required; necessary for OpenAPI schema generation build.
# Technically only fields without a `default_factory` need to be listed here.
# See: https://github.com/pydantic/pydantic/discussions/4577
class Config:
schema_extra = {
'required': [
'id',
'graph',
'execution_graph',
'executed',
'executed_history',
'results',
'errors',
'prepared_source_mapping',
'source_prepared_mapping',
]
}
def next(self) -> BaseInvocation | None: def next(self) -> BaseInvocation | None:
"""Gets the next node ready to execute.""" """Gets the next node ready to execute."""

View File

@ -497,7 +497,8 @@ class Generator:
matched_result.paste(init_image, (0, 0), mask=multiplied_blurred_init_mask) matched_result.paste(init_image, (0, 0), mask=multiplied_blurred_init_mask)
return matched_result return matched_result
def sample_to_lowres_estimated_image(self, samples): @staticmethod
def sample_to_lowres_estimated_image(samples):
# origingally adapted from code by @erucipe and @keturn here: # origingally adapted from code by @erucipe and @keturn here:
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7 # https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7

View File

@ -3,6 +3,9 @@ import math
import multiprocessing as mp import multiprocessing as mp
import os import os
import re import re
import io
import base64
from collections import abc from collections import abc
from inspect import isfunction from inspect import isfunction
from pathlib import Path from pathlib import Path
@ -364,3 +367,16 @@ def url_attachment_name(url: str) -> dict:
def download_with_progress_bar(url: str, dest: Path) -> bool: def download_with_progress_bar(url: str, dest: Path) -> bool:
result = download_with_resume(url, dest, access_token=None) result = download_with_resume(url, dest, access_token=None)
return result is not None return result is not None
def image_to_dataURL(image: Image.Image, image_format: str = "PNG") -> str:
"""
Converts an image into a base64 image dataURL.
"""
buffered = io.BytesIO()
image.save(buffered, format=image_format)
mime_type = Image.MIME.get(image_format.upper(), "image/" + image_format.lower())
image_base64 = f"data:{mime_type};base64," + base64.b64encode(
buffered.getvalue()
).decode("UTF-8")
return image_base64

View File

@ -38,7 +38,7 @@ dependencies = [
"albumentations", "albumentations",
"click", "click",
"clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip", "clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
"compel==0.1.10", "compel==1.0.1",
"datasets", "datasets",
"diffusers[torch]~=0.14", "diffusers[torch]~=0.14",
"dnspython==2.2.1", "dnspython==2.2.1",