Merge branch 'bugfix/convert-script' of github.com:invoke-ai/InvokeAI into bugfix/convert-script

This commit is contained in:
Lincoln Stein 2023-07-29 17:31:02 -04:00
commit bb18251fad
7 changed files with 490 additions and 326 deletions

View File

@ -139,8 +139,19 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
useHotkeys('s', handleUseSeed, [imageDTO]); useHotkeys('s', handleUseSeed, [imageDTO]);
const handleUsePrompt = useCallback(() => { const handleUsePrompt = useCallback(() => {
recallBothPrompts(metadata?.positive_prompt, metadata?.negative_prompt); recallBothPrompts(
}, [metadata?.negative_prompt, metadata?.positive_prompt, recallBothPrompts]); metadata?.positive_prompt,
metadata?.negative_prompt,
metadata?.positive_style_prompt,
metadata?.negative_style_prompt
);
}, [
metadata?.negative_prompt,
metadata?.positive_prompt,
metadata?.positive_style_prompt,
metadata?.negative_style_prompt,
recallBothPrompts,
]);
useHotkeys('p', handleUsePrompt, [imageDTO]); useHotkeys('p', handleUsePrompt, [imageDTO]);

View File

@ -102,8 +102,19 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
// Recall parameters handlers // Recall parameters handlers
const handleRecallPrompt = useCallback(() => { const handleRecallPrompt = useCallback(() => {
recallBothPrompts(metadata?.positive_prompt, metadata?.negative_prompt); recallBothPrompts(
}, [metadata?.negative_prompt, metadata?.positive_prompt, recallBothPrompts]); metadata?.positive_prompt,
metadata?.negative_prompt,
metadata?.positive_style_prompt,
metadata?.negative_style_prompt
);
}, [
metadata?.negative_prompt,
metadata?.positive_prompt,
metadata?.positive_style_prompt,
metadata?.negative_style_prompt,
recallBothPrompts,
]);
const handleRecallSeed = useCallback(() => { const handleRecallSeed = useCallback(() => {
recallSeed(metadata?.seed); recallSeed(metadata?.seed);

View File

@ -1,5 +1,15 @@
import { useAppToaster } from 'app/components/Toaster'; import { useAppToaster } from 'app/components/Toaster';
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import {
refinerModelChanged,
setNegativeStylePromptSDXL,
setPositiveStylePromptSDXL,
setRefinerAestheticScore,
setRefinerCFGScale,
setRefinerScheduler,
setRefinerStart,
setRefinerSteps,
} from 'features/sdxl/store/sdxlSlice';
import { useCallback } from 'react'; import { useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { UnsafeImageMetadata } from 'services/api/endpoints/images'; import { UnsafeImageMetadata } from 'services/api/endpoints/images';
@ -22,6 +32,10 @@ import {
isValidMainModel, isValidMainModel,
isValidNegativePrompt, isValidNegativePrompt,
isValidPositivePrompt, isValidPositivePrompt,
isValidSDXLNegativeStylePrompt,
isValidSDXLPositiveStylePrompt,
isValidSDXLRefinerAestheticScore,
isValidSDXLRefinerStart,
isValidScheduler, isValidScheduler,
isValidSeed, isValidSeed,
isValidSteps, isValidSteps,
@ -74,17 +88,34 @@ export const useRecallParameters = () => {
* Recall both prompts with toast * Recall both prompts with toast
*/ */
const recallBothPrompts = useCallback( const recallBothPrompts = useCallback(
(positivePrompt: unknown, negativePrompt: unknown) => { (
positivePrompt: unknown,
negativePrompt: unknown,
positiveStylePrompt: unknown,
negativeStylePrompt: unknown
) => {
if ( if (
isValidPositivePrompt(positivePrompt) || isValidPositivePrompt(positivePrompt) ||
isValidNegativePrompt(negativePrompt) isValidNegativePrompt(negativePrompt) ||
isValidSDXLPositiveStylePrompt(positiveStylePrompt) ||
isValidSDXLNegativeStylePrompt(negativeStylePrompt)
) { ) {
if (isValidPositivePrompt(positivePrompt)) { if (isValidPositivePrompt(positivePrompt)) {
dispatch(setPositivePrompt(positivePrompt)); dispatch(setPositivePrompt(positivePrompt));
} }
if (isValidNegativePrompt(negativePrompt)) { if (isValidNegativePrompt(negativePrompt)) {
dispatch(setNegativePrompt(negativePrompt)); dispatch(setNegativePrompt(negativePrompt));
} }
if (isValidSDXLPositiveStylePrompt(positiveStylePrompt)) {
dispatch(setPositiveStylePromptSDXL(positiveStylePrompt));
}
if (isValidSDXLPositiveStylePrompt(negativeStylePrompt)) {
dispatch(setNegativeStylePromptSDXL(negativeStylePrompt));
}
parameterSetToast(); parameterSetToast();
return; return;
} }
@ -123,6 +154,36 @@ export const useRecallParameters = () => {
[dispatch, parameterSetToast, parameterNotSetToast] [dispatch, parameterSetToast, parameterNotSetToast]
); );
/**
* Recall SDXL Positive Style Prompt with toast
*/
const recallSDXLPositiveStylePrompt = useCallback(
(positiveStylePrompt: unknown) => {
if (!isValidSDXLPositiveStylePrompt(positiveStylePrompt)) {
parameterNotSetToast();
return;
}
dispatch(setPositiveStylePromptSDXL(positiveStylePrompt));
parameterSetToast();
},
[dispatch, parameterSetToast, parameterNotSetToast]
);
/**
* Recall SDXL Negative Style Prompt with toast
*/
const recallSDXLNegativeStylePrompt = useCallback(
(negativeStylePrompt: unknown) => {
if (!isValidSDXLNegativeStylePrompt(negativeStylePrompt)) {
parameterNotSetToast();
return;
}
dispatch(setNegativeStylePromptSDXL(negativeStylePrompt));
parameterSetToast();
},
[dispatch, parameterSetToast, parameterNotSetToast]
);
/** /**
* Recall seed with toast * Recall seed with toast
*/ */
@ -271,6 +332,14 @@ export const useRecallParameters = () => {
steps, steps,
width, width,
strength, strength,
positive_style_prompt,
negative_style_prompt,
refiner_model,
refiner_cfg_scale,
refiner_steps,
refiner_scheduler,
refiner_aesthetic_store,
refiner_start,
} = metadata; } = metadata;
if (isValidCfgScale(cfg_scale)) { if (isValidCfgScale(cfg_scale)) {
@ -304,6 +373,38 @@ export const useRecallParameters = () => {
dispatch(setImg2imgStrength(strength)); dispatch(setImg2imgStrength(strength));
} }
if (isValidSDXLPositiveStylePrompt(positive_style_prompt)) {
dispatch(setPositiveStylePromptSDXL(positive_style_prompt));
}
if (isValidSDXLNegativeStylePrompt(negative_style_prompt)) {
dispatch(setNegativeStylePromptSDXL(negative_style_prompt));
}
if (isValidMainModel(refiner_model)) {
dispatch(refinerModelChanged(refiner_model));
}
if (isValidSteps(refiner_steps)) {
dispatch(setRefinerSteps(refiner_steps));
}
if (isValidCfgScale(refiner_cfg_scale)) {
dispatch(setRefinerCFGScale(refiner_cfg_scale));
}
if (isValidScheduler(refiner_scheduler)) {
dispatch(setRefinerScheduler(refiner_scheduler));
}
if (isValidSDXLRefinerAestheticScore(refiner_aesthetic_store)) {
dispatch(setRefinerAestheticScore(refiner_aesthetic_store));
}
if (isValidSDXLRefinerStart(refiner_start)) {
dispatch(setRefinerStart(refiner_start));
}
allParameterSetToast(); allParameterSetToast();
}, },
[allParameterNotSetToast, allParameterSetToast, dispatch] [allParameterNotSetToast, allParameterSetToast, dispatch]
@ -313,6 +414,8 @@ export const useRecallParameters = () => {
recallBothPrompts, recallBothPrompts,
recallPositivePrompt, recallPositivePrompt,
recallNegativePrompt, recallNegativePrompt,
recallSDXLPositiveStylePrompt,
recallSDXLNegativeStylePrompt,
recallSeed, recallSeed,
recallCfgScale, recallCfgScale,
recallModel, recallModel,

View File

@ -310,6 +310,39 @@ export type PrecisionParam = z.infer<typeof zPrecision>;
export const isValidPrecision = (val: unknown): val is PrecisionParam => export const isValidPrecision = (val: unknown): val is PrecisionParam =>
zPrecision.safeParse(val).success; zPrecision.safeParse(val).success;
/**
* Zod schema for SDXL refiner aesthetic score parameter
*/
export const zSDXLRefinerAestheticScore = z.number().min(1).max(10);
/**
* Type alias for SDXL refiner aesthetic score parameter, inferred from its zod schema
*/
export type SDXLRefinerAestheticScoreParam = z.infer<
typeof zSDXLRefinerAestheticScore
>;
/**
* Validates/type-guards a value as a SDXL refiner aesthetic score parameter
*/
export const isValidSDXLRefinerAestheticScore = (
val: unknown
): val is SDXLRefinerAestheticScoreParam =>
zSDXLRefinerAestheticScore.safeParse(val).success;
/**
* Zod schema for SDXL start parameter
*/
export const zSDXLRefinerstart = z.number().min(0).max(1);
/**
* Type alias for SDXL start, inferred from its zod schema
*/
export type SDXLRefinerStartParam = z.infer<typeof zSDXLRefinerstart>;
/**
* Validates/type-guards a value as a SDXL refiner aesthetic score parameter
*/
export const isValidSDXLRefinerStart = (
val: unknown
): val is SDXLRefinerStartParam => zSDXLRefinerstart.safeParse(val).success;
// /** // /**
// * Zod schema for BaseModelType // * Zod schema for BaseModelType
// */ // */

View File

@ -21,8 +21,8 @@ export default function ParamSDXLConcatButton() {
return ( return (
<IAIIconButton <IAIIconButton
aria-label="Concat" aria-label="Concatenate Prompt & Style"
tooltip="Concatenates Basic Prompt with Style (Recommended)" tooltip="Concatenate Prompt & Style"
variant="outline" variant="outline"
isChecked={shouldConcatSDXLStylePrompt} isChecked={shouldConcatSDXLStylePrompt}
onClick={handleShouldConcatPromptChange} onClick={handleShouldConcatPromptChange}

View File

@ -40,7 +40,7 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"#@title 1. Check current GPU assigned\n", "# @title 1. Check current GPU assigned\n",
"!nvidia-smi -L\n", "!nvidia-smi -L\n",
"!nvidia-smi" "!nvidia-smi"
] ]
@ -54,7 +54,7 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"#@title 2. Download stable-diffusion Repository\n", "# @title 2. Download stable-diffusion Repository\n",
"from os.path import exists\n", "from os.path import exists\n",
"\n", "\n",
"!git clone --quiet https://github.com/invoke-ai/InvokeAI.git # Original repo\n", "!git clone --quiet https://github.com/invoke-ai/InvokeAI.git # Original repo\n",
@ -71,7 +71,7 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"#@title 3. Install dependencies\n", "# @title 3. Install dependencies\n",
"import gc\n", "import gc\n",
"\n", "\n",
"!wget https://raw.githubusercontent.com/invoke-ai/InvokeAI/development/environments-and-requirements/requirements-base.txt\n", "!wget https://raw.githubusercontent.com/invoke-ai/InvokeAI/development/environments-and-requirements/requirements-base.txt\n",
@ -92,7 +92,7 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"#@title 4. Restart Runtime\n", "# @title 4. Restart Runtime\n",
"exit()" "exit()"
] ]
}, },
@ -105,8 +105,9 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"#@title 5. Load small ML models required\n", "# @title 5. Load small ML models required\n",
"import gc\n", "import gc\n",
"\n",
"%cd /content/InvokeAI/\n", "%cd /content/InvokeAI/\n",
"!python scripts/preload_models.py\n", "!python scripts/preload_models.py\n",
"gc.collect()" "gc.collect()"
@ -130,9 +131,10 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"#@title 6. Mount google Drive\n", "# @title 6. Mount google Drive\n",
"from google.colab import drive\n", "from google.colab import drive\n",
"drive.mount('/content/drive')" "\n",
"drive.mount(\"/content/drive\")"
] ]
}, },
{ {
@ -144,16 +146,16 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"#@title 7. Drive Path to model\n", "# @title 7. Drive Path to model\n",
"#@markdown Path should start with /content/drive/path-to-your-file <br>\n", "# @markdown Path should start with /content/drive/path-to-your-file <br>\n",
"#@markdown <font color=\"red\">Note:</font> Model should be downloaded from https://huggingface.co <br>\n", "# @markdown <font color=\"red\">Note:</font> Model should be downloaded from https://huggingface.co <br>\n",
"#@markdown Lastest release: [Stable-Diffusion-v-1-4](https://huggingface.co/CompVis/stable-diffusion-v-1-4-original)\n", "# @markdown Lastest release: [Stable-Diffusion-v-1-4](https://huggingface.co/CompVis/stable-diffusion-v-1-4-original)\n",
"from os.path import exists\n", "from os.path import exists\n",
"\n", "\n",
"model_path = \"\" #@param {type:\"string\"}\n", "model_path = \"\" # @param {type:\"string\"}\n",
"if exists(model_path):\n", "if exists(model_path):\n",
" print(\"✅ Valid directory\")\n", " print(\"✅ Valid directory\")\n",
"else: \n", "else:\n",
" print(\"❌ File doesn't exist\")" " print(\"❌ File doesn't exist\")"
] ]
}, },
@ -166,10 +168,10 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"#@title 8. Symlink to model\n", "# @title 8. Symlink to model\n",
"\n", "\n",
"from os.path import exists\n", "from os.path import exists\n",
"import os \n", "import os\n",
"\n", "\n",
"# Folder creation if it doesn't exist\n", "# Folder creation if it doesn't exist\n",
"if exists(\"/content/InvokeAI/models/ldm/stable-diffusion-v1\"):\n", "if exists(\"/content/InvokeAI/models/ldm/stable-diffusion-v1\"):\n",
@ -181,10 +183,10 @@
"# Symbolic link if it doesn't exist\n", "# Symbolic link if it doesn't exist\n",
"if exists(\"/content/InvokeAI/models/ldm/stable-diffusion-v1/model.ckpt\"):\n", "if exists(\"/content/InvokeAI/models/ldm/stable-diffusion-v1/model.ckpt\"):\n",
" print(\"❗ Symlink already created\")\n", " print(\"❗ Symlink already created\")\n",
"else: \n", "else:\n",
" src = model_path\n", " src = model_path\n",
" dst = '/content/InvokeAI/models/ldm/stable-diffusion-v1/model.ckpt'\n", " dst = \"/content/InvokeAI/models/ldm/stable-diffusion-v1/model.ckpt\"\n",
" os.symlink(src, dst) \n", " os.symlink(src, dst)\n",
" print(\"✅ Symbolic link created successfully\")" " print(\"✅ Symbolic link created successfully\")"
] ]
}, },
@ -206,12 +208,12 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"#@title 9. Run Terminal and Execute Dream bot\n", "# @title 9. Run Terminal and Execute Dream bot\n",
"#@markdown <font color=\"blue\">Steps:</font> <br>\n", "# @markdown <font color=\"blue\">Steps:</font> <br>\n",
"#@markdown 1. Execute command `python scripts/invoke.py` to run InvokeAI.<br>\n", "# @markdown 1. Execute command `python scripts/invoke.py` to run InvokeAI.<br>\n",
"#@markdown 2. After initialized you'll see `Dream>` line.<br>\n", "# @markdown 2. After initialized you'll see `Dream>` line.<br>\n",
"#@markdown 3. Example text: `Astronaut floating in a distant galaxy` <br>\n", "# @markdown 3. Example text: `Astronaut floating in a distant galaxy` <br>\n",
"#@markdown 4. To quit Dream bot use: `q` command.<br>\n", "# @markdown 4. To quit Dream bot use: `q` command.<br>\n",
"\n", "\n",
"%load_ext colabxterm\n", "%load_ext colabxterm\n",
"%xterm\n", "%xterm\n",

