fix: SDXL Metadata not being retrieved (#4057)

## What type of PR is this? (check all applicable)

- [x] Bug Fix

## Have you discussed this change with the InvokeAI team?
- [x] Yes

## Description

- SDXL Metadata was not being retrieved. This PR fixes it.
This commit is contained in:
Lincoln Stein 2023-07-29 15:37:02 -04:00 committed by GitHub
commit 9a1cfadd8b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 493 additions and 329 deletions

View File

@ -6,8 +6,7 @@ from pydantic import Field
from invokeai.app.invocations.prompt import PromptOutput from invokeai.app.invocations.prompt import PromptOutput
from .baseinvocation import (BaseInvocation, BaseInvocationOutput, from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
InvocationConfig, InvocationContext)
from .math import FloatOutput, IntOutput from .math import FloatOutput, IntOutput
# Pass-through parameter nodes - used by subgraphs # Pass-through parameter nodes - used by subgraphs
@ -68,6 +67,7 @@ class ParamStringInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> StringOutput: def invoke(self, context: InvocationContext) -> StringOutput:
return StringOutput(text=self.text) return StringOutput(text=self.text)
class ParamPromptInvocation(BaseInvocation): class ParamPromptInvocation(BaseInvocation):
"""A prompt input parameter""" """A prompt input parameter"""

View File

@ -139,8 +139,19 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
useHotkeys('s', handleUseSeed, [imageDTO]); useHotkeys('s', handleUseSeed, [imageDTO]);
const handleUsePrompt = useCallback(() => { const handleUsePrompt = useCallback(() => {
recallBothPrompts(metadata?.positive_prompt, metadata?.negative_prompt); recallBothPrompts(
}, [metadata?.negative_prompt, metadata?.positive_prompt, recallBothPrompts]); metadata?.positive_prompt,
metadata?.negative_prompt,
metadata?.positive_style_prompt,
metadata?.negative_style_prompt
);
}, [
metadata?.negative_prompt,
metadata?.positive_prompt,
metadata?.positive_style_prompt,
metadata?.negative_style_prompt,
recallBothPrompts,
]);
useHotkeys('p', handleUsePrompt, [imageDTO]); useHotkeys('p', handleUsePrompt, [imageDTO]);

View File

@ -102,8 +102,19 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
// Recall parameters handlers // Recall parameters handlers
const handleRecallPrompt = useCallback(() => { const handleRecallPrompt = useCallback(() => {
recallBothPrompts(metadata?.positive_prompt, metadata?.negative_prompt); recallBothPrompts(
}, [metadata?.negative_prompt, metadata?.positive_prompt, recallBothPrompts]); metadata?.positive_prompt,
metadata?.negative_prompt,
metadata?.positive_style_prompt,
metadata?.negative_style_prompt
);
}, [
metadata?.negative_prompt,
metadata?.positive_prompt,
metadata?.positive_style_prompt,
metadata?.negative_style_prompt,
recallBothPrompts,
]);
const handleRecallSeed = useCallback(() => { const handleRecallSeed = useCallback(() => {
recallSeed(metadata?.seed); recallSeed(metadata?.seed);

View File

@ -1,5 +1,15 @@
import { useAppToaster } from 'app/components/Toaster'; import { useAppToaster } from 'app/components/Toaster';
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import {
refinerModelChanged,
setNegativeStylePromptSDXL,
setPositiveStylePromptSDXL,
setRefinerAestheticScore,
setRefinerCFGScale,
setRefinerScheduler,
setRefinerStart,
setRefinerSteps,
} from 'features/sdxl/store/sdxlSlice';
import { useCallback } from 'react'; import { useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { UnsafeImageMetadata } from 'services/api/endpoints/images'; import { UnsafeImageMetadata } from 'services/api/endpoints/images';
@ -22,6 +32,10 @@ import {
isValidMainModel, isValidMainModel,
isValidNegativePrompt, isValidNegativePrompt,
isValidPositivePrompt, isValidPositivePrompt,
isValidSDXLNegativeStylePrompt,
isValidSDXLPositiveStylePrompt,
isValidSDXLRefinerAestheticScore,
isValidSDXLRefinerStart,
isValidScheduler, isValidScheduler,
isValidSeed, isValidSeed,
isValidSteps, isValidSteps,
@ -74,17 +88,34 @@ export const useRecallParameters = () => {
* Recall both prompts with toast * Recall both prompts with toast
*/ */
const recallBothPrompts = useCallback( const recallBothPrompts = useCallback(
(positivePrompt: unknown, negativePrompt: unknown) => { (
positivePrompt: unknown,
negativePrompt: unknown,
positiveStylePrompt: unknown,
negativeStylePrompt: unknown
) => {
if ( if (
isValidPositivePrompt(positivePrompt) || isValidPositivePrompt(positivePrompt) ||
isValidNegativePrompt(negativePrompt) isValidNegativePrompt(negativePrompt) ||
isValidSDXLPositiveStylePrompt(positiveStylePrompt) ||
isValidSDXLNegativeStylePrompt(negativeStylePrompt)
) { ) {
if (isValidPositivePrompt(positivePrompt)) { if (isValidPositivePrompt(positivePrompt)) {
dispatch(setPositivePrompt(positivePrompt)); dispatch(setPositivePrompt(positivePrompt));
} }
if (isValidNegativePrompt(negativePrompt)) { if (isValidNegativePrompt(negativePrompt)) {
dispatch(setNegativePrompt(negativePrompt)); dispatch(setNegativePrompt(negativePrompt));
} }
if (isValidSDXLPositiveStylePrompt(positiveStylePrompt)) {
dispatch(setPositiveStylePromptSDXL(positiveStylePrompt));
}
if (isValidSDXLPositiveStylePrompt(negativeStylePrompt)) {
dispatch(setNegativeStylePromptSDXL(negativeStylePrompt));
}
parameterSetToast(); parameterSetToast();
return; return;
} }
@ -123,6 +154,36 @@ export const useRecallParameters = () => {
[dispatch, parameterSetToast, parameterNotSetToast] [dispatch, parameterSetToast, parameterNotSetToast]
); );
/**
* Recall SDXL Positive Style Prompt with toast
*/
const recallSDXLPositiveStylePrompt = useCallback(
(positiveStylePrompt: unknown) => {
if (!isValidSDXLPositiveStylePrompt(positiveStylePrompt)) {
parameterNotSetToast();
return;
}
dispatch(setPositiveStylePromptSDXL(positiveStylePrompt));
parameterSetToast();
},
[dispatch, parameterSetToast, parameterNotSetToast]
);
/**
* Recall SDXL Negative Style Prompt with toast
*/
const recallSDXLNegativeStylePrompt = useCallback(
(negativeStylePrompt: unknown) => {
if (!isValidSDXLNegativeStylePrompt(negativeStylePrompt)) {
parameterNotSetToast();
return;
}
dispatch(setNegativeStylePromptSDXL(negativeStylePrompt));
parameterSetToast();
},
[dispatch, parameterSetToast, parameterNotSetToast]
);
/** /**
* Recall seed with toast * Recall seed with toast
*/ */
@ -271,6 +332,14 @@ export const useRecallParameters = () => {
steps, steps,
width, width,
strength, strength,
positive_style_prompt,
negative_style_prompt,
refiner_model,
refiner_cfg_scale,
refiner_steps,
refiner_scheduler,
refiner_aesthetic_store,
refiner_start,
} = metadata; } = metadata;
if (isValidCfgScale(cfg_scale)) { if (isValidCfgScale(cfg_scale)) {
@ -304,6 +373,38 @@ export const useRecallParameters = () => {
dispatch(setImg2imgStrength(strength)); dispatch(setImg2imgStrength(strength));
} }
if (isValidSDXLPositiveStylePrompt(positive_style_prompt)) {
dispatch(setPositiveStylePromptSDXL(positive_style_prompt));
}
if (isValidSDXLNegativeStylePrompt(negative_style_prompt)) {
dispatch(setNegativeStylePromptSDXL(negative_style_prompt));
}
if (isValidMainModel(refiner_model)) {
dispatch(refinerModelChanged(refiner_model));
}
if (isValidSteps(refiner_steps)) {
dispatch(setRefinerSteps(refiner_steps));
}
if (isValidCfgScale(refiner_cfg_scale)) {
dispatch(setRefinerCFGScale(refiner_cfg_scale));
}
if (isValidScheduler(refiner_scheduler)) {
dispatch(setRefinerScheduler(refiner_scheduler));
}
if (isValidSDXLRefinerAestheticScore(refiner_aesthetic_store)) {
dispatch(setRefinerAestheticScore(refiner_aesthetic_store));
}
if (isValidSDXLRefinerStart(refiner_start)) {
dispatch(setRefinerStart(refiner_start));
}
allParameterSetToast(); allParameterSetToast();
}, },
[allParameterNotSetToast, allParameterSetToast, dispatch] [allParameterNotSetToast, allParameterSetToast, dispatch]
@ -313,6 +414,8 @@ export const useRecallParameters = () => {
recallBothPrompts, recallBothPrompts,
recallPositivePrompt, recallPositivePrompt,
recallNegativePrompt, recallNegativePrompt,
recallSDXLPositiveStylePrompt,
recallSDXLNegativeStylePrompt,
recallSeed, recallSeed,
recallCfgScale, recallCfgScale,
recallModel, recallModel,

View File

@ -310,6 +310,39 @@ export type PrecisionParam = z.infer<typeof zPrecision>;
export const isValidPrecision = (val: unknown): val is PrecisionParam => export const isValidPrecision = (val: unknown): val is PrecisionParam =>
zPrecision.safeParse(val).success; zPrecision.safeParse(val).success;
/**
* Zod schema for SDXL refiner aesthetic score parameter
*/
export const zSDXLRefinerAestheticScore = z.number().min(1).max(10);
/**
* Type alias for SDXL refiner aesthetic score parameter, inferred from its zod schema
*/
export type SDXLRefinerAestheticScoreParam = z.infer<
typeof zSDXLRefinerAestheticScore
>;
/**
* Validates/type-guards a value as a SDXL refiner aesthetic score parameter
*/
export const isValidSDXLRefinerAestheticScore = (
val: unknown
): val is SDXLRefinerAestheticScoreParam =>
zSDXLRefinerAestheticScore.safeParse(val).success;
/**
* Zod schema for SDXL start parameter
*/
export const zSDXLRefinerstart = z.number().min(0).max(1);
/**
* Type alias for SDXL start, inferred from its zod schema
*/
export type SDXLRefinerStartParam = z.infer<typeof zSDXLRefinerstart>;
/**
* Validates/type-guards a value as a SDXL refiner aesthetic score parameter
*/
export const isValidSDXLRefinerStart = (
val: unknown
): val is SDXLRefinerStartParam => zSDXLRefinerstart.safeParse(val).success;
// /** // /**
// * Zod schema for BaseModelType // * Zod schema for BaseModelType
// */ // */

View File

@ -21,8 +21,8 @@ export default function ParamSDXLConcatButton() {
return ( return (
<IAIIconButton <IAIIconButton
aria-label="Concat" aria-label="Concatenate Prompt & Style"
tooltip="Concatenates Basic Prompt with Style (Recommended)" tooltip="Concatenate Prompt & Style"
variant="outline" variant="outline"
isChecked={shouldConcatSDXLStylePrompt} isChecked={shouldConcatSDXLStylePrompt}
onClick={handleShouldConcatPromptChange} onClick={handleShouldConcatPromptChange}

View File

@ -40,7 +40,7 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"#@title 1. Check current GPU assigned\n", "# @title 1. Check current GPU assigned\n",
"!nvidia-smi -L\n", "!nvidia-smi -L\n",
"!nvidia-smi" "!nvidia-smi"
] ]
@ -54,7 +54,7 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"#@title 2. Download stable-diffusion Repository\n", "# @title 2. Download stable-diffusion Repository\n",
"from os.path import exists\n", "from os.path import exists\n",
"\n", "\n",
"!git clone --quiet https://github.com/invoke-ai/InvokeAI.git # Original repo\n", "!git clone --quiet https://github.com/invoke-ai/InvokeAI.git # Original repo\n",
@ -71,7 +71,7 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"#@title 3. Install dependencies\n", "# @title 3. Install dependencies\n",
"import gc\n", "import gc\n",
"\n", "\n",
"!wget https://raw.githubusercontent.com/invoke-ai/InvokeAI/development/environments-and-requirements/requirements-base.txt\n", "!wget https://raw.githubusercontent.com/invoke-ai/InvokeAI/development/environments-and-requirements/requirements-base.txt\n",
@ -92,7 +92,7 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"#@title 4. Restart Runtime\n", "# @title 4. Restart Runtime\n",
"exit()" "exit()"
] ]
}, },
@ -105,8 +105,9 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"#@title 5. Load small ML models required\n", "# @title 5. Load small ML models required\n",
"import gc\n", "import gc\n",
"\n",
"%cd /content/InvokeAI/\n", "%cd /content/InvokeAI/\n",
"!python scripts/preload_models.py\n", "!python scripts/preload_models.py\n",
"gc.collect()" "gc.collect()"
@ -130,9 +131,10 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"#@title 6. Mount google Drive\n", "# @title 6. Mount google Drive\n",
"from google.colab import drive\n", "from google.colab import drive\n",
"drive.mount('/content/drive')" "\n",
"drive.mount(\"/content/drive\")"
] ]
}, },
{ {
@ -144,16 +146,16 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"#@title 7. Drive Path to model\n", "# @title 7. Drive Path to model\n",
"#@markdown Path should start with /content/drive/path-to-your-file <br>\n", "# @markdown Path should start with /content/drive/path-to-your-file <br>\n",
"#@markdown <font color=\"red\">Note:</font> Model should be downloaded from https://huggingface.co <br>\n", "# @markdown <font color=\"red\">Note:</font> Model should be downloaded from https://huggingface.co <br>\n",
"#@markdown Lastest release: [Stable-Diffusion-v-1-4](https://huggingface.co/CompVis/stable-diffusion-v-1-4-original)\n", "# @markdown Lastest release: [Stable-Diffusion-v-1-4](https://huggingface.co/CompVis/stable-diffusion-v-1-4-original)\n",
"from os.path import exists\n", "from os.path import exists\n",
"\n", "\n",
"model_path = \"\" #@param {type:\"string\"}\n", "model_path = \"\" # @param {type:\"string\"}\n",
"if exists(model_path):\n", "if exists(model_path):\n",
" print(\"✅ Valid directory\")\n", " print(\"✅ Valid directory\")\n",
"else: \n", "else:\n",
" print(\"❌ File doesn't exist\")" " print(\"❌ File doesn't exist\")"
] ]
}, },
@ -166,10 +168,10 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"#@title 8. Symlink to model\n", "# @title 8. Symlink to model\n",
"\n", "\n",
"from os.path import exists\n", "from os.path import exists\n",
"import os \n", "import os\n",
"\n", "\n",
"# Folder creation if it doesn't exist\n", "# Folder creation if it doesn't exist\n",
"if exists(\"/content/InvokeAI/models/ldm/stable-diffusion-v1\"):\n", "if exists(\"/content/InvokeAI/models/ldm/stable-diffusion-v1\"):\n",
@ -181,10 +183,10 @@
"# Symbolic link if it doesn't exist\n", "# Symbolic link if it doesn't exist\n",
"if exists(\"/content/InvokeAI/models/ldm/stable-diffusion-v1/model.ckpt\"):\n", "if exists(\"/content/InvokeAI/models/ldm/stable-diffusion-v1/model.ckpt\"):\n",
" print(\"❗ Symlink already created\")\n", " print(\"❗ Symlink already created\")\n",
"else: \n", "else:\n",
" src = model_path\n", " src = model_path\n",
" dst = '/content/InvokeAI/models/ldm/stable-diffusion-v1/model.ckpt'\n", " dst = \"/content/InvokeAI/models/ldm/stable-diffusion-v1/model.ckpt\"\n",
" os.symlink(src, dst) \n", " os.symlink(src, dst)\n",
" print(\"✅ Symbolic link created successfully\")" " print(\"✅ Symbolic link created successfully\")"
] ]
}, },
@ -206,12 +208,12 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"#@title 9. Run Terminal and Execute Dream bot\n", "# @title 9. Run Terminal and Execute Dream bot\n",
"#@markdown <font color=\"blue\">Steps:</font> <br>\n", "# @markdown <font color=\"blue\">Steps:</font> <br>\n",
"#@markdown 1. Execute command `python scripts/invoke.py` to run InvokeAI.<br>\n", "# @markdown 1. Execute command `python scripts/invoke.py` to run InvokeAI.<br>\n",
"#@markdown 2. After initialized you'll see `Dream>` line.<br>\n", "# @markdown 2. After initialized you'll see `Dream>` line.<br>\n",
"#@markdown 3. Example text: `Astronaut floating in a distant galaxy` <br>\n", "# @markdown 3. Example text: `Astronaut floating in a distant galaxy` <br>\n",
"#@markdown 4. To quit Dream bot use: `q` command.<br>\n", "# @markdown 4. To quit Dream bot use: `q` command.<br>\n",
"\n", "\n",
"%load_ext colabxterm\n", "%load_ext colabxterm\n",
"%xterm\n", "%xterm\n",

