Merge branch 'main' into refactor/model-manager-2

This commit is contained in:
Lincoln Stein 2023-11-10 19:24:19 -05:00 committed by GitHub
commit cb8cdefd59
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 651 additions and 408 deletions

View File

@ -460,10 +460,10 @@ def get_torch_source() -> (Union[str, None], str):
url = "https://download.pytorch.org/whl/cpu"
if device == "cuda":
url = "https://download.pytorch.org/whl/cu121"
url = "https://download.pytorch.org/whl/cu118"
optional_modules = "[xformers,onnx-cuda]"
if device == "cuda_and_dml":
url = "https://download.pytorch.org/whl/cu121"
url = "https://download.pytorch.org/whl/cu118"
optional_modules = "[xformers,onnx-directml]"
# in all other cases, Torch wheels should be coming from PyPi as of Torch 1.13

View File

@ -160,13 +160,14 @@ class CoreMetadataInvocation(BaseInvocation):
)
# High resolution fix metadata.
hrf_width: Optional[int] = InputField(
hrf_enabled: Optional[float] = InputField(
default=None,
description="The high resolution fix height and width multipler.",
description="Whether or not high resolution fix was enabled.",
)
hrf_height: Optional[int] = InputField(
# TODO: should this be stricter or do we just let the UI handle it?
hrf_method: Optional[str] = InputField(
default=None,
description="The high resolution fix height and width multipler.",
description="The high resolution fix upscale method.",
)
hrf_strength: Optional[float] = InputField(
default=None,

View File

@ -254,7 +254,13 @@ class ModelInstall(object):
elif path.is_dir() and any(
[
(path / x).exists()
for x in {"config.json", "model_index.json", "learned_embeds.bin", "pytorch_lora_weights.bin"}
for x in {
"config.json",
"model_index.json",
"learned_embeds.bin",
"pytorch_lora_weights.bin",
"pytorch_lora_weights.safetensors",
}
]
):
models_installed.update({str(model_path_id_or_url): self._install_path(path)})
@ -357,7 +363,7 @@ class ModelInstall(object):
for suffix in ["safetensors", "bin"]:
if f"{prefix}pytorch_lora_weights.{suffix}" in files:
location = self._download_hf_model(
repo_id, ["pytorch_lora_weights.bin"], staging, subfolder=subfolder
repo_id, [f"pytorch_lora_weights.{suffix}"], staging, subfolder=subfolder
) # LoRA
break
elif (

View File

@ -166,6 +166,15 @@ class ModelPatcher:
init_tokens_count = None
new_tokens_added = None
# TODO: This is required since Transformers 4.32 see
# https://github.com/huggingface/transformers/pull/25088
# More information by NVIDIA:
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc
# This value might need to be changed in the future and take the GPUs model into account as there seem
# to be ideal values for different GPUS. This value is temporary!
# For references to the current discussion please see https://github.com/invoke-ai/InvokeAI/pull/4817
pad_to_multiple_of = 8
try:
# HACK: The CLIPTokenizer API does not include a way to remove tokens after calling add_tokens(...). As a
# workaround, we create a full copy of `tokenizer` so that its original behavior can be restored after
@ -175,7 +184,7 @@ class ModelPatcher:
# but a pickle roundtrip was found to be much faster (1 sec vs. 0.05 secs).
ti_tokenizer = pickle.loads(pickle.dumps(tokenizer))
ti_manager = TextualInversionManager(ti_tokenizer)
init_tokens_count = text_encoder.resize_token_embeddings(None).num_embeddings
init_tokens_count = text_encoder.resize_token_embeddings(None, pad_to_multiple_of).num_embeddings
def _get_trigger(ti_name, index):
trigger = ti_name
@ -190,7 +199,7 @@ class ModelPatcher:
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti_name, i))
# modify text_encoder
text_encoder.resize_token_embeddings(init_tokens_count + new_tokens_added)
text_encoder.resize_token_embeddings(init_tokens_count + new_tokens_added, pad_to_multiple_of)
model_embeddings = text_encoder.get_input_embeddings()
for ti_name, ti in ti_list:
@ -222,7 +231,7 @@ class ModelPatcher:
finally:
if init_tokens_count and new_tokens_added:
text_encoder.resize_token_embeddings(init_tokens_count)
text_encoder.resize_token_embeddings(init_tokens_count, pad_to_multiple_of)
@classmethod
@contextmanager

View File

@ -183,12 +183,13 @@ class ModelProbe(object):
if model:
class_name = model.__class__.__name__
else:
for suffix in ["bin", "safetensors"]:
if (folder_path / f"learned_embeds.{suffix}").exists():
return ModelType.TextualInversion
if (folder_path / f"pytorch_lora_weights.{suffix}").exists():
return ModelType.Lora
if (folder_path / "unet/model.onnx").exists():
return ModelType.ONNX
if (folder_path / "learned_embeds.bin").exists():
return ModelType.TextualInversion
if (folder_path / "pytorch_lora_weights.bin").exists():
return ModelType.Lora
if (folder_path / "image_encoder.txt").exists():
return ModelType.IPAdapter

View File

@ -68,8 +68,9 @@ class LoRAModel(ModelBase):
raise ModelNotFoundException()
if os.path.isdir(path):
if os.path.exists(os.path.join(path, "pytorch_lora_weights.bin")):
return LoRAModelFormat.Diffusers
for ext in ["safetensors", "bin"]:
if os.path.exists(os.path.join(path, f"pytorch_lora_weights.{ext}")):
return LoRAModelFormat.Diffusers
if os.path.isfile(path):
if any([path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]]):
@ -86,8 +87,10 @@ class LoRAModel(ModelBase):
base_model: BaseModelType,
) -> str:
if cls.detect_format(model_path) == LoRAModelFormat.Diffusers:
# TODO: add diffusers lora when it stabilizes a bit
raise NotImplementedError("Diffusers lora not supported")
for ext in ["safetensors", "bin"]: # return path to the safetensors file inside the folder
path = Path(model_path, f"pytorch_lora_weights.{ext}")
if path.exists():
return path
else:
return model_path

