Merge branch 'main' into feat/standalone_diffusers_ti

This commit is contained in:
blessedcoolant 2023-07-17 22:01:52 +12:00 committed by GitHub
commit cbfd1d1b27
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
32 changed files with 1163 additions and 1863 deletions

View File

@ -132,8 +132,10 @@ and go to http://localhost:9090.
### Command-Line Installation (for developers and users familiar with Terminals)
You must have Python 3.9 or 3.10 installed on your machine. Earlier or later versions are
not supported.
You must have Python 3.9 or 3.10 installed on your machine. Earlier or
later versions are not supported.
Node.js also needs to be installed along with yarn (can be installed with
the command `npm install -g yarn` if needed)
1. Open a command-line window on your machine. The PowerShell is recommended for Windows.
2. Create a directory to install InvokeAI into. You'll need at least 15 GB of free space:
@ -197,11 +199,18 @@ not supported.
7. Launch the web server (do it every time you run InvokeAI):
```terminal
invokeai --web
invokeai-web
```
8. Point your browser to http://localhost:9090 to bring up the web interface.
9. Type `banana sushi` in the box on the top left and click `Invoke`.
8. Build Node.js assets
```terminal
cd invokeai/frontend/web/
yarn vite build
```
9. Point your browser to http://localhost:9090 to bring up the web interface.
10. Type `banana sushi` in the box on the top left and click `Invoke`.
Be sure to activate the virtual environment each time before re-launching InvokeAI,
using `source .venv/bin/activate` or `.venv\Scripts\activate`.

View File

@ -11,6 +11,7 @@ from invokeai.app.services.board_images import (
)
from invokeai.app.services.board_record_storage import SqliteBoardRecordStorage
from invokeai.app.services.boards import BoardService, BoardServiceDependencies
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
from invokeai.app.services.images import ImageService, ImageServiceDependencies
from invokeai.app.services.resource_name import SimpleNameService
@ -20,7 +21,6 @@ from invokeai.version.invokeai_version import __version__
from ..services.default_graphs import create_system_graphs
from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
from ..services.restoration_services import RestorationServices
from ..services.graph import GraphExecutionState, LibraryGraph
from ..services.image_file_storage import DiskImageFileStorage
from ..services.invocation_queue import MemoryInvocationQueue
@ -57,8 +57,8 @@ class ApiDependencies:
invoker: Invoker = None
@staticmethod
def initialize(config, event_handler_id: int, logger: Logger = logger):
logger.debug(f'InvokeAI version {__version__}')
def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger = logger):
logger.debug(f"InvokeAI version {__version__}")
logger.debug(f"Internet connectivity is {config.internet_available}")
events = FastAPIEventService(event_handler_id)
@ -117,7 +117,7 @@ class ApiDependencies:
)
services = InvocationServices(
model_manager=ModelManagerService(config,logger),
model_manager=ModelManagerService(config, logger),
events=events,
latents=latents,
images=images,
@ -129,7 +129,6 @@ class ApiDependencies:
),
graph_execution_manager=graph_execution_manager,
processor=DefaultInvocationProcessor(),
restoration=RestorationServices(config, logger),
configuration=config,
logger=logger,
)

View File

@ -39,6 +39,7 @@ from .invocations.baseinvocation import BaseInvocation
import torch
import invokeai.backend.util.hotfixes
if torch.backends.mps.is_available():
import invokeai.backend.util.mps_fixes

View File

@ -54,10 +54,10 @@ from .services.invocation_services import InvocationServices
from .services.invoker import Invoker
from .services.model_manager_service import ModelManagerService
from .services.processor import DefaultInvocationProcessor
from .services.restoration_services import RestorationServices
from .services.sqlite import SqliteItemStorage
import torch
import invokeai.backend.util.hotfixes
if torch.backends.mps.is_available():
import invokeai.backend.util.mps_fixes
@ -295,7 +295,6 @@ def invoke_cli():
),
graph_execution_manager=graph_execution_manager,
processor=DefaultInvocationProcessor(),
restoration=RestorationServices(config,logger=logger),
logger=logger,
configuration=config,
)

View File

@ -1,6 +1,8 @@
from typing import Literal
from os.path import exists
from typing import Literal, Optional
from pydantic.fields import Field
import numpy as np
from pydantic import Field, validator
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
from dynamicprompts.generators import RandomPromptGenerator, CombinatorialPromptGenerator
@ -55,3 +57,41 @@ class DynamicPromptInvocation(BaseInvocation):
prompts = generator.generate(self.prompt, num_images=self.max_prompts)
return PromptCollectionOutput(prompt_collection=prompts, count=len(prompts))
class PromptsFromFileInvocation(BaseInvocation):
'''Loads prompts from a text file'''
# fmt: off
type: Literal['prompt_from_file'] = 'prompt_from_file'
# Inputs
file_path: str = Field(description="Path to prompt text file")
pre_prompt: Optional[str] = Field(description="String to prepend to each prompt")
post_prompt: Optional[str] = Field(description="String to append to each prompt")
start_line: int = Field(default=1, ge=1, description="Line in the file to start start from")
max_prompts: int = Field(default=1, ge=0, description="Max lines to read from file (0=all)")
#fmt: on
@validator("file_path")
def file_path_exists(cls, v):
if not exists(v):
raise ValueError(FileNotFoundError)
return v
def promptsFromFile(self, file_path: str, pre_prompt: str, post_prompt: str, start_line: int, max_prompts: int):
prompts = []
start_line -= 1
end_line = start_line + max_prompts
if max_prompts <= 0:
end_line = np.iinfo(np.int32).max
with open(file_path) as f:
for i, line in enumerate(f):
if i >= start_line and i < end_line:
prompts.append((pre_prompt or '') + line.strip() + (post_prompt or ''))
if i >= end_line:
break
return prompts
def invoke(self, context: InvocationContext) -> PromptCollectionOutput:
prompts = self.promptsFromFile(self.file_path, self.pre_prompt, self.post_prompt, self.start_line, self.max_prompts)
return PromptCollectionOutput(prompt_collection=prompts, count=len(prompts))

View File

@ -1,55 +0,0 @@
from typing import Literal, Optional
from pydantic import Field
from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from .image import ImageOutput
class RestoreFaceInvocation(BaseInvocation):
"""Restores faces in an image."""
# fmt: off
type: Literal["restore_face"] = "restore_face"
# Inputs
image: Optional[ImageField] = Field(description="The input image")
strength: float = Field(default=0.75, gt=0, le=1, description="The strength of the restoration" )
# fmt: on
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["restoration", "image"],
},
}
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name)
results = context.services.restoration.upscale_and_reconstruct(
image_list=[[image, 0]],
upscale=None,
strength=self.strength, # GFPGAN strength
save_original=False,
image_callback=None,
)
# Results are image and seed, unwrap for now
# TODO: can this return multiple results?
image_dto = context.services.images.create(
image=results[0][0],
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
)
return ImageOutput(
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)

View File

