mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Apply requested changes
Co-Authored-By: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
This commit is contained in:
parent
17fed1c870
commit
e9ec5ab85c
@ -157,7 +157,15 @@ class CompelInvocation(BaseInvocation):
|
|||||||
|
|
||||||
|
|
||||||
class SDXLPromptInvocationBase:
|
class SDXLPromptInvocationBase:
|
||||||
def run_clip_compel(self, context, clip_field, prompt, get_pooled, lora_prefix, zero_on_empty):
|
def run_clip_compel(
|
||||||
|
self,
|
||||||
|
context: InvocationContext,
|
||||||
|
clip_field: ClipField,
|
||||||
|
prompt: str,
|
||||||
|
get_pooled: bool,
|
||||||
|
lora_prefix: str,
|
||||||
|
zero_on_empty: bool,
|
||||||
|
):
|
||||||
tokenizer_info = context.services.model_manager.get_model(
|
tokenizer_info = context.services.model_manager.get_model(
|
||||||
**clip_field.tokenizer.dict(),
|
**clip_field.tokenizer.dict(),
|
||||||
context=context,
|
context=context,
|
||||||
|
@ -705,12 +705,16 @@ class MaskEdgeInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
|
|
||||||
|
|
||||||
class ColorCorrectInvocation(BaseInvocation, PILInvocationConfig):
|
class ColorCorrectInvocation(BaseInvocation, PILInvocationConfig):
|
||||||
|
"""
|
||||||
|
Shifts the colors of a target image to match the reference image, optionally
|
||||||
|
using a mask to only color-correct certain regions of the target image.
|
||||||
|
"""
|
||||||
|
|
||||||
type: Literal["color_correct"] = "color_correct"
|
type: Literal["color_correct"] = "color_correct"
|
||||||
|
|
||||||
init: Optional[ImageField] = Field(default=None, description="Initial image")
|
image: Optional[ImageField] = Field(default=None, description="The image to color-correct")
|
||||||
result: Optional[ImageField] = Field(default=None, description="Resulted image")
|
reference: Optional[ImageField] = Field(default=None, description="Reference image for color-correction")
|
||||||
mask: Optional[ImageField] = Field(default=None, description="Mask image")
|
mask: Optional[ImageField] = Field(default=None, description="Mask to use when applying color-correction")
|
||||||
mask_blur_radius: float = Field(default=8, description="Mask blur radius")
|
mask_blur_radius: float = Field(default=8, description="Mask blur radius")
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
@ -721,11 +725,11 @@ class ColorCorrectInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
).convert("L")
|
).convert("L")
|
||||||
|
|
||||||
init_image = context.services.images.get_pil_image(
|
init_image = context.services.images.get_pil_image(
|
||||||
self.init.image_name
|
self.reference.image_name
|
||||||
)
|
)
|
||||||
|
|
||||||
result = context.services.images.get_pil_image(
|
result = context.services.images.get_pil_image(
|
||||||
self.result.image_name
|
self.image.image_name
|
||||||
).convert("RGBA")
|
).convert("RGBA")
|
||||||
|
|
||||||
|
|
||||||
|
@ -336,7 +336,9 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
with SilenceWarnings():
|
with SilenceWarnings():
|
||||||
noise = context.services.latents.get(self.noise.latents_name)
|
noise = context.services.latents.get(self.noise.latents_name)
|
||||||
seed = self.noise.seed or 0
|
seed = self.noise.seed
|
||||||
|
if seed is None:
|
||||||
|
seed = 0
|
||||||
|
|
||||||
# Get the source node id (we are invoking the prepared node)
|
# Get the source node id (we are invoking the prepared node)
|
||||||
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
||||||
@ -420,6 +422,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
|||||||
# Inputs
|
# Inputs
|
||||||
noise: Optional[LatentsField] = Field(description="The noise to use (test override for future optional)")
|
noise: Optional[LatentsField] = Field(description="The noise to use (test override for future optional)")
|
||||||
|
|
||||||
|
# denoising_start = 1 - strength
|
||||||
denoising_start: float = Field(default=0.0, ge=0, le=1, description="")
|
denoising_start: float = Field(default=0.0, ge=0, le=1, description="")
|
||||||
#denoising_end: float = Field(default=1.0, ge=0, le=1, description="")
|
#denoising_end: float = Field(default=1.0, ge=0, le=1, description="")
|
||||||
|
|
||||||
@ -462,16 +465,23 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
with SilenceWarnings(): # this quenches NSFW nag from diffusers
|
with SilenceWarnings(): # this quenches NSFW nag from diffusers
|
||||||
latent = context.services.latents.get(self.latents.latents_name)
|
seed = None
|
||||||
seed = self.latents.seed or 0
|
|
||||||
|
|
||||||
noise = None
|
noise = None
|
||||||
if self.noise is not None:
|
if self.noise is not None:
|
||||||
noise = context.services.latents.get(self.noise.latents_name)
|
noise = context.services.latents.get(self.noise.latents_name)
|
||||||
if self.noise.seed is not None:
|
seed = self.noise.seed
|
||||||
seed = self.noise.seed
|
|
||||||
|
|
||||||
mask = self.prep_mask_tensor(self.mask, context, latent)
|
if self.latents is not None:
|
||||||
|
latents = context.services.latents.get(self.latents.latents_name)
|
||||||
|
if seed is None:
|
||||||
|
seed = self.latents.seed
|
||||||
|
else:
|
||||||
|
latents = torch.zeros_like(noise)
|
||||||
|
|
||||||
|
if seed is None:
|
||||||
|
seed = 0
|
||||||
|
|
||||||
|
mask = self.prep_mask_tensor(self.mask, context, latents)
|
||||||
|
|
||||||
# Get the source node id (we are invoking the prepared node)
|
# Get the source node id (we are invoking the prepared node)
|
||||||
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
||||||
@ -497,7 +507,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
|||||||
with ExitStack() as exit_stack, ModelPatcher.apply_lora_unet(
|
with ExitStack() as exit_stack, ModelPatcher.apply_lora_unet(
|
||||||
unet_info.context.model, _lora_loader()
|
unet_info.context.model, _lora_loader()
|
||||||
), unet_info as unet:
|
), unet_info as unet:
|
||||||
latent = latent.to(device=unet.device, dtype=unet.dtype)
|
latents = latents.to(device=unet.device, dtype=unet.dtype)
|
||||||
if noise is not None:
|
if noise is not None:
|
||||||
noise = noise.to(device=unet.device, dtype=unet.dtype)
|
noise = noise.to(device=unet.device, dtype=unet.dtype)
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
@ -516,7 +526,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
|||||||
model=pipeline,
|
model=pipeline,
|
||||||
context=context,
|
context=context,
|
||||||
control_input=self.control,
|
control_input=self.control,
|
||||||
latents_shape=latent.shape,
|
latents_shape=latents.shape,
|
||||||
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
||||||
do_classifier_free_guidance=True,
|
do_classifier_free_guidance=True,
|
||||||
exit_stack=exit_stack,
|
exit_stack=exit_stack,
|
||||||
@ -531,7 +541,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
|
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
|
||||||
latents=latent,
|
latents=latents,
|
||||||
timesteps=timesteps,
|
timesteps=timesteps,
|
||||||
noise=noise,
|
noise=noise,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
|
@ -58,6 +58,8 @@ def stable_diffusion_step_callback(
|
|||||||
# TODO: only output a preview image when requested
|
# TODO: only output a preview image when requested
|
||||||
|
|
||||||
if base_model in [BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner]:
|
if base_model in [BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner]:
|
||||||
|
# fast latents preview matrix for sdxl
|
||||||
|
# generated by @StAlKeR7779
|
||||||
sdxl_latent_rgb_factors = torch.tensor(
|
sdxl_latent_rgb_factors = torch.tensor(
|
||||||
[
|
[
|
||||||
# R G B
|
# R G B
|
||||||
@ -72,9 +74,6 @@ def stable_diffusion_step_callback(
|
|||||||
|
|
||||||
sdxl_smooth_matrix = torch.tensor(
|
sdxl_smooth_matrix = torch.tensor(
|
||||||
[
|
[
|
||||||
# [ 0.0478, 0.1285, 0.0478],
|
|
||||||
# [ 0.1285, 0.2948, 0.1285],
|
|
||||||
# [ 0.0478, 0.1285, 0.0478],
|
|
||||||
[0.0358, 0.0964, 0.0358],
|
[0.0358, 0.0964, 0.0358],
|
||||||
[0.0964, 0.4711, 0.0964],
|
[0.0964, 0.4711, 0.0964],
|
||||||
[0.0358, 0.0964, 0.0358],
|
[0.0358, 0.0964, 0.0358],
|
||||||
|
Loading…
x
Reference in New Issue
Block a user