View File

@ -221,6 +221,19 @@
"resetIPAdapterImage": "Reset IP Adapter Image",
"ipAdapterImageFallback": "No IP Adapter Image Selected"
},
"hrf": {
"hrf": "High Resolution Fix",
"enableHrf": "Enable High Resolution Fix",
"enableHrfTooltip": "Generate with a lower initial resolution, upscale to the base resolution, then run Image-to-Image.",
"upscaleMethod": "Upscale Method",
"hrfStrength": "High Resolution Fix Strength",
"strengthTooltip": "Lower values result in fewer details, which may reduce potential artifacts.",
"metadata": {
"enabled": "High Resolution Fix Enabled",
"strength": "High Resolution Fix Strength",
"method": "High Resolution Fix Method"
}
},
"embedding": {
"addEmbedding": "Add Embedding",
"incompatibleModel": "Incompatible base model:",
@ -1258,15 +1271,11 @@
},
"compositingBlur": {
"heading": "Blur",
"paragraphs": [
"The blur radius of the mask."
]
"paragraphs": ["The blur radius of the mask."]
},
"compositingBlurMethod": {
"heading": "Blur Method",
"paragraphs": [
"The method of blur applied to the masked area."
]
"paragraphs": ["The method of blur applied to the masked area."]
},
"compositingCoherencePass": {
"heading": "Coherence Pass",
@ -1276,9 +1285,7 @@
},
"compositingCoherenceMode": {
"heading": "Mode",
"paragraphs": [
"The mode of the Coherence Pass."
]
"paragraphs": ["The mode of the Coherence Pass."]
},
"compositingCoherenceSteps": {
"heading": "Steps",
@ -1296,9 +1303,7 @@
},
"compositingMaskAdjustments": {
"heading": "Mask Adjustments",
"paragraphs": [
"Adjust the mask."
]
"paragraphs": ["Adjust the mask."]
},
"controlNetBeginEnd": {
"heading": "Begin / End Step Percentage",
@ -1356,9 +1361,7 @@
},
"infillMethod": {
"heading": "Infill Method",
"paragraphs": [
"Method to infill the selected area."
]
"paragraphs": ["Method to infill the selected area."]
},
"lora": {
"heading": "LoRA Weight",

View File

@ -35,6 +35,9 @@ const ImageMetadataActions = (props: Props) => {
recallWidth,
recallHeight,
recallStrength,
recallHrfEnabled,
recallHrfStrength,
recallHrfMethod,
recallLoRA,
recallControlNet,
recallIPAdapter,
@ -81,6 +84,18 @@ const ImageMetadataActions = (props: Props) => {
recallStrength(metadata?.strength);
}, [metadata?.strength, recallStrength]);
const handleRecallHrfEnabled = useCallback(() => {
recallHrfEnabled(metadata?.hrf_enabled);
}, [metadata?.hrf_enabled, recallHrfEnabled]);
const handleRecallHrfStrength = useCallback(() => {
recallHrfStrength(metadata?.hrf_strength);
}, [metadata?.hrf_strength, recallHrfStrength]);
const handleRecallHrfMethod = useCallback(() => {
recallHrfMethod(metadata?.hrf_method);
}, [metadata?.hrf_method, recallHrfMethod]);
const handleRecallLoRA = useCallback(
(lora: LoRAMetadataItem) => {
recallLoRA(lora);
@ -225,6 +240,27 @@ const ImageMetadataActions = (props: Props) => {
onClick={handleRecallStrength}
/>
)}
{metadata.hrf_enabled && (
<ImageMetadataItem
label={t('hrf.metadata.enabled')}
value={metadata.hrf_enabled}
onClick={handleRecallHrfEnabled}
/>
)}
{metadata.hrf_enabled && metadata.hrf_strength && (
<ImageMetadataItem
label={t('hrf.metadata.strength')}
value={metadata.hrf_strength}
onClick={handleRecallHrfStrength}
/>
)}
{metadata.hrf_enabled && metadata.hrf_method && (
<ImageMetadataItem
label={t('hrf.metadata.method')}
value={metadata.hrf_method}
onClick={handleRecallHrfMethod}
/>
)}
{metadata.loras &&
metadata.loras.map((lora, index) => {
if (isValidLoRAModel(lora.lora)) {

View File

@ -1424,6 +1424,9 @@ export const zCoreMetadata = z
loras: z.array(zLoRAMetadataItem).nullish().catch(null),
vae: zVaeModelField.nullish().catch(null),
strength: z.number().nullish().catch(null),
hrf_enabled: z.boolean().nullish().catch(null),
hrf_strength: z.number().nullish().catch(null),
hrf_method: z.string().nullish().catch(null),
init_image: z.string().nullish().catch(null),
positive_style_prompt: z.string().nullish().catch(null),
negative_style_prompt: z.string().nullish().catch(null),

View File

@ -1,22 +1,26 @@
import { logger } from 'app/logging/logger';
import { RootState } from 'app/store/store';
import { roundToMultiple } from 'common/util/roundDownToMultiple';
import { NonNullableGraph } from 'features/nodes/types/types';
import {
DenoiseLatentsInvocation,
ESRGANInvocation,
Edge,
LatentsToImageInvocation,
NoiseInvocation,
ResizeLatentsInvocation,
} from 'services/api/types';
import {
DENOISE_LATENTS,
DENOISE_LATENTS_HRF,
ESRGAN_HRF,
IMAGE_TO_LATENTS_HRF,
LATENTS_TO_IMAGE,
LATENTS_TO_IMAGE_HRF,
LATENTS_TO_IMAGE_HRF_HR,
LATENTS_TO_IMAGE_HRF_LR,
MAIN_MODEL_LOADER,
NOISE,
NOISE_HRF,
RESCALE_LATENTS,
RESIZE_HRF,
VAE_LOADER,
} from './constants';
import { upsertMetadata } from './metadata';
@ -56,6 +60,52 @@ function copyConnectionsToDenoiseLatentsHrf(graph: NonNullableGraph): void {
graph.edges = graph.edges.concat(newEdges);
}
/**
* Calculates the new resolution for high-resolution features (HRF) based on base model type.
* Adjusts the width and height to maintain the aspect ratio and constrains them by the model's dimension limits,
* rounding down to the nearest multiple of 8.
*
* @param {string} baseModel The base model type, which determines the base dimension used in calculations.
* @param {number} width The current width to be adjusted for HRF.
* @param {number} height The current height to be adjusted for HRF.
* @return {{newWidth: number, newHeight: number}} The new width and height, adjusted and rounded as needed.
*/
function calculateHrfRes(
baseModel: string,
width: number,
height: number
): { newWidth: number; newHeight: number } {
const aspect = width / height;
let dimension;
if (baseModel == 'sdxl') {
dimension = 1024;
} else {
dimension = 512;
}
const minDimension = Math.floor(dimension * 0.5);
const modelArea = dimension * dimension; // Assuming square images for model_area
let initWidth;
let initHeight;
if (aspect > 1.0) {
initHeight = Math.max(minDimension, Math.sqrt(modelArea / aspect));
initWidth = initHeight * aspect;
} else {
initWidth = Math.max(minDimension, Math.sqrt(modelArea * aspect));
initHeight = initWidth / aspect;
}
// Cap initial height and width to final height and width.
initWidth = Math.min(width, initWidth);
initHeight = Math.min(height, initHeight);
const newWidth = roundToMultiple(Math.floor(initWidth), 8);
const newHeight = roundToMultiple(Math.floor(initHeight), 8);
return { newWidth, newHeight };
}
// Adds the high-res fix feature to the given graph.
export const addHrfToGraph = (
state: RootState,
@ -71,151 +121,61 @@ export const addHrfToGraph = (
}
const log = logger('txt2img');
const { vae, hrfWidth, hrfHeight, hrfStrength } = state.generation;
const { vae, hrfStrength, hrfEnabled, hrfMethod } = state.generation;
const isAutoVae = !vae;
const width = state.generation.width;
const height = state.generation.height;
const baseModel = state.generation.model
? state.generation.model.base_model
: 'sd1';
const { newWidth: hrfWidth, newHeight: hrfHeight } = calculateHrfRes(
baseModel,
width,
height
);
// Pre-existing (original) graph nodes.
const originalDenoiseLatentsNode = graph.nodes[DENOISE_LATENTS] as
| DenoiseLatentsInvocation
| undefined;
const originalNoiseNode = graph.nodes[NOISE] as NoiseInvocation | undefined;
// Original latents to image should pick this up.
const originalLatentsToImageNode = graph.nodes[LATENTS_TO_IMAGE] as
| LatentsToImageInvocation
| undefined;
// Check if originalDenoiseLatentsNode is undefined and log an error
if (!originalDenoiseLatentsNode) {
log.error('originalDenoiseLatentsNode is undefined');
return;
}
// Check if originalNoiseNode is undefined and log an error
if (!originalNoiseNode) {
log.error('originalNoiseNode is undefined');
return;
}
// Check if originalLatentsToImageNode is undefined and log an error
if (!originalLatentsToImageNode) {
log.error('originalLatentsToImageNode is undefined');
return;
}
// Change height and width of original noise node to initial resolution.
if (originalNoiseNode) {
originalNoiseNode.width = hrfWidth;
originalNoiseNode.height = hrfHeight;
}
// Define new nodes.
// Denoise latents node to be run on upscaled latents.
const denoiseLatentsHrfNode: DenoiseLatentsInvocation = {
type: 'denoise_latents',
id: DENOISE_LATENTS_HRF,
is_intermediate: originalDenoiseLatentsNode?.is_intermediate,
cfg_scale: originalDenoiseLatentsNode?.cfg_scale,
scheduler: originalDenoiseLatentsNode?.scheduler,
steps: originalDenoiseLatentsNode?.steps,
denoising_start: 1 - hrfStrength,
denoising_end: 1,
// Define new nodes and their connections, roughly in order of operations.
graph.nodes[LATENTS_TO_IMAGE_HRF_LR] = {
type: 'l2i',
id: LATENTS_TO_IMAGE_HRF_LR,
fp32: originalLatentsToImageNode?.fp32,
is_intermediate: true,
};
// New base resolution noise node.
const hrfNoiseNode: NoiseInvocation = {
type: 'noise',
id: NOISE_HRF,
seed: originalNoiseNode?.seed,
use_cpu: originalNoiseNode?.use_cpu,
is_intermediate: originalNoiseNode?.is_intermediate,
};
const rescaleLatentsNode: ResizeLatentsInvocation = {
id: RESCALE_LATENTS,
type: 'lresize',
width: state.generation.width,
height: state.generation.height,
};
// New node to convert latents to image.
const latentsToImageHrfNode: LatentsToImageInvocation | undefined =
originalLatentsToImageNode
? {
type: 'l2i',
id: LATENTS_TO_IMAGE_HRF,
fp32: originalLatentsToImageNode?.fp32,
is_intermediate: originalLatentsToImageNode?.is_intermediate,
}
: undefined;
// Add new nodes to graph.
graph.nodes[LATENTS_TO_IMAGE_HRF] =
latentsToImageHrfNode as LatentsToImageInvocation;
graph.nodes[DENOISE_LATENTS_HRF] =
denoiseLatentsHrfNode as DenoiseLatentsInvocation;
graph.nodes[NOISE_HRF] = hrfNoiseNode as NoiseInvocation;
graph.nodes[RESCALE_LATENTS] = rescaleLatentsNode as ResizeLatentsInvocation;
// Connect nodes.
graph.edges.push(
{
// Set up rescale latents.
source: {
node_id: DENOISE_LATENTS,
field: 'latents',
},
destination: {
node_id: RESCALE_LATENTS,
field: 'latents',
},
},
// Set up new noise node
{
source: {
node_id: RESCALE_LATENTS,
field: 'height',
},
destination: {
node_id: NOISE_HRF,
field: 'height',
},
},
{
source: {
node_id: RESCALE_LATENTS,
field: 'width',
},
destination: {
node_id: NOISE_HRF,
field: 'width',
},
},
// Set up new denoise node.
{
source: {
node_id: RESCALE_LATENTS,
field: 'latents',
},
destination: {
node_id: DENOISE_LATENTS_HRF,
field: 'latents',
},
},
{
source: {
node_id: NOISE_HRF,
field: 'noise',
},
destination: {
node_id: DENOISE_LATENTS_HRF,
field: 'noise',
},
},
// Set up new latents to image node.
{
source: {
node_id: DENOISE_LATENTS_HRF,
field: 'latents',
},
destination: {
node_id: LATENTS_TO_IMAGE_HRF,
node_id: LATENTS_TO_IMAGE_HRF_LR,
field: 'latents',
},
},
@ -225,17 +185,188 @@ export const addHrfToGraph = (
field: 'vae',
},
destination: {
node_id: LATENTS_TO_IMAGE_HRF,
node_id: LATENTS_TO_IMAGE_HRF_LR,
field: 'vae',
},
}
);
upsertMetadata(graph, {
hrf_height: hrfHeight,
hrf_width: hrfWidth,
hrf_strength: hrfStrength,
});
graph.nodes[RESIZE_HRF] = {
id: RESIZE_HRF,
type: 'img_resize',
is_intermediate: true,
width: width,
height: height,
};
if (hrfMethod == 'ESRGAN') {
let model_name: ESRGANInvocation['model_name'] = 'RealESRGAN_x2plus.pth';
if ((width * height) / (hrfWidth * hrfHeight) > 2) {
model_name = 'RealESRGAN_x4plus.pth';
}
graph.nodes[ESRGAN_HRF] = {
id: ESRGAN_HRF,
type: 'esrgan',
model_name,
is_intermediate: true,
};
graph.edges.push(
{
source: {
node_id: LATENTS_TO_IMAGE_HRF_LR,
field: 'image',
},
destination: {
node_id: ESRGAN_HRF,
field: 'image',
},
},
{
source: {
node_id: ESRGAN_HRF,
field: 'image',
},
destination: {
node_id: RESIZE_HRF,
field: 'image',
},
}
);
} else {
graph.edges.push({
source: {
node_id: LATENTS_TO_IMAGE_HRF_LR,
field: 'image',
},
destination: {
node_id: RESIZE_HRF,
field: 'image',
},
});
}
graph.nodes[NOISE_HRF] = {
type: 'noise',
id: NOISE_HRF,
seed: originalNoiseNode?.seed,
use_cpu: originalNoiseNode?.use_cpu,
is_intermediate: true,
};
graph.edges.push(
{
source: {
node_id: RESIZE_HRF,
field: 'height',
},
destination: {
node_id: NOISE_HRF,
field: 'height',
},
},
{
source: {
node_id: RESIZE_HRF,
field: 'width',
},
destination: {
node_id: NOISE_HRF,
field: 'width',
},
}
);
graph.nodes[IMAGE_TO_LATENTS_HRF] = {
type: 'i2l',
id: IMAGE_TO_LATENTS_HRF,
is_intermediate: true,
};
graph.edges.push(
{
source: {
node_id: isAutoVae ? MAIN_MODEL_LOADER : VAE_LOADER,
field: 'vae',
},
destination: {
node_id: IMAGE_TO_LATENTS_HRF,
field: 'vae',
},
},
{
source: {
node_id: RESIZE_HRF,
field: 'image',
},
destination: {
node_id: IMAGE_TO_LATENTS_HRF,
field: 'image',
},
}
);
graph.nodes[DENOISE_LATENTS_HRF] = {
type: 'denoise_latents',
id: DENOISE_LATENTS_HRF,
is_intermediate: true,
cfg_scale: originalDenoiseLatentsNode?.cfg_scale,
scheduler: originalDenoiseLatentsNode?.scheduler,
steps: originalDenoiseLatentsNode?.steps,
denoising_start: 1 - state.generation.hrfStrength,
denoising_end: 1,
};
graph.edges.push(
{
source: {
node_id: IMAGE_TO_LATENTS_HRF,
field: 'latents',
},
destination: {
node_id: DENOISE_LATENTS_HRF,
field: 'latents',
},
},
{
source: {
node_id: NOISE_HRF,
field: 'noise',
},
destination: {
node_id: DENOISE_LATENTS_HRF,
field: 'noise',
},
}
);
copyConnectionsToDenoiseLatentsHrf(graph);
graph.nodes[LATENTS_TO_IMAGE_HRF_HR] = {
type: 'l2i',
id: LATENTS_TO_IMAGE_HRF_HR,
fp32: originalLatentsToImageNode?.fp32,
is_intermediate: true,
};
graph.edges.push(
{
source: {
node_id: isAutoVae ? MAIN_MODEL_LOADER : VAE_LOADER,
field: 'vae',
},
destination: {
node_id: LATENTS_TO_IMAGE_HRF_HR,
field: 'vae',
},
},
{
source: {
node_id: DENOISE_LATENTS_HRF,
field: 'latents',
},
destination: {
node_id: LATENTS_TO_IMAGE_HRF_HR,
field: 'latents',
},
}
);
upsertMetadata(graph, {
hrf_strength: hrfStrength,
hrf_enabled: hrfEnabled,
hrf_method: hrfMethod,
});
};

View File

@ -5,7 +5,7 @@ import { SaveImageInvocation } from 'services/api/types';
import {
CANVAS_OUTPUT,
LATENTS_TO_IMAGE,
LATENTS_TO_IMAGE_HRF,
LATENTS_TO_IMAGE_HRF_HR,
NSFW_CHECKER,
SAVE_IMAGE,
WATERMARKER,
@ -62,10 +62,10 @@ export const addSaveImageNode = (
},
destination,
});
} else if (LATENTS_TO_IMAGE_HRF in graph.nodes) {
} else if (LATENTS_TO_IMAGE_HRF_HR in graph.nodes) {
graph.edges.push({
source: {
node_id: LATENTS_TO_IMAGE_HRF,
node_id: LATENTS_TO_IMAGE_HRF_HR,
field: 'image',
},
destination,

View File

@ -4,7 +4,11 @@ export const NEGATIVE_CONDITIONING = 'negative_conditioning';
export const DENOISE_LATENTS = 'denoise_latents';
export const DENOISE_LATENTS_HRF = 'denoise_latents_hrf';
export const LATENTS_TO_IMAGE = 'latents_to_image';
export const LATENTS_TO_IMAGE_HRF = 'latents_to_image_hrf';
export const LATENTS_TO_IMAGE_HRF_HR = 'latents_to_image_hrf_hr';
export const LATENTS_TO_IMAGE_HRF_LR = 'latents_to_image_hrf_lr';
export const IMAGE_TO_LATENTS_HRF = 'image_to_latents_hrf';
export const RESIZE_HRF = 'resize_hrf';
export const ESRGAN_HRF = 'esrgan_hrf';
export const SAVE_IMAGE = 'save_image';
export const NSFW_CHECKER = 'nsfw_checker';
export const WATERMARKER = 'invisible_watermark';
@ -21,7 +25,6 @@ export const CLIP_SKIP = 'clip_skip';
export const IMAGE_TO_LATENTS = 'image_to_latents';
export const LATENTS_TO_LATENTS = 'latents_to_latents';
export const RESIZE = 'resize_image';
export const RESCALE_LATENTS = 'rescale_latents';
export const IMG2IMG_RESIZE = 'img2img_resize';
export const CANVAS_OUTPUT = 'canvas_output';
export const INPAINT_IMAGE = 'inpaint_image';

View File

@ -7,10 +7,9 @@ import IAICollapse from 'common/components/IAICollapse';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import ParamHrfHeight from './ParamHrfHeight';
import ParamHrfStrength from './ParamHrfStrength';
import ParamHrfToggle from './ParamHrfToggle';
import ParamHrfWidth from './ParamHrfWidth';
import ParamHrfMethod from './ParamHrfMethod';
const selector = createSelector(
stateSelector,
@ -37,28 +36,11 @@ export default function ParamHrfCollapse() {
}
return (
<IAICollapse label="High Resolution Fix" activeLabel={activeLabel}>
<IAICollapse label={t('hrf.hrf')} activeLabel={activeLabel}>
<Flex sx={{ flexDir: 'column', gap: 2 }}>
<ParamHrfToggle />
{hrfEnabled && (
<Flex
sx={{
gap: 2,
p: 4,
borderRadius: 4,
flexDirection: 'column',
w: 'full',
bg: 'base.100',
_dark: {
bg: 'base.750',
},
}}
>
<ParamHrfWidth />
<ParamHrfHeight />
</Flex>
)}
{hrfEnabled && <ParamHrfStrength />}
<ParamHrfStrength />
<ParamHrfMethod />
</Flex>
</IAICollapse>
);

View File

@ -1,87 +0,0 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAISlider, { IAIFullSliderProps } from 'common/components/IAISlider';
import { roundToMultiple } from 'common/util/roundDownToMultiple';
import {
setHrfHeight,
setHrfWidth,
} from 'features/parameters/store/generationSlice';
import { memo, useCallback } from 'react';
function findPrevMultipleOfEight(n: number): number {
return Math.floor((n - 1) / 8) * 8;
}
const selector = createSelector(
[stateSelector],
({ generation, hotkeys, config }) => {
const { min, fineStep, coarseStep } = config.sd.height;
const { model, height, hrfHeight, aspectRatio, hrfEnabled } = generation;
const step = hotkeys.shift ? fineStep : coarseStep;
return {
model,
height,
hrfHeight,
min,
step,
aspectRatio,
hrfEnabled,
};
},
defaultSelectorOptions
);
type ParamHeightProps = Omit<
IAIFullSliderProps,
'label' | 'value' | 'onChange'
>;
const ParamHrfHeight = (props: ParamHeightProps) => {
const { height, hrfHeight, min, step, aspectRatio, hrfEnabled } =
useAppSelector(selector);
const dispatch = useAppDispatch();
const maxHrfHeight = Math.max(findPrevMultipleOfEight(height), min);
const handleChange = useCallback(
(v: number) => {
dispatch(setHrfHeight(v));
if (aspectRatio) {
const newWidth = roundToMultiple(v * aspectRatio, 8);
dispatch(setHrfWidth(newWidth));
}
},
[dispatch, aspectRatio]
);
const handleReset = useCallback(() => {
dispatch(setHrfHeight(maxHrfHeight));
if (aspectRatio) {
const newWidth = roundToMultiple(maxHrfHeight * aspectRatio, 8);
dispatch(setHrfWidth(newWidth));
}
}, [dispatch, maxHrfHeight, aspectRatio]);
return (
<IAISlider
label="Initial Height"
value={hrfHeight}
min={min}
step={step}
max={maxHrfHeight}
onChange={handleChange}
handleReset={handleReset}
withInput
withReset
withSliderMarks
sliderNumberInputProps={{ max: maxHrfHeight }}
isDisabled={!hrfEnabled}
{...props}
/>
);
};
export default memo(ParamHrfHeight);

View File

@ -0,0 +1,49 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { setHrfMethod } from 'features/parameters/store/generationSlice';
import { HrfMethodParam } from 'features/parameters/types/parameterSchemas';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
const selector = createSelector(
stateSelector,
({ generation }) => {
const { hrfMethod, hrfEnabled } = generation;
return { hrfMethod, hrfEnabled };
},
defaultSelectorOptions
);
const DATA = ['ESRGAN', 'bilinear'];
// Dropdown selection for the type of high resolution fix method to use.
const ParamHrfMethodSelect = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const { hrfMethod, hrfEnabled } = useAppSelector(selector);
const handleChange = useCallback(
(v: HrfMethodParam | null) => {
if (!v) {
return;
}
dispatch(setHrfMethod(v));
},
[dispatch]
);
return (
<IAIMantineSelect
label={t('hrf.upscaleMethod')}
value={hrfMethod}
data={DATA}
onChange={handleChange}
disabled={!hrfEnabled}
/>
);
};
export default memo(ParamHrfMethodSelect);

View File

@ -5,6 +5,8 @@ import { memo, useCallback } from 'react';
import { stateSelector } from 'app/store/store';
import { setHrfStrength } from 'features/parameters/store/generationSlice';
import IAISlider from 'common/components/IAISlider';
import { Tooltip } from '@chakra-ui/react';
import { useTranslation } from 'react-i18next';
const selector = createSelector(
[stateSelector],
@ -31,6 +33,7 @@ const ParamHrfStrength = () => {
const { hrfStrength, initial, min, sliderMax, step, hrfEnabled } =
useAppSelector(selector);
const dispatch = useAppDispatch();
const { t } = useTranslation();
const handleHrfStrengthReset = useCallback(() => {
dispatch(setHrfStrength(initial));
@ -44,20 +47,21 @@ const ParamHrfStrength = () => {
);
return (
<IAISlider
label="Denoising Strength"
aria-label="High Resolution Denoising Strength"
min={min}
max={sliderMax}
step={step}
value={hrfStrength}
onChange={handleHrfStrengthChange}
withSliderMarks
withInput
withReset
handleReset={handleHrfStrengthReset}
isDisabled={!hrfEnabled}
/>
<Tooltip label={t('hrf.strengthTooltip')} placement="right" hasArrow>
<IAISlider
label={t('parameters.denoisingStrength')}
min={min}
max={sliderMax}
step={step}
value={hrfStrength}
onChange={handleHrfStrengthChange}
withSliderMarks
withInput
withReset
handleReset={handleHrfStrengthReset}
isDisabled={!hrfEnabled}
/>
</Tooltip>
);
};

View File

@ -3,9 +3,11 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAISwitch from 'common/components/IAISwitch';
import { setHrfEnabled } from 'features/parameters/store/generationSlice';
import { ChangeEvent, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
export default function ParamHrfToggle() {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const hrfEnabled = useAppSelector(
(state: RootState) => state.generation.hrfEnabled
@ -19,9 +21,10 @@ export default function ParamHrfToggle() {
return (
<IAISwitch
label="Enable High Resolution Fix"
label={t('hrf.enableHrf')}
isChecked={hrfEnabled}
onChange={handleHrfEnabled}
tooltip={t('hrf.enableHrfTooltip')}
/>
);
}

View File

@ -1,84 +0,0 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAISlider, { IAIFullSliderProps } from 'common/components/IAISlider';
import { roundToMultiple } from 'common/util/roundDownToMultiple';
import {
setHrfHeight,
setHrfWidth,
} from 'features/parameters/store/generationSlice';
import { memo, useCallback } from 'react';
function findPrevMultipleOfEight(n: number): number {
return Math.floor((n - 1) / 8) * 8;
}
const selector = createSelector(
[stateSelector],
({ generation, hotkeys, config }) => {
const { min, fineStep, coarseStep } = config.sd.width;
const { model, width, hrfWidth, aspectRatio, hrfEnabled } = generation;
const step = hotkeys.shift ? fineStep : coarseStep;
return {
model,
width,
hrfWidth,
min,
step,
aspectRatio,
hrfEnabled,
};
},
defaultSelectorOptions
);
type ParamWidthProps = Omit<IAIFullSliderProps, 'label' | 'value' | 'onChange'>;
const ParamHrfWidth = (props: ParamWidthProps) => {
const { width, hrfWidth, min, step, aspectRatio, hrfEnabled } =
useAppSelector(selector);
const dispatch = useAppDispatch();
const maxHrfWidth = Math.max(findPrevMultipleOfEight(width), min);
const handleChange = useCallback(
(v: number) => {
dispatch(setHrfWidth(v));
if (aspectRatio) {
const newHeight = roundToMultiple(v / aspectRatio, 8);
dispatch(setHrfHeight(newHeight));
}
},
[dispatch, aspectRatio]
);
const handleReset = useCallback(() => {
dispatch(setHrfWidth(maxHrfWidth));
if (aspectRatio) {
const newHeight = roundToMultiple(maxHrfWidth / aspectRatio, 8);
dispatch(setHrfHeight(newHeight));
}
}, [dispatch, maxHrfWidth, aspectRatio]);
return (
<IAISlider
label="Initial Width"
value={hrfWidth}
min={min}
step={step}
max={maxHrfWidth}
onChange={handleChange}
handleReset={handleReset}
withInput
withReset
withSliderMarks
sliderNumberInputProps={{ max: maxHrfWidth }}
isDisabled={!hrfEnabled}
{...props}
/>
);
};
export default memo(ParamHrfWidth);

View File

@ -55,6 +55,9 @@ import { initialImageSelected, modelSelected } from '../store/actions';
import {
setCfgScale,
setHeight,
setHrfEnabled,
setHrfMethod,
setHrfStrength,
setImg2imgStrength,
setNegativePrompt,
setPositivePrompt,
@ -67,6 +70,7 @@ import {
isValidCfgScale,
isValidControlNetModel,
isValidHeight,
isValidHrfMethod,
isValidIPAdapterModel,
isValidLoRAModel,
isValidMainModel,
@ -83,6 +87,7 @@ import {
isValidSteps,
isValidStrength,
isValidWidth,
isValidBoolean,
} from '../types/parameterSchemas';
const selector = createSelector(
@ -361,6 +366,51 @@ export const useRecallParameters = () => {
[dispatch, parameterSetToast, parameterNotSetToast]
);
/**
* Recall high resolution enabled with toast
*/
const recallHrfEnabled = useCallback(
(hrfEnabled: unknown) => {
if (!isValidBoolean(hrfEnabled)) {
parameterNotSetToast();
return;
}
dispatch(setHrfEnabled(hrfEnabled));
parameterSetToast();
},
[dispatch, parameterSetToast, parameterNotSetToast]
);
/**
* Recall high resolution strength with toast
*/
const recallHrfStrength = useCallback(
(hrfStrength: unknown) => {
if (!isValidStrength(hrfStrength)) {
parameterNotSetToast();
return;
}
dispatch(setHrfStrength(hrfStrength));
parameterSetToast();
},
[dispatch, parameterSetToast, parameterNotSetToast]
);
/**
* Recall high resolution method with toast
*/
const recallHrfMethod = useCallback(
(hrfMethod: unknown) => {
if (!isValidHrfMethod(hrfMethod)) {
parameterNotSetToast();
return;
}
dispatch(setHrfMethod(hrfMethod));
parameterSetToast();
},
[dispatch, parameterSetToast, parameterNotSetToast]
);
/**
* Recall LoRA with toast
*/
@ -711,6 +761,9 @@ export const useRecallParameters = () => {
steps,
width,
strength,
hrf_enabled,
hrf_strength,
hrf_method,
positive_style_prompt,
negative_style_prompt,
refiner_model,
@ -729,34 +782,55 @@ export const useRecallParameters = () => {
if (isValidCfgScale(cfg_scale)) {
dispatch(setCfgScale(cfg_scale));
}
if (isValidMainModel(model)) {
dispatch(modelSelected(model));
}
if (isValidPositivePrompt(positive_prompt)) {
dispatch(setPositivePrompt(positive_prompt));
}
if (isValidNegativePrompt(negative_prompt)) {
dispatch(setNegativePrompt(negative_prompt));
}
if (isValidScheduler(scheduler)) {
dispatch(setScheduler(scheduler));
}
if (isValidSeed(seed)) {
dispatch(setSeed(seed));
}
if (isValidSteps(steps)) {
dispatch(setSteps(steps));
}
if (isValidWidth(width)) {
dispatch(setWidth(width));
}
if (isValidHeight(height)) {
dispatch(setHeight(height));
}
if (isValidStrength(strength)) {
dispatch(setImg2imgStrength(strength));
}
if (isValidBoolean(hrf_enabled)) {
dispatch(setHrfEnabled(hrf_enabled));
}
if (isValidStrength(hrf_strength)) {
dispatch(setHrfStrength(hrf_strength));
}
if (isValidHrfMethod(hrf_method)) {
dispatch(setHrfMethod(hrf_method));
}
if (isValidSDXLPositiveStylePrompt(positive_style_prompt)) {
dispatch(setPositiveStylePromptSDXL(positive_style_prompt));
}
@ -862,6 +936,9 @@ export const useRecallParameters = () => {
recallWidth,
recallHeight,
recallStrength,
recallHrfEnabled,
recallHrfStrength,
recallHrfMethod,
recallLoRA,
recallControlNet,
recallIPAdapter,

View File

@ -11,6 +11,7 @@ import {
CanvasCoherenceModeParam,
CfgScaleParam,
HeightParam,
HrfMethodParam,
MainModelParam,
MaskBlurMethodParam,
NegativePromptParam,
@ -27,10 +28,9 @@ import {
} from '../types/parameterSchemas';
export interface GenerationState {
hrfHeight: HeightParam;
hrfWidth: WidthParam;
hrfEnabled: boolean;
hrfStrength: StrengthParam;
hrfMethod: HrfMethodParam;
cfgScale: CfgScaleParam;
height: HeightParam;
img2imgStrength: StrengthParam;
@ -73,10 +73,9 @@ export interface GenerationState {
}
export const initialGenerationState: GenerationState = {
hrfHeight: 64,
hrfWidth: 64,
hrfStrength: 0.75,
hrfStrength: 0.45,
hrfEnabled: false,
hrfMethod: 'ESRGAN',
cfgScale: 7.5,
height: 512,
img2imgStrength: 0.75,
@ -279,18 +278,15 @@ export const generationSlice = createSlice({
setClipSkip: (state, action: PayloadAction<number>) => {
state.clipSkip = action.payload;
},
setHrfHeight: (state, action: PayloadAction<number>) => {
state.hrfHeight = action.payload;
},
setHrfWidth: (state, action: PayloadAction<number>) => {
state.hrfWidth = action.payload;
},
setHrfStrength: (state, action: PayloadAction<number>) => {
state.hrfStrength = action.payload;
},
setHrfEnabled: (state, action: PayloadAction<boolean>) => {
state.hrfEnabled = action.payload;
},
setHrfMethod: (state, action: PayloadAction<HrfMethodParam>) => {
state.hrfMethod = action.payload;
},
shouldUseCpuNoiseChanged: (state, action: PayloadAction<boolean>) => {
state.shouldUseCpuNoise = action.payload;
},
@ -375,10 +371,9 @@ export const {
setSeamlessXAxis,
setSeamlessYAxis,
setClipSkip,
setHrfHeight,
setHrfWidth,
setHrfStrength,
setHrfEnabled,
setHrfStrength,
setHrfMethod,
shouldUseCpuNoiseChanged,
setAspectRatio,
setShouldLockAspectRatio,

View File

@ -400,6 +400,20 @@ export type PrecisionParam = z.infer<typeof zPrecision>;
export const isValidPrecision = (val: unknown): val is PrecisionParam =>
zPrecision.safeParse(val).success;
/**
* Zod schema for a high resolution fix method parameter.
*/
export const zHrfMethod = z.enum(['ESRGAN', 'bilinear']);
/**
* Type alias for high resolution fix method parameter, inferred from its zod schema
*/
export type HrfMethodParam = z.infer<typeof zHrfMethod>;
/**
* Validates/type-guards a value as a high resolution fix method parameter
*/
export const isValidHrfMethod = (val: unknown): val is HrfMethodParam =>
zHrfMethod.safeParse(val).success;
/**
* Zod schema for SDXL refiner positive aesthetic score parameter
*/
@ -482,6 +496,17 @@ export const isValidCoherenceModeParam = (
): val is CanvasCoherenceModeParam =>
zCanvasCoherenceMode.safeParse(val).success;
/**
* Zod schema for a boolean.
*/
export const zBoolean = z.boolean();
/**
* Validates/type-guards a value as a boolean parameter
*/
export const isValidBoolean = (val: unknown): val is boolean =>
zBoolean.safeParse(val).success && val !== null && val !== undefined;
// /**
// * Zod schema for BaseModelType
// */

View File

@ -70,7 +70,7 @@ export const initialConfigState: AppConfig = {
coarseStep: 0.05,
},
hrfStrength: {
initial: 0.7,
initial: 0.45,
min: 0,
sliderMax: 1,
inputMax: 1,

File diff suppressed because one or more lines are too long

View File

@ -127,7 +127,6 @@ export type CompelInvocation = s['CompelInvocation'];
export type DynamicPromptInvocation = s['DynamicPromptInvocation'];
export type NoiseInvocation = s['NoiseInvocation'];
export type DenoiseLatentsInvocation = s['DenoiseLatentsInvocation'];
export type ResizeLatentsInvocation = s['ResizeLatentsInvocation'];
export type ONNXTextToLatentsInvocation = s['ONNXTextToLatentsInvocation'];
export type SDXLLoraLoaderInvocation = s['SDXLLoraLoaderInvocation'];
export type ImageToLatentsInvocation = s['ImageToLatentsInvocation'];

View File

@ -83,7 +83,7 @@ dependencies = [
"torchvision~=0.16",
"torchmetrics~=0.11.0",
"torchsde~=0.2.5",
"transformers~=4.31.0",
"transformers~=4.35.0",
"uvicorn[standard]~=0.21.1",
"windows-curses; sys_platform=='win32'",
]