Merge branch 'main' into bugfix/convert-script

This commit is contained in:
Lincoln Stein 2023-07-29 17:30:40 -04:00 committed by GitHub
commit 078b33bda2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 493 additions and 329 deletions

View File

@ -6,8 +6,7 @@ from pydantic import Field
from invokeai.app.invocations.prompt import PromptOutput from invokeai.app.invocations.prompt import PromptOutput
from .baseinvocation import (BaseInvocation, BaseInvocationOutput, from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
InvocationConfig, InvocationContext)
from .math import FloatOutput, IntOutput from .math import FloatOutput, IntOutput
# Pass-through parameter nodes - used by subgraphs # Pass-through parameter nodes - used by subgraphs
@ -68,6 +67,7 @@ class ParamStringInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> StringOutput: def invoke(self, context: InvocationContext) -> StringOutput:
return StringOutput(text=self.text) return StringOutput(text=self.text)
class ParamPromptInvocation(BaseInvocation): class ParamPromptInvocation(BaseInvocation):
"""A prompt input parameter""" """A prompt input parameter"""

View File

@ -139,8 +139,19 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
useHotkeys('s', handleUseSeed, [imageDTO]); useHotkeys('s', handleUseSeed, [imageDTO]);
const handleUsePrompt = useCallback(() => { const handleUsePrompt = useCallback(() => {
recallBothPrompts(metadata?.positive_prompt, metadata?.negative_prompt); recallBothPrompts(
}, [metadata?.negative_prompt, metadata?.positive_prompt, recallBothPrompts]); metadata?.positive_prompt,
metadata?.negative_prompt,
metadata?.positive_style_prompt,
metadata?.negative_style_prompt
);
}, [
metadata?.negative_prompt,
metadata?.positive_prompt,
metadata?.positive_style_prompt,
metadata?.negative_style_prompt,
recallBothPrompts,
]);
useHotkeys('p', handleUsePrompt, [imageDTO]); useHotkeys('p', handleUsePrompt, [imageDTO]);

View File

