Merge branch 'main' into fix/detect-more-loras

This commit is contained in:
Kent Keirsey 2023-08-10 17:33:16 -04:00 committed by GitHub
commit f6522c8971
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 956 additions and 607 deletions

View File

@ -161,7 +161,7 @@ the command `npm install -g yarn` if needed)
_For Windows/Linux with an NVIDIA GPU:_ _For Windows/Linux with an NVIDIA GPU:_
```terminal ```terminal
pip install "InvokeAI[xformers]" --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu117 pip install "InvokeAI[xformers]" --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu118
``` ```
_For Linux with an AMD GPU:_ _For Linux with an AMD GPU:_
@ -306,13 +306,30 @@ InvokeAI. The second will prepare the 2.3 directory for use with 3.0.
You may now launch the WebUI in the usual way, by selecting option [1] You may now launch the WebUI in the usual way, by selecting option [1]
from the launcher script from the launcher script
#### Migration Caveats #### Migrating Images
The migration script will migrate your invokeai settings and models, The migration script will migrate your invokeai settings and models,
including textual inversion models, LoRAs and merges that you may have including textual inversion models, LoRAs and merges that you may have
installed previously. However it does **not** migrate the generated installed previously. However it does **not** migrate the generated
images stored in your 2.3-format outputs directory. You will need to images stored in your 2.3-format outputs directory. To do this, you
manually import selected images into the 3.0 gallery via drag-and-drop. need to run an additional step:
1. From a working InvokeAI 3.0 root directory, start the launcher and
enter menu option [8] to open the "developer's console".
2. At the developer's console command line, type the command:
```bash
invokeai-import-images
```
3. This will lead you through the process of confirming the desired
source and destination for the imported images. The images will
appear in the gallery board of your choice, and contain the
original prompt, model name, and other parameters used to generate
the image.
(Many kudos to **techjedi** for contributing this script.)
## Hardware Requirements ## Hardware Requirements

View File

@ -264,7 +264,7 @@ experimental versions later.
you can create several levels of subfolders and drop your models into you can create several levels of subfolders and drop your models into
whichever ones you want. whichever ones you want.
- ***Autoimport FolderLICENSE*** - ***LICENSE***
At the bottom of the screen you will see a checkbox for accepting At the bottom of the screen you will see a checkbox for accepting
the CreativeML Responsible AI Licenses. You need to accept the license the CreativeML Responsible AI Licenses. You need to accept the license
@ -471,7 +471,7 @@ Then type the following commands:
=== "NVIDIA System" === "NVIDIA System"
```bash ```bash
pip install torch torchvision --force-reinstall --extra-index-url https://download.pytorch.org/whl/cu117 pip install torch torchvision --force-reinstall --extra-index-url https://download.pytorch.org/whl/cu118
pip install xformers pip install xformers
``` ```

View File

@ -148,7 +148,7 @@ manager, please follow these steps:
=== "CUDA (NVidia)" === "CUDA (NVidia)"
```bash ```bash
pip install "InvokeAI[xformers]" --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu117 pip install "InvokeAI[xformers]" --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu118
``` ```
=== "ROCm (AMD)" === "ROCm (AMD)"
@ -312,7 +312,7 @@ installation protocol (important!)
=== "CUDA (NVidia)" === "CUDA (NVidia)"
```bash ```bash
pip install -e .[xformers] --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu117 pip install -e .[xformers] --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu118
``` ```
=== "ROCm (AMD)" === "ROCm (AMD)"
@ -356,7 +356,7 @@ you can do so using this unsupported recipe:
mkdir ~/invokeai mkdir ~/invokeai
conda create -n invokeai python=3.10 conda create -n invokeai python=3.10
conda activate invokeai conda activate invokeai
pip install InvokeAI[xformers] --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu117 pip install InvokeAI[xformers] --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu118
invokeai-configure --root ~/invokeai invokeai-configure --root ~/invokeai
invokeai --root ~/invokeai --web invokeai --root ~/invokeai --web
``` ```

View File

@ -34,11 +34,11 @@ directly from NVIDIA. **Do not try to install Ubuntu's
nvidia-cuda-toolkit package. It is out of date and will cause nvidia-cuda-toolkit package. It is out of date and will cause
conflicts among the NVIDIA driver and binaries.** conflicts among the NVIDIA driver and binaries.**
Go to [CUDA Toolkit 11.7 Go to [CUDA Toolkit
Downloads](https://developer.nvidia.com/cuda-11-7-0-download-archive), Downloads](https://developer.nvidia.com/cuda-downloads), and use the
and use the target selection wizard to choose your operating system, target selection wizard to choose your operating system, hardware
hardware platform, and preferred installation method (e.g. "local" platform, and preferred installation method (e.g. "local" versus
versus "network"). "network").
This will provide you with a downloadable install file or, depending This will provide you with a downloadable install file or, depending
on your choices, a recipe for downloading and running a install shell on your choices, a recipe for downloading and running a install shell
@ -61,7 +61,7 @@ Runtime Site](https://developer.nvidia.com/nvidia-container-runtime)
When installing torch and torchvision manually with `pip`, remember to provide When installing torch and torchvision manually with `pip`, remember to provide
the argument `--extra-index-url the argument `--extra-index-url
https://download.pytorch.org/whl/cu117` as described in the [Manual https://download.pytorch.org/whl/cu118` as described in the [Manual
Installation Guide](020_INSTALL_MANUAL.md). Installation Guide](020_INSTALL_MANUAL.md).
## :simple-amd: ROCm ## :simple-amd: ROCm

View File

