mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into feat/taesd
This commit is contained in:
commit
7df67d077a
@ -279,8 +279,8 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
|||||||
crop_left: int = InputField(default=0, description="")
|
crop_left: int = InputField(default=0, description="")
|
||||||
target_width: int = InputField(default=1024, description="")
|
target_width: int = InputField(default=1024, description="")
|
||||||
target_height: int = InputField(default=1024, description="")
|
target_height: int = InputField(default=1024, description="")
|
||||||
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
|
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 1")
|
||||||
clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
|
clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2")
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||||
|
@ -72,10 +72,10 @@ class CoreMetadata(BaseModelExcludeNull):
|
|||||||
)
|
)
|
||||||
refiner_steps: Optional[int] = Field(default=None, description="The number of steps used for the refiner")
|
refiner_steps: Optional[int] = Field(default=None, description="The number of steps used for the refiner")
|
||||||
refiner_scheduler: Optional[str] = Field(default=None, description="The scheduler used for the refiner")
|
refiner_scheduler: Optional[str] = Field(default=None, description="The scheduler used for the refiner")
|
||||||
refiner_positive_aesthetic_store: Optional[float] = Field(
|
refiner_positive_aesthetic_score: Optional[float] = Field(
|
||||||
default=None, description="The aesthetic score used for the refiner"
|
default=None, description="The aesthetic score used for the refiner"
|
||||||
)
|
)
|
||||||
refiner_negative_aesthetic_store: Optional[float] = Field(
|
refiner_negative_aesthetic_score: Optional[float] = Field(
|
||||||
default=None, description="The aesthetic score used for the refiner"
|
default=None, description="The aesthetic score used for the refiner"
|
||||||
)
|
)
|
||||||
refiner_start: Optional[float] = Field(default=None, description="The start value used for refiner denoising")
|
refiner_start: Optional[float] = Field(default=None, description="The start value used for refiner denoising")
|
||||||
@ -160,11 +160,11 @@ class MetadataAccumulatorInvocation(BaseInvocation):
|
|||||||
default=None,
|
default=None,
|
||||||
description="The scheduler used for the refiner",
|
description="The scheduler used for the refiner",
|
||||||
)
|
)
|
||||||
refiner_positive_aesthetic_store: Optional[float] = InputField(
|
refiner_positive_aesthetic_score: Optional[float] = InputField(
|
||||||
default=None,
|
default=None,
|
||||||
description="The aesthetic score used for the refiner",
|
description="The aesthetic score used for the refiner",
|
||||||
)
|
)
|
||||||
refiner_negative_aesthetic_store: Optional[float] = InputField(
|
refiner_negative_aesthetic_score: Optional[float] = InputField(
|
||||||
default=None,
|
default=None,
|
||||||
description="The aesthetic score used for the refiner",
|
description="The aesthetic score used for the refiner",
|
||||||
)
|
)
|
||||||
|
@ -249,14 +249,14 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
|
|||||||
"""Apply selected lora to unet and text_encoder."""
|
"""Apply selected lora to unet and text_encoder."""
|
||||||
|
|
||||||
lora: LoRAModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
|
lora: LoRAModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
|
||||||
weight: float = Field(default=0.75, description=FieldDescriptions.lora_weight)
|
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
|
||||||
unet: Optional[UNetField] = Field(
|
unet: Optional[UNetField] = InputField(
|
||||||
default=None, description=FieldDescriptions.unet, input=Input.Connection, title="UNET"
|
default=None, description=FieldDescriptions.unet, input=Input.Connection, title="UNet"
|
||||||
)
|
)
|
||||||
clip: Optional[ClipField] = Field(
|
clip: Optional[ClipField] = InputField(
|
||||||
default=None, description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 1"
|
default=None, description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 1"
|
||||||
)
|
)
|
||||||
clip2: Optional[ClipField] = Field(
|
clip2: Optional[ClipField] = InputField(
|
||||||
default=None, description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2"
|
default=None, description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -49,6 +49,7 @@ class ModelProbe(object):
|
|||||||
"StableDiffusionInpaintPipeline": ModelType.Main,
|
"StableDiffusionInpaintPipeline": ModelType.Main,
|
||||||
"StableDiffusionXLPipeline": ModelType.Main,
|
"StableDiffusionXLPipeline": ModelType.Main,
|
||||||
"StableDiffusionXLImg2ImgPipeline": ModelType.Main,
|
"StableDiffusionXLImg2ImgPipeline": ModelType.Main,
|
||||||
|
"StableDiffusionXLInpaintPipeline": ModelType.Main,
|
||||||
"AutoencoderKL": ModelType.Vae,
|
"AutoencoderKL": ModelType.Vae,
|
||||||
"AutoencoderTiny": ModelType.Vae,
|
"AutoencoderTiny": ModelType.Vae,
|
||||||
"ControlNetModel": ModelType.ControlNet,
|
"ControlNetModel": ModelType.ControlNet,
|
||||||
|
@ -265,7 +265,7 @@ class InvokeAICrossAttentionMixin:
|
|||||||
if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
|
if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
|
||||||
return self.einsum_lowest_level(q, k, v, None, None, None)
|
return self.einsum_lowest_level(q, k, v, None, None, None)
|
||||||
else:
|
else:
|
||||||
slice_size = math.floor(2 ** 30 / (q.shape[0] * q.shape[1]))
|
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
|
||||||
return self.einsum_op_slice_dim1(q, k, v, slice_size)
|
return self.einsum_op_slice_dim1(q, k, v, slice_size)
|
||||||
|
|
||||||
def einsum_op_mps_v2(self, q, k, v):
|
def einsum_op_mps_v2(self, q, k, v):
|
||||||
|
@ -215,7 +215,10 @@ class InvokeAIDiffuserComponent:
|
|||||||
dim=0,
|
dim=0,
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
(encoder_hidden_states, encoder_attention_mask,) = self._concat_conditionings_for_batch(
|
(
|
||||||
|
encoder_hidden_states,
|
||||||
|
encoder_attention_mask,
|
||||||
|
) = self._concat_conditionings_for_batch(
|
||||||
conditioning_data.unconditioned_embeddings.embeds,
|
conditioning_data.unconditioned_embeddings.embeds,
|
||||||
conditioning_data.text_embeddings.embeds,
|
conditioning_data.text_embeddings.embeds,
|
||||||
)
|
)
|
||||||
@ -277,7 +280,10 @@ class InvokeAIDiffuserComponent:
|
|||||||
wants_cross_attention_control = len(cross_attention_control_types_to_do) > 0
|
wants_cross_attention_control = len(cross_attention_control_types_to_do) > 0
|
||||||
|
|
||||||
if wants_cross_attention_control:
|
if wants_cross_attention_control:
|
||||||
(unconditioned_next_x, conditioned_next_x,) = self._apply_cross_attention_controlled_conditioning(
|
(
|
||||||
|
unconditioned_next_x,
|
||||||
|
conditioned_next_x,
|
||||||
|
) = self._apply_cross_attention_controlled_conditioning(
|
||||||
sample,
|
sample,
|
||||||
timestep,
|
timestep,
|
||||||
conditioning_data,
|
conditioning_data,
|
||||||
@ -285,7 +291,10 @@ class InvokeAIDiffuserComponent:
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
elif self.sequential_guidance:
|
elif self.sequential_guidance:
|
||||||
(unconditioned_next_x, conditioned_next_x,) = self._apply_standard_conditioning_sequentially(
|
(
|
||||||
|
unconditioned_next_x,
|
||||||
|
conditioned_next_x,
|
||||||
|
) = self._apply_standard_conditioning_sequentially(
|
||||||
sample,
|
sample,
|
||||||
timestep,
|
timestep,
|
||||||
conditioning_data,
|
conditioning_data,
|
||||||
@ -293,7 +302,10 @@ class InvokeAIDiffuserComponent:
|
|||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
(unconditioned_next_x, conditioned_next_x,) = self._apply_standard_conditioning(
|
(
|
||||||
|
unconditioned_next_x,
|
||||||
|
conditioned_next_x,
|
||||||
|
) = self._apply_standard_conditioning(
|
||||||
sample,
|
sample,
|
||||||
timestep,
|
timestep,
|
||||||
conditioning_data,
|
conditioning_data,
|
||||||
|
@ -395,7 +395,7 @@ def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
|
|||||||
D = np.diag(np.random.rand(3))
|
D = np.diag(np.random.rand(3))
|
||||||
U = orth(np.random.rand(3, 3))
|
U = orth(np.random.rand(3, 3))
|
||||||
conv = np.dot(np.dot(np.transpose(U), D), U)
|
conv = np.dot(np.dot(np.transpose(U), D), U)
|
||||||
img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
|
img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L**2 * conv), img.shape[:2]).astype(np.float32)
|
||||||
img = np.clip(img, 0.0, 1.0)
|
img = np.clip(img, 0.0, 1.0)
|
||||||
return img
|
return img
|
||||||
|
|
||||||
@ -413,7 +413,7 @@ def add_speckle_noise(img, noise_level1=2, noise_level2=25):
|
|||||||
D = np.diag(np.random.rand(3))
|
D = np.diag(np.random.rand(3))
|
||||||
U = orth(np.random.rand(3, 3))
|
U = orth(np.random.rand(3, 3))
|
||||||
conv = np.dot(np.dot(np.transpose(U), D), U)
|
conv = np.dot(np.dot(np.transpose(U), D), U)
|
||||||
img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
|
img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L**2 * conv), img.shape[:2]).astype(np.float32)
|
||||||
img = np.clip(img, 0.0, 1.0)
|
img = np.clip(img, 0.0, 1.0)
|
||||||
return img
|
return img
|
||||||
|
|
||||||
|
@ -399,7 +399,7 @@ def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
|
|||||||
D = np.diag(np.random.rand(3))
|
D = np.diag(np.random.rand(3))
|
||||||
U = orth(np.random.rand(3, 3))
|
U = orth(np.random.rand(3, 3))
|
||||||
conv = np.dot(np.dot(np.transpose(U), D), U)
|
conv = np.dot(np.dot(np.transpose(U), D), U)
|
||||||
img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
|
img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L**2 * conv), img.shape[:2]).astype(np.float32)
|
||||||
img = np.clip(img, 0.0, 1.0)
|
img = np.clip(img, 0.0, 1.0)
|
||||||
return img
|
return img
|
||||||
|
|
||||||
@ -417,7 +417,7 @@ def add_speckle_noise(img, noise_level1=2, noise_level2=25):
|
|||||||
D = np.diag(np.random.rand(3))
|
D = np.diag(np.random.rand(3))
|
||||||
U = orth(np.random.rand(3, 3))
|
U = orth(np.random.rand(3, 3))
|
||||||
conv = np.dot(np.dot(np.transpose(U), D), U)
|
conv = np.dot(np.dot(np.transpose(U), D), U)
|
||||||
img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
|
img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L**2 * conv), img.shape[:2]).astype(np.float32)
|
||||||
img = np.clip(img, 0.0, 1.0)
|
img = np.clip(img, 0.0, 1.0)
|
||||||
return img
|
return img
|
||||||
|
|
||||||
|
@ -562,18 +562,14 @@ def rgb2ycbcr(img, only_y=True):
|
|||||||
if only_y:
|
if only_y:
|
||||||
rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0
|
rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0
|
||||||
else:
|
else:
|
||||||
rlt = (
|
rlt = np.matmul(
|
||||||
np.matmul(
|
img,
|
||||||
img,
|
[
|
||||||
[
|
[65.481, -37.797, 112.0],
|
||||||
[65.481, -37.797, 112.0],
|
[128.553, -74.203, -93.786],
|
||||||
[128.553, -74.203, -93.786],
|
[24.966, 112.0, -18.214],
|
||||||
[24.966, 112.0, -18.214],
|
],
|
||||||
],
|
) / 255.0 + [16, 128, 128]
|
||||||
)
|
|
||||||
/ 255.0
|
|
||||||
+ [16, 128, 128]
|
|
||||||
)
|
|
||||||
if in_img_type == np.uint8:
|
if in_img_type == np.uint8:
|
||||||
rlt = rlt.round()
|
rlt = rlt.round()
|
||||||
else:
|
else:
|
||||||
@ -592,18 +588,14 @@ def ycbcr2rgb(img):
|
|||||||
if in_img_type != np.uint8:
|
if in_img_type != np.uint8:
|
||||||
img *= 255.0
|
img *= 255.0
|
||||||
# convert
|
# convert
|
||||||
rlt = (
|
rlt = np.matmul(
|
||||||
np.matmul(
|
img,
|
||||||
img,
|
[
|
||||||
[
|
[0.00456621, 0.00456621, 0.00456621],
|
||||||
[0.00456621, 0.00456621, 0.00456621],
|
[0, -0.00153632, 0.00791071],
|
||||||
[0, -0.00153632, 0.00791071],
|
[0.00625893, -0.00318811, 0],
|
||||||
[0.00625893, -0.00318811, 0],
|
],
|
||||||
],
|
) * 255.0 + [-222.921, 135.576, -276.836]
|
||||||
)
|
|
||||||
* 255.0
|
|
||||||
+ [-222.921, 135.576, -276.836]
|
|
||||||
)
|
|
||||||
if in_img_type == np.uint8:
|
if in_img_type == np.uint8:
|
||||||
rlt = rlt.round()
|
rlt = rlt.round()
|
||||||
else:
|
else:
|
||||||
@ -626,18 +618,14 @@ def bgr2ycbcr(img, only_y=True):
|
|||||||
if only_y:
|
if only_y:
|
||||||
rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0
|
rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0
|
||||||
else:
|
else:
|
||||||
rlt = (
|
rlt = np.matmul(
|
||||||
np.matmul(
|
img,
|
||||||
img,
|
[
|
||||||
[
|
[24.966, 112.0, -18.214],
|
||||||
[24.966, 112.0, -18.214],
|
[128.553, -74.203, -93.786],
|
||||||
[128.553, -74.203, -93.786],
|
[65.481, -37.797, 112.0],
|
||||||
[65.481, -37.797, 112.0],
|
],
|
||||||
],
|
) / 255.0 + [16, 128, 128]
|
||||||
)
|
|
||||||
/ 255.0
|
|
||||||
+ [16, 128, 128]
|
|
||||||
)
|
|
||||||
if in_img_type == np.uint8:
|
if in_img_type == np.uint8:
|
||||||
rlt = rlt.round()
|
rlt = rlt.round()
|
||||||
else:
|
else:
|
||||||
@ -728,11 +716,11 @@ def ssim(img1, img2):
|
|||||||
|
|
||||||
mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
|
mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
|
||||||
mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
|
mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
|
||||||
mu1_sq = mu1 ** 2
|
mu1_sq = mu1**2
|
||||||
mu2_sq = mu2 ** 2
|
mu2_sq = mu2**2
|
||||||
mu1_mu2 = mu1 * mu2
|
mu1_mu2 = mu1 * mu2
|
||||||
sigma1_sq = cv2.filter2D(img1 ** 2, -1, window)[5:-5, 5:-5] - mu1_sq
|
sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
|
||||||
sigma2_sq = cv2.filter2D(img2 ** 2, -1, window)[5:-5, 5:-5] - mu2_sq
|
sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
|
||||||
sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
|
sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
|
||||||
|
|
||||||
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
|
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
|
||||||
@ -749,8 +737,8 @@ def ssim(img1, img2):
|
|||||||
# matlab 'imresize' function, now only support 'bicubic'
|
# matlab 'imresize' function, now only support 'bicubic'
|
||||||
def cubic(x):
|
def cubic(x):
|
||||||
absx = torch.abs(x)
|
absx = torch.abs(x)
|
||||||
absx2 = absx ** 2
|
absx2 = absx**2
|
||||||
absx3 = absx ** 3
|
absx3 = absx**3
|
||||||
return (1.5 * absx3 - 2.5 * absx2 + 1) * ((absx <= 1).type_as(absx)) + (
|
return (1.5 * absx3 - 2.5 * absx2 + 1) * ((absx <= 1).type_as(absx)) + (
|
||||||
-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2
|
-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2
|
||||||
) * (((absx > 1) * (absx <= 2)).type_as(absx))
|
) * (((absx > 1) * (absx <= 2)).type_as(absx))
|
||||||
|
@ -475,7 +475,10 @@ class TextualInversionDataset(Dataset):
|
|||||||
|
|
||||||
if self.center_crop:
|
if self.center_crop:
|
||||||
crop = min(img.shape[0], img.shape[1])
|
crop = min(img.shape[0], img.shape[1])
|
||||||
(h, w,) = (
|
(
|
||||||
|
h,
|
||||||
|
w,
|
||||||
|
) = (
|
||||||
img.shape[0],
|
img.shape[0],
|
||||||
img.shape[1],
|
img.shape[1],
|
||||||
)
|
)
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import math
|
import math
|
||||||
import torch
|
|
||||||
import diffusers
|
|
||||||
|
|
||||||
|
import diffusers
|
||||||
|
import torch
|
||||||
|
|
||||||
if torch.backends.mps.is_available():
|
if torch.backends.mps.is_available():
|
||||||
torch.empty = torch.zeros
|
torch.empty = torch.zeros
|
||||||
@ -203,7 +203,7 @@ class ChunkedSlicedAttnProcessor:
|
|||||||
if attn.upcast_attention:
|
if attn.upcast_attention:
|
||||||
out_item_size = 4
|
out_item_size = 4
|
||||||
|
|
||||||
chunk_size = 2 ** 29
|
chunk_size = 2**29
|
||||||
|
|
||||||
out_size = query.shape[1] * key.shape[1] * out_item_size
|
out_size = query.shape[1] * key.shape[1] * out_item_size
|
||||||
chunks_count = min(query.shape[1], math.ceil((out_size - 1) / chunk_size))
|
chunks_count = min(query.shape[1], math.ceil((out_size - 1) / chunk_size))
|
||||||
|
@ -207,7 +207,7 @@ def parallel_data_prefetch(
|
|||||||
return gather_res
|
return gather_res
|
||||||
|
|
||||||
|
|
||||||
def rand_perlin_2d(shape, res, device, fade=lambda t: 6 * t ** 5 - 15 * t ** 4 + 10 * t ** 3):
|
def rand_perlin_2d(shape, res, device, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3):
|
||||||
delta = (res[0] / shape[0], res[1] / shape[1])
|
delta = (res[0] / shape[0], res[1] / shape[1])
|
||||||
d = (shape[0] // res[0], shape[1] // res[1])
|
d = (shape[0] // res[0], shape[1] // res[1])
|
||||||
|
|
||||||
|
@ -104,22 +104,22 @@ const ControlNetImagePreview = ({ isSmall, controlNet }: Props) => {
|
|||||||
]);
|
]);
|
||||||
|
|
||||||
const handleSetControlImageToDimensions = useCallback(() => {
|
const handleSetControlImageToDimensions = useCallback(() => {
|
||||||
if (!processedControlImage) {
|
if (!controlImage) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (activeTabName === 'unifiedCanvas') {
|
if (activeTabName === 'unifiedCanvas') {
|
||||||
dispatch(
|
dispatch(
|
||||||
setBoundingBoxDimensions({
|
setBoundingBoxDimensions({
|
||||||
width: processedControlImage.width,
|
width: controlImage.width,
|
||||||
height: processedControlImage.height,
|
height: controlImage.height,
|
||||||
})
|
})
|
||||||
);
|
);
|
||||||
} else {
|
} else {
|
||||||
dispatch(setWidth(processedControlImage.width));
|
dispatch(setWidth(controlImage.width));
|
||||||
dispatch(setHeight(processedControlImage.height));
|
dispatch(setHeight(controlImage.height));
|
||||||
}
|
}
|
||||||
}, [processedControlImage, activeTabName, dispatch]);
|
}, [controlImage, activeTabName, dispatch]);
|
||||||
|
|
||||||
const handleMouseEnter = useCallback(() => {
|
const handleMouseEnter = useCallback(() => {
|
||||||
setIsMouseOverImage(true);
|
setIsMouseOverImage(true);
|
||||||
|
@ -110,7 +110,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
|||||||
);
|
);
|
||||||
|
|
||||||
const { metadata, workflow, isLoading } = useGetImageMetadataFromFileQuery(
|
const { metadata, workflow, isLoading } = useGetImageMetadataFromFileQuery(
|
||||||
lastSelectedImage?.image_name ?? skipToken,
|
lastSelectedImage ?? skipToken,
|
||||||
{
|
{
|
||||||
selectFromResult: (res) => ({
|
selectFromResult: (res) => ({
|
||||||
isLoading: res.isFetching,
|
isLoading: res.isFetching,
|
||||||
|
@ -52,7 +52,7 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
|
|||||||
const isCanvasEnabled = useFeatureStatus('unifiedCanvas').isFeatureEnabled;
|
const isCanvasEnabled = useFeatureStatus('unifiedCanvas').isFeatureEnabled;
|
||||||
|
|
||||||
const { metadata, workflow, isLoading } = useGetImageMetadataFromFileQuery(
|
const { metadata, workflow, isLoading } = useGetImageMetadataFromFileQuery(
|
||||||
imageDTO.image_name,
|
imageDTO,
|
||||||
{
|
{
|
||||||
selectFromResult: (res) => ({
|
selectFromResult: (res) => ({
|
||||||
isLoading: res.isFetching,
|
isLoading: res.isFetching,
|
||||||
|
@ -101,13 +101,15 @@ const ImageMetadataActions = (props: Props) => {
|
|||||||
onClick={handleRecallSeed}
|
onClick={handleRecallSeed}
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
{metadata.model !== undefined && metadata.model !== null && (
|
{metadata.model !== undefined &&
|
||||||
<ImageMetadataItem
|
metadata.model !== null &&
|
||||||
label="Model"
|
metadata.model.model_name && (
|
||||||
value={metadata.model.model_name}
|
<ImageMetadataItem
|
||||||
onClick={handleRecallModel}
|
label="Model"
|
||||||
/>
|
value={metadata.model.model_name}
|
||||||
)}
|
onClick={handleRecallModel}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
{metadata.width && (
|
{metadata.width && (
|
||||||
<ImageMetadataItem
|
<ImageMetadataItem
|
||||||
label="Width"
|
label="Width"
|
||||||
|
@ -27,15 +27,12 @@ const ImageMetadataViewer = ({ image }: ImageMetadataViewerProps) => {
|
|||||||
// dispatch(setShouldShowImageDetails(false));
|
// dispatch(setShouldShowImageDetails(false));
|
||||||
// });
|
// });
|
||||||
|
|
||||||
const { metadata, workflow } = useGetImageMetadataFromFileQuery(
|
const { metadata, workflow } = useGetImageMetadataFromFileQuery(image, {
|
||||||
image.image_name,
|
selectFromResult: (res) => ({
|
||||||
{
|
metadata: res?.currentData?.metadata,
|
||||||
selectFromResult: (res) => ({
|
workflow: res?.currentData?.workflow,
|
||||||
metadata: res?.currentData?.metadata,
|
}),
|
||||||
workflow: res?.currentData?.workflow,
|
});
|
||||||
}),
|
|
||||||
}
|
|
||||||
);
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Flex
|
<Flex
|
||||||
|
@ -1,8 +1,9 @@
|
|||||||
import { store } from 'app/store/store';
|
|
||||||
import {
|
import {
|
||||||
SchedulerParam,
|
SchedulerParam,
|
||||||
zBaseModel,
|
zBaseModel,
|
||||||
|
zMainModel,
|
||||||
zMainOrOnnxModel,
|
zMainOrOnnxModel,
|
||||||
|
zOnnxModel,
|
||||||
zSDXLRefinerModel,
|
zSDXLRefinerModel,
|
||||||
zScheduler,
|
zScheduler,
|
||||||
} from 'features/parameters/types/parameterSchemas';
|
} from 'features/parameters/types/parameterSchemas';
|
||||||
@ -10,7 +11,6 @@ import { keyBy } from 'lodash-es';
|
|||||||
import { OpenAPIV3 } from 'openapi-types';
|
import { OpenAPIV3 } from 'openapi-types';
|
||||||
import { RgbaColor } from 'react-colorful';
|
import { RgbaColor } from 'react-colorful';
|
||||||
import { Node } from 'reactflow';
|
import { Node } from 'reactflow';
|
||||||
import { JsonObject } from 'type-fest';
|
|
||||||
import { Graph, ImageDTO, _InputField, _OutputField } from 'services/api/types';
|
import { Graph, ImageDTO, _InputField, _OutputField } from 'services/api/types';
|
||||||
import {
|
import {
|
||||||
AnyInvocationType,
|
AnyInvocationType,
|
||||||
@ -18,6 +18,7 @@ import {
|
|||||||
ProgressImage,
|
ProgressImage,
|
||||||
} from 'services/events/types';
|
} from 'services/events/types';
|
||||||
import { O } from 'ts-toolbelt';
|
import { O } from 'ts-toolbelt';
|
||||||
|
import { JsonObject } from 'type-fest';
|
||||||
import { z } from 'zod';
|
import { z } from 'zod';
|
||||||
|
|
||||||
export type NonNullableGraph = O.Required<Graph, 'nodes' | 'edges'>;
|
export type NonNullableGraph = O.Required<Graph, 'nodes' | 'edges'>;
|
||||||
@ -770,12 +771,14 @@ export const zCoreMetadata = z
|
|||||||
steps: z.number().int().nullish(),
|
steps: z.number().int().nullish(),
|
||||||
scheduler: z.string().nullish(),
|
scheduler: z.string().nullish(),
|
||||||
clip_skip: z.number().int().nullish(),
|
clip_skip: z.number().int().nullish(),
|
||||||
model: zMainOrOnnxModel.nullish(),
|
model: z
|
||||||
controlnets: z.array(zControlField).nullish(),
|
.union([zMainModel.deepPartial(), zOnnxModel.deepPartial()])
|
||||||
|
.nullish(),
|
||||||
|
controlnets: z.array(zControlField.deepPartial()).nullish(),
|
||||||
loras: z
|
loras: z
|
||||||
.array(
|
.array(
|
||||||
z.object({
|
z.object({
|
||||||
lora: zLoRAModelField,
|
lora: zLoRAModelField.deepPartial(),
|
||||||
weight: z.number(),
|
weight: z.number(),
|
||||||
})
|
})
|
||||||
)
|
)
|
||||||
@ -785,15 +788,15 @@ export const zCoreMetadata = z
|
|||||||
init_image: z.string().nullish(),
|
init_image: z.string().nullish(),
|
||||||
positive_style_prompt: z.string().nullish(),
|
positive_style_prompt: z.string().nullish(),
|
||||||
negative_style_prompt: z.string().nullish(),
|
negative_style_prompt: z.string().nullish(),
|
||||||
refiner_model: zSDXLRefinerModel.nullish(),
|
refiner_model: zSDXLRefinerModel.deepPartial().nullish(),
|
||||||
refiner_cfg_scale: z.number().nullish(),
|
refiner_cfg_scale: z.number().nullish(),
|
||||||
refiner_steps: z.number().int().nullish(),
|
refiner_steps: z.number().int().nullish(),
|
||||||
refiner_scheduler: z.string().nullish(),
|
refiner_scheduler: z.string().nullish(),
|
||||||
refiner_positive_aesthetic_store: z.number().nullish(),
|
refiner_positive_aesthetic_score: z.number().nullish(),
|
||||||
refiner_negative_aesthetic_store: z.number().nullish(),
|
refiner_negative_aesthetic_score: z.number().nullish(),
|
||||||
refiner_start: z.number().nullish(),
|
refiner_start: z.number().nullish(),
|
||||||
})
|
})
|
||||||
.catchall(z.record(z.any()));
|
.passthrough();
|
||||||
|
|
||||||
export type CoreMetadata = z.infer<typeof zCoreMetadata>;
|
export type CoreMetadata = z.infer<typeof zCoreMetadata>;
|
||||||
|
|
||||||
@ -936,22 +939,10 @@ export const zWorkflow = z.object({
|
|||||||
});
|
});
|
||||||
|
|
||||||
export const zValidatedWorkflow = zWorkflow.transform((workflow) => {
|
export const zValidatedWorkflow = zWorkflow.transform((workflow) => {
|
||||||
const nodeTemplates = store.getState().nodes.nodeTemplates;
|
|
||||||
const { nodes, edges } = workflow;
|
const { nodes, edges } = workflow;
|
||||||
const warnings: WorkflowWarning[] = [];
|
const warnings: WorkflowWarning[] = [];
|
||||||
const invocationNodes = nodes.filter(isWorkflowInvocationNode);
|
const invocationNodes = nodes.filter(isWorkflowInvocationNode);
|
||||||
const keyedNodes = keyBy(invocationNodes, 'id');
|
const keyedNodes = keyBy(invocationNodes, 'id');
|
||||||
invocationNodes.forEach((node, i) => {
|
|
||||||
const nodeTemplate = nodeTemplates[node.data.type];
|
|
||||||
if (!nodeTemplate) {
|
|
||||||
warnings.push({
|
|
||||||
message: `Node "${node.data.label || node.data.id}" skipped`,
|
|
||||||
issues: [`Unable to find template for type "${node.data.type}"`],
|
|
||||||
data: node,
|
|
||||||
});
|
|
||||||
delete nodes[i];
|
|
||||||
}
|
|
||||||
});
|
|
||||||
edges.forEach((edge, i) => {
|
edges.forEach((edge, i) => {
|
||||||
const sourceNode = keyedNodes[edge.source];
|
const sourceNode = keyedNodes[edge.source];
|
||||||
const targetNode = keyedNodes[edge.target];
|
const targetNode = keyedNodes[edge.target];
|
||||||
|
@ -1,4 +1,6 @@
|
|||||||
import * as png from '@stevebel/png';
|
import * as png from '@stevebel/png';
|
||||||
|
import { logger } from 'app/logging/logger';
|
||||||
|
import { parseify } from 'common/util/serialize';
|
||||||
import {
|
import {
|
||||||
ImageMetadataAndWorkflow,
|
ImageMetadataAndWorkflow,
|
||||||
zCoreMetadata,
|
zCoreMetadata,
|
||||||
@ -18,6 +20,11 @@ export const getMetadataAndWorkflowFromImageBlob = async (
|
|||||||
const metadataResult = zCoreMetadata.safeParse(JSON.parse(rawMetadata));
|
const metadataResult = zCoreMetadata.safeParse(JSON.parse(rawMetadata));
|
||||||
if (metadataResult.success) {
|
if (metadataResult.success) {
|
||||||
data.metadata = metadataResult.data;
|
data.metadata = metadataResult.data;
|
||||||
|
} else {
|
||||||
|
logger('system').error(
|
||||||
|
{ error: parseify(metadataResult.error) },
|
||||||
|
'Problem reading metadata from image'
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -26,6 +33,11 @@ export const getMetadataAndWorkflowFromImageBlob = async (
|
|||||||
const workflowResult = zWorkflow.safeParse(JSON.parse(rawWorkflow));
|
const workflowResult = zWorkflow.safeParse(JSON.parse(rawWorkflow));
|
||||||
if (workflowResult.success) {
|
if (workflowResult.success) {
|
||||||
data.workflow = workflowResult.data;
|
data.workflow = workflowResult.data;
|
||||||
|
} else {
|
||||||
|
logger('system').error(
|
||||||
|
{ error: parseify(workflowResult.error) },
|
||||||
|
'Problem reading workflow from image'
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -60,9 +60,9 @@ export const addSDXLRefinerToGraph = (
|
|||||||
|
|
||||||
if (metadataAccumulator) {
|
if (metadataAccumulator) {
|
||||||
metadataAccumulator.refiner_model = refinerModel;
|
metadataAccumulator.refiner_model = refinerModel;
|
||||||
metadataAccumulator.refiner_positive_aesthetic_store =
|
metadataAccumulator.refiner_positive_aesthetic_score =
|
||||||
refinerPositiveAestheticScore;
|
refinerPositiveAestheticScore;
|
||||||
metadataAccumulator.refiner_negative_aesthetic_store =
|
metadataAccumulator.refiner_negative_aesthetic_score =
|
||||||
refinerNegativeAestheticScore;
|
refinerNegativeAestheticScore;
|
||||||
metadataAccumulator.refiner_cfg_scale = refinerCFGScale;
|
metadataAccumulator.refiner_cfg_scale = refinerCFGScale;
|
||||||
metadataAccumulator.refiner_scheduler = refinerScheduler;
|
metadataAccumulator.refiner_scheduler = refinerScheduler;
|
||||||
|
@ -341,8 +341,8 @@ export const useRecallParameters = () => {
|
|||||||
refiner_cfg_scale,
|
refiner_cfg_scale,
|
||||||
refiner_steps,
|
refiner_steps,
|
||||||
refiner_scheduler,
|
refiner_scheduler,
|
||||||
refiner_positive_aesthetic_store,
|
refiner_positive_aesthetic_score,
|
||||||
refiner_negative_aesthetic_store,
|
refiner_negative_aesthetic_score,
|
||||||
refiner_start,
|
refiner_start,
|
||||||
} = metadata;
|
} = metadata;
|
||||||
|
|
||||||
@ -403,21 +403,21 @@ export const useRecallParameters = () => {
|
|||||||
|
|
||||||
if (
|
if (
|
||||||
isValidSDXLRefinerPositiveAestheticScore(
|
isValidSDXLRefinerPositiveAestheticScore(
|
||||||
refiner_positive_aesthetic_store
|
refiner_positive_aesthetic_score
|
||||||
)
|
)
|
||||||
) {
|
) {
|
||||||
dispatch(
|
dispatch(
|
||||||
setRefinerPositiveAestheticScore(refiner_positive_aesthetic_store)
|
setRefinerPositiveAestheticScore(refiner_positive_aesthetic_score)
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (
|
if (
|
||||||
isValidSDXLRefinerNegativeAestheticScore(
|
isValidSDXLRefinerNegativeAestheticScore(
|
||||||
refiner_negative_aesthetic_store
|
refiner_negative_aesthetic_score
|
||||||
)
|
)
|
||||||
) {
|
) {
|
||||||
dispatch(
|
dispatch(
|
||||||
setRefinerNegativeAestheticScore(refiner_negative_aesthetic_store)
|
setRefinerNegativeAestheticScore(refiner_negative_aesthetic_score)
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -28,6 +28,8 @@ import {
|
|||||||
} from '../util';
|
} from '../util';
|
||||||
import { boardsApi } from './boards';
|
import { boardsApi } from './boards';
|
||||||
import { ImageMetadataAndWorkflow } from 'features/nodes/types/types';
|
import { ImageMetadataAndWorkflow } from 'features/nodes/types/types';
|
||||||
|
import { fetchBaseQuery } from '@reduxjs/toolkit/dist/query';
|
||||||
|
import { $authToken, $projectId } from '../client';
|
||||||
|
|
||||||
export const imagesApi = api.injectEndpoints({
|
export const imagesApi = api.injectEndpoints({
|
||||||
endpoints: (build) => ({
|
endpoints: (build) => ({
|
||||||
@ -115,18 +117,40 @@ export const imagesApi = api.injectEndpoints({
|
|||||||
],
|
],
|
||||||
keepUnusedDataFor: 86400, // 24 hours
|
keepUnusedDataFor: 86400, // 24 hours
|
||||||
}),
|
}),
|
||||||
getImageMetadataFromFile: build.query<ImageMetadataAndWorkflow, string>({
|
getImageMetadataFromFile: build.query<ImageMetadataAndWorkflow, ImageDTO>({
|
||||||
query: (image_name) => ({
|
queryFn: async (args: ImageDTO, api, extraOptions) => {
|
||||||
url: `images/i/${image_name}/full`,
|
const authToken = $authToken.get();
|
||||||
responseHandler: async (res) => {
|
const projectId = $projectId.get();
|
||||||
return await res.blob();
|
const customBaseQuery = fetchBaseQuery({
|
||||||
},
|
baseUrl: '',
|
||||||
}),
|
prepareHeaders: (headers) => {
|
||||||
providesTags: (result, error, image_name) => [
|
if (authToken) {
|
||||||
{ type: 'ImageMetadataFromFile', id: image_name },
|
headers.set('Authorization', `Bearer ${authToken}`);
|
||||||
|
}
|
||||||
|
if (projectId) {
|
||||||
|
headers.set('project-id', projectId);
|
||||||
|
}
|
||||||
|
|
||||||
|
return headers;
|
||||||
|
},
|
||||||
|
responseHandler: async (res) => {
|
||||||
|
return await res.blob();
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
const response = await customBaseQuery(
|
||||||
|
args.image_url,
|
||||||
|
api,
|
||||||
|
extraOptions
|
||||||
|
);
|
||||||
|
const data = await getMetadataAndWorkflowFromImageBlob(
|
||||||
|
response.data as Blob
|
||||||
|
);
|
||||||
|
return { data };
|
||||||
|
},
|
||||||
|
providesTags: (result, error, image_dto) => [
|
||||||
|
{ type: 'ImageMetadataFromFile', id: image_dto.image_name },
|
||||||
],
|
],
|
||||||
transformResponse: (response: Blob) =>
|
|
||||||
getMetadataAndWorkflowFromImageBlob(response),
|
|
||||||
keepUnusedDataFor: 86400, // 24 hours
|
keepUnusedDataFor: 86400, // 24 hours
|
||||||
}),
|
}),
|
||||||
clearIntermediates: build.mutation<number, void>({
|
clearIntermediates: build.mutation<number, void>({
|
||||||
|
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user