@ -102,8 +102,19 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
// Recall parameters handlers // Recall parameters handlers
const handleRecallPrompt = useCallback(() => { const handleRecallPrompt = useCallback(() => {
recallBothPrompts(metadata?.positive_prompt, metadata?.negative_prompt); recallBothPrompts(
}, [metadata?.negative_prompt, metadata?.positive_prompt, recallBothPrompts]); metadata?.positive_prompt,
metadata?.negative_prompt,
metadata?.positive_style_prompt,
metadata?.negative_style_prompt
);
}, [
metadata?.negative_prompt,
metadata?.positive_prompt,
metadata?.positive_style_prompt,
metadata?.negative_style_prompt,
recallBothPrompts,
]);
const handleRecallSeed = useCallback(() => { const handleRecallSeed = useCallback(() => {
recallSeed(metadata?.seed); recallSeed(metadata?.seed);

View File

@ -1,5 +1,15 @@
import { useAppToaster } from 'app/components/Toaster'; import { useAppToaster } from 'app/components/Toaster';
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import {
refinerModelChanged,
setNegativeStylePromptSDXL,
setPositiveStylePromptSDXL,
setRefinerAestheticScore,
setRefinerCFGScale,
setRefinerScheduler,
setRefinerStart,
setRefinerSteps,
} from 'features/sdxl/store/sdxlSlice';
import { useCallback } from 'react'; import { useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { UnsafeImageMetadata } from 'services/api/endpoints/images'; import { UnsafeImageMetadata } from 'services/api/endpoints/images';
@ -22,6 +32,10 @@ import {
isValidMainModel, isValidMainModel,
isValidNegativePrompt, isValidNegativePrompt,
isValidPositivePrompt, isValidPositivePrompt,
isValidSDXLNegativeStylePrompt,
isValidSDXLPositiveStylePrompt,
isValidSDXLRefinerAestheticScore,
isValidSDXLRefinerStart,
isValidScheduler, isValidScheduler,
isValidSeed, isValidSeed,
isValidSteps, isValidSteps,
@ -74,17 +88,34 @@ export const useRecallParameters = () => {
* Recall both prompts with toast * Recall both prompts with toast
*/ */
const recallBothPrompts = useCallback( const recallBothPrompts = useCallback(
(positivePrompt: unknown, negativePrompt: unknown) => { (
positivePrompt: unknown,
negativePrompt: unknown,
positiveStylePrompt: unknown,
negativeStylePrompt: unknown
) => {
if ( if (
isValidPositivePrompt(positivePrompt) || isValidPositivePrompt(positivePrompt) ||
isValidNegativePrompt(negativePrompt) isValidNegativePrompt(negativePrompt) ||
isValidSDXLPositiveStylePrompt(positiveStylePrompt) ||
isValidSDXLNegativeStylePrompt(negativeStylePrompt)
) { ) {
if (isValidPositivePrompt(positivePrompt)) { if (isValidPositivePrompt(positivePrompt)) {
dispatch(setPositivePrompt(positivePrompt)); dispatch(setPositivePrompt(positivePrompt));
} }
if (isValidNegativePrompt(negativePrompt)) { if (isValidNegativePrompt(negativePrompt)) {
dispatch(setNegativePrompt(negativePrompt)); dispatch(setNegativePrompt(negativePrompt));
} }
if (isValidSDXLPositiveStylePrompt(positiveStylePrompt)) {
dispatch(setPositiveStylePromptSDXL(positiveStylePrompt));
}
if (isValidSDXLPositiveStylePrompt(negativeStylePrompt)) {
dispatch(setNegativeStylePromptSDXL(negativeStylePrompt));
}
parameterSetToast(); parameterSetToast();
return; return;
} }
@ -123,6 +154,36 @@ export const useRecallParameters = () => {
[dispatch, parameterSetToast, parameterNotSetToast] [dispatch, parameterSetToast, parameterNotSetToast]
); );
/**
* Recall SDXL Positive Style Prompt with toast
*/
const recallSDXLPositiveStylePrompt = useCallback(
(positiveStylePrompt: unknown) => {
if (!isValidSDXLPositiveStylePrompt(positiveStylePrompt)) {
parameterNotSetToast();
return;
}
dispatch(setPositiveStylePromptSDXL(positiveStylePrompt));
parameterSetToast();
},
[dispatch, parameterSetToast, parameterNotSetToast]
);
/**
* Recall SDXL Negative Style Prompt with toast
*/
const recallSDXLNegativeStylePrompt = useCallback(
(negativeStylePrompt: unknown) => {
if (!isValidSDXLNegativeStylePrompt(negativeStylePrompt)) {
parameterNotSetToast();
return;
}
dispatch(setNegativeStylePromptSDXL(negativeStylePrompt));
parameterSetToast();
},
[dispatch, parameterSetToast, parameterNotSetToast]
);
/** /**
* Recall seed with toast * Recall seed with toast
*/ */
@ -271,6 +332,14 @@ export const useRecallParameters = () => {
steps, steps,
width, width,
strength, strength,
positive_style_prompt,
negative_style_prompt,
refiner_model,
refiner_cfg_scale,
refiner_steps,
refiner_scheduler,
refiner_aesthetic_store,
refiner_start,
} = metadata; } = metadata;
if (isValidCfgScale(cfg_scale)) { if (isValidCfgScale(cfg_scale)) {
@ -304,6 +373,38 @@ export const useRecallParameters = () => {
dispatch(setImg2imgStrength(strength)); dispatch(setImg2imgStrength(strength));
} }
if (isValidSDXLPositiveStylePrompt(positive_style_prompt)) {
dispatch(setPositiveStylePromptSDXL(positive_style_prompt));
}
if (isValidSDXLNegativeStylePrompt(negative_style_prompt)) {
dispatch(setNegativeStylePromptSDXL(negative_style_prompt));
}
if (isValidMainModel(refiner_model)) {
dispatch(refinerModelChanged(refiner_model));
}
if (isValidSteps(refiner_steps)) {
dispatch(setRefinerSteps(refiner_steps));
}
if (isValidCfgScale(refiner_cfg_scale)) {
dispatch(setRefinerCFGScale(refiner_cfg_scale));
}
if (isValidScheduler(refiner_scheduler)) {
dispatch(setRefinerScheduler(refiner_scheduler));
}
if (isValidSDXLRefinerAestheticScore(refiner_aesthetic_store)) {
dispatch(setRefinerAestheticScore(refiner_aesthetic_store));
}
if (isValidSDXLRefinerStart(refiner_start)) {
dispatch(setRefinerStart(refiner_start));
}
allParameterSetToast(); allParameterSetToast();
}, },
[allParameterNotSetToast, allParameterSetToast, dispatch] [allParameterNotSetToast, allParameterSetToast, dispatch]
@ -313,6 +414,8 @@ export const useRecallParameters = () => {
recallBothPrompts, recallBothPrompts,
recallPositivePrompt, recallPositivePrompt,
recallNegativePrompt, recallNegativePrompt,
recallSDXLPositiveStylePrompt,
recallSDXLNegativeStylePrompt,
recallSeed, recallSeed,
recallCfgScale, recallCfgScale,
recallModel, recallModel,

View File

@ -310,6 +310,39 @@ export type PrecisionParam = z.infer<typeof zPrecision>;
export const isValidPrecision = (val: unknown): val is PrecisionParam => export const isValidPrecision = (val: unknown): val is PrecisionParam =>
zPrecision.safeParse(val).success; zPrecision.safeParse(val).success;
/**
* Zod schema for SDXL refiner aesthetic score parameter
*/
export const zSDXLRefinerAestheticScore = z.number().min(1).max(10);
/**
* Type alias for SDXL refiner aesthetic score parameter, inferred from its zod schema
*/
export type SDXLRefinerAestheticScoreParam = z.infer<
typeof zSDXLRefinerAestheticScore
>;
/**
* Validates/type-guards a value as a SDXL refiner aesthetic score parameter
*/
export const isValidSDXLRefinerAestheticScore = (
val: unknown
): val is SDXLRefinerAestheticScoreParam =>
zSDXLRefinerAestheticScore.safeParse(val).success;
/**
* Zod schema for SDXL start parameter
*/
export const zSDXLRefinerstart = z.number().min(0).max(1);
/**
* Type alias for SDXL start, inferred from its zod schema
*/
export type SDXLRefinerStartParam = z.infer<typeof zSDXLRefinerstart>;
/**
* Validates/type-guards a value as a SDXL refiner aesthetic score parameter
*/
export const isValidSDXLRefinerStart = (
val: unknown
): val is SDXLRefinerStartParam => zSDXLRefinerstart.safeParse(val).success;
// /** // /**
// * Zod schema for BaseModelType // * Zod schema for BaseModelType
// */ // */

View File

@ -21,8 +21,8 @@ export default function ParamSDXLConcatButton() {
return ( return (
<IAIIconButton <IAIIconButton
aria-label="Concat" aria-label="Concatenate Prompt & Style"
tooltip="Concatenates Basic Prompt with Style (Recommended)" tooltip="Concatenate Prompt & Style"
variant="outline" variant="outline"
isChecked={shouldConcatSDXLStylePrompt} isChecked={shouldConcatSDXLStylePrompt}
onClick={handleShouldConcatPromptChange} onClick={handleShouldConcatPromptChange}

View File

@ -107,6 +107,7 @@
"source": [ "source": [
"# @title 5. Load small ML models required\n", "# @title 5. Load small ML models required\n",
"import gc\n", "import gc\n",
"\n",
"%cd /content/InvokeAI/\n", "%cd /content/InvokeAI/\n",
"!python scripts/preload_models.py\n", "!python scripts/preload_models.py\n",
"gc.collect()" "gc.collect()"
@ -132,7 +133,8 @@
"source": [ "source": [
"# @title 6. Mount google Drive\n", "# @title 6. Mount google Drive\n",
"from google.colab import drive\n", "from google.colab import drive\n",
"drive.mount('/content/drive')" "\n",
"drive.mount(\"/content/drive\")"
] ]
}, },
{ {
@ -183,7 +185,7 @@
" print(\"❗ Symlink already created\")\n", " print(\"❗ Symlink already created\")\n",
"else:\n", "else:\n",
" src = model_path\n", " src = model_path\n",
" dst = '/content/InvokeAI/models/ldm/stable-diffusion-v1/model.ckpt'\n", " dst = \"/content/InvokeAI/models/ldm/stable-diffusion-v1/model.ckpt\"\n",
" os.symlink(src, dst)\n", " os.symlink(src, dst)\n",
" print(\"✅ Symbolic link created successfully\")" " print(\"✅ Symbolic link created successfully\")"
] ]

View File

@ -52,17 +52,17 @@
"name": "stdout", "name": "stdout",
"text": [ "text": [
"Cloning into 'latent-diffusion'...\n", "Cloning into 'latent-diffusion'...\n",
"remote: Enumerating objects: 992, done.\u001B[K\n", "remote: Enumerating objects: 992, done.\u001b[K\n",
"remote: Counting objects: 100% (695/695), done.\u001B[K\n", "remote: Counting objects: 100% (695/695), done.\u001b[K\n",
"remote: Compressing objects: 100% (397/397), done.\u001B[K\n", "remote: Compressing objects: 100% (397/397), done.\u001b[K\n",
"remote: Total 992 (delta 375), reused 564 (delta 253), pack-reused 297\u001B[K\n", "remote: Total 992 (delta 375), reused 564 (delta 253), pack-reused 297\u001b[K\n",
"Receiving objects: 100% (992/992), 30.78 MiB | 29.43 MiB/s, done.\n", "Receiving objects: 100% (992/992), 30.78 MiB | 29.43 MiB/s, done.\n",
"Resolving deltas: 100% (510/510), done.\n", "Resolving deltas: 100% (510/510), done.\n",
"Cloning into 'taming-transformers'...\n", "Cloning into 'taming-transformers'...\n",
"remote: Enumerating objects: 1335, done.\u001B[K\n", "remote: Enumerating objects: 1335, done.\u001b[K\n",
"remote: Counting objects: 100% (525/525), done.\u001B[K\n", "remote: Counting objects: 100% (525/525), done.\u001b[K\n",
"remote: Compressing objects: 100% (493/493), done.\u001B[K\n", "remote: Compressing objects: 100% (493/493), done.\u001b[K\n",
"remote: Total 1335 (delta 58), reused 481 (delta 30), pack-reused 810\u001B[K\n", "remote: Total 1335 (delta 58), reused 481 (delta 30), pack-reused 810\u001b[K\n",
"Receiving objects: 100% (1335/1335), 412.35 MiB | 30.53 MiB/s, done.\n", "Receiving objects: 100% (1335/1335), 412.35 MiB | 30.53 MiB/s, done.\n",
"Resolving deltas: 100% (267/267), done.\n", "Resolving deltas: 100% (267/267), done.\n",
"Obtaining file:///content/taming-transformers\n", "Obtaining file:///content/taming-transformers\n",
@ -73,9 +73,9 @@
"Installing collected packages: taming-transformers\n", "Installing collected packages: taming-transformers\n",
" Running setup.py develop for taming-transformers\n", " Running setup.py develop for taming-transformers\n",
"Successfully installed taming-transformers-0.0.1\n", "Successfully installed taming-transformers-0.0.1\n",
"\u001B[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
"tensorflow 2.8.0 requires tf-estimator-nightly==2.8.0.dev2021122109, which is not installed.\n", "tensorflow 2.8.0 requires tf-estimator-nightly==2.8.0.dev2021122109, which is not installed.\n",
"arviz 0.11.4 requires typing-extensions<4,>=3.7.4.3, but you have typing-extensions 4.1.1 which is incompatible.\u001B[0m\n" "arviz 0.11.4 requires typing-extensions<4,>=3.7.4.3, but you have typing-extensions 4.1.1 which is incompatible.\u001b[0m\n"
] ]
} }
], ],
@ -87,8 +87,9 @@
"!pip install omegaconf>=2.0.0 pytorch-lightning>=1.0.8 torch-fidelity einops\n", "!pip install omegaconf>=2.0.0 pytorch-lightning>=1.0.8 torch-fidelity einops\n",
"\n", "\n",
"import sys\n", "import sys\n",
"\n",
"sys.path.append(\".\")\n", "sys.path.append(\".\")\n",
"sys.path.append('./taming-transformers')\n", "sys.path.append(\"./taming-transformers\")\n",
"from taming.models import vqgan" "from taming.models import vqgan"
] ]
}, },
@ -299,32 +300,35 @@
" )\n", " )\n",
"\n", "\n",
" for class_label in classes:\n", " for class_label in classes:\n",
" print(f\"rendering {n_samples_per_class} examples of class '{class_label}' in {ddim_steps} steps and using s={scale:.2f}.\")\n", " print(\n",
" f\"rendering {n_samples_per_class} examples of class '{class_label}' in {ddim_steps} steps and using s={scale:.2f}.\"\n",
" )\n",
" xc = torch.tensor(n_samples_per_class * [class_label])\n", " xc = torch.tensor(n_samples_per_class * [class_label])\n",
" c = model.get_learned_conditioning({model.cond_stage_key: xc.to(model.device)})\n", " c = model.get_learned_conditioning({model.cond_stage_key: xc.to(model.device)})\n",
"\n", "\n",
" samples_ddim, _ = sampler.sample(S=ddim_steps,\n", " samples_ddim, _ = sampler.sample(\n",
" S=ddim_steps,\n",
" conditioning=c,\n", " conditioning=c,\n",
" batch_size=n_samples_per_class,\n", " batch_size=n_samples_per_class,\n",
" shape=[3, 64, 64],\n", " shape=[3, 64, 64],\n",
" verbose=False,\n", " verbose=False,\n",
" unconditional_guidance_scale=scale,\n", " unconditional_guidance_scale=scale,\n",
" unconditional_conditioning=uc,\n", " unconditional_conditioning=uc,\n",
" eta=ddim_eta)\n", " eta=ddim_eta,\n",
" )\n",
"\n", "\n",
" x_samples_ddim = model.decode_first_stage(samples_ddim)\n", " x_samples_ddim = model.decode_first_stage(samples_ddim)\n",
" x_samples_ddim = torch.clamp((x_samples_ddim+1.0)/2.0, \n", " x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)\n",
" min=0.0, max=1.0)\n",
" all_samples.append(x_samples_ddim)\n", " all_samples.append(x_samples_ddim)\n",
"\n", "\n",
"\n", "\n",
"# display as grid\n", "# display as grid\n",
"grid = torch.stack(all_samples, 0)\n", "grid = torch.stack(all_samples, 0)\n",
"grid = rearrange(grid, 'n b c h w -> (n b) c h w')\n", "grid = rearrange(grid, \"n b c h w -> (n b) c h w\")\n",
"grid = make_grid(grid, nrow=n_samples_per_class)\n", "grid = make_grid(grid, nrow=n_samples_per_class)\n",
"\n", "\n",
"# to image\n", "# to image\n",
"grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()\n", "grid = 255.0 * rearrange(grid, \"c h w -> h w c\").cpu().numpy()\n",
"Image.fromarray(grid.astype(np.uint8))" "Image.fromarray(grid.astype(np.uint8))"
], ],
"metadata": { "metadata": {