mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Apply requested changes
Co-Authored-By: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
This commit is contained in:
parent
17fed1c870
commit
e9ec5ab85c
@ -157,7 +157,15 @@ class CompelInvocation(BaseInvocation):
|
||||
|
||||
|
||||
class SDXLPromptInvocationBase:
|
||||
def run_clip_compel(self, context, clip_field, prompt, get_pooled, lora_prefix, zero_on_empty):
|
||||
def run_clip_compel(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
clip_field: ClipField,
|
||||
prompt: str,
|
||||
get_pooled: bool,
|
||||
lora_prefix: str,
|
||||
zero_on_empty: bool,
|
||||
):
|
||||
tokenizer_info = context.services.model_manager.get_model(
|
||||
**clip_field.tokenizer.dict(),
|
||||
context=context,
|
||||
|
@ -705,12 +705,16 @@ class MaskEdgeInvocation(BaseInvocation, PILInvocationConfig):
|
||||
|
||||
|
||||
class ColorCorrectInvocation(BaseInvocation, PILInvocationConfig):
|
||||
"""
|
||||
Shifts the colors of a target image to match the reference image, optionally
|
||||
using a mask to only color-correct certain regions of the target image.
|
||||
"""
|
||||
|
||||
type: Literal["color_correct"] = "color_correct"
|
||||
|
||||
init: Optional[ImageField] = Field(default=None, description="Initial image")
|
||||
result: Optional[ImageField] = Field(default=None, description="Resulted image")
|
||||
mask: Optional[ImageField] = Field(default=None, description="Mask image")
|
||||
image: Optional[ImageField] = Field(default=None, description="The image to color-correct")
|
||||
reference: Optional[ImageField] = Field(default=None, description="Reference image for color-correction")
|
||||
mask: Optional[ImageField] = Field(default=None, description="Mask to use when applying color-correction")
|
||||
mask_blur_radius: float = Field(default=8, description="Mask blur radius")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
@ -721,11 +725,11 @@ class ColorCorrectInvocation(BaseInvocation, PILInvocationConfig):
|
||||
).convert("L")
|
||||
|
||||
init_image = context.services.images.get_pil_image(
|
||||
self.init.image_name
|
||||
self.reference.image_name
|
||||
)
|
||||
|
||||
result = context.services.images.get_pil_image(
|
||||
self.result.image_name
|
||||
self.image.image_name
|
||||
).convert("RGBA")
|
||||
|
||||
|
||||
|
@ -336,7 +336,9 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
with SilenceWarnings():
|
||||
noise = context.services.latents.get(self.noise.latents_name)
|
||||
seed = self.noise.seed or 0
|
||||
seed = self.noise.seed
|
||||
if seed is None:
|
||||
seed = 0
|
||||
|
||||
# Get the source node id (we are invoking the prepared node)
|
||||
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
||||
@ -420,6 +422,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
# Inputs
|
||||
noise: Optional[LatentsField] = Field(description="The noise to use (test override for future optional)")
|
||||
|
||||
# denoising_start = 1 - strength
|
||||
denoising_start: float = Field(default=0.0, ge=0, le=1, description="")
|
||||
#denoising_end: float = Field(default=1.0, ge=0, le=1, description="")
|
||||
|
||||
@ -462,16 +465,23 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
with SilenceWarnings(): # this quenches NSFW nag from diffusers
|
||||
latent = context.services.latents.get(self.latents.latents_name)
|
||||
seed = self.latents.seed or 0
|
||||
|
||||
seed = None
|
||||
noise = None
|
||||
if self.noise is not None:
|
||||
noise = context.services.latents.get(self.noise.latents_name)
|
||||
if self.noise.seed is not None:
|
||||
seed = self.noise.seed
|
||||
seed = self.noise.seed
|
||||
|
||||
mask = self.prep_mask_tensor(self.mask, context, latent)
|
||||
if self.latents is not None:
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
if seed is None:
|
||||
seed = self.latents.seed
|
||||
else:
|
||||
latents = torch.zeros_like(noise)
|
||||
|
||||
if seed is None:
|
||||
seed = 0
|
||||
|
||||
mask = self.prep_mask_tensor(self.mask, context, latents)
|
||||
|
||||
# Get the source node id (we are invoking the prepared node)
|
||||
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
||||
@ -497,7 +507,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
with ExitStack() as exit_stack, ModelPatcher.apply_lora_unet(
|
||||
unet_info.context.model, _lora_loader()
|
||||
), unet_info as unet:
|
||||
latent = latent.to(device=unet.device, dtype=unet.dtype)
|
||||
latents = latents.to(device=unet.device, dtype=unet.dtype)
|
||||
if noise is not None:
|
||||
noise = noise.to(device=unet.device, dtype=unet.dtype)
|
||||
if mask is not None:
|
||||
@ -516,7 +526,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
model=pipeline,
|
||||
context=context,
|
||||
control_input=self.control,
|
||||
latents_shape=latent.shape,
|
||||
latents_shape=latents.shape,
|
||||
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
||||
do_classifier_free_guidance=True,
|
||||
exit_stack=exit_stack,
|
||||
@ -531,7 +541,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
)
|
||||
|
||||
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
|
||||
latents=latent,
|
||||
latents=latents,
|
||||
timesteps=timesteps,
|
||||
noise=noise,
|
||||
seed=seed,
|
||||
|
@ -58,6 +58,8 @@ def stable_diffusion_step_callback(
|
||||
# TODO: only output a preview image when requested
|
||||
|
||||
if base_model in [BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner]:
|
||||
# fast latents preview matrix for sdxl
|
||||
# generated by @StAlKeR7779
|
||||
sdxl_latent_rgb_factors = torch.tensor(
|
||||
[
|
||||
# R G B
|
||||
@ -72,9 +74,6 @@ def stable_diffusion_step_callback(
|
||||
|
||||
sdxl_smooth_matrix = torch.tensor(
|
||||
[
|
||||
# [ 0.0478, 0.1285, 0.0478],
|
||||
# [ 0.1285, 0.2948, 0.1285],
|
||||
# [ 0.0478, 0.1285, 0.0478],
|
||||
[0.0358, 0.0964, 0.0358],
|
||||
[0.0964, 0.4711, 0.0964],
|
||||
[0.0358, 0.0964, 0.0358],
|
||||
|
Loading…
Reference in New Issue
Block a user