mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Compare commits
1 Commits
v4.0.2
...
bugfix-for
Author | SHA1 | Date | |
---|---|---|---|
0b238b1ece |
@ -18,22 +18,6 @@ Note that any releases marked as _pre-release_ are in a beta state. You may expe
|
||||
|
||||
The Model Manager tab in the UI provides a few ways to install models, including using your already-downloaded models. You'll see a popup directing you there on first startup. For more information, see the [model install docs].
|
||||
|
||||
## Missing models after updating to v4
|
||||
|
||||
If you find some models are missing after updating to v4, it's likely they weren't correctly registered before the update and didn't get picked up in the migration.
|
||||
|
||||
You can use the `Scan Folder` tab in the Model Manager UI to fix this. The models will either be in the old, now-unused `autoimport` folder, or your `models` folder.
|
||||
|
||||
- Find and copy your install's old `autoimport` folder path, install the main install folder.
|
||||
- Go to the Model Manager and click `Scan Folder`.
|
||||
- Paste the path and scan.
|
||||
- IMPORTANT: Uncheck `Inplace install`.
|
||||
- Click `Install All` to install all found models, or just install the models you want.
|
||||
|
||||
Next, find and copy your install's `models` folder path (this could be your custom models folder path, or the `models` folder inside the main install folder).
|
||||
|
||||
Follow the same steps to scan and import the missing models.
|
||||
|
||||
## Slow generation
|
||||
|
||||
- Check the [system requirements] to ensure that your system is capable of generating images.
|
||||
|
@ -44,7 +44,7 @@ The installation process is simple, with a few prompts:
|
||||
|
||||
- Select the version to install. Unless you have a specific reason to install a specific version, select the default (the latest version).
|
||||
- Select location for the install. Be sure you have enough space in this folder for the base application, as described in the [installation requirements].
|
||||
- Select a GPU device.
|
||||
- Select a GPU device. If you are unsure, you can let the installer figure it out.
|
||||
|
||||
!!! info "Slow Installation"
|
||||
|
||||
|
@ -6,7 +6,11 @@
|
||||
|
||||
## Introduction
|
||||
|
||||
InvokeAI is distributed as a python package on PyPI, installable with `pip`. There are a few things that are handled by the installer and launcher that you'll need to manage manually, described in this guide.
|
||||
!!! tip "Conda"
|
||||
|
||||
As of InvokeAI v2.3.0 installation using the `conda` package manager is no longer being supported. It will likely still work, but we are not testing this installation method.
|
||||
|
||||
InvokeAI is distributed as a python package on PyPI, installable with `pip`. There are a few things that are handled by the installer that you'll need to manage manually, described in this guide.
|
||||
|
||||
### Requirements
|
||||
|
||||
@ -36,11 +40,11 @@ Before you start, go through the [installation requirements].
|
||||
|
||||
1. Enter the root (invokeai) directory and create a virtual Python environment within it named `.venv`.
|
||||
|
||||
!!! warning "Virtual Environment Location"
|
||||
!!! info "Virtual Environment Location"
|
||||
|
||||
While you may create the virtual environment anywhere in the file system, we recommend that you create it within the root directory as shown here. This allows the application to automatically detect its data directories.
|
||||
|
||||
If you choose a different location for the venv, then you _must_ set the `INVOKEAI_ROOT` environment variable or specify the root directory using the `--root` CLI arg.
|
||||
If you choose a different location for the venv, then you must set the `INVOKEAI_ROOT` environment variable or pass the directory using the `--root` CLI arg.
|
||||
|
||||
```terminal
|
||||
cd $INVOKEAI_ROOT
|
||||
@ -77,23 +81,31 @@ Before you start, go through the [installation requirements].
|
||||
python3 -m pip install --upgrade pip
|
||||
```
|
||||
|
||||
1. Install the InvokeAI Package. The base command is `pip install InvokeAI --use-pep517`, but you may need to change this depending on your system and the desired features.
|
||||
1. Install the InvokeAI Package. The `--extra-index-url` option is used to select the correct `torch` backend:
|
||||
|
||||
- You may need to provide an [extra index URL]. Select your platform configuration using [this tool on the PyTorch website]. Copy the `--extra-index-url` string from this and append it to your install command.
|
||||
=== "CUDA (NVidia)"
|
||||
|
||||
!!! example "Install with an extra index URL"
|
||||
```bash
|
||||
pip install "InvokeAI[xformers]" --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu121
|
||||
```
|
||||
|
||||
```bash
|
||||
pip install InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu121
|
||||
```
|
||||
=== "ROCm (AMD)"
|
||||
|
||||
- If you have a CUDA GPU and want to install with `xformers`, you need to add an option to the package name. Note that `xformers` is not necessary. PyTorch includes an implementation of the SDP attention algorithm with the same performance.
|
||||
```bash
|
||||
pip install InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/rocm5.6
|
||||
```
|
||||
|
||||
!!! example "Install with `xformers`"
|
||||
=== "CPU (Intel Macs & non-GPU systems)"
|
||||
|
||||
```bash
|
||||
pip install "InvokeAI[xformers]" --use-pep517
|
||||
```
|
||||
```bash
|
||||
pip install InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/cpu
|
||||
```
|
||||
|
||||
=== "MPS (Apple Silicon)"
|
||||
|
||||
```bash
|
||||
pip install InvokeAI --use-pep517
|
||||
```
|
||||
|
||||
1. Deactivate and reactivate your runtime directory so that the invokeai-specific commands become available in the environment:
|
||||
|
||||
@ -114,6 +126,37 @@ Before you start, go through the [installation requirements].
|
||||
|
||||
Run `invokeai-web` to start the UI. You must activate the virtual environment before running the app.
|
||||
|
||||
!!! warning
|
||||
If the virtual environment you selected is NOT inside `INVOKEAI_ROOT`, then you must specify the path to the root directory by adding
|
||||
`--root_dir \path\to\invokeai`.
|
||||
|
||||
If the virtual environment is _not_ inside the root directory, then you _must_ specify the path to the root directory with `--root_dir \path\to\invokeai` or the `INVOKEAI_ROOT` environment variable.
|
||||
!!! tip
|
||||
|
||||
You can permanently set the location of the runtime directory
|
||||
by setting the environment variable `INVOKEAI_ROOT` to the
|
||||
path of the directory. As mentioned previously, this is
|
||||
recommended if your virtual environment is located outside of
|
||||
your runtime directory.
|
||||
|
||||
## Unsupported Conda Install
|
||||
|
||||
Congratulations, you found the "secret" Conda installation instructions. If you really **really** want to use Conda with InvokeAI, you can do so using this unsupported recipe:
|
||||
|
||||
```sh
|
||||
mkdir ~/invokeai
|
||||
conda create -n invokeai python=3.11
|
||||
conda activate invokeai
|
||||
# Adjust this as described above for the appropriate torch backend
|
||||
pip install InvokeAI[xformers] --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu121
|
||||
invokeai-web --root ~/invokeai
|
||||
```
|
||||
|
||||
The `pip install` command shown in this recipe is for Linux/Windows
|
||||
systems with an NVIDIA GPU. See step (6) above for the command to use
|
||||
with other platforms/GPU combinations. If you don't wish to pass the
|
||||
`--root` argument to `invokeai` with each launch, you may set the
|
||||
environment variable `INVOKEAI_ROOT` to point to the installation directory.
|
||||
|
||||
Note that if you run into problems with the Conda installation, the InvokeAI
|
||||
staff will **not** be able to help you out. Caveat Emptor!
|
||||
|
||||
[installation requirements]: INSTALL_REQUIREMENTS.md
|
||||
|
@ -32,5 +32,5 @@ As described in the [frontend dev toolchain] docs, you can run the UI using a de
|
||||
[Fork and clone]: https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/fork-a-repo
|
||||
[InvokeAI repo]: https://github.com/invoke-ai/InvokeAI
|
||||
[frontend dev toolchain]: ../contributing/frontend/OVERVIEW.md
|
||||
[manual installation]: ./020_INSTALL_MANUAL.md
|
||||
[manual installation]: installation/020_INSTALL_MANUAL.md
|
||||
[editable install]: https://pip.pypa.io/en/latest/cli/pip_install/#cmdoption-e
|
||||
|
@ -3,7 +3,6 @@
|
||||
InvokeAI installer script
|
||||
"""
|
||||
|
||||
import locale
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
@ -317,9 +316,7 @@ def upgrade_pip(venv_path: Path) -> str | None:
|
||||
python = str(venv_path.expanduser().resolve() / python)
|
||||
|
||||
try:
|
||||
result = subprocess.check_output([python, "-m", "pip", "install", "--upgrade", "pip"]).decode(
|
||||
encoding=locale.getpreferredencoding()
|
||||
)
|
||||
result = subprocess.check_output([python, "-m", "pip", "install", "--upgrade", "pip"]).decode()
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(e)
|
||||
result = None
|
||||
@ -407,29 +404,22 @@ def get_torch_source() -> Tuple[str | None, str | None]:
|
||||
# device can be one of: "cuda", "rocm", "cpu", "cuda_and_dml, autodetect"
|
||||
device = select_gpu()
|
||||
|
||||
# The correct extra index URLs for torch are inconsistent, see https://pytorch.org/get-started/locally/#start-locally
|
||||
|
||||
url = None
|
||||
optional_modules: str | None = None
|
||||
optional_modules = "[onnx]"
|
||||
if OS == "Linux":
|
||||
if device.value == "rocm":
|
||||
url = "https://download.pytorch.org/whl/rocm5.6"
|
||||
elif device.value == "cpu":
|
||||
url = "https://download.pytorch.org/whl/cpu"
|
||||
elif device.value == "cuda":
|
||||
# CUDA uses the default PyPi index
|
||||
optional_modules = "[xformers,onnx-cuda]"
|
||||
|
||||
elif OS == "Windows":
|
||||
if device.value == "cuda":
|
||||
url = "https://download.pytorch.org/whl/cu121"
|
||||
optional_modules = "[xformers,onnx-cuda]"
|
||||
elif device.value == "cpu":
|
||||
# CPU uses the default PyPi index, no optional modules
|
||||
pass
|
||||
elif OS == "Darwin":
|
||||
# macOS uses the default PyPi index, no optional modules
|
||||
pass
|
||||
if device.value == "cuda_and_dml":
|
||||
url = "https://download.pytorch.org/whl/cu121"
|
||||
optional_modules = "[xformers,onnx-directml]"
|
||||
|
||||
# Fall back to defaults
|
||||
# in all other cases, Torch wheels should be coming from PyPi as of Torch 1.13
|
||||
|
||||
return (url, optional_modules)
|
||||
|
@ -207,8 +207,10 @@ def dest_path(dest: Optional[str | Path] = None) -> Path | None:
|
||||
|
||||
class GpuType(Enum):
|
||||
CUDA = "cuda"
|
||||
CUDA_AND_DML = "cuda_and_dml"
|
||||
ROCM = "rocm"
|
||||
CPU = "cpu"
|
||||
AUTODETECT = "autodetect"
|
||||
|
||||
|
||||
def select_gpu() -> GpuType:
|
||||
@ -224,6 +226,10 @@ def select_gpu() -> GpuType:
|
||||
"an [gold1 b]NVIDIA[/] GPU (using CUDA™)",
|
||||
GpuType.CUDA,
|
||||
)
|
||||
nvidia_with_dml = (
|
||||
"an [gold1 b]NVIDIA[/] GPU (using CUDA™, and DirectML™ for ONNX) -- ALPHA",
|
||||
GpuType.CUDA_AND_DML,
|
||||
)
|
||||
amd = (
|
||||
"an [gold1 b]AMD[/] GPU (using ROCm™)",
|
||||
GpuType.ROCM,
|
||||
@ -232,19 +238,27 @@ def select_gpu() -> GpuType:
|
||||
"Do not install any GPU support, use CPU for generation (slow)",
|
||||
GpuType.CPU,
|
||||
)
|
||||
autodetect = (
|
||||
"I'm not sure what to choose",
|
||||
GpuType.AUTODETECT,
|
||||
)
|
||||
|
||||
options = []
|
||||
if OS == "Windows":
|
||||
options = [nvidia, cpu]
|
||||
options = [nvidia, nvidia_with_dml, cpu]
|
||||
if OS == "Linux":
|
||||
options = [nvidia, amd, cpu]
|
||||
elif OS == "Darwin":
|
||||
options = [cpu]
|
||||
# future CoreML?
|
||||
|
||||
if len(options) == 1:
|
||||
print(f'Your platform [gold1]{OS}-{ARCH}[/] only supports the "{options[0][1]}" driver. Proceeding with that.')
|
||||
return options[0][1]
|
||||
|
||||
# "I don't know" is always added the last option
|
||||
options.append(autodetect) # type: ignore
|
||||
|
||||
options = {str(i): opt for i, opt in enumerate(options, 1)}
|
||||
|
||||
console.rule(":space_invader: GPU (Graphics Card) selection :space_invader:")
|
||||
@ -278,6 +292,11 @@ def select_gpu() -> GpuType:
|
||||
),
|
||||
)
|
||||
|
||||
if options[choice][1] is GpuType.AUTODETECT:
|
||||
console.print(
|
||||
"No problem. We will install CUDA support first :crossed_fingers: If Invoke does not detect a GPU, please re-run the installer and select one of the other GPU types."
|
||||
)
|
||||
|
||||
return options[choice][1]
|
||||
|
||||
|
||||
|
@ -219,13 +219,28 @@ async def scan_for_models(
|
||||
non_core_model_paths = [p for p in found_model_paths if not p.is_relative_to(core_models_path)]
|
||||
|
||||
installed_models = ApiDependencies.invoker.services.model_manager.store.search_by_attr()
|
||||
resolved_installed_model_paths: list[str] = []
|
||||
installed_model_sources: list[str] = []
|
||||
|
||||
# This call lists all installed models.
|
||||
for model in installed_models:
|
||||
path = pathlib.Path(model.path)
|
||||
# If the model has a source, we need to add it to the list of installed sources.
|
||||
if model.source:
|
||||
installed_model_sources.append(model.source)
|
||||
# If the path is not absolute, that means it is in the app models directory, and we need to join it with
|
||||
# the models path before resolving.
|
||||
if not path.is_absolute():
|
||||
resolved_installed_model_paths.append(str(pathlib.Path(models_path, path).resolve()))
|
||||
continue
|
||||
resolved_installed_model_paths.append(str(path.resolve()))
|
||||
|
||||
scan_results: list[FoundModel] = []
|
||||
|
||||
# Check if the model is installed by comparing paths, appending to the scan result.
|
||||
# Check if the model is installed by comparing the resolved paths, appending to the scan result.
|
||||
for p in non_core_model_paths:
|
||||
path = str(p)
|
||||
is_installed = any(str(models_path / m.path) == path for m in installed_models)
|
||||
is_installed = path in resolved_installed_model_paths or path in installed_model_sources
|
||||
found_model = FoundModel(path=path, is_installed=is_installed)
|
||||
scan_results.append(found_model)
|
||||
except Exception as e:
|
||||
|
@ -3,7 +3,6 @@ Invoke-managed custom node loader. See README.md for more information.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import traceback
|
||||
from importlib.util import module_from_spec, spec_from_file_location
|
||||
from pathlib import Path
|
||||
|
||||
@ -42,15 +41,11 @@ for d in Path(__file__).parent.iterdir():
|
||||
|
||||
logger.info(f"Loading node pack {module_name}")
|
||||
|
||||
try:
|
||||
module = module_from_spec(spec)
|
||||
sys.modules[spec.name] = module
|
||||
spec.loader.exec_module(module)
|
||||
module = module_from_spec(spec)
|
||||
sys.modules[spec.name] = module
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
loaded_count += 1
|
||||
except Exception:
|
||||
full_error = traceback.format_exc()
|
||||
logger.error(f"Failed to load node pack {module_name}:\n{full_error}")
|
||||
loaded_count += 1
|
||||
|
||||
del init, module_name
|
||||
|
||||
|
@ -1,22 +1,21 @@
|
||||
from builtins import float
|
||||
from typing import List, Literal, Union
|
||||
from typing import List, Union
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.invocations.primitives import ImageField
|
||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
IPAdapterCheckpointConfig,
|
||||
IPAdapterInvokeAIConfig,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, IPAdapterConfig, ModelType
|
||||
|
||||
|
||||
class IPAdapterField(BaseModel):
|
||||
@ -49,15 +48,12 @@ class IPAdapterOutput(BaseInvocationOutput):
|
||||
ip_adapter: IPAdapterField = OutputField(description=FieldDescriptions.ip_adapter, title="IP-Adapter")
|
||||
|
||||
|
||||
CLIP_VISION_MODEL_MAP = {"ViT-H": "ip_adapter_sd_image_encoder", "ViT-G": "ip_adapter_sdxl_image_encoder"}
|
||||
|
||||
|
||||
@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.2.2")
|
||||
class IPAdapterInvocation(BaseInvocation):
|
||||
"""Collects IP-Adapter info to pass to other nodes."""
|
||||
|
||||
# Inputs
|
||||
image: Union[ImageField, List[ImageField]] = InputField(description="The IP-Adapter image prompt(s).", ui_order=1)
|
||||
image: Union[ImageField, List[ImageField]] = InputField(description="The IP-Adapter image prompt(s).")
|
||||
ip_adapter_model: ModelIdentifierField = InputField(
|
||||
description="The IP-Adapter model.",
|
||||
title="IP-Adapter Model",
|
||||
@ -65,11 +61,7 @@ class IPAdapterInvocation(BaseInvocation):
|
||||
ui_order=-1,
|
||||
ui_type=UIType.IPAdapterModel,
|
||||
)
|
||||
clip_vision_model: Literal["auto", "ViT-H", "ViT-G"] = InputField(
|
||||
description="CLIP Vision model to use. Overrides model settings. Mandatory for checkpoint models.",
|
||||
default="auto",
|
||||
ui_order=2,
|
||||
)
|
||||
|
||||
weight: Union[float, List[float]] = InputField(
|
||||
default=1, description="The weight given to the IP-Adapter", title="Weight"
|
||||
)
|
||||
@ -94,21 +86,10 @@ class IPAdapterInvocation(BaseInvocation):
|
||||
def invoke(self, context: InvocationContext) -> IPAdapterOutput:
|
||||
# Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model.
|
||||
ip_adapter_info = context.models.get_config(self.ip_adapter_model.key)
|
||||
assert isinstance(ip_adapter_info, (IPAdapterInvokeAIConfig, IPAdapterCheckpointConfig))
|
||||
|
||||
if self.clip_vision_model == "auto":
|
||||
if isinstance(ip_adapter_info, IPAdapterInvokeAIConfig):
|
||||
image_encoder_model_id = ip_adapter_info.image_encoder_model_id
|
||||
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"You need to set the appropriate CLIP Vision model for checkpoint IP Adapter models."
|
||||
)
|
||||
else:
|
||||
image_encoder_model_name = CLIP_VISION_MODEL_MAP[self.clip_vision_model]
|
||||
|
||||
assert isinstance(ip_adapter_info, IPAdapterConfig)
|
||||
image_encoder_model_id = ip_adapter_info.image_encoder_model_id
|
||||
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
|
||||
image_encoder_model = self._get_image_encoder(context, image_encoder_model_name)
|
||||
|
||||
return IPAdapterOutput(
|
||||
ip_adapter=IPAdapterField(
|
||||
image=self.image,
|
||||
@ -121,25 +102,19 @@ class IPAdapterInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
def _get_image_encoder(self, context: InvocationContext, image_encoder_model_name: str) -> AnyModelConfig:
|
||||
image_encoder_models = context.models.search_by_attrs(
|
||||
name=image_encoder_model_name, base=BaseModelType.Any, type=ModelType.CLIPVision
|
||||
)
|
||||
|
||||
if not len(image_encoder_models) > 0:
|
||||
context.logger.warning(
|
||||
f"The image encoder required by this IP Adapter ({image_encoder_model_name}) is not installed. \
|
||||
Downloading and installing now. This may take a while."
|
||||
)
|
||||
|
||||
installer = context._services.model_manager.install
|
||||
job = installer.heuristic_import(f"InvokeAI/{image_encoder_model_name}")
|
||||
installer.wait_for_job(job, timeout=600) # Wait for up to 10 minutes
|
||||
found = False
|
||||
while not found:
|
||||
image_encoder_models = context.models.search_by_attrs(
|
||||
name=image_encoder_model_name, base=BaseModelType.Any, type=ModelType.CLIPVision
|
||||
)
|
||||
|
||||
if len(image_encoder_models) == 0:
|
||||
context.logger.error("Error while fetching CLIP Vision Image Encoder")
|
||||
assert len(image_encoder_models) == 1
|
||||
|
||||
found = len(image_encoder_models) > 0
|
||||
if not found:
|
||||
context.logger.warning(
|
||||
f"The image encoder required by this IP Adapter ({image_encoder_model_name}) is not installed."
|
||||
)
|
||||
context.logger.warning("Downloading and installing now. This may take a while.")
|
||||
installer = context._services.model_manager.install
|
||||
job = installer.heuristic_import(f"InvokeAI/{image_encoder_model_name}")
|
||||
installer.wait_for_job(job, timeout=600) # wait up to 10 minutes - then raise a TimeoutException
|
||||
assert len(image_encoder_models) == 1
|
||||
return image_encoder_models[0]
|
||||
|
@ -43,7 +43,11 @@ from invokeai.app.invocations.fields import (
|
||||
WithMetadata,
|
||||
)
|
||||
from invokeai.app.invocations.ip_adapter import IPAdapterField
|
||||
from invokeai.app.invocations.primitives import DenoiseMaskOutput, ImageOutput, LatentsOutput
|
||||
from invokeai.app.invocations.primitives import (
|
||||
DenoiseMaskOutput,
|
||||
ImageOutput,
|
||||
LatentsOutput,
|
||||
)
|
||||
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||
@ -64,7 +68,12 @@ from ...backend.stable_diffusion.diffusers_pipeline import (
|
||||
)
|
||||
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||
from ...backend.util.devices import choose_precision, choose_torch_device
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from .controlnet_image_processors import ControlField
|
||||
from .model import ModelIdentifierField, UNetField, VAEField
|
||||
|
||||
|
@ -2,8 +2,16 @@ from typing import Any, Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
||||
from invokeai.app.invocations.controlnet_image_processors import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.controlnet_image_processors import (
|
||||
CONTROLNET_MODE_VALUES,
|
||||
CONTROLNET_RESIZE_VALUES,
|
||||
)
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
ImageField,
|
||||
@ -35,7 +43,6 @@ class IPAdapterMetadataField(BaseModel):
|
||||
|
||||
image: ImageField = Field(description="The IP-Adapter image prompt.")
|
||||
ip_adapter_model: ModelIdentifierField = Field(description="The IP-Adapter model.")
|
||||
clip_vision_model: Literal["ViT-H", "ViT-G"] = Field(description="The CLIP Vision model")
|
||||
weight: Union[float, list[float]] = Field(description="The weight given to the IP-Adapter")
|
||||
begin_step_percent: float = Field(description="When the IP-Adapter is first applied (% of total steps)")
|
||||
end_step_percent: float = Field(description="When the IP-Adapter is last applied (% of total steps)")
|
||||
|
@ -3,7 +3,6 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import locale
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
@ -318,10 +317,11 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
@staticmethod
|
||||
def find_root() -> Path:
|
||||
"""Choose the runtime root directory when not specified on command line or init file."""
|
||||
venv = Path(os.environ.get("VIRTUAL_ENV") or ".")
|
||||
if os.environ.get("INVOKEAI_ROOT"):
|
||||
root = Path(os.environ["INVOKEAI_ROOT"])
|
||||
elif venv := os.environ.get("VIRTUAL_ENV", None):
|
||||
root = Path(venv).parent.resolve()
|
||||
elif any((venv.parent / x).exists() for x in [INIT_FILE, LEGACY_INIT_FILE]):
|
||||
root = (venv.parent).resolve()
|
||||
else:
|
||||
root = Path("~/invokeai").expanduser().resolve()
|
||||
return root
|
||||
@ -373,16 +373,13 @@ def migrate_v3_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
|
||||
if k == "conf_path":
|
||||
parsed_config_dict["legacy_models_yaml_path"] = v
|
||||
if k == "legacy_conf_dir":
|
||||
# The old default for this was "configs/stable-diffusion" ("configs\stable-diffusion" on Windows).
|
||||
if v == "configs/stable-diffusion" or v == "configs\\stable-diffusion":
|
||||
# If if the incoming config has the default value, skip
|
||||
continue
|
||||
elif Path(v).name == "stable-diffusion":
|
||||
# Else if the path ends in "stable-diffusion", we assume the parent is the new correct path.
|
||||
parsed_config_dict["legacy_conf_dir"] = str(Path(v).parent)
|
||||
else:
|
||||
# Else we do not attempt to migrate this setting
|
||||
# The old default for this was "configs/stable-diffusion". If if the incoming config has that as the value, we won't set it.
|
||||
# Else if the path ends in "stable-diffusion", we assume the parent is the new correct path.
|
||||
# Else we do not attempt to migrate this setting
|
||||
if v != "configs/stable-diffusion":
|
||||
parsed_config_dict["legacy_conf_dir"] = v
|
||||
elif Path(v).name == "stable-diffusion":
|
||||
parsed_config_dict["legacy_conf_dir"] = str(Path(v).parent)
|
||||
elif k in InvokeAIAppConfig.model_fields:
|
||||
# skip unknown fields
|
||||
parsed_config_dict[k] = v
|
||||
@ -402,7 +399,7 @@ def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
|
||||
An instance of `InvokeAIAppConfig` with the loaded and migrated settings.
|
||||
"""
|
||||
assert config_path.suffix == ".yaml"
|
||||
with open(config_path, "rt", encoding=locale.getpreferredencoding()) as file:
|
||||
with open(config_path) as file:
|
||||
loaded_config_dict = yaml.safe_load(file)
|
||||
|
||||
assert isinstance(loaded_config_dict, dict)
|
||||
|
@ -1,6 +1,5 @@
|
||||
"""Model installation class."""
|
||||
|
||||
import locale
|
||||
import os
|
||||
import re
|
||||
import signal
|
||||
@ -324,8 +323,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
legacy_models_yaml_path = Path(self._app_config.root_path, legacy_models_yaml_path)
|
||||
|
||||
if legacy_models_yaml_path.exists():
|
||||
with open(legacy_models_yaml_path, "rt", encoding=locale.getpreferredencoding()) as file:
|
||||
legacy_models_yaml = yaml.safe_load(file)
|
||||
legacy_models_yaml = yaml.safe_load(legacy_models_yaml_path.read_text())
|
||||
|
||||
yaml_metadata = legacy_models_yaml.pop("__metadata__")
|
||||
yaml_version = yaml_metadata.get("version")
|
||||
@ -566,7 +564,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
# The model is not in the models directory - we don't need to move it.
|
||||
return model
|
||||
|
||||
new_path = models_dir / model.base.value / model.type.value / old_path.name
|
||||
new_path = (models_dir / model.base.value / model.type.value / model.name).with_suffix(old_path.suffix)
|
||||
|
||||
if old_path == new_path or new_path.exists() and old_path == new_path.resolve():
|
||||
return model
|
||||
|
@ -11,7 +11,6 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_5 import
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_6 import build_migration_6
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_7 import build_migration_7
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_8 import build_migration_8
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_9 import build_migration_9
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
|
||||
|
||||
|
||||
@ -40,7 +39,6 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
|
||||
migrator.register_migration(build_migration_6())
|
||||
migrator.register_migration(build_migration_7())
|
||||
migrator.register_migration(build_migration_8(app_config=config))
|
||||
migrator.register_migration(build_migration_9())
|
||||
migrator.run_migrations()
|
||||
|
||||
return db
|
||||
|
@ -1,29 +0,0 @@
|
||||
import sqlite3
|
||||
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
|
||||
|
||||
|
||||
class Migration9Callback:
|
||||
def __call__(self, cursor: sqlite3.Cursor) -> None:
|
||||
self._empty_session_queue(cursor)
|
||||
|
||||
def _empty_session_queue(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""Empties the session queue. This is done to prevent any lingering session queue items from causing pydantic errors due to changed schemas."""
|
||||
|
||||
cursor.execute("DELETE FROM session_queue;")
|
||||
|
||||
|
||||
def build_migration_9() -> Migration:
|
||||
"""
|
||||
Build the migration from database version 8 to 9.
|
||||
|
||||
This migration does the following:
|
||||
- Empties the session queue. This is done to prevent any lingering session queue items from causing pydantic errors due to changed schemas.
|
||||
"""
|
||||
migration_9 = Migration(
|
||||
from_version=8,
|
||||
to_version=9,
|
||||
callback=Migration9Callback(),
|
||||
)
|
||||
|
||||
return migration_9
|
@ -1,6 +1,4 @@
|
||||
import sqlite3
|
||||
from contextlib import closing
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
@ -34,7 +32,6 @@ class SqliteMigrator:
|
||||
self._db = db
|
||||
self._logger = db.logger
|
||||
self._migration_set = MigrationSet()
|
||||
self._backup_path: Optional[Path] = None
|
||||
|
||||
def register_migration(self, migration: Migration) -> None:
|
||||
"""Registers a migration."""
|
||||
@ -58,18 +55,6 @@ class SqliteMigrator:
|
||||
return False
|
||||
|
||||
self._logger.info("Database update needed")
|
||||
|
||||
# Make a backup of the db if it needs to be updated and is a file db
|
||||
if self._db.db_path is not None:
|
||||
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
|
||||
self._backup_path = self._db.db_path.parent / f"{self._db.db_path.stem}_backup_{timestamp}.db"
|
||||
self._logger.info(f"Backing up database to {str(self._backup_path)}")
|
||||
# Use SQLite to do the backup
|
||||
with closing(sqlite3.connect(self._backup_path)) as backup_conn:
|
||||
self._db.conn.backup(backup_conn)
|
||||
else:
|
||||
self._logger.info("Using in-memory database, no backup needed")
|
||||
|
||||
next_migration = self._migration_set.get(from_version=self._get_current_version(cursor))
|
||||
while next_migration is not None:
|
||||
self._run_migration(next_migration)
|
||||
|
@ -1,11 +1,8 @@
|
||||
# copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0)
|
||||
# and modified as needed
|
||||
|
||||
import pathlib
|
||||
from typing import List, Optional, TypedDict, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import safetensors
|
||||
import safetensors.torch
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||
@ -16,17 +13,10 @@ from ..raw_model import RawModel
|
||||
from .resampler import Resampler
|
||||
|
||||
|
||||
class IPAdapterStateDict(TypedDict):
|
||||
ip_adapter: dict[str, torch.Tensor]
|
||||
image_proj: dict[str, torch.Tensor]
|
||||
|
||||
|
||||
class ImageProjModel(torch.nn.Module):
|
||||
"""Image Projection Model"""
|
||||
|
||||
def __init__(
|
||||
self, cross_attention_dim: int = 1024, clip_embeddings_dim: int = 1024, clip_extra_context_tokens: int = 4
|
||||
):
|
||||
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
|
||||
super().__init__()
|
||||
|
||||
self.cross_attention_dim = cross_attention_dim
|
||||
@ -35,7 +25,7 @@ class ImageProjModel(torch.nn.Module):
|
||||
self.norm = torch.nn.LayerNorm(cross_attention_dim)
|
||||
|
||||
@classmethod
|
||||
def from_state_dict(cls, state_dict: dict[str, torch.Tensor], clip_extra_context_tokens: int = 4):
|
||||
def from_state_dict(cls, state_dict: dict[torch.Tensor], clip_extra_context_tokens=4):
|
||||
"""Initialize an ImageProjModel from a state_dict.
|
||||
|
||||
The cross_attention_dim and clip_embeddings_dim are inferred from the shape of the tensors in the state_dict.
|
||||
@ -55,7 +45,7 @@ class ImageProjModel(torch.nn.Module):
|
||||
model.load_state_dict(state_dict)
|
||||
return model
|
||||
|
||||
def forward(self, image_embeds: torch.Tensor):
|
||||
def forward(self, image_embeds):
|
||||
embeds = image_embeds
|
||||
clip_extra_context_tokens = self.proj(embeds).reshape(
|
||||
-1, self.clip_extra_context_tokens, self.cross_attention_dim
|
||||
@ -67,7 +57,7 @@ class ImageProjModel(torch.nn.Module):
|
||||
class MLPProjModel(torch.nn.Module):
|
||||
"""SD model with image prompt"""
|
||||
|
||||
def __init__(self, cross_attention_dim: int = 1024, clip_embeddings_dim: int = 1024):
|
||||
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024):
|
||||
super().__init__()
|
||||
|
||||
self.proj = torch.nn.Sequential(
|
||||
@ -78,7 +68,7 @@ class MLPProjModel(torch.nn.Module):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_state_dict(cls, state_dict: dict[str, torch.Tensor]):
|
||||
def from_state_dict(cls, state_dict: dict[torch.Tensor]):
|
||||
"""Initialize an MLPProjModel from a state_dict.
|
||||
|
||||
The cross_attention_dim and clip_embeddings_dim are inferred from the shape of the tensors in the state_dict.
|
||||
@ -97,7 +87,7 @@ class MLPProjModel(torch.nn.Module):
|
||||
model.load_state_dict(state_dict)
|
||||
return model
|
||||
|
||||
def forward(self, image_embeds: torch.Tensor):
|
||||
def forward(self, image_embeds):
|
||||
clip_extra_context_tokens = self.proj(image_embeds)
|
||||
return clip_extra_context_tokens
|
||||
|
||||
@ -107,7 +97,7 @@ class IPAdapter(RawModel):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
state_dict: IPAdapterStateDict,
|
||||
state_dict: dict[str, torch.Tensor],
|
||||
device: torch.device,
|
||||
dtype: torch.dtype = torch.float16,
|
||||
num_tokens: int = 4,
|
||||
@ -139,27 +129,24 @@ class IPAdapter(RawModel):
|
||||
|
||||
return calc_model_size_by_data(self._image_proj_model) + calc_model_size_by_data(self.attn_weights)
|
||||
|
||||
def _init_image_proj_model(
|
||||
self, state_dict: dict[str, torch.Tensor]
|
||||
) -> Union[ImageProjModel, Resampler, MLPProjModel]:
|
||||
def _init_image_proj_model(self, state_dict):
|
||||
return ImageProjModel.from_state_dict(state_dict, self._num_tokens).to(self.device, dtype=self.dtype)
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_image_embeds(self, pil_image: List[Image.Image], image_encoder: CLIPVisionModelWithProjection):
|
||||
def get_image_embeds(self, pil_image, image_encoder: CLIPVisionModelWithProjection):
|
||||
if isinstance(pil_image, Image.Image):
|
||||
pil_image = [pil_image]
|
||||
clip_image = self._clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
|
||||
clip_image_embeds = image_encoder(clip_image.to(self.device, dtype=self.dtype)).image_embeds
|
||||
try:
|
||||
image_prompt_embeds = self._image_proj_model(clip_image_embeds)
|
||||
uncond_image_prompt_embeds = self._image_proj_model(torch.zeros_like(clip_image_embeds))
|
||||
return image_prompt_embeds, uncond_image_prompt_embeds
|
||||
except RuntimeError as e:
|
||||
raise RuntimeError("Selected CLIP Vision Model is incompatible with the current IP Adapter") from e
|
||||
image_prompt_embeds = self._image_proj_model(clip_image_embeds)
|
||||
uncond_image_prompt_embeds = self._image_proj_model(torch.zeros_like(clip_image_embeds))
|
||||
return image_prompt_embeds, uncond_image_prompt_embeds
|
||||
|
||||
|
||||
class IPAdapterPlus(IPAdapter):
|
||||
"""IP-Adapter with fine-grained features"""
|
||||
|
||||
def _init_image_proj_model(self, state_dict: dict[str, torch.Tensor]) -> Union[Resampler, MLPProjModel]:
|
||||
def _init_image_proj_model(self, state_dict):
|
||||
return Resampler.from_state_dict(
|
||||
state_dict=state_dict,
|
||||
depth=4,
|
||||
@ -170,32 +157,31 @@ class IPAdapterPlus(IPAdapter):
|
||||
).to(self.device, dtype=self.dtype)
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_image_embeds(self, pil_image: List[Image.Image], image_encoder: CLIPVisionModelWithProjection):
|
||||
def get_image_embeds(self, pil_image, image_encoder: CLIPVisionModelWithProjection):
|
||||
if isinstance(pil_image, Image.Image):
|
||||
pil_image = [pil_image]
|
||||
clip_image = self._clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
|
||||
clip_image = clip_image.to(self.device, dtype=self.dtype)
|
||||
clip_image_embeds = image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
|
||||
image_prompt_embeds = self._image_proj_model(clip_image_embeds)
|
||||
uncond_clip_image_embeds = image_encoder(torch.zeros_like(clip_image), output_hidden_states=True).hidden_states[
|
||||
-2
|
||||
]
|
||||
try:
|
||||
image_prompt_embeds = self._image_proj_model(clip_image_embeds)
|
||||
uncond_image_prompt_embeds = self._image_proj_model(uncond_clip_image_embeds)
|
||||
return image_prompt_embeds, uncond_image_prompt_embeds
|
||||
except RuntimeError as e:
|
||||
raise RuntimeError("Selected CLIP Vision Model is incompatible with the current IP Adapter") from e
|
||||
uncond_image_prompt_embeds = self._image_proj_model(uncond_clip_image_embeds)
|
||||
return image_prompt_embeds, uncond_image_prompt_embeds
|
||||
|
||||
|
||||
class IPAdapterFull(IPAdapterPlus):
|
||||
"""IP-Adapter Plus with full features."""
|
||||
|
||||
def _init_image_proj_model(self, state_dict: dict[str, torch.Tensor]):
|
||||
def _init_image_proj_model(self, state_dict: dict[torch.Tensor]):
|
||||
return MLPProjModel.from_state_dict(state_dict).to(self.device, dtype=self.dtype)
|
||||
|
||||
|
||||
class IPAdapterPlusXL(IPAdapterPlus):
|
||||
"""IP-Adapter Plus for SDXL."""
|
||||
|
||||
def _init_image_proj_model(self, state_dict: dict[str, torch.Tensor]):
|
||||
def _init_image_proj_model(self, state_dict):
|
||||
return Resampler.from_state_dict(
|
||||
state_dict=state_dict,
|
||||
depth=4,
|
||||
@ -206,48 +192,24 @@ class IPAdapterPlusXL(IPAdapterPlus):
|
||||
).to(self.device, dtype=self.dtype)
|
||||
|
||||
|
||||
def load_ip_adapter_tensors(ip_adapter_ckpt_path: pathlib.Path, device: str) -> IPAdapterStateDict:
|
||||
state_dict: IPAdapterStateDict = {"ip_adapter": {}, "image_proj": {}}
|
||||
|
||||
if ip_adapter_ckpt_path.suffix == ".safetensors":
|
||||
model = safetensors.torch.load_file(ip_adapter_ckpt_path, device=device)
|
||||
for key in model.keys():
|
||||
if key.startswith("image_proj."):
|
||||
state_dict["image_proj"][key.replace("image_proj.", "")] = model[key]
|
||||
elif key.startswith("ip_adapter."):
|
||||
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = model[key]
|
||||
else:
|
||||
raise RuntimeError(f"Encountered unexpected IP Adapter state dict key: '{key}'.")
|
||||
else:
|
||||
ip_adapter_diffusers_checkpoint_path = ip_adapter_ckpt_path / "ip_adapter.bin"
|
||||
state_dict = torch.load(ip_adapter_diffusers_checkpoint_path, map_location="cpu")
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def build_ip_adapter(
|
||||
ip_adapter_ckpt_path: pathlib.Path, device: torch.device, dtype: torch.dtype = torch.float16
|
||||
) -> Union[IPAdapter, IPAdapterPlus, IPAdapterPlusXL, IPAdapterPlus]:
|
||||
state_dict = load_ip_adapter_tensors(ip_adapter_ckpt_path, device.type)
|
||||
ip_adapter_ckpt_path: str, device: torch.device, dtype: torch.dtype = torch.float16
|
||||
) -> Union[IPAdapter, IPAdapterPlus]:
|
||||
state_dict = torch.load(ip_adapter_ckpt_path, map_location="cpu")
|
||||
|
||||
# IPAdapter (with ImageProjModel)
|
||||
if "proj.weight" in state_dict["image_proj"]:
|
||||
if "proj.weight" in state_dict["image_proj"]: # IPAdapter (with ImageProjModel).
|
||||
return IPAdapter(state_dict, device=device, dtype=dtype)
|
||||
|
||||
# IPAdaterPlus or IPAdapterPlusXL (with Resampler)
|
||||
elif "proj_in.weight" in state_dict["image_proj"]:
|
||||
elif "proj_in.weight" in state_dict["image_proj"]: # IPAdaterPlus or IPAdapterPlusXL (with Resampler).
|
||||
cross_attention_dim = state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[-1]
|
||||
if cross_attention_dim == 768:
|
||||
return IPAdapterPlus(state_dict, device=device, dtype=dtype) # SD1 IP-Adapter Plus
|
||||
# SD1 IP-Adapter Plus
|
||||
return IPAdapterPlus(state_dict, device=device, dtype=dtype)
|
||||
elif cross_attention_dim == 2048:
|
||||
return IPAdapterPlusXL(state_dict, device=device, dtype=dtype) # SDXL IP-Adapter Plus
|
||||
# SDXL IP-Adapter Plus
|
||||
return IPAdapterPlusXL(state_dict, device=device, dtype=dtype)
|
||||
else:
|
||||
raise Exception(f"Unsupported IP-Adapter Plus cross-attention dimension: {cross_attention_dim}.")
|
||||
|
||||
# IPAdapterFull (with MLPProjModel)
|
||||
elif "proj.0.weight" in state_dict["image_proj"]:
|
||||
elif "proj.0.weight" in state_dict["image_proj"]: # IPAdapterFull (with MLPProjModel).
|
||||
return IPAdapterFull(state_dict, device=device, dtype=dtype)
|
||||
|
||||
# Unrecognized IP Adapter Architectures
|
||||
else:
|
||||
raise ValueError(f"'{ip_adapter_ckpt_path}' has an unrecognized IP-Adapter model architecture.")
|
||||
|
@ -9,8 +9,8 @@ import torch.nn as nn
|
||||
|
||||
|
||||
# FFN
|
||||
def FeedForward(dim: int, mult: int = 4):
|
||||
inner_dim = dim * mult
|
||||
def FeedForward(dim, mult=4):
|
||||
inner_dim = int(dim * mult)
|
||||
return nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, inner_dim, bias=False),
|
||||
@ -19,8 +19,8 @@ def FeedForward(dim: int, mult: int = 4):
|
||||
)
|
||||
|
||||
|
||||
def reshape_tensor(x: torch.Tensor, heads: int):
|
||||
bs, length, _ = x.shape
|
||||
def reshape_tensor(x, heads):
|
||||
bs, length, width = x.shape
|
||||
# (bs, length, width) --> (bs, length, n_heads, dim_per_head)
|
||||
x = x.view(bs, length, heads, -1)
|
||||
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
|
||||
@ -31,7 +31,7 @@ def reshape_tensor(x: torch.Tensor, heads: int):
|
||||
|
||||
|
||||
class PerceiverAttention(nn.Module):
|
||||
def __init__(self, *, dim: int, dim_head: int = 64, heads: int = 8):
|
||||
def __init__(self, *, dim, dim_head=64, heads=8):
|
||||
super().__init__()
|
||||
self.scale = dim_head**-0.5
|
||||
self.dim_head = dim_head
|
||||
@ -45,7 +45,7 @@ class PerceiverAttention(nn.Module):
|
||||
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
|
||||
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
||||
|
||||
def forward(self, x: torch.Tensor, latents: torch.Tensor):
|
||||
def forward(self, x, latents):
|
||||
"""
|
||||
Args:
|
||||
x (torch.Tensor): image features
|
||||
@ -80,14 +80,14 @@ class PerceiverAttention(nn.Module):
|
||||
class Resampler(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int = 1024,
|
||||
depth: int = 8,
|
||||
dim_head: int = 64,
|
||||
heads: int = 16,
|
||||
num_queries: int = 8,
|
||||
embedding_dim: int = 768,
|
||||
output_dim: int = 1024,
|
||||
ff_mult: int = 4,
|
||||
dim=1024,
|
||||
depth=8,
|
||||
dim_head=64,
|
||||
heads=16,
|
||||
num_queries=8,
|
||||
embedding_dim=768,
|
||||
output_dim=1024,
|
||||
ff_mult=4,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -110,15 +110,7 @@ class Resampler(nn.Module):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_state_dict(
|
||||
cls,
|
||||
state_dict: dict[str, torch.Tensor],
|
||||
depth: int = 8,
|
||||
dim_head: int = 64,
|
||||
heads: int = 16,
|
||||
num_queries: int = 8,
|
||||
ff_mult: int = 4,
|
||||
):
|
||||
def from_state_dict(cls, state_dict: dict[torch.Tensor], depth=8, dim_head=64, heads=16, num_queries=8, ff_mult=4):
|
||||
"""A convenience function that initializes a Resampler from a state_dict.
|
||||
|
||||
Some of the shape parameters are inferred from the state_dict (e.g. dim, embedding_dim, etc.). At the time of
|
||||
@ -153,7 +145,7 @@ class Resampler(nn.Module):
|
||||
model.load_state_dict(state_dict)
|
||||
return model
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
def forward(self, x):
|
||||
latents = self.latents.repeat(x.size(0), 1, 1)
|
||||
|
||||
x = self.proj_in(x)
|
||||
|
@ -323,13 +323,10 @@ class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase):
|
||||
return Tag(f"{ModelType.Main.value}.{ModelFormat.Diffusers.value}")
|
||||
|
||||
|
||||
class IPAdapterBaseConfig(ModelConfigBase):
|
||||
class IPAdapterConfig(ModelConfigBase):
|
||||
"""Model config for IP Adaptor format models."""
|
||||
|
||||
type: Literal[ModelType.IPAdapter] = ModelType.IPAdapter
|
||||
|
||||
|
||||
class IPAdapterInvokeAIConfig(IPAdapterBaseConfig):
|
||||
"""Model config for IP Adapter diffusers format models."""
|
||||
|
||||
image_encoder_model_id: str
|
||||
format: Literal[ModelFormat.InvokeAI]
|
||||
|
||||
@ -338,16 +335,6 @@ class IPAdapterInvokeAIConfig(IPAdapterBaseConfig):
|
||||
return Tag(f"{ModelType.IPAdapter.value}.{ModelFormat.InvokeAI.value}")
|
||||
|
||||
|
||||
class IPAdapterCheckpointConfig(IPAdapterBaseConfig):
|
||||
"""Model config for IP Adapter checkpoint format models."""
|
||||
|
||||
format: Literal[ModelFormat.Checkpoint]
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.IPAdapter.value}.{ModelFormat.Checkpoint.value}")
|
||||
|
||||
|
||||
class CLIPVisionDiffusersConfig(DiffusersConfigBase):
|
||||
"""Model config for CLIPVision."""
|
||||
|
||||
@ -403,8 +390,7 @@ AnyModelConfig = Annotated[
|
||||
Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()],
|
||||
Annotated[TextualInversionFileConfig, TextualInversionFileConfig.get_tag()],
|
||||
Annotated[TextualInversionFolderConfig, TextualInversionFolderConfig.get_tag()],
|
||||
Annotated[IPAdapterInvokeAIConfig, IPAdapterInvokeAIConfig.get_tag()],
|
||||
Annotated[IPAdapterCheckpointConfig, IPAdapterCheckpointConfig.get_tag()],
|
||||
Annotated[IPAdapterConfig, IPAdapterConfig.get_tag()],
|
||||
Annotated[T2IAdapterConfig, T2IAdapterConfig.get_tag()],
|
||||
Annotated[CLIPVisionDiffusersConfig, CLIPVisionDiffusersConfig.get_tag()],
|
||||
],
|
||||
|
@ -429,8 +429,4 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
)
|
||||
free_mem, _ = torch.cuda.mem_get_info(torch.device(vram_device))
|
||||
if needed_size > free_mem:
|
||||
needed_gb = round(needed_size / GIG, 2)
|
||||
free_gb = round(free_mem / GIG, 2)
|
||||
raise torch.cuda.OutOfMemoryError(
|
||||
f"Insufficient VRAM to load model, requested {needed_gb}GB but only had {free_gb}GB free"
|
||||
)
|
||||
raise torch.cuda.OutOfMemoryError
|
||||
|
@ -7,13 +7,19 @@ from typing import Optional
|
||||
import torch
|
||||
|
||||
from invokeai.backend.ip_adapter.ip_adapter import build_ip_adapter
|
||||
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.load import ModelLoader, ModelLoaderRegistry
|
||||
from invokeai.backend.raw_model import RawModel
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.IPAdapter, format=ModelFormat.InvokeAI)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.IPAdapter, format=ModelFormat.Checkpoint)
|
||||
class IPAdapterInvokeAILoader(ModelLoader):
|
||||
"""Class to load IP Adapter diffusers models."""
|
||||
|
||||
@ -26,7 +32,7 @@ class IPAdapterInvokeAILoader(ModelLoader):
|
||||
raise ValueError("There are no submodels in an IP-Adapter model.")
|
||||
model_path = Path(config.path)
|
||||
model: RawModel = build_ip_adapter(
|
||||
ip_adapter_ckpt_path=model_path,
|
||||
ip_adapter_ckpt_path=str(model_path / "ip_adapter.bin"),
|
||||
device=torch.device("cpu"),
|
||||
dtype=self._torch_dtype,
|
||||
)
|
||||
|
@ -230,10 +230,9 @@ class ModelProbe(object):
|
||||
return ModelType.LoRA
|
||||
elif any(key.startswith(v) for v in {"controlnet", "control_model", "input_blocks"}):
|
||||
return ModelType.ControlNet
|
||||
elif any(key.startswith(v) for v in {"image_proj.", "ip_adapter."}):
|
||||
return ModelType.IPAdapter
|
||||
elif key in {"emb_params", "string_to_param"}:
|
||||
return ModelType.TextualInversion
|
||||
|
||||
else:
|
||||
# diffusers-ti
|
||||
if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
|
||||
@ -528,25 +527,8 @@ class ControlNetCheckpointProbe(CheckpointProbeBase):
|
||||
|
||||
|
||||
class IPAdapterCheckpointProbe(CheckpointProbeBase):
|
||||
"""Class for probing IP Adapters"""
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
checkpoint = self.checkpoint
|
||||
for key in checkpoint.keys():
|
||||
if not key.startswith(("image_proj.", "ip_adapter.")):
|
||||
continue
|
||||
cross_attention_dim = checkpoint["ip_adapter.1.to_k_ip.weight"].shape[-1]
|
||||
if cross_attention_dim == 768:
|
||||
return BaseModelType.StableDiffusion1
|
||||
elif cross_attention_dim == 1024:
|
||||
return BaseModelType.StableDiffusion2
|
||||
elif cross_attention_dim == 2048:
|
||||
return BaseModelType.StableDiffusionXL
|
||||
else:
|
||||
raise InvalidModelConfigException(
|
||||
f"IP-Adapter had unexpected cross-attention dimension: {cross_attention_dim}."
|
||||
)
|
||||
raise InvalidModelConfigException(f"{self.model_path}: Unable to determine base type")
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class CLIPVisionCheckpointProbe(CheckpointProbeBase):
|
||||
@ -786,7 +768,7 @@ class T2IAdapterFolderProbe(FolderProbeBase):
|
||||
)
|
||||
|
||||
|
||||
# Register probe classes
|
||||
############## register probe classes ######
|
||||
ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.VAE, VaeFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.LoRA, LoRAFolderProbe)
|
||||
|
@ -94,7 +94,6 @@
|
||||
"reactflow": "^11.10.4",
|
||||
"redux-dynamic-middlewares": "^2.2.0",
|
||||
"redux-remember": "^5.1.0",
|
||||
"rfdc": "^1.3.1",
|
||||
"roarr": "^7.21.1",
|
||||
"serialize-error": "^11.0.3",
|
||||
"socket.io-client": "^4.7.5",
|
||||
|
7
invokeai/frontend/web/pnpm-lock.yaml
generated
7
invokeai/frontend/web/pnpm-lock.yaml
generated
@ -137,9 +137,6 @@ dependencies:
|
||||
redux-remember:
|
||||
specifier: ^5.1.0
|
||||
version: 5.1.0(redux@5.0.1)
|
||||
rfdc:
|
||||
specifier: ^1.3.1
|
||||
version: 1.3.1
|
||||
roarr:
|
||||
specifier: ^7.21.1
|
||||
version: 7.21.1
|
||||
@ -12131,10 +12128,6 @@ packages:
|
||||
resolution: {integrity: sha512-/x8uIPdTafBqakK0TmPNJzgkLP+3H+yxpUJhCQHsLBg1rYEVNR2D8BRYNWQhVBjyOd7oo1dZRVzIkwMY2oqfYQ==}
|
||||
dev: true
|
||||
|
||||
/rfdc@1.3.1:
|
||||
resolution: {integrity: sha512-r5a3l5HzYlIC68TpmYKlxWjmOP6wiPJ1vWv2HeLhNsRZMrCkxeqxiHlQ21oXmQ4F3SiryXBHhAD7JZqvOJjFmg==}
|
||||
dev: false
|
||||
|
||||
/rimraf@2.6.3:
|
||||
resolution: {integrity: sha512-mwqeW5XsA2qAejG46gYdENaxXjx9onRNCfn7L0duuP4hCuTIi/QO7PDK07KJfp1d+izWPrzEJDcSqBa0OZQriA==}
|
||||
hasBin: true
|
||||
|
@ -4,7 +4,7 @@
|
||||
"reportBugLabel": "Fehler melden",
|
||||
"settingsLabel": "Einstellungen",
|
||||
"img2img": "Bild zu Bild",
|
||||
"nodes": "Arbeitsabläufe",
|
||||
"nodes": "Knoten Editor",
|
||||
"upload": "Hochladen",
|
||||
"load": "Laden",
|
||||
"statusDisconnected": "Getrennt",
|
||||
@ -74,8 +74,7 @@
|
||||
"updated": "Aktualisiert",
|
||||
"copy": "Kopieren",
|
||||
"aboutHeading": "Nutzen Sie Ihre kreative Energie",
|
||||
"toResolve": "Lösen",
|
||||
"add": "Hinzufügen"
|
||||
"toResolve": "Lösen"
|
||||
},
|
||||
"gallery": {
|
||||
"galleryImageSize": "Bildgröße",
|
||||
@ -105,16 +104,11 @@
|
||||
"dropToUpload": "$t(gallery.drop) zum hochladen",
|
||||
"dropOrUpload": "$t(gallery.drop) oder hochladen",
|
||||
"drop": "Ablegen",
|
||||
"problemDeletingImages": "Problem beim Löschen der Bilder",
|
||||
"bulkDownloadRequested": "Download vorbereiten",
|
||||
"bulkDownloadRequestedDesc": "Dein Download wird vorbereitet. Dies kann ein paar Momente dauern.",
|
||||
"bulkDownloadRequestFailed": "Problem beim Download vorbereiten",
|
||||
"bulkDownloadFailed": "Download fehlgeschlagen",
|
||||
"alwaysShowImageSizeBadge": "Zeige immer Bilder Größe Abzeichen"
|
||||
"problemDeletingImages": "Problem beim Löschen der Bilder"
|
||||
},
|
||||
"hotkeys": {
|
||||
"keyboardShortcuts": "Tastenkürzel",
|
||||
"appHotkeys": "App",
|
||||
"appHotkeys": "App-Tastenkombinationen",
|
||||
"generalHotkeys": "Allgemein",
|
||||
"galleryHotkeys": "Galerie",
|
||||
"unifiedCanvasHotkeys": "Leinwand",
|
||||
@ -763,9 +757,7 @@
|
||||
"scheduler": "Planer",
|
||||
"noRecallParameters": "Es wurden keine Parameter zum Abrufen gefunden",
|
||||
"recallParameters": "Parameter wiederherstellen",
|
||||
"cfgRescaleMultiplier": "$t(parameters.cfgRescaleMultiplier)",
|
||||
"allPrompts": "Alle Prompts",
|
||||
"imageDimensions": "Bilder Auslösungen"
|
||||
"cfgRescaleMultiplier": "$t(parameters.cfgRescaleMultiplier)"
|
||||
},
|
||||
"popovers": {
|
||||
"noiseUseCPU": {
|
||||
@ -1076,10 +1068,5 @@
|
||||
},
|
||||
"dynamicPrompts": {
|
||||
"showDynamicPrompts": "Dynamische Prompts anzeigen"
|
||||
},
|
||||
"prompt": {
|
||||
"noMatchingTriggers": "Keine passenden Auslöser",
|
||||
"addPromptTrigger": "Auslöse Text hinzufügen",
|
||||
"compatibleEmbeddings": "Kompatible Einbettungen"
|
||||
}
|
||||
}
|
||||
|
@ -217,7 +217,6 @@
|
||||
"saveControlImage": "Save Control Image",
|
||||
"scribble": "scribble",
|
||||
"selectModel": "Select a model",
|
||||
"selectCLIPVisionModel": "Select a CLIP Vision model",
|
||||
"setControlImageDimensions": "Set Control Image Dimensions To W/H",
|
||||
"showAdvanced": "Show Advanced",
|
||||
"small": "Small",
|
||||
@ -656,7 +655,6 @@
|
||||
"install": "Install",
|
||||
"installAll": "Install All",
|
||||
"installRepo": "Install Repo",
|
||||
"ipAdapters": "IP Adapters",
|
||||
"load": "Load",
|
||||
"localOnly": "local only",
|
||||
"manual": "Manual",
|
||||
|
@ -73,8 +73,7 @@
|
||||
"ai": "ia",
|
||||
"file": "File",
|
||||
"toResolve": "Da risolvere",
|
||||
"add": "Aggiungi",
|
||||
"loglevel": "Livello di log"
|
||||
"add": "Aggiungi"
|
||||
},
|
||||
"gallery": {
|
||||
"galleryImageSize": "Dimensione dell'immagine",
|
||||
@ -935,9 +934,7 @@
|
||||
"base": "Base",
|
||||
"lineart": "Linea",
|
||||
"controlnet": "$t(controlnet.controlAdapter_one) #{{number}} ($t(common.controlNet))",
|
||||
"mediapipeFace": "Mediapipe Volto",
|
||||
"ip_adapter": "$t(controlnet.controlAdapter_one) #{{number}} ($t(common.ipAdapter))",
|
||||
"t2i_adapter": "$t(controlnet.controlAdapter_one) #{{number}} ($t(common.t2iAdapter))"
|
||||
"mediapipeFace": "Mediapipe Volto"
|
||||
},
|
||||
"queue": {
|
||||
"queueFront": "Aggiungi all'inizio della coda",
|
||||
@ -1493,8 +1490,7 @@
|
||||
"title": "Generazione"
|
||||
},
|
||||
"advanced": {
|
||||
"title": "Avanzate",
|
||||
"options": "Opzioni $t(accordions.advanced.title)"
|
||||
"title": "Avanzate"
|
||||
},
|
||||
"image": {
|
||||
"title": "Immagine"
|
||||
|
@ -75,8 +75,7 @@
|
||||
"copy": "Копировать",
|
||||
"localSystem": "Локальная система",
|
||||
"aboutDesc": "Используя Invoke для работы? Проверьте это:",
|
||||
"add": "Добавить",
|
||||
"loglevel": "Уровень логов"
|
||||
"add": "Добавить"
|
||||
},
|
||||
"gallery": {
|
||||
"galleryImageSize": "Размер изображений",
|
||||
@ -1506,8 +1505,7 @@
|
||||
"title": "Генерация"
|
||||
},
|
||||
"advanced": {
|
||||
"title": "Расширенные",
|
||||
"options": "Опции $t(accordions.advanced.title)"
|
||||
"title": "Расширенные"
|
||||
},
|
||||
"image": {
|
||||
"title": "Изображение"
|
||||
|
@ -1,7 +1,7 @@
|
||||
import type { UnknownAction } from '@reduxjs/toolkit';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { isAnyGraphBuilt } from 'features/nodes/store/actions';
|
||||
import { nodeTemplatesBuilt } from 'features/nodes/store/nodesSlice';
|
||||
import { cloneDeep } from 'lodash-es';
|
||||
import { appInfoApi } from 'services/api/endpoints/appInfo';
|
||||
import type { Graph } from 'services/api/types';
|
||||
import { socketGeneratorProgress } from 'services/events/actions';
|
||||
@ -33,7 +33,7 @@ export const actionSanitizer = <A extends UnknownAction>(action: A): A => {
|
||||
}
|
||||
|
||||
if (socketGeneratorProgress.match(action)) {
|
||||
const sanitized = deepClone(action);
|
||||
const sanitized = cloneDeep(action);
|
||||
if (sanitized.payload.data.progress_image) {
|
||||
sanitized.payload.data.progress_image.dataURL = '<Progress image omitted>';
|
||||
}
|
||||
|
@ -43,7 +43,6 @@ export const addModelInstallEventListener = (startAppListening: AppStartListenin
|
||||
})
|
||||
);
|
||||
dispatch(api.util.invalidateTags([{ type: 'ModelConfig', id: LIST_TAG }]));
|
||||
dispatch(api.util.invalidateTags([{ type: 'ModelScanFolderResults', id: LIST_TAG }]));
|
||||
},
|
||||
});
|
||||
|
||||
|
@ -1,5 +1,4 @@
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { merge } from 'lodash-es';
|
||||
import { cloneDeep, merge } from 'lodash-es';
|
||||
import { ClickScrollPlugin, OverlayScrollbars } from 'overlayscrollbars';
|
||||
import type { UseOverlayScrollbarsParams } from 'overlayscrollbars-react';
|
||||
|
||||
@ -23,7 +22,7 @@ export const getOverlayScrollbarsParams = (
|
||||
overflowX: 'hidden' | 'scroll' = 'hidden',
|
||||
overflowY: 'hidden' | 'scroll' = 'scroll'
|
||||
) => {
|
||||
const params = deepClone(overlayScrollbarsParams);
|
||||
const params = cloneDeep(overlayScrollbarsParams);
|
||||
merge(params, { options: { overflow: { y: overflowY, x: overflowX } } });
|
||||
return params;
|
||||
};
|
||||
|
@ -1,15 +0,0 @@
|
||||
import rfdc from 'rfdc';
|
||||
const _rfdc = rfdc();
|
||||
|
||||
/**
|
||||
* Deep-clones an object using Really Fast Deep Clone.
|
||||
* This is the fastest deep clone library on Chrome, but not the fastest on FF. Still, it's much faster than lodash
|
||||
* and structuredClone, so it's the best all-around choice.
|
||||
*
|
||||
* Simple Benchmark: https://www.measurethat.net/Benchmarks/Show/30358/0/lodash-clonedeep-vs-jsonparsejsonstringify-vs-recursive
|
||||
* Repo: https://github.com/davidmarkclements/rfdc
|
||||
*
|
||||
* @param obj The object to deep-clone
|
||||
* @returns The cloned object
|
||||
*/
|
||||
export const deepClone = <T>(obj: T): T => _rfdc(obj);
|
@ -1,7 +1,6 @@
|
||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||
import { createSlice } from '@reduxjs/toolkit';
|
||||
import type { PersistConfig, RootState } from 'app/store/store';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { roundDownToMultiple, roundToMultiple } from 'common/util/roundDownToMultiple';
|
||||
import calculateCoordinates from 'features/canvas/util/calculateCoordinates';
|
||||
import calculateScale from 'features/canvas/util/calculateScale';
|
||||
@ -14,7 +13,7 @@ import { modelChanged } from 'features/parameters/store/generationSlice';
|
||||
import type { PayloadActionWithOptimalDimension } from 'features/parameters/store/types';
|
||||
import { getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/optimalDimension';
|
||||
import type { IRect, Vector2d } from 'konva/lib/types';
|
||||
import { clamp } from 'lodash-es';
|
||||
import { clamp, cloneDeep } from 'lodash-es';
|
||||
import type { RgbaColor } from 'react-colorful';
|
||||
import { queueApi } from 'services/api/endpoints/queue';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
@ -37,7 +36,7 @@ import { CANVAS_GRID_SIZE_FINE } from './constants';
|
||||
/**
|
||||
* The maximum history length to keep in the past/future layer states.
|
||||
*/
|
||||
const MAX_HISTORY = 100;
|
||||
const MAX_HISTORY = 128;
|
||||
|
||||
const initialLayerState: CanvasLayerState = {
|
||||
objects: [],
|
||||
@ -122,7 +121,7 @@ export const canvasSlice = createSlice({
|
||||
state.brushSize = action.payload;
|
||||
},
|
||||
clearMask: (state) => {
|
||||
pushToPrevLayerStates(state);
|
||||
state.pastLayerStates.push(cloneDeep(state.layerState));
|
||||
state.layerState.objects = state.layerState.objects.filter((obj) => !isCanvasMaskLine(obj));
|
||||
state.futureLayerStates = [];
|
||||
state.shouldPreserveMaskedArea = false;
|
||||
@ -164,10 +163,10 @@ export const canvasSlice = createSlice({
|
||||
state.boundingBoxDimensions = newBoundingBoxDimensions;
|
||||
state.boundingBoxCoordinates = newBoundingBoxCoordinates;
|
||||
|
||||
pushToPrevLayerStates(state);
|
||||
state.pastLayerStates.push(cloneDeep(state.layerState));
|
||||
|
||||
state.layerState = {
|
||||
...deepClone(initialLayerState),
|
||||
...cloneDeep(initialLayerState),
|
||||
objects: [
|
||||
{
|
||||
kind: 'image',
|
||||
@ -262,7 +261,11 @@ export const canvasSlice = createSlice({
|
||||
return;
|
||||
}
|
||||
|
||||
pushToPrevLayerStates(state);
|
||||
state.pastLayerStates.push(cloneDeep(state.layerState));
|
||||
|
||||
if (state.pastLayerStates.length > MAX_HISTORY) {
|
||||
state.pastLayerStates.shift();
|
||||
}
|
||||
|
||||
state.layerState.stagingArea.images.push({
|
||||
kind: 'image',
|
||||
@ -276,9 +279,13 @@ export const canvasSlice = createSlice({
|
||||
state.futureLayerStates = [];
|
||||
},
|
||||
discardStagedImages: (state) => {
|
||||
pushToPrevLayerStates(state);
|
||||
state.pastLayerStates.push(cloneDeep(state.layerState));
|
||||
|
||||
state.layerState.stagingArea = deepClone(initialLayerState.stagingArea);
|
||||
if (state.pastLayerStates.length > MAX_HISTORY) {
|
||||
state.pastLayerStates.shift();
|
||||
}
|
||||
|
||||
state.layerState.stagingArea = cloneDeep(cloneDeep(initialLayerState)).stagingArea;
|
||||
|
||||
state.futureLayerStates = [];
|
||||
state.shouldShowStagingOutline = true;
|
||||
@ -287,7 +294,11 @@ export const canvasSlice = createSlice({
|
||||
},
|
||||
discardStagedImage: (state) => {
|
||||
const { images, selectedImageIndex } = state.layerState.stagingArea;
|
||||
pushToPrevLayerStates(state);
|
||||
state.pastLayerStates.push(cloneDeep(state.layerState));
|
||||
|
||||
if (state.pastLayerStates.length > MAX_HISTORY) {
|
||||
state.pastLayerStates.shift();
|
||||
}
|
||||
|
||||
if (!images.length) {
|
||||
return;
|
||||
@ -309,7 +320,11 @@ export const canvasSlice = createSlice({
|
||||
addFillRect: (state) => {
|
||||
const { boundingBoxCoordinates, boundingBoxDimensions, brushColor } = state;
|
||||
|
||||
pushToPrevLayerStates(state);
|
||||
state.pastLayerStates.push(cloneDeep(state.layerState));
|
||||
|
||||
if (state.pastLayerStates.length > MAX_HISTORY) {
|
||||
state.pastLayerStates.shift();
|
||||
}
|
||||
|
||||
state.layerState.objects.push({
|
||||
kind: 'fillRect',
|
||||
@ -324,7 +339,11 @@ export const canvasSlice = createSlice({
|
||||
addEraseRect: (state) => {
|
||||
const { boundingBoxCoordinates, boundingBoxDimensions } = state;
|
||||
|
||||
pushToPrevLayerStates(state);
|
||||
state.pastLayerStates.push(cloneDeep(state.layerState));
|
||||
|
||||
if (state.pastLayerStates.length > MAX_HISTORY) {
|
||||
state.pastLayerStates.shift();
|
||||
}
|
||||
|
||||
state.layerState.objects.push({
|
||||
kind: 'eraseRect',
|
||||
@ -348,7 +367,11 @@ export const canvasSlice = createSlice({
|
||||
// set & then spread this to only conditionally add the "color" key
|
||||
const newColor = layer === 'base' && tool === 'brush' ? { color: brushColor } : {};
|
||||
|
||||
pushToPrevLayerStates(state);
|
||||
state.pastLayerStates.push(cloneDeep(state.layerState));
|
||||
|
||||
if (state.pastLayerStates.length > MAX_HISTORY) {
|
||||
state.pastLayerStates.shift();
|
||||
}
|
||||
|
||||
const newLine: CanvasMaskLine | CanvasBaseLine = {
|
||||
kind: 'line',
|
||||
@ -386,7 +409,11 @@ export const canvasSlice = createSlice({
|
||||
return;
|
||||
}
|
||||
|
||||
pushToFutureLayerStates(state);
|
||||
state.futureLayerStates.unshift(cloneDeep(state.layerState));
|
||||
|
||||
if (state.futureLayerStates.length > MAX_HISTORY) {
|
||||
state.futureLayerStates.pop();
|
||||
}
|
||||
|
||||
state.layerState = targetState;
|
||||
},
|
||||
@ -397,7 +424,11 @@ export const canvasSlice = createSlice({
|
||||
return;
|
||||
}
|
||||
|
||||
pushToPrevLayerStates(state);
|
||||
state.pastLayerStates.push(cloneDeep(state.layerState));
|
||||
|
||||
if (state.pastLayerStates.length > MAX_HISTORY) {
|
||||
state.pastLayerStates.shift();
|
||||
}
|
||||
|
||||
state.layerState = targetState;
|
||||
},
|
||||
@ -414,8 +445,8 @@ export const canvasSlice = createSlice({
|
||||
state.shouldShowIntermediates = action.payload;
|
||||
},
|
||||
resetCanvas: (state) => {
|
||||
pushToPrevLayerStates(state);
|
||||
state.layerState = deepClone(initialLayerState);
|
||||
state.pastLayerStates.push(cloneDeep(state.layerState));
|
||||
state.layerState = cloneDeep(initialLayerState);
|
||||
state.futureLayerStates = [];
|
||||
state.batchIds = [];
|
||||
state.boundingBoxCoordinates = {
|
||||
@ -509,7 +540,11 @@ export const canvasSlice = createSlice({
|
||||
|
||||
const { images, selectedImageIndex } = state.layerState.stagingArea;
|
||||
|
||||
pushToPrevLayerStates(state);
|
||||
state.pastLayerStates.push(cloneDeep(state.layerState));
|
||||
|
||||
if (state.pastLayerStates.length > MAX_HISTORY) {
|
||||
state.pastLayerStates.shift();
|
||||
}
|
||||
|
||||
const imageToCommit = images[selectedImageIndex];
|
||||
|
||||
@ -518,7 +553,7 @@ export const canvasSlice = createSlice({
|
||||
...imageToCommit,
|
||||
});
|
||||
}
|
||||
state.layerState.stagingArea = deepClone(initialLayerState.stagingArea);
|
||||
state.layerState.stagingArea = cloneDeep(initialLayerState).stagingArea;
|
||||
|
||||
state.futureLayerStates = [];
|
||||
state.shouldShowStagingOutline = true;
|
||||
@ -588,7 +623,7 @@ export const canvasSlice = createSlice({
|
||||
};
|
||||
},
|
||||
setMergedCanvas: (state, action: PayloadAction<CanvasImage>) => {
|
||||
pushToPrevLayerStates(state);
|
||||
state.pastLayerStates.push(cloneDeep(state.layerState));
|
||||
|
||||
state.futureLayerStates = [];
|
||||
|
||||
@ -708,17 +743,3 @@ export const canvasPersistConfig: PersistConfig<CanvasState> = {
|
||||
migrate: migrateCanvasState,
|
||||
persistDenylist: [],
|
||||
};
|
||||
|
||||
const pushToPrevLayerStates = (state: CanvasState) => {
|
||||
state.pastLayerStates.push(deepClone(state.layerState));
|
||||
if (state.pastLayerStates.length > MAX_HISTORY) {
|
||||
state.pastLayerStates = state.pastLayerStates.slice(-MAX_HISTORY);
|
||||
}
|
||||
};
|
||||
|
||||
const pushToFutureLayerStates = (state: CanvasState) => {
|
||||
state.futureLayerStates.unshift(deepClone(state.layerState));
|
||||
if (state.futureLayerStates.length > MAX_HISTORY) {
|
||||
state.futureLayerStates = state.futureLayerStates.slice(0, MAX_HISTORY);
|
||||
}
|
||||
};
|
||||
|
@ -1,18 +1,12 @@
|
||||
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
||||
import { Combobox, Flex, FormControl, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { Combobox, FormControl, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import { useControlAdapterCLIPVisionModel } from 'features/controlAdapters/hooks/useControlAdapterCLIPVisionModel';
|
||||
import { useControlAdapterIsEnabled } from 'features/controlAdapters/hooks/useControlAdapterIsEnabled';
|
||||
import { useControlAdapterModel } from 'features/controlAdapters/hooks/useControlAdapterModel';
|
||||
import { useControlAdapterModels } from 'features/controlAdapters/hooks/useControlAdapterModels';
|
||||
import { useControlAdapterType } from 'features/controlAdapters/hooks/useControlAdapterType';
|
||||
import {
|
||||
controlAdapterCLIPVisionModelChanged,
|
||||
controlAdapterModelChanged,
|
||||
} from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||
import type { CLIPVisionModel } from 'features/controlAdapters/store/types';
|
||||
import { controlAdapterModelChanged } from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||
import { selectGenerationSlice } from 'features/parameters/store/generationSlice';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@ -35,7 +29,6 @@ const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => {
|
||||
const { modelConfig } = useControlAdapterModel(id);
|
||||
const dispatch = useAppDispatch();
|
||||
const currentBaseModel = useAppSelector((s) => s.generation.model?.base);
|
||||
const currentCLIPVisionModel = useControlAdapterCLIPVisionModel(id);
|
||||
const mainModel = useAppSelector(selectMainModel);
|
||||
const { t } = useTranslation();
|
||||
|
||||
@ -56,16 +49,6 @@ const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => {
|
||||
[dispatch, id]
|
||||
);
|
||||
|
||||
const onCLIPVisionModelChange = useCallback<ComboboxOnChange>(
|
||||
(v) => {
|
||||
if (!v?.value) {
|
||||
return;
|
||||
}
|
||||
dispatch(controlAdapterCLIPVisionModelChanged({ id, clipVisionModel: v.value as CLIPVisionModel }));
|
||||
},
|
||||
[dispatch, id]
|
||||
);
|
||||
|
||||
const selectedModel = useMemo(
|
||||
() => (modelConfig && controlAdapterType ? { ...modelConfig, model_type: controlAdapterType } : null),
|
||||
[controlAdapterType, modelConfig]
|
||||
@ -88,51 +71,18 @@ const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => {
|
||||
isLoading,
|
||||
});
|
||||
|
||||
const clipVisionOptions = useMemo<ComboboxOption[]>(
|
||||
() => [
|
||||
{ label: 'ViT-H', value: 'ViT-H' },
|
||||
{ label: 'ViT-G', value: 'ViT-G' },
|
||||
],
|
||||
[]
|
||||
);
|
||||
|
||||
const clipVisionModel = useMemo(
|
||||
() => clipVisionOptions.find((o) => o.value === currentCLIPVisionModel),
|
||||
[clipVisionOptions, currentCLIPVisionModel]
|
||||
);
|
||||
|
||||
return (
|
||||
<Flex sx={{ gap: 2 }}>
|
||||
<Tooltip label={value?.description}>
|
||||
<FormControl
|
||||
isDisabled={!isEnabled}
|
||||
isInvalid={!value || mainModel?.base !== modelConfig?.base}
|
||||
sx={{ width: '100%' }}
|
||||
>
|
||||
<Combobox
|
||||
options={options}
|
||||
placeholder={t('controlnet.selectModel')}
|
||||
value={value}
|
||||
onChange={onChange}
|
||||
noOptionsMessage={noOptionsMessage}
|
||||
/>
|
||||
</FormControl>
|
||||
</Tooltip>
|
||||
{modelConfig?.type === 'ip_adapter' && modelConfig.format === 'checkpoint' && (
|
||||
<FormControl
|
||||
isDisabled={!isEnabled}
|
||||
isInvalid={!value || mainModel?.base !== modelConfig?.base}
|
||||
sx={{ width: 'max-content', minWidth: 28 }}
|
||||
>
|
||||
<Combobox
|
||||
options={clipVisionOptions}
|
||||
placeholder={t('controlnet.selectCLIPVisionModel')}
|
||||
value={clipVisionModel}
|
||||
onChange={onCLIPVisionModelChange}
|
||||
/>
|
||||
</FormControl>
|
||||
)}
|
||||
</Flex>
|
||||
<Tooltip label={value?.description}>
|
||||
<FormControl isDisabled={!isEnabled} isInvalid={!value || mainModel?.base !== modelConfig?.base}>
|
||||
<Combobox
|
||||
options={options}
|
||||
placeholder={t('controlnet.selectModel')}
|
||||
value={value}
|
||||
onChange={onChange}
|
||||
noOptionsMessage={noOptionsMessage}
|
||||
/>
|
||||
</FormControl>
|
||||
</Tooltip>
|
||||
);
|
||||
};
|
||||
|
||||
|
@ -1,24 +0,0 @@
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import {
|
||||
selectControlAdapterById,
|
||||
selectControlAdaptersSlice,
|
||||
} from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useControlAdapterCLIPVisionModel = (id: string) => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createMemoizedSelector(selectControlAdaptersSlice, (controlAdapters) => {
|
||||
const cn = selectControlAdapterById(controlAdapters, id);
|
||||
if (cn && cn?.type === 'ip_adapter') {
|
||||
return cn.clipVisionModel;
|
||||
}
|
||||
}),
|
||||
[id]
|
||||
);
|
||||
|
||||
const clipVisionModel = useAppSelector(selector);
|
||||
|
||||
return clipVisionModel;
|
||||
};
|
@ -2,11 +2,10 @@ import type { PayloadAction, Update } from '@reduxjs/toolkit';
|
||||
import { createEntityAdapter, createSlice, isAnyOf } from '@reduxjs/toolkit';
|
||||
import { getSelectorsOptions } from 'app/store/createMemoizedSelector';
|
||||
import type { PersistConfig, RootState } from 'app/store/store';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { buildControlAdapter } from 'features/controlAdapters/util/buildControlAdapter';
|
||||
import { buildControlAdapterProcessor } from 'features/controlAdapters/util/buildControlAdapterProcessor';
|
||||
import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||
import { merge, uniq } from 'lodash-es';
|
||||
import { cloneDeep, merge, uniq } from 'lodash-es';
|
||||
import type { ControlNetModelConfig, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types';
|
||||
import { socketInvocationError } from 'services/events/actions';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
@ -14,7 +13,6 @@ import { v4 as uuidv4 } from 'uuid';
|
||||
import { controlAdapterImageProcessed } from './actions';
|
||||
import { CONTROLNET_PROCESSORS } from './constants';
|
||||
import type {
|
||||
CLIPVisionModel,
|
||||
ControlAdapterConfig,
|
||||
ControlAdapterProcessorType,
|
||||
ControlAdaptersState,
|
||||
@ -116,7 +114,7 @@ export const controlAdaptersSlice = createSlice({
|
||||
if (!controlAdapter) {
|
||||
return;
|
||||
}
|
||||
const newControlAdapter = merge(deepClone(controlAdapter), {
|
||||
const newControlAdapter = merge(cloneDeep(controlAdapter), {
|
||||
id: newId,
|
||||
isEnabled: true,
|
||||
});
|
||||
@ -245,13 +243,6 @@ export const controlAdaptersSlice = createSlice({
|
||||
}
|
||||
caAdapter.updateOne(state, { id, changes: { controlMode } });
|
||||
},
|
||||
controlAdapterCLIPVisionModelChanged: (
|
||||
state,
|
||||
action: PayloadAction<{ id: string; clipVisionModel: CLIPVisionModel }>
|
||||
) => {
|
||||
const { id, clipVisionModel } = action.payload;
|
||||
caAdapter.updateOne(state, { id, changes: { clipVisionModel } });
|
||||
},
|
||||
controlAdapterResizeModeChanged: (
|
||||
state,
|
||||
action: PayloadAction<{
|
||||
@ -279,7 +270,7 @@ export const controlAdaptersSlice = createSlice({
|
||||
return;
|
||||
}
|
||||
|
||||
const processorNode = merge(deepClone(cn.processorNode), params);
|
||||
const processorNode = merge(cloneDeep(cn.processorNode), params);
|
||||
|
||||
caAdapter.updateOne(state, {
|
||||
id,
|
||||
@ -302,7 +293,7 @@ export const controlAdaptersSlice = createSlice({
|
||||
return;
|
||||
}
|
||||
|
||||
const processorNode = deepClone(
|
||||
const processorNode = cloneDeep(
|
||||
CONTROLNET_PROCESSORS[processorType].buildDefaults(cn.model?.base)
|
||||
) as RequiredControlAdapterProcessorNode;
|
||||
|
||||
@ -342,7 +333,7 @@ export const controlAdaptersSlice = createSlice({
|
||||
caAdapter.updateOne(state, update);
|
||||
},
|
||||
controlAdaptersReset: () => {
|
||||
return deepClone(initialControlAdaptersState);
|
||||
return cloneDeep(initialControlAdaptersState);
|
||||
},
|
||||
pendingControlImagesCleared: (state) => {
|
||||
state.pendingControlImages = [];
|
||||
@ -389,7 +380,6 @@ export const {
|
||||
controlAdapterProcessedImageChanged,
|
||||
controlAdapterIsEnabledChanged,
|
||||
controlAdapterModelChanged,
|
||||
controlAdapterCLIPVisionModelChanged,
|
||||
controlAdapterWeightChanged,
|
||||
controlAdapterBeginStepPctChanged,
|
||||
controlAdapterEndStepPctChanged,
|
||||
@ -416,7 +406,7 @@ const migrateControlAdaptersState = (state: any): any => {
|
||||
state._version = 1;
|
||||
}
|
||||
if (state._version === 1) {
|
||||
state = deepClone(initialControlAdaptersState);
|
||||
state = cloneDeep(initialControlAdaptersState);
|
||||
}
|
||||
return state;
|
||||
};
|
||||
|
@ -243,15 +243,12 @@ export type T2IAdapterConfig = {
|
||||
shouldAutoConfig: boolean;
|
||||
};
|
||||
|
||||
export type CLIPVisionModel = 'ViT-H' | 'ViT-G';
|
||||
|
||||
export type IPAdapterConfig = {
|
||||
type: 'ip_adapter';
|
||||
id: string;
|
||||
isEnabled: boolean;
|
||||
controlImage: string | null;
|
||||
model: ParameterIPAdapterModel | null;
|
||||
clipVisionModel: CLIPVisionModel;
|
||||
weight: number;
|
||||
beginStepPct: number;
|
||||
endStepPct: number;
|
||||
|
@ -1,4 +1,3 @@
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants';
|
||||
import type {
|
||||
ControlAdapterConfig,
|
||||
@ -8,7 +7,7 @@ import type {
|
||||
RequiredCannyImageProcessorInvocation,
|
||||
T2IAdapterConfig,
|
||||
} from 'features/controlAdapters/store/types';
|
||||
import { merge } from 'lodash-es';
|
||||
import { cloneDeep, merge } from 'lodash-es';
|
||||
|
||||
export const initialControlNet: Omit<ControlNetConfig, 'id'> = {
|
||||
type: 'controlnet',
|
||||
@ -46,7 +45,6 @@ export const initialIPAdapter: Omit<IPAdapterConfig, 'id'> = {
|
||||
isEnabled: true,
|
||||
controlImage: null,
|
||||
model: null,
|
||||
clipVisionModel: 'ViT-H',
|
||||
weight: 1,
|
||||
beginStepPct: 0,
|
||||
endStepPct: 1,
|
||||
@ -59,11 +57,11 @@ export const buildControlAdapter = (
|
||||
): ControlAdapterConfig => {
|
||||
switch (type) {
|
||||
case 'controlnet':
|
||||
return merge(deepClone(initialControlNet), { id, ...overrides });
|
||||
return merge(cloneDeep(initialControlNet), { id, ...overrides });
|
||||
case 't2i_adapter':
|
||||
return merge(deepClone(initialT2IAdapter), { id, ...overrides });
|
||||
return merge(cloneDeep(initialT2IAdapter), { id, ...overrides });
|
||||
case 'ip_adapter':
|
||||
return merge(deepClone(initialIPAdapter), { id, ...overrides });
|
||||
return merge(cloneDeep(initialIPAdapter), { id, ...overrides });
|
||||
default:
|
||||
throw new Error(`Unknown control adapter type: ${type}`);
|
||||
}
|
||||
|
@ -1,9 +1,9 @@
|
||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||
import { createSlice } from '@reduxjs/toolkit';
|
||||
import type { PersistConfig, RootState } from 'app/store/store';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||
import type { ParameterLoRAModel } from 'features/parameters/types/parameterSchemas';
|
||||
import { cloneDeep } from 'lodash-es';
|
||||
import type { LoRAModelConfig } from 'services/api/types';
|
||||
|
||||
export type LoRA = {
|
||||
@ -58,7 +58,7 @@ export const loraSlice = createSlice({
|
||||
}
|
||||
lora.isEnabled = isEnabled;
|
||||
},
|
||||
lorasReset: () => deepClone(initialLoraState),
|
||||
lorasReset: () => cloneDeep(initialLoraState),
|
||||
},
|
||||
});
|
||||
|
||||
@ -74,7 +74,7 @@ const migrateLoRAState = (state: any): any => {
|
||||
}
|
||||
if (state._version === 1) {
|
||||
// Model type has changed, so we need to reset the state - too risky to migrate
|
||||
state = deepClone(initialLoraState);
|
||||
state = cloneDeep(initialLoraState);
|
||||
}
|
||||
return state;
|
||||
};
|
||||
|
@ -372,7 +372,6 @@ const parseIPAdapter: MetadataParseFunc<IPAdapterConfigMetadata> = async (metada
|
||||
type: 'ip_adapter',
|
||||
isEnabled: true,
|
||||
model: zModelIdentifierField.parse(ipAdapterModel),
|
||||
clipVisionModel: 'ViT-H',
|
||||
controlImage: image?.image_name ?? null,
|
||||
weight: weight ?? initialIPAdapter.weight,
|
||||
beginStepPct: begin_step_percent ?? initialIPAdapter.beginStepPct,
|
||||
|
@ -87,10 +87,6 @@ export const ModelInstallQueueItem = (props: ModelListItemProps) => {
|
||||
}, [installJob.source]);
|
||||
|
||||
const progressValue = useMemo(() => {
|
||||
if (installJob.status === 'completed' || installJob.status === 'error' || installJob.status === 'cancelled') {
|
||||
return 100;
|
||||
}
|
||||
|
||||
if (isNil(installJob.bytes) || isNil(installJob.total_bytes)) {
|
||||
return null;
|
||||
}
|
||||
@ -100,7 +96,7 @@ export const ModelInstallQueueItem = (props: ModelListItemProps) => {
|
||||
}
|
||||
|
||||
return (installJob.bytes / installJob.total_bytes) * 100;
|
||||
}, [installJob.bytes, installJob.status, installJob.total_bytes]);
|
||||
}, [installJob.bytes, installJob.total_bytes]);
|
||||
|
||||
return (
|
||||
<Flex gap={3} w="full" alignItems="center">
|
||||
|
@ -1,19 +1,48 @@
|
||||
import { Badge, Box, Flex, IconButton, Text } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiPlusBold } from 'react-icons/pi';
|
||||
import type { ScanFolderResponse } from 'services/api/endpoints/models';
|
||||
import { useInstallModelMutation } from 'services/api/endpoints/models';
|
||||
|
||||
type Props = {
|
||||
result: ScanFolderResponse[number];
|
||||
installModel: (source: string) => void;
|
||||
};
|
||||
export const ScanModelResultItem = ({ result, installModel }: Props) => {
|
||||
export const ScanModelResultItem = ({ result }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const handleInstall = useCallback(() => {
|
||||
installModel(result.path);
|
||||
}, [installModel, result]);
|
||||
const [installModel] = useInstallModelMutation();
|
||||
|
||||
const handleQuickAdd = useCallback(() => {
|
||||
installModel({ source: result.path })
|
||||
.unwrap()
|
||||
.then((_) => {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('toast.modelAddedSimple'),
|
||||
status: 'success',
|
||||
})
|
||||
)
|
||||
);
|
||||
})
|
||||
.catch((error) => {
|
||||
if (error) {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: `${error.data.detail} `,
|
||||
status: 'error',
|
||||
})
|
||||
)
|
||||
);
|
||||
}
|
||||
});
|
||||
}, [installModel, result, dispatch, t]);
|
||||
|
||||
return (
|
||||
<Flex alignItems="center" justifyContent="space-between" w="100%" gap={3}>
|
||||
@ -25,7 +54,7 @@ export const ScanModelResultItem = ({ result, installModel }: Props) => {
|
||||
{result.is_installed ? (
|
||||
<Badge>{t('common.installed')}</Badge>
|
||||
) : (
|
||||
<IconButton aria-label={t('modelManager.install')} icon={<PiPlusBold />} onClick={handleInstall} size="sm" />
|
||||
<IconButton aria-label={t('modelManager.install')} icon={<PiPlusBold />} onClick={handleQuickAdd} size="sm" />
|
||||
)}
|
||||
</Box>
|
||||
</Flex>
|
||||
|
@ -1,10 +1,7 @@
|
||||
import {
|
||||
Button,
|
||||
Checkbox,
|
||||
Divider,
|
||||
Flex,
|
||||
FormControl,
|
||||
FormLabel,
|
||||
Heading,
|
||||
IconButton,
|
||||
Input,
|
||||
@ -15,7 +12,7 @@ import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import type { ChangeEvent, ChangeEventHandler } from 'react';
|
||||
import type { ChangeEventHandler } from 'react';
|
||||
import { useCallback, useMemo, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiXBold } from 'react-icons/pi';
|
||||
@ -31,7 +28,7 @@ export const ScanModelsResults = ({ results }: ScanModelResultsProps) => {
|
||||
const { t } = useTranslation();
|
||||
const [searchTerm, setSearchTerm] = useState('');
|
||||
const dispatch = useAppDispatch();
|
||||
const [inplace, setInplace] = useState(true);
|
||||
|
||||
const [installModel] = useInstallModelMutation();
|
||||
|
||||
const filteredResults = useMemo(() => {
|
||||
@ -45,10 +42,6 @@ export const ScanModelsResults = ({ results }: ScanModelResultsProps) => {
|
||||
setSearchTerm(e.target.value.trim());
|
||||
}, []);
|
||||
|
||||
const onChangeInplace = useCallback((e: ChangeEvent<HTMLInputElement>) => {
|
||||
setInplace(e.target.checked);
|
||||
}, []);
|
||||
|
||||
const clearSearch = useCallback(() => {
|
||||
setSearchTerm('');
|
||||
}, []);
|
||||
@ -58,7 +51,7 @@ export const ScanModelsResults = ({ results }: ScanModelResultsProps) => {
|
||||
if (result.is_installed) {
|
||||
continue;
|
||||
}
|
||||
installModel({ source: result.path, inplace })
|
||||
installModel({ source: result.path })
|
||||
.unwrap()
|
||||
.then((_) => {
|
||||
dispatch(
|
||||
@ -83,37 +76,7 @@ export const ScanModelsResults = ({ results }: ScanModelResultsProps) => {
|
||||
}
|
||||
});
|
||||
}
|
||||
}, [filteredResults, installModel, inplace, dispatch, t]);
|
||||
|
||||
const handleInstallOne = useCallback(
|
||||
(source: string) => {
|
||||
installModel({ source, inplace })
|
||||
.unwrap()
|
||||
.then((_) => {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('toast.modelAddedSimple'),
|
||||
status: 'success',
|
||||
})
|
||||
)
|
||||
);
|
||||
})
|
||||
.catch((error) => {
|
||||
if (error) {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: `${error.data.detail} `,
|
||||
status: 'error',
|
||||
})
|
||||
)
|
||||
);
|
||||
}
|
||||
});
|
||||
},
|
||||
[installModel, inplace, dispatch, t]
|
||||
);
|
||||
}, [installModel, filteredResults, dispatch, t]);
|
||||
|
||||
return (
|
||||
<>
|
||||
@ -122,10 +85,6 @@ export const ScanModelsResults = ({ results }: ScanModelResultsProps) => {
|
||||
<Flex justifyContent="space-between" alignItems="center">
|
||||
<Heading size="sm">{t('modelManager.scanResults')}</Heading>
|
||||
<Flex alignItems="center" gap={3}>
|
||||
<FormControl w="min-content">
|
||||
<FormLabel m={0}>{t('modelManager.inplaceInstall')}</FormLabel>
|
||||
<Checkbox isChecked={inplace} onChange={onChangeInplace} size="md" />
|
||||
</FormControl>
|
||||
<Button size="sm" onClick={handleAddAll} isDisabled={filteredResults.length === 0}>
|
||||
{t('modelManager.installAll')}
|
||||
</Button>
|
||||
@ -157,7 +116,7 @@ export const ScanModelsResults = ({ results }: ScanModelResultsProps) => {
|
||||
<ScrollableContent>
|
||||
<Flex flexDir="column" gap={3}>
|
||||
{filteredResults.map((result) => (
|
||||
<ScanModelResultItem key={result.path} result={result} installModel={handleInstallOne} />
|
||||
<ScanModelResultItem key={result.path} result={result} />
|
||||
))}
|
||||
</Flex>
|
||||
</ScrollableContent>
|
||||
|
@ -90,13 +90,11 @@ const ModelListItem = (props: ModelListItemProps) => {
|
||||
cursor="pointer"
|
||||
onClick={handleSelectModel}
|
||||
>
|
||||
<Flex gap={2} w="full" h="full" minW={0}>
|
||||
<Flex gap={2} w="full" h="full">
|
||||
<ModelImage image_url={model.cover_image} />
|
||||
<Flex gap={1} alignItems="flex-start" flexDir="column" w="full" minW={0}>
|
||||
<Flex gap={1} alignItems="flex-start" flexDir="column" w="full">
|
||||
<Flex gap={2} w="full" alignItems="flex-start">
|
||||
<Text fontWeight="semibold" noOfLines={1} wordBreak="break-all">
|
||||
{model.name}
|
||||
</Text>
|
||||
<Text fontWeight="semibold">{model.name}</Text>
|
||||
<Spacer />
|
||||
</Flex>
|
||||
<Text variant="subtext" noOfLines={1}>
|
||||
|
@ -87,9 +87,9 @@ export const Model = () => {
|
||||
<Flex flexDir="column" gap={4}>
|
||||
<Flex alignItems="flex-start" gap={4}>
|
||||
<ModelImageUpload model_key={selectedModelKey} model_image={data.cover_image} />
|
||||
<Flex flexDir="column" gap={1} flexGrow={1} minW={0}>
|
||||
<Flex flexDir="column" gap={1} flexGrow={1}>
|
||||
<Flex gap={2}>
|
||||
<Heading as="h2" fontSize="lg" noOfLines={1} wordBreak="break-all">
|
||||
<Heading as="h2" fontSize="lg">
|
||||
{data.name}
|
||||
</Heading>
|
||||
<Spacer />
|
||||
@ -114,7 +114,7 @@ export const Model = () => {
|
||||
)}
|
||||
</Flex>
|
||||
{data.source && (
|
||||
<Text variant="subtext" noOfLines={1} wordBreak="break-all">
|
||||
<Text variant="subtext">
|
||||
{t('modelManager.source')}: {data?.source}
|
||||
</Text>
|
||||
)}
|
||||
|
@ -9,9 +9,7 @@ export const ModelAttrView = ({ label, value }: Props) => {
|
||||
return (
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={0}>
|
||||
<FormLabel>{label}</FormLabel>
|
||||
<Text fontSize="md" noOfLines={1} wordBreak="break-all">
|
||||
{value || '-'}
|
||||
</Text>
|
||||
<Text fontSize="md">{value || '-'}</Text>
|
||||
</FormControl>
|
||||
);
|
||||
};
|
||||
|
@ -53,7 +53,7 @@ export const ModelView = () => {
|
||||
</>
|
||||
)}
|
||||
|
||||
{data.type === 'ip_adapter' && data.format === 'invokeai' && (
|
||||
{data.type === 'ip_adapter' && (
|
||||
<Flex gap={2}>
|
||||
<ModelAttrView label={t('modelManager.imageEncoderModelId')} value={data.image_encoder_model_id} />
|
||||
</Flex>
|
||||
|
@ -1,7 +1,6 @@
|
||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||
import { createSlice, isAnyOf } from '@reduxjs/toolkit';
|
||||
import type { PersistConfig, RootState } from 'app/store/store';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { workflowLoaded } from 'features/nodes/store/actions';
|
||||
import { SHARED_NODE_PROPERTIES } from 'features/nodes/types/constants';
|
||||
import type {
|
||||
@ -45,7 +44,7 @@ import {
|
||||
} from 'features/nodes/types/field';
|
||||
import type { AnyNode, InvocationTemplate, NodeExecutionState } from 'features/nodes/types/invocation';
|
||||
import { isInvocationNode, isNotesNode, zNodeStatus } from 'features/nodes/types/invocation';
|
||||
import { forEach } from 'lodash-es';
|
||||
import { cloneDeep, forEach } from 'lodash-es';
|
||||
import type {
|
||||
Connection,
|
||||
Edge,
|
||||
@ -572,23 +571,8 @@ export const nodesSlice = createSlice({
|
||||
);
|
||||
},
|
||||
selectionCopied: (state) => {
|
||||
const nodesToCopy: AnyNode[] = [];
|
||||
const edgesToCopy: Edge[] = [];
|
||||
|
||||
for (const node of state.nodes) {
|
||||
if (node.selected) {
|
||||
nodesToCopy.push(deepClone(node));
|
||||
}
|
||||
}
|
||||
|
||||
for (const edge of state.edges) {
|
||||
if (edge.selected) {
|
||||
edgesToCopy.push(deepClone(edge));
|
||||
}
|
||||
}
|
||||
|
||||
state.nodesToCopy = nodesToCopy;
|
||||
state.edgesToCopy = edgesToCopy;
|
||||
state.nodesToCopy = state.nodes.filter((n) => n.selected).map(cloneDeep);
|
||||
state.edgesToCopy = state.edges.filter((e) => e.selected).map(cloneDeep);
|
||||
|
||||
if (state.nodesToCopy.length > 0) {
|
||||
const averagePosition = { x: 0, y: 0 };
|
||||
@ -610,21 +594,11 @@ export const nodesSlice = createSlice({
|
||||
},
|
||||
selectionPasted: (state, action: PayloadAction<{ cursorPosition?: XYPosition }>) => {
|
||||
const { cursorPosition } = action.payload;
|
||||
const newNodes: AnyNode[] = [];
|
||||
|
||||
for (const node of state.nodesToCopy) {
|
||||
newNodes.push(deepClone(node));
|
||||
}
|
||||
|
||||
const newNodes = state.nodesToCopy.map(cloneDeep);
|
||||
const oldNodeIds = newNodes.map((n) => n.data.id);
|
||||
|
||||
const newEdges: Edge[] = [];
|
||||
|
||||
for (const edge of state.edgesToCopy) {
|
||||
if (oldNodeIds.includes(edge.source) && oldNodeIds.includes(edge.target)) {
|
||||
newEdges.push(deepClone(edge));
|
||||
}
|
||||
}
|
||||
const newEdges = state.edgesToCopy
|
||||
.filter((e) => oldNodeIds.includes(e.source) && oldNodeIds.includes(e.target))
|
||||
.map(cloneDeep);
|
||||
|
||||
newEdges.forEach((e) => (e.selected = true));
|
||||
|
||||
|
@ -1,7 +1,6 @@
|
||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||
import { createSlice } from '@reduxjs/toolkit';
|
||||
import type { PersistConfig, RootState } from 'app/store/store';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { workflowLoaded } from 'features/nodes/store/actions';
|
||||
import { isAnyNodeOrEdgeMutation, nodeEditorReset, nodesChanged, nodesDeleted } from 'features/nodes/store/nodesSlice';
|
||||
import type {
|
||||
@ -12,7 +11,7 @@ import type {
|
||||
import type { FieldIdentifier } from 'features/nodes/types/field';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import type { WorkflowCategory, WorkflowV3 } from 'features/nodes/types/workflow';
|
||||
import { isEqual, omit, uniqBy } from 'lodash-es';
|
||||
import { cloneDeep, isEqual, omit, uniqBy } from 'lodash-es';
|
||||
|
||||
const blankWorkflow: Omit<WorkflowV3, 'nodes' | 'edges'> = {
|
||||
name: '',
|
||||
@ -132,8 +131,8 @@ export const workflowSlice = createSlice({
|
||||
});
|
||||
|
||||
return {
|
||||
...deepClone(initialWorkflowState),
|
||||
...deepClone(workflowExtra),
|
||||
...cloneDeep(initialWorkflowState),
|
||||
...cloneDeep(workflowExtra),
|
||||
originalExposedFieldValues,
|
||||
mode: state.mode,
|
||||
};
|
||||
@ -145,7 +144,7 @@ export const workflowSlice = createSlice({
|
||||
});
|
||||
});
|
||||
|
||||
builder.addCase(nodeEditorReset, () => deepClone(initialWorkflowState));
|
||||
builder.addCase(nodeEditorReset, () => cloneDeep(initialWorkflowState));
|
||||
|
||||
builder.addCase(nodesChanged, (state, action) => {
|
||||
// Not all changes to nodes should result in the workflow being marked touched
|
||||
|
@ -48,7 +48,7 @@ export const addIPAdapterToLinearGraph = async (
|
||||
if (!ipAdapter.model) {
|
||||
return;
|
||||
}
|
||||
const { id, weight, model, clipVisionModel, beginStepPct, endStepPct, controlImage } = ipAdapter;
|
||||
const { id, weight, model, beginStepPct, endStepPct, controlImage } = ipAdapter;
|
||||
|
||||
assert(controlImage, 'IP Adapter image is required');
|
||||
|
||||
@ -58,7 +58,6 @@ export const addIPAdapterToLinearGraph = async (
|
||||
is_intermediate: true,
|
||||
weight: weight,
|
||||
ip_adapter_model: model,
|
||||
clip_vision_model: clipVisionModel,
|
||||
begin_step_percent: beginStepPct,
|
||||
end_step_percent: endStepPct,
|
||||
image: {
|
||||
@ -84,7 +83,7 @@ export const addIPAdapterToLinearGraph = async (
|
||||
};
|
||||
|
||||
const buildIPAdapterMetadata = (ipAdapter: IPAdapterConfig): S['IPAdapterMetadataField'] => {
|
||||
const { controlImage, beginStepPct, endStepPct, model, clipVisionModel, weight } = ipAdapter;
|
||||
const { controlImage, beginStepPct, endStepPct, model, weight } = ipAdapter;
|
||||
|
||||
assert(model, 'IP Adapter model is required');
|
||||
|
||||
@ -100,7 +99,6 @@ const buildIPAdapterMetadata = (ipAdapter: IPAdapterConfig): S['IPAdapterMetadat
|
||||
|
||||
return {
|
||||
ip_adapter_model: model,
|
||||
clip_vision_model: clipVisionModel,
|
||||
weight,
|
||||
begin_step_percent: beginStepPct,
|
||||
end_step_percent: endStepPct,
|
||||
|
@ -1,9 +1,8 @@
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { satisfies } from 'compare-versions';
|
||||
import { NodeUpdateError } from 'features/nodes/types/error';
|
||||
import type { InvocationNode, InvocationTemplate } from 'features/nodes/types/invocation';
|
||||
import { zParsedSemver } from 'features/nodes/types/semver';
|
||||
import { defaultsDeep, keys, pick } from 'lodash-es';
|
||||
import { cloneDeep, defaultsDeep, keys, pick } from 'lodash-es';
|
||||
|
||||
import { buildInvocationNode } from './buildInvocationNode';
|
||||
|
||||
@ -51,7 +50,7 @@ export const updateNode = (node: InvocationNode, template: InvocationTemplate):
|
||||
// The updateability of a node, via semver comparison, relies on the this kind of recursive merge
|
||||
// being valid. We rely on the template's major version to be majorly incremented if this kind of
|
||||
// merge would result in an invalid node.
|
||||
const clone = deepClone(node);
|
||||
const clone = cloneDeep(node);
|
||||
clone.data.version = template.version;
|
||||
defaultsDeep(clone, defaults); // mutates!
|
||||
|
||||
|
@ -1,12 +1,11 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import type { NodesState, WorkflowsState } from 'features/nodes/store/types';
|
||||
import { isInvocationNode, isNotesNode } from 'features/nodes/types/invocation';
|
||||
import type { WorkflowV3 } from 'features/nodes/types/workflow';
|
||||
import { zWorkflowV3 } from 'features/nodes/types/workflow';
|
||||
import i18n from 'i18n';
|
||||
import { pick } from 'lodash-es';
|
||||
import { cloneDeep, pick } from 'lodash-es';
|
||||
import { fromZodError } from 'zod-validation-error';
|
||||
|
||||
export type BuildWorkflowArg = {
|
||||
@ -31,7 +30,7 @@ const workflowKeys = [
|
||||
type BuildWorkflowFunction = (arg: BuildWorkflowArg) => WorkflowV3;
|
||||
|
||||
export const buildWorkflowFast: BuildWorkflowFunction = ({ nodes, edges, workflow }: BuildWorkflowArg): WorkflowV3 => {
|
||||
const clonedWorkflow = pick(deepClone(workflow), workflowKeys);
|
||||
const clonedWorkflow = pick(cloneDeep(workflow), workflowKeys);
|
||||
|
||||
const newWorkflow: WorkflowV3 = {
|
||||
...clonedWorkflow,
|
||||
@ -44,14 +43,14 @@ export const buildWorkflowFast: BuildWorkflowFunction = ({ nodes, edges, workflo
|
||||
newWorkflow.nodes.push({
|
||||
id: node.id,
|
||||
type: node.type,
|
||||
data: deepClone(node.data),
|
||||
data: cloneDeep(node.data),
|
||||
position: { ...node.position },
|
||||
});
|
||||
} else if (isNotesNode(node) && node.type) {
|
||||
newWorkflow.nodes.push({
|
||||
id: node.id,
|
||||
type: node.type,
|
||||
data: deepClone(node.data),
|
||||
data: cloneDeep(node.data),
|
||||
position: { ...node.position },
|
||||
});
|
||||
}
|
||||
|
@ -1,5 +1,4 @@
|
||||
import { $store } from 'app/store/nanostores/store';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { WorkflowMigrationError, WorkflowVersionError } from 'features/nodes/types/error';
|
||||
import type { FieldType } from 'features/nodes/types/field';
|
||||
import type { InvocationNodeData } from 'features/nodes/types/invocation';
|
||||
@ -12,7 +11,7 @@ import { zWorkflowV2 } from 'features/nodes/types/v2/workflow';
|
||||
import type { WorkflowV3 } from 'features/nodes/types/workflow';
|
||||
import { zWorkflowV3 } from 'features/nodes/types/workflow';
|
||||
import { t } from 'i18next';
|
||||
import { forEach } from 'lodash-es';
|
||||
import { cloneDeep, forEach } from 'lodash-es';
|
||||
import { z } from 'zod';
|
||||
|
||||
/**
|
||||
@ -90,7 +89,7 @@ export const parseAndMigrateWorkflow = (data: unknown): WorkflowV3 => {
|
||||
throw new WorkflowVersionError(t('nodes.unableToGetWorkflowVersion'));
|
||||
}
|
||||
|
||||
let workflow = deepClone(data) as WorkflowV1 | WorkflowV2 | WorkflowV3;
|
||||
let workflow = cloneDeep(data) as WorkflowV1 | WorkflowV2 | WorkflowV3;
|
||||
|
||||
if (workflow.meta.version === '1.0.0') {
|
||||
const v1 = zWorkflowV1.parse(workflow);
|
||||
|
@ -280,7 +280,6 @@ const migrateGenerationState = (state: any): GenerationState => {
|
||||
// The signature of the model has changed, so we need to reset it
|
||||
state._version = 2;
|
||||
state.model = null;
|
||||
state.canvasCoherenceMode = initialGenerationState.canvasCoherenceMode;
|
||||
}
|
||||
return state;
|
||||
};
|
||||
|
@ -61,7 +61,7 @@ export const AdvancedSettingsAccordion = memo(() => {
|
||||
|
||||
return (
|
||||
<StandaloneAccordion label={t('accordions.advanced.title')} badges={badges} isOpen={isOpen} onToggle={onToggle}>
|
||||
<Flex gap={4} alignItems="center" p={4} flexDir="column" data-testid="advanced-settings-accordion">
|
||||
<Flex gap={4} alignItems="center" p={4} flexDir="column">
|
||||
<Flex gap={4} w="full">
|
||||
<ParamVAEModelSelect />
|
||||
<ParamVAEPrecision />
|
||||
|
@ -77,7 +77,7 @@ export const ControlSettingsAccordion: React.FC = memo(() => {
|
||||
|
||||
return (
|
||||
<StandaloneAccordion label={t('accordions.control.title')} badges={badges} isOpen={isOpen} onToggle={onToggle}>
|
||||
<Flex gap={2} p={4} flexDir="column" data-testid="control-accordion">
|
||||
<Flex gap={2} p={4} flexDir="column">
|
||||
<ButtonGroup size="sm" w="full" justifyContent="space-between" variant="ghost" isAttached={false}>
|
||||
<Button
|
||||
tooltip={t('controlnet.addControlNet')}
|
||||
|
@ -53,7 +53,7 @@ export const GenerationSettingsAccordion = memo(() => {
|
||||
isOpen={isOpenAccordion}
|
||||
onToggle={onToggleAccordion}
|
||||
>
|
||||
<Box px={4} pt={4} data-testid="generation-accordion">
|
||||
<Box px={4} pt={4}>
|
||||
<Flex gap={4} flexDir="column">
|
||||
<Flex gap={4} alignItems="center">
|
||||
<ParamMainModelSelect />
|
||||
|
@ -83,7 +83,7 @@ export const ImageSettingsAccordion = memo(() => {
|
||||
isOpen={isOpenAccordion}
|
||||
onToggle={onToggleAccordion}
|
||||
>
|
||||
<Flex px={4} pt={4} w="full" h="full" flexDir="column" data-testid="image-settings-accordion">
|
||||
<Flex px={4} pt={4} w="full" h="full" flexDir="column">
|
||||
{activeTabName === 'unifiedCanvas' ? <ImageSizeCanvas /> : <ImageSizeLinear />}
|
||||
<Expander label={t('accordions.advanced.options')} isOpen={isOpenExpander} onToggle={onToggleExpander}>
|
||||
<Flex gap={4} pb={4} flexDir="column">
|
||||
|
@ -195,7 +195,6 @@ export const modelsApi = api.injectEndpoints({
|
||||
url: buildModelsUrl(`scan_folder?${folderQueryStr}`),
|
||||
};
|
||||
},
|
||||
providesTags: [{ type: 'ModelScanFolderResults', id: LIST_TAG }],
|
||||
}),
|
||||
getHuggingFaceModels: build.query<GetHuggingFaceModelsResponse, string>({
|
||||
query: (hugging_face_repo) => {
|
||||
|
@ -192,7 +192,7 @@ export const queueApi = api.injectEndpoints({
|
||||
{ batch_id: string }
|
||||
>({
|
||||
query: ({ batch_id }) => ({
|
||||
url: buildQueueUrl(`b/${batch_id}/status`),
|
||||
url: buildQueueUrl(`/b/${batch_id}/status`),
|
||||
method: 'GET',
|
||||
}),
|
||||
providesTags: (result) => {
|
||||
|
@ -29,7 +29,6 @@ const tagTypes = [
|
||||
'InvocationCacheStatus',
|
||||
'ModelConfig',
|
||||
'ModelInstalls',
|
||||
'ModelScanFolderResults',
|
||||
'T2IAdapterModel',
|
||||
'MainModel',
|
||||
'VaeModel',
|
||||
|
File diff suppressed because one or more lines are too long
@ -46,7 +46,7 @@ export type LoRAModelConfig = S['LoRADiffusersConfig'] | S['LoRALyCORISConfig'];
|
||||
// TODO(MM2): Can we rename this from Vae -> VAE
|
||||
export type VAEModelConfig = S['VAECheckpointConfig'] | S['VAEDiffusersConfig'];
|
||||
export type ControlNetModelConfig = S['ControlNetDiffusersConfig'] | S['ControlNetCheckpointConfig'];
|
||||
export type IPAdapterModelConfig = S['IPAdapterInvokeAIConfig'] | S['IPAdapterCheckpointConfig'];
|
||||
export type IPAdapterModelConfig = S['IPAdapterConfig'];
|
||||
export type T2IAdapterModelConfig = S['T2IAdapterConfig'];
|
||||
type TextualInversionModelConfig = S['TextualInversionFileConfig'] | S['TextualInversionFolderConfig'];
|
||||
type DiffusersModelConfig = S['MainDiffusersConfig'];
|
||||
|
@ -1 +1 @@
|
||||
__version__ = "4.0.2"
|
||||
__version__ = "4.0.0rc6"
|
||||
|
@ -87,11 +87,9 @@ def test_rename(
|
||||
key = mm2_installer.install_path(embedding_file)
|
||||
model_record = store.get_model(key)
|
||||
assert model_record.path.endswith("sd-1/embedding/test_embedding.safetensors")
|
||||
store.update_model(key, ModelRecordChanges(name="new model name", base=BaseModelType("sd-2")))
|
||||
store.update_model(key, ModelRecordChanges(name="new_name.safetensors", base=BaseModelType("sd-2")))
|
||||
new_model_record = mm2_installer.sync_model_path(key)
|
||||
# Renaming the model record shouldn't rename the file
|
||||
assert new_model_record.name == "new model name"
|
||||
assert new_model_record.path.endswith("sd-2/embedding/test_embedding.safetensors")
|
||||
assert new_model_record.path.endswith("sd-2/embedding/new_name.safetensors")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
@ -98,32 +98,6 @@ def test_migrate_v3_config_from_file(tmp_path: Path, patch_rootdir: None):
|
||||
assert not hasattr(config, "esrgan")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"legacy_conf_dir,expected_value,expected_is_set",
|
||||
[
|
||||
# not set, expected value is the default value
|
||||
("configs/stable-diffusion", Path("configs"), False),
|
||||
# not set, expected value is the default value
|
||||
("configs\\stable-diffusion", Path("configs"), False),
|
||||
# set, best-effort resolution of the path
|
||||
("partial_custom_path/stable-diffusion", Path("partial_custom_path"), True),
|
||||
# set, exact path
|
||||
("full/custom/path", Path("full/custom/path"), True),
|
||||
],
|
||||
)
|
||||
def test_migrate_v3_legacy_conf_dir_defaults(
|
||||
tmp_path: Path, patch_rootdir: None, legacy_conf_dir: str, expected_value: Path, expected_is_set: bool
|
||||
):
|
||||
"""Test reading configuration from a file."""
|
||||
config_content = f"InvokeAI:\n Paths:\n legacy_conf_dir: {legacy_conf_dir}"
|
||||
temp_config_file = tmp_path / "temp_invokeai.yaml"
|
||||
temp_config_file.write_text(config_content)
|
||||
|
||||
config = load_and_migrate_config(temp_config_file)
|
||||
assert config.legacy_conf_dir == expected_value
|
||||
assert ("legacy_conf_dir" in config.model_fields_set) is expected_is_set
|
||||
|
||||
|
||||
def test_migrate_v3_backup(tmp_path: Path, patch_rootdir: None):
|
||||
"""Test the backup of the config file."""
|
||||
temp_config_file = tmp_path / "temp_invokeai.yaml"
|
||||
|
@ -250,32 +250,6 @@ def test_migrator_runs_all_migrations_file(logger: Logger) -> None:
|
||||
db.conn.close()
|
||||
|
||||
|
||||
def test_migrator_backs_up_db(logger: Logger) -> None:
|
||||
with TemporaryDirectory() as tempdir:
|
||||
original_db_path = Path(tempdir) / "invokeai.db"
|
||||
db = SqliteDatabase(db_path=original_db_path, logger=logger, verbose=False)
|
||||
# Write some data to the db to test for successful backup
|
||||
temp_cursor = db.conn.cursor()
|
||||
temp_cursor.execute("CREATE TABLE test (id INTEGER PRIMARY KEY);")
|
||||
db.conn.commit()
|
||||
# Set up the migrator
|
||||
migrator = SqliteMigrator(db=db)
|
||||
migrations = [Migration(from_version=i, to_version=i + 1, callback=create_migrate(i)) for i in range(0, 3)]
|
||||
for migration in migrations:
|
||||
migrator.register_migration(migration)
|
||||
migrator.run_migrations()
|
||||
# Must manually close else we get an error on Windows
|
||||
db.conn.close()
|
||||
assert original_db_path.exists()
|
||||
# We should have a backup file when we migrated a file db
|
||||
assert migrator._backup_path
|
||||
# Check that the test table exists as a proxy for successful backup
|
||||
with closing(sqlite3.connect(migrator._backup_path)) as backup_db_conn:
|
||||
backup_db_cursor = backup_db_conn.cursor()
|
||||
backup_db_cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='test';")
|
||||
assert backup_db_cursor.fetchone() is not None
|
||||
|
||||
|
||||
def test_migrator_makes_no_changes_on_failed_migration(
|
||||
migrator: SqliteMigrator, migration_no_op: Migration, failing_migrate_callback: MigrateCallback
|
||||
) -> None:
|
||||
|
Reference in New Issue
Block a user