mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Compare commits
2 Commits
chore/pyda
...
release/v3
Author | SHA1 | Date | |
---|---|---|---|
77eed8712e | |||
5c5813bbe0 |
25
README.md
25
README.md
@ -161,7 +161,7 @@ the command `npm install -g yarn` if needed)
|
||||
_For Windows/Linux with an NVIDIA GPU:_
|
||||
|
||||
```terminal
|
||||
pip install "InvokeAI[xformers]" --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu118
|
||||
pip install "InvokeAI[xformers]" --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu117
|
||||
```
|
||||
|
||||
_For Linux with an AMD GPU:_
|
||||
@ -306,30 +306,13 @@ InvokeAI. The second will prepare the 2.3 directory for use with 3.0.
|
||||
You may now launch the WebUI in the usual way, by selecting option [1]
|
||||
from the launcher script
|
||||
|
||||
#### Migrating Images
|
||||
#### Migration Caveats
|
||||
|
||||
The migration script will migrate your invokeai settings and models,
|
||||
including textual inversion models, LoRAs and merges that you may have
|
||||
installed previously. However it does **not** migrate the generated
|
||||
images stored in your 2.3-format outputs directory. To do this, you
|
||||
need to run an additional step:
|
||||
|
||||
1. From a working InvokeAI 3.0 root directory, start the launcher and
|
||||
enter menu option [8] to open the "developer's console".
|
||||
|
||||
2. At the developer's console command line, type the command:
|
||||
|
||||
```bash
|
||||
invokeai-import-images
|
||||
```
|
||||
|
||||
3. This will lead you through the process of confirming the desired
|
||||
source and destination for the imported images. The images will
|
||||
appear in the gallery board of your choice, and contain the
|
||||
original prompt, model name, and other parameters used to generate
|
||||
the image.
|
||||
|
||||
(Many kudos to **techjedi** for contributing this script.)
|
||||
images stored in your 2.3-format outputs directory. You will need to
|
||||
manually import selected images into the 3.0 gallery via drag-and-drop.
|
||||
|
||||
## Hardware Requirements
|
||||
|
||||
|
@ -264,7 +264,7 @@ experimental versions later.
|
||||
you can create several levels of subfolders and drop your models into
|
||||
whichever ones you want.
|
||||
|
||||
- ***LICENSE***
|
||||
- ***Autoimport FolderLICENSE***
|
||||
|
||||
At the bottom of the screen you will see a checkbox for accepting
|
||||
the CreativeML Responsible AI Licenses. You need to accept the license
|
||||
@ -471,7 +471,7 @@ Then type the following commands:
|
||||
|
||||
=== "NVIDIA System"
|
||||
```bash
|
||||
pip install torch torchvision --force-reinstall --extra-index-url https://download.pytorch.org/whl/cu118
|
||||
pip install torch torchvision --force-reinstall --extra-index-url https://download.pytorch.org/whl/cu117
|
||||
pip install xformers
|
||||
```
|
||||
|
||||
|
@ -148,7 +148,7 @@ manager, please follow these steps:
|
||||
=== "CUDA (NVidia)"
|
||||
|
||||
```bash
|
||||
pip install "InvokeAI[xformers]" --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu118
|
||||
pip install "InvokeAI[xformers]" --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu117
|
||||
```
|
||||
|
||||
=== "ROCm (AMD)"
|
||||
@ -312,7 +312,7 @@ installation protocol (important!)
|
||||
|
||||
=== "CUDA (NVidia)"
|
||||
```bash
|
||||
pip install -e .[xformers] --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu118
|
||||
pip install -e .[xformers] --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu117
|
||||
```
|
||||
|
||||
=== "ROCm (AMD)"
|
||||
@ -356,7 +356,7 @@ you can do so using this unsupported recipe:
|
||||
mkdir ~/invokeai
|
||||
conda create -n invokeai python=3.10
|
||||
conda activate invokeai
|
||||
pip install InvokeAI[xformers] --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu118
|
||||
pip install InvokeAI[xformers] --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu117
|
||||
invokeai-configure --root ~/invokeai
|
||||
invokeai --root ~/invokeai --web
|
||||
```
|
||||
|
@ -34,11 +34,11 @@ directly from NVIDIA. **Do not try to install Ubuntu's
|
||||
nvidia-cuda-toolkit package. It is out of date and will cause
|
||||
conflicts among the NVIDIA driver and binaries.**
|
||||
|
||||
Go to [CUDA Toolkit
|
||||
Downloads](https://developer.nvidia.com/cuda-downloads), and use the
|
||||
target selection wizard to choose your operating system, hardware
|
||||
platform, and preferred installation method (e.g. "local" versus
|
||||
"network").
|
||||
Go to [CUDA Toolkit 11.7
|
||||
Downloads](https://developer.nvidia.com/cuda-11-7-0-download-archive),
|
||||
and use the target selection wizard to choose your operating system,
|
||||
hardware platform, and preferred installation method (e.g. "local"
|
||||
versus "network").
|
||||
|
||||
This will provide you with a downloadable install file or, depending
|
||||
on your choices, a recipe for downloading and running a install shell
|
||||
@ -61,7 +61,7 @@ Runtime Site](https://developer.nvidia.com/nvidia-container-runtime)
|
||||
|
||||
When installing torch and torchvision manually with `pip`, remember to provide
|
||||
the argument `--extra-index-url
|
||||
https://download.pytorch.org/whl/cu118` as described in the [Manual
|
||||
https://download.pytorch.org/whl/cu117` as described in the [Manual
|
||||
Installation Guide](020_INSTALL_MANUAL.md).
|
||||
|
||||
## :simple-amd: ROCm
|
||||
|
@ -28,21 +28,18 @@ command line, then just be sure to activate it's virtual environment.
|
||||
Then run the following three commands:
|
||||
|
||||
```sh
|
||||
pip install xformers~=0.0.19
|
||||
pip install triton # WON'T WORK ON WINDOWS
|
||||
pip install xformers==0.0.16rc425
|
||||
pip install triton
|
||||
python -m xformers.info output
|
||||
```
|
||||
|
||||
The first command installs `xformers`, the second installs the
|
||||
`triton` training accelerator, and the third prints out the `xformers`
|
||||
installation status. On Windows, please omit the `triton` package,
|
||||
which is not available on that platform.
|
||||
|
||||
If all goes well, you'll see a report like the
|
||||
installation status. If all goes well, you'll see a report like the
|
||||
following:
|
||||
|
||||
```sh
|
||||
xFormers 0.0.20
|
||||
xFormers 0.0.16rc425
|
||||
memory_efficient_attention.cutlassF: available
|
||||
memory_efficient_attention.cutlassB: available
|
||||
memory_efficient_attention.flshattF: available
|
||||
@ -51,28 +48,22 @@ memory_efficient_attention.smallkF: available
|
||||
memory_efficient_attention.smallkB: available
|
||||
memory_efficient_attention.tritonflashattF: available
|
||||
memory_efficient_attention.tritonflashattB: available
|
||||
indexing.scaled_index_addF: available
|
||||
indexing.scaled_index_addB: available
|
||||
indexing.index_select: available
|
||||
swiglu.dual_gemm_silu: available
|
||||
swiglu.gemm_fused_operand_sum: available
|
||||
swiglu.fused.p.cpp: available
|
||||
is_triton_available: True
|
||||
is_functorch_available: False
|
||||
pytorch.version: 2.0.1+cu118
|
||||
pytorch.version: 1.13.1+cu117
|
||||
pytorch.cuda: available
|
||||
gpu.compute_capability: 8.9
|
||||
gpu.name: NVIDIA GeForce RTX 4070
|
||||
gpu.compute_capability: 8.6
|
||||
gpu.name: NVIDIA RTX A2000 12GB
|
||||
build.info: available
|
||||
build.cuda_version: 1108
|
||||
build.python_version: 3.10.11
|
||||
build.torch_version: 2.0.1+cu118
|
||||
build.cuda_version: 1107
|
||||
build.python_version: 3.10.9
|
||||
build.torch_version: 1.13.1+cu117
|
||||
build.env.TORCH_CUDA_ARCH_LIST: 5.0+PTX 6.0 6.1 7.0 7.5 8.0 8.6
|
||||
build.env.XFORMERS_BUILD_TYPE: Release
|
||||
build.env.XFORMERS_ENABLE_DEBUG_ASSERTIONS: None
|
||||
build.env.NVCC_FLAGS: None
|
||||
build.env.XFORMERS_PACKAGE_FROM: wheel-v0.0.20
|
||||
build.nvcc_version: 11.8.89
|
||||
build.env.XFORMERS_PACKAGE_FROM: wheel-v0.0.16rc425
|
||||
source.privacy: open source
|
||||
```
|
||||
|
||||
@ -92,14 +83,14 @@ installed from source. These instructions were written for a system
|
||||
running Ubuntu 22.04, but other Linux distributions should be able to
|
||||
adapt this recipe.
|
||||
|
||||
#### 1. Install CUDA Toolkit 11.8
|
||||
#### 1. Install CUDA Toolkit 11.7
|
||||
|
||||
You will need the CUDA developer's toolkit in order to compile and
|
||||
install xFormers. **Do not try to install Ubuntu's nvidia-cuda-toolkit
|
||||
package.** It is out of date and will cause conflicts among the NVIDIA
|
||||
driver and binaries. Instead install the CUDA Toolkit package provided
|
||||
by NVIDIA itself. Go to [CUDA Toolkit 11.8
|
||||
Downloads](https://developer.nvidia.com/cuda-11-8-0-download-archive)
|
||||
by NVIDIA itself. Go to [CUDA Toolkit 11.7
|
||||
Downloads](https://developer.nvidia.com/cuda-11-7-0-download-archive)
|
||||
and use the target selection wizard to choose your platform and Linux
|
||||
distribution. Select an installer type of "runfile (local)" at the
|
||||
last step.
|
||||
@ -110,17 +101,17 @@ example, the install script recipe for Ubuntu 22.04 running on a
|
||||
x86_64 system is:
|
||||
|
||||
```
|
||||
wget https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run
|
||||
sudo sh cuda_11.8.0_520.61.05_linux.run
|
||||
wget https://developer.download.nvidia.com/compute/cuda/11.7.0/local_installers/cuda_11.7.0_515.43.04_linux.run
|
||||
sudo sh cuda_11.7.0_515.43.04_linux.run
|
||||
```
|
||||
|
||||
Rather than cut-and-paste this example, We recommend that you walk
|
||||
through the toolkit wizard in order to get the most up to date
|
||||
installer for your system.
|
||||
|
||||
#### 2. Confirm/Install pyTorch 2.01 with CUDA 11.8 support
|
||||
#### 2. Confirm/Install pyTorch 1.13 with CUDA 11.7 support
|
||||
|
||||
If you are using InvokeAI 3.0.2 or higher, these will already be
|
||||
If you are using InvokeAI 2.3 or higher, these will already be
|
||||
installed. If not, you can check whether you have the needed libraries
|
||||
using a quick command. Activate the invokeai virtual environment,
|
||||
either by entering the "developer's console", or manually with a
|
||||
@ -133,7 +124,7 @@ Then run the command:
|
||||
python -c 'exec("import torch\nprint(torch.__version__)")'
|
||||
```
|
||||
|
||||
If it prints __1.13.1+cu118__ you're good. If not, you can install the
|
||||
If it prints __1.13.1+cu117__ you're good. If not, you can install the
|
||||
most up to date libraries with this command:
|
||||
|
||||
```sh
|
||||
|
@ -348,7 +348,7 @@ class InvokeAiInstance:
|
||||
|
||||
introduction()
|
||||
|
||||
from invokeai.frontend.install.invokeai_configure import invokeai_configure
|
||||
from invokeai.frontend.install import invokeai_configure
|
||||
|
||||
# NOTE: currently the config script does its own arg parsing! this means the command-line switches
|
||||
# from the installer will also automatically propagate down to the config script.
|
||||
@ -463,10 +463,10 @@ def get_torch_source() -> (Union[str, None], str):
|
||||
url = "https://download.pytorch.org/whl/cpu"
|
||||
|
||||
if device == "cuda":
|
||||
url = "https://download.pytorch.org/whl/cu118"
|
||||
url = "https://download.pytorch.org/whl/cu117"
|
||||
optional_modules = "[xformers,onnx-cuda]"
|
||||
if device == "cuda_and_dml":
|
||||
url = "https://download.pytorch.org/whl/cu118"
|
||||
url = "https://download.pytorch.org/whl/cu117"
|
||||
optional_modules = "[xformers,onnx-directml]"
|
||||
|
||||
# in all other cases, Torch wheels should be coming from PyPi as of Torch 1.13
|
||||
|
@ -8,13 +8,16 @@ Preparations:
|
||||
to work. Instructions are given here:
|
||||
https://invoke-ai.github.io/InvokeAI/installation/INSTALL_AUTOMATED/
|
||||
|
||||
NOTE: At this time we do not recommend Python 3.11. We recommend
|
||||
Version 3.10.9, which has been extensively tested with InvokeAI.
|
||||
|
||||
Before you start the installer, please open up your system's command
|
||||
line window (Terminal or Command) and type the commands:
|
||||
|
||||
python --version
|
||||
|
||||
If all is well, it will print "Python 3.X.X", where the version number
|
||||
is at least 3.9.*, and not higher than 3.11.*.
|
||||
is at least 3.9.1, and less than 3.11.
|
||||
|
||||
If this works, check the version of the Python package manager, pip:
|
||||
|
||||
|
@ -104,12 +104,8 @@ async def update_model(
|
||||
): # model manager moved model path during rename - don't overwrite it
|
||||
info.path = new_info.get("path")
|
||||
|
||||
# replace empty string values with None/null to avoid phenomenon of vae: ''
|
||||
info_dict = info.model_dump()
|
||||
info_dict = {x: info_dict[x] if info_dict[x] else None for x in info_dict.keys()}
|
||||
|
||||
ApiDependencies.invoker.services.model_manager.update_model(
|
||||
model_name=model_name, base_model=base_model, model_type=model_type, model_attributes=info_dict
|
||||
model_name=model_name, base_model=base_model, model_type=model_type, model_attributes=info.dict()
|
||||
)
|
||||
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||
@ -203,7 +199,7 @@ async def add_model(
|
||||
|
||||
try:
|
||||
ApiDependencies.invoker.services.model_manager.add_model(
|
||||
info.model_name, info.base_model, info.model_type, model_attributes=info.model_dump()
|
||||
info.model_name, info.base_model, info.model_type, model_attributes=info.dict()
|
||||
)
|
||||
logger.info(f"Successfully added {info.model_name}")
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||
@ -348,7 +344,7 @@ async def sync_to_config() -> bool:
|
||||
)
|
||||
async def merge_models(
|
||||
base_model: BaseModelType = Path(description="Base model"),
|
||||
model_names: List[str] = Body(description="model name", min_length=2, max_length=3),
|
||||
model_names: List[str] = Body(description="model name", min_items=2, max_items=3),
|
||||
merged_model_name: Optional[str] = Body(description="Name of destination model"),
|
||||
alpha: Optional[float] = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5),
|
||||
interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method"),
|
||||
|
@ -15,7 +15,7 @@ from fastapi.staticfiles import StaticFiles
|
||||
from fastapi_events.handlers.local import local_handler
|
||||
from fastapi_events.middleware import EventHandlerASGIMiddleware
|
||||
from pathlib import Path
|
||||
#from pydantic.schema import schema
|
||||
from pydantic.schema import schema
|
||||
|
||||
# This should come early so that modules can log their initialization properly
|
||||
from .services.config import InvokeAIAppConfig
|
||||
@ -126,13 +126,13 @@ def custom_openapi():
|
||||
output_type = signature(invoker.invoke).return_annotation
|
||||
output_types.add(output_type)
|
||||
|
||||
# output_schemas = schema(output_types, ref_prefix="#/components/schemas/")
|
||||
# for schema_key, output_schema in output_schemas["definitions"].items():
|
||||
# openapi_schema["components"]["schemas"][schema_key] = output_schema
|
||||
output_schemas = schema(output_types, ref_prefix="#/components/schemas/")
|
||||
for schema_key, output_schema in output_schemas["definitions"].items():
|
||||
openapi_schema["components"]["schemas"][schema_key] = output_schema
|
||||
|
||||
# # TODO: note that we assume the schema_key here is the TYPE.__name__
|
||||
# # This could break in some cases, figure out a better way to do it
|
||||
# output_type_titles[schema_key] = output_schema["title"]
|
||||
# TODO: note that we assume the schema_key here is the TYPE.__name__
|
||||
# This could break in some cases, figure out a better way to do it
|
||||
output_type_titles[schema_key] = output_schema["title"]
|
||||
|
||||
# Add a reference to the output type to additionalProperties of the invoker schema
|
||||
for invoker in all_invocations:
|
||||
|
@ -67,7 +67,7 @@ def add_parsers(
|
||||
add_arguments(command_parser)
|
||||
|
||||
# Convert all fields to arguments
|
||||
fields = command.model_fields # type: ignore
|
||||
fields = command.__fields__ # type: ignore
|
||||
for name, field in fields.items():
|
||||
if name in exclude_fields:
|
||||
continue
|
||||
@ -87,7 +87,7 @@ def add_graph_parsers(
|
||||
# Add arguments for inputs
|
||||
for exposed_input in graph.exposed_inputs:
|
||||
node = graph.graph.get_node(exposed_input.node_path)
|
||||
field = node.model_fields[exposed_input.field]
|
||||
field = node.__fields__[exposed_input.field]
|
||||
default_override = getattr(node, exposed_input.field)
|
||||
add_field_argument(command_parser, exposed_input.alias, field, default_override)
|
||||
|
||||
@ -194,7 +194,7 @@ def get_graph_execution_history(
|
||||
|
||||
|
||||
def get_invocation_command(invocation) -> str:
|
||||
fields = invocation.model_fields.items()
|
||||
fields = invocation.__fields__.items()
|
||||
type_hints = get_type_hints(type(invocation))
|
||||
command = [invocation.type]
|
||||
for name, field in fields:
|
||||
|
@ -118,7 +118,7 @@ class CustomisedSchemaExtra(TypedDict):
|
||||
class InvocationConfig(BaseConfig):
|
||||
"""Customizes pydantic's BaseModel.Config class for use by Invocations.
|
||||
|
||||
Provide `json_schema_extra` a `ui` dict to add hints for generated UIs.
|
||||
Provide `schema_extra` a `ui` dict to add hints for generated UIs.
|
||||
|
||||
`tags`
|
||||
- A list of strings, used to categorise invocations.
|
||||
@ -131,7 +131,7 @@ class InvocationConfig(BaseConfig):
|
||||
|
||||
```python
|
||||
class Config(InvocationConfig):
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["stable-diffusion", "image"],
|
||||
"type_hints": {
|
||||
@ -142,4 +142,4 @@ class InvocationConfig(BaseConfig):
|
||||
```
|
||||
"""
|
||||
|
||||
json_schema_extra: CustomisedSchemaExtra
|
||||
schema_extra: CustomisedSchemaExtra
|
||||
|
@ -3,7 +3,7 @@
|
||||
from typing import Literal
|
||||
|
||||
import numpy as np
|
||||
from pydantic import Field, field_field_validator
|
||||
from pydantic import Field, validator
|
||||
|
||||
from invokeai.app.models.image import ImageField
|
||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||
@ -38,7 +38,7 @@ class ImageCollectionOutput(BaseInvocationOutput):
|
||||
collection: list[ImageField] = Field(default=[], description="The output images")
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {"required": ["type", "collection"]}
|
||||
schema_extra = {"required": ["type", "collection"]}
|
||||
|
||||
|
||||
class RangeInvocation(BaseInvocation):
|
||||
@ -52,11 +52,11 @@ class RangeInvocation(BaseInvocation):
|
||||
step: int = Field(default=1, description="The step of the range")
|
||||
|
||||
class Config(InvocationConfig):
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"ui": {"title": "Range", "tags": ["range", "integer", "collection"]},
|
||||
}
|
||||
|
||||
@field_validator("stop")
|
||||
@validator("stop")
|
||||
def stop_gt_start(cls, v, values):
|
||||
if "start" in values and v <= values["start"]:
|
||||
raise ValueError("stop must be greater than start")
|
||||
@ -77,7 +77,7 @@ class RangeOfSizeInvocation(BaseInvocation):
|
||||
step: int = Field(default=1, description="The step of the range")
|
||||
|
||||
class Config(InvocationConfig):
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"ui": {"title": "Sized Range", "tags": ["range", "integer", "size", "collection"]},
|
||||
}
|
||||
|
||||
@ -102,7 +102,7 @@ class RandomRangeInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
class Config(InvocationConfig):
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"ui": {"title": "Random Range", "tags": ["range", "integer", "random", "collection"]},
|
||||
}
|
||||
|
||||
@ -127,7 +127,7 @@ class ImageCollectionInvocation(BaseInvocation):
|
||||
return ImageCollectionOutput(collection=self.images)
|
||||
|
||||
class Config(InvocationConfig):
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"type_hints": {
|
||||
"title": "Image Collection",
|
||||
|
@ -26,7 +26,7 @@ class ConditioningField(BaseModel):
|
||||
conditioning_name: Optional[str] = Field(default=None, description="The name of conditioning data")
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {"required": ["conditioning_name"]}
|
||||
schema_extra = {"required": ["conditioning_name"]}
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -80,18 +80,18 @@ class CompelInvocation(BaseInvocation):
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"ui": {"title": "Prompt (Compel)", "tags": ["prompt", "compel"], "type_hints": {"model": "model"}},
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
||||
tokenizer_info = context.services.model_manager.get_model(
|
||||
**self.clip.tokenizer.model_dump(),
|
||||
**self.clip.tokenizer.dict(),
|
||||
context=context,
|
||||
)
|
||||
text_encoder_info = context.services.model_manager.get_model(
|
||||
**self.clip.text_encoder.model_dump(),
|
||||
**self.clip.text_encoder.dict(),
|
||||
context=context,
|
||||
)
|
||||
|
||||
@ -178,11 +178,11 @@ class CompelInvocation(BaseInvocation):
|
||||
class SDXLPromptInvocationBase:
|
||||
def run_clip_raw(self, context, clip_field, prompt, get_pooled, lora_prefix):
|
||||
tokenizer_info = context.services.model_manager.get_model(
|
||||
**clip_field.tokenizer.model_dump(),
|
||||
**clip_field.tokenizer.dict(),
|
||||
context=context,
|
||||
)
|
||||
text_encoder_info = context.services.model_manager.get_model(
|
||||
**clip_field.text_encoder.model_dump(),
|
||||
**clip_field.text_encoder.dict(),
|
||||
context=context,
|
||||
)
|
||||
|
||||
@ -255,11 +255,11 @@ class SDXLPromptInvocationBase:
|
||||
|
||||
def run_clip_compel(self, context, clip_field, prompt, get_pooled, lora_prefix):
|
||||
tokenizer_info = context.services.model_manager.get_model(
|
||||
**clip_field.tokenizer.model_dump(),
|
||||
**clip_field.tokenizer.dict(),
|
||||
context=context,
|
||||
)
|
||||
text_encoder_info = context.services.model_manager.get_model(
|
||||
**clip_field.text_encoder.model_dump(),
|
||||
**clip_field.text_encoder.dict(),
|
||||
context=context,
|
||||
)
|
||||
|
||||
@ -360,7 +360,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"ui": {"title": "SDXL Prompt (Compel)", "tags": ["prompt", "compel"], "type_hints": {"model": "model"}},
|
||||
}
|
||||
|
||||
@ -414,7 +414,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "SDXL Refiner Prompt (Compel)",
|
||||
"tags": ["prompt", "compel"],
|
||||
@ -471,7 +471,7 @@ class SDXLRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"ui": {"title": "SDXL Prompt (Raw)", "tags": ["prompt", "compel"], "type_hints": {"model": "model"}},
|
||||
}
|
||||
|
||||
@ -525,7 +525,7 @@ class SDXLRefinerRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "SDXL Refiner Prompt (Raw)",
|
||||
"tags": ["prompt", "compel"],
|
||||
@ -580,7 +580,7 @@ class ClipSkipInvocation(BaseInvocation):
|
||||
skipped_layers: int = Field(0, description="Number of layers to skip in text_encoder")
|
||||
|
||||
class Config(InvocationConfig):
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"ui": {"title": "CLIP Skip", "tags": ["clip", "skip"]},
|
||||
}
|
||||
|
||||
|
@ -24,7 +24,7 @@ from controlnet_aux import (
|
||||
)
|
||||
from controlnet_aux.util import HWC3, ade_palette
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic import BaseModel, Field, validator
|
||||
|
||||
from ...backend.model_management import BaseModelType, ModelType
|
||||
from ..models.image import ImageCategory, ImageField, ResourceOrigin
|
||||
@ -123,7 +123,7 @@ class ControlField(BaseModel):
|
||||
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode to use")
|
||||
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
|
||||
|
||||
@field_validator("control_weight")
|
||||
@validator("control_weight")
|
||||
def validate_control_weight(cls, v):
|
||||
"""Validate that all control weights in the valid range"""
|
||||
if isinstance(v, list):
|
||||
@ -136,7 +136,7 @@ class ControlField(BaseModel):
|
||||
return v
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"required": ["image", "control_model", "control_weight", "begin_step_percent", "end_step_percent"],
|
||||
"ui": {
|
||||
"type_hints": {
|
||||
@ -176,7 +176,7 @@ class ControlNetInvocation(BaseInvocation):
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "ControlNet",
|
||||
"tags": ["controlnet", "latents"],
|
||||
@ -214,7 +214,7 @@ class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"ui": {"title": "Image Processor", "tags": ["image", "processor"]},
|
||||
}
|
||||
|
||||
@ -266,7 +266,7 @@ class CannyImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfi
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"ui": {"title": "Canny Processor", "tags": ["controlnet", "canny", "image", "processor"]},
|
||||
}
|
||||
|
||||
@ -290,7 +290,7 @@ class HedImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig)
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"ui": {"title": "Softedge(HED) Processor", "tags": ["controlnet", "softedge", "hed", "image", "processor"]},
|
||||
}
|
||||
|
||||
@ -319,7 +319,7 @@ class LineartImageProcessorInvocation(ImageProcessorInvocation, PILInvocationCon
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"ui": {"title": "Lineart Processor", "tags": ["controlnet", "lineart", "image", "processor"]},
|
||||
}
|
||||
|
||||
@ -342,7 +342,7 @@ class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation, PILInvocati
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Lineart Anime Processor",
|
||||
"tags": ["controlnet", "lineart", "anime", "image", "processor"],
|
||||
@ -371,7 +371,7 @@ class OpenposeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationCo
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"ui": {"title": "Openpose Processor", "tags": ["controlnet", "openpose", "image", "processor"]},
|
||||
}
|
||||
|
||||
@ -399,7 +399,7 @@ class MidasDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocation
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"ui": {"title": "Midas (Depth) Processor", "tags": ["controlnet", "midas", "depth", "image", "processor"]},
|
||||
}
|
||||
|
||||
@ -426,7 +426,7 @@ class NormalbaeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationC
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"ui": {"title": "Normal BAE Processor", "tags": ["controlnet", "normal", "bae", "image", "processor"]},
|
||||
}
|
||||
|
||||
@ -451,7 +451,7 @@ class MlsdImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"ui": {"title": "MLSD Processor", "tags": ["controlnet", "mlsd", "image", "processor"]},
|
||||
}
|
||||
|
||||
@ -480,7 +480,7 @@ class PidiImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"ui": {"title": "PIDI Processor", "tags": ["controlnet", "pidi", "image", "processor"]},
|
||||
}
|
||||
|
||||
@ -510,7 +510,7 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation, PILInvoca
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Content Shuffle Processor",
|
||||
"tags": ["controlnet", "contentshuffle", "image", "processor"],
|
||||
@ -539,7 +539,7 @@ class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocationCo
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"ui": {"title": "Zoe (Depth) Processor", "tags": ["controlnet", "zoe", "depth", "image", "processor"]},
|
||||
}
|
||||
|
||||
@ -560,7 +560,7 @@ class MediapipeFaceProcessorInvocation(ImageProcessorInvocation, PILInvocationCo
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"ui": {"title": "Mediapipe Processor", "tags": ["controlnet", "mediapipe", "image", "processor"]},
|
||||
}
|
||||
|
||||
@ -588,7 +588,7 @@ class LeresImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfi
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"ui": {"title": "Leres (Depth) Processor", "tags": ["controlnet", "leres", "depth", "image", "processor"]},
|
||||
}
|
||||
|
||||
@ -614,7 +614,7 @@ class TileResamplerProcessorInvocation(ImageProcessorInvocation, PILInvocationCo
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Tile Resample Processor",
|
||||
"tags": ["controlnet", "tile", "resample", "image", "processor"],
|
||||
@ -656,7 +656,7 @@ class SegmentAnythingProcessorInvocation(ImageProcessorInvocation, PILInvocation
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Segment Anything Processor",
|
||||
"tags": ["controlnet", "segment", "anything", "sam", "image", "processor"],
|
||||
|
@ -17,7 +17,7 @@ class CvInvocationConfig(BaseModel):
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["cv", "image"],
|
||||
},
|
||||
@ -36,7 +36,7 @@ class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"ui": {"title": "OpenCV Inpaint", "tags": ["opencv", "inpaint"]},
|
||||
}
|
||||
|
||||
|
@ -1,23 +1,26 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from contextlib import contextmanager, ContextDecorator
|
||||
from functools import partial
|
||||
from typing import Literal, Optional, get_args
|
||||
|
||||
import torch
|
||||
from pydantic import Field
|
||||
|
||||
from invokeai.app.models.image import ColorField, ImageCategory, ImageField, ResourceOrigin
|
||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||
from invokeai.backend.generator.inpaint import infill_methods
|
||||
from .baseinvocation import BaseInvocation, InvocationConfig, InvocationContext
|
||||
from .compel import ConditioningField
|
||||
from .image import ImageOutput
|
||||
from .model import UNetField, VaeField
|
||||
from ..util.step_callback import stable_diffusion_step_callback
|
||||
|
||||
from ...backend.generator import Inpaint, InvokeAIGenerator
|
||||
from ...backend.model_management.lora import ModelPatcher
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
from ..util.step_callback import stable_diffusion_step_callback
|
||||
from .baseinvocation import BaseInvocation, InvocationConfig, InvocationContext
|
||||
from .image import ImageOutput
|
||||
|
||||
from ...backend.model_management.lora import ModelPatcher
|
||||
from ...backend.stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline
|
||||
from .model import UNetField, VaeField
|
||||
from .compel import ConditioningField
|
||||
from contextlib import contextmanager, ExitStack, ContextDecorator
|
||||
|
||||
SAMPLER_NAME_VALUES = Literal[tuple(InvokeAIGenerator.schedulers())]
|
||||
INFILL_METHODS = Literal[tuple(infill_methods())]
|
||||
@ -129,7 +132,7 @@ class InpaintInvocation(BaseInvocation):
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"ui": {"tags": ["stable-diffusion", "image"], "title": "Inpaint"},
|
||||
}
|
||||
|
||||
@ -142,7 +145,7 @@ class InpaintInvocation(BaseInvocation):
|
||||
stable_diffusion_step_callback(
|
||||
context=context,
|
||||
intermediate_state=intermediate_state,
|
||||
node=self.model_dump(),
|
||||
node=self.dict(),
|
||||
source_node_id=source_node_id,
|
||||
)
|
||||
|
||||
@ -169,11 +172,11 @@ class InpaintInvocation(BaseInvocation):
|
||||
return
|
||||
|
||||
unet_info = context.services.model_manager.get_model(
|
||||
**self.unet.unet.model_dump(),
|
||||
**self.unet.unet.dict(),
|
||||
context=context,
|
||||
)
|
||||
vae_info = context.services.model_manager.get_model(
|
||||
**self.vae.vae.model_dump(),
|
||||
**self.vae.vae.dict(),
|
||||
context=context,
|
||||
)
|
||||
|
||||
@ -181,8 +184,6 @@ class InpaintInvocation(BaseInvocation):
|
||||
device = context.services.model_manager.mgr.cache.execution_device
|
||||
dtype = context.services.model_manager.mgr.cache.precision
|
||||
|
||||
vae.to(dtype=unet.dtype)
|
||||
|
||||
pipeline = StableDiffusionGeneratorPipeline(
|
||||
vae=vae,
|
||||
text_encoder=None,
|
||||
@ -192,6 +193,8 @@ class InpaintInvocation(BaseInvocation):
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False,
|
||||
precision="float16" if dtype == torch.float16 else "float32",
|
||||
execution_device=device,
|
||||
)
|
||||
|
||||
yield OldModelInfo(
|
||||
|
@ -39,7 +39,7 @@ class LoadImageInvocation(BaseInvocation):
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"ui": {"title": "Load Image", "tags": ["image", "load"]},
|
||||
}
|
||||
|
||||
@ -62,7 +62,7 @@ class ShowImageInvocation(BaseInvocation):
|
||||
image: Optional[ImageField] = Field(default=None, description="The image to show")
|
||||
|
||||
class Config(InvocationConfig):
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"ui": {"title": "Show Image", "tags": ["image", "show"]},
|
||||
}
|
||||
|
||||
@ -95,7 +95,7 @@ class ImageCropInvocation(BaseInvocation, PILInvocationConfig):
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"ui": {"title": "Crop Image", "tags": ["image", "crop"]},
|
||||
}
|
||||
|
||||
@ -136,7 +136,7 @@ class ImagePasteInvocation(BaseInvocation, PILInvocationConfig):
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"ui": {"title": "Paste Image", "tags": ["image", "paste"]},
|
||||
}
|
||||
|
||||
@ -185,7 +185,7 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"ui": {"title": "Mask From Alpha", "tags": ["image", "mask", "alpha"]},
|
||||
}
|
||||
|
||||
@ -224,7 +224,7 @@ class ImageMultiplyInvocation(BaseInvocation, PILInvocationConfig):
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"ui": {"title": "Multiply Images", "tags": ["image", "multiply"]},
|
||||
}
|
||||
|
||||
@ -265,7 +265,7 @@ class ImageChannelInvocation(BaseInvocation, PILInvocationConfig):
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"ui": {"title": "Image Channel", "tags": ["image", "channel"]},
|
||||
}
|
||||
|
||||
@ -305,7 +305,7 @@ class ImageConvertInvocation(BaseInvocation, PILInvocationConfig):
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"ui": {"title": "Convert Image", "tags": ["image", "convert"]},
|
||||
}
|
||||
|
||||
@ -343,7 +343,7 @@ class ImageBlurInvocation(BaseInvocation, PILInvocationConfig):
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"ui": {"title": "Blur Image", "tags": ["image", "blur"]},
|
||||
}
|
||||
|
||||
@ -405,7 +405,7 @@ class ImageResizeInvocation(BaseInvocation, PILInvocationConfig):
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"ui": {"title": "Resize Image", "tags": ["image", "resize"]},
|
||||
}
|
||||
|
||||
@ -448,7 +448,7 @@ class ImageScaleInvocation(BaseInvocation, PILInvocationConfig):
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"ui": {"title": "Scale Image", "tags": ["image", "scale"]},
|
||||
}
|
||||
|
||||
@ -493,7 +493,7 @@ class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"ui": {"title": "Image Linear Interpolation", "tags": ["image", "linear", "interpolation", "lerp"]},
|
||||
}
|
||||
|
||||
@ -501,7 +501,7 @@ class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
image_arr = numpy.asarray(image, dtype=numpy.float32) / 255
|
||||
image_arr = image_arr * (self.max - self.min) + self.min
|
||||
image_arr = image_arr * (self.max - self.min) + self.max
|
||||
|
||||
lerp_image = Image.fromarray(numpy.uint8(image_arr))
|
||||
|
||||
@ -534,7 +534,7 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Image Inverse Linear Interpolation",
|
||||
"tags": ["image", "linear", "interpolation", "inverse"],
|
||||
@ -577,7 +577,7 @@ class ImageNSFWBlurInvocation(BaseInvocation, PILInvocationConfig):
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"ui": {"title": "Blur NSFW Images", "tags": ["image", "nsfw", "checker"]},
|
||||
}
|
||||
|
||||
@ -600,7 +600,7 @@ class ImageNSFWBlurInvocation(BaseInvocation, PILInvocationConfig):
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata.model_dump() if self.metadata else None,
|
||||
metadata=self.metadata.dict() if self.metadata else None,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
@ -629,7 +629,7 @@ class ImageWatermarkInvocation(BaseInvocation, PILInvocationConfig):
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"ui": {"title": "Add Invisible Watermark", "tags": ["image", "watermark", "invisible"]},
|
||||
}
|
||||
|
||||
@ -643,7 +643,7 @@ class ImageWatermarkInvocation(BaseInvocation, PILInvocationConfig):
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata.model_dump() if self.metadata else None,
|
||||
metadata=self.metadata.dict() if self.metadata else None,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
|
@ -125,7 +125,7 @@ class InfillColorInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
class Config(InvocationConfig):
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"ui": {"title": "Color Infill", "tags": ["image", "inpaint", "color", "infill"]},
|
||||
}
|
||||
|
||||
@ -168,7 +168,7 @@ class InfillTileInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
class Config(InvocationConfig):
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"ui": {"title": "Tile Infill", "tags": ["image", "inpaint", "tile", "infill"]},
|
||||
}
|
||||
|
||||
@ -202,7 +202,7 @@ class InfillPatchMatchInvocation(BaseInvocation):
|
||||
image: Optional[ImageField] = Field(default=None, description="The image to infill")
|
||||
|
||||
class Config(InvocationConfig):
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"ui": {"title": "Patch Match Infill", "tags": ["image", "inpaint", "patchmatch", "infill"]},
|
||||
}
|
||||
|
||||
|
@ -5,26 +5,15 @@ from typing import List, Literal, Optional, Union
|
||||
|
||||
import einops
|
||||
import torch
|
||||
from diffusers import ControlNetModel
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.models.attention_processor import (
|
||||
AttnProcessor2_0,
|
||||
LoRAAttnProcessor2_0,
|
||||
LoRAXFormersAttnProcessor,
|
||||
XFormersAttnProcessor,
|
||||
)
|
||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic import BaseModel, Field, validator
|
||||
|
||||
from invokeai.app.invocations.metadata import CoreMetadata
|
||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||
from invokeai.backend.model_management.models import ModelType, SilenceWarnings
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
|
||||
from .compel import ConditioningField
|
||||
from .controlnet_image_processors import ControlField
|
||||
from .image import ImageOutput
|
||||
from .model import ModelInfo, UNetField, VaeField
|
||||
from ..models.image import ImageCategory, ImageField, ResourceOrigin
|
||||
|
||||
from ...backend.model_management import ModelPatcher
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
from ...backend.stable_diffusion.diffusers_pipeline import (
|
||||
@ -35,7 +24,23 @@ from ...backend.stable_diffusion.diffusers_pipeline import (
|
||||
)
|
||||
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
||||
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||
from ...backend.model_management import ModelPatcher
|
||||
from ...backend.util.devices import choose_torch_device, torch_dtype, choose_precision
|
||||
from ..models.image import ImageCategory, ImageField, ResourceOrigin
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
|
||||
from .compel import ConditioningField
|
||||
from .controlnet_image_processors import ControlField
|
||||
from .image import ImageOutput
|
||||
from .model import ModelInfo, UNetField, VaeField
|
||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||
|
||||
from diffusers.models.attention_processor import (
|
||||
AttnProcessor2_0,
|
||||
LoRAAttnProcessor2_0,
|
||||
LoRAXFormersAttnProcessor,
|
||||
XFormersAttnProcessor,
|
||||
)
|
||||
|
||||
|
||||
DEFAULT_PRECISION = choose_precision(choose_torch_device())
|
||||
|
||||
@ -46,7 +51,7 @@ class LatentsField(BaseModel):
|
||||
latents_name: Optional[str] = Field(default=None, description="The name of the latents")
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {"required": ["latents_name"]}
|
||||
schema_extra = {"required": ["latents_name"]}
|
||||
|
||||
|
||||
class LatentsOutput(BaseInvocationOutput):
|
||||
@ -80,7 +85,7 @@ def get_scheduler(
|
||||
) -> Scheduler:
|
||||
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
|
||||
orig_scheduler_info = context.services.model_manager.get_model(
|
||||
**scheduler_info.model_dump(),
|
||||
**scheduler_info.dict(),
|
||||
context=context,
|
||||
)
|
||||
with orig_scheduler_info as orig_scheduler:
|
||||
@ -121,7 +126,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
# seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
||||
# fmt: on
|
||||
|
||||
@field_validator("cfg_scale")
|
||||
@validator("cfg_scale")
|
||||
def ge_one(cls, v):
|
||||
"""validate that all cfg_scale values are >= 1"""
|
||||
if isinstance(v, list):
|
||||
@ -135,7 +140,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Text To Latents",
|
||||
"tags": ["latents"],
|
||||
@ -158,7 +163,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
stable_diffusion_step_callback(
|
||||
context=context,
|
||||
intermediate_state=intermediate_state,
|
||||
node=self.model_dump(),
|
||||
node=self.dict(),
|
||||
source_node_id=source_node_id,
|
||||
)
|
||||
|
||||
@ -226,6 +231,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False,
|
||||
precision="float16" if unet.dtype == torch.float16 else "float32",
|
||||
)
|
||||
|
||||
def prep_control_data(
|
||||
@ -327,7 +333,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
return
|
||||
|
||||
unet_info = context.services.model_manager.get_model(
|
||||
**self.unet.unet.model_dump(),
|
||||
**self.unet.unet.dict(),
|
||||
context=context,
|
||||
)
|
||||
with ExitStack() as exit_stack, ModelPatcher.apply_lora_unet(
|
||||
@ -384,7 +390,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Latent To Latents",
|
||||
"tags": ["latents"],
|
||||
@ -420,7 +426,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
return
|
||||
|
||||
unet_info = context.services.model_manager.get_model(
|
||||
**self.unet.unet.model_dump(),
|
||||
**self.unet.unet.dict(),
|
||||
context=context,
|
||||
)
|
||||
with ExitStack() as exit_stack, ModelPatcher.apply_lora_unet(
|
||||
@ -495,7 +501,7 @@ class LatentsToImageInvocation(BaseInvocation):
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Latents To Image",
|
||||
"tags": ["latents", "image"],
|
||||
@ -507,7 +513,7 @@ class LatentsToImageInvocation(BaseInvocation):
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
|
||||
vae_info = context.services.model_manager.get_model(
|
||||
**self.vae.vae.model_dump(),
|
||||
**self.vae.vae.dict(),
|
||||
context=context,
|
||||
)
|
||||
|
||||
@ -565,7 +571,7 @@ class LatentsToImageInvocation(BaseInvocation):
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata.model_dump() if self.metadata else None,
|
||||
metadata=self.metadata.dict() if self.metadata else None,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
@ -593,7 +599,7 @@ class ResizeLatentsInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
class Config(InvocationConfig):
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"ui": {"title": "Resize Latents", "tags": ["latents", "resize"]},
|
||||
}
|
||||
|
||||
@ -634,7 +640,7 @@ class ScaleLatentsInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
class Config(InvocationConfig):
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"ui": {"title": "Scale Latents", "tags": ["latents", "scale"]},
|
||||
}
|
||||
|
||||
@ -675,7 +681,7 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"ui": {"title": "Image To Latents", "tags": ["latents", "image"]},
|
||||
}
|
||||
|
||||
@ -686,9 +692,9 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
# )
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
# vae_info = context.services.model_manager.get_model(**self.vae.vae.model_dump())
|
||||
# vae_info = context.services.model_manager.get_model(**self.vae.vae.dict())
|
||||
vae_info = context.services.model_manager.get_model(
|
||||
**self.vae.vae.model_dump(),
|
||||
**self.vae.vae.dict(),
|
||||
context=context,
|
||||
)
|
||||
|
||||
|
@ -18,7 +18,7 @@ class MathInvocationConfig(BaseModel):
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["math"],
|
||||
}
|
||||
@ -53,7 +53,7 @@ class AddInvocation(BaseInvocation, MathInvocationConfig):
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"ui": {"title": "Add", "tags": ["math", "add"]},
|
||||
}
|
||||
|
||||
@ -71,7 +71,7 @@ class SubtractInvocation(BaseInvocation, MathInvocationConfig):
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"ui": {"title": "Subtract", "tags": ["math", "subtract"]},
|
||||
}
|
||||
|
||||
@ -89,7 +89,7 @@ class MultiplyInvocation(BaseInvocation, MathInvocationConfig):
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"ui": {"title": "Multiply", "tags": ["math", "multiply"]},
|
||||
}
|
||||
|
||||
@ -107,7 +107,7 @@ class DivideInvocation(BaseInvocation, MathInvocationConfig):
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"ui": {"title": "Divide", "tags": ["math", "divide"]},
|
||||
}
|
||||
|
||||
@ -127,7 +127,7 @@ class RandomIntInvocation(BaseInvocation):
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"ui": {"title": "Random Integer", "tags": ["math", "random", "integer"]},
|
||||
}
|
||||
|
||||
|
@ -2,7 +2,6 @@ from typing import Literal, Optional, Union
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from ...version import __version__
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
@ -24,7 +23,6 @@ class LoRAMetadataField(BaseModelExcludeNull):
|
||||
class CoreMetadata(BaseModelExcludeNull):
|
||||
"""Core generation metadata for an image generated in InvokeAI."""
|
||||
|
||||
app_version: str = Field(default=__version__, description="The version of InvokeAI used to generate this image")
|
||||
generation_mode: str = Field(
|
||||
description="The generation mode that output this image",
|
||||
)
|
||||
@ -142,7 +140,7 @@ class MetadataAccumulatorInvocation(BaseInvocation):
|
||||
refiner_start: Union[float, None] = Field(default=None, description="The start value used for refiner denoising")
|
||||
|
||||
class Config(InvocationConfig):
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Metadata Accumulator",
|
||||
"tags": ["image", "metadata", "generation"],
|
||||
@ -152,4 +150,4 @@ class MetadataAccumulatorInvocation(BaseInvocation):
|
||||
def invoke(self, context: InvocationContext) -> MetadataAccumulatorOutput:
|
||||
"""Collects and outputs a CoreMetadata object"""
|
||||
|
||||
return MetadataAccumulatorOutput(metadata=CoreMetadata(**self.model_dump()))
|
||||
return MetadataAccumulatorOutput(metadata=CoreMetadata(**self.dict()))
|
||||
|
@ -73,7 +73,7 @@ class MainModelLoaderInvocation(BaseInvocation):
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Model Loader",
|
||||
"tags": ["model", "loader"],
|
||||
@ -205,7 +205,7 @@ class LoraLoaderInvocation(BaseInvocation):
|
||||
clip: Optional[ClipField] = Field(description="Clip model for applying lora")
|
||||
|
||||
class Config(InvocationConfig):
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Lora Loader",
|
||||
"tags": ["lora", "loader"],
|
||||
@ -287,7 +287,7 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
|
||||
clip2: Optional[ClipField] = Field(description="Clip2 model for applying lora")
|
||||
|
||||
class Config(InvocationConfig):
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "SDXL Lora Loader",
|
||||
"tags": ["lora", "loader"],
|
||||
@ -385,7 +385,7 @@ class VaeLoaderInvocation(BaseInvocation):
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "VAE Loader",
|
||||
"tags": ["vae", "loader"],
|
||||
|
@ -3,7 +3,7 @@
|
||||
import math
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
from pydantic import Field, validator
|
||||
import torch
|
||||
from invokeai.app.invocations.latent import LatentsField
|
||||
|
||||
@ -110,14 +110,14 @@ class NoiseInvocation(BaseInvocation):
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Noise",
|
||||
"tags": ["latents", "noise"],
|
||||
},
|
||||
}
|
||||
|
||||
@field_validator("seed", pre=True)
|
||||
@validator("seed", pre=True)
|
||||
def modulo_seed(cls, v):
|
||||
"""Returns the seed modulo (SEED_MAX + 1) to ensure it is within the valid range."""
|
||||
return v % (SEED_MAX + 1)
|
||||
|
@ -6,7 +6,7 @@ from typing import List, Literal, Optional, Union
|
||||
import re
|
||||
import inspect
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic import BaseModel, Field, validator
|
||||
import torch
|
||||
import numpy as np
|
||||
from diffusers import ControlNetModel, DPMSolverMultistepScheduler
|
||||
@ -59,10 +59,10 @@ class ONNXPromptInvocation(BaseInvocation):
|
||||
|
||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
||||
tokenizer_info = context.services.model_manager.get_model(
|
||||
**self.clip.tokenizer.model_dump(),
|
||||
**self.clip.tokenizer.dict(),
|
||||
)
|
||||
text_encoder_info = context.services.model_manager.get_model(
|
||||
**self.clip.text_encoder.model_dump(),
|
||||
**self.clip.text_encoder.dict(),
|
||||
)
|
||||
with tokenizer_info as orig_tokenizer, text_encoder_info as text_encoder, ExitStack() as stack:
|
||||
loras = [
|
||||
@ -154,7 +154,7 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
|
||||
# seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
||||
# fmt: on
|
||||
|
||||
@field_validator("cfg_scale")
|
||||
@validator("cfg_scale")
|
||||
def ge_one(cls, v):
|
||||
"""validate that all cfg_scale values are >= 1"""
|
||||
if isinstance(v, list):
|
||||
@ -168,7 +168,7 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["latents"],
|
||||
"type_hints": {
|
||||
@ -226,7 +226,7 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
|
||||
stable_diffusion_step_callback(
|
||||
context=context,
|
||||
intermediate_state=intermediate_state,
|
||||
node=self.model_dump(),
|
||||
node=self.dict(),
|
||||
source_node_id=source_node_id,
|
||||
)
|
||||
|
||||
@ -239,7 +239,7 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
|
||||
eta=0.0,
|
||||
)
|
||||
|
||||
unet_info = context.services.model_manager.get_model(**self.unet.unet.model_dump())
|
||||
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict())
|
||||
|
||||
with unet_info as unet, ExitStack() as stack:
|
||||
# loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
|
||||
@ -314,7 +314,7 @@ class ONNXLatentsToImageInvocation(BaseInvocation):
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["latents", "image"],
|
||||
},
|
||||
@ -327,7 +327,7 @@ class ONNXLatentsToImageInvocation(BaseInvocation):
|
||||
raise Exception(f"Expected vae_decoder, found: {self.vae.vae.model_type}")
|
||||
|
||||
vae_info = context.services.model_manager.get_model(
|
||||
**self.vae.vae.model_dump(),
|
||||
**self.vae.vae.dict(),
|
||||
)
|
||||
|
||||
# clear memory as vae decode can request a lot
|
||||
@ -356,7 +356,7 @@ class ONNXLatentsToImageInvocation(BaseInvocation):
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata.model_dump() if self.metadata else None,
|
||||
metadata=self.metadata.dict() if self.metadata else None,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
@ -389,7 +389,7 @@ class ONNXSD1ModelLoaderInvocation(BaseInvocation):
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"ui": {"tags": ["model", "loader"], "type_hints": {"model_name": "model"}}, # TODO: rename to model_name?
|
||||
}
|
||||
|
||||
@ -472,7 +472,7 @@ class OnnxModelLoaderInvocation(BaseInvocation):
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Onnx Model Loader",
|
||||
"tags": ["model", "loader"],
|
||||
|
@ -65,7 +65,7 @@ class FloatLinearRangeInvocation(BaseInvocation):
|
||||
steps: int = Field(default=30, description="number of values to interpolate over (including start and stop)")
|
||||
|
||||
class Config(InvocationConfig):
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"ui": {"title": "Linear Range (Float)", "tags": ["math", "float", "linear", "range"]},
|
||||
}
|
||||
|
||||
@ -136,7 +136,7 @@ class StepParamEasingInvocation(BaseInvocation):
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"ui": {"title": "Param Easing By Step", "tags": ["param", "step", "easing"]},
|
||||
}
|
||||
|
||||
|
@ -21,7 +21,7 @@ class ParamIntInvocation(BaseInvocation):
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"ui": {"tags": ["param", "integer"], "title": "Integer Parameter"},
|
||||
}
|
||||
|
||||
@ -38,7 +38,7 @@ class ParamFloatInvocation(BaseInvocation):
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"ui": {"tags": ["param", "float"], "title": "Float Parameter"},
|
||||
}
|
||||
|
||||
@ -60,7 +60,7 @@ class ParamStringInvocation(BaseInvocation):
|
||||
text: str = Field(default="", description="The string value")
|
||||
|
||||
class Config(InvocationConfig):
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"ui": {"tags": ["param", "string"], "title": "String Parameter"},
|
||||
}
|
||||
|
||||
@ -75,7 +75,7 @@ class ParamPromptInvocation(BaseInvocation):
|
||||
prompt: str = Field(default="", description="The prompt value")
|
||||
|
||||
class Config(InvocationConfig):
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"ui": {"tags": ["param", "prompt"], "title": "Prompt"},
|
||||
}
|
||||
|
||||
|
@ -2,7 +2,7 @@ from os.path import exists
|
||||
from typing import Literal, Optional
|
||||
|
||||
import numpy as np
|
||||
from pydantic import Field, field_validator
|
||||
from pydantic import Field, validator
|
||||
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
|
||||
from dynamicprompts.generators import RandomPromptGenerator, CombinatorialPromptGenerator
|
||||
@ -18,7 +18,7 @@ class PromptOutput(BaseInvocationOutput):
|
||||
# fmt: on
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"required": [
|
||||
"type",
|
||||
"prompt",
|
||||
@ -37,7 +37,7 @@ class PromptCollectionOutput(BaseInvocationOutput):
|
||||
# fmt: on
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {"required": ["type", "prompt_collection", "count"]}
|
||||
schema_extra = {"required": ["type", "prompt_collection", "count"]}
|
||||
|
||||
|
||||
class DynamicPromptInvocation(BaseInvocation):
|
||||
@ -49,7 +49,7 @@ class DynamicPromptInvocation(BaseInvocation):
|
||||
combinatorial: bool = Field(default=False, description="Whether to use the combinatorial generator")
|
||||
|
||||
class Config(InvocationConfig):
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"ui": {"title": "Dynamic Prompt", "tags": ["prompt", "dynamic"]},
|
||||
}
|
||||
|
||||
@ -79,11 +79,11 @@ class PromptsFromFileInvocation(BaseInvocation):
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"ui": {"title": "Prompts From File", "tags": ["prompt", "file"]},
|
||||
}
|
||||
|
||||
@field_validator("file_path")
|
||||
@validator("file_path")
|
||||
def file_path_exists(cls, v):
|
||||
if not exists(v):
|
||||
raise ValueError(FileNotFoundError)
|
||||
|
@ -3,7 +3,7 @@ import inspect
|
||||
from tqdm import tqdm
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
from pydantic import Field, validator
|
||||
|
||||
from ...backend.model_management import ModelType, SubModelType, ModelPatcher
|
||||
from invokeai.app.util.step_callback import stable_diffusion_xl_step_callback
|
||||
@ -49,7 +49,7 @@ class SDXLModelLoaderInvocation(BaseInvocation):
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "SDXL Model Loader",
|
||||
"tags": ["model", "loader", "sdxl"],
|
||||
@ -139,7 +139,7 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "SDXL Refiner Model Loader",
|
||||
"tags": ["model", "loader", "sdxl_refiner"],
|
||||
@ -224,7 +224,7 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
|
||||
# seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
||||
# fmt: on
|
||||
|
||||
@field_validator("cfg_scale")
|
||||
@validator("cfg_scale")
|
||||
def ge_one(cls, v):
|
||||
"""validate that all cfg_scale values are >= 1"""
|
||||
if isinstance(v, list):
|
||||
@ -238,7 +238,7 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "SDXL Text To Latents",
|
||||
"tags": ["latents"],
|
||||
@ -260,7 +260,7 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
|
||||
) -> None:
|
||||
stable_diffusion_xl_step_callback(
|
||||
context=context,
|
||||
node=self.model_dump(),
|
||||
node=self.dict(),
|
||||
source_node_id=source_node_id,
|
||||
sample=sample,
|
||||
step=step,
|
||||
@ -303,7 +303,7 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
|
||||
del lora_info
|
||||
return
|
||||
|
||||
unet_info = context.services.model_manager.get_model(**self.unet.unet.model_dump(), context=context)
|
||||
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict(), context=context)
|
||||
do_classifier_free_guidance = True
|
||||
cross_attention_kwargs = None
|
||||
with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()), unet_info as unet:
|
||||
@ -481,7 +481,7 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
|
||||
# seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
||||
# fmt: on
|
||||
|
||||
@field_validator("cfg_scale")
|
||||
@validator("cfg_scale")
|
||||
def ge_one(cls, v):
|
||||
"""validate that all cfg_scale values are >= 1"""
|
||||
if isinstance(v, list):
|
||||
@ -495,7 +495,7 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "SDXL Latents to Latents",
|
||||
"tags": ["latents"],
|
||||
@ -517,7 +517,7 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
|
||||
) -> None:
|
||||
stable_diffusion_xl_step_callback(
|
||||
context=context,
|
||||
node=self.model_dump(),
|
||||
node=self.dict(),
|
||||
source_node_id=source_node_id,
|
||||
sample=sample,
|
||||
step=step,
|
||||
@ -549,7 +549,7 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
unet_info = context.services.model_manager.get_model(
|
||||
**self.unet.unet.model_dump(),
|
||||
**self.unet.unet.dict(),
|
||||
context=context,
|
||||
)
|
||||
|
||||
|
@ -32,7 +32,7 @@ class ESRGANInvocation(BaseInvocation):
|
||||
model_name: ESRGAN_MODELS = Field(default="RealESRGAN_x4plus.pth", description="The Real-ESRGAN model to use")
|
||||
|
||||
class Config(InvocationConfig):
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"ui": {"title": "Upscale (RealESRGAN)", "tags": ["image", "upscale", "realesrgan"]},
|
||||
}
|
||||
|
||||
|
@ -15,7 +15,7 @@ class ImageField(BaseModel):
|
||||
image_name: Optional[str] = Field(default=None, description="The name of the image")
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {"required": ["image_name"]}
|
||||
schema_extra = {"required": ["image_name"]}
|
||||
|
||||
|
||||
class ColorField(BaseModel):
|
||||
@ -40,7 +40,7 @@ class PILInvocationConfig(BaseModel):
|
||||
"""Helper class to provide all PIL invocations with additional config"""
|
||||
|
||||
class Config(InvocationConfig):
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["PIL", "image"],
|
||||
},
|
||||
@ -58,7 +58,7 @@ class ImageOutput(BaseInvocationOutput):
|
||||
# fmt: on
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {"required": ["type", "image", "width", "height"]}
|
||||
schema_extra = {"required": ["type", "image", "width", "height"]}
|
||||
|
||||
|
||||
class MaskOutput(BaseInvocationOutput):
|
||||
@ -72,7 +72,7 @@ class MaskOutput(BaseInvocationOutput):
|
||||
# fmt: on
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"required": [
|
||||
"type",
|
||||
"mask",
|
||||
|
@ -166,8 +166,7 @@ import sys
|
||||
from argparse import ArgumentParser
|
||||
from omegaconf import OmegaConf, DictConfig, ListConfig
|
||||
from pathlib import Path
|
||||
from pydantic import Field, parse_obj_as
|
||||
from pydantic_settings import BaseSettings
|
||||
from pydantic import BaseSettings, Field, parse_obj_as
|
||||
from typing import ClassVar, Dict, List, Set, Literal, Union, get_origin, get_type_hints, get_args
|
||||
|
||||
INIT_FILE = Path("invokeai.yaml")
|
||||
@ -188,7 +187,7 @@ class InvokeAISettings(BaseSettings):
|
||||
def parse_args(self, argv: list = sys.argv[1:]):
|
||||
parser = self.get_parser()
|
||||
opt = parser.parse_args(argv)
|
||||
for name in self.model_fields:
|
||||
for name in self.__fields__:
|
||||
if name not in self._excluded():
|
||||
value = getattr(opt, name)
|
||||
if isinstance(value, ListConfig):
|
||||
@ -205,7 +204,7 @@ class InvokeAISettings(BaseSettings):
|
||||
cls = self.__class__
|
||||
type = get_args(get_type_hints(cls)["type"])[0]
|
||||
field_dict = dict({type: dict()})
|
||||
for name, field in self.model_fields.items():
|
||||
for name, field in self.__fields__.items():
|
||||
if name in cls._excluded_from_yaml():
|
||||
continue
|
||||
category = field.field_info.extra.get("category") or "Uncategorized"
|
||||
@ -239,7 +238,7 @@ class InvokeAISettings(BaseSettings):
|
||||
for key, value in os.environ.items():
|
||||
upcase_environ[key.upper()] = value
|
||||
|
||||
fields = cls.model_fields
|
||||
fields = cls.__fields__
|
||||
cls.argparse_groups = {}
|
||||
|
||||
for name, field in fields.items():
|
||||
|
@ -44,7 +44,7 @@ class EventServiceBase:
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
node=node,
|
||||
source_node_id=source_node_id,
|
||||
progress_image=progress_image.model_dump() if progress_image is not None else None,
|
||||
progress_image=progress_image.dict() if progress_image is not None else None,
|
||||
step=step,
|
||||
total_steps=total_steps,
|
||||
),
|
||||
|
@ -15,7 +15,7 @@ from typing import (
|
||||
)
|
||||
|
||||
import networkx as nx
|
||||
from pydantic import BaseModel, model_validator, field_validator
|
||||
from pydantic import BaseModel, root_validator, validator
|
||||
from pydantic.fields import Field
|
||||
|
||||
from ..invocations import *
|
||||
@ -156,7 +156,7 @@ class GraphInvocationOutput(BaseInvocationOutput):
|
||||
type: Literal["graph_output"] = "graph_output"
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"required": [
|
||||
"type",
|
||||
"image",
|
||||
@ -186,7 +186,7 @@ class IterateInvocationOutput(BaseInvocationOutput):
|
||||
item: Any = Field(description="The item being iterated over")
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"required": [
|
||||
"type",
|
||||
"item",
|
||||
@ -214,7 +214,7 @@ class CollectInvocationOutput(BaseInvocationOutput):
|
||||
collection: list[Any] = Field(description="The collection of input items")
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"required": [
|
||||
"type",
|
||||
"collection",
|
||||
@ -755,7 +755,7 @@ class GraphExecutionState(BaseModel):
|
||||
)
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
schema_extra = {
|
||||
"required": [
|
||||
"id",
|
||||
"graph",
|
||||
@ -1110,13 +1110,13 @@ class LibraryGraph(BaseModel):
|
||||
description="The outputs exposed by this graph", default_factory=list
|
||||
)
|
||||
|
||||
@field_validator("exposed_inputs", "exposed_outputs")
|
||||
@validator("exposed_inputs", "exposed_outputs")
|
||||
def validate_exposed_aliases(cls, v):
|
||||
if len(v) != len(set(i.alias for i in v)):
|
||||
raise ValueError("Duplicate exposed alias")
|
||||
return v
|
||||
|
||||
@model_validator
|
||||
@root_validator
|
||||
def validate_exposed_nodes(cls, values):
|
||||
graph = values["graph"]
|
||||
|
||||
|
@ -2,11 +2,12 @@ from abc import ABC, abstractmethod
|
||||
from typing import Callable, Generic, Optional, TypeVar
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic.generics import GenericModel
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
class PaginatedResults(BaseModel, Generic[T]):
|
||||
class PaginatedResults(GenericModel, Generic[T]):
|
||||
"""Paginated results"""
|
||||
|
||||
# fmt: off
|
||||
|
@ -238,7 +238,7 @@ class ModelManagerServiceBase(ABC):
|
||||
def merge_models(
|
||||
self,
|
||||
model_names: List[str] = Field(
|
||||
default=None, min_length=2, max_length=3, description="List of model names to merge"
|
||||
default=None, min_items=2, max_items=3, description="List of model names to merge"
|
||||
),
|
||||
base_model: Union[BaseModelType, str] = Field(
|
||||
default=None, description="Base model shared by all models to be merged"
|
||||
@ -568,7 +568,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
def merge_models(
|
||||
self,
|
||||
model_names: List[str] = Field(
|
||||
default=None, min_length=2, max_length=3, description="List of model names to merge"
|
||||
default=None, min_items=2, max_items=3, description="List of model names to merge"
|
||||
),
|
||||
base_model: Union[BaseModelType, str] = Field(
|
||||
default=None, description="Base model shared by all models to be merged"
|
||||
|
@ -89,7 +89,7 @@ def image_record_to_dto(
|
||||
) -> ImageDTO:
|
||||
"""Converts an image record to an image DTO."""
|
||||
return ImageDTO(
|
||||
**image_record.model_dump(),
|
||||
**image_record.dict(),
|
||||
image_url=image_url,
|
||||
thumbnail_url=thumbnail_url,
|
||||
board_id=board_id,
|
||||
|
@ -80,7 +80,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
# Send starting event
|
||||
self.__invoker.services.events.emit_invocation_started(
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
node=invocation.model_dump(),
|
||||
node=invocation.dict(),
|
||||
source_node_id=source_node_id,
|
||||
)
|
||||
|
||||
@ -107,9 +107,9 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
# Send complete event
|
||||
self.__invoker.services.events.emit_invocation_complete(
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
node=invocation.model_dump(),
|
||||
node=invocation.dict(),
|
||||
source_node_id=source_node_id,
|
||||
result=outputs.model_dump(),
|
||||
result=outputs.dict(),
|
||||
)
|
||||
statistics.log_stats()
|
||||
|
||||
@ -134,7 +134,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
# Send error event
|
||||
self.__invoker.services.events.emit_invocation_error(
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
node=invocation.model_dump(),
|
||||
node=invocation.dict(),
|
||||
source_node_id=source_node_id,
|
||||
error_type=e.__class__.__name__,
|
||||
error=error,
|
||||
@ -155,7 +155,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
self.__invoker.services.logger.error("Error while invoking:\n%s" % e)
|
||||
self.__invoker.services.events.emit_invocation_error(
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
node=invocation.model_dump(),
|
||||
node=invocation.dict(),
|
||||
source_node_id=source_node_id,
|
||||
error_type=e.__class__.__name__,
|
||||
error=traceback.format_exc(),
|
||||
|
@ -1,11 +1,25 @@
|
||||
"""
|
||||
invokeai.backend.generator.img2img descends from .generator
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from accelerate.utils import set_seed
|
||||
from diffusers import logging
|
||||
|
||||
from ..stable_diffusion import (
|
||||
ConditioningData,
|
||||
PostprocessingSettings,
|
||||
StableDiffusionGeneratorPipeline,
|
||||
)
|
||||
from .base import Generator
|
||||
|
||||
|
||||
class Img2Img(Generator):
|
||||
def __init__(self, model, precision):
|
||||
super().__init__(model, precision)
|
||||
self.init_latent = None # by get_noise()
|
||||
|
||||
def get_make_image(
|
||||
self,
|
||||
sampler,
|
||||
@ -28,4 +42,51 @@ class Img2Img(Generator):
|
||||
Returns a function returning an image derived from the prompt and the initial image
|
||||
Return value depends on the seed at the time you call it.
|
||||
"""
|
||||
raise NotImplementedError("replaced by invokeai.app.invocations.latent.LatentsToLatentsInvocation")
|
||||
self.perlin = perlin
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
pipeline: StableDiffusionGeneratorPipeline = self.model
|
||||
pipeline.scheduler = sampler
|
||||
|
||||
uc, c, extra_conditioning_info = conditioning
|
||||
conditioning_data = ConditioningData(
|
||||
uc,
|
||||
c,
|
||||
cfg_scale,
|
||||
extra_conditioning_info,
|
||||
postprocessing_settings=PostprocessingSettings(
|
||||
threshold=threshold,
|
||||
warmup=warmup,
|
||||
h_symmetry_time_pct=h_symmetry_time_pct,
|
||||
v_symmetry_time_pct=v_symmetry_time_pct,
|
||||
),
|
||||
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)
|
||||
|
||||
def make_image(x_T: torch.Tensor, seed: int):
|
||||
# FIXME: use x_T for initial seeded noise
|
||||
# We're not at the moment because the pipeline automatically resizes init_image if
|
||||
# necessary, which the x_T input might not match.
|
||||
# In the meantime, reset the seed prior to generating pipeline output so we at least get the same result.
|
||||
logging.set_verbosity_error() # quench safety check warnings
|
||||
pipeline_output = pipeline.img2img_from_embeddings(
|
||||
init_image,
|
||||
strength,
|
||||
steps,
|
||||
conditioning_data,
|
||||
noise_func=self.get_noise_like,
|
||||
callback=step_callback,
|
||||
seed=seed,
|
||||
)
|
||||
if pipeline_output.attention_map_saver is not None and attention_maps_callback is not None:
|
||||
attention_maps_callback(pipeline_output.attention_map_saver)
|
||||
return pipeline.numpy_to_pil(pipeline_output.images)[0]
|
||||
|
||||
return make_image
|
||||
|
||||
def get_noise_like(self, like: torch.Tensor):
|
||||
device = like.device
|
||||
x = torch.randn_like(like, device=device)
|
||||
if self.perlin > 0.0:
|
||||
shape = like.shape
|
||||
x = (1 - self.perlin) * x + self.perlin * self.get_perlin_noise(shape[3], shape[2])
|
||||
return x
|
||||
|
@ -377,11 +377,3 @@ class Inpaint(Img2Img):
|
||||
)
|
||||
|
||||
return corrected_result
|
||||
|
||||
def get_noise_like(self, like: torch.Tensor):
|
||||
device = like.device
|
||||
x = torch.randn_like(like, device=device)
|
||||
if self.perlin > 0.0:
|
||||
shape = like.shape
|
||||
x = (1 - self.perlin) * x + self.perlin * self.get_perlin_noise(shape[3], shape[2])
|
||||
return x
|
||||
|
@ -21,6 +21,7 @@ from argparse import Namespace
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from shutil import get_terminal_size
|
||||
from typing import get_type_hints
|
||||
from urllib import request
|
||||
|
||||
import npyscreen
|
||||
@ -395,23 +396,13 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
|
||||
max_width=80,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely += 1
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.TitleFixedText,
|
||||
name="RAM cache size (GB). Make this at least large enough to hold a single full model.",
|
||||
begin_entry_at=0,
|
||||
editable=False,
|
||||
color="CONTROL",
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely -= 1
|
||||
self.max_cache_size = self.add_widget_intelligent(
|
||||
npyscreen.Slider,
|
||||
value=clip(old_opts.max_cache_size, range=(3.0, MAX_RAM), step=0.5),
|
||||
out_of=round(MAX_RAM),
|
||||
lowest=0.0,
|
||||
step=0.5,
|
||||
relx=8,
|
||||
IntTitleSlider,
|
||||
name="RAM cache size (GB). Make this at least large enough to hold a single full model.",
|
||||
value=old_opts.max_cache_size,
|
||||
out_of=MAX_RAM,
|
||||
lowest=3,
|
||||
begin_entry_at=6,
|
||||
scroll_exit=True,
|
||||
)
|
||||
if HAS_CUDA:
|
||||
@ -427,7 +418,7 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
|
||||
self.nextrely -= 1
|
||||
self.max_vram_cache_size = self.add_widget_intelligent(
|
||||
npyscreen.Slider,
|
||||
value=clip(old_opts.max_vram_cache_size, range=(0, MAX_VRAM), step=0.25),
|
||||
value=old_opts.max_vram_cache_size,
|
||||
out_of=round(MAX_VRAM * 2) / 2,
|
||||
lowest=0.0,
|
||||
relx=8,
|
||||
@ -605,16 +596,6 @@ def default_user_selections(program_opts: Namespace) -> InstallSelections:
|
||||
)
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def clip(value: float, range: tuple[float, float], step: float) -> float:
|
||||
minimum, maximum = range
|
||||
if value < minimum:
|
||||
value = minimum
|
||||
if value > maximum:
|
||||
value = maximum
|
||||
return round(value / step) * step
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def initialize_rootdir(root: Path, yes_to_all: bool = False):
|
||||
logger.info("Initializing InvokeAI runtime directory")
|
||||
@ -716,7 +697,7 @@ def migrate_init_file(legacy_format: Path):
|
||||
old = legacy_parser.parse_args([f"@{str(legacy_format)}"])
|
||||
new = InvokeAIAppConfig.get_config()
|
||||
|
||||
fields = [x for x, y in InvokeAIAppConfig.model_fields.items() if y.field_info.extra.get("category") != "DEPRECATED"]
|
||||
fields = [x for x, y in InvokeAIAppConfig.__fields__.items() if y.field_info.extra.get("category") != "DEPRECATED"]
|
||||
for attr in fields:
|
||||
if hasattr(old, attr):
|
||||
try:
|
||||
|
@ -591,6 +591,7 @@ script, which will perform a full upgrade in place.""",
|
||||
# TODO: revisit - don't rely on invokeai.yaml to exist yet!
|
||||
dest_is_setup = (dest_root / "models/core").exists() and (dest_root / "databases").exists()
|
||||
if not dest_is_setup:
|
||||
import invokeai.frontend.install.invokeai_configure
|
||||
from invokeai.backend.install.invokeai_configure import initialize_rootdir
|
||||
|
||||
initialize_rootdir(dest_root, True)
|
||||
|
@ -143,7 +143,7 @@ class ModelPatcher:
|
||||
# with torch.autocast(device_type="cpu"):
|
||||
layer.to(dtype=torch.float32)
|
||||
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
|
||||
layer_weight = layer.get_weight(original_weights[module_key]) * lora_weight * layer_scale
|
||||
layer_weight = layer.get_weight() * lora_weight * layer_scale
|
||||
|
||||
if module.weight.shape != layer_weight.shape:
|
||||
# TODO: debug on lycoris
|
||||
@ -361,8 +361,7 @@ class ONNXModelPatcher:
|
||||
|
||||
layer.to(dtype=torch.float32)
|
||||
layer_key = layer_key.replace(prefix, "")
|
||||
# TODO: rewrite to pass original tensor weight(required by ia3)
|
||||
layer_weight = layer.get_weight(None).detach().cpu().numpy() * lora_weight
|
||||
layer_weight = layer.get_weight().detach().cpu().numpy() * lora_weight
|
||||
if layer_key is blended_loras:
|
||||
blended_loras[layer_key] += layer_weight
|
||||
else:
|
||||
|
@ -526,7 +526,7 @@ class ModelManager(object):
|
||||
# Does the config explicitly override the submodel?
|
||||
if submodel_type is not None and hasattr(model_config, submodel_type):
|
||||
submodel_path = getattr(model_config, submodel_type)
|
||||
if submodel_path is not None and len(submodel_path) > 0:
|
||||
if submodel_path is not None:
|
||||
model_path = getattr(model_config, submodel_type)
|
||||
is_submodel_override = True
|
||||
|
||||
@ -897,7 +897,7 @@ class ModelManager(object):
|
||||
Write current configuration out to the indicated file.
|
||||
"""
|
||||
data_to_save = dict()
|
||||
data_to_save["__metadata__"] = self.config_meta.model_dump()
|
||||
data_to_save["__metadata__"] = self.config_meta.dict()
|
||||
|
||||
for model_key, model_config in self.models.items():
|
||||
model_name, base_model, model_type = self.parse_key(model_key)
|
||||
|
@ -17,7 +17,6 @@ from .models import (
|
||||
SilenceWarnings,
|
||||
InvalidModelException,
|
||||
)
|
||||
from .util import lora_token_vector_length
|
||||
from .models.base import read_checkpoint_meta
|
||||
|
||||
|
||||
@ -316,16 +315,38 @@ class LoRACheckpointProbe(CheckpointProbeBase):
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
checkpoint = self.checkpoint
|
||||
token_vector_length = lora_token_vector_length(checkpoint)
|
||||
|
||||
if token_vector_length == 768:
|
||||
return BaseModelType.StableDiffusion1
|
||||
elif token_vector_length == 1024:
|
||||
# SD-2 models are very hard to probe. These probes are brittle and likely to fail in the future
|
||||
# There are also some "SD-2 LoRAs" that have identical keys and shapes to SD-1 and will be
|
||||
# misclassified as SD-1
|
||||
key = "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight"
|
||||
if key in checkpoint and checkpoint[key].shape[0] == 320:
|
||||
return BaseModelType.StableDiffusion2
|
||||
elif token_vector_length == 2048:
|
||||
|
||||
key = "lora_unet_output_blocks_5_1_transformer_blocks_1_ff_net_2.lora_up.weight"
|
||||
if key in checkpoint:
|
||||
return BaseModelType.StableDiffusionXL
|
||||
|
||||
key1 = "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight"
|
||||
key2 = "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"
|
||||
key3 = "lora_te_text_model_encoder_layers_0_self_attn_k_proj.hada_w1_a"
|
||||
|
||||
lora_token_vector_length = (
|
||||
checkpoint[key1].shape[1]
|
||||
if key1 in checkpoint
|
||||
else checkpoint[key2].shape[1]
|
||||
if key2 in checkpoint
|
||||
else checkpoint[key3].shape[0]
|
||||
if key3 in checkpoint
|
||||
else None
|
||||
)
|
||||
|
||||
if lora_token_vector_length == 768:
|
||||
return BaseModelType.StableDiffusion1
|
||||
elif lora_token_vector_length == 1024:
|
||||
return BaseModelType.StableDiffusion2
|
||||
else:
|
||||
raise InvalidModelException(f"Unknown LoRA type: {self.checkpoint_path}")
|
||||
raise InvalidModelException(f"Unknown LoRA type")
|
||||
|
||||
|
||||
class TextualInversionCheckpointProbe(CheckpointProbeBase):
|
||||
|
@ -1,21 +1,18 @@
|
||||
import bisect
|
||||
import os
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
from enum import Enum
|
||||
from typing import Optional, Dict, Union, Literal, Any
|
||||
from pathlib import Path
|
||||
from safetensors.torch import load_file
|
||||
|
||||
from .base import (
|
||||
BaseModelType,
|
||||
InvalidModelException,
|
||||
ModelBase,
|
||||
ModelConfigBase,
|
||||
ModelNotFoundException,
|
||||
BaseModelType,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
classproperty,
|
||||
InvalidModelException,
|
||||
ModelNotFoundException,
|
||||
)
|
||||
|
||||
|
||||
@ -125,7 +122,41 @@ class LoRALayerBase:
|
||||
self.rank = None # set in layer implementation
|
||||
self.layer_key = layer_key
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor):
|
||||
def forward(
|
||||
self,
|
||||
module: torch.nn.Module,
|
||||
input_h: Any, # for real looks like Tuple[torch.nn.Tensor] but not sure
|
||||
multiplier: float,
|
||||
):
|
||||
if type(module) == torch.nn.Conv2d:
|
||||
op = torch.nn.functional.conv2d
|
||||
extra_args = dict(
|
||||
stride=module.stride,
|
||||
padding=module.padding,
|
||||
dilation=module.dilation,
|
||||
groups=module.groups,
|
||||
)
|
||||
|
||||
else:
|
||||
op = torch.nn.functional.linear
|
||||
extra_args = {}
|
||||
|
||||
weight = self.get_weight()
|
||||
|
||||
bias = self.bias if self.bias is not None else 0
|
||||
scale = self.alpha / self.rank if (self.alpha and self.rank) else 1.0
|
||||
return (
|
||||
op(
|
||||
*input_h,
|
||||
(weight + bias).view(module.weight.shape),
|
||||
None,
|
||||
**extra_args,
|
||||
)
|
||||
* multiplier
|
||||
* scale
|
||||
)
|
||||
|
||||
def get_weight(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def calc_size(self) -> int:
|
||||
@ -166,7 +197,7 @@ class LoRALayer(LoRALayerBase):
|
||||
|
||||
self.rank = self.down.shape[0]
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor):
|
||||
def get_weight(self):
|
||||
if self.mid is not None:
|
||||
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
|
||||
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
|
||||
@ -229,7 +260,7 @@ class LoHALayer(LoRALayerBase):
|
||||
|
||||
self.rank = self.w1_b.shape[0]
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor):
|
||||
def get_weight(self):
|
||||
if self.t1 is None:
|
||||
weight = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
|
||||
|
||||
@ -311,7 +342,7 @@ class LoKRLayer(LoRALayerBase):
|
||||
else:
|
||||
self.rank = None # unscaled
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor):
|
||||
def get_weight(self):
|
||||
w1 = self.w1
|
||||
if w1 is None:
|
||||
w1 = self.w1_a @ self.w1_b
|
||||
@ -379,7 +410,7 @@ class FullLayer(LoRALayerBase):
|
||||
|
||||
self.rank = None # unscaled
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor):
|
||||
def get_weight(self):
|
||||
return self.weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
@ -397,45 +428,6 @@ class FullLayer(LoRALayerBase):
|
||||
self.weight = self.weight.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
class IA3Layer(LoRALayerBase):
|
||||
# weight: torch.Tensor
|
||||
# on_input: torch.Tensor
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: dict,
|
||||
):
|
||||
super().__init__(layer_key, values)
|
||||
|
||||
self.weight = values["weight"]
|
||||
self.on_input = values["on_input"]
|
||||
|
||||
self.rank = None # unscaled
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor):
|
||||
weight = self.weight
|
||||
if not self.on_input:
|
||||
weight = weight.reshape(-1, 1)
|
||||
return orig_weight * weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
model_size += self.weight.nelement() * self.weight.element_size()
|
||||
model_size += self.on_input.nelement() * self.on_input.element_size()
|
||||
return model_size
|
||||
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.weight = self.weight.to(device=device, dtype=dtype)
|
||||
self.on_input = self.on_input.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
# TODO: rename all methods used in model logic with Info postfix and remove here Raw postfix
|
||||
class LoRAModelRaw: # (torch.nn.Module):
|
||||
_name: str
|
||||
@ -485,61 +477,30 @@ class LoRAModelRaw: # (torch.nn.Module):
|
||||
return model_size
|
||||
|
||||
@classmethod
|
||||
def _convert_sdxl_keys_to_diffusers_format(cls, state_dict):
|
||||
"""Convert the keys of an SDXL LoRA state_dict to diffusers format.
|
||||
|
||||
The input state_dict can be in either Stability AI format or diffusers format. If the state_dict is already in
|
||||
diffusers format, then this function will have no effect.
|
||||
|
||||
This function is adapted from:
|
||||
https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L385-L409
|
||||
|
||||
Args:
|
||||
state_dict (Dict[str, Tensor]): The SDXL LoRA state_dict.
|
||||
|
||||
Raises:
|
||||
ValueError: If state_dict contains an unrecognized key, or not all keys could be converted.
|
||||
|
||||
Returns:
|
||||
Dict[str, Tensor]: The diffusers-format state_dict.
|
||||
"""
|
||||
converted_count = 0 # The number of Stability AI keys converted to diffusers format.
|
||||
not_converted_count = 0 # The number of keys that were not converted.
|
||||
|
||||
# Get a sorted list of Stability AI UNet keys so that we can efficiently search for keys with matching prefixes.
|
||||
# For example, we want to efficiently find `input_blocks_4_1` in the list when searching for
|
||||
# `input_blocks_4_1_proj_in`.
|
||||
stability_unet_keys = list(SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP)
|
||||
stability_unet_keys.sort()
|
||||
|
||||
def _convert_sdxl_compvis_keys(cls, state_dict):
|
||||
new_state_dict = dict()
|
||||
for full_key, value in state_dict.items():
|
||||
if full_key.startswith("lora_unet_"):
|
||||
search_key = full_key.replace("lora_unet_", "")
|
||||
# Use bisect to find the key in stability_unet_keys that *may* match the search_key's prefix.
|
||||
position = bisect.bisect_right(stability_unet_keys, search_key)
|
||||
map_key = stability_unet_keys[position - 1]
|
||||
# Now, check if the map_key *actually* matches the search_key.
|
||||
if search_key.startswith(map_key):
|
||||
new_key = full_key.replace(map_key, SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP[map_key])
|
||||
new_state_dict[new_key] = value
|
||||
converted_count += 1
|
||||
else:
|
||||
new_state_dict[full_key] = value
|
||||
not_converted_count += 1
|
||||
elif full_key.startswith("lora_te1_") or full_key.startswith("lora_te2_"):
|
||||
# The CLIP text encoders have the same keys in both Stability AI and diffusers formats.
|
||||
new_state_dict[full_key] = value
|
||||
continue
|
||||
else:
|
||||
raise ValueError(f"Unrecognized SDXL LoRA key prefix: '{full_key}'.")
|
||||
if full_key.startswith("lora_te1_") or full_key.startswith("lora_te2_"):
|
||||
continue # clip same
|
||||
|
||||
if converted_count > 0 and not_converted_count > 0:
|
||||
raise ValueError(
|
||||
f"The SDXL LoRA could only be partially converted to diffusers format. converted={converted_count},"
|
||||
f" not_converted={not_converted_count}"
|
||||
)
|
||||
if not full_key.startswith("lora_unet_"):
|
||||
raise NotImplementedError(f"Unknown prefix for sdxl lora key - {full_key}")
|
||||
src_key = full_key.replace("lora_unet_", "")
|
||||
try:
|
||||
dst_key = None
|
||||
while "_" in src_key:
|
||||
if src_key in SDXL_UNET_COMPVIS_MAP:
|
||||
dst_key = SDXL_UNET_COMPVIS_MAP[src_key]
|
||||
break
|
||||
src_key = "_".join(src_key.split("_")[:-1])
|
||||
|
||||
if dst_key is None:
|
||||
raise Exception(f"Unknown sdxl lora key - {full_key}")
|
||||
new_key = full_key.replace(src_key, dst_key)
|
||||
except:
|
||||
print(SDXL_UNET_COMPVIS_MAP)
|
||||
raise
|
||||
new_state_dict[new_key] = value
|
||||
return new_state_dict
|
||||
|
||||
@classmethod
|
||||
@ -571,7 +532,7 @@ class LoRAModelRaw: # (torch.nn.Module):
|
||||
state_dict = cls._group_state(state_dict)
|
||||
|
||||
if base_model == BaseModelType.StableDiffusionXL:
|
||||
state_dict = cls._convert_sdxl_keys_to_diffusers_format(state_dict)
|
||||
state_dict = cls._convert_sdxl_compvis_keys(state_dict)
|
||||
|
||||
for layer_key, values in state_dict.items():
|
||||
# lora and locon
|
||||
@ -586,15 +547,11 @@ class LoRAModelRaw: # (torch.nn.Module):
|
||||
elif "lokr_w1_b" in values or "lokr_w1" in values:
|
||||
layer = LoKRLayer(layer_key, values)
|
||||
|
||||
# diff
|
||||
elif "diff" in values:
|
||||
layer = FullLayer(layer_key, values)
|
||||
|
||||
# ia3
|
||||
elif "weight" in values and "on_input" in values:
|
||||
layer = IA3Layer(layer_key, values)
|
||||
|
||||
else:
|
||||
# TODO: ia3/... format
|
||||
print(f">> Encountered unknown lora layer module in {model.name}: {layer_key} - {list(values.keys())}")
|
||||
raise Exception("Unknown lora format!")
|
||||
|
||||
@ -622,7 +579,6 @@ class LoRAModelRaw: # (torch.nn.Module):
|
||||
# code from
|
||||
# https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L15C1-L97C32
|
||||
def make_sdxl_unet_conversion_map():
|
||||
"""Create a dict mapping state_dict keys from Stability AI SDXL format to diffusers SDXL format."""
|
||||
unet_conversion_map_layer = []
|
||||
|
||||
for i in range(3): # num_blocks is 3 in sdxl
|
||||
@ -706,6 +662,7 @@ def make_sdxl_unet_conversion_map():
|
||||
return unet_conversion_map
|
||||
|
||||
|
||||
SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP = {
|
||||
sd.rstrip(".").replace(".", "_"): hf.rstrip(".").replace(".", "_") for sd, hf in make_sdxl_unet_conversion_map()
|
||||
SDXL_UNET_COMPVIS_MAP = {
|
||||
f"{sd}".rstrip(".").replace(".", "_"): f"{hf}".rstrip(".").replace(".", "_")
|
||||
for sd, hf in make_sdxl_unet_conversion_map()
|
||||
}
|
||||
|
@ -1,5 +1,6 @@
|
||||
import os
|
||||
import json
|
||||
import invokeai.backend.util.logging as logger
|
||||
from enum import Enum
|
||||
from pydantic import Field
|
||||
from typing import Literal, Optional
|
||||
@ -11,7 +12,6 @@ from .base import (
|
||||
DiffusersModel,
|
||||
read_checkpoint_meta,
|
||||
classproperty,
|
||||
InvalidModelException,
|
||||
)
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
@ -65,7 +65,7 @@ class StableDiffusionXLModel(DiffusersModel):
|
||||
in_channels = unet_config["in_channels"]
|
||||
|
||||
else:
|
||||
raise InvalidModelException(f"{path} is not a recognized Stable Diffusion diffusers model")
|
||||
raise Exception("Not supported stable diffusion diffusers format(possibly onnx?)")
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown stable diffusion 2.* format: {model_format}")
|
||||
|
@ -1,75 +0,0 @@
|
||||
# Copyright (c) 2023 The InvokeAI Development Team
|
||||
"""Utilities used by the Model Manager"""
|
||||
|
||||
|
||||
def lora_token_vector_length(checkpoint: dict) -> int:
|
||||
"""
|
||||
Given a checkpoint in memory, return the lora token vector length
|
||||
|
||||
:param checkpoint: The checkpoint
|
||||
"""
|
||||
|
||||
def _get_shape_1(key, tensor, checkpoint):
|
||||
lora_token_vector_length = None
|
||||
|
||||
if "." not in key:
|
||||
return lora_token_vector_length # wrong key format
|
||||
model_key, lora_key = key.split(".", 1)
|
||||
|
||||
# check lora/locon
|
||||
if lora_key == "lora_down.weight":
|
||||
lora_token_vector_length = tensor.shape[1]
|
||||
|
||||
# check loha (don't worry about hada_t1/hada_t2 as it used only in 4d shapes)
|
||||
elif lora_key in ["hada_w1_b", "hada_w2_b"]:
|
||||
lora_token_vector_length = tensor.shape[1]
|
||||
|
||||
# check lokr (don't worry about lokr_t2 as it used only in 4d shapes)
|
||||
elif "lokr_" in lora_key:
|
||||
if model_key + ".lokr_w1" in checkpoint:
|
||||
_lokr_w1 = checkpoint[model_key + ".lokr_w1"]
|
||||
elif model_key + "lokr_w1_b" in checkpoint:
|
||||
_lokr_w1 = checkpoint[model_key + ".lokr_w1_b"]
|
||||
else:
|
||||
return lora_token_vector_length # unknown format
|
||||
|
||||
if model_key + ".lokr_w2" in checkpoint:
|
||||
_lokr_w2 = checkpoint[model_key + ".lokr_w2"]
|
||||
elif model_key + "lokr_w2_b" in checkpoint:
|
||||
_lokr_w2 = checkpoint[model_key + ".lokr_w2_b"]
|
||||
else:
|
||||
return lora_token_vector_length # unknown format
|
||||
|
||||
lora_token_vector_length = _lokr_w1.shape[1] * _lokr_w2.shape[1]
|
||||
|
||||
elif lora_key == "diff":
|
||||
lora_token_vector_length = tensor.shape[1]
|
||||
|
||||
# ia3 can be detected only by shape[0] in text encoder
|
||||
elif lora_key == "weight" and "lora_unet_" not in model_key:
|
||||
lora_token_vector_length = tensor.shape[0]
|
||||
|
||||
return lora_token_vector_length
|
||||
|
||||
lora_token_vector_length = None
|
||||
lora_te1_length = None
|
||||
lora_te2_length = None
|
||||
for key, tensor in checkpoint.items():
|
||||
if key.startswith("lora_unet_") and ("_attn2_to_k." in key or "_attn2_to_v." in key):
|
||||
lora_token_vector_length = _get_shape_1(key, tensor, checkpoint)
|
||||
elif key.startswith("lora_te") and "_self_attn_" in key:
|
||||
tmp_length = _get_shape_1(key, tensor, checkpoint)
|
||||
if key.startswith("lora_te_"):
|
||||
lora_token_vector_length = tmp_length
|
||||
elif key.startswith("lora_te1_"):
|
||||
lora_te1_length = tmp_length
|
||||
elif key.startswith("lora_te2_"):
|
||||
lora_te2_length = tmp_length
|
||||
|
||||
if lora_te1_length is not None and lora_te2_length is not None:
|
||||
lora_token_vector_length = lora_te1_length + lora_te2_length
|
||||
|
||||
if lora_token_vector_length is not None:
|
||||
break
|
||||
|
||||
return lora_token_vector_length
|
@ -4,21 +4,25 @@ import dataclasses
|
||||
import inspect
|
||||
import math
|
||||
import secrets
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Generic, List, Optional, Type, TypeVar, Union
|
||||
from pydantic import Field
|
||||
|
||||
import PIL.Image
|
||||
import einops
|
||||
import PIL.Image
|
||||
import numpy as np
|
||||
from accelerate.utils import set_seed
|
||||
import psutil
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
from accelerate.utils import set_seed
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.models.controlnet import ControlNetModel
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
|
||||
StableDiffusionPipeline,
|
||||
)
|
||||
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import (
|
||||
StableDiffusionImg2ImgPipeline,
|
||||
)
|
||||
@ -27,20 +31,21 @@ from diffusers.pipelines.stable_diffusion.safety_checker import (
|
||||
)
|
||||
from diffusers.schedulers import KarrasDiffusionSchedulers
|
||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
|
||||
from diffusers.utils import PIL_INTERPOLATION
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.outputs import BaseOutput
|
||||
from pydantic import Field
|
||||
from torchvision.transforms.functional import resize as tv_resize
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from ..util import CPU_DEVICE, normalize_device
|
||||
from .diffusion import (
|
||||
AttentionMapSaver,
|
||||
InvokeAIDiffuserComponent,
|
||||
PostprocessingSettings,
|
||||
)
|
||||
from ..util import normalize_device
|
||||
from .offloading import FullyLoadedModelGroup, ModelGroup
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -284,6 +289,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
feature_extractor ([`CLIPFeatureExtractor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
_model_group: ModelGroup
|
||||
|
||||
ID_LENGTH = 8
|
||||
|
||||
def __init__(
|
||||
@ -296,7 +303,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
safety_checker: Optional[StableDiffusionSafetyChecker],
|
||||
feature_extractor: Optional[CLIPFeatureExtractor],
|
||||
requires_safety_checker: bool = False,
|
||||
precision: str = "float32",
|
||||
control_model: ControlNetModel = None,
|
||||
execution_device: Optional[torch.device] = None,
|
||||
):
|
||||
super().__init__(
|
||||
vae,
|
||||
@ -321,6 +330,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
# control_model=control_model,
|
||||
)
|
||||
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet, self._unet_forward)
|
||||
|
||||
self._model_group = FullyLoadedModelGroup(execution_device or self.unet.device)
|
||||
self._model_group.install(*self._submodels)
|
||||
self.control_model = control_model
|
||||
|
||||
def _adjust_memory_efficient_attention(self, latents: torch.Tensor):
|
||||
@ -356,6 +368,72 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
else:
|
||||
self.disable_attention_slicing()
|
||||
|
||||
def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings=False):
|
||||
# overridden method; types match the superclass.
|
||||
if torch_device is None:
|
||||
return self
|
||||
self._model_group.set_device(torch.device(torch_device))
|
||||
self._model_group.ready()
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return self._model_group.execution_device
|
||||
|
||||
@property
|
||||
def _submodels(self) -> Sequence[torch.nn.Module]:
|
||||
module_names, _, _ = self.extract_init_dict(dict(self.config))
|
||||
submodels = []
|
||||
for name in module_names.keys():
|
||||
if hasattr(self, name):
|
||||
value = getattr(self, name)
|
||||
else:
|
||||
value = getattr(self.config, name)
|
||||
if isinstance(value, torch.nn.Module):
|
||||
submodels.append(value)
|
||||
return submodels
|
||||
|
||||
def image_from_embeddings(
|
||||
self,
|
||||
latents: torch.Tensor,
|
||||
num_inference_steps: int,
|
||||
conditioning_data: ConditioningData,
|
||||
*,
|
||||
noise: torch.Tensor,
|
||||
callback: Callable[[PipelineIntermediateState], None] = None,
|
||||
run_id=None,
|
||||
) -> InvokeAIStableDiffusionPipelineOutput:
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
:param conditioning_data:
|
||||
:param latents: Pre-generated un-noised latents, to be used as inputs for
|
||||
image generation. Can be used to tweak the same generation with different prompts.
|
||||
:param num_inference_steps: The number of denoising steps. More denoising steps usually lead to a higher quality
|
||||
image at the expense of slower inference.
|
||||
:param noise: Noise to add to the latents, sampled from a Gaussian distribution.
|
||||
:param callback:
|
||||
:param run_id:
|
||||
"""
|
||||
result_latents, result_attention_map_saver = self.latents_from_embeddings(
|
||||
latents,
|
||||
num_inference_steps,
|
||||
conditioning_data,
|
||||
noise=noise,
|
||||
run_id=run_id,
|
||||
callback=callback,
|
||||
)
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
with torch.inference_mode():
|
||||
image = self.decode_latents(result_latents)
|
||||
output = InvokeAIStableDiffusionPipelineOutput(
|
||||
images=image,
|
||||
nsfw_content_detected=[],
|
||||
attention_map_saver=result_attention_map_saver,
|
||||
)
|
||||
return self.check_for_safety(output, dtype=conditioning_data.dtype)
|
||||
|
||||
def latents_from_embeddings(
|
||||
self,
|
||||
latents: torch.Tensor,
|
||||
@ -372,7 +450,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
if self.scheduler.config.get("cpu_only", False):
|
||||
scheduler_device = torch.device("cpu")
|
||||
else:
|
||||
scheduler_device = self.unet.device
|
||||
scheduler_device = self._model_group.device_for(self.unet)
|
||||
|
||||
if timesteps is None:
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=scheduler_device)
|
||||
@ -426,7 +504,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
(batch_size,),
|
||||
timesteps[0],
|
||||
dtype=timesteps.dtype,
|
||||
device=self.unet.device,
|
||||
device=self._model_group.device_for(self.unet),
|
||||
)
|
||||
latents = self.scheduler.add_noise(latents, noise, batched_t)
|
||||
|
||||
@ -622,6 +700,79 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
**kwargs,
|
||||
).sample
|
||||
|
||||
def img2img_from_embeddings(
|
||||
self,
|
||||
init_image: Union[torch.FloatTensor, PIL.Image.Image],
|
||||
strength: float,
|
||||
num_inference_steps: int,
|
||||
conditioning_data: ConditioningData,
|
||||
*,
|
||||
callback: Callable[[PipelineIntermediateState], None] = None,
|
||||
run_id=None,
|
||||
noise_func=None,
|
||||
seed=None,
|
||||
) -> InvokeAIStableDiffusionPipelineOutput:
|
||||
if isinstance(init_image, PIL.Image.Image):
|
||||
init_image = image_resized_to_grid_as_tensor(init_image.convert("RGB"))
|
||||
|
||||
if init_image.dim() == 3:
|
||||
init_image = einops.rearrange(init_image, "c h w -> 1 c h w")
|
||||
|
||||
# 6. Prepare latent variables
|
||||
initial_latents = self.non_noised_latents_from_image(
|
||||
init_image,
|
||||
device=self._model_group.device_for(self.unet),
|
||||
dtype=self.unet.dtype,
|
||||
)
|
||||
if seed is not None:
|
||||
set_seed(seed)
|
||||
noise = noise_func(initial_latents)
|
||||
|
||||
return self.img2img_from_latents_and_embeddings(
|
||||
initial_latents,
|
||||
num_inference_steps,
|
||||
conditioning_data,
|
||||
strength,
|
||||
noise,
|
||||
run_id,
|
||||
callback,
|
||||
)
|
||||
|
||||
def img2img_from_latents_and_embeddings(
|
||||
self,
|
||||
initial_latents,
|
||||
num_inference_steps,
|
||||
conditioning_data: ConditioningData,
|
||||
strength,
|
||||
noise: torch.Tensor,
|
||||
run_id=None,
|
||||
callback=None,
|
||||
) -> InvokeAIStableDiffusionPipelineOutput:
|
||||
timesteps, _ = self.get_img2img_timesteps(num_inference_steps, strength)
|
||||
result_latents, result_attention_maps = self.latents_from_embeddings(
|
||||
latents=initial_latents
|
||||
if strength < 1.0
|
||||
else torch.zeros_like(initial_latents, device=initial_latents.device, dtype=initial_latents.dtype),
|
||||
num_inference_steps=num_inference_steps,
|
||||
conditioning_data=conditioning_data,
|
||||
timesteps=timesteps,
|
||||
noise=noise,
|
||||
run_id=run_id,
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
with torch.inference_mode():
|
||||
image = self.decode_latents(result_latents)
|
||||
output = InvokeAIStableDiffusionPipelineOutput(
|
||||
images=image,
|
||||
nsfw_content_detected=[],
|
||||
attention_map_saver=result_attention_maps,
|
||||
)
|
||||
return self.check_for_safety(output, dtype=conditioning_data.dtype)
|
||||
|
||||
def get_img2img_timesteps(self, num_inference_steps: int, strength: float, device=None) -> (torch.Tensor, int):
|
||||
img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components)
|
||||
assert img2img_pipeline.scheduler is self.scheduler
|
||||
@ -629,7 +780,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
if self.scheduler.config.get("cpu_only", False):
|
||||
scheduler_device = torch.device("cpu")
|
||||
else:
|
||||
scheduler_device = self.unet.device
|
||||
scheduler_device = self._model_group.device_for(self.unet)
|
||||
|
||||
img2img_pipeline.scheduler.set_timesteps(num_inference_steps, device=scheduler_device)
|
||||
timesteps, adjusted_steps = img2img_pipeline.get_timesteps(
|
||||
@ -655,7 +806,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
noise_func=None,
|
||||
seed=None,
|
||||
) -> InvokeAIStableDiffusionPipelineOutput:
|
||||
device = self.unet.device
|
||||
device = self._model_group.device_for(self.unet)
|
||||
latents_dtype = self.unet.dtype
|
||||
|
||||
if isinstance(init_image, PIL.Image.Image):
|
||||
@ -726,17 +877,42 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
nsfw_content_detected=[],
|
||||
attention_map_saver=result_attention_maps,
|
||||
)
|
||||
return output
|
||||
return self.check_for_safety(output, dtype=conditioning_data.dtype)
|
||||
|
||||
def non_noised_latents_from_image(self, init_image, *, device: torch.device, dtype):
|
||||
init_image = init_image.to(device=device, dtype=dtype)
|
||||
with torch.inference_mode():
|
||||
self._model_group.load(self.vae)
|
||||
init_latent_dist = self.vae.encode(init_image).latent_dist
|
||||
init_latents = init_latent_dist.sample().to(dtype=dtype) # FIXME: uses torch.randn. make reproducible!
|
||||
|
||||
init_latents = 0.18215 * init_latents
|
||||
return init_latents
|
||||
|
||||
def check_for_safety(self, output, dtype):
|
||||
with torch.inference_mode():
|
||||
screened_images, has_nsfw_concept = self.run_safety_checker(output.images, dtype=dtype)
|
||||
screened_attention_map_saver = None
|
||||
if has_nsfw_concept is None or not has_nsfw_concept:
|
||||
screened_attention_map_saver = output.attention_map_saver
|
||||
return InvokeAIStableDiffusionPipelineOutput(
|
||||
screened_images,
|
||||
has_nsfw_concept,
|
||||
# block the attention maps if NSFW content is detected
|
||||
attention_map_saver=screened_attention_map_saver,
|
||||
)
|
||||
|
||||
def run_safety_checker(self, image, device=None, dtype=None):
|
||||
# overriding to use the model group for device info instead of requiring the caller to know.
|
||||
if self.safety_checker is not None:
|
||||
device = self._model_group.device_for(self.safety_checker)
|
||||
return super().run_safety_checker(image, device, dtype)
|
||||
|
||||
def decode_latents(self, latents):
|
||||
# Explicit call to get the vae loaded, since `decode` isn't the forward method.
|
||||
self._model_group.load(self.vae)
|
||||
return super().decode_latents(latents)
|
||||
|
||||
def debug_latents(self, latents, msg):
|
||||
from invokeai.backend.image_util import debug_image
|
||||
|
||||
|
253
invokeai/backend/stable_diffusion/offloading.py
Normal file
253
invokeai/backend/stable_diffusion/offloading.py
Normal file
@ -0,0 +1,253 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
import weakref
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from collections.abc import MutableMapping
|
||||
from typing import Callable, Union
|
||||
|
||||
import torch
|
||||
from accelerate.utils import send_to_device
|
||||
from torch.utils.hooks import RemovableHandle
|
||||
|
||||
OFFLOAD_DEVICE = torch.device("cpu")
|
||||
|
||||
|
||||
class _NoModel:
|
||||
"""Symbol that indicates no model is loaded.
|
||||
|
||||
(We can't weakref.ref(None), so this was my best idea at the time to come up with something
|
||||
type-checkable.)
|
||||
"""
|
||||
|
||||
def __bool__(self):
|
||||
return False
|
||||
|
||||
def to(self, device: torch.device):
|
||||
pass
|
||||
|
||||
def __repr__(self):
|
||||
return "<NO MODEL>"
|
||||
|
||||
|
||||
NO_MODEL = _NoModel()
|
||||
|
||||
|
||||
class ModelGroup(metaclass=ABCMeta):
|
||||
"""
|
||||
A group of models.
|
||||
|
||||
The use case I had in mind when writing this is the sub-models used by a DiffusionPipeline,
|
||||
e.g. its text encoder, U-net, VAE, etc.
|
||||
|
||||
Those models are :py:class:`diffusers.ModelMixin`, but "model" is interchangeable with
|
||||
:py:class:`torch.nn.Module` here.
|
||||
"""
|
||||
|
||||
def __init__(self, execution_device: torch.device):
|
||||
self.execution_device = execution_device
|
||||
|
||||
@abstractmethod
|
||||
def install(self, *models: torch.nn.Module):
|
||||
"""Add models to this group."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def uninstall(self, models: torch.nn.Module):
|
||||
"""Remove models from this group."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def uninstall_all(self):
|
||||
"""Remove all models from this group."""
|
||||
|
||||
@abstractmethod
|
||||
def load(self, model: torch.nn.Module):
|
||||
"""Load this model to the execution device."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def offload_current(self):
|
||||
"""Offload the current model(s) from the execution device."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def ready(self):
|
||||
"""Ready this group for use."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def set_device(self, device: torch.device):
|
||||
"""Change which device models from this group will execute on."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def device_for(self, model) -> torch.device:
|
||||
"""Get the device the given model will execute on.
|
||||
|
||||
The model should already be a member of this group.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def __contains__(self, model):
|
||||
"""Check if the model is a member of this group."""
|
||||
pass
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<{self.__class__.__name__} object at {id(self):x}: " f"device={self.execution_device} >"
|
||||
|
||||
|
||||
class LazilyLoadedModelGroup(ModelGroup):
|
||||
"""
|
||||
Only one model from this group is loaded on the GPU at a time.
|
||||
|
||||
Running the forward method of a model will displace the previously-loaded model,
|
||||
offloading it to CPU.
|
||||
|
||||
If you call other methods on the model, e.g. ``model.encode(x)`` instead of ``model(x)``,
|
||||
you will need to explicitly load it with :py:method:`.load(model)`.
|
||||
|
||||
This implementation relies on pytorch forward-pre-hooks, and it will copy forward arguments
|
||||
to the appropriate execution device, as long as they are positional arguments and not keyword
|
||||
arguments. (I didn't make the rules; that's the way the pytorch 1.13 API works for hooks.)
|
||||
"""
|
||||
|
||||
_hooks: MutableMapping[torch.nn.Module, RemovableHandle]
|
||||
_current_model_ref: Callable[[], Union[torch.nn.Module, _NoModel]]
|
||||
|
||||
def __init__(self, execution_device: torch.device):
|
||||
super().__init__(execution_device)
|
||||
self._hooks = weakref.WeakKeyDictionary()
|
||||
self._current_model_ref = weakref.ref(NO_MODEL)
|
||||
|
||||
def install(self, *models: torch.nn.Module):
|
||||
for model in models:
|
||||
self._hooks[model] = model.register_forward_pre_hook(self._pre_hook)
|
||||
|
||||
def uninstall(self, *models: torch.nn.Module):
|
||||
for model in models:
|
||||
hook = self._hooks.pop(model)
|
||||
hook.remove()
|
||||
if self.is_current_model(model):
|
||||
# no longer hooked by this object, so don't claim to manage it
|
||||
self.clear_current_model()
|
||||
|
||||
def uninstall_all(self):
|
||||
self.uninstall(*self._hooks.keys())
|
||||
|
||||
def _pre_hook(self, module: torch.nn.Module, forward_input):
|
||||
self.load(module)
|
||||
if len(forward_input) == 0:
|
||||
warnings.warn(
|
||||
f"Hook for {module.__class__.__name__} got no input. " f"Inputs must be positional, not keywords.",
|
||||
stacklevel=3,
|
||||
)
|
||||
return send_to_device(forward_input, self.execution_device)
|
||||
|
||||
def load(self, module):
|
||||
if not self.is_current_model(module):
|
||||
self.offload_current()
|
||||
self._load(module)
|
||||
|
||||
def offload_current(self):
|
||||
module = self._current_model_ref()
|
||||
if module is not NO_MODEL:
|
||||
module.to(OFFLOAD_DEVICE)
|
||||
self.clear_current_model()
|
||||
|
||||
def _load(self, module: torch.nn.Module) -> torch.nn.Module:
|
||||
assert self.is_empty(), f"A model is already loaded: {self._current_model_ref()}"
|
||||
module = module.to(self.execution_device)
|
||||
self.set_current_model(module)
|
||||
return module
|
||||
|
||||
def is_current_model(self, model: torch.nn.Module) -> bool:
|
||||
"""Is the given model the one currently loaded on the execution device?"""
|
||||
return self._current_model_ref() is model
|
||||
|
||||
def is_empty(self):
|
||||
"""Are none of this group's models loaded on the execution device?"""
|
||||
return self._current_model_ref() is NO_MODEL
|
||||
|
||||
def set_current_model(self, value):
|
||||
self._current_model_ref = weakref.ref(value)
|
||||
|
||||
def clear_current_model(self):
|
||||
self._current_model_ref = weakref.ref(NO_MODEL)
|
||||
|
||||
def set_device(self, device: torch.device):
|
||||
if device == self.execution_device:
|
||||
return
|
||||
self.execution_device = device
|
||||
current = self._current_model_ref()
|
||||
if current is not NO_MODEL:
|
||||
current.to(device)
|
||||
|
||||
def device_for(self, model):
|
||||
if model not in self:
|
||||
raise KeyError(f"This does not manage this model {type(model).__name__}", model)
|
||||
return self.execution_device # this implementation only dispatches to one device
|
||||
|
||||
def ready(self):
|
||||
pass # always ready to load on-demand
|
||||
|
||||
def __contains__(self, model):
|
||||
return model in self._hooks
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"<{self.__class__.__name__} object at {id(self):x}: "
|
||||
f"current_model={type(self._current_model_ref()).__name__} >"
|
||||
)
|
||||
|
||||
|
||||
class FullyLoadedModelGroup(ModelGroup):
|
||||
"""
|
||||
A group of models without any implicit loading or unloading.
|
||||
|
||||
:py:meth:`.ready` loads _all_ the models to the execution device at once.
|
||||
"""
|
||||
|
||||
_models: weakref.WeakSet
|
||||
|
||||
def __init__(self, execution_device: torch.device):
|
||||
super().__init__(execution_device)
|
||||
self._models = weakref.WeakSet()
|
||||
|
||||
def install(self, *models: torch.nn.Module):
|
||||
for model in models:
|
||||
self._models.add(model)
|
||||
model.to(self.execution_device)
|
||||
|
||||
def uninstall(self, *models: torch.nn.Module):
|
||||
for model in models:
|
||||
self._models.remove(model)
|
||||
|
||||
def uninstall_all(self):
|
||||
self.uninstall(*self._models)
|
||||
|
||||
def load(self, model):
|
||||
model.to(self.execution_device)
|
||||
|
||||
def offload_current(self):
|
||||
for model in self._models:
|
||||
model.to(OFFLOAD_DEVICE)
|
||||
|
||||
def ready(self):
|
||||
for model in self._models:
|
||||
self.load(model)
|
||||
|
||||
def set_device(self, device: torch.device):
|
||||
self.execution_device = device
|
||||
for model in self._models:
|
||||
if model.device != OFFLOAD_DEVICE:
|
||||
model.to(device)
|
||||
|
||||
def device_for(self, model):
|
||||
if model not in self:
|
||||
raise KeyError("This does not manage this model f{type(model).__name__}", model)
|
||||
return self.execution_device # this implementation only dispatches to one device
|
||||
|
||||
def __contains__(self, model):
|
||||
return model in self._models
|
@ -1,3 +1,6 @@
|
||||
"""
|
||||
Initialization file for invokeai.frontend.config
|
||||
"""
|
||||
from .invokeai_configure import main as invokeai_configure
|
||||
from .invokeai_update import main as invokeai_update
|
||||
from .model_install import main as invokeai_model_install
|
||||
|
@ -1,795 +0,0 @@
|
||||
# Copyright (c) 2023 - The InvokeAI Team
|
||||
# Primary Author: David Lovell (github @f412design, discord @techjedi)
|
||||
# co-author, minor tweaks - Lincoln Stein
|
||||
|
||||
# pylint: disable=line-too-long
|
||||
# pylint: disable=broad-exception-caught
|
||||
"""Script to import images into the new database system for 3.0.0"""
|
||||
|
||||
import os
|
||||
import datetime
|
||||
import shutil
|
||||
import locale
|
||||
import sqlite3
|
||||
import json
|
||||
import glob
|
||||
import re
|
||||
import uuid
|
||||
import yaml
|
||||
import PIL
|
||||
import PIL.ImageOps
|
||||
import PIL.PngImagePlugin
|
||||
|
||||
from pathlib import Path
|
||||
from prompt_toolkit import prompt
|
||||
from prompt_toolkit.shortcuts import message_dialog
|
||||
from prompt_toolkit.completion import PathCompleter
|
||||
from prompt_toolkit.key_binding import KeyBindings
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
app_config = InvokeAIAppConfig.get_config()
|
||||
|
||||
bindings = KeyBindings()
|
||||
|
||||
|
||||
@bindings.add("c-c")
|
||||
def _(event):
|
||||
raise KeyboardInterrupt
|
||||
|
||||
|
||||
# release notes
|
||||
# "Use All" with size dimensions not selectable in the UI will not load dimensions
|
||||
|
||||
|
||||
class Config:
|
||||
"""Configuration loader."""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
TIMESTAMP_STRING = datetime.datetime.utcnow().strftime("%Y%m%dT%H%M%SZ")
|
||||
|
||||
INVOKE_DIRNAME = "invokeai"
|
||||
YAML_FILENAME = "invokeai.yaml"
|
||||
DATABASE_FILENAME = "invokeai.db"
|
||||
|
||||
database_path = None
|
||||
database_backup_dir = None
|
||||
outputs_path = None
|
||||
thumbnail_path = None
|
||||
|
||||
def find_and_load(self):
|
||||
"""find the yaml config file and load"""
|
||||
root = app_config.root_path
|
||||
if not self.confirm_and_load(os.path.abspath(root)):
|
||||
print("\r\nSpecify custom database and outputs paths:")
|
||||
self.confirm_and_load_from_user()
|
||||
|
||||
self.database_backup_dir = os.path.join(os.path.dirname(self.database_path), "backup")
|
||||
self.thumbnail_path = os.path.join(self.outputs_path, "thumbnails")
|
||||
|
||||
def confirm_and_load(self, invoke_root):
|
||||
"""Validates a yaml path exists, confirms the user wants to use it and loads config."""
|
||||
yaml_path = os.path.join(invoke_root, self.YAML_FILENAME)
|
||||
if os.path.exists(yaml_path):
|
||||
db_dir, outdir = self.load_paths_from_yaml(yaml_path)
|
||||
if os.path.isabs(db_dir):
|
||||
database_path = os.path.join(db_dir, self.DATABASE_FILENAME)
|
||||
else:
|
||||
database_path = os.path.join(invoke_root, db_dir, self.DATABASE_FILENAME)
|
||||
|
||||
if os.path.isabs(outdir):
|
||||
outputs_path = os.path.join(outdir, "images")
|
||||
else:
|
||||
outputs_path = os.path.join(invoke_root, outdir, "images")
|
||||
|
||||
db_exists = os.path.exists(database_path)
|
||||
outdir_exists = os.path.exists(outputs_path)
|
||||
|
||||
text = f"Found {self.YAML_FILENAME} file at {yaml_path}:"
|
||||
text += f"\n Database : {database_path}"
|
||||
text += f"\n Outputs : {outputs_path}"
|
||||
text += "\n\nUse these paths for import (yes) or choose different ones (no) [Yn]: "
|
||||
|
||||
if db_exists and outdir_exists:
|
||||
if (prompt(text).strip() or "Y").upper().startswith("Y"):
|
||||
self.database_path = database_path
|
||||
self.outputs_path = outputs_path
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
else:
|
||||
print(" Invalid: One or more paths in this config did not exist and cannot be used.")
|
||||
|
||||
else:
|
||||
message_dialog(
|
||||
title="Path not found",
|
||||
text=f"Auto-discovery of configuration failed! Could not find ({yaml_path}), Custom paths can be specified.",
|
||||
).run()
|
||||
return False
|
||||
|
||||
def confirm_and_load_from_user(self):
|
||||
default = ""
|
||||
while True:
|
||||
database_path = os.path.expanduser(
|
||||
prompt(
|
||||
"Database: Specify absolute path to the database to import into: ",
|
||||
completer=PathCompleter(
|
||||
expanduser=True, file_filter=lambda x: Path(x).is_dir() or x.endswith((".db"))
|
||||
),
|
||||
default=default,
|
||||
)
|
||||
)
|
||||
if database_path.endswith(".db") and os.path.isabs(database_path) and os.path.exists(database_path):
|
||||
break
|
||||
default = database_path + "/" if Path(database_path).is_dir() else database_path
|
||||
|
||||
default = ""
|
||||
while True:
|
||||
outputs_path = os.path.expanduser(
|
||||
prompt(
|
||||
"Outputs: Specify absolute path to outputs/images directory to import into: ",
|
||||
completer=PathCompleter(expanduser=True, only_directories=True),
|
||||
default=default,
|
||||
)
|
||||
)
|
||||
|
||||
if outputs_path.endswith("images") and os.path.isabs(outputs_path) and os.path.exists(outputs_path):
|
||||
break
|
||||
default = outputs_path + "/" if Path(outputs_path).is_dir() else outputs_path
|
||||
|
||||
self.database_path = database_path
|
||||
self.outputs_path = outputs_path
|
||||
|
||||
return
|
||||
|
||||
def load_paths_from_yaml(self, yaml_path):
|
||||
"""Load an Invoke AI yaml file and get the database and outputs paths."""
|
||||
try:
|
||||
with open(yaml_path, "rt", encoding=locale.getpreferredencoding()) as file:
|
||||
yamlinfo = yaml.safe_load(file)
|
||||
db_dir = yamlinfo.get("InvokeAI", {}).get("Paths", {}).get("db_dir", None)
|
||||
outdir = yamlinfo.get("InvokeAI", {}).get("Paths", {}).get("outdir", None)
|
||||
return db_dir, outdir
|
||||
except Exception:
|
||||
print(f"Failed to load paths from yaml file! {yaml_path}!")
|
||||
return None, None
|
||||
|
||||
|
||||
class ImportStats:
|
||||
"""DTO for tracking work progress."""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
time_start = datetime.datetime.utcnow()
|
||||
count_source_files = 0
|
||||
count_skipped_file_exists = 0
|
||||
count_skipped_db_exists = 0
|
||||
count_imported = 0
|
||||
count_imported_by_version = {}
|
||||
count_file_errors = 0
|
||||
|
||||
@staticmethod
|
||||
def get_elapsed_time_string():
|
||||
"""Get a friendly time string for the time elapsed since processing start."""
|
||||
time_now = datetime.datetime.utcnow()
|
||||
total_seconds = (time_now - ImportStats.time_start).total_seconds()
|
||||
hours = int((total_seconds) / 3600)
|
||||
minutes = int(((total_seconds) % 3600) / 60)
|
||||
seconds = total_seconds % 60
|
||||
out_str = f"{hours} hour(s) -" if hours > 0 else ""
|
||||
out_str += f"{minutes} minute(s) -" if minutes > 0 else ""
|
||||
out_str += f"{seconds:.2f} second(s)"
|
||||
return out_str
|
||||
|
||||
|
||||
class InvokeAIMetadata:
|
||||
"""DTO for core Invoke AI generation properties parsed from metadata."""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __str__(self):
|
||||
formatted_str = f"{self.generation_mode}~{self.steps}~{self.cfg_scale}~{self.model_name}~{self.scheduler}~{self.seed}~{self.width}~{self.height}~{self.rand_device}~{self.strength}~{self.init_image}"
|
||||
formatted_str += f"\r\npositive_prompt: {self.positive_prompt}"
|
||||
formatted_str += f"\r\nnegative_prompt: {self.negative_prompt}"
|
||||
return formatted_str
|
||||
|
||||
generation_mode = None
|
||||
steps = None
|
||||
cfg_scale = None
|
||||
model_name = None
|
||||
scheduler = None
|
||||
seed = None
|
||||
width = None
|
||||
height = None
|
||||
rand_device = None
|
||||
strength = None
|
||||
init_image = None
|
||||
positive_prompt = None
|
||||
negative_prompt = None
|
||||
imported_app_version = None
|
||||
|
||||
def to_json(self):
|
||||
"""Convert the active instance to json format."""
|
||||
prop_dict = {}
|
||||
prop_dict["generation_mode"] = self.generation_mode
|
||||
# dont render prompt nodes if neither are set to avoid the ui thinking it can set them
|
||||
# if at least one exists, render them both, but use empty string instead of None if one of them is empty
|
||||
# this allows the field that is empty to actually be cleared byt he UI instead of leaving the previous value
|
||||
if self.positive_prompt or self.negative_prompt:
|
||||
prop_dict["positive_prompt"] = "" if self.positive_prompt is None else self.positive_prompt
|
||||
prop_dict["negative_prompt"] = "" if self.negative_prompt is None else self.negative_prompt
|
||||
prop_dict["width"] = self.width
|
||||
prop_dict["height"] = self.height
|
||||
# only render seed if it has a value to avoid ui thinking it can set this and then error
|
||||
if self.seed:
|
||||
prop_dict["seed"] = self.seed
|
||||
prop_dict["rand_device"] = self.rand_device
|
||||
prop_dict["cfg_scale"] = self.cfg_scale
|
||||
prop_dict["steps"] = self.steps
|
||||
prop_dict["scheduler"] = self.scheduler
|
||||
prop_dict["clip_skip"] = 0
|
||||
prop_dict["model"] = {}
|
||||
prop_dict["model"]["model_name"] = self.model_name
|
||||
prop_dict["model"]["base_model"] = None
|
||||
prop_dict["controlnets"] = []
|
||||
prop_dict["loras"] = []
|
||||
prop_dict["vae"] = None
|
||||
prop_dict["strength"] = self.strength
|
||||
prop_dict["init_image"] = self.init_image
|
||||
prop_dict["positive_style_prompt"] = None
|
||||
prop_dict["negative_style_prompt"] = None
|
||||
prop_dict["refiner_model"] = None
|
||||
prop_dict["refiner_cfg_scale"] = None
|
||||
prop_dict["refiner_steps"] = None
|
||||
prop_dict["refiner_scheduler"] = None
|
||||
prop_dict["refiner_aesthetic_store"] = None
|
||||
prop_dict["refiner_start"] = None
|
||||
prop_dict["imported_app_version"] = self.imported_app_version
|
||||
|
||||
return json.dumps(prop_dict)
|
||||
|
||||
|
||||
class InvokeAIMetadataParser:
|
||||
"""Parses strings with json data to find Invoke AI core metadata properties."""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def parse_meta_tag_dream(self, dream_string):
|
||||
"""Take as input an png metadata json node for the 'dream' field variant from prior to 1.15"""
|
||||
props = InvokeAIMetadata()
|
||||
|
||||
props.imported_app_version = "pre1.15"
|
||||
seed_match = re.search("-S\\s*(\\d+)", dream_string)
|
||||
if seed_match is not None:
|
||||
try:
|
||||
props.seed = int(seed_match[1])
|
||||
except ValueError:
|
||||
props.seed = None
|
||||
raw_prompt = re.sub("(-S\\s*\\d+)", "", dream_string)
|
||||
else:
|
||||
raw_prompt = dream_string
|
||||
|
||||
pos_prompt, neg_prompt = self.split_prompt(raw_prompt)
|
||||
|
||||
props.positive_prompt = pos_prompt
|
||||
props.negative_prompt = neg_prompt
|
||||
|
||||
return props
|
||||
|
||||
def parse_meta_tag_sd_metadata(self, tag_value):
|
||||
"""Take as input an png metadata json node for the 'sd-metadata' field variant from 1.15 through 2.3.5 post 2"""
|
||||
props = InvokeAIMetadata()
|
||||
|
||||
props.imported_app_version = tag_value.get("app_version")
|
||||
props.model_name = tag_value.get("model_weights")
|
||||
img_node = tag_value.get("image")
|
||||
if img_node is not None:
|
||||
props.generation_mode = img_node.get("type")
|
||||
props.width = img_node.get("width")
|
||||
props.height = img_node.get("height")
|
||||
props.seed = img_node.get("seed")
|
||||
props.rand_device = "cuda" # hardcoded since all generations pre 3.0 used cuda random noise instead of cpu
|
||||
props.cfg_scale = img_node.get("cfg_scale")
|
||||
props.steps = img_node.get("steps")
|
||||
props.scheduler = self.map_scheduler(img_node.get("sampler"))
|
||||
props.strength = img_node.get("strength")
|
||||
if props.strength is None:
|
||||
props.strength = img_node.get("strength_steps") # try second name for this property
|
||||
props.init_image = img_node.get("init_image_path")
|
||||
if props.init_image is None: # try second name for this property
|
||||
props.init_image = img_node.get("init_img")
|
||||
# remove the path info from init_image so if we move the init image, it will be correctly relative in the new location
|
||||
if props.init_image is not None:
|
||||
props.init_image = os.path.basename(props.init_image)
|
||||
raw_prompt = img_node.get("prompt")
|
||||
if isinstance(raw_prompt, list):
|
||||
raw_prompt = raw_prompt[0].get("prompt")
|
||||
|
||||
props.positive_prompt, props.negative_prompt = self.split_prompt(raw_prompt)
|
||||
|
||||
return props
|
||||
|
||||
def parse_meta_tag_invokeai(self, tag_value):
|
||||
"""Take as input an png metadata json node for the 'invokeai' field variant from 3.0.0 beta 1 through 5"""
|
||||
props = InvokeAIMetadata()
|
||||
|
||||
props.imported_app_version = "3.0.0 or later"
|
||||
props.generation_mode = tag_value.get("type")
|
||||
if props.generation_mode is not None:
|
||||
props.generation_mode = props.generation_mode.replace("t2l", "txt2img").replace("l2l", "img2img")
|
||||
|
||||
props.width = tag_value.get("width")
|
||||
props.height = tag_value.get("height")
|
||||
props.seed = tag_value.get("seed")
|
||||
props.cfg_scale = tag_value.get("cfg_scale")
|
||||
props.steps = tag_value.get("steps")
|
||||
props.scheduler = tag_value.get("scheduler")
|
||||
props.strength = tag_value.get("strength")
|
||||
props.positive_prompt = tag_value.get("positive_conditioning")
|
||||
props.negative_prompt = tag_value.get("negative_conditioning")
|
||||
|
||||
return props
|
||||
|
||||
def map_scheduler(self, old_scheduler):
|
||||
"""Convert the legacy sampler names to matching 3.0 schedulers"""
|
||||
if old_scheduler is None:
|
||||
return None
|
||||
|
||||
match (old_scheduler):
|
||||
case "ddim":
|
||||
return "ddim"
|
||||
case "plms":
|
||||
return "pnmd"
|
||||
case "k_lms":
|
||||
return "lms"
|
||||
case "k_dpm_2":
|
||||
return "kdpm_2"
|
||||
case "k_dpm_2_a":
|
||||
return "kdpm_2_a"
|
||||
case "dpmpp_2":
|
||||
return "dpmpp_2s"
|
||||
case "k_dpmpp_2":
|
||||
return "dpmpp_2m"
|
||||
case "k_dpmpp_2_a":
|
||||
return None # invalid, in 2.3.x, selecting this sample would just fallback to last run or plms if new session
|
||||
case "k_euler":
|
||||
return "euler"
|
||||
case "k_euler_a":
|
||||
return "euler_a"
|
||||
case "k_heun":
|
||||
return "heun"
|
||||
return None
|
||||
|
||||
def split_prompt(self, raw_prompt: str):
|
||||
"""Split the unified prompt strings by extracting all negative prompt blocks out into the negative prompt."""
|
||||
if raw_prompt is None:
|
||||
return "", ""
|
||||
raw_prompt_search = raw_prompt.replace("\r", "").replace("\n", "")
|
||||
matches = re.findall(r"\[(.+?)\]", raw_prompt_search)
|
||||
if len(matches) > 0:
|
||||
negative_prompt = ""
|
||||
if len(matches) == 1:
|
||||
negative_prompt = matches[0].strip().strip(",")
|
||||
else:
|
||||
for match in matches:
|
||||
negative_prompt += f"({match.strip().strip(',')})"
|
||||
positive_prompt = re.sub(r"(\[.+?\])", "", raw_prompt_search).strip()
|
||||
else:
|
||||
positive_prompt = raw_prompt_search.strip()
|
||||
negative_prompt = ""
|
||||
|
||||
return positive_prompt, negative_prompt
|
||||
|
||||
|
||||
class DatabaseMapper:
|
||||
"""Class to abstract database functionality."""
|
||||
|
||||
def __init__(self, database_path, database_backup_dir):
|
||||
self.database_path = database_path
|
||||
self.database_backup_dir = database_backup_dir
|
||||
self.connection = None
|
||||
self.cursor = None
|
||||
|
||||
def connect(self):
|
||||
"""Open connection to the database."""
|
||||
self.connection = sqlite3.connect(self.database_path)
|
||||
self.cursor = self.connection.cursor()
|
||||
|
||||
def get_board_names(self):
|
||||
"""Get a list of the current board names from the database."""
|
||||
sql_get_board_name = "SELECT board_name FROM boards"
|
||||
self.cursor.execute(sql_get_board_name)
|
||||
rows = self.cursor.fetchall()
|
||||
return [row[0] for row in rows]
|
||||
|
||||
def does_image_exist(self, image_name):
|
||||
"""Check database if a image name already exists and return a boolean."""
|
||||
sql_get_image_by_name = f"SELECT image_name FROM images WHERE image_name='{image_name}'"
|
||||
self.cursor.execute(sql_get_image_by_name)
|
||||
rows = self.cursor.fetchall()
|
||||
return True if len(rows) > 0 else False
|
||||
|
||||
def add_new_image_to_database(self, filename, width, height, metadata, modified_date_string):
|
||||
"""Add an image to the database."""
|
||||
sql_add_image = f"""INSERT INTO images (image_name, image_origin, image_category, width, height, session_id, node_id, metadata, is_intermediate, created_at, updated_at)
|
||||
VALUES ('{filename}', 'internal', 'general', {width}, {height}, null, null, '{metadata}', 0, '{modified_date_string}', '{modified_date_string}')"""
|
||||
self.cursor.execute(sql_add_image)
|
||||
self.connection.commit()
|
||||
|
||||
def get_board_id_with_create(self, board_name):
|
||||
"""Get the board id for supplied name, and create the board if one does not exist."""
|
||||
sql_find_board = f"SELECT board_id FROM boards WHERE board_name='{board_name}' COLLATE NOCASE"
|
||||
self.cursor.execute(sql_find_board)
|
||||
rows = self.cursor.fetchall()
|
||||
if len(rows) > 0:
|
||||
return rows[0][0]
|
||||
else:
|
||||
board_date_string = datetime.datetime.utcnow().date().isoformat()
|
||||
new_board_id = str(uuid.uuid4())
|
||||
sql_insert_board = f"INSERT INTO boards (board_id, board_name, created_at, updated_at) VALUES ('{new_board_id}', '{board_name}', '{board_date_string}', '{board_date_string}')"
|
||||
self.cursor.execute(sql_insert_board)
|
||||
self.connection.commit()
|
||||
return new_board_id
|
||||
|
||||
def add_image_to_board(self, filename, board_id):
|
||||
"""Add an image mapping to a board."""
|
||||
add_datetime_str = datetime.datetime.utcnow().isoformat()
|
||||
sql_add_image_to_board = f"""INSERT INTO board_images (board_id, image_name, created_at, updated_at)
|
||||
VALUES ('{board_id}', '{filename}', '{add_datetime_str}', '{add_datetime_str}')"""
|
||||
self.cursor.execute(sql_add_image_to_board)
|
||||
self.connection.commit()
|
||||
|
||||
def disconnect(self):
|
||||
"""Disconnect from the db, cleaning up connections and cursors."""
|
||||
if self.cursor is not None:
|
||||
self.cursor.close()
|
||||
if self.connection is not None:
|
||||
self.connection.close()
|
||||
|
||||
def backup(self, timestamp_string):
|
||||
"""Take a backup of the database."""
|
||||
if not os.path.exists(self.database_backup_dir):
|
||||
print(f"Database backup directory {self.database_backup_dir} does not exist -> creating...", end="")
|
||||
os.makedirs(self.database_backup_dir)
|
||||
print("Done!")
|
||||
database_backup_path = os.path.join(self.database_backup_dir, f"backup-{timestamp_string}-invokeai.db")
|
||||
print(f"Making DB Backup at {database_backup_path}...", end="")
|
||||
shutil.copy2(self.database_path, database_backup_path)
|
||||
print("Done!")
|
||||
|
||||
|
||||
class MediaImportProcessor:
|
||||
"""Containing class for script functionality."""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
board_name_id_map = {}
|
||||
|
||||
def get_import_file_list(self):
|
||||
"""Ask the user for the import folder and scan for the list of files to return."""
|
||||
while True:
|
||||
default = ""
|
||||
while True:
|
||||
import_dir = os.path.expanduser(
|
||||
prompt(
|
||||
"Inputs: Specify absolute path containing InvokeAI .png images to import: ",
|
||||
completer=PathCompleter(expanduser=True, only_directories=True),
|
||||
default=default,
|
||||
)
|
||||
)
|
||||
if len(import_dir) > 0 and Path(import_dir).is_dir():
|
||||
break
|
||||
default = import_dir
|
||||
|
||||
recurse_directories = (
|
||||
(prompt("Include files from subfolders recursively [yN]? ").strip() or "N").upper().startswith("N")
|
||||
)
|
||||
if recurse_directories:
|
||||
is_recurse = False
|
||||
matching_file_list = glob.glob(import_dir + "/*.png", recursive=False)
|
||||
else:
|
||||
is_recurse = True
|
||||
matching_file_list = glob.glob(import_dir + "/**/*.png", recursive=True)
|
||||
|
||||
if len(matching_file_list) > 0:
|
||||
return import_dir, is_recurse, matching_file_list
|
||||
else:
|
||||
print(f"The specific path {import_dir} exists, but does not contain .png files!")
|
||||
|
||||
def get_file_details(self, filepath):
|
||||
"""Retrieve the embedded metedata fields and dimensions from an image file."""
|
||||
with PIL.Image.open(filepath) as img:
|
||||
img.load()
|
||||
png_width, png_height = img.size
|
||||
img_info = img.info
|
||||
return img_info, png_width, png_height
|
||||
|
||||
def select_board_option(self, board_names, timestamp_string):
|
||||
"""Allow the user to choose how a board is selected for imported files."""
|
||||
while True:
|
||||
print("\r\nOptions for board selection for imported images:")
|
||||
print(f"1) Select an existing board name. (found {len(board_names)})")
|
||||
print("2) Specify a board name to create/add to.")
|
||||
print("3) Create/add to board named 'IMPORT'.")
|
||||
print(
|
||||
f"4) Create/add to board named 'IMPORT' with the current datetime string appended (.e.g IMPORT_{timestamp_string})."
|
||||
)
|
||||
print(
|
||||
"5) Create/add to board named 'IMPORT' with a the original file app_version appended (.e.g IMPORT_2.2.5)."
|
||||
)
|
||||
input_option = input("Specify desired board option: ")
|
||||
match (input_option):
|
||||
case "1":
|
||||
if len(board_names) < 1:
|
||||
print("\r\nThere are no existing board names to choose from. Select another option!")
|
||||
continue
|
||||
board_name = self.select_item_from_list(
|
||||
board_names, "board name", True, "Cancel, go back and choose a different board option."
|
||||
)
|
||||
if board_name is not None:
|
||||
return board_name
|
||||
case "2":
|
||||
while True:
|
||||
board_name = input("Specify new/existing board name: ")
|
||||
if board_name:
|
||||
return board_name
|
||||
case "3":
|
||||
return "IMPORT"
|
||||
case "4":
|
||||
return f"IMPORT_{timestamp_string}"
|
||||
case "5":
|
||||
return "IMPORT_APPVERSION"
|
||||
|
||||
def select_item_from_list(self, items, entity_name, allow_cancel, cancel_string):
|
||||
"""A general function to render a list of items to select in the console, prompt the user for a selection and ensure a valid entry is selected."""
|
||||
print(f"Select a {entity_name.lower()} from the following list:")
|
||||
index = 1
|
||||
for item in items:
|
||||
print(f"{index}) {item}")
|
||||
index += 1
|
||||
if allow_cancel:
|
||||
print(f"{index}) {cancel_string}")
|
||||
while True:
|
||||
try:
|
||||
option_number = int(input("Specify number of selection: "))
|
||||
except ValueError:
|
||||
continue
|
||||
if allow_cancel and option_number == index:
|
||||
return None
|
||||
if option_number >= 1 and option_number <= len(items):
|
||||
return items[option_number - 1]
|
||||
|
||||
def import_image(self, filepath: str, board_name_option: str, db_mapper: DatabaseMapper, config: Config):
|
||||
"""Import a single file by its path"""
|
||||
parser = InvokeAIMetadataParser()
|
||||
file_name = os.path.basename(filepath)
|
||||
file_destination_path = os.path.join(config.outputs_path, file_name)
|
||||
|
||||
print("===============================================================================")
|
||||
print(f"Importing {filepath}")
|
||||
|
||||
# check destination to see if the file was previously imported
|
||||
if os.path.exists(file_destination_path):
|
||||
print("File already exists in the destination, skipping!")
|
||||
ImportStats.count_skipped_file_exists += 1
|
||||
return
|
||||
|
||||
# check if file name is already referenced in the database
|
||||
if db_mapper.does_image_exist(file_name):
|
||||
print("A reference to a file with this name already exists in the database, skipping!")
|
||||
ImportStats.count_skipped_db_exists += 1
|
||||
return
|
||||
|
||||
# load image info and dimensions
|
||||
img_info, png_width, png_height = self.get_file_details(filepath)
|
||||
|
||||
# parse metadata
|
||||
destination_needs_meta_update = True
|
||||
log_version_note = "(Unknown)"
|
||||
if "invokeai_metadata" in img_info:
|
||||
# for the latest, we will just re-emit the same json, no need to parse/modify
|
||||
converted_field = None
|
||||
latest_json_string = img_info.get("invokeai_metadata")
|
||||
log_version_note = "3.0.0+"
|
||||
destination_needs_meta_update = False
|
||||
else:
|
||||
if "sd-metadata" in img_info:
|
||||
converted_field = parser.parse_meta_tag_sd_metadata(json.loads(img_info.get("sd-metadata")))
|
||||
elif "invokeai" in img_info:
|
||||
converted_field = parser.parse_meta_tag_invokeai(json.loads(img_info.get("invokeai")))
|
||||
elif "dream" in img_info:
|
||||
converted_field = parser.parse_meta_tag_dream(img_info.get("dream"))
|
||||
elif "Dream" in img_info:
|
||||
converted_field = parser.parse_meta_tag_dream(img_info.get("Dream"))
|
||||
else:
|
||||
converted_field = InvokeAIMetadata()
|
||||
destination_needs_meta_update = False
|
||||
print("File does not have metadata from known Invoke AI versions, add only, no update!")
|
||||
|
||||
# use the loaded img dimensions if the metadata didnt have them
|
||||
if converted_field.width is None:
|
||||
converted_field.width = png_width
|
||||
if converted_field.height is None:
|
||||
converted_field.height = png_height
|
||||
|
||||
log_version_note = converted_field.imported_app_version if converted_field else "NoVersion"
|
||||
log_version_note = log_version_note or "NoVersion"
|
||||
|
||||
latest_json_string = converted_field.to_json()
|
||||
|
||||
print(f"From Invoke AI Version {log_version_note} with dimensions {png_width} x {png_height}.")
|
||||
|
||||
# if metadata needs update, then update metdata and copy in one shot
|
||||
if destination_needs_meta_update:
|
||||
print("Updating metadata while copying...", end="")
|
||||
self.update_file_metadata_while_copying(
|
||||
filepath, file_destination_path, "invokeai_metadata", latest_json_string
|
||||
)
|
||||
print("Done!")
|
||||
else:
|
||||
print("No metadata update necessary, copying only...", end="")
|
||||
shutil.copy2(filepath, file_destination_path)
|
||||
print("Done!")
|
||||
|
||||
# create thumbnail
|
||||
print("Creating thumbnail...", end="")
|
||||
thumbnail_path = os.path.join(config.thumbnail_path, os.path.splitext(file_name)[0]) + ".webp"
|
||||
thumbnail_size = 256, 256
|
||||
with PIL.Image.open(filepath) as source_image:
|
||||
source_image.thumbnail(thumbnail_size)
|
||||
source_image.save(thumbnail_path, "webp")
|
||||
print("Done!")
|
||||
|
||||
# finalize the dynamic board name if there is an APPVERSION token in it.
|
||||
if converted_field is not None:
|
||||
board_name = board_name_option.replace("APPVERSION", converted_field.imported_app_version or "NoVersion")
|
||||
else:
|
||||
board_name = board_name_option.replace("APPVERSION", "Latest")
|
||||
|
||||
# maintain a map of alrady created/looked up ids to avoid DB queries
|
||||
print("Finding/Creating board...", end="")
|
||||
if board_name in self.board_name_id_map:
|
||||
board_id = self.board_name_id_map[board_name]
|
||||
else:
|
||||
board_id = db_mapper.get_board_id_with_create(board_name)
|
||||
self.board_name_id_map[board_name] = board_id
|
||||
print("Done!")
|
||||
|
||||
# add image to db
|
||||
print("Adding image to database......", end="")
|
||||
modified_time = datetime.datetime.utcfromtimestamp(os.path.getmtime(filepath))
|
||||
db_mapper.add_new_image_to_database(file_name, png_width, png_height, latest_json_string, modified_time)
|
||||
print("Done!")
|
||||
|
||||
# add image to board
|
||||
print("Adding image to board......", end="")
|
||||
db_mapper.add_image_to_board(file_name, board_id)
|
||||
print("Done!")
|
||||
|
||||
ImportStats.count_imported += 1
|
||||
if log_version_note in ImportStats.count_imported_by_version:
|
||||
ImportStats.count_imported_by_version[log_version_note] += 1
|
||||
else:
|
||||
ImportStats.count_imported_by_version[log_version_note] = 1
|
||||
|
||||
def update_file_metadata_while_copying(self, filepath, file_destination_path, tag_name, tag_value):
|
||||
"""Perform a metadata update with save to a new destination which accomplishes a copy while updating metadata."""
|
||||
with PIL.Image.open(filepath) as target_image:
|
||||
existing_img_info = target_image.info
|
||||
metadata = PIL.PngImagePlugin.PngInfo()
|
||||
# re-add any existing invoke ai tags unless they are the one we are trying to add
|
||||
for key in existing_img_info:
|
||||
if key != tag_name and key in ("dream", "Dream", "sd-metadata", "invokeai", "invokeai_metadata"):
|
||||
metadata.add_text(key, existing_img_info[key])
|
||||
metadata.add_text(tag_name, tag_value)
|
||||
target_image.save(file_destination_path, pnginfo=metadata)
|
||||
|
||||
def process(self):
|
||||
"""Begin main processing."""
|
||||
|
||||
print("===============================================================================")
|
||||
print("This script will import images generated by earlier versions of")
|
||||
print("InvokeAI into the currently installed root directory:")
|
||||
print(f" {app_config.root_path}")
|
||||
print("If this is not what you want to do, type ctrl-C now to cancel.")
|
||||
|
||||
# load config
|
||||
print("===============================================================================")
|
||||
print("= Configuration & Settings")
|
||||
|
||||
config = Config()
|
||||
config.find_and_load()
|
||||
db_mapper = DatabaseMapper(config.database_path, config.database_backup_dir)
|
||||
db_mapper.connect()
|
||||
|
||||
import_dir, is_recurse, import_file_list = self.get_import_file_list()
|
||||
ImportStats.count_source_files = len(import_file_list)
|
||||
|
||||
board_names = db_mapper.get_board_names()
|
||||
board_name_option = self.select_board_option(board_names, config.TIMESTAMP_STRING)
|
||||
|
||||
print("\r\n===============================================================================")
|
||||
print("= Import Settings Confirmation")
|
||||
|
||||
print()
|
||||
print(f"Database File Path : {config.database_path}")
|
||||
print(f"Outputs/Images Directory : {config.outputs_path}")
|
||||
print(f"Import Image Source Directory : {import_dir}")
|
||||
print(f" Recurse Source SubDirectories : {'Yes' if is_recurse else 'No'}")
|
||||
print(f"Count of .png file(s) found : {len(import_file_list)}")
|
||||
print(f"Board name option specified : {board_name_option}")
|
||||
print(f"Database backup will be taken at : {config.database_backup_dir}")
|
||||
|
||||
print("\r\nNotes about the import process:")
|
||||
print("- Source image files will not be modified, only copied to the outputs directory.")
|
||||
print("- If the same file name already exists in the destination, the file will be skipped.")
|
||||
print("- If the same file name already has a record in the database, the file will be skipped.")
|
||||
print("- Invoke AI metadata tags will be updated/written into the imported copy only.")
|
||||
print(
|
||||
"- On the imported copy, only Invoke AI known tags (latest and legacy) will be retained (dream, sd-metadata, invokeai, invokeai_metadata)"
|
||||
)
|
||||
print(
|
||||
"- A property 'imported_app_version' will be added to metadata that can be viewed in the UI's metadata viewer."
|
||||
)
|
||||
print(
|
||||
"- The new 3.x InvokeAI outputs folder structure is flat so recursively found source imges will all be placed into the single outputs/images folder."
|
||||
)
|
||||
|
||||
while True:
|
||||
should_continue = prompt("\nDo you wish to continue with the import [Yn] ? ").lower() or "y"
|
||||
if should_continue == "n":
|
||||
print("\r\nCancelling Import")
|
||||
return
|
||||
elif should_continue == "y":
|
||||
print()
|
||||
break
|
||||
|
||||
db_mapper.backup(config.TIMESTAMP_STRING)
|
||||
|
||||
print()
|
||||
ImportStats.time_start = datetime.datetime.utcnow()
|
||||
|
||||
for filepath in import_file_list:
|
||||
try:
|
||||
self.import_image(filepath, board_name_option, db_mapper, config)
|
||||
except sqlite3.Error as sql_ex:
|
||||
print(f"A database related exception was found processing {filepath}, will continue to next file. ")
|
||||
print("Exception detail:")
|
||||
print(sql_ex)
|
||||
ImportStats.count_file_errors += 1
|
||||
except Exception as ex:
|
||||
print(f"Exception processing {filepath}, will continue to next file. ")
|
||||
print("Exception detail:")
|
||||
print(ex)
|
||||
ImportStats.count_file_errors += 1
|
||||
|
||||
print("\r\n===============================================================================")
|
||||
print(f"= Import Complete - Elpased Time: {ImportStats.get_elapsed_time_string()}")
|
||||
print()
|
||||
print(f"Source File(s) : {ImportStats.count_source_files}")
|
||||
print(f"Total Imported : {ImportStats.count_imported}")
|
||||
print(f"Skipped b/c file already exists on disk : {ImportStats.count_skipped_file_exists}")
|
||||
print(f"Skipped b/c file already exists in db : {ImportStats.count_skipped_db_exists}")
|
||||
print(f"Errors during import : {ImportStats.count_file_errors}")
|
||||
if ImportStats.count_imported > 0:
|
||||
print("\r\nBreakdown of imported files by version:")
|
||||
for key, version in ImportStats.count_imported_by_version.items():
|
||||
print(f" {key:20} : {version}")
|
||||
|
||||
|
||||
def main():
|
||||
try:
|
||||
processor = MediaImportProcessor()
|
||||
processor.process()
|
||||
except KeyboardInterrupt:
|
||||
print("\r\n\r\nUser cancelled execution.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,4 +1,4 @@
|
||||
"""
|
||||
Wrapper for invokeai.backend.configure.invokeai_configure
|
||||
"""
|
||||
from ...backend.install.invokeai_configure import main as invokeai_configure
|
||||
from ...backend.install.invokeai_configure import main
|
||||
|
@ -382,8 +382,7 @@ def run_cli(args: Namespace):
|
||||
|
||||
def main():
|
||||
args = _parse_args()
|
||||
if args.root_dir:
|
||||
config.parse_args(["--root", str(args.root_dir)])
|
||||
config.parse_args(["--root", str(args.root_dir)])
|
||||
|
||||
try:
|
||||
if args.front_end:
|
||||
|
File diff suppressed because one or more lines are too long
@ -1,4 +1,4 @@
|
||||
import{B as m,g7 as Je,A as y,a5 as Ka,g8 as Xa,af as va,aj as d,g9 as b,ga as t,gb as Ya,gc as h,gd as ua,ge as Ja,gf as Qa,aL as Za,gg as et,ad as rt,gh as at}from"./index-815faab3.js";import{s as fa,n as o,t as tt,o as ha,p as ot,q as ma,v as ga,w as ya,x as it,y as Sa,z as pa,A as xr,B as nt,D as lt,E as st,F as xa,G as $a,H as ka,J as dt,K as _a,L as ct,M as bt,N as vt,O as ut,Q as wa,R as ft,S as ht,T as mt,U as gt,V as yt,W as St,e as pt,X as xt}from"./menu-e9f8a36e.js";var za=String.raw,Ca=za`
|
||||
import{B as m,g7 as Je,A as y,a5 as Ka,g8 as Xa,af as va,aj as d,g9 as b,ga as t,gb as Ya,gc as h,gd as ua,ge as Ja,gf as Qa,aL as Za,gg as et,ad as rt,gh as at}from"./index-dd054634.js";import{s as fa,n as o,t as tt,o as ha,p as ot,q as ma,v as ga,w as ya,x as it,y as Sa,z as pa,A as xr,B as nt,D as lt,E as st,F as xa,G as $a,H as ka,J as dt,K as _a,L as ct,M as bt,N as vt,O as ut,Q as wa,R as ft,S as ht,T as mt,U as gt,V as yt,W as St,e as pt,X as xt}from"./menu-b42141e3.js";var za=String.raw,Ca=za`
|
||||
:root,
|
||||
:host {
|
||||
--chakra-vh: 100vh;
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
2
invokeai/frontend/web/dist/index.html
vendored
2
invokeai/frontend/web/dist/index.html
vendored
@ -12,7 +12,7 @@
|
||||
margin: 0;
|
||||
}
|
||||
</style>
|
||||
<script type="module" crossorigin src="./assets/index-815faab3.js"></script>
|
||||
<script type="module" crossorigin src="./assets/index-dd054634.js"></script>
|
||||
</head>
|
||||
|
||||
<body dir="ltr">
|
||||
|
@ -20,7 +20,7 @@ export const addStagingAreaImageSavedListener = () => {
|
||||
// we may need to add it to the autoadd board
|
||||
const { autoAddBoardId } = getState().gallery;
|
||||
|
||||
if (autoAddBoardId && autoAddBoardId !== 'none') {
|
||||
if (autoAddBoardId) {
|
||||
await dispatch(
|
||||
imagesApi.endpoints.addImageToBoard.initiate({
|
||||
imageDTO: newImageDTO,
|
||||
|
@ -1,58 +1,55 @@
|
||||
import { modelChanged } from 'features/parameters/store/generationSlice';
|
||||
import { setActiveTab } from 'features/ui/store/uiSlice';
|
||||
import { forEach } from 'lodash-es';
|
||||
import { NON_REFINER_BASE_MODELS } from 'services/api/constants';
|
||||
import { mainModelsAdapter, modelsApi } from 'services/api/endpoints/models';
|
||||
import {
|
||||
MainModelConfigEntity,
|
||||
modelsApi,
|
||||
} from 'services/api/endpoints/models';
|
||||
import { startAppListening } from '..';
|
||||
|
||||
export const addTabChangedListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: setActiveTab,
|
||||
effect: async (action, { getState, dispatch }) => {
|
||||
effect: (action, { getState, dispatch }) => {
|
||||
const activeTabName = action.payload;
|
||||
if (activeTabName === 'unifiedCanvas') {
|
||||
const currentBaseModel = getState().generation.model?.base_model;
|
||||
// grab the models from RTK Query cache
|
||||
const { data } = modelsApi.endpoints.getMainModels.select(
|
||||
NON_REFINER_BASE_MODELS
|
||||
)(getState());
|
||||
|
||||
if (currentBaseModel && ['sd-1', 'sd-2'].includes(currentBaseModel)) {
|
||||
// if we're already on a valid model, no change needed
|
||||
if (!data) {
|
||||
// no models yet, so we can't do anything
|
||||
dispatch(modelChanged(null));
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
// just grab fresh models
|
||||
const modelsRequest = dispatch(
|
||||
modelsApi.endpoints.getMainModels.initiate(NON_REFINER_BASE_MODELS)
|
||||
);
|
||||
const models = await modelsRequest.unwrap();
|
||||
// cancel this cache subscription
|
||||
modelsRequest.unsubscribe();
|
||||
// need to filter out all the invalid canvas models (currently, this is just sdxl)
|
||||
const validCanvasModels: MainModelConfigEntity[] = [];
|
||||
|
||||
if (!models.ids.length) {
|
||||
// no valid canvas models
|
||||
dispatch(modelChanged(null));
|
||||
forEach(data.entities, (entity) => {
|
||||
if (!entity) {
|
||||
return;
|
||||
}
|
||||
|
||||
// need to filter out all the invalid canvas models (currently sdxl & refiner)
|
||||
const validCanvasModels = mainModelsAdapter
|
||||
.getSelectors()
|
||||
.selectAll(models)
|
||||
.filter((model) => ['sd-1', 'sd-2'].includes(model.base_model));
|
||||
|
||||
const firstValidCanvasModel = validCanvasModels[0];
|
||||
|
||||
if (!firstValidCanvasModel) {
|
||||
// no valid canvas models
|
||||
dispatch(modelChanged(null));
|
||||
return;
|
||||
if (['sd-1', 'sd-2'].includes(entity.base_model)) {
|
||||
validCanvasModels.push(entity);
|
||||
}
|
||||
});
|
||||
|
||||
const { base_model, model_name, model_type } = firstValidCanvasModel;
|
||||
// this could still be undefined even tho TS doesn't say so
|
||||
const firstValidCanvasModel = validCanvasModels[0];
|
||||
|
||||
dispatch(modelChanged({ base_model, model_name, model_type }));
|
||||
} catch {
|
||||
// network request failed, bail
|
||||
if (!firstValidCanvasModel) {
|
||||
// uh oh, we have no models that are valid for canvas
|
||||
dispatch(modelChanged(null));
|
||||
return;
|
||||
}
|
||||
|
||||
// only store the model name and base model in redux
|
||||
const { base_model, model_name, model_type } = firstValidCanvasModel;
|
||||
|
||||
dispatch(modelChanged({ base_model, model_name, model_type }));
|
||||
}
|
||||
},
|
||||
});
|
||||
|
@ -54,7 +54,12 @@ const ParamLoRASelect = () => {
|
||||
});
|
||||
});
|
||||
|
||||
return data.sort((a, b) => (a.disabled && !b.disabled ? 1 : -1));
|
||||
// Sort Alphabetically
|
||||
data.sort((a, b) =>
|
||||
a.label && b.label ? (a.label?.localeCompare(b.label) ? 1 : -1) : -1
|
||||
);
|
||||
|
||||
return data.sort((a, b) => (a.disabled && !b.disabled ? -1 : 1));
|
||||
}, [loras, loraModels, currentMainModel?.base_model]);
|
||||
|
||||
const handleChange = useCallback(
|
||||
|
@ -365,19 +365,12 @@ export const systemSlice = createSlice({
|
||||
state.statusTranslationKey = 'common.statusConnected';
|
||||
state.progressImage = null;
|
||||
|
||||
let errorDescription = undefined;
|
||||
|
||||
if (action.payload?.status === 422) {
|
||||
errorDescription = 'Validation Error';
|
||||
} else if (action.payload?.error) {
|
||||
errorDescription = action.payload?.error as string;
|
||||
}
|
||||
|
||||
state.toastQueue.push(
|
||||
makeToast({
|
||||
title: t('toast.serverError'),
|
||||
status: 'error',
|
||||
description: errorDescription,
|
||||
description:
|
||||
action.payload?.status === 422 ? 'Validation Error' : undefined,
|
||||
})
|
||||
);
|
||||
});
|
||||
|
@ -60,9 +60,6 @@ type InvokedSessionThunkConfig = {
|
||||
const isErrorWithStatus = (error: unknown): error is { status: number } =>
|
||||
isObject(error) && 'status' in error;
|
||||
|
||||
const isErrorWithDetail = (error: unknown): error is { detail: string } =>
|
||||
isObject(error) && 'detail' in error;
|
||||
|
||||
/**
|
||||
* `SessionsService.invokeSession()` thunk
|
||||
*/
|
||||
@ -88,15 +85,7 @@ export const sessionInvoked = createAsyncThunk<
|
||||
error: (error as any).body.detail,
|
||||
});
|
||||
}
|
||||
if (isErrorWithDetail(error) && response.status === 403) {
|
||||
return rejectWithValue({
|
||||
arg,
|
||||
status: response.status,
|
||||
error: error.detail
|
||||
});
|
||||
}
|
||||
if (error)
|
||||
return rejectWithValue({ arg, status: response.status, error });
|
||||
return rejectWithValue({ arg, status: response.status, error });
|
||||
}
|
||||
});
|
||||
|
||||
|
@ -47,7 +47,7 @@ dependencies = [
|
||||
"einops",
|
||||
"eventlet",
|
||||
"facexlib",
|
||||
"fastapi==0.101.0",
|
||||
"fastapi==0.88.0",
|
||||
"fastapi-events==0.8.0",
|
||||
"fastapi-socketio==0.0.10",
|
||||
"flask==2.1.3",
|
||||
@ -64,8 +64,7 @@ dependencies = [
|
||||
"onnx",
|
||||
"onnxruntime",
|
||||
"opencv-python",
|
||||
"pydantic==2.1.1",
|
||||
"pydantic-settings",
|
||||
"pydantic==1.*",
|
||||
"picklescan",
|
||||
"pillow",
|
||||
"prompt-toolkit",
|
||||
@ -78,7 +77,7 @@ dependencies = [
|
||||
"realesrgan",
|
||||
"requests~=2.28.2",
|
||||
"rich~=13.3",
|
||||
"safetensors==0.3.1",
|
||||
"safetensors~=0.3.0",
|
||||
"scikit-image~=0.21.0",
|
||||
"send2trash",
|
||||
"test-tube~=0.7.5",
|
||||
@ -119,7 +118,7 @@ dependencies = [
|
||||
[project.scripts]
|
||||
|
||||
# legacy entrypoints; provided for backwards compatibility
|
||||
"configure_invokeai.py" = "invokeai.frontend.install.invokeai_configure:invokeai_configure"
|
||||
"configure_invokeai.py" = "invokeai.frontend.install:invokeai_configure"
|
||||
"textual_inversion.py" = "invokeai.frontend.training:invokeai_textual_inversion"
|
||||
|
||||
# shortcut commands to start cli and web
|
||||
@ -131,16 +130,15 @@ dependencies = [
|
||||
"invokeai-web" = "invokeai.app.api_app:invoke_api"
|
||||
|
||||
# full commands
|
||||
"invokeai-configure" = "invokeai.frontend.install.invokeai_configure:invokeai_configure"
|
||||
"invokeai-configure" = "invokeai.frontend.install:invokeai_configure"
|
||||
"invokeai-merge" = "invokeai.frontend.merge:invokeai_merge_diffusers"
|
||||
"invokeai-ti" = "invokeai.frontend.training:invokeai_textual_inversion"
|
||||
"invokeai-model-install" = "invokeai.frontend.install.model_install:main"
|
||||
"invokeai-model-install" = "invokeai.frontend.install:invokeai_model_install"
|
||||
"invokeai-migrate3" = "invokeai.backend.install.migrate_to_3:main"
|
||||
"invokeai-update" = "invokeai.frontend.install.invokeai_update:main"
|
||||
"invokeai-update" = "invokeai.frontend.install:invokeai_update"
|
||||
"invokeai-metadata" = "invokeai.frontend.CLI.sd_metadata:print_metadata"
|
||||
"invokeai-node-cli" = "invokeai.app.cli_app:invoke_cli"
|
||||
"invokeai-node-web" = "invokeai.app.api_app:invoke_api"
|
||||
"invokeai-import-images" = "invokeai.frontend.install.import_images:main"
|
||||
|
||||
[project.urls]
|
||||
"Homepage" = "https://invoke-ai.github.io/InvokeAI/"
|
||||
|
@ -1,34 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Read a checkpoint/safetensors file and write out a template .json file containing
|
||||
its metadata for use in fast model probing.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
import json
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from invokeai.backend.model_management.models.base import read_checkpoint_meta
|
||||
|
||||
parser = argparse.ArgumentParser(description="Create a .json template from checkpoint/safetensors model")
|
||||
parser.add_argument("--checkpoint", "--in", type=Path, help="Path to the input checkpoint/safetensors file")
|
||||
parser.add_argument("--template", "--out", type=Path, help="Path to the output .json file")
|
||||
|
||||
opt = parser.parse_args()
|
||||
ckpt = read_checkpoint_meta(opt.checkpoint)
|
||||
while "state_dict" in ckpt:
|
||||
ckpt = ckpt["state_dict"]
|
||||
|
||||
tmpl = {}
|
||||
|
||||
for key, tensor in ckpt.items():
|
||||
tmpl[key] = list(tensor.shape)
|
||||
|
||||
try:
|
||||
with open(opt.template, "w") as f:
|
||||
json.dump(tmpl, f)
|
||||
print(f"Template written out as {opt.template}")
|
||||
except Exception as e:
|
||||
print(f"An exception occurred while writing template: {str(e)}")
|
@ -1,37 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Read a checkpoint/safetensors file and compare it to a template .json.
|
||||
Returns True if their metadata match.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
import json
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from invokeai.backend.model_management.models.base import read_checkpoint_meta
|
||||
|
||||
parser = argparse.ArgumentParser(description="Compare a checkpoint/safetensors file to a JSON metadata template.")
|
||||
parser.add_argument("--checkpoint", "--in", type=Path, help="Path to the input checkpoint/safetensors file")
|
||||
parser.add_argument("--template", "--out", type=Path, help="Path to the template .json file to match against")
|
||||
|
||||
opt = parser.parse_args()
|
||||
ckpt = read_checkpoint_meta(opt.checkpoint)
|
||||
while "state_dict" in ckpt:
|
||||
ckpt = ckpt["state_dict"]
|
||||
|
||||
checkpoint_metadata = {}
|
||||
|
||||
for key, tensor in ckpt.items():
|
||||
checkpoint_metadata[key] = list(tensor.shape)
|
||||
|
||||
with open(opt.template, "r") as f:
|
||||
template = json.load(f)
|
||||
|
||||
if checkpoint_metadata == template:
|
||||
print("True")
|
||||
sys.exit(0)
|
||||
else:
|
||||
print("False")
|
||||
sys.exit(-1)
|
@ -584,7 +584,7 @@ def test_graph_can_serialize():
|
||||
g.add_edge(e)
|
||||
|
||||
# Not throwing on this line is sufficient
|
||||
json = g.model_dump_json()
|
||||
json = g.json()
|
||||
|
||||
|
||||
def test_graph_can_deserialize():
|
||||
@ -596,8 +596,8 @@ def test_graph_can_deserialize():
|
||||
e = create_edge(n1.id, "image", n2.id, "image")
|
||||
g.add_edge(e)
|
||||
|
||||
json = g.model_dump_json()
|
||||
g2 = Graph.model_validate_json(json)
|
||||
json = g.json()
|
||||
g2 = Graph.parse_raw(json)
|
||||
|
||||
assert g2 is not None
|
||||
assert g2.nodes["1"] is not None
|
||||
|
@ -7,7 +7,6 @@ from invokeai.backend import ModelManager, BaseModelType, ModelType, SubModelTyp
|
||||
|
||||
BASIC_MODEL_NAME = ("SDXL base", BaseModelType.StableDiffusionXL, ModelType.Main)
|
||||
VAE_OVERRIDE_MODEL_NAME = ("SDXL with VAE", BaseModelType.StableDiffusionXL, ModelType.Main)
|
||||
VAE_NULL_OVERRIDE_MODEL_NAME = ("SDXL with empty VAE", BaseModelType.StableDiffusionXL, ModelType.Main)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -37,11 +36,3 @@ def test_get_model_path_for_overridden_vae(model_manager: ModelManager, datadir:
|
||||
expected_vae_path = datadir / "models" / "sdxl" / "vae" / "sdxl-vae-fp16-fix"
|
||||
assert vae_model_path == expected_vae_path
|
||||
assert is_override
|
||||
|
||||
|
||||
def test_get_model_path_for_null_overridden_vae(model_manager: ModelManager, datadir: Path):
|
||||
model_config = model_manager._get_model_config(
|
||||
VAE_NULL_OVERRIDE_MODEL_NAME[1], VAE_NULL_OVERRIDE_MODEL_NAME[0], VAE_NULL_OVERRIDE_MODEL_NAME[2]
|
||||
)
|
||||
vae_model_path, is_override = model_manager._get_model_path(model_config, SubModelType.Vae)
|
||||
assert not is_override
|
||||
|
@ -13,10 +13,3 @@ sdxl/main/SDXL with VAE:
|
||||
vae: sdxl/vae/sdxl-vae-fp16-fix/
|
||||
variant: normal
|
||||
format: diffusers
|
||||
|
||||
sdxl/main/SDXL with empty VAE:
|
||||
path: sdxl/main/SDXL base 1_0
|
||||
description: SDXL with customized VAE
|
||||
vae: ''
|
||||
variant: normal
|
||||
format: diffusers
|
||||
|
Reference in New Issue
Block a user