@ -1,48 +1,112 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from typing import Literal, Optional
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) & the InvokeAI Team
from pathlib import Path, PosixPath
from typing import Literal, Union, cast
import cv2 as cv
import numpy as np
from basicsr.archs.rrdbnet_arch import RRDBNet
from PIL import Image
from pydantic import Field
from realesrgan import RealESRGANer
from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from .baseinvocation import BaseInvocation, InvocationContext
from .image import ImageOutput
# TODO: Populate this from disk?
# TODO: Use model manager to load?
REALESRGAN_MODELS = Literal[
"RealESRGAN_x4plus.pth",
"RealESRGAN_x4plus_anime_6B.pth",
"ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
]
class UpscaleInvocation(BaseInvocation):
"""Upscales an image."""
# fmt: off
type: Literal["upscale"] = "upscale"
class RealESRGANInvocation(BaseInvocation):
"""Upscales an image using RealESRGAN."""
# Inputs
image: Optional[ImageField] = Field(description="The input image", default=None)
strength: float = Field(default=0.75, gt=0, le=1, description="The strength")
level: Literal[2, 4] = Field(default=2, description="The upscale level")
# fmt: on
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["upscaling", "image"],
},
}
type: Literal["realesrgan"] = "realesrgan"
image: Union[ImageField, None] = Field(default=None, description="The input image")
model_name: REALESRGAN_MODELS = Field(
default="RealESRGAN_x4plus.pth", description="The Real-ESRGAN model to use"
)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name)
results = context.services.restoration.upscale_and_reconstruct(
image_list=[[image, 0]],
upscale=(self.level, self.strength),
strength=0.0, # GFPGAN strength
save_original=False,
image_callback=None,
models_path = context.services.configuration.models_path
rrdbnet_model = None
netscale = None
esrgan_model_path = None
if self.model_name in [
"RealESRGAN_x4plus.pth",
"ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
]:
# x4 RRDBNet model
rrdbnet_model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=4,
)
netscale = 4
elif self.model_name in ["RealESRGAN_x4plus_anime_6B.pth"]:
# x4 RRDBNet model, 6 blocks
rrdbnet_model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=6, # 6 blocks
num_grow_ch=32,
scale=4,
)
netscale = 4
# TODO: add x2 models handling?
# elif self.model_name in ["RealESRGAN_x2plus"]:
# # x2 RRDBNet model
# model = RRDBNet(
# num_in_ch=3,
# num_out_ch=3,
# num_feat=64,
# num_block=23,
# num_grow_ch=32,
# scale=2,
# )
# model_path = Path()
# netscale = 2
else:
msg = f"Invalid RealESRGAN model: {self.model_name}"
context.services.logger.error(msg)
raise ValueError(msg)
esrgan_model_path = Path(f"core/upscaling/realesrgan/{self.model_name}")
upsampler = RealESRGANer(
scale=netscale,
model_path=str(models_path / esrgan_model_path),
model=rrdbnet_model,
half=False,
)
# Results are image and seed, unwrap for now
# TODO: can this return multiple results?
# prepare image - Real-ESRGAN uses cv2 internally, and cv2 uses BGR vs RGB for PIL
cv_image = cv.cvtColor(np.array(image.convert("RGB")), cv.COLOR_RGB2BGR)
# We can pass an `outscale` value here, but it just resizes the image by that factor after
# upscaling, so it's kinda pointless for our purposes. If you want something other than 4x
# upscaling, you'll need to add a resize node after this one.
upscaled_image, img_mode = upsampler.enhance(cv_image)
# back to PIL
pil_image = Image.fromarray(
cv.cvtColor(upscaled_image, cv.COLOR_BGR2RGB)
).convert("RGBA")
image_dto = context.services.images.create(
image=results[0][0],
image=pil_image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,

View File

@ -271,13 +271,13 @@ class InvokeAISettings(BaseSettings):
@classmethod
def _excluded(self)->List[str]:
# combination of deprecated parameters and internal ones that shouldn't be exposed
# internal fields that shouldn't be exposed as command line options
return ['type','initconf']
@classmethod
def _excluded_from_yaml(self)->List[str]:
# combination of deprecated parameters and internal ones that shouldn't be exposed
return ['type','initconf', 'gpu_mem_reserved', 'max_loaded_models', 'version', 'from_file', 'model']
# combination of deprecated parameters and internal ones that shouldn't be exposed as invokeai.yaml options
return ['type','initconf', 'gpu_mem_reserved', 'max_loaded_models', 'version', 'from_file', 'model', 'restore']
class Config:
env_file_encoding = 'utf-8'
@ -366,7 +366,7 @@ setting environment variables INVOKEAI_<setting>.
log_tokenization : bool = Field(default=False, description="Enable logging of parsed prompt tokens.", category='Features')
nsfw_checker : bool = Field(default=True, description="Enable/disable the NSFW checker", category='Features')
patchmatch : bool = Field(default=True, description="Enable/disable patchmatch inpaint code", category='Features')
restore : bool = Field(default=True, description="Enable/disable face restoration code", category='Features')
restore : bool = Field(default=True, description="Enable/disable face restoration code (DEPRECATED)", category='DEPRECATED')
always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", category='Memory/Performance')
free_gpu_mem : bool = Field(default=False, description="If true, purge model from GPU after each generation.", category='Memory/Performance')

View File

@ -10,10 +10,9 @@ if TYPE_CHECKING:
from invokeai.app.services.model_manager_service import ModelManagerServiceBase
from invokeai.app.services.events import EventServiceBase
from invokeai.app.services.latent_storage import LatentsStorageBase
from invokeai.app.services.restoration_services import RestorationServices
from invokeai.app.services.invocation_queue import InvocationQueueABC
from invokeai.app.services.item_storage import ItemStorageABC
from invokeai.app.services.config import InvokeAISettings
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.graph import GraphExecutionState, LibraryGraph
from invokeai.app.services.invoker import InvocationProcessorABC
@ -24,7 +23,7 @@ class InvocationServices:
# TODO: Just forward-declared everything due to circular dependencies. Fix structure.
board_images: "BoardImagesServiceABC"
boards: "BoardServiceABC"
configuration: "InvokeAISettings"
configuration: "InvokeAIAppConfig"
events: "EventServiceBase"
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"]
graph_library: "ItemStorageABC"["LibraryGraph"]
@ -34,13 +33,12 @@ class InvocationServices:
model_manager: "ModelManagerServiceBase"
processor: "InvocationProcessorABC"
queue: "InvocationQueueABC"
restoration: "RestorationServices"
def __init__(
self,
board_images: "BoardImagesServiceABC",
boards: "BoardServiceABC",
configuration: "InvokeAISettings",
configuration: "InvokeAIAppConfig",
events: "EventServiceBase",
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"],
graph_library: "ItemStorageABC"["LibraryGraph"],
@ -50,7 +48,6 @@ class InvocationServices:
model_manager: "ModelManagerServiceBase",
processor: "InvocationProcessorABC",
queue: "InvocationQueueABC",
restoration: "RestorationServices",
):
self.board_images = board_images
self.boards = boards
@ -65,4 +62,3 @@ class InvocationServices:
self.model_manager = model_manager
self.processor = processor
self.queue = queue
self.restoration = restoration

View File

@ -1,113 +0,0 @@
import sys
import traceback
import torch
from typing import types
from ...backend.restoration import Restoration
from ...backend.util import choose_torch_device, CPU_DEVICE, MPS_DEVICE
# This should be a real base class for postprocessing functions,
# but right now we just instantiate the existing gfpgan, esrgan
# and codeformer functions.
class RestorationServices:
'''Face restoration and upscaling'''
def __init__(self,args,logger:types.ModuleType):
try:
gfpgan, codeformer, esrgan = None, None, None
if args.restore or args.esrgan:
restoration = Restoration()
# TODO: redo for new model structure
if False and args.restore:
gfpgan, codeformer = restoration.load_face_restore_models(
args.gfpgan_model_path
)
else:
logger.info("Face restoration disabled")
if False and args.esrgan:
esrgan = restoration.load_esrgan(args.esrgan_bg_tile)
else:
logger.info("Upscaling disabled")
else:
logger.info("Face restoration and upscaling disabled")
except (ModuleNotFoundError, ImportError):
print(traceback.format_exc(), file=sys.stderr)
logger.info("You may need to install the ESRGAN and/or GFPGAN modules")
self.device = torch.device(choose_torch_device())
self.gfpgan = gfpgan
self.codeformer = codeformer
self.esrgan = esrgan
self.logger = logger
self.logger.info('Face restoration initialized')
# note that this one method does gfpgan and codepath reconstruction, as well as
# esrgan upscaling
# TO DO: refactor into separate methods
def upscale_and_reconstruct(
self,
image_list,
facetool="gfpgan",
upscale=None,
upscale_denoise_str=0.75,
strength=0.0,
codeformer_fidelity=0.75,
save_original=False,
image_callback=None,
prefix=None,
):
results = []
for r in image_list:
image, seed = r
try:
if strength > 0:
if self.gfpgan is not None or self.codeformer is not None:
if facetool == "gfpgan":
if self.gfpgan is None:
self.logger.info(
"GFPGAN not found. Face restoration is disabled."
)
else:
image = self.gfpgan.process(image, strength, seed)
if facetool == "codeformer":
if self.codeformer is None:
self.logger.info(
"CodeFormer not found. Face restoration is disabled."
)
else:
cf_device = (
CPU_DEVICE if self.device == MPS_DEVICE else self.device
)
image = self.codeformer.process(
image=image,
strength=strength,
device=cf_device,
seed=seed,
fidelity=codeformer_fidelity,
)
else:
self.logger.info("Face Restoration is disabled.")
if upscale is not None:
if self.esrgan is not None:
if len(upscale) < 2:
upscale.append(0.75)
image = self.esrgan.process(
image,
upscale[1],
seed,
int(upscale[0]),
denoise_str=upscale_denoise_str,
)
else:
self.logger.info("ESRGAN is disabled. Image not upscaled.")
except Exception as e:
self.logger.info(
f"Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}"
)
if image_callback is not None:
image_callback(image, seed, upscaled=True, use_prefix=prefix)
else:
r[0] = image
results.append([image, seed])
return results

View File

@ -30,8 +30,6 @@ from huggingface_hub import login as hf_hub_login
from omegaconf import OmegaConf
from tqdm import tqdm
from transformers import (
AutoProcessor,
CLIPSegForImageSegmentation,
CLIPTextModel,
CLIPTokenizer,
AutoFeatureExtractor,
@ -45,7 +43,6 @@ from invokeai.app.services.config import (
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.frontend.install.model_install import addModelsForm, process_and_execute
from invokeai.frontend.install.widgets import (
SingleSelectColumns,
CenteredButtonPress,
IntTitleSlider,
set_min_terminal_size,
@ -226,64 +223,30 @@ def download_conversion_models():
# ---------------------------------------------
def download_realesrgan():
logger.info("Installing models from RealESRGAN...")
model_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth"
wdn_model_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth"
model_dest = config.root_path / "models/core/upscaling/realesrgan/realesr-general-x4v3.pth"
wdn_model_dest = config.root_path / "models/core/upscaling/realesrgan/realesr-general-wdn-x4v3.pth"
download_with_progress_bar(model_url, str(model_dest), "RealESRGAN")
download_with_progress_bar(wdn_model_url, str(wdn_model_dest), "RealESRGANwdn")
def download_gfpgan():
logger.info("Installing GFPGAN models...")
for model in (
[
"https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth",
"./models/core/face_restoration/gfpgan/GFPGANv1.4.pth",
],
[
"https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_Resnet50_Final.pth",
"./models/core/face_restoration/gfpgan/weights/detection_Resnet50_Final.pth",
],
[
"https://github.com/xinntao/facexlib/releases/download/v0.2.2/parsing_parsenet.pth",
"./models/core/face_restoration/gfpgan/weights/parsing_parsenet.pth",
],
):
model_url, model_dest = model[0], config.root_path / model[1]
download_with_progress_bar(model_url, str(model_dest), "GFPGAN weights")
logger.info("Installing RealESRGAN models...")
URLs = [
dict(
url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
dest = "core/upscaling/realesrgan/RealESRGAN_x4plus.pth",
description = "RealESRGAN_x4plus.pth",
),
dict(
url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
dest = "core/upscaling/realesrgan/RealESRGAN_x4plus_anime_6B.pth",
description = "RealESRGAN_x4plus_anime_6B.pth",
),
dict(
url= "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
dest= "core/upscaling/realesrgan/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
description = "ESRGAN_SRx4_DF2KOST_official.pth",
),
]
for model in URLs:
download_with_progress_bar(model['url'], config.models_path / model['dest'], model['description'])
# ---------------------------------------------
def download_codeformer():
logger.info("Installing CodeFormer model file...")
model_url = (
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth"
)
model_dest = config.root_path / "models/core/face_restoration/codeformer/codeformer.pth"
download_with_progress_bar(model_url, str(model_dest), "CodeFormer")
# ---------------------------------------------
def download_clipseg():
logger.info("Installing clipseg model for text-based masking...")
CLIPSEG_MODEL = "CIDAS/clipseg-rd64-refined"
try:
hf_download_from_pretrained(AutoProcessor, CLIPSEG_MODEL, config.root_path / 'models/core/misc/clipseg')
hf_download_from_pretrained(CLIPSegForImageSegmentation, CLIPSEG_MODEL, config.root_path / 'models/core/misc/clipseg')
except Exception:
logger.info("Error installing clipseg model:")
logger.info(traceback.format_exc())
def download_support_models():
download_realesrgan()
download_gfpgan()
download_codeformer()
download_clipseg()
download_conversion_models()
# -------------------------------------
@ -858,9 +821,9 @@ def main():
download_support_models()
if opt.skip_sd_weights:
logger.info("\n** SKIPPING DIFFUSION WEIGHTS DOWNLOAD PER USER REQUEST **")
logger.warning("SKIPPING DIFFUSION WEIGHTS DOWNLOAD PER USER REQUEST")
elif models_to_download:
logger.info("\n** DOWNLOADING DIFFUSION WEIGHTS **")
logger.info("DOWNLOADING DIFFUSION WEIGHTS")
process_and_execute(opt, models_to_download)
postscript(errors=errors)

View File

@ -117,6 +117,7 @@ class ModelInstall(object):
# supplement with entries in models.yaml
installed_models = self.mgr.list_models()
for md in installed_models:
base = md['base_model']
model_type = md['model_type']
@ -134,6 +135,12 @@ class ModelInstall(object):
)
return {x : model_dict[x] for x in sorted(model_dict.keys(),key=lambda y: model_dict[y].name.lower())}
def list_models(self, model_type):
installed = self.mgr.list_models(model_type=model_type)
print(f'Installed models of type `{model_type}`:')
for i in installed:
print(f"{i['model_name']}\t{i['base_model']}\t{i['path']}")
def starter_models(self)->Set[str]:
models = set()
for key, value in self.datasets.items():

View File

@ -908,7 +908,6 @@ class ModelManager(object):
from invokeai.backend.install.model_install_backend import ModelInstall
from invokeai.frontend.install.model_install import ask_user_for_prediction_type
class ScanAndImport(ModelSearch):
def __init__(self, directories, logger, ignore: Set[Path], installer: ModelInstall):
super().__init__(directories, logger)

View File

@ -1,4 +0,0 @@
"""
Initialization file for the invokeai.backend.restoration package
"""
from .base import Restoration

View File

@ -1,45 +0,0 @@
import invokeai.backend.util.logging as logger
class Restoration:
def __init__(self) -> None:
pass
def load_face_restore_models(
self, gfpgan_model_path="./models/core/face_restoration/gfpgan/GFPGANv1.4.pth"
):
# Load GFPGAN
gfpgan = self.load_gfpgan(gfpgan_model_path)
if gfpgan.gfpgan_model_exists:
logger.info("GFPGAN Initialized")
else:
logger.info("GFPGAN Disabled")
gfpgan = None
# Load CodeFormer
codeformer = self.load_codeformer()
if codeformer.codeformer_model_exists:
logger.info("CodeFormer Initialized")
else:
logger.info("CodeFormer Disabled")
codeformer = None
return gfpgan, codeformer
# Face Restore Models
def load_gfpgan(self, gfpgan_model_path):
from .gfpgan import GFPGAN
return GFPGAN(gfpgan_model_path)
def load_codeformer(self):
from .codeformer import CodeFormerRestoration
return CodeFormerRestoration()
# Upscale Models
def load_esrgan(self, esrgan_bg_tile=400):
from .realesrgan import ESRGAN
esrgan = ESRGAN(esrgan_bg_tile)
logger.info("ESRGAN Initialized")
return esrgan

View File

@ -1,120 +0,0 @@
import os
import sys
import warnings
import numpy as np
import torch
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig
pretrained_model_url = (
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth"
)
class CodeFormerRestoration:
def __init__(
self, codeformer_dir="./models/core/face_restoration/codeformer", codeformer_model_path="codeformer.pth"
) -> None:
self.globals = InvokeAIAppConfig.get_config()
codeformer_dir = self.globals.root_dir / codeformer_dir
self.model_path = codeformer_dir / codeformer_model_path
self.codeformer_model_exists = self.model_path.exists()
if not self.codeformer_model_exists:
logger.error(f"NOT FOUND: CodeFormer model not found at {self.model_path}")
sys.path.append(os.path.abspath(codeformer_dir))
def process(self, image, strength, device, seed=None, fidelity=0.75):
if seed is not None:
logger.info(f"CodeFormer - Restoring Faces for image seed:{seed}")
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=UserWarning)
from basicsr.utils import img2tensor, tensor2img
from basicsr.utils.download_util import load_file_from_url
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
from PIL import Image
from torchvision.transforms.functional import normalize
from .codeformer_arch import CodeFormer
cf_class = CodeFormer
cf = cf_class(
dim_embd=512,
codebook_size=1024,
n_head=8,
n_layers=9,
connect_list=["32", "64", "128", "256"],
).to(device)
# note that this file should already be downloaded and cached at
# this point
checkpoint_path = load_file_from_url(
url=pretrained_model_url,
model_dir=os.path.abspath(os.path.dirname(self.model_path)),
progress=True,
)
checkpoint = torch.load(checkpoint_path)["params_ema"]
cf.load_state_dict(checkpoint)
cf.eval()
image = image.convert("RGB")
# Codeformer expects a BGR np array; make array and flip channels
bgr_image_array = np.array(image, dtype=np.uint8)[..., ::-1]
face_helper = FaceRestoreHelper(
upscale_factor=1,
use_parse=True,
device=device,
model_rootpath = self.globals.model_path / 'core/face_restoration/gfpgan/weights'
)
face_helper.clean_all()
face_helper.read_image(bgr_image_array)
face_helper.get_face_landmarks_5(resize=640, eye_dist_threshold=5)
face_helper.align_warp_face()
for idx, cropped_face in enumerate(face_helper.cropped_faces):
cropped_face_t = img2tensor(
cropped_face / 255.0, bgr2rgb=True, float32=True
)
normalize(
cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True
)
cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
try:
with torch.no_grad():
output = cf(cropped_face_t, w=fidelity, adain=True)[0]
restored_face = tensor2img(
output.squeeze(0), rgb2bgr=True, min_max=(-1, 1)
)
del output
torch.cuda.empty_cache()
except RuntimeError as error:
logger.error(f"Failed inference for CodeFormer: {error}.")
restored_face = cropped_face
restored_face = restored_face.astype("uint8")
face_helper.add_restored_face(restored_face)
face_helper.get_inverse_affine(None)
restored_img = face_helper.paste_faces_to_input_image()
# Flip the channels back to RGB
res = Image.fromarray(restored_img[..., ::-1])
if strength < 1.0:
# Resize the image to the new image if the sizes have changed
if restored_img.size != image.size:
image = image.resize(res.size)
res = Image.blend(image, res, strength)
cf = None
return res

View File

@ -1,325 +0,0 @@
import math
from typing import List, Optional
import numpy as np
import torch
import torch.nn.functional as F
from basicsr.utils import get_root_logger
from basicsr.utils.registry import ARCH_REGISTRY
from torch import Tensor, nn
from .vqgan_arch import *
def calc_mean_std(feat, eps=1e-5):
"""Calculate mean and std for adaptive_instance_normalization.
Args:
feat (Tensor): 4D tensor.
eps (float): A small value added to the variance to avoid
divide-by-zero. Default: 1e-5.
"""
size = feat.size()
assert len(size) == 4, "The input feature should be 4D tensor."
b, c = size[:2]
feat_var = feat.view(b, c, -1).var(dim=2) + eps
feat_std = feat_var.sqrt().view(b, c, 1, 1)
feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1)
return feat_mean, feat_std
def adaptive_instance_normalization(content_feat, style_feat):
"""Adaptive instance normalization.
Adjust the reference features to have the similar color and illuminations
as those in the degradate features.
Args:
content_feat (Tensor): The reference feature.
style_feat (Tensor): The degradate features.
"""
size = content_feat.size()
style_mean, style_std = calc_mean_std(style_feat)
content_mean, content_std = calc_mean_std(content_feat)
normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(
size
)
return normalized_feat * style_std.expand(size) + style_mean.expand(size)
class PositionEmbeddingSine(nn.Module):
"""
This is a more standard version of the position embedding, very similar to the one
used by the Attention is all you need paper, generalized to work on images.
"""
def __init__(
self, num_pos_feats=64, temperature=10000, normalize=False, scale=None
):
super().__init__()
self.num_pos_feats = num_pos_feats
self.temperature = temperature
self.normalize = normalize
if scale is not None and normalize is False:
raise ValueError("normalize should be True if scale is passed")
if scale is None:
scale = 2 * math.pi
self.scale = scale
def forward(self, x, mask=None):
if mask is None:
mask = torch.zeros(
(x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool
)
not_mask = ~mask
y_embed = not_mask.cumsum(1, dtype=torch.float32)
x_embed = not_mask.cumsum(2, dtype=torch.float32)
if self.normalize:
eps = 1e-6
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
pos_x = x_embed[:, :, :, None] / dim_t
pos_y = y_embed[:, :, :, None] / dim_t
pos_x = torch.stack(
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
).flatten(3)
pos_y = torch.stack(
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
).flatten(3)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
return pos
def _get_activation_fn(activation):
"""Return an activation function given a string"""
if activation == "relu":
return F.relu
if activation == "gelu":
return F.gelu
if activation == "glu":
return F.glu
raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
class TransformerSALayer(nn.Module):
def __init__(
self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="gelu"
):
super().__init__()
self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout)
# Implementation of Feedforward model - MLP
self.linear1 = nn.Linear(embed_dim, dim_mlp)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_mlp, embed_dim)
self.norm1 = nn.LayerNorm(embed_dim)
self.norm2 = nn.LayerNorm(embed_dim)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos
def forward(
self,
tgt,
tgt_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
):
# self attention
tgt2 = self.norm1(tgt)
q = k = self.with_pos_embed(tgt2, query_pos)
tgt2 = self.self_attn(
q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
)[0]
tgt = tgt + self.dropout1(tgt2)
# ffn
tgt2 = self.norm2(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
tgt = tgt + self.dropout2(tgt2)
return tgt
class Fuse_sft_block(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.encode_enc = ResBlock(2 * in_ch, out_ch)
self.scale = nn.Sequential(
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
nn.LeakyReLU(0.2, True),
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
)
self.shift = nn.Sequential(
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
nn.LeakyReLU(0.2, True),
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
)
def forward(self, enc_feat, dec_feat, w=1):
enc_feat = self.encode_enc(torch.cat([enc_feat, dec_feat], dim=1))
scale = self.scale(enc_feat)
shift = self.shift(enc_feat)
residual = w * (dec_feat * scale + shift)
out = dec_feat + residual
return out
@ARCH_REGISTRY.register()
class CodeFormer(VQAutoEncoder):
def __init__(
self,
dim_embd=512,
n_head=8,
n_layers=9,
codebook_size=1024,
latent_size=256,
connect_list=["32", "64", "128", "256"],
fix_modules=["quantize", "generator"],
):
super(CodeFormer, self).__init__(
512, 64, [1, 2, 2, 4, 4, 8], "nearest", 2, [16], codebook_size
)
if fix_modules is not None:
for module in fix_modules:
for param in getattr(self, module).parameters():
param.requires_grad = False
self.connect_list = connect_list
self.n_layers = n_layers
self.dim_embd = dim_embd
self.dim_mlp = dim_embd * 2
self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embd))
self.feat_emb = nn.Linear(256, self.dim_embd)
# transformer
self.ft_layers = nn.Sequential(
*[
TransformerSALayer(
embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0
)
for _ in range(self.n_layers)
]
)
# logits_predict head
self.idx_pred_layer = nn.Sequential(
nn.LayerNorm(dim_embd), nn.Linear(dim_embd, codebook_size, bias=False)
)
self.channels = {
"16": 512,
"32": 256,
"64": 256,
"128": 128,
"256": 128,
"512": 64,
}
# after second residual block for > 16, before attn layer for ==16
self.fuse_encoder_block = {
"512": 2,
"256": 5,
"128": 8,
"64": 11,
"32": 14,
"16": 18,
}
# after first residual block for > 16, before attn layer for ==16
self.fuse_generator_block = {
"16": 6,
"32": 9,
"64": 12,
"128": 15,
"256": 18,
"512": 21,
}
# fuse_convs_dict
self.fuse_convs_dict = nn.ModuleDict()
for f_size in self.connect_list:
in_ch = self.channels[f_size]
self.fuse_convs_dict[f_size] = Fuse_sft_block(in_ch, in_ch)
def _init_weights(self, module):
if isinstance(module, (nn.Linear, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=0.02)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def forward(self, x, w=0, detach_16=True, code_only=False, adain=False):
# ################### Encoder #####################
enc_feat_dict = {}
out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
for i, block in enumerate(self.encoder.blocks):
x = block(x)
if i in out_list:
enc_feat_dict[str(x.shape[-1])] = x.clone()
lq_feat = x
# ################# Transformer ###################
# quant_feat, codebook_loss, quant_stats = self.quantize(lq_feat)
pos_emb = self.position_emb.unsqueeze(1).repeat(1, x.shape[0], 1)
# BCHW -> BC(HW) -> (HW)BC
feat_emb = self.feat_emb(lq_feat.flatten(2).permute(2, 0, 1))
query_emb = feat_emb
# Transformer encoder
for layer in self.ft_layers:
query_emb = layer(query_emb, query_pos=pos_emb)
# output logits
logits = self.idx_pred_layer(query_emb) # (hw)bn
logits = logits.permute(1, 0, 2) # (hw)bn -> b(hw)n
if code_only: # for training stage II
# logits doesn't need softmax before cross_entropy loss
return logits, lq_feat
# ################# Quantization ###################
# if self.training:
# quant_feat = torch.einsum('btn,nc->btc', [soft_one_hot, self.quantize.embedding.weight])
# # b(hw)c -> bc(hw) -> bchw
# quant_feat = quant_feat.permute(0,2,1).view(lq_feat.shape)
# ------------
soft_one_hot = F.softmax(logits, dim=2)
_, top_idx = torch.topk(soft_one_hot, 1, dim=2)
quant_feat = self.quantize.get_codebook_feat(
top_idx, shape=[x.shape[0], 16, 16, 256]
)
# preserve gradients
# quant_feat = lq_feat + (quant_feat - lq_feat).detach()
if detach_16:
quant_feat = quant_feat.detach() # for training stage III
if adain:
quant_feat = adaptive_instance_normalization(quant_feat, lq_feat)
# ################## Generator ####################
x = quant_feat
fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
for i, block in enumerate(self.generator.blocks):
x = block(x)
if i in fuse_list: # fuse after i-th block
f_size = str(x.shape[-1])
if w > 0:
x = self.fuse_convs_dict[f_size](
enc_feat_dict[f_size].detach(), x, w
)
out = x
# logits doesn't need softmax before cross_entropy loss
return out, logits, lq_feat

View File

@ -1,84 +0,0 @@
import os
import sys
import warnings
import numpy as np
import torch
from PIL import Image
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig
class GFPGAN:
def __init__(self, gfpgan_model_path="models/gfpgan/GFPGANv1.4.pth") -> None:
self.globals = InvokeAIAppConfig.get_config()
if not os.path.isabs(gfpgan_model_path):
gfpgan_model_path = self.globals.root_dir / gfpgan_model_path
self.model_path = gfpgan_model_path
self.gfpgan_model_exists = os.path.isfile(self.model_path)
if not self.gfpgan_model_exists:
logger.error(f"NOT FOUND: GFPGAN model not found at {self.model_path}")
return None
def model_exists(self):
return os.path.isfile(self.model_path)
def process(self, image, strength: float, seed: str = None):
if seed is not None:
logger.info(f"GFPGAN - Restoring Faces for image seed:{seed}")
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=UserWarning)
cwd = os.getcwd()
os.chdir(self.globals.root_dir / 'models')
try:
from gfpgan import GFPGANer
self.gfpgan = GFPGANer(
model_path=self.model_path,
upscale=1,
arch="clean",
channel_multiplier=2,
bg_upsampler=None,
)
except Exception:
import traceback
logger.error("Error loading GFPGAN:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
os.chdir(cwd)
if self.gfpgan is None:
logger.warning("WARNING: GFPGAN not initialized.")
logger.warning(
f"Download https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth to {self.model_path}"
)
image = image.convert("RGB")
# GFPGAN expects a BGR np array; make array and flip channels
bgr_image_array = np.array(image, dtype=np.uint8)[..., ::-1]
_, _, restored_img = self.gfpgan.enhance(
bgr_image_array,
has_aligned=False,
only_center_face=False,
paste_back=True,
)
# Flip the channels back to RGB
res = Image.fromarray(restored_img[..., ::-1])
if strength < 1.0:
# Resize the image to the new image if the sizes have changed
if restored_img.size != image.size:
image = image.resize(res.size)
res = Image.blend(image, res, strength)
if torch.cuda.is_available():
torch.cuda.empty_cache()
self.gfpgan = None
return res

View File

@ -1,118 +0,0 @@
import math
from PIL import Image
import invokeai.backend.util.logging as logger
class Outcrop(object):
def __init__(
self,
image,
generate, # current generate object
):
self.image = image
self.generate = generate
def process(
self,
extents: dict,
opt, # current options
orig_opt, # ones originally used to generate the image
image_callback=None,
prefix=None,
):
# grow and mask the image
extended_image = self._extend_all(extents)
# switch samplers temporarily
curr_sampler = self.generate.sampler
self.generate.sampler_name = opt.sampler_name
self.generate._set_scheduler()
def wrapped_callback(img, seed, **kwargs):
preferred_seed = (
orig_opt.seed
if orig_opt.seed is not None and orig_opt.seed >= 0
else seed
)
image_callback(img, preferred_seed, use_prefix=prefix, **kwargs)
result = self.generate.prompt2image(
opt.prompt,
seed=opt.seed or orig_opt.seed,
sampler=self.generate.sampler,
steps=opt.steps,
cfg_scale=opt.cfg_scale,
ddim_eta=self.generate.ddim_eta,
width=extended_image.width,
height=extended_image.height,
init_img=extended_image,
strength=0.90,
image_callback=wrapped_callback if image_callback else None,
seam_size=opt.seam_size or 96,
seam_blur=opt.seam_blur or 16,
seam_strength=opt.seam_strength or 0.7,
seam_steps=20,
tile_size=32,
color_match=True,
force_outpaint=True, # this just stops the warning about erased regions
)
# swap sampler back
self.generate.sampler = curr_sampler
return result
def _extend_all(
self,
extents: dict,
) -> Image:
"""
Extend the image in direction ('top','bottom','left','right') by
the indicated value. The image canvas is extended, and the empty
rectangular section will be filled with a blurred copy of the
adjacent image.
"""
image = self.image
for direction in extents:
assert direction in [
"top",
"left",
"bottom",
"right",
], 'Direction must be one of "top", "left", "bottom", "right"'
pixels = extents[direction]
# round pixels up to the nearest 64
pixels = math.ceil(pixels / 64) * 64
logger.info(f"extending image {direction}ward by {pixels} pixels")
image = self._rotate(image, direction)
image = self._extend(image, pixels)
image = self._rotate(image, direction, reverse=True)
return image
def _rotate(self, image: Image, direction: str, reverse=False) -> Image:
"""
Rotates image so that the area to extend is always at the top top.
Simplifies logic later. The reverse argument, if true, will undo the
previous transpose.
"""
transposes = {
"right": ["ROTATE_90", "ROTATE_270"],
"bottom": ["ROTATE_180", "ROTATE_180"],
"left": ["ROTATE_270", "ROTATE_90"],
}
if direction not in transposes:
return image
transpose = transposes[direction][1 if reverse else 0]
return image.transpose(Image.Transpose.__dict__[transpose])
def _extend(self, image: Image, pixels: int) -> Image:
extended_img = Image.new("RGBA", (image.width, image.height + pixels))
extended_img.paste((0, 0, 0), [0, 0, image.width, image.height + pixels])
extended_img.paste(image, box=(0, pixels))
# now make the top part transparent to use as a mask
alpha = extended_img.getchannel("A")
alpha.paste(0, (0, 0, extended_img.width, pixels))
extended_img.putalpha(alpha)
return extended_img

View File

@ -1,102 +0,0 @@
import math
import warnings
from PIL import Image, ImageFilter
class Outpaint(object):
def __init__(self, image, generate):
self.image = image
self.generate = generate
def process(self, opt, old_opt, image_callback=None, prefix=None):
image = self._create_outpaint_image(self.image, opt.out_direction)
seed = old_opt.seed
prompt = old_opt.prompt
def wrapped_callback(img, seed, **kwargs):
image_callback(img, seed, use_prefix=prefix, **kwargs)
return self.generate.prompt2image(
prompt,
seed=seed,
sampler=self.generate.sampler,
steps=opt.steps,
cfg_scale=opt.cfg_scale,
ddim_eta=self.generate.ddim_eta,
width=opt.width,
height=opt.height,
init_img=image,
strength=0.83,
image_callback=wrapped_callback,
prefix=prefix,
)
def _create_outpaint_image(self, image, direction_args):
assert len(direction_args) in [
1,
2,
], "Direction (-D) must have exactly one or two arguments."
if len(direction_args) == 1:
direction = direction_args[0]
pixels = None
elif len(direction_args) == 2:
direction = direction_args[0]
pixels = int(direction_args[1])
assert direction in [
"top",
"left",
"bottom",
"right",
], 'Direction (-D) must be one of "top", "left", "bottom", "right"'
image = image.convert("RGBA")
# we always extend top, but rotate to extend along the requested side
if direction == "left":
image = image.transpose(Image.Transpose.ROTATE_270)
elif direction == "bottom":
image = image.transpose(Image.Transpose.ROTATE_180)
elif direction == "right":
image = image.transpose(Image.Transpose.ROTATE_90)
pixels = image.height // 2 if pixels is None else int(pixels)
assert (
0 < pixels < image.height
), "Direction (-D) pixels length must be in the range 0 - image.size"
# the top part of the image is taken from the source image mirrored
# coordinates (0,0) are the upper left corner of an image
top = image.transpose(Image.Transpose.FLIP_TOP_BOTTOM).convert("RGBA")
top = top.crop((0, top.height - pixels, top.width, top.height))
# setting all alpha of the top part to 0
alpha = top.getchannel("A")
alpha.paste(0, (0, 0, top.width, top.height))
top.putalpha(alpha)
# taking the bottom from the original image
bottom = image.crop((0, 0, image.width, image.height - pixels))
new_img = image.copy()
new_img.paste(top, (0, 0))
new_img.paste(bottom, (0, pixels))
# create a 10% dither in the middle
dither = min(image.height // 10, pixels)
for x in range(0, image.width, 2):
for y in range(pixels - dither, pixels + dither):
(r, g, b, a) = new_img.getpixel((x, y))
new_img.putpixel((x, y), (r, g, b, 0))
# let's rotate back again
if direction == "left":
new_img = new_img.transpose(Image.Transpose.ROTATE_90)
elif direction == "bottom":
new_img = new_img.transpose(Image.Transpose.ROTATE_180)
elif direction == "right":
new_img = new_img.transpose(Image.Transpose.ROTATE_270)
return new_img

View File

@ -1,104 +0,0 @@
import warnings
import numpy as np
import torch
from PIL import Image
from PIL.Image import Image as ImageType
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig
config = InvokeAIAppConfig.get_config()
class ESRGAN:
def __init__(self, bg_tile_size=400) -> None:
self.bg_tile_size = bg_tile_size
def load_esrgan_bg_upsampler(self, denoise_str):
if not torch.cuda.is_available(): # CPU or MPS on M1
use_half_precision = False
else:
use_half_precision = True
from realesrgan import RealESRGANer
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
model = SRVGGNetCompact(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_conv=32,
upscale=4,
act_type="prelu",
)
model_path = config.models_path / "core/upscaling/realesrgan/realesr-general-x4v3.pth"
wdn_model_path = config.models_path / "core/upscaling/realesrgan/realesr-general-wdn-x4v3.pth"
scale = 4
bg_upsampler = RealESRGANer(
scale=scale,
model_path=[model_path, wdn_model_path],
model=model,
tile=self.bg_tile_size,
dni_weight=[denoise_str, 1 - denoise_str],
tile_pad=10,
pre_pad=0,
half=use_half_precision,
)
return bg_upsampler
def process(
self,
image: ImageType,
strength: float,
seed: str = None,
upsampler_scale: int = 2,
denoise_str: float = 0.75,
):
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=UserWarning)
try:
upsampler = self.load_esrgan_bg_upsampler(denoise_str)
except Exception:
import sys
import traceback
logger.error("Error loading Real-ESRGAN:")
print(traceback.format_exc(), file=sys.stderr)
if upsampler_scale == 0:
logger.warning("Real-ESRGAN: Invalid scaling option. Image not upscaled.")
return image
if seed is not None:
logger.info(
f"Real-ESRGAN Upscaling seed:{seed}, scale:{upsampler_scale}x, tile:{self.bg_tile_size}, denoise:{denoise_str}"
)
# ESRGAN outputs images with partial transparency if given RGBA images; convert to RGB
image = image.convert("RGB")
# REALSRGAN expects a BGR np array; make array and flip channels
bgr_image_array = np.array(image, dtype=np.uint8)[..., ::-1]
output, _ = upsampler.enhance(
bgr_image_array,
outscale=upsampler_scale,
alpha_upsampler="realesrgan",
)
# Flip the channels back to RGB
res = Image.fromarray(output[..., ::-1])
if strength < 1.0:
# Resize the image to the new image if the sizes have changed
if output.size != image.size:
image = image.resize(res.size)
res = Image.blend(image, res, strength)
if torch.cuda.is_available():
torch.cuda.empty_cache()
upsampler = None
return res

View File

@ -1,514 +0,0 @@
"""
VQGAN code, adapted from the original created by the Unleashing Transformers authors:
https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py
"""
import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from basicsr.utils import get_root_logger
from basicsr.utils.registry import ARCH_REGISTRY
def normalize(in_channels):
return torch.nn.GroupNorm(
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
)
@torch.jit.script
def swish(x):
return x * torch.sigmoid(x)
# Define VQVAE classes
class VectorQuantizer(nn.Module):
def __init__(self, codebook_size, emb_dim, beta):
super(VectorQuantizer, self).__init__()
self.codebook_size = codebook_size # number of embeddings
self.emb_dim = emb_dim # dimension of embedding
self.beta = beta # commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
self.embedding = nn.Embedding(self.codebook_size, self.emb_dim)
self.embedding.weight.data.uniform_(
-1.0 / self.codebook_size, 1.0 / self.codebook_size
)
def forward(self, z):
# reshape z -> (batch, height, width, channel) and flatten
z = z.permute(0, 2, 3, 1).contiguous()
z_flattened = z.view(-1, self.emb_dim)
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
d = (
(z_flattened**2).sum(dim=1, keepdim=True)
+ (self.embedding.weight**2).sum(1)
- 2 * torch.matmul(z_flattened, self.embedding.weight.t())
)
mean_distance = torch.mean(d)
# find closest encodings
# min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
min_encoding_scores, min_encoding_indices = torch.topk(
d, 1, dim=1, largest=False
)
# [0-1], higher score, higher confidence
min_encoding_scores = torch.exp(-min_encoding_scores / 10)
min_encodings = torch.zeros(
min_encoding_indices.shape[0], self.codebook_size
).to(z)
min_encodings.scatter_(1, min_encoding_indices, 1)
# get quantized latent vectors
z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
# compute loss for embedding
loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean(
(z_q - z.detach()) ** 2
)
# preserve gradients
z_q = z + (z_q - z).detach()
# perplexity
e_mean = torch.mean(min_encodings, dim=0)
perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
# reshape back to match original input shape
z_q = z_q.permute(0, 3, 1, 2).contiguous()
return (
z_q,
loss,
{
"perplexity": perplexity,
"min_encodings": min_encodings,
"min_encoding_indices": min_encoding_indices,
"min_encoding_scores": min_encoding_scores,
"mean_distance": mean_distance,
},
)
def get_codebook_feat(self, indices, shape):
# input indices: batch*token_num -> (batch*token_num)*1
# shape: batch, height, width, channel
indices = indices.view(-1, 1)
min_encodings = torch.zeros(indices.shape[0], self.codebook_size).to(indices)
min_encodings.scatter_(1, indices, 1)
# get quantized latent vectors
z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
if shape is not None: # reshape back to match original input shape
z_q = z_q.view(shape).permute(0, 3, 1, 2).contiguous()
return z_q
class GumbelQuantizer(nn.Module):
def __init__(
self,
codebook_size,
emb_dim,
num_hiddens,
straight_through=False,
kl_weight=5e-4,
temp_init=1.0,
):
super().__init__()
self.codebook_size = codebook_size # number of embeddings
self.emb_dim = emb_dim # dimension of embedding
self.straight_through = straight_through
self.temperature = temp_init
self.kl_weight = kl_weight
self.proj = nn.Conv2d(
num_hiddens, codebook_size, 1
) # projects last encoder layer to quantized logits
self.embed = nn.Embedding(codebook_size, emb_dim)
def forward(self, z):
hard = self.straight_through if self.training else True
logits = self.proj(z)
soft_one_hot = F.gumbel_softmax(logits, tau=self.temperature, dim=1, hard=hard)
z_q = torch.einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight)
# + kl divergence to the prior loss
qy = F.softmax(logits, dim=1)
diff = (
self.kl_weight
* torch.sum(qy * torch.log(qy * self.codebook_size + 1e-10), dim=1).mean()
)
min_encoding_indices = soft_one_hot.argmax(dim=1)
return z_q, diff, {"min_encoding_indices": min_encoding_indices}
class Downsample(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.conv = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=2, padding=0
)
def forward(self, x):
pad = (0, 1, 0, 1)
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
x = self.conv(x)
return x
class Upsample(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.conv = nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=1, padding=1
)
def forward(self, x):
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
x = self.conv(x)
return x
class ResBlock(nn.Module):
def __init__(self, in_channels, out_channels=None):
super(ResBlock, self).__init__()
self.in_channels = in_channels
self.out_channels = in_channels if out_channels is None else out_channels
self.norm1 = normalize(in_channels)
self.conv1 = nn.Conv2d(
in_channels, out_channels, kernel_size=3, stride=1, padding=1
)
self.norm2 = normalize(out_channels)
self.conv2 = nn.Conv2d(
out_channels, out_channels, kernel_size=3, stride=1, padding=1
)
if self.in_channels != self.out_channels:
self.conv_out = nn.Conv2d(
in_channels, out_channels, kernel_size=1, stride=1, padding=0
)
def forward(self, x_in):
x = x_in
x = self.norm1(x)
x = swish(x)
x = self.conv1(x)
x = self.norm2(x)
x = swish(x)
x = self.conv2(x)
if self.in_channels != self.out_channels:
x_in = self.conv_out(x_in)
return x + x_in
class AttnBlock(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = normalize(in_channels)
self.q = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.k = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.v = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.proj_out = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
b, c, h, w = q.shape
q = q.reshape(b, c, h * w)
q = q.permute(0, 2, 1)
k = k.reshape(b, c, h * w)
w_ = torch.bmm(q, k)
w_ = w_ * (int(c) ** (-0.5))
w_ = F.softmax(w_, dim=2)
# attend to values
v = v.reshape(b, c, h * w)
w_ = w_.permute(0, 2, 1)
h_ = torch.bmm(v, w_)
h_ = h_.reshape(b, c, h, w)
h_ = self.proj_out(h_)
return x + h_
class Encoder(nn.Module):
def __init__(
self,
in_channels,
nf,
emb_dim,
ch_mult,
num_res_blocks,
resolution,
attn_resolutions,
):
super().__init__()
self.nf = nf
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.attn_resolutions = attn_resolutions
curr_res = self.resolution
in_ch_mult = (1,) + tuple(ch_mult)
blocks = []
# initial convultion
blocks.append(nn.Conv2d(in_channels, nf, kernel_size=3, stride=1, padding=1))
# residual and downsampling blocks, with attention on smaller res (16x16)
for i in range(self.num_resolutions):
block_in_ch = nf * in_ch_mult[i]
block_out_ch = nf * ch_mult[i]
for _ in range(self.num_res_blocks):
blocks.append(ResBlock(block_in_ch, block_out_ch))
block_in_ch = block_out_ch
if curr_res in attn_resolutions:
blocks.append(AttnBlock(block_in_ch))
if i != self.num_resolutions - 1:
blocks.append(Downsample(block_in_ch))
curr_res = curr_res // 2
# non-local attention block
blocks.append(ResBlock(block_in_ch, block_in_ch))
blocks.append(AttnBlock(block_in_ch))
blocks.append(ResBlock(block_in_ch, block_in_ch))
# normalise and convert to latent size
blocks.append(normalize(block_in_ch))
blocks.append(
nn.Conv2d(block_in_ch, emb_dim, kernel_size=3, stride=1, padding=1)
)
self.blocks = nn.ModuleList(blocks)
def forward(self, x):
for block in self.blocks:
x = block(x)
return x
class Generator(nn.Module):
def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions):
super().__init__()
self.nf = nf
self.ch_mult = ch_mult
self.num_resolutions = len(self.ch_mult)
self.num_res_blocks = res_blocks
self.resolution = img_size
self.attn_resolutions = attn_resolutions
self.in_channels = emb_dim
self.out_channels = 3
block_in_ch = self.nf * self.ch_mult[-1]
curr_res = self.resolution // 2 ** (self.num_resolutions - 1)
blocks = []
# initial conv
blocks.append(
nn.Conv2d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1)
)
# non-local attention block
blocks.append(ResBlock(block_in_ch, block_in_ch))
blocks.append(AttnBlock(block_in_ch))
blocks.append(ResBlock(block_in_ch, block_in_ch))
for i in reversed(range(self.num_resolutions)):
block_out_ch = self.nf * self.ch_mult[i]
for _ in range(self.num_res_blocks):
blocks.append(ResBlock(block_in_ch, block_out_ch))
block_in_ch = block_out_ch
if curr_res in self.attn_resolutions:
blocks.append(AttnBlock(block_in_ch))
if i != 0:
blocks.append(Upsample(block_in_ch))
curr_res = curr_res * 2
blocks.append(normalize(block_in_ch))
blocks.append(
nn.Conv2d(
block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1
)
)
self.blocks = nn.ModuleList(blocks)
def forward(self, x):
for block in self.blocks:
x = block(x)
return x
@ARCH_REGISTRY.register()
class VQAutoEncoder(nn.Module):
def __init__(
self,
img_size,
nf,
ch_mult,
quantizer="nearest",
res_blocks=2,
attn_resolutions=[16],
codebook_size=1024,
emb_dim=256,
beta=0.25,
gumbel_straight_through=False,
gumbel_kl_weight=1e-8,
model_path=None,
):
super().__init__()
logger = get_root_logger()
self.in_channels = 3
self.nf = nf
self.n_blocks = res_blocks
self.codebook_size = codebook_size
self.embed_dim = emb_dim
self.ch_mult = ch_mult
self.resolution = img_size
self.attn_resolutions = attn_resolutions
self.quantizer_type = quantizer
self.encoder = Encoder(
self.in_channels,
self.nf,
self.embed_dim,
self.ch_mult,
self.n_blocks,
self.resolution,
self.attn_resolutions,
)
if self.quantizer_type == "nearest":
self.beta = beta # 0.25
self.quantize = VectorQuantizer(
self.codebook_size, self.embed_dim, self.beta
)
elif self.quantizer_type == "gumbel":
self.gumbel_num_hiddens = emb_dim
self.straight_through = gumbel_straight_through
self.kl_weight = gumbel_kl_weight
self.quantize = GumbelQuantizer(
self.codebook_size,
self.embed_dim,
self.gumbel_num_hiddens,
self.straight_through,
self.kl_weight,
)
self.generator = Generator(
self.nf,
self.embed_dim,
self.ch_mult,
self.n_blocks,
self.resolution,
self.attn_resolutions,
)
if model_path is not None:
chkpt = torch.load(model_path, map_location="cpu")
if "params_ema" in chkpt:
self.load_state_dict(
torch.load(model_path, map_location="cpu")["params_ema"]
)
logger.info(f"vqgan is loaded from: {model_path} [params_ema]")
elif "params" in chkpt:
self.load_state_dict(
torch.load(model_path, map_location="cpu")["params"]
)
logger.info(f"vqgan is loaded from: {model_path} [params]")
else:
raise ValueError(f"Wrong params!")
def forward(self, x):
x = self.encoder(x)
quant, codebook_loss, quant_stats = self.quantize(x)
x = self.generator(quant)
return x, codebook_loss, quant_stats
# patch based discriminator
@ARCH_REGISTRY.register()
class VQGANDiscriminator(nn.Module):
def __init__(self, nc=3, ndf=64, n_layers=4, model_path=None):
super().__init__()
layers = [
nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(0.2, True),
]
ndf_mult = 1
ndf_mult_prev = 1
for n in range(1, n_layers): # gradually increase the number of filters
ndf_mult_prev = ndf_mult
ndf_mult = min(2**n, 8)
layers += [
nn.Conv2d(
ndf * ndf_mult_prev,
ndf * ndf_mult,
kernel_size=4,
stride=2,
padding=1,
bias=False,
),
nn.BatchNorm2d(ndf * ndf_mult),
nn.LeakyReLU(0.2, True),
]
ndf_mult_prev = ndf_mult
ndf_mult = min(2**n_layers, 8)
layers += [
nn.Conv2d(
ndf * ndf_mult_prev,
ndf * ndf_mult,
kernel_size=4,
stride=1,
padding=1,
bias=False,
),
nn.BatchNorm2d(ndf * ndf_mult),
nn.LeakyReLU(0.2, True),
]
layers += [
nn.Conv2d(ndf * ndf_mult, 1, kernel_size=4, stride=1, padding=1)
] # output 1 channel prediction map
self.main = nn.Sequential(*layers)
if model_path is not None:
chkpt = torch.load(model_path, map_location="cpu")
if "params_d" in chkpt:
self.load_state_dict(
torch.load(model_path, map_location="cpu")["params_d"]
)
elif "params" in chkpt:
self.load_state_dict(
torch.load(model_path, map_location="cpu")["params"]
)
else:
raise ValueError(f"Wrong params!")
def forward(self, x):
return self.main(x)

