Merge branch 'main' into ryan/regional-prompting-naive
@ -18,12 +18,47 @@ Note that any releases marked as _pre-release_ are in a beta state. You may expe
|
|||||||
|
|
||||||
The Model Manager tab in the UI provides a few ways to install models, including using your already-downloaded models. You'll see a popup directing you there on first startup. For more information, see the [model install docs].
|
The Model Manager tab in the UI provides a few ways to install models, including using your already-downloaded models. You'll see a popup directing you there on first startup. For more information, see the [model install docs].
|
||||||
|
|
||||||
|
## Missing models after updating to v4
|
||||||
|
|
||||||
|
If you find some models are missing after updating to v4, it's likely they weren't correctly registered before the update and didn't get picked up in the migration.
|
||||||
|
|
||||||
|
You can use the `Scan Folder` tab in the Model Manager UI to fix this. The models will either be in the old, now-unused `autoimport` folder, or your `models` folder.
|
||||||
|
|
||||||
|
- Find and copy your install's old `autoimport` folder path, install the main install folder.
|
||||||
|
- Go to the Model Manager and click `Scan Folder`.
|
||||||
|
- Paste the path and scan.
|
||||||
|
- IMPORTANT: Uncheck `Inplace install`.
|
||||||
|
- Click `Install All` to install all found models, or just install the models you want.
|
||||||
|
|
||||||
|
Next, find and copy your install's `models` folder path (this could be your custom models folder path, or the `models` folder inside the main install folder).
|
||||||
|
|
||||||
|
Follow the same steps to scan and import the missing models.
|
||||||
|
|
||||||
## Slow generation
|
## Slow generation
|
||||||
|
|
||||||
- Check the [system requirements] to ensure that your system is capable of generating images.
|
- Check the [system requirements] to ensure that your system is capable of generating images.
|
||||||
- Check the `ram` setting in `invokeai.yaml`. This setting tells Invoke how much of your system RAM can be used to cache models. Having this too high or too low can slow things down. That said, it's generally safest to not set this at all and instead let Invoke manage it.
|
- Check the `ram` setting in `invokeai.yaml`. This setting tells Invoke how much of your system RAM can be used to cache models. Having this too high or too low can slow things down. That said, it's generally safest to not set this at all and instead let Invoke manage it.
|
||||||
- Check the `vram` setting in `invokeai.yaml`. This setting tells Invoke how much of your GPU VRAM can be used to cache models. Counter-intuitively, if this setting is too high, Invoke will need to do a lot of shuffling of models as it juggles the VRAM cache and the currently-loaded model. The default value of 0.25 is generally works well for GPUs without 16GB or more VRAM. Even on a 24GB card, the default works well.
|
- Check the `vram` setting in `invokeai.yaml`. This setting tells Invoke how much of your GPU VRAM can be used to cache models. Counter-intuitively, if this setting is too high, Invoke will need to do a lot of shuffling of models as it juggles the VRAM cache and the currently-loaded model. The default value of 0.25 is generally works well for GPUs without 16GB or more VRAM. Even on a 24GB card, the default works well.
|
||||||
- Check that your generations are happening on your GPU (if you have one). InvokeAI will log what is being used for generation upon startup. If your GPU isn't used, re-install to ensure the correct versions of torch get installed.
|
- Check that your generations are happening on your GPU (if you have one). InvokeAI will log what is being used for generation upon startup. If your GPU isn't used, re-install to ensure the correct versions of torch get installed.
|
||||||
|
- If you are on Windows, you may have exceeded your GPU's VRAM capacity and are using slower [shared GPU memory](#shared-gpu-memory-windows). There's a guide to opt out of this behaviour in the linked FAQ entry.
|
||||||
|
|
||||||
|
## Shared GPU Memory (Windows)
|
||||||
|
|
||||||
|
!!! tip "Nvidia GPUs with driver 536.40"
|
||||||
|
|
||||||
|
This only applies to current Nvidia cards with driver 536.40 or later, released in June 2023.
|
||||||
|
|
||||||
|
When the GPU doesn't have enough VRAM for a task, Windows is able to allocate some of its CPU RAM to the GPU. This is much slower than VRAM, but it does allow the system to generate when it otherwise might no have enough VRAM.
|
||||||
|
|
||||||
|
When shared GPU memory is used, generation slows down dramatically - but at least it doesn't crash.
|
||||||
|
|
||||||
|
If you'd like to opt out of this behavior and instead get an error when you exceed your GPU's VRAM, follow [this guide from Nvidia](https://nvidia.custhelp.com/app/answers/detail/a_id/5490).
|
||||||
|
|
||||||
|
Here's how to get the python path required in the linked guide:
|
||||||
|
|
||||||
|
- Run `invoke.bat`.
|
||||||
|
- Select option 2 for developer console.
|
||||||
|
- At least one python path will be printed. Copy the path that includes your invoke installation directory (typically the first).
|
||||||
|
|
||||||
## Installer cannot find python (Windows)
|
## Installer cannot find python (Windows)
|
||||||
|
|
||||||
|
@ -44,7 +44,7 @@ The installation process is simple, with a few prompts:
|
|||||||
|
|
||||||
- Select the version to install. Unless you have a specific reason to install a specific version, select the default (the latest version).
|
- Select the version to install. Unless you have a specific reason to install a specific version, select the default (the latest version).
|
||||||
- Select location for the install. Be sure you have enough space in this folder for the base application, as described in the [installation requirements].
|
- Select location for the install. Be sure you have enough space in this folder for the base application, as described in the [installation requirements].
|
||||||
- Select a GPU device. If you are unsure, you can let the installer figure it out.
|
- Select a GPU device.
|
||||||
|
|
||||||
!!! info "Slow Installation"
|
!!! info "Slow Installation"
|
||||||
|
|
||||||
|
@ -6,11 +6,7 @@
|
|||||||
|
|
||||||
## Introduction
|
## Introduction
|
||||||
|
|
||||||
!!! tip "Conda"
|
InvokeAI is distributed as a python package on PyPI, installable with `pip`. There are a few things that are handled by the installer and launcher that you'll need to manage manually, described in this guide.
|
||||||
|
|
||||||
As of InvokeAI v2.3.0 installation using the `conda` package manager is no longer being supported. It will likely still work, but we are not testing this installation method.
|
|
||||||
|
|
||||||
InvokeAI is distributed as a python package on PyPI, installable with `pip`. There are a few things that are handled by the installer that you'll need to manage manually, described in this guide.
|
|
||||||
|
|
||||||
### Requirements
|
### Requirements
|
||||||
|
|
||||||
@ -40,11 +36,11 @@ Before you start, go through the [installation requirements].
|
|||||||
|
|
||||||
1. Enter the root (invokeai) directory and create a virtual Python environment within it named `.venv`.
|
1. Enter the root (invokeai) directory and create a virtual Python environment within it named `.venv`.
|
||||||
|
|
||||||
!!! info "Virtual Environment Location"
|
!!! warning "Virtual Environment Location"
|
||||||
|
|
||||||
While you may create the virtual environment anywhere in the file system, we recommend that you create it within the root directory as shown here. This allows the application to automatically detect its data directories.
|
While you may create the virtual environment anywhere in the file system, we recommend that you create it within the root directory as shown here. This allows the application to automatically detect its data directories.
|
||||||
|
|
||||||
If you choose a different location for the venv, then you must set the `INVOKEAI_ROOT` environment variable or pass the directory using the `--root` CLI arg.
|
If you choose a different location for the venv, then you _must_ set the `INVOKEAI_ROOT` environment variable or specify the root directory using the `--root` CLI arg.
|
||||||
|
|
||||||
```terminal
|
```terminal
|
||||||
cd $INVOKEAI_ROOT
|
cd $INVOKEAI_ROOT
|
||||||
@ -81,30 +77,22 @@ Before you start, go through the [installation requirements].
|
|||||||
python3 -m pip install --upgrade pip
|
python3 -m pip install --upgrade pip
|
||||||
```
|
```
|
||||||
|
|
||||||
1. Install the InvokeAI Package. The `--extra-index-url` option is used to select the correct `torch` backend:
|
1. Install the InvokeAI Package. The base command is `pip install InvokeAI --use-pep517`, but you may need to change this depending on your system and the desired features.
|
||||||
|
|
||||||
=== "CUDA (NVidia)"
|
- You may need to provide an [extra index URL]. Select your platform configuration using [this tool on the PyTorch website]. Copy the `--extra-index-url` string from this and append it to your install command.
|
||||||
|
|
||||||
|
!!! example "Install with an extra index URL"
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install "InvokeAI[xformers]" --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu121
|
pip install InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu121
|
||||||
```
|
```
|
||||||
|
|
||||||
=== "ROCm (AMD)"
|
- If you have a CUDA GPU and want to install with `xformers`, you need to add an option to the package name. Note that `xformers` is not necessary. PyTorch includes an implementation of the SDP attention algorithm with the same performance.
|
||||||
|
|
||||||
|
!!! example "Install with `xformers`"
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/rocm5.6
|
pip install "InvokeAI[xformers]" --use-pep517
|
||||||
```
|
|
||||||
|
|
||||||
=== "CPU (Intel Macs & non-GPU systems)"
|
|
||||||
|
|
||||||
```bash
|
|
||||||
pip install InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/cpu
|
|
||||||
```
|
|
||||||
|
|
||||||
=== "MPS (Apple Silicon)"
|
|
||||||
|
|
||||||
```bash
|
|
||||||
pip install InvokeAI --use-pep517
|
|
||||||
```
|
```
|
||||||
|
|
||||||
1. Deactivate and reactivate your runtime directory so that the invokeai-specific commands become available in the environment:
|
1. Deactivate and reactivate your runtime directory so that the invokeai-specific commands become available in the environment:
|
||||||
@ -126,37 +114,6 @@ Before you start, go through the [installation requirements].
|
|||||||
|
|
||||||
Run `invokeai-web` to start the UI. You must activate the virtual environment before running the app.
|
Run `invokeai-web` to start the UI. You must activate the virtual environment before running the app.
|
||||||
|
|
||||||
If the virtual environment you selected is NOT inside `INVOKEAI_ROOT`, then you must specify the path to the root directory by adding
|
!!! warning
|
||||||
`--root_dir \path\to\invokeai`.
|
|
||||||
|
|
||||||
!!! tip
|
If the virtual environment is _not_ inside the root directory, then you _must_ specify the path to the root directory with `--root_dir \path\to\invokeai` or the `INVOKEAI_ROOT` environment variable.
|
||||||
|
|
||||||
You can permanently set the location of the runtime directory
|
|
||||||
by setting the environment variable `INVOKEAI_ROOT` to the
|
|
||||||
path of the directory. As mentioned previously, this is
|
|
||||||
recommended if your virtual environment is located outside of
|
|
||||||
your runtime directory.
|
|
||||||
|
|
||||||
## Unsupported Conda Install
|
|
||||||
|
|
||||||
Congratulations, you found the "secret" Conda installation instructions. If you really **really** want to use Conda with InvokeAI, you can do so using this unsupported recipe:
|
|
||||||
|
|
||||||
```sh
|
|
||||||
mkdir ~/invokeai
|
|
||||||
conda create -n invokeai python=3.11
|
|
||||||
conda activate invokeai
|
|
||||||
# Adjust this as described above for the appropriate torch backend
|
|
||||||
pip install InvokeAI[xformers] --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu121
|
|
||||||
invokeai-web --root ~/invokeai
|
|
||||||
```
|
|
||||||
|
|
||||||
The `pip install` command shown in this recipe is for Linux/Windows
|
|
||||||
systems with an NVIDIA GPU. See step (6) above for the command to use
|
|
||||||
with other platforms/GPU combinations. If you don't wish to pass the
|
|
||||||
`--root` argument to `invokeai` with each launch, you may set the
|
|
||||||
environment variable `INVOKEAI_ROOT` to point to the installation directory.
|
|
||||||
|
|
||||||
Note that if you run into problems with the Conda installation, the InvokeAI
|
|
||||||
staff will **not** be able to help you out. Caveat Emptor!
|
|
||||||
|
|
||||||
[installation requirements]: INSTALL_REQUIREMENTS.md
|
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
InvokeAI installer script
|
InvokeAI installer script
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import locale
|
||||||
import os
|
import os
|
||||||
import platform
|
import platform
|
||||||
import re
|
import re
|
||||||
@ -316,7 +317,9 @@ def upgrade_pip(venv_path: Path) -> str | None:
|
|||||||
python = str(venv_path.expanduser().resolve() / python)
|
python = str(venv_path.expanduser().resolve() / python)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = subprocess.check_output([python, "-m", "pip", "install", "--upgrade", "pip"]).decode()
|
result = subprocess.check_output([python, "-m", "pip", "install", "--upgrade", "pip"]).decode(
|
||||||
|
encoding=locale.getpreferredencoding()
|
||||||
|
)
|
||||||
except subprocess.CalledProcessError as e:
|
except subprocess.CalledProcessError as e:
|
||||||
print(e)
|
print(e)
|
||||||
result = None
|
result = None
|
||||||
@ -404,22 +407,29 @@ def get_torch_source() -> Tuple[str | None, str | None]:
|
|||||||
# device can be one of: "cuda", "rocm", "cpu", "cuda_and_dml, autodetect"
|
# device can be one of: "cuda", "rocm", "cpu", "cuda_and_dml, autodetect"
|
||||||
device = select_gpu()
|
device = select_gpu()
|
||||||
|
|
||||||
|
# The correct extra index URLs for torch are inconsistent, see https://pytorch.org/get-started/locally/#start-locally
|
||||||
|
|
||||||
url = None
|
url = None
|
||||||
optional_modules = "[onnx]"
|
optional_modules: str | None = None
|
||||||
if OS == "Linux":
|
if OS == "Linux":
|
||||||
if device.value == "rocm":
|
if device.value == "rocm":
|
||||||
url = "https://download.pytorch.org/whl/rocm5.6"
|
url = "https://download.pytorch.org/whl/rocm5.6"
|
||||||
elif device.value == "cpu":
|
elif device.value == "cpu":
|
||||||
url = "https://download.pytorch.org/whl/cpu"
|
url = "https://download.pytorch.org/whl/cpu"
|
||||||
|
elif device.value == "cuda":
|
||||||
|
# CUDA uses the default PyPi index
|
||||||
|
optional_modules = "[xformers,onnx-cuda]"
|
||||||
elif OS == "Windows":
|
elif OS == "Windows":
|
||||||
if device.value == "cuda":
|
if device.value == "cuda":
|
||||||
url = "https://download.pytorch.org/whl/cu121"
|
url = "https://download.pytorch.org/whl/cu121"
|
||||||
optional_modules = "[xformers,onnx-cuda]"
|
optional_modules = "[xformers,onnx-cuda]"
|
||||||
if device.value == "cuda_and_dml":
|
elif device.value == "cpu":
|
||||||
url = "https://download.pytorch.org/whl/cu121"
|
# CPU uses the default PyPi index, no optional modules
|
||||||
optional_modules = "[xformers,onnx-directml]"
|
pass
|
||||||
|
elif OS == "Darwin":
|
||||||
|
# macOS uses the default PyPi index, no optional modules
|
||||||
|
pass
|
||||||
|
|
||||||
# in all other cases, Torch wheels should be coming from PyPi as of Torch 1.13
|
# Fall back to defaults
|
||||||
|
|
||||||
return (url, optional_modules)
|
return (url, optional_modules)
|
||||||
|
@ -207,10 +207,8 @@ def dest_path(dest: Optional[str | Path] = None) -> Path | None:
|
|||||||
|
|
||||||
class GpuType(Enum):
|
class GpuType(Enum):
|
||||||
CUDA = "cuda"
|
CUDA = "cuda"
|
||||||
CUDA_AND_DML = "cuda_and_dml"
|
|
||||||
ROCM = "rocm"
|
ROCM = "rocm"
|
||||||
CPU = "cpu"
|
CPU = "cpu"
|
||||||
AUTODETECT = "autodetect"
|
|
||||||
|
|
||||||
|
|
||||||
def select_gpu() -> GpuType:
|
def select_gpu() -> GpuType:
|
||||||
@ -226,10 +224,6 @@ def select_gpu() -> GpuType:
|
|||||||
"an [gold1 b]NVIDIA[/] GPU (using CUDA™)",
|
"an [gold1 b]NVIDIA[/] GPU (using CUDA™)",
|
||||||
GpuType.CUDA,
|
GpuType.CUDA,
|
||||||
)
|
)
|
||||||
nvidia_with_dml = (
|
|
||||||
"an [gold1 b]NVIDIA[/] GPU (using CUDA™, and DirectML™ for ONNX) -- ALPHA",
|
|
||||||
GpuType.CUDA_AND_DML,
|
|
||||||
)
|
|
||||||
amd = (
|
amd = (
|
||||||
"an [gold1 b]AMD[/] GPU (using ROCm™)",
|
"an [gold1 b]AMD[/] GPU (using ROCm™)",
|
||||||
GpuType.ROCM,
|
GpuType.ROCM,
|
||||||
@ -238,27 +232,19 @@ def select_gpu() -> GpuType:
|
|||||||
"Do not install any GPU support, use CPU for generation (slow)",
|
"Do not install any GPU support, use CPU for generation (slow)",
|
||||||
GpuType.CPU,
|
GpuType.CPU,
|
||||||
)
|
)
|
||||||
autodetect = (
|
|
||||||
"I'm not sure what to choose",
|
|
||||||
GpuType.AUTODETECT,
|
|
||||||
)
|
|
||||||
|
|
||||||
options = []
|
options = []
|
||||||
if OS == "Windows":
|
if OS == "Windows":
|
||||||
options = [nvidia, nvidia_with_dml, cpu]
|
options = [nvidia, cpu]
|
||||||
if OS == "Linux":
|
if OS == "Linux":
|
||||||
options = [nvidia, amd, cpu]
|
options = [nvidia, amd, cpu]
|
||||||
elif OS == "Darwin":
|
elif OS == "Darwin":
|
||||||
options = [cpu]
|
options = [cpu]
|
||||||
# future CoreML?
|
|
||||||
|
|
||||||
if len(options) == 1:
|
if len(options) == 1:
|
||||||
print(f'Your platform [gold1]{OS}-{ARCH}[/] only supports the "{options[0][1]}" driver. Proceeding with that.')
|
print(f'Your platform [gold1]{OS}-{ARCH}[/] only supports the "{options[0][1]}" driver. Proceeding with that.')
|
||||||
return options[0][1]
|
return options[0][1]
|
||||||
|
|
||||||
# "I don't know" is always added the last option
|
|
||||||
options.append(autodetect) # type: ignore
|
|
||||||
|
|
||||||
options = {str(i): opt for i, opt in enumerate(options, 1)}
|
options = {str(i): opt for i, opt in enumerate(options, 1)}
|
||||||
|
|
||||||
console.rule(":space_invader: GPU (Graphics Card) selection :space_invader:")
|
console.rule(":space_invader: GPU (Graphics Card) selection :space_invader:")
|
||||||
@ -292,11 +278,6 @@ def select_gpu() -> GpuType:
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
if options[choice][1] is GpuType.AUTODETECT:
|
|
||||||
console.print(
|
|
||||||
"No problem. We will install CUDA support first :crossed_fingers: If Invoke does not detect a GPU, please re-run the installer and select one of the other GPU types."
|
|
||||||
)
|
|
||||||
|
|
||||||
return options[choice][1]
|
return options[choice][1]
|
||||||
|
|
||||||
|
|
||||||
|
@ -12,7 +12,7 @@ from pydantic import BaseModel, Field
|
|||||||
|
|
||||||
from invokeai.app.invocations.upscale import ESRGAN_MODELS
|
from invokeai.app.invocations.upscale import ESRGAN_MODELS
|
||||||
from invokeai.app.services.invocation_cache.invocation_cache_common import InvocationCacheStatus
|
from invokeai.app.services.invocation_cache.invocation_cache_common import InvocationCacheStatus
|
||||||
from invokeai.backend.image_util.patchmatch import PatchMatch
|
from invokeai.backend.image_util.infill_methods.patchmatch import PatchMatch
|
||||||
from invokeai.backend.image_util.safety_checker import SafetyChecker
|
from invokeai.backend.image_util.safety_checker import SafetyChecker
|
||||||
from invokeai.backend.util.logging import logging
|
from invokeai.backend.util.logging import logging
|
||||||
from invokeai.version import __version__
|
from invokeai.version import __version__
|
||||||
@ -100,7 +100,7 @@ async def get_app_deps() -> AppDependencyVersions:
|
|||||||
|
|
||||||
@app_router.get("/config", operation_id="get_config", status_code=200, response_model=AppConfig)
|
@app_router.get("/config", operation_id="get_config", status_code=200, response_model=AppConfig)
|
||||||
async def get_config() -> AppConfig:
|
async def get_config() -> AppConfig:
|
||||||
infill_methods = ["tile", "lama", "cv2"]
|
infill_methods = ["tile", "lama", "cv2", "color"] # TODO: add mosaic back
|
||||||
if PatchMatch.patchmatch_available():
|
if PatchMatch.patchmatch_available():
|
||||||
infill_methods.append("patchmatch")
|
infill_methods.append("patchmatch")
|
||||||
|
|
||||||
|
@ -219,28 +219,13 @@ async def scan_for_models(
|
|||||||
non_core_model_paths = [p for p in found_model_paths if not p.is_relative_to(core_models_path)]
|
non_core_model_paths = [p for p in found_model_paths if not p.is_relative_to(core_models_path)]
|
||||||
|
|
||||||
installed_models = ApiDependencies.invoker.services.model_manager.store.search_by_attr()
|
installed_models = ApiDependencies.invoker.services.model_manager.store.search_by_attr()
|
||||||
resolved_installed_model_paths: list[str] = []
|
|
||||||
installed_model_sources: list[str] = []
|
|
||||||
|
|
||||||
# This call lists all installed models.
|
|
||||||
for model in installed_models:
|
|
||||||
path = pathlib.Path(model.path)
|
|
||||||
# If the model has a source, we need to add it to the list of installed sources.
|
|
||||||
if model.source:
|
|
||||||
installed_model_sources.append(model.source)
|
|
||||||
# If the path is not absolute, that means it is in the app models directory, and we need to join it with
|
|
||||||
# the models path before resolving.
|
|
||||||
if not path.is_absolute():
|
|
||||||
resolved_installed_model_paths.append(str(pathlib.Path(models_path, path).resolve()))
|
|
||||||
continue
|
|
||||||
resolved_installed_model_paths.append(str(path.resolve()))
|
|
||||||
|
|
||||||
scan_results: list[FoundModel] = []
|
scan_results: list[FoundModel] = []
|
||||||
|
|
||||||
# Check if the model is installed by comparing the resolved paths, appending to the scan result.
|
# Check if the model is installed by comparing paths, appending to the scan result.
|
||||||
for p in non_core_model_paths:
|
for p in non_core_model_paths:
|
||||||
path = str(p)
|
path = str(p)
|
||||||
is_installed = path in resolved_installed_model_paths or path in installed_model_sources
|
is_installed = any(str(models_path / m.path) == path for m in installed_models)
|
||||||
found_model = FoundModel(path=path, is_installed=is_installed)
|
found_model = FoundModel(path=path, is_installed=is_installed)
|
||||||
scan_results.append(found_model)
|
scan_results.append(found_model)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -1,154 +1,91 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
from abc import abstractmethod
|
||||||
|
from typing import Literal, get_args
|
||||||
|
|
||||||
import math
|
from PIL import Image
|
||||||
from typing import Literal, Optional, get_args
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image, ImageOps
|
|
||||||
|
|
||||||
from invokeai.app.invocations.fields import ColorField, ImageField
|
from invokeai.app.invocations.fields import ColorField, ImageField
|
||||||
from invokeai.app.invocations.primitives import ImageOutput
|
from invokeai.app.invocations.primitives import ImageOutput
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.app.util.download_with_progress import download_with_progress_bar
|
|
||||||
from invokeai.app.util.misc import SEED_MAX
|
from invokeai.app.util.misc import SEED_MAX
|
||||||
from invokeai.backend.image_util.cv2_inpaint import cv2_inpaint
|
from invokeai.backend.image_util.infill_methods.cv2_inpaint import cv2_inpaint
|
||||||
from invokeai.backend.image_util.lama import LaMA
|
from invokeai.backend.image_util.infill_methods.lama import LaMA
|
||||||
from invokeai.backend.image_util.patchmatch import PatchMatch
|
from invokeai.backend.image_util.infill_methods.mosaic import infill_mosaic
|
||||||
|
from invokeai.backend.image_util.infill_methods.patchmatch import PatchMatch, infill_patchmatch
|
||||||
|
from invokeai.backend.image_util.infill_methods.tile import infill_tile
|
||||||
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
|
|
||||||
from .baseinvocation import BaseInvocation, invocation
|
from .baseinvocation import BaseInvocation, invocation
|
||||||
from .fields import InputField, WithBoard, WithMetadata
|
from .fields import InputField, WithBoard, WithMetadata
|
||||||
from .image import PIL_RESAMPLING_MAP, PIL_RESAMPLING_MODES
|
from .image import PIL_RESAMPLING_MAP, PIL_RESAMPLING_MODES
|
||||||
|
|
||||||
|
logger = InvokeAILogger.get_logger()
|
||||||
|
|
||||||
def infill_methods() -> list[str]:
|
|
||||||
methods = ["tile", "solid", "lama", "cv2"]
|
def get_infill_methods():
|
||||||
|
methods = Literal["tile", "color", "lama", "cv2"] # TODO: add mosaic back
|
||||||
if PatchMatch.patchmatch_available():
|
if PatchMatch.patchmatch_available():
|
||||||
methods.insert(0, "patchmatch")
|
methods = Literal["patchmatch", "tile", "color", "lama", "cv2"] # TODO: add mosaic back
|
||||||
return methods
|
return methods
|
||||||
|
|
||||||
|
|
||||||
INFILL_METHODS = Literal[tuple(infill_methods())]
|
INFILL_METHODS = get_infill_methods()
|
||||||
DEFAULT_INFILL_METHOD = "patchmatch" if "patchmatch" in get_args(INFILL_METHODS) else "tile"
|
DEFAULT_INFILL_METHOD = "patchmatch" if "patchmatch" in get_args(INFILL_METHODS) else "tile"
|
||||||
|
|
||||||
|
|
||||||
def infill_lama(im: Image.Image) -> Image.Image:
|
class InfillImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||||
lama = LaMA()
|
"""Base class for invocations that preprocess images for Infilling"""
|
||||||
return lama(im)
|
|
||||||
|
|
||||||
|
image: ImageField = InputField(description="The image to process")
|
||||||
|
|
||||||
def infill_patchmatch(im: Image.Image) -> Image.Image:
|
@abstractmethod
|
||||||
if im.mode != "RGBA":
|
def infill(self, image: Image.Image) -> Image.Image:
|
||||||
return im
|
"""Infill the image with the specified method"""
|
||||||
|
pass
|
||||||
|
|
||||||
# Skip patchmatch if patchmatch isn't available
|
def load_image(self, context: InvocationContext) -> tuple[Image.Image, bool]:
|
||||||
if not PatchMatch.patchmatch_available():
|
"""Process the image to have an alpha channel before being infilled"""
|
||||||
return im
|
image = context.images.get_pil(self.image.image_name)
|
||||||
|
has_alpha = True if image.mode == "RGBA" else False
|
||||||
|
return image, has_alpha
|
||||||
|
|
||||||
# Patchmatch (note, we may want to expose patch_size? Increasing it significantly impacts performance though)
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
im_patched_np = PatchMatch.inpaint(im.convert("RGB"), ImageOps.invert(im.split()[-1]), patch_size=3)
|
# Retrieve and process image to be infilled
|
||||||
im_patched = Image.fromarray(im_patched_np, mode="RGB")
|
input_image, has_alpha = self.load_image(context)
|
||||||
return im_patched
|
|
||||||
|
|
||||||
|
# If the input image has no alpha channel, return it
|
||||||
|
if has_alpha is False:
|
||||||
|
return ImageOutput.build(context.images.get_dto(self.image.image_name))
|
||||||
|
|
||||||
def infill_cv2(im: Image.Image) -> Image.Image:
|
# Perform Infill action
|
||||||
return cv2_inpaint(im)
|
infilled_image = self.infill(input_image)
|
||||||
|
|
||||||
|
# Create ImageDTO for Infilled Image
|
||||||
|
infilled_image_dto = context.images.save(image=infilled_image)
|
||||||
|
|
||||||
def get_tile_images(image: np.ndarray, width=8, height=8):
|
# Return Infilled Image
|
||||||
_nrows, _ncols, depth = image.shape
|
return ImageOutput.build(infilled_image_dto)
|
||||||
_strides = image.strides
|
|
||||||
|
|
||||||
nrows, _m = divmod(_nrows, height)
|
|
||||||
ncols, _n = divmod(_ncols, width)
|
|
||||||
if _m != 0 or _n != 0:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return np.lib.stride_tricks.as_strided(
|
|
||||||
np.ravel(image),
|
|
||||||
shape=(nrows, ncols, height, width, depth),
|
|
||||||
strides=(height * _strides[0], width * _strides[1], *_strides),
|
|
||||||
writeable=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def tile_fill_missing(im: Image.Image, tile_size: int = 16, seed: Optional[int] = None) -> Image.Image:
|
|
||||||
# Only fill if there's an alpha layer
|
|
||||||
if im.mode != "RGBA":
|
|
||||||
return im
|
|
||||||
|
|
||||||
a = np.asarray(im, dtype=np.uint8)
|
|
||||||
|
|
||||||
tile_size_tuple = (tile_size, tile_size)
|
|
||||||
|
|
||||||
# Get the image as tiles of a specified size
|
|
||||||
tiles = get_tile_images(a, *tile_size_tuple).copy()
|
|
||||||
|
|
||||||
# Get the mask as tiles
|
|
||||||
tiles_mask = tiles[:, :, :, :, 3]
|
|
||||||
|
|
||||||
# Find any mask tiles with any fully transparent pixels (we will be replacing these later)
|
|
||||||
tmask_shape = tiles_mask.shape
|
|
||||||
tiles_mask = tiles_mask.reshape(math.prod(tiles_mask.shape))
|
|
||||||
n, ny = (math.prod(tmask_shape[0:2])), math.prod(tmask_shape[2:])
|
|
||||||
tiles_mask = tiles_mask > 0
|
|
||||||
tiles_mask = tiles_mask.reshape((n, ny)).all(axis=1)
|
|
||||||
|
|
||||||
# Get RGB tiles in single array and filter by the mask
|
|
||||||
tshape = tiles.shape
|
|
||||||
tiles_all = tiles.reshape((math.prod(tiles.shape[0:2]), *tiles.shape[2:]))
|
|
||||||
filtered_tiles = tiles_all[tiles_mask]
|
|
||||||
|
|
||||||
if len(filtered_tiles) == 0:
|
|
||||||
return im
|
|
||||||
|
|
||||||
# Find all invalid tiles and replace with a random valid tile
|
|
||||||
replace_count = (tiles_mask == False).sum() # noqa: E712
|
|
||||||
rng = np.random.default_rng(seed=seed)
|
|
||||||
tiles_all[np.logical_not(tiles_mask)] = filtered_tiles[rng.choice(filtered_tiles.shape[0], replace_count), :, :, :]
|
|
||||||
|
|
||||||
# Convert back to an image
|
|
||||||
tiles_all = tiles_all.reshape(tshape)
|
|
||||||
tiles_all = tiles_all.swapaxes(1, 2)
|
|
||||||
st = tiles_all.reshape(
|
|
||||||
(
|
|
||||||
math.prod(tiles_all.shape[0:2]),
|
|
||||||
math.prod(tiles_all.shape[2:4]),
|
|
||||||
tiles_all.shape[4],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
si = Image.fromarray(st, mode="RGBA")
|
|
||||||
|
|
||||||
return si
|
|
||||||
|
|
||||||
|
|
||||||
@invocation("infill_rgba", title="Solid Color Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.2")
|
@invocation("infill_rgba", title="Solid Color Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.2")
|
||||||
class InfillColorInvocation(BaseInvocation, WithMetadata, WithBoard):
|
class InfillColorInvocation(InfillImageProcessorInvocation):
|
||||||
"""Infills transparent areas of an image with a solid color"""
|
"""Infills transparent areas of an image with a solid color"""
|
||||||
|
|
||||||
image: ImageField = InputField(description="The image to infill")
|
|
||||||
color: ColorField = InputField(
|
color: ColorField = InputField(
|
||||||
default=ColorField(r=127, g=127, b=127, a=255),
|
default=ColorField(r=127, g=127, b=127, a=255),
|
||||||
description="The color to use to infill",
|
description="The color to use to infill",
|
||||||
)
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def infill(self, image: Image.Image):
|
||||||
image = context.images.get_pil(self.image.image_name)
|
|
||||||
|
|
||||||
solid_bg = Image.new("RGBA", image.size, self.color.tuple())
|
solid_bg = Image.new("RGBA", image.size, self.color.tuple())
|
||||||
infilled = Image.alpha_composite(solid_bg, image.convert("RGBA"))
|
infilled = Image.alpha_composite(solid_bg, image.convert("RGBA"))
|
||||||
|
|
||||||
infilled.paste(image, (0, 0), image.split()[-1])
|
infilled.paste(image, (0, 0), image.split()[-1])
|
||||||
|
return infilled
|
||||||
image_dto = context.images.save(image=infilled)
|
|
||||||
|
|
||||||
return ImageOutput.build(image_dto)
|
|
||||||
|
|
||||||
|
|
||||||
@invocation("infill_tile", title="Tile Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.3")
|
@invocation("infill_tile", title="Tile Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.3")
|
||||||
class InfillTileInvocation(BaseInvocation, WithMetadata, WithBoard):
|
class InfillTileInvocation(InfillImageProcessorInvocation):
|
||||||
"""Infills transparent areas of an image with tiles of the image"""
|
"""Infills transparent areas of an image with tiles of the image"""
|
||||||
|
|
||||||
image: ImageField = InputField(description="The image to infill")
|
|
||||||
tile_size: int = InputField(default=32, ge=1, description="The tile size (px)")
|
tile_size: int = InputField(default=32, ge=1, description="The tile size (px)")
|
||||||
seed: int = InputField(
|
seed: int = InputField(
|
||||||
default=0,
|
default=0,
|
||||||
@ -157,92 +94,74 @@ class InfillTileInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
description="The seed to use for tile generation (omit for random)",
|
description="The seed to use for tile generation (omit for random)",
|
||||||
)
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def infill(self, image: Image.Image):
|
||||||
image = context.images.get_pil(self.image.image_name)
|
output = infill_tile(image, seed=self.seed, tile_size=self.tile_size)
|
||||||
|
return output.infilled
|
||||||
infilled = tile_fill_missing(image.copy(), seed=self.seed, tile_size=self.tile_size)
|
|
||||||
infilled.paste(image, (0, 0), image.split()[-1])
|
|
||||||
|
|
||||||
image_dto = context.images.save(image=infilled)
|
|
||||||
|
|
||||||
return ImageOutput.build(image_dto)
|
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
"infill_patchmatch", title="PatchMatch Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.2"
|
"infill_patchmatch", title="PatchMatch Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.2"
|
||||||
)
|
)
|
||||||
class InfillPatchMatchInvocation(BaseInvocation, WithMetadata, WithBoard):
|
class InfillPatchMatchInvocation(InfillImageProcessorInvocation):
|
||||||
"""Infills transparent areas of an image using the PatchMatch algorithm"""
|
"""Infills transparent areas of an image using the PatchMatch algorithm"""
|
||||||
|
|
||||||
image: ImageField = InputField(description="The image to infill")
|
|
||||||
downscale: float = InputField(default=2.0, gt=0, description="Run patchmatch on downscaled image to speedup infill")
|
downscale: float = InputField(default=2.0, gt=0, description="Run patchmatch on downscaled image to speedup infill")
|
||||||
resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode")
|
resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode")
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def infill(self, image: Image.Image):
|
||||||
image = context.images.get_pil(self.image.image_name).convert("RGBA")
|
|
||||||
|
|
||||||
resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
|
resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
|
||||||
|
|
||||||
infill_image = image.copy()
|
|
||||||
width = int(image.width / self.downscale)
|
width = int(image.width / self.downscale)
|
||||||
height = int(image.height / self.downscale)
|
height = int(image.height / self.downscale)
|
||||||
infill_image = infill_image.resize(
|
|
||||||
|
infilled = image.resize(
|
||||||
(width, height),
|
(width, height),
|
||||||
resample=resample_mode,
|
resample=resample_mode,
|
||||||
)
|
)
|
||||||
|
infilled = infill_patchmatch(image)
|
||||||
if PatchMatch.patchmatch_available():
|
|
||||||
infilled = infill_patchmatch(infill_image)
|
|
||||||
else:
|
|
||||||
raise ValueError("PatchMatch is not available on this system")
|
|
||||||
|
|
||||||
infilled = infilled.resize(
|
infilled = infilled.resize(
|
||||||
(image.width, image.height),
|
(image.width, image.height),
|
||||||
resample=resample_mode,
|
resample=resample_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
infilled.paste(image, (0, 0), mask=image.split()[-1])
|
infilled.paste(image, (0, 0), mask=image.split()[-1])
|
||||||
# image.paste(infilled, (0, 0), mask=image.split()[-1])
|
|
||||||
|
|
||||||
image_dto = context.images.save(image=infilled)
|
return infilled
|
||||||
|
|
||||||
return ImageOutput.build(image_dto)
|
|
||||||
|
|
||||||
|
|
||||||
@invocation("infill_lama", title="LaMa Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.2")
|
@invocation("infill_lama", title="LaMa Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.2")
|
||||||
class LaMaInfillInvocation(BaseInvocation, WithMetadata, WithBoard):
|
class LaMaInfillInvocation(InfillImageProcessorInvocation):
|
||||||
"""Infills transparent areas of an image using the LaMa model"""
|
"""Infills transparent areas of an image using the LaMa model"""
|
||||||
|
|
||||||
image: ImageField = InputField(description="The image to infill")
|
def infill(self, image: Image.Image):
|
||||||
|
lama = LaMA()
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
return lama(image)
|
||||||
image = context.images.get_pil(self.image.image_name)
|
|
||||||
|
|
||||||
# Downloads the LaMa model if it doesn't already exist
|
|
||||||
download_with_progress_bar(
|
|
||||||
name="LaMa Inpainting Model",
|
|
||||||
url="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
|
|
||||||
dest_path=context.config.get().models_path / "core/misc/lama/lama.pt",
|
|
||||||
)
|
|
||||||
|
|
||||||
infilled = infill_lama(image.copy())
|
|
||||||
|
|
||||||
image_dto = context.images.save(image=infilled)
|
|
||||||
|
|
||||||
return ImageOutput.build(image_dto)
|
|
||||||
|
|
||||||
|
|
||||||
@invocation("infill_cv2", title="CV2 Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.2")
|
@invocation("infill_cv2", title="CV2 Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.2")
|
||||||
class CV2InfillInvocation(BaseInvocation, WithMetadata, WithBoard):
|
class CV2InfillInvocation(InfillImageProcessorInvocation):
|
||||||
"""Infills transparent areas of an image using OpenCV Inpainting"""
|
"""Infills transparent areas of an image using OpenCV Inpainting"""
|
||||||
|
|
||||||
|
def infill(self, image: Image.Image):
|
||||||
|
return cv2_inpaint(image)
|
||||||
|
|
||||||
|
|
||||||
|
# @invocation(
|
||||||
|
# "infill_mosaic", title="Mosaic Infill", tags=["image", "inpaint", "outpaint"], category="inpaint", version="1.0.0"
|
||||||
|
# )
|
||||||
|
class MosaicInfillInvocation(InfillImageProcessorInvocation):
|
||||||
|
"""Infills transparent areas of an image with a mosaic pattern drawing colors from the rest of the image"""
|
||||||
|
|
||||||
image: ImageField = InputField(description="The image to infill")
|
image: ImageField = InputField(description="The image to infill")
|
||||||
|
tile_width: int = InputField(default=64, description="Width of the tile")
|
||||||
|
tile_height: int = InputField(default=64, description="Height of the tile")
|
||||||
|
min_color: ColorField = InputField(
|
||||||
|
default=ColorField(r=0, g=0, b=0, a=255),
|
||||||
|
description="The min threshold for color",
|
||||||
|
)
|
||||||
|
max_color: ColorField = InputField(
|
||||||
|
default=ColorField(r=255, g=255, b=255, a=255),
|
||||||
|
description="The max threshold for color",
|
||||||
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def infill(self, image: Image.Image):
|
||||||
image = context.images.get_pil(self.image.image_name)
|
return infill_mosaic(image, (self.tile_width, self.tile_height), self.min_color.tuple(), self.max_color.tuple())
|
||||||
|
|
||||||
infilled = infill_cv2(image.copy())
|
|
||||||
|
|
||||||
image_dto = context.images.save(image=infilled)
|
|
||||||
|
|
||||||
return ImageOutput.build(image_dto)
|
|
||||||
|
@ -1,21 +1,22 @@
|
|||||||
from builtins import float
|
from builtins import float
|
||||||
from typing import List, Union
|
from typing import List, Literal, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import (
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
||||||
BaseInvocation,
|
|
||||||
BaseInvocationOutput,
|
|
||||||
invocation,
|
|
||||||
invocation_output,
|
|
||||||
)
|
|
||||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
|
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
|
||||||
from invokeai.app.invocations.model import ModelIdentifierField
|
from invokeai.app.invocations.model import ModelIdentifierField
|
||||||
from invokeai.app.invocations.primitives import ImageField
|
from invokeai.app.invocations.primitives import ImageField
|
||||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, IPAdapterConfig, ModelType
|
from invokeai.backend.model_manager.config import (
|
||||||
|
AnyModelConfig,
|
||||||
|
BaseModelType,
|
||||||
|
IPAdapterCheckpointConfig,
|
||||||
|
IPAdapterInvokeAIConfig,
|
||||||
|
ModelType,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class IPAdapterField(BaseModel):
|
class IPAdapterField(BaseModel):
|
||||||
@ -48,12 +49,15 @@ class IPAdapterOutput(BaseInvocationOutput):
|
|||||||
ip_adapter: IPAdapterField = OutputField(description=FieldDescriptions.ip_adapter, title="IP-Adapter")
|
ip_adapter: IPAdapterField = OutputField(description=FieldDescriptions.ip_adapter, title="IP-Adapter")
|
||||||
|
|
||||||
|
|
||||||
|
CLIP_VISION_MODEL_MAP = {"ViT-H": "ip_adapter_sd_image_encoder", "ViT-G": "ip_adapter_sdxl_image_encoder"}
|
||||||
|
|
||||||
|
|
||||||
@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.2.2")
|
@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.2.2")
|
||||||
class IPAdapterInvocation(BaseInvocation):
|
class IPAdapterInvocation(BaseInvocation):
|
||||||
"""Collects IP-Adapter info to pass to other nodes."""
|
"""Collects IP-Adapter info to pass to other nodes."""
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
image: Union[ImageField, List[ImageField]] = InputField(description="The IP-Adapter image prompt(s).")
|
image: Union[ImageField, List[ImageField]] = InputField(description="The IP-Adapter image prompt(s).", ui_order=1)
|
||||||
ip_adapter_model: ModelIdentifierField = InputField(
|
ip_adapter_model: ModelIdentifierField = InputField(
|
||||||
description="The IP-Adapter model.",
|
description="The IP-Adapter model.",
|
||||||
title="IP-Adapter Model",
|
title="IP-Adapter Model",
|
||||||
@ -61,7 +65,11 @@ class IPAdapterInvocation(BaseInvocation):
|
|||||||
ui_order=-1,
|
ui_order=-1,
|
||||||
ui_type=UIType.IPAdapterModel,
|
ui_type=UIType.IPAdapterModel,
|
||||||
)
|
)
|
||||||
|
clip_vision_model: Literal["ViT-H", "ViT-G"] = InputField(
|
||||||
|
description="CLIP Vision model to use. Overrides model settings. Mandatory for checkpoint models.",
|
||||||
|
default="ViT-H",
|
||||||
|
ui_order=2,
|
||||||
|
)
|
||||||
weight: Union[float, List[float]] = InputField(
|
weight: Union[float, List[float]] = InputField(
|
||||||
default=1, description="The weight given to the IP-Adapter", title="Weight"
|
default=1, description="The weight given to the IP-Adapter", title="Weight"
|
||||||
)
|
)
|
||||||
@ -86,10 +94,16 @@ class IPAdapterInvocation(BaseInvocation):
|
|||||||
def invoke(self, context: InvocationContext) -> IPAdapterOutput:
|
def invoke(self, context: InvocationContext) -> IPAdapterOutput:
|
||||||
# Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model.
|
# Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model.
|
||||||
ip_adapter_info = context.models.get_config(self.ip_adapter_model.key)
|
ip_adapter_info = context.models.get_config(self.ip_adapter_model.key)
|
||||||
assert isinstance(ip_adapter_info, IPAdapterConfig)
|
assert isinstance(ip_adapter_info, (IPAdapterInvokeAIConfig, IPAdapterCheckpointConfig))
|
||||||
|
|
||||||
|
if isinstance(ip_adapter_info, IPAdapterInvokeAIConfig):
|
||||||
image_encoder_model_id = ip_adapter_info.image_encoder_model_id
|
image_encoder_model_id = ip_adapter_info.image_encoder_model_id
|
||||||
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
|
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
|
||||||
|
else:
|
||||||
|
image_encoder_model_name = CLIP_VISION_MODEL_MAP[self.clip_vision_model]
|
||||||
|
|
||||||
image_encoder_model = self._get_image_encoder(context, image_encoder_model_name)
|
image_encoder_model = self._get_image_encoder(context, image_encoder_model_name)
|
||||||
|
|
||||||
return IPAdapterOutput(
|
return IPAdapterOutput(
|
||||||
ip_adapter=IPAdapterField(
|
ip_adapter=IPAdapterField(
|
||||||
image=self.image,
|
image=self.image,
|
||||||
@ -102,19 +116,25 @@ class IPAdapterInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _get_image_encoder(self, context: InvocationContext, image_encoder_model_name: str) -> AnyModelConfig:
|
def _get_image_encoder(self, context: InvocationContext, image_encoder_model_name: str) -> AnyModelConfig:
|
||||||
found = False
|
|
||||||
while not found:
|
|
||||||
image_encoder_models = context.models.search_by_attrs(
|
image_encoder_models = context.models.search_by_attrs(
|
||||||
name=image_encoder_model_name, base=BaseModelType.Any, type=ModelType.CLIPVision
|
name=image_encoder_model_name, base=BaseModelType.Any, type=ModelType.CLIPVision
|
||||||
)
|
)
|
||||||
found = len(image_encoder_models) > 0
|
|
||||||
if not found:
|
if not len(image_encoder_models) > 0:
|
||||||
context.logger.warning(
|
context.logger.warning(
|
||||||
f"The image encoder required by this IP Adapter ({image_encoder_model_name}) is not installed."
|
f"The image encoder required by this IP Adapter ({image_encoder_model_name}) is not installed. \
|
||||||
|
Downloading and installing now. This may take a while."
|
||||||
)
|
)
|
||||||
context.logger.warning("Downloading and installing now. This may take a while.")
|
|
||||||
installer = context._services.model_manager.install
|
installer = context._services.model_manager.install
|
||||||
job = installer.heuristic_import(f"InvokeAI/{image_encoder_model_name}")
|
job = installer.heuristic_import(f"InvokeAI/{image_encoder_model_name}")
|
||||||
installer.wait_for_job(job, timeout=600) # wait up to 10 minutes - then raise a TimeoutException
|
installer.wait_for_job(job, timeout=600) # Wait for up to 10 minutes
|
||||||
|
image_encoder_models = context.models.search_by_attrs(
|
||||||
|
name=image_encoder_model_name, base=BaseModelType.Any, type=ModelType.CLIPVision
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(image_encoder_models) == 0:
|
||||||
|
context.logger.error("Error while fetching CLIP Vision Image Encoder")
|
||||||
assert len(image_encoder_models) == 1
|
assert len(image_encoder_models) == 1
|
||||||
|
|
||||||
return image_encoder_models[0]
|
return image_encoder_models[0]
|
||||||
|
@ -44,11 +44,7 @@ from invokeai.app.invocations.fields import (
|
|||||||
WithMetadata,
|
WithMetadata,
|
||||||
)
|
)
|
||||||
from invokeai.app.invocations.ip_adapter import IPAdapterField
|
from invokeai.app.invocations.ip_adapter import IPAdapterField
|
||||||
from invokeai.app.invocations.primitives import (
|
from invokeai.app.invocations.primitives import DenoiseMaskOutput, ImageOutput, LatentsOutput
|
||||||
DenoiseMaskOutput,
|
|
||||||
ImageOutput,
|
|
||||||
LatentsOutput,
|
|
||||||
)
|
|
||||||
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
|
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||||
@ -76,12 +72,7 @@ from ...backend.stable_diffusion.diffusers_pipeline import (
|
|||||||
)
|
)
|
||||||
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||||
from ...backend.util.devices import choose_precision, choose_torch_device
|
from ...backend.util.devices import choose_precision, choose_torch_device
|
||||||
from .baseinvocation import (
|
from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
||||||
BaseInvocation,
|
|
||||||
BaseInvocationOutput,
|
|
||||||
invocation,
|
|
||||||
invocation_output,
|
|
||||||
)
|
|
||||||
from .controlnet_image_processors import ControlField
|
from .controlnet_image_processors import ControlField
|
||||||
from .model import ModelIdentifierField, UNetField, VAEField
|
from .model import ModelIdentifierField, UNetField, VAEField
|
||||||
|
|
||||||
@ -1423,7 +1414,7 @@ class IdealSizeInvocation(BaseInvocation):
|
|||||||
return tuple((x - x % multiple_of) for x in args)
|
return tuple((x - x % multiple_of) for x in args)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> IdealSizeOutput:
|
def invoke(self, context: InvocationContext) -> IdealSizeOutput:
|
||||||
unet_config = context.models.get_config(**self.unet.unet.model_dump())
|
unet_config = context.models.get_config(self.unet.unet.key)
|
||||||
aspect = self.width / self.height
|
aspect = self.width / self.height
|
||||||
dimension: float = 512
|
dimension: float = 512
|
||||||
if unet_config.base == BaseModelType.StableDiffusion2:
|
if unet_config.base == BaseModelType.StableDiffusion2:
|
||||||
|
@ -2,16 +2,8 @@ from typing import Any, Literal, Optional, Union
|
|||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import (
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
||||||
BaseInvocation,
|
from invokeai.app.invocations.controlnet_image_processors import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES
|
||||||
BaseInvocationOutput,
|
|
||||||
invocation,
|
|
||||||
invocation_output,
|
|
||||||
)
|
|
||||||
from invokeai.app.invocations.controlnet_image_processors import (
|
|
||||||
CONTROLNET_MODE_VALUES,
|
|
||||||
CONTROLNET_RESIZE_VALUES,
|
|
||||||
)
|
|
||||||
from invokeai.app.invocations.fields import (
|
from invokeai.app.invocations.fields import (
|
||||||
FieldDescriptions,
|
FieldDescriptions,
|
||||||
ImageField,
|
ImageField,
|
||||||
@ -43,6 +35,7 @@ class IPAdapterMetadataField(BaseModel):
|
|||||||
|
|
||||||
image: ImageField = Field(description="The IP-Adapter image prompt.")
|
image: ImageField = Field(description="The IP-Adapter image prompt.")
|
||||||
ip_adapter_model: ModelIdentifierField = Field(description="The IP-Adapter model.")
|
ip_adapter_model: ModelIdentifierField = Field(description="The IP-Adapter model.")
|
||||||
|
clip_vision_model: Literal["ViT-H", "ViT-G"] = Field(description="The CLIP Vision model")
|
||||||
weight: Union[float, list[float]] = Field(description="The weight given to the IP-Adapter")
|
weight: Union[float, list[float]] = Field(description="The weight given to the IP-Adapter")
|
||||||
begin_step_percent: float = Field(description="When the IP-Adapter is first applied (% of total steps)")
|
begin_step_percent: float = Field(description="When the IP-Adapter is first applied (% of total steps)")
|
||||||
end_step_percent: float = Field(description="When the IP-Adapter is last applied (% of total steps)")
|
end_step_percent: float = Field(description="When the IP-Adapter is last applied (% of total steps)")
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import locale
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
@ -317,11 +318,10 @@ class InvokeAIAppConfig(BaseSettings):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def find_root() -> Path:
|
def find_root() -> Path:
|
||||||
"""Choose the runtime root directory when not specified on command line or init file."""
|
"""Choose the runtime root directory when not specified on command line or init file."""
|
||||||
venv = Path(os.environ.get("VIRTUAL_ENV") or ".")
|
|
||||||
if os.environ.get("INVOKEAI_ROOT"):
|
if os.environ.get("INVOKEAI_ROOT"):
|
||||||
root = Path(os.environ["INVOKEAI_ROOT"])
|
root = Path(os.environ["INVOKEAI_ROOT"])
|
||||||
elif any((venv.parent / x).exists() for x in [INIT_FILE, LEGACY_INIT_FILE]):
|
elif venv := os.environ.get("VIRTUAL_ENV", None):
|
||||||
root = (venv.parent).resolve()
|
root = Path(venv).parent.resolve()
|
||||||
else:
|
else:
|
||||||
root = Path("~/invokeai").expanduser().resolve()
|
root = Path("~/invokeai").expanduser().resolve()
|
||||||
return root
|
return root
|
||||||
@ -402,7 +402,7 @@ def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
|
|||||||
An instance of `InvokeAIAppConfig` with the loaded and migrated settings.
|
An instance of `InvokeAIAppConfig` with the loaded and migrated settings.
|
||||||
"""
|
"""
|
||||||
assert config_path.suffix == ".yaml"
|
assert config_path.suffix == ".yaml"
|
||||||
with open(config_path) as file:
|
with open(config_path, "rt", encoding=locale.getpreferredencoding()) as file:
|
||||||
loaded_config_dict = yaml.safe_load(file)
|
loaded_config_dict = yaml.safe_load(file)
|
||||||
|
|
||||||
assert isinstance(loaded_config_dict, dict)
|
assert isinstance(loaded_config_dict, dict)
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
"""Model installation class."""
|
"""Model installation class."""
|
||||||
|
|
||||||
|
import locale
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import signal
|
import signal
|
||||||
@ -323,7 +324,8 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
legacy_models_yaml_path = Path(self._app_config.root_path, legacy_models_yaml_path)
|
legacy_models_yaml_path = Path(self._app_config.root_path, legacy_models_yaml_path)
|
||||||
|
|
||||||
if legacy_models_yaml_path.exists():
|
if legacy_models_yaml_path.exists():
|
||||||
legacy_models_yaml = yaml.safe_load(legacy_models_yaml_path.read_text())
|
with open(legacy_models_yaml_path, "rt", encoding=locale.getpreferredencoding()) as file:
|
||||||
|
legacy_models_yaml = yaml.safe_load(file)
|
||||||
|
|
||||||
yaml_metadata = legacy_models_yaml.pop("__metadata__")
|
yaml_metadata = legacy_models_yaml.pop("__metadata__")
|
||||||
yaml_version = yaml_metadata.get("version")
|
yaml_version = yaml_metadata.get("version")
|
||||||
@ -564,7 +566,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
# The model is not in the models directory - we don't need to move it.
|
# The model is not in the models directory - we don't need to move it.
|
||||||
return model
|
return model
|
||||||
|
|
||||||
new_path = (models_dir / model.base.value / model.type.value / model.name).with_suffix(old_path.suffix)
|
new_path = models_dir / model.base.value / model.type.value / old_path.name
|
||||||
|
|
||||||
if old_path == new_path or new_path.exists() and old_path == new_path.resolve():
|
if old_path == new_path or new_path.exists() and old_path == new_path.resolve():
|
||||||
return model
|
return model
|
||||||
|
@ -80,6 +80,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
ram_cache = ModelCache(
|
ram_cache = ModelCache(
|
||||||
max_cache_size=app_config.ram,
|
max_cache_size=app_config.ram,
|
||||||
max_vram_cache_size=app_config.vram,
|
max_vram_cache_size=app_config.vram,
|
||||||
|
lazy_offloading=app_config.lazy_offload,
|
||||||
logger=logger,
|
logger=logger,
|
||||||
execution_device=execution_device,
|
execution_device=execution_device,
|
||||||
)
|
)
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
Initialization file for invokeai.backend.image_util methods.
|
Initialization file for invokeai.backend.image_util methods.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .patchmatch import PatchMatch # noqa: F401
|
from .infill_methods.patchmatch import PatchMatch # noqa: F401
|
||||||
from .pngwriter import PngWriter, PromptFormatter, retrieve_metadata, write_metadata # noqa: F401
|
from .pngwriter import PngWriter, PromptFormatter, retrieve_metadata, write_metadata # noqa: F401
|
||||||
from .seamless import configure_model_padding # noqa: F401
|
from .seamless import configure_model_padding # noqa: F401
|
||||||
from .util import InitImageResizer, make_grid # noqa: F401
|
from .util import InitImageResizer, make_grid # noqa: F401
|
||||||
|
@ -7,6 +7,7 @@ from PIL import Image
|
|||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.app.services.config.config_default import get_config
|
from invokeai.app.services.config.config_default import get_config
|
||||||
|
from invokeai.app.util.download_with_progress import download_with_progress_bar
|
||||||
from invokeai.backend.util.devices import choose_torch_device
|
from invokeai.backend.util.devices import choose_torch_device
|
||||||
|
|
||||||
|
|
||||||
@ -30,6 +31,14 @@ class LaMA:
|
|||||||
def __call__(self, input_image: Image.Image, *args: Any, **kwds: Any) -> Any:
|
def __call__(self, input_image: Image.Image, *args: Any, **kwds: Any) -> Any:
|
||||||
device = choose_torch_device()
|
device = choose_torch_device()
|
||||||
model_location = get_config().models_path / "core/misc/lama/lama.pt"
|
model_location = get_config().models_path / "core/misc/lama/lama.pt"
|
||||||
|
|
||||||
|
if not model_location.exists():
|
||||||
|
download_with_progress_bar(
|
||||||
|
name="LaMa Inpainting Model",
|
||||||
|
url="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
|
||||||
|
dest_path=model_location,
|
||||||
|
)
|
||||||
|
|
||||||
model = load_jit_model(model_location, device)
|
model = load_jit_model(model_location, device)
|
||||||
|
|
||||||
image = np.asarray(input_image.convert("RGB"))
|
image = np.asarray(input_image.convert("RGB"))
|
60
invokeai/backend/image_util/infill_methods/mosaic.py
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
def infill_mosaic(
|
||||||
|
image: Image.Image,
|
||||||
|
tile_shape: Tuple[int, int] = (64, 64),
|
||||||
|
min_color: Tuple[int, int, int, int] = (0, 0, 0, 0),
|
||||||
|
max_color: Tuple[int, int, int, int] = (255, 255, 255, 0),
|
||||||
|
) -> Image.Image:
|
||||||
|
"""
|
||||||
|
image:PIL - A PIL Image
|
||||||
|
tile_shape: Tuple[int,int] - Tile width & Tile Height
|
||||||
|
min_color: Tuple[int,int,int] - RGB values for the lowest color to clip to (0-255)
|
||||||
|
max_color: Tuple[int,int,int] - RGB values for the highest color to clip to (0-255)
|
||||||
|
"""
|
||||||
|
|
||||||
|
np_image = np.array(image) # Convert image to np array
|
||||||
|
alpha = np_image[:, :, 3] # Get the mask from the alpha channel of the image
|
||||||
|
non_transparent_pixels = np_image[alpha != 0, :3] # List of non-transparent pixels
|
||||||
|
|
||||||
|
# Create color tiles to paste in the empty areas of the image
|
||||||
|
tile_width, tile_height = tile_shape
|
||||||
|
|
||||||
|
# Clip the range of colors in the image to a particular spectrum only
|
||||||
|
r_min, g_min, b_min, _ = min_color
|
||||||
|
r_max, g_max, b_max, _ = max_color
|
||||||
|
non_transparent_pixels[:, 0] = np.clip(non_transparent_pixels[:, 0], r_min, r_max)
|
||||||
|
non_transparent_pixels[:, 1] = np.clip(non_transparent_pixels[:, 1], g_min, g_max)
|
||||||
|
non_transparent_pixels[:, 2] = np.clip(non_transparent_pixels[:, 2], b_min, b_max)
|
||||||
|
|
||||||
|
tiles = []
|
||||||
|
for _ in range(256):
|
||||||
|
color = non_transparent_pixels[np.random.randint(len(non_transparent_pixels))]
|
||||||
|
tile = np.zeros((tile_height, tile_width, 3), dtype=np.uint8)
|
||||||
|
tile[:, :] = color
|
||||||
|
tiles.append(tile)
|
||||||
|
|
||||||
|
# Fill the transparent area with tiles
|
||||||
|
filled_image = np.zeros((image.height, image.width, 3), dtype=np.uint8)
|
||||||
|
|
||||||
|
for x in range(image.width):
|
||||||
|
for y in range(image.height):
|
||||||
|
tile = tiles[np.random.randint(len(tiles))]
|
||||||
|
try:
|
||||||
|
filled_image[
|
||||||
|
y - (y % tile_height) : y - (y % tile_height) + tile_height,
|
||||||
|
x - (x % tile_width) : x - (x % tile_width) + tile_width,
|
||||||
|
] = tile
|
||||||
|
except ValueError:
|
||||||
|
# Need to handle edge cases - literally
|
||||||
|
pass
|
||||||
|
|
||||||
|
filled_image = Image.fromarray(filled_image) # Convert the filled tiles image to PIL
|
||||||
|
image = Image.composite(
|
||||||
|
image, filled_image, image.split()[-1]
|
||||||
|
) # Composite the original image on top of the filled tiles
|
||||||
|
return image
|
67
invokeai/backend/image_util/infill_methods/patchmatch.py
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
"""
|
||||||
|
This module defines a singleton object, "patchmatch" that
|
||||||
|
wraps the actual patchmatch object. It respects the global
|
||||||
|
"try_patchmatch" attribute, so that patchmatch loading can
|
||||||
|
be suppressed or deferred
|
||||||
|
"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
import invokeai.backend.util.logging as logger
|
||||||
|
from invokeai.app.services.config.config_default import get_config
|
||||||
|
|
||||||
|
|
||||||
|
class PatchMatch:
|
||||||
|
"""
|
||||||
|
Thin class wrapper around the patchmatch function.
|
||||||
|
"""
|
||||||
|
|
||||||
|
patch_match = None
|
||||||
|
tried_load: bool = False
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _load_patch_match(cls):
|
||||||
|
if cls.tried_load:
|
||||||
|
return
|
||||||
|
if get_config().patchmatch:
|
||||||
|
from patchmatch import patch_match as pm
|
||||||
|
|
||||||
|
if pm.patchmatch_available:
|
||||||
|
logger.info("Patchmatch initialized")
|
||||||
|
cls.patch_match = pm
|
||||||
|
else:
|
||||||
|
logger.info("Patchmatch not loaded (nonfatal)")
|
||||||
|
else:
|
||||||
|
logger.info("Patchmatch loading disabled")
|
||||||
|
cls.tried_load = True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def patchmatch_available(cls) -> bool:
|
||||||
|
cls._load_patch_match()
|
||||||
|
if not cls.patch_match:
|
||||||
|
return False
|
||||||
|
return cls.patch_match.patchmatch_available
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def inpaint(cls, image: Image.Image) -> Image.Image:
|
||||||
|
if cls.patch_match is None or not cls.patchmatch_available():
|
||||||
|
return image
|
||||||
|
|
||||||
|
np_image = np.array(image)
|
||||||
|
mask = 255 - np_image[:, :, 3]
|
||||||
|
infilled = cls.patch_match.inpaint(np_image[:, :, :3], mask, patch_size=3)
|
||||||
|
return Image.fromarray(infilled, mode="RGB")
|
||||||
|
|
||||||
|
|
||||||
|
def infill_patchmatch(image: Image.Image) -> Image.Image:
|
||||||
|
IS_PATCHMATCH_AVAILABLE = PatchMatch.patchmatch_available()
|
||||||
|
|
||||||
|
if not IS_PATCHMATCH_AVAILABLE:
|
||||||
|
logger.warning("PatchMatch is not available on this system")
|
||||||
|
return image
|
||||||
|
|
||||||
|
return PatchMatch.inpaint(image)
|
After Width: | Height: | Size: 45 KiB |
After Width: | Height: | Size: 2.2 KiB |
After Width: | Height: | Size: 36 KiB |
After Width: | Height: | Size: 33 KiB |
After Width: | Height: | Size: 21 KiB |
After Width: | Height: | Size: 39 KiB |
After Width: | Height: | Size: 42 KiB |
After Width: | Height: | Size: 48 KiB |
After Width: | Height: | Size: 49 KiB |
After Width: | Height: | Size: 60 KiB |
95
invokeai/backend/image_util/infill_methods/tile.ipynb
Normal file
@ -0,0 +1,95 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"\"\"\"Smoke test for the tile infill\"\"\"\n",
|
||||||
|
"\n",
|
||||||
|
"from pathlib import Path\n",
|
||||||
|
"from typing import Optional\n",
|
||||||
|
"from PIL import Image\n",
|
||||||
|
"from invokeai.backend.image_util.infill_methods.tile import infill_tile\n",
|
||||||
|
"\n",
|
||||||
|
"images: list[tuple[str, Image.Image]] = []\n",
|
||||||
|
"\n",
|
||||||
|
"for i in sorted(Path(\"./test_images/\").glob(\"*.webp\")):\n",
|
||||||
|
" images.append((i.name, Image.open(i)))\n",
|
||||||
|
" images.append((i.name, Image.open(i).transpose(Image.FLIP_LEFT_RIGHT)))\n",
|
||||||
|
" images.append((i.name, Image.open(i).transpose(Image.FLIP_TOP_BOTTOM)))\n",
|
||||||
|
" images.append((i.name, Image.open(i).resize((512, 512))))\n",
|
||||||
|
" images.append((i.name, Image.open(i).resize((1234, 461))))\n",
|
||||||
|
"\n",
|
||||||
|
"outputs: list[tuple[str, Image.Image, Image.Image, Optional[Image.Image]]] = []\n",
|
||||||
|
"\n",
|
||||||
|
"for name, image in images:\n",
|
||||||
|
" try:\n",
|
||||||
|
" output = infill_tile(image, seed=0, tile_size=32)\n",
|
||||||
|
" outputs.append((name, image, output.infilled, output.tile_image))\n",
|
||||||
|
" except ValueError as e:\n",
|
||||||
|
" print(f\"Skipping image {name}: {e}\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Display the images in jupyter notebook\n",
|
||||||
|
"import matplotlib.pyplot as plt\n",
|
||||||
|
"from PIL import ImageOps\n",
|
||||||
|
"\n",
|
||||||
|
"fig, axes = plt.subplots(len(outputs), 3, figsize=(10, 3 * len(outputs)))\n",
|
||||||
|
"plt.subplots_adjust(hspace=0)\n",
|
||||||
|
"\n",
|
||||||
|
"for i, (name, original, infilled, tile_image) in enumerate(outputs):\n",
|
||||||
|
" # Add a border to each image, helps to see the edges\n",
|
||||||
|
" size = original.size\n",
|
||||||
|
" original = ImageOps.expand(original, border=5, fill=\"red\")\n",
|
||||||
|
" filled = ImageOps.expand(infilled, border=5, fill=\"red\")\n",
|
||||||
|
" if tile_image:\n",
|
||||||
|
" tile_image = ImageOps.expand(tile_image, border=5, fill=\"red\")\n",
|
||||||
|
"\n",
|
||||||
|
" axes[i, 0].imshow(original)\n",
|
||||||
|
" axes[i, 0].axis(\"off\")\n",
|
||||||
|
" axes[i, 0].set_title(f\"Original ({name} - {size})\")\n",
|
||||||
|
"\n",
|
||||||
|
" if tile_image:\n",
|
||||||
|
" axes[i, 1].imshow(tile_image)\n",
|
||||||
|
" axes[i, 1].axis(\"off\")\n",
|
||||||
|
" axes[i, 1].set_title(\"Tile Image\")\n",
|
||||||
|
" else:\n",
|
||||||
|
" axes[i, 1].axis(\"off\")\n",
|
||||||
|
" axes[i, 1].set_title(\"NO TILES GENERATED (NO TRANSPARENCY)\")\n",
|
||||||
|
"\n",
|
||||||
|
" axes[i, 2].imshow(filled)\n",
|
||||||
|
" axes[i, 2].axis(\"off\")\n",
|
||||||
|
" axes[i, 2].set_title(\"Filled\")"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": ".invokeai",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.10.12"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 2
|
||||||
|
}
|
122
invokeai/backend/image_util/infill_methods/tile.py
Normal file
@ -0,0 +1,122 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
def create_tile_pool(img_array: np.ndarray, tile_size: tuple[int, int]) -> list[np.ndarray]:
|
||||||
|
"""
|
||||||
|
Create a pool of tiles from non-transparent areas of the image by systematically walking through the image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img_array: numpy array of the image.
|
||||||
|
tile_size: tuple (tile_width, tile_height) specifying the size of each tile.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of numpy arrays, each representing a tile.
|
||||||
|
"""
|
||||||
|
tiles: list[np.ndarray] = []
|
||||||
|
rows, cols = img_array.shape[:2]
|
||||||
|
tile_width, tile_height = tile_size
|
||||||
|
|
||||||
|
for y in range(0, rows - tile_height + 1, tile_height):
|
||||||
|
for x in range(0, cols - tile_width + 1, tile_width):
|
||||||
|
tile = img_array[y : y + tile_height, x : x + tile_width]
|
||||||
|
# Check if the image has an alpha channel and the tile is completely opaque
|
||||||
|
if img_array.shape[2] == 4 and np.all(tile[:, :, 3] == 255):
|
||||||
|
tiles.append(tile)
|
||||||
|
elif img_array.shape[2] == 3: # If no alpha channel, append the tile
|
||||||
|
tiles.append(tile)
|
||||||
|
|
||||||
|
if not tiles:
|
||||||
|
raise ValueError(
|
||||||
|
"Not enough opaque pixels to generate any tiles. Use a smaller tile size or a different image."
|
||||||
|
)
|
||||||
|
|
||||||
|
return tiles
|
||||||
|
|
||||||
|
|
||||||
|
def create_filled_image(
|
||||||
|
img_array: np.ndarray, tile_pool: list[np.ndarray], tile_size: tuple[int, int], seed: int
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Create an image of the same dimensions as the original, filled entirely with tiles from the pool.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img_array: numpy array of the original image.
|
||||||
|
tile_pool: A list of numpy arrays, each representing a tile.
|
||||||
|
tile_size: tuple (tile_width, tile_height) specifying the size of each tile.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A numpy array representing the filled image.
|
||||||
|
"""
|
||||||
|
|
||||||
|
rows, cols, _ = img_array.shape
|
||||||
|
tile_width, tile_height = tile_size
|
||||||
|
|
||||||
|
# Prep an empty RGB image
|
||||||
|
filled_img_array = np.zeros((rows, cols, 3), dtype=img_array.dtype)
|
||||||
|
|
||||||
|
# Make the random tile selection reproducible
|
||||||
|
rng = np.random.default_rng(seed)
|
||||||
|
|
||||||
|
for y in range(0, rows, tile_height):
|
||||||
|
for x in range(0, cols, tile_width):
|
||||||
|
# Pick a random tile from the pool
|
||||||
|
tile = tile_pool[rng.integers(len(tile_pool))]
|
||||||
|
|
||||||
|
# Calculate the space available (may be less than tile size near the edges)
|
||||||
|
space_y = min(tile_height, rows - y)
|
||||||
|
space_x = min(tile_width, cols - x)
|
||||||
|
|
||||||
|
# Crop the tile if necessary to fit into the available space
|
||||||
|
cropped_tile = tile[:space_y, :space_x, :3]
|
||||||
|
|
||||||
|
# Fill the available space with the (possibly cropped) tile
|
||||||
|
filled_img_array[y : y + space_y, x : x + space_x, :3] = cropped_tile
|
||||||
|
|
||||||
|
return filled_img_array
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class InfillTileOutput:
|
||||||
|
infilled: Image.Image
|
||||||
|
tile_image: Optional[Image.Image] = None
|
||||||
|
|
||||||
|
|
||||||
|
def infill_tile(image_to_infill: Image.Image, seed: int, tile_size: int) -> InfillTileOutput:
|
||||||
|
"""Infills an image with random tiles from the image itself.
|
||||||
|
|
||||||
|
If the image is not an RGBA image, it is returned untouched.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: The image to infill.
|
||||||
|
tile_size: The size of the tiles to use for infilling.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If there are not enough opaque pixels to generate any tiles.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if image_to_infill.mode != "RGBA":
|
||||||
|
return InfillTileOutput(infilled=image_to_infill)
|
||||||
|
|
||||||
|
# Internally, we want a tuple of (tile_width, tile_height). In the future, the tile size can be any rectangle.
|
||||||
|
_tile_size = (tile_size, tile_size)
|
||||||
|
np_image = np.array(image_to_infill, dtype=np.uint8)
|
||||||
|
|
||||||
|
# Create the pool of tiles that we will use to infill
|
||||||
|
tile_pool = create_tile_pool(np_image, _tile_size)
|
||||||
|
|
||||||
|
# Create an image from the tiles, same size as the original
|
||||||
|
tile_np_image = create_filled_image(np_image, tile_pool, _tile_size, seed)
|
||||||
|
|
||||||
|
# Paste the OG image over the tile image, effectively infilling the area
|
||||||
|
tile_image = Image.fromarray(tile_np_image, "RGB")
|
||||||
|
infilled = tile_image.copy()
|
||||||
|
infilled.paste(image_to_infill, (0, 0), image_to_infill.split()[-1])
|
||||||
|
|
||||||
|
# I think we want this to be "RGBA"?
|
||||||
|
infilled.convert("RGBA")
|
||||||
|
|
||||||
|
return InfillTileOutput(infilled=infilled, tile_image=tile_image)
|
@ -1,49 +0,0 @@
|
|||||||
"""
|
|
||||||
This module defines a singleton object, "patchmatch" that
|
|
||||||
wraps the actual patchmatch object. It respects the global
|
|
||||||
"try_patchmatch" attribute, so that patchmatch loading can
|
|
||||||
be suppressed or deferred
|
|
||||||
"""
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
from invokeai.app.services.config.config_default import get_config
|
|
||||||
|
|
||||||
|
|
||||||
class PatchMatch:
|
|
||||||
"""
|
|
||||||
Thin class wrapper around the patchmatch function.
|
|
||||||
"""
|
|
||||||
|
|
||||||
patch_match = None
|
|
||||||
tried_load: bool = False
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _load_patch_match(self):
|
|
||||||
if self.tried_load:
|
|
||||||
return
|
|
||||||
if get_config().patchmatch:
|
|
||||||
from patchmatch import patch_match as pm
|
|
||||||
|
|
||||||
if pm.patchmatch_available:
|
|
||||||
logger.info("Patchmatch initialized")
|
|
||||||
else:
|
|
||||||
logger.info("Patchmatch not loaded (nonfatal)")
|
|
||||||
self.patch_match = pm
|
|
||||||
else:
|
|
||||||
logger.info("Patchmatch loading disabled")
|
|
||||||
self.tried_load = True
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def patchmatch_available(self) -> bool:
|
|
||||||
self._load_patch_match()
|
|
||||||
return self.patch_match and self.patch_match.patchmatch_available
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def inpaint(self, *args, **kwargs) -> np.ndarray:
|
|
||||||
if self.patchmatch_available():
|
|
||||||
return self.patch_match.inpaint(*args, **kwargs)
|
|
@ -1,8 +1,11 @@
|
|||||||
# copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0)
|
# copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0)
|
||||||
# and modified as needed
|
# and modified as needed
|
||||||
|
|
||||||
from typing import Optional, Union
|
import pathlib
|
||||||
|
from typing import List, Optional, TypedDict, Union
|
||||||
|
|
||||||
|
import safetensors
|
||||||
|
import safetensors.torch
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||||
@ -13,10 +16,17 @@ from ..raw_model import RawModel
|
|||||||
from .resampler import Resampler
|
from .resampler import Resampler
|
||||||
|
|
||||||
|
|
||||||
|
class IPAdapterStateDict(TypedDict):
|
||||||
|
ip_adapter: dict[str, torch.Tensor]
|
||||||
|
image_proj: dict[str, torch.Tensor]
|
||||||
|
|
||||||
|
|
||||||
class ImageProjModel(torch.nn.Module):
|
class ImageProjModel(torch.nn.Module):
|
||||||
"""Image Projection Model"""
|
"""Image Projection Model"""
|
||||||
|
|
||||||
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
|
def __init__(
|
||||||
|
self, cross_attention_dim: int = 1024, clip_embeddings_dim: int = 1024, clip_extra_context_tokens: int = 4
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.cross_attention_dim = cross_attention_dim
|
self.cross_attention_dim = cross_attention_dim
|
||||||
@ -25,7 +35,7 @@ class ImageProjModel(torch.nn.Module):
|
|||||||
self.norm = torch.nn.LayerNorm(cross_attention_dim)
|
self.norm = torch.nn.LayerNorm(cross_attention_dim)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_state_dict(cls, state_dict: dict[torch.Tensor], clip_extra_context_tokens=4):
|
def from_state_dict(cls, state_dict: dict[str, torch.Tensor], clip_extra_context_tokens: int = 4):
|
||||||
"""Initialize an ImageProjModel from a state_dict.
|
"""Initialize an ImageProjModel from a state_dict.
|
||||||
|
|
||||||
The cross_attention_dim and clip_embeddings_dim are inferred from the shape of the tensors in the state_dict.
|
The cross_attention_dim and clip_embeddings_dim are inferred from the shape of the tensors in the state_dict.
|
||||||
@ -45,7 +55,7 @@ class ImageProjModel(torch.nn.Module):
|
|||||||
model.load_state_dict(state_dict)
|
model.load_state_dict(state_dict)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def forward(self, image_embeds):
|
def forward(self, image_embeds: torch.Tensor):
|
||||||
embeds = image_embeds
|
embeds = image_embeds
|
||||||
clip_extra_context_tokens = self.proj(embeds).reshape(
|
clip_extra_context_tokens = self.proj(embeds).reshape(
|
||||||
-1, self.clip_extra_context_tokens, self.cross_attention_dim
|
-1, self.clip_extra_context_tokens, self.cross_attention_dim
|
||||||
@ -57,7 +67,7 @@ class ImageProjModel(torch.nn.Module):
|
|||||||
class MLPProjModel(torch.nn.Module):
|
class MLPProjModel(torch.nn.Module):
|
||||||
"""SD model with image prompt"""
|
"""SD model with image prompt"""
|
||||||
|
|
||||||
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024):
|
def __init__(self, cross_attention_dim: int = 1024, clip_embeddings_dim: int = 1024):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.proj = torch.nn.Sequential(
|
self.proj = torch.nn.Sequential(
|
||||||
@ -68,7 +78,7 @@ class MLPProjModel(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_state_dict(cls, state_dict: dict[torch.Tensor]):
|
def from_state_dict(cls, state_dict: dict[str, torch.Tensor]):
|
||||||
"""Initialize an MLPProjModel from a state_dict.
|
"""Initialize an MLPProjModel from a state_dict.
|
||||||
|
|
||||||
The cross_attention_dim and clip_embeddings_dim are inferred from the shape of the tensors in the state_dict.
|
The cross_attention_dim and clip_embeddings_dim are inferred from the shape of the tensors in the state_dict.
|
||||||
@ -87,7 +97,7 @@ class MLPProjModel(torch.nn.Module):
|
|||||||
model.load_state_dict(state_dict)
|
model.load_state_dict(state_dict)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def forward(self, image_embeds):
|
def forward(self, image_embeds: torch.Tensor):
|
||||||
clip_extra_context_tokens = self.proj(image_embeds)
|
clip_extra_context_tokens = self.proj(image_embeds)
|
||||||
return clip_extra_context_tokens
|
return clip_extra_context_tokens
|
||||||
|
|
||||||
@ -97,7 +107,7 @@ class IPAdapter(RawModel):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
state_dict: dict[str, torch.Tensor],
|
state_dict: IPAdapterStateDict,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
dtype: torch.dtype = torch.float16,
|
dtype: torch.dtype = torch.float16,
|
||||||
num_tokens: int = 4,
|
num_tokens: int = 4,
|
||||||
@ -129,24 +139,27 @@ class IPAdapter(RawModel):
|
|||||||
|
|
||||||
return calc_model_size_by_data(self._image_proj_model) + calc_model_size_by_data(self.attn_weights)
|
return calc_model_size_by_data(self._image_proj_model) + calc_model_size_by_data(self.attn_weights)
|
||||||
|
|
||||||
def _init_image_proj_model(self, state_dict):
|
def _init_image_proj_model(
|
||||||
|
self, state_dict: dict[str, torch.Tensor]
|
||||||
|
) -> Union[ImageProjModel, Resampler, MLPProjModel]:
|
||||||
return ImageProjModel.from_state_dict(state_dict, self._num_tokens).to(self.device, dtype=self.dtype)
|
return ImageProjModel.from_state_dict(state_dict, self._num_tokens).to(self.device, dtype=self.dtype)
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def get_image_embeds(self, pil_image, image_encoder: CLIPVisionModelWithProjection):
|
def get_image_embeds(self, pil_image: List[Image.Image], image_encoder: CLIPVisionModelWithProjection):
|
||||||
if isinstance(pil_image, Image.Image):
|
|
||||||
pil_image = [pil_image]
|
|
||||||
clip_image = self._clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
|
clip_image = self._clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
|
||||||
clip_image_embeds = image_encoder(clip_image.to(self.device, dtype=self.dtype)).image_embeds
|
clip_image_embeds = image_encoder(clip_image.to(self.device, dtype=self.dtype)).image_embeds
|
||||||
|
try:
|
||||||
image_prompt_embeds = self._image_proj_model(clip_image_embeds)
|
image_prompt_embeds = self._image_proj_model(clip_image_embeds)
|
||||||
uncond_image_prompt_embeds = self._image_proj_model(torch.zeros_like(clip_image_embeds))
|
uncond_image_prompt_embeds = self._image_proj_model(torch.zeros_like(clip_image_embeds))
|
||||||
return image_prompt_embeds, uncond_image_prompt_embeds
|
return image_prompt_embeds, uncond_image_prompt_embeds
|
||||||
|
except RuntimeError as e:
|
||||||
|
raise RuntimeError("Selected CLIP Vision Model is incompatible with the current IP Adapter") from e
|
||||||
|
|
||||||
|
|
||||||
class IPAdapterPlus(IPAdapter):
|
class IPAdapterPlus(IPAdapter):
|
||||||
"""IP-Adapter with fine-grained features"""
|
"""IP-Adapter with fine-grained features"""
|
||||||
|
|
||||||
def _init_image_proj_model(self, state_dict):
|
def _init_image_proj_model(self, state_dict: dict[str, torch.Tensor]) -> Union[Resampler, MLPProjModel]:
|
||||||
return Resampler.from_state_dict(
|
return Resampler.from_state_dict(
|
||||||
state_dict=state_dict,
|
state_dict=state_dict,
|
||||||
depth=4,
|
depth=4,
|
||||||
@ -157,31 +170,32 @@ class IPAdapterPlus(IPAdapter):
|
|||||||
).to(self.device, dtype=self.dtype)
|
).to(self.device, dtype=self.dtype)
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def get_image_embeds(self, pil_image, image_encoder: CLIPVisionModelWithProjection):
|
def get_image_embeds(self, pil_image: List[Image.Image], image_encoder: CLIPVisionModelWithProjection):
|
||||||
if isinstance(pil_image, Image.Image):
|
|
||||||
pil_image = [pil_image]
|
|
||||||
clip_image = self._clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
|
clip_image = self._clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
|
||||||
clip_image = clip_image.to(self.device, dtype=self.dtype)
|
clip_image = clip_image.to(self.device, dtype=self.dtype)
|
||||||
clip_image_embeds = image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
|
clip_image_embeds = image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
|
||||||
image_prompt_embeds = self._image_proj_model(clip_image_embeds)
|
|
||||||
uncond_clip_image_embeds = image_encoder(torch.zeros_like(clip_image), output_hidden_states=True).hidden_states[
|
uncond_clip_image_embeds = image_encoder(torch.zeros_like(clip_image), output_hidden_states=True).hidden_states[
|
||||||
-2
|
-2
|
||||||
]
|
]
|
||||||
|
try:
|
||||||
|
image_prompt_embeds = self._image_proj_model(clip_image_embeds)
|
||||||
uncond_image_prompt_embeds = self._image_proj_model(uncond_clip_image_embeds)
|
uncond_image_prompt_embeds = self._image_proj_model(uncond_clip_image_embeds)
|
||||||
return image_prompt_embeds, uncond_image_prompt_embeds
|
return image_prompt_embeds, uncond_image_prompt_embeds
|
||||||
|
except RuntimeError as e:
|
||||||
|
raise RuntimeError("Selected CLIP Vision Model is incompatible with the current IP Adapter") from e
|
||||||
|
|
||||||
|
|
||||||
class IPAdapterFull(IPAdapterPlus):
|
class IPAdapterFull(IPAdapterPlus):
|
||||||
"""IP-Adapter Plus with full features."""
|
"""IP-Adapter Plus with full features."""
|
||||||
|
|
||||||
def _init_image_proj_model(self, state_dict: dict[torch.Tensor]):
|
def _init_image_proj_model(self, state_dict: dict[str, torch.Tensor]):
|
||||||
return MLPProjModel.from_state_dict(state_dict).to(self.device, dtype=self.dtype)
|
return MLPProjModel.from_state_dict(state_dict).to(self.device, dtype=self.dtype)
|
||||||
|
|
||||||
|
|
||||||
class IPAdapterPlusXL(IPAdapterPlus):
|
class IPAdapterPlusXL(IPAdapterPlus):
|
||||||
"""IP-Adapter Plus for SDXL."""
|
"""IP-Adapter Plus for SDXL."""
|
||||||
|
|
||||||
def _init_image_proj_model(self, state_dict):
|
def _init_image_proj_model(self, state_dict: dict[str, torch.Tensor]):
|
||||||
return Resampler.from_state_dict(
|
return Resampler.from_state_dict(
|
||||||
state_dict=state_dict,
|
state_dict=state_dict,
|
||||||
depth=4,
|
depth=4,
|
||||||
@ -192,24 +206,48 @@ class IPAdapterPlusXL(IPAdapterPlus):
|
|||||||
).to(self.device, dtype=self.dtype)
|
).to(self.device, dtype=self.dtype)
|
||||||
|
|
||||||
|
|
||||||
def build_ip_adapter(
|
def load_ip_adapter_tensors(ip_adapter_ckpt_path: pathlib.Path, device: str) -> IPAdapterStateDict:
|
||||||
ip_adapter_ckpt_path: str, device: torch.device, dtype: torch.dtype = torch.float16
|
state_dict: IPAdapterStateDict = {"ip_adapter": {}, "image_proj": {}}
|
||||||
) -> Union[IPAdapter, IPAdapterPlus]:
|
|
||||||
state_dict = torch.load(ip_adapter_ckpt_path, map_location="cpu")
|
|
||||||
|
|
||||||
if "proj.weight" in state_dict["image_proj"]: # IPAdapter (with ImageProjModel).
|
if ip_adapter_ckpt_path.suffix == ".safetensors":
|
||||||
|
model = safetensors.torch.load_file(ip_adapter_ckpt_path, device=device)
|
||||||
|
for key in model.keys():
|
||||||
|
if key.startswith("image_proj."):
|
||||||
|
state_dict["image_proj"][key.replace("image_proj.", "")] = model[key]
|
||||||
|
elif key.startswith("ip_adapter."):
|
||||||
|
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = model[key]
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Encountered unexpected IP Adapter state dict key: '{key}'.")
|
||||||
|
else:
|
||||||
|
ip_adapter_diffusers_checkpoint_path = ip_adapter_ckpt_path / "ip_adapter.bin"
|
||||||
|
state_dict = torch.load(ip_adapter_diffusers_checkpoint_path, map_location="cpu")
|
||||||
|
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
|
def build_ip_adapter(
|
||||||
|
ip_adapter_ckpt_path: pathlib.Path, device: torch.device, dtype: torch.dtype = torch.float16
|
||||||
|
) -> Union[IPAdapter, IPAdapterPlus, IPAdapterPlusXL, IPAdapterPlus]:
|
||||||
|
state_dict = load_ip_adapter_tensors(ip_adapter_ckpt_path, device.type)
|
||||||
|
|
||||||
|
# IPAdapter (with ImageProjModel)
|
||||||
|
if "proj.weight" in state_dict["image_proj"]:
|
||||||
return IPAdapter(state_dict, device=device, dtype=dtype)
|
return IPAdapter(state_dict, device=device, dtype=dtype)
|
||||||
elif "proj_in.weight" in state_dict["image_proj"]: # IPAdaterPlus or IPAdapterPlusXL (with Resampler).
|
|
||||||
|
# IPAdaterPlus or IPAdapterPlusXL (with Resampler)
|
||||||
|
elif "proj_in.weight" in state_dict["image_proj"]:
|
||||||
cross_attention_dim = state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[-1]
|
cross_attention_dim = state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[-1]
|
||||||
if cross_attention_dim == 768:
|
if cross_attention_dim == 768:
|
||||||
# SD1 IP-Adapter Plus
|
return IPAdapterPlus(state_dict, device=device, dtype=dtype) # SD1 IP-Adapter Plus
|
||||||
return IPAdapterPlus(state_dict, device=device, dtype=dtype)
|
|
||||||
elif cross_attention_dim == 2048:
|
elif cross_attention_dim == 2048:
|
||||||
# SDXL IP-Adapter Plus
|
return IPAdapterPlusXL(state_dict, device=device, dtype=dtype) # SDXL IP-Adapter Plus
|
||||||
return IPAdapterPlusXL(state_dict, device=device, dtype=dtype)
|
|
||||||
else:
|
else:
|
||||||
raise Exception(f"Unsupported IP-Adapter Plus cross-attention dimension: {cross_attention_dim}.")
|
raise Exception(f"Unsupported IP-Adapter Plus cross-attention dimension: {cross_attention_dim}.")
|
||||||
elif "proj.0.weight" in state_dict["image_proj"]: # IPAdapterFull (with MLPProjModel).
|
|
||||||
|
# IPAdapterFull (with MLPProjModel)
|
||||||
|
elif "proj.0.weight" in state_dict["image_proj"]:
|
||||||
return IPAdapterFull(state_dict, device=device, dtype=dtype)
|
return IPAdapterFull(state_dict, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
# Unrecognized IP Adapter Architectures
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"'{ip_adapter_ckpt_path}' has an unrecognized IP-Adapter model architecture.")
|
raise ValueError(f"'{ip_adapter_ckpt_path}' has an unrecognized IP-Adapter model architecture.")
|
||||||
|
@ -9,8 +9,8 @@ import torch.nn as nn
|
|||||||
|
|
||||||
|
|
||||||
# FFN
|
# FFN
|
||||||
def FeedForward(dim, mult=4):
|
def FeedForward(dim: int, mult: int = 4):
|
||||||
inner_dim = int(dim * mult)
|
inner_dim = dim * mult
|
||||||
return nn.Sequential(
|
return nn.Sequential(
|
||||||
nn.LayerNorm(dim),
|
nn.LayerNorm(dim),
|
||||||
nn.Linear(dim, inner_dim, bias=False),
|
nn.Linear(dim, inner_dim, bias=False),
|
||||||
@ -19,8 +19,8 @@ def FeedForward(dim, mult=4):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def reshape_tensor(x, heads):
|
def reshape_tensor(x: torch.Tensor, heads: int):
|
||||||
bs, length, width = x.shape
|
bs, length, _ = x.shape
|
||||||
# (bs, length, width) --> (bs, length, n_heads, dim_per_head)
|
# (bs, length, width) --> (bs, length, n_heads, dim_per_head)
|
||||||
x = x.view(bs, length, heads, -1)
|
x = x.view(bs, length, heads, -1)
|
||||||
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
|
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
|
||||||
@ -31,7 +31,7 @@ def reshape_tensor(x, heads):
|
|||||||
|
|
||||||
|
|
||||||
class PerceiverAttention(nn.Module):
|
class PerceiverAttention(nn.Module):
|
||||||
def __init__(self, *, dim, dim_head=64, heads=8):
|
def __init__(self, *, dim: int, dim_head: int = 64, heads: int = 8):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.scale = dim_head**-0.5
|
self.scale = dim_head**-0.5
|
||||||
self.dim_head = dim_head
|
self.dim_head = dim_head
|
||||||
@ -45,7 +45,7 @@ class PerceiverAttention(nn.Module):
|
|||||||
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
|
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
|
||||||
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
||||||
|
|
||||||
def forward(self, x, latents):
|
def forward(self, x: torch.Tensor, latents: torch.Tensor):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
x (torch.Tensor): image features
|
x (torch.Tensor): image features
|
||||||
@ -80,14 +80,14 @@ class PerceiverAttention(nn.Module):
|
|||||||
class Resampler(nn.Module):
|
class Resampler(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dim=1024,
|
dim: int = 1024,
|
||||||
depth=8,
|
depth: int = 8,
|
||||||
dim_head=64,
|
dim_head: int = 64,
|
||||||
heads=16,
|
heads: int = 16,
|
||||||
num_queries=8,
|
num_queries: int = 8,
|
||||||
embedding_dim=768,
|
embedding_dim: int = 768,
|
||||||
output_dim=1024,
|
output_dim: int = 1024,
|
||||||
ff_mult=4,
|
ff_mult: int = 4,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -110,7 +110,15 @@ class Resampler(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_state_dict(cls, state_dict: dict[torch.Tensor], depth=8, dim_head=64, heads=16, num_queries=8, ff_mult=4):
|
def from_state_dict(
|
||||||
|
cls,
|
||||||
|
state_dict: dict[str, torch.Tensor],
|
||||||
|
depth: int = 8,
|
||||||
|
dim_head: int = 64,
|
||||||
|
heads: int = 16,
|
||||||
|
num_queries: int = 8,
|
||||||
|
ff_mult: int = 4,
|
||||||
|
):
|
||||||
"""A convenience function that initializes a Resampler from a state_dict.
|
"""A convenience function that initializes a Resampler from a state_dict.
|
||||||
|
|
||||||
Some of the shape parameters are inferred from the state_dict (e.g. dim, embedding_dim, etc.). At the time of
|
Some of the shape parameters are inferred from the state_dict (e.g. dim, embedding_dim, etc.). At the time of
|
||||||
@ -145,7 +153,7 @@ class Resampler(nn.Module):
|
|||||||
model.load_state_dict(state_dict)
|
model.load_state_dict(state_dict)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x: torch.Tensor):
|
||||||
latents = self.latents.repeat(x.size(0), 1, 1)
|
latents = self.latents.repeat(x.size(0), 1, 1)
|
||||||
|
|
||||||
x = self.proj_in(x)
|
x = self.proj_in(x)
|
||||||
|
@ -323,10 +323,13 @@ class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase):
|
|||||||
return Tag(f"{ModelType.Main.value}.{ModelFormat.Diffusers.value}")
|
return Tag(f"{ModelType.Main.value}.{ModelFormat.Diffusers.value}")
|
||||||
|
|
||||||
|
|
||||||
class IPAdapterConfig(ModelConfigBase):
|
class IPAdapterBaseConfig(ModelConfigBase):
|
||||||
"""Model config for IP Adaptor format models."""
|
|
||||||
|
|
||||||
type: Literal[ModelType.IPAdapter] = ModelType.IPAdapter
|
type: Literal[ModelType.IPAdapter] = ModelType.IPAdapter
|
||||||
|
|
||||||
|
|
||||||
|
class IPAdapterInvokeAIConfig(IPAdapterBaseConfig):
|
||||||
|
"""Model config for IP Adapter diffusers format models."""
|
||||||
|
|
||||||
image_encoder_model_id: str
|
image_encoder_model_id: str
|
||||||
format: Literal[ModelFormat.InvokeAI]
|
format: Literal[ModelFormat.InvokeAI]
|
||||||
|
|
||||||
@ -335,6 +338,16 @@ class IPAdapterConfig(ModelConfigBase):
|
|||||||
return Tag(f"{ModelType.IPAdapter.value}.{ModelFormat.InvokeAI.value}")
|
return Tag(f"{ModelType.IPAdapter.value}.{ModelFormat.InvokeAI.value}")
|
||||||
|
|
||||||
|
|
||||||
|
class IPAdapterCheckpointConfig(IPAdapterBaseConfig):
|
||||||
|
"""Model config for IP Adapter checkpoint format models."""
|
||||||
|
|
||||||
|
format: Literal[ModelFormat.Checkpoint]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_tag() -> Tag:
|
||||||
|
return Tag(f"{ModelType.IPAdapter.value}.{ModelFormat.Checkpoint.value}")
|
||||||
|
|
||||||
|
|
||||||
class CLIPVisionDiffusersConfig(DiffusersConfigBase):
|
class CLIPVisionDiffusersConfig(DiffusersConfigBase):
|
||||||
"""Model config for CLIPVision."""
|
"""Model config for CLIPVision."""
|
||||||
|
|
||||||
@ -390,7 +403,8 @@ AnyModelConfig = Annotated[
|
|||||||
Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()],
|
Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()],
|
||||||
Annotated[TextualInversionFileConfig, TextualInversionFileConfig.get_tag()],
|
Annotated[TextualInversionFileConfig, TextualInversionFileConfig.get_tag()],
|
||||||
Annotated[TextualInversionFolderConfig, TextualInversionFolderConfig.get_tag()],
|
Annotated[TextualInversionFolderConfig, TextualInversionFolderConfig.get_tag()],
|
||||||
Annotated[IPAdapterConfig, IPAdapterConfig.get_tag()],
|
Annotated[IPAdapterInvokeAIConfig, IPAdapterInvokeAIConfig.get_tag()],
|
||||||
|
Annotated[IPAdapterCheckpointConfig, IPAdapterCheckpointConfig.get_tag()],
|
||||||
Annotated[T2IAdapterConfig, T2IAdapterConfig.get_tag()],
|
Annotated[T2IAdapterConfig, T2IAdapterConfig.get_tag()],
|
||||||
Annotated[CLIPVisionDiffusersConfig, CLIPVisionDiffusersConfig.get_tag()],
|
Annotated[CLIPVisionDiffusersConfig, CLIPVisionDiffusersConfig.get_tag()],
|
||||||
],
|
],
|
||||||
|
@ -37,7 +37,7 @@ class ModelLoader(ModelLoaderBase):
|
|||||||
self._logger = logger
|
self._logger = logger
|
||||||
self._ram_cache = ram_cache
|
self._ram_cache = ram_cache
|
||||||
self._convert_cache = convert_cache
|
self._convert_cache = convert_cache
|
||||||
self._torch_dtype = torch_dtype(choose_torch_device(), app_config)
|
self._torch_dtype = torch_dtype(choose_torch_device())
|
||||||
|
|
||||||
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||||
"""
|
"""
|
||||||
|
@ -117,7 +117,7 @@ class ModelCacheBase(ABC, Generic[T]):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def stats(self) -> CacheStats:
|
def stats(self) -> Optional[CacheStats]:
|
||||||
"""Return collected CacheStats object."""
|
"""Return collected CacheStats object."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -269,9 +269,6 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
if torch.device(source_device).type == torch.device(target_device).type:
|
if torch.device(source_device).type == torch.device(target_device).type:
|
||||||
return
|
return
|
||||||
|
|
||||||
# may raise an exception here if insufficient GPU VRAM
|
|
||||||
self._check_free_vram(target_device, cache_entry.size)
|
|
||||||
|
|
||||||
start_model_to_time = time.time()
|
start_model_to_time = time.time()
|
||||||
snapshot_before = self._capture_memory_snapshot()
|
snapshot_before = self._capture_memory_snapshot()
|
||||||
cache_entry.model.to(target_device)
|
cache_entry.model.to(target_device)
|
||||||
@ -329,11 +326,11 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
f" {in_ram_models}/{in_vram_models}({locked_in_vram_models})"
|
f" {in_ram_models}/{in_vram_models}({locked_in_vram_models})"
|
||||||
)
|
)
|
||||||
|
|
||||||
def make_room(self, model_size: int) -> None:
|
def make_room(self, size: int) -> None:
|
||||||
"""Make enough room in the cache to accommodate a new model of indicated size."""
|
"""Make enough room in the cache to accommodate a new model of indicated size."""
|
||||||
# calculate how much memory this model will require
|
# calculate how much memory this model will require
|
||||||
# multiplier = 2 if self.precision==torch.float32 else 1
|
# multiplier = 2 if self.precision==torch.float32 else 1
|
||||||
bytes_needed = model_size
|
bytes_needed = size
|
||||||
maximum_size = self.max_cache_size * GIG # stored in GB, convert to bytes
|
maximum_size = self.max_cache_size * GIG # stored in GB, convert to bytes
|
||||||
current_size = self.cache_size()
|
current_size = self.cache_size()
|
||||||
|
|
||||||
@ -388,7 +385,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
# 1 from onnx runtime object
|
# 1 from onnx runtime object
|
||||||
if not cache_entry.locked and refs <= (3 if "onnx" in model_key else 2):
|
if not cache_entry.locked and refs <= (3 if "onnx" in model_key else 2):
|
||||||
self.logger.debug(
|
self.logger.debug(
|
||||||
f"Removing {model_key} from RAM cache to free at least {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
|
f"Removing {model_key} from RAM cache to free at least {(size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
|
||||||
)
|
)
|
||||||
current_size -= cache_entry.size
|
current_size -= cache_entry.size
|
||||||
models_cleared += 1
|
models_cleared += 1
|
||||||
@ -420,13 +417,3 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
mps.empty_cache()
|
mps.empty_cache()
|
||||||
|
|
||||||
self.logger.debug(f"After making room: cached_models={len(self._cached_models)}")
|
self.logger.debug(f"After making room: cached_models={len(self._cached_models)}")
|
||||||
|
|
||||||
def _check_free_vram(self, target_device: torch.device, needed_size: int) -> None:
|
|
||||||
if target_device.type != "cuda":
|
|
||||||
return
|
|
||||||
vram_device = ( # mem_get_info() needs an indexed device
|
|
||||||
target_device if target_device.index is not None else torch.device(str(target_device), index=0)
|
|
||||||
)
|
|
||||||
free_mem, _ = torch.cuda.mem_get_info(torch.device(vram_device))
|
|
||||||
if needed_size > free_mem:
|
|
||||||
raise torch.cuda.OutOfMemoryError
|
|
||||||
|
@ -34,7 +34,6 @@ class ModelLocker(ModelLockerBase):
|
|||||||
|
|
||||||
# NOTE that the model has to have the to() method in order for this code to move it into GPU!
|
# NOTE that the model has to have the to() method in order for this code to move it into GPU!
|
||||||
self._cache_entry.lock()
|
self._cache_entry.lock()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if self._cache.lazy_offloading:
|
if self._cache.lazy_offloading:
|
||||||
self._cache.offload_unlocked_models(self._cache_entry.size)
|
self._cache.offload_unlocked_models(self._cache_entry.size)
|
||||||
@ -51,6 +50,7 @@ class ModelLocker(ModelLockerBase):
|
|||||||
except Exception:
|
except Exception:
|
||||||
self._cache_entry.unlock()
|
self._cache_entry.unlock()
|
||||||
raise
|
raise
|
||||||
|
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
def unlock(self) -> None:
|
def unlock(self) -> None:
|
||||||
|
@ -7,19 +7,13 @@ from typing import Optional
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from invokeai.backend.ip_adapter.ip_adapter import build_ip_adapter
|
from invokeai.backend.ip_adapter.ip_adapter import build_ip_adapter
|
||||||
from invokeai.backend.model_manager import (
|
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
|
||||||
AnyModel,
|
|
||||||
AnyModelConfig,
|
|
||||||
BaseModelType,
|
|
||||||
ModelFormat,
|
|
||||||
ModelType,
|
|
||||||
SubModelType,
|
|
||||||
)
|
|
||||||
from invokeai.backend.model_manager.load import ModelLoader, ModelLoaderRegistry
|
from invokeai.backend.model_manager.load import ModelLoader, ModelLoaderRegistry
|
||||||
from invokeai.backend.raw_model import RawModel
|
from invokeai.backend.raw_model import RawModel
|
||||||
|
|
||||||
|
|
||||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.IPAdapter, format=ModelFormat.InvokeAI)
|
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.IPAdapter, format=ModelFormat.InvokeAI)
|
||||||
|
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.IPAdapter, format=ModelFormat.Checkpoint)
|
||||||
class IPAdapterInvokeAILoader(ModelLoader):
|
class IPAdapterInvokeAILoader(ModelLoader):
|
||||||
"""Class to load IP Adapter diffusers models."""
|
"""Class to load IP Adapter diffusers models."""
|
||||||
|
|
||||||
@ -32,7 +26,7 @@ class IPAdapterInvokeAILoader(ModelLoader):
|
|||||||
raise ValueError("There are no submodels in an IP-Adapter model.")
|
raise ValueError("There are no submodels in an IP-Adapter model.")
|
||||||
model_path = Path(config.path)
|
model_path = Path(config.path)
|
||||||
model: RawModel = build_ip_adapter(
|
model: RawModel = build_ip_adapter(
|
||||||
ip_adapter_ckpt_path=str(model_path / "ip_adapter.bin"),
|
ip_adapter_ckpt_path=model_path,
|
||||||
device=torch.device("cpu"),
|
device=torch.device("cpu"),
|
||||||
dtype=self._torch_dtype,
|
dtype=self._torch_dtype,
|
||||||
)
|
)
|
||||||
|
@ -230,9 +230,10 @@ class ModelProbe(object):
|
|||||||
return ModelType.LoRA
|
return ModelType.LoRA
|
||||||
elif any(key.startswith(v) for v in {"controlnet", "control_model", "input_blocks"}):
|
elif any(key.startswith(v) for v in {"controlnet", "control_model", "input_blocks"}):
|
||||||
return ModelType.ControlNet
|
return ModelType.ControlNet
|
||||||
|
elif any(key.startswith(v) for v in {"image_proj.", "ip_adapter."}):
|
||||||
|
return ModelType.IPAdapter
|
||||||
elif key in {"emb_params", "string_to_param"}:
|
elif key in {"emb_params", "string_to_param"}:
|
||||||
return ModelType.TextualInversion
|
return ModelType.TextualInversion
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# diffusers-ti
|
# diffusers-ti
|
||||||
if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
|
if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
|
||||||
@ -323,7 +324,7 @@ class ModelProbe(object):
|
|||||||
with SilenceWarnings():
|
with SilenceWarnings():
|
||||||
if model_path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin")):
|
if model_path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin")):
|
||||||
cls._scan_model(model_path.name, model_path)
|
cls._scan_model(model_path.name, model_path)
|
||||||
model = torch.load(model_path)
|
model = torch.load(model_path, map_location="cpu")
|
||||||
assert isinstance(model, dict)
|
assert isinstance(model, dict)
|
||||||
return model
|
return model
|
||||||
else:
|
else:
|
||||||
@ -527,8 +528,25 @@ class ControlNetCheckpointProbe(CheckpointProbeBase):
|
|||||||
|
|
||||||
|
|
||||||
class IPAdapterCheckpointProbe(CheckpointProbeBase):
|
class IPAdapterCheckpointProbe(CheckpointProbeBase):
|
||||||
|
"""Class for probing IP Adapters"""
|
||||||
|
|
||||||
def get_base_type(self) -> BaseModelType:
|
def get_base_type(self) -> BaseModelType:
|
||||||
raise NotImplementedError()
|
checkpoint = self.checkpoint
|
||||||
|
for key in checkpoint.keys():
|
||||||
|
if not key.startswith(("image_proj.", "ip_adapter.")):
|
||||||
|
continue
|
||||||
|
cross_attention_dim = checkpoint["ip_adapter.1.to_k_ip.weight"].shape[-1]
|
||||||
|
if cross_attention_dim == 768:
|
||||||
|
return BaseModelType.StableDiffusion1
|
||||||
|
elif cross_attention_dim == 1024:
|
||||||
|
return BaseModelType.StableDiffusion2
|
||||||
|
elif cross_attention_dim == 2048:
|
||||||
|
return BaseModelType.StableDiffusionXL
|
||||||
|
else:
|
||||||
|
raise InvalidModelConfigException(
|
||||||
|
f"IP-Adapter had unexpected cross-attention dimension: {cross_attention_dim}."
|
||||||
|
)
|
||||||
|
raise InvalidModelConfigException(f"{self.model_path}: Unable to determine base type")
|
||||||
|
|
||||||
|
|
||||||
class CLIPVisionCheckpointProbe(CheckpointProbeBase):
|
class CLIPVisionCheckpointProbe(CheckpointProbeBase):
|
||||||
@ -768,7 +786,7 @@ class T2IAdapterFolderProbe(FolderProbeBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
############## register probe classes ######
|
# Register probe classes
|
||||||
ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe)
|
ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe)
|
||||||
ModelProbe.register_probe("diffusers", ModelType.VAE, VaeFolderProbe)
|
ModelProbe.register_probe("diffusers", ModelType.VAE, VaeFolderProbe)
|
||||||
ModelProbe.register_probe("diffusers", ModelType.LoRA, LoRAFolderProbe)
|
ModelProbe.register_probe("diffusers", ModelType.LoRA, LoRAFolderProbe)
|
||||||
|
@ -6,8 +6,7 @@ from typing import Literal, Optional, Union
|
|||||||
import torch
|
import torch
|
||||||
from torch import autocast
|
from torch import autocast
|
||||||
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config.config_default import PRECISION, get_config
|
||||||
from invokeai.app.services.config.config_default import get_config
|
|
||||||
|
|
||||||
CPU_DEVICE = torch.device("cpu")
|
CPU_DEVICE = torch.device("cpu")
|
||||||
CUDA_DEVICE = torch.device("cuda")
|
CUDA_DEVICE = torch.device("cuda")
|
||||||
@ -33,35 +32,34 @@ def get_torch_device_name() -> str:
|
|||||||
return torch.cuda.get_device_name(device) if device.type == "cuda" else device.type.upper()
|
return torch.cuda.get_device_name(device) if device.type == "cuda" else device.type.upper()
|
||||||
|
|
||||||
|
|
||||||
# We are in transition here from using a single global AppConfig to allowing multiple
|
def choose_precision(device: torch.device) -> Literal["float32", "float16", "bfloat16"]:
|
||||||
# configurations. It is strongly recommended to pass the app_config to this function.
|
|
||||||
def choose_precision(
|
|
||||||
device: torch.device, app_config: Optional[InvokeAIAppConfig] = None
|
|
||||||
) -> Literal["float32", "float16", "bfloat16"]:
|
|
||||||
"""Return an appropriate precision for the given torch device."""
|
"""Return an appropriate precision for the given torch device."""
|
||||||
app_config = app_config or get_config()
|
app_config = get_config()
|
||||||
if device.type == "cuda":
|
if device.type == "cuda":
|
||||||
device_name = torch.cuda.get_device_name(device)
|
device_name = torch.cuda.get_device_name(device)
|
||||||
if not ("GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name):
|
if "GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name:
|
||||||
if app_config.precision == "float32":
|
# These GPUs have limited support for float16
|
||||||
return "float32"
|
return "float32"
|
||||||
elif app_config.precision == "bfloat16":
|
elif app_config.precision == "auto" or app_config.precision == "autocast":
|
||||||
return "bfloat16"
|
# Default to float16 for CUDA devices
|
||||||
|
return "float16"
|
||||||
else:
|
else:
|
||||||
return "float16"
|
# Use the user-defined precision
|
||||||
|
return app_config.precision
|
||||||
elif device.type == "mps":
|
elif device.type == "mps":
|
||||||
|
if app_config.precision == "auto" or app_config.precision == "autocast":
|
||||||
|
# Default to float16 for MPS devices
|
||||||
return "float16"
|
return "float16"
|
||||||
|
else:
|
||||||
|
# Use the user-defined precision
|
||||||
|
return app_config.precision
|
||||||
|
# CPU / safe fallback
|
||||||
return "float32"
|
return "float32"
|
||||||
|
|
||||||
|
|
||||||
# We are in transition here from using a single global AppConfig to allowing multiple
|
def torch_dtype(device: Optional[torch.device] = None) -> torch.dtype:
|
||||||
# configurations. It is strongly recommended to pass the app_config to this function.
|
|
||||||
def torch_dtype(
|
|
||||||
device: Optional[torch.device] = None,
|
|
||||||
app_config: Optional[InvokeAIAppConfig] = None,
|
|
||||||
) -> torch.dtype:
|
|
||||||
device = device or choose_torch_device()
|
device = device or choose_torch_device()
|
||||||
precision = choose_precision(device, app_config)
|
precision = choose_precision(device)
|
||||||
if precision == "float16":
|
if precision == "float16":
|
||||||
return torch.float16
|
return torch.float16
|
||||||
if precision == "bfloat16":
|
if precision == "bfloat16":
|
||||||
@ -71,7 +69,7 @@ def torch_dtype(
|
|||||||
return torch.float32
|
return torch.float32
|
||||||
|
|
||||||
|
|
||||||
def choose_autocast(precision):
|
def choose_autocast(precision: PRECISION):
|
||||||
"""Returns an autocast context or nullcontext for the given precision string"""
|
"""Returns an autocast context or nullcontext for the given precision string"""
|
||||||
# float16 currently requires autocast to avoid errors like:
|
# float16 currently requires autocast to avoid errors like:
|
||||||
# 'expected scalar type Half but found Float'
|
# 'expected scalar type Half but found Float'
|
||||||
|
@ -291,7 +291,6 @@
|
|||||||
"canvasMerged": "تم دمج الخط",
|
"canvasMerged": "تم دمج الخط",
|
||||||
"sentToImageToImage": "تم إرسال إلى صورة إلى صورة",
|
"sentToImageToImage": "تم إرسال إلى صورة إلى صورة",
|
||||||
"sentToUnifiedCanvas": "تم إرسال إلى لوحة موحدة",
|
"sentToUnifiedCanvas": "تم إرسال إلى لوحة موحدة",
|
||||||
"parametersSet": "تم تعيين المعلمات",
|
|
||||||
"parametersNotSet": "لم يتم تعيين المعلمات",
|
"parametersNotSet": "لم يتم تعيين المعلمات",
|
||||||
"metadataLoadFailed": "فشل تحميل البيانات الوصفية"
|
"metadataLoadFailed": "فشل تحميل البيانات الوصفية"
|
||||||
},
|
},
|
||||||
|
@ -75,7 +75,8 @@
|
|||||||
"copy": "Kopieren",
|
"copy": "Kopieren",
|
||||||
"aboutHeading": "Nutzen Sie Ihre kreative Energie",
|
"aboutHeading": "Nutzen Sie Ihre kreative Energie",
|
||||||
"toResolve": "Lösen",
|
"toResolve": "Lösen",
|
||||||
"add": "Hinzufügen"
|
"add": "Hinzufügen",
|
||||||
|
"loglevel": "Protokoll Stufe"
|
||||||
},
|
},
|
||||||
"gallery": {
|
"gallery": {
|
||||||
"galleryImageSize": "Bildgröße",
|
"galleryImageSize": "Bildgröße",
|
||||||
@ -388,7 +389,14 @@
|
|||||||
"vaePrecision": "VAE-Präzision",
|
"vaePrecision": "VAE-Präzision",
|
||||||
"variant": "Variante",
|
"variant": "Variante",
|
||||||
"modelDeleteFailed": "Modell konnte nicht gelöscht werden",
|
"modelDeleteFailed": "Modell konnte nicht gelöscht werden",
|
||||||
"noModelSelected": "Kein Modell ausgewählt"
|
"noModelSelected": "Kein Modell ausgewählt",
|
||||||
|
"huggingFace": "HuggingFace",
|
||||||
|
"defaultSettings": "Standardeinstellungen",
|
||||||
|
"edit": "Bearbeiten",
|
||||||
|
"cancel": "Stornieren",
|
||||||
|
"defaultSettingsSaved": "Standardeinstellungen gespeichert",
|
||||||
|
"addModels": "Model hinzufügen",
|
||||||
|
"deleteModelImage": "Lösche Model Bild"
|
||||||
},
|
},
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"images": "Bilder",
|
"images": "Bilder",
|
||||||
@ -472,7 +480,6 @@
|
|||||||
"canvasMerged": "Leinwand zusammengeführt",
|
"canvasMerged": "Leinwand zusammengeführt",
|
||||||
"sentToImageToImage": "Gesendet an Bild zu Bild",
|
"sentToImageToImage": "Gesendet an Bild zu Bild",
|
||||||
"sentToUnifiedCanvas": "Gesendet an Leinwand",
|
"sentToUnifiedCanvas": "Gesendet an Leinwand",
|
||||||
"parametersSet": "Parameter festlegen",
|
|
||||||
"parametersNotSet": "Parameter nicht festgelegt",
|
"parametersNotSet": "Parameter nicht festgelegt",
|
||||||
"metadataLoadFailed": "Metadaten konnten nicht geladen werden",
|
"metadataLoadFailed": "Metadaten konnten nicht geladen werden",
|
||||||
"setCanvasInitialImage": "Ausgangsbild setzen",
|
"setCanvasInitialImage": "Ausgangsbild setzen",
|
||||||
@ -677,7 +684,8 @@
|
|||||||
"body": "Körper",
|
"body": "Körper",
|
||||||
"hands": "Hände",
|
"hands": "Hände",
|
||||||
"dwOpenpose": "DW Openpose",
|
"dwOpenpose": "DW Openpose",
|
||||||
"dwOpenposeDescription": "Posenschätzung mit DW Openpose"
|
"dwOpenposeDescription": "Posenschätzung mit DW Openpose",
|
||||||
|
"selectCLIPVisionModel": "Wähle ein CLIP Vision Model aus"
|
||||||
},
|
},
|
||||||
"queue": {
|
"queue": {
|
||||||
"status": "Status",
|
"status": "Status",
|
||||||
@ -765,7 +773,10 @@
|
|||||||
"recallParameters": "Parameter wiederherstellen",
|
"recallParameters": "Parameter wiederherstellen",
|
||||||
"cfgRescaleMultiplier": "$t(parameters.cfgRescaleMultiplier)",
|
"cfgRescaleMultiplier": "$t(parameters.cfgRescaleMultiplier)",
|
||||||
"allPrompts": "Alle Prompts",
|
"allPrompts": "Alle Prompts",
|
||||||
"imageDimensions": "Bilder Auslösungen"
|
"imageDimensions": "Bilder Auslösungen",
|
||||||
|
"parameterSet": "Parameter {{parameter}} setzen",
|
||||||
|
"recallParameter": "{{label}} Abrufen",
|
||||||
|
"parsingFailed": "Parsing Fehlgeschlagen"
|
||||||
},
|
},
|
||||||
"popovers": {
|
"popovers": {
|
||||||
"noiseUseCPU": {
|
"noiseUseCPU": {
|
||||||
@ -1030,7 +1041,8 @@
|
|||||||
"title": "Bild"
|
"title": "Bild"
|
||||||
},
|
},
|
||||||
"advanced": {
|
"advanced": {
|
||||||
"title": "Erweitert"
|
"title": "Erweitert",
|
||||||
|
"options": "$t(accordions.advanced.title) Optionen"
|
||||||
},
|
},
|
||||||
"control": {
|
"control": {
|
||||||
"title": "Kontrolle"
|
"title": "Kontrolle"
|
||||||
|
@ -217,6 +217,7 @@
|
|||||||
"saveControlImage": "Save Control Image",
|
"saveControlImage": "Save Control Image",
|
||||||
"scribble": "scribble",
|
"scribble": "scribble",
|
||||||
"selectModel": "Select a model",
|
"selectModel": "Select a model",
|
||||||
|
"selectCLIPVisionModel": "Select a CLIP Vision model",
|
||||||
"setControlImageDimensions": "Set Control Image Dimensions To W/H",
|
"setControlImageDimensions": "Set Control Image Dimensions To W/H",
|
||||||
"showAdvanced": "Show Advanced",
|
"showAdvanced": "Show Advanced",
|
||||||
"small": "Small",
|
"small": "Small",
|
||||||
@ -655,6 +656,7 @@
|
|||||||
"install": "Install",
|
"install": "Install",
|
||||||
"installAll": "Install All",
|
"installAll": "Install All",
|
||||||
"installRepo": "Install Repo",
|
"installRepo": "Install Repo",
|
||||||
|
"ipAdapters": "IP Adapters",
|
||||||
"load": "Load",
|
"load": "Load",
|
||||||
"localOnly": "local only",
|
"localOnly": "local only",
|
||||||
"manual": "Manual",
|
"manual": "Manual",
|
||||||
@ -682,6 +684,7 @@
|
|||||||
"noModelsInstalled": "No Models Installed",
|
"noModelsInstalled": "No Models Installed",
|
||||||
"noModelsInstalledDesc1": "Install models with the",
|
"noModelsInstalledDesc1": "Install models with the",
|
||||||
"noModelSelected": "No Model Selected",
|
"noModelSelected": "No Model Selected",
|
||||||
|
"noMatchingModels": "No matching Models",
|
||||||
"none": "none",
|
"none": "none",
|
||||||
"path": "Path",
|
"path": "Path",
|
||||||
"pathToConfig": "Path To Config",
|
"pathToConfig": "Path To Config",
|
||||||
@ -885,6 +888,11 @@
|
|||||||
"imageFit": "Fit Initial Image To Output Size",
|
"imageFit": "Fit Initial Image To Output Size",
|
||||||
"images": "Images",
|
"images": "Images",
|
||||||
"infillMethod": "Infill Method",
|
"infillMethod": "Infill Method",
|
||||||
|
"infillMosaicTileWidth": "Tile Width",
|
||||||
|
"infillMosaicTileHeight": "Tile Height",
|
||||||
|
"infillMosaicMinColor": "Min Color",
|
||||||
|
"infillMosaicMaxColor": "Max Color",
|
||||||
|
"infillColorValue": "Fill Color",
|
||||||
"info": "Info",
|
"info": "Info",
|
||||||
"invoke": {
|
"invoke": {
|
||||||
"addingImagesTo": "Adding images to",
|
"addingImagesTo": "Adding images to",
|
||||||
@ -1033,10 +1041,10 @@
|
|||||||
"metadataLoadFailed": "Failed to load metadata",
|
"metadataLoadFailed": "Failed to load metadata",
|
||||||
"modelAddedSimple": "Model Added to Queue",
|
"modelAddedSimple": "Model Added to Queue",
|
||||||
"modelImportCanceled": "Model Import Canceled",
|
"modelImportCanceled": "Model Import Canceled",
|
||||||
|
"parameters": "Parameters",
|
||||||
"parameterNotSet": "{{parameter}} not set",
|
"parameterNotSet": "{{parameter}} not set",
|
||||||
"parameterSet": "{{parameter}} set",
|
"parameterSet": "{{parameter}} set",
|
||||||
"parametersNotSet": "Parameters Not Set",
|
"parametersNotSet": "Parameters Not Set",
|
||||||
"parametersSet": "Parameters Set",
|
|
||||||
"problemCopyingCanvas": "Problem Copying Canvas",
|
"problemCopyingCanvas": "Problem Copying Canvas",
|
||||||
"problemCopyingCanvasDesc": "Unable to export base layer",
|
"problemCopyingCanvasDesc": "Unable to export base layer",
|
||||||
"problemCopyingImage": "Unable to Copy Image",
|
"problemCopyingImage": "Unable to Copy Image",
|
||||||
@ -1415,6 +1423,7 @@
|
|||||||
"eraseBoundingBox": "Erase Bounding Box",
|
"eraseBoundingBox": "Erase Bounding Box",
|
||||||
"eraser": "Eraser",
|
"eraser": "Eraser",
|
||||||
"fillBoundingBox": "Fill Bounding Box",
|
"fillBoundingBox": "Fill Bounding Box",
|
||||||
|
"initialFitImageSize": "Fit Image Size on Drop",
|
||||||
"invertBrushSizeScrollDirection": "Invert Scroll for Brush Size",
|
"invertBrushSizeScrollDirection": "Invert Scroll for Brush Size",
|
||||||
"layer": "Layer",
|
"layer": "Layer",
|
||||||
"limitStrokesToBox": "Limit Strokes to Box",
|
"limitStrokesToBox": "Limit Strokes to Box",
|
||||||
|
@ -363,7 +363,6 @@
|
|||||||
"canvasMerged": "Lienzo consolidado",
|
"canvasMerged": "Lienzo consolidado",
|
||||||
"sentToImageToImage": "Enviar hacia Imagen a Imagen",
|
"sentToImageToImage": "Enviar hacia Imagen a Imagen",
|
||||||
"sentToUnifiedCanvas": "Enviar hacia Lienzo Consolidado",
|
"sentToUnifiedCanvas": "Enviar hacia Lienzo Consolidado",
|
||||||
"parametersSet": "Parámetros establecidos",
|
|
||||||
"parametersNotSet": "Parámetros no establecidos",
|
"parametersNotSet": "Parámetros no establecidos",
|
||||||
"metadataLoadFailed": "Error al cargar metadatos",
|
"metadataLoadFailed": "Error al cargar metadatos",
|
||||||
"serverError": "Error en el servidor",
|
"serverError": "Error en el servidor",
|
||||||
|
@ -298,7 +298,6 @@
|
|||||||
"canvasMerged": "Canvas fusionné",
|
"canvasMerged": "Canvas fusionné",
|
||||||
"sentToImageToImage": "Envoyé à Image à Image",
|
"sentToImageToImage": "Envoyé à Image à Image",
|
||||||
"sentToUnifiedCanvas": "Envoyé à Canvas unifié",
|
"sentToUnifiedCanvas": "Envoyé à Canvas unifié",
|
||||||
"parametersSet": "Paramètres définis",
|
|
||||||
"parametersNotSet": "Paramètres non définis",
|
"parametersNotSet": "Paramètres non définis",
|
||||||
"metadataLoadFailed": "Échec du chargement des métadonnées"
|
"metadataLoadFailed": "Échec du chargement des métadonnées"
|
||||||
},
|
},
|
||||||
|
@ -306,7 +306,6 @@
|
|||||||
"canvasMerged": "קנבס מוזג",
|
"canvasMerged": "קנבס מוזג",
|
||||||
"sentToImageToImage": "נשלח לתמונה לתמונה",
|
"sentToImageToImage": "נשלח לתמונה לתמונה",
|
||||||
"sentToUnifiedCanvas": "נשלח אל קנבס מאוחד",
|
"sentToUnifiedCanvas": "נשלח אל קנבס מאוחד",
|
||||||
"parametersSet": "הגדרת פרמטרים",
|
|
||||||
"parametersNotSet": "פרמטרים לא הוגדרו",
|
"parametersNotSet": "פרמטרים לא הוגדרו",
|
||||||
"metadataLoadFailed": "טעינת מטא-נתונים נכשלה"
|
"metadataLoadFailed": "טעינת מטא-נתונים נכשלה"
|
||||||
},
|
},
|
||||||
|
@ -366,7 +366,7 @@
|
|||||||
"modelConverted": "Modello convertito",
|
"modelConverted": "Modello convertito",
|
||||||
"alpha": "Alpha",
|
"alpha": "Alpha",
|
||||||
"convertToDiffusersHelpText1": "Questo modello verrà convertito nel formato 🧨 Diffusori.",
|
"convertToDiffusersHelpText1": "Questo modello verrà convertito nel formato 🧨 Diffusori.",
|
||||||
"convertToDiffusersHelpText3": "Il file Checkpoint su disco verrà eliminato se si trova nella cartella principale di InvokeAI. Se si trova invece in una posizione personalizzata, NON verrà eliminato.",
|
"convertToDiffusersHelpText3": "Il file del modello su disco verrà eliminato se si trova nella cartella principale di InvokeAI. Se si trova invece in una posizione personalizzata, NON verrà eliminato.",
|
||||||
"v2_base": "v2 (512px)",
|
"v2_base": "v2 (512px)",
|
||||||
"v2_768": "v2 (768px)",
|
"v2_768": "v2 (768px)",
|
||||||
"none": "nessuno",
|
"none": "nessuno",
|
||||||
@ -443,7 +443,8 @@
|
|||||||
"noModelsInstalled": "Nessun modello installato",
|
"noModelsInstalled": "Nessun modello installato",
|
||||||
"hfTokenInvalidErrorMessage2": "Aggiornalo in ",
|
"hfTokenInvalidErrorMessage2": "Aggiornalo in ",
|
||||||
"main": "Principali",
|
"main": "Principali",
|
||||||
"noModelsInstalledDesc1": "Installa i modelli con"
|
"noModelsInstalledDesc1": "Installa i modelli con",
|
||||||
|
"ipAdapters": "Adattatori IP"
|
||||||
},
|
},
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"images": "Immagini",
|
"images": "Immagini",
|
||||||
@ -568,7 +569,6 @@
|
|||||||
"canvasMerged": "Tela unita",
|
"canvasMerged": "Tela unita",
|
||||||
"sentToImageToImage": "Inviato a Immagine a Immagine",
|
"sentToImageToImage": "Inviato a Immagine a Immagine",
|
||||||
"sentToUnifiedCanvas": "Inviato a Tela Unificata",
|
"sentToUnifiedCanvas": "Inviato a Tela Unificata",
|
||||||
"parametersSet": "Parametri impostati",
|
|
||||||
"parametersNotSet": "Parametri non impostati",
|
"parametersNotSet": "Parametri non impostati",
|
||||||
"metadataLoadFailed": "Impossibile caricare i metadati",
|
"metadataLoadFailed": "Impossibile caricare i metadati",
|
||||||
"serverError": "Errore del Server",
|
"serverError": "Errore del Server",
|
||||||
@ -937,7 +937,8 @@
|
|||||||
"controlnet": "$t(controlnet.controlAdapter_one) #{{number}} ($t(common.controlNet))",
|
"controlnet": "$t(controlnet.controlAdapter_one) #{{number}} ($t(common.controlNet))",
|
||||||
"mediapipeFace": "Mediapipe Volto",
|
"mediapipeFace": "Mediapipe Volto",
|
||||||
"ip_adapter": "$t(controlnet.controlAdapter_one) #{{number}} ($t(common.ipAdapter))",
|
"ip_adapter": "$t(controlnet.controlAdapter_one) #{{number}} ($t(common.ipAdapter))",
|
||||||
"t2i_adapter": "$t(controlnet.controlAdapter_one) #{{number}} ($t(common.t2iAdapter))"
|
"t2i_adapter": "$t(controlnet.controlAdapter_one) #{{number}} ($t(common.t2iAdapter))",
|
||||||
|
"selectCLIPVisionModel": "Seleziona un modello CLIP Vision"
|
||||||
},
|
},
|
||||||
"queue": {
|
"queue": {
|
||||||
"queueFront": "Aggiungi all'inizio della coda",
|
"queueFront": "Aggiungi all'inizio della coda",
|
||||||
|
@ -420,7 +420,6 @@
|
|||||||
"canvasMerged": "Canvas samengevoegd",
|
"canvasMerged": "Canvas samengevoegd",
|
||||||
"sentToImageToImage": "Gestuurd naar Afbeelding naar afbeelding",
|
"sentToImageToImage": "Gestuurd naar Afbeelding naar afbeelding",
|
||||||
"sentToUnifiedCanvas": "Gestuurd naar Centraal canvas",
|
"sentToUnifiedCanvas": "Gestuurd naar Centraal canvas",
|
||||||
"parametersSet": "Parameters ingesteld",
|
|
||||||
"parametersNotSet": "Parameters niet ingesteld",
|
"parametersNotSet": "Parameters niet ingesteld",
|
||||||
"metadataLoadFailed": "Fout bij laden metagegevens",
|
"metadataLoadFailed": "Fout bij laden metagegevens",
|
||||||
"serverError": "Serverfout",
|
"serverError": "Serverfout",
|
||||||
|
@ -267,7 +267,6 @@
|
|||||||
"canvasMerged": "Scalono widoczne warstwy",
|
"canvasMerged": "Scalono widoczne warstwy",
|
||||||
"sentToImageToImage": "Wysłano do Obraz na obraz",
|
"sentToImageToImage": "Wysłano do Obraz na obraz",
|
||||||
"sentToUnifiedCanvas": "Wysłano do trybu uniwersalnego",
|
"sentToUnifiedCanvas": "Wysłano do trybu uniwersalnego",
|
||||||
"parametersSet": "Ustawiono parametry",
|
|
||||||
"parametersNotSet": "Nie ustawiono parametrów",
|
"parametersNotSet": "Nie ustawiono parametrów",
|
||||||
"metadataLoadFailed": "Błąd wczytywania metadanych"
|
"metadataLoadFailed": "Błąd wczytywania metadanych"
|
||||||
},
|
},
|
||||||
|
@ -310,7 +310,6 @@
|
|||||||
"canvasMerged": "Tela Fundida",
|
"canvasMerged": "Tela Fundida",
|
||||||
"sentToImageToImage": "Mandar Para Imagem Para Imagem",
|
"sentToImageToImage": "Mandar Para Imagem Para Imagem",
|
||||||
"sentToUnifiedCanvas": "Enviada para a Tela Unificada",
|
"sentToUnifiedCanvas": "Enviada para a Tela Unificada",
|
||||||
"parametersSet": "Parâmetros Definidos",
|
|
||||||
"parametersNotSet": "Parâmetros Não Definidos",
|
"parametersNotSet": "Parâmetros Não Definidos",
|
||||||
"metadataLoadFailed": "Falha ao tentar carregar metadados"
|
"metadataLoadFailed": "Falha ao tentar carregar metadados"
|
||||||
},
|
},
|
||||||
|
@ -307,7 +307,6 @@
|
|||||||
"canvasMerged": "Tela Fundida",
|
"canvasMerged": "Tela Fundida",
|
||||||
"sentToImageToImage": "Mandar Para Imagem Para Imagem",
|
"sentToImageToImage": "Mandar Para Imagem Para Imagem",
|
||||||
"sentToUnifiedCanvas": "Enviada para a Tela Unificada",
|
"sentToUnifiedCanvas": "Enviada para a Tela Unificada",
|
||||||
"parametersSet": "Parâmetros Definidos",
|
|
||||||
"parametersNotSet": "Parâmetros Não Definidos",
|
"parametersNotSet": "Parâmetros Não Definidos",
|
||||||
"metadataLoadFailed": "Falha ao tentar carregar metadados"
|
"metadataLoadFailed": "Falha ao tentar carregar metadados"
|
||||||
},
|
},
|
||||||
|
@ -575,7 +575,6 @@
|
|||||||
"canvasMerged": "Холст объединен",
|
"canvasMerged": "Холст объединен",
|
||||||
"sentToImageToImage": "Отправить в img2img",
|
"sentToImageToImage": "Отправить в img2img",
|
||||||
"sentToUnifiedCanvas": "Отправлено на Единый холст",
|
"sentToUnifiedCanvas": "Отправлено на Единый холст",
|
||||||
"parametersSet": "Параметры заданы",
|
|
||||||
"parametersNotSet": "Параметры не заданы",
|
"parametersNotSet": "Параметры не заданы",
|
||||||
"metadataLoadFailed": "Не удалось загрузить метаданные",
|
"metadataLoadFailed": "Не удалось загрузить метаданные",
|
||||||
"serverError": "Ошибка сервера",
|
"serverError": "Ошибка сервера",
|
||||||
|
@ -315,7 +315,6 @@
|
|||||||
"canvasMerged": "Полотно об'єднане",
|
"canvasMerged": "Полотно об'єднане",
|
||||||
"sentToImageToImage": "Надіслати до img2img",
|
"sentToImageToImage": "Надіслати до img2img",
|
||||||
"sentToUnifiedCanvas": "Надіслати на полотно",
|
"sentToUnifiedCanvas": "Надіслати на полотно",
|
||||||
"parametersSet": "Параметри задані",
|
|
||||||
"parametersNotSet": "Параметри не задані",
|
"parametersNotSet": "Параметри не задані",
|
||||||
"metadataLoadFailed": "Не вдалося завантажити метадані",
|
"metadataLoadFailed": "Не вдалося завантажити метадані",
|
||||||
"serverError": "Помилка сервера",
|
"serverError": "Помилка сервера",
|
||||||
|
@ -487,7 +487,6 @@
|
|||||||
"canvasMerged": "画布已合并",
|
"canvasMerged": "画布已合并",
|
||||||
"sentToImageToImage": "已发送到图生图",
|
"sentToImageToImage": "已发送到图生图",
|
||||||
"sentToUnifiedCanvas": "已发送到统一画布",
|
"sentToUnifiedCanvas": "已发送到统一画布",
|
||||||
"parametersSet": "参数已设定",
|
|
||||||
"parametersNotSet": "参数未设定",
|
"parametersNotSet": "参数未设定",
|
||||||
"metadataLoadFailed": "加载元数据失败",
|
"metadataLoadFailed": "加载元数据失败",
|
||||||
"uploadFailedInvalidUploadDesc": "必须是单张的 PNG 或 JPEG 图片",
|
"uploadFailedInvalidUploadDesc": "必须是单张的 PNG 或 JPEG 图片",
|
||||||
|
@ -43,6 +43,7 @@ export const addModelInstallEventListener = (startAppListening: AppStartListenin
|
|||||||
})
|
})
|
||||||
);
|
);
|
||||||
dispatch(api.util.invalidateTags([{ type: 'ModelConfig', id: LIST_TAG }]));
|
dispatch(api.util.invalidateTags([{ type: 'ModelConfig', id: LIST_TAG }]));
|
||||||
|
dispatch(api.util.invalidateTags([{ type: 'ModelScanFolderResults', id: LIST_TAG }]));
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -18,6 +18,7 @@ import {
|
|||||||
setShouldAutoSave,
|
setShouldAutoSave,
|
||||||
setShouldCropToBoundingBoxOnSave,
|
setShouldCropToBoundingBoxOnSave,
|
||||||
setShouldDarkenOutsideBoundingBox,
|
setShouldDarkenOutsideBoundingBox,
|
||||||
|
setShouldFitImageSize,
|
||||||
setShouldInvertBrushSizeScrollDirection,
|
setShouldInvertBrushSizeScrollDirection,
|
||||||
setShouldRestrictStrokesToBox,
|
setShouldRestrictStrokesToBox,
|
||||||
setShouldShowCanvasDebugInfo,
|
setShouldShowCanvasDebugInfo,
|
||||||
@ -48,6 +49,7 @@ const IAICanvasSettingsButtonPopover = () => {
|
|||||||
const shouldSnapToGrid = useAppSelector((s) => s.canvas.shouldSnapToGrid);
|
const shouldSnapToGrid = useAppSelector((s) => s.canvas.shouldSnapToGrid);
|
||||||
const shouldRestrictStrokesToBox = useAppSelector((s) => s.canvas.shouldRestrictStrokesToBox);
|
const shouldRestrictStrokesToBox = useAppSelector((s) => s.canvas.shouldRestrictStrokesToBox);
|
||||||
const shouldAntialias = useAppSelector((s) => s.canvas.shouldAntialias);
|
const shouldAntialias = useAppSelector((s) => s.canvas.shouldAntialias);
|
||||||
|
const shouldFitImageSize = useAppSelector((s) => s.canvas.shouldFitImageSize);
|
||||||
|
|
||||||
useHotkeys(
|
useHotkeys(
|
||||||
['n'],
|
['n'],
|
||||||
@ -102,6 +104,10 @@ const IAICanvasSettingsButtonPopover = () => {
|
|||||||
(e: ChangeEvent<HTMLInputElement>) => dispatch(setShouldAntialias(e.target.checked)),
|
(e: ChangeEvent<HTMLInputElement>) => dispatch(setShouldAntialias(e.target.checked)),
|
||||||
[dispatch]
|
[dispatch]
|
||||||
);
|
);
|
||||||
|
const handleChangeShouldFitImageSize = useCallback(
|
||||||
|
(e: ChangeEvent<HTMLInputElement>) => dispatch(setShouldFitImageSize(e.target.checked)),
|
||||||
|
[dispatch]
|
||||||
|
);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Popover>
|
<Popover>
|
||||||
@ -165,6 +171,10 @@ const IAICanvasSettingsButtonPopover = () => {
|
|||||||
<FormLabel>{t('unifiedCanvas.antialiasing')}</FormLabel>
|
<FormLabel>{t('unifiedCanvas.antialiasing')}</FormLabel>
|
||||||
<Checkbox isChecked={shouldAntialias} onChange={handleChangeShouldAntialias} />
|
<Checkbox isChecked={shouldAntialias} onChange={handleChangeShouldAntialias} />
|
||||||
</FormControl>
|
</FormControl>
|
||||||
|
<FormControl>
|
||||||
|
<FormLabel>{t('unifiedCanvas.initialFitImageSize')}</FormLabel>
|
||||||
|
<Checkbox isChecked={shouldFitImageSize} onChange={handleChangeShouldFitImageSize} />
|
||||||
|
</FormControl>
|
||||||
</FormControlGroup>
|
</FormControlGroup>
|
||||||
<ClearCanvasHistoryButtonModal />
|
<ClearCanvasHistoryButtonModal />
|
||||||
</Flex>
|
</Flex>
|
||||||
|
@ -66,6 +66,7 @@ const initialCanvasState: CanvasState = {
|
|||||||
shouldAutoSave: false,
|
shouldAutoSave: false,
|
||||||
shouldCropToBoundingBoxOnSave: false,
|
shouldCropToBoundingBoxOnSave: false,
|
||||||
shouldDarkenOutsideBoundingBox: false,
|
shouldDarkenOutsideBoundingBox: false,
|
||||||
|
shouldFitImageSize: true,
|
||||||
shouldInvertBrushSizeScrollDirection: false,
|
shouldInvertBrushSizeScrollDirection: false,
|
||||||
shouldLockBoundingBox: false,
|
shouldLockBoundingBox: false,
|
||||||
shouldPreserveMaskedArea: false,
|
shouldPreserveMaskedArea: false,
|
||||||
@ -144,11 +145,19 @@ export const canvasSlice = createSlice({
|
|||||||
reducer: (state, action: PayloadActionWithOptimalDimension<ImageDTO>) => {
|
reducer: (state, action: PayloadActionWithOptimalDimension<ImageDTO>) => {
|
||||||
const { width, height, image_name } = action.payload;
|
const { width, height, image_name } = action.payload;
|
||||||
const { optimalDimension } = action.meta;
|
const { optimalDimension } = action.meta;
|
||||||
const { stageDimensions } = state;
|
const { stageDimensions, shouldFitImageSize } = state;
|
||||||
|
|
||||||
const newBoundingBoxDimensions = {
|
const newBoundingBoxDimensions = shouldFitImageSize
|
||||||
|
? {
|
||||||
|
width: roundDownToMultiple(width, CANVAS_GRID_SIZE_FINE),
|
||||||
|
height: roundDownToMultiple(height, CANVAS_GRID_SIZE_FINE),
|
||||||
|
}
|
||||||
|
: {
|
||||||
width: roundDownToMultiple(clamp(width, CANVAS_GRID_SIZE_FINE, optimalDimension), CANVAS_GRID_SIZE_FINE),
|
width: roundDownToMultiple(clamp(width, CANVAS_GRID_SIZE_FINE, optimalDimension), CANVAS_GRID_SIZE_FINE),
|
||||||
height: roundDownToMultiple(clamp(height, CANVAS_GRID_SIZE_FINE, optimalDimension), CANVAS_GRID_SIZE_FINE),
|
height: roundDownToMultiple(
|
||||||
|
clamp(height, CANVAS_GRID_SIZE_FINE, optimalDimension),
|
||||||
|
CANVAS_GRID_SIZE_FINE
|
||||||
|
),
|
||||||
};
|
};
|
||||||
|
|
||||||
const newBoundingBoxCoordinates = {
|
const newBoundingBoxCoordinates = {
|
||||||
@ -289,12 +298,19 @@ export const canvasSlice = createSlice({
|
|||||||
const { images, selectedImageIndex } = state.layerState.stagingArea;
|
const { images, selectedImageIndex } = state.layerState.stagingArea;
|
||||||
pushToPrevLayerStates(state);
|
pushToPrevLayerStates(state);
|
||||||
|
|
||||||
if (!images.length) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
images.splice(selectedImageIndex, 1);
|
images.splice(selectedImageIndex, 1);
|
||||||
|
|
||||||
|
if (images.length === 0) {
|
||||||
|
pushToPrevLayerStates(state);
|
||||||
|
|
||||||
|
state.layerState.stagingArea = deepClone(initialLayerState.stagingArea);
|
||||||
|
|
||||||
|
state.futureLayerStates = [];
|
||||||
|
state.shouldShowStagingOutline = true;
|
||||||
|
state.shouldShowStagingImage = true;
|
||||||
|
state.batchIds = [];
|
||||||
|
}
|
||||||
|
|
||||||
if (selectedImageIndex >= images.length) {
|
if (selectedImageIndex >= images.length) {
|
||||||
state.layerState.stagingArea.selectedImageIndex = images.length - 1;
|
state.layerState.stagingArea.selectedImageIndex = images.length - 1;
|
||||||
}
|
}
|
||||||
@ -575,6 +591,9 @@ export const canvasSlice = createSlice({
|
|||||||
setShouldAntialias: (state, action: PayloadAction<boolean>) => {
|
setShouldAntialias: (state, action: PayloadAction<boolean>) => {
|
||||||
state.shouldAntialias = action.payload;
|
state.shouldAntialias = action.payload;
|
||||||
},
|
},
|
||||||
|
setShouldFitImageSize: (state, action: PayloadAction<boolean>) => {
|
||||||
|
state.shouldFitImageSize = action.payload;
|
||||||
|
},
|
||||||
setShouldCropToBoundingBoxOnSave: (state, action: PayloadAction<boolean>) => {
|
setShouldCropToBoundingBoxOnSave: (state, action: PayloadAction<boolean>) => {
|
||||||
state.shouldCropToBoundingBoxOnSave = action.payload;
|
state.shouldCropToBoundingBoxOnSave = action.payload;
|
||||||
},
|
},
|
||||||
@ -685,6 +704,7 @@ export const {
|
|||||||
setShouldRestrictStrokesToBox,
|
setShouldRestrictStrokesToBox,
|
||||||
stagingAreaInitialized,
|
stagingAreaInitialized,
|
||||||
setShouldAntialias,
|
setShouldAntialias,
|
||||||
|
setShouldFitImageSize,
|
||||||
canvasResized,
|
canvasResized,
|
||||||
canvasBatchIdAdded,
|
canvasBatchIdAdded,
|
||||||
canvasBatchIdsReset,
|
canvasBatchIdsReset,
|
||||||
|
@ -120,6 +120,7 @@ export interface CanvasState {
|
|||||||
shouldAutoSave: boolean;
|
shouldAutoSave: boolean;
|
||||||
shouldCropToBoundingBoxOnSave: boolean;
|
shouldCropToBoundingBoxOnSave: boolean;
|
||||||
shouldDarkenOutsideBoundingBox: boolean;
|
shouldDarkenOutsideBoundingBox: boolean;
|
||||||
|
shouldFitImageSize: boolean;
|
||||||
shouldInvertBrushSizeScrollDirection: boolean;
|
shouldInvertBrushSizeScrollDirection: boolean;
|
||||||
shouldLockBoundingBox: boolean;
|
shouldLockBoundingBox: boolean;
|
||||||
shouldPreserveMaskedArea: boolean;
|
shouldPreserveMaskedArea: boolean;
|
||||||
|
@ -1,12 +1,18 @@
|
|||||||
import { Combobox, FormControl, Tooltip } from '@invoke-ai/ui-library';
|
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
||||||
|
import { Combobox, Flex, FormControl, Tooltip } from '@invoke-ai/ui-library';
|
||||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||||
|
import { useControlAdapterCLIPVisionModel } from 'features/controlAdapters/hooks/useControlAdapterCLIPVisionModel';
|
||||||
import { useControlAdapterIsEnabled } from 'features/controlAdapters/hooks/useControlAdapterIsEnabled';
|
import { useControlAdapterIsEnabled } from 'features/controlAdapters/hooks/useControlAdapterIsEnabled';
|
||||||
import { useControlAdapterModel } from 'features/controlAdapters/hooks/useControlAdapterModel';
|
import { useControlAdapterModel } from 'features/controlAdapters/hooks/useControlAdapterModel';
|
||||||
import { useControlAdapterModels } from 'features/controlAdapters/hooks/useControlAdapterModels';
|
import { useControlAdapterModels } from 'features/controlAdapters/hooks/useControlAdapterModels';
|
||||||
import { useControlAdapterType } from 'features/controlAdapters/hooks/useControlAdapterType';
|
import { useControlAdapterType } from 'features/controlAdapters/hooks/useControlAdapterType';
|
||||||
import { controlAdapterModelChanged } from 'features/controlAdapters/store/controlAdaptersSlice';
|
import {
|
||||||
|
controlAdapterCLIPVisionModelChanged,
|
||||||
|
controlAdapterModelChanged,
|
||||||
|
} from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||||
|
import type { CLIPVisionModel } from 'features/controlAdapters/store/types';
|
||||||
import { selectGenerationSlice } from 'features/parameters/store/generationSlice';
|
import { selectGenerationSlice } from 'features/parameters/store/generationSlice';
|
||||||
import { memo, useCallback, useMemo } from 'react';
|
import { memo, useCallback, useMemo } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
@ -29,6 +35,7 @@ const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => {
|
|||||||
const { modelConfig } = useControlAdapterModel(id);
|
const { modelConfig } = useControlAdapterModel(id);
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const currentBaseModel = useAppSelector((s) => s.generation.model?.base);
|
const currentBaseModel = useAppSelector((s) => s.generation.model?.base);
|
||||||
|
const currentCLIPVisionModel = useControlAdapterCLIPVisionModel(id);
|
||||||
const mainModel = useAppSelector(selectMainModel);
|
const mainModel = useAppSelector(selectMainModel);
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
@ -49,6 +56,16 @@ const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => {
|
|||||||
[dispatch, id]
|
[dispatch, id]
|
||||||
);
|
);
|
||||||
|
|
||||||
|
const onCLIPVisionModelChange = useCallback<ComboboxOnChange>(
|
||||||
|
(v) => {
|
||||||
|
if (!v?.value) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
dispatch(controlAdapterCLIPVisionModelChanged({ id, clipVisionModel: v.value as CLIPVisionModel }));
|
||||||
|
},
|
||||||
|
[dispatch, id]
|
||||||
|
);
|
||||||
|
|
||||||
const selectedModel = useMemo(
|
const selectedModel = useMemo(
|
||||||
() => (modelConfig && controlAdapterType ? { ...modelConfig, model_type: controlAdapterType } : null),
|
() => (modelConfig && controlAdapterType ? { ...modelConfig, model_type: controlAdapterType } : null),
|
||||||
[controlAdapterType, modelConfig]
|
[controlAdapterType, modelConfig]
|
||||||
@ -71,9 +88,27 @@ const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => {
|
|||||||
isLoading,
|
isLoading,
|
||||||
});
|
});
|
||||||
|
|
||||||
|
const clipVisionOptions = useMemo<ComboboxOption[]>(
|
||||||
|
() => [
|
||||||
|
{ label: 'ViT-H', value: 'ViT-H' },
|
||||||
|
{ label: 'ViT-G', value: 'ViT-G' },
|
||||||
|
],
|
||||||
|
[]
|
||||||
|
);
|
||||||
|
|
||||||
|
const clipVisionModel = useMemo(
|
||||||
|
() => clipVisionOptions.find((o) => o.value === currentCLIPVisionModel),
|
||||||
|
[clipVisionOptions, currentCLIPVisionModel]
|
||||||
|
);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
|
<Flex sx={{ gap: 2 }}>
|
||||||
<Tooltip label={value?.description}>
|
<Tooltip label={value?.description}>
|
||||||
<FormControl isDisabled={!isEnabled} isInvalid={!value || mainModel?.base !== modelConfig?.base}>
|
<FormControl
|
||||||
|
isDisabled={!isEnabled}
|
||||||
|
isInvalid={!value || mainModel?.base !== modelConfig?.base}
|
||||||
|
sx={{ width: '100%' }}
|
||||||
|
>
|
||||||
<Combobox
|
<Combobox
|
||||||
options={options}
|
options={options}
|
||||||
placeholder={t('controlnet.selectModel')}
|
placeholder={t('controlnet.selectModel')}
|
||||||
@ -83,6 +118,21 @@ const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => {
|
|||||||
/>
|
/>
|
||||||
</FormControl>
|
</FormControl>
|
||||||
</Tooltip>
|
</Tooltip>
|
||||||
|
{modelConfig?.type === 'ip_adapter' && modelConfig.format === 'checkpoint' && (
|
||||||
|
<FormControl
|
||||||
|
isDisabled={!isEnabled}
|
||||||
|
isInvalid={!value || mainModel?.base !== modelConfig?.base}
|
||||||
|
sx={{ width: 'max-content', minWidth: 28 }}
|
||||||
|
>
|
||||||
|
<Combobox
|
||||||
|
options={clipVisionOptions}
|
||||||
|
placeholder={t('controlnet.selectCLIPVisionModel')}
|
||||||
|
value={clipVisionModel}
|
||||||
|
onChange={onCLIPVisionModelChange}
|
||||||
|
/>
|
||||||
|
</FormControl>
|
||||||
|
)}
|
||||||
|
</Flex>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -0,0 +1,24 @@
|
|||||||
|
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||||
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import {
|
||||||
|
selectControlAdapterById,
|
||||||
|
selectControlAdaptersSlice,
|
||||||
|
} from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||||
|
import { useMemo } from 'react';
|
||||||
|
|
||||||
|
export const useControlAdapterCLIPVisionModel = (id: string) => {
|
||||||
|
const selector = useMemo(
|
||||||
|
() =>
|
||||||
|
createMemoizedSelector(selectControlAdaptersSlice, (controlAdapters) => {
|
||||||
|
const cn = selectControlAdapterById(controlAdapters, id);
|
||||||
|
if (cn && cn?.type === 'ip_adapter') {
|
||||||
|
return cn.clipVisionModel;
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
[id]
|
||||||
|
);
|
||||||
|
|
||||||
|
const clipVisionModel = useAppSelector(selector);
|
||||||
|
|
||||||
|
return clipVisionModel;
|
||||||
|
};
|
@ -14,6 +14,7 @@ import { v4 as uuidv4 } from 'uuid';
|
|||||||
import { controlAdapterImageProcessed } from './actions';
|
import { controlAdapterImageProcessed } from './actions';
|
||||||
import { CONTROLNET_PROCESSORS } from './constants';
|
import { CONTROLNET_PROCESSORS } from './constants';
|
||||||
import type {
|
import type {
|
||||||
|
CLIPVisionModel,
|
||||||
ControlAdapterConfig,
|
ControlAdapterConfig,
|
||||||
ControlAdapterProcessorType,
|
ControlAdapterProcessorType,
|
||||||
ControlAdaptersState,
|
ControlAdaptersState,
|
||||||
@ -244,6 +245,13 @@ export const controlAdaptersSlice = createSlice({
|
|||||||
}
|
}
|
||||||
caAdapter.updateOne(state, { id, changes: { controlMode } });
|
caAdapter.updateOne(state, { id, changes: { controlMode } });
|
||||||
},
|
},
|
||||||
|
controlAdapterCLIPVisionModelChanged: (
|
||||||
|
state,
|
||||||
|
action: PayloadAction<{ id: string; clipVisionModel: CLIPVisionModel }>
|
||||||
|
) => {
|
||||||
|
const { id, clipVisionModel } = action.payload;
|
||||||
|
caAdapter.updateOne(state, { id, changes: { clipVisionModel } });
|
||||||
|
},
|
||||||
controlAdapterResizeModeChanged: (
|
controlAdapterResizeModeChanged: (
|
||||||
state,
|
state,
|
||||||
action: PayloadAction<{
|
action: PayloadAction<{
|
||||||
@ -381,6 +389,7 @@ export const {
|
|||||||
controlAdapterProcessedImageChanged,
|
controlAdapterProcessedImageChanged,
|
||||||
controlAdapterIsEnabledChanged,
|
controlAdapterIsEnabledChanged,
|
||||||
controlAdapterModelChanged,
|
controlAdapterModelChanged,
|
||||||
|
controlAdapterCLIPVisionModelChanged,
|
||||||
controlAdapterWeightChanged,
|
controlAdapterWeightChanged,
|
||||||
controlAdapterBeginStepPctChanged,
|
controlAdapterBeginStepPctChanged,
|
||||||
controlAdapterEndStepPctChanged,
|
controlAdapterEndStepPctChanged,
|
||||||
|
@ -243,12 +243,15 @@ export type T2IAdapterConfig = {
|
|||||||
shouldAutoConfig: boolean;
|
shouldAutoConfig: boolean;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export type CLIPVisionModel = 'ViT-H' | 'ViT-G';
|
||||||
|
|
||||||
export type IPAdapterConfig = {
|
export type IPAdapterConfig = {
|
||||||
type: 'ip_adapter';
|
type: 'ip_adapter';
|
||||||
id: string;
|
id: string;
|
||||||
isEnabled: boolean;
|
isEnabled: boolean;
|
||||||
controlImage: string | null;
|
controlImage: string | null;
|
||||||
model: ParameterIPAdapterModel | null;
|
model: ParameterIPAdapterModel | null;
|
||||||
|
clipVisionModel: CLIPVisionModel;
|
||||||
weight: number;
|
weight: number;
|
||||||
beginStepPct: number;
|
beginStepPct: number;
|
||||||
endStepPct: number;
|
endStepPct: number;
|
||||||
|
@ -46,6 +46,7 @@ export const initialIPAdapter: Omit<IPAdapterConfig, 'id'> = {
|
|||||||
isEnabled: true,
|
isEnabled: true,
|
||||||
controlImage: null,
|
controlImage: null,
|
||||||
model: null,
|
model: null,
|
||||||
|
clipVisionModel: 'ViT-H',
|
||||||
weight: 1,
|
weight: 1,
|
||||||
beginStepPct: 0,
|
beginStepPct: 0,
|
||||||
endStepPct: 1,
|
endStepPct: 1,
|
||||||
|
@ -33,6 +33,7 @@ const ImageMetadataActions = (props: Props) => {
|
|||||||
<MetadataItem metadata={metadata} handlers={handlers.scheduler} />
|
<MetadataItem metadata={metadata} handlers={handlers.scheduler} />
|
||||||
<MetadataItem metadata={metadata} handlers={handlers.cfgScale} />
|
<MetadataItem metadata={metadata} handlers={handlers.cfgScale} />
|
||||||
<MetadataItem metadata={metadata} handlers={handlers.cfgRescaleMultiplier} />
|
<MetadataItem metadata={metadata} handlers={handlers.cfgRescaleMultiplier} />
|
||||||
|
<MetadataItem metadata={metadata} handlers={handlers.initialImage} />
|
||||||
<MetadataItem metadata={metadata} handlers={handlers.strength} />
|
<MetadataItem metadata={metadata} handlers={handlers.strength} />
|
||||||
<MetadataItem metadata={metadata} handlers={handlers.hrfEnabled} />
|
<MetadataItem metadata={metadata} handlers={handlers.hrfEnabled} />
|
||||||
<MetadataItem metadata={metadata} handlers={handlers.hrfMethod} />
|
<MetadataItem metadata={metadata} handlers={handlers.hrfMethod} />
|
||||||
|
@ -189,6 +189,12 @@ export const handlers = {
|
|||||||
recaller: recallers.cfgScale,
|
recaller: recallers.cfgScale,
|
||||||
}),
|
}),
|
||||||
height: buildHandlers({ getLabel: () => t('metadata.height'), parser: parsers.height, recaller: recallers.height }),
|
height: buildHandlers({ getLabel: () => t('metadata.height'), parser: parsers.height, recaller: recallers.height }),
|
||||||
|
initialImage: buildHandlers({
|
||||||
|
getLabel: () => t('metadata.initImage'),
|
||||||
|
parser: parsers.initialImage,
|
||||||
|
recaller: recallers.initialImage,
|
||||||
|
renderValue: async (imageDTO) => imageDTO.image_name,
|
||||||
|
}),
|
||||||
negativePrompt: buildHandlers({
|
negativePrompt: buildHandlers({
|
||||||
getLabel: () => t('metadata.negativePrompt'),
|
getLabel: () => t('metadata.negativePrompt'),
|
||||||
parser: parsers.negativePrompt,
|
parser: parsers.negativePrompt,
|
||||||
@ -405,6 +411,6 @@ export const parseAndRecallAllMetadata = async (metadata: unknown, skip: (keyof
|
|||||||
})
|
})
|
||||||
);
|
);
|
||||||
if (results.some((result) => result.status === 'fulfilled')) {
|
if (results.some((result) => result.status === 'fulfilled')) {
|
||||||
parameterSetToast(t('toast.parametersSet'));
|
parameterSetToast(t('toast.parameters'));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import { getStore } from 'app/store/nanostores/store';
|
||||||
import {
|
import {
|
||||||
initialControlNet,
|
initialControlNet,
|
||||||
initialIPAdapter,
|
initialIPAdapter,
|
||||||
@ -57,6 +58,8 @@ import {
|
|||||||
isParameterWidth,
|
isParameterWidth,
|
||||||
} from 'features/parameters/types/parameterSchemas';
|
} from 'features/parameters/types/parameterSchemas';
|
||||||
import { get, isArray, isString } from 'lodash-es';
|
import { get, isArray, isString } from 'lodash-es';
|
||||||
|
import { imagesApi } from 'services/api/endpoints/images';
|
||||||
|
import type { ImageDTO } from 'services/api/types';
|
||||||
import {
|
import {
|
||||||
isControlNetModelConfig,
|
isControlNetModelConfig,
|
||||||
isIPAdapterModelConfig,
|
isIPAdapterModelConfig,
|
||||||
@ -135,6 +138,14 @@ const parseCFGRescaleMultiplier: MetadataParseFunc<ParameterCFGRescaleMultiplier
|
|||||||
const parseScheduler: MetadataParseFunc<ParameterScheduler> = (metadata) =>
|
const parseScheduler: MetadataParseFunc<ParameterScheduler> = (metadata) =>
|
||||||
getProperty(metadata, 'scheduler', isParameterScheduler);
|
getProperty(metadata, 'scheduler', isParameterScheduler);
|
||||||
|
|
||||||
|
const parseInitialImage: MetadataParseFunc<ImageDTO> = async (metadata) => {
|
||||||
|
const imageName = await getProperty(metadata, 'init_image', isString);
|
||||||
|
const imageDTORequest = getStore().dispatch(imagesApi.endpoints.getImageDTO.initiate(imageName));
|
||||||
|
const imageDTO = await imageDTORequest.unwrap();
|
||||||
|
imageDTORequest.unsubscribe();
|
||||||
|
return imageDTO;
|
||||||
|
};
|
||||||
|
|
||||||
const parseWidth: MetadataParseFunc<ParameterWidth> = (metadata) => getProperty(metadata, 'width', isParameterWidth);
|
const parseWidth: MetadataParseFunc<ParameterWidth> = (metadata) => getProperty(metadata, 'width', isParameterWidth);
|
||||||
|
|
||||||
const parseHeight: MetadataParseFunc<ParameterHeight> = (metadata) =>
|
const parseHeight: MetadataParseFunc<ParameterHeight> = (metadata) =>
|
||||||
@ -372,6 +383,7 @@ const parseIPAdapter: MetadataParseFunc<IPAdapterConfigMetadata> = async (metada
|
|||||||
type: 'ip_adapter',
|
type: 'ip_adapter',
|
||||||
isEnabled: true,
|
isEnabled: true,
|
||||||
model: zModelIdentifierField.parse(ipAdapterModel),
|
model: zModelIdentifierField.parse(ipAdapterModel),
|
||||||
|
clipVisionModel: 'ViT-H',
|
||||||
controlImage: image?.image_name ?? null,
|
controlImage: image?.image_name ?? null,
|
||||||
weight: weight ?? initialIPAdapter.weight,
|
weight: weight ?? initialIPAdapter.weight,
|
||||||
beginStepPct: begin_step_percent ?? initialIPAdapter.beginStepPct,
|
beginStepPct: begin_step_percent ?? initialIPAdapter.beginStepPct,
|
||||||
@ -401,6 +413,7 @@ export const parsers = {
|
|||||||
cfgScale: parseCFGScale,
|
cfgScale: parseCFGScale,
|
||||||
cfgRescaleMultiplier: parseCFGRescaleMultiplier,
|
cfgRescaleMultiplier: parseCFGRescaleMultiplier,
|
||||||
scheduler: parseScheduler,
|
scheduler: parseScheduler,
|
||||||
|
initialImage: parseInitialImage,
|
||||||
width: parseWidth,
|
width: parseWidth,
|
||||||
height: parseHeight,
|
height: parseHeight,
|
||||||
steps: parseSteps,
|
steps: parseSteps,
|
||||||
|
@ -17,6 +17,7 @@ import type {
|
|||||||
import { modelSelected } from 'features/parameters/store/actions';
|
import { modelSelected } from 'features/parameters/store/actions';
|
||||||
import {
|
import {
|
||||||
heightRecalled,
|
heightRecalled,
|
||||||
|
initialImageChanged,
|
||||||
setCfgRescaleMultiplier,
|
setCfgRescaleMultiplier,
|
||||||
setCfgScale,
|
setCfgScale,
|
||||||
setImg2imgStrength,
|
setImg2imgStrength,
|
||||||
@ -61,6 +62,7 @@ import {
|
|||||||
setRefinerStart,
|
setRefinerStart,
|
||||||
setRefinerSteps,
|
setRefinerSteps,
|
||||||
} from 'features/sdxl/store/sdxlSlice';
|
} from 'features/sdxl/store/sdxlSlice';
|
||||||
|
import type { ImageDTO } from 'services/api/types';
|
||||||
|
|
||||||
const recallPositivePrompt: MetadataRecallFunc<ParameterPositivePrompt> = (positivePrompt) => {
|
const recallPositivePrompt: MetadataRecallFunc<ParameterPositivePrompt> = (positivePrompt) => {
|
||||||
getStore().dispatch(setPositivePrompt(positivePrompt));
|
getStore().dispatch(setPositivePrompt(positivePrompt));
|
||||||
@ -94,6 +96,10 @@ const recallScheduler: MetadataRecallFunc<ParameterScheduler> = (scheduler) => {
|
|||||||
getStore().dispatch(setScheduler(scheduler));
|
getStore().dispatch(setScheduler(scheduler));
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const recallInitialImage: MetadataRecallFunc<ImageDTO> = async (imageDTO) => {
|
||||||
|
getStore().dispatch(initialImageChanged(imageDTO));
|
||||||
|
};
|
||||||
|
|
||||||
const recallWidth: MetadataRecallFunc<ParameterWidth> = (width) => {
|
const recallWidth: MetadataRecallFunc<ParameterWidth> = (width) => {
|
||||||
getStore().dispatch(widthRecalled(width));
|
getStore().dispatch(widthRecalled(width));
|
||||||
};
|
};
|
||||||
@ -235,6 +241,7 @@ export const recallers = {
|
|||||||
cfgScale: recallCFGScale,
|
cfgScale: recallCFGScale,
|
||||||
cfgRescaleMultiplier: recallCFGRescaleMultiplier,
|
cfgRescaleMultiplier: recallCFGRescaleMultiplier,
|
||||||
scheduler: recallScheduler,
|
scheduler: recallScheduler,
|
||||||
|
initialImage: recallInitialImage,
|
||||||
width: recallWidth,
|
width: recallWidth,
|
||||||
height: recallHeight,
|
height: recallHeight,
|
||||||
steps: recallSteps,
|
steps: recallSteps,
|
||||||
|
@ -3,7 +3,7 @@ import { createSlice } from '@reduxjs/toolkit';
|
|||||||
import type { PersistConfig } from 'app/store/store';
|
import type { PersistConfig } from 'app/store/store';
|
||||||
import type { ModelType } from 'services/api/types';
|
import type { ModelType } from 'services/api/types';
|
||||||
|
|
||||||
export type FilterableModelType = Exclude<ModelType, 'onnx' | 'clip_vision'>;
|
export type FilterableModelType = Exclude<ModelType, 'onnx' | 'clip_vision'> | 'refiner';
|
||||||
|
|
||||||
type ModelManagerState = {
|
type ModelManagerState = {
|
||||||
_version: 1;
|
_version: 1;
|
||||||
|
@ -87,6 +87,10 @@ export const ModelInstallQueueItem = (props: ModelListItemProps) => {
|
|||||||
}, [installJob.source]);
|
}, [installJob.source]);
|
||||||
|
|
||||||
const progressValue = useMemo(() => {
|
const progressValue = useMemo(() => {
|
||||||
|
if (installJob.status === 'completed' || installJob.status === 'error' || installJob.status === 'cancelled') {
|
||||||
|
return 100;
|
||||||
|
}
|
||||||
|
|
||||||
if (isNil(installJob.bytes) || isNil(installJob.total_bytes)) {
|
if (isNil(installJob.bytes) || isNil(installJob.total_bytes)) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
@ -96,7 +100,7 @@ export const ModelInstallQueueItem = (props: ModelListItemProps) => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return (installJob.bytes / installJob.total_bytes) * 100;
|
return (installJob.bytes / installJob.total_bytes) * 100;
|
||||||
}, [installJob.bytes, installJob.total_bytes]);
|
}, [installJob.bytes, installJob.status, installJob.total_bytes]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Flex gap={3} w="full" alignItems="center">
|
<Flex gap={3} w="full" alignItems="center">
|
||||||
|
@ -1,48 +1,19 @@
|
|||||||
import { Badge, Box, Flex, IconButton, Text } from '@invoke-ai/ui-library';
|
import { Badge, Box, Flex, IconButton, Text } from '@invoke-ai/ui-library';
|
||||||
import { useAppDispatch } from 'app/store/storeHooks';
|
|
||||||
import { addToast } from 'features/system/store/systemSlice';
|
|
||||||
import { makeToast } from 'features/system/util/makeToast';
|
|
||||||
import { useCallback } from 'react';
|
import { useCallback } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { PiPlusBold } from 'react-icons/pi';
|
import { PiPlusBold } from 'react-icons/pi';
|
||||||
import type { ScanFolderResponse } from 'services/api/endpoints/models';
|
import type { ScanFolderResponse } from 'services/api/endpoints/models';
|
||||||
import { useInstallModelMutation } from 'services/api/endpoints/models';
|
|
||||||
|
|
||||||
type Props = {
|
type Props = {
|
||||||
result: ScanFolderResponse[number];
|
result: ScanFolderResponse[number];
|
||||||
|
installModel: (source: string) => void;
|
||||||
};
|
};
|
||||||
export const ScanModelResultItem = ({ result }: Props) => {
|
export const ScanModelResultItem = ({ result, installModel }: Props) => {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const dispatch = useAppDispatch();
|
|
||||||
|
|
||||||
const [installModel] = useInstallModelMutation();
|
const handleInstall = useCallback(() => {
|
||||||
|
installModel(result.path);
|
||||||
const handleQuickAdd = useCallback(() => {
|
}, [installModel, result]);
|
||||||
installModel({ source: result.path })
|
|
||||||
.unwrap()
|
|
||||||
.then((_) => {
|
|
||||||
dispatch(
|
|
||||||
addToast(
|
|
||||||
makeToast({
|
|
||||||
title: t('toast.modelAddedSimple'),
|
|
||||||
status: 'success',
|
|
||||||
})
|
|
||||||
)
|
|
||||||
);
|
|
||||||
})
|
|
||||||
.catch((error) => {
|
|
||||||
if (error) {
|
|
||||||
dispatch(
|
|
||||||
addToast(
|
|
||||||
makeToast({
|
|
||||||
title: `${error.data.detail} `,
|
|
||||||
status: 'error',
|
|
||||||
})
|
|
||||||
)
|
|
||||||
);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}, [installModel, result, dispatch, t]);
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Flex alignItems="center" justifyContent="space-between" w="100%" gap={3}>
|
<Flex alignItems="center" justifyContent="space-between" w="100%" gap={3}>
|
||||||
@ -54,7 +25,7 @@ export const ScanModelResultItem = ({ result }: Props) => {
|
|||||||
{result.is_installed ? (
|
{result.is_installed ? (
|
||||||
<Badge>{t('common.installed')}</Badge>
|
<Badge>{t('common.installed')}</Badge>
|
||||||
) : (
|
) : (
|
||||||
<IconButton aria-label={t('modelManager.install')} icon={<PiPlusBold />} onClick={handleQuickAdd} size="sm" />
|
<IconButton aria-label={t('modelManager.install')} icon={<PiPlusBold />} onClick={handleInstall} size="sm" />
|
||||||
)}
|
)}
|
||||||
</Box>
|
</Box>
|
||||||
</Flex>
|
</Flex>
|
||||||
|
@ -1,7 +1,10 @@
|
|||||||
import {
|
import {
|
||||||
Button,
|
Button,
|
||||||
|
Checkbox,
|
||||||
Divider,
|
Divider,
|
||||||
Flex,
|
Flex,
|
||||||
|
FormControl,
|
||||||
|
FormLabel,
|
||||||
Heading,
|
Heading,
|
||||||
IconButton,
|
IconButton,
|
||||||
Input,
|
Input,
|
||||||
@ -12,7 +15,7 @@ import { useAppDispatch } from 'app/store/storeHooks';
|
|||||||
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
||||||
import { addToast } from 'features/system/store/systemSlice';
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
import { makeToast } from 'features/system/util/makeToast';
|
import { makeToast } from 'features/system/util/makeToast';
|
||||||
import type { ChangeEventHandler } from 'react';
|
import type { ChangeEvent, ChangeEventHandler } from 'react';
|
||||||
import { useCallback, useMemo, useState } from 'react';
|
import { useCallback, useMemo, useState } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { PiXBold } from 'react-icons/pi';
|
import { PiXBold } from 'react-icons/pi';
|
||||||
@ -28,7 +31,7 @@ export const ScanModelsResults = ({ results }: ScanModelResultsProps) => {
|
|||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const [searchTerm, setSearchTerm] = useState('');
|
const [searchTerm, setSearchTerm] = useState('');
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
const [inplace, setInplace] = useState(true);
|
||||||
const [installModel] = useInstallModelMutation();
|
const [installModel] = useInstallModelMutation();
|
||||||
|
|
||||||
const filteredResults = useMemo(() => {
|
const filteredResults = useMemo(() => {
|
||||||
@ -42,6 +45,10 @@ export const ScanModelsResults = ({ results }: ScanModelResultsProps) => {
|
|||||||
setSearchTerm(e.target.value.trim());
|
setSearchTerm(e.target.value.trim());
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
|
const onChangeInplace = useCallback((e: ChangeEvent<HTMLInputElement>) => {
|
||||||
|
setInplace(e.target.checked);
|
||||||
|
}, []);
|
||||||
|
|
||||||
const clearSearch = useCallback(() => {
|
const clearSearch = useCallback(() => {
|
||||||
setSearchTerm('');
|
setSearchTerm('');
|
||||||
}, []);
|
}, []);
|
||||||
@ -51,7 +58,7 @@ export const ScanModelsResults = ({ results }: ScanModelResultsProps) => {
|
|||||||
if (result.is_installed) {
|
if (result.is_installed) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
installModel({ source: result.path })
|
installModel({ source: result.path, inplace })
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.then((_) => {
|
.then((_) => {
|
||||||
dispatch(
|
dispatch(
|
||||||
@ -76,7 +83,37 @@ export const ScanModelsResults = ({ results }: ScanModelResultsProps) => {
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}, [installModel, filteredResults, dispatch, t]);
|
}, [filteredResults, installModel, inplace, dispatch, t]);
|
||||||
|
|
||||||
|
const handleInstallOne = useCallback(
|
||||||
|
(source: string) => {
|
||||||
|
installModel({ source, inplace })
|
||||||
|
.unwrap()
|
||||||
|
.then((_) => {
|
||||||
|
dispatch(
|
||||||
|
addToast(
|
||||||
|
makeToast({
|
||||||
|
title: t('toast.modelAddedSimple'),
|
||||||
|
status: 'success',
|
||||||
|
})
|
||||||
|
)
|
||||||
|
);
|
||||||
|
})
|
||||||
|
.catch((error) => {
|
||||||
|
if (error) {
|
||||||
|
dispatch(
|
||||||
|
addToast(
|
||||||
|
makeToast({
|
||||||
|
title: `${error.data.detail} `,
|
||||||
|
status: 'error',
|
||||||
|
})
|
||||||
|
)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
},
|
||||||
|
[installModel, inplace, dispatch, t]
|
||||||
|
);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
@ -85,6 +122,10 @@ export const ScanModelsResults = ({ results }: ScanModelResultsProps) => {
|
|||||||
<Flex justifyContent="space-between" alignItems="center">
|
<Flex justifyContent="space-between" alignItems="center">
|
||||||
<Heading size="sm">{t('modelManager.scanResults')}</Heading>
|
<Heading size="sm">{t('modelManager.scanResults')}</Heading>
|
||||||
<Flex alignItems="center" gap={3}>
|
<Flex alignItems="center" gap={3}>
|
||||||
|
<FormControl w="min-content">
|
||||||
|
<FormLabel m={0}>{t('modelManager.inplaceInstall')}</FormLabel>
|
||||||
|
<Checkbox isChecked={inplace} onChange={onChangeInplace} size="md" />
|
||||||
|
</FormControl>
|
||||||
<Button size="sm" onClick={handleAddAll} isDisabled={filteredResults.length === 0}>
|
<Button size="sm" onClick={handleAddAll} isDisabled={filteredResults.length === 0}>
|
||||||
{t('modelManager.installAll')}
|
{t('modelManager.installAll')}
|
||||||
</Button>
|
</Button>
|
||||||
@ -116,7 +157,7 @@ export const ScanModelsResults = ({ results }: ScanModelResultsProps) => {
|
|||||||
<ScrollableContent>
|
<ScrollableContent>
|
||||||
<Flex flexDir="column" gap={3}>
|
<Flex flexDir="column" gap={3}>
|
||||||
{filteredResults.map((result) => (
|
{filteredResults.map((result) => (
|
||||||
<ScanModelResultItem key={result.path} result={result} />
|
<ScanModelResultItem key={result.path} result={result} installModel={handleInstallOne} />
|
||||||
))}
|
))}
|
||||||
</Flex>
|
</Flex>
|
||||||
</ScrollableContent>
|
</ScrollableContent>
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import { Flex } from '@invoke-ai/ui-library';
|
import { Flex, Text } from '@invoke-ai/ui-library';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
||||||
|
import type { FilterableModelType } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
||||||
import { memo, useMemo } from 'react';
|
import { memo, useMemo } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import {
|
import {
|
||||||
@ -9,10 +10,11 @@ import {
|
|||||||
useIPAdapterModels,
|
useIPAdapterModels,
|
||||||
useLoRAModels,
|
useLoRAModels,
|
||||||
useMainModels,
|
useMainModels,
|
||||||
|
useRefinerModels,
|
||||||
useT2IAdapterModels,
|
useT2IAdapterModels,
|
||||||
useVAEModels,
|
useVAEModels,
|
||||||
} from 'services/api/hooks/modelsByType';
|
} from 'services/api/hooks/modelsByType';
|
||||||
import type { AnyModelConfig, ModelType } from 'services/api/types';
|
import type { AnyModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
import { FetchingModelsLoader } from './FetchingModelsLoader';
|
import { FetchingModelsLoader } from './FetchingModelsLoader';
|
||||||
import { ModelListWrapper } from './ModelListWrapper';
|
import { ModelListWrapper } from './ModelListWrapper';
|
||||||
@ -27,6 +29,12 @@ const ModelList = () => {
|
|||||||
[mainModels, searchTerm, filteredModelType]
|
[mainModels, searchTerm, filteredModelType]
|
||||||
);
|
);
|
||||||
|
|
||||||
|
const [refinerModels, { isLoading: isLoadingRefinerModels }] = useRefinerModels();
|
||||||
|
const filteredRefinerModels = useMemo(
|
||||||
|
() => modelsFilter(refinerModels, searchTerm, filteredModelType),
|
||||||
|
[refinerModels, searchTerm, filteredModelType]
|
||||||
|
);
|
||||||
|
|
||||||
const [loraModels, { isLoading: isLoadingLoRAModels }] = useLoRAModels();
|
const [loraModels, { isLoading: isLoadingLoRAModels }] = useLoRAModels();
|
||||||
const filteredLoRAModels = useMemo(
|
const filteredLoRAModels = useMemo(
|
||||||
() => modelsFilter(loraModels, searchTerm, filteredModelType),
|
() => modelsFilter(loraModels, searchTerm, filteredModelType),
|
||||||
@ -63,6 +71,28 @@ const ModelList = () => {
|
|||||||
[vaeModels, searchTerm, filteredModelType]
|
[vaeModels, searchTerm, filteredModelType]
|
||||||
);
|
);
|
||||||
|
|
||||||
|
const totalFilteredModels = useMemo(() => {
|
||||||
|
return (
|
||||||
|
filteredMainModels.length +
|
||||||
|
filteredRefinerModels.length +
|
||||||
|
filteredLoRAModels.length +
|
||||||
|
filteredEmbeddingModels.length +
|
||||||
|
filteredControlNetModels.length +
|
||||||
|
filteredT2IAdapterModels.length +
|
||||||
|
filteredIPAdapterModels.length +
|
||||||
|
filteredVAEModels.length
|
||||||
|
);
|
||||||
|
}, [
|
||||||
|
filteredControlNetModels.length,
|
||||||
|
filteredEmbeddingModels.length,
|
||||||
|
filteredIPAdapterModels.length,
|
||||||
|
filteredLoRAModels.length,
|
||||||
|
filteredMainModels.length,
|
||||||
|
filteredRefinerModels.length,
|
||||||
|
filteredT2IAdapterModels.length,
|
||||||
|
filteredVAEModels.length,
|
||||||
|
]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<ScrollableContent>
|
<ScrollableContent>
|
||||||
<Flex flexDirection="column" w="full" h="full" gap={4}>
|
<Flex flexDirection="column" w="full" h="full" gap={4}>
|
||||||
@ -71,6 +101,11 @@ const ModelList = () => {
|
|||||||
{!isLoadingMainModels && filteredMainModels.length > 0 && (
|
{!isLoadingMainModels && filteredMainModels.length > 0 && (
|
||||||
<ModelListWrapper title={t('modelManager.main')} modelList={filteredMainModels} key="main" />
|
<ModelListWrapper title={t('modelManager.main')} modelList={filteredMainModels} key="main" />
|
||||||
)}
|
)}
|
||||||
|
{/* Refiner Model List */}
|
||||||
|
{isLoadingRefinerModels && <FetchingModelsLoader loadingMessage="Loading Refiner Models..." />}
|
||||||
|
{!isLoadingRefinerModels && filteredRefinerModels.length > 0 && (
|
||||||
|
<ModelListWrapper title={t('sdxl.refiner')} modelList={filteredRefinerModels} key="refiner" />
|
||||||
|
)}
|
||||||
{/* LoRAs List */}
|
{/* LoRAs List */}
|
||||||
{isLoadingLoRAModels && <FetchingModelsLoader loadingMessage="Loading LoRAs..." />}
|
{isLoadingLoRAModels && <FetchingModelsLoader loadingMessage="Loading LoRAs..." />}
|
||||||
{!isLoadingLoRAModels && filteredLoRAModels.length > 0 && (
|
{!isLoadingLoRAModels && filteredLoRAModels.length > 0 && (
|
||||||
@ -108,6 +143,11 @@ const ModelList = () => {
|
|||||||
{!isLoadingT2IAdapterModels && filteredT2IAdapterModels.length > 0 && (
|
{!isLoadingT2IAdapterModels && filteredT2IAdapterModels.length > 0 && (
|
||||||
<ModelListWrapper title={t('common.t2iAdapter')} modelList={filteredT2IAdapterModels} key="t2i-adapters" />
|
<ModelListWrapper title={t('common.t2iAdapter')} modelList={filteredT2IAdapterModels} key="t2i-adapters" />
|
||||||
)}
|
)}
|
||||||
|
{totalFilteredModels === 0 && (
|
||||||
|
<Flex w="full" h="full" alignItems="center" justifyContent="center">
|
||||||
|
<Text>{t('modelManager.noMatchingModels')}</Text>
|
||||||
|
</Flex>
|
||||||
|
)}
|
||||||
</Flex>
|
</Flex>
|
||||||
</ScrollableContent>
|
</ScrollableContent>
|
||||||
);
|
);
|
||||||
@ -118,12 +158,24 @@ export default memo(ModelList);
|
|||||||
const modelsFilter = <T extends AnyModelConfig>(
|
const modelsFilter = <T extends AnyModelConfig>(
|
||||||
data: T[],
|
data: T[],
|
||||||
nameFilter: string,
|
nameFilter: string,
|
||||||
filteredModelType: ModelType | null
|
filteredModelType: FilterableModelType | null
|
||||||
): T[] => {
|
): T[] => {
|
||||||
return data.filter((model) => {
|
return data.filter((model) => {
|
||||||
const matchesFilter = model.name.toLowerCase().includes(nameFilter.toLowerCase());
|
const matchesFilter = model.name.toLowerCase().includes(nameFilter.toLowerCase());
|
||||||
const matchesType = filteredModelType ? model.type === filteredModelType : true;
|
const matchesType = getMatchesType(model, filteredModelType);
|
||||||
|
|
||||||
return matchesFilter && matchesType;
|
return matchesFilter && matchesType;
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const getMatchesType = (modelConfig: AnyModelConfig, filteredModelType: FilterableModelType | null): boolean => {
|
||||||
|
if (filteredModelType === 'refiner') {
|
||||||
|
return modelConfig.base === 'sdxl-refiner';
|
||||||
|
}
|
||||||
|
|
||||||
|
if (filteredModelType === 'main' && modelConfig.base === 'sdxl-refiner') {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return filteredModelType ? modelConfig.type === filteredModelType : true;
|
||||||
|
};
|
||||||
|
@ -90,11 +90,13 @@ const ModelListItem = (props: ModelListItemProps) => {
|
|||||||
cursor="pointer"
|
cursor="pointer"
|
||||||
onClick={handleSelectModel}
|
onClick={handleSelectModel}
|
||||||
>
|
>
|
||||||
<Flex gap={2} w="full" h="full">
|
<Flex gap={2} w="full" h="full" minW={0}>
|
||||||
<ModelImage image_url={model.cover_image} />
|
<ModelImage image_url={model.cover_image} />
|
||||||
<Flex gap={1} alignItems="flex-start" flexDir="column" w="full">
|
<Flex gap={1} alignItems="flex-start" flexDir="column" w="full" minW={0}>
|
||||||
<Flex gap={2} w="full" alignItems="flex-start">
|
<Flex gap={2} w="full" alignItems="flex-start">
|
||||||
<Text fontWeight="semibold">{model.name}</Text>
|
<Text fontWeight="semibold" noOfLines={1} wordBreak="break-all">
|
||||||
|
{model.name}
|
||||||
|
</Text>
|
||||||
<Spacer />
|
<Spacer />
|
||||||
</Flex>
|
</Flex>
|
||||||
<Text variant="subtext" noOfLines={1}>
|
<Text variant="subtext" noOfLines={1}>
|
||||||
|
@ -13,6 +13,7 @@ export const ModelTypeFilter = () => {
|
|||||||
const MODEL_TYPE_LABELS: Record<FilterableModelType, string> = useMemo(
|
const MODEL_TYPE_LABELS: Record<FilterableModelType, string> = useMemo(
|
||||||
() => ({
|
() => ({
|
||||||
main: t('modelManager.main'),
|
main: t('modelManager.main'),
|
||||||
|
refiner: t('sdxl.refiner'),
|
||||||
lora: 'LoRA',
|
lora: 'LoRA',
|
||||||
embedding: t('modelManager.textualInversions'),
|
embedding: t('modelManager.textualInversions'),
|
||||||
controlnet: 'ControlNet',
|
controlnet: 'ControlNet',
|
||||||
|
@ -87,9 +87,9 @@ export const Model = () => {
|
|||||||
<Flex flexDir="column" gap={4}>
|
<Flex flexDir="column" gap={4}>
|
||||||
<Flex alignItems="flex-start" gap={4}>
|
<Flex alignItems="flex-start" gap={4}>
|
||||||
<ModelImageUpload model_key={selectedModelKey} model_image={data.cover_image} />
|
<ModelImageUpload model_key={selectedModelKey} model_image={data.cover_image} />
|
||||||
<Flex flexDir="column" gap={1} flexGrow={1}>
|
<Flex flexDir="column" gap={1} flexGrow={1} minW={0}>
|
||||||
<Flex gap={2}>
|
<Flex gap={2}>
|
||||||
<Heading as="h2" fontSize="lg">
|
<Heading as="h2" fontSize="lg" noOfLines={1} wordBreak="break-all">
|
||||||
{data.name}
|
{data.name}
|
||||||
</Heading>
|
</Heading>
|
||||||
<Spacer />
|
<Spacer />
|
||||||
@ -114,7 +114,7 @@ export const Model = () => {
|
|||||||
)}
|
)}
|
||||||
</Flex>
|
</Flex>
|
||||||
{data.source && (
|
{data.source && (
|
||||||
<Text variant="subtext">
|
<Text variant="subtext" noOfLines={1} wordBreak="break-all">
|
||||||
{t('modelManager.source')}: {data?.source}
|
{t('modelManager.source')}: {data?.source}
|
||||||
</Text>
|
</Text>
|
||||||
)}
|
)}
|
||||||
|
@ -9,7 +9,9 @@ export const ModelAttrView = ({ label, value }: Props) => {
|
|||||||
return (
|
return (
|
||||||
<FormControl flexDir="column" alignItems="flex-start" gap={0}>
|
<FormControl flexDir="column" alignItems="flex-start" gap={0}>
|
||||||
<FormLabel>{label}</FormLabel>
|
<FormLabel>{label}</FormLabel>
|
||||||
<Text fontSize="md">{value || '-'}</Text>
|
<Text fontSize="md" noOfLines={1} wordBreak="break-all">
|
||||||
|
{value || '-'}
|
||||||
|
</Text>
|
||||||
</FormControl>
|
</FormControl>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
@ -53,7 +53,7 @@ export const ModelView = () => {
|
|||||||
</>
|
</>
|
||||||
)}
|
)}
|
||||||
|
|
||||||
{data.type === 'ip_adapter' && (
|
{data.type === 'ip_adapter' && data.format === 'invokeai' && (
|
||||||
<Flex gap={2}>
|
<Flex gap={2}>
|
||||||
<ModelAttrView label={t('modelManager.imageEncoderModelId')} value={data.image_encoder_model_id} />
|
<ModelAttrView label={t('modelManager.imageEncoderModelId')} value={data.image_encoder_model_id} />
|
||||||
</Flex>
|
</Flex>
|
||||||
|
@ -37,34 +37,50 @@ const NumberFieldInputComponent = (
|
|||||||
);
|
);
|
||||||
|
|
||||||
const min = useMemo(() => {
|
const min = useMemo(() => {
|
||||||
|
let min = -NUMPY_RAND_MAX;
|
||||||
if (!isNil(fieldTemplate.minimum)) {
|
if (!isNil(fieldTemplate.minimum)) {
|
||||||
return fieldTemplate.minimum;
|
min = fieldTemplate.minimum;
|
||||||
}
|
}
|
||||||
if (!isNil(fieldTemplate.exclusiveMinimum)) {
|
if (!isNil(fieldTemplate.exclusiveMinimum)) {
|
||||||
return fieldTemplate.exclusiveMinimum + 0.01;
|
min = fieldTemplate.exclusiveMinimum + 0.01;
|
||||||
}
|
}
|
||||||
return;
|
return min;
|
||||||
}, [fieldTemplate.exclusiveMinimum, fieldTemplate.minimum]);
|
}, [fieldTemplate.exclusiveMinimum, fieldTemplate.minimum]);
|
||||||
|
|
||||||
const max = useMemo(() => {
|
const max = useMemo(() => {
|
||||||
|
let max = NUMPY_RAND_MAX;
|
||||||
if (!isNil(fieldTemplate.maximum)) {
|
if (!isNil(fieldTemplate.maximum)) {
|
||||||
return fieldTemplate.maximum;
|
max = fieldTemplate.maximum;
|
||||||
}
|
}
|
||||||
if (!isNil(fieldTemplate.exclusiveMaximum)) {
|
if (!isNil(fieldTemplate.exclusiveMaximum)) {
|
||||||
return fieldTemplate.exclusiveMaximum - 0.01;
|
max = fieldTemplate.exclusiveMaximum - 0.01;
|
||||||
}
|
}
|
||||||
return;
|
return max;
|
||||||
}, [fieldTemplate.exclusiveMaximum, fieldTemplate.maximum]);
|
}, [fieldTemplate.exclusiveMaximum, fieldTemplate.maximum]);
|
||||||
|
|
||||||
|
const step = useMemo(() => {
|
||||||
|
if (isNil(fieldTemplate.multipleOf)) {
|
||||||
|
return isIntegerField ? 1 : 0.1;
|
||||||
|
}
|
||||||
|
return fieldTemplate.multipleOf;
|
||||||
|
}, [fieldTemplate.multipleOf, isIntegerField]);
|
||||||
|
|
||||||
|
const fineStep = useMemo(() => {
|
||||||
|
if (isNil(fieldTemplate.multipleOf)) {
|
||||||
|
return isIntegerField ? 1 : 0.01;
|
||||||
|
}
|
||||||
|
return fieldTemplate.multipleOf;
|
||||||
|
}, [fieldTemplate.multipleOf, isIntegerField]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<CompositeNumberInput
|
<CompositeNumberInput
|
||||||
defaultValue={fieldTemplate.default}
|
defaultValue={fieldTemplate.default}
|
||||||
onChange={handleValueChanged}
|
onChange={handleValueChanged}
|
||||||
value={field.value}
|
value={field.value}
|
||||||
min={min ?? -NUMPY_RAND_MAX}
|
min={min}
|
||||||
max={max ?? NUMPY_RAND_MAX}
|
max={max}
|
||||||
step={isIntegerField ? 1 : 0.1}
|
step={step}
|
||||||
fineStep={isIntegerField ? 1 : 0.01}
|
fineStep={fineStep}
|
||||||
className="nodrag"
|
className="nodrag"
|
||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
|
@ -48,7 +48,7 @@ export const addIPAdapterToLinearGraph = async (
|
|||||||
if (!ipAdapter.model) {
|
if (!ipAdapter.model) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
const { id, weight, model, beginStepPct, endStepPct, controlImage } = ipAdapter;
|
const { id, weight, model, clipVisionModel, beginStepPct, endStepPct, controlImage } = ipAdapter;
|
||||||
|
|
||||||
assert(controlImage, 'IP Adapter image is required');
|
assert(controlImage, 'IP Adapter image is required');
|
||||||
|
|
||||||
@ -58,6 +58,7 @@ export const addIPAdapterToLinearGraph = async (
|
|||||||
is_intermediate: true,
|
is_intermediate: true,
|
||||||
weight: weight,
|
weight: weight,
|
||||||
ip_adapter_model: model,
|
ip_adapter_model: model,
|
||||||
|
clip_vision_model: clipVisionModel,
|
||||||
begin_step_percent: beginStepPct,
|
begin_step_percent: beginStepPct,
|
||||||
end_step_percent: endStepPct,
|
end_step_percent: endStepPct,
|
||||||
image: {
|
image: {
|
||||||
@ -83,7 +84,7 @@ export const addIPAdapterToLinearGraph = async (
|
|||||||
};
|
};
|
||||||
|
|
||||||
const buildIPAdapterMetadata = (ipAdapter: IPAdapterConfig): S['IPAdapterMetadataField'] => {
|
const buildIPAdapterMetadata = (ipAdapter: IPAdapterConfig): S['IPAdapterMetadataField'] => {
|
||||||
const { controlImage, beginStepPct, endStepPct, model, weight } = ipAdapter;
|
const { controlImage, beginStepPct, endStepPct, model, clipVisionModel, weight } = ipAdapter;
|
||||||
|
|
||||||
assert(model, 'IP Adapter model is required');
|
assert(model, 'IP Adapter model is required');
|
||||||
|
|
||||||
@ -99,6 +100,7 @@ const buildIPAdapterMetadata = (ipAdapter: IPAdapterConfig): S['IPAdapterMetadat
|
|||||||
|
|
||||||
return {
|
return {
|
||||||
ip_adapter_model: model,
|
ip_adapter_model: model,
|
||||||
|
clip_vision_model: clipVisionModel,
|
||||||
weight,
|
weight,
|
||||||
begin_step_percent: beginStepPct,
|
begin_step_percent: beginStepPct,
|
||||||
end_step_percent: endStepPct,
|
end_step_percent: endStepPct,
|
||||||
|
@ -65,6 +65,11 @@ export const buildCanvasOutpaintGraph = async (
|
|||||||
infillTileSize,
|
infillTileSize,
|
||||||
infillPatchmatchDownscaleSize,
|
infillPatchmatchDownscaleSize,
|
||||||
infillMethod,
|
infillMethod,
|
||||||
|
// infillMosaicTileWidth,
|
||||||
|
// infillMosaicTileHeight,
|
||||||
|
// infillMosaicMinColor,
|
||||||
|
// infillMosaicMaxColor,
|
||||||
|
infillColorValue,
|
||||||
clipSkip,
|
clipSkip,
|
||||||
seamlessXAxis,
|
seamlessXAxis,
|
||||||
seamlessYAxis,
|
seamlessYAxis,
|
||||||
@ -356,6 +361,28 @@ export const buildCanvasOutpaintGraph = async (
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: add mosaic back
|
||||||
|
// if (infillMethod === 'mosaic') {
|
||||||
|
// graph.nodes[INPAINT_INFILL] = {
|
||||||
|
// type: 'infill_mosaic',
|
||||||
|
// id: INPAINT_INFILL,
|
||||||
|
// is_intermediate,
|
||||||
|
// tile_width: infillMosaicTileWidth,
|
||||||
|
// tile_height: infillMosaicTileHeight,
|
||||||
|
// min_color: infillMosaicMinColor,
|
||||||
|
// max_color: infillMosaicMaxColor,
|
||||||
|
// };
|
||||||
|
// }
|
||||||
|
|
||||||
|
if (infillMethod === 'color') {
|
||||||
|
graph.nodes[INPAINT_INFILL] = {
|
||||||
|
type: 'infill_rgba',
|
||||||
|
id: INPAINT_INFILL,
|
||||||
|
color: infillColorValue,
|
||||||
|
is_intermediate,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
// Handle Scale Before Processing
|
// Handle Scale Before Processing
|
||||||
if (isUsingScaledDimensions) {
|
if (isUsingScaledDimensions) {
|
||||||
const scaledWidth: number = scaledBoundingBoxDimensions.width;
|
const scaledWidth: number = scaledBoundingBoxDimensions.width;
|
||||||
|
@ -66,6 +66,11 @@ export const buildCanvasSDXLOutpaintGraph = async (
|
|||||||
infillTileSize,
|
infillTileSize,
|
||||||
infillPatchmatchDownscaleSize,
|
infillPatchmatchDownscaleSize,
|
||||||
infillMethod,
|
infillMethod,
|
||||||
|
// infillMosaicTileWidth,
|
||||||
|
// infillMosaicTileHeight,
|
||||||
|
// infillMosaicMinColor,
|
||||||
|
// infillMosaicMaxColor,
|
||||||
|
infillColorValue,
|
||||||
seamlessXAxis,
|
seamlessXAxis,
|
||||||
seamlessYAxis,
|
seamlessYAxis,
|
||||||
canvasCoherenceMode,
|
canvasCoherenceMode,
|
||||||
@ -365,6 +370,28 @@ export const buildCanvasSDXLOutpaintGraph = async (
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: add mosaic back
|
||||||
|
// if (infillMethod === 'mosaic') {
|
||||||
|
// graph.nodes[INPAINT_INFILL] = {
|
||||||
|
// type: 'infill_mosaic',
|
||||||
|
// id: INPAINT_INFILL,
|
||||||
|
// is_intermediate,
|
||||||
|
// tile_width: infillMosaicTileWidth,
|
||||||
|
// tile_height: infillMosaicTileHeight,
|
||||||
|
// min_color: infillMosaicMinColor,
|
||||||
|
// max_color: infillMosaicMaxColor,
|
||||||
|
// };
|
||||||
|
// }
|
||||||
|
|
||||||
|
if (infillMethod === 'color') {
|
||||||
|
graph.nodes[INPAINT_INFILL] = {
|
||||||
|
type: 'infill_rgba',
|
||||||
|
id: INPAINT_INFILL,
|
||||||
|
is_intermediate,
|
||||||
|
color: infillColorValue,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
// Handle Scale Before Processing
|
// Handle Scale Before Processing
|
||||||
if (isUsingScaledDimensions) {
|
if (isUsingScaledDimensions) {
|
||||||
const scaledWidth: number = scaledBoundingBoxDimensions.width;
|
const scaledWidth: number = scaledBoundingBoxDimensions.width;
|
||||||
|
@ -0,0 +1,46 @@
|
|||||||
|
import { Box, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import IAIColorPicker from 'common/components/IAIColorPicker';
|
||||||
|
import { selectGenerationSlice, setInfillColorValue } from 'features/parameters/store/generationSlice';
|
||||||
|
import { memo, useCallback, useMemo } from 'react';
|
||||||
|
import type { RgbaColor } from 'react-colorful';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
|
const ParamInfillColorOptions = () => {
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
|
const selector = useMemo(
|
||||||
|
() =>
|
||||||
|
createSelector(selectGenerationSlice, (generation) => ({
|
||||||
|
infillColor: generation.infillColorValue,
|
||||||
|
})),
|
||||||
|
[]
|
||||||
|
);
|
||||||
|
|
||||||
|
const { infillColor } = useAppSelector(selector);
|
||||||
|
|
||||||
|
const infillMethod = useAppSelector((s) => s.generation.infillMethod);
|
||||||
|
|
||||||
|
const { t } = useTranslation();
|
||||||
|
|
||||||
|
const handleInfillColor = useCallback(
|
||||||
|
(v: RgbaColor) => {
|
||||||
|
dispatch(setInfillColorValue(v));
|
||||||
|
},
|
||||||
|
[dispatch]
|
||||||
|
);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Flex flexDir="column" gap={4}>
|
||||||
|
<FormControl isDisabled={infillMethod !== 'color'}>
|
||||||
|
<FormLabel>{t('parameters.infillColorValue')}</FormLabel>
|
||||||
|
<Box w="full" pt={2} pb={2}>
|
||||||
|
<IAIColorPicker color={infillColor} onChange={handleInfillColor} />
|
||||||
|
</Box>
|
||||||
|
</FormControl>
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(ParamInfillColorOptions);
|
@ -0,0 +1,127 @@
|
|||||||
|
import { Box, CompositeNumberInput, CompositeSlider, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import IAIColorPicker from 'common/components/IAIColorPicker';
|
||||||
|
import {
|
||||||
|
selectGenerationSlice,
|
||||||
|
setInfillMosaicMaxColor,
|
||||||
|
setInfillMosaicMinColor,
|
||||||
|
setInfillMosaicTileHeight,
|
||||||
|
setInfillMosaicTileWidth,
|
||||||
|
} from 'features/parameters/store/generationSlice';
|
||||||
|
import { memo, useCallback, useMemo } from 'react';
|
||||||
|
import type { RgbaColor } from 'react-colorful';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
|
const ParamInfillMosaicTileSize = () => {
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
|
const selector = useMemo(
|
||||||
|
() =>
|
||||||
|
createSelector(selectGenerationSlice, (generation) => ({
|
||||||
|
infillMosaicTileWidth: generation.infillMosaicTileWidth,
|
||||||
|
infillMosaicTileHeight: generation.infillMosaicTileHeight,
|
||||||
|
infillMosaicMinColor: generation.infillMosaicMinColor,
|
||||||
|
infillMosaicMaxColor: generation.infillMosaicMaxColor,
|
||||||
|
})),
|
||||||
|
[]
|
||||||
|
);
|
||||||
|
|
||||||
|
const { infillMosaicTileWidth, infillMosaicTileHeight, infillMosaicMinColor, infillMosaicMaxColor } =
|
||||||
|
useAppSelector(selector);
|
||||||
|
|
||||||
|
const infillMethod = useAppSelector((s) => s.generation.infillMethod);
|
||||||
|
|
||||||
|
const { t } = useTranslation();
|
||||||
|
|
||||||
|
const handleInfillMosaicTileWidthChange = useCallback(
|
||||||
|
(v: number) => {
|
||||||
|
dispatch(setInfillMosaicTileWidth(v));
|
||||||
|
},
|
||||||
|
[dispatch]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleInfillMosaicTileHeightChange = useCallback(
|
||||||
|
(v: number) => {
|
||||||
|
dispatch(setInfillMosaicTileHeight(v));
|
||||||
|
},
|
||||||
|
[dispatch]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleInfillMosaicMinColor = useCallback(
|
||||||
|
(v: RgbaColor) => {
|
||||||
|
dispatch(setInfillMosaicMinColor(v));
|
||||||
|
},
|
||||||
|
[dispatch]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleInfillMosaicMaxColor = useCallback(
|
||||||
|
(v: RgbaColor) => {
|
||||||
|
dispatch(setInfillMosaicMaxColor(v));
|
||||||
|
},
|
||||||
|
[dispatch]
|
||||||
|
);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Flex flexDir="column" gap={4}>
|
||||||
|
<FormControl isDisabled={infillMethod !== 'mosaic'}>
|
||||||
|
<FormLabel>{t('parameters.infillMosaicTileWidth')}</FormLabel>
|
||||||
|
<CompositeSlider
|
||||||
|
min={8}
|
||||||
|
max={256}
|
||||||
|
value={infillMosaicTileWidth}
|
||||||
|
defaultValue={64}
|
||||||
|
onChange={handleInfillMosaicTileWidthChange}
|
||||||
|
step={8}
|
||||||
|
fineStep={8}
|
||||||
|
marks
|
||||||
|
/>
|
||||||
|
<CompositeNumberInput
|
||||||
|
min={8}
|
||||||
|
max={1024}
|
||||||
|
value={infillMosaicTileWidth}
|
||||||
|
defaultValue={64}
|
||||||
|
onChange={handleInfillMosaicTileWidthChange}
|
||||||
|
step={8}
|
||||||
|
fineStep={8}
|
||||||
|
/>
|
||||||
|
</FormControl>
|
||||||
|
<FormControl isDisabled={infillMethod !== 'mosaic'}>
|
||||||
|
<FormLabel>{t('parameters.infillMosaicTileHeight')}</FormLabel>
|
||||||
|
<CompositeSlider
|
||||||
|
min={8}
|
||||||
|
max={256}
|
||||||
|
value={infillMosaicTileHeight}
|
||||||
|
defaultValue={64}
|
||||||
|
onChange={handleInfillMosaicTileHeightChange}
|
||||||
|
step={8}
|
||||||
|
fineStep={8}
|
||||||
|
marks
|
||||||
|
/>
|
||||||
|
<CompositeNumberInput
|
||||||
|
min={8}
|
||||||
|
max={1024}
|
||||||
|
value={infillMosaicTileHeight}
|
||||||
|
defaultValue={64}
|
||||||
|
onChange={handleInfillMosaicTileHeightChange}
|
||||||
|
step={8}
|
||||||
|
fineStep={8}
|
||||||
|
/>
|
||||||
|
</FormControl>
|
||||||
|
<FormControl isDisabled={infillMethod !== 'mosaic'}>
|
||||||
|
<FormLabel>{t('parameters.infillMosaicMinColor')}</FormLabel>
|
||||||
|
<Box w="full" pt={2} pb={2}>
|
||||||
|
<IAIColorPicker color={infillMosaicMinColor} onChange={handleInfillMosaicMinColor} />
|
||||||
|
</Box>
|
||||||
|
</FormControl>
|
||||||
|
<FormControl isDisabled={infillMethod !== 'mosaic'}>
|
||||||
|
<FormLabel>{t('parameters.infillMosaicMaxColor')}</FormLabel>
|
||||||
|
<Box w="full" pt={2} pb={2}>
|
||||||
|
<IAIColorPicker color={infillMosaicMaxColor} onChange={handleInfillMosaicMaxColor} />
|
||||||
|
</Box>
|
||||||
|
</FormControl>
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(ParamInfillMosaicTileSize);
|
@ -1,6 +1,8 @@
|
|||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import { memo } from 'react';
|
import { memo } from 'react';
|
||||||
|
|
||||||
|
import ParamInfillColorOptions from './ParamInfillColorOptions';
|
||||||
|
import ParamInfillMosaicOptions from './ParamInfillMosaicOptions';
|
||||||
import ParamInfillPatchmatchDownscaleSize from './ParamInfillPatchmatchDownscaleSize';
|
import ParamInfillPatchmatchDownscaleSize from './ParamInfillPatchmatchDownscaleSize';
|
||||||
import ParamInfillTilesize from './ParamInfillTilesize';
|
import ParamInfillTilesize from './ParamInfillTilesize';
|
||||||
|
|
||||||
@ -14,6 +16,14 @@ const ParamInfillOptions = () => {
|
|||||||
return <ParamInfillPatchmatchDownscaleSize />;
|
return <ParamInfillPatchmatchDownscaleSize />;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (infillMethod === 'mosaic') {
|
||||||
|
return <ParamInfillMosaicOptions />;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (infillMethod === 'color') {
|
||||||
|
return <ParamInfillColorOptions />;
|
||||||
|
}
|
||||||
|
|
||||||
return null;
|
return null;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -19,6 +19,7 @@ import type {
|
|||||||
import { getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/optimalDimension';
|
import { getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/optimalDimension';
|
||||||
import { configChanged } from 'features/system/store/configSlice';
|
import { configChanged } from 'features/system/store/configSlice';
|
||||||
import { clamp } from 'lodash-es';
|
import { clamp } from 'lodash-es';
|
||||||
|
import type { RgbaColor } from 'react-colorful';
|
||||||
import type { ImageDTO } from 'services/api/types';
|
import type { ImageDTO } from 'services/api/types';
|
||||||
|
|
||||||
import type { GenerationState } from './types';
|
import type { GenerationState } from './types';
|
||||||
@ -43,8 +44,6 @@ const initialGenerationState: GenerationState = {
|
|||||||
shouldFitToWidthHeight: true,
|
shouldFitToWidthHeight: true,
|
||||||
shouldRandomizeSeed: true,
|
shouldRandomizeSeed: true,
|
||||||
steps: 50,
|
steps: 50,
|
||||||
infillTileSize: 32,
|
|
||||||
infillPatchmatchDownscaleSize: 1,
|
|
||||||
width: 512,
|
width: 512,
|
||||||
model: null,
|
model: null,
|
||||||
vae: null,
|
vae: null,
|
||||||
@ -55,6 +54,13 @@ const initialGenerationState: GenerationState = {
|
|||||||
shouldUseCpuNoise: true,
|
shouldUseCpuNoise: true,
|
||||||
shouldShowAdvancedOptions: false,
|
shouldShowAdvancedOptions: false,
|
||||||
aspectRatio: { ...initialAspectRatioState },
|
aspectRatio: { ...initialAspectRatioState },
|
||||||
|
infillTileSize: 32,
|
||||||
|
infillPatchmatchDownscaleSize: 1,
|
||||||
|
infillMosaicTileWidth: 64,
|
||||||
|
infillMosaicTileHeight: 64,
|
||||||
|
infillMosaicMinColor: { r: 0, g: 0, b: 0, a: 1 },
|
||||||
|
infillMosaicMaxColor: { r: 255, g: 255, b: 255, a: 1 },
|
||||||
|
infillColorValue: { r: 0, g: 0, b: 0, a: 1 },
|
||||||
};
|
};
|
||||||
|
|
||||||
export const generationSlice = createSlice({
|
export const generationSlice = createSlice({
|
||||||
@ -116,15 +122,6 @@ export const generationSlice = createSlice({
|
|||||||
setCanvasCoherenceMinDenoise: (state, action: PayloadAction<number>) => {
|
setCanvasCoherenceMinDenoise: (state, action: PayloadAction<number>) => {
|
||||||
state.canvasCoherenceMinDenoise = action.payload;
|
state.canvasCoherenceMinDenoise = action.payload;
|
||||||
},
|
},
|
||||||
setInfillMethod: (state, action: PayloadAction<string>) => {
|
|
||||||
state.infillMethod = action.payload;
|
|
||||||
},
|
|
||||||
setInfillTileSize: (state, action: PayloadAction<number>) => {
|
|
||||||
state.infillTileSize = action.payload;
|
|
||||||
},
|
|
||||||
setInfillPatchmatchDownscaleSize: (state, action: PayloadAction<number>) => {
|
|
||||||
state.infillPatchmatchDownscaleSize = action.payload;
|
|
||||||
},
|
|
||||||
initialImageChanged: (state, action: PayloadAction<ImageDTO>) => {
|
initialImageChanged: (state, action: PayloadAction<ImageDTO>) => {
|
||||||
const { image_name, width, height } = action.payload;
|
const { image_name, width, height } = action.payload;
|
||||||
state.initialImage = { imageName: image_name, width, height };
|
state.initialImage = { imageName: image_name, width, height };
|
||||||
@ -206,6 +203,30 @@ export const generationSlice = createSlice({
|
|||||||
aspectRatioChanged: (state, action: PayloadAction<AspectRatioState>) => {
|
aspectRatioChanged: (state, action: PayloadAction<AspectRatioState>) => {
|
||||||
state.aspectRatio = action.payload;
|
state.aspectRatio = action.payload;
|
||||||
},
|
},
|
||||||
|
setInfillMethod: (state, action: PayloadAction<string>) => {
|
||||||
|
state.infillMethod = action.payload;
|
||||||
|
},
|
||||||
|
setInfillTileSize: (state, action: PayloadAction<number>) => {
|
||||||
|
state.infillTileSize = action.payload;
|
||||||
|
},
|
||||||
|
setInfillPatchmatchDownscaleSize: (state, action: PayloadAction<number>) => {
|
||||||
|
state.infillPatchmatchDownscaleSize = action.payload;
|
||||||
|
},
|
||||||
|
setInfillMosaicTileWidth: (state, action: PayloadAction<number>) => {
|
||||||
|
state.infillMosaicTileWidth = action.payload;
|
||||||
|
},
|
||||||
|
setInfillMosaicTileHeight: (state, action: PayloadAction<number>) => {
|
||||||
|
state.infillMosaicTileHeight = action.payload;
|
||||||
|
},
|
||||||
|
setInfillMosaicMinColor: (state, action: PayloadAction<RgbaColor>) => {
|
||||||
|
state.infillMosaicMinColor = action.payload;
|
||||||
|
},
|
||||||
|
setInfillMosaicMaxColor: (state, action: PayloadAction<RgbaColor>) => {
|
||||||
|
state.infillMosaicMaxColor = action.payload;
|
||||||
|
},
|
||||||
|
setInfillColorValue: (state, action: PayloadAction<RgbaColor>) => {
|
||||||
|
state.infillColorValue = action.payload;
|
||||||
|
},
|
||||||
},
|
},
|
||||||
extraReducers: (builder) => {
|
extraReducers: (builder) => {
|
||||||
builder.addCase(configChanged, (state, action) => {
|
builder.addCase(configChanged, (state, action) => {
|
||||||
@ -249,8 +270,6 @@ export const {
|
|||||||
setShouldFitToWidthHeight,
|
setShouldFitToWidthHeight,
|
||||||
setShouldRandomizeSeed,
|
setShouldRandomizeSeed,
|
||||||
setSteps,
|
setSteps,
|
||||||
setInfillTileSize,
|
|
||||||
setInfillPatchmatchDownscaleSize,
|
|
||||||
initialImageChanged,
|
initialImageChanged,
|
||||||
modelChanged,
|
modelChanged,
|
||||||
vaeSelected,
|
vaeSelected,
|
||||||
@ -264,6 +283,13 @@ export const {
|
|||||||
heightChanged,
|
heightChanged,
|
||||||
widthRecalled,
|
widthRecalled,
|
||||||
heightRecalled,
|
heightRecalled,
|
||||||
|
setInfillTileSize,
|
||||||
|
setInfillPatchmatchDownscaleSize,
|
||||||
|
setInfillMosaicTileWidth,
|
||||||
|
setInfillMosaicTileHeight,
|
||||||
|
setInfillMosaicMinColor,
|
||||||
|
setInfillMosaicMaxColor,
|
||||||
|
setInfillColorValue,
|
||||||
} = generationSlice.actions;
|
} = generationSlice.actions;
|
||||||
|
|
||||||
export const { selectOptimalDimension } = generationSlice.selectors;
|
export const { selectOptimalDimension } = generationSlice.selectors;
|
||||||
|
@ -17,6 +17,7 @@ import type {
|
|||||||
ParameterVAEModel,
|
ParameterVAEModel,
|
||||||
ParameterWidth,
|
ParameterWidth,
|
||||||
} from 'features/parameters/types/parameterSchemas';
|
} from 'features/parameters/types/parameterSchemas';
|
||||||
|
import type { RgbaColor } from 'react-colorful';
|
||||||
|
|
||||||
export interface GenerationState {
|
export interface GenerationState {
|
||||||
_version: 2;
|
_version: 2;
|
||||||
@ -39,8 +40,6 @@ export interface GenerationState {
|
|||||||
shouldFitToWidthHeight: boolean;
|
shouldFitToWidthHeight: boolean;
|
||||||
shouldRandomizeSeed: boolean;
|
shouldRandomizeSeed: boolean;
|
||||||
steps: ParameterSteps;
|
steps: ParameterSteps;
|
||||||
infillTileSize: number;
|
|
||||||
infillPatchmatchDownscaleSize: number;
|
|
||||||
width: ParameterWidth;
|
width: ParameterWidth;
|
||||||
model: ParameterModel | null;
|
model: ParameterModel | null;
|
||||||
vae: ParameterVAEModel | null;
|
vae: ParameterVAEModel | null;
|
||||||
@ -51,6 +50,13 @@ export interface GenerationState {
|
|||||||
shouldUseCpuNoise: boolean;
|
shouldUseCpuNoise: boolean;
|
||||||
shouldShowAdvancedOptions: boolean;
|
shouldShowAdvancedOptions: boolean;
|
||||||
aspectRatio: AspectRatioState;
|
aspectRatio: AspectRatioState;
|
||||||
|
infillTileSize: number;
|
||||||
|
infillPatchmatchDownscaleSize: number;
|
||||||
|
infillMosaicTileWidth: number;
|
||||||
|
infillMosaicTileHeight: number;
|
||||||
|
infillMosaicMinColor: RgbaColor;
|
||||||
|
infillMosaicMaxColor: RgbaColor;
|
||||||
|
infillColorValue: RgbaColor;
|
||||||
}
|
}
|
||||||
|
|
||||||
export type PayloadActionWithOptimalDimension<T = void> = PayloadAction<T, string, { optimalDimension: number }>;
|
export type PayloadActionWithOptimalDimension<T = void> = PayloadAction<T, string, { optimalDimension: number }>;
|
||||||
|
@ -61,7 +61,7 @@ export const AdvancedSettingsAccordion = memo(() => {
|
|||||||
|
|
||||||
return (
|
return (
|
||||||
<StandaloneAccordion label={t('accordions.advanced.title')} badges={badges} isOpen={isOpen} onToggle={onToggle}>
|
<StandaloneAccordion label={t('accordions.advanced.title')} badges={badges} isOpen={isOpen} onToggle={onToggle}>
|
||||||
<Flex gap={4} alignItems="center" p={4} flexDir="column">
|
<Flex gap={4} alignItems="center" p={4} flexDir="column" data-testid="advanced-settings-accordion">
|
||||||
<Flex gap={4} w="full">
|
<Flex gap={4} w="full">
|
||||||
<ParamVAEModelSelect />
|
<ParamVAEModelSelect />
|
||||||
<ParamVAEPrecision />
|
<ParamVAEPrecision />
|
||||||
|
@ -77,7 +77,7 @@ export const ControlSettingsAccordion: React.FC = memo(() => {
|
|||||||
|
|
||||||
return (
|
return (
|
||||||
<StandaloneAccordion label={t('accordions.control.title')} badges={badges} isOpen={isOpen} onToggle={onToggle}>
|
<StandaloneAccordion label={t('accordions.control.title')} badges={badges} isOpen={isOpen} onToggle={onToggle}>
|
||||||
<Flex gap={2} p={4} flexDir="column">
|
<Flex gap={2} p={4} flexDir="column" data-testid="control-accordion">
|
||||||
<ButtonGroup size="sm" w="full" justifyContent="space-between" variant="ghost" isAttached={false}>
|
<ButtonGroup size="sm" w="full" justifyContent="space-between" variant="ghost" isAttached={false}>
|
||||||
<Button
|
<Button
|
||||||
tooltip={t('controlnet.addControlNet')}
|
tooltip={t('controlnet.addControlNet')}
|
||||||
|
@ -53,7 +53,7 @@ export const GenerationSettingsAccordion = memo(() => {
|
|||||||
isOpen={isOpenAccordion}
|
isOpen={isOpenAccordion}
|
||||||
onToggle={onToggleAccordion}
|
onToggle={onToggleAccordion}
|
||||||
>
|
>
|
||||||
<Box px={4} pt={4}>
|
<Box px={4} pt={4} data-testid="generation-accordion">
|
||||||
<Flex gap={4} flexDir="column">
|
<Flex gap={4} flexDir="column">
|
||||||
<Flex gap={4} alignItems="center">
|
<Flex gap={4} alignItems="center">
|
||||||
<ParamMainModelSelect />
|
<ParamMainModelSelect />
|
||||||
|
@ -83,7 +83,7 @@ export const ImageSettingsAccordion = memo(() => {
|
|||||||
isOpen={isOpenAccordion}
|
isOpen={isOpenAccordion}
|
||||||
onToggle={onToggleAccordion}
|
onToggle={onToggleAccordion}
|
||||||
>
|
>
|
||||||
<Flex px={4} pt={4} w="full" h="full" flexDir="column">
|
<Flex px={4} pt={4} w="full" h="full" flexDir="column" data-testid="image-settings-accordion">
|
||||||
{activeTabName === 'unifiedCanvas' ? <ImageSizeCanvas /> : <ImageSizeLinear />}
|
{activeTabName === 'unifiedCanvas' ? <ImageSizeCanvas /> : <ImageSizeLinear />}
|
||||||
<Expander label={t('accordions.advanced.options')} isOpen={isOpenExpander} onToggle={onToggleExpander}>
|
<Expander label={t('accordions.advanced.options')} isOpen={isOpenExpander} onToggle={onToggleExpander}>
|
||||||
<Flex gap={4} pb={4} flexDir="column">
|
<Flex gap={4} pb={4} flexDir="column">
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||||
import { createSlice } from '@reduxjs/toolkit';
|
import { createSlice } from '@reduxjs/toolkit';
|
||||||
import type { PersistConfig, RootState } from 'app/store/store';
|
import type { PersistConfig, RootState } from 'app/store/store';
|
||||||
|
import { workflowLoadRequested } from 'features/nodes/store/actions';
|
||||||
import { initialImageChanged } from 'features/parameters/store/generationSlice';
|
import { initialImageChanged } from 'features/parameters/store/generationSlice';
|
||||||
|
|
||||||
import type { InvokeTabName } from './tabMap';
|
import type { InvokeTabName } from './tabMap';
|
||||||
@ -45,6 +46,9 @@ export const uiSlice = createSlice({
|
|||||||
builder.addCase(initialImageChanged, (state) => {
|
builder.addCase(initialImageChanged, (state) => {
|
||||||
state.activeTab = 'img2img';
|
state.activeTab = 'img2img';
|
||||||
});
|
});
|
||||||
|
builder.addCase(workflowLoadRequested, (state) => {
|
||||||
|
state.activeTab = 'nodes';
|
||||||
|
});
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -195,6 +195,7 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
url: buildModelsUrl(`scan_folder?${folderQueryStr}`),
|
url: buildModelsUrl(`scan_folder?${folderQueryStr}`),
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
|
providesTags: [{ type: 'ModelScanFolderResults', id: LIST_TAG }],
|
||||||
}),
|
}),
|
||||||
getHuggingFaceModels: build.query<GetHuggingFaceModelsResponse, string>({
|
getHuggingFaceModels: build.query<GetHuggingFaceModelsResponse, string>({
|
||||||
query: (hugging_face_repo) => {
|
query: (hugging_face_repo) => {
|
||||||
|
@ -29,6 +29,7 @@ const tagTypes = [
|
|||||||
'InvocationCacheStatus',
|
'InvocationCacheStatus',
|
||||||
'ModelConfig',
|
'ModelConfig',
|
||||||
'ModelInstalls',
|
'ModelInstalls',
|
||||||
|
'ModelScanFolderResults',
|
||||||
'T2IAdapterModel',
|
'T2IAdapterModel',
|
||||||
'MainModel',
|
'MainModel',
|
||||||
'VaeModel',
|
'VaeModel',
|
||||||
|
@ -46,7 +46,7 @@ export type LoRAModelConfig = S['LoRADiffusersConfig'] | S['LoRALyCORISConfig'];
|
|||||||
// TODO(MM2): Can we rename this from Vae -> VAE
|
// TODO(MM2): Can we rename this from Vae -> VAE
|
||||||
export type VAEModelConfig = S['VAECheckpointConfig'] | S['VAEDiffusersConfig'];
|
export type VAEModelConfig = S['VAECheckpointConfig'] | S['VAEDiffusersConfig'];
|
||||||
export type ControlNetModelConfig = S['ControlNetDiffusersConfig'] | S['ControlNetCheckpointConfig'];
|
export type ControlNetModelConfig = S['ControlNetDiffusersConfig'] | S['ControlNetCheckpointConfig'];
|
||||||
export type IPAdapterModelConfig = S['IPAdapterConfig'];
|
export type IPAdapterModelConfig = S['IPAdapterInvokeAIConfig'] | S['IPAdapterCheckpointConfig'];
|
||||||
export type T2IAdapterModelConfig = S['T2IAdapterConfig'];
|
export type T2IAdapterModelConfig = S['T2IAdapterConfig'];
|
||||||
type TextualInversionModelConfig = S['TextualInversionFileConfig'] | S['TextualInversionFolderConfig'];
|
type TextualInversionModelConfig = S['TextualInversionFileConfig'] | S['TextualInversionFolderConfig'];
|
||||||
type DiffusersModelConfig = S['MainDiffusersConfig'];
|
type DiffusersModelConfig = S['MainDiffusersConfig'];
|
||||||
|
@ -27,6 +27,7 @@ from invokeai.app.invocations.fields import (
|
|||||||
OutputField,
|
OutputField,
|
||||||
UIComponent,
|
UIComponent,
|
||||||
UIType,
|
UIType,
|
||||||
|
WithBoard,
|
||||||
WithMetadata,
|
WithMetadata,
|
||||||
WithWorkflow,
|
WithWorkflow,
|
||||||
)
|
)
|
||||||
@ -105,6 +106,7 @@ __all__ = [
|
|||||||
"OutputField",
|
"OutputField",
|
||||||
"UIComponent",
|
"UIComponent",
|
||||||
"UIType",
|
"UIType",
|
||||||
|
"WithBoard",
|
||||||
"WithMetadata",
|
"WithMetadata",
|
||||||
"WithWorkflow",
|
"WithWorkflow",
|
||||||
# invokeai.app.invocations.latent
|
# invokeai.app.invocations.latent
|
||||||
|
@ -1 +1 @@
|
|||||||
__version__ = "4.0.1"
|
__version__ = "4.0.4"
|
||||||
|
@ -87,9 +87,11 @@ def test_rename(
|
|||||||
key = mm2_installer.install_path(embedding_file)
|
key = mm2_installer.install_path(embedding_file)
|
||||||
model_record = store.get_model(key)
|
model_record = store.get_model(key)
|
||||||
assert model_record.path.endswith("sd-1/embedding/test_embedding.safetensors")
|
assert model_record.path.endswith("sd-1/embedding/test_embedding.safetensors")
|
||||||
store.update_model(key, ModelRecordChanges(name="new_name.safetensors", base=BaseModelType("sd-2")))
|
store.update_model(key, ModelRecordChanges(name="new model name", base=BaseModelType("sd-2")))
|
||||||
new_model_record = mm2_installer.sync_model_path(key)
|
new_model_record = mm2_installer.sync_model_path(key)
|
||||||
assert new_model_record.path.endswith("sd-2/embedding/new_name.safetensors")
|
# Renaming the model record shouldn't rename the file
|
||||||
|
assert new_model_record.name == "new model name"
|
||||||
|
assert new_model_record.path.endswith("sd-2/embedding/test_embedding.safetensors")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|