View File

@ -52,17 +52,17 @@
"name": "stdout", "name": "stdout",
"text": [ "text": [
"Cloning into 'latent-diffusion'...\n", "Cloning into 'latent-diffusion'...\n",
"remote: Enumerating objects: 992, done.\u001B[K\n", "remote: Enumerating objects: 992, done.\u001b[K\n",
"remote: Counting objects: 100% (695/695), done.\u001B[K\n", "remote: Counting objects: 100% (695/695), done.\u001b[K\n",
"remote: Compressing objects: 100% (397/397), done.\u001B[K\n", "remote: Compressing objects: 100% (397/397), done.\u001b[K\n",
"remote: Total 992 (delta 375), reused 564 (delta 253), pack-reused 297\u001B[K\n", "remote: Total 992 (delta 375), reused 564 (delta 253), pack-reused 297\u001b[K\n",
"Receiving objects: 100% (992/992), 30.78 MiB | 29.43 MiB/s, done.\n", "Receiving objects: 100% (992/992), 30.78 MiB | 29.43 MiB/s, done.\n",
"Resolving deltas: 100% (510/510), done.\n", "Resolving deltas: 100% (510/510), done.\n",
"Cloning into 'taming-transformers'...\n", "Cloning into 'taming-transformers'...\n",
"remote: Enumerating objects: 1335, done.\u001B[K\n", "remote: Enumerating objects: 1335, done.\u001b[K\n",
"remote: Counting objects: 100% (525/525), done.\u001B[K\n", "remote: Counting objects: 100% (525/525), done.\u001b[K\n",
"remote: Compressing objects: 100% (493/493), done.\u001B[K\n", "remote: Compressing objects: 100% (493/493), done.\u001b[K\n",
"remote: Total 1335 (delta 58), reused 481 (delta 30), pack-reused 810\u001B[K\n", "remote: Total 1335 (delta 58), reused 481 (delta 30), pack-reused 810\u001b[K\n",
"Receiving objects: 100% (1335/1335), 412.35 MiB | 30.53 MiB/s, done.\n", "Receiving objects: 100% (1335/1335), 412.35 MiB | 30.53 MiB/s, done.\n",
"Resolving deltas: 100% (267/267), done.\n", "Resolving deltas: 100% (267/267), done.\n",
"Obtaining file:///content/taming-transformers\n", "Obtaining file:///content/taming-transformers\n",
@ -73,23 +73,24 @@
"Installing collected packages: taming-transformers\n", "Installing collected packages: taming-transformers\n",
" Running setup.py develop for taming-transformers\n", " Running setup.py develop for taming-transformers\n",
"Successfully installed taming-transformers-0.0.1\n", "Successfully installed taming-transformers-0.0.1\n",
"\u001B[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
"tensorflow 2.8.0 requires tf-estimator-nightly==2.8.0.dev2021122109, which is not installed.\n", "tensorflow 2.8.0 requires tf-estimator-nightly==2.8.0.dev2021122109, which is not installed.\n",
"arviz 0.11.4 requires typing-extensions<4,>=3.7.4.3, but you have typing-extensions 4.1.1 which is incompatible.\u001B[0m\n" "arviz 0.11.4 requires typing-extensions<4,>=3.7.4.3, but you have typing-extensions 4.1.1 which is incompatible.\u001b[0m\n"
] ]
} }
], ],
"source": [ "source": [
"#@title Installation\n", "# @title Installation\n",
"!git clone https://github.com/CompVis/latent-diffusion.git\n", "!git clone https://github.com/CompVis/latent-diffusion.git\n",
"!git clone https://github.com/CompVis/taming-transformers\n", "!git clone https://github.com/CompVis/taming-transformers\n",
"!pip install -e ./taming-transformers\n", "!pip install -e ./taming-transformers\n",
"!pip install omegaconf>=2.0.0 pytorch-lightning>=1.0.8 torch-fidelity einops\n", "!pip install omegaconf>=2.0.0 pytorch-lightning>=1.0.8 torch-fidelity einops\n",
"\n", "\n",
"import sys\n", "import sys\n",
"\n",
"sys.path.append(\".\")\n", "sys.path.append(\".\")\n",
"sys.path.append('./taming-transformers')\n", "sys.path.append(\"./taming-transformers\")\n",
"from taming.models import vqgan " "from taming.models import vqgan"
] ]
}, },
{ {
@ -104,11 +105,11 @@
{ {
"cell_type": "code", "cell_type": "code",
"source": [ "source": [
"#@title Download\n", "# @title Download\n",
"%cd latent-diffusion/ \n", "%cd latent-diffusion/\n",
"\n", "\n",
"!mkdir -p models/ldm/cin256-v2/\n", "!mkdir -p models/ldm/cin256-v2/\n",
"!wget -O models/ldm/cin256-v2/model.ckpt https://ommer-lab.com/files/latent-diffusion/nitro/cin/model.ckpt " "!wget -O models/ldm/cin256-v2/model.ckpt https://ommer-lab.com/files/latent-diffusion/nitro/cin/model.ckpt"
], ],
"metadata": { "metadata": {
"colab": { "colab": {
@ -203,7 +204,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"source": [ "source": [
"#@title loading utils\n", "# @title loading utils\n",
"import torch\n", "import torch\n",
"from omegaconf import OmegaConf\n", "from omegaconf import OmegaConf\n",
"\n", "\n",
@ -212,7 +213,7 @@
"\n", "\n",
"def load_model_from_config(config, ckpt):\n", "def load_model_from_config(config, ckpt):\n",
" print(f\"Loading model from {ckpt}\")\n", " print(f\"Loading model from {ckpt}\")\n",
" pl_sd = torch.load(ckpt)#, map_location=\"cpu\")\n", " pl_sd = torch.load(ckpt) # , map_location=\"cpu\")\n",
" sd = pl_sd[\"state_dict\"]\n", " sd = pl_sd[\"state_dict\"]\n",
" model = instantiate_from_config(config.model)\n", " model = instantiate_from_config(config.model)\n",
" m, u = model.load_state_dict(sd, strict=False)\n", " m, u = model.load_state_dict(sd, strict=False)\n",
@ -222,7 +223,7 @@
"\n", "\n",
"\n", "\n",
"def get_model():\n", "def get_model():\n",
" config = OmegaConf.load(\"configs/latent-diffusion/cin256-v2.yaml\") \n", " config = OmegaConf.load(\"configs/latent-diffusion/cin256-v2.yaml\")\n",
" model = load_model_from_config(config, \"models/ldm/cin256-v2/model.ckpt\")\n", " model = load_model_from_config(config, \"models/ldm/cin256-v2/model.ckpt\")\n",
" return model" " return model"
], ],
@ -276,7 +277,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"source": [ "source": [
"import numpy as np \n", "import numpy as np\n",
"from PIL import Image\n", "from PIL import Image\n",
"from einops import rearrange\n", "from einops import rearrange\n",
"from torchvision.utils import make_grid\n", "from torchvision.utils import make_grid\n",
@ -295,36 +296,39 @@
"with torch.no_grad():\n", "with torch.no_grad():\n",
" with model.ema_scope():\n", " with model.ema_scope():\n",
" uc = model.get_learned_conditioning(\n", " uc = model.get_learned_conditioning(\n",
" {model.cond_stage_key: torch.tensor(n_samples_per_class*[1000]).to(model.device)}\n", " {model.cond_stage_key: torch.tensor(n_samples_per_class * [1000]).to(model.device)}\n",
" )\n", " )\n",
" \n", "\n",
" for class_label in classes:\n", " for class_label in classes:\n",
" print(f\"rendering {n_samples_per_class} examples of class '{class_label}' in {ddim_steps} steps and using s={scale:.2f}.\")\n", " print(\n",
" xc = torch.tensor(n_samples_per_class*[class_label])\n", " f\"rendering {n_samples_per_class} examples of class '{class_label}' in {ddim_steps} steps and using s={scale:.2f}.\"\n",
" )\n",
" xc = torch.tensor(n_samples_per_class * [class_label])\n",
" c = model.get_learned_conditioning({model.cond_stage_key: xc.to(model.device)})\n", " c = model.get_learned_conditioning({model.cond_stage_key: xc.to(model.device)})\n",
" \n", "\n",
" samples_ddim, _ = sampler.sample(S=ddim_steps,\n", " samples_ddim, _ = sampler.sample(\n",
" S=ddim_steps,\n",
" conditioning=c,\n", " conditioning=c,\n",
" batch_size=n_samples_per_class,\n", " batch_size=n_samples_per_class,\n",
" shape=[3, 64, 64],\n", " shape=[3, 64, 64],\n",
" verbose=False,\n", " verbose=False,\n",
" unconditional_guidance_scale=scale,\n", " unconditional_guidance_scale=scale,\n",
" unconditional_conditioning=uc, \n", " unconditional_conditioning=uc,\n",
" eta=ddim_eta)\n", " eta=ddim_eta,\n",
" )\n",
"\n", "\n",
" x_samples_ddim = model.decode_first_stage(samples_ddim)\n", " x_samples_ddim = model.decode_first_stage(samples_ddim)\n",
" x_samples_ddim = torch.clamp((x_samples_ddim+1.0)/2.0, \n", " x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)\n",
" min=0.0, max=1.0)\n",
" all_samples.append(x_samples_ddim)\n", " all_samples.append(x_samples_ddim)\n",
"\n", "\n",
"\n", "\n",
"# display as grid\n", "# display as grid\n",
"grid = torch.stack(all_samples, 0)\n", "grid = torch.stack(all_samples, 0)\n",
"grid = rearrange(grid, 'n b c h w -> (n b) c h w')\n", "grid = rearrange(grid, \"n b c h w -> (n b) c h w\")\n",
"grid = make_grid(grid, nrow=n_samples_per_class)\n", "grid = make_grid(grid, nrow=n_samples_per_class)\n",
"\n", "\n",
"# to image\n", "# to image\n",
"grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()\n", "grid = 255.0 * rearrange(grid, \"c h w -> h w c\").cpu().numpy()\n",
"Image.fromarray(grid.astype(np.uint8))" "Image.fromarray(grid.astype(np.uint8))"
], ],
"metadata": { "metadata": {