View File

@ -221,7 +221,7 @@ class ControlNetData:
control_mode: str = Field(default="balanced")
@dataclass(frozen=True)
@dataclass
class ConditioningData:
unconditioned_embeddings: torch.Tensor
text_embeddings: torch.Tensor
@ -507,6 +507,40 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
control_data: List[ControlNetData] = None,
**kwargs,
):
def _pad_conditioning(cond, target_len, encoder_attention_mask):
conditioning_attention_mask = torch.ones((cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype)
if cond.shape[1] < max_len:
conditioning_attention_mask = torch.cat([
conditioning_attention_mask,
torch.zeros((cond.shape[0], max_len - cond.shape[1]), device=cond.device, dtype=cond.dtype),
], dim=1)
cond = torch.cat([
cond,
torch.zeros((cond.shape[0], max_len - cond.shape[1], cond.shape[2]), device=cond.device, dtype=cond.dtype),
], dim=1)
if encoder_attention_mask is None:
encoder_attention_mask = conditioning_attention_mask
else:
encoder_attention_mask = torch.cat([
encoder_attention_mask,
conditioning_attention_mask,
])
return cond, encoder_attention_mask
encoder_attention_mask = None
if conditioning_data.unconditioned_embeddings.shape[1] != conditioning_data.text_embeddings.shape[1]:
max_len = max(conditioning_data.unconditioned_embeddings.shape[1], conditioning_data.text_embeddings.shape[1])
conditioning_data.unconditioned_embeddings, encoder_attention_mask = _pad_conditioning(
conditioning_data.unconditioned_embeddings, max_len, encoder_attention_mask
)
conditioning_data.text_embeddings, encoder_attention_mask = _pad_conditioning(
conditioning_data.text_embeddings, max_len, encoder_attention_mask
)
self._adjust_memory_efficient_attention(latents)
if run_id is None:
run_id = secrets.token_urlsafe(self.ID_LENGTH)
@ -546,6 +580,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
total_step_count=len(timesteps),
additional_guidance=additional_guidance,
control_data=control_data,
encoder_attention_mask=encoder_attention_mask,
**kwargs,
)
latents = step_output.prev_sample
@ -603,6 +638,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
down_block_res_samples, mid_block_res_sample = None, None
if control_data is not None:
# TODO: rewrite to pass with conditionings
encoder_attention_mask = kwargs.get("encoder_attention_mask", None)
# control_data should be type List[ControlNetData]
# this loop covers both ControlNet (one ControlNetData in list)
# and MultiControlNet (multiple ControlNetData in list)
@ -649,6 +686,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
encoder_hidden_states=encoder_hidden_states,
controlnet_cond=control_datum.image_tensor,
conditioning_scale=controlnet_weight, # controlnet specific, NOT the guidance scale
encoder_attention_mask=encoder_attention_mask,
guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel
return_dict=False,
)

