Apply requested changes

Co-Authored-By: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
This commit is contained in:
Sergey Borisov 2023-08-10 06:19:22 +03:00
parent 17fed1c870
commit e9ec5ab85c
4 changed files with 40 additions and 19 deletions

View File

@ -157,7 +157,15 @@ class CompelInvocation(BaseInvocation):
class SDXLPromptInvocationBase:
def run_clip_compel(self, context, clip_field, prompt, get_pooled, lora_prefix, zero_on_empty):
def run_clip_compel(
self,
context: InvocationContext,
clip_field: ClipField,
prompt: str,
get_pooled: bool,
lora_prefix: str,
zero_on_empty: bool,
):
tokenizer_info = context.services.model_manager.get_model(
**clip_field.tokenizer.dict(),
context=context,

View File

@ -705,12 +705,16 @@ class MaskEdgeInvocation(BaseInvocation, PILInvocationConfig):
class ColorCorrectInvocation(BaseInvocation, PILInvocationConfig):
"""
Shifts the colors of a target image to match the reference image, optionally
using a mask to only color-correct certain regions of the target image.
"""
type: Literal["color_correct"] = "color_correct"
init: Optional[ImageField] = Field(default=None, description="Initial image")
result: Optional[ImageField] = Field(default=None, description="Resulted image")
mask: Optional[ImageField] = Field(default=None, description="Mask image")
image: Optional[ImageField] = Field(default=None, description="The image to color-correct")
reference: Optional[ImageField] = Field(default=None, description="Reference image for color-correction")
mask: Optional[ImageField] = Field(default=None, description="Mask to use when applying color-correction")
mask_blur_radius: float = Field(default=8, description="Mask blur radius")
def invoke(self, context: InvocationContext) -> ImageOutput:
@ -721,11 +725,11 @@ class ColorCorrectInvocation(BaseInvocation, PILInvocationConfig):
).convert("L")
init_image = context.services.images.get_pil_image(
self.init.image_name
self.reference.image_name
)
result = context.services.images.get_pil_image(
self.result.image_name
self.image.image_name
).convert("RGBA")

View File

@ -336,7 +336,9 @@ class TextToLatentsInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> LatentsOutput:
with SilenceWarnings():
noise = context.services.latents.get(self.noise.latents_name)
seed = self.noise.seed or 0
seed = self.noise.seed
if seed is None:
seed = 0
# Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
@ -420,6 +422,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
# Inputs
noise: Optional[LatentsField] = Field(description="The noise to use (test override for future optional)")
# denoising_start = 1 - strength
denoising_start: float = Field(default=0.0, ge=0, le=1, description="")
#denoising_end: float = Field(default=1.0, ge=0, le=1, description="")
@ -462,16 +465,23 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
with SilenceWarnings(): # this quenches NSFW nag from diffusers
latent = context.services.latents.get(self.latents.latents_name)
seed = self.latents.seed or 0
seed = None
noise = None
if self.noise is not None:
noise = context.services.latents.get(self.noise.latents_name)
if self.noise.seed is not None:
seed = self.noise.seed
seed = self.noise.seed
mask = self.prep_mask_tensor(self.mask, context, latent)
if self.latents is not None:
latents = context.services.latents.get(self.latents.latents_name)
if seed is None:
seed = self.latents.seed
else:
latents = torch.zeros_like(noise)
if seed is None:
seed = 0
mask = self.prep_mask_tensor(self.mask, context, latents)
# Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
@ -497,7 +507,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
with ExitStack() as exit_stack, ModelPatcher.apply_lora_unet(
unet_info.context.model, _lora_loader()
), unet_info as unet:
latent = latent.to(device=unet.device, dtype=unet.dtype)
latents = latents.to(device=unet.device, dtype=unet.dtype)
if noise is not None:
noise = noise.to(device=unet.device, dtype=unet.dtype)
if mask is not None:
@ -516,7 +526,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
model=pipeline,
context=context,
control_input=self.control,
latents_shape=latent.shape,
latents_shape=latents.shape,
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
do_classifier_free_guidance=True,
exit_stack=exit_stack,
@ -531,7 +541,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
)
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
latents=latent,
latents=latents,
timesteps=timesteps,
noise=noise,
seed=seed,

View File

@ -58,6 +58,8 @@ def stable_diffusion_step_callback(
# TODO: only output a preview image when requested
if base_model in [BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner]:
# fast latents preview matrix for sdxl
# generated by @StAlKeR7779
sdxl_latent_rgb_factors = torch.tensor(
[
# R G B
@ -72,9 +74,6 @@ def stable_diffusion_step_callback(
sdxl_smooth_matrix = torch.tensor(
[
# [ 0.0478, 0.1285, 0.0478],
# [ 0.1285, 0.2948, 0.1285],
# [ 0.0478, 0.1285, 0.0478],
[0.0358, 0.0964, 0.0358],
[0.0964, 0.4711, 0.0964],
[0.0358, 0.0964, 0.0358],