Merge branch 'main' into feat/onnx

This commit is contained in:
Brandon Rising 2023-07-31 10:58:40 -04:00
commit f5ac73b091
37 changed files with 1019 additions and 666 deletions

View File

@ -394,7 +394,7 @@ rm .\.venv -r -force
python -mvenv .venv python -mvenv .venv
.\.venv\Scripts\activate .\.venv\Scripts\activate
pip install invokeai pip install invokeai
invokeai-configure --root . invokeai-configure --yes --root .
``` ```
If you see anything marked as an error during this process please stop If you see anything marked as an error during this process please stop

View File

@ -14,20 +14,25 @@ The nodes linked below have been developed and contributed by members of the Inv
## List of Nodes ## List of Nodes
### Face Mask ### FaceTools
**Description:** This node autodetects a face in the image using MediaPipe and masks it by making it transparent. Via outpainting you can swap faces with other faces, or invert the mask and swap things around the face with other things. Additionally, you can supply X and Y offset values to scale/change the shape of the mask for finer control. The node also outputs an all-white mask in the same dimensions as the input image. This is needed by the inpaint node (and unified canvas) for outpainting. **Description:** FaceTools is a collection of nodes created to manipulate faces as you would in Unified Canvas. It includes FaceMask, FaceOff, and FacePlace. FaceMask autodetects a face in the image using MediaPipe and creates a mask from it. FaceOff similarly detects a face, then takes the face off of the image by adding a square bounding box around it and cropping/scaling it. FacePlace puts the bounded face image from FaceOff back onto the original image. Using these nodes with other inpainting node(s), you can put new faces on existing things, put new things around existing faces, and work closer with a face as a bounded image. Additionally, you can supply X and Y offset values to scale/change the shape of the mask for finer control on FaceMask and FaceOff. See GitHub repository below for usage examples.
**Node Link:** https://github.com/ymgenesis/InvokeAI/blob/facemaskmediapipe/invokeai/app/invocations/facemask.py **Node Link:** https://github.com/ymgenesis/FaceTools/
**Example Node Graph:** https://www.mediafire.com/file/gohn5sb1bfp8use/21-July_2023-FaceMask.json/file **FaceMask Output Examples**
**Output Examples** ![5cc8abce-53b0-487a-b891-3bf94dcc8960](https://github.com/invoke-ai/InvokeAI/assets/25252829/43f36d24-1429-4ab1-bd06-a4bedfe0955e)
![b920b710-1882-49a0-8d02-82dff2cca907](https://github.com/invoke-ai/InvokeAI/assets/25252829/7660c1ed-bf7d-4d0a-947f-1fc1679557ba)
![71a91805-fda5-481c-b380-264665703133](https://github.com/invoke-ai/InvokeAI/assets/25252829/f8f6a2ee-2b68-4482-87da-b90221d5c3e2)
![2e3168cb-af6a-475d-bfac-c7b7fd67b4c2](https://github.com/ymgenesis/InvokeAI/assets/25252829/a5ad7d44-2ada-4b3c-a56e-a21f8244a1ac) <hr>
![2_annotated](https://github.com/ymgenesis/InvokeAI/assets/25252829/53416c8a-a23b-4d76-bb6d-3cfd776e0096)
![2fe2150c-fd08-4e26-8c36-f0610bf441bb](https://github.com/ymgenesis/InvokeAI/assets/25252829/b0f7ecfe-f093-4147-a904-b9f131b41dc9) ### Ideal Size
![831b6b98-4f0f-4360-93c8-69a9c1338cbe](https://github.com/ymgenesis/InvokeAI/assets/25252829/fc7b0622-e361-4155-8a76-082894d084f0)
**Description:** This node calculates an ideal image size for a first pass of a multi-pass upscaling. The aim is to avoid duplication that results from choosing a size larger than the model is capable of.
**Node Link:** https://github.com/JPPhoto/ideal-size-node
-------------------------------- --------------------------------
### Super Cool Node Template ### Super Cool Node Template
@ -42,11 +47,5 @@ The nodes linked below have been developed and contributed by members of the Inv
![Invoke AI](https://invoke-ai.github.io/InvokeAI/assets/invoke_ai_banner.png) ![Invoke AI](https://invoke-ai.github.io/InvokeAI/assets/invoke_ai_banner.png)
### Ideal Size
**Description:** This node calculates an ideal image size for a first pass of a multi-pass upscaling. The aim is to avoid duplication that results from choosing a size larger than the model is capable of.
**Node Link:** https://github.com/JPPhoto/ideal-size-node
## Help ## Help
If you run into any issues with a node, please post in the [InvokeAI Discord](https://discord.gg/ZmtBAhwWhy). If you run into any issues with a node, please post in the [InvokeAI Discord](https://discord.gg/ZmtBAhwWhy).

25
flake.lock Normal file
View File

@ -0,0 +1,25 @@
{
"nodes": {
"nixpkgs": {
"locked": {
"lastModified": 1690630721,
"narHash": "sha256-Y04onHyBQT4Erfr2fc82dbJTfXGYrf4V0ysLUYnPOP8=",
"owner": "NixOS",
"repo": "nixpkgs",
"rev": "d2b52322f35597c62abf56de91b0236746b2a03d",
"type": "github"
},
"original": {
"id": "nixpkgs",
"type": "indirect"
}
},
"root": {
"inputs": {
"nixpkgs": "nixpkgs"
}
}
},
"root": "root",
"version": 7
}

81
flake.nix Normal file
View File

@ -0,0 +1,81 @@
# Important note: this flake does not attempt to create a fully isolated, 'pure'
# Python environment for InvokeAI. Instead, it depends on local invocations of
# virtualenv/pip to install the required (binary) packages, most importantly the
# prebuilt binary pytorch packages with CUDA support.
# ML Python packages with CUDA support, like pytorch, are notoriously expensive
# to compile so it's purposefuly not what this flake does.
{
description = "An (impure) flake to develop on InvokeAI.";
outputs = { self, nixpkgs }:
let
system = "x86_64-linux";
pkgs = import nixpkgs {
inherit system;
config.allowUnfree = true;
};
python = pkgs.python310;
mkShell = { dir, install }:
let
setupScript = pkgs.writeScript "setup-invokai" ''
# This must be sourced using 'source', not executed.
${python}/bin/python -m venv ${dir}
${dir}/bin/python -m pip install ${install}
# ${dir}/bin/python -c 'import torch; assert(torch.cuda.is_available())'
source ${dir}/bin/activate
'';
in
pkgs.mkShell rec {
buildInputs = with pkgs; [
# Backend: graphics, CUDA.
cudaPackages.cudnn
cudaPackages.cuda_nvrtc
cudatoolkit
freeglut
glib
gperf
procps
libGL
libGLU
linuxPackages.nvidia_x11
python
stdenv.cc
stdenv.cc.cc.lib
xorg.libX11
xorg.libXext
xorg.libXi
xorg.libXmu
xorg.libXrandr
xorg.libXv
zlib
# Pre-commit hooks.
black
# Frontend.
yarn
nodejs
];
LD_LIBRARY_PATH = pkgs.lib.makeLibraryPath buildInputs;
CUDA_PATH = pkgs.cudatoolkit;
EXTRA_LDFLAGS = "-L${pkgs.linuxPackages.nvidia_x11}/lib";
shellHook = ''
if [[ -f "${dir}/bin/activate" ]]; then
source "${dir}/bin/activate"
echo "Using Python: $(which python)"
else
echo "Use 'source ${setupScript}' to set up the environment."
fi
'';
};
in
{
devShells.${system} = rec {
develop = mkShell { dir = "venv"; install = "-e '.[xformers]' --extra-index-url https://download.pytorch.org/whl/cu118"; };
default = develop;
};
};
}

View File

@ -13,7 +13,7 @@ from pathlib import Path
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from typing import Union from typing import Union
SUPPORTED_PYTHON = ">=3.9.0,<3.11" SUPPORTED_PYTHON = ">=3.9.0,<=3.11.100"
INSTALLER_REQS = ["rich", "semver", "requests", "plumbum", "prompt-toolkit"] INSTALLER_REQS = ["rich", "semver", "requests", "plumbum", "prompt-toolkit"]
BOOTSTRAP_VENV_PREFIX = "invokeai-installer-tmp" BOOTSTRAP_VENV_PREFIX = "invokeai-installer-tmp"
@ -149,7 +149,7 @@ class Installer:
return venv_dir return venv_dir
def install( def install(
self, root: str = "~/invokeai-3", version: str = "latest", yes_to_all=False, find_links: Path = None self, root: str = "~/invokeai", version: str = "latest", yes_to_all=False, find_links: Path = None
) -> None: ) -> None:
""" """
Install the InvokeAI application into the given runtime path Install the InvokeAI application into the given runtime path
@ -168,7 +168,8 @@ class Installer:
messages.welcome() messages.welcome()
self.dest = Path(root).expanduser().resolve() if yes_to_all else messages.dest_path(root) default_path = os.environ.get("INVOKEAI_ROOT") or Path(root).expanduser().resolve()
self.dest = default_path if yes_to_all else messages.dest_path(root)
# create the venv for the app # create the venv for the app
self.venv = self.app_venv() self.venv = self.app_venv()
@ -248,6 +249,9 @@ class InvokeAiInstance:
pip[ pip[
"install", "install",
"--require-virtualenv", "--require-virtualenv",
"numpy~=1.24.0", # choose versions that won't be uninstalled during phase 2
"urllib3~=1.26.0",
"requests~=2.28.0",
"torch~=2.0.0", "torch~=2.0.0",
"torchmetrics==0.11.4", "torchmetrics==0.11.4",
"torchvision>=0.14.1", "torchvision>=0.14.1",

View File

@ -3,6 +3,7 @@ InvokeAI Installer
""" """
import argparse import argparse
import os
from pathlib import Path from pathlib import Path
from installer import Installer from installer import Installer
@ -15,7 +16,7 @@ if __name__ == "__main__":
dest="root", dest="root",
type=str, type=str,
help="Destination path for installation", help="Destination path for installation",
default="~/invokeai", default=os.environ.get("INVOKEAI_ROOT") or "~/invokeai",
) )
parser.add_argument( parser.add_argument(
"-y", "-y",

View File

@ -41,7 +41,7 @@ IF /I "%choice%" == "1" (
python .venv\Scripts\invokeai-configure.exe --skip-sd-weight --skip-support-models python .venv\Scripts\invokeai-configure.exe --skip-sd-weight --skip-support-models
) ELSE IF /I "%choice%" == "7" ( ) ELSE IF /I "%choice%" == "7" (
echo Running invokeai-configure... echo Running invokeai-configure...
python .venv\Scripts\invokeai-configure.exe --yes --default_only python .venv\Scripts\invokeai-configure.exe --yes --skip-sd-weight
) ELSE IF /I "%choice%" == "8" ( ) ELSE IF /I "%choice%" == "8" (
echo Developer Console echo Developer Console
echo Python command is: echo Python command is:

View File

@ -82,7 +82,7 @@ do_choice() {
7) 7)
clear clear
printf "Re-run the configure script to fix a broken install or to complete a major upgrade\n" printf "Re-run the configure script to fix a broken install or to complete a major upgrade\n"
invokeai-configure --root ${INVOKEAI_ROOT} --yes --default_only invokeai-configure --root ${INVOKEAI_ROOT} --yes --default_only --skip-sd-weights
;; ;;
8) 8)
clear clear

View File

@ -4,6 +4,8 @@ from typing import Literal
from pydantic import Field from pydantic import Field
from invokeai.app.invocations.prompt import PromptOutput
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
from .math import FloatOutput, IntOutput from .math import FloatOutput, IntOutput
@ -64,3 +66,18 @@ class ParamStringInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> StringOutput: def invoke(self, context: InvocationContext) -> StringOutput:
return StringOutput(text=self.text) return StringOutput(text=self.text)
class ParamPromptInvocation(BaseInvocation):
"""A prompt input parameter"""
type: Literal["param_prompt"] = "param_prompt"
prompt: str = Field(default="", description="The prompt value")
class Config(InvocationConfig):
schema_extra = {
"ui": {"tags": ["param", "prompt"], "title": "Prompt"},
}
def invoke(self, context: InvocationContext) -> PromptOutput:
return PromptOutput(prompt=self.prompt)

View File

@ -171,7 +171,6 @@ from pydantic import BaseSettings, Field, parse_obj_as
from typing import ClassVar, Dict, List, Set, Literal, Union, get_origin, get_type_hints, get_args from typing import ClassVar, Dict, List, Set, Literal, Union, get_origin, get_type_hints, get_args
INIT_FILE = Path("invokeai.yaml") INIT_FILE = Path("invokeai.yaml")
MODEL_CORE = Path("models/core")
DB_FILE = Path("invokeai.db") DB_FILE = Path("invokeai.db")
LEGACY_INIT_FILE = Path("invokeai.init") LEGACY_INIT_FILE = Path("invokeai.init")
@ -275,7 +274,7 @@ class InvokeAISettings(BaseSettings):
@classmethod @classmethod
def _excluded(self) -> List[str]: def _excluded(self) -> List[str]:
# internal fields that shouldn't be exposed as command line options # internal fields that shouldn't be exposed as command line options
return ["type", "initconf"] return ["type", "initconf", "cached_root"]
@classmethod @classmethod
def _excluded_from_yaml(self) -> List[str]: def _excluded_from_yaml(self) -> List[str]:
@ -291,6 +290,7 @@ class InvokeAISettings(BaseSettings):
"restore", "restore",
"root", "root",
"nsfw_checker", "nsfw_checker",
"cached_root",
] ]
class Config: class Config:
@ -357,7 +357,7 @@ def _find_root() -> Path:
venv = Path(os.environ.get("VIRTUAL_ENV") or ".") venv = Path(os.environ.get("VIRTUAL_ENV") or ".")
if os.environ.get("INVOKEAI_ROOT"): if os.environ.get("INVOKEAI_ROOT"):
root = Path(os.environ.get("INVOKEAI_ROOT")).resolve() root = Path(os.environ.get("INVOKEAI_ROOT")).resolve()
elif any([(venv.parent / x).exists() for x in [INIT_FILE, LEGACY_INIT_FILE, MODEL_CORE]]): elif any([(venv.parent / x).exists() for x in [INIT_FILE, LEGACY_INIT_FILE]]):
root = (venv.parent).resolve() root = (venv.parent).resolve()
else: else:
root = Path("~/invokeai").expanduser().resolve() root = Path("~/invokeai").expanduser().resolve()
@ -424,6 +424,7 @@ class InvokeAIAppConfig(InvokeAISettings):
log_level : Literal[tuple(["debug","info","warning","error","critical"])] = Field(default="info", description="Emit logging messages at this level or higher", category="Logging") log_level : Literal[tuple(["debug","info","warning","error","critical"])] = Field(default="info", description="Emit logging messages at this level or higher", category="Logging")
version : bool = Field(default=False, description="Show InvokeAI version and exit", category="Other") version : bool = Field(default=False, description="Show InvokeAI version and exit", category="Other")
cached_root : Path = Field(default=None, description="internal use only", category="DEPRECATED")
# fmt: on # fmt: on
def parse_args(self, argv: List[str] = None, conf: DictConfig = None, clobber=False): def parse_args(self, argv: List[str] = None, conf: DictConfig = None, clobber=False):
@ -471,10 +472,15 @@ class InvokeAIAppConfig(InvokeAISettings):
""" """
Path to the runtime root directory Path to the runtime root directory
""" """
if self.root: # we cache value of root to protect against it being '.' and the cwd changing
return Path(self.root).expanduser().absolute() if self.cached_root:
root = self.cached_root
elif self.root:
root = Path(self.root).expanduser().absolute()
else: else:
return self.find_root() root = self.find_root()
self.cached_root = root
return self.cached_root
@property @property
def root_dir(self) -> Path: def root_dir(self) -> Path:

View File

@ -181,7 +181,7 @@ def download_with_progress_bar(model_url: str, model_dest: str, label: str = "th
def download_conversion_models(): def download_conversion_models():
target_dir = config.root_path / "models/core/convert" target_dir = config.models_path / "core/convert"
kwargs = dict() # for future use kwargs = dict() # for future use
try: try:
logger.info("Downloading core tokenizers and text encoders") logger.info("Downloading core tokenizers and text encoders")

View File

@ -7,7 +7,7 @@ import warnings
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from typing import List, Dict, Callable, Union, Set from typing import List, Dict, Callable, Union, Set, Optional
import requests import requests
from diffusers import DiffusionPipeline from diffusers import DiffusionPipeline
@ -129,7 +129,9 @@ class ModelInstall(object):
model_dict[key] = ModelLoadInfo(**value) model_dict[key] = ModelLoadInfo(**value)
# supplement with entries in models.yaml # supplement with entries in models.yaml
installed_models = self.mgr.list_models() installed_models = [x for x in self.mgr.list_models()]
# suppresses autoloaded models
# installed_models = [x for x in self.mgr.list_models() if not self._is_autoloaded(x)]
for md in installed_models: for md in installed_models:
base = md["base_model"] base = md["base_model"]
@ -148,6 +150,17 @@ class ModelInstall(object):
) )
return {x: model_dict[x] for x in sorted(model_dict.keys(), key=lambda y: model_dict[y].name.lower())} return {x: model_dict[x] for x in sorted(model_dict.keys(), key=lambda y: model_dict[y].name.lower())}
def _is_autoloaded(self, model_info: dict) -> bool:
path = model_info.get("path")
if not path:
return False
for autodir in ["autoimport_dir", "lora_dir", "embedding_dir", "controlnet_dir"]:
if autodir_path := getattr(self.config, autodir):
autodir_path = self.config.root_path / autodir_path
if Path(path).is_relative_to(autodir_path):
return True
return False
def list_models(self, model_type): def list_models(self, model_type):
installed = self.mgr.list_models(model_type=model_type) installed = self.mgr.list_models(model_type=model_type)
print(f"Installed models of type `{model_type}`:") print(f"Installed models of type `{model_type}`:")
@ -274,6 +287,7 @@ class ModelInstall(object):
logger.error(f"Unable to download {url}. Skipping.") logger.error(f"Unable to download {url}. Skipping.")
info = ModelProbe().heuristic_probe(location) info = ModelProbe().heuristic_probe(location)
dest = self.config.models_path / info.base_type.value / info.model_type.value / location.name dest = self.config.models_path / info.base_type.value / info.model_type.value / location.name
dest.parent.mkdir(parents=True, exist_ok=True)
models_path = shutil.move(location, dest) models_path = shutil.move(location, dest)
# staged version will be garbage-collected at this time # staged version will be garbage-collected at this time
@ -349,7 +363,7 @@ class ModelInstall(object):
if key in self.datasets: if key in self.datasets:
description = self.datasets[key].get("description") or description description = self.datasets[key].get("description") or description
rel_path = self.relative_to_root(path) rel_path = self.relative_to_root(path, self.config.models_path)
attributes = dict( attributes = dict(
path=str(rel_path), path=str(rel_path),
@ -389,8 +403,8 @@ class ModelInstall(object):
attributes.update(dict(config=str(legacy_conf))) attributes.update(dict(config=str(legacy_conf)))
return attributes return attributes
def relative_to_root(self, path: Path) -> Path: def relative_to_root(self, path: Path, root: Optional[Path] = None) -> Path:
root = self.config.root_path root = root or self.config.root_path
if path.is_relative_to(root): if path.is_relative_to(root):
return path.relative_to(root) return path.relative_to(root)
else: else:

View File

@ -63,7 +63,7 @@ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionS
from diffusers.pipelines.stable_diffusion.stable_unclip_image_normalizer import StableUnCLIPImageNormalizer from diffusers.pipelines.stable_diffusion.stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
from invokeai.app.services.config import InvokeAIAppConfig, MODEL_CORE from invokeai.app.services.config import InvokeAIAppConfig
from picklescan.scanner import scan_file_path from picklescan.scanner import scan_file_path
from .models import BaseModelType, ModelVariantType from .models import BaseModelType, ModelVariantType
@ -81,7 +81,7 @@ if is_accelerate_available():
from accelerate.utils import set_module_tensor_to_device from accelerate.utils import set_module_tensor_to_device
logger = InvokeAILogger.getLogger(__name__) logger = InvokeAILogger.getLogger(__name__)
CONVERT_MODEL_ROOT = InvokeAIAppConfig.get_config().root_path / MODEL_CORE / "convert" CONVERT_MODEL_ROOT = InvokeAIAppConfig.get_config().models_path / "core/convert"
def shave_segments(path, n_shave_prefix_segments=1): def shave_segments(path, n_shave_prefix_segments=1):
@ -1070,7 +1070,7 @@ def convert_controlnet_checkpoint(
extract_ema, extract_ema,
use_linear_projection=None, use_linear_projection=None,
cross_attention_dim=None, cross_attention_dim=None,
precision: torch.dtype = torch.float32, precision: Optional[torch.dtype] = None,
): ):
ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True) ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True)
ctrlnet_config["upcast_attention"] = upcast_attention ctrlnet_config["upcast_attention"] = upcast_attention
@ -1111,7 +1111,6 @@ def convert_controlnet_checkpoint(
return controlnet.to(precision) return controlnet.to(precision)
# TO DO - PASS PRECISION
def download_from_original_stable_diffusion_ckpt( def download_from_original_stable_diffusion_ckpt(
checkpoint_path: str, checkpoint_path: str,
model_version: BaseModelType, model_version: BaseModelType,
@ -1121,7 +1120,7 @@ def download_from_original_stable_diffusion_ckpt(
prediction_type: str = None, prediction_type: str = None,
model_type: str = None, model_type: str = None,
extract_ema: bool = False, extract_ema: bool = False,
precision: torch.dtype = torch.float32, precision: Optional[torch.dtype] = None,
scheduler_type: str = "pndm", scheduler_type: str = "pndm",
num_in_channels: Optional[int] = None, num_in_channels: Optional[int] = None,
upcast_attention: Optional[bool] = None, upcast_attention: Optional[bool] = None,
@ -1194,6 +1193,8 @@ def download_from_original_stable_diffusion_ckpt(
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer) [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer)
to use. If this parameter is `None`, the function will load a new instance of [CLIPTokenizer] by itself, if to use. If this parameter is `None`, the function will load a new instance of [CLIPTokenizer] by itself, if
needed. needed.
precision (`torch.dtype`, *optional*, defauts to `None`):
If not provided the precision will be set to the precision of the original file.
return: A StableDiffusionPipeline object representing the passed-in `.ckpt`/`.safetensors` file. return: A StableDiffusionPipeline object representing the passed-in `.ckpt`/`.safetensors` file.
""" """
@ -1252,6 +1253,10 @@ def download_from_original_stable_diffusion_ckpt(
logger.debug(f"model_type = {model_type}; original_config_file = {original_config_file}") logger.debug(f"model_type = {model_type}; original_config_file = {original_config_file}")
precision_probing_key = "model.diffusion_model.input_blocks.0.0.bias"
logger.debug(f"original checkpoint precision == {checkpoint[precision_probing_key].dtype}")
precision = precision or checkpoint[precision_probing_key].dtype
if original_config_file is None: if original_config_file is None:
key_name_v2_1 = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" key_name_v2_1 = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
key_name_sd_xl_base = "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias" key_name_sd_xl_base = "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias"
@ -1279,9 +1284,12 @@ def download_from_original_stable_diffusion_ckpt(
original_config_file = BytesIO(requests.get(config_url).content) original_config_file = BytesIO(requests.get(config_url).content)
original_config = OmegaConf.load(original_config_file) original_config = OmegaConf.load(original_config_file)
if original_config["model"]["params"].get("use_ema") is not None:
extract_ema = original_config["model"]["params"]["use_ema"]
if ( if (
model_version == BaseModelType.StableDiffusion2 model_version == BaseModelType.StableDiffusion2
and original_config["model"]["params"]["parameterization"] == "v" and original_config["model"]["params"].get("parameterization") == "v"
): ):
prediction_type = "v_prediction" prediction_type = "v_prediction"
upcast_attention = True upcast_attention = True
@ -1447,7 +1455,7 @@ def download_from_original_stable_diffusion_ckpt(
if controlnet: if controlnet:
pipe = pipeline_class( pipe = pipeline_class(
vae=vae.to(precision), vae=vae.to(precision),
text_encoder=text_model, text_encoder=text_model.to(precision),
tokenizer=tokenizer, tokenizer=tokenizer,
unet=unet.to(precision), unet=unet.to(precision),
scheduler=scheduler, scheduler=scheduler,
@ -1459,7 +1467,7 @@ def download_from_original_stable_diffusion_ckpt(
else: else:
pipe = pipeline_class( pipe = pipeline_class(
vae=vae.to(precision), vae=vae.to(precision),
text_encoder=text_model, text_encoder=text_model.to(precision),
tokenizer=tokenizer, tokenizer=tokenizer,
unet=unet.to(precision), unet=unet.to(precision),
scheduler=scheduler, scheduler=scheduler,
@ -1484,8 +1492,8 @@ def download_from_original_stable_diffusion_ckpt(
image_noising_scheduler=image_noising_scheduler, image_noising_scheduler=image_noising_scheduler,
# regular denoising components # regular denoising components
tokenizer=tokenizer, tokenizer=tokenizer,
text_encoder=text_model, text_encoder=text_model.to(precision),
unet=unet, unet=unet.to(precision),
scheduler=scheduler, scheduler=scheduler,
# vae # vae
vae=vae, vae=vae,
@ -1560,7 +1568,7 @@ def download_from_original_stable_diffusion_ckpt(
if controlnet: if controlnet:
pipe = pipeline_class( pipe = pipeline_class(
vae=vae.to(precision), vae=vae.to(precision),
text_encoder=text_model, text_encoder=text_model.to(precision),
tokenizer=tokenizer, tokenizer=tokenizer,
unet=unet.to(precision), unet=unet.to(precision),
controlnet=controlnet, controlnet=controlnet,
@ -1571,7 +1579,7 @@ def download_from_original_stable_diffusion_ckpt(
else: else:
pipe = pipeline_class( pipe = pipeline_class(
vae=vae.to(precision), vae=vae.to(precision),
text_encoder=text_model, text_encoder=text_model.to(precision),
tokenizer=tokenizer, tokenizer=tokenizer,
unet=unet.to(precision), unet=unet.to(precision),
scheduler=scheduler, scheduler=scheduler,
@ -1594,9 +1602,9 @@ def download_from_original_stable_diffusion_ckpt(
pipe = StableDiffusionXLPipeline( pipe = StableDiffusionXLPipeline(
vae=vae.to(precision), vae=vae.to(precision),
text_encoder=text_encoder, text_encoder=text_encoder.to(precision),
tokenizer=tokenizer, tokenizer=tokenizer,
text_encoder_2=text_encoder_2, text_encoder_2=text_encoder_2.to(precision),
tokenizer_2=tokenizer_2, tokenizer_2=tokenizer_2,
unet=unet.to(precision), unet=unet.to(precision),
scheduler=scheduler, scheduler=scheduler,
@ -1639,7 +1647,7 @@ def download_controlnet_from_original_ckpt(
original_config_file: str, original_config_file: str,
image_size: int = 512, image_size: int = 512,
extract_ema: bool = False, extract_ema: bool = False,
precision: torch.dtype = torch.float32, precision: Optional[torch.dtype] = None,
num_in_channels: Optional[int] = None, num_in_channels: Optional[int] = None,
upcast_attention: Optional[bool] = None, upcast_attention: Optional[bool] = None,
device: str = None, device: str = None,
@ -1680,6 +1688,12 @@ def download_controlnet_from_original_ckpt(
while "state_dict" in checkpoint: while "state_dict" in checkpoint:
checkpoint = checkpoint["state_dict"] checkpoint = checkpoint["state_dict"]
# use original precision
precision_probing_key = "input_blocks.0.0.bias"
ckpt_precision = checkpoint[precision_probing_key].dtype
logger.debug(f"original controlnet precision = {ckpt_precision}")
precision = precision or ckpt_precision
original_config = OmegaConf.load(original_config_file) original_config = OmegaConf.load(original_config_file)
if num_in_channels is not None: if num_in_channels is not None:
@ -1699,7 +1713,7 @@ def download_controlnet_from_original_ckpt(
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
) )
return controlnet return controlnet.to(precision)
def convert_ldm_vae_to_diffusers(checkpoint, vae_config: DictConfig, image_size: int) -> AutoencoderKL: def convert_ldm_vae_to_diffusers(checkpoint, vae_config: DictConfig, image_size: int) -> AutoencoderKL:

View File

@ -187,7 +187,9 @@ class ModelCache(object):
# TODO: lock for no copies on simultaneous calls? # TODO: lock for no copies on simultaneous calls?
cache_entry = self._cached_models.get(key, None) cache_entry = self._cached_models.get(key, None)
if cache_entry is None: if cache_entry is None:
self.logger.info(f"Loading model {model_path}, type {base_model}:{model_type}:{submodel}") self.logger.info(
f"Loading model {model_path}, type {base_model.value}:{model_type.value}:{submodel.value if submodel else ''}"
)
# this will remove older cached models until # this will remove older cached models until
# there is sufficient room to load the requested model # there is sufficient room to load the requested model

View File

@ -423,7 +423,7 @@ class ModelManager(object):
return (model_name, base_model, model_type) return (model_name, base_model, model_type)
def _get_model_cache_path(self, model_path): def _get_model_cache_path(self, model_path):
return self.app_config.models_path / ".cache" / hashlib.md5(str(model_path).encode()).hexdigest() return self.resolve_model_path(Path(".cache") / hashlib.md5(str(model_path).encode()).hexdigest())
@classmethod @classmethod
def initialize_model_config(cls, config_path: Path): def initialize_model_config(cls, config_path: Path):
@ -456,7 +456,7 @@ class ModelManager(object):
raise ModelNotFoundException(f"Model not found - {model_key}") raise ModelNotFoundException(f"Model not found - {model_key}")
model_config = self.models[model_key] model_config = self.models[model_key]
model_path = self.app_config.root_path / model_config.path model_path = self.resolve_model_path(model_config.path)
if not model_path.exists(): if not model_path.exists():
if model_class.save_to_config: if model_class.save_to_config:
@ -586,7 +586,7 @@ class ModelManager(object):
# expose paths as absolute to help web UI # expose paths as absolute to help web UI
if path := model_dict.get("path"): if path := model_dict.get("path"):
model_dict["path"] = str(self.app_config.root_path / path) model_dict["path"] = str(self.resolve_model_path(path))
models.append(model_dict) models.append(model_dict)
return models return models
@ -623,7 +623,7 @@ class ModelManager(object):
self.cache.uncache_model(cache_id) self.cache.uncache_model(cache_id)
# if model inside invoke models folder - delete files # if model inside invoke models folder - delete files
model_path = self.app_config.root_path / model_cfg.path model_path = self.resolve_model_path(model_cfg.path)
cache_path = self._get_model_cache_path(model_path) cache_path = self._get_model_cache_path(model_path)
if cache_path.exists(): if cache_path.exists():
rmtree(str(cache_path)) rmtree(str(cache_path))
@ -654,10 +654,9 @@ class ModelManager(object):
The returned dict has the same format as the dict returned by The returned dict has the same format as the dict returned by
model_info(). model_info().
""" """
# relativize paths as they go in - this makes it easier to move the root directory around # relativize paths as they go in - this makes it easier to move the models directory around
if path := model_attributes.get("path"): if path := model_attributes.get("path"):
if Path(path).is_relative_to(self.app_config.root_path): model_attributes["path"] = str(self.relative_model_path(Path(path)))
model_attributes["path"] = str(Path(path).relative_to(self.app_config.root_path))
model_class = MODEL_CLASSES[base_model][model_type] model_class = MODEL_CLASSES[base_model][model_type]
model_config = model_class.create_config(**model_attributes) model_config = model_class.create_config(**model_attributes)
@ -715,7 +714,7 @@ class ModelManager(object):
if not model_cfg: if not model_cfg:
raise ModelNotFoundException(f"Unknown model: {model_key}") raise ModelNotFoundException(f"Unknown model: {model_key}")
old_path = self.app_config.root_path / model_cfg.path old_path = self.resolve_model_path(model_cfg.path)
new_name = new_name or model_name new_name = new_name or model_name
new_base = new_base or base_model new_base = new_base or base_model
new_key = self.create_key(new_name, new_base, model_type) new_key = self.create_key(new_name, new_base, model_type)
@ -724,15 +723,15 @@ class ModelManager(object):
# if this is a model file/directory that we manage ourselves, we need to move it # if this is a model file/directory that we manage ourselves, we need to move it
if old_path.is_relative_to(self.app_config.models_path): if old_path.is_relative_to(self.app_config.models_path):
new_path = ( new_path = self.resolve_model_path(
self.app_config.root_path Path(
/ "models" BaseModelType(new_base).value,
/ BaseModelType(new_base).value ModelType(model_type).value,
/ ModelType(model_type).value new_name,
/ new_name )
) )
move(old_path, new_path) move(old_path, new_path)
model_cfg.path = str(new_path.relative_to(self.app_config.root_path)) model_cfg.path = str(new_path.relative_to(self.app_config.models_path))
# clean up caches # clean up caches
old_model_cache = self._get_model_cache_path(old_path) old_model_cache = self._get_model_cache_path(old_path)
@ -782,7 +781,7 @@ class ModelManager(object):
**submodel, **submodel,
) )
checkpoint_path = self.app_config.root_path / info["path"] checkpoint_path = self.app_config.root_path / info["path"]
old_diffusers_path = self.app_config.models_path / model.location old_diffusers_path = self.resolve_model_path(model.location)
new_diffusers_path = ( new_diffusers_path = (
dest_directory or self.app_config.models_path / base_model.value / model_type.value dest_directory or self.app_config.models_path / base_model.value / model_type.value
) / model_name ) / model_name
@ -795,7 +794,7 @@ class ModelManager(object):
info["path"] = ( info["path"] = (
str(new_diffusers_path) str(new_diffusers_path)
if dest_directory if dest_directory
else str(new_diffusers_path.relative_to(self.app_config.root_path)) else str(new_diffusers_path.relative_to(self.app_config.models_path))
) )
info.pop("config") info.pop("config")
@ -810,6 +809,15 @@ class ModelManager(object):
return result return result
def resolve_model_path(self, path: Union[Path, str]) -> Path:
"""return relative paths based on configured models_path"""
return self.app_config.models_path / path
def relative_model_path(self, model_path: Path) -> Path:
if model_path.is_relative_to(self.app_config.models_path):
model_path = model_path.relative_to(self.app_config.models_path)
return model_path
def search_models(self, search_folder): def search_models(self, search_folder):
self.logger.info(f"Finding Models In: {search_folder}") self.logger.info(f"Finding Models In: {search_folder}")
models_folder_ckpt = Path(search_folder).glob("**/*.ckpt") models_folder_ckpt = Path(search_folder).glob("**/*.ckpt")
@ -883,10 +891,17 @@ class ModelManager(object):
new_models_found = False new_models_found = False
self.logger.info(f"Scanning {self.app_config.models_path} for new models") self.logger.info(f"Scanning {self.app_config.models_path} for new models")
with Chdir(self.app_config.root_path): with Chdir(self.app_config.models_path):
for model_key, model_config in list(self.models.items()): for model_key, model_config in list(self.models.items()):
model_name, cur_base_model, cur_model_type = self.parse_key(model_key) model_name, cur_base_model, cur_model_type = self.parse_key(model_key)
model_path = self.app_config.root_path.absolute() / model_config.path
# Patch for relative path bug in older models.yaml - paths should not
# be starting with a hard-coded 'models'. This will also fix up
# models.yaml when committed.
if model_config.path.startswith("models"):
model_config.path = str(Path(*Path(model_config.path).parts[1:]))
model_path = self.resolve_model_path(model_config.path).absolute()
if not model_path.exists(): if not model_path.exists():
model_class = MODEL_CLASSES[cur_base_model][cur_model_type] model_class = MODEL_CLASSES[cur_base_model][cur_model_type]
if model_class.save_to_config: if model_class.save_to_config:
@ -905,7 +920,7 @@ class ModelManager(object):
if model_type is not None and cur_model_type != model_type: if model_type is not None and cur_model_type != model_type:
continue continue
model_class = MODEL_CLASSES[cur_base_model][cur_model_type] model_class = MODEL_CLASSES[cur_base_model][cur_model_type]
models_dir = self.app_config.models_path / cur_base_model.value / cur_model_type.value models_dir = self.resolve_model_path(Path(cur_base_model.value, cur_model_type.value))
if not models_dir.exists(): if not models_dir.exists():
continue # TODO: or create all folders? continue # TODO: or create all folders?
@ -919,9 +934,7 @@ class ModelManager(object):
if model_key in self.models: if model_key in self.models:
raise DuplicateModelException(f"Model with key {model_key} added twice") raise DuplicateModelException(f"Model with key {model_key} added twice")
if model_path.is_relative_to(self.app_config.root_path): model_path = self.relative_model_path(model_path)
model_path = model_path.relative_to(self.app_config.root_path)
model_config: ModelConfigBase = model_class.probe_config(str(model_path)) model_config: ModelConfigBase = model_class.probe_config(str(model_path))
self.models[model_key] = model_config self.models[model_key] = model_config
new_models_found = True new_models_found = True
@ -932,12 +945,11 @@ class ModelManager(object):
except NotImplementedError as e: except NotImplementedError as e:
self.logger.warning(e) self.logger.warning(e)
imported_models = self.autoimport() imported_models = self.scan_autoimport_directory()
if (new_models_found or imported_models) and self.config_path: if (new_models_found or imported_models) and self.config_path:
self.commit() self.commit()
def autoimport(self) -> Dict[str, AddModelResult]: def scan_autoimport_directory(self) -> Dict[str, AddModelResult]:
""" """
Scan the autoimport directory (if defined) and import new models, delete defunct models. Scan the autoimport directory (if defined) and import new models, delete defunct models.
""" """
@ -971,7 +983,7 @@ class ModelManager(object):
# LS: hacky # LS: hacky
# Patch in the SD VAE from core so that it is available for use by the UI # Patch in the SD VAE from core so that it is available for use by the UI
try: try:
self.heuristic_import({config.root_path / "models/core/convert/sd-vae-ft-mse"}) self.heuristic_import({self.resolve_model_path("core/convert/sd-vae-ft-mse")})
except: except:
pass pass

View File

@ -17,6 +17,7 @@ from .base import (
ModelNotFoundException, ModelNotFoundException,
) )
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
import invokeai.backend.util.logging as logger
class ControlNetModelFormat(str, Enum): class ControlNetModelFormat(str, Enum):
@ -66,7 +67,7 @@ class ControlNetModel(ModelBase):
child_type: Optional[SubModelType] = None, child_type: Optional[SubModelType] = None,
): ):
if child_type is not None: if child_type is not None:
raise Exception("There is no child models in controlnet model") raise Exception("There are no child models in controlnet model")
model = None model = None
for variant in ["fp16", None]: for variant in ["fp16", None]:
@ -124,9 +125,7 @@ class ControlNetModel(ModelBase):
return model_path return model_path
@classmethod
def _convert_controlnet_ckpt_and_cache( def _convert_controlnet_ckpt_and_cache(
cls,
model_path: str, model_path: str,
output_path: str, output_path: str,
base_model: BaseModelType, base_model: BaseModelType,
@ -141,6 +140,7 @@ def _convert_controlnet_ckpt_and_cache(
weights = app_config.root_path / model_path weights = app_config.root_path / model_path
output_path = Path(output_path) output_path = Path(output_path)
logger.info(f"Converting {weights} to diffusers format")
# return cached version if it exists # return cached version if it exists
if output_path.exists(): if output_path.exists():
return output_path return output_path

View File

@ -123,6 +123,7 @@ class StableDiffusion1Model(DiffusersModel):
return _convert_ckpt_and_cache( return _convert_ckpt_and_cache(
version=BaseModelType.StableDiffusion1, version=BaseModelType.StableDiffusion1,
model_config=config, model_config=config,
load_safety_checker=False,
output_path=output_path, output_path=output_path,
) )
else: else:
@ -259,7 +260,7 @@ def _convert_ckpt_and_cache(
""" """
app_config = InvokeAIAppConfig.get_config() app_config = InvokeAIAppConfig.get_config()
weights = app_config.root_path / model_config.path weights = app_config.models_path / model_config.path
config_file = app_config.root_path / model_config.config config_file = app_config.root_path / model_config.config
output_path = Path(output_path) output_path = Path(output_path)

View File

@ -112,7 +112,7 @@ def main():
extras = get_extras() extras = get_extras()
print(f":crossed_fingers: Upgrading to [yellow]{tag if tag else release}[/yellow]") print(f":crossed_fingers: Upgrading to [yellow]{tag or release or branch}[/yellow]")
if release: if release:
cmd = f'pip install "invokeai{extras} @ {INVOKE_AI_SRC}/{release}.zip" --use-pep517 --upgrade' cmd = f'pip install "invokeai{extras} @ {INVOKE_AI_SRC}/{release}.zip" --use-pep517 --upgrade'
elif tag: elif tag:

View File

@ -58,6 +58,9 @@ logger = InvokeAILogger.getLogger()
# from https://stackoverflow.com/questions/92438/stripping-non-printable-characters-from-a-string-in-python # from https://stackoverflow.com/questions/92438/stripping-non-printable-characters-from-a-string-in-python
NOPRINT_TRANS_TABLE = {i: None for i in range(0, sys.maxunicode + 1) if not chr(i).isprintable()} NOPRINT_TRANS_TABLE = {i: None for i in range(0, sys.maxunicode + 1) if not chr(i).isprintable()}
# maximum number of installed models we can display before overflowing vertically
MAX_OTHER_MODELS = 72
def make_printable(s: str) -> str: def make_printable(s: str) -> str:
"""Replace non-printable characters in a string""" """Replace non-printable characters in a string"""
@ -102,7 +105,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
SingleSelectColumns, SingleSelectColumns,
values=[ values=[
"STARTER MODELS", "STARTER MODELS",
"MORE MODELS", "MAIN MODELS",
"CONTROLNETS", "CONTROLNETS",
"LORA/LYCORIS", "LORA/LYCORIS",
"TEXTUAL INVERSION", "TEXTUAL INVERSION",
@ -153,7 +156,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
BufferBox, BufferBox,
name="Log Messages", name="Log Messages",
editable=False, editable=False,
max_height=8, max_height=15,
) )
self.nextrely += 1 self.nextrely += 1
@ -253,6 +256,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
model_labels = [self.model_labels[x] for x in model_list] model_labels = [self.model_labels[x] for x in model_list]
show_recommended = len(self.installed_models) == 0 show_recommended = len(self.installed_models) == 0
truncated = False
if len(model_list) > 0: if len(model_list) > 0:
max_width = max([len(x) for x in model_labels]) max_width = max([len(x) for x in model_labels])
columns = window_width // (max_width + 8) # 8 characters for "[x] " and padding columns = window_width // (max_width + 8) # 8 characters for "[x] " and padding
@ -271,6 +275,10 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
) )
) )
if len(model_labels) > MAX_OTHER_MODELS:
model_labels = model_labels[0:MAX_OTHER_MODELS]
truncated = True
widgets.update( widgets.update(
models_selected=self.add_widget_intelligent( models_selected=self.add_widget_intelligent(
MultiSelectColumns, MultiSelectColumns,
@ -289,6 +297,16 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
models=model_list, models=model_list,
) )
if truncated:
widgets.update(
warning_message=self.add_widget_intelligent(
npyscreen.FixedText,
value=f"Too many models to display (max={MAX_OTHER_MODELS}). Some are not displayed.",
editable=False,
color="CAUTION",
)
)
self.nextrely += 1 self.nextrely += 1
widgets.update( widgets.update(
download_ids=self.add_widget_intelligent( download_ids=self.add_widget_intelligent(
@ -313,7 +331,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
widgets = self.add_model_widgets( widgets = self.add_model_widgets(
model_type=model_type, model_type=model_type,
window_width=window_width, window_width=window_width,
install_prompt=f"Additional {model_type.value.title()} models already installed.", install_prompt=f"Installed {model_type.value.title()} models. Unchecked models in the InvokeAI root directory will be deleted. Enter URLs, paths or repo_ids to import.",
**kwargs, **kwargs,
) )
@ -399,7 +417,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
self.ok_button.hidden = True self.ok_button.hidden = True
self.display() self.display()
# for communication with the subprocess # TO DO: Spawn a worker thread, not a subprocess
parent_conn, child_conn = Pipe() parent_conn, child_conn = Pipe()
p = Process( p = Process(
target=process_and_execute, target=process_and_execute,
@ -414,7 +432,6 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
self.subprocess_connection = parent_conn self.subprocess_connection = parent_conn
self.subprocess = p self.subprocess = p
app.install_selections = InstallSelections() app.install_selections = InstallSelections()
# process_and_execute(app.opt, app.install_selections)
def on_back(self): def on_back(self):
self.parentApp.switchFormPrevious() self.parentApp.switchFormPrevious()
@ -489,8 +506,6 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
# rebuild the form, saving and restoring some of the fields that need to be preserved. # rebuild the form, saving and restoring some of the fields that need to be preserved.
saved_messages = self.monitor.entry_widget.values saved_messages = self.monitor.entry_widget.values
# autoload_dir = str(config.root_path / self.pipeline_models['autoload_directory'].value)
# autoscan = self.pipeline_models['autoscan_on_startup'].value
app.main_form = app.addForm( app.main_form = app.addForm(
"MAIN", "MAIN",
@ -544,12 +559,6 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
if downloads := section.get("download_ids"): if downloads := section.get("download_ids"):
selections.install_models.extend(downloads.value.split()) selections.install_models.extend(downloads.value.split())
# load directory and whether to scan on startup
# if self.parentApp.autoload_pending:
# selections.scan_directory = str(config.root_path / self.pipeline_models['autoload_directory'].value)
# self.parentApp.autoload_pending = False
# selections.autoscan_on_startup = self.pipeline_models['autoscan_on_startup'].value
class AddModelApplication(npyscreen.NPSAppManaged): class AddModelApplication(npyscreen.NPSAppManaged):
def __init__(self, opt): def __init__(self, opt):
@ -639,6 +648,11 @@ def process_and_execute(
selections: InstallSelections, selections: InstallSelections,
conn_out: Connection = None, conn_out: Connection = None,
): ):
# need to reinitialize config in subprocess
config = InvokeAIAppConfig.get_config()
args = ["--root", opt.root] if opt.root else []
config.parse_args(args)
# set up so that stderr is sent to conn_out # set up so that stderr is sent to conn_out
if conn_out: if conn_out:
translator = StderrToMessage(conn_out) translator = StderrToMessage(conn_out)
@ -656,38 +670,11 @@ def process_and_execute(
conn_out.close() conn_out.close()
def do_listings(opt) -> bool:
"""List installed models of various sorts, and return
True if any were requested."""
model_manager = ModelManager(config.model_conf_path)
if opt.list_models == "diffusers":
print("Diffuser models:")
model_manager.print_models()
elif opt.list_models == "controlnets":
print("Installed Controlnet Models:")
cnm = model_manager.list_controlnet_models()
print(textwrap.indent("\n".join([x for x in cnm if cnm[x]]), prefix=" "))
elif opt.list_models == "loras":
print("Installed LoRA/LyCORIS Models:")
cnm = model_manager.list_lora_models()
print(textwrap.indent("\n".join([x for x in cnm if cnm[x]]), prefix=" "))
elif opt.list_models == "tis":
print("Installed Textual Inversion Embeddings:")
cnm = model_manager.list_ti_models()
print(textwrap.indent("\n".join([x for x in cnm if cnm[x]]), prefix=" "))
else:
return False
return True
# -------------------------------------------------------- # --------------------------------------------------------
def select_and_download_models(opt: Namespace): def select_and_download_models(opt: Namespace):
precision = "float32" if opt.full_precision else choose_precision(torch.device(choose_torch_device())) precision = "float32" if opt.full_precision else choose_precision(torch.device(choose_torch_device()))
config.precision = precision config.precision = precision
helper = lambda x: ask_user_for_prediction_type(x) helper = lambda x: ask_user_for_prediction_type(x)
# if do_listings(opt):
# pass
installer = ModelInstall(config, prediction_type_helper=helper) installer = ModelInstall(config, prediction_type_helper=helper)
if opt.list_models: if opt.list_models:
installer.list_models(opt.list_models) installer.list_models(opt.list_models)
@ -706,8 +693,6 @@ def select_and_download_models(opt: Namespace):
# needed to support the probe() method running under a subprocess # needed to support the probe() method running under a subprocess
torch.multiprocessing.set_start_method("spawn") torch.multiprocessing.set_start_method("spawn")
# the third argument is needed in the Windows 11 environment in
# order to launch and resize a console window running this program
set_min_terminal_size(MIN_COLS, MIN_LINES) set_min_terminal_size(MIN_COLS, MIN_LINES)
installApp = AddModelApplication(opt) installApp = AddModelApplication(opt)
try: try:

View File

@ -320,7 +320,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
def get_model_names(self, base_model: BaseModelType = None) -> List[str]: def get_model_names(self, base_model: BaseModelType = None) -> List[str]:
model_names = [ model_names = [
info["name"] info["model_name"]
for info in self.model_manager.list_models(model_type=ModelType.Main, base_model=base_model) for info in self.model_manager.list_models(model_type=ModelType.Main, base_model=base_model)
if info["model_format"] == "diffusers" if info["model_format"] == "diffusers"
] ]

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -12,7 +12,7 @@
margin: 0; margin: 0;
} }
</style> </style>
<script type="module" crossorigin src="./assets/index-bad7ff83.js"></script> <script type="module" crossorigin src="./assets/index-9bb68e3a.js"></script>
</head> </head>
<body dir="ltr"> <body dir="ltr">

View File

@ -139,8 +139,19 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
useHotkeys('s', handleUseSeed, [imageDTO]); useHotkeys('s', handleUseSeed, [imageDTO]);
const handleUsePrompt = useCallback(() => { const handleUsePrompt = useCallback(() => {
recallBothPrompts(metadata?.positive_prompt, metadata?.negative_prompt); recallBothPrompts(
}, [metadata?.negative_prompt, metadata?.positive_prompt, recallBothPrompts]); metadata?.positive_prompt,
metadata?.negative_prompt,
metadata?.positive_style_prompt,
metadata?.negative_style_prompt
);
}, [
metadata?.negative_prompt,
metadata?.positive_prompt,
metadata?.positive_style_prompt,
metadata?.negative_style_prompt,
recallBothPrompts,
]);
useHotkeys('p', handleUsePrompt, [imageDTO]); useHotkeys('p', handleUsePrompt, [imageDTO]);

View File

@ -102,8 +102,19 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
// Recall parameters handlers // Recall parameters handlers
const handleRecallPrompt = useCallback(() => { const handleRecallPrompt = useCallback(() => {
recallBothPrompts(metadata?.positive_prompt, metadata?.negative_prompt); recallBothPrompts(
}, [metadata?.negative_prompt, metadata?.positive_prompt, recallBothPrompts]); metadata?.positive_prompt,
metadata?.negative_prompt,
metadata?.positive_style_prompt,
metadata?.negative_style_prompt
);
}, [
metadata?.negative_prompt,
metadata?.positive_prompt,
metadata?.positive_style_prompt,
metadata?.negative_style_prompt,
recallBothPrompts,
]);
const handleRecallSeed = useCallback(() => { const handleRecallSeed = useCallback(() => {
recallSeed(metadata?.seed); recallSeed(metadata?.seed);

View File

@ -1,4 +1,4 @@
import { Input } from '@chakra-ui/react'; import { Input, Textarea } from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import { fieldValueChanged } from 'features/nodes/store/nodesSlice'; import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import { import {
@ -12,10 +12,11 @@ const StringInputFieldComponent = (
props: FieldComponentProps<StringInputFieldValue, StringInputFieldTemplate> props: FieldComponentProps<StringInputFieldValue, StringInputFieldTemplate>
) => { ) => {
const { nodeId, field } = props; const { nodeId, field } = props;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const handleValueChanged = (e: ChangeEvent<HTMLInputElement>) => { const handleValueChanged = (
e: ChangeEvent<HTMLInputElement | HTMLTextAreaElement>
) => {
dispatch( dispatch(
fieldValueChanged({ fieldValueChanged({
nodeId, nodeId,
@ -25,7 +26,11 @@ const StringInputFieldComponent = (
); );
}; };
return <Input onChange={handleValueChanged} value={field.value}></Input>; return ['prompt', 'style'].includes(field.name.toLowerCase()) ? (
<Textarea onChange={handleValueChanged} value={field.value} rows={2} />
) : (
<Input onChange={handleValueChanged} value={field.value} />
);
}; };
export default memo(StringInputFieldComponent); export default memo(StringInputFieldComponent);

View File

@ -1,5 +1,15 @@
import { useAppToaster } from 'app/components/Toaster'; import { useAppToaster } from 'app/components/Toaster';
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import {
refinerModelChanged,
setNegativeStylePromptSDXL,
setPositiveStylePromptSDXL,
setRefinerAestheticScore,
setRefinerCFGScale,
setRefinerScheduler,
setRefinerStart,
setRefinerSteps,
} from 'features/sdxl/store/sdxlSlice';
import { useCallback } from 'react'; import { useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { UnsafeImageMetadata } from 'services/api/endpoints/images'; import { UnsafeImageMetadata } from 'services/api/endpoints/images';
@ -22,6 +32,10 @@ import {
isValidMainModel, isValidMainModel,
isValidNegativePrompt, isValidNegativePrompt,
isValidPositivePrompt, isValidPositivePrompt,
isValidSDXLNegativeStylePrompt,
isValidSDXLPositiveStylePrompt,
isValidSDXLRefinerAestheticScore,
isValidSDXLRefinerStart,
isValidScheduler, isValidScheduler,
isValidSeed, isValidSeed,
isValidSteps, isValidSteps,
@ -74,17 +88,34 @@ export const useRecallParameters = () => {
* Recall both prompts with toast * Recall both prompts with toast
*/ */
const recallBothPrompts = useCallback( const recallBothPrompts = useCallback(
(positivePrompt: unknown, negativePrompt: unknown) => { (
positivePrompt: unknown,
negativePrompt: unknown,
positiveStylePrompt: unknown,
negativeStylePrompt: unknown
) => {
if ( if (
isValidPositivePrompt(positivePrompt) || isValidPositivePrompt(positivePrompt) ||
isValidNegativePrompt(negativePrompt) isValidNegativePrompt(negativePrompt) ||
isValidSDXLPositiveStylePrompt(positiveStylePrompt) ||
isValidSDXLNegativeStylePrompt(negativeStylePrompt)
) { ) {
if (isValidPositivePrompt(positivePrompt)) { if (isValidPositivePrompt(positivePrompt)) {
dispatch(setPositivePrompt(positivePrompt)); dispatch(setPositivePrompt(positivePrompt));
} }
if (isValidNegativePrompt(negativePrompt)) { if (isValidNegativePrompt(negativePrompt)) {
dispatch(setNegativePrompt(negativePrompt)); dispatch(setNegativePrompt(negativePrompt));
} }
if (isValidSDXLPositiveStylePrompt(positiveStylePrompt)) {
dispatch(setPositiveStylePromptSDXL(positiveStylePrompt));
}
if (isValidSDXLPositiveStylePrompt(negativeStylePrompt)) {
dispatch(setNegativeStylePromptSDXL(negativeStylePrompt));
}
parameterSetToast(); parameterSetToast();
return; return;
} }
@ -123,6 +154,36 @@ export const useRecallParameters = () => {
[dispatch, parameterSetToast, parameterNotSetToast] [dispatch, parameterSetToast, parameterNotSetToast]
); );
/**
* Recall SDXL Positive Style Prompt with toast
*/
const recallSDXLPositiveStylePrompt = useCallback(
(positiveStylePrompt: unknown) => {
if (!isValidSDXLPositiveStylePrompt(positiveStylePrompt)) {
parameterNotSetToast();
return;
}
dispatch(setPositiveStylePromptSDXL(positiveStylePrompt));
parameterSetToast();
},
[dispatch, parameterSetToast, parameterNotSetToast]
);
/**
* Recall SDXL Negative Style Prompt with toast
*/
const recallSDXLNegativeStylePrompt = useCallback(
(negativeStylePrompt: unknown) => {
if (!isValidSDXLNegativeStylePrompt(negativeStylePrompt)) {
parameterNotSetToast();
return;
}
dispatch(setNegativeStylePromptSDXL(negativeStylePrompt));
parameterSetToast();
},
[dispatch, parameterSetToast, parameterNotSetToast]
);
/** /**
* Recall seed with toast * Recall seed with toast
*/ */
@ -271,6 +332,14 @@ export const useRecallParameters = () => {
steps, steps,
width, width,
strength, strength,
positive_style_prompt,
negative_style_prompt,
refiner_model,
refiner_cfg_scale,
refiner_steps,
refiner_scheduler,
refiner_aesthetic_store,
refiner_start,
} = metadata; } = metadata;
if (isValidCfgScale(cfg_scale)) { if (isValidCfgScale(cfg_scale)) {
@ -304,6 +373,38 @@ export const useRecallParameters = () => {
dispatch(setImg2imgStrength(strength)); dispatch(setImg2imgStrength(strength));
} }
if (isValidSDXLPositiveStylePrompt(positive_style_prompt)) {
dispatch(setPositiveStylePromptSDXL(positive_style_prompt));
}
if (isValidSDXLNegativeStylePrompt(negative_style_prompt)) {
dispatch(setNegativeStylePromptSDXL(negative_style_prompt));
}
if (isValidMainModel(refiner_model)) {
dispatch(refinerModelChanged(refiner_model));
}
if (isValidSteps(refiner_steps)) {
dispatch(setRefinerSteps(refiner_steps));
}
if (isValidCfgScale(refiner_cfg_scale)) {
dispatch(setRefinerCFGScale(refiner_cfg_scale));
}
if (isValidScheduler(refiner_scheduler)) {
dispatch(setRefinerScheduler(refiner_scheduler));
}
if (isValidSDXLRefinerAestheticScore(refiner_aesthetic_store)) {
dispatch(setRefinerAestheticScore(refiner_aesthetic_store));
}
if (isValidSDXLRefinerStart(refiner_start)) {
dispatch(setRefinerStart(refiner_start));
}
allParameterSetToast(); allParameterSetToast();
}, },
[allParameterNotSetToast, allParameterSetToast, dispatch] [allParameterNotSetToast, allParameterSetToast, dispatch]
@ -313,6 +414,8 @@ export const useRecallParameters = () => {
recallBothPrompts, recallBothPrompts,
recallPositivePrompt, recallPositivePrompt,
recallNegativePrompt, recallNegativePrompt,
recallSDXLPositiveStylePrompt,
recallSDXLNegativeStylePrompt,
recallSeed, recallSeed,
recallCfgScale, recallCfgScale,
recallModel, recallModel,

View File

@ -324,6 +324,39 @@ export type PrecisionParam = z.infer<typeof zPrecision>;
export const isValidPrecision = (val: unknown): val is PrecisionParam => export const isValidPrecision = (val: unknown): val is PrecisionParam =>
zPrecision.safeParse(val).success; zPrecision.safeParse(val).success;
/**
* Zod schema for SDXL refiner aesthetic score parameter
*/
export const zSDXLRefinerAestheticScore = z.number().min(1).max(10);
/**
* Type alias for SDXL refiner aesthetic score parameter, inferred from its zod schema
*/
export type SDXLRefinerAestheticScoreParam = z.infer<
typeof zSDXLRefinerAestheticScore
>;
/**
* Validates/type-guards a value as a SDXL refiner aesthetic score parameter
*/
export const isValidSDXLRefinerAestheticScore = (
val: unknown
): val is SDXLRefinerAestheticScoreParam =>
zSDXLRefinerAestheticScore.safeParse(val).success;
/**
* Zod schema for SDXL start parameter
*/
export const zSDXLRefinerstart = z.number().min(0).max(1);
/**
* Type alias for SDXL start, inferred from its zod schema
*/
export type SDXLRefinerStartParam = z.infer<typeof zSDXLRefinerstart>;
/**
* Validates/type-guards a value as a SDXL refiner aesthetic score parameter
*/
export const isValidSDXLRefinerStart = (
val: unknown
): val is SDXLRefinerStartParam => zSDXLRefinerstart.safeParse(val).success;
// /** // /**
// * Zod schema for BaseModelType // * Zod schema for BaseModelType
// */ // */

View File

@ -21,8 +21,8 @@ export default function ParamSDXLConcatButton() {
return ( return (
<IAIIconButton <IAIIconButton
aria-label="Concat" aria-label="Concatenate Prompt & Style"
tooltip="Concatenates Basic Prompt with Style (Recommended)" tooltip="Concatenate Prompt & Style"
variant="outline" variant="outline"
isChecked={shouldConcatSDXLStylePrompt} isChecked={shouldConcatSDXLStylePrompt}
onClick={handleShouldConcatPromptChange} onClick={handleShouldConcatPromptChange}

View File

@ -1381,7 +1381,7 @@ export type components = {
* @description The nodes in this graph * @description The nodes in this graph
*/ */
nodes?: { nodes?: {
[key: string]: (components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["ParamStringInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["SDXLRawPromptInvocation"] | components["schemas"]["SDXLRefinerRawPromptInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["ONNXPromptInvocation"] | components["schemas"]["ONNXTextToLatentsInvocation"] | components["schemas"]["ONNXLatentsToImageInvocation"] | components["schemas"]["ONNXSD1ModelLoaderInvocation"] | components["schemas"]["OnnxModelLoaderInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["SDXLTextToLatentsInvocation"] | components["schemas"]["SDXLLatentsToLatentsInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"]) | undefined; [key: string]: (components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["SDXLRawPromptInvocation"] | components["schemas"]["SDXLRefinerRawPromptInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["ParamStringInvocation"] | components["schemas"]["ParamPromptInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["SDXLTextToLatentsInvocation"] | components["schemas"]["SDXLLatentsToLatentsInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"]) | undefined;
}; };
/** /**
* Edges * Edges
@ -1424,7 +1424,7 @@ export type components = {
* @description The results of node executions * @description The results of node executions
*/ */
results: { results: {
[key: string]: (components["schemas"]["ImageOutput"] | components["schemas"]["MaskOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["LoraLoaderOutput"] | components["schemas"]["VaeLoaderOutput"] | components["schemas"]["MetadataAccumulatorOutput"] | components["schemas"]["PromptOutput"] | components["schemas"]["PromptCollectionOutput"] | components["schemas"]["IntOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["CompelOutput"] | components["schemas"]["ClipSkipInvocationOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["IntCollectionOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["ONNXModelLoaderOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["GraphInvocationOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["CollectInvocationOutput"]) | undefined; [key: string]: (components["schemas"]["ImageOutput"] | components["schemas"]["MaskOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["LoraLoaderOutput"] | components["schemas"]["VaeLoaderOutput"] | components["schemas"]["MetadataAccumulatorOutput"] | components["schemas"]["IntCollectionOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["CompelOutput"] | components["schemas"]["ClipSkipInvocationOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["IntOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["PromptOutput"] | components["schemas"]["PromptCollectionOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["GraphInvocationOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["CollectInvocationOutput"]) | undefined;
}; };
/** /**
* Errors * Errors
@ -4265,6 +4265,35 @@ export type components = {
*/ */
a?: number; a?: number;
}; };
/**
* ParamPromptInvocation
* @description A prompt input parameter
*/
ParamPromptInvocation: {
/**
* Id
* @description The id of this node. Must be unique among all nodes.
*/
id: string;
/**
* Is Intermediate
* @description Whether or not this node is an intermediate node.
* @default false
*/
is_intermediate?: boolean;
/**
* Type
* @default param_prompt
* @enum {string}
*/
type?: "param_prompt";
/**
* Prompt
* @description The prompt value
* @default
*/
prompt?: string;
};
/** /**
* ParamStringInvocation * ParamStringInvocation
* @description A string parameter * @description A string parameter
@ -5874,24 +5903,18 @@ export type components = {
* @enum {string} * @enum {string}
*/ */
StableDiffusionXLModelFormat: "checkpoint" | "diffusers"; StableDiffusionXLModelFormat: "checkpoint" | "diffusers";
/**
* ControlNetModelFormat
* @description An enumeration.
* @enum {string}
*/
ControlNetModelFormat: "checkpoint" | "diffusers";
/**
* StableDiffusion2ModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
/** /**
* StableDiffusion1ModelFormat * StableDiffusion1ModelFormat
* @description An enumeration. * @description An enumeration.
* @enum {string} * @enum {string}
*/ */
StableDiffusion1ModelFormat: "checkpoint" | "diffusers"; StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
/**
* ControlNetModelFormat
* @description An enumeration.
* @enum {string}
*/
ControlNetModelFormat: "checkpoint" | "diffusers";
}; };
responses: never; responses: never;
parameters: never; parameters: never;
@ -6002,7 +6025,7 @@ export type operations = {
}; };
requestBody: { requestBody: {
content: { content: {
"application/json": components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["ParamStringInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["SDXLRawPromptInvocation"] | components["schemas"]["SDXLRefinerRawPromptInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["ONNXPromptInvocation"] | components["schemas"]["ONNXTextToLatentsInvocation"] | components["schemas"]["ONNXLatentsToImageInvocation"] | components["schemas"]["ONNXSD1ModelLoaderInvocation"] | components["schemas"]["OnnxModelLoaderInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["SDXLTextToLatentsInvocation"] | components["schemas"]["SDXLLatentsToLatentsInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"]; "application/json": components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["SDXLRawPromptInvocation"] | components["schemas"]["SDXLRefinerRawPromptInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["ParamStringInvocation"] | components["schemas"]["ParamPromptInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["SDXLTextToLatentsInvocation"] | components["schemas"]["SDXLLatentsToLatentsInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"];
}; };
}; };
responses: { responses: {
@ -6039,7 +6062,7 @@ export type operations = {
}; };
requestBody: { requestBody: {
content: { content: {
"application/json": components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["ParamStringInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["SDXLRawPromptInvocation"] | components["schemas"]["SDXLRefinerRawPromptInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["ONNXPromptInvocation"] | components["schemas"]["ONNXTextToLatentsInvocation"] | components["schemas"]["ONNXLatentsToImageInvocation"] | components["schemas"]["ONNXSD1ModelLoaderInvocation"] | components["schemas"]["OnnxModelLoaderInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["SDXLTextToLatentsInvocation"] | components["schemas"]["SDXLLatentsToLatentsInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"]; "application/json": components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["SDXLRawPromptInvocation"] | components["schemas"]["SDXLRefinerRawPromptInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["ParamStringInvocation"] | components["schemas"]["ParamPromptInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["SDXLTextToLatentsInvocation"] | components["schemas"]["SDXLLatentsToLatentsInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"];
}; };
}; };
responses: { responses: {

View File

@ -1 +1 @@
__version__ = "3.0.1" __version__ = "3.0.1post3"

View File

@ -1,281 +1,283 @@
{ {
"cells": [ "cells": [
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"id": "ycYWcsEKc6w7" "id": "ycYWcsEKc6w7"
}, },
"source": [ "source": [
"# Stable Diffusion AI Notebook (Release 2.0.0)\n", "# Stable Diffusion AI Notebook (Release 2.0.0)\n",
"\n", "\n",
"<img src=\"https://user-images.githubusercontent.com/60411196/186547976-d9de378a-9de8-4201-9c25-c057a9c59bad.jpeg\" alt=\"stable-diffusion-ai\" width=\"170px\"/> <br>\n", "<img src=\"https://user-images.githubusercontent.com/60411196/186547976-d9de378a-9de8-4201-9c25-c057a9c59bad.jpeg\" alt=\"stable-diffusion-ai\" width=\"170px\"/> <br>\n",
"#### Instructions:\n", "#### Instructions:\n",
"1. Execute each cell in order to mount a Dream bot and create images from text. <br>\n", "1. Execute each cell in order to mount a Dream bot and create images from text. <br>\n",
"2. Once cells 1-8 were run correctly you'll be executing a terminal in cell #9, you'll need to enter `python scripts/dream.py` command to run Dream bot.<br> \n", "2. Once cells 1-8 were run correctly you'll be executing a terminal in cell #9, you'll need to enter `python scripts/dream.py` command to run Dream bot.<br> \n",
"3. After launching dream bot, you'll see: <br> `Dream > ` in terminal. <br> Insert a command, eg. `Dream > Astronaut floating in a distant galaxy`, or type `-h` for help.\n", "3. After launching dream bot, you'll see: <br> `Dream > ` in terminal. <br> Insert a command, eg. `Dream > Astronaut floating in a distant galaxy`, or type `-h` for help.\n",
"3. After completion you'll see your generated images in path `stable-diffusion/outputs/img-samples/`, you can also show last generated images in cell #10.\n", "3. After completion you'll see your generated images in path `stable-diffusion/outputs/img-samples/`, you can also show last generated images in cell #10.\n",
"4. To quit Dream bot use `q` command. <br> \n", "4. To quit Dream bot use `q` command. <br> \n",
"---\n", "---\n",
"<font color=\"red\">Note:</font> It takes some time to load, but after installing all dependencies you can use the bot all time you want while colab instance is up. <br>\n", "<font color=\"red\">Note:</font> It takes some time to load, but after installing all dependencies you can use the bot all time you want while colab instance is up. <br>\n",
"<font color=\"red\">Requirements:</font> For this notebook to work you need to have [Stable-Diffusion-v-1-4](https://huggingface.co/CompVis/stable-diffusion-v-1-4-original) stored in your Google Drive, it will be needed in cell #7\n", "<font color=\"red\">Requirements:</font> For this notebook to work you need to have [Stable-Diffusion-v-1-4](https://huggingface.co/CompVis/stable-diffusion-v-1-4-original) stored in your Google Drive, it will be needed in cell #7\n",
"##### For more details visit Github repository: [invoke-ai/InvokeAI](https://github.com/invoke-ai/InvokeAI)\n", "##### For more details visit Github repository: [invoke-ai/InvokeAI](https://github.com/invoke-ai/InvokeAI)\n",
"---\n" "---\n"
] ]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dr32VLxlnouf"
},
"source": [
"## ◢ Installation"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "a2Z5Qu_o8VtQ"
},
"outputs": [],
"source": [
"#@title 1. Check current GPU assigned\n",
"!nvidia-smi -L\n",
"!nvidia-smi"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "vbI9ZsQHzjqF"
},
"outputs": [],
"source": [
"#@title 2. Download stable-diffusion Repository\n",
"from os.path import exists\n",
"\n",
"!git clone --quiet https://github.com/invoke-ai/InvokeAI.git # Original repo\n",
"%cd /content/InvokeAI/\n",
"!git checkout --quiet tags/v2.0.0"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "QbXcGXYEFSNB"
},
"outputs": [],
"source": [
"#@title 3. Install dependencies\n",
"import gc\n",
"\n",
"!wget https://raw.githubusercontent.com/invoke-ai/InvokeAI/development/environments-and-requirements/requirements-base.txt\n",
"!wget https://raw.githubusercontent.com/invoke-ai/InvokeAI/development/environments-and-requirements/requirements-win-colab-cuda.txt\n",
"!pip install colab-xterm\n",
"!pip install -r requirements-lin-win-colab-CUDA.txt\n",
"!pip install clean-fid torchtext\n",
"!pip install transformers\n",
"gc.collect()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "8rSMhgnAttQa"
},
"outputs": [],
"source": [
"#@title 4. Restart Runtime\n",
"exit()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "ChIDWxLVHGGJ"
},
"outputs": [],
"source": [
"#@title 5. Load small ML models required\n",
"import gc\n",
"%cd /content/InvokeAI/\n",
"!python scripts/preload_models.py\n",
"gc.collect()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "795x1tMoo8b1"
},
"source": [
"## ◢ Configuration"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "YEWPV-sF1RDM"
},
"outputs": [],
"source": [
"#@title 6. Mount google Drive\n",
"from google.colab import drive\n",
"drive.mount('/content/drive')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "zRTJeZ461WGu"
},
"outputs": [],
"source": [
"#@title 7. Drive Path to model\n",
"#@markdown Path should start with /content/drive/path-to-your-file <br>\n",
"#@markdown <font color=\"red\">Note:</font> Model should be downloaded from https://huggingface.co <br>\n",
"#@markdown Lastest release: [Stable-Diffusion-v-1-4](https://huggingface.co/CompVis/stable-diffusion-v-1-4-original)\n",
"from os.path import exists\n",
"\n",
"model_path = \"\" #@param {type:\"string\"}\n",
"if exists(model_path):\n",
" print(\"✅ Valid directory\")\n",
"else: \n",
" print(\"❌ File doesn't exist\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "UY-NNz4I8_aG"
},
"outputs": [],
"source": [
"#@title 8. Symlink to model\n",
"\n",
"from os.path import exists\n",
"import os \n",
"\n",
"# Folder creation if it doesn't exist\n",
"if exists(\"/content/InvokeAI/models/ldm/stable-diffusion-v1\"):\n",
" print(\"❗ Dir stable-diffusion-v1 already exists\")\n",
"else:\n",
" %mkdir /content/InvokeAI/models/ldm/stable-diffusion-v1\n",
" print(\"✅ Dir stable-diffusion-v1 created\")\n",
"\n",
"# Symbolic link if it doesn't exist\n",
"if exists(\"/content/InvokeAI/models/ldm/stable-diffusion-v1/model.ckpt\"):\n",
" print(\"❗ Symlink already created\")\n",
"else: \n",
" src = model_path\n",
" dst = '/content/InvokeAI/models/ldm/stable-diffusion-v1/model.ckpt'\n",
" os.symlink(src, dst) \n",
" print(\"✅ Symbolic link created successfully\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Mc28N0_NrCQH"
},
"source": [
"## ◢ Execution"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "ir4hCrMIuUpl"
},
"outputs": [],
"source": [
"#@title 9. Run Terminal and Execute Dream bot\n",
"#@markdown <font color=\"blue\">Steps:</font> <br>\n",
"#@markdown 1. Execute command `python scripts/invoke.py` to run InvokeAI.<br>\n",
"#@markdown 2. After initialized you'll see `Dream>` line.<br>\n",
"#@markdown 3. Example text: `Astronaut floating in a distant galaxy` <br>\n",
"#@markdown 4. To quit Dream bot use: `q` command.<br>\n",
"\n",
"%load_ext colabxterm\n",
"%xterm\n",
"gc.collect()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "qnLohSHmKoGk"
},
"outputs": [],
"source": [
"#@title 10. Show the last 15 generated images\n",
"import glob\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib.image as mpimg\n",
"%matplotlib inline\n",
"\n",
"images = []\n",
"for img_path in sorted(glob.glob('/content/InvokeAI/outputs/img-samples/*.png'), reverse=True):\n",
" images.append(mpimg.imread(img_path))\n",
"\n",
"images = images[:15] \n",
"\n",
"plt.figure(figsize=(20,10))\n",
"\n",
"columns = 5\n",
"for i, image in enumerate(images):\n",
" ax = plt.subplot(len(images) / columns + 1, columns, i + 1)\n",
" ax.axes.xaxis.set_visible(False)\n",
" ax.axes.yaxis.set_visible(False)\n",
" ax.axis('off')\n",
" plt.imshow(image)\n",
" gc.collect()\n",
"\n"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"private_outputs": true,
"provenance": []
},
"gpuClass": "standard",
"kernelspec": {
"display_name": "Python 3.9.12 64-bit",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.9.12"
},
"vscode": {
"interpreter": {
"hash": "4e870c5c5fe42db7e2c5647ae5af656ff3391bf8c2b729cbf7fa0e16ca8cb5af"
}
}
}, },
"nbformat": 4, {
"nbformat_minor": 0 "cell_type": "markdown",
"metadata": {
"id": "dr32VLxlnouf"
},
"source": [
"## ◢ Installation"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "a2Z5Qu_o8VtQ"
},
"outputs": [],
"source": [
"# @title 1. Check current GPU assigned\n",
"!nvidia-smi -L\n",
"!nvidia-smi"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "vbI9ZsQHzjqF"
},
"outputs": [],
"source": [
"# @title 2. Download stable-diffusion Repository\n",
"from os.path import exists\n",
"\n",
"!git clone --quiet https://github.com/invoke-ai/InvokeAI.git # Original repo\n",
"%cd /content/InvokeAI/\n",
"!git checkout --quiet tags/v2.0.0"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "QbXcGXYEFSNB"
},
"outputs": [],
"source": [
"# @title 3. Install dependencies\n",
"import gc\n",
"\n",
"!wget https://raw.githubusercontent.com/invoke-ai/InvokeAI/development/environments-and-requirements/requirements-base.txt\n",
"!wget https://raw.githubusercontent.com/invoke-ai/InvokeAI/development/environments-and-requirements/requirements-win-colab-cuda.txt\n",
"!pip install colab-xterm\n",
"!pip install -r requirements-lin-win-colab-CUDA.txt\n",
"!pip install clean-fid torchtext\n",
"!pip install transformers\n",
"gc.collect()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "8rSMhgnAttQa"
},
"outputs": [],
"source": [
"# @title 4. Restart Runtime\n",
"exit()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "ChIDWxLVHGGJ"
},
"outputs": [],
"source": [
"# @title 5. Load small ML models required\n",
"import gc\n",
"\n",
"%cd /content/InvokeAI/\n",
"!python scripts/preload_models.py\n",
"gc.collect()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "795x1tMoo8b1"
},
"source": [
"## ◢ Configuration"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "YEWPV-sF1RDM"
},
"outputs": [],
"source": [
"# @title 6. Mount google Drive\n",
"from google.colab import drive\n",
"\n",
"drive.mount(\"/content/drive\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "zRTJeZ461WGu"
},
"outputs": [],
"source": [
"# @title 7. Drive Path to model\n",
"# @markdown Path should start with /content/drive/path-to-your-file <br>\n",
"# @markdown <font color=\"red\">Note:</font> Model should be downloaded from https://huggingface.co <br>\n",
"# @markdown Lastest release: [Stable-Diffusion-v-1-4](https://huggingface.co/CompVis/stable-diffusion-v-1-4-original)\n",
"from os.path import exists\n",
"\n",
"model_path = \"\" # @param {type:\"string\"}\n",
"if exists(model_path):\n",
" print(\"✅ Valid directory\")\n",
"else:\n",
" print(\"❌ File doesn't exist\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "UY-NNz4I8_aG"
},
"outputs": [],
"source": [
"# @title 8. Symlink to model\n",
"\n",
"from os.path import exists\n",
"import os\n",
"\n",
"# Folder creation if it doesn't exist\n",
"if exists(\"/content/InvokeAI/models/ldm/stable-diffusion-v1\"):\n",
" print(\"❗ Dir stable-diffusion-v1 already exists\")\n",
"else:\n",
" %mkdir /content/InvokeAI/models/ldm/stable-diffusion-v1\n",
" print(\"✅ Dir stable-diffusion-v1 created\")\n",
"\n",
"# Symbolic link if it doesn't exist\n",
"if exists(\"/content/InvokeAI/models/ldm/stable-diffusion-v1/model.ckpt\"):\n",
" print(\"❗ Symlink already created\")\n",
"else:\n",
" src = model_path\n",
" dst = \"/content/InvokeAI/models/ldm/stable-diffusion-v1/model.ckpt\"\n",
" os.symlink(src, dst)\n",
" print(\"✅ Symbolic link created successfully\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Mc28N0_NrCQH"
},
"source": [
"## ◢ Execution"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "ir4hCrMIuUpl"
},
"outputs": [],
"source": [
"# @title 9. Run Terminal and Execute Dream bot\n",
"# @markdown <font color=\"blue\">Steps:</font> <br>\n",
"# @markdown 1. Execute command `python scripts/invoke.py` to run InvokeAI.<br>\n",
"# @markdown 2. After initialized you'll see `Dream>` line.<br>\n",
"# @markdown 3. Example text: `Astronaut floating in a distant galaxy` <br>\n",
"# @markdown 4. To quit Dream bot use: `q` command.<br>\n",
"\n",
"%load_ext colabxterm\n",
"%xterm\n",
"gc.collect()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "qnLohSHmKoGk"
},
"outputs": [],
"source": [
"#@title 10. Show the last 15 generated images\n",
"import glob\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib.image as mpimg\n",
"%matplotlib inline\n",
"\n",
"images = []\n",
"for img_path in sorted(glob.glob('/content/InvokeAI/outputs/img-samples/*.png'), reverse=True):\n",
" images.append(mpimg.imread(img_path))\n",
"\n",
"images = images[:15] \n",
"\n",
"plt.figure(figsize=(20,10))\n",
"\n",
"columns = 5\n",
"for i, image in enumerate(images):\n",
" ax = plt.subplot(len(images) / columns + 1, columns, i + 1)\n",
" ax.axes.xaxis.set_visible(False)\n",
" ax.axes.yaxis.set_visible(False)\n",
" ax.axis('off')\n",
" plt.imshow(image)\n",
" gc.collect()\n",
"\n"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"private_outputs": true,
"provenance": []
},
"gpuClass": "standard",
"kernelspec": {
"display_name": "Python 3.9.12 64-bit",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.9.12"
},
"vscode": {
"interpreter": {
"hash": "4e870c5c5fe42db7e2c5647ae5af656ff3391bf8c2b729cbf7fa0e16ca8cb5af"
}
}
},
"nbformat": 4,
"nbformat_minor": 0
} }

View File

@ -58,15 +58,15 @@ dependencies = [
"invisible-watermark~=0.2.0", # needed to install SDXL base and refiner using their repo_ids "invisible-watermark~=0.2.0", # needed to install SDXL base and refiner using their repo_ids
"matplotlib", # needed for plotting of Penner easing functions "matplotlib", # needed for plotting of Penner easing functions
"mediapipe", # needed for "mediapipeface" controlnet model "mediapipe", # needed for "mediapipeface" controlnet model
"numpy",
"npyscreen", "npyscreen",
"numpy==1.24.4",
"omegaconf", "omegaconf",
"onnx", "onnx",
"opencv-python", "opencv-python",
"pydantic==1.*",
"picklescan", "picklescan",
"pillow", "pillow",
"prompt-toolkit", "prompt-toolkit",
"pydantic==1.10.10",
"pympler~=1.0.1", "pympler~=1.0.1",
"pypatchmatch", "pypatchmatch",
'pyperclip', 'pyperclip',
@ -82,7 +82,7 @@ dependencies = [
"test-tube~=0.7.5", "test-tube~=0.7.5",
"torch~=2.0.1", "torch~=2.0.1",
"torchvision~=0.15.2", "torchvision~=0.15.2",
"torchmetrics~=1.0.1", "torchmetrics~=0.11.0",
"torchsde~=0.2.5", "torchsde~=0.2.5",
"transformers~=4.31.0", "transformers~=4.31.0",
"uvicorn[standard]~=0.21.1", "uvicorn[standard]~=0.21.1",

View File

@ -52,17 +52,17 @@
"name": "stdout", "name": "stdout",
"text": [ "text": [
"Cloning into 'latent-diffusion'...\n", "Cloning into 'latent-diffusion'...\n",
"remote: Enumerating objects: 992, done.\u001B[K\n", "remote: Enumerating objects: 992, done.\u001b[K\n",
"remote: Counting objects: 100% (695/695), done.\u001B[K\n", "remote: Counting objects: 100% (695/695), done.\u001b[K\n",
"remote: Compressing objects: 100% (397/397), done.\u001B[K\n", "remote: Compressing objects: 100% (397/397), done.\u001b[K\n",
"remote: Total 992 (delta 375), reused 564 (delta 253), pack-reused 297\u001B[K\n", "remote: Total 992 (delta 375), reused 564 (delta 253), pack-reused 297\u001b[K\n",
"Receiving objects: 100% (992/992), 30.78 MiB | 29.43 MiB/s, done.\n", "Receiving objects: 100% (992/992), 30.78 MiB | 29.43 MiB/s, done.\n",
"Resolving deltas: 100% (510/510), done.\n", "Resolving deltas: 100% (510/510), done.\n",
"Cloning into 'taming-transformers'...\n", "Cloning into 'taming-transformers'...\n",
"remote: Enumerating objects: 1335, done.\u001B[K\n", "remote: Enumerating objects: 1335, done.\u001b[K\n",
"remote: Counting objects: 100% (525/525), done.\u001B[K\n", "remote: Counting objects: 100% (525/525), done.\u001b[K\n",
"remote: Compressing objects: 100% (493/493), done.\u001B[K\n", "remote: Compressing objects: 100% (493/493), done.\u001b[K\n",
"remote: Total 1335 (delta 58), reused 481 (delta 30), pack-reused 810\u001B[K\n", "remote: Total 1335 (delta 58), reused 481 (delta 30), pack-reused 810\u001b[K\n",
"Receiving objects: 100% (1335/1335), 412.35 MiB | 30.53 MiB/s, done.\n", "Receiving objects: 100% (1335/1335), 412.35 MiB | 30.53 MiB/s, done.\n",
"Resolving deltas: 100% (267/267), done.\n", "Resolving deltas: 100% (267/267), done.\n",
"Obtaining file:///content/taming-transformers\n", "Obtaining file:///content/taming-transformers\n",
@ -73,23 +73,24 @@
"Installing collected packages: taming-transformers\n", "Installing collected packages: taming-transformers\n",
" Running setup.py develop for taming-transformers\n", " Running setup.py develop for taming-transformers\n",
"Successfully installed taming-transformers-0.0.1\n", "Successfully installed taming-transformers-0.0.1\n",
"\u001B[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
"tensorflow 2.8.0 requires tf-estimator-nightly==2.8.0.dev2021122109, which is not installed.\n", "tensorflow 2.8.0 requires tf-estimator-nightly==2.8.0.dev2021122109, which is not installed.\n",
"arviz 0.11.4 requires typing-extensions<4,>=3.7.4.3, but you have typing-extensions 4.1.1 which is incompatible.\u001B[0m\n" "arviz 0.11.4 requires typing-extensions<4,>=3.7.4.3, but you have typing-extensions 4.1.1 which is incompatible.\u001b[0m\n"
] ]
} }
], ],
"source": [ "source": [
"#@title Installation\n", "# @title Installation\n",
"!git clone https://github.com/CompVis/latent-diffusion.git\n", "!git clone https://github.com/CompVis/latent-diffusion.git\n",
"!git clone https://github.com/CompVis/taming-transformers\n", "!git clone https://github.com/CompVis/taming-transformers\n",
"!pip install -e ./taming-transformers\n", "!pip install -e ./taming-transformers\n",
"!pip install omegaconf>=2.0.0 pytorch-lightning>=1.0.8 torch-fidelity einops\n", "!pip install omegaconf>=2.0.0 pytorch-lightning>=1.0.8 torch-fidelity einops\n",
"\n", "\n",
"import sys\n", "import sys\n",
"\n",
"sys.path.append(\".\")\n", "sys.path.append(\".\")\n",
"sys.path.append('./taming-transformers')\n", "sys.path.append(\"./taming-transformers\")\n",
"from taming.models import vqgan " "from taming.models import vqgan"
] ]
}, },
{ {
@ -104,11 +105,11 @@
{ {
"cell_type": "code", "cell_type": "code",
"source": [ "source": [
"#@title Download\n", "# @title Download\n",
"%cd latent-diffusion/ \n", "%cd latent-diffusion/\n",
"\n", "\n",
"!mkdir -p models/ldm/cin256-v2/\n", "!mkdir -p models/ldm/cin256-v2/\n",
"!wget -O models/ldm/cin256-v2/model.ckpt https://ommer-lab.com/files/latent-diffusion/nitro/cin/model.ckpt " "!wget -O models/ldm/cin256-v2/model.ckpt https://ommer-lab.com/files/latent-diffusion/nitro/cin/model.ckpt"
], ],
"metadata": { "metadata": {
"colab": { "colab": {
@ -203,7 +204,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"source": [ "source": [
"#@title loading utils\n", "# @title loading utils\n",
"import torch\n", "import torch\n",
"from omegaconf import OmegaConf\n", "from omegaconf import OmegaConf\n",
"\n", "\n",
@ -212,7 +213,7 @@
"\n", "\n",
"def load_model_from_config(config, ckpt):\n", "def load_model_from_config(config, ckpt):\n",
" print(f\"Loading model from {ckpt}\")\n", " print(f\"Loading model from {ckpt}\")\n",
" pl_sd = torch.load(ckpt)#, map_location=\"cpu\")\n", " pl_sd = torch.load(ckpt) # , map_location=\"cpu\")\n",
" sd = pl_sd[\"state_dict\"]\n", " sd = pl_sd[\"state_dict\"]\n",
" model = instantiate_from_config(config.model)\n", " model = instantiate_from_config(config.model)\n",
" m, u = model.load_state_dict(sd, strict=False)\n", " m, u = model.load_state_dict(sd, strict=False)\n",
@ -222,7 +223,7 @@
"\n", "\n",
"\n", "\n",
"def get_model():\n", "def get_model():\n",
" config = OmegaConf.load(\"configs/latent-diffusion/cin256-v2.yaml\") \n", " config = OmegaConf.load(\"configs/latent-diffusion/cin256-v2.yaml\")\n",
" model = load_model_from_config(config, \"models/ldm/cin256-v2/model.ckpt\")\n", " model = load_model_from_config(config, \"models/ldm/cin256-v2/model.ckpt\")\n",
" return model" " return model"
], ],
@ -276,18 +277,18 @@
{ {
"cell_type": "code", "cell_type": "code",
"source": [ "source": [
"import numpy as np \n", "import numpy as np\n",
"from PIL import Image\n", "from PIL import Image\n",
"from einops import rearrange\n", "from einops import rearrange\n",
"from torchvision.utils import make_grid\n", "from torchvision.utils import make_grid\n",
"\n", "\n",
"\n", "\n",
"classes = [25, 187, 448, 992] # define classes to be sampled here\n", "classes = [25, 187, 448, 992] # define classes to be sampled here\n",
"n_samples_per_class = 6\n", "n_samples_per_class = 6\n",
"\n", "\n",
"ddim_steps = 20\n", "ddim_steps = 20\n",
"ddim_eta = 0.0\n", "ddim_eta = 0.0\n",
"scale = 3.0 # for unconditional guidance\n", "scale = 3.0 # for unconditional guidance\n",
"\n", "\n",
"\n", "\n",
"all_samples = list()\n", "all_samples = list()\n",
@ -295,36 +296,39 @@
"with torch.no_grad():\n", "with torch.no_grad():\n",
" with model.ema_scope():\n", " with model.ema_scope():\n",
" uc = model.get_learned_conditioning(\n", " uc = model.get_learned_conditioning(\n",
" {model.cond_stage_key: torch.tensor(n_samples_per_class*[1000]).to(model.device)}\n", " {model.cond_stage_key: torch.tensor(n_samples_per_class * [1000]).to(model.device)}\n",
" )\n", " )\n",
" \n", "\n",
" for class_label in classes:\n", " for class_label in classes:\n",
" print(f\"rendering {n_samples_per_class} examples of class '{class_label}' in {ddim_steps} steps and using s={scale:.2f}.\")\n", " print(\n",
" xc = torch.tensor(n_samples_per_class*[class_label])\n", " f\"rendering {n_samples_per_class} examples of class '{class_label}' in {ddim_steps} steps and using s={scale:.2f}.\"\n",
" )\n",
" xc = torch.tensor(n_samples_per_class * [class_label])\n",
" c = model.get_learned_conditioning({model.cond_stage_key: xc.to(model.device)})\n", " c = model.get_learned_conditioning({model.cond_stage_key: xc.to(model.device)})\n",
" \n", "\n",
" samples_ddim, _ = sampler.sample(S=ddim_steps,\n", " samples_ddim, _ = sampler.sample(\n",
" conditioning=c,\n", " S=ddim_steps,\n",
" batch_size=n_samples_per_class,\n", " conditioning=c,\n",
" shape=[3, 64, 64],\n", " batch_size=n_samples_per_class,\n",
" verbose=False,\n", " shape=[3, 64, 64],\n",
" unconditional_guidance_scale=scale,\n", " verbose=False,\n",
" unconditional_conditioning=uc, \n", " unconditional_guidance_scale=scale,\n",
" eta=ddim_eta)\n", " unconditional_conditioning=uc,\n",
" eta=ddim_eta,\n",
" )\n",
"\n", "\n",
" x_samples_ddim = model.decode_first_stage(samples_ddim)\n", " x_samples_ddim = model.decode_first_stage(samples_ddim)\n",
" x_samples_ddim = torch.clamp((x_samples_ddim+1.0)/2.0, \n", " x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)\n",
" min=0.0, max=1.0)\n",
" all_samples.append(x_samples_ddim)\n", " all_samples.append(x_samples_ddim)\n",
"\n", "\n",
"\n", "\n",
"# display as grid\n", "# display as grid\n",
"grid = torch.stack(all_samples, 0)\n", "grid = torch.stack(all_samples, 0)\n",
"grid = rearrange(grid, 'n b c h w -> (n b) c h w')\n", "grid = rearrange(grid, \"n b c h w -> (n b) c h w\")\n",
"grid = make_grid(grid, nrow=n_samples_per_class)\n", "grid = make_grid(grid, nrow=n_samples_per_class)\n",
"\n", "\n",
"# to image\n", "# to image\n",
"grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()\n", "grid = 255.0 * rearrange(grid, \"c h w -> h w c\").cpu().numpy()\n",
"Image.fromarray(grid.astype(np.uint8))" "Image.fromarray(grid.astype(np.uint8))"
], ],
"metadata": { "metadata": {