mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix: SDXL Metadata not being retrieved (#4057)
## What type of PR is this? (check all applicable) - [x] Bug Fix ## Have you discussed this change with the InvokeAI team? - [x] Yes ## Description - SDXL Metadata was not being retrieved. This PR fixes it.
This commit is contained in:
commit
9a1cfadd8b
@ -6,8 +6,7 @@ from pydantic import Field
|
||||
|
||||
from invokeai.app.invocations.prompt import PromptOutput
|
||||
|
||||
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
|
||||
InvocationConfig, InvocationContext)
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
|
||||
from .math import FloatOutput, IntOutput
|
||||
|
||||
# Pass-through parameter nodes - used by subgraphs
|
||||
@ -68,6 +67,7 @@ class ParamStringInvocation(BaseInvocation):
|
||||
def invoke(self, context: InvocationContext) -> StringOutput:
|
||||
return StringOutput(text=self.text)
|
||||
|
||||
|
||||
class ParamPromptInvocation(BaseInvocation):
|
||||
"""A prompt input parameter"""
|
||||
|
||||
@ -80,4 +80,4 @@ class ParamPromptInvocation(BaseInvocation):
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> PromptOutput:
|
||||
return PromptOutput(prompt=self.prompt)
|
||||
return PromptOutput(prompt=self.prompt)
|
||||
|
@ -139,8 +139,19 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
useHotkeys('s', handleUseSeed, [imageDTO]);
|
||||
|
||||
const handleUsePrompt = useCallback(() => {
|
||||
recallBothPrompts(metadata?.positive_prompt, metadata?.negative_prompt);
|
||||
}, [metadata?.negative_prompt, metadata?.positive_prompt, recallBothPrompts]);
|
||||
recallBothPrompts(
|
||||
metadata?.positive_prompt,
|
||||
metadata?.negative_prompt,
|
||||
metadata?.positive_style_prompt,
|
||||
metadata?.negative_style_prompt
|
||||
);
|
||||
}, [
|
||||
metadata?.negative_prompt,
|
||||
metadata?.positive_prompt,
|
||||
metadata?.positive_style_prompt,
|
||||
metadata?.negative_style_prompt,
|
||||
recallBothPrompts,
|
||||
]);
|
||||
|
||||
useHotkeys('p', handleUsePrompt, [imageDTO]);
|
||||
|
||||
|
@ -102,8 +102,19 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
|
||||
|
||||
// Recall parameters handlers
|
||||
const handleRecallPrompt = useCallback(() => {
|
||||
recallBothPrompts(metadata?.positive_prompt, metadata?.negative_prompt);
|
||||
}, [metadata?.negative_prompt, metadata?.positive_prompt, recallBothPrompts]);
|
||||
recallBothPrompts(
|
||||
metadata?.positive_prompt,
|
||||
metadata?.negative_prompt,
|
||||
metadata?.positive_style_prompt,
|
||||
metadata?.negative_style_prompt
|
||||
);
|
||||
}, [
|
||||
metadata?.negative_prompt,
|
||||
metadata?.positive_prompt,
|
||||
metadata?.positive_style_prompt,
|
||||
metadata?.negative_style_prompt,
|
||||
recallBothPrompts,
|
||||
]);
|
||||
|
||||
const handleRecallSeed = useCallback(() => {
|
||||
recallSeed(metadata?.seed);
|
||||
|
@ -1,5 +1,15 @@
|
||||
import { useAppToaster } from 'app/components/Toaster';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import {
|
||||
refinerModelChanged,
|
||||
setNegativeStylePromptSDXL,
|
||||
setPositiveStylePromptSDXL,
|
||||
setRefinerAestheticScore,
|
||||
setRefinerCFGScale,
|
||||
setRefinerScheduler,
|
||||
setRefinerStart,
|
||||
setRefinerSteps,
|
||||
} from 'features/sdxl/store/sdxlSlice';
|
||||
import { useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { UnsafeImageMetadata } from 'services/api/endpoints/images';
|
||||
@ -22,6 +32,10 @@ import {
|
||||
isValidMainModel,
|
||||
isValidNegativePrompt,
|
||||
isValidPositivePrompt,
|
||||
isValidSDXLNegativeStylePrompt,
|
||||
isValidSDXLPositiveStylePrompt,
|
||||
isValidSDXLRefinerAestheticScore,
|
||||
isValidSDXLRefinerStart,
|
||||
isValidScheduler,
|
||||
isValidSeed,
|
||||
isValidSteps,
|
||||
@ -74,17 +88,34 @@ export const useRecallParameters = () => {
|
||||
* Recall both prompts with toast
|
||||
*/
|
||||
const recallBothPrompts = useCallback(
|
||||
(positivePrompt: unknown, negativePrompt: unknown) => {
|
||||
(
|
||||
positivePrompt: unknown,
|
||||
negativePrompt: unknown,
|
||||
positiveStylePrompt: unknown,
|
||||
negativeStylePrompt: unknown
|
||||
) => {
|
||||
if (
|
||||
isValidPositivePrompt(positivePrompt) ||
|
||||
isValidNegativePrompt(negativePrompt)
|
||||
isValidNegativePrompt(negativePrompt) ||
|
||||
isValidSDXLPositiveStylePrompt(positiveStylePrompt) ||
|
||||
isValidSDXLNegativeStylePrompt(negativeStylePrompt)
|
||||
) {
|
||||
if (isValidPositivePrompt(positivePrompt)) {
|
||||
dispatch(setPositivePrompt(positivePrompt));
|
||||
}
|
||||
|
||||
if (isValidNegativePrompt(negativePrompt)) {
|
||||
dispatch(setNegativePrompt(negativePrompt));
|
||||
}
|
||||
|
||||
if (isValidSDXLPositiveStylePrompt(positiveStylePrompt)) {
|
||||
dispatch(setPositiveStylePromptSDXL(positiveStylePrompt));
|
||||
}
|
||||
|
||||
if (isValidSDXLPositiveStylePrompt(negativeStylePrompt)) {
|
||||
dispatch(setNegativeStylePromptSDXL(negativeStylePrompt));
|
||||
}
|
||||
|
||||
parameterSetToast();
|
||||
return;
|
||||
}
|
||||
@ -123,6 +154,36 @@ export const useRecallParameters = () => {
|
||||
[dispatch, parameterSetToast, parameterNotSetToast]
|
||||
);
|
||||
|
||||
/**
|
||||
* Recall SDXL Positive Style Prompt with toast
|
||||
*/
|
||||
const recallSDXLPositiveStylePrompt = useCallback(
|
||||
(positiveStylePrompt: unknown) => {
|
||||
if (!isValidSDXLPositiveStylePrompt(positiveStylePrompt)) {
|
||||
parameterNotSetToast();
|
||||
return;
|
||||
}
|
||||
dispatch(setPositiveStylePromptSDXL(positiveStylePrompt));
|
||||
parameterSetToast();
|
||||
},
|
||||
[dispatch, parameterSetToast, parameterNotSetToast]
|
||||
);
|
||||
|
||||
/**
|
||||
* Recall SDXL Negative Style Prompt with toast
|
||||
*/
|
||||
const recallSDXLNegativeStylePrompt = useCallback(
|
||||
(negativeStylePrompt: unknown) => {
|
||||
if (!isValidSDXLNegativeStylePrompt(negativeStylePrompt)) {
|
||||
parameterNotSetToast();
|
||||
return;
|
||||
}
|
||||
dispatch(setNegativeStylePromptSDXL(negativeStylePrompt));
|
||||
parameterSetToast();
|
||||
},
|
||||
[dispatch, parameterSetToast, parameterNotSetToast]
|
||||
);
|
||||
|
||||
/**
|
||||
* Recall seed with toast
|
||||
*/
|
||||
@ -271,6 +332,14 @@ export const useRecallParameters = () => {
|
||||
steps,
|
||||
width,
|
||||
strength,
|
||||
positive_style_prompt,
|
||||
negative_style_prompt,
|
||||
refiner_model,
|
||||
refiner_cfg_scale,
|
||||
refiner_steps,
|
||||
refiner_scheduler,
|
||||
refiner_aesthetic_store,
|
||||
refiner_start,
|
||||
} = metadata;
|
||||
|
||||
if (isValidCfgScale(cfg_scale)) {
|
||||
@ -304,6 +373,38 @@ export const useRecallParameters = () => {
|
||||
dispatch(setImg2imgStrength(strength));
|
||||
}
|
||||
|
||||
if (isValidSDXLPositiveStylePrompt(positive_style_prompt)) {
|
||||
dispatch(setPositiveStylePromptSDXL(positive_style_prompt));
|
||||
}
|
||||
|
||||
if (isValidSDXLNegativeStylePrompt(negative_style_prompt)) {
|
||||
dispatch(setNegativeStylePromptSDXL(negative_style_prompt));
|
||||
}
|
||||
|
||||
if (isValidMainModel(refiner_model)) {
|
||||
dispatch(refinerModelChanged(refiner_model));
|
||||
}
|
||||
|
||||
if (isValidSteps(refiner_steps)) {
|
||||
dispatch(setRefinerSteps(refiner_steps));
|
||||
}
|
||||
|
||||
if (isValidCfgScale(refiner_cfg_scale)) {
|
||||
dispatch(setRefinerCFGScale(refiner_cfg_scale));
|
||||
}
|
||||
|
||||
if (isValidScheduler(refiner_scheduler)) {
|
||||
dispatch(setRefinerScheduler(refiner_scheduler));
|
||||
}
|
||||
|
||||
if (isValidSDXLRefinerAestheticScore(refiner_aesthetic_store)) {
|
||||
dispatch(setRefinerAestheticScore(refiner_aesthetic_store));
|
||||
}
|
||||
|
||||
if (isValidSDXLRefinerStart(refiner_start)) {
|
||||
dispatch(setRefinerStart(refiner_start));
|
||||
}
|
||||
|
||||
allParameterSetToast();
|
||||
},
|
||||
[allParameterNotSetToast, allParameterSetToast, dispatch]
|
||||
@ -313,6 +414,8 @@ export const useRecallParameters = () => {
|
||||
recallBothPrompts,
|
||||
recallPositivePrompt,
|
||||
recallNegativePrompt,
|
||||
recallSDXLPositiveStylePrompt,
|
||||
recallSDXLNegativeStylePrompt,
|
||||
recallSeed,
|
||||
recallCfgScale,
|
||||
recallModel,
|
||||
|
@ -310,6 +310,39 @@ export type PrecisionParam = z.infer<typeof zPrecision>;
|
||||
export const isValidPrecision = (val: unknown): val is PrecisionParam =>
|
||||
zPrecision.safeParse(val).success;
|
||||
|
||||
/**
|
||||
* Zod schema for SDXL refiner aesthetic score parameter
|
||||
*/
|
||||
export const zSDXLRefinerAestheticScore = z.number().min(1).max(10);
|
||||
/**
|
||||
* Type alias for SDXL refiner aesthetic score parameter, inferred from its zod schema
|
||||
*/
|
||||
export type SDXLRefinerAestheticScoreParam = z.infer<
|
||||
typeof zSDXLRefinerAestheticScore
|
||||
>;
|
||||
/**
|
||||
* Validates/type-guards a value as a SDXL refiner aesthetic score parameter
|
||||
*/
|
||||
export const isValidSDXLRefinerAestheticScore = (
|
||||
val: unknown
|
||||
): val is SDXLRefinerAestheticScoreParam =>
|
||||
zSDXLRefinerAestheticScore.safeParse(val).success;
|
||||
|
||||
/**
|
||||
* Zod schema for SDXL start parameter
|
||||
*/
|
||||
export const zSDXLRefinerstart = z.number().min(0).max(1);
|
||||
/**
|
||||
* Type alias for SDXL start, inferred from its zod schema
|
||||
*/
|
||||
export type SDXLRefinerStartParam = z.infer<typeof zSDXLRefinerstart>;
|
||||
/**
|
||||
* Validates/type-guards a value as a SDXL refiner aesthetic score parameter
|
||||
*/
|
||||
export const isValidSDXLRefinerStart = (
|
||||
val: unknown
|
||||
): val is SDXLRefinerStartParam => zSDXLRefinerstart.safeParse(val).success;
|
||||
|
||||
// /**
|
||||
// * Zod schema for BaseModelType
|
||||
// */
|
||||
|
@ -21,8 +21,8 @@ export default function ParamSDXLConcatButton() {
|
||||
|
||||
return (
|
||||
<IAIIconButton
|
||||
aria-label="Concat"
|
||||
tooltip="Concatenates Basic Prompt with Style (Recommended)"
|
||||
aria-label="Concatenate Prompt & Style"
|
||||
tooltip="Concatenate Prompt & Style"
|
||||
variant="outline"
|
||||
isChecked={shouldConcatSDXLStylePrompt}
|
||||
onClick={handleShouldConcatPromptChange}
|
||||
|
@ -1,281 +1,283 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "ycYWcsEKc6w7"
|
||||
},
|
||||
"source": [
|
||||
"# Stable Diffusion AI Notebook (Release 2.0.0)\n",
|
||||
"\n",
|
||||
"<img src=\"https://user-images.githubusercontent.com/60411196/186547976-d9de378a-9de8-4201-9c25-c057a9c59bad.jpeg\" alt=\"stable-diffusion-ai\" width=\"170px\"/> <br>\n",
|
||||
"#### Instructions:\n",
|
||||
"1. Execute each cell in order to mount a Dream bot and create images from text. <br>\n",
|
||||
"2. Once cells 1-8 were run correctly you'll be executing a terminal in cell #9, you'll need to enter `python scripts/dream.py` command to run Dream bot.<br> \n",
|
||||
"3. After launching dream bot, you'll see: <br> `Dream > ` in terminal. <br> Insert a command, eg. `Dream > Astronaut floating in a distant galaxy`, or type `-h` for help.\n",
|
||||
"3. After completion you'll see your generated images in path `stable-diffusion/outputs/img-samples/`, you can also show last generated images in cell #10.\n",
|
||||
"4. To quit Dream bot use `q` command. <br> \n",
|
||||
"---\n",
|
||||
"<font color=\"red\">Note:</font> It takes some time to load, but after installing all dependencies you can use the bot all time you want while colab instance is up. <br>\n",
|
||||
"<font color=\"red\">Requirements:</font> For this notebook to work you need to have [Stable-Diffusion-v-1-4](https://huggingface.co/CompVis/stable-diffusion-v-1-4-original) stored in your Google Drive, it will be needed in cell #7\n",
|
||||
"##### For more details visit Github repository: [invoke-ai/InvokeAI](https://github.com/invoke-ai/InvokeAI)\n",
|
||||
"---\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "dr32VLxlnouf"
|
||||
},
|
||||
"source": [
|
||||
"## ◢ Installation"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"cellView": "form",
|
||||
"id": "a2Z5Qu_o8VtQ"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#@title 1. Check current GPU assigned\n",
|
||||
"!nvidia-smi -L\n",
|
||||
"!nvidia-smi"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"cellView": "form",
|
||||
"id": "vbI9ZsQHzjqF"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#@title 2. Download stable-diffusion Repository\n",
|
||||
"from os.path import exists\n",
|
||||
"\n",
|
||||
"!git clone --quiet https://github.com/invoke-ai/InvokeAI.git # Original repo\n",
|
||||
"%cd /content/InvokeAI/\n",
|
||||
"!git checkout --quiet tags/v2.0.0"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"cellView": "form",
|
||||
"id": "QbXcGXYEFSNB"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#@title 3. Install dependencies\n",
|
||||
"import gc\n",
|
||||
"\n",
|
||||
"!wget https://raw.githubusercontent.com/invoke-ai/InvokeAI/development/environments-and-requirements/requirements-base.txt\n",
|
||||
"!wget https://raw.githubusercontent.com/invoke-ai/InvokeAI/development/environments-and-requirements/requirements-win-colab-cuda.txt\n",
|
||||
"!pip install colab-xterm\n",
|
||||
"!pip install -r requirements-lin-win-colab-CUDA.txt\n",
|
||||
"!pip install clean-fid torchtext\n",
|
||||
"!pip install transformers\n",
|
||||
"gc.collect()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"cellView": "form",
|
||||
"id": "8rSMhgnAttQa"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#@title 4. Restart Runtime\n",
|
||||
"exit()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"cellView": "form",
|
||||
"id": "ChIDWxLVHGGJ"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#@title 5. Load small ML models required\n",
|
||||
"import gc\n",
|
||||
"%cd /content/InvokeAI/\n",
|
||||
"!python scripts/preload_models.py\n",
|
||||
"gc.collect()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "795x1tMoo8b1"
|
||||
},
|
||||
"source": [
|
||||
"## ◢ Configuration"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"cellView": "form",
|
||||
"id": "YEWPV-sF1RDM"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#@title 6. Mount google Drive\n",
|
||||
"from google.colab import drive\n",
|
||||
"drive.mount('/content/drive')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"cellView": "form",
|
||||
"id": "zRTJeZ461WGu"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#@title 7. Drive Path to model\n",
|
||||
"#@markdown Path should start with /content/drive/path-to-your-file <br>\n",
|
||||
"#@markdown <font color=\"red\">Note:</font> Model should be downloaded from https://huggingface.co <br>\n",
|
||||
"#@markdown Lastest release: [Stable-Diffusion-v-1-4](https://huggingface.co/CompVis/stable-diffusion-v-1-4-original)\n",
|
||||
"from os.path import exists\n",
|
||||
"\n",
|
||||
"model_path = \"\" #@param {type:\"string\"}\n",
|
||||
"if exists(model_path):\n",
|
||||
" print(\"✅ Valid directory\")\n",
|
||||
"else: \n",
|
||||
" print(\"❌ File doesn't exist\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"cellView": "form",
|
||||
"id": "UY-NNz4I8_aG"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#@title 8. Symlink to model\n",
|
||||
"\n",
|
||||
"from os.path import exists\n",
|
||||
"import os \n",
|
||||
"\n",
|
||||
"# Folder creation if it doesn't exist\n",
|
||||
"if exists(\"/content/InvokeAI/models/ldm/stable-diffusion-v1\"):\n",
|
||||
" print(\"❗ Dir stable-diffusion-v1 already exists\")\n",
|
||||
"else:\n",
|
||||
" %mkdir /content/InvokeAI/models/ldm/stable-diffusion-v1\n",
|
||||
" print(\"✅ Dir stable-diffusion-v1 created\")\n",
|
||||
"\n",
|
||||
"# Symbolic link if it doesn't exist\n",
|
||||
"if exists(\"/content/InvokeAI/models/ldm/stable-diffusion-v1/model.ckpt\"):\n",
|
||||
" print(\"❗ Symlink already created\")\n",
|
||||
"else: \n",
|
||||
" src = model_path\n",
|
||||
" dst = '/content/InvokeAI/models/ldm/stable-diffusion-v1/model.ckpt'\n",
|
||||
" os.symlink(src, dst) \n",
|
||||
" print(\"✅ Symbolic link created successfully\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "Mc28N0_NrCQH"
|
||||
},
|
||||
"source": [
|
||||
"## ◢ Execution"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"cellView": "form",
|
||||
"id": "ir4hCrMIuUpl"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#@title 9. Run Terminal and Execute Dream bot\n",
|
||||
"#@markdown <font color=\"blue\">Steps:</font> <br>\n",
|
||||
"#@markdown 1. Execute command `python scripts/invoke.py` to run InvokeAI.<br>\n",
|
||||
"#@markdown 2. After initialized you'll see `Dream>` line.<br>\n",
|
||||
"#@markdown 3. Example text: `Astronaut floating in a distant galaxy` <br>\n",
|
||||
"#@markdown 4. To quit Dream bot use: `q` command.<br>\n",
|
||||
"\n",
|
||||
"%load_ext colabxterm\n",
|
||||
"%xterm\n",
|
||||
"gc.collect()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"cellView": "form",
|
||||
"id": "qnLohSHmKoGk"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#@title 10. Show the last 15 generated images\n",
|
||||
"import glob\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import matplotlib.image as mpimg\n",
|
||||
"%matplotlib inline\n",
|
||||
"\n",
|
||||
"images = []\n",
|
||||
"for img_path in sorted(glob.glob('/content/InvokeAI/outputs/img-samples/*.png'), reverse=True):\n",
|
||||
" images.append(mpimg.imread(img_path))\n",
|
||||
"\n",
|
||||
"images = images[:15] \n",
|
||||
"\n",
|
||||
"plt.figure(figsize=(20,10))\n",
|
||||
"\n",
|
||||
"columns = 5\n",
|
||||
"for i, image in enumerate(images):\n",
|
||||
" ax = plt.subplot(len(images) / columns + 1, columns, i + 1)\n",
|
||||
" ax.axes.xaxis.set_visible(False)\n",
|
||||
" ax.axes.yaxis.set_visible(False)\n",
|
||||
" ax.axis('off')\n",
|
||||
" plt.imshow(image)\n",
|
||||
" gc.collect()\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"accelerator": "GPU",
|
||||
"colab": {
|
||||
"collapsed_sections": [],
|
||||
"private_outputs": true,
|
||||
"provenance": []
|
||||
},
|
||||
"gpuClass": "standard",
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3.9.12 64-bit",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"name": "python",
|
||||
"version": "3.9.12"
|
||||
},
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
"hash": "4e870c5c5fe42db7e2c5647ae5af656ff3391bf8c2b729cbf7fa0e16ca8cb5af"
|
||||
}
|
||||
}
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "ycYWcsEKc6w7"
|
||||
},
|
||||
"source": [
|
||||
"# Stable Diffusion AI Notebook (Release 2.0.0)\n",
|
||||
"\n",
|
||||
"<img src=\"https://user-images.githubusercontent.com/60411196/186547976-d9de378a-9de8-4201-9c25-c057a9c59bad.jpeg\" alt=\"stable-diffusion-ai\" width=\"170px\"/> <br>\n",
|
||||
"#### Instructions:\n",
|
||||
"1. Execute each cell in order to mount a Dream bot and create images from text. <br>\n",
|
||||
"2. Once cells 1-8 were run correctly you'll be executing a terminal in cell #9, you'll need to enter `python scripts/dream.py` command to run Dream bot.<br> \n",
|
||||
"3. After launching dream bot, you'll see: <br> `Dream > ` in terminal. <br> Insert a command, eg. `Dream > Astronaut floating in a distant galaxy`, or type `-h` for help.\n",
|
||||
"3. After completion you'll see your generated images in path `stable-diffusion/outputs/img-samples/`, you can also show last generated images in cell #10.\n",
|
||||
"4. To quit Dream bot use `q` command. <br> \n",
|
||||
"---\n",
|
||||
"<font color=\"red\">Note:</font> It takes some time to load, but after installing all dependencies you can use the bot all time you want while colab instance is up. <br>\n",
|
||||
"<font color=\"red\">Requirements:</font> For this notebook to work you need to have [Stable-Diffusion-v-1-4](https://huggingface.co/CompVis/stable-diffusion-v-1-4-original) stored in your Google Drive, it will be needed in cell #7\n",
|
||||
"##### For more details visit Github repository: [invoke-ai/InvokeAI](https://github.com/invoke-ai/InvokeAI)\n",
|
||||
"---\n"
|
||||
]
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "dr32VLxlnouf"
|
||||
},
|
||||
"source": [
|
||||
"## ◢ Installation"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"cellView": "form",
|
||||
"id": "a2Z5Qu_o8VtQ"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# @title 1. Check current GPU assigned\n",
|
||||
"!nvidia-smi -L\n",
|
||||
"!nvidia-smi"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"cellView": "form",
|
||||
"id": "vbI9ZsQHzjqF"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# @title 2. Download stable-diffusion Repository\n",
|
||||
"from os.path import exists\n",
|
||||
"\n",
|
||||
"!git clone --quiet https://github.com/invoke-ai/InvokeAI.git # Original repo\n",
|
||||
"%cd /content/InvokeAI/\n",
|
||||
"!git checkout --quiet tags/v2.0.0"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"cellView": "form",
|
||||
"id": "QbXcGXYEFSNB"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# @title 3. Install dependencies\n",
|
||||
"import gc\n",
|
||||
"\n",
|
||||
"!wget https://raw.githubusercontent.com/invoke-ai/InvokeAI/development/environments-and-requirements/requirements-base.txt\n",
|
||||
"!wget https://raw.githubusercontent.com/invoke-ai/InvokeAI/development/environments-and-requirements/requirements-win-colab-cuda.txt\n",
|
||||
"!pip install colab-xterm\n",
|
||||
"!pip install -r requirements-lin-win-colab-CUDA.txt\n",
|
||||
"!pip install clean-fid torchtext\n",
|
||||
"!pip install transformers\n",
|
||||
"gc.collect()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"cellView": "form",
|
||||
"id": "8rSMhgnAttQa"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# @title 4. Restart Runtime\n",
|
||||
"exit()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"cellView": "form",
|
||||
"id": "ChIDWxLVHGGJ"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# @title 5. Load small ML models required\n",
|
||||
"import gc\n",
|
||||
"\n",
|
||||
"%cd /content/InvokeAI/\n",
|
||||
"!python scripts/preload_models.py\n",
|
||||
"gc.collect()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "795x1tMoo8b1"
|
||||
},
|
||||
"source": [
|
||||
"## ◢ Configuration"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"cellView": "form",
|
||||
"id": "YEWPV-sF1RDM"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# @title 6. Mount google Drive\n",
|
||||
"from google.colab import drive\n",
|
||||
"\n",
|
||||
"drive.mount(\"/content/drive\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"cellView": "form",
|
||||
"id": "zRTJeZ461WGu"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# @title 7. Drive Path to model\n",
|
||||
"# @markdown Path should start with /content/drive/path-to-your-file <br>\n",
|
||||
"# @markdown <font color=\"red\">Note:</font> Model should be downloaded from https://huggingface.co <br>\n",
|
||||
"# @markdown Lastest release: [Stable-Diffusion-v-1-4](https://huggingface.co/CompVis/stable-diffusion-v-1-4-original)\n",
|
||||
"from os.path import exists\n",
|
||||
"\n",
|
||||
"model_path = \"\" # @param {type:\"string\"}\n",
|
||||
"if exists(model_path):\n",
|
||||
" print(\"✅ Valid directory\")\n",
|
||||
"else:\n",
|
||||
" print(\"❌ File doesn't exist\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"cellView": "form",
|
||||
"id": "UY-NNz4I8_aG"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# @title 8. Symlink to model\n",
|
||||
"\n",
|
||||
"from os.path import exists\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"# Folder creation if it doesn't exist\n",
|
||||
"if exists(\"/content/InvokeAI/models/ldm/stable-diffusion-v1\"):\n",
|
||||
" print(\"❗ Dir stable-diffusion-v1 already exists\")\n",
|
||||
"else:\n",
|
||||
" %mkdir /content/InvokeAI/models/ldm/stable-diffusion-v1\n",
|
||||
" print(\"✅ Dir stable-diffusion-v1 created\")\n",
|
||||
"\n",
|
||||
"# Symbolic link if it doesn't exist\n",
|
||||
"if exists(\"/content/InvokeAI/models/ldm/stable-diffusion-v1/model.ckpt\"):\n",
|
||||
" print(\"❗ Symlink already created\")\n",
|
||||
"else:\n",
|
||||
" src = model_path\n",
|
||||
" dst = \"/content/InvokeAI/models/ldm/stable-diffusion-v1/model.ckpt\"\n",
|
||||
" os.symlink(src, dst)\n",
|
||||
" print(\"✅ Symbolic link created successfully\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "Mc28N0_NrCQH"
|
||||
},
|
||||
"source": [
|
||||
"## ◢ Execution"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"cellView": "form",
|
||||
"id": "ir4hCrMIuUpl"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# @title 9. Run Terminal and Execute Dream bot\n",
|
||||
"# @markdown <font color=\"blue\">Steps:</font> <br>\n",
|
||||
"# @markdown 1. Execute command `python scripts/invoke.py` to run InvokeAI.<br>\n",
|
||||
"# @markdown 2. After initialized you'll see `Dream>` line.<br>\n",
|
||||
"# @markdown 3. Example text: `Astronaut floating in a distant galaxy` <br>\n",
|
||||
"# @markdown 4. To quit Dream bot use: `q` command.<br>\n",
|
||||
"\n",
|
||||
"%load_ext colabxterm\n",
|
||||
"%xterm\n",
|
||||
"gc.collect()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"cellView": "form",
|
||||
"id": "qnLohSHmKoGk"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#@title 10. Show the last 15 generated images\n",
|
||||
"import glob\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import matplotlib.image as mpimg\n",
|
||||
"%matplotlib inline\n",
|
||||
"\n",
|
||||
"images = []\n",
|
||||
"for img_path in sorted(glob.glob('/content/InvokeAI/outputs/img-samples/*.png'), reverse=True):\n",
|
||||
" images.append(mpimg.imread(img_path))\n",
|
||||
"\n",
|
||||
"images = images[:15] \n",
|
||||
"\n",
|
||||
"plt.figure(figsize=(20,10))\n",
|
||||
"\n",
|
||||
"columns = 5\n",
|
||||
"for i, image in enumerate(images):\n",
|
||||
" ax = plt.subplot(len(images) / columns + 1, columns, i + 1)\n",
|
||||
" ax.axes.xaxis.set_visible(False)\n",
|
||||
" ax.axes.yaxis.set_visible(False)\n",
|
||||
" ax.axis('off')\n",
|
||||
" plt.imshow(image)\n",
|
||||
" gc.collect()\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"accelerator": "GPU",
|
||||
"colab": {
|
||||
"collapsed_sections": [],
|
||||
"private_outputs": true,
|
||||
"provenance": []
|
||||
},
|
||||
"gpuClass": "standard",
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3.9.12 64-bit",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"name": "python",
|
||||
"version": "3.9.12"
|
||||
},
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
"hash": "4e870c5c5fe42db7e2c5647ae5af656ff3391bf8c2b729cbf7fa0e16ca8cb5af"
|
||||
}
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
||||
|
@ -52,17 +52,17 @@
|
||||
"name": "stdout",
|
||||
"text": [
|
||||
"Cloning into 'latent-diffusion'...\n",
|
||||
"remote: Enumerating objects: 992, done.\u001B[K\n",
|
||||
"remote: Counting objects: 100% (695/695), done.\u001B[K\n",
|
||||
"remote: Compressing objects: 100% (397/397), done.\u001B[K\n",
|
||||
"remote: Total 992 (delta 375), reused 564 (delta 253), pack-reused 297\u001B[K\n",
|
||||
"remote: Enumerating objects: 992, done.\u001b[K\n",
|
||||
"remote: Counting objects: 100% (695/695), done.\u001b[K\n",
|
||||
"remote: Compressing objects: 100% (397/397), done.\u001b[K\n",
|
||||
"remote: Total 992 (delta 375), reused 564 (delta 253), pack-reused 297\u001b[K\n",
|
||||
"Receiving objects: 100% (992/992), 30.78 MiB | 29.43 MiB/s, done.\n",
|
||||
"Resolving deltas: 100% (510/510), done.\n",
|
||||
"Cloning into 'taming-transformers'...\n",
|
||||
"remote: Enumerating objects: 1335, done.\u001B[K\n",
|
||||
"remote: Counting objects: 100% (525/525), done.\u001B[K\n",
|
||||
"remote: Compressing objects: 100% (493/493), done.\u001B[K\n",
|
||||
"remote: Total 1335 (delta 58), reused 481 (delta 30), pack-reused 810\u001B[K\n",
|
||||
"remote: Enumerating objects: 1335, done.\u001b[K\n",
|
||||
"remote: Counting objects: 100% (525/525), done.\u001b[K\n",
|
||||
"remote: Compressing objects: 100% (493/493), done.\u001b[K\n",
|
||||
"remote: Total 1335 (delta 58), reused 481 (delta 30), pack-reused 810\u001b[K\n",
|
||||
"Receiving objects: 100% (1335/1335), 412.35 MiB | 30.53 MiB/s, done.\n",
|
||||
"Resolving deltas: 100% (267/267), done.\n",
|
||||
"Obtaining file:///content/taming-transformers\n",
|
||||
@ -73,23 +73,24 @@
|
||||
"Installing collected packages: taming-transformers\n",
|
||||
" Running setup.py develop for taming-transformers\n",
|
||||
"Successfully installed taming-transformers-0.0.1\n",
|
||||
"\u001B[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
|
||||
"\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
|
||||
"tensorflow 2.8.0 requires tf-estimator-nightly==2.8.0.dev2021122109, which is not installed.\n",
|
||||
"arviz 0.11.4 requires typing-extensions<4,>=3.7.4.3, but you have typing-extensions 4.1.1 which is incompatible.\u001B[0m\n"
|
||||
"arviz 0.11.4 requires typing-extensions<4,>=3.7.4.3, but you have typing-extensions 4.1.1 which is incompatible.\u001b[0m\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"#@title Installation\n",
|
||||
"# @title Installation\n",
|
||||
"!git clone https://github.com/CompVis/latent-diffusion.git\n",
|
||||
"!git clone https://github.com/CompVis/taming-transformers\n",
|
||||
"!pip install -e ./taming-transformers\n",
|
||||
"!pip install omegaconf>=2.0.0 pytorch-lightning>=1.0.8 torch-fidelity einops\n",
|
||||
"\n",
|
||||
"import sys\n",
|
||||
"\n",
|
||||
"sys.path.append(\".\")\n",
|
||||
"sys.path.append('./taming-transformers')\n",
|
||||
"from taming.models import vqgan "
|
||||
"sys.path.append(\"./taming-transformers\")\n",
|
||||
"from taming.models import vqgan"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -104,11 +105,11 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"#@title Download\n",
|
||||
"%cd latent-diffusion/ \n",
|
||||
"# @title Download\n",
|
||||
"%cd latent-diffusion/\n",
|
||||
"\n",
|
||||
"!mkdir -p models/ldm/cin256-v2/\n",
|
||||
"!wget -O models/ldm/cin256-v2/model.ckpt https://ommer-lab.com/files/latent-diffusion/nitro/cin/model.ckpt "
|
||||
"!wget -O models/ldm/cin256-v2/model.ckpt https://ommer-lab.com/files/latent-diffusion/nitro/cin/model.ckpt"
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
@ -203,7 +204,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"#@title loading utils\n",
|
||||
"# @title loading utils\n",
|
||||
"import torch\n",
|
||||
"from omegaconf import OmegaConf\n",
|
||||
"\n",
|
||||
@ -212,7 +213,7 @@
|
||||
"\n",
|
||||
"def load_model_from_config(config, ckpt):\n",
|
||||
" print(f\"Loading model from {ckpt}\")\n",
|
||||
" pl_sd = torch.load(ckpt)#, map_location=\"cpu\")\n",
|
||||
" pl_sd = torch.load(ckpt) # , map_location=\"cpu\")\n",
|
||||
" sd = pl_sd[\"state_dict\"]\n",
|
||||
" model = instantiate_from_config(config.model)\n",
|
||||
" m, u = model.load_state_dict(sd, strict=False)\n",
|
||||
@ -222,7 +223,7 @@
|
||||
"\n",
|
||||
"\n",
|
||||
"def get_model():\n",
|
||||
" config = OmegaConf.load(\"configs/latent-diffusion/cin256-v2.yaml\") \n",
|
||||
" config = OmegaConf.load(\"configs/latent-diffusion/cin256-v2.yaml\")\n",
|
||||
" model = load_model_from_config(config, \"models/ldm/cin256-v2/model.ckpt\")\n",
|
||||
" return model"
|
||||
],
|
||||
@ -276,18 +277,18 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"import numpy as np \n",
|
||||
"import numpy as np\n",
|
||||
"from PIL import Image\n",
|
||||
"from einops import rearrange\n",
|
||||
"from torchvision.utils import make_grid\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"classes = [25, 187, 448, 992] # define classes to be sampled here\n",
|
||||
"classes = [25, 187, 448, 992] # define classes to be sampled here\n",
|
||||
"n_samples_per_class = 6\n",
|
||||
"\n",
|
||||
"ddim_steps = 20\n",
|
||||
"ddim_eta = 0.0\n",
|
||||
"scale = 3.0 # for unconditional guidance\n",
|
||||
"scale = 3.0 # for unconditional guidance\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"all_samples = list()\n",
|
||||
@ -295,36 +296,39 @@
|
||||
"with torch.no_grad():\n",
|
||||
" with model.ema_scope():\n",
|
||||
" uc = model.get_learned_conditioning(\n",
|
||||
" {model.cond_stage_key: torch.tensor(n_samples_per_class*[1000]).to(model.device)}\n",
|
||||
" )\n",
|
||||
" \n",
|
||||
" {model.cond_stage_key: torch.tensor(n_samples_per_class * [1000]).to(model.device)}\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" for class_label in classes:\n",
|
||||
" print(f\"rendering {n_samples_per_class} examples of class '{class_label}' in {ddim_steps} steps and using s={scale:.2f}.\")\n",
|
||||
" xc = torch.tensor(n_samples_per_class*[class_label])\n",
|
||||
" print(\n",
|
||||
" f\"rendering {n_samples_per_class} examples of class '{class_label}' in {ddim_steps} steps and using s={scale:.2f}.\"\n",
|
||||
" )\n",
|
||||
" xc = torch.tensor(n_samples_per_class * [class_label])\n",
|
||||
" c = model.get_learned_conditioning({model.cond_stage_key: xc.to(model.device)})\n",
|
||||
" \n",
|
||||
" samples_ddim, _ = sampler.sample(S=ddim_steps,\n",
|
||||
" conditioning=c,\n",
|
||||
" batch_size=n_samples_per_class,\n",
|
||||
" shape=[3, 64, 64],\n",
|
||||
" verbose=False,\n",
|
||||
" unconditional_guidance_scale=scale,\n",
|
||||
" unconditional_conditioning=uc, \n",
|
||||
" eta=ddim_eta)\n",
|
||||
"\n",
|
||||
" samples_ddim, _ = sampler.sample(\n",
|
||||
" S=ddim_steps,\n",
|
||||
" conditioning=c,\n",
|
||||
" batch_size=n_samples_per_class,\n",
|
||||
" shape=[3, 64, 64],\n",
|
||||
" verbose=False,\n",
|
||||
" unconditional_guidance_scale=scale,\n",
|
||||
" unconditional_conditioning=uc,\n",
|
||||
" eta=ddim_eta,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" x_samples_ddim = model.decode_first_stage(samples_ddim)\n",
|
||||
" x_samples_ddim = torch.clamp((x_samples_ddim+1.0)/2.0, \n",
|
||||
" min=0.0, max=1.0)\n",
|
||||
" x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)\n",
|
||||
" all_samples.append(x_samples_ddim)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# display as grid\n",
|
||||
"grid = torch.stack(all_samples, 0)\n",
|
||||
"grid = rearrange(grid, 'n b c h w -> (n b) c h w')\n",
|
||||
"grid = rearrange(grid, \"n b c h w -> (n b) c h w\")\n",
|
||||
"grid = make_grid(grid, nrow=n_samples_per_class)\n",
|
||||
"\n",
|
||||
"# to image\n",
|
||||
"grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()\n",
|
||||
"grid = 255.0 * rearrange(grid, \"c h w -> h w c\").cpu().numpy()\n",
|
||||
"Image.fromarray(grid.astype(np.uint8))"
|
||||
],
|
||||
"metadata": {
|
||||
|
Loading…
Reference in New Issue
Block a user