mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into feat/standalone_diffusers_ti
This commit is contained in:
commit
cbfd1d1b27
19
README.md
19
README.md
@ -132,8 +132,10 @@ and go to http://localhost:9090.
|
|||||||
|
|
||||||
### Command-Line Installation (for developers and users familiar with Terminals)
|
### Command-Line Installation (for developers and users familiar with Terminals)
|
||||||
|
|
||||||
You must have Python 3.9 or 3.10 installed on your machine. Earlier or later versions are
|
You must have Python 3.9 or 3.10 installed on your machine. Earlier or
|
||||||
not supported.
|
later versions are not supported.
|
||||||
|
Node.js also needs to be installed along with yarn (can be installed with
|
||||||
|
the command `npm install -g yarn` if needed)
|
||||||
|
|
||||||
1. Open a command-line window on your machine. The PowerShell is recommended for Windows.
|
1. Open a command-line window on your machine. The PowerShell is recommended for Windows.
|
||||||
2. Create a directory to install InvokeAI into. You'll need at least 15 GB of free space:
|
2. Create a directory to install InvokeAI into. You'll need at least 15 GB of free space:
|
||||||
@ -197,11 +199,18 @@ not supported.
|
|||||||
7. Launch the web server (do it every time you run InvokeAI):
|
7. Launch the web server (do it every time you run InvokeAI):
|
||||||
|
|
||||||
```terminal
|
```terminal
|
||||||
invokeai --web
|
invokeai-web
|
||||||
```
|
```
|
||||||
|
|
||||||
8. Point your browser to http://localhost:9090 to bring up the web interface.
|
8. Build Node.js assets
|
||||||
9. Type `banana sushi` in the box on the top left and click `Invoke`.
|
|
||||||
|
```terminal
|
||||||
|
cd invokeai/frontend/web/
|
||||||
|
yarn vite build
|
||||||
|
```
|
||||||
|
|
||||||
|
9. Point your browser to http://localhost:9090 to bring up the web interface.
|
||||||
|
10. Type `banana sushi` in the box on the top left and click `Invoke`.
|
||||||
|
|
||||||
Be sure to activate the virtual environment each time before re-launching InvokeAI,
|
Be sure to activate the virtual environment each time before re-launching InvokeAI,
|
||||||
using `source .venv/bin/activate` or `.venv\Scripts\activate`.
|
using `source .venv/bin/activate` or `.venv\Scripts\activate`.
|
||||||
|
@ -11,6 +11,7 @@ from invokeai.app.services.board_images import (
|
|||||||
)
|
)
|
||||||
from invokeai.app.services.board_record_storage import SqliteBoardRecordStorage
|
from invokeai.app.services.board_record_storage import SqliteBoardRecordStorage
|
||||||
from invokeai.app.services.boards import BoardService, BoardServiceDependencies
|
from invokeai.app.services.boards import BoardService, BoardServiceDependencies
|
||||||
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
|
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
|
||||||
from invokeai.app.services.images import ImageService, ImageServiceDependencies
|
from invokeai.app.services.images import ImageService, ImageServiceDependencies
|
||||||
from invokeai.app.services.resource_name import SimpleNameService
|
from invokeai.app.services.resource_name import SimpleNameService
|
||||||
@ -20,7 +21,6 @@ from invokeai.version.invokeai_version import __version__
|
|||||||
|
|
||||||
from ..services.default_graphs import create_system_graphs
|
from ..services.default_graphs import create_system_graphs
|
||||||
from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||||
from ..services.restoration_services import RestorationServices
|
|
||||||
from ..services.graph import GraphExecutionState, LibraryGraph
|
from ..services.graph import GraphExecutionState, LibraryGraph
|
||||||
from ..services.image_file_storage import DiskImageFileStorage
|
from ..services.image_file_storage import DiskImageFileStorage
|
||||||
from ..services.invocation_queue import MemoryInvocationQueue
|
from ..services.invocation_queue import MemoryInvocationQueue
|
||||||
@ -57,8 +57,8 @@ class ApiDependencies:
|
|||||||
invoker: Invoker = None
|
invoker: Invoker = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def initialize(config, event_handler_id: int, logger: Logger = logger):
|
def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger = logger):
|
||||||
logger.debug(f'InvokeAI version {__version__}')
|
logger.debug(f"InvokeAI version {__version__}")
|
||||||
logger.debug(f"Internet connectivity is {config.internet_available}")
|
logger.debug(f"Internet connectivity is {config.internet_available}")
|
||||||
|
|
||||||
events = FastAPIEventService(event_handler_id)
|
events = FastAPIEventService(event_handler_id)
|
||||||
@ -117,7 +117,7 @@ class ApiDependencies:
|
|||||||
)
|
)
|
||||||
|
|
||||||
services = InvocationServices(
|
services = InvocationServices(
|
||||||
model_manager=ModelManagerService(config,logger),
|
model_manager=ModelManagerService(config, logger),
|
||||||
events=events,
|
events=events,
|
||||||
latents=latents,
|
latents=latents,
|
||||||
images=images,
|
images=images,
|
||||||
@ -129,7 +129,6 @@ class ApiDependencies:
|
|||||||
),
|
),
|
||||||
graph_execution_manager=graph_execution_manager,
|
graph_execution_manager=graph_execution_manager,
|
||||||
processor=DefaultInvocationProcessor(),
|
processor=DefaultInvocationProcessor(),
|
||||||
restoration=RestorationServices(config, logger),
|
|
||||||
configuration=config,
|
configuration=config,
|
||||||
logger=logger,
|
logger=logger,
|
||||||
)
|
)
|
||||||
|
@ -39,6 +39,7 @@ from .invocations.baseinvocation import BaseInvocation
|
|||||||
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import invokeai.backend.util.hotfixes
|
||||||
if torch.backends.mps.is_available():
|
if torch.backends.mps.is_available():
|
||||||
import invokeai.backend.util.mps_fixes
|
import invokeai.backend.util.mps_fixes
|
||||||
|
|
||||||
|
@ -54,10 +54,10 @@ from .services.invocation_services import InvocationServices
|
|||||||
from .services.invoker import Invoker
|
from .services.invoker import Invoker
|
||||||
from .services.model_manager_service import ModelManagerService
|
from .services.model_manager_service import ModelManagerService
|
||||||
from .services.processor import DefaultInvocationProcessor
|
from .services.processor import DefaultInvocationProcessor
|
||||||
from .services.restoration_services import RestorationServices
|
|
||||||
from .services.sqlite import SqliteItemStorage
|
from .services.sqlite import SqliteItemStorage
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import invokeai.backend.util.hotfixes
|
||||||
if torch.backends.mps.is_available():
|
if torch.backends.mps.is_available():
|
||||||
import invokeai.backend.util.mps_fixes
|
import invokeai.backend.util.mps_fixes
|
||||||
|
|
||||||
@ -295,7 +295,6 @@ def invoke_cli():
|
|||||||
),
|
),
|
||||||
graph_execution_manager=graph_execution_manager,
|
graph_execution_manager=graph_execution_manager,
|
||||||
processor=DefaultInvocationProcessor(),
|
processor=DefaultInvocationProcessor(),
|
||||||
restoration=RestorationServices(config,logger=logger),
|
|
||||||
logger=logger,
|
logger=logger,
|
||||||
configuration=config,
|
configuration=config,
|
||||||
)
|
)
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
from typing import Literal
|
from os.path import exists
|
||||||
|
from typing import Literal, Optional
|
||||||
|
|
||||||
from pydantic.fields import Field
|
import numpy as np
|
||||||
|
from pydantic import Field, validator
|
||||||
|
|
||||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
|
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
|
||||||
from dynamicprompts.generators import RandomPromptGenerator, CombinatorialPromptGenerator
|
from dynamicprompts.generators import RandomPromptGenerator, CombinatorialPromptGenerator
|
||||||
@ -55,3 +57,41 @@ class DynamicPromptInvocation(BaseInvocation):
|
|||||||
prompts = generator.generate(self.prompt, num_images=self.max_prompts)
|
prompts = generator.generate(self.prompt, num_images=self.max_prompts)
|
||||||
|
|
||||||
return PromptCollectionOutput(prompt_collection=prompts, count=len(prompts))
|
return PromptCollectionOutput(prompt_collection=prompts, count=len(prompts))
|
||||||
|
|
||||||
|
|
||||||
|
class PromptsFromFileInvocation(BaseInvocation):
|
||||||
|
'''Loads prompts from a text file'''
|
||||||
|
# fmt: off
|
||||||
|
type: Literal['prompt_from_file'] = 'prompt_from_file'
|
||||||
|
|
||||||
|
# Inputs
|
||||||
|
file_path: str = Field(description="Path to prompt text file")
|
||||||
|
pre_prompt: Optional[str] = Field(description="String to prepend to each prompt")
|
||||||
|
post_prompt: Optional[str] = Field(description="String to append to each prompt")
|
||||||
|
start_line: int = Field(default=1, ge=1, description="Line in the file to start start from")
|
||||||
|
max_prompts: int = Field(default=1, ge=0, description="Max lines to read from file (0=all)")
|
||||||
|
#fmt: on
|
||||||
|
|
||||||
|
@validator("file_path")
|
||||||
|
def file_path_exists(cls, v):
|
||||||
|
if not exists(v):
|
||||||
|
raise ValueError(FileNotFoundError)
|
||||||
|
return v
|
||||||
|
|
||||||
|
def promptsFromFile(self, file_path: str, pre_prompt: str, post_prompt: str, start_line: int, max_prompts: int):
|
||||||
|
prompts = []
|
||||||
|
start_line -= 1
|
||||||
|
end_line = start_line + max_prompts
|
||||||
|
if max_prompts <= 0:
|
||||||
|
end_line = np.iinfo(np.int32).max
|
||||||
|
with open(file_path) as f:
|
||||||
|
for i, line in enumerate(f):
|
||||||
|
if i >= start_line and i < end_line:
|
||||||
|
prompts.append((pre_prompt or '') + line.strip() + (post_prompt or ''))
|
||||||
|
if i >= end_line:
|
||||||
|
break
|
||||||
|
return prompts
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> PromptCollectionOutput:
|
||||||
|
prompts = self.promptsFromFile(self.file_path, self.pre_prompt, self.post_prompt, self.start_line, self.max_prompts)
|
||||||
|
return PromptCollectionOutput(prompt_collection=prompts, count=len(prompts))
|
||||||
|
@ -1,55 +0,0 @@
|
|||||||
from typing import Literal, Optional
|
|
||||||
|
|
||||||
from pydantic import Field
|
|
||||||
|
|
||||||
from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin
|
|
||||||
|
|
||||||
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
|
||||||
from .image import ImageOutput
|
|
||||||
|
|
||||||
|
|
||||||
class RestoreFaceInvocation(BaseInvocation):
|
|
||||||
"""Restores faces in an image."""
|
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["restore_face"] = "restore_face"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
image: Optional[ImageField] = Field(description="The input image")
|
|
||||||
strength: float = Field(default=0.75, gt=0, le=1, description="The strength of the restoration" )
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
# Schema customisation
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {
|
|
||||||
"tags": ["restoration", "image"],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
|
||||||
results = context.services.restoration.upscale_and_reconstruct(
|
|
||||||
image_list=[[image, 0]],
|
|
||||||
upscale=None,
|
|
||||||
strength=self.strength, # GFPGAN strength
|
|
||||||
save_original=False,
|
|
||||||
image_callback=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Results are image and seed, unwrap for now
|
|
||||||
# TODO: can this return multiple results?
|
|
||||||
image_dto = context.services.images.create(
|
|
||||||
image=results[0][0],
|
|
||||||
image_origin=ResourceOrigin.INTERNAL,
|
|
||||||
image_category=ImageCategory.GENERAL,
|
|
||||||
node_id=self.id,
|
|
||||||
session_id=context.graph_execution_state_id,
|
|
||||||
is_intermediate=self.is_intermediate,
|
|
||||||
)
|
|
||||||
|
|
||||||
return ImageOutput(
|
|
||||||
image=ImageField(image_name=image_dto.image_name),
|
|
||||||
width=image_dto.width,
|
|
||||||
height=image_dto.height,
|
|
||||||
)
|
|
@ -1,48 +1,112 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) & the InvokeAI Team
|
||||||
|
from pathlib import Path, PosixPath
|
||||||
from typing import Literal, Optional
|
from typing import Literal, Union, cast
|
||||||
|
|
||||||
|
import cv2 as cv
|
||||||
|
import numpy as np
|
||||||
|
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||||
|
from PIL import Image
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
from realesrgan import RealESRGANer
|
||||||
|
|
||||||
from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin
|
from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin
|
||||||
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
|
||||||
|
from .baseinvocation import BaseInvocation, InvocationContext
|
||||||
from .image import ImageOutput
|
from .image import ImageOutput
|
||||||
|
|
||||||
|
# TODO: Populate this from disk?
|
||||||
|
# TODO: Use model manager to load?
|
||||||
|
REALESRGAN_MODELS = Literal[
|
||||||
|
"RealESRGAN_x4plus.pth",
|
||||||
|
"RealESRGAN_x4plus_anime_6B.pth",
|
||||||
|
"ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
|
||||||
|
]
|
||||||
|
|
||||||
class UpscaleInvocation(BaseInvocation):
|
|
||||||
"""Upscales an image."""
|
|
||||||
|
|
||||||
# fmt: off
|
class RealESRGANInvocation(BaseInvocation):
|
||||||
type: Literal["upscale"] = "upscale"
|
"""Upscales an image using RealESRGAN."""
|
||||||
|
|
||||||
# Inputs
|
type: Literal["realesrgan"] = "realesrgan"
|
||||||
image: Optional[ImageField] = Field(description="The input image", default=None)
|
image: Union[ImageField, None] = Field(default=None, description="The input image")
|
||||||
strength: float = Field(default=0.75, gt=0, le=1, description="The strength")
|
model_name: REALESRGAN_MODELS = Field(
|
||||||
level: Literal[2, 4] = Field(default=2, description="The upscale level")
|
default="RealESRGAN_x4plus.pth", description="The Real-ESRGAN model to use"
|
||||||
# fmt: on
|
)
|
||||||
|
|
||||||
# Schema customisation
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {
|
|
||||||
"tags": ["upscaling", "image"],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
results = context.services.restoration.upscale_and_reconstruct(
|
models_path = context.services.configuration.models_path
|
||||||
image_list=[[image, 0]],
|
|
||||||
upscale=(self.level, self.strength),
|
rrdbnet_model = None
|
||||||
strength=0.0, # GFPGAN strength
|
netscale = None
|
||||||
save_original=False,
|
esrgan_model_path = None
|
||||||
image_callback=None,
|
|
||||||
|
if self.model_name in [
|
||||||
|
"RealESRGAN_x4plus.pth",
|
||||||
|
"ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
|
||||||
|
]:
|
||||||
|
# x4 RRDBNet model
|
||||||
|
rrdbnet_model = RRDBNet(
|
||||||
|
num_in_ch=3,
|
||||||
|
num_out_ch=3,
|
||||||
|
num_feat=64,
|
||||||
|
num_block=23,
|
||||||
|
num_grow_ch=32,
|
||||||
|
scale=4,
|
||||||
|
)
|
||||||
|
netscale = 4
|
||||||
|
elif self.model_name in ["RealESRGAN_x4plus_anime_6B.pth"]:
|
||||||
|
# x4 RRDBNet model, 6 blocks
|
||||||
|
rrdbnet_model = RRDBNet(
|
||||||
|
num_in_ch=3,
|
||||||
|
num_out_ch=3,
|
||||||
|
num_feat=64,
|
||||||
|
num_block=6, # 6 blocks
|
||||||
|
num_grow_ch=32,
|
||||||
|
scale=4,
|
||||||
|
)
|
||||||
|
netscale = 4
|
||||||
|
# TODO: add x2 models handling?
|
||||||
|
# elif self.model_name in ["RealESRGAN_x2plus"]:
|
||||||
|
# # x2 RRDBNet model
|
||||||
|
# model = RRDBNet(
|
||||||
|
# num_in_ch=3,
|
||||||
|
# num_out_ch=3,
|
||||||
|
# num_feat=64,
|
||||||
|
# num_block=23,
|
||||||
|
# num_grow_ch=32,
|
||||||
|
# scale=2,
|
||||||
|
# )
|
||||||
|
# model_path = Path()
|
||||||
|
# netscale = 2
|
||||||
|
else:
|
||||||
|
msg = f"Invalid RealESRGAN model: {self.model_name}"
|
||||||
|
context.services.logger.error(msg)
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
esrgan_model_path = Path(f"core/upscaling/realesrgan/{self.model_name}")
|
||||||
|
|
||||||
|
upsampler = RealESRGANer(
|
||||||
|
scale=netscale,
|
||||||
|
model_path=str(models_path / esrgan_model_path),
|
||||||
|
model=rrdbnet_model,
|
||||||
|
half=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Results are image and seed, unwrap for now
|
# prepare image - Real-ESRGAN uses cv2 internally, and cv2 uses BGR vs RGB for PIL
|
||||||
# TODO: can this return multiple results?
|
cv_image = cv.cvtColor(np.array(image.convert("RGB")), cv.COLOR_RGB2BGR)
|
||||||
|
|
||||||
|
# We can pass an `outscale` value here, but it just resizes the image by that factor after
|
||||||
|
# upscaling, so it's kinda pointless for our purposes. If you want something other than 4x
|
||||||
|
# upscaling, you'll need to add a resize node after this one.
|
||||||
|
upscaled_image, img_mode = upsampler.enhance(cv_image)
|
||||||
|
|
||||||
|
# back to PIL
|
||||||
|
pil_image = Image.fromarray(
|
||||||
|
cv.cvtColor(upscaled_image, cv.COLOR_BGR2RGB)
|
||||||
|
).convert("RGBA")
|
||||||
|
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.services.images.create(
|
||||||
image=results[0][0],
|
image=pil_image,
|
||||||
image_origin=ResourceOrigin.INTERNAL,
|
image_origin=ResourceOrigin.INTERNAL,
|
||||||
image_category=ImageCategory.GENERAL,
|
image_category=ImageCategory.GENERAL,
|
||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
|
@ -271,13 +271,13 @@ class InvokeAISettings(BaseSettings):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _excluded(self)->List[str]:
|
def _excluded(self)->List[str]:
|
||||||
# combination of deprecated parameters and internal ones that shouldn't be exposed
|
# internal fields that shouldn't be exposed as command line options
|
||||||
return ['type','initconf']
|
return ['type','initconf']
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _excluded_from_yaml(self)->List[str]:
|
def _excluded_from_yaml(self)->List[str]:
|
||||||
# combination of deprecated parameters and internal ones that shouldn't be exposed
|
# combination of deprecated parameters and internal ones that shouldn't be exposed as invokeai.yaml options
|
||||||
return ['type','initconf', 'gpu_mem_reserved', 'max_loaded_models', 'version', 'from_file', 'model']
|
return ['type','initconf', 'gpu_mem_reserved', 'max_loaded_models', 'version', 'from_file', 'model', 'restore']
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
env_file_encoding = 'utf-8'
|
env_file_encoding = 'utf-8'
|
||||||
@ -366,7 +366,7 @@ setting environment variables INVOKEAI_<setting>.
|
|||||||
log_tokenization : bool = Field(default=False, description="Enable logging of parsed prompt tokens.", category='Features')
|
log_tokenization : bool = Field(default=False, description="Enable logging of parsed prompt tokens.", category='Features')
|
||||||
nsfw_checker : bool = Field(default=True, description="Enable/disable the NSFW checker", category='Features')
|
nsfw_checker : bool = Field(default=True, description="Enable/disable the NSFW checker", category='Features')
|
||||||
patchmatch : bool = Field(default=True, description="Enable/disable patchmatch inpaint code", category='Features')
|
patchmatch : bool = Field(default=True, description="Enable/disable patchmatch inpaint code", category='Features')
|
||||||
restore : bool = Field(default=True, description="Enable/disable face restoration code", category='Features')
|
restore : bool = Field(default=True, description="Enable/disable face restoration code (DEPRECATED)", category='DEPRECATED')
|
||||||
|
|
||||||
always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", category='Memory/Performance')
|
always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", category='Memory/Performance')
|
||||||
free_gpu_mem : bool = Field(default=False, description="If true, purge model from GPU after each generation.", category='Memory/Performance')
|
free_gpu_mem : bool = Field(default=False, description="If true, purge model from GPU after each generation.", category='Memory/Performance')
|
||||||
|
@ -10,10 +10,9 @@ if TYPE_CHECKING:
|
|||||||
from invokeai.app.services.model_manager_service import ModelManagerServiceBase
|
from invokeai.app.services.model_manager_service import ModelManagerServiceBase
|
||||||
from invokeai.app.services.events import EventServiceBase
|
from invokeai.app.services.events import EventServiceBase
|
||||||
from invokeai.app.services.latent_storage import LatentsStorageBase
|
from invokeai.app.services.latent_storage import LatentsStorageBase
|
||||||
from invokeai.app.services.restoration_services import RestorationServices
|
|
||||||
from invokeai.app.services.invocation_queue import InvocationQueueABC
|
from invokeai.app.services.invocation_queue import InvocationQueueABC
|
||||||
from invokeai.app.services.item_storage import ItemStorageABC
|
from invokeai.app.services.item_storage import ItemStorageABC
|
||||||
from invokeai.app.services.config import InvokeAISettings
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.app.services.graph import GraphExecutionState, LibraryGraph
|
from invokeai.app.services.graph import GraphExecutionState, LibraryGraph
|
||||||
from invokeai.app.services.invoker import InvocationProcessorABC
|
from invokeai.app.services.invoker import InvocationProcessorABC
|
||||||
|
|
||||||
@ -24,7 +23,7 @@ class InvocationServices:
|
|||||||
# TODO: Just forward-declared everything due to circular dependencies. Fix structure.
|
# TODO: Just forward-declared everything due to circular dependencies. Fix structure.
|
||||||
board_images: "BoardImagesServiceABC"
|
board_images: "BoardImagesServiceABC"
|
||||||
boards: "BoardServiceABC"
|
boards: "BoardServiceABC"
|
||||||
configuration: "InvokeAISettings"
|
configuration: "InvokeAIAppConfig"
|
||||||
events: "EventServiceBase"
|
events: "EventServiceBase"
|
||||||
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"]
|
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"]
|
||||||
graph_library: "ItemStorageABC"["LibraryGraph"]
|
graph_library: "ItemStorageABC"["LibraryGraph"]
|
||||||
@ -34,13 +33,12 @@ class InvocationServices:
|
|||||||
model_manager: "ModelManagerServiceBase"
|
model_manager: "ModelManagerServiceBase"
|
||||||
processor: "InvocationProcessorABC"
|
processor: "InvocationProcessorABC"
|
||||||
queue: "InvocationQueueABC"
|
queue: "InvocationQueueABC"
|
||||||
restoration: "RestorationServices"
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
board_images: "BoardImagesServiceABC",
|
board_images: "BoardImagesServiceABC",
|
||||||
boards: "BoardServiceABC",
|
boards: "BoardServiceABC",
|
||||||
configuration: "InvokeAISettings",
|
configuration: "InvokeAIAppConfig",
|
||||||
events: "EventServiceBase",
|
events: "EventServiceBase",
|
||||||
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"],
|
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"],
|
||||||
graph_library: "ItemStorageABC"["LibraryGraph"],
|
graph_library: "ItemStorageABC"["LibraryGraph"],
|
||||||
@ -50,7 +48,6 @@ class InvocationServices:
|
|||||||
model_manager: "ModelManagerServiceBase",
|
model_manager: "ModelManagerServiceBase",
|
||||||
processor: "InvocationProcessorABC",
|
processor: "InvocationProcessorABC",
|
||||||
queue: "InvocationQueueABC",
|
queue: "InvocationQueueABC",
|
||||||
restoration: "RestorationServices",
|
|
||||||
):
|
):
|
||||||
self.board_images = board_images
|
self.board_images = board_images
|
||||||
self.boards = boards
|
self.boards = boards
|
||||||
@ -65,4 +62,3 @@ class InvocationServices:
|
|||||||
self.model_manager = model_manager
|
self.model_manager = model_manager
|
||||||
self.processor = processor
|
self.processor = processor
|
||||||
self.queue = queue
|
self.queue = queue
|
||||||
self.restoration = restoration
|
|
||||||
|
@ -1,113 +0,0 @@
|
|||||||
import sys
|
|
||||||
import traceback
|
|
||||||
import torch
|
|
||||||
from typing import types
|
|
||||||
from ...backend.restoration import Restoration
|
|
||||||
from ...backend.util import choose_torch_device, CPU_DEVICE, MPS_DEVICE
|
|
||||||
|
|
||||||
# This should be a real base class for postprocessing functions,
|
|
||||||
# but right now we just instantiate the existing gfpgan, esrgan
|
|
||||||
# and codeformer functions.
|
|
||||||
class RestorationServices:
|
|
||||||
'''Face restoration and upscaling'''
|
|
||||||
|
|
||||||
def __init__(self,args,logger:types.ModuleType):
|
|
||||||
try:
|
|
||||||
gfpgan, codeformer, esrgan = None, None, None
|
|
||||||
if args.restore or args.esrgan:
|
|
||||||
restoration = Restoration()
|
|
||||||
# TODO: redo for new model structure
|
|
||||||
if False and args.restore:
|
|
||||||
gfpgan, codeformer = restoration.load_face_restore_models(
|
|
||||||
args.gfpgan_model_path
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.info("Face restoration disabled")
|
|
||||||
if False and args.esrgan:
|
|
||||||
esrgan = restoration.load_esrgan(args.esrgan_bg_tile)
|
|
||||||
else:
|
|
||||||
logger.info("Upscaling disabled")
|
|
||||||
else:
|
|
||||||
logger.info("Face restoration and upscaling disabled")
|
|
||||||
except (ModuleNotFoundError, ImportError):
|
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
logger.info("You may need to install the ESRGAN and/or GFPGAN modules")
|
|
||||||
self.device = torch.device(choose_torch_device())
|
|
||||||
self.gfpgan = gfpgan
|
|
||||||
self.codeformer = codeformer
|
|
||||||
self.esrgan = esrgan
|
|
||||||
self.logger = logger
|
|
||||||
self.logger.info('Face restoration initialized')
|
|
||||||
|
|
||||||
# note that this one method does gfpgan and codepath reconstruction, as well as
|
|
||||||
# esrgan upscaling
|
|
||||||
# TO DO: refactor into separate methods
|
|
||||||
def upscale_and_reconstruct(
|
|
||||||
self,
|
|
||||||
image_list,
|
|
||||||
facetool="gfpgan",
|
|
||||||
upscale=None,
|
|
||||||
upscale_denoise_str=0.75,
|
|
||||||
strength=0.0,
|
|
||||||
codeformer_fidelity=0.75,
|
|
||||||
save_original=False,
|
|
||||||
image_callback=None,
|
|
||||||
prefix=None,
|
|
||||||
):
|
|
||||||
results = []
|
|
||||||
for r in image_list:
|
|
||||||
image, seed = r
|
|
||||||
try:
|
|
||||||
if strength > 0:
|
|
||||||
if self.gfpgan is not None or self.codeformer is not None:
|
|
||||||
if facetool == "gfpgan":
|
|
||||||
if self.gfpgan is None:
|
|
||||||
self.logger.info(
|
|
||||||
"GFPGAN not found. Face restoration is disabled."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
image = self.gfpgan.process(image, strength, seed)
|
|
||||||
if facetool == "codeformer":
|
|
||||||
if self.codeformer is None:
|
|
||||||
self.logger.info(
|
|
||||||
"CodeFormer not found. Face restoration is disabled."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
cf_device = (
|
|
||||||
CPU_DEVICE if self.device == MPS_DEVICE else self.device
|
|
||||||
)
|
|
||||||
image = self.codeformer.process(
|
|
||||||
image=image,
|
|
||||||
strength=strength,
|
|
||||||
device=cf_device,
|
|
||||||
seed=seed,
|
|
||||||
fidelity=codeformer_fidelity,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.logger.info("Face Restoration is disabled.")
|
|
||||||
if upscale is not None:
|
|
||||||
if self.esrgan is not None:
|
|
||||||
if len(upscale) < 2:
|
|
||||||
upscale.append(0.75)
|
|
||||||
image = self.esrgan.process(
|
|
||||||
image,
|
|
||||||
upscale[1],
|
|
||||||
seed,
|
|
||||||
int(upscale[0]),
|
|
||||||
denoise_str=upscale_denoise_str,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.logger.info("ESRGAN is disabled. Image not upscaled.")
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.info(
|
|
||||||
f"Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if image_callback is not None:
|
|
||||||
image_callback(image, seed, upscaled=True, use_prefix=prefix)
|
|
||||||
else:
|
|
||||||
r[0] = image
|
|
||||||
|
|
||||||
results.append([image, seed])
|
|
||||||
|
|
||||||
return results
|
|
@ -30,8 +30,6 @@ from huggingface_hub import login as hf_hub_login
|
|||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoProcessor,
|
|
||||||
CLIPSegForImageSegmentation,
|
|
||||||
CLIPTextModel,
|
CLIPTextModel,
|
||||||
CLIPTokenizer,
|
CLIPTokenizer,
|
||||||
AutoFeatureExtractor,
|
AutoFeatureExtractor,
|
||||||
@ -45,7 +43,6 @@ from invokeai.app.services.config import (
|
|||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
from invokeai.frontend.install.model_install import addModelsForm, process_and_execute
|
from invokeai.frontend.install.model_install import addModelsForm, process_and_execute
|
||||||
from invokeai.frontend.install.widgets import (
|
from invokeai.frontend.install.widgets import (
|
||||||
SingleSelectColumns,
|
|
||||||
CenteredButtonPress,
|
CenteredButtonPress,
|
||||||
IntTitleSlider,
|
IntTitleSlider,
|
||||||
set_min_terminal_size,
|
set_min_terminal_size,
|
||||||
@ -226,64 +223,30 @@ def download_conversion_models():
|
|||||||
|
|
||||||
# ---------------------------------------------
|
# ---------------------------------------------
|
||||||
def download_realesrgan():
|
def download_realesrgan():
|
||||||
logger.info("Installing models from RealESRGAN...")
|
logger.info("Installing RealESRGAN models...")
|
||||||
model_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth"
|
URLs = [
|
||||||
wdn_model_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth"
|
dict(
|
||||||
|
url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
|
||||||
model_dest = config.root_path / "models/core/upscaling/realesrgan/realesr-general-x4v3.pth"
|
dest = "core/upscaling/realesrgan/RealESRGAN_x4plus.pth",
|
||||||
wdn_model_dest = config.root_path / "models/core/upscaling/realesrgan/realesr-general-wdn-x4v3.pth"
|
description = "RealESRGAN_x4plus.pth",
|
||||||
|
),
|
||||||
download_with_progress_bar(model_url, str(model_dest), "RealESRGAN")
|
dict(
|
||||||
download_with_progress_bar(wdn_model_url, str(wdn_model_dest), "RealESRGANwdn")
|
url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
|
||||||
|
dest = "core/upscaling/realesrgan/RealESRGAN_x4plus_anime_6B.pth",
|
||||||
|
description = "RealESRGAN_x4plus_anime_6B.pth",
|
||||||
def download_gfpgan():
|
),
|
||||||
logger.info("Installing GFPGAN models...")
|
dict(
|
||||||
for model in (
|
url= "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
|
||||||
[
|
dest= "core/upscaling/realesrgan/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
|
||||||
"https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth",
|
description = "ESRGAN_SRx4_DF2KOST_official.pth",
|
||||||
"./models/core/face_restoration/gfpgan/GFPGANv1.4.pth",
|
),
|
||||||
],
|
]
|
||||||
[
|
for model in URLs:
|
||||||
"https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_Resnet50_Final.pth",
|
download_with_progress_bar(model['url'], config.models_path / model['dest'], model['description'])
|
||||||
"./models/core/face_restoration/gfpgan/weights/detection_Resnet50_Final.pth",
|
|
||||||
],
|
|
||||||
[
|
|
||||||
"https://github.com/xinntao/facexlib/releases/download/v0.2.2/parsing_parsenet.pth",
|
|
||||||
"./models/core/face_restoration/gfpgan/weights/parsing_parsenet.pth",
|
|
||||||
],
|
|
||||||
):
|
|
||||||
model_url, model_dest = model[0], config.root_path / model[1]
|
|
||||||
download_with_progress_bar(model_url, str(model_dest), "GFPGAN weights")
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------
|
# ---------------------------------------------
|
||||||
def download_codeformer():
|
|
||||||
logger.info("Installing CodeFormer model file...")
|
|
||||||
model_url = (
|
|
||||||
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth"
|
|
||||||
)
|
|
||||||
model_dest = config.root_path / "models/core/face_restoration/codeformer/codeformer.pth"
|
|
||||||
download_with_progress_bar(model_url, str(model_dest), "CodeFormer")
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------
|
|
||||||
def download_clipseg():
|
|
||||||
logger.info("Installing clipseg model for text-based masking...")
|
|
||||||
CLIPSEG_MODEL = "CIDAS/clipseg-rd64-refined"
|
|
||||||
try:
|
|
||||||
hf_download_from_pretrained(AutoProcessor, CLIPSEG_MODEL, config.root_path / 'models/core/misc/clipseg')
|
|
||||||
hf_download_from_pretrained(CLIPSegForImageSegmentation, CLIPSEG_MODEL, config.root_path / 'models/core/misc/clipseg')
|
|
||||||
except Exception:
|
|
||||||
logger.info("Error installing clipseg model:")
|
|
||||||
logger.info(traceback.format_exc())
|
|
||||||
|
|
||||||
|
|
||||||
def download_support_models():
|
def download_support_models():
|
||||||
download_realesrgan()
|
download_realesrgan()
|
||||||
download_gfpgan()
|
|
||||||
download_codeformer()
|
|
||||||
download_clipseg()
|
|
||||||
download_conversion_models()
|
download_conversion_models()
|
||||||
|
|
||||||
# -------------------------------------
|
# -------------------------------------
|
||||||
@ -858,9 +821,9 @@ def main():
|
|||||||
download_support_models()
|
download_support_models()
|
||||||
|
|
||||||
if opt.skip_sd_weights:
|
if opt.skip_sd_weights:
|
||||||
logger.info("\n** SKIPPING DIFFUSION WEIGHTS DOWNLOAD PER USER REQUEST **")
|
logger.warning("SKIPPING DIFFUSION WEIGHTS DOWNLOAD PER USER REQUEST")
|
||||||
elif models_to_download:
|
elif models_to_download:
|
||||||
logger.info("\n** DOWNLOADING DIFFUSION WEIGHTS **")
|
logger.info("DOWNLOADING DIFFUSION WEIGHTS")
|
||||||
process_and_execute(opt, models_to_download)
|
process_and_execute(opt, models_to_download)
|
||||||
|
|
||||||
postscript(errors=errors)
|
postscript(errors=errors)
|
||||||
|
@ -117,6 +117,7 @@ class ModelInstall(object):
|
|||||||
|
|
||||||
# supplement with entries in models.yaml
|
# supplement with entries in models.yaml
|
||||||
installed_models = self.mgr.list_models()
|
installed_models = self.mgr.list_models()
|
||||||
|
|
||||||
for md in installed_models:
|
for md in installed_models:
|
||||||
base = md['base_model']
|
base = md['base_model']
|
||||||
model_type = md['model_type']
|
model_type = md['model_type']
|
||||||
@ -134,6 +135,12 @@ class ModelInstall(object):
|
|||||||
)
|
)
|
||||||
return {x : model_dict[x] for x in sorted(model_dict.keys(),key=lambda y: model_dict[y].name.lower())}
|
return {x : model_dict[x] for x in sorted(model_dict.keys(),key=lambda y: model_dict[y].name.lower())}
|
||||||
|
|
||||||
|
def list_models(self, model_type):
|
||||||
|
installed = self.mgr.list_models(model_type=model_type)
|
||||||
|
print(f'Installed models of type `{model_type}`:')
|
||||||
|
for i in installed:
|
||||||
|
print(f"{i['model_name']}\t{i['base_model']}\t{i['path']}")
|
||||||
|
|
||||||
def starter_models(self)->Set[str]:
|
def starter_models(self)->Set[str]:
|
||||||
models = set()
|
models = set()
|
||||||
for key, value in self.datasets.items():
|
for key, value in self.datasets.items():
|
||||||
|
@ -908,7 +908,6 @@ class ModelManager(object):
|
|||||||
from invokeai.backend.install.model_install_backend import ModelInstall
|
from invokeai.backend.install.model_install_backend import ModelInstall
|
||||||
from invokeai.frontend.install.model_install import ask_user_for_prediction_type
|
from invokeai.frontend.install.model_install import ask_user_for_prediction_type
|
||||||
|
|
||||||
|
|
||||||
class ScanAndImport(ModelSearch):
|
class ScanAndImport(ModelSearch):
|
||||||
def __init__(self, directories, logger, ignore: Set[Path], installer: ModelInstall):
|
def __init__(self, directories, logger, ignore: Set[Path], installer: ModelInstall):
|
||||||
super().__init__(directories, logger)
|
super().__init__(directories, logger)
|
||||||
|
@ -1,4 +0,0 @@
|
|||||||
"""
|
|
||||||
Initialization file for the invokeai.backend.restoration package
|
|
||||||
"""
|
|
||||||
from .base import Restoration
|
|
@ -1,45 +0,0 @@
|
|||||||
import invokeai.backend.util.logging as logger
|
|
||||||
|
|
||||||
class Restoration:
|
|
||||||
def __init__(self) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def load_face_restore_models(
|
|
||||||
self, gfpgan_model_path="./models/core/face_restoration/gfpgan/GFPGANv1.4.pth"
|
|
||||||
):
|
|
||||||
# Load GFPGAN
|
|
||||||
gfpgan = self.load_gfpgan(gfpgan_model_path)
|
|
||||||
if gfpgan.gfpgan_model_exists:
|
|
||||||
logger.info("GFPGAN Initialized")
|
|
||||||
else:
|
|
||||||
logger.info("GFPGAN Disabled")
|
|
||||||
gfpgan = None
|
|
||||||
|
|
||||||
# Load CodeFormer
|
|
||||||
codeformer = self.load_codeformer()
|
|
||||||
if codeformer.codeformer_model_exists:
|
|
||||||
logger.info("CodeFormer Initialized")
|
|
||||||
else:
|
|
||||||
logger.info("CodeFormer Disabled")
|
|
||||||
codeformer = None
|
|
||||||
|
|
||||||
return gfpgan, codeformer
|
|
||||||
|
|
||||||
# Face Restore Models
|
|
||||||
def load_gfpgan(self, gfpgan_model_path):
|
|
||||||
from .gfpgan import GFPGAN
|
|
||||||
|
|
||||||
return GFPGAN(gfpgan_model_path)
|
|
||||||
|
|
||||||
def load_codeformer(self):
|
|
||||||
from .codeformer import CodeFormerRestoration
|
|
||||||
|
|
||||||
return CodeFormerRestoration()
|
|
||||||
|
|
||||||
# Upscale Models
|
|
||||||
def load_esrgan(self, esrgan_bg_tile=400):
|
|
||||||
from .realesrgan import ESRGAN
|
|
||||||
|
|
||||||
esrgan = ESRGAN(esrgan_bg_tile)
|
|
||||||
logger.info("ESRGAN Initialized")
|
|
||||||
return esrgan
|
|
@ -1,120 +0,0 @@
|
|||||||
import os
|
|
||||||
import sys
|
|
||||||
import warnings
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
|
||||||
|
|
||||||
pretrained_model_url = (
|
|
||||||
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class CodeFormerRestoration:
|
|
||||||
def __init__(
|
|
||||||
self, codeformer_dir="./models/core/face_restoration/codeformer", codeformer_model_path="codeformer.pth"
|
|
||||||
) -> None:
|
|
||||||
|
|
||||||
self.globals = InvokeAIAppConfig.get_config()
|
|
||||||
codeformer_dir = self.globals.root_dir / codeformer_dir
|
|
||||||
self.model_path = codeformer_dir / codeformer_model_path
|
|
||||||
self.codeformer_model_exists = self.model_path.exists()
|
|
||||||
|
|
||||||
if not self.codeformer_model_exists:
|
|
||||||
logger.error(f"NOT FOUND: CodeFormer model not found at {self.model_path}")
|
|
||||||
sys.path.append(os.path.abspath(codeformer_dir))
|
|
||||||
|
|
||||||
def process(self, image, strength, device, seed=None, fidelity=0.75):
|
|
||||||
if seed is not None:
|
|
||||||
logger.info(f"CodeFormer - Restoring Faces for image seed:{seed}")
|
|
||||||
with warnings.catch_warnings():
|
|
||||||
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
|
||||||
warnings.filterwarnings("ignore", category=UserWarning)
|
|
||||||
|
|
||||||
from basicsr.utils import img2tensor, tensor2img
|
|
||||||
from basicsr.utils.download_util import load_file_from_url
|
|
||||||
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
|
|
||||||
from PIL import Image
|
|
||||||
from torchvision.transforms.functional import normalize
|
|
||||||
|
|
||||||
from .codeformer_arch import CodeFormer
|
|
||||||
|
|
||||||
cf_class = CodeFormer
|
|
||||||
|
|
||||||
cf = cf_class(
|
|
||||||
dim_embd=512,
|
|
||||||
codebook_size=1024,
|
|
||||||
n_head=8,
|
|
||||||
n_layers=9,
|
|
||||||
connect_list=["32", "64", "128", "256"],
|
|
||||||
).to(device)
|
|
||||||
|
|
||||||
# note that this file should already be downloaded and cached at
|
|
||||||
# this point
|
|
||||||
checkpoint_path = load_file_from_url(
|
|
||||||
url=pretrained_model_url,
|
|
||||||
model_dir=os.path.abspath(os.path.dirname(self.model_path)),
|
|
||||||
progress=True,
|
|
||||||
)
|
|
||||||
checkpoint = torch.load(checkpoint_path)["params_ema"]
|
|
||||||
cf.load_state_dict(checkpoint)
|
|
||||||
cf.eval()
|
|
||||||
|
|
||||||
image = image.convert("RGB")
|
|
||||||
# Codeformer expects a BGR np array; make array and flip channels
|
|
||||||
bgr_image_array = np.array(image, dtype=np.uint8)[..., ::-1]
|
|
||||||
|
|
||||||
face_helper = FaceRestoreHelper(
|
|
||||||
upscale_factor=1,
|
|
||||||
use_parse=True,
|
|
||||||
device=device,
|
|
||||||
model_rootpath = self.globals.model_path / 'core/face_restoration/gfpgan/weights'
|
|
||||||
)
|
|
||||||
face_helper.clean_all()
|
|
||||||
face_helper.read_image(bgr_image_array)
|
|
||||||
face_helper.get_face_landmarks_5(resize=640, eye_dist_threshold=5)
|
|
||||||
face_helper.align_warp_face()
|
|
||||||
|
|
||||||
for idx, cropped_face in enumerate(face_helper.cropped_faces):
|
|
||||||
cropped_face_t = img2tensor(
|
|
||||||
cropped_face / 255.0, bgr2rgb=True, float32=True
|
|
||||||
)
|
|
||||||
normalize(
|
|
||||||
cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True
|
|
||||||
)
|
|
||||||
cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
|
|
||||||
|
|
||||||
try:
|
|
||||||
with torch.no_grad():
|
|
||||||
output = cf(cropped_face_t, w=fidelity, adain=True)[0]
|
|
||||||
restored_face = tensor2img(
|
|
||||||
output.squeeze(0), rgb2bgr=True, min_max=(-1, 1)
|
|
||||||
)
|
|
||||||
del output
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
except RuntimeError as error:
|
|
||||||
logger.error(f"Failed inference for CodeFormer: {error}.")
|
|
||||||
restored_face = cropped_face
|
|
||||||
|
|
||||||
restored_face = restored_face.astype("uint8")
|
|
||||||
face_helper.add_restored_face(restored_face)
|
|
||||||
|
|
||||||
face_helper.get_inverse_affine(None)
|
|
||||||
|
|
||||||
restored_img = face_helper.paste_faces_to_input_image()
|
|
||||||
|
|
||||||
# Flip the channels back to RGB
|
|
||||||
res = Image.fromarray(restored_img[..., ::-1])
|
|
||||||
|
|
||||||
if strength < 1.0:
|
|
||||||
# Resize the image to the new image if the sizes have changed
|
|
||||||
if restored_img.size != image.size:
|
|
||||||
image = image.resize(res.size)
|
|
||||||
res = Image.blend(image, res, strength)
|
|
||||||
|
|
||||||
cf = None
|
|
||||||
|
|
||||||
return res
|
|
@ -1,325 +0,0 @@
|
|||||||
import math
|
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from basicsr.utils import get_root_logger
|
|
||||||
from basicsr.utils.registry import ARCH_REGISTRY
|
|
||||||
from torch import Tensor, nn
|
|
||||||
|
|
||||||
from .vqgan_arch import *
|
|
||||||
|
|
||||||
|
|
||||||
def calc_mean_std(feat, eps=1e-5):
|
|
||||||
"""Calculate mean and std for adaptive_instance_normalization.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
feat (Tensor): 4D tensor.
|
|
||||||
eps (float): A small value added to the variance to avoid
|
|
||||||
divide-by-zero. Default: 1e-5.
|
|
||||||
"""
|
|
||||||
size = feat.size()
|
|
||||||
assert len(size) == 4, "The input feature should be 4D tensor."
|
|
||||||
b, c = size[:2]
|
|
||||||
feat_var = feat.view(b, c, -1).var(dim=2) + eps
|
|
||||||
feat_std = feat_var.sqrt().view(b, c, 1, 1)
|
|
||||||
feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1)
|
|
||||||
return feat_mean, feat_std
|
|
||||||
|
|
||||||
|
|
||||||
def adaptive_instance_normalization(content_feat, style_feat):
|
|
||||||
"""Adaptive instance normalization.
|
|
||||||
|
|
||||||
Adjust the reference features to have the similar color and illuminations
|
|
||||||
as those in the degradate features.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
content_feat (Tensor): The reference feature.
|
|
||||||
style_feat (Tensor): The degradate features.
|
|
||||||
"""
|
|
||||||
size = content_feat.size()
|
|
||||||
style_mean, style_std = calc_mean_std(style_feat)
|
|
||||||
content_mean, content_std = calc_mean_std(content_feat)
|
|
||||||
normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(
|
|
||||||
size
|
|
||||||
)
|
|
||||||
return normalized_feat * style_std.expand(size) + style_mean.expand(size)
|
|
||||||
|
|
||||||
|
|
||||||
class PositionEmbeddingSine(nn.Module):
|
|
||||||
"""
|
|
||||||
This is a more standard version of the position embedding, very similar to the one
|
|
||||||
used by the Attention is all you need paper, generalized to work on images.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, num_pos_feats=64, temperature=10000, normalize=False, scale=None
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.num_pos_feats = num_pos_feats
|
|
||||||
self.temperature = temperature
|
|
||||||
self.normalize = normalize
|
|
||||||
if scale is not None and normalize is False:
|
|
||||||
raise ValueError("normalize should be True if scale is passed")
|
|
||||||
if scale is None:
|
|
||||||
scale = 2 * math.pi
|
|
||||||
self.scale = scale
|
|
||||||
|
|
||||||
def forward(self, x, mask=None):
|
|
||||||
if mask is None:
|
|
||||||
mask = torch.zeros(
|
|
||||||
(x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool
|
|
||||||
)
|
|
||||||
not_mask = ~mask
|
|
||||||
y_embed = not_mask.cumsum(1, dtype=torch.float32)
|
|
||||||
x_embed = not_mask.cumsum(2, dtype=torch.float32)
|
|
||||||
if self.normalize:
|
|
||||||
eps = 1e-6
|
|
||||||
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
|
||||||
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
|
||||||
|
|
||||||
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
|
||||||
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
|
||||||
|
|
||||||
pos_x = x_embed[:, :, :, None] / dim_t
|
|
||||||
pos_y = y_embed[:, :, :, None] / dim_t
|
|
||||||
pos_x = torch.stack(
|
|
||||||
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
|
|
||||||
).flatten(3)
|
|
||||||
pos_y = torch.stack(
|
|
||||||
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
|
|
||||||
).flatten(3)
|
|
||||||
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
|
||||||
return pos
|
|
||||||
|
|
||||||
|
|
||||||
def _get_activation_fn(activation):
|
|
||||||
"""Return an activation function given a string"""
|
|
||||||
if activation == "relu":
|
|
||||||
return F.relu
|
|
||||||
if activation == "gelu":
|
|
||||||
return F.gelu
|
|
||||||
if activation == "glu":
|
|
||||||
return F.glu
|
|
||||||
raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
|
|
||||||
|
|
||||||
|
|
||||||
class TransformerSALayer(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="gelu"
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout)
|
|
||||||
# Implementation of Feedforward model - MLP
|
|
||||||
self.linear1 = nn.Linear(embed_dim, dim_mlp)
|
|
||||||
self.dropout = nn.Dropout(dropout)
|
|
||||||
self.linear2 = nn.Linear(dim_mlp, embed_dim)
|
|
||||||
|
|
||||||
self.norm1 = nn.LayerNorm(embed_dim)
|
|
||||||
self.norm2 = nn.LayerNorm(embed_dim)
|
|
||||||
self.dropout1 = nn.Dropout(dropout)
|
|
||||||
self.dropout2 = nn.Dropout(dropout)
|
|
||||||
|
|
||||||
self.activation = _get_activation_fn(activation)
|
|
||||||
|
|
||||||
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
|
||||||
return tensor if pos is None else tensor + pos
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
tgt,
|
|
||||||
tgt_mask: Optional[Tensor] = None,
|
|
||||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
|
||||||
query_pos: Optional[Tensor] = None,
|
|
||||||
):
|
|
||||||
# self attention
|
|
||||||
tgt2 = self.norm1(tgt)
|
|
||||||
q = k = self.with_pos_embed(tgt2, query_pos)
|
|
||||||
tgt2 = self.self_attn(
|
|
||||||
q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
|
|
||||||
)[0]
|
|
||||||
tgt = tgt + self.dropout1(tgt2)
|
|
||||||
|
|
||||||
# ffn
|
|
||||||
tgt2 = self.norm2(tgt)
|
|
||||||
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
|
|
||||||
tgt = tgt + self.dropout2(tgt2)
|
|
||||||
return tgt
|
|
||||||
|
|
||||||
|
|
||||||
class Fuse_sft_block(nn.Module):
|
|
||||||
def __init__(self, in_ch, out_ch):
|
|
||||||
super().__init__()
|
|
||||||
self.encode_enc = ResBlock(2 * in_ch, out_ch)
|
|
||||||
|
|
||||||
self.scale = nn.Sequential(
|
|
||||||
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
|
|
||||||
nn.LeakyReLU(0.2, True),
|
|
||||||
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
|
|
||||||
)
|
|
||||||
|
|
||||||
self.shift = nn.Sequential(
|
|
||||||
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
|
|
||||||
nn.LeakyReLU(0.2, True),
|
|
||||||
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, enc_feat, dec_feat, w=1):
|
|
||||||
enc_feat = self.encode_enc(torch.cat([enc_feat, dec_feat], dim=1))
|
|
||||||
scale = self.scale(enc_feat)
|
|
||||||
shift = self.shift(enc_feat)
|
|
||||||
residual = w * (dec_feat * scale + shift)
|
|
||||||
out = dec_feat + residual
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
@ARCH_REGISTRY.register()
|
|
||||||
class CodeFormer(VQAutoEncoder):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dim_embd=512,
|
|
||||||
n_head=8,
|
|
||||||
n_layers=9,
|
|
||||||
codebook_size=1024,
|
|
||||||
latent_size=256,
|
|
||||||
connect_list=["32", "64", "128", "256"],
|
|
||||||
fix_modules=["quantize", "generator"],
|
|
||||||
):
|
|
||||||
super(CodeFormer, self).__init__(
|
|
||||||
512, 64, [1, 2, 2, 4, 4, 8], "nearest", 2, [16], codebook_size
|
|
||||||
)
|
|
||||||
|
|
||||||
if fix_modules is not None:
|
|
||||||
for module in fix_modules:
|
|
||||||
for param in getattr(self, module).parameters():
|
|
||||||
param.requires_grad = False
|
|
||||||
|
|
||||||
self.connect_list = connect_list
|
|
||||||
self.n_layers = n_layers
|
|
||||||
self.dim_embd = dim_embd
|
|
||||||
self.dim_mlp = dim_embd * 2
|
|
||||||
|
|
||||||
self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embd))
|
|
||||||
self.feat_emb = nn.Linear(256, self.dim_embd)
|
|
||||||
|
|
||||||
# transformer
|
|
||||||
self.ft_layers = nn.Sequential(
|
|
||||||
*[
|
|
||||||
TransformerSALayer(
|
|
||||||
embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0
|
|
||||||
)
|
|
||||||
for _ in range(self.n_layers)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
# logits_predict head
|
|
||||||
self.idx_pred_layer = nn.Sequential(
|
|
||||||
nn.LayerNorm(dim_embd), nn.Linear(dim_embd, codebook_size, bias=False)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.channels = {
|
|
||||||
"16": 512,
|
|
||||||
"32": 256,
|
|
||||||
"64": 256,
|
|
||||||
"128": 128,
|
|
||||||
"256": 128,
|
|
||||||
"512": 64,
|
|
||||||
}
|
|
||||||
|
|
||||||
# after second residual block for > 16, before attn layer for ==16
|
|
||||||
self.fuse_encoder_block = {
|
|
||||||
"512": 2,
|
|
||||||
"256": 5,
|
|
||||||
"128": 8,
|
|
||||||
"64": 11,
|
|
||||||
"32": 14,
|
|
||||||
"16": 18,
|
|
||||||
}
|
|
||||||
# after first residual block for > 16, before attn layer for ==16
|
|
||||||
self.fuse_generator_block = {
|
|
||||||
"16": 6,
|
|
||||||
"32": 9,
|
|
||||||
"64": 12,
|
|
||||||
"128": 15,
|
|
||||||
"256": 18,
|
|
||||||
"512": 21,
|
|
||||||
}
|
|
||||||
|
|
||||||
# fuse_convs_dict
|
|
||||||
self.fuse_convs_dict = nn.ModuleDict()
|
|
||||||
for f_size in self.connect_list:
|
|
||||||
in_ch = self.channels[f_size]
|
|
||||||
self.fuse_convs_dict[f_size] = Fuse_sft_block(in_ch, in_ch)
|
|
||||||
|
|
||||||
def _init_weights(self, module):
|
|
||||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
|
||||||
module.weight.data.normal_(mean=0.0, std=0.02)
|
|
||||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
|
||||||
module.bias.data.zero_()
|
|
||||||
elif isinstance(module, nn.LayerNorm):
|
|
||||||
module.bias.data.zero_()
|
|
||||||
module.weight.data.fill_(1.0)
|
|
||||||
|
|
||||||
def forward(self, x, w=0, detach_16=True, code_only=False, adain=False):
|
|
||||||
# ################### Encoder #####################
|
|
||||||
enc_feat_dict = {}
|
|
||||||
out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
|
|
||||||
for i, block in enumerate(self.encoder.blocks):
|
|
||||||
x = block(x)
|
|
||||||
if i in out_list:
|
|
||||||
enc_feat_dict[str(x.shape[-1])] = x.clone()
|
|
||||||
|
|
||||||
lq_feat = x
|
|
||||||
# ################# Transformer ###################
|
|
||||||
# quant_feat, codebook_loss, quant_stats = self.quantize(lq_feat)
|
|
||||||
pos_emb = self.position_emb.unsqueeze(1).repeat(1, x.shape[0], 1)
|
|
||||||
# BCHW -> BC(HW) -> (HW)BC
|
|
||||||
feat_emb = self.feat_emb(lq_feat.flatten(2).permute(2, 0, 1))
|
|
||||||
query_emb = feat_emb
|
|
||||||
# Transformer encoder
|
|
||||||
for layer in self.ft_layers:
|
|
||||||
query_emb = layer(query_emb, query_pos=pos_emb)
|
|
||||||
|
|
||||||
# output logits
|
|
||||||
logits = self.idx_pred_layer(query_emb) # (hw)bn
|
|
||||||
logits = logits.permute(1, 0, 2) # (hw)bn -> b(hw)n
|
|
||||||
|
|
||||||
if code_only: # for training stage II
|
|
||||||
# logits doesn't need softmax before cross_entropy loss
|
|
||||||
return logits, lq_feat
|
|
||||||
|
|
||||||
# ################# Quantization ###################
|
|
||||||
# if self.training:
|
|
||||||
# quant_feat = torch.einsum('btn,nc->btc', [soft_one_hot, self.quantize.embedding.weight])
|
|
||||||
# # b(hw)c -> bc(hw) -> bchw
|
|
||||||
# quant_feat = quant_feat.permute(0,2,1).view(lq_feat.shape)
|
|
||||||
# ------------
|
|
||||||
soft_one_hot = F.softmax(logits, dim=2)
|
|
||||||
_, top_idx = torch.topk(soft_one_hot, 1, dim=2)
|
|
||||||
quant_feat = self.quantize.get_codebook_feat(
|
|
||||||
top_idx, shape=[x.shape[0], 16, 16, 256]
|
|
||||||
)
|
|
||||||
# preserve gradients
|
|
||||||
# quant_feat = lq_feat + (quant_feat - lq_feat).detach()
|
|
||||||
|
|
||||||
if detach_16:
|
|
||||||
quant_feat = quant_feat.detach() # for training stage III
|
|
||||||
if adain:
|
|
||||||
quant_feat = adaptive_instance_normalization(quant_feat, lq_feat)
|
|
||||||
|
|
||||||
# ################## Generator ####################
|
|
||||||
x = quant_feat
|
|
||||||
fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
|
|
||||||
|
|
||||||
for i, block in enumerate(self.generator.blocks):
|
|
||||||
x = block(x)
|
|
||||||
if i in fuse_list: # fuse after i-th block
|
|
||||||
f_size = str(x.shape[-1])
|
|
||||||
if w > 0:
|
|
||||||
x = self.fuse_convs_dict[f_size](
|
|
||||||
enc_feat_dict[f_size].detach(), x, w
|
|
||||||
)
|
|
||||||
out = x
|
|
||||||
# logits doesn't need softmax before cross_entropy loss
|
|
||||||
return out, logits, lq_feat
|
|
@ -1,84 +0,0 @@
|
|||||||
import os
|
|
||||||
import sys
|
|
||||||
import warnings
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
|
||||||
|
|
||||||
class GFPGAN:
|
|
||||||
def __init__(self, gfpgan_model_path="models/gfpgan/GFPGANv1.4.pth") -> None:
|
|
||||||
self.globals = InvokeAIAppConfig.get_config()
|
|
||||||
if not os.path.isabs(gfpgan_model_path):
|
|
||||||
gfpgan_model_path = self.globals.root_dir / gfpgan_model_path
|
|
||||||
self.model_path = gfpgan_model_path
|
|
||||||
self.gfpgan_model_exists = os.path.isfile(self.model_path)
|
|
||||||
|
|
||||||
if not self.gfpgan_model_exists:
|
|
||||||
logger.error(f"NOT FOUND: GFPGAN model not found at {self.model_path}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def model_exists(self):
|
|
||||||
return os.path.isfile(self.model_path)
|
|
||||||
|
|
||||||
def process(self, image, strength: float, seed: str = None):
|
|
||||||
if seed is not None:
|
|
||||||
logger.info(f"GFPGAN - Restoring Faces for image seed:{seed}")
|
|
||||||
|
|
||||||
with warnings.catch_warnings():
|
|
||||||
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
|
||||||
warnings.filterwarnings("ignore", category=UserWarning)
|
|
||||||
cwd = os.getcwd()
|
|
||||||
os.chdir(self.globals.root_dir / 'models')
|
|
||||||
try:
|
|
||||||
from gfpgan import GFPGANer
|
|
||||||
|
|
||||||
self.gfpgan = GFPGANer(
|
|
||||||
model_path=self.model_path,
|
|
||||||
upscale=1,
|
|
||||||
arch="clean",
|
|
||||||
channel_multiplier=2,
|
|
||||||
bg_upsampler=None,
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
logger.error("Error loading GFPGAN:", file=sys.stderr)
|
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
os.chdir(cwd)
|
|
||||||
|
|
||||||
if self.gfpgan is None:
|
|
||||||
logger.warning("WARNING: GFPGAN not initialized.")
|
|
||||||
logger.warning(
|
|
||||||
f"Download https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth to {self.model_path}"
|
|
||||||
)
|
|
||||||
|
|
||||||
image = image.convert("RGB")
|
|
||||||
|
|
||||||
# GFPGAN expects a BGR np array; make array and flip channels
|
|
||||||
bgr_image_array = np.array(image, dtype=np.uint8)[..., ::-1]
|
|
||||||
|
|
||||||
_, _, restored_img = self.gfpgan.enhance(
|
|
||||||
bgr_image_array,
|
|
||||||
has_aligned=False,
|
|
||||||
only_center_face=False,
|
|
||||||
paste_back=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Flip the channels back to RGB
|
|
||||||
res = Image.fromarray(restored_img[..., ::-1])
|
|
||||||
|
|
||||||
if strength < 1.0:
|
|
||||||
# Resize the image to the new image if the sizes have changed
|
|
||||||
if restored_img.size != image.size:
|
|
||||||
image = image.resize(res.size)
|
|
||||||
res = Image.blend(image, res, strength)
|
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
self.gfpgan = None
|
|
||||||
|
|
||||||
return res
|
|
@ -1,118 +0,0 @@
|
|||||||
import math
|
|
||||||
|
|
||||||
from PIL import Image
|
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
|
|
||||||
class Outcrop(object):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
image,
|
|
||||||
generate, # current generate object
|
|
||||||
):
|
|
||||||
self.image = image
|
|
||||||
self.generate = generate
|
|
||||||
|
|
||||||
def process(
|
|
||||||
self,
|
|
||||||
extents: dict,
|
|
||||||
opt, # current options
|
|
||||||
orig_opt, # ones originally used to generate the image
|
|
||||||
image_callback=None,
|
|
||||||
prefix=None,
|
|
||||||
):
|
|
||||||
# grow and mask the image
|
|
||||||
extended_image = self._extend_all(extents)
|
|
||||||
|
|
||||||
# switch samplers temporarily
|
|
||||||
curr_sampler = self.generate.sampler
|
|
||||||
self.generate.sampler_name = opt.sampler_name
|
|
||||||
self.generate._set_scheduler()
|
|
||||||
|
|
||||||
def wrapped_callback(img, seed, **kwargs):
|
|
||||||
preferred_seed = (
|
|
||||||
orig_opt.seed
|
|
||||||
if orig_opt.seed is not None and orig_opt.seed >= 0
|
|
||||||
else seed
|
|
||||||
)
|
|
||||||
image_callback(img, preferred_seed, use_prefix=prefix, **kwargs)
|
|
||||||
|
|
||||||
result = self.generate.prompt2image(
|
|
||||||
opt.prompt,
|
|
||||||
seed=opt.seed or orig_opt.seed,
|
|
||||||
sampler=self.generate.sampler,
|
|
||||||
steps=opt.steps,
|
|
||||||
cfg_scale=opt.cfg_scale,
|
|
||||||
ddim_eta=self.generate.ddim_eta,
|
|
||||||
width=extended_image.width,
|
|
||||||
height=extended_image.height,
|
|
||||||
init_img=extended_image,
|
|
||||||
strength=0.90,
|
|
||||||
image_callback=wrapped_callback if image_callback else None,
|
|
||||||
seam_size=opt.seam_size or 96,
|
|
||||||
seam_blur=opt.seam_blur or 16,
|
|
||||||
seam_strength=opt.seam_strength or 0.7,
|
|
||||||
seam_steps=20,
|
|
||||||
tile_size=32,
|
|
||||||
color_match=True,
|
|
||||||
force_outpaint=True, # this just stops the warning about erased regions
|
|
||||||
)
|
|
||||||
|
|
||||||
# swap sampler back
|
|
||||||
self.generate.sampler = curr_sampler
|
|
||||||
return result
|
|
||||||
|
|
||||||
def _extend_all(
|
|
||||||
self,
|
|
||||||
extents: dict,
|
|
||||||
) -> Image:
|
|
||||||
"""
|
|
||||||
Extend the image in direction ('top','bottom','left','right') by
|
|
||||||
the indicated value. The image canvas is extended, and the empty
|
|
||||||
rectangular section will be filled with a blurred copy of the
|
|
||||||
adjacent image.
|
|
||||||
"""
|
|
||||||
image = self.image
|
|
||||||
for direction in extents:
|
|
||||||
assert direction in [
|
|
||||||
"top",
|
|
||||||
"left",
|
|
||||||
"bottom",
|
|
||||||
"right",
|
|
||||||
], 'Direction must be one of "top", "left", "bottom", "right"'
|
|
||||||
pixels = extents[direction]
|
|
||||||
# round pixels up to the nearest 64
|
|
||||||
pixels = math.ceil(pixels / 64) * 64
|
|
||||||
logger.info(f"extending image {direction}ward by {pixels} pixels")
|
|
||||||
image = self._rotate(image, direction)
|
|
||||||
image = self._extend(image, pixels)
|
|
||||||
image = self._rotate(image, direction, reverse=True)
|
|
||||||
return image
|
|
||||||
|
|
||||||
def _rotate(self, image: Image, direction: str, reverse=False) -> Image:
|
|
||||||
"""
|
|
||||||
Rotates image so that the area to extend is always at the top top.
|
|
||||||
Simplifies logic later. The reverse argument, if true, will undo the
|
|
||||||
previous transpose.
|
|
||||||
"""
|
|
||||||
transposes = {
|
|
||||||
"right": ["ROTATE_90", "ROTATE_270"],
|
|
||||||
"bottom": ["ROTATE_180", "ROTATE_180"],
|
|
||||||
"left": ["ROTATE_270", "ROTATE_90"],
|
|
||||||
}
|
|
||||||
if direction not in transposes:
|
|
||||||
return image
|
|
||||||
transpose = transposes[direction][1 if reverse else 0]
|
|
||||||
return image.transpose(Image.Transpose.__dict__[transpose])
|
|
||||||
|
|
||||||
def _extend(self, image: Image, pixels: int) -> Image:
|
|
||||||
extended_img = Image.new("RGBA", (image.width, image.height + pixels))
|
|
||||||
|
|
||||||
extended_img.paste((0, 0, 0), [0, 0, image.width, image.height + pixels])
|
|
||||||
extended_img.paste(image, box=(0, pixels))
|
|
||||||
|
|
||||||
# now make the top part transparent to use as a mask
|
|
||||||
alpha = extended_img.getchannel("A")
|
|
||||||
alpha.paste(0, (0, 0, extended_img.width, pixels))
|
|
||||||
extended_img.putalpha(alpha)
|
|
||||||
|
|
||||||
return extended_img
|
|
@ -1,102 +0,0 @@
|
|||||||
import math
|
|
||||||
import warnings
|
|
||||||
|
|
||||||
from PIL import Image, ImageFilter
|
|
||||||
|
|
||||||
|
|
||||||
class Outpaint(object):
|
|
||||||
def __init__(self, image, generate):
|
|
||||||
self.image = image
|
|
||||||
self.generate = generate
|
|
||||||
|
|
||||||
def process(self, opt, old_opt, image_callback=None, prefix=None):
|
|
||||||
image = self._create_outpaint_image(self.image, opt.out_direction)
|
|
||||||
|
|
||||||
seed = old_opt.seed
|
|
||||||
prompt = old_opt.prompt
|
|
||||||
|
|
||||||
def wrapped_callback(img, seed, **kwargs):
|
|
||||||
image_callback(img, seed, use_prefix=prefix, **kwargs)
|
|
||||||
|
|
||||||
return self.generate.prompt2image(
|
|
||||||
prompt,
|
|
||||||
seed=seed,
|
|
||||||
sampler=self.generate.sampler,
|
|
||||||
steps=opt.steps,
|
|
||||||
cfg_scale=opt.cfg_scale,
|
|
||||||
ddim_eta=self.generate.ddim_eta,
|
|
||||||
width=opt.width,
|
|
||||||
height=opt.height,
|
|
||||||
init_img=image,
|
|
||||||
strength=0.83,
|
|
||||||
image_callback=wrapped_callback,
|
|
||||||
prefix=prefix,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _create_outpaint_image(self, image, direction_args):
|
|
||||||
assert len(direction_args) in [
|
|
||||||
1,
|
|
||||||
2,
|
|
||||||
], "Direction (-D) must have exactly one or two arguments."
|
|
||||||
|
|
||||||
if len(direction_args) == 1:
|
|
||||||
direction = direction_args[0]
|
|
||||||
pixels = None
|
|
||||||
elif len(direction_args) == 2:
|
|
||||||
direction = direction_args[0]
|
|
||||||
pixels = int(direction_args[1])
|
|
||||||
|
|
||||||
assert direction in [
|
|
||||||
"top",
|
|
||||||
"left",
|
|
||||||
"bottom",
|
|
||||||
"right",
|
|
||||||
], 'Direction (-D) must be one of "top", "left", "bottom", "right"'
|
|
||||||
|
|
||||||
image = image.convert("RGBA")
|
|
||||||
# we always extend top, but rotate to extend along the requested side
|
|
||||||
if direction == "left":
|
|
||||||
image = image.transpose(Image.Transpose.ROTATE_270)
|
|
||||||
elif direction == "bottom":
|
|
||||||
image = image.transpose(Image.Transpose.ROTATE_180)
|
|
||||||
elif direction == "right":
|
|
||||||
image = image.transpose(Image.Transpose.ROTATE_90)
|
|
||||||
|
|
||||||
pixels = image.height // 2 if pixels is None else int(pixels)
|
|
||||||
assert (
|
|
||||||
0 < pixels < image.height
|
|
||||||
), "Direction (-D) pixels length must be in the range 0 - image.size"
|
|
||||||
|
|
||||||
# the top part of the image is taken from the source image mirrored
|
|
||||||
# coordinates (0,0) are the upper left corner of an image
|
|
||||||
top = image.transpose(Image.Transpose.FLIP_TOP_BOTTOM).convert("RGBA")
|
|
||||||
top = top.crop((0, top.height - pixels, top.width, top.height))
|
|
||||||
|
|
||||||
# setting all alpha of the top part to 0
|
|
||||||
alpha = top.getchannel("A")
|
|
||||||
alpha.paste(0, (0, 0, top.width, top.height))
|
|
||||||
top.putalpha(alpha)
|
|
||||||
|
|
||||||
# taking the bottom from the original image
|
|
||||||
bottom = image.crop((0, 0, image.width, image.height - pixels))
|
|
||||||
|
|
||||||
new_img = image.copy()
|
|
||||||
new_img.paste(top, (0, 0))
|
|
||||||
new_img.paste(bottom, (0, pixels))
|
|
||||||
|
|
||||||
# create a 10% dither in the middle
|
|
||||||
dither = min(image.height // 10, pixels)
|
|
||||||
for x in range(0, image.width, 2):
|
|
||||||
for y in range(pixels - dither, pixels + dither):
|
|
||||||
(r, g, b, a) = new_img.getpixel((x, y))
|
|
||||||
new_img.putpixel((x, y), (r, g, b, 0))
|
|
||||||
|
|
||||||
# let's rotate back again
|
|
||||||
if direction == "left":
|
|
||||||
new_img = new_img.transpose(Image.Transpose.ROTATE_90)
|
|
||||||
elif direction == "bottom":
|
|
||||||
new_img = new_img.transpose(Image.Transpose.ROTATE_180)
|
|
||||||
elif direction == "right":
|
|
||||||
new_img = new_img.transpose(Image.Transpose.ROTATE_270)
|
|
||||||
|
|
||||||
return new_img
|
|
@ -1,104 +0,0 @@
|
|||||||
import warnings
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from PIL import Image
|
|
||||||
from PIL.Image import Image as ImageType
|
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
|
||||||
config = InvokeAIAppConfig.get_config()
|
|
||||||
|
|
||||||
class ESRGAN:
|
|
||||||
def __init__(self, bg_tile_size=400) -> None:
|
|
||||||
self.bg_tile_size = bg_tile_size
|
|
||||||
|
|
||||||
def load_esrgan_bg_upsampler(self, denoise_str):
|
|
||||||
if not torch.cuda.is_available(): # CPU or MPS on M1
|
|
||||||
use_half_precision = False
|
|
||||||
else:
|
|
||||||
use_half_precision = True
|
|
||||||
|
|
||||||
from realesrgan import RealESRGANer
|
|
||||||
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
|
|
||||||
|
|
||||||
model = SRVGGNetCompact(
|
|
||||||
num_in_ch=3,
|
|
||||||
num_out_ch=3,
|
|
||||||
num_feat=64,
|
|
||||||
num_conv=32,
|
|
||||||
upscale=4,
|
|
||||||
act_type="prelu",
|
|
||||||
)
|
|
||||||
model_path = config.models_path / "core/upscaling/realesrgan/realesr-general-x4v3.pth"
|
|
||||||
wdn_model_path = config.models_path / "core/upscaling/realesrgan/realesr-general-wdn-x4v3.pth"
|
|
||||||
scale = 4
|
|
||||||
|
|
||||||
bg_upsampler = RealESRGANer(
|
|
||||||
scale=scale,
|
|
||||||
model_path=[model_path, wdn_model_path],
|
|
||||||
model=model,
|
|
||||||
tile=self.bg_tile_size,
|
|
||||||
dni_weight=[denoise_str, 1 - denoise_str],
|
|
||||||
tile_pad=10,
|
|
||||||
pre_pad=0,
|
|
||||||
half=use_half_precision,
|
|
||||||
)
|
|
||||||
|
|
||||||
return bg_upsampler
|
|
||||||
|
|
||||||
def process(
|
|
||||||
self,
|
|
||||||
image: ImageType,
|
|
||||||
strength: float,
|
|
||||||
seed: str = None,
|
|
||||||
upsampler_scale: int = 2,
|
|
||||||
denoise_str: float = 0.75,
|
|
||||||
):
|
|
||||||
with warnings.catch_warnings():
|
|
||||||
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
|
||||||
warnings.filterwarnings("ignore", category=UserWarning)
|
|
||||||
|
|
||||||
try:
|
|
||||||
upsampler = self.load_esrgan_bg_upsampler(denoise_str)
|
|
||||||
except Exception:
|
|
||||||
import sys
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
logger.error("Error loading Real-ESRGAN:")
|
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
|
|
||||||
if upsampler_scale == 0:
|
|
||||||
logger.warning("Real-ESRGAN: Invalid scaling option. Image not upscaled.")
|
|
||||||
return image
|
|
||||||
|
|
||||||
if seed is not None:
|
|
||||||
logger.info(
|
|
||||||
f"Real-ESRGAN Upscaling seed:{seed}, scale:{upsampler_scale}x, tile:{self.bg_tile_size}, denoise:{denoise_str}"
|
|
||||||
)
|
|
||||||
# ESRGAN outputs images with partial transparency if given RGBA images; convert to RGB
|
|
||||||
image = image.convert("RGB")
|
|
||||||
|
|
||||||
# REALSRGAN expects a BGR np array; make array and flip channels
|
|
||||||
bgr_image_array = np.array(image, dtype=np.uint8)[..., ::-1]
|
|
||||||
|
|
||||||
output, _ = upsampler.enhance(
|
|
||||||
bgr_image_array,
|
|
||||||
outscale=upsampler_scale,
|
|
||||||
alpha_upsampler="realesrgan",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Flip the channels back to RGB
|
|
||||||
res = Image.fromarray(output[..., ::-1])
|
|
||||||
|
|
||||||
if strength < 1.0:
|
|
||||||
# Resize the image to the new image if the sizes have changed
|
|
||||||
if output.size != image.size:
|
|
||||||
image = image.resize(res.size)
|
|
||||||
res = Image.blend(image, res, strength)
|
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
upsampler = None
|
|
||||||
|
|
||||||
return res
|
|
@ -1,514 +0,0 @@
|
|||||||
"""
|
|
||||||
VQGAN code, adapted from the original created by the Unleashing Transformers authors:
|
|
||||||
https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py
|
|
||||||
|
|
||||||
"""
|
|
||||||
import copy
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from basicsr.utils import get_root_logger
|
|
||||||
from basicsr.utils.registry import ARCH_REGISTRY
|
|
||||||
|
|
||||||
|
|
||||||
def normalize(in_channels):
|
|
||||||
return torch.nn.GroupNorm(
|
|
||||||
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
|
||||||
def swish(x):
|
|
||||||
return x * torch.sigmoid(x)
|
|
||||||
|
|
||||||
|
|
||||||
# Define VQVAE classes
|
|
||||||
class VectorQuantizer(nn.Module):
|
|
||||||
def __init__(self, codebook_size, emb_dim, beta):
|
|
||||||
super(VectorQuantizer, self).__init__()
|
|
||||||
self.codebook_size = codebook_size # number of embeddings
|
|
||||||
self.emb_dim = emb_dim # dimension of embedding
|
|
||||||
self.beta = beta # commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
|
|
||||||
self.embedding = nn.Embedding(self.codebook_size, self.emb_dim)
|
|
||||||
self.embedding.weight.data.uniform_(
|
|
||||||
-1.0 / self.codebook_size, 1.0 / self.codebook_size
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, z):
|
|
||||||
# reshape z -> (batch, height, width, channel) and flatten
|
|
||||||
z = z.permute(0, 2, 3, 1).contiguous()
|
|
||||||
z_flattened = z.view(-1, self.emb_dim)
|
|
||||||
|
|
||||||
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
|
||||||
d = (
|
|
||||||
(z_flattened**2).sum(dim=1, keepdim=True)
|
|
||||||
+ (self.embedding.weight**2).sum(1)
|
|
||||||
- 2 * torch.matmul(z_flattened, self.embedding.weight.t())
|
|
||||||
)
|
|
||||||
|
|
||||||
mean_distance = torch.mean(d)
|
|
||||||
# find closest encodings
|
|
||||||
# min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
|
|
||||||
min_encoding_scores, min_encoding_indices = torch.topk(
|
|
||||||
d, 1, dim=1, largest=False
|
|
||||||
)
|
|
||||||
# [0-1], higher score, higher confidence
|
|
||||||
min_encoding_scores = torch.exp(-min_encoding_scores / 10)
|
|
||||||
|
|
||||||
min_encodings = torch.zeros(
|
|
||||||
min_encoding_indices.shape[0], self.codebook_size
|
|
||||||
).to(z)
|
|
||||||
min_encodings.scatter_(1, min_encoding_indices, 1)
|
|
||||||
|
|
||||||
# get quantized latent vectors
|
|
||||||
z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
|
|
||||||
# compute loss for embedding
|
|
||||||
loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean(
|
|
||||||
(z_q - z.detach()) ** 2
|
|
||||||
)
|
|
||||||
# preserve gradients
|
|
||||||
z_q = z + (z_q - z).detach()
|
|
||||||
|
|
||||||
# perplexity
|
|
||||||
e_mean = torch.mean(min_encodings, dim=0)
|
|
||||||
perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
|
|
||||||
# reshape back to match original input shape
|
|
||||||
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
|
||||||
|
|
||||||
return (
|
|
||||||
z_q,
|
|
||||||
loss,
|
|
||||||
{
|
|
||||||
"perplexity": perplexity,
|
|
||||||
"min_encodings": min_encodings,
|
|
||||||
"min_encoding_indices": min_encoding_indices,
|
|
||||||
"min_encoding_scores": min_encoding_scores,
|
|
||||||
"mean_distance": mean_distance,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_codebook_feat(self, indices, shape):
|
|
||||||
# input indices: batch*token_num -> (batch*token_num)*1
|
|
||||||
# shape: batch, height, width, channel
|
|
||||||
indices = indices.view(-1, 1)
|
|
||||||
min_encodings = torch.zeros(indices.shape[0], self.codebook_size).to(indices)
|
|
||||||
min_encodings.scatter_(1, indices, 1)
|
|
||||||
# get quantized latent vectors
|
|
||||||
z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
|
|
||||||
|
|
||||||
if shape is not None: # reshape back to match original input shape
|
|
||||||
z_q = z_q.view(shape).permute(0, 3, 1, 2).contiguous()
|
|
||||||
|
|
||||||
return z_q
|
|
||||||
|
|
||||||
|
|
||||||
class GumbelQuantizer(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
codebook_size,
|
|
||||||
emb_dim,
|
|
||||||
num_hiddens,
|
|
||||||
straight_through=False,
|
|
||||||
kl_weight=5e-4,
|
|
||||||
temp_init=1.0,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.codebook_size = codebook_size # number of embeddings
|
|
||||||
self.emb_dim = emb_dim # dimension of embedding
|
|
||||||
self.straight_through = straight_through
|
|
||||||
self.temperature = temp_init
|
|
||||||
self.kl_weight = kl_weight
|
|
||||||
self.proj = nn.Conv2d(
|
|
||||||
num_hiddens, codebook_size, 1
|
|
||||||
) # projects last encoder layer to quantized logits
|
|
||||||
self.embed = nn.Embedding(codebook_size, emb_dim)
|
|
||||||
|
|
||||||
def forward(self, z):
|
|
||||||
hard = self.straight_through if self.training else True
|
|
||||||
|
|
||||||
logits = self.proj(z)
|
|
||||||
|
|
||||||
soft_one_hot = F.gumbel_softmax(logits, tau=self.temperature, dim=1, hard=hard)
|
|
||||||
|
|
||||||
z_q = torch.einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight)
|
|
||||||
|
|
||||||
# + kl divergence to the prior loss
|
|
||||||
qy = F.softmax(logits, dim=1)
|
|
||||||
diff = (
|
|
||||||
self.kl_weight
|
|
||||||
* torch.sum(qy * torch.log(qy * self.codebook_size + 1e-10), dim=1).mean()
|
|
||||||
)
|
|
||||||
min_encoding_indices = soft_one_hot.argmax(dim=1)
|
|
||||||
|
|
||||||
return z_q, diff, {"min_encoding_indices": min_encoding_indices}
|
|
||||||
|
|
||||||
|
|
||||||
class Downsample(nn.Module):
|
|
||||||
def __init__(self, in_channels):
|
|
||||||
super().__init__()
|
|
||||||
self.conv = torch.nn.Conv2d(
|
|
||||||
in_channels, in_channels, kernel_size=3, stride=2, padding=0
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
pad = (0, 1, 0, 1)
|
|
||||||
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
|
||||||
x = self.conv(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class Upsample(nn.Module):
|
|
||||||
def __init__(self, in_channels):
|
|
||||||
super().__init__()
|
|
||||||
self.conv = nn.Conv2d(
|
|
||||||
in_channels, in_channels, kernel_size=3, stride=1, padding=1
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
|
|
||||||
x = self.conv(x)
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class ResBlock(nn.Module):
|
|
||||||
def __init__(self, in_channels, out_channels=None):
|
|
||||||
super(ResBlock, self).__init__()
|
|
||||||
self.in_channels = in_channels
|
|
||||||
self.out_channels = in_channels if out_channels is None else out_channels
|
|
||||||
self.norm1 = normalize(in_channels)
|
|
||||||
self.conv1 = nn.Conv2d(
|
|
||||||
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
|
||||||
)
|
|
||||||
self.norm2 = normalize(out_channels)
|
|
||||||
self.conv2 = nn.Conv2d(
|
|
||||||
out_channels, out_channels, kernel_size=3, stride=1, padding=1
|
|
||||||
)
|
|
||||||
if self.in_channels != self.out_channels:
|
|
||||||
self.conv_out = nn.Conv2d(
|
|
||||||
in_channels, out_channels, kernel_size=1, stride=1, padding=0
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x_in):
|
|
||||||
x = x_in
|
|
||||||
x = self.norm1(x)
|
|
||||||
x = swish(x)
|
|
||||||
x = self.conv1(x)
|
|
||||||
x = self.norm2(x)
|
|
||||||
x = swish(x)
|
|
||||||
x = self.conv2(x)
|
|
||||||
if self.in_channels != self.out_channels:
|
|
||||||
x_in = self.conv_out(x_in)
|
|
||||||
|
|
||||||
return x + x_in
|
|
||||||
|
|
||||||
|
|
||||||
class AttnBlock(nn.Module):
|
|
||||||
def __init__(self, in_channels):
|
|
||||||
super().__init__()
|
|
||||||
self.in_channels = in_channels
|
|
||||||
|
|
||||||
self.norm = normalize(in_channels)
|
|
||||||
self.q = torch.nn.Conv2d(
|
|
||||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
|
||||||
)
|
|
||||||
self.k = torch.nn.Conv2d(
|
|
||||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
|
||||||
)
|
|
||||||
self.v = torch.nn.Conv2d(
|
|
||||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
|
||||||
)
|
|
||||||
self.proj_out = torch.nn.Conv2d(
|
|
||||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
h_ = x
|
|
||||||
h_ = self.norm(h_)
|
|
||||||
q = self.q(h_)
|
|
||||||
k = self.k(h_)
|
|
||||||
v = self.v(h_)
|
|
||||||
|
|
||||||
# compute attention
|
|
||||||
b, c, h, w = q.shape
|
|
||||||
q = q.reshape(b, c, h * w)
|
|
||||||
q = q.permute(0, 2, 1)
|
|
||||||
k = k.reshape(b, c, h * w)
|
|
||||||
w_ = torch.bmm(q, k)
|
|
||||||
w_ = w_ * (int(c) ** (-0.5))
|
|
||||||
w_ = F.softmax(w_, dim=2)
|
|
||||||
|
|
||||||
# attend to values
|
|
||||||
v = v.reshape(b, c, h * w)
|
|
||||||
w_ = w_.permute(0, 2, 1)
|
|
||||||
h_ = torch.bmm(v, w_)
|
|
||||||
h_ = h_.reshape(b, c, h, w)
|
|
||||||
|
|
||||||
h_ = self.proj_out(h_)
|
|
||||||
|
|
||||||
return x + h_
|
|
||||||
|
|
||||||
|
|
||||||
class Encoder(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
in_channels,
|
|
||||||
nf,
|
|
||||||
emb_dim,
|
|
||||||
ch_mult,
|
|
||||||
num_res_blocks,
|
|
||||||
resolution,
|
|
||||||
attn_resolutions,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.nf = nf
|
|
||||||
self.num_resolutions = len(ch_mult)
|
|
||||||
self.num_res_blocks = num_res_blocks
|
|
||||||
self.resolution = resolution
|
|
||||||
self.attn_resolutions = attn_resolutions
|
|
||||||
|
|
||||||
curr_res = self.resolution
|
|
||||||
in_ch_mult = (1,) + tuple(ch_mult)
|
|
||||||
|
|
||||||
blocks = []
|
|
||||||
# initial convultion
|
|
||||||
blocks.append(nn.Conv2d(in_channels, nf, kernel_size=3, stride=1, padding=1))
|
|
||||||
|
|
||||||
# residual and downsampling blocks, with attention on smaller res (16x16)
|
|
||||||
for i in range(self.num_resolutions):
|
|
||||||
block_in_ch = nf * in_ch_mult[i]
|
|
||||||
block_out_ch = nf * ch_mult[i]
|
|
||||||
for _ in range(self.num_res_blocks):
|
|
||||||
blocks.append(ResBlock(block_in_ch, block_out_ch))
|
|
||||||
block_in_ch = block_out_ch
|
|
||||||
if curr_res in attn_resolutions:
|
|
||||||
blocks.append(AttnBlock(block_in_ch))
|
|
||||||
|
|
||||||
if i != self.num_resolutions - 1:
|
|
||||||
blocks.append(Downsample(block_in_ch))
|
|
||||||
curr_res = curr_res // 2
|
|
||||||
|
|
||||||
# non-local attention block
|
|
||||||
blocks.append(ResBlock(block_in_ch, block_in_ch))
|
|
||||||
blocks.append(AttnBlock(block_in_ch))
|
|
||||||
blocks.append(ResBlock(block_in_ch, block_in_ch))
|
|
||||||
|
|
||||||
# normalise and convert to latent size
|
|
||||||
blocks.append(normalize(block_in_ch))
|
|
||||||
blocks.append(
|
|
||||||
nn.Conv2d(block_in_ch, emb_dim, kernel_size=3, stride=1, padding=1)
|
|
||||||
)
|
|
||||||
self.blocks = nn.ModuleList(blocks)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
for block in self.blocks:
|
|
||||||
x = block(x)
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class Generator(nn.Module):
|
|
||||||
def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions):
|
|
||||||
super().__init__()
|
|
||||||
self.nf = nf
|
|
||||||
self.ch_mult = ch_mult
|
|
||||||
self.num_resolutions = len(self.ch_mult)
|
|
||||||
self.num_res_blocks = res_blocks
|
|
||||||
self.resolution = img_size
|
|
||||||
self.attn_resolutions = attn_resolutions
|
|
||||||
self.in_channels = emb_dim
|
|
||||||
self.out_channels = 3
|
|
||||||
block_in_ch = self.nf * self.ch_mult[-1]
|
|
||||||
curr_res = self.resolution // 2 ** (self.num_resolutions - 1)
|
|
||||||
|
|
||||||
blocks = []
|
|
||||||
# initial conv
|
|
||||||
blocks.append(
|
|
||||||
nn.Conv2d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1)
|
|
||||||
)
|
|
||||||
|
|
||||||
# non-local attention block
|
|
||||||
blocks.append(ResBlock(block_in_ch, block_in_ch))
|
|
||||||
blocks.append(AttnBlock(block_in_ch))
|
|
||||||
blocks.append(ResBlock(block_in_ch, block_in_ch))
|
|
||||||
|
|
||||||
for i in reversed(range(self.num_resolutions)):
|
|
||||||
block_out_ch = self.nf * self.ch_mult[i]
|
|
||||||
|
|
||||||
for _ in range(self.num_res_blocks):
|
|
||||||
blocks.append(ResBlock(block_in_ch, block_out_ch))
|
|
||||||
block_in_ch = block_out_ch
|
|
||||||
|
|
||||||
if curr_res in self.attn_resolutions:
|
|
||||||
blocks.append(AttnBlock(block_in_ch))
|
|
||||||
|
|
||||||
if i != 0:
|
|
||||||
blocks.append(Upsample(block_in_ch))
|
|
||||||
curr_res = curr_res * 2
|
|
||||||
|
|
||||||
blocks.append(normalize(block_in_ch))
|
|
||||||
blocks.append(
|
|
||||||
nn.Conv2d(
|
|
||||||
block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.blocks = nn.ModuleList(blocks)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
for block in self.blocks:
|
|
||||||
x = block(x)
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
@ARCH_REGISTRY.register()
|
|
||||||
class VQAutoEncoder(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
img_size,
|
|
||||||
nf,
|
|
||||||
ch_mult,
|
|
||||||
quantizer="nearest",
|
|
||||||
res_blocks=2,
|
|
||||||
attn_resolutions=[16],
|
|
||||||
codebook_size=1024,
|
|
||||||
emb_dim=256,
|
|
||||||
beta=0.25,
|
|
||||||
gumbel_straight_through=False,
|
|
||||||
gumbel_kl_weight=1e-8,
|
|
||||||
model_path=None,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
logger = get_root_logger()
|
|
||||||
self.in_channels = 3
|
|
||||||
self.nf = nf
|
|
||||||
self.n_blocks = res_blocks
|
|
||||||
self.codebook_size = codebook_size
|
|
||||||
self.embed_dim = emb_dim
|
|
||||||
self.ch_mult = ch_mult
|
|
||||||
self.resolution = img_size
|
|
||||||
self.attn_resolutions = attn_resolutions
|
|
||||||
self.quantizer_type = quantizer
|
|
||||||
self.encoder = Encoder(
|
|
||||||
self.in_channels,
|
|
||||||
self.nf,
|
|
||||||
self.embed_dim,
|
|
||||||
self.ch_mult,
|
|
||||||
self.n_blocks,
|
|
||||||
self.resolution,
|
|
||||||
self.attn_resolutions,
|
|
||||||
)
|
|
||||||
if self.quantizer_type == "nearest":
|
|
||||||
self.beta = beta # 0.25
|
|
||||||
self.quantize = VectorQuantizer(
|
|
||||||
self.codebook_size, self.embed_dim, self.beta
|
|
||||||
)
|
|
||||||
elif self.quantizer_type == "gumbel":
|
|
||||||
self.gumbel_num_hiddens = emb_dim
|
|
||||||
self.straight_through = gumbel_straight_through
|
|
||||||
self.kl_weight = gumbel_kl_weight
|
|
||||||
self.quantize = GumbelQuantizer(
|
|
||||||
self.codebook_size,
|
|
||||||
self.embed_dim,
|
|
||||||
self.gumbel_num_hiddens,
|
|
||||||
self.straight_through,
|
|
||||||
self.kl_weight,
|
|
||||||
)
|
|
||||||
self.generator = Generator(
|
|
||||||
self.nf,
|
|
||||||
self.embed_dim,
|
|
||||||
self.ch_mult,
|
|
||||||
self.n_blocks,
|
|
||||||
self.resolution,
|
|
||||||
self.attn_resolutions,
|
|
||||||
)
|
|
||||||
|
|
||||||
if model_path is not None:
|
|
||||||
chkpt = torch.load(model_path, map_location="cpu")
|
|
||||||
if "params_ema" in chkpt:
|
|
||||||
self.load_state_dict(
|
|
||||||
torch.load(model_path, map_location="cpu")["params_ema"]
|
|
||||||
)
|
|
||||||
logger.info(f"vqgan is loaded from: {model_path} [params_ema]")
|
|
||||||
elif "params" in chkpt:
|
|
||||||
self.load_state_dict(
|
|
||||||
torch.load(model_path, map_location="cpu")["params"]
|
|
||||||
)
|
|
||||||
logger.info(f"vqgan is loaded from: {model_path} [params]")
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Wrong params!")
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.encoder(x)
|
|
||||||
quant, codebook_loss, quant_stats = self.quantize(x)
|
|
||||||
x = self.generator(quant)
|
|
||||||
return x, codebook_loss, quant_stats
|
|
||||||
|
|
||||||
|
|
||||||
# patch based discriminator
|
|
||||||
@ARCH_REGISTRY.register()
|
|
||||||
class VQGANDiscriminator(nn.Module):
|
|
||||||
def __init__(self, nc=3, ndf=64, n_layers=4, model_path=None):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
layers = [
|
|
||||||
nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1),
|
|
||||||
nn.LeakyReLU(0.2, True),
|
|
||||||
]
|
|
||||||
ndf_mult = 1
|
|
||||||
ndf_mult_prev = 1
|
|
||||||
for n in range(1, n_layers): # gradually increase the number of filters
|
|
||||||
ndf_mult_prev = ndf_mult
|
|
||||||
ndf_mult = min(2**n, 8)
|
|
||||||
layers += [
|
|
||||||
nn.Conv2d(
|
|
||||||
ndf * ndf_mult_prev,
|
|
||||||
ndf * ndf_mult,
|
|
||||||
kernel_size=4,
|
|
||||||
stride=2,
|
|
||||||
padding=1,
|
|
||||||
bias=False,
|
|
||||||
),
|
|
||||||
nn.BatchNorm2d(ndf * ndf_mult),
|
|
||||||
nn.LeakyReLU(0.2, True),
|
|
||||||
]
|
|
||||||
|
|
||||||
ndf_mult_prev = ndf_mult
|
|
||||||
ndf_mult = min(2**n_layers, 8)
|
|
||||||
|
|
||||||
layers += [
|
|
||||||
nn.Conv2d(
|
|
||||||
ndf * ndf_mult_prev,
|
|
||||||
ndf * ndf_mult,
|
|
||||||
kernel_size=4,
|
|
||||||
stride=1,
|
|
||||||
padding=1,
|
|
||||||
bias=False,
|
|
||||||
),
|
|
||||||
nn.BatchNorm2d(ndf * ndf_mult),
|
|
||||||
nn.LeakyReLU(0.2, True),
|
|
||||||
]
|
|
||||||
|
|
||||||
layers += [
|
|
||||||
nn.Conv2d(ndf * ndf_mult, 1, kernel_size=4, stride=1, padding=1)
|
|
||||||
] # output 1 channel prediction map
|
|
||||||
self.main = nn.Sequential(*layers)
|
|
||||||
|
|
||||||
if model_path is not None:
|
|
||||||
chkpt = torch.load(model_path, map_location="cpu")
|
|
||||||
if "params_d" in chkpt:
|
|
||||||
self.load_state_dict(
|
|
||||||
torch.load(model_path, map_location="cpu")["params_d"]
|
|
||||||
)
|
|
||||||
elif "params" in chkpt:
|
|
||||||
self.load_state_dict(
|
|
||||||
torch.load(model_path, map_location="cpu")["params"]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Wrong params!")
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.main(x)
|
|
@ -221,7 +221,7 @@ class ControlNetData:
|
|||||||
control_mode: str = Field(default="balanced")
|
control_mode: str = Field(default="balanced")
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass
|
||||||
class ConditioningData:
|
class ConditioningData:
|
||||||
unconditioned_embeddings: torch.Tensor
|
unconditioned_embeddings: torch.Tensor
|
||||||
text_embeddings: torch.Tensor
|
text_embeddings: torch.Tensor
|
||||||
@ -507,6 +507,40 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
control_data: List[ControlNetData] = None,
|
control_data: List[ControlNetData] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
def _pad_conditioning(cond, target_len, encoder_attention_mask):
|
||||||
|
conditioning_attention_mask = torch.ones((cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype)
|
||||||
|
|
||||||
|
if cond.shape[1] < max_len:
|
||||||
|
conditioning_attention_mask = torch.cat([
|
||||||
|
conditioning_attention_mask,
|
||||||
|
torch.zeros((cond.shape[0], max_len - cond.shape[1]), device=cond.device, dtype=cond.dtype),
|
||||||
|
], dim=1)
|
||||||
|
|
||||||
|
cond = torch.cat([
|
||||||
|
cond,
|
||||||
|
torch.zeros((cond.shape[0], max_len - cond.shape[1], cond.shape[2]), device=cond.device, dtype=cond.dtype),
|
||||||
|
], dim=1)
|
||||||
|
|
||||||
|
if encoder_attention_mask is None:
|
||||||
|
encoder_attention_mask = conditioning_attention_mask
|
||||||
|
else:
|
||||||
|
encoder_attention_mask = torch.cat([
|
||||||
|
encoder_attention_mask,
|
||||||
|
conditioning_attention_mask,
|
||||||
|
])
|
||||||
|
|
||||||
|
return cond, encoder_attention_mask
|
||||||
|
|
||||||
|
encoder_attention_mask = None
|
||||||
|
if conditioning_data.unconditioned_embeddings.shape[1] != conditioning_data.text_embeddings.shape[1]:
|
||||||
|
max_len = max(conditioning_data.unconditioned_embeddings.shape[1], conditioning_data.text_embeddings.shape[1])
|
||||||
|
conditioning_data.unconditioned_embeddings, encoder_attention_mask = _pad_conditioning(
|
||||||
|
conditioning_data.unconditioned_embeddings, max_len, encoder_attention_mask
|
||||||
|
)
|
||||||
|
conditioning_data.text_embeddings, encoder_attention_mask = _pad_conditioning(
|
||||||
|
conditioning_data.text_embeddings, max_len, encoder_attention_mask
|
||||||
|
)
|
||||||
|
|
||||||
self._adjust_memory_efficient_attention(latents)
|
self._adjust_memory_efficient_attention(latents)
|
||||||
if run_id is None:
|
if run_id is None:
|
||||||
run_id = secrets.token_urlsafe(self.ID_LENGTH)
|
run_id = secrets.token_urlsafe(self.ID_LENGTH)
|
||||||
@ -546,6 +580,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
total_step_count=len(timesteps),
|
total_step_count=len(timesteps),
|
||||||
additional_guidance=additional_guidance,
|
additional_guidance=additional_guidance,
|
||||||
control_data=control_data,
|
control_data=control_data,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
latents = step_output.prev_sample
|
latents = step_output.prev_sample
|
||||||
@ -603,6 +638,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
down_block_res_samples, mid_block_res_sample = None, None
|
down_block_res_samples, mid_block_res_sample = None, None
|
||||||
|
|
||||||
if control_data is not None:
|
if control_data is not None:
|
||||||
|
# TODO: rewrite to pass with conditionings
|
||||||
|
encoder_attention_mask = kwargs.get("encoder_attention_mask", None)
|
||||||
# control_data should be type List[ControlNetData]
|
# control_data should be type List[ControlNetData]
|
||||||
# this loop covers both ControlNet (one ControlNetData in list)
|
# this loop covers both ControlNet (one ControlNetData in list)
|
||||||
# and MultiControlNet (multiple ControlNetData in list)
|
# and MultiControlNet (multiple ControlNetData in list)
|
||||||
@ -649,6 +686,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
controlnet_cond=control_datum.image_tensor,
|
controlnet_cond=control_datum.image_tensor,
|
||||||
conditioning_scale=controlnet_weight, # controlnet specific, NOT the guidance scale
|
conditioning_scale=controlnet_weight, # controlnet specific, NOT the guidance scale
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel
|
guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel
|
||||||
return_dict=False,
|
return_dict=False,
|
||||||
)
|
)
|
||||||
|
@ -241,45 +241,12 @@ class InvokeAIDiffuserComponent:
|
|||||||
|
|
||||||
def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs):
|
def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs):
|
||||||
# fast batched path
|
# fast batched path
|
||||||
|
|
||||||
def _pad_conditioning(cond, target_len, encoder_attention_mask):
|
|
||||||
conditioning_attention_mask = torch.ones((cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype)
|
|
||||||
|
|
||||||
if cond.shape[1] < max_len:
|
|
||||||
conditioning_attention_mask = torch.cat([
|
|
||||||
conditioning_attention_mask,
|
|
||||||
torch.zeros((cond.shape[0], max_len - cond.shape[1]), device=cond.device, dtype=cond.dtype),
|
|
||||||
], dim=1)
|
|
||||||
|
|
||||||
cond = torch.cat([
|
|
||||||
cond,
|
|
||||||
torch.zeros((cond.shape[0], max_len - cond.shape[1], cond.shape[2]), device=cond.device, dtype=cond.dtype),
|
|
||||||
], dim=1)
|
|
||||||
|
|
||||||
if encoder_attention_mask is None:
|
|
||||||
encoder_attention_mask = conditioning_attention_mask
|
|
||||||
else:
|
|
||||||
encoder_attention_mask = torch.cat([
|
|
||||||
encoder_attention_mask,
|
|
||||||
conditioning_attention_mask,
|
|
||||||
])
|
|
||||||
|
|
||||||
return cond, encoder_attention_mask
|
|
||||||
|
|
||||||
x_twice = torch.cat([x] * 2)
|
x_twice = torch.cat([x] * 2)
|
||||||
sigma_twice = torch.cat([sigma] * 2)
|
sigma_twice = torch.cat([sigma] * 2)
|
||||||
|
|
||||||
encoder_attention_mask = None
|
|
||||||
if unconditioning.shape[1] != conditioning.shape[1]:
|
|
||||||
max_len = max(unconditioning.shape[1], conditioning.shape[1])
|
|
||||||
unconditioning, encoder_attention_mask = _pad_conditioning(unconditioning, max_len, encoder_attention_mask)
|
|
||||||
conditioning, encoder_attention_mask = _pad_conditioning(conditioning, max_len, encoder_attention_mask)
|
|
||||||
|
|
||||||
both_conditionings = torch.cat([unconditioning, conditioning])
|
both_conditionings = torch.cat([unconditioning, conditioning])
|
||||||
both_results = self.model_forward_callback(
|
both_results = self.model_forward_callback(
|
||||||
x_twice, sigma_twice, both_conditionings,
|
x_twice, sigma_twice, both_conditionings, **kwargs,
|
||||||
encoder_attention_mask=encoder_attention_mask,
|
|
||||||
**kwargs,
|
|
||||||
)
|
)
|
||||||
unconditioned_next_x, conditioned_next_x = both_results.chunk(2)
|
unconditioned_next_x, conditioned_next_x = both_results.chunk(2)
|
||||||
return unconditioned_next_x, conditioned_next_x
|
return unconditioned_next_x, conditioned_next_x
|
||||||
|
634
invokeai/backend/util/hotfixes.py
Normal file
634
invokeai/backend/util/hotfixes.py
Normal file
@ -0,0 +1,634 @@
|
|||||||
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||||
|
from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
|
||||||
|
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
|
||||||
|
from diffusers.models.modeling_utils import ModelMixin
|
||||||
|
from diffusers.models.unet_2d_blocks import (
|
||||||
|
CrossAttnDownBlock2D,
|
||||||
|
DownBlock2D,
|
||||||
|
UNetMidBlock2DCrossAttn,
|
||||||
|
get_down_block,
|
||||||
|
)
|
||||||
|
from diffusers.models.unet_2d_condition import UNet2DConditionModel
|
||||||
|
|
||||||
|
import diffusers
|
||||||
|
from diffusers.models.controlnet import ControlNetConditioningEmbedding, ControlNetOutput, zero_module
|
||||||
|
|
||||||
|
# Modified ControlNetModel with encoder_attention_mask argument added
|
||||||
|
|
||||||
|
class ControlNetModel(ModelMixin, ConfigMixin):
|
||||||
|
"""
|
||||||
|
A ControlNet model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (`int`, defaults to 4):
|
||||||
|
The number of channels in the input sample.
|
||||||
|
flip_sin_to_cos (`bool`, defaults to `True`):
|
||||||
|
Whether to flip the sin to cos in the time embedding.
|
||||||
|
freq_shift (`int`, defaults to 0):
|
||||||
|
The frequency shift to apply to the time embedding.
|
||||||
|
down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
|
||||||
|
The tuple of downsample blocks to use.
|
||||||
|
only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
|
||||||
|
block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
|
||||||
|
The tuple of output channels for each block.
|
||||||
|
layers_per_block (`int`, defaults to 2):
|
||||||
|
The number of layers per block.
|
||||||
|
downsample_padding (`int`, defaults to 1):
|
||||||
|
The padding to use for the downsampling convolution.
|
||||||
|
mid_block_scale_factor (`float`, defaults to 1):
|
||||||
|
The scale factor to use for the mid block.
|
||||||
|
act_fn (`str`, defaults to "silu"):
|
||||||
|
The activation function to use.
|
||||||
|
norm_num_groups (`int`, *optional*, defaults to 32):
|
||||||
|
The number of groups to use for the normalization. If None, normalization and activation layers is skipped
|
||||||
|
in post-processing.
|
||||||
|
norm_eps (`float`, defaults to 1e-5):
|
||||||
|
The epsilon to use for the normalization.
|
||||||
|
cross_attention_dim (`int`, defaults to 1280):
|
||||||
|
The dimension of the cross attention features.
|
||||||
|
attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
|
||||||
|
The dimension of the attention heads.
|
||||||
|
use_linear_projection (`bool`, defaults to `False`):
|
||||||
|
class_embed_type (`str`, *optional*, defaults to `None`):
|
||||||
|
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
|
||||||
|
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
|
||||||
|
num_class_embeds (`int`, *optional*, defaults to 0):
|
||||||
|
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
|
||||||
|
class conditioning with `class_embed_type` equal to `None`.
|
||||||
|
upcast_attention (`bool`, defaults to `False`):
|
||||||
|
resnet_time_scale_shift (`str`, defaults to `"default"`):
|
||||||
|
Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
|
||||||
|
projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
|
||||||
|
The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
|
||||||
|
`class_embed_type="projection"`.
|
||||||
|
controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
|
||||||
|
The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
|
||||||
|
conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
|
||||||
|
The tuple of output channel for each block in the `conditioning_embedding` layer.
|
||||||
|
global_pool_conditions (`bool`, defaults to `False`):
|
||||||
|
"""
|
||||||
|
|
||||||
|
_supports_gradient_checkpointing = True
|
||||||
|
|
||||||
|
@register_to_config
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int = 4,
|
||||||
|
conditioning_channels: int = 3,
|
||||||
|
flip_sin_to_cos: bool = True,
|
||||||
|
freq_shift: int = 0,
|
||||||
|
down_block_types: Tuple[str] = (
|
||||||
|
"CrossAttnDownBlock2D",
|
||||||
|
"CrossAttnDownBlock2D",
|
||||||
|
"CrossAttnDownBlock2D",
|
||||||
|
"DownBlock2D",
|
||||||
|
),
|
||||||
|
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
||||||
|
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
||||||
|
layers_per_block: int = 2,
|
||||||
|
downsample_padding: int = 1,
|
||||||
|
mid_block_scale_factor: float = 1,
|
||||||
|
act_fn: str = "silu",
|
||||||
|
norm_num_groups: Optional[int] = 32,
|
||||||
|
norm_eps: float = 1e-5,
|
||||||
|
cross_attention_dim: int = 1280,
|
||||||
|
attention_head_dim: Union[int, Tuple[int]] = 8,
|
||||||
|
num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
|
||||||
|
use_linear_projection: bool = False,
|
||||||
|
class_embed_type: Optional[str] = None,
|
||||||
|
num_class_embeds: Optional[int] = None,
|
||||||
|
upcast_attention: bool = False,
|
||||||
|
resnet_time_scale_shift: str = "default",
|
||||||
|
projection_class_embeddings_input_dim: Optional[int] = None,
|
||||||
|
controlnet_conditioning_channel_order: str = "rgb",
|
||||||
|
conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
|
||||||
|
global_pool_conditions: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# If `num_attention_heads` is not defined (which is the case for most models)
|
||||||
|
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
|
||||||
|
# The reason for this behavior is to correct for incorrectly named variables that were introduced
|
||||||
|
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
|
||||||
|
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
|
||||||
|
# which is why we correct for the naming here.
|
||||||
|
num_attention_heads = num_attention_heads or attention_head_dim
|
||||||
|
|
||||||
|
# Check inputs
|
||||||
|
if len(block_out_channels) != len(down_block_types):
|
||||||
|
raise ValueError(
|
||||||
|
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
||||||
|
)
|
||||||
|
|
||||||
|
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
|
||||||
|
raise ValueError(
|
||||||
|
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
|
||||||
|
)
|
||||||
|
|
||||||
|
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
|
||||||
|
raise ValueError(
|
||||||
|
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
||||||
|
)
|
||||||
|
|
||||||
|
# input
|
||||||
|
conv_in_kernel = 3
|
||||||
|
conv_in_padding = (conv_in_kernel - 1) // 2
|
||||||
|
self.conv_in = nn.Conv2d(
|
||||||
|
in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
|
||||||
|
)
|
||||||
|
|
||||||
|
# time
|
||||||
|
time_embed_dim = block_out_channels[0] * 4
|
||||||
|
|
||||||
|
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
||||||
|
timestep_input_dim = block_out_channels[0]
|
||||||
|
|
||||||
|
self.time_embedding = TimestepEmbedding(
|
||||||
|
timestep_input_dim,
|
||||||
|
time_embed_dim,
|
||||||
|
act_fn=act_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
# class embedding
|
||||||
|
if class_embed_type is None and num_class_embeds is not None:
|
||||||
|
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
||||||
|
elif class_embed_type == "timestep":
|
||||||
|
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
||||||
|
elif class_embed_type == "identity":
|
||||||
|
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
||||||
|
elif class_embed_type == "projection":
|
||||||
|
if projection_class_embeddings_input_dim is None:
|
||||||
|
raise ValueError(
|
||||||
|
"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
|
||||||
|
)
|
||||||
|
# The projection `class_embed_type` is the same as the timestep `class_embed_type` except
|
||||||
|
# 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
|
||||||
|
# 2. it projects from an arbitrary input dimension.
|
||||||
|
#
|
||||||
|
# Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
|
||||||
|
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
|
||||||
|
# As a result, `TimestepEmbedding` can be passed arbitrary vectors.
|
||||||
|
self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
||||||
|
else:
|
||||||
|
self.class_embedding = None
|
||||||
|
|
||||||
|
# control net conditioning embedding
|
||||||
|
self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
|
||||||
|
conditioning_embedding_channels=block_out_channels[0],
|
||||||
|
block_out_channels=conditioning_embedding_out_channels,
|
||||||
|
conditioning_channels=conditioning_channels,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.down_blocks = nn.ModuleList([])
|
||||||
|
self.controlnet_down_blocks = nn.ModuleList([])
|
||||||
|
|
||||||
|
if isinstance(only_cross_attention, bool):
|
||||||
|
only_cross_attention = [only_cross_attention] * len(down_block_types)
|
||||||
|
|
||||||
|
if isinstance(attention_head_dim, int):
|
||||||
|
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
||||||
|
|
||||||
|
if isinstance(num_attention_heads, int):
|
||||||
|
num_attention_heads = (num_attention_heads,) * len(down_block_types)
|
||||||
|
|
||||||
|
# down
|
||||||
|
output_channel = block_out_channels[0]
|
||||||
|
|
||||||
|
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
||||||
|
controlnet_block = zero_module(controlnet_block)
|
||||||
|
self.controlnet_down_blocks.append(controlnet_block)
|
||||||
|
|
||||||
|
for i, down_block_type in enumerate(down_block_types):
|
||||||
|
input_channel = output_channel
|
||||||
|
output_channel = block_out_channels[i]
|
||||||
|
is_final_block = i == len(block_out_channels) - 1
|
||||||
|
|
||||||
|
down_block = get_down_block(
|
||||||
|
down_block_type,
|
||||||
|
num_layers=layers_per_block,
|
||||||
|
in_channels=input_channel,
|
||||||
|
out_channels=output_channel,
|
||||||
|
temb_channels=time_embed_dim,
|
||||||
|
add_downsample=not is_final_block,
|
||||||
|
resnet_eps=norm_eps,
|
||||||
|
resnet_act_fn=act_fn,
|
||||||
|
resnet_groups=norm_num_groups,
|
||||||
|
cross_attention_dim=cross_attention_dim,
|
||||||
|
num_attention_heads=num_attention_heads[i],
|
||||||
|
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
||||||
|
downsample_padding=downsample_padding,
|
||||||
|
use_linear_projection=use_linear_projection,
|
||||||
|
only_cross_attention=only_cross_attention[i],
|
||||||
|
upcast_attention=upcast_attention,
|
||||||
|
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||||
|
)
|
||||||
|
self.down_blocks.append(down_block)
|
||||||
|
|
||||||
|
for _ in range(layers_per_block):
|
||||||
|
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
||||||
|
controlnet_block = zero_module(controlnet_block)
|
||||||
|
self.controlnet_down_blocks.append(controlnet_block)
|
||||||
|
|
||||||
|
if not is_final_block:
|
||||||
|
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
||||||
|
controlnet_block = zero_module(controlnet_block)
|
||||||
|
self.controlnet_down_blocks.append(controlnet_block)
|
||||||
|
|
||||||
|
# mid
|
||||||
|
mid_block_channel = block_out_channels[-1]
|
||||||
|
|
||||||
|
controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
|
||||||
|
controlnet_block = zero_module(controlnet_block)
|
||||||
|
self.controlnet_mid_block = controlnet_block
|
||||||
|
|
||||||
|
self.mid_block = UNetMidBlock2DCrossAttn(
|
||||||
|
in_channels=mid_block_channel,
|
||||||
|
temb_channels=time_embed_dim,
|
||||||
|
resnet_eps=norm_eps,
|
||||||
|
resnet_act_fn=act_fn,
|
||||||
|
output_scale_factor=mid_block_scale_factor,
|
||||||
|
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||||
|
cross_attention_dim=cross_attention_dim,
|
||||||
|
num_attention_heads=num_attention_heads[-1],
|
||||||
|
resnet_groups=norm_num_groups,
|
||||||
|
use_linear_projection=use_linear_projection,
|
||||||
|
upcast_attention=upcast_attention,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_unet(
|
||||||
|
cls,
|
||||||
|
unet: UNet2DConditionModel,
|
||||||
|
controlnet_conditioning_channel_order: str = "rgb",
|
||||||
|
conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
|
||||||
|
load_weights_from_unet: bool = True,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`].
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
unet (`UNet2DConditionModel`):
|
||||||
|
The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied
|
||||||
|
where applicable.
|
||||||
|
"""
|
||||||
|
controlnet = cls(
|
||||||
|
in_channels=unet.config.in_channels,
|
||||||
|
flip_sin_to_cos=unet.config.flip_sin_to_cos,
|
||||||
|
freq_shift=unet.config.freq_shift,
|
||||||
|
down_block_types=unet.config.down_block_types,
|
||||||
|
only_cross_attention=unet.config.only_cross_attention,
|
||||||
|
block_out_channels=unet.config.block_out_channels,
|
||||||
|
layers_per_block=unet.config.layers_per_block,
|
||||||
|
downsample_padding=unet.config.downsample_padding,
|
||||||
|
mid_block_scale_factor=unet.config.mid_block_scale_factor,
|
||||||
|
act_fn=unet.config.act_fn,
|
||||||
|
norm_num_groups=unet.config.norm_num_groups,
|
||||||
|
norm_eps=unet.config.norm_eps,
|
||||||
|
cross_attention_dim=unet.config.cross_attention_dim,
|
||||||
|
attention_head_dim=unet.config.attention_head_dim,
|
||||||
|
num_attention_heads=unet.config.num_attention_heads,
|
||||||
|
use_linear_projection=unet.config.use_linear_projection,
|
||||||
|
class_embed_type=unet.config.class_embed_type,
|
||||||
|
num_class_embeds=unet.config.num_class_embeds,
|
||||||
|
upcast_attention=unet.config.upcast_attention,
|
||||||
|
resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
|
||||||
|
projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
|
||||||
|
controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
|
||||||
|
conditioning_embedding_out_channels=conditioning_embedding_out_channels,
|
||||||
|
)
|
||||||
|
|
||||||
|
if load_weights_from_unet:
|
||||||
|
controlnet.conv_in.load_state_dict(unet.conv_in.state_dict())
|
||||||
|
controlnet.time_proj.load_state_dict(unet.time_proj.state_dict())
|
||||||
|
controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
|
||||||
|
|
||||||
|
if controlnet.class_embedding:
|
||||||
|
controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())
|
||||||
|
|
||||||
|
controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict())
|
||||||
|
controlnet.mid_block.load_state_dict(unet.mid_block.state_dict())
|
||||||
|
|
||||||
|
return controlnet
|
||||||
|
|
||||||
|
@property
|
||||||
|
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||||
|
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||||
|
r"""
|
||||||
|
Returns:
|
||||||
|
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||||
|
indexed by its weight name.
|
||||||
|
"""
|
||||||
|
# set recursively
|
||||||
|
processors = {}
|
||||||
|
|
||||||
|
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||||
|
if hasattr(module, "set_processor"):
|
||||||
|
processors[f"{name}.processor"] = module.processor
|
||||||
|
|
||||||
|
for sub_name, child in module.named_children():
|
||||||
|
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||||
|
|
||||||
|
return processors
|
||||||
|
|
||||||
|
for name, module in self.named_children():
|
||||||
|
fn_recursive_add_processors(name, module, processors)
|
||||||
|
|
||||||
|
return processors
|
||||||
|
|
||||||
|
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||||
|
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||||
|
r"""
|
||||||
|
Sets the attention processor to use to compute attention.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||||
|
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||||
|
for **all** `Attention` layers.
|
||||||
|
|
||||||
|
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||||
|
processor. This is strongly recommended when setting trainable attention processors.
|
||||||
|
|
||||||
|
"""
|
||||||
|
count = len(self.attn_processors.keys())
|
||||||
|
|
||||||
|
if isinstance(processor, dict) and len(processor) != count:
|
||||||
|
raise ValueError(
|
||||||
|
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||||
|
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||||
|
)
|
||||||
|
|
||||||
|
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||||
|
if hasattr(module, "set_processor"):
|
||||||
|
if not isinstance(processor, dict):
|
||||||
|
module.set_processor(processor)
|
||||||
|
else:
|
||||||
|
module.set_processor(processor.pop(f"{name}.processor"))
|
||||||
|
|
||||||
|
for sub_name, child in module.named_children():
|
||||||
|
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||||
|
|
||||||
|
for name, module in self.named_children():
|
||||||
|
fn_recursive_attn_processor(name, module, processor)
|
||||||
|
|
||||||
|
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
||||||
|
def set_default_attn_processor(self):
|
||||||
|
"""
|
||||||
|
Disables custom attention processors and sets the default attention implementation.
|
||||||
|
"""
|
||||||
|
self.set_attn_processor(AttnProcessor())
|
||||||
|
|
||||||
|
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
|
||||||
|
def set_attention_slice(self, slice_size):
|
||||||
|
r"""
|
||||||
|
Enable sliced attention computation.
|
||||||
|
|
||||||
|
When this option is enabled, the attention module splits the input tensor in slices to compute attention in
|
||||||
|
several steps. This is useful for saving some memory in exchange for a small decrease in speed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
|
||||||
|
When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
|
||||||
|
`"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
|
||||||
|
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
|
||||||
|
must be a multiple of `slice_size`.
|
||||||
|
"""
|
||||||
|
sliceable_head_dims = []
|
||||||
|
|
||||||
|
def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
|
||||||
|
if hasattr(module, "set_attention_slice"):
|
||||||
|
sliceable_head_dims.append(module.sliceable_head_dim)
|
||||||
|
|
||||||
|
for child in module.children():
|
||||||
|
fn_recursive_retrieve_sliceable_dims(child)
|
||||||
|
|
||||||
|
# retrieve number of attention layers
|
||||||
|
for module in self.children():
|
||||||
|
fn_recursive_retrieve_sliceable_dims(module)
|
||||||
|
|
||||||
|
num_sliceable_layers = len(sliceable_head_dims)
|
||||||
|
|
||||||
|
if slice_size == "auto":
|
||||||
|
# half the attention head size is usually a good trade-off between
|
||||||
|
# speed and memory
|
||||||
|
slice_size = [dim // 2 for dim in sliceable_head_dims]
|
||||||
|
elif slice_size == "max":
|
||||||
|
# make smallest slice possible
|
||||||
|
slice_size = num_sliceable_layers * [1]
|
||||||
|
|
||||||
|
slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
|
||||||
|
|
||||||
|
if len(slice_size) != len(sliceable_head_dims):
|
||||||
|
raise ValueError(
|
||||||
|
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
|
||||||
|
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
|
||||||
|
)
|
||||||
|
|
||||||
|
for i in range(len(slice_size)):
|
||||||
|
size = slice_size[i]
|
||||||
|
dim = sliceable_head_dims[i]
|
||||||
|
if size is not None and size > dim:
|
||||||
|
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
|
||||||
|
|
||||||
|
# Recursively walk through all the children.
|
||||||
|
# Any children which exposes the set_attention_slice method
|
||||||
|
# gets the message
|
||||||
|
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
|
||||||
|
if hasattr(module, "set_attention_slice"):
|
||||||
|
module.set_attention_slice(slice_size.pop())
|
||||||
|
|
||||||
|
for child in module.children():
|
||||||
|
fn_recursive_set_attention_slice(child, slice_size)
|
||||||
|
|
||||||
|
reversed_slice_size = list(reversed(slice_size))
|
||||||
|
for module in self.children():
|
||||||
|
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
||||||
|
|
||||||
|
def _set_gradient_checkpointing(self, module, value=False):
|
||||||
|
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
|
||||||
|
module.gradient_checkpointing = value
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
sample: torch.FloatTensor,
|
||||||
|
timestep: Union[torch.Tensor, float, int],
|
||||||
|
encoder_hidden_states: torch.Tensor,
|
||||||
|
controlnet_cond: torch.FloatTensor,
|
||||||
|
conditioning_scale: float = 1.0,
|
||||||
|
class_labels: Optional[torch.Tensor] = None,
|
||||||
|
timestep_cond: Optional[torch.Tensor] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
guess_mode: bool = False,
|
||||||
|
return_dict: bool = True,
|
||||||
|
) -> Union[ControlNetOutput, Tuple]:
|
||||||
|
"""
|
||||||
|
The [`ControlNetModel`] forward method.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sample (`torch.FloatTensor`):
|
||||||
|
The noisy input tensor.
|
||||||
|
timestep (`Union[torch.Tensor, float, int]`):
|
||||||
|
The number of timesteps to denoise an input.
|
||||||
|
encoder_hidden_states (`torch.Tensor`):
|
||||||
|
The encoder hidden states.
|
||||||
|
controlnet_cond (`torch.FloatTensor`):
|
||||||
|
The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
|
||||||
|
conditioning_scale (`float`, defaults to `1.0`):
|
||||||
|
The scale factor for ControlNet outputs.
|
||||||
|
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
|
||||||
|
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
|
||||||
|
timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
|
||||||
|
attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
|
||||||
|
cross_attention_kwargs(`dict[str]`, *optional*, defaults to `None`):
|
||||||
|
A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
|
||||||
|
encoder_attention_mask (`torch.Tensor`):
|
||||||
|
A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
|
||||||
|
`True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
|
||||||
|
which adds large negative values to the attention scores corresponding to "discard" tokens.
|
||||||
|
guess_mode (`bool`, defaults to `False`):
|
||||||
|
In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
|
||||||
|
you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
|
||||||
|
return_dict (`bool`, defaults to `True`):
|
||||||
|
Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[`~models.controlnet.ControlNetOutput`] **or** `tuple`:
|
||||||
|
If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
|
||||||
|
returned where the first element is the sample tensor.
|
||||||
|
"""
|
||||||
|
# check channel order
|
||||||
|
channel_order = self.config.controlnet_conditioning_channel_order
|
||||||
|
|
||||||
|
if channel_order == "rgb":
|
||||||
|
# in rgb order by default
|
||||||
|
...
|
||||||
|
elif channel_order == "bgr":
|
||||||
|
controlnet_cond = torch.flip(controlnet_cond, dims=[1])
|
||||||
|
else:
|
||||||
|
raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}")
|
||||||
|
|
||||||
|
# prepare attention_mask
|
||||||
|
if attention_mask is not None:
|
||||||
|
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
||||||
|
attention_mask = attention_mask.unsqueeze(1)
|
||||||
|
|
||||||
|
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
||||||
|
if encoder_attention_mask is not None:
|
||||||
|
encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
|
||||||
|
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
||||||
|
|
||||||
|
# 1. time
|
||||||
|
timesteps = timestep
|
||||||
|
if not torch.is_tensor(timesteps):
|
||||||
|
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
||||||
|
# This would be a good case for the `match` statement (Python 3.10+)
|
||||||
|
is_mps = sample.device.type == "mps"
|
||||||
|
if isinstance(timestep, float):
|
||||||
|
dtype = torch.float32 if is_mps else torch.float64
|
||||||
|
else:
|
||||||
|
dtype = torch.int32 if is_mps else torch.int64
|
||||||
|
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
||||||
|
elif len(timesteps.shape) == 0:
|
||||||
|
timesteps = timesteps[None].to(sample.device)
|
||||||
|
|
||||||
|
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||||
|
timesteps = timesteps.expand(sample.shape[0])
|
||||||
|
|
||||||
|
t_emb = self.time_proj(timesteps)
|
||||||
|
|
||||||
|
# timesteps does not contain any weights and will always return f32 tensors
|
||||||
|
# but time_embedding might actually be running in fp16. so we need to cast here.
|
||||||
|
# there might be better ways to encapsulate this.
|
||||||
|
t_emb = t_emb.to(dtype=sample.dtype)
|
||||||
|
|
||||||
|
emb = self.time_embedding(t_emb, timestep_cond)
|
||||||
|
|
||||||
|
if self.class_embedding is not None:
|
||||||
|
if class_labels is None:
|
||||||
|
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
||||||
|
|
||||||
|
if self.config.class_embed_type == "timestep":
|
||||||
|
class_labels = self.time_proj(class_labels)
|
||||||
|
|
||||||
|
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
||||||
|
emb = emb + class_emb
|
||||||
|
|
||||||
|
# 2. pre-process
|
||||||
|
sample = self.conv_in(sample)
|
||||||
|
|
||||||
|
controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
|
||||||
|
|
||||||
|
sample = sample + controlnet_cond
|
||||||
|
|
||||||
|
# 3. down
|
||||||
|
down_block_res_samples = (sample,)
|
||||||
|
for downsample_block in self.down_blocks:
|
||||||
|
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
||||||
|
sample, res_samples = downsample_block(
|
||||||
|
hidden_states=sample,
|
||||||
|
temb=emb,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
cross_attention_kwargs=cross_attention_kwargs,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
||||||
|
|
||||||
|
down_block_res_samples += res_samples
|
||||||
|
|
||||||
|
# 4. mid
|
||||||
|
if self.mid_block is not None:
|
||||||
|
sample = self.mid_block(
|
||||||
|
sample,
|
||||||
|
emb,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
cross_attention_kwargs=cross_attention_kwargs,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 5. Control net blocks
|
||||||
|
|
||||||
|
controlnet_down_block_res_samples = ()
|
||||||
|
|
||||||
|
for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
|
||||||
|
down_block_res_sample = controlnet_block(down_block_res_sample)
|
||||||
|
controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
|
||||||
|
|
||||||
|
down_block_res_samples = controlnet_down_block_res_samples
|
||||||
|
|
||||||
|
mid_block_res_sample = self.controlnet_mid_block(sample)
|
||||||
|
|
||||||
|
# 6. scaling
|
||||||
|
if guess_mode and not self.config.global_pool_conditions:
|
||||||
|
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
|
||||||
|
|
||||||
|
scales = scales * conditioning_scale
|
||||||
|
down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
|
||||||
|
mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
|
||||||
|
else:
|
||||||
|
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
|
||||||
|
mid_block_res_sample = mid_block_res_sample * conditioning_scale
|
||||||
|
|
||||||
|
if self.config.global_pool_conditions:
|
||||||
|
down_block_res_samples = [
|
||||||
|
torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
|
||||||
|
]
|
||||||
|
mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (down_block_res_samples, mid_block_res_sample)
|
||||||
|
|
||||||
|
return ControlNetOutput(
|
||||||
|
down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
|
||||||
|
)
|
||||||
|
|
||||||
|
diffusers.ControlNetModel = ControlNetModel
|
||||||
|
diffusers.models.controlnet.ControlNetModel = ControlNetModel
|
@ -58,22 +58,29 @@ sd-1/main/waifu-diffusion:
|
|||||||
recommended: False
|
recommended: False
|
||||||
sd-1/controlnet/canny:
|
sd-1/controlnet/canny:
|
||||||
repo_id: lllyasviel/control_v11p_sd15_canny
|
repo_id: lllyasviel/control_v11p_sd15_canny
|
||||||
|
recommended: True
|
||||||
sd-1/controlnet/inpaint:
|
sd-1/controlnet/inpaint:
|
||||||
repo_id: lllyasviel/control_v11p_sd15_inpaint
|
repo_id: lllyasviel/control_v11p_sd15_inpaint
|
||||||
sd-1/controlnet/mlsd:
|
sd-1/controlnet/mlsd:
|
||||||
repo_id: lllyasviel/control_v11p_sd15_mlsd
|
repo_id: lllyasviel/control_v11p_sd15_mlsd
|
||||||
sd-1/controlnet/depth:
|
sd-1/controlnet/depth:
|
||||||
repo_id: lllyasviel/control_v11f1p_sd15_depth
|
repo_id: lllyasviel/control_v11f1p_sd15_depth
|
||||||
|
recommended: True
|
||||||
sd-1/controlnet/normal_bae:
|
sd-1/controlnet/normal_bae:
|
||||||
repo_id: lllyasviel/control_v11p_sd15_normalbae
|
repo_id: lllyasviel/control_v11p_sd15_normalbae
|
||||||
sd-1/controlnet/seg:
|
sd-1/controlnet/seg:
|
||||||
repo_id: lllyasviel/control_v11p_sd15_seg
|
repo_id: lllyasviel/control_v11p_sd15_seg
|
||||||
sd-1/controlnet/lineart:
|
sd-1/controlnet/lineart:
|
||||||
repo_id: lllyasviel/control_v11p_sd15_lineart
|
repo_id: lllyasviel/control_v11p_sd15_lineart
|
||||||
|
recommended: True
|
||||||
sd-1/controlnet/lineart_anime:
|
sd-1/controlnet/lineart_anime:
|
||||||
repo_id: lllyasviel/control_v11p_sd15s2_lineart_anime
|
repo_id: lllyasviel/control_v11p_sd15s2_lineart_anime
|
||||||
|
sd-1/controlnet/openpose:
|
||||||
|
repo_id: lllyasviel/control_v11p_sd15_openpose
|
||||||
|
recommended: True
|
||||||
sd-1/controlnet/scribble:
|
sd-1/controlnet/scribble:
|
||||||
repo_id: lllyasviel/control_v11p_sd15_scribble
|
repo_id: lllyasviel/control_v11p_sd15_scribble
|
||||||
|
recommended: False
|
||||||
sd-1/controlnet/softedge:
|
sd-1/controlnet/softedge:
|
||||||
repo_id: lllyasviel/control_v11p_sd15_softedge
|
repo_id: lllyasviel/control_v11p_sd15_softedge
|
||||||
sd-1/controlnet/shuffle:
|
sd-1/controlnet/shuffle:
|
||||||
@ -84,6 +91,7 @@ sd-1/controlnet/ip2p:
|
|||||||
repo_id: lllyasviel/control_v11e_sd15_ip2p
|
repo_id: lllyasviel/control_v11e_sd15_ip2p
|
||||||
sd-1/embedding/EasyNegative:
|
sd-1/embedding/EasyNegative:
|
||||||
path: https://huggingface.co/embed/EasyNegative/resolve/main/EasyNegative.safetensors
|
path: https://huggingface.co/embed/EasyNegative/resolve/main/EasyNegative.safetensors
|
||||||
|
recommended: True
|
||||||
sd-1/embedding/ahx-beta-453407d:
|
sd-1/embedding/ahx-beta-453407d:
|
||||||
repo_id: sd-concepts-library/ahx-beta-453407d
|
repo_id: sd-concepts-library/ahx-beta-453407d
|
||||||
sd-1/lora/LowRA:
|
sd-1/lora/LowRA:
|
||||||
|
@ -256,6 +256,8 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
|||||||
widgets = dict()
|
widgets = dict()
|
||||||
model_list = [x for x in self.all_models if self.all_models[x].model_type==model_type and not x in exclude]
|
model_list = [x for x in self.all_models if self.all_models[x].model_type==model_type and not x in exclude]
|
||||||
model_labels = [self.model_labels[x] for x in model_list]
|
model_labels = [self.model_labels[x] for x in model_list]
|
||||||
|
|
||||||
|
show_recommended = len(self.installed_models)==0
|
||||||
if len(model_list) > 0:
|
if len(model_list) > 0:
|
||||||
max_width = max([len(x) for x in model_labels])
|
max_width = max([len(x) for x in model_labels])
|
||||||
columns = window_width // (max_width+8) # 8 characters for "[x] " and padding
|
columns = window_width // (max_width+8) # 8 characters for "[x] " and padding
|
||||||
@ -280,7 +282,8 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
|||||||
value=[
|
value=[
|
||||||
model_list.index(x)
|
model_list.index(x)
|
||||||
for x in model_list
|
for x in model_list
|
||||||
if self.all_models[x].installed
|
if (show_recommended and self.all_models[x].recommended) \
|
||||||
|
or self.all_models[x].installed
|
||||||
],
|
],
|
||||||
max_height=len(model_list)//columns + 1,
|
max_height=len(model_list)//columns + 1,
|
||||||
relx=4,
|
relx=4,
|
||||||
@ -672,7 +675,9 @@ def select_and_download_models(opt: Namespace):
|
|||||||
# pass
|
# pass
|
||||||
|
|
||||||
installer = ModelInstall(config, prediction_type_helper=helper)
|
installer = ModelInstall(config, prediction_type_helper=helper)
|
||||||
if opt.add or opt.delete:
|
if opt.list_models:
|
||||||
|
installer.list_models(opt.list_models)
|
||||||
|
elif opt.add or opt.delete:
|
||||||
selections = InstallSelections(
|
selections = InstallSelections(
|
||||||
install_models = opt.add or [],
|
install_models = opt.add or [],
|
||||||
remove_models = opt.delete or []
|
remove_models = opt.delete or []
|
||||||
@ -745,7 +750,7 @@ def main():
|
|||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--list-models",
|
"--list-models",
|
||||||
choices=["diffusers","loras","controlnets","tis"],
|
choices=[x.value for x in ModelType],
|
||||||
help="list installed models",
|
help="list installed models",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -773,7 +778,7 @@ def main():
|
|||||||
config.parse_args(invoke_args)
|
config.parse_args(invoke_args)
|
||||||
logger = InvokeAILogger().getLogger(config=config)
|
logger = InvokeAILogger().getLogger(config=config)
|
||||||
|
|
||||||
if not (config.conf_path / 'models.yaml').exists():
|
if not config.model_conf_path.exists():
|
||||||
logger.info(
|
logger.info(
|
||||||
"Your InvokeAI root directory is not set up. Calling invokeai-configure."
|
"Your InvokeAI root directory is not set up. Calling invokeai-configure."
|
||||||
)
|
)
|
||||||
|
355
invokeai/frontend/web/src/services/api/schema.d.ts
vendored
355
invokeai/frontend/web/src/services/api/schema.d.ts
vendored
@ -75,11 +75,6 @@ export type paths = {
|
|||||||
* @description Gets a list of models
|
* @description Gets a list of models
|
||||||
*/
|
*/
|
||||||
get: operations["list_models"];
|
get: operations["list_models"];
|
||||||
/**
|
|
||||||
* Import Model
|
|
||||||
* @description Add a model using its local path, repo_id, or remote URL
|
|
||||||
*/
|
|
||||||
post: operations["import_model"];
|
|
||||||
};
|
};
|
||||||
"/api/v1/models/{base_model}/{model_type}/{model_name}": {
|
"/api/v1/models/{base_model}/{model_type}/{model_name}": {
|
||||||
/**
|
/**
|
||||||
@ -93,13 +88,53 @@ export type paths = {
|
|||||||
*/
|
*/
|
||||||
patch: operations["update_model"];
|
patch: operations["update_model"];
|
||||||
};
|
};
|
||||||
|
"/api/v1/models/import": {
|
||||||
|
/**
|
||||||
|
* Import Model
|
||||||
|
* @description Add a model using its local path, repo_id, or remote URL. Model characteristics will be probed and configured automatically
|
||||||
|
*/
|
||||||
|
post: operations["import_model"];
|
||||||
|
};
|
||||||
|
"/api/v1/models/add": {
|
||||||
|
/**
|
||||||
|
* Add Model
|
||||||
|
* @description Add a model using the configuration information appropriate for its type. Only local models can be added by path
|
||||||
|
*/
|
||||||
|
post: operations["add_model"];
|
||||||
|
};
|
||||||
|
"/api/v1/models/rename/{base_model}/{model_type}/{model_name}": {
|
||||||
|
/**
|
||||||
|
* Rename Model
|
||||||
|
* @description Rename a model
|
||||||
|
*/
|
||||||
|
post: operations["rename_model"];
|
||||||
|
};
|
||||||
"/api/v1/models/convert/{base_model}/{model_type}/{model_name}": {
|
"/api/v1/models/convert/{base_model}/{model_type}/{model_name}": {
|
||||||
/**
|
/**
|
||||||
* Convert Model
|
* Convert Model
|
||||||
* @description Convert a checkpoint model into a diffusers model
|
* @description Convert a checkpoint model into a diffusers model, optionally saving to the indicated destination directory, or `models` if none.
|
||||||
*/
|
*/
|
||||||
put: operations["convert_model"];
|
put: operations["convert_model"];
|
||||||
};
|
};
|
||||||
|
"/api/v1/models/search": {
|
||||||
|
/** Search For Models */
|
||||||
|
get: operations["search_for_models"];
|
||||||
|
};
|
||||||
|
"/api/v1/models/ckpt_confs": {
|
||||||
|
/**
|
||||||
|
* List Ckpt Configs
|
||||||
|
* @description Return a list of the legacy checkpoint configuration files stored in `ROOT/configs/stable-diffusion`, relative to ROOT.
|
||||||
|
*/
|
||||||
|
get: operations["list_ckpt_configs"];
|
||||||
|
};
|
||||||
|
"/api/v1/models/sync": {
|
||||||
|
/**
|
||||||
|
* Sync To Config
|
||||||
|
* @description Call after making changes to models.yaml, autoimport directories or models directory to synchronize
|
||||||
|
* in-memory data structures with disk data structures.
|
||||||
|
*/
|
||||||
|
get: operations["sync_to_config"];
|
||||||
|
};
|
||||||
"/api/v1/models/merge/{base_model}": {
|
"/api/v1/models/merge/{base_model}": {
|
||||||
/**
|
/**
|
||||||
* Merge Models
|
* Merge Models
|
||||||
@ -397,6 +432,11 @@ export type components = {
|
|||||||
* @default false
|
* @default false
|
||||||
*/
|
*/
|
||||||
force?: boolean;
|
force?: boolean;
|
||||||
|
/**
|
||||||
|
* Merge Dest Directory
|
||||||
|
* @description Save the merged model to the designated directory (with 'merged_model_name' appended)
|
||||||
|
*/
|
||||||
|
merge_dest_directory?: string;
|
||||||
};
|
};
|
||||||
/** Body_remove_board_image */
|
/** Body_remove_board_image */
|
||||||
Body_remove_board_image: {
|
Body_remove_board_image: {
|
||||||
@ -1186,7 +1226,7 @@ export type components = {
|
|||||||
* @description The nodes in this graph
|
* @description The nodes in this graph
|
||||||
*/
|
*/
|
||||||
nodes?: {
|
nodes?: {
|
||||||
[key: string]: (components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["UpscaleInvocation"] | components["schemas"]["RestoreFaceInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"]) | undefined;
|
[key: string]: (components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["RealESRGANInvocation"] | components["schemas"]["RestoreFaceInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"]) | undefined;
|
||||||
};
|
};
|
||||||
/**
|
/**
|
||||||
* Edges
|
* Edges
|
||||||
@ -3302,7 +3342,7 @@ export type components = {
|
|||||||
/** ModelsList */
|
/** ModelsList */
|
||||||
ModelsList: {
|
ModelsList: {
|
||||||
/** Models */
|
/** Models */
|
||||||
models: (components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"])[];
|
models: (components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"])[];
|
||||||
};
|
};
|
||||||
/**
|
/**
|
||||||
* MultiplyInvocation
|
* MultiplyInvocation
|
||||||
@ -3893,6 +3933,41 @@ export type components = {
|
|||||||
*/
|
*/
|
||||||
step?: number;
|
step?: number;
|
||||||
};
|
};
|
||||||
|
/**
|
||||||
|
* RealESRGANInvocation
|
||||||
|
* @description Upscales an image using RealESRGAN.
|
||||||
|
*/
|
||||||
|
RealESRGANInvocation: {
|
||||||
|
/**
|
||||||
|
* Id
|
||||||
|
* @description The id of this node. Must be unique among all nodes.
|
||||||
|
*/
|
||||||
|
id: string;
|
||||||
|
/**
|
||||||
|
* Is Intermediate
|
||||||
|
* @description Whether or not this node is an intermediate node.
|
||||||
|
* @default false
|
||||||
|
*/
|
||||||
|
is_intermediate?: boolean;
|
||||||
|
/**
|
||||||
|
* Type
|
||||||
|
* @default realesrgan
|
||||||
|
* @enum {string}
|
||||||
|
*/
|
||||||
|
type?: "realesrgan";
|
||||||
|
/**
|
||||||
|
* Image
|
||||||
|
* @description The input image
|
||||||
|
*/
|
||||||
|
image?: components["schemas"]["ImageField"];
|
||||||
|
/**
|
||||||
|
* Model Name
|
||||||
|
* @description The Real-ESRGAN model to use
|
||||||
|
* @default RealESRGAN_x4plus.pth
|
||||||
|
* @enum {string}
|
||||||
|
*/
|
||||||
|
model_name?: "RealESRGAN_x4plus.pth" | "RealESRGAN_x4plus_anime_6B.pth" | "ESRGAN_SRx4_DF2KOST_official-ff704c30.pth";
|
||||||
|
};
|
||||||
/**
|
/**
|
||||||
* ResizeLatentsInvocation
|
* ResizeLatentsInvocation
|
||||||
* @description Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8.
|
* @description Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8.
|
||||||
@ -4452,47 +4527,6 @@ export type components = {
|
|||||||
*/
|
*/
|
||||||
loras: (components["schemas"]["LoraInfo"])[];
|
loras: (components["schemas"]["LoraInfo"])[];
|
||||||
};
|
};
|
||||||
/**
|
|
||||||
* UpscaleInvocation
|
|
||||||
* @description Upscales an image.
|
|
||||||
*/
|
|
||||||
UpscaleInvocation: {
|
|
||||||
/**
|
|
||||||
* Id
|
|
||||||
* @description The id of this node. Must be unique among all nodes.
|
|
||||||
*/
|
|
||||||
id: string;
|
|
||||||
/**
|
|
||||||
* Is Intermediate
|
|
||||||
* @description Whether or not this node is an intermediate node.
|
|
||||||
* @default false
|
|
||||||
*/
|
|
||||||
is_intermediate?: boolean;
|
|
||||||
/**
|
|
||||||
* Type
|
|
||||||
* @default upscale
|
|
||||||
* @enum {string}
|
|
||||||
*/
|
|
||||||
type?: "upscale";
|
|
||||||
/**
|
|
||||||
* Image
|
|
||||||
* @description The input image
|
|
||||||
*/
|
|
||||||
image?: components["schemas"]["ImageField"];
|
|
||||||
/**
|
|
||||||
* Strength
|
|
||||||
* @description The strength
|
|
||||||
* @default 0.75
|
|
||||||
*/
|
|
||||||
strength?: number;
|
|
||||||
/**
|
|
||||||
* Level
|
|
||||||
* @description The upscale level
|
|
||||||
* @default 2
|
|
||||||
* @enum {integer}
|
|
||||||
*/
|
|
||||||
level?: 2 | 4;
|
|
||||||
};
|
|
||||||
/**
|
/**
|
||||||
* VAEModelField
|
* VAEModelField
|
||||||
* @description Vae model field
|
* @description Vae model field
|
||||||
@ -4619,18 +4653,18 @@ export type components = {
|
|||||||
*/
|
*/
|
||||||
image?: components["schemas"]["ImageField"];
|
image?: components["schemas"]["ImageField"];
|
||||||
};
|
};
|
||||||
/**
|
|
||||||
* StableDiffusion1ModelFormat
|
|
||||||
* @description An enumeration.
|
|
||||||
* @enum {string}
|
|
||||||
*/
|
|
||||||
StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
|
|
||||||
/**
|
/**
|
||||||
* StableDiffusion2ModelFormat
|
* StableDiffusion2ModelFormat
|
||||||
* @description An enumeration.
|
* @description An enumeration.
|
||||||
* @enum {string}
|
* @enum {string}
|
||||||
*/
|
*/
|
||||||
StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
|
StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
|
||||||
|
/**
|
||||||
|
* StableDiffusion1ModelFormat
|
||||||
|
* @description An enumeration.
|
||||||
|
* @enum {string}
|
||||||
|
*/
|
||||||
|
StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
|
||||||
};
|
};
|
||||||
responses: never;
|
responses: never;
|
||||||
parameters: never;
|
parameters: never;
|
||||||
@ -4741,7 +4775,7 @@ export type operations = {
|
|||||||
};
|
};
|
||||||
requestBody: {
|
requestBody: {
|
||||||
content: {
|
content: {
|
||||||
"application/json": components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["UpscaleInvocation"] | components["schemas"]["RestoreFaceInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"];
|
"application/json": components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["RealESRGANInvocation"] | components["schemas"]["RestoreFaceInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"];
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
responses: {
|
responses: {
|
||||||
@ -4778,7 +4812,7 @@ export type operations = {
|
|||||||
};
|
};
|
||||||
requestBody: {
|
requestBody: {
|
||||||
content: {
|
content: {
|
||||||
"application/json": components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["UpscaleInvocation"] | components["schemas"]["RestoreFaceInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"];
|
"application/json": components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["RealESRGANInvocation"] | components["schemas"]["RestoreFaceInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"];
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
responses: {
|
responses: {
|
||||||
@ -4997,37 +5031,6 @@ export type operations = {
|
|||||||
};
|
};
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
/**
|
|
||||||
* Import Model
|
|
||||||
* @description Add a model using its local path, repo_id, or remote URL
|
|
||||||
*/
|
|
||||||
import_model: {
|
|
||||||
requestBody: {
|
|
||||||
content: {
|
|
||||||
"application/json": components["schemas"]["Body_import_model"];
|
|
||||||
};
|
|
||||||
};
|
|
||||||
responses: {
|
|
||||||
/** @description The model imported successfully */
|
|
||||||
201: {
|
|
||||||
content: {
|
|
||||||
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"];
|
|
||||||
};
|
|
||||||
};
|
|
||||||
/** @description The model could not be found */
|
|
||||||
404: never;
|
|
||||||
/** @description There is already a model corresponding to this path or repo_id */
|
|
||||||
409: never;
|
|
||||||
/** @description Validation Error */
|
|
||||||
422: {
|
|
||||||
content: {
|
|
||||||
"application/json": components["schemas"]["HTTPValidationError"];
|
|
||||||
};
|
|
||||||
};
|
|
||||||
/** @description The model appeared to import successfully, but could not be found in the model manager */
|
|
||||||
424: never;
|
|
||||||
};
|
|
||||||
};
|
|
||||||
/**
|
/**
|
||||||
* Delete Model
|
* Delete Model
|
||||||
* @description Delete Model
|
* @description Delete Model
|
||||||
@ -5044,12 +5047,6 @@ export type operations = {
|
|||||||
};
|
};
|
||||||
};
|
};
|
||||||
responses: {
|
responses: {
|
||||||
/** @description Successful Response */
|
|
||||||
200: {
|
|
||||||
content: {
|
|
||||||
"application/json": unknown;
|
|
||||||
};
|
|
||||||
};
|
|
||||||
/** @description Model deleted successfully */
|
/** @description Model deleted successfully */
|
||||||
204: never;
|
204: never;
|
||||||
/** @description Model not found */
|
/** @description Model not found */
|
||||||
@ -5079,14 +5076,14 @@ export type operations = {
|
|||||||
};
|
};
|
||||||
requestBody: {
|
requestBody: {
|
||||||
content: {
|
content: {
|
||||||
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"];
|
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"];
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
responses: {
|
responses: {
|
||||||
/** @description The model was updated successfully */
|
/** @description The model was updated successfully */
|
||||||
200: {
|
200: {
|
||||||
content: {
|
content: {
|
||||||
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"];
|
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"];
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
/** @description Bad request */
|
/** @description Bad request */
|
||||||
@ -5101,12 +5098,118 @@ export type operations = {
|
|||||||
};
|
};
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
/**
|
||||||
|
* Import Model
|
||||||
|
* @description Add a model using its local path, repo_id, or remote URL. Model characteristics will be probed and configured automatically
|
||||||
|
*/
|
||||||
|
import_model: {
|
||||||
|
requestBody: {
|
||||||
|
content: {
|
||||||
|
"application/json": components["schemas"]["Body_import_model"];
|
||||||
|
};
|
||||||
|
};
|
||||||
|
responses: {
|
||||||
|
/** @description The model imported successfully */
|
||||||
|
201: {
|
||||||
|
content: {
|
||||||
|
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"];
|
||||||
|
};
|
||||||
|
};
|
||||||
|
/** @description The model could not be found */
|
||||||
|
404: never;
|
||||||
|
/** @description There is already a model corresponding to this path or repo_id */
|
||||||
|
409: never;
|
||||||
|
/** @description Validation Error */
|
||||||
|
422: {
|
||||||
|
content: {
|
||||||
|
"application/json": components["schemas"]["HTTPValidationError"];
|
||||||
|
};
|
||||||
|
};
|
||||||
|
/** @description The model appeared to import successfully, but could not be found in the model manager */
|
||||||
|
424: never;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
/**
|
||||||
|
* Add Model
|
||||||
|
* @description Add a model using the configuration information appropriate for its type. Only local models can be added by path
|
||||||
|
*/
|
||||||
|
add_model: {
|
||||||
|
requestBody: {
|
||||||
|
content: {
|
||||||
|
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"];
|
||||||
|
};
|
||||||
|
};
|
||||||
|
responses: {
|
||||||
|
/** @description The model added successfully */
|
||||||
|
201: {
|
||||||
|
content: {
|
||||||
|
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"];
|
||||||
|
};
|
||||||
|
};
|
||||||
|
/** @description The model could not be found */
|
||||||
|
404: never;
|
||||||
|
/** @description There is already a model corresponding to this path or repo_id */
|
||||||
|
409: never;
|
||||||
|
/** @description Validation Error */
|
||||||
|
422: {
|
||||||
|
content: {
|
||||||
|
"application/json": components["schemas"]["HTTPValidationError"];
|
||||||
|
};
|
||||||
|
};
|
||||||
|
/** @description The model appeared to add successfully, but could not be found in the model manager */
|
||||||
|
424: never;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
/**
|
||||||
|
* Rename Model
|
||||||
|
* @description Rename a model
|
||||||
|
*/
|
||||||
|
rename_model: {
|
||||||
|
parameters: {
|
||||||
|
query?: {
|
||||||
|
/** @description new model name */
|
||||||
|
new_name?: string;
|
||||||
|
/** @description new model base */
|
||||||
|
new_base?: components["schemas"]["BaseModelType"];
|
||||||
|
};
|
||||||
|
path: {
|
||||||
|
/** @description Base model */
|
||||||
|
base_model: components["schemas"]["BaseModelType"];
|
||||||
|
/** @description The type of model */
|
||||||
|
model_type: components["schemas"]["ModelType"];
|
||||||
|
/** @description current model name */
|
||||||
|
model_name: string;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
responses: {
|
||||||
|
/** @description The model was renamed successfully */
|
||||||
|
201: {
|
||||||
|
content: {
|
||||||
|
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"];
|
||||||
|
};
|
||||||
|
};
|
||||||
|
/** @description The model could not be found */
|
||||||
|
404: never;
|
||||||
|
/** @description There is already a model corresponding to the new name */
|
||||||
|
409: never;
|
||||||
|
/** @description Validation Error */
|
||||||
|
422: {
|
||||||
|
content: {
|
||||||
|
"application/json": components["schemas"]["HTTPValidationError"];
|
||||||
|
};
|
||||||
|
};
|
||||||
|
};
|
||||||
|
};
|
||||||
/**
|
/**
|
||||||
* Convert Model
|
* Convert Model
|
||||||
* @description Convert a checkpoint model into a diffusers model
|
* @description Convert a checkpoint model into a diffusers model, optionally saving to the indicated destination directory, or `models` if none.
|
||||||
*/
|
*/
|
||||||
convert_model: {
|
convert_model: {
|
||||||
parameters: {
|
parameters: {
|
||||||
|
query?: {
|
||||||
|
/** @description Save the converted model to the designated directory */
|
||||||
|
convert_dest_directory?: string;
|
||||||
|
};
|
||||||
path: {
|
path: {
|
||||||
/** @description Base model */
|
/** @description Base model */
|
||||||
base_model: components["schemas"]["BaseModelType"];
|
base_model: components["schemas"]["BaseModelType"];
|
||||||
@ -5120,7 +5223,7 @@ export type operations = {
|
|||||||
/** @description Model converted successfully */
|
/** @description Model converted successfully */
|
||||||
200: {
|
200: {
|
||||||
content: {
|
content: {
|
||||||
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"];
|
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"];
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
/** @description Bad request */
|
/** @description Bad request */
|
||||||
@ -5135,6 +5238,60 @@ export type operations = {
|
|||||||
};
|
};
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
/** Search For Models */
|
||||||
|
search_for_models: {
|
||||||
|
parameters: {
|
||||||
|
query: {
|
||||||
|
/** @description Directory path to search for models */
|
||||||
|
search_path: string;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
responses: {
|
||||||
|
/** @description Directory searched successfully */
|
||||||
|
200: {
|
||||||
|
content: {
|
||||||
|
"application/json": (string)[];
|
||||||
|
};
|
||||||
|
};
|
||||||
|
/** @description Invalid directory path */
|
||||||
|
404: never;
|
||||||
|
/** @description Validation Error */
|
||||||
|
422: {
|
||||||
|
content: {
|
||||||
|
"application/json": components["schemas"]["HTTPValidationError"];
|
||||||
|
};
|
||||||
|
};
|
||||||
|
};
|
||||||
|
};
|
||||||
|
/**
|
||||||
|
* List Ckpt Configs
|
||||||
|
* @description Return a list of the legacy checkpoint configuration files stored in `ROOT/configs/stable-diffusion`, relative to ROOT.
|
||||||
|
*/
|
||||||
|
list_ckpt_configs: {
|
||||||
|
responses: {
|
||||||
|
/** @description paths retrieved successfully */
|
||||||
|
200: {
|
||||||
|
content: {
|
||||||
|
"application/json": (string)[];
|
||||||
|
};
|
||||||
|
};
|
||||||
|
};
|
||||||
|
};
|
||||||
|
/**
|
||||||
|
* Sync To Config
|
||||||
|
* @description Call after making changes to models.yaml, autoimport directories or models directory to synchronize
|
||||||
|
* in-memory data structures with disk data structures.
|
||||||
|
*/
|
||||||
|
sync_to_config: {
|
||||||
|
responses: {
|
||||||
|
/** @description synchronization successful */
|
||||||
|
201: {
|
||||||
|
content: {
|
||||||
|
"application/json": unknown;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
};
|
||||||
|
};
|
||||||
/**
|
/**
|
||||||
* Merge Models
|
* Merge Models
|
||||||
* @description Convert a checkpoint model into a diffusers model
|
* @description Convert a checkpoint model into a diffusers model
|
||||||
@ -5155,7 +5312,7 @@ export type operations = {
|
|||||||
/** @description Model converted successfully */
|
/** @description Model converted successfully */
|
||||||
200: {
|
200: {
|
||||||
content: {
|
content: {
|
||||||
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"];
|
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"];
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
/** @description Incompatible models */
|
/** @description Incompatible models */
|
||||||
|
@ -1 +1 @@
|
|||||||
__version__ = "3.0.0+b5"
|
__version__ = "3.0.0+b6"
|
||||||
|
@ -55,7 +55,6 @@ def mock_services() -> InvocationServices:
|
|||||||
),
|
),
|
||||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
|
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
|
||||||
processor = DefaultInvocationProcessor(),
|
processor = DefaultInvocationProcessor(),
|
||||||
restoration = None, # type: ignore
|
|
||||||
configuration = None, # type: ignore
|
configuration = None, # type: ignore
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -48,7 +48,6 @@ def mock_services() -> InvocationServices:
|
|||||||
),
|
),
|
||||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
|
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
|
||||||
processor = DefaultInvocationProcessor(),
|
processor = DefaultInvocationProcessor(),
|
||||||
restoration = None, # type: ignore
|
|
||||||
configuration = None, # type: ignore
|
configuration = None, # type: ignore
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from .test_nodes import ImageToImageTestInvocation, TextToImageTestInvocation, ListPassThroughInvocation, PromptTestInvocation
|
from .test_nodes import ImageToImageTestInvocation, TextToImageTestInvocation, ListPassThroughInvocation, PromptTestInvocation
|
||||||
from invokeai.app.services.graph import Edge, Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation
|
from invokeai.app.services.graph import Edge, Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation
|
||||||
from invokeai.app.invocations.upscale import UpscaleInvocation
|
from invokeai.app.invocations.upscale import RealESRGANInvocation
|
||||||
from invokeai.app.invocations.image import *
|
from invokeai.app.invocations.image import *
|
||||||
from invokeai.app.invocations.math import AddInvocation, SubtractInvocation
|
from invokeai.app.invocations.math import AddInvocation, SubtractInvocation
|
||||||
from invokeai.app.invocations.params import ParamIntInvocation
|
from invokeai.app.invocations.params import ParamIntInvocation
|
||||||
@ -19,7 +19,7 @@ def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> Edg
|
|||||||
def test_connections_are_compatible():
|
def test_connections_are_compatible():
|
||||||
from_node = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
from_node = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||||
from_field = "image"
|
from_field = "image"
|
||||||
to_node = UpscaleInvocation(id = "2")
|
to_node = RealESRGANInvocation(id = "2")
|
||||||
to_field = "image"
|
to_field = "image"
|
||||||
|
|
||||||
result = are_connections_compatible(from_node, from_field, to_node, to_field)
|
result = are_connections_compatible(from_node, from_field, to_node, to_field)
|
||||||
@ -29,7 +29,7 @@ def test_connections_are_compatible():
|
|||||||
def test_connections_are_incompatible():
|
def test_connections_are_incompatible():
|
||||||
from_node = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
from_node = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||||
from_field = "image"
|
from_field = "image"
|
||||||
to_node = UpscaleInvocation(id = "2")
|
to_node = RealESRGANInvocation(id = "2")
|
||||||
to_field = "strength"
|
to_field = "strength"
|
||||||
|
|
||||||
result = are_connections_compatible(from_node, from_field, to_node, to_field)
|
result = are_connections_compatible(from_node, from_field, to_node, to_field)
|
||||||
@ -39,7 +39,7 @@ def test_connections_are_incompatible():
|
|||||||
def test_connections_incompatible_with_invalid_fields():
|
def test_connections_incompatible_with_invalid_fields():
|
||||||
from_node = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
from_node = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||||
from_field = "invalid_field"
|
from_field = "invalid_field"
|
||||||
to_node = UpscaleInvocation(id = "2")
|
to_node = RealESRGANInvocation(id = "2")
|
||||||
to_field = "image"
|
to_field = "image"
|
||||||
|
|
||||||
# From field is invalid
|
# From field is invalid
|
||||||
@ -86,10 +86,10 @@ def test_graph_fails_to_update_node_if_type_changes():
|
|||||||
g = Graph()
|
g = Graph()
|
||||||
n = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
n = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||||
g.add_node(n)
|
g.add_node(n)
|
||||||
n2 = UpscaleInvocation(id = "2")
|
n2 = RealESRGANInvocation(id = "2")
|
||||||
g.add_node(n2)
|
g.add_node(n2)
|
||||||
|
|
||||||
nu = UpscaleInvocation(id = "1")
|
nu = RealESRGANInvocation(id = "1")
|
||||||
|
|
||||||
with pytest.raises(TypeError):
|
with pytest.raises(TypeError):
|
||||||
g.update_node("1", nu)
|
g.update_node("1", nu)
|
||||||
@ -98,7 +98,7 @@ def test_graph_allows_non_conflicting_id_change():
|
|||||||
g = Graph()
|
g = Graph()
|
||||||
n = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
n = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||||
g.add_node(n)
|
g.add_node(n)
|
||||||
n2 = UpscaleInvocation(id = "2")
|
n2 = RealESRGANInvocation(id = "2")
|
||||||
g.add_node(n2)
|
g.add_node(n2)
|
||||||
e1 = create_edge(n.id,"image",n2.id,"image")
|
e1 = create_edge(n.id,"image",n2.id,"image")
|
||||||
g.add_edge(e1)
|
g.add_edge(e1)
|
||||||
@ -128,7 +128,7 @@ def test_graph_fails_to_update_node_id_if_conflict():
|
|||||||
def test_graph_adds_edge():
|
def test_graph_adds_edge():
|
||||||
g = Graph()
|
g = Graph()
|
||||||
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||||
n2 = UpscaleInvocation(id = "2")
|
n2 = RealESRGANInvocation(id = "2")
|
||||||
g.add_node(n1)
|
g.add_node(n1)
|
||||||
g.add_node(n2)
|
g.add_node(n2)
|
||||||
e = create_edge(n1.id,"image",n2.id,"image")
|
e = create_edge(n1.id,"image",n2.id,"image")
|
||||||
@ -139,7 +139,7 @@ def test_graph_adds_edge():
|
|||||||
|
|
||||||
def test_graph_fails_to_add_edge_with_cycle():
|
def test_graph_fails_to_add_edge_with_cycle():
|
||||||
g = Graph()
|
g = Graph()
|
||||||
n1 = UpscaleInvocation(id = "1")
|
n1 = RealESRGANInvocation(id = "1")
|
||||||
g.add_node(n1)
|
g.add_node(n1)
|
||||||
e = create_edge(n1.id,"image",n1.id,"image")
|
e = create_edge(n1.id,"image",n1.id,"image")
|
||||||
with pytest.raises(InvalidEdgeError):
|
with pytest.raises(InvalidEdgeError):
|
||||||
@ -148,8 +148,8 @@ def test_graph_fails_to_add_edge_with_cycle():
|
|||||||
def test_graph_fails_to_add_edge_with_long_cycle():
|
def test_graph_fails_to_add_edge_with_long_cycle():
|
||||||
g = Graph()
|
g = Graph()
|
||||||
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||||
n2 = UpscaleInvocation(id = "2")
|
n2 = RealESRGANInvocation(id = "2")
|
||||||
n3 = UpscaleInvocation(id = "3")
|
n3 = RealESRGANInvocation(id = "3")
|
||||||
g.add_node(n1)
|
g.add_node(n1)
|
||||||
g.add_node(n2)
|
g.add_node(n2)
|
||||||
g.add_node(n3)
|
g.add_node(n3)
|
||||||
@ -164,7 +164,7 @@ def test_graph_fails_to_add_edge_with_long_cycle():
|
|||||||
def test_graph_fails_to_add_edge_with_missing_node_id():
|
def test_graph_fails_to_add_edge_with_missing_node_id():
|
||||||
g = Graph()
|
g = Graph()
|
||||||
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||||
n2 = UpscaleInvocation(id = "2")
|
n2 = RealESRGANInvocation(id = "2")
|
||||||
g.add_node(n1)
|
g.add_node(n1)
|
||||||
g.add_node(n2)
|
g.add_node(n2)
|
||||||
e1 = create_edge("1","image","3","image")
|
e1 = create_edge("1","image","3","image")
|
||||||
@ -177,8 +177,8 @@ def test_graph_fails_to_add_edge_with_missing_node_id():
|
|||||||
def test_graph_fails_to_add_edge_when_destination_exists():
|
def test_graph_fails_to_add_edge_when_destination_exists():
|
||||||
g = Graph()
|
g = Graph()
|
||||||
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||||
n2 = UpscaleInvocation(id = "2")
|
n2 = RealESRGANInvocation(id = "2")
|
||||||
n3 = UpscaleInvocation(id = "3")
|
n3 = RealESRGANInvocation(id = "3")
|
||||||
g.add_node(n1)
|
g.add_node(n1)
|
||||||
g.add_node(n2)
|
g.add_node(n2)
|
||||||
g.add_node(n3)
|
g.add_node(n3)
|
||||||
@ -194,7 +194,7 @@ def test_graph_fails_to_add_edge_when_destination_exists():
|
|||||||
def test_graph_fails_to_add_edge_with_mismatched_types():
|
def test_graph_fails_to_add_edge_with_mismatched_types():
|
||||||
g = Graph()
|
g = Graph()
|
||||||
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||||
n2 = UpscaleInvocation(id = "2")
|
n2 = RealESRGANInvocation(id = "2")
|
||||||
g.add_node(n1)
|
g.add_node(n1)
|
||||||
g.add_node(n2)
|
g.add_node(n2)
|
||||||
e1 = create_edge("1","image","2","strength")
|
e1 = create_edge("1","image","2","strength")
|
||||||
@ -344,7 +344,7 @@ def test_graph_iterator_invalid_if_output_and_input_types_different():
|
|||||||
def test_graph_validates():
|
def test_graph_validates():
|
||||||
g = Graph()
|
g = Graph()
|
||||||
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||||
n2 = UpscaleInvocation(id = "2")
|
n2 = RealESRGANInvocation(id = "2")
|
||||||
g.add_node(n1)
|
g.add_node(n1)
|
||||||
g.add_node(n2)
|
g.add_node(n2)
|
||||||
e1 = create_edge("1","image","2","image")
|
e1 = create_edge("1","image","2","image")
|
||||||
@ -377,8 +377,8 @@ def test_graph_invalid_if_subgraph_invalid():
|
|||||||
|
|
||||||
def test_graph_invalid_if_has_cycle():
|
def test_graph_invalid_if_has_cycle():
|
||||||
g = Graph()
|
g = Graph()
|
||||||
n1 = UpscaleInvocation(id = "1")
|
n1 = RealESRGANInvocation(id = "1")
|
||||||
n2 = UpscaleInvocation(id = "2")
|
n2 = RealESRGANInvocation(id = "2")
|
||||||
g.nodes[n1.id] = n1
|
g.nodes[n1.id] = n1
|
||||||
g.nodes[n2.id] = n2
|
g.nodes[n2.id] = n2
|
||||||
e1 = create_edge("1","image","2","image")
|
e1 = create_edge("1","image","2","image")
|
||||||
@ -391,7 +391,7 @@ def test_graph_invalid_if_has_cycle():
|
|||||||
def test_graph_invalid_with_invalid_connection():
|
def test_graph_invalid_with_invalid_connection():
|
||||||
g = Graph()
|
g = Graph()
|
||||||
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||||
n2 = UpscaleInvocation(id = "2")
|
n2 = RealESRGANInvocation(id = "2")
|
||||||
g.nodes[n1.id] = n1
|
g.nodes[n1.id] = n1
|
||||||
g.nodes[n2.id] = n2
|
g.nodes[n2.id] = n2
|
||||||
e1 = create_edge("1","image","2","strength")
|
e1 = create_edge("1","image","2","strength")
|
||||||
@ -503,7 +503,7 @@ def test_graph_fails_to_enumerate_non_subgraph_node():
|
|||||||
|
|
||||||
g.add_node(n1)
|
g.add_node(n1)
|
||||||
|
|
||||||
n2 = UpscaleInvocation(id = "2")
|
n2 = RealESRGANInvocation(id = "2")
|
||||||
g.add_node(n2)
|
g.add_node(n2)
|
||||||
|
|
||||||
with pytest.raises(NodeNotFoundError):
|
with pytest.raises(NodeNotFoundError):
|
||||||
@ -512,7 +512,7 @@ def test_graph_fails_to_enumerate_non_subgraph_node():
|
|||||||
def test_graph_gets_networkx_graph():
|
def test_graph_gets_networkx_graph():
|
||||||
g = Graph()
|
g = Graph()
|
||||||
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||||
n2 = UpscaleInvocation(id = "2")
|
n2 = RealESRGANInvocation(id = "2")
|
||||||
g.add_node(n1)
|
g.add_node(n1)
|
||||||
g.add_node(n2)
|
g.add_node(n2)
|
||||||
e = create_edge(n1.id,"image",n2.id,"image")
|
e = create_edge(n1.id,"image",n2.id,"image")
|
||||||
@ -529,7 +529,7 @@ def test_graph_gets_networkx_graph():
|
|||||||
def test_graph_can_serialize():
|
def test_graph_can_serialize():
|
||||||
g = Graph()
|
g = Graph()
|
||||||
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||||
n2 = UpscaleInvocation(id = "2")
|
n2 = RealESRGANInvocation(id = "2")
|
||||||
g.add_node(n1)
|
g.add_node(n1)
|
||||||
g.add_node(n2)
|
g.add_node(n2)
|
||||||
e = create_edge(n1.id,"image",n2.id,"image")
|
e = create_edge(n1.id,"image",n2.id,"image")
|
||||||
@ -541,7 +541,7 @@ def test_graph_can_serialize():
|
|||||||
def test_graph_can_deserialize():
|
def test_graph_can_deserialize():
|
||||||
g = Graph()
|
g = Graph()
|
||||||
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||||
n2 = UpscaleInvocation(id = "2")
|
n2 = RealESRGANInvocation(id = "2")
|
||||||
g.add_node(n1)
|
g.add_node(n1)
|
||||||
g.add_node(n2)
|
g.add_node(n2)
|
||||||
e = create_edge(n1.id,"image",n2.id,"image")
|
e = create_edge(n1.id,"image",n2.id,"image")
|
||||||
|
Loading…
Reference in New Issue
Block a user