View File

@ -52,17 +52,17 @@
"name": "stdout", "name": "stdout",
"text": [ "text": [
"Cloning into 'latent-diffusion'...\n", "Cloning into 'latent-diffusion'...\n",
"remote: Enumerating objects: 992, done.\u001B[K\n", "remote: Enumerating objects: 992, done.\u001b[K\n",
"remote: Counting objects: 100% (695/695), done.\u001B[K\n", "remote: Counting objects: 100% (695/695), done.\u001b[K\n",
"remote: Compressing objects: 100% (397/397), done.\u001B[K\n", "remote: Compressing objects: 100% (397/397), done.\u001b[K\n",
"remote: Total 992 (delta 375), reused 564 (delta 253), pack-reused 297\u001B[K\n", "remote: Total 992 (delta 375), reused 564 (delta 253), pack-reused 297\u001b[K\n",
"Receiving objects: 100% (992/992), 30.78 MiB | 29.43 MiB/s, done.\n", "Receiving objects: 100% (992/992), 30.78 MiB | 29.43 MiB/s, done.\n",
"Resolving deltas: 100% (510/510), done.\n", "Resolving deltas: 100% (510/510), done.\n",
"Cloning into 'taming-transformers'...\n", "Cloning into 'taming-transformers'...\n",
"remote: Enumerating objects: 1335, done.\u001B[K\n", "remote: Enumerating objects: 1335, done.\u001b[K\n",
"remote: Counting objects: 100% (525/525), done.\u001B[K\n", "remote: Counting objects: 100% (525/525), done.\u001b[K\n",
"remote: Compressing objects: 100% (493/493), done.\u001B[K\n", "remote: Compressing objects: 100% (493/493), done.\u001b[K\n",
"remote: Total 1335 (delta 58), reused 481 (delta 30), pack-reused 810\u001B[K\n", "remote: Total 1335 (delta 58), reused 481 (delta 30), pack-reused 810\u001b[K\n",
"Receiving objects: 100% (1335/1335), 412.35 MiB | 30.53 MiB/s, done.\n", "Receiving objects: 100% (1335/1335), 412.35 MiB | 30.53 MiB/s, done.\n",
"Resolving deltas: 100% (267/267), done.\n", "Resolving deltas: 100% (267/267), done.\n",
"Obtaining file:///content/taming-transformers\n", "Obtaining file:///content/taming-transformers\n",
@ -73,23 +73,24 @@
"Installing collected packages: taming-transformers\n", "Installing collected packages: taming-transformers\n",
" Running setup.py develop for taming-transformers\n", " Running setup.py develop for taming-transformers\n",
"Successfully installed taming-transformers-0.0.1\n", "Successfully installed taming-transformers-0.0.1\n",
"\u001B[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
"tensorflow 2.8.0 requires tf-estimator-nightly==2.8.0.dev2021122109, which is not installed.\n", "tensorflow 2.8.0 requires tf-estimator-nightly==2.8.0.dev2021122109, which is not installed.\n",
"arviz 0.11.4 requires typing-extensions<4,>=3.7.4.3, but you have typing-extensions 4.1.1 which is incompatible.\u001B[0m\n" "arviz 0.11.4 requires typing-extensions<4,>=3.7.4.3, but you have typing-extensions 4.1.1 which is incompatible.\u001b[0m\n"
] ]
} }
], ],
"source": [ "source": [
"#@title Installation\n", "# @title Installation\n",
"!git clone https://github.com/CompVis/latent-diffusion.git\n", "!git clone https://github.com/CompVis/latent-diffusion.git\n",
"!git clone https://github.com/CompVis/taming-transformers\n", "!git clone https://github.com/CompVis/taming-transformers\n",
"!pip install -e ./taming-transformers\n", "!pip install -e ./taming-transformers\n",
"!pip install omegaconf>=2.0.0 pytorch-lightning>=1.0.8 torch-fidelity einops\n", "!pip install omegaconf>=2.0.0 pytorch-lightning>=1.0.8 torch-fidelity einops\n",
"\n", "\n",
"import sys\n", "import sys\n",
"\n",
"sys.path.append(\".\")\n", "sys.path.append(\".\")\n",
"sys.path.append('./taming-transformers')\n", "sys.path.append(\"./taming-transformers\")\n",
"from taming.models import vqgan " "from taming.models import vqgan"
] ]
}, },
{ {
@ -104,11 +105,11 @@
{ {
"cell_type": "code", "cell_type": "code",
"source": [ "source": [
"#@title Download\n", "# @title Download\n",
"%cd latent-diffusion/ \n", "%cd latent-diffusion/\n",
"\n", "\n",
"!mkdir -p models/ldm/cin256-v2/\n", "!mkdir -p models/ldm/cin256-v2/\n",
"!wget -O models/ldm/cin256-v2/model.ckpt https://ommer-lab.com/files/latent-diffusion/nitro/cin/model.ckpt " "!wget -O models/ldm/cin256-v2/model.ckpt https://ommer-lab.com/files/latent-diffusion/nitro/cin/model.ckpt"
], ],
"metadata": { "metadata": {
"colab": { "colab": {
@ -203,7 +204,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"source": [ "source": [
"#@title loading utils\n", "# @title loading utils\n",
"import torch\n", "import torch\n",
"from omegaconf import OmegaConf\n", "from omegaconf import OmegaConf\n",
"\n", "\n",
@ -212,7 +213,7 @@
"\n", "\n",
"def load_model_from_config(config, ckpt):\n", "def load_model_from_config(config, ckpt):\n",
" print(f\"Loading model from {ckpt}\")\n", " print(f\"Loading model from {ckpt}\")\n",
" pl_sd = torch.load(ckpt)#, map_location=\"cpu\")\n", " pl_sd = torch.load(ckpt) # , map_location=\"cpu\")\n",
" sd = pl_sd[\"state_dict\"]\n", " sd = pl_sd[\"state_dict\"]\n",
" model = instantiate_from_config(config.model)\n", " model = instantiate_from_config(config.model)\n",
" m, u = model.load_state_dict(sd, strict=False)\n", " m, u = model.load_state_dict(sd, strict=False)\n",
@ -222,7 +223,7 @@
"\n", "\n",
"\n", "\n",
"def get_model():\n", "def get_model():\n",
" config = OmegaConf.load(\"configs/latent-diffusion/cin256-v2.yaml\") \n", " config = OmegaConf.load(\"configs/latent-diffusion/cin256-v2.yaml\")\n",
" model = load_model_from_config(config, \"models/ldm/cin256-v2/model.ckpt\")\n", " model = load_model_from_config(config, \"models/ldm/cin256-v2/model.ckpt\")\n",
" return model" " return model"
], ],
@ -276,7 +277,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"source": [ "source": [
"import numpy as np \n", "import numpy as np\n",
"from PIL import Image\n", "from PIL import Image\n",
"from einops import rearrange\n", "from einops import rearrange\n",
"from torchvision.utils import make_grid\n", "from torchvision.utils import make_grid\n",
@ -295,36 +296,39 @@
"with torch.no_grad():\n", "with torch.no_grad():\n",
" with model.ema_scope():\n", " with model.ema_scope():\n",
" uc = model.get_learned_conditioning(\n", " uc = model.get_learned_conditioning(\n",
" {model.cond_stage_key: torch.tensor(n_samples_per_class*[1000]).to(model.device)}\n", " {model.cond_stage_key: torch.tensor(n_samples_per_class * [1000]).to(model.device)}\n",
" )\n", " )\n",
" \n", "\n",
" for class_label in classes:\n", " for class_label in classes:\n",
" print(f\"rendering {n_samples_per_class} examples of class '{class_label}' in {ddim_steps} steps and using s={scale:.2f}.\")\n", " print(\n",
" xc = torch.tensor(n_samples_per_class*[class_label])\n", " f\"rendering {n_samples_per_class} examples of class '{class_label}' in {ddim_steps} steps and using s={scale:.2f}.\"\n",
" )\n",
" xc = torch.tensor(n_samples_per_class * [class_label])\n",
" c = model.get_learned_conditioning({model.cond_stage_key: xc.to(model.device)})\n", " c = model.get_learned_conditioning({model.cond_stage_key: xc.to(model.device)})\n",
" \n", "\n",
" samples_ddim, _ = sampler.sample(S=ddim_steps,\n", " samples_ddim, _ = sampler.sample(\n",
" S=ddim_steps,\n",
" conditioning=c,\n", " conditioning=c,\n",
" batch_size=n_samples_per_class,\n", " batch_size=n_samples_per_class,\n",
" shape=[3, 64, 64],\n", " shape=[3, 64, 64],\n",
" verbose=False,\n", " verbose=False,\n",
" unconditional_guidance_scale=scale,\n", " unconditional_guidance_scale=scale,\n",
" unconditional_conditioning=uc, \n", " unconditional_conditioning=uc,\n",
" eta=ddim_eta)\n", " eta=ddim_eta,\n",
" )\n",
"\n", "\n",
" x_samples_ddim = model.decode_first_stage(samples_ddim)\n", " x_samples_ddim = model.decode_first_stage(samples_ddim)\n",
" x_samples_ddim = torch.clamp((x_samples_ddim+1.0)/2.0, \n", " x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)\n",
" min=0.0, max=1.0)\n",
" all_samples.append(x_samples_ddim)\n", " all_samples.append(x_samples_ddim)\n",
"\n", "\n",
"\n", "\n",
"# display as grid\n", "# display as grid\n",
"grid = torch.stack(all_samples, 0)\n", "grid = torch.stack(all_samples, 0)\n",
"grid = rearrange(grid, 'n b c h w -> (n b) c h w')\n", "grid = rearrange(grid, \"n b c h w -> (n b) c h w\")\n",
"grid = make_grid(grid, nrow=n_samples_per_class)\n", "grid = make_grid(grid, nrow=n_samples_per_class)\n",
"\n", "\n",
"# to image\n", "# to image\n",
"grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()\n", "grid = 255.0 * rearrange(grid, \"c h w -> h w c\").cpu().numpy()\n",
"Image.fromarray(grid.astype(np.uint8))" "Image.fromarray(grid.astype(np.uint8))"
], ],
"metadata": { "metadata": {