@ -28,18 +28,21 @@ command line, then just be sure to activate it's virtual environment.
Then run the following three commands: Then run the following three commands:
```sh ```sh
pip install xformers==0.0.16rc425 pip install xformers~=0.0.19
pip install triton pip install triton # WON'T WORK ON WINDOWS
python -m xformers.info output python -m xformers.info output
``` ```
The first command installs `xformers`, the second installs the The first command installs `xformers`, the second installs the
`triton` training accelerator, and the third prints out the `xformers` `triton` training accelerator, and the third prints out the `xformers`
installation status. If all goes well, you'll see a report like the installation status. On Windows, please omit the `triton` package,
which is not available on that platform.
If all goes well, you'll see a report like the
following: following:
```sh ```sh
xFormers 0.0.16rc425 xFormers 0.0.20
memory_efficient_attention.cutlassF: available memory_efficient_attention.cutlassF: available
memory_efficient_attention.cutlassB: available memory_efficient_attention.cutlassB: available
memory_efficient_attention.flshattF: available memory_efficient_attention.flshattF: available
@ -48,22 +51,28 @@ memory_efficient_attention.smallkF: available
memory_efficient_attention.smallkB: available memory_efficient_attention.smallkB: available
memory_efficient_attention.tritonflashattF: available memory_efficient_attention.tritonflashattF: available
memory_efficient_attention.tritonflashattB: available memory_efficient_attention.tritonflashattB: available
indexing.scaled_index_addF: available
indexing.scaled_index_addB: available
indexing.index_select: available
swiglu.dual_gemm_silu: available
swiglu.gemm_fused_operand_sum: available
swiglu.fused.p.cpp: available swiglu.fused.p.cpp: available
is_triton_available: True is_triton_available: True
is_functorch_available: False is_functorch_available: False
pytorch.version: 1.13.1+cu117 pytorch.version: 2.0.1+cu118
pytorch.cuda: available pytorch.cuda: available
gpu.compute_capability: 8.6 gpu.compute_capability: 8.9
gpu.name: NVIDIA RTX A2000 12GB gpu.name: NVIDIA GeForce RTX 4070
build.info: available build.info: available
build.cuda_version: 1107 build.cuda_version: 1108
build.python_version: 3.10.9 build.python_version: 3.10.11
build.torch_version: 1.13.1+cu117 build.torch_version: 2.0.1+cu118
build.env.TORCH_CUDA_ARCH_LIST: 5.0+PTX 6.0 6.1 7.0 7.5 8.0 8.6 build.env.TORCH_CUDA_ARCH_LIST: 5.0+PTX 6.0 6.1 7.0 7.5 8.0 8.6
build.env.XFORMERS_BUILD_TYPE: Release build.env.XFORMERS_BUILD_TYPE: Release
build.env.XFORMERS_ENABLE_DEBUG_ASSERTIONS: None build.env.XFORMERS_ENABLE_DEBUG_ASSERTIONS: None
build.env.NVCC_FLAGS: None build.env.NVCC_FLAGS: None
build.env.XFORMERS_PACKAGE_FROM: wheel-v0.0.16rc425 build.env.XFORMERS_PACKAGE_FROM: wheel-v0.0.20
build.nvcc_version: 11.8.89
source.privacy: open source source.privacy: open source
``` ```
@ -83,14 +92,14 @@ installed from source. These instructions were written for a system
running Ubuntu 22.04, but other Linux distributions should be able to running Ubuntu 22.04, but other Linux distributions should be able to
adapt this recipe. adapt this recipe.
#### 1. Install CUDA Toolkit 11.7 #### 1. Install CUDA Toolkit 11.8
You will need the CUDA developer's toolkit in order to compile and You will need the CUDA developer's toolkit in order to compile and
install xFormers. **Do not try to install Ubuntu's nvidia-cuda-toolkit install xFormers. **Do not try to install Ubuntu's nvidia-cuda-toolkit
package.** It is out of date and will cause conflicts among the NVIDIA package.** It is out of date and will cause conflicts among the NVIDIA
driver and binaries. Instead install the CUDA Toolkit package provided driver and binaries. Instead install the CUDA Toolkit package provided
by NVIDIA itself. Go to [CUDA Toolkit 11.7 by NVIDIA itself. Go to [CUDA Toolkit 11.8
Downloads](https://developer.nvidia.com/cuda-11-7-0-download-archive) Downloads](https://developer.nvidia.com/cuda-11-8-0-download-archive)
and use the target selection wizard to choose your platform and Linux and use the target selection wizard to choose your platform and Linux
distribution. Select an installer type of "runfile (local)" at the distribution. Select an installer type of "runfile (local)" at the
last step. last step.
@ -101,17 +110,17 @@ example, the install script recipe for Ubuntu 22.04 running on a
x86_64 system is: x86_64 system is:
``` ```
wget https://developer.download.nvidia.com/compute/cuda/11.7.0/local_installers/cuda_11.7.0_515.43.04_linux.run wget https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run
sudo sh cuda_11.7.0_515.43.04_linux.run sudo sh cuda_11.8.0_520.61.05_linux.run
``` ```
Rather than cut-and-paste this example, We recommend that you walk Rather than cut-and-paste this example, We recommend that you walk
through the toolkit wizard in order to get the most up to date through the toolkit wizard in order to get the most up to date
installer for your system. installer for your system.
#### 2. Confirm/Install pyTorch 1.13 with CUDA 11.7 support #### 2. Confirm/Install pyTorch 2.01 with CUDA 11.8 support
If you are using InvokeAI 2.3 or higher, these will already be If you are using InvokeAI 3.0.2 or higher, these will already be
installed. If not, you can check whether you have the needed libraries installed. If not, you can check whether you have the needed libraries
using a quick command. Activate the invokeai virtual environment, using a quick command. Activate the invokeai virtual environment,
either by entering the "developer's console", or manually with a either by entering the "developer's console", or manually with a
@ -124,7 +133,7 @@ Then run the command:
python -c 'exec("import torch\nprint(torch.__version__)")' python -c 'exec("import torch\nprint(torch.__version__)")'
``` ```
If it prints __1.13.1+cu117__ you're good. If not, you can install the If it prints __1.13.1+cu118__ you're good. If not, you can install the
most up to date libraries with this command: most up to date libraries with this command:
```sh ```sh

View File

@ -463,10 +463,10 @@ def get_torch_source() -> (Union[str, None], str):
url = "https://download.pytorch.org/whl/cpu" url = "https://download.pytorch.org/whl/cpu"
if device == "cuda": if device == "cuda":
url = "https://download.pytorch.org/whl/cu117" url = "https://download.pytorch.org/whl/cu118"
optional_modules = "[xformers,onnx-cuda]" optional_modules = "[xformers,onnx-cuda]"
if device == "cuda_and_dml": if device == "cuda_and_dml":
url = "https://download.pytorch.org/whl/cu117" url = "https://download.pytorch.org/whl/cu118"
optional_modules = "[xformers,onnx-directml]" optional_modules = "[xformers,onnx-directml]"
# in all other cases, Torch wheels should be coming from PyPi as of Torch 1.13 # in all other cases, Torch wheels should be coming from PyPi as of Torch 1.13

View File

@ -104,8 +104,12 @@ async def update_model(
): # model manager moved model path during rename - don't overwrite it ): # model manager moved model path during rename - don't overwrite it
info.path = new_info.get("path") info.path = new_info.get("path")
# replace empty string values with None/null to avoid phenomenon of vae: ''
info_dict = info.dict()
info_dict = {x: info_dict[x] if info_dict[x] else None for x in info_dict.keys()}
ApiDependencies.invoker.services.model_manager.update_model( ApiDependencies.invoker.services.model_manager.update_model(
model_name=model_name, base_model=base_model, model_type=model_type, model_attributes=info.dict() model_name=model_name, base_model=base_model, model_type=model_type, model_attributes=info_dict
) )
model_raw = ApiDependencies.invoker.services.model_manager.list_model( model_raw = ApiDependencies.invoker.services.model_manager.list_model(

View File

@ -1,26 +1,23 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from contextlib import contextmanager, ContextDecorator
from functools import partial from functools import partial
from typing import Literal, Optional, get_args from typing import Literal, Optional, get_args
import torch
from pydantic import Field from pydantic import Field
from invokeai.app.models.image import ColorField, ImageCategory, ImageField, ResourceOrigin from invokeai.app.models.image import ColorField, ImageCategory, ImageField, ResourceOrigin
from invokeai.app.util.misc import SEED_MAX, get_random_seed from invokeai.app.util.misc import SEED_MAX, get_random_seed
from invokeai.backend.generator.inpaint import infill_methods from invokeai.backend.generator.inpaint import infill_methods
from ...backend.generator import Inpaint, InvokeAIGenerator
from ...backend.stable_diffusion import PipelineIntermediateState
from ..util.step_callback import stable_diffusion_step_callback
from .baseinvocation import BaseInvocation, InvocationConfig, InvocationContext from .baseinvocation import BaseInvocation, InvocationConfig, InvocationContext
from .image import ImageOutput
from ...backend.model_management.lora import ModelPatcher
from ...backend.stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline
from .model import UNetField, VaeField
from .compel import ConditioningField from .compel import ConditioningField
from contextlib import contextmanager, ExitStack, ContextDecorator from .image import ImageOutput
from .model import UNetField, VaeField
from ..util.step_callback import stable_diffusion_step_callback
from ...backend.generator import Inpaint, InvokeAIGenerator
from ...backend.model_management.lora import ModelPatcher
from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline
SAMPLER_NAME_VALUES = Literal[tuple(InvokeAIGenerator.schedulers())] SAMPLER_NAME_VALUES = Literal[tuple(InvokeAIGenerator.schedulers())]
INFILL_METHODS = Literal[tuple(infill_methods())] INFILL_METHODS = Literal[tuple(infill_methods())]
@ -193,8 +190,6 @@ class InpaintInvocation(BaseInvocation):
safety_checker=None, safety_checker=None,
feature_extractor=None, feature_extractor=None,
requires_safety_checker=False, requires_safety_checker=False,
precision="float16" if dtype == torch.float16 else "float32",
execution_device=device,
) )
yield OldModelInfo( yield OldModelInfo(

View File

@ -501,7 +501,7 @@ class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
image = context.services.images.get_pil_image(self.image.image_name) image = context.services.images.get_pil_image(self.image.image_name)
image_arr = numpy.asarray(image, dtype=numpy.float32) / 255 image_arr = numpy.asarray(image, dtype=numpy.float32) / 255
image_arr = image_arr * (self.max - self.min) + self.max image_arr = image_arr * (self.max - self.min) + self.min
lerp_image = Image.fromarray(numpy.uint8(image_arr)) lerp_image = Image.fromarray(numpy.uint8(image_arr))

View File

@ -5,15 +5,26 @@ from typing import List, Literal, Optional, Union
import einops import einops
import torch import torch
from diffusers import ControlNetModel
from diffusers.image_processor import VaeImageProcessor from diffusers.image_processor import VaeImageProcessor
from diffusers.models.attention_processor import (
AttnProcessor2_0,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor,
XFormersAttnProcessor,
)
from diffusers.schedulers import SchedulerMixin as Scheduler from diffusers.schedulers import SchedulerMixin as Scheduler
from pydantic import BaseModel, Field, validator from pydantic import BaseModel, Field, validator
from invokeai.app.invocations.metadata import CoreMetadata from invokeai.app.invocations.metadata import CoreMetadata
from invokeai.app.util.controlnet_utils import prepare_control_image
from invokeai.app.util.step_callback import stable_diffusion_step_callback from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend.model_management.models import ModelType, SilenceWarnings from invokeai.backend.model_management.models import ModelType, SilenceWarnings
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
from .compel import ConditioningField
from .controlnet_image_processors import ControlField
from .image import ImageOutput
from .model import ModelInfo, UNetField, VaeField
from ..models.image import ImageCategory, ImageField, ResourceOrigin
from ...backend.model_management import ModelPatcher from ...backend.model_management import ModelPatcher
from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.stable_diffusion.diffusers_pipeline import ( from ...backend.stable_diffusion.diffusers_pipeline import (
@ -24,23 +35,7 @@ from ...backend.stable_diffusion.diffusers_pipeline import (
) )
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
from ...backend.model_management import ModelPatcher
from ...backend.util.devices import choose_torch_device, torch_dtype, choose_precision from ...backend.util.devices import choose_torch_device, torch_dtype, choose_precision
from ..models.image import ImageCategory, ImageField, ResourceOrigin
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
from .compel import ConditioningField
from .controlnet_image_processors import ControlField
from .image import ImageOutput
from .model import ModelInfo, UNetField, VaeField
from invokeai.app.util.controlnet_utils import prepare_control_image
from diffusers.models.attention_processor import (
AttnProcessor2_0,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor,
XFormersAttnProcessor,
)
DEFAULT_PRECISION = choose_precision(choose_torch_device()) DEFAULT_PRECISION = choose_precision(choose_torch_device())
@ -231,7 +226,6 @@ class TextToLatentsInvocation(BaseInvocation):
safety_checker=None, safety_checker=None,
feature_extractor=None, feature_extractor=None,
requires_safety_checker=False, requires_safety_checker=False,
precision="float16" if unet.dtype == torch.float16 else "float32",
) )
def prep_control_data( def prep_control_data(

View File

@ -2,6 +2,7 @@ from typing import Literal, Optional, Union
from pydantic import Field from pydantic import Field
from ...version import __version__
from invokeai.app.invocations.baseinvocation import ( from invokeai.app.invocations.baseinvocation import (
BaseInvocation, BaseInvocation,
BaseInvocationOutput, BaseInvocationOutput,
@ -23,6 +24,7 @@ class LoRAMetadataField(BaseModelExcludeNull):
class CoreMetadata(BaseModelExcludeNull): class CoreMetadata(BaseModelExcludeNull):
"""Core generation metadata for an image generated in InvokeAI.""" """Core generation metadata for an image generated in InvokeAI."""
app_version: str = Field(default=__version__, description="The version of InvokeAI used to generate this image")
generation_mode: str = Field( generation_mode: str = Field(
description="The generation mode that output this image", description="The generation mode that output this image",
) )

View File

@ -1,25 +1,11 @@
""" """
invokeai.backend.generator.img2img descends from .generator invokeai.backend.generator.img2img descends from .generator
""" """
from typing import Optional
import torch
from accelerate.utils import set_seed
from diffusers import logging
from ..stable_diffusion import (
ConditioningData,
PostprocessingSettings,
StableDiffusionGeneratorPipeline,
)
from .base import Generator from .base import Generator
class Img2Img(Generator): class Img2Img(Generator):
def __init__(self, model, precision):
super().__init__(model, precision)
self.init_latent = None # by get_noise()
def get_make_image( def get_make_image(
self, self,
sampler, sampler,
@ -42,51 +28,4 @@ class Img2Img(Generator):
Returns a function returning an image derived from the prompt and the initial image Returns a function returning an image derived from the prompt and the initial image
Return value depends on the seed at the time you call it. Return value depends on the seed at the time you call it.
""" """
self.perlin = perlin raise NotImplementedError("replaced by invokeai.app.invocations.latent.LatentsToLatentsInvocation")
# noinspection PyTypeChecker
pipeline: StableDiffusionGeneratorPipeline = self.model
pipeline.scheduler = sampler
uc, c, extra_conditioning_info = conditioning
conditioning_data = ConditioningData(
uc,
c,
cfg_scale,
extra_conditioning_info,
postprocessing_settings=PostprocessingSettings(
threshold=threshold,
warmup=warmup,
h_symmetry_time_pct=h_symmetry_time_pct,
v_symmetry_time_pct=v_symmetry_time_pct,
),
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)
def make_image(x_T: torch.Tensor, seed: int):
# FIXME: use x_T for initial seeded noise
# We're not at the moment because the pipeline automatically resizes init_image if
# necessary, which the x_T input might not match.
# In the meantime, reset the seed prior to generating pipeline output so we at least get the same result.
logging.set_verbosity_error() # quench safety check warnings
pipeline_output = pipeline.img2img_from_embeddings(
init_image,
strength,
steps,
conditioning_data,
noise_func=self.get_noise_like,
callback=step_callback,
seed=seed,
)
if pipeline_output.attention_map_saver is not None and attention_maps_callback is not None:
attention_maps_callback(pipeline_output.attention_map_saver)
return pipeline.numpy_to_pil(pipeline_output.images)[0]
return make_image
def get_noise_like(self, like: torch.Tensor):
device = like.device
x = torch.randn_like(like, device=device)
if self.perlin > 0.0:
shape = like.shape
x = (1 - self.perlin) * x + self.perlin * self.get_perlin_noise(shape[3], shape[2])
return x

View File

@ -377,3 +377,11 @@ class Inpaint(Img2Img):
) )
return corrected_result return corrected_result
def get_noise_like(self, like: torch.Tensor):
device = like.device
x = torch.randn_like(like, device=device)
if self.perlin > 0.0:
shape = like.shape
x = (1 - self.perlin) * x + self.perlin * self.get_perlin_noise(shape[3], shape[2])
return x

View File

@ -526,7 +526,7 @@ class ModelManager(object):
# Does the config explicitly override the submodel? # Does the config explicitly override the submodel?
if submodel_type is not None and hasattr(model_config, submodel_type): if submodel_type is not None and hasattr(model_config, submodel_type):
submodel_path = getattr(model_config, submodel_type) submodel_path = getattr(model_config, submodel_type)
if submodel_path is not None: if submodel_path is not None and len(submodel_path) > 0:
model_path = getattr(model_config, submodel_type) model_path = getattr(model_config, submodel_type)
is_submodel_override = True is_submodel_override = True

View File

@ -4,25 +4,21 @@ import dataclasses
import inspect import inspect
import math import math
import secrets import secrets
from collections.abc import Sequence
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Callable, Generic, List, Optional, Type, TypeVar, Union from typing import Any, Callable, Generic, List, Optional, Type, TypeVar, Union
from pydantic import Field
import einops
import PIL.Image import PIL.Image
import numpy as np import einops
from accelerate.utils import set_seed
import psutil import psutil
import torch import torch
import torchvision.transforms as T import torchvision.transforms as T
from accelerate.utils import set_seed
from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.models.controlnet import ControlNetModel from diffusers.models.controlnet import ControlNetModel
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import ( from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
StableDiffusionPipeline, StableDiffusionPipeline,
) )
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import ( from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import (
StableDiffusionImg2ImgPipeline, StableDiffusionImg2ImgPipeline,
) )
@ -31,21 +27,20 @@ from diffusers.pipelines.stable_diffusion.safety_checker import (
) )
from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
from diffusers.utils import PIL_INTERPOLATION
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.outputs import BaseOutput from diffusers.utils.outputs import BaseOutput
from pydantic import Field
from torchvision.transforms.functional import resize as tv_resize from torchvision.transforms.functional import resize as tv_resize
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from typing_extensions import ParamSpec from typing_extensions import ParamSpec
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from ..util import CPU_DEVICE, normalize_device
from .diffusion import ( from .diffusion import (
AttentionMapSaver, AttentionMapSaver,
InvokeAIDiffuserComponent, InvokeAIDiffuserComponent,
PostprocessingSettings, PostprocessingSettings,
) )
from .offloading import FullyLoadedModelGroup, ModelGroup from ..util import normalize_device
@dataclass @dataclass
@ -289,8 +284,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
feature_extractor ([`CLIPFeatureExtractor`]): feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`. Model that extracts features from generated images to be used as inputs for the `safety_checker`.
""" """
_model_group: ModelGroup
ID_LENGTH = 8 ID_LENGTH = 8
def __init__( def __init__(
@ -303,9 +296,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
safety_checker: Optional[StableDiffusionSafetyChecker], safety_checker: Optional[StableDiffusionSafetyChecker],
feature_extractor: Optional[CLIPFeatureExtractor], feature_extractor: Optional[CLIPFeatureExtractor],
requires_safety_checker: bool = False, requires_safety_checker: bool = False,
precision: str = "float32",
control_model: ControlNetModel = None, control_model: ControlNetModel = None,
execution_device: Optional[torch.device] = None,
): ):
super().__init__( super().__init__(
vae, vae,
@ -330,9 +321,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# control_model=control_model, # control_model=control_model,
) )
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet, self._unet_forward) self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet, self._unet_forward)
self._model_group = FullyLoadedModelGroup(execution_device or self.unet.device)
self._model_group.install(*self._submodels)
self.control_model = control_model self.control_model = control_model
def _adjust_memory_efficient_attention(self, latents: torch.Tensor): def _adjust_memory_efficient_attention(self, latents: torch.Tensor):
@ -368,72 +356,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
else: else:
self.disable_attention_slicing() self.disable_attention_slicing()
def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings=False):
# overridden method; types match the superclass.
if torch_device is None:
return self
self._model_group.set_device(torch.device(torch_device))
self._model_group.ready()
@property
def device(self) -> torch.device:
return self._model_group.execution_device
@property
def _submodels(self) -> Sequence[torch.nn.Module]:
module_names, _, _ = self.extract_init_dict(dict(self.config))
submodels = []
for name in module_names.keys():
if hasattr(self, name):
value = getattr(self, name)
else:
value = getattr(self.config, name)
if isinstance(value, torch.nn.Module):
submodels.append(value)
return submodels
def image_from_embeddings(
self,
latents: torch.Tensor,
num_inference_steps: int,
conditioning_data: ConditioningData,
*,
noise: torch.Tensor,
callback: Callable[[PipelineIntermediateState], None] = None,
run_id=None,
) -> InvokeAIStableDiffusionPipelineOutput:
r"""
Function invoked when calling the pipeline for generation.
:param conditioning_data:
:param latents: Pre-generated un-noised latents, to be used as inputs for
image generation. Can be used to tweak the same generation with different prompts.
:param num_inference_steps: The number of denoising steps. More denoising steps usually lead to a higher quality
image at the expense of slower inference.
:param noise: Noise to add to the latents, sampled from a Gaussian distribution.
:param callback:
:param run_id:
"""
result_latents, result_attention_map_saver = self.latents_from_embeddings(
latents,
num_inference_steps,
conditioning_data,
noise=noise,
run_id=run_id,
callback=callback,
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache()
with torch.inference_mode():
image = self.decode_latents(result_latents)
output = InvokeAIStableDiffusionPipelineOutput(
images=image,
nsfw_content_detected=[],
attention_map_saver=result_attention_map_saver,
)
return self.check_for_safety(output, dtype=conditioning_data.dtype)
def latents_from_embeddings( def latents_from_embeddings(
self, self,
latents: torch.Tensor, latents: torch.Tensor,
@ -450,7 +372,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
if self.scheduler.config.get("cpu_only", False): if self.scheduler.config.get("cpu_only", False):
scheduler_device = torch.device("cpu") scheduler_device = torch.device("cpu")
else: else:
scheduler_device = self._model_group.device_for(self.unet) scheduler_device = self.unet.device
if timesteps is None: if timesteps is None:
self.scheduler.set_timesteps(num_inference_steps, device=scheduler_device) self.scheduler.set_timesteps(num_inference_steps, device=scheduler_device)
@ -504,7 +426,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
(batch_size,), (batch_size,),
timesteps[0], timesteps[0],
dtype=timesteps.dtype, dtype=timesteps.dtype,
device=self._model_group.device_for(self.unet), device=self.unet.device,
) )
latents = self.scheduler.add_noise(latents, noise, batched_t) latents = self.scheduler.add_noise(latents, noise, batched_t)
@ -700,79 +622,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
**kwargs, **kwargs,
).sample ).sample
def img2img_from_embeddings(
self,
init_image: Union[torch.FloatTensor, PIL.Image.Image],
strength: float,
num_inference_steps: int,
conditioning_data: ConditioningData,
*,
callback: Callable[[PipelineIntermediateState], None] = None,
run_id=None,
noise_func=None,
seed=None,
) -> InvokeAIStableDiffusionPipelineOutput:
if isinstance(init_image, PIL.Image.Image):
init_image = image_resized_to_grid_as_tensor(init_image.convert("RGB"))
if init_image.dim() == 3:
init_image = einops.rearrange(init_image, "c h w -> 1 c h w")
# 6. Prepare latent variables
initial_latents = self.non_noised_latents_from_image(
init_image,
device=self._model_group.device_for(self.unet),
dtype=self.unet.dtype,
)
if seed is not None:
set_seed(seed)
noise = noise_func(initial_latents)
return self.img2img_from_latents_and_embeddings(
initial_latents,
num_inference_steps,
conditioning_data,
strength,
noise,
run_id,
callback,
)
def img2img_from_latents_and_embeddings(
self,
initial_latents,
num_inference_steps,
conditioning_data: ConditioningData,
strength,
noise: torch.Tensor,
run_id=None,
callback=None,
) -> InvokeAIStableDiffusionPipelineOutput:
timesteps, _ = self.get_img2img_timesteps(num_inference_steps, strength)
result_latents, result_attention_maps = self.latents_from_embeddings(
latents=initial_latents
if strength < 1.0
else torch.zeros_like(initial_latents, device=initial_latents.device, dtype=initial_latents.dtype),
num_inference_steps=num_inference_steps,
conditioning_data=conditioning_data,
timesteps=timesteps,
noise=noise,
run_id=run_id,
callback=callback,
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache()
with torch.inference_mode():
image = self.decode_latents(result_latents)
output = InvokeAIStableDiffusionPipelineOutput(
images=image,
nsfw_content_detected=[],
attention_map_saver=result_attention_maps,
)
return self.check_for_safety(output, dtype=conditioning_data.dtype)
def get_img2img_timesteps(self, num_inference_steps: int, strength: float, device=None) -> (torch.Tensor, int): def get_img2img_timesteps(self, num_inference_steps: int, strength: float, device=None) -> (torch.Tensor, int):
img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components) img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components)
assert img2img_pipeline.scheduler is self.scheduler assert img2img_pipeline.scheduler is self.scheduler
@ -780,7 +629,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
if self.scheduler.config.get("cpu_only", False): if self.scheduler.config.get("cpu_only", False):
scheduler_device = torch.device("cpu") scheduler_device = torch.device("cpu")
else: else:
scheduler_device = self._model_group.device_for(self.unet) scheduler_device = self.unet.device
img2img_pipeline.scheduler.set_timesteps(num_inference_steps, device=scheduler_device) img2img_pipeline.scheduler.set_timesteps(num_inference_steps, device=scheduler_device)
timesteps, adjusted_steps = img2img_pipeline.get_timesteps( timesteps, adjusted_steps = img2img_pipeline.get_timesteps(
@ -806,7 +655,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
noise_func=None, noise_func=None,
seed=None, seed=None,
) -> InvokeAIStableDiffusionPipelineOutput: ) -> InvokeAIStableDiffusionPipelineOutput:
device = self._model_group.device_for(self.unet) device = self.unet.device
latents_dtype = self.unet.dtype latents_dtype = self.unet.dtype
if isinstance(init_image, PIL.Image.Image): if isinstance(init_image, PIL.Image.Image):
@ -877,42 +726,17 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
nsfw_content_detected=[], nsfw_content_detected=[],
attention_map_saver=result_attention_maps, attention_map_saver=result_attention_maps,
) )
return self.check_for_safety(output, dtype=conditioning_data.dtype) return output
def non_noised_latents_from_image(self, init_image, *, device: torch.device, dtype): def non_noised_latents_from_image(self, init_image, *, device: torch.device, dtype):
init_image = init_image.to(device=device, dtype=dtype) init_image = init_image.to(device=device, dtype=dtype)
with torch.inference_mode(): with torch.inference_mode():
self._model_group.load(self.vae)
init_latent_dist = self.vae.encode(init_image).latent_dist init_latent_dist = self.vae.encode(init_image).latent_dist
init_latents = init_latent_dist.sample().to(dtype=dtype) # FIXME: uses torch.randn. make reproducible! init_latents = init_latent_dist.sample().to(dtype=dtype) # FIXME: uses torch.randn. make reproducible!
init_latents = 0.18215 * init_latents init_latents = 0.18215 * init_latents
return init_latents return init_latents
def check_for_safety(self, output, dtype):
with torch.inference_mode():
screened_images, has_nsfw_concept = self.run_safety_checker(output.images, dtype=dtype)
screened_attention_map_saver = None
if has_nsfw_concept is None or not has_nsfw_concept:
screened_attention_map_saver = output.attention_map_saver
return InvokeAIStableDiffusionPipelineOutput(
screened_images,
has_nsfw_concept,
# block the attention maps if NSFW content is detected
attention_map_saver=screened_attention_map_saver,
)
def run_safety_checker(self, image, device=None, dtype=None):
# overriding to use the model group for device info instead of requiring the caller to know.
if self.safety_checker is not None:
device = self._model_group.device_for(self.safety_checker)
return super().run_safety_checker(image, device, dtype)
def decode_latents(self, latents):
# Explicit call to get the vae loaded, since `decode` isn't the forward method.
self._model_group.load(self.vae)
return super().decode_latents(latents)
def debug_latents(self, latents, msg): def debug_latents(self, latents, msg):
from invokeai.backend.image_util import debug_image from invokeai.backend.image_util import debug_image

View File

@ -1,253 +0,0 @@
from __future__ import annotations
import warnings
import weakref
from abc import ABCMeta, abstractmethod
from collections.abc import MutableMapping
from typing import Callable, Union
import torch
from accelerate.utils import send_to_device
from torch.utils.hooks import RemovableHandle
OFFLOAD_DEVICE = torch.device("cpu")
class _NoModel:
"""Symbol that indicates no model is loaded.
(We can't weakref.ref(None), so this was my best idea at the time to come up with something
type-checkable.)
"""
def __bool__(self):
return False
def to(self, device: torch.device):
pass
def __repr__(self):
return "<NO MODEL>"
NO_MODEL = _NoModel()
class ModelGroup(metaclass=ABCMeta):
"""
A group of models.
The use case I had in mind when writing this is the sub-models used by a DiffusionPipeline,
e.g. its text encoder, U-net, VAE, etc.
Those models are :py:class:`diffusers.ModelMixin`, but "model" is interchangeable with
:py:class:`torch.nn.Module` here.
"""
def __init__(self, execution_device: torch.device):
self.execution_device = execution_device
@abstractmethod
def install(self, *models: torch.nn.Module):
"""Add models to this group."""
pass
@abstractmethod
def uninstall(self, models: torch.nn.Module):
"""Remove models from this group."""
pass
@abstractmethod
def uninstall_all(self):
"""Remove all models from this group."""
@abstractmethod
def load(self, model: torch.nn.Module):
"""Load this model to the execution device."""
pass
@abstractmethod
def offload_current(self):
"""Offload the current model(s) from the execution device."""
pass
@abstractmethod
def ready(self):
"""Ready this group for use."""
pass
@abstractmethod
def set_device(self, device: torch.device):
"""Change which device models from this group will execute on."""
pass
@abstractmethod
def device_for(self, model) -> torch.device:
"""Get the device the given model will execute on.
The model should already be a member of this group.
"""
pass
@abstractmethod
def __contains__(self, model):
"""Check if the model is a member of this group."""
pass
def __repr__(self) -> str:
return f"<{self.__class__.__name__} object at {id(self):x}: " f"device={self.execution_device} >"
class LazilyLoadedModelGroup(ModelGroup):
"""
Only one model from this group is loaded on the GPU at a time.
Running the forward method of a model will displace the previously-loaded model,
offloading it to CPU.
If you call other methods on the model, e.g. ``model.encode(x)`` instead of ``model(x)``,
you will need to explicitly load it with :py:method:`.load(model)`.
This implementation relies on pytorch forward-pre-hooks, and it will copy forward arguments
to the appropriate execution device, as long as they are positional arguments and not keyword
arguments. (I didn't make the rules; that's the way the pytorch 1.13 API works for hooks.)
"""
_hooks: MutableMapping[torch.nn.Module, RemovableHandle]
_current_model_ref: Callable[[], Union[torch.nn.Module, _NoModel]]
def __init__(self, execution_device: torch.device):
super().__init__(execution_device)
self._hooks = weakref.WeakKeyDictionary()
self._current_model_ref = weakref.ref(NO_MODEL)
def install(self, *models: torch.nn.Module):
for model in models:
self._hooks[model] = model.register_forward_pre_hook(self._pre_hook)
def uninstall(self, *models: torch.nn.Module):
for model in models:
hook = self._hooks.pop(model)
hook.remove()
if self.is_current_model(model):
# no longer hooked by this object, so don't claim to manage it
self.clear_current_model()
def uninstall_all(self):
self.uninstall(*self._hooks.keys())
def _pre_hook(self, module: torch.nn.Module, forward_input):
self.load(module)
if len(forward_input) == 0:
warnings.warn(
f"Hook for {module.__class__.__name__} got no input. " f"Inputs must be positional, not keywords.",
stacklevel=3,
)
return send_to_device(forward_input, self.execution_device)
def load(self, module):
if not self.is_current_model(module):
self.offload_current()
self._load(module)
def offload_current(self):
module = self._current_model_ref()
if module is not NO_MODEL:
module.to(OFFLOAD_DEVICE)
self.clear_current_model()
def _load(self, module: torch.nn.Module) -> torch.nn.Module:
assert self.is_empty(), f"A model is already loaded: {self._current_model_ref()}"
module = module.to(self.execution_device)
self.set_current_model(module)
return module
def is_current_model(self, model: torch.nn.Module) -> bool:
"""Is the given model the one currently loaded on the execution device?"""
return self._current_model_ref() is model
def is_empty(self):
"""Are none of this group's models loaded on the execution device?"""
return self._current_model_ref() is NO_MODEL
def set_current_model(self, value):
self._current_model_ref = weakref.ref(value)
def clear_current_model(self):
self._current_model_ref = weakref.ref(NO_MODEL)
def set_device(self, device: torch.device):
if device == self.execution_device:
return
self.execution_device = device
current = self._current_model_ref()
if current is not NO_MODEL:
current.to(device)
def device_for(self, model):
if model not in self:
raise KeyError(f"This does not manage this model {type(model).__name__}", model)
return self.execution_device # this implementation only dispatches to one device
def ready(self):
pass # always ready to load on-demand
def __contains__(self, model):
return model in self._hooks
def __repr__(self) -> str:
return (
f"<{self.__class__.__name__} object at {id(self):x}: "
f"current_model={type(self._current_model_ref()).__name__} >"
)
class FullyLoadedModelGroup(ModelGroup):
"""
A group of models without any implicit loading or unloading.
:py:meth:`.ready` loads _all_ the models to the execution device at once.
"""
_models: weakref.WeakSet
def __init__(self, execution_device: torch.device):
super().__init__(execution_device)
self._models = weakref.WeakSet()
def install(self, *models: torch.nn.Module):
for model in models:
self._models.add(model)
model.to(self.execution_device)
def uninstall(self, *models: torch.nn.Module):
for model in models:
self._models.remove(model)
def uninstall_all(self):
self.uninstall(*self._models)
def load(self, model):
model.to(self.execution_device)
def offload_current(self):
for model in self._models:
model.to(OFFLOAD_DEVICE)
def ready(self):
for model in self._models:
self.load(model)
def set_device(self, device: torch.device):
self.execution_device = device
for model in self._models:
if model.device != OFFLOAD_DEVICE:
model.to(device)
def device_for(self, model):
if model not in self:
raise KeyError("This does not manage this model f{type(model).__name__}", model)
return self.execution_device # this implementation only dispatches to one device
def __contains__(self, model):
return model in self._models

View File

@ -0,0 +1,795 @@
# Copyright (c) 2023 - The InvokeAI Team
# Primary Author: David Lovell (github @f412design, discord @techjedi)
# co-author, minor tweaks - Lincoln Stein
# pylint: disable=line-too-long
# pylint: disable=broad-exception-caught
"""Script to import images into the new database system for 3.0.0"""
import os
import datetime
import shutil
import locale
import sqlite3
import json
import glob
import re
import uuid
import yaml
import PIL
import PIL.ImageOps
import PIL.PngImagePlugin
from pathlib import Path
from prompt_toolkit import prompt
from prompt_toolkit.shortcuts import message_dialog
from prompt_toolkit.completion import PathCompleter
from prompt_toolkit.key_binding import KeyBindings
from invokeai.app.services.config import InvokeAIAppConfig
app_config = InvokeAIAppConfig.get_config()
bindings = KeyBindings()
@bindings.add("c-c")
def _(event):
raise KeyboardInterrupt
# release notes
# "Use All" with size dimensions not selectable in the UI will not load dimensions
class Config:
"""Configuration loader."""
def __init__(self):
pass
TIMESTAMP_STRING = datetime.datetime.utcnow().strftime("%Y%m%dT%H%M%SZ")
INVOKE_DIRNAME = "invokeai"
YAML_FILENAME = "invokeai.yaml"
DATABASE_FILENAME = "invokeai.db"
database_path = None
database_backup_dir = None
outputs_path = None
thumbnail_path = None
def find_and_load(self):
"""find the yaml config file and load"""
root = app_config.root_path
if not self.confirm_and_load(os.path.abspath(root)):
print("\r\nSpecify custom database and outputs paths:")
self.confirm_and_load_from_user()
self.database_backup_dir = os.path.join(os.path.dirname(self.database_path), "backup")
self.thumbnail_path = os.path.join(self.outputs_path, "thumbnails")
def confirm_and_load(self, invoke_root):
"""Validates a yaml path exists, confirms the user wants to use it and loads config."""
yaml_path = os.path.join(invoke_root, self.YAML_FILENAME)
if os.path.exists(yaml_path):
db_dir, outdir = self.load_paths_from_yaml(yaml_path)
if os.path.isabs(db_dir):
database_path = os.path.join(db_dir, self.DATABASE_FILENAME)
else:
database_path = os.path.join(invoke_root, db_dir, self.DATABASE_FILENAME)
if os.path.isabs(outdir):
outputs_path = os.path.join(outdir, "images")
else:
outputs_path = os.path.join(invoke_root, outdir, "images")
db_exists = os.path.exists(database_path)
outdir_exists = os.path.exists(outputs_path)
text = f"Found {self.YAML_FILENAME} file at {yaml_path}:"
text += f"\n Database : {database_path}"
text += f"\n Outputs : {outputs_path}"
text += "\n\nUse these paths for import (yes) or choose different ones (no) [Yn]: "
if db_exists and outdir_exists:
if (prompt(text).strip() or "Y").upper().startswith("Y"):
self.database_path = database_path
self.outputs_path = outputs_path
return True
else:
return False
else:
print(" Invalid: One or more paths in this config did not exist and cannot be used.")
else:
message_dialog(
title="Path not found",
text=f"Auto-discovery of configuration failed! Could not find ({yaml_path}), Custom paths can be specified.",
).run()
return False
def confirm_and_load_from_user(self):
default = ""
while True:
database_path = os.path.expanduser(
prompt(
"Database: Specify absolute path to the database to import into: ",
completer=PathCompleter(
expanduser=True, file_filter=lambda x: Path(x).is_dir() or x.endswith((".db"))
),
default=default,
)
)
if database_path.endswith(".db") and os.path.isabs(database_path) and os.path.exists(database_path):
break
default = database_path + "/" if Path(database_path).is_dir() else database_path
default = ""
while True:
outputs_path = os.path.expanduser(
prompt(
"Outputs: Specify absolute path to outputs/images directory to import into: ",
completer=PathCompleter(expanduser=True, only_directories=True),
default=default,
)
)
if outputs_path.endswith("images") and os.path.isabs(outputs_path) and os.path.exists(outputs_path):
break
default = outputs_path + "/" if Path(outputs_path).is_dir() else outputs_path
self.database_path = database_path
self.outputs_path = outputs_path
return
def load_paths_from_yaml(self, yaml_path):
"""Load an Invoke AI yaml file and get the database and outputs paths."""
try:
with open(yaml_path, "rt", encoding=locale.getpreferredencoding()) as file:
yamlinfo = yaml.safe_load(file)
db_dir = yamlinfo.get("InvokeAI", {}).get("Paths", {}).get("db_dir", None)
outdir = yamlinfo.get("InvokeAI", {}).get("Paths", {}).get("outdir", None)
return db_dir, outdir
except Exception:
print(f"Failed to load paths from yaml file! {yaml_path}!")
return None, None
class ImportStats:
"""DTO for tracking work progress."""
def __init__(self):
pass
time_start = datetime.datetime.utcnow()
count_source_files = 0
count_skipped_file_exists = 0
count_skipped_db_exists = 0
count_imported = 0
count_imported_by_version = {}
count_file_errors = 0
@staticmethod
def get_elapsed_time_string():
"""Get a friendly time string for the time elapsed since processing start."""
time_now = datetime.datetime.utcnow()
total_seconds = (time_now - ImportStats.time_start).total_seconds()
hours = int((total_seconds) / 3600)
minutes = int(((total_seconds) % 3600) / 60)
seconds = total_seconds % 60
out_str = f"{hours} hour(s) -" if hours > 0 else ""
out_str += f"{minutes} minute(s) -" if minutes > 0 else ""
out_str += f"{seconds:.2f} second(s)"
return out_str
class InvokeAIMetadata:
"""DTO for core Invoke AI generation properties parsed from metadata."""
def __init__(self):
pass
def __str__(self):
formatted_str = f"{self.generation_mode}~{self.steps}~{self.cfg_scale}~{self.model_name}~{self.scheduler}~{self.seed}~{self.width}~{self.height}~{self.rand_device}~{self.strength}~{self.init_image}"
formatted_str += f"\r\npositive_prompt: {self.positive_prompt}"
formatted_str += f"\r\nnegative_prompt: {self.negative_prompt}"
return formatted_str
generation_mode = None
steps = None
cfg_scale = None
model_name = None
scheduler = None
seed = None
width = None
height = None
rand_device = None
strength = None
init_image = None
positive_prompt = None
negative_prompt = None
imported_app_version = None
def to_json(self):
"""Convert the active instance to json format."""
prop_dict = {}
prop_dict["generation_mode"] = self.generation_mode
# dont render prompt nodes if neither are set to avoid the ui thinking it can set them
# if at least one exists, render them both, but use empty string instead of None if one of them is empty
# this allows the field that is empty to actually be cleared byt he UI instead of leaving the previous value
if self.positive_prompt or self.negative_prompt:
prop_dict["positive_prompt"] = "" if self.positive_prompt is None else self.positive_prompt
prop_dict["negative_prompt"] = "" if self.negative_prompt is None else self.negative_prompt
prop_dict["width"] = self.width
prop_dict["height"] = self.height
# only render seed if it has a value to avoid ui thinking it can set this and then error
if self.seed:
prop_dict["seed"] = self.seed
prop_dict["rand_device"] = self.rand_device
prop_dict["cfg_scale"] = self.cfg_scale
prop_dict["steps"] = self.steps
prop_dict["scheduler"] = self.scheduler
prop_dict["clip_skip"] = 0
prop_dict["model"] = {}
prop_dict["model"]["model_name"] = self.model_name
prop_dict["model"]["base_model"] = None
prop_dict["controlnets"] = []
prop_dict["loras"] = []
prop_dict["vae"] = None
prop_dict["strength"] = self.strength
prop_dict["init_image"] = self.init_image
prop_dict["positive_style_prompt"] = None
prop_dict["negative_style_prompt"] = None
prop_dict["refiner_model"] = None
prop_dict["refiner_cfg_scale"] = None
prop_dict["refiner_steps"] = None
prop_dict["refiner_scheduler"] = None
prop_dict["refiner_aesthetic_store"] = None
prop_dict["refiner_start"] = None
prop_dict["imported_app_version"] = self.imported_app_version
return json.dumps(prop_dict)
class InvokeAIMetadataParser:
"""Parses strings with json data to find Invoke AI core metadata properties."""
def __init__(self):
pass
def parse_meta_tag_dream(self, dream_string):
"""Take as input an png metadata json node for the 'dream' field variant from prior to 1.15"""
props = InvokeAIMetadata()
props.imported_app_version = "pre1.15"
seed_match = re.search("-S\\s*(\\d+)", dream_string)
if seed_match is not None:
try:
props.seed = int(seed_match[1])
except ValueError:
props.seed = None
raw_prompt = re.sub("(-S\\s*\\d+)", "", dream_string)
else:
raw_prompt = dream_string
pos_prompt, neg_prompt = self.split_prompt(raw_prompt)
props.positive_prompt = pos_prompt
props.negative_prompt = neg_prompt
return props
def parse_meta_tag_sd_metadata(self, tag_value):
"""Take as input an png metadata json node for the 'sd-metadata' field variant from 1.15 through 2.3.5 post 2"""
props = InvokeAIMetadata()
props.imported_app_version = tag_value.get("app_version")
props.model_name = tag_value.get("model_weights")
img_node = tag_value.get("image")
if img_node is not None:
props.generation_mode = img_node.get("type")
props.width = img_node.get("width")
props.height = img_node.get("height")
props.seed = img_node.get("seed")
props.rand_device = "cuda" # hardcoded since all generations pre 3.0 used cuda random noise instead of cpu
props.cfg_scale = img_node.get("cfg_scale")
props.steps = img_node.get("steps")
props.scheduler = self.map_scheduler(img_node.get("sampler"))
props.strength = img_node.get("strength")
if props.strength is None:
props.strength = img_node.get("strength_steps") # try second name for this property
props.init_image = img_node.get("init_image_path")
if props.init_image is None: # try second name for this property
props.init_image = img_node.get("init_img")
# remove the path info from init_image so if we move the init image, it will be correctly relative in the new location
if props.init_image is not None:
props.init_image = os.path.basename(props.init_image)
raw_prompt = img_node.get("prompt")
if isinstance(raw_prompt, list):
raw_prompt = raw_prompt[0].get("prompt")
props.positive_prompt, props.negative_prompt = self.split_prompt(raw_prompt)
return props
def parse_meta_tag_invokeai(self, tag_value):
"""Take as input an png metadata json node for the 'invokeai' field variant from 3.0.0 beta 1 through 5"""
props = InvokeAIMetadata()
props.imported_app_version = "3.0.0 or later"
props.generation_mode = tag_value.get("type")
if props.generation_mode is not None:
props.generation_mode = props.generation_mode.replace("t2l", "txt2img").replace("l2l", "img2img")
props.width = tag_value.get("width")
props.height = tag_value.get("height")
props.seed = tag_value.get("seed")
props.cfg_scale = tag_value.get("cfg_scale")
props.steps = tag_value.get("steps")
props.scheduler = tag_value.get("scheduler")
props.strength = tag_value.get("strength")
props.positive_prompt = tag_value.get("positive_conditioning")
props.negative_prompt = tag_value.get("negative_conditioning")
return props
def map_scheduler(self, old_scheduler):
"""Convert the legacy sampler names to matching 3.0 schedulers"""
if old_scheduler is None:
return None
match (old_scheduler):
case "ddim":
return "ddim"
case "plms":
return "pnmd"
case "k_lms":
return "lms"
case "k_dpm_2":
return "kdpm_2"
case "k_dpm_2_a":
return "kdpm_2_a"
case "dpmpp_2":
return "dpmpp_2s"
case "k_dpmpp_2":
return "dpmpp_2m"
case "k_dpmpp_2_a":
return None # invalid, in 2.3.x, selecting this sample would just fallback to last run or plms if new session
case "k_euler":
return "euler"
case "k_euler_a":
return "euler_a"
case "k_heun":
return "heun"
return None
def split_prompt(self, raw_prompt: str):
"""Split the unified prompt strings by extracting all negative prompt blocks out into the negative prompt."""
if raw_prompt is None:
return "", ""
raw_prompt_search = raw_prompt.replace("\r", "").replace("\n", "")
matches = re.findall(r"\[(.+?)\]", raw_prompt_search)
if len(matches) > 0:
negative_prompt = ""
if len(matches) == 1:
negative_prompt = matches[0].strip().strip(",")
else:
for match in matches:
negative_prompt += f"({match.strip().strip(',')})"
positive_prompt = re.sub(r"(\[.+?\])", "", raw_prompt_search).strip()
else:
positive_prompt = raw_prompt_search.strip()
negative_prompt = ""
return positive_prompt, negative_prompt
class DatabaseMapper:
"""Class to abstract database functionality."""
def __init__(self, database_path, database_backup_dir):
self.database_path = database_path
self.database_backup_dir = database_backup_dir
self.connection = None
self.cursor = None
def connect(self):
"""Open connection to the database."""
self.connection = sqlite3.connect(self.database_path)
self.cursor = self.connection.cursor()
def get_board_names(self):
"""Get a list of the current board names from the database."""
sql_get_board_name = "SELECT board_name FROM boards"
self.cursor.execute(sql_get_board_name)
rows = self.cursor.fetchall()
return [row[0] for row in rows]
def does_image_exist(self, image_name):
"""Check database if a image name already exists and return a boolean."""
sql_get_image_by_name = f"SELECT image_name FROM images WHERE image_name='{image_name}'"
self.cursor.execute(sql_get_image_by_name)
rows = self.cursor.fetchall()
return True if len(rows) > 0 else False
def add_new_image_to_database(self, filename, width, height, metadata, modified_date_string):
"""Add an image to the database."""
sql_add_image = f"""INSERT INTO images (image_name, image_origin, image_category, width, height, session_id, node_id, metadata, is_intermediate, created_at, updated_at)
VALUES ('{filename}', 'internal', 'general', {width}, {height}, null, null, '{metadata}', 0, '{modified_date_string}', '{modified_date_string}')"""
self.cursor.execute(sql_add_image)
self.connection.commit()
def get_board_id_with_create(self, board_name):
"""Get the board id for supplied name, and create the board if one does not exist."""
sql_find_board = f"SELECT board_id FROM boards WHERE board_name='{board_name}' COLLATE NOCASE"
self.cursor.execute(sql_find_board)
rows = self.cursor.fetchall()
if len(rows) > 0:
return rows[0][0]
else:
board_date_string = datetime.datetime.utcnow().date().isoformat()
new_board_id = str(uuid.uuid4())
sql_insert_board = f"INSERT INTO boards (board_id, board_name, created_at, updated_at) VALUES ('{new_board_id}', '{board_name}', '{board_date_string}', '{board_date_string}')"
self.cursor.execute(sql_insert_board)
self.connection.commit()
return new_board_id
def add_image_to_board(self, filename, board_id):
"""Add an image mapping to a board."""
add_datetime_str = datetime.datetime.utcnow().isoformat()
sql_add_image_to_board = f"""INSERT INTO board_images (board_id, image_name, created_at, updated_at)
VALUES ('{board_id}', '{filename}', '{add_datetime_str}', '{add_datetime_str}')"""
self.cursor.execute(sql_add_image_to_board)
self.connection.commit()
def disconnect(self):
"""Disconnect from the db, cleaning up connections and cursors."""
if self.cursor is not None:
self.cursor.close()
if self.connection is not None:
self.connection.close()
def backup(self, timestamp_string):
"""Take a backup of the database."""
if not os.path.exists(self.database_backup_dir):
print(f"Database backup directory {self.database_backup_dir} does not exist -> creating...", end="")
os.makedirs(self.database_backup_dir)
print("Done!")
database_backup_path = os.path.join(self.database_backup_dir, f"backup-{timestamp_string}-invokeai.db")
print(f"Making DB Backup at {database_backup_path}...", end="")
shutil.copy2(self.database_path, database_backup_path)
print("Done!")
class MediaImportProcessor:
"""Containing class for script functionality."""
def __init__(self):
pass
board_name_id_map = {}
def get_import_file_list(self):
"""Ask the user for the import folder and scan for the list of files to return."""
while True:
default = ""
while True:
import_dir = os.path.expanduser(
prompt(
"Inputs: Specify absolute path containing InvokeAI .png images to import: ",
completer=PathCompleter(expanduser=True, only_directories=True),
default=default,
)
)
if len(import_dir) > 0 and Path(import_dir).is_dir():
break
default = import_dir
recurse_directories = (
(prompt("Include files from subfolders recursively [yN]? ").strip() or "N").upper().startswith("N")
)
if recurse_directories:
is_recurse = False
matching_file_list = glob.glob(import_dir + "/*.png", recursive=False)
else:
is_recurse = True
matching_file_list = glob.glob(import_dir + "/**/*.png", recursive=True)
if len(matching_file_list) > 0:
return import_dir, is_recurse, matching_file_list
else:
print(f"The specific path {import_dir} exists, but does not contain .png files!")
def get_file_details(self, filepath):
"""Retrieve the embedded metedata fields and dimensions from an image file."""
with PIL.Image.open(filepath) as img:
img.load()
png_width, png_height = img.size
img_info = img.info
return img_info, png_width, png_height
def select_board_option(self, board_names, timestamp_string):
"""Allow the user to choose how a board is selected for imported files."""
while True:
print("\r\nOptions for board selection for imported images:")
print(f"1) Select an existing board name. (found {len(board_names)})")
print("2) Specify a board name to create/add to.")
print("3) Create/add to board named 'IMPORT'.")
print(
f"4) Create/add to board named 'IMPORT' with the current datetime string appended (.e.g IMPORT_{timestamp_string})."
)
print(
"5) Create/add to board named 'IMPORT' with a the original file app_version appended (.e.g IMPORT_2.2.5)."
)
input_option = input("Specify desired board option: ")
match (input_option):
case "1":
if len(board_names) < 1:
print("\r\nThere are no existing board names to choose from. Select another option!")
continue
board_name = self.select_item_from_list(
board_names, "board name", True, "Cancel, go back and choose a different board option."
)
if board_name is not None:
return board_name
case "2":
while True:
board_name = input("Specify new/existing board name: ")
if board_name:
return board_name
case "3":
return "IMPORT"
case "4":
return f"IMPORT_{timestamp_string}"
case "5":
return "IMPORT_APPVERSION"
def select_item_from_list(self, items, entity_name, allow_cancel, cancel_string):
"""A general function to render a list of items to select in the console, prompt the user for a selection and ensure a valid entry is selected."""
print(f"Select a {entity_name.lower()} from the following list:")
index = 1
for item in items:
print(f"{index}) {item}")
index += 1
if allow_cancel:
print(f"{index}) {cancel_string}")
while True:
try:
option_number = int(input("Specify number of selection: "))
except ValueError:
continue
if allow_cancel and option_number == index:
return None
if option_number >= 1 and option_number <= len(items):
return items[option_number - 1]
def import_image(self, filepath: str, board_name_option: str, db_mapper: DatabaseMapper, config: Config):
"""Import a single file by its path"""
parser = InvokeAIMetadataParser()
file_name = os.path.basename(filepath)
file_destination_path = os.path.join(config.outputs_path, file_name)
print("===============================================================================")
print(f"Importing {filepath}")
# check destination to see if the file was previously imported
if os.path.exists(file_destination_path):
print("File already exists in the destination, skipping!")
ImportStats.count_skipped_file_exists += 1
return
# check if file name is already referenced in the database
if db_mapper.does_image_exist(file_name):
print("A reference to a file with this name already exists in the database, skipping!")
ImportStats.count_skipped_db_exists += 1
return
# load image info and dimensions
img_info, png_width, png_height = self.get_file_details(filepath)
# parse metadata
destination_needs_meta_update = True
log_version_note = "(Unknown)"
if "invokeai_metadata" in img_info:
# for the latest, we will just re-emit the same json, no need to parse/modify
converted_field = None
latest_json_string = img_info.get("invokeai_metadata")
log_version_note = "3.0.0+"
destination_needs_meta_update = False
else:
if "sd-metadata" in img_info:
converted_field = parser.parse_meta_tag_sd_metadata(json.loads(img_info.get("sd-metadata")))
elif "invokeai" in img_info:
converted_field = parser.parse_meta_tag_invokeai(json.loads(img_info.get("invokeai")))
elif "dream" in img_info:
converted_field = parser.parse_meta_tag_dream(img_info.get("dream"))
elif "Dream" in img_info:
converted_field = parser.parse_meta_tag_dream(img_info.get("Dream"))
else:
converted_field = InvokeAIMetadata()
destination_needs_meta_update = False
print("File does not have metadata from known Invoke AI versions, add only, no update!")
# use the loaded img dimensions if the metadata didnt have them
if converted_field.width is None:
converted_field.width = png_width
if converted_field.height is None:
converted_field.height = png_height
log_version_note = converted_field.imported_app_version if converted_field else "NoVersion"
log_version_note = log_version_note or "NoVersion"
latest_json_string = converted_field.to_json()
print(f"From Invoke AI Version {log_version_note} with dimensions {png_width} x {png_height}.")
# if metadata needs update, then update metdata and copy in one shot
if destination_needs_meta_update:
print("Updating metadata while copying...", end="")
self.update_file_metadata_while_copying(
filepath, file_destination_path, "invokeai_metadata", latest_json_string
)
print("Done!")
else:
print("No metadata update necessary, copying only...", end="")
shutil.copy2(filepath, file_destination_path)
print("Done!")
# create thumbnail
print("Creating thumbnail...", end="")
thumbnail_path = os.path.join(config.thumbnail_path, os.path.splitext(file_name)[0]) + ".webp"
thumbnail_size = 256, 256
with PIL.Image.open(filepath) as source_image:
source_image.thumbnail(thumbnail_size)
source_image.save(thumbnail_path, "webp")
print("Done!")
# finalize the dynamic board name if there is an APPVERSION token in it.
if converted_field is not None:
board_name = board_name_option.replace("APPVERSION", converted_field.imported_app_version or "NoVersion")
else:
board_name = board_name_option.replace("APPVERSION", "Latest")
# maintain a map of alrady created/looked up ids to avoid DB queries
print("Finding/Creating board...", end="")
if board_name in self.board_name_id_map:
board_id = self.board_name_id_map[board_name]
else:
board_id = db_mapper.get_board_id_with_create(board_name)
self.board_name_id_map[board_name] = board_id
print("Done!")
# add image to db
print("Adding image to database......", end="")
modified_time = datetime.datetime.utcfromtimestamp(os.path.getmtime(filepath))
db_mapper.add_new_image_to_database(file_name, png_width, png_height, latest_json_string, modified_time)
print("Done!")
# add image to board
print("Adding image to board......", end="")
db_mapper.add_image_to_board(file_name, board_id)
print("Done!")
ImportStats.count_imported += 1
if log_version_note in ImportStats.count_imported_by_version:
ImportStats.count_imported_by_version[log_version_note] += 1
else:
ImportStats.count_imported_by_version[log_version_note] = 1
def update_file_metadata_while_copying(self, filepath, file_destination_path, tag_name, tag_value):
"""Perform a metadata update with save to a new destination which accomplishes a copy while updating metadata."""
with PIL.Image.open(filepath) as target_image:
existing_img_info = target_image.info
metadata = PIL.PngImagePlugin.PngInfo()
# re-add any existing invoke ai tags unless they are the one we are trying to add
for key in existing_img_info:
if key != tag_name and key in ("dream", "Dream", "sd-metadata", "invokeai", "invokeai_metadata"):
metadata.add_text(key, existing_img_info[key])
metadata.add_text(tag_name, tag_value)
target_image.save(file_destination_path, pnginfo=metadata)
def process(self):
"""Begin main processing."""
print("===============================================================================")
print("This script will import images generated by earlier versions of")
print("InvokeAI into the currently installed root directory:")
print(f" {app_config.root_path}")
print("If this is not what you want to do, type ctrl-C now to cancel.")
# load config
print("===============================================================================")
print("= Configuration & Settings")
config = Config()
config.find_and_load()
db_mapper = DatabaseMapper(config.database_path, config.database_backup_dir)
db_mapper.connect()
import_dir, is_recurse, import_file_list = self.get_import_file_list()
ImportStats.count_source_files = len(import_file_list)
board_names = db_mapper.get_board_names()
board_name_option = self.select_board_option(board_names, config.TIMESTAMP_STRING)
print("\r\n===============================================================================")
print("= Import Settings Confirmation")
print()
print(f"Database File Path : {config.database_path}")
print(f"Outputs/Images Directory : {config.outputs_path}")
print(f"Import Image Source Directory : {import_dir}")
print(f" Recurse Source SubDirectories : {'Yes' if is_recurse else 'No'}")
print(f"Count of .png file(s) found : {len(import_file_list)}")
print(f"Board name option specified : {board_name_option}")
print(f"Database backup will be taken at : {config.database_backup_dir}")
print("\r\nNotes about the import process:")
print("- Source image files will not be modified, only copied to the outputs directory.")
print("- If the same file name already exists in the destination, the file will be skipped.")
print("- If the same file name already has a record in the database, the file will be skipped.")
print("- Invoke AI metadata tags will be updated/written into the imported copy only.")
print(
"- On the imported copy, only Invoke AI known tags (latest and legacy) will be retained (dream, sd-metadata, invokeai, invokeai_metadata)"
)
print(
"- A property 'imported_app_version' will be added to metadata that can be viewed in the UI's metadata viewer."
)
print(
"- The new 3.x InvokeAI outputs folder structure is flat so recursively found source imges will all be placed into the single outputs/images folder."
)
while True:
should_continue = prompt("\nDo you wish to continue with the import [Yn] ? ").lower() or "y"
if should_continue == "n":
print("\r\nCancelling Import")
return
elif should_continue == "y":
print()
break
db_mapper.backup(config.TIMESTAMP_STRING)
print()
ImportStats.time_start = datetime.datetime.utcnow()
for filepath in import_file_list:
try:
self.import_image(filepath, board_name_option, db_mapper, config)
except sqlite3.Error as sql_ex:
print(f"A database related exception was found processing {filepath}, will continue to next file. ")
print("Exception detail:")
print(sql_ex)
ImportStats.count_file_errors += 1
except Exception as ex:
print(f"Exception processing {filepath}, will continue to next file. ")
print("Exception detail:")
print(ex)
ImportStats.count_file_errors += 1
print("\r\n===============================================================================")
print(f"= Import Complete - Elpased Time: {ImportStats.get_elapsed_time_string()}")
print()
print(f"Source File(s) : {ImportStats.count_source_files}")
print(f"Total Imported : {ImportStats.count_imported}")
print(f"Skipped b/c file already exists on disk : {ImportStats.count_skipped_file_exists}")
print(f"Skipped b/c file already exists in db : {ImportStats.count_skipped_db_exists}")
print(f"Errors during import : {ImportStats.count_file_errors}")
if ImportStats.count_imported > 0:
print("\r\nBreakdown of imported files by version:")
for key, version in ImportStats.count_imported_by_version.items():
print(f" {key:20} : {version}")
def main():
try:
processor = MediaImportProcessor()
processor.process()
except KeyboardInterrupt:
print("\r\n\r\nUser cancelled execution.")
if __name__ == "__main__":
main()

View File

@ -1,55 +1,58 @@
import { modelChanged } from 'features/parameters/store/generationSlice'; import { modelChanged } from 'features/parameters/store/generationSlice';
import { setActiveTab } from 'features/ui/store/uiSlice'; import { setActiveTab } from 'features/ui/store/uiSlice';
import { forEach } from 'lodash-es';
import { NON_REFINER_BASE_MODELS } from 'services/api/constants'; import { NON_REFINER_BASE_MODELS } from 'services/api/constants';
import { import { mainModelsAdapter, modelsApi } from 'services/api/endpoints/models';
MainModelConfigEntity,
modelsApi,
} from 'services/api/endpoints/models';
import { startAppListening } from '..'; import { startAppListening } from '..';
export const addTabChangedListener = () => { export const addTabChangedListener = () => {
startAppListening({ startAppListening({
actionCreator: setActiveTab, actionCreator: setActiveTab,
effect: (action, { getState, dispatch }) => { effect: async (action, { getState, dispatch }) => {
const activeTabName = action.payload; const activeTabName = action.payload;
if (activeTabName === 'unifiedCanvas') { if (activeTabName === 'unifiedCanvas') {
// grab the models from RTK Query cache const currentBaseModel = getState().generation.model?.base_model;
const { data } = modelsApi.endpoints.getMainModels.select(
NON_REFINER_BASE_MODELS
)(getState());
if (!data) { if (currentBaseModel && ['sd-1', 'sd-2'].includes(currentBaseModel)) {
// no models yet, so we can't do anything // if we're already on a valid model, no change needed
dispatch(modelChanged(null));
return; return;
} }
// need to filter out all the invalid canvas models (currently, this is just sdxl) try {
const validCanvasModels: MainModelConfigEntity[] = []; // just grab fresh models
const modelsRequest = dispatch(
modelsApi.endpoints.getMainModels.initiate(NON_REFINER_BASE_MODELS)
);
const models = await modelsRequest.unwrap();
// cancel this cache subscription
modelsRequest.unsubscribe();
forEach(data.entities, (entity) => { if (!models.ids.length) {
if (!entity) { // no valid canvas models
dispatch(modelChanged(null));
return; return;
} }
if (['sd-1', 'sd-2'].includes(entity.base_model)) {
validCanvasModels.push(entity); // need to filter out all the invalid canvas models (currently sdxl & refiner)
const validCanvasModels = mainModelsAdapter
.getSelectors()
.selectAll(models)
.filter((model) => ['sd-1', 'sd-2'].includes(model.base_model));
const firstValidCanvasModel = validCanvasModels[0];
if (!firstValidCanvasModel) {
// no valid canvas models
dispatch(modelChanged(null));
return;
} }
});
// this could still be undefined even tho TS doesn't say so const { base_model, model_name, model_type } = firstValidCanvasModel;
const firstValidCanvasModel = validCanvasModels[0];
if (!firstValidCanvasModel) { dispatch(modelChanged({ base_model, model_name, model_type }));
// uh oh, we have no models that are valid for canvas } catch {
// network request failed, bail
dispatch(modelChanged(null)); dispatch(modelChanged(null));
return;
} }
// only store the model name and base model in redux
const { base_model, model_name, model_type } = firstValidCanvasModel;
dispatch(modelChanged({ base_model, model_name, model_type }));
} }
}, },
}); });

View File

@ -54,12 +54,7 @@ const ParamLoRASelect = () => {
}); });
}); });
// Sort Alphabetically return data.sort((a, b) => (a.disabled && !b.disabled ? 1 : -1));
data.sort((a, b) =>
a.label && b.label ? (a.label?.localeCompare(b.label) ? 1 : -1) : -1
);
return data.sort((a, b) => (a.disabled && !b.disabled ? -1 : 1));
}, [loras, loraModels, currentMainModel?.base_model]); }, [loras, loraModels, currentMainModel?.base_model]);
const handleChange = useCallback( const handleChange = useCallback(

View File

@ -139,6 +139,7 @@ dependencies = [
"invokeai-metadata" = "invokeai.frontend.CLI.sd_metadata:print_metadata" "invokeai-metadata" = "invokeai.frontend.CLI.sd_metadata:print_metadata"
"invokeai-node-cli" = "invokeai.app.cli_app:invoke_cli" "invokeai-node-cli" = "invokeai.app.cli_app:invoke_cli"
"invokeai-node-web" = "invokeai.app.api_app:invoke_api" "invokeai-node-web" = "invokeai.app.api_app:invoke_api"
"invokeai-import-images" = "invokeai.frontend.install.import_images:main"
[project.urls] [project.urls]
"Homepage" = "https://invoke-ai.github.io/InvokeAI/" "Homepage" = "https://invoke-ai.github.io/InvokeAI/"

View File

@ -7,6 +7,7 @@ from invokeai.backend import ModelManager, BaseModelType, ModelType, SubModelTyp
BASIC_MODEL_NAME = ("SDXL base", BaseModelType.StableDiffusionXL, ModelType.Main) BASIC_MODEL_NAME = ("SDXL base", BaseModelType.StableDiffusionXL, ModelType.Main)
VAE_OVERRIDE_MODEL_NAME = ("SDXL with VAE", BaseModelType.StableDiffusionXL, ModelType.Main) VAE_OVERRIDE_MODEL_NAME = ("SDXL with VAE", BaseModelType.StableDiffusionXL, ModelType.Main)
VAE_NULL_OVERRIDE_MODEL_NAME = ("SDXL with empty VAE", BaseModelType.StableDiffusionXL, ModelType.Main)
@pytest.fixture @pytest.fixture
@ -36,3 +37,11 @@ def test_get_model_path_for_overridden_vae(model_manager: ModelManager, datadir:
expected_vae_path = datadir / "models" / "sdxl" / "vae" / "sdxl-vae-fp16-fix" expected_vae_path = datadir / "models" / "sdxl" / "vae" / "sdxl-vae-fp16-fix"
assert vae_model_path == expected_vae_path assert vae_model_path == expected_vae_path
assert is_override assert is_override
def test_get_model_path_for_null_overridden_vae(model_manager: ModelManager, datadir: Path):
model_config = model_manager._get_model_config(
VAE_NULL_OVERRIDE_MODEL_NAME[1], VAE_NULL_OVERRIDE_MODEL_NAME[0], VAE_NULL_OVERRIDE_MODEL_NAME[2]
)
vae_model_path, is_override = model_manager._get_model_path(model_config, SubModelType.Vae)
assert not is_override

View File

@ -13,3 +13,10 @@ sdxl/main/SDXL with VAE:
vae: sdxl/vae/sdxl-vae-fp16-fix/ vae: sdxl/vae/sdxl-vae-fp16-fix/
variant: normal variant: normal
format: diffusers format: diffusers
sdxl/main/SDXL with empty VAE:
path: sdxl/main/SDXL base 1_0
description: SDXL with customized VAE
vae: ''
variant: normal
format: diffusers