View File

@ -241,45 +241,12 @@ class InvokeAIDiffuserComponent:
def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs):
# fast batched path
def _pad_conditioning(cond, target_len, encoder_attention_mask):
conditioning_attention_mask = torch.ones((cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype)
if cond.shape[1] < max_len:
conditioning_attention_mask = torch.cat([
conditioning_attention_mask,
torch.zeros((cond.shape[0], max_len - cond.shape[1]), device=cond.device, dtype=cond.dtype),
], dim=1)
cond = torch.cat([
cond,
torch.zeros((cond.shape[0], max_len - cond.shape[1], cond.shape[2]), device=cond.device, dtype=cond.dtype),
], dim=1)
if encoder_attention_mask is None:
encoder_attention_mask = conditioning_attention_mask
else:
encoder_attention_mask = torch.cat([
encoder_attention_mask,
conditioning_attention_mask,
])
return cond, encoder_attention_mask
x_twice = torch.cat([x] * 2)
sigma_twice = torch.cat([sigma] * 2)
encoder_attention_mask = None
if unconditioning.shape[1] != conditioning.shape[1]:
max_len = max(unconditioning.shape[1], conditioning.shape[1])
unconditioning, encoder_attention_mask = _pad_conditioning(unconditioning, max_len, encoder_attention_mask)
conditioning, encoder_attention_mask = _pad_conditioning(conditioning, max_len, encoder_attention_mask)
both_conditionings = torch.cat([unconditioning, conditioning])
both_results = self.model_forward_callback(
x_twice, sigma_twice, both_conditionings,
encoder_attention_mask=encoder_attention_mask,
**kwargs,
x_twice, sigma_twice, both_conditionings, **kwargs,
)
unconditioned_next_x, conditioned_next_x = both_results.chunk(2)
return unconditioned_next_x, conditioned_next_x

View File

@ -0,0 +1,634 @@
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from torch import nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.unet_2d_blocks import (
CrossAttnDownBlock2D,
DownBlock2D,
UNetMidBlock2DCrossAttn,
get_down_block,
)
from diffusers.models.unet_2d_condition import UNet2DConditionModel
import diffusers
from diffusers.models.controlnet import ControlNetConditioningEmbedding, ControlNetOutput, zero_module
# Modified ControlNetModel with encoder_attention_mask argument added
class ControlNetModel(ModelMixin, ConfigMixin):
"""
A ControlNet model.
Args:
in_channels (`int`, defaults to 4):
The number of channels in the input sample.
flip_sin_to_cos (`bool`, defaults to `True`):
Whether to flip the sin to cos in the time embedding.
freq_shift (`int`, defaults to 0):
The frequency shift to apply to the time embedding.
down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
The tuple of downsample blocks to use.
only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
The tuple of output channels for each block.
layers_per_block (`int`, defaults to 2):
The number of layers per block.
downsample_padding (`int`, defaults to 1):
The padding to use for the downsampling convolution.
mid_block_scale_factor (`float`, defaults to 1):
The scale factor to use for the mid block.
act_fn (`str`, defaults to "silu"):
The activation function to use.
norm_num_groups (`int`, *optional*, defaults to 32):
The number of groups to use for the normalization. If None, normalization and activation layers is skipped
in post-processing.
norm_eps (`float`, defaults to 1e-5):
The epsilon to use for the normalization.
cross_attention_dim (`int`, defaults to 1280):
The dimension of the cross attention features.
attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
The dimension of the attention heads.
use_linear_projection (`bool`, defaults to `False`):
class_embed_type (`str`, *optional*, defaults to `None`):
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
num_class_embeds (`int`, *optional*, defaults to 0):
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
class conditioning with `class_embed_type` equal to `None`.
upcast_attention (`bool`, defaults to `False`):
resnet_time_scale_shift (`str`, defaults to `"default"`):
Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
`class_embed_type="projection"`.
controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
The tuple of output channel for each block in the `conditioning_embedding` layer.
global_pool_conditions (`bool`, defaults to `False`):
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
in_channels: int = 4,
conditioning_channels: int = 3,
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
down_block_types: Tuple[str] = (
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"DownBlock2D",
),
only_cross_attention: Union[bool, Tuple[bool]] = False,
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
layers_per_block: int = 2,
downsample_padding: int = 1,
mid_block_scale_factor: float = 1,
act_fn: str = "silu",
norm_num_groups: Optional[int] = 32,
norm_eps: float = 1e-5,
cross_attention_dim: int = 1280,
attention_head_dim: Union[int, Tuple[int]] = 8,
num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
use_linear_projection: bool = False,
class_embed_type: Optional[str] = None,
num_class_embeds: Optional[int] = None,
upcast_attention: bool = False,
resnet_time_scale_shift: str = "default",
projection_class_embeddings_input_dim: Optional[int] = None,
controlnet_conditioning_channel_order: str = "rgb",
conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
global_pool_conditions: bool = False,
):
super().__init__()
# If `num_attention_heads` is not defined (which is the case for most models)
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
# The reason for this behavior is to correct for incorrectly named variables that were introduced
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
# which is why we correct for the naming here.
num_attention_heads = num_attention_heads or attention_head_dim
# Check inputs
if len(block_out_channels) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
)
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
)
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
)
# input
conv_in_kernel = 3
conv_in_padding = (conv_in_kernel - 1) // 2
self.conv_in = nn.Conv2d(
in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
)
# time
time_embed_dim = block_out_channels[0] * 4
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
timestep_input_dim = block_out_channels[0]
self.time_embedding = TimestepEmbedding(
timestep_input_dim,
time_embed_dim,
act_fn=act_fn,
)
# class embedding
if class_embed_type is None and num_class_embeds is not None:
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
elif class_embed_type == "timestep":
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
elif class_embed_type == "identity":
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
elif class_embed_type == "projection":
if projection_class_embeddings_input_dim is None:
raise ValueError(
"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
)
# The projection `class_embed_type` is the same as the timestep `class_embed_type` except
# 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
# 2. it projects from an arbitrary input dimension.
#
# Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
# As a result, `TimestepEmbedding` can be passed arbitrary vectors.
self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
else:
self.class_embedding = None
# control net conditioning embedding
self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
conditioning_embedding_channels=block_out_channels[0],
block_out_channels=conditioning_embedding_out_channels,
conditioning_channels=conditioning_channels,
)
self.down_blocks = nn.ModuleList([])
self.controlnet_down_blocks = nn.ModuleList([])
if isinstance(only_cross_attention, bool):
only_cross_attention = [only_cross_attention] * len(down_block_types)
if isinstance(attention_head_dim, int):
attention_head_dim = (attention_head_dim,) * len(down_block_types)
if isinstance(num_attention_heads, int):
num_attention_heads = (num_attention_heads,) * len(down_block_types)
# down
output_channel = block_out_channels[0]
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
controlnet_block = zero_module(controlnet_block)
self.controlnet_down_blocks.append(controlnet_block)
for i, down_block_type in enumerate(down_block_types):
input_channel = output_channel
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
down_block = get_down_block(
down_block_type,
num_layers=layers_per_block,
in_channels=input_channel,
out_channels=output_channel,
temb_channels=time_embed_dim,
add_downsample=not is_final_block,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
num_attention_heads=num_attention_heads[i],
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
downsample_padding=downsample_padding,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention[i],
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
)
self.down_blocks.append(down_block)
for _ in range(layers_per_block):
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
controlnet_block = zero_module(controlnet_block)
self.controlnet_down_blocks.append(controlnet_block)
if not is_final_block:
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
controlnet_block = zero_module(controlnet_block)
self.controlnet_down_blocks.append(controlnet_block)
# mid
mid_block_channel = block_out_channels[-1]
controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
controlnet_block = zero_module(controlnet_block)
self.controlnet_mid_block = controlnet_block
self.mid_block = UNetMidBlock2DCrossAttn(
in_channels=mid_block_channel,
temb_channels=time_embed_dim,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
resnet_time_scale_shift=resnet_time_scale_shift,
cross_attention_dim=cross_attention_dim,
num_attention_heads=num_attention_heads[-1],
resnet_groups=norm_num_groups,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
)
@classmethod
def from_unet(
cls,
unet: UNet2DConditionModel,
controlnet_conditioning_channel_order: str = "rgb",
conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
load_weights_from_unet: bool = True,
):
r"""
Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`].
Parameters:
unet (`UNet2DConditionModel`):
The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied
where applicable.
"""
controlnet = cls(
in_channels=unet.config.in_channels,
flip_sin_to_cos=unet.config.flip_sin_to_cos,
freq_shift=unet.config.freq_shift,
down_block_types=unet.config.down_block_types,
only_cross_attention=unet.config.only_cross_attention,
block_out_channels=unet.config.block_out_channels,
layers_per_block=unet.config.layers_per_block,
downsample_padding=unet.config.downsample_padding,
mid_block_scale_factor=unet.config.mid_block_scale_factor,
act_fn=unet.config.act_fn,
norm_num_groups=unet.config.norm_num_groups,
norm_eps=unet.config.norm_eps,
cross_attention_dim=unet.config.cross_attention_dim,
attention_head_dim=unet.config.attention_head_dim,
num_attention_heads=unet.config.num_attention_heads,
use_linear_projection=unet.config.use_linear_projection,
class_embed_type=unet.config.class_embed_type,
num_class_embeds=unet.config.num_class_embeds,
upcast_attention=unet.config.upcast_attention,
resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
conditioning_embedding_out_channels=conditioning_embedding_out_channels,
)
if load_weights_from_unet:
controlnet.conv_in.load_state_dict(unet.conv_in.state_dict())
controlnet.time_proj.load_state_dict(unet.time_proj.state_dict())
controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
if controlnet.class_embedding:
controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())
controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict())
controlnet.mid_block.load_state_dict(unet.mid_block.state_dict())
return controlnet
@property
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "set_processor"):
processors[f"{name}.processor"] = module.processor
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
"""
self.set_attn_processor(AttnProcessor())
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
def set_attention_slice(self, slice_size):
r"""
Enable sliced attention computation.
When this option is enabled, the attention module splits the input tensor in slices to compute attention in
several steps. This is useful for saving some memory in exchange for a small decrease in speed.
Args:
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
`"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
must be a multiple of `slice_size`.
"""
sliceable_head_dims = []
def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
if hasattr(module, "set_attention_slice"):
sliceable_head_dims.append(module.sliceable_head_dim)
for child in module.children():
fn_recursive_retrieve_sliceable_dims(child)
# retrieve number of attention layers
for module in self.children():
fn_recursive_retrieve_sliceable_dims(module)
num_sliceable_layers = len(sliceable_head_dims)
if slice_size == "auto":
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = [dim // 2 for dim in sliceable_head_dims]
elif slice_size == "max":
# make smallest slice possible
slice_size = num_sliceable_layers * [1]
slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
if len(slice_size) != len(sliceable_head_dims):
raise ValueError(
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
)
for i in range(len(slice_size)):
size = slice_size[i]
dim = sliceable_head_dims[i]
if size is not None and size > dim:
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
# Recursively walk through all the children.
# Any children which exposes the set_attention_slice method
# gets the message
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
if hasattr(module, "set_attention_slice"):
module.set_attention_slice(slice_size.pop())
for child in module.children():
fn_recursive_set_attention_slice(child, slice_size)
reversed_slice_size = list(reversed(slice_size))
for module in self.children():
fn_recursive_set_attention_slice(module, reversed_slice_size)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
module.gradient_checkpointing = value
def forward(
self,
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
controlnet_cond: torch.FloatTensor,
conditioning_scale: float = 1.0,
class_labels: Optional[torch.Tensor] = None,
timestep_cond: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
guess_mode: bool = False,
return_dict: bool = True,
) -> Union[ControlNetOutput, Tuple]:
"""
The [`ControlNetModel`] forward method.
Args:
sample (`torch.FloatTensor`):
The noisy input tensor.
timestep (`Union[torch.Tensor, float, int]`):
The number of timesteps to denoise an input.
encoder_hidden_states (`torch.Tensor`):
The encoder hidden states.
controlnet_cond (`torch.FloatTensor`):
The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
conditioning_scale (`float`, defaults to `1.0`):
The scale factor for ControlNet outputs.
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
cross_attention_kwargs(`dict[str]`, *optional*, defaults to `None`):
A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
encoder_attention_mask (`torch.Tensor`):
A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
`True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
which adds large negative values to the attention scores corresponding to "discard" tokens.
guess_mode (`bool`, defaults to `False`):
In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
return_dict (`bool`, defaults to `True`):
Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
Returns:
[`~models.controlnet.ControlNetOutput`] **or** `tuple`:
If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
returned where the first element is the sample tensor.
"""
# check channel order
channel_order = self.config.controlnet_conditioning_channel_order
if channel_order == "rgb":
# in rgb order by default
...
elif channel_order == "bgr":
controlnet_cond = torch.flip(controlnet_cond, dims=[1])
else:
raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}")
# prepare attention_mask
if attention_mask is not None:
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
attention_mask = attention_mask.unsqueeze(1)
# convert encoder_attention_mask to a bias the same way we do for attention_mask
if encoder_attention_mask is not None:
encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
if isinstance(timestep, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps.expand(sample.shape[0])
t_emb = self.time_proj(timesteps)
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=sample.dtype)
emb = self.time_embedding(t_emb, timestep_cond)
if self.class_embedding is not None:
if class_labels is None:
raise ValueError("class_labels should be provided when num_class_embeds > 0")
if self.config.class_embed_type == "timestep":
class_labels = self.time_proj(class_labels)
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
emb = emb + class_emb
# 2. pre-process
sample = self.conv_in(sample)
controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
sample = sample + controlnet_cond
# 3. down
down_block_res_samples = (sample,)
for downsample_block in self.down_blocks:
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
sample, res_samples = downsample_block(
hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
encoder_attention_mask=encoder_attention_mask,
)
else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
down_block_res_samples += res_samples
# 4. mid
if self.mid_block is not None:
sample = self.mid_block(
sample,
emb,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
encoder_attention_mask=encoder_attention_mask,
)
# 5. Control net blocks
controlnet_down_block_res_samples = ()
for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
down_block_res_sample = controlnet_block(down_block_res_sample)
controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
down_block_res_samples = controlnet_down_block_res_samples
mid_block_res_sample = self.controlnet_mid_block(sample)
# 6. scaling
if guess_mode and not self.config.global_pool_conditions:
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
scales = scales * conditioning_scale
down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
else:
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
mid_block_res_sample = mid_block_res_sample * conditioning_scale
if self.config.global_pool_conditions:
down_block_res_samples = [
torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
]
mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)
if not return_dict:
return (down_block_res_samples, mid_block_res_sample)
return ControlNetOutput(
down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
)
diffusers.ControlNetModel = ControlNetModel
diffusers.models.controlnet.ControlNetModel = ControlNetModel

View File

@ -58,22 +58,29 @@ sd-1/main/waifu-diffusion:
recommended: False
sd-1/controlnet/canny:
repo_id: lllyasviel/control_v11p_sd15_canny
recommended: True
sd-1/controlnet/inpaint:
repo_id: lllyasviel/control_v11p_sd15_inpaint
sd-1/controlnet/mlsd:
repo_id: lllyasviel/control_v11p_sd15_mlsd
sd-1/controlnet/depth:
repo_id: lllyasviel/control_v11f1p_sd15_depth
recommended: True
sd-1/controlnet/normal_bae:
repo_id: lllyasviel/control_v11p_sd15_normalbae
sd-1/controlnet/seg:
repo_id: lllyasviel/control_v11p_sd15_seg
sd-1/controlnet/lineart:
repo_id: lllyasviel/control_v11p_sd15_lineart
recommended: True
sd-1/controlnet/lineart_anime:
repo_id: lllyasviel/control_v11p_sd15s2_lineart_anime
sd-1/controlnet/openpose:
repo_id: lllyasviel/control_v11p_sd15_openpose
recommended: True
sd-1/controlnet/scribble:
repo_id: lllyasviel/control_v11p_sd15_scribble
recommended: False
sd-1/controlnet/softedge:
repo_id: lllyasviel/control_v11p_sd15_softedge
sd-1/controlnet/shuffle:
@ -84,6 +91,7 @@ sd-1/controlnet/ip2p:
repo_id: lllyasviel/control_v11e_sd15_ip2p
sd-1/embedding/EasyNegative:
path: https://huggingface.co/embed/EasyNegative/resolve/main/EasyNegative.safetensors
recommended: True
sd-1/embedding/ahx-beta-453407d:
repo_id: sd-concepts-library/ahx-beta-453407d
sd-1/lora/LowRA:

View File

@ -256,6 +256,8 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
widgets = dict()
model_list = [x for x in self.all_models if self.all_models[x].model_type==model_type and not x in exclude]
model_labels = [self.model_labels[x] for x in model_list]
show_recommended = len(self.installed_models)==0
if len(model_list) > 0:
max_width = max([len(x) for x in model_labels])
columns = window_width // (max_width+8) # 8 characters for "[x] " and padding
@ -280,7 +282,8 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
value=[
model_list.index(x)
for x in model_list
if self.all_models[x].installed
if (show_recommended and self.all_models[x].recommended) \
or self.all_models[x].installed
],
max_height=len(model_list)//columns + 1,
relx=4,
@ -672,7 +675,9 @@ def select_and_download_models(opt: Namespace):
# pass
installer = ModelInstall(config, prediction_type_helper=helper)
if opt.add or opt.delete:
if opt.list_models:
installer.list_models(opt.list_models)
elif opt.add or opt.delete:
selections = InstallSelections(
install_models = opt.add or [],
remove_models = opt.delete or []
@ -745,7 +750,7 @@ def main():
)
parser.add_argument(
"--list-models",
choices=["diffusers","loras","controlnets","tis"],
choices=[x.value for x in ModelType],
help="list installed models",
)
parser.add_argument(
@ -773,7 +778,7 @@ def main():
config.parse_args(invoke_args)
logger = InvokeAILogger().getLogger(config=config)
if not (config.conf_path / 'models.yaml').exists():
if not config.model_conf_path.exists():
logger.info(
"Your InvokeAI root directory is not set up. Calling invokeai-configure."
)

View File

@ -75,11 +75,6 @@ export type paths = {
* @description Gets a list of models
*/
get: operations["list_models"];
/**
* Import Model
* @description Add a model using its local path, repo_id, or remote URL
*/
post: operations["import_model"];
};
"/api/v1/models/{base_model}/{model_type}/{model_name}": {
/**
@ -93,13 +88,53 @@ export type paths = {
*/
patch: operations["update_model"];
};
"/api/v1/models/import": {
/**
* Import Model
* @description Add a model using its local path, repo_id, or remote URL. Model characteristics will be probed and configured automatically
*/
post: operations["import_model"];
};
"/api/v1/models/add": {
/**
* Add Model
* @description Add a model using the configuration information appropriate for its type. Only local models can be added by path
*/
post: operations["add_model"];
};
"/api/v1/models/rename/{base_model}/{model_type}/{model_name}": {
/**
* Rename Model
* @description Rename a model
*/
post: operations["rename_model"];
};
"/api/v1/models/convert/{base_model}/{model_type}/{model_name}": {
/**
* Convert Model
* @description Convert a checkpoint model into a diffusers model
* @description Convert a checkpoint model into a diffusers model, optionally saving to the indicated destination directory, or `models` if none.
*/
put: operations["convert_model"];
};
"/api/v1/models/search": {
/** Search For Models */
get: operations["search_for_models"];
};
"/api/v1/models/ckpt_confs": {
/**
* List Ckpt Configs
* @description Return a list of the legacy checkpoint configuration files stored in `ROOT/configs/stable-diffusion`, relative to ROOT.
*/
get: operations["list_ckpt_configs"];
};
"/api/v1/models/sync": {
/**
* Sync To Config
* @description Call after making changes to models.yaml, autoimport directories or models directory to synchronize
* in-memory data structures with disk data structures.
*/
get: operations["sync_to_config"];
};
"/api/v1/models/merge/{base_model}": {
/**
* Merge Models
@ -397,6 +432,11 @@ export type components = {
* @default false
*/
force?: boolean;
/**
* Merge Dest Directory
* @description Save the merged model to the designated directory (with 'merged_model_name' appended)
*/
merge_dest_directory?: string;
};
/** Body_remove_board_image */
Body_remove_board_image: {
@ -1186,7 +1226,7 @@ export type components = {
* @description The nodes in this graph
*/
nodes?: {
[key: string]: (components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["UpscaleInvocation"] | components["schemas"]["RestoreFaceInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"]) | undefined;
[key: string]: (components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["RealESRGANInvocation"] | components["schemas"]["RestoreFaceInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"]) | undefined;
};
/**
* Edges
@ -3302,7 +3342,7 @@ export type components = {
/** ModelsList */
ModelsList: {
/** Models */
models: (components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"])[];
models: (components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"])[];
};
/**
* MultiplyInvocation
@ -3893,6 +3933,41 @@ export type components = {
*/
step?: number;
};
/**
* RealESRGANInvocation
* @description Upscales an image using RealESRGAN.
*/
RealESRGANInvocation: {
/**
* Id
* @description The id of this node. Must be unique among all nodes.
*/
id: string;
/**
* Is Intermediate
* @description Whether or not this node is an intermediate node.
* @default false
*/
is_intermediate?: boolean;
/**
* Type
* @default realesrgan
* @enum {string}
*/
type?: "realesrgan";
/**
* Image
* @description The input image
*/
image?: components["schemas"]["ImageField"];
/**
* Model Name
* @description The Real-ESRGAN model to use
* @default RealESRGAN_x4plus.pth
* @enum {string}
*/
model_name?: "RealESRGAN_x4plus.pth" | "RealESRGAN_x4plus_anime_6B.pth" | "ESRGAN_SRx4_DF2KOST_official-ff704c30.pth";
};
/**
* ResizeLatentsInvocation
* @description Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8.
@ -4452,47 +4527,6 @@ export type components = {
*/
loras: (components["schemas"]["LoraInfo"])[];
};
/**
* UpscaleInvocation
* @description Upscales an image.
*/
UpscaleInvocation: {
/**
* Id
* @description The id of this node. Must be unique among all nodes.
*/
id: string;
/**
* Is Intermediate
* @description Whether or not this node is an intermediate node.
* @default false
*/
is_intermediate?: boolean;
/**
* Type
* @default upscale
* @enum {string}
*/
type?: "upscale";
/**
* Image
* @description The input image
*/
image?: components["schemas"]["ImageField"];
/**
* Strength
* @description The strength
* @default 0.75
*/
strength?: number;
/**
* Level
* @description The upscale level
* @default 2
* @enum {integer}
*/
level?: 2 | 4;
};
/**
* VAEModelField
* @description Vae model field
@ -4619,18 +4653,18 @@ export type components = {
*/
image?: components["schemas"]["ImageField"];
};
/**
* StableDiffusion1ModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
/**
* StableDiffusion2ModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
/**
* StableDiffusion1ModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
};
responses: never;
parameters: never;
@ -4741,7 +4775,7 @@ export type operations = {
};
requestBody: {
content: {
"application/json": components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["UpscaleInvocation"] | components["schemas"]["RestoreFaceInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"];
"application/json": components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["RealESRGANInvocation"] | components["schemas"]["RestoreFaceInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"];
};
};
responses: {
@ -4778,7 +4812,7 @@ export type operations = {
};
requestBody: {
content: {
"application/json": components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["UpscaleInvocation"] | components["schemas"]["RestoreFaceInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"];
"application/json": components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["RealESRGANInvocation"] | components["schemas"]["RestoreFaceInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"];
};
};
responses: {
@ -4997,37 +5031,6 @@ export type operations = {
};
};
};
/**
* Import Model
* @description Add a model using its local path, repo_id, or remote URL
*/
import_model: {
requestBody: {
content: {
"application/json": components["schemas"]["Body_import_model"];
};
};
responses: {
/** @description The model imported successfully */
201: {
content: {
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"];
};
};
/** @description The model could not be found */
404: never;
/** @description There is already a model corresponding to this path or repo_id */
409: never;
/** @description Validation Error */
422: {
content: {
"application/json": components["schemas"]["HTTPValidationError"];
};
};
/** @description The model appeared to import successfully, but could not be found in the model manager */
424: never;
};
};
/**
* Delete Model
* @description Delete Model
@ -5044,12 +5047,6 @@ export type operations = {
};
};
responses: {
/** @description Successful Response */
200: {
content: {
"application/json": unknown;
};
};
/** @description Model deleted successfully */
204: never;
/** @description Model not found */
@ -5079,14 +5076,14 @@ export type operations = {
};
requestBody: {
content: {
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"];
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"];
};
};
responses: {
/** @description The model was updated successfully */
200: {
content: {
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"];
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"];
};
};
/** @description Bad request */
@ -5101,12 +5098,118 @@ export type operations = {
};
};
};
/**
* Import Model
* @description Add a model using its local path, repo_id, or remote URL. Model characteristics will be probed and configured automatically
*/
import_model: {
requestBody: {
content: {
"application/json": components["schemas"]["Body_import_model"];
};
};
responses: {
/** @description The model imported successfully */
201: {
content: {
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"];
};
};
/** @description The model could not be found */
404: never;
/** @description There is already a model corresponding to this path or repo_id */
409: never;
/** @description Validation Error */
422: {
content: {
"application/json": components["schemas"]["HTTPValidationError"];
};
};
/** @description The model appeared to import successfully, but could not be found in the model manager */
424: never;
};
};
/**
* Add Model
* @description Add a model using the configuration information appropriate for its type. Only local models can be added by path
*/
add_model: {
requestBody: {
content: {
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"];
};
};
responses: {
/** @description The model added successfully */
201: {
content: {
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"];
};
};
/** @description The model could not be found */
404: never;
/** @description There is already a model corresponding to this path or repo_id */
409: never;
/** @description Validation Error */
422: {
content: {
"application/json": components["schemas"]["HTTPValidationError"];
};
};
/** @description The model appeared to add successfully, but could not be found in the model manager */
424: never;
};
};
/**
* Rename Model
* @description Rename a model
*/
rename_model: {
parameters: {
query?: {
/** @description new model name */
new_name?: string;
/** @description new model base */
new_base?: components["schemas"]["BaseModelType"];
};
path: {
/** @description Base model */
base_model: components["schemas"]["BaseModelType"];
/** @description The type of model */
model_type: components["schemas"]["ModelType"];
/** @description current model name */
model_name: string;
};
};
responses: {
/** @description The model was renamed successfully */
201: {
content: {
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"];
};
};
/** @description The model could not be found */
404: never;
/** @description There is already a model corresponding to the new name */
409: never;
/** @description Validation Error */
422: {
content: {
"application/json": components["schemas"]["HTTPValidationError"];
};
};
};
};
/**
* Convert Model
* @description Convert a checkpoint model into a diffusers model
* @description Convert a checkpoint model into a diffusers model, optionally saving to the indicated destination directory, or `models` if none.
*/
convert_model: {
parameters: {
query?: {
/** @description Save the converted model to the designated directory */
convert_dest_directory?: string;
};
path: {
/** @description Base model */
base_model: components["schemas"]["BaseModelType"];
@ -5120,7 +5223,7 @@ export type operations = {
/** @description Model converted successfully */
200: {
content: {
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"];
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"];
};
};
/** @description Bad request */
@ -5135,6 +5238,60 @@ export type operations = {
};
};
};
/** Search For Models */
search_for_models: {
parameters: {
query: {
/** @description Directory path to search for models */
search_path: string;
};
};
responses: {
/** @description Directory searched successfully */
200: {
content: {
"application/json": (string)[];
};
};
/** @description Invalid directory path */
404: never;
/** @description Validation Error */
422: {
content: {
"application/json": components["schemas"]["HTTPValidationError"];
};
};
};
};
/**
* List Ckpt Configs
* @description Return a list of the legacy checkpoint configuration files stored in `ROOT/configs/stable-diffusion`, relative to ROOT.
*/
list_ckpt_configs: {
responses: {
/** @description paths retrieved successfully */
200: {
content: {
"application/json": (string)[];
};
};
};
};
/**
* Sync To Config
* @description Call after making changes to models.yaml, autoimport directories or models directory to synchronize
* in-memory data structures with disk data structures.
*/
sync_to_config: {
responses: {
/** @description synchronization successful */
201: {
content: {
"application/json": unknown;
};
};
};
};
/**
* Merge Models
* @description Convert a checkpoint model into a diffusers model
@ -5155,7 +5312,7 @@ export type operations = {
/** @description Model converted successfully */
200: {
content: {
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"];
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"];
};
};
/** @description Incompatible models */

View File

@ -1 +1 @@
__version__ = "3.0.0+b5"
__version__ = "3.0.0+b6"

View File

@ -55,7 +55,6 @@ def mock_services() -> InvocationServices:
),
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
processor = DefaultInvocationProcessor(),
restoration = None, # type: ignore
configuration = None, # type: ignore
)

View File

@ -48,7 +48,6 @@ def mock_services() -> InvocationServices:
),
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
processor = DefaultInvocationProcessor(),
restoration = None, # type: ignore
configuration = None, # type: ignore
)

View File

@ -1,6 +1,6 @@
from .test_nodes import ImageToImageTestInvocation, TextToImageTestInvocation, ListPassThroughInvocation, PromptTestInvocation
from invokeai.app.services.graph import Edge, Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation
from invokeai.app.invocations.upscale import UpscaleInvocation
from invokeai.app.invocations.upscale import RealESRGANInvocation
from invokeai.app.invocations.image import *
from invokeai.app.invocations.math import AddInvocation, SubtractInvocation
from invokeai.app.invocations.params import ParamIntInvocation
@ -19,7 +19,7 @@ def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> Edg
def test_connections_are_compatible():
from_node = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
from_field = "image"
to_node = UpscaleInvocation(id = "2")
to_node = RealESRGANInvocation(id = "2")
to_field = "image"
result = are_connections_compatible(from_node, from_field, to_node, to_field)
@ -29,7 +29,7 @@ def test_connections_are_compatible():
def test_connections_are_incompatible():
from_node = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
from_field = "image"
to_node = UpscaleInvocation(id = "2")
to_node = RealESRGANInvocation(id = "2")
to_field = "strength"
result = are_connections_compatible(from_node, from_field, to_node, to_field)
@ -39,7 +39,7 @@ def test_connections_are_incompatible():
def test_connections_incompatible_with_invalid_fields():
from_node = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
from_field = "invalid_field"
to_node = UpscaleInvocation(id = "2")
to_node = RealESRGANInvocation(id = "2")
to_field = "image"
# From field is invalid
@ -86,10 +86,10 @@ def test_graph_fails_to_update_node_if_type_changes():
g = Graph()
n = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
g.add_node(n)
n2 = UpscaleInvocation(id = "2")
n2 = RealESRGANInvocation(id = "2")
g.add_node(n2)
nu = UpscaleInvocation(id = "1")
nu = RealESRGANInvocation(id = "1")
with pytest.raises(TypeError):
g.update_node("1", nu)
@ -98,7 +98,7 @@ def test_graph_allows_non_conflicting_id_change():
g = Graph()
n = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
g.add_node(n)
n2 = UpscaleInvocation(id = "2")
n2 = RealESRGANInvocation(id = "2")
g.add_node(n2)
e1 = create_edge(n.id,"image",n2.id,"image")
g.add_edge(e1)
@ -128,7 +128,7 @@ def test_graph_fails_to_update_node_id_if_conflict():
def test_graph_adds_edge():
g = Graph()
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
n2 = UpscaleInvocation(id = "2")
n2 = RealESRGANInvocation(id = "2")
g.add_node(n1)
g.add_node(n2)
e = create_edge(n1.id,"image",n2.id,"image")
@ -139,7 +139,7 @@ def test_graph_adds_edge():
def test_graph_fails_to_add_edge_with_cycle():
g = Graph()
n1 = UpscaleInvocation(id = "1")
n1 = RealESRGANInvocation(id = "1")
g.add_node(n1)
e = create_edge(n1.id,"image",n1.id,"image")
with pytest.raises(InvalidEdgeError):
@ -148,8 +148,8 @@ def test_graph_fails_to_add_edge_with_cycle():
def test_graph_fails_to_add_edge_with_long_cycle():
g = Graph()
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
n2 = UpscaleInvocation(id = "2")
n3 = UpscaleInvocation(id = "3")
n2 = RealESRGANInvocation(id = "2")
n3 = RealESRGANInvocation(id = "3")
g.add_node(n1)
g.add_node(n2)
g.add_node(n3)
@ -164,7 +164,7 @@ def test_graph_fails_to_add_edge_with_long_cycle():
def test_graph_fails_to_add_edge_with_missing_node_id():
g = Graph()
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
n2 = UpscaleInvocation(id = "2")
n2 = RealESRGANInvocation(id = "2")
g.add_node(n1)
g.add_node(n2)
e1 = create_edge("1","image","3","image")
@ -177,8 +177,8 @@ def test_graph_fails_to_add_edge_with_missing_node_id():
def test_graph_fails_to_add_edge_when_destination_exists():
g = Graph()
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
n2 = UpscaleInvocation(id = "2")
n3 = UpscaleInvocation(id = "3")
n2 = RealESRGANInvocation(id = "2")
n3 = RealESRGANInvocation(id = "3")
g.add_node(n1)
g.add_node(n2)
g.add_node(n3)
@ -194,7 +194,7 @@ def test_graph_fails_to_add_edge_when_destination_exists():
def test_graph_fails_to_add_edge_with_mismatched_types():
g = Graph()
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
n2 = UpscaleInvocation(id = "2")
n2 = RealESRGANInvocation(id = "2")
g.add_node(n1)
g.add_node(n2)
e1 = create_edge("1","image","2","strength")
@ -344,7 +344,7 @@ def test_graph_iterator_invalid_if_output_and_input_types_different():
def test_graph_validates():
g = Graph()
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
n2 = UpscaleInvocation(id = "2")
n2 = RealESRGANInvocation(id = "2")
g.add_node(n1)
g.add_node(n2)
e1 = create_edge("1","image","2","image")
@ -377,8 +377,8 @@ def test_graph_invalid_if_subgraph_invalid():
def test_graph_invalid_if_has_cycle():
g = Graph()
n1 = UpscaleInvocation(id = "1")
n2 = UpscaleInvocation(id = "2")
n1 = RealESRGANInvocation(id = "1")
n2 = RealESRGANInvocation(id = "2")
g.nodes[n1.id] = n1
g.nodes[n2.id] = n2
e1 = create_edge("1","image","2","image")
@ -391,7 +391,7 @@ def test_graph_invalid_if_has_cycle():
def test_graph_invalid_with_invalid_connection():
g = Graph()
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
n2 = UpscaleInvocation(id = "2")
n2 = RealESRGANInvocation(id = "2")
g.nodes[n1.id] = n1
g.nodes[n2.id] = n2
e1 = create_edge("1","image","2","strength")
@ -503,7 +503,7 @@ def test_graph_fails_to_enumerate_non_subgraph_node():
g.add_node(n1)
n2 = UpscaleInvocation(id = "2")
n2 = RealESRGANInvocation(id = "2")
g.add_node(n2)
with pytest.raises(NodeNotFoundError):
@ -512,7 +512,7 @@ def test_graph_fails_to_enumerate_non_subgraph_node():
def test_graph_gets_networkx_graph():
g = Graph()
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
n2 = UpscaleInvocation(id = "2")
n2 = RealESRGANInvocation(id = "2")
g.add_node(n1)
g.add_node(n2)
e = create_edge(n1.id,"image",n2.id,"image")
@ -529,7 +529,7 @@ def test_graph_gets_networkx_graph():
def test_graph_can_serialize():
g = Graph()
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
n2 = UpscaleInvocation(id = "2")
n2 = RealESRGANInvocation(id = "2")
g.add_node(n1)
g.add_node(n2)
e = create_edge(n1.id,"image",n2.id,"image")
@ -541,7 +541,7 @@ def test_graph_can_serialize():
def test_graph_can_deserialize():
g = Graph()
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
n2 = UpscaleInvocation(id = "2")
n2 = RealESRGANInvocation(id = "2")
g.add_node(n1)
g.add_node(n2)
e = create_edge(n1.id,"image",n2.id,"image")