mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into main-text-fixup-PR
This commit is contained in:
commit
7d7a28beb3
@ -4,6 +4,8 @@ from datetime import datetime, timezone
|
|||||||
from typing import Any, Literal, Optional, Union
|
from typing import Any, Literal, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from torch import Tensor
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from skimage.exposure.histogram_matching import match_histograms
|
from skimage.exposure.histogram_matching import match_histograms
|
||||||
@ -12,7 +14,9 @@ from ..services.image_storage import ImageType
|
|||||||
from ..services.invocation_services import InvocationServices
|
from ..services.invocation_services import InvocationServices
|
||||||
from .baseinvocation import BaseInvocation, InvocationContext
|
from .baseinvocation import BaseInvocation, InvocationContext
|
||||||
from .image import ImageField, ImageOutput
|
from .image import ImageField, ImageOutput
|
||||||
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator
|
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator, Generator
|
||||||
|
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||||
|
from ...backend.util.util import image_to_dataURL
|
||||||
|
|
||||||
SAMPLER_NAME_VALUES = Literal[
|
SAMPLER_NAME_VALUES = Literal[
|
||||||
tuple(InvokeAIGenerator.schedulers())
|
tuple(InvokeAIGenerator.schedulers())
|
||||||
@ -41,18 +45,32 @@ class TextToImageInvocation(BaseInvocation):
|
|||||||
|
|
||||||
# TODO: pass this an emitter method or something? or a session for dispatching?
|
# TODO: pass this an emitter method or something? or a session for dispatching?
|
||||||
def dispatch_progress(
|
def dispatch_progress(
|
||||||
self, context: InvocationContext, sample: Any = None, step: int = 0
|
self, context: InvocationContext, sample: Tensor, step: int
|
||||||
) -> None:
|
) -> None:
|
||||||
|
# TODO: only output a preview image when requested
|
||||||
|
image = Generator.sample_to_lowres_estimated_image(sample)
|
||||||
|
|
||||||
|
(width, height) = image.size
|
||||||
|
width *= 8
|
||||||
|
height *= 8
|
||||||
|
|
||||||
|
dataURL = image_to_dataURL(image, image_format="JPEG")
|
||||||
|
|
||||||
context.services.events.emit_generator_progress(
|
context.services.events.emit_generator_progress(
|
||||||
context.graph_execution_state_id,
|
context.graph_execution_state_id,
|
||||||
self.id,
|
self.id,
|
||||||
|
{
|
||||||
|
"width": width,
|
||||||
|
"height": height,
|
||||||
|
"dataURL": dataURL
|
||||||
|
},
|
||||||
step,
|
step,
|
||||||
float(step) / float(self.steps),
|
self.steps,
|
||||||
)
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
def step_callback(sample, step=0):
|
def step_callback(state: PipelineIntermediateState):
|
||||||
self.dispatch_progress(context, sample, step)
|
self.dispatch_progress(context, state.latents, state.step)
|
||||||
|
|
||||||
# Handle invalid model parameter
|
# Handle invalid model parameter
|
||||||
# TODO: figure out if this can be done via a validator that uses the model_cache
|
# TODO: figure out if this can be done via a validator that uses the model_cache
|
||||||
|
@ -1,7 +1,10 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict, TypedDict
|
||||||
|
|
||||||
|
ProgressImage = TypedDict(
|
||||||
|
"ProgressImage", {"dataURL": str, "width": int, "height": int}
|
||||||
|
)
|
||||||
|
|
||||||
class EventServiceBase:
|
class EventServiceBase:
|
||||||
session_event: str = "session_event"
|
session_event: str = "session_event"
|
||||||
@ -23,8 +26,9 @@ class EventServiceBase:
|
|||||||
self,
|
self,
|
||||||
graph_execution_state_id: str,
|
graph_execution_state_id: str,
|
||||||
invocation_id: str,
|
invocation_id: str,
|
||||||
|
progress_image: ProgressImage | None,
|
||||||
step: int,
|
step: int,
|
||||||
percent: float,
|
total_steps: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Emitted when there is generation progress"""
|
"""Emitted when there is generation progress"""
|
||||||
self.__emit_session_event(
|
self.__emit_session_event(
|
||||||
@ -32,8 +36,9 @@ class EventServiceBase:
|
|||||||
payload=dict(
|
payload=dict(
|
||||||
graph_execution_state_id=graph_execution_state_id,
|
graph_execution_state_id=graph_execution_state_id,
|
||||||
invocation_id=invocation_id,
|
invocation_id=invocation_id,
|
||||||
|
progress_image=progress_image,
|
||||||
step=step,
|
step=step,
|
||||||
percent=percent,
|
total_steps=total_steps,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -773,6 +773,24 @@ class GraphExecutionState(BaseModel):
|
|||||||
default_factory=dict,
|
default_factory=dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Declare all fields as required; necessary for OpenAPI schema generation build.
|
||||||
|
# Technically only fields without a `default_factory` need to be listed here.
|
||||||
|
# See: https://github.com/pydantic/pydantic/discussions/4577
|
||||||
|
class Config:
|
||||||
|
schema_extra = {
|
||||||
|
'required': [
|
||||||
|
'id',
|
||||||
|
'graph',
|
||||||
|
'execution_graph',
|
||||||
|
'executed',
|
||||||
|
'executed_history',
|
||||||
|
'results',
|
||||||
|
'errors',
|
||||||
|
'prepared_source_mapping',
|
||||||
|
'source_prepared_mapping',
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
def next(self) -> BaseInvocation | None:
|
def next(self) -> BaseInvocation | None:
|
||||||
"""Gets the next node ready to execute."""
|
"""Gets the next node ready to execute."""
|
||||||
|
|
||||||
|
@ -497,7 +497,8 @@ class Generator:
|
|||||||
matched_result.paste(init_image, (0, 0), mask=multiplied_blurred_init_mask)
|
matched_result.paste(init_image, (0, 0), mask=multiplied_blurred_init_mask)
|
||||||
return matched_result
|
return matched_result
|
||||||
|
|
||||||
def sample_to_lowres_estimated_image(self, samples):
|
@staticmethod
|
||||||
|
def sample_to_lowres_estimated_image(samples):
|
||||||
# origingally adapted from code by @erucipe and @keturn here:
|
# origingally adapted from code by @erucipe and @keturn here:
|
||||||
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7
|
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7
|
||||||
|
|
||||||
|
@ -3,6 +3,9 @@ import math
|
|||||||
import multiprocessing as mp
|
import multiprocessing as mp
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
import io
|
||||||
|
import base64
|
||||||
|
|
||||||
from collections import abc
|
from collections import abc
|
||||||
from inspect import isfunction
|
from inspect import isfunction
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -364,3 +367,16 @@ def url_attachment_name(url: str) -> dict:
|
|||||||
def download_with_progress_bar(url: str, dest: Path) -> bool:
|
def download_with_progress_bar(url: str, dest: Path) -> bool:
|
||||||
result = download_with_resume(url, dest, access_token=None)
|
result = download_with_resume(url, dest, access_token=None)
|
||||||
return result is not None
|
return result is not None
|
||||||
|
|
||||||
|
|
||||||
|
def image_to_dataURL(image: Image.Image, image_format: str = "PNG") -> str:
|
||||||
|
"""
|
||||||
|
Converts an image into a base64 image dataURL.
|
||||||
|
"""
|
||||||
|
buffered = io.BytesIO()
|
||||||
|
image.save(buffered, format=image_format)
|
||||||
|
mime_type = Image.MIME.get(image_format.upper(), "image/" + image_format.lower())
|
||||||
|
image_base64 = f"data:{mime_type};base64," + base64.b64encode(
|
||||||
|
buffered.getvalue()
|
||||||
|
).decode("UTF-8")
|
||||||
|
return image_base64
|
||||||
|
@ -38,7 +38,7 @@ dependencies = [
|
|||||||
"albumentations",
|
"albumentations",
|
||||||
"click",
|
"click",
|
||||||
"clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
|
"clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
|
||||||
"compel==0.1.10",
|
"compel==1.0.1",
|
||||||
"datasets",
|
"datasets",
|
||||||
"diffusers[torch]~=0.14",
|
"diffusers[torch]~=0.14",
|
||||||
"dnspython==2.2.1",
|
"dnspython==2.2.1",
|
||||||
|
Loading…
Reference in New Issue
Block a user