mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge remote-tracking branch 'origin/main' into maryhipp/trigger-phrases-main
This commit is contained in:
commit
79a9567119
@ -19,6 +19,8 @@ their descriptions.
|
||||
| Conditioning Primitive | A conditioning tensor primitive value |
|
||||
| Content Shuffle Processor | Applies content shuffle processing to image |
|
||||
| ControlNet | Collects ControlNet info to pass to other nodes |
|
||||
| Create Denoise Mask | Converts a greyscale or transparency image into a mask for denoising. |
|
||||
| Create Gradient Mask | Creates a mask for Gradient ("soft", "differential") inpainting that gradually expands during denoising. Improves edge coherence. |
|
||||
| Denoise Latents | Denoises noisy latents to decodable images |
|
||||
| Divide Integers | Divides two numbers |
|
||||
| Dynamic Prompt | Parses a prompt using adieyal/dynamicprompts' random or combinatorial generator |
|
||||
|
@ -283,6 +283,47 @@ async def update_model_metadata(
|
||||
return result
|
||||
|
||||
|
||||
@model_manager_router.patch(
|
||||
"/i/{key}/metadata",
|
||||
operation_id="update_model_metadata",
|
||||
responses={
|
||||
201: {
|
||||
"description": "The model metadata was updated successfully",
|
||||
"content": {"application/json": {"example": example_model_metadata}},
|
||||
},
|
||||
400: {"description": "Bad request"},
|
||||
},
|
||||
)
|
||||
async def update_model_metadata(
|
||||
key: str = Path(description="Key of the model repo metadata to fetch."),
|
||||
changes: ModelMetadataChanges = Body(description="The changes"),
|
||||
) -> Optional[AnyModelRepoMetadata]:
|
||||
"""Updates or creates a model metadata object."""
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
metadata_store = ApiDependencies.invoker.services.model_manager.store.metadata_store
|
||||
|
||||
try:
|
||||
original_metadata = record_store.get_metadata(key)
|
||||
if original_metadata:
|
||||
if changes.default_settings:
|
||||
original_metadata.default_settings = changes.default_settings
|
||||
|
||||
metadata_store.update_metadata(key, original_metadata)
|
||||
else:
|
||||
metadata_store.add_metadata(
|
||||
key, BaseMetadata(name="", author="", default_settings=changes.default_settings)
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"An error occurred while updating the model metadata: {e}",
|
||||
)
|
||||
|
||||
result: Optional[AnyModelRepoMetadata] = record_store.get_metadata(key)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
"/tags",
|
||||
operation_id="list_tags",
|
||||
@ -491,6 +532,7 @@ async def add_model_record(
|
||||
)
|
||||
async def install_model(
|
||||
source: str = Query(description="Model source to install, can be a local path, repo_id, or remote URL"),
|
||||
inplace: Optional[bool] = Query(description="Whether or not to install a local model in place", default=False),
|
||||
# TODO(MM2): Can we type this?
|
||||
config: Optional[Dict[str, Any]] = Body(
|
||||
description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
|
||||
@ -533,6 +575,7 @@ async def install_model(
|
||||
source=source,
|
||||
config=config,
|
||||
access_token=access_token,
|
||||
inplace=bool(inplace),
|
||||
)
|
||||
logger.info(f"Started installation of {source}")
|
||||
except UnknownModelException as e:
|
||||
|
@ -173,6 +173,16 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation_output("gradient_mask_output")
|
||||
class GradientMaskOutput(BaseInvocationOutput):
|
||||
"""Outputs a denoise mask and an image representing the total gradient of the mask."""
|
||||
|
||||
denoise_mask: DenoiseMaskField = OutputField(description="Mask for denoise model run")
|
||||
expanded_mask_area: ImageField = OutputField(
|
||||
description="Image representing the total gradient area of the mask. For paste-back purposes."
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"create_gradient_mask",
|
||||
title="Create Gradient Mask",
|
||||
@ -193,38 +203,42 @@ class CreateGradientMaskInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> DenoiseMaskOutput:
|
||||
def invoke(self, context: InvocationContext) -> GradientMaskOutput:
|
||||
mask_image = context.images.get_pil(self.mask.image_name, mode="L")
|
||||
if self.coherence_mode == "Box Blur":
|
||||
blur_mask = mask_image.filter(ImageFilter.BoxBlur(self.edge_radius))
|
||||
else: # Gaussian Blur OR Staged
|
||||
# Gaussian Blur uses standard deviation. 1/2 radius is a good approximation
|
||||
blur_mask = mask_image.filter(ImageFilter.GaussianBlur(self.edge_radius / 2))
|
||||
if self.edge_radius > 0:
|
||||
if self.coherence_mode == "Box Blur":
|
||||
blur_mask = mask_image.filter(ImageFilter.BoxBlur(self.edge_radius))
|
||||
else: # Gaussian Blur OR Staged
|
||||
# Gaussian Blur uses standard deviation. 1/2 radius is a good approximation
|
||||
blur_mask = mask_image.filter(ImageFilter.GaussianBlur(self.edge_radius / 2))
|
||||
|
||||
mask_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
|
||||
blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(blur_mask, normalize=False)
|
||||
blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(blur_mask, normalize=False)
|
||||
|
||||
# redistribute blur so that the edges are 0 and blur out to 1
|
||||
blur_tensor = (blur_tensor - 0.5) * 2
|
||||
# redistribute blur so that the original edges are 0 and blur outwards to 1
|
||||
blur_tensor = (blur_tensor - 0.5) * 2
|
||||
|
||||
threshold = 1 - self.minimum_denoise
|
||||
threshold = 1 - self.minimum_denoise
|
||||
|
||||
if self.coherence_mode == "Staged":
|
||||
# wherever the blur_tensor is less than fully masked, convert it to threshold
|
||||
blur_tensor = torch.where((blur_tensor < 1) & (blur_tensor > 0), threshold, blur_tensor)
|
||||
else:
|
||||
# wherever the blur_tensor is above threshold but less than 1, drop it to threshold
|
||||
blur_tensor = torch.where((blur_tensor > threshold) & (blur_tensor < 1), threshold, blur_tensor)
|
||||
|
||||
if self.coherence_mode == "Staged":
|
||||
# wherever the blur_tensor is masked to any degree, convert it to threshold
|
||||
blur_tensor = torch.where((blur_tensor < 1), threshold, blur_tensor)
|
||||
else:
|
||||
# wherever the blur_tensor is above threshold but less than 1, drop it to threshold
|
||||
blur_tensor = torch.where((blur_tensor > threshold) & (blur_tensor < 1), threshold, blur_tensor)
|
||||
|
||||
# multiply original mask to force actually masked regions to 0
|
||||
blur_tensor = mask_tensor * blur_tensor
|
||||
blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
|
||||
|
||||
mask_name = context.tensors.save(tensor=blur_tensor.unsqueeze(1))
|
||||
|
||||
return DenoiseMaskOutput.build(
|
||||
mask_name=mask_name,
|
||||
masked_latents_name=None,
|
||||
gradient=True,
|
||||
# compute a [0, 1] mask from the blur_tensor
|
||||
expanded_mask = torch.where((blur_tensor < 1), 0, 1)
|
||||
expanded_mask_image = Image.fromarray((expanded_mask.squeeze(0).numpy() * 255).astype(np.uint8), mode="L")
|
||||
expanded_image_dto = context.images.save(expanded_mask_image)
|
||||
|
||||
return GradientMaskOutput(
|
||||
denoise_mask=DenoiseMaskField(mask_name=mask_name, masked_latents_name=None, gradient=True),
|
||||
expanded_mask_area=ImageField(image_name=expanded_image_dto.image_name),
|
||||
)
|
||||
|
||||
|
||||
@ -775,10 +789,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
denoising_end=self.denoising_end,
|
||||
)
|
||||
|
||||
(
|
||||
result_latents,
|
||||
result_attention_map_saver,
|
||||
) = pipeline.latents_from_embeddings(
|
||||
result_latents = pipeline.latents_from_embeddings(
|
||||
latents=latents,
|
||||
timesteps=timesteps,
|
||||
init_timestep=init_timestep,
|
||||
|
@ -7,7 +7,6 @@ import time
|
||||
from hashlib import sha256
|
||||
from pathlib import Path
|
||||
from queue import Empty, Queue
|
||||
from random import randbytes
|
||||
from shutil import copyfile, copytree, move, rmtree
|
||||
from tempfile import mkdtemp
|
||||
from typing import Any, Dict, List, Optional, Set, Union
|
||||
@ -21,6 +20,7 @@ from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
@ -150,7 +150,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
config = config or {}
|
||||
if not config.get("source"):
|
||||
config["source"] = model_path.resolve().as_posix()
|
||||
config["key"] = config.get("key", self._create_key())
|
||||
config["key"] = config.get("key", uuid_string())
|
||||
|
||||
info: AnyModelConfig = self._probe_model(Path(model_path), config)
|
||||
|
||||
@ -178,13 +178,14 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
source: str,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
access_token: Optional[str] = None,
|
||||
inplace: bool = False,
|
||||
) -> ModelInstallJob:
|
||||
variants = "|".join(ModelRepoVariant.__members__.values())
|
||||
hf_repoid_re = f"^([^/:]+/[^/:]+)(?::({variants})?(?::/?([^:]+))?)?$"
|
||||
source_obj: Optional[StringLikeSource] = None
|
||||
|
||||
if Path(source).exists(): # A local file or directory
|
||||
source_obj = LocalModelSource(path=Path(source))
|
||||
source_obj = LocalModelSource(path=Path(source), inplace=inplace)
|
||||
elif match := re.match(hf_repoid_re, source):
|
||||
source_obj = HFModelSource(
|
||||
repo_id=match.group(1),
|
||||
@ -526,16 +527,17 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
setattr(info, key, value)
|
||||
return info
|
||||
|
||||
def _create_key(self) -> str:
|
||||
return sha256(randbytes(100)).hexdigest()[0:32]
|
||||
|
||||
def _register(
|
||||
self, model_path: Path, config: Optional[Dict[str, Any]] = None, info: Optional[AnyModelConfig] = None
|
||||
) -> str:
|
||||
# Note that we may be passed a pre-populated AnyModelConfig object,
|
||||
# in which case the key field should have been populated by the caller (e.g. in `install_path`).
|
||||
config["key"] = config.get("key", self._create_key())
|
||||
config["key"] = config.get("key", uuid_string())
|
||||
info = info or ModelProbe.probe(model_path, config)
|
||||
override_key: Optional[str] = config.get("key") if config else None
|
||||
|
||||
assert info.original_hash # always assigned by probe()
|
||||
info.key = override_key or info.original_hash
|
||||
|
||||
model_path = model_path.absolute()
|
||||
if model_path.is_relative_to(self.app_config.models_path):
|
||||
|
@ -7,17 +7,23 @@ from abc import ABC, abstractmethod
|
||||
from typing import List, Optional, Set, Tuple
|
||||
|
||||
from pydantic import Field
|
||||
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
||||
|
||||
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||
from invokeai.backend.model_manager.metadata.metadata_base import ModelDefaultSettings
|
||||
|
||||
class ModelMetadataChanges(BaseModelExcludeNull, extra="allow"):
|
||||
"""A set of changes to apply to model metadata.
|
||||
|
||||
Only limited changes are valid:
|
||||
- `default_settings`: the user-configured default settings for this model
|
||||
- `trigger_phrases`: the list of trigger phrases for this model
|
||||
"""
|
||||
|
||||
default_settings: Optional[ModelDefaultSettings] = Field(
|
||||
default=None, description="The user-configured default settings for this model"
|
||||
)
|
||||
"""The user-configured default settings for this model"""
|
||||
trigger_phrases: Optional[List[str]] = Field(default=None, description="The model's list of trigger phrases")
|
||||
"""The model's list of trigger phrases"""
|
||||
|
||||
|
@ -184,7 +184,7 @@ class ModelMetadataStoreSQL(ModelMetadataStoreBase):
|
||||
def _update_tags(self, model_key: str, tags: Optional[Set[str]]) -> None:
|
||||
"""Update tags for the model referenced by model_key."""
|
||||
if tags:
|
||||
# remove previous tags from this model
|
||||
# remove previous tags from this model
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM model_tags
|
||||
|
@ -200,6 +200,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
self._invoker.services.logger.error(
|
||||
f"Error while invoking session {self._queue_item.session_id}, invocation {self._invocation.id} ({self._invocation.get_type()}):\n{e}"
|
||||
)
|
||||
self._invoker.services.logger.error(error)
|
||||
|
||||
# Send error event
|
||||
self._invoker.services.events.emit_invocation_error(
|
||||
|
@ -3,7 +3,6 @@
|
||||
|
||||
import json
|
||||
import sqlite3
|
||||
from hashlib import sha1
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
@ -22,7 +21,7 @@ from invokeai.backend.model_manager.config import (
|
||||
ModelConfigFactory,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.hash import FastModelHash
|
||||
from invokeai.backend.model_manager.hash import ModelHash
|
||||
|
||||
ModelsValidator = TypeAdapter(AnyModelConfig)
|
||||
|
||||
@ -73,19 +72,27 @@ class MigrateModelYamlToDb1:
|
||||
|
||||
base_type, model_type, model_name = str(model_key).split("/")
|
||||
try:
|
||||
hash = FastModelHash.hash(self.config.models_path / stanza.path)
|
||||
hash = ModelHash().hash(self.config.models_path / stanza.path)
|
||||
except OSError:
|
||||
self.logger.warning(f"The model at {stanza.path} is not a valid file or directory. Skipping migration.")
|
||||
continue
|
||||
|
||||
assert isinstance(model_key, str)
|
||||
new_key = sha1(model_key.encode("utf-8")).hexdigest()
|
||||
|
||||
stanza["base"] = BaseModelType(base_type)
|
||||
stanza["type"] = ModelType(model_type)
|
||||
stanza["name"] = model_name
|
||||
stanza["original_hash"] = hash
|
||||
stanza["current_hash"] = hash
|
||||
new_key = hash # deterministic key assignment
|
||||
|
||||
# special case for ip adapters, which need the new `image_encoder_model_id` field
|
||||
if stanza["type"] == ModelType.IPAdapter:
|
||||
try:
|
||||
stanza["image_encoder_model_id"] = self._get_image_encoder_model_id(
|
||||
self.config.models_path / stanza.path
|
||||
)
|
||||
except OSError:
|
||||
self.logger.warning(f"Could not determine image encoder for {stanza.path}. Skipping.")
|
||||
continue
|
||||
|
||||
new_config: AnyModelConfig = ModelsValidator.validate_python(stanza) # type: ignore # see https://github.com/pydantic/pydantic/discussions/7094
|
||||
|
||||
@ -95,7 +102,7 @@ class MigrateModelYamlToDb1:
|
||||
self.logger.info(f"Updating model {model_name} with information from models.yaml using key {key}")
|
||||
self._update_model(key, new_config)
|
||||
else:
|
||||
self.logger.info(f"Adding model {model_name} with key {model_key}")
|
||||
self.logger.info(f"Adding model {model_name} with key {new_key}")
|
||||
self._add_model(new_key, new_config)
|
||||
except DuplicateModelException:
|
||||
self.logger.warning(f"Model {model_name} is already in the database")
|
||||
@ -149,3 +156,8 @@ class MigrateModelYamlToDb1:
|
||||
)
|
||||
except sqlite3.IntegrityError as exc:
|
||||
raise DuplicateModelException(f"{record.name}: model is already in database") from exc
|
||||
|
||||
def _get_image_encoder_model_id(self, model_path: Path) -> str:
|
||||
with open(model_path / "image_encoder.txt") as f:
|
||||
encoder = f.read()
|
||||
return encoder.strip()
|
||||
|
@ -11,56 +11,175 @@ from invokeai.backend.model_managre.model_hash import FastModelHash
|
||||
import hashlib
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict, Union
|
||||
from typing import Callable, Literal, Optional, Union
|
||||
|
||||
from imohash import hashfile
|
||||
from blake3 import blake3
|
||||
|
||||
MODEL_FILE_EXTENSIONS = (".ckpt", ".safetensors", ".bin", ".pt", ".pth")
|
||||
|
||||
ALGORITHM = Literal[
|
||||
"md5",
|
||||
"sha1",
|
||||
"sha224",
|
||||
"sha256",
|
||||
"sha384",
|
||||
"sha512",
|
||||
"blake2b",
|
||||
"blake2s",
|
||||
"sha3_224",
|
||||
"sha3_256",
|
||||
"sha3_384",
|
||||
"sha3_512",
|
||||
"shake_128",
|
||||
"shake_256",
|
||||
"blake3",
|
||||
]
|
||||
|
||||
|
||||
class FastModelHash(object):
|
||||
"""FastModelHash obect provides one public class method, hash()."""
|
||||
class ModelHash:
|
||||
"""
|
||||
Creates a hash of a model using a specified algorithm.
|
||||
|
||||
@classmethod
|
||||
def hash(cls, model_location: Union[str, Path]) -> str:
|
||||
"""
|
||||
Return hexdigest string for model located at model_location.
|
||||
Args:
|
||||
algorithm: Hashing algorithm to use. Defaults to BLAKE3.
|
||||
file_filter: A function that takes a file name and returns True if the file should be included in the hash.
|
||||
|
||||
:param model_location: Path to the model
|
||||
"""
|
||||
model_location = Path(model_location)
|
||||
if model_location.is_file():
|
||||
return cls._hash_file(model_location)
|
||||
elif model_location.is_dir():
|
||||
return cls._hash_dir(model_location)
|
||||
If the model is a single file, it is hashed directly using the provided algorithm.
|
||||
|
||||
If the model is a directory, each model weights file in the directory is hashed using the provided algorithm.
|
||||
|
||||
Only files with the following extensions are hashed: .ckpt, .safetensors, .bin, .pt, .pth
|
||||
|
||||
The final hash is computed by hashing the hashes of all model files in the directory using BLAKE3, ensuring
|
||||
that directory hashes are never weaker than the file hashes.
|
||||
|
||||
Usage:
|
||||
```py
|
||||
# BLAKE3 hash
|
||||
ModelHash().hash("path/to/some/model.safetensors")
|
||||
# MD5
|
||||
ModelHash("md5").hash("path/to/model/dir/")
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, algorithm: ALGORITHM = "blake3", file_filter: Optional[Callable[[str], bool]] = None) -> None:
|
||||
if algorithm == "blake3":
|
||||
self._hash_file = self._blake3
|
||||
elif algorithm in hashlib.algorithms_available:
|
||||
self._hash_file = self._get_hashlib(algorithm)
|
||||
else:
|
||||
raise OSError(f"Not a valid file or directory: {model_location}")
|
||||
raise ValueError(f"Algorithm {algorithm} not available")
|
||||
|
||||
@classmethod
|
||||
def _hash_file(cls, model_location: Union[str, Path]) -> str:
|
||||
self._file_filter = file_filter or self._default_file_filter
|
||||
|
||||
def hash(self, model_path: Union[str, Path]) -> str:
|
||||
"""
|
||||
Fasthash a single file and return its hexdigest.
|
||||
Return hexdigest of hash of model located at model_path using the algorithm provided at class instantiation.
|
||||
|
||||
:param model_location: Path to the model file
|
||||
If model_path is a directory, the hash is computed by hashing the hashes of all model files in the
|
||||
directory. The final composite hash is always computed using BLAKE3.
|
||||
|
||||
Args:
|
||||
model_path: Path to the model
|
||||
|
||||
Returns:
|
||||
str: Hexdigest of the hash of the model
|
||||
"""
|
||||
# we return md5 hash of the filehash to make it shorter
|
||||
# cryptographic security not needed here
|
||||
return hashlib.md5(hashfile(model_location)).hexdigest()
|
||||
|
||||
@classmethod
|
||||
def _hash_dir(cls, model_location: Union[str, Path]) -> str:
|
||||
components: Dict[str, str] = {}
|
||||
model_path = Path(model_path)
|
||||
if model_path.is_file():
|
||||
return self._hash_file(model_path)
|
||||
elif model_path.is_dir():
|
||||
return self._hash_dir(model_path)
|
||||
else:
|
||||
raise OSError(f"Not a valid file or directory: {model_path}")
|
||||
|
||||
for root, _dirs, files in os.walk(model_location):
|
||||
for file in files:
|
||||
# only tally tensor files because diffusers config files change slightly
|
||||
# depending on how the model was downloaded/converted.
|
||||
if not file.endswith((".ckpt", ".safetensors", ".bin", ".pt", ".pth")):
|
||||
continue
|
||||
path = (Path(root) / file).as_posix()
|
||||
fast_hash = cls._hash_file(path)
|
||||
components.update({path: fast_hash})
|
||||
def _hash_dir(self, dir: Path) -> str:
|
||||
"""Compute the hash for all files in a directory and return a hexdigest.
|
||||
|
||||
# hash all the model hashes together, using alphabetic file order
|
||||
md5 = hashlib.md5()
|
||||
for _path, fast_hash in sorted(components.items()):
|
||||
md5.update(fast_hash.encode("utf-8"))
|
||||
return md5.hexdigest()
|
||||
Args:
|
||||
dir: Path to the directory
|
||||
|
||||
Returns:
|
||||
str: Hexdigest of the hash of the directory
|
||||
"""
|
||||
model_component_paths = self._get_file_paths(dir, self._file_filter)
|
||||
|
||||
component_hashes: list[str] = []
|
||||
for component in sorted(model_component_paths):
|
||||
component_hashes.append(self._hash_file(component))
|
||||
|
||||
# BLAKE3 is cryptographically secure. We may as well fall back on a secure algorithm
|
||||
# for the composite hash
|
||||
composite_hasher = blake3()
|
||||
for h in component_hashes:
|
||||
composite_hasher.update(h.encode("utf-8"))
|
||||
return composite_hasher.hexdigest()
|
||||
|
||||
@staticmethod
|
||||
def _get_file_paths(model_path: Path, file_filter: Callable[[str], bool]) -> list[Path]:
|
||||
"""Return a list of all model files in the directory.
|
||||
|
||||
Args:
|
||||
model_path: Path to the model
|
||||
file_filter: Function that takes a file name and returns True if the file should be included in the list.
|
||||
|
||||
Returns:
|
||||
List of all model files in the directory
|
||||
"""
|
||||
|
||||
files: list[Path] = []
|
||||
for root, _dirs, _files in os.walk(model_path):
|
||||
for file in _files:
|
||||
if file_filter(file):
|
||||
files.append(Path(root, file))
|
||||
return files
|
||||
|
||||
@staticmethod
|
||||
def _blake3(file_path: Path) -> str:
|
||||
"""Hashes a file using BLAKE3
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to hash
|
||||
|
||||
Returns:
|
||||
Hexdigest of the hash of the file
|
||||
"""
|
||||
file_hasher = blake3(max_threads=blake3.AUTO)
|
||||
file_hasher.update_mmap(file_path)
|
||||
return file_hasher.hexdigest()
|
||||
|
||||
@staticmethod
|
||||
def _get_hashlib(algorithm: ALGORITHM) -> Callable[[Path], str]:
|
||||
"""Factory function that returns a function to hash a file with the given algorithm.
|
||||
|
||||
Args:
|
||||
algorithm: Hashing algorithm to use
|
||||
|
||||
Returns:
|
||||
A function that hashes a file using the given algorithm
|
||||
"""
|
||||
|
||||
def hashlib_hasher(file_path: Path) -> str:
|
||||
"""Hashes a file using a hashlib algorithm. Uses `memoryview` to avoid reading the entire file into memory."""
|
||||
hasher = hashlib.new(algorithm)
|
||||
buffer = bytearray(128 * 1024)
|
||||
mv = memoryview(buffer)
|
||||
with open(file_path, "rb", buffering=0) as f:
|
||||
while n := f.readinto(mv):
|
||||
hasher.update(mv[:n])
|
||||
return hasher.hexdigest()
|
||||
|
||||
return hashlib_hasher
|
||||
|
||||
@staticmethod
|
||||
def _default_file_filter(file_path: str) -> bool:
|
||||
"""A default file filter that only includes files with the following extensions: .ckpt, .safetensors, .bin, .pt, .pth
|
||||
|
||||
Args:
|
||||
file_path: Path to the file
|
||||
|
||||
Returns:
|
||||
True if the file matches the given extensions, otherwise False
|
||||
"""
|
||||
return file_path.endswith(MODEL_FILE_EXTENSIONS)
|
||||
|
@ -25,6 +25,7 @@ from pydantic.networks import AnyHttpUrl
|
||||
from requests.sessions import Session
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from invokeai.app.invocations.constants import SCHEDULER_NAME_VALUES
|
||||
from invokeai.backend.model_manager import ModelRepoVariant
|
||||
|
||||
from ..util import select_hf_files
|
||||
@ -68,6 +69,15 @@ class RemoteModelFile(BaseModel):
|
||||
sha256: Optional[str] = Field(description="SHA256 hash of this model (not always available)", default=None)
|
||||
|
||||
|
||||
class ModelDefaultSettings(BaseModel):
|
||||
vae: str | None
|
||||
vae_precision: str | None
|
||||
scheduler: SCHEDULER_NAME_VALUES | None
|
||||
steps: int | None
|
||||
cfg_scale: float | None
|
||||
cfg_rescale_multiplier: float | None
|
||||
|
||||
|
||||
class ModelMetadataBase(BaseModel):
|
||||
"""Base class for model metadata information."""
|
||||
|
||||
@ -75,6 +85,9 @@ class ModelMetadataBase(BaseModel):
|
||||
author: str = Field(description="model's author")
|
||||
tags: Optional[Set[str]] = Field(description="tags provided by model source", default=None)
|
||||
trigger_phrases: Optional[List[str]] = Field(description="trigger phrases for this model", default=None)
|
||||
default_settings: Optional[ModelDefaultSettings] = Field(
|
||||
description="default settings for this model", default=None
|
||||
)
|
||||
|
||||
|
||||
class BaseMetadata(ModelMetadataBase):
|
||||
|
@ -21,7 +21,7 @@ from .config import (
|
||||
ModelVariantType,
|
||||
SchedulerPredictionType,
|
||||
)
|
||||
from .hash import FastModelHash
|
||||
from .hash import ModelHash
|
||||
from .util.model_util import lora_token_vector_length, read_checkpoint_meta
|
||||
|
||||
CkptType = Dict[str, Any]
|
||||
@ -147,7 +147,7 @@ class ModelProbe(object):
|
||||
if not probe_class:
|
||||
raise InvalidModelConfigException(f"Unhandled combination of {format_type} and {model_type}")
|
||||
|
||||
hash = FastModelHash.hash(model_path)
|
||||
hash = ModelHash().hash(model_path)
|
||||
probe = probe_class(model_path)
|
||||
|
||||
fields["path"] = model_path.as_posix()
|
||||
|
@ -4,13 +4,11 @@ Initialization file for the invokeai.backend.stable_diffusion package
|
||||
|
||||
from .diffusers_pipeline import PipelineIntermediateState, StableDiffusionGeneratorPipeline # noqa: F401
|
||||
from .diffusion import InvokeAIDiffuserComponent # noqa: F401
|
||||
from .diffusion.cross_attention_map_saving import AttentionMapSaver # noqa: F401
|
||||
from .seamless import set_seamless # noqa: F401
|
||||
|
||||
__all__ = [
|
||||
"PipelineIntermediateState",
|
||||
"StableDiffusionGeneratorPipeline",
|
||||
"InvokeAIDiffuserComponent",
|
||||
"AttentionMapSaver",
|
||||
"set_seamless",
|
||||
]
|
||||
|
@ -12,7 +12,6 @@ import torch
|
||||
import torchvision.transforms as T
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.models.controlnet import ControlNetModel
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from diffusers.schedulers import KarrasDiffusionSchedulers
|
||||
@ -26,9 +25,9 @@ from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||
from invokeai.backend.ip_adapter.unet_patcher import UNetPatcher
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData
|
||||
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
|
||||
from ..util import auto_detect_slice_size, normalize_device
|
||||
from .diffusion import AttentionMapSaver, InvokeAIDiffuserComponent
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -39,7 +38,6 @@ class PipelineIntermediateState:
|
||||
timestep: int
|
||||
latents: torch.Tensor
|
||||
predicted_original: Optional[torch.Tensor] = None
|
||||
attention_map_saver: Optional[AttentionMapSaver] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -190,19 +188,6 @@ class T2IAdapterData:
|
||||
end_step_percent: float = Field(default=1.0)
|
||||
|
||||
|
||||
@dataclass
|
||||
class InvokeAIStableDiffusionPipelineOutput(StableDiffusionPipelineOutput):
|
||||
r"""
|
||||
Output class for InvokeAI's Stable Diffusion pipeline.
|
||||
|
||||
Args:
|
||||
attention_map_saver (`AttentionMapSaver`): Object containing attention maps that can be displayed to the user
|
||||
after generation completes. Optional.
|
||||
"""
|
||||
|
||||
attention_map_saver: Optional[AttentionMapSaver]
|
||||
|
||||
|
||||
class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Stable Diffusion.
|
||||
@ -343,9 +328,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
masked_latents: Optional[torch.Tensor] = None,
|
||||
gradient_mask: Optional[bool] = False,
|
||||
seed: Optional[int] = None,
|
||||
) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]:
|
||||
) -> torch.Tensor:
|
||||
if init_timestep.shape[0] == 0:
|
||||
return latents, None
|
||||
return latents
|
||||
|
||||
if additional_guidance is None:
|
||||
additional_guidance = []
|
||||
@ -385,7 +370,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
additional_guidance.append(AddsMaskGuidance(mask, orig_latents, self.scheduler, noise, gradient_mask))
|
||||
|
||||
try:
|
||||
latents, attention_map_saver = self.generate_latents_from_embeddings(
|
||||
latents = self.generate_latents_from_embeddings(
|
||||
latents,
|
||||
timesteps,
|
||||
conditioning_data,
|
||||
@ -402,7 +387,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
if mask is not None and not gradient_mask:
|
||||
latents = torch.lerp(orig_latents, latents.to(dtype=orig_latents.dtype), mask.to(dtype=orig_latents.dtype))
|
||||
|
||||
return latents, attention_map_saver
|
||||
return latents
|
||||
|
||||
def generate_latents_from_embeddings(
|
||||
self,
|
||||
@ -415,16 +400,15 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
ip_adapter_data: Optional[list[IPAdapterData]] = None,
|
||||
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
|
||||
callback: Callable[[PipelineIntermediateState], None] = None,
|
||||
):
|
||||
) -> torch.Tensor:
|
||||
self._adjust_memory_efficient_attention(latents)
|
||||
if additional_guidance is None:
|
||||
additional_guidance = []
|
||||
|
||||
batch_size = latents.shape[0]
|
||||
attention_map_saver: Optional[AttentionMapSaver] = None
|
||||
|
||||
if timesteps.shape[0] == 0:
|
||||
return latents, attention_map_saver
|
||||
return latents
|
||||
|
||||
ip_adapter_unet_patcher = None
|
||||
extra_conditioning_info = conditioning_data.text_embeddings.extra_conditioning
|
||||
@ -432,7 +416,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
attn_ctx = self.invokeai_diffuser.custom_attention_context(
|
||||
self.invokeai_diffuser.model,
|
||||
extra_conditioning_info=extra_conditioning_info,
|
||||
step_count=len(self.scheduler.timesteps),
|
||||
)
|
||||
self.use_ip_adapter = False
|
||||
elif ip_adapter_data is not None:
|
||||
@ -483,13 +466,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
|
||||
predicted_original = getattr(step_output, "pred_original_sample", None)
|
||||
|
||||
# TODO resuscitate attention map saving
|
||||
# if i == len(timesteps)-1 and extra_conditioning_info is not None:
|
||||
# eos_token_index = extra_conditioning_info.tokens_count_including_eos_bos - 1
|
||||
# attention_map_token_ids = range(1, eos_token_index)
|
||||
# attention_map_saver = AttentionMapSaver(token_ids=attention_map_token_ids, latents_shape=latents.shape[-2:])
|
||||
# self.invokeai_diffuser.setup_attention_map_saving(attention_map_saver)
|
||||
|
||||
if callback is not None:
|
||||
callback(
|
||||
PipelineIntermediateState(
|
||||
@ -499,11 +475,10 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
timestep=int(t),
|
||||
latents=latents,
|
||||
predicted_original=predicted_original,
|
||||
attention_map_saver=attention_map_saver,
|
||||
)
|
||||
)
|
||||
|
||||
return latents, attention_map_saver
|
||||
return latents
|
||||
|
||||
@torch.inference_mode()
|
||||
def step(
|
||||
|
@ -2,6 +2,4 @@
|
||||
Initialization file for invokeai.models.diffusion
|
||||
"""
|
||||
|
||||
from .cross_attention_control import InvokeAICrossAttentionMixin # noqa: F401
|
||||
from .cross_attention_map_saving import AttentionMapSaver # noqa: F401
|
||||
from .shared_invokeai_diffusion import InvokeAIDiffuserComponent # noqa: F401
|
||||
|
@ -3,19 +3,13 @@
|
||||
|
||||
|
||||
import enum
|
||||
import math
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Callable, Optional
|
||||
from typing import Optional
|
||||
|
||||
import diffusers
|
||||
import psutil
|
||||
import torch
|
||||
from compel.cross_attention_control import Arguments
|
||||
from diffusers.models.attention_processor import Attention, AttentionProcessor, AttnProcessor, SlicedAttnProcessor
|
||||
from diffusers.models.attention_processor import Attention, SlicedAttnProcessor
|
||||
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
||||
from torch import nn
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
|
||||
from ...util import torch_dtype
|
||||
|
||||
@ -25,72 +19,14 @@ class CrossAttentionType(enum.Enum):
|
||||
TOKENS = 2
|
||||
|
||||
|
||||
class Context:
|
||||
cross_attention_mask: Optional[torch.Tensor]
|
||||
cross_attention_index_map: Optional[torch.Tensor]
|
||||
|
||||
class Action(enum.Enum):
|
||||
NONE = 0
|
||||
SAVE = (1,)
|
||||
APPLY = 2
|
||||
|
||||
def __init__(self, arguments: Arguments, step_count: int):
|
||||
class CrossAttnControlContext:
|
||||
def __init__(self, arguments: Arguments):
|
||||
"""
|
||||
:param arguments: Arguments for the cross-attention control process
|
||||
:param step_count: The absolute total number of steps of diffusion (for img2img this is likely larger than the number of steps that will actually run)
|
||||
"""
|
||||
self.cross_attention_mask = None
|
||||
self.cross_attention_index_map = None
|
||||
self.self_cross_attention_action = Context.Action.NONE
|
||||
self.tokens_cross_attention_action = Context.Action.NONE
|
||||
self.cross_attention_mask: Optional[torch.Tensor] = None
|
||||
self.cross_attention_index_map: Optional[torch.Tensor] = None
|
||||
self.arguments = arguments
|
||||
self.step_count = step_count
|
||||
|
||||
self.self_cross_attention_module_identifiers = []
|
||||
self.tokens_cross_attention_module_identifiers = []
|
||||
|
||||
self.saved_cross_attention_maps = {}
|
||||
|
||||
self.clear_requests(cleanup=True)
|
||||
|
||||
def register_cross_attention_modules(self, model):
|
||||
for name, _module in get_cross_attention_modules(model, CrossAttentionType.SELF):
|
||||
if name in self.self_cross_attention_module_identifiers:
|
||||
raise AssertionError(f"name {name} cannot appear more than once")
|
||||
self.self_cross_attention_module_identifiers.append(name)
|
||||
for name, _module in get_cross_attention_modules(model, CrossAttentionType.TOKENS):
|
||||
if name in self.tokens_cross_attention_module_identifiers:
|
||||
raise AssertionError(f"name {name} cannot appear more than once")
|
||||
self.tokens_cross_attention_module_identifiers.append(name)
|
||||
|
||||
def request_save_attention_maps(self, cross_attention_type: CrossAttentionType):
|
||||
if cross_attention_type == CrossAttentionType.SELF:
|
||||
self.self_cross_attention_action = Context.Action.SAVE
|
||||
else:
|
||||
self.tokens_cross_attention_action = Context.Action.SAVE
|
||||
|
||||
def request_apply_saved_attention_maps(self, cross_attention_type: CrossAttentionType):
|
||||
if cross_attention_type == CrossAttentionType.SELF:
|
||||
self.self_cross_attention_action = Context.Action.APPLY
|
||||
else:
|
||||
self.tokens_cross_attention_action = Context.Action.APPLY
|
||||
|
||||
def is_tokens_cross_attention(self, module_identifier) -> bool:
|
||||
return module_identifier in self.tokens_cross_attention_module_identifiers
|
||||
|
||||
def get_should_save_maps(self, module_identifier: str) -> bool:
|
||||
if module_identifier in self.self_cross_attention_module_identifiers:
|
||||
return self.self_cross_attention_action == Context.Action.SAVE
|
||||
elif module_identifier in self.tokens_cross_attention_module_identifiers:
|
||||
return self.tokens_cross_attention_action == Context.Action.SAVE
|
||||
return False
|
||||
|
||||
def get_should_apply_saved_maps(self, module_identifier: str) -> bool:
|
||||
if module_identifier in self.self_cross_attention_module_identifiers:
|
||||
return self.self_cross_attention_action == Context.Action.APPLY
|
||||
elif module_identifier in self.tokens_cross_attention_module_identifiers:
|
||||
return self.tokens_cross_attention_action == Context.Action.APPLY
|
||||
return False
|
||||
|
||||
def get_active_cross_attention_control_types_for_step(
|
||||
self, percent_through: float = None
|
||||
@ -111,219 +47,8 @@ class Context:
|
||||
to_control.append(CrossAttentionType.TOKENS)
|
||||
return to_control
|
||||
|
||||
def save_slice(
|
||||
self,
|
||||
identifier: str,
|
||||
slice: torch.Tensor,
|
||||
dim: Optional[int],
|
||||
offset: int,
|
||||
slice_size: Optional[int],
|
||||
):
|
||||
if identifier not in self.saved_cross_attention_maps:
|
||||
self.saved_cross_attention_maps[identifier] = {
|
||||
"dim": dim,
|
||||
"slice_size": slice_size,
|
||||
"slices": {offset or 0: slice},
|
||||
}
|
||||
else:
|
||||
self.saved_cross_attention_maps[identifier]["slices"][offset or 0] = slice
|
||||
|
||||
def get_slice(
|
||||
self,
|
||||
identifier: str,
|
||||
requested_dim: Optional[int],
|
||||
requested_offset: int,
|
||||
slice_size: int,
|
||||
):
|
||||
saved_attention_dict = self.saved_cross_attention_maps[identifier]
|
||||
if requested_dim is None:
|
||||
if saved_attention_dict["dim"] is not None:
|
||||
raise RuntimeError(f"dim mismatch: expected dim=None, have {saved_attention_dict['dim']}")
|
||||
return saved_attention_dict["slices"][0]
|
||||
|
||||
if saved_attention_dict["dim"] == requested_dim:
|
||||
if slice_size != saved_attention_dict["slice_size"]:
|
||||
raise RuntimeError(
|
||||
f"slice_size mismatch: expected slice_size={slice_size}, have {saved_attention_dict['slice_size']}"
|
||||
)
|
||||
return saved_attention_dict["slices"][requested_offset]
|
||||
|
||||
if saved_attention_dict["dim"] is None:
|
||||
whole_saved_attention = saved_attention_dict["slices"][0]
|
||||
if requested_dim == 0:
|
||||
return whole_saved_attention[requested_offset : requested_offset + slice_size]
|
||||
elif requested_dim == 1:
|
||||
return whole_saved_attention[:, requested_offset : requested_offset + slice_size]
|
||||
|
||||
raise RuntimeError(f"Cannot convert dim {saved_attention_dict['dim']} to requested dim {requested_dim}")
|
||||
|
||||
def get_slicing_strategy(self, identifier: str) -> tuple[Optional[int], Optional[int]]:
|
||||
saved_attention = self.saved_cross_attention_maps.get(identifier, None)
|
||||
if saved_attention is None:
|
||||
return None, None
|
||||
return saved_attention["dim"], saved_attention["slice_size"]
|
||||
|
||||
def clear_requests(self, cleanup=True):
|
||||
self.tokens_cross_attention_action = Context.Action.NONE
|
||||
self.self_cross_attention_action = Context.Action.NONE
|
||||
if cleanup:
|
||||
self.saved_cross_attention_maps = {}
|
||||
|
||||
def offload_saved_attention_slices_to_cpu(self):
|
||||
for _key, map_dict in self.saved_cross_attention_maps.items():
|
||||
for offset, slice in map_dict["slices"].items():
|
||||
map_dict[offset] = slice.to("cpu")
|
||||
|
||||
|
||||
class InvokeAICrossAttentionMixin:
|
||||
"""
|
||||
Enable InvokeAI-flavoured Attention calculation, which does aggressive low-memory slicing and calls
|
||||
through both to an attention_slice_wrangler and a slicing_strategy_getter for custom attention map wrangling
|
||||
and dymamic slicing strategy selection.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.mem_total_gb = psutil.virtual_memory().total // (1 << 30)
|
||||
self.attention_slice_wrangler = None
|
||||
self.slicing_strategy_getter = None
|
||||
self.attention_slice_calculated_callback = None
|
||||
|
||||
def set_attention_slice_wrangler(
|
||||
self,
|
||||
wrangler: Optional[Callable[[nn.Module, torch.Tensor, int, int, int], torch.Tensor]],
|
||||
):
|
||||
"""
|
||||
Set custom attention calculator to be called when attention is calculated
|
||||
:param wrangler: Callback, with args (module, suggested_attention_slice, dim, offset, slice_size),
|
||||
which returns either the suggested_attention_slice or an adjusted equivalent.
|
||||
`module` is the current Attention module for which the callback is being invoked.
|
||||
`suggested_attention_slice` is the default-calculated attention slice
|
||||
`dim` is -1 if the attenion map has not been sliced, or 0 or 1 for dimension-0 or dimension-1 slicing.
|
||||
If `dim` is >= 0, `offset` and `slice_size` specify the slice start and length.
|
||||
|
||||
Pass None to use the default attention calculation.
|
||||
:return:
|
||||
"""
|
||||
self.attention_slice_wrangler = wrangler
|
||||
|
||||
def set_slicing_strategy_getter(self, getter: Optional[Callable[[nn.Module], tuple[int, int]]]):
|
||||
self.slicing_strategy_getter = getter
|
||||
|
||||
def set_attention_slice_calculated_callback(self, callback: Optional[Callable[[torch.Tensor], None]]):
|
||||
self.attention_slice_calculated_callback = callback
|
||||
|
||||
def einsum_lowest_level(self, query, key, value, dim, offset, slice_size):
|
||||
# calculate attention scores
|
||||
# attention_scores = torch.einsum('b i d, b j d -> b i j', q, k)
|
||||
attention_scores = torch.baddbmm(
|
||||
torch.empty(
|
||||
query.shape[0],
|
||||
query.shape[1],
|
||||
key.shape[1],
|
||||
dtype=query.dtype,
|
||||
device=query.device,
|
||||
),
|
||||
query,
|
||||
key.transpose(-1, -2),
|
||||
beta=0,
|
||||
alpha=self.scale,
|
||||
)
|
||||
|
||||
# calculate attention slice by taking the best scores for each latent pixel
|
||||
default_attention_slice = attention_scores.softmax(dim=-1, dtype=attention_scores.dtype)
|
||||
attention_slice_wrangler = self.attention_slice_wrangler
|
||||
if attention_slice_wrangler is not None:
|
||||
attention_slice = attention_slice_wrangler(self, default_attention_slice, dim, offset, slice_size)
|
||||
else:
|
||||
attention_slice = default_attention_slice
|
||||
|
||||
if self.attention_slice_calculated_callback is not None:
|
||||
self.attention_slice_calculated_callback(attention_slice, dim, offset, slice_size)
|
||||
|
||||
hidden_states = torch.bmm(attention_slice, value)
|
||||
return hidden_states
|
||||
|
||||
def einsum_op_slice_dim0(self, q, k, v, slice_size):
|
||||
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
for i in range(0, q.shape[0], slice_size):
|
||||
end = i + slice_size
|
||||
r[i:end] = self.einsum_lowest_level(q[i:end], k[i:end], v[i:end], dim=0, offset=i, slice_size=slice_size)
|
||||
return r
|
||||
|
||||
def einsum_op_slice_dim1(self, q, k, v, slice_size):
|
||||
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
end = i + slice_size
|
||||
r[:, i:end] = self.einsum_lowest_level(q[:, i:end], k, v, dim=1, offset=i, slice_size=slice_size)
|
||||
return r
|
||||
|
||||
def einsum_op_mps_v1(self, q, k, v):
|
||||
if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
|
||||
return self.einsum_lowest_level(q, k, v, None, None, None)
|
||||
else:
|
||||
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
|
||||
return self.einsum_op_slice_dim1(q, k, v, slice_size)
|
||||
|
||||
def einsum_op_mps_v2(self, q, k, v):
|
||||
if self.mem_total_gb > 8 and q.shape[1] <= 4096:
|
||||
return self.einsum_lowest_level(q, k, v, None, None, None)
|
||||
else:
|
||||
return self.einsum_op_slice_dim0(q, k, v, 1)
|
||||
|
||||
def einsum_op_tensor_mem(self, q, k, v, max_tensor_mb):
|
||||
size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)
|
||||
if size_mb <= max_tensor_mb:
|
||||
return self.einsum_lowest_level(q, k, v, None, None, None)
|
||||
div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length()
|
||||
if div <= q.shape[0]:
|
||||
return self.einsum_op_slice_dim0(q, k, v, q.shape[0] // div)
|
||||
return self.einsum_op_slice_dim1(q, k, v, max(q.shape[1] // div, 1))
|
||||
|
||||
def einsum_op_cuda(self, q, k, v):
|
||||
# check if we already have a slicing strategy (this should only happen during cross-attention controlled generation)
|
||||
slicing_strategy_getter = self.slicing_strategy_getter
|
||||
if slicing_strategy_getter is not None:
|
||||
(dim, slice_size) = slicing_strategy_getter(self)
|
||||
if dim is not None:
|
||||
# print("using saved slicing strategy with dim", dim, "slice size", slice_size)
|
||||
if dim == 0:
|
||||
return self.einsum_op_slice_dim0(q, k, v, slice_size)
|
||||
elif dim == 1:
|
||||
return self.einsum_op_slice_dim1(q, k, v, slice_size)
|
||||
|
||||
# fallback for when there is no saved strategy, or saved strategy does not slice
|
||||
mem_free_total = get_mem_free_total(q.device)
|
||||
# Divide factor of safety as there's copying and fragmentation
|
||||
return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
|
||||
|
||||
def get_invokeai_attention_mem_efficient(self, q, k, v):
|
||||
if q.device.type == "cuda":
|
||||
# print("in get_attention_mem_efficient with q shape", q.shape, ", k shape", k.shape, ", free memory is", get_mem_free_total(q.device))
|
||||
return self.einsum_op_cuda(q, k, v)
|
||||
|
||||
if q.device.type == "mps" or q.device.type == "cpu":
|
||||
if self.mem_total_gb >= 32:
|
||||
return self.einsum_op_mps_v1(q, k, v)
|
||||
return self.einsum_op_mps_v2(q, k, v)
|
||||
|
||||
# Smaller slices are faster due to L2/L3/SLC caches.
|
||||
# Tested on i7 with 8MB L3 cache.
|
||||
return self.einsum_op_tensor_mem(q, k, v, 32)
|
||||
|
||||
|
||||
def restore_default_cross_attention(
|
||||
model,
|
||||
is_running_diffusers: bool,
|
||||
restore_attention_processor: Optional[AttentionProcessor] = None,
|
||||
):
|
||||
if is_running_diffusers:
|
||||
unet = model
|
||||
unet.set_attn_processor(restore_attention_processor or AttnProcessor())
|
||||
else:
|
||||
remove_attention_function(model)
|
||||
|
||||
|
||||
def setup_cross_attention_control_attention_processors(unet: UNet2DConditionModel, context: Context):
|
||||
def setup_cross_attention_control_attention_processors(unet: UNet2DConditionModel, context: CrossAttnControlContext):
|
||||
"""
|
||||
Inject attention parameters and functions into the passed in model to enable cross attention editing.
|
||||
|
||||
@ -362,170 +87,6 @@ def setup_cross_attention_control_attention_processors(unet: UNet2DConditionMode
|
||||
unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size))
|
||||
|
||||
|
||||
def get_cross_attention_modules(model, which: CrossAttentionType) -> list[tuple[str, InvokeAICrossAttentionMixin]]:
|
||||
cross_attention_class: type = InvokeAIDiffusersCrossAttention
|
||||
which_attn = "attn1" if which is CrossAttentionType.SELF else "attn2"
|
||||
attention_module_tuples = [
|
||||
(name, module)
|
||||
for name, module in model.named_modules()
|
||||
if isinstance(module, cross_attention_class) and which_attn in name
|
||||
]
|
||||
cross_attention_modules_in_model_count = len(attention_module_tuples)
|
||||
expected_count = 16
|
||||
if cross_attention_modules_in_model_count != expected_count:
|
||||
# non-fatal error but .swap() won't work.
|
||||
logger.error(
|
||||
f"Error! CrossAttentionControl found an unexpected number of {cross_attention_class} modules in the model "
|
||||
f"(expected {expected_count}, found {cross_attention_modules_in_model_count}). Either monkey-patching "
|
||||
"failed or some assumption has changed about the structure of the model itself. Please fix the "
|
||||
f"monkey-patching, and/or update the {expected_count} above to an appropriate number, and/or find and "
|
||||
"inform someone who knows what it means. This error is non-fatal, but it is likely that .swap() and "
|
||||
"attention map display will not work properly until it is fixed."
|
||||
)
|
||||
return attention_module_tuples
|
||||
|
||||
|
||||
def inject_attention_function(unet, context: Context):
|
||||
# ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276
|
||||
|
||||
def attention_slice_wrangler(module, suggested_attention_slice: torch.Tensor, dim, offset, slice_size):
|
||||
# memory_usage = suggested_attention_slice.element_size() * suggested_attention_slice.nelement()
|
||||
|
||||
attention_slice = suggested_attention_slice
|
||||
|
||||
if context.get_should_save_maps(module.identifier):
|
||||
# print(module.identifier, "saving suggested_attention_slice of shape",
|
||||
# suggested_attention_slice.shape, "dim", dim, "offset", offset)
|
||||
slice_to_save = attention_slice.to("cpu") if dim is not None else attention_slice
|
||||
context.save_slice(
|
||||
module.identifier,
|
||||
slice_to_save,
|
||||
dim=dim,
|
||||
offset=offset,
|
||||
slice_size=slice_size,
|
||||
)
|
||||
elif context.get_should_apply_saved_maps(module.identifier):
|
||||
# print(module.identifier, "applying saved attention slice for dim", dim, "offset", offset)
|
||||
saved_attention_slice = context.get_slice(module.identifier, dim, offset, slice_size)
|
||||
|
||||
# slice may have been offloaded to CPU
|
||||
saved_attention_slice = saved_attention_slice.to(suggested_attention_slice.device)
|
||||
|
||||
if context.is_tokens_cross_attention(module.identifier):
|
||||
index_map = context.cross_attention_index_map
|
||||
remapped_saved_attention_slice = torch.index_select(saved_attention_slice, -1, index_map)
|
||||
this_attention_slice = suggested_attention_slice
|
||||
|
||||
mask = context.cross_attention_mask.to(torch_dtype(suggested_attention_slice.device))
|
||||
saved_mask = mask
|
||||
this_mask = 1 - mask
|
||||
attention_slice = remapped_saved_attention_slice * saved_mask + this_attention_slice * this_mask
|
||||
else:
|
||||
# just use everything
|
||||
attention_slice = saved_attention_slice
|
||||
|
||||
return attention_slice
|
||||
|
||||
cross_attention_modules = get_cross_attention_modules(
|
||||
unet, CrossAttentionType.TOKENS
|
||||
) + get_cross_attention_modules(unet, CrossAttentionType.SELF)
|
||||
for identifier, module in cross_attention_modules:
|
||||
module.identifier = identifier
|
||||
try:
|
||||
module.set_attention_slice_wrangler(attention_slice_wrangler)
|
||||
module.set_slicing_strategy_getter(lambda module: context.get_slicing_strategy(identifier)) # noqa: B023
|
||||
except AttributeError as e:
|
||||
if is_attribute_error_about(e, "set_attention_slice_wrangler"):
|
||||
print(f"TODO: implement set_attention_slice_wrangler for {type(module)}") # TODO
|
||||
else:
|
||||
raise
|
||||
|
||||
|
||||
def remove_attention_function(unet):
|
||||
cross_attention_modules = get_cross_attention_modules(
|
||||
unet, CrossAttentionType.TOKENS
|
||||
) + get_cross_attention_modules(unet, CrossAttentionType.SELF)
|
||||
for _identifier, module in cross_attention_modules:
|
||||
try:
|
||||
# clear wrangler callback
|
||||
module.set_attention_slice_wrangler(None)
|
||||
module.set_slicing_strategy_getter(None)
|
||||
except AttributeError as e:
|
||||
if is_attribute_error_about(e, "set_attention_slice_wrangler"):
|
||||
print(f"TODO: implement set_attention_slice_wrangler for {type(module)}")
|
||||
else:
|
||||
raise
|
||||
|
||||
|
||||
def is_attribute_error_about(error: AttributeError, attribute: str):
|
||||
if hasattr(error, "name"): # Python 3.10
|
||||
return error.name == attribute
|
||||
else: # Python 3.9
|
||||
return attribute in str(error)
|
||||
|
||||
|
||||
def get_mem_free_total(device):
|
||||
# only on cuda
|
||||
if not torch.cuda.is_available():
|
||||
return None
|
||||
stats = torch.cuda.memory_stats(device)
|
||||
mem_active = stats["active_bytes.all.current"]
|
||||
mem_reserved = stats["reserved_bytes.all.current"]
|
||||
mem_free_cuda, _ = torch.cuda.mem_get_info(device)
|
||||
mem_free_torch = mem_reserved - mem_active
|
||||
mem_free_total = mem_free_cuda + mem_free_torch
|
||||
return mem_free_total
|
||||
|
||||
|
||||
class InvokeAIDiffusersCrossAttention(diffusers.models.attention.Attention, InvokeAICrossAttentionMixin):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
InvokeAICrossAttentionMixin.__init__(self)
|
||||
|
||||
def _attention(self, query, key, value, attention_mask=None):
|
||||
# default_result = super()._attention(query, key, value)
|
||||
if attention_mask is not None:
|
||||
print(f"{type(self).__name__} ignoring passed-in attention_mask")
|
||||
attention_result = self.get_invokeai_attention_mem_efficient(query, key, value)
|
||||
|
||||
hidden_states = self.reshape_batch_dim_to_heads(attention_result)
|
||||
return hidden_states
|
||||
|
||||
|
||||
## 🧨diffusers implementation follows
|
||||
|
||||
|
||||
"""
|
||||
# base implementation
|
||||
|
||||
class AttnProcessor:
|
||||
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
||||
batch_size, sequence_length, _ = hidden_states.shape
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
query = attn.head_to_batch_dim(query)
|
||||
|
||||
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
key = attn.head_to_batch_dim(key)
|
||||
value = attn.head_to_batch_dim(value)
|
||||
|
||||
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
||||
hidden_states = torch.bmm(attention_probs, value)
|
||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class SwapCrossAttnContext:
|
||||
modified_text_embeddings: torch.Tensor
|
||||
|
@ -1,100 +0,0 @@
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torchvision.transforms.functional import InterpolationMode
|
||||
from torchvision.transforms.functional import resize as tv_resize
|
||||
|
||||
|
||||
class AttentionMapSaver:
|
||||
def __init__(self, token_ids: range, latents_shape: torch.Size):
|
||||
self.token_ids = token_ids
|
||||
self.latents_shape = latents_shape
|
||||
# self.collated_maps = #torch.zeros([len(token_ids), latents_shape[0], latents_shape[1]])
|
||||
self.collated_maps: dict[str, torch.Tensor] = {}
|
||||
|
||||
def clear_maps(self):
|
||||
self.collated_maps = {}
|
||||
|
||||
def add_attention_maps(self, maps: torch.Tensor, key: str):
|
||||
"""
|
||||
Accumulate the given attention maps and store by summing with existing maps at the passed-in key (if any).
|
||||
:param maps: Attention maps to store. Expected shape [A, (H*W), N] where A is attention heads count, H and W are the map size (fixed per-key) and N is the number of tokens (typically 77).
|
||||
:param key: Storage key. If a map already exists for this key it will be summed with the incoming data. In this case the maps sizes (H and W) should match.
|
||||
:return: None
|
||||
"""
|
||||
key_and_size = f"{key}_{maps.shape[1]}"
|
||||
|
||||
# extract desired tokens
|
||||
maps = maps[:, :, self.token_ids]
|
||||
|
||||
# merge attention heads to a single map per token
|
||||
maps = torch.sum(maps, 0)
|
||||
|
||||
# store
|
||||
if key_and_size not in self.collated_maps:
|
||||
self.collated_maps[key_and_size] = torch.zeros_like(maps, device="cpu")
|
||||
self.collated_maps[key_and_size] += maps.cpu()
|
||||
|
||||
def write_maps_to_disk(self, path: str):
|
||||
pil_image = self.get_stacked_maps_image()
|
||||
if pil_image is not None:
|
||||
pil_image.save(path, "PNG")
|
||||
|
||||
def get_stacked_maps_image(self) -> Optional[Image.Image]:
|
||||
"""
|
||||
Scale all collected attention maps to the same size, blend them together and return as an image.
|
||||
:return: An image containing a vertical stack of blended attention maps, one for each requested token.
|
||||
"""
|
||||
num_tokens = len(self.token_ids)
|
||||
if num_tokens == 0:
|
||||
return None
|
||||
|
||||
latents_height = self.latents_shape[0]
|
||||
latents_width = self.latents_shape[1]
|
||||
|
||||
merged = None
|
||||
|
||||
for _key, maps in self.collated_maps.items():
|
||||
# maps has shape [(H*W), N] for N tokens
|
||||
# but we want [N, H, W]
|
||||
this_scale_factor = math.sqrt(maps.shape[0] / (latents_width * latents_height))
|
||||
this_maps_height = int(float(latents_height) * this_scale_factor)
|
||||
this_maps_width = int(float(latents_width) * this_scale_factor)
|
||||
# and we need to do some dimension juggling
|
||||
maps = torch.reshape(
|
||||
torch.swapdims(maps, 0, 1),
|
||||
[num_tokens, this_maps_height, this_maps_width],
|
||||
)
|
||||
|
||||
# scale to output size if necessary
|
||||
if this_scale_factor != 1:
|
||||
maps = tv_resize(maps, [latents_height, latents_width], InterpolationMode.BICUBIC)
|
||||
|
||||
# normalize
|
||||
maps_min = torch.min(maps)
|
||||
maps_range = torch.max(maps) - maps_min
|
||||
# print(f"map {key} size {[this_maps_width, this_maps_height]} range {[maps_min, maps_min + maps_range]}")
|
||||
maps_normalized = (maps - maps_min) / maps_range
|
||||
# expand to (-0.1, 1.1) and clamp
|
||||
maps_normalized_expanded = maps_normalized * 1.1 - 0.05
|
||||
maps_normalized_expanded_clamped = torch.clamp(maps_normalized_expanded, 0, 1)
|
||||
|
||||
# merge together, producing a vertical stack
|
||||
maps_stacked = torch.reshape(
|
||||
maps_normalized_expanded_clamped,
|
||||
[num_tokens * latents_height, latents_width],
|
||||
)
|
||||
|
||||
if merged is None:
|
||||
merged = maps_stacked
|
||||
else:
|
||||
# screen blend
|
||||
merged = 1 - (1 - maps_stacked) * (1 - merged)
|
||||
|
||||
if merged is None:
|
||||
return None
|
||||
|
||||
merged_bytes = merged.mul(0xFF).byte()
|
||||
return Image.fromarray(merged_bytes.numpy(), mode="L")
|
@ -17,13 +17,11 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
)
|
||||
|
||||
from .cross_attention_control import (
|
||||
Context,
|
||||
CrossAttentionType,
|
||||
CrossAttnControlContext,
|
||||
SwapCrossAttnContext,
|
||||
get_cross_attention_modules,
|
||||
setup_cross_attention_control_attention_processors,
|
||||
)
|
||||
from .cross_attention_map_saving import AttentionMapSaver
|
||||
|
||||
ModelForwardCallback: TypeAlias = Union[
|
||||
# x, t, conditioning, Optional[cross-attention kwargs]
|
||||
@ -69,14 +67,12 @@ class InvokeAIDiffuserComponent:
|
||||
self,
|
||||
unet: UNet2DConditionModel,
|
||||
extra_conditioning_info: Optional[ExtraConditioningInfo],
|
||||
step_count: int,
|
||||
):
|
||||
old_attn_processors = unet.attn_processors
|
||||
|
||||
try:
|
||||
self.cross_attention_control_context = Context(
|
||||
self.cross_attention_control_context = CrossAttnControlContext(
|
||||
arguments=extra_conditioning_info.cross_attention_control_args,
|
||||
step_count=step_count,
|
||||
)
|
||||
setup_cross_attention_control_attention_processors(
|
||||
unet,
|
||||
@ -87,27 +83,6 @@ class InvokeAIDiffuserComponent:
|
||||
finally:
|
||||
self.cross_attention_control_context = None
|
||||
unet.set_attn_processor(old_attn_processors)
|
||||
# TODO resuscitate attention map saving
|
||||
# self.remove_attention_map_saving()
|
||||
|
||||
def setup_attention_map_saving(self, saver: AttentionMapSaver):
|
||||
def callback(slice, dim, offset, slice_size, key):
|
||||
if dim is not None:
|
||||
# sliced tokens attention map saving is not implemented
|
||||
return
|
||||
saver.add_attention_maps(slice, key)
|
||||
|
||||
tokens_cross_attention_modules = get_cross_attention_modules(self.model, CrossAttentionType.TOKENS)
|
||||
for identifier, module in tokens_cross_attention_modules:
|
||||
key = "down" if identifier.startswith("down") else "up" if identifier.startswith("up") else "mid"
|
||||
module.set_attention_slice_calculated_callback(
|
||||
lambda slice, dim, offset, slice_size, key=key: callback(slice, dim, offset, slice_size, key)
|
||||
)
|
||||
|
||||
def remove_attention_map_saving(self):
|
||||
tokens_cross_attention_modules = get_cross_attention_modules(self.model, CrossAttentionType.TOKENS)
|
||||
for _, module in tokens_cross_attention_modules:
|
||||
module.set_attention_slice_calculated_callback(None)
|
||||
|
||||
def do_controlnet_step(
|
||||
self,
|
||||
@ -592,54 +567,3 @@ class InvokeAIDiffuserComponent:
|
||||
|
||||
self.last_percent_through = percent_through
|
||||
return latents.to(device=dev)
|
||||
|
||||
# todo: make this work
|
||||
@classmethod
|
||||
def apply_conjunction(cls, x, t, forward_func, uc, c_or_weighted_c_list, global_guidance_scale):
|
||||
x_in = torch.cat([x] * 2)
|
||||
t_in = torch.cat([t] * 2) # aka sigmas
|
||||
|
||||
deltas = None
|
||||
uncond_latents = None
|
||||
weighted_cond_list = (
|
||||
c_or_weighted_c_list if isinstance(c_or_weighted_c_list, list) else [(c_or_weighted_c_list, 1)]
|
||||
)
|
||||
|
||||
# below is fugly omg
|
||||
conditionings = [uc] + [c for c, weight in weighted_cond_list]
|
||||
weights = [1] + [weight for c, weight in weighted_cond_list]
|
||||
chunk_count = math.ceil(len(conditionings) / 2)
|
||||
deltas = None
|
||||
for chunk_index in range(chunk_count):
|
||||
offset = chunk_index * 2
|
||||
chunk_size = min(2, len(conditionings) - offset)
|
||||
|
||||
if chunk_size == 1:
|
||||
c_in = conditionings[offset]
|
||||
latents_a = forward_func(x_in[:-1], t_in[:-1], c_in)
|
||||
latents_b = None
|
||||
else:
|
||||
c_in = torch.cat(conditionings[offset : offset + 2])
|
||||
latents_a, latents_b = forward_func(x_in, t_in, c_in).chunk(2)
|
||||
|
||||
# first chunk is guaranteed to be 2 entries: uncond_latents + first conditioining
|
||||
if chunk_index == 0:
|
||||
uncond_latents = latents_a
|
||||
deltas = latents_b - uncond_latents
|
||||
else:
|
||||
deltas = torch.cat((deltas, latents_a - uncond_latents))
|
||||
if latents_b is not None:
|
||||
deltas = torch.cat((deltas, latents_b - uncond_latents))
|
||||
|
||||
# merge the weighted deltas together into a single merged delta
|
||||
per_delta_weights = torch.tensor(weights[1:], dtype=deltas.dtype, device=deltas.device)
|
||||
normalize = False
|
||||
if normalize:
|
||||
per_delta_weights /= torch.sum(per_delta_weights)
|
||||
reshaped_weights = per_delta_weights.reshape(per_delta_weights.shape + (1, 1, 1))
|
||||
deltas_merged = torch.sum(deltas * reshaped_weights, dim=0, keepdim=True)
|
||||
|
||||
# old_return_value = super().forward(x, sigma, uncond, cond, cond_scale)
|
||||
# assert(0 == len(torch.nonzero(old_return_value - (uncond_latents + deltas_merged * cond_scale))))
|
||||
|
||||
return uncond_latents + deltas_merged * global_guidance_scale
|
||||
|
@ -134,8 +134,6 @@
|
||||
"loadMore": "Mehr laden",
|
||||
"noImagesInGallery": "Keine Bilder in der Galerie",
|
||||
"loading": "Lade",
|
||||
"preparingDownload": "bereite Download vor",
|
||||
"preparingDownloadFailed": "Problem beim Download vorbereiten",
|
||||
"deleteImage": "Lösche Bild",
|
||||
"copy": "Kopieren",
|
||||
"download": "Runterladen",
|
||||
@ -967,7 +965,7 @@
|
||||
"resumeFailed": "Problem beim Fortsetzen des Prozesses",
|
||||
"pruneFailed": "Problem beim leeren der Warteschlange",
|
||||
"pauseTooltip": "Prozess anhalten",
|
||||
"back": "Hinten",
|
||||
"back": "Ende",
|
||||
"resumeSucceeded": "Prozess wird fortgesetzt",
|
||||
"resumeTooltip": "Prozess wieder aufnehmen",
|
||||
"time": "Zeit",
|
||||
|
@ -741,6 +741,8 @@
|
||||
"customConfig": "Custom Config",
|
||||
"customConfigFileLocation": "Custom Config File Location",
|
||||
"customSaveLocation": "Custom Save Location",
|
||||
"defaultSettings": "Default Settings",
|
||||
"defaultSettingsSaved": "Default Settings Saved",
|
||||
"delete": "Delete",
|
||||
"deleteConfig": "Delete Config",
|
||||
"deleteModel": "Delete Model",
|
||||
@ -852,6 +854,7 @@
|
||||
"upcastAttention": "Upcast Attention",
|
||||
"updateModel": "Update Model",
|
||||
"useCustomConfig": "Use Custom Config",
|
||||
"useDefaultSettings": "Use Default Settings",
|
||||
"v1": "v1",
|
||||
"v2_768": "v2 (768px)",
|
||||
"v2_base": "v2 (512px)",
|
||||
@ -870,6 +873,7 @@
|
||||
"models": {
|
||||
"addLora": "Add LoRA",
|
||||
"allLoRAsAdded": "All LoRAs added",
|
||||
"concepts": "Concepts",
|
||||
"loraAlreadyAdded": "LoRA already added",
|
||||
"esrganModel": "ESRGAN Model",
|
||||
"loading": "loading",
|
||||
|
@ -505,8 +505,6 @@
|
||||
"seamLowThreshold": "Bajo",
|
||||
"coherencePassHeader": "Parámetros de la coherencia",
|
||||
"compositingSettingsHeader": "Ajustes de la composición",
|
||||
"coherenceSteps": "Pasos",
|
||||
"coherenceStrength": "Fuerza",
|
||||
"patchmatchDownScaleSize": "Reducir a escala",
|
||||
"coherenceMode": "Modo"
|
||||
},
|
||||
|
@ -114,7 +114,8 @@
|
||||
"checkpoint": "Checkpoint",
|
||||
"safetensors": "Safetensors",
|
||||
"ai": "ia",
|
||||
"file": "File"
|
||||
"file": "File",
|
||||
"toResolve": "Da risolvere"
|
||||
},
|
||||
"gallery": {
|
||||
"generations": "Generazioni",
|
||||
@ -142,8 +143,6 @@
|
||||
"copy": "Copia",
|
||||
"download": "Scarica",
|
||||
"setCurrentImage": "Imposta come immagine corrente",
|
||||
"preparingDownload": "Preparazione del download",
|
||||
"preparingDownloadFailed": "Problema durante la preparazione del download",
|
||||
"downloadSelection": "Scarica gli elementi selezionati",
|
||||
"noImageSelected": "Nessuna immagine selezionata",
|
||||
"deleteSelection": "Elimina la selezione",
|
||||
@ -609,8 +608,6 @@
|
||||
"seamLowThreshold": "Basso",
|
||||
"seamHighThreshold": "Alto",
|
||||
"coherencePassHeader": "Passaggio di coerenza",
|
||||
"coherenceSteps": "Passi",
|
||||
"coherenceStrength": "Forza",
|
||||
"compositingSettingsHeader": "Impostazioni di composizione",
|
||||
"patchmatchDownScaleSize": "Ridimensiona",
|
||||
"coherenceMode": "Modalità",
|
||||
@ -1400,19 +1397,6 @@
|
||||
"Regola la maschera."
|
||||
]
|
||||
},
|
||||
"compositingCoherenceSteps": {
|
||||
"heading": "Passi",
|
||||
"paragraphs": [
|
||||
"Numero di passi utilizzati nel Passaggio di Coerenza.",
|
||||
"Simile ai passi di generazione."
|
||||
]
|
||||
},
|
||||
"compositingBlur": {
|
||||
"heading": "Sfocatura",
|
||||
"paragraphs": [
|
||||
"Il raggio di sfocatura della maschera."
|
||||
]
|
||||
},
|
||||
"compositingCoherenceMode": {
|
||||
"heading": "Modalità",
|
||||
"paragraphs": [
|
||||
@ -1431,13 +1415,6 @@
|
||||
"Un secondo ciclo di riduzione del rumore aiuta a comporre l'immagine Inpaint/Outpaint."
|
||||
]
|
||||
},
|
||||
"compositingStrength": {
|
||||
"heading": "Forza",
|
||||
"paragraphs": [
|
||||
"Quantità di rumore aggiunta per il Passaggio di Coerenza.",
|
||||
"Simile alla forza di riduzione del rumore."
|
||||
]
|
||||
},
|
||||
"paramNegativeConditioning": {
|
||||
"paragraphs": [
|
||||
"Il processo di generazione evita i concetti nel prompt negativo. Utilizzatelo per escludere qualità o oggetti dall'output.",
|
||||
|
@ -123,8 +123,6 @@
|
||||
"autoSwitchNewImages": "새로운 이미지로 자동 전환",
|
||||
"loading": "불러오는 중",
|
||||
"unableToLoad": "갤러리를 로드할 수 없음",
|
||||
"preparingDownload": "다운로드 준비",
|
||||
"preparingDownloadFailed": "다운로드 준비 중 발생한 문제",
|
||||
"singleColumnLayout": "단일 열 레이아웃",
|
||||
"image": "이미지",
|
||||
"loadMore": "더 불러오기",
|
||||
|
@ -97,8 +97,6 @@
|
||||
"featuresWillReset": "Als je deze afbeelding verwijdert, dan worden deze functies onmiddellijk teruggezet.",
|
||||
"loading": "Bezig met laden",
|
||||
"unableToLoad": "Kan galerij niet laden",
|
||||
"preparingDownload": "Bezig met voorbereiden van download",
|
||||
"preparingDownloadFailed": "Fout bij voorbereiden van download",
|
||||
"downloadSelection": "Download selectie",
|
||||
"currentlyInUse": "Deze afbeelding is momenteel in gebruik door de volgende functies:",
|
||||
"copy": "Kopieer",
|
||||
@ -535,8 +533,6 @@
|
||||
"coherencePassHeader": "Coherentiestap",
|
||||
"maskBlur": "Vervaag",
|
||||
"maskBlurMethod": "Vervagingsmethode",
|
||||
"coherenceSteps": "Stappen",
|
||||
"coherenceStrength": "Sterkte",
|
||||
"seamHighThreshold": "Hoog",
|
||||
"seamLowThreshold": "Laag",
|
||||
"invoke": {
|
||||
@ -1139,13 +1135,6 @@
|
||||
"Een afbeeldingsgrootte (in aantal pixels) equivalent aan 512x512 wordt aanbevolen voor SD1.5-modellen. Een grootte-equivalent van 1024x1024 wordt aanbevolen voor SDXL-modellen."
|
||||
]
|
||||
},
|
||||
"compositingCoherenceSteps": {
|
||||
"heading": "Stappen",
|
||||
"paragraphs": [
|
||||
"Het aantal te gebruiken ontruisingsstappen in de coherentiefase.",
|
||||
"Gelijk aan de hoofdparameter Stappen."
|
||||
]
|
||||
},
|
||||
"dynamicPrompts": {
|
||||
"paragraphs": [
|
||||
"Dynamische prompts vormt een enkele prompt om in vele.",
|
||||
@ -1160,12 +1149,6 @@
|
||||
],
|
||||
"heading": "VAE"
|
||||
},
|
||||
"compositingBlur": {
|
||||
"heading": "Vervaging",
|
||||
"paragraphs": [
|
||||
"De vervagingsstraal van het masker."
|
||||
]
|
||||
},
|
||||
"paramIterations": {
|
||||
"paragraphs": [
|
||||
"Het aantal te genereren afbeeldingen.",
|
||||
@ -1240,13 +1223,6 @@
|
||||
],
|
||||
"heading": "Ontruisingssterkte"
|
||||
},
|
||||
"compositingStrength": {
|
||||
"heading": "Sterkte",
|
||||
"paragraphs": [
|
||||
"Ontruisingssterkte voor de coherentiefase.",
|
||||
"Gelijk aan de parameter Ontruisingssterkte Afbeelding naar afbeelding."
|
||||
]
|
||||
},
|
||||
"paramNegativeConditioning": {
|
||||
"paragraphs": [
|
||||
"Het genereerproces voorkomt de gegeven begrippen in de negatieve prompt. Gebruik dit om bepaalde zaken of voorwerpen uit te sluiten van de uitvoerafbeelding.",
|
||||
|
@ -143,8 +143,6 @@
|
||||
"problemDeletingImagesDesc": "Не удалось удалить одно или несколько изображений",
|
||||
"loading": "Загрузка",
|
||||
"unableToLoad": "Невозможно загрузить галерею",
|
||||
"preparingDownload": "Подготовка к скачиванию",
|
||||
"preparingDownloadFailed": "Проблема с подготовкой к скачиванию",
|
||||
"image": "изображение",
|
||||
"drop": "перебросить",
|
||||
"problemDeletingImages": "Проблема с удалением изображений",
|
||||
@ -612,9 +610,7 @@
|
||||
"maskBlurMethod": "Метод размытия",
|
||||
"seamLowThreshold": "Низкий",
|
||||
"seamHighThreshold": "Высокий",
|
||||
"coherenceSteps": "Шагов",
|
||||
"coherencePassHeader": "Порог Coherence",
|
||||
"coherenceStrength": "Сила",
|
||||
"compositingSettingsHeader": "Настройки компоновки",
|
||||
"invoke": {
|
||||
"noNodesInGraph": "Нет узлов в графе",
|
||||
@ -1321,13 +1317,6 @@
|
||||
"Размер изображения (в пикселях), эквивалентный 512x512, рекомендуется для моделей SD1.5, а размер, эквивалентный 1024x1024, рекомендуется для моделей SDXL."
|
||||
]
|
||||
},
|
||||
"compositingCoherenceSteps": {
|
||||
"heading": "Шаги",
|
||||
"paragraphs": [
|
||||
"Количество шагов снижения шума, используемых при прохождении когерентности.",
|
||||
"То же, что и основной параметр «Шаги»."
|
||||
]
|
||||
},
|
||||
"dynamicPrompts": {
|
||||
"paragraphs": [
|
||||
"Динамические запросы превращают одно приглашение на множество.",
|
||||
@ -1342,12 +1331,6 @@
|
||||
],
|
||||
"heading": "VAE"
|
||||
},
|
||||
"compositingBlur": {
|
||||
"heading": "Размытие",
|
||||
"paragraphs": [
|
||||
"Радиус размытия маски."
|
||||
]
|
||||
},
|
||||
"paramIterations": {
|
||||
"paragraphs": [
|
||||
"Количество изображений, которые нужно сгенерировать.",
|
||||
@ -1422,13 +1405,6 @@
|
||||
],
|
||||
"heading": "Шумоподавление"
|
||||
},
|
||||
"compositingStrength": {
|
||||
"heading": "Сила",
|
||||
"paragraphs": [
|
||||
null,
|
||||
"То же, что параметр «Сила шумоподавления img2img»."
|
||||
]
|
||||
},
|
||||
"paramNegativeConditioning": {
|
||||
"paragraphs": [
|
||||
"Stable Diffusion пытается избежать указанных в отрицательном запросе концепций. Используйте это, чтобы исключить качества или объекты из вывода.",
|
||||
|
@ -355,7 +355,6 @@
|
||||
"starImage": "Yıldız Koy",
|
||||
"download": "İndir",
|
||||
"deleteSelection": "Seçileni Sil",
|
||||
"preparingDownloadFailed": "İndirme Hazırlanırken Sorun",
|
||||
"problemDeletingImages": "Görsel Silmede Sorun",
|
||||
"featuresWillReset": "Bu görseli silerseniz, o özellikler resetlenecektir.",
|
||||
"galleryImageResetSize": "Boyutu Resetle",
|
||||
@ -377,7 +376,6 @@
|
||||
"setCurrentImage": "Çalışma Görseli Yap",
|
||||
"unableToLoad": "Galeri Yüklenemedi",
|
||||
"downloadSelection": "Seçileni İndir",
|
||||
"preparingDownload": "İndirmeye Hazırlanıyor",
|
||||
"singleColumnLayout": "Tek Sütun Düzen",
|
||||
"generations": "Çıktılar",
|
||||
"showUploads": "Yüklenenleri Göster",
|
||||
@ -723,7 +721,6 @@
|
||||
"clipSkip": "CLIP Atlama",
|
||||
"randomizeSeed": "Rastgele Tohum",
|
||||
"cfgScale": "CFG Ölçeği",
|
||||
"coherenceStrength": "Etki",
|
||||
"controlNetControlMode": "Yönetim Kipi",
|
||||
"general": "Genel",
|
||||
"img2imgStrength": "Görselden Görsel Ölçüsü",
|
||||
@ -793,7 +790,6 @@
|
||||
"cfgRescaleMultiplier": "CFG Rescale Çarpanı",
|
||||
"cfgRescale": "CFG Rescale",
|
||||
"coherencePassHeader": "Uyum Geçişi",
|
||||
"coherenceSteps": "Adım",
|
||||
"infillMethod": "Doldurma Yöntemi",
|
||||
"maskBlurMethod": "Bulandırma Yöntemi",
|
||||
"steps": "Adım",
|
||||
|
@ -136,8 +136,6 @@
|
||||
"copy": "复制",
|
||||
"download": "下载",
|
||||
"setCurrentImage": "设为当前图像",
|
||||
"preparingDownload": "准备下载",
|
||||
"preparingDownloadFailed": "准备下载时出现问题",
|
||||
"downloadSelection": "下载所选内容",
|
||||
"noImageSelected": "无选中的图像",
|
||||
"deleteSelection": "删除所选内容",
|
||||
@ -616,11 +614,9 @@
|
||||
"incompatibleBaseModelForControlAdapter": "有 #{{number}} 个 Control Adapter 模型与主模型不兼容。"
|
||||
},
|
||||
"patchmatchDownScaleSize": "缩小",
|
||||
"coherenceSteps": "步数",
|
||||
"clipSkip": "CLIP 跳过层",
|
||||
"compositingSettingsHeader": "合成设置",
|
||||
"useCpuNoise": "使用 CPU 噪声",
|
||||
"coherenceStrength": "强度",
|
||||
"enableNoiseSettings": "启用噪声设置",
|
||||
"coherenceMode": "模式",
|
||||
"cpuNoise": "CPU 噪声",
|
||||
@ -1402,19 +1398,6 @@
|
||||
"图像尺寸(单位:像素)建议 SD 1.5 模型使用等效 512x512 的尺寸,SDXL 模型使用等效 1024x1024 的尺寸。"
|
||||
]
|
||||
},
|
||||
"compositingCoherenceSteps": {
|
||||
"heading": "步数",
|
||||
"paragraphs": [
|
||||
"一致性层中使用的去噪步数。",
|
||||
"与主参数中的步数相同。"
|
||||
]
|
||||
},
|
||||
"compositingBlur": {
|
||||
"heading": "模糊",
|
||||
"paragraphs": [
|
||||
"遮罩模糊半径。"
|
||||
]
|
||||
},
|
||||
"noiseUseCPU": {
|
||||
"heading": "使用 CPU 噪声",
|
||||
"paragraphs": [
|
||||
@ -1467,13 +1450,6 @@
|
||||
"第二轮去噪有助于合成内补/外扩图像。"
|
||||
]
|
||||
},
|
||||
"compositingStrength": {
|
||||
"heading": "强度",
|
||||
"paragraphs": [
|
||||
"一致性层使用的去噪强度。",
|
||||
"去噪强度与图生图的参数相同。"
|
||||
]
|
||||
},
|
||||
"paramNegativeConditioning": {
|
||||
"paragraphs": [
|
||||
"生成过程会避免生成负向提示词中的概念。使用此选项来使输出排除部分质量或对象。",
|
||||
|
@ -55,6 +55,8 @@ import { addUpscaleRequestedListener } from 'app/store/middleware/listenerMiddle
|
||||
import { addWorkflowLoadRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested';
|
||||
import type { AppDispatch, RootState } from 'app/store/store';
|
||||
|
||||
import { addSetDefaultSettingsListener } from './listeners/setDefaultSettings';
|
||||
|
||||
export const listenerMiddleware = createListenerMiddleware();
|
||||
|
||||
export type AppStartListening = TypedStartListening<RootState, AppDispatch>;
|
||||
@ -153,3 +155,5 @@ addUpscaleRequestedListener(startAppListening);
|
||||
|
||||
// Dynamic prompts
|
||||
addDynamicPromptsListener(startAppListening);
|
||||
|
||||
addSetDefaultSettingsListener(startAppListening);
|
||||
|
@ -0,0 +1,96 @@
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { setDefaultSettings } from 'features/parameters/store/actions';
|
||||
import {
|
||||
setCfgRescaleMultiplier,
|
||||
setCfgScale,
|
||||
setScheduler,
|
||||
setSteps,
|
||||
vaePrecisionChanged,
|
||||
vaeSelected,
|
||||
} from 'features/parameters/store/generationSlice';
|
||||
import {
|
||||
isParameterCFGRescaleMultiplier,
|
||||
isParameterCFGScale,
|
||||
isParameterPrecision,
|
||||
isParameterScheduler,
|
||||
isParameterSteps,
|
||||
zParameterVAEModel,
|
||||
} from 'features/parameters/types/parameterSchemas';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { t } from 'i18next';
|
||||
import { map } from 'lodash-es';
|
||||
import { modelsApi } from 'services/api/endpoints/models';
|
||||
|
||||
export const addSetDefaultSettingsListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
actionCreator: setDefaultSettings,
|
||||
effect: async (action, { dispatch, getState }) => {
|
||||
const state = getState();
|
||||
|
||||
const currentModel = state.generation.model;
|
||||
|
||||
if (!currentModel) {
|
||||
return;
|
||||
}
|
||||
|
||||
const metadata = await dispatch(modelsApi.endpoints.getModelMetadata.initiate(currentModel.key)).unwrap();
|
||||
|
||||
if (!metadata || !metadata.default_settings) {
|
||||
return;
|
||||
}
|
||||
|
||||
const { vae, vae_precision, cfg_scale, cfg_rescale_multiplier, steps, scheduler } = metadata.default_settings;
|
||||
|
||||
if (vae) {
|
||||
// we store this as "default" within default settings
|
||||
// to distinguish it from no default set
|
||||
if (vae === 'default') {
|
||||
dispatch(vaeSelected(null));
|
||||
} else {
|
||||
const { data } = modelsApi.endpoints.getVaeModels.select()(state);
|
||||
const vaeArray = map(data?.entities);
|
||||
const validVae = vaeArray.find((model) => model.key === vae);
|
||||
|
||||
const result = zParameterVAEModel.safeParse(validVae);
|
||||
if (!result.success) {
|
||||
return;
|
||||
}
|
||||
dispatch(vaeSelected(result.data));
|
||||
}
|
||||
}
|
||||
|
||||
if (vae_precision) {
|
||||
if (isParameterPrecision(vae_precision)) {
|
||||
dispatch(vaePrecisionChanged(vae_precision));
|
||||
}
|
||||
}
|
||||
|
||||
if (cfg_scale) {
|
||||
if (isParameterCFGScale(cfg_scale)) {
|
||||
dispatch(setCfgScale(cfg_scale));
|
||||
}
|
||||
}
|
||||
|
||||
if (cfg_rescale_multiplier) {
|
||||
if (isParameterCFGRescaleMultiplier(cfg_rescale_multiplier)) {
|
||||
dispatch(setCfgRescaleMultiplier(cfg_rescale_multiplier));
|
||||
}
|
||||
}
|
||||
|
||||
if (steps) {
|
||||
if (isParameterSteps(steps)) {
|
||||
dispatch(setSteps(steps));
|
||||
}
|
||||
}
|
||||
|
||||
if (scheduler) {
|
||||
if (isParameterScheduler(scheduler)) {
|
||||
dispatch(setScheduler(scheduler));
|
||||
}
|
||||
}
|
||||
|
||||
dispatch(addToast(makeToast({ title: t('toast.parameterSet', { parameter: 'Default settings' }) })));
|
||||
},
|
||||
});
|
||||
};
|
@ -1,4 +1,5 @@
|
||||
import type { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants';
|
||||
import type { ParameterPrecision, ParameterScheduler } from 'features/parameters/types/parameterSchemas';
|
||||
import type { InvokeTabName } from 'features/ui/store/tabMap';
|
||||
import type { O } from 'ts-toolbelt';
|
||||
|
||||
@ -82,6 +83,8 @@ export type AppConfig = {
|
||||
guidance: NumericalParameterConfig;
|
||||
cfgRescaleMultiplier: NumericalParameterConfig;
|
||||
img2imgStrength: NumericalParameterConfig;
|
||||
scheduler?: ParameterScheduler;
|
||||
vaePrecision?: ParameterPrecision;
|
||||
// Canvas
|
||||
boundingBoxHeight: NumericalParameterConfig; // initial value comes from model
|
||||
boundingBoxWidth: NumericalParameterConfig; // initial value comes from model
|
||||
|
@ -59,7 +59,7 @@ const LoRASelect = () => {
|
||||
return (
|
||||
<FormControl isDisabled={!options.length}>
|
||||
<InformationalPopover feature="lora">
|
||||
<FormLabel>{t('models.lora')} </FormLabel>
|
||||
<FormLabel>{t('models.concepts')} </FormLabel>
|
||||
</InformationalPopover>
|
||||
<Combobox
|
||||
placeholder={placeholder}
|
||||
|
@ -15,7 +15,7 @@ const STATUSES = {
|
||||
const ImportQueueBadge = ({ status, errorReason }: { status?: ModelInstallStatus; errorReason?: string | null }) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
if (!status) {
|
||||
if (!status || !Object.keys(STATUSES).includes(status)) {
|
||||
return <></>;
|
||||
}
|
||||
|
||||
|
@ -0,0 +1,66 @@
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import Loading from 'common/components/Loading/Loading';
|
||||
import { selectConfigSlice } from 'features/system/store/configSlice';
|
||||
import { isNil } from 'lodash-es';
|
||||
import { useMemo } from 'react';
|
||||
import { useGetModelMetadataQuery } from 'services/api/endpoints/models';
|
||||
|
||||
import { DefaultSettingsForm } from './DefaultSettings/DefaultSettingsForm';
|
||||
|
||||
const initialStatesSelector = createMemoizedSelector(selectConfigSlice, (config) => {
|
||||
const { steps, guidance, scheduler, cfgRescaleMultiplier, vaePrecision } = config.sd;
|
||||
|
||||
return {
|
||||
initialSteps: steps.initial,
|
||||
initialCfg: guidance.initial,
|
||||
initialScheduler: scheduler,
|
||||
initialCfgRescaleMultiplier: cfgRescaleMultiplier.initial,
|
||||
initialVaePrecision: vaePrecision,
|
||||
};
|
||||
});
|
||||
|
||||
export const DefaultSettings = () => {
|
||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||
|
||||
const { data, isLoading } = useGetModelMetadataQuery(selectedModelKey ?? skipToken);
|
||||
const { initialSteps, initialCfg, initialScheduler, initialCfgRescaleMultiplier, initialVaePrecision } =
|
||||
useAppSelector(initialStatesSelector);
|
||||
|
||||
const defaultSettingsDefaults = useMemo(() => {
|
||||
return {
|
||||
vae: { isEnabled: !isNil(data?.default_settings?.vae), value: data?.default_settings?.vae || 'default' },
|
||||
vaePrecision: {
|
||||
isEnabled: !isNil(data?.default_settings?.vae_precision),
|
||||
value: data?.default_settings?.vae_precision || initialVaePrecision || 'fp32',
|
||||
},
|
||||
scheduler: {
|
||||
isEnabled: !isNil(data?.default_settings?.scheduler),
|
||||
value: data?.default_settings?.scheduler || initialScheduler || 'euler',
|
||||
},
|
||||
steps: { isEnabled: !isNil(data?.default_settings?.steps), value: data?.default_settings?.steps || initialSteps },
|
||||
cfgScale: {
|
||||
isEnabled: !isNil(data?.default_settings?.cfg_scale),
|
||||
value: data?.default_settings?.cfg_scale || initialCfg,
|
||||
},
|
||||
cfgRescaleMultiplier: {
|
||||
isEnabled: !isNil(data?.default_settings?.cfg_rescale_multiplier),
|
||||
value: data?.default_settings?.cfg_rescale_multiplier || initialCfgRescaleMultiplier,
|
||||
},
|
||||
};
|
||||
}, [
|
||||
data?.default_settings,
|
||||
initialSteps,
|
||||
initialCfg,
|
||||
initialScheduler,
|
||||
initialCfgRescaleMultiplier,
|
||||
initialVaePrecision,
|
||||
]);
|
||||
|
||||
if (isLoading) {
|
||||
return <Loading />;
|
||||
}
|
||||
|
||||
return <DefaultSettingsForm defaultSettingsDefaults={defaultSettingsDefaults} />;
|
||||
};
|
@ -0,0 +1,72 @@
|
||||
import { CompositeNumberInput, CompositeSlider, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { UseControllerProps } from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import type { DefaultSettingsFormData } from './DefaultSettingsForm';
|
||||
|
||||
type DefaultCfgRescaleMultiplierType = DefaultSettingsFormData['cfgRescaleMultiplier'];
|
||||
|
||||
export function DefaultCfgRescaleMultiplier(props: UseControllerProps<DefaultSettingsFormData>) {
|
||||
const { field } = useController(props);
|
||||
|
||||
const sliderMin = useAppSelector((s) => s.config.sd.cfgRescaleMultiplier.sliderMin);
|
||||
const sliderMax = useAppSelector((s) => s.config.sd.cfgRescaleMultiplier.sliderMax);
|
||||
const numberInputMin = useAppSelector((s) => s.config.sd.cfgRescaleMultiplier.numberInputMin);
|
||||
const numberInputMax = useAppSelector((s) => s.config.sd.cfgRescaleMultiplier.numberInputMax);
|
||||
const coarseStep = useAppSelector((s) => s.config.sd.cfgRescaleMultiplier.coarseStep);
|
||||
const fineStep = useAppSelector((s) => s.config.sd.cfgRescaleMultiplier.fineStep);
|
||||
const { t } = useTranslation();
|
||||
const marks = useMemo(() => [sliderMin, Math.floor(sliderMax / 2), sliderMax], [sliderMax, sliderMin]);
|
||||
|
||||
const onChange = useCallback(
|
||||
(v: number) => {
|
||||
const updatedValue = {
|
||||
...(field.value as DefaultCfgRescaleMultiplierType),
|
||||
value: v,
|
||||
};
|
||||
field.onChange(updatedValue);
|
||||
},
|
||||
[field]
|
||||
);
|
||||
|
||||
const value = useMemo(() => {
|
||||
return (field.value as DefaultCfgRescaleMultiplierType).value;
|
||||
}, [field.value]);
|
||||
|
||||
const isDisabled = useMemo(() => {
|
||||
return !(field.value as DefaultCfgRescaleMultiplierType).isEnabled;
|
||||
}, [field.value]);
|
||||
|
||||
return (
|
||||
<FormControl flexDir="column" gap={1} alignItems="flex-start">
|
||||
<InformationalPopover feature="paramCFGRescaleMultiplier">
|
||||
<FormLabel>{t('parameters.cfgRescaleMultiplier')}</FormLabel>
|
||||
</InformationalPopover>
|
||||
<Flex w="full" gap={1}>
|
||||
<CompositeSlider
|
||||
value={value}
|
||||
min={sliderMin}
|
||||
max={sliderMax}
|
||||
step={coarseStep}
|
||||
fineStep={fineStep}
|
||||
onChange={onChange}
|
||||
marks={marks}
|
||||
isDisabled={isDisabled}
|
||||
/>
|
||||
<CompositeNumberInput
|
||||
value={value}
|
||||
min={numberInputMin}
|
||||
max={numberInputMax}
|
||||
step={coarseStep}
|
||||
fineStep={fineStep}
|
||||
onChange={onChange}
|
||||
isDisabled={isDisabled}
|
||||
/>
|
||||
</Flex>
|
||||
</FormControl>
|
||||
);
|
||||
}
|
@ -0,0 +1,72 @@
|
||||
import { CompositeNumberInput, CompositeSlider, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { UseControllerProps } from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import type { DefaultSettingsFormData } from './DefaultSettingsForm';
|
||||
|
||||
type DefaultCfgType = DefaultSettingsFormData['cfgScale'];
|
||||
|
||||
export function DefaultCfgScale(props: UseControllerProps<DefaultSettingsFormData>) {
|
||||
const { field } = useController(props);
|
||||
|
||||
const sliderMin = useAppSelector((s) => s.config.sd.guidance.sliderMin);
|
||||
const sliderMax = useAppSelector((s) => s.config.sd.guidance.sliderMax);
|
||||
const numberInputMin = useAppSelector((s) => s.config.sd.guidance.numberInputMin);
|
||||
const numberInputMax = useAppSelector((s) => s.config.sd.guidance.numberInputMax);
|
||||
const coarseStep = useAppSelector((s) => s.config.sd.guidance.coarseStep);
|
||||
const fineStep = useAppSelector((s) => s.config.sd.guidance.fineStep);
|
||||
const { t } = useTranslation();
|
||||
const marks = useMemo(() => [sliderMin, Math.floor(sliderMax / 2), sliderMax], [sliderMax, sliderMin]);
|
||||
|
||||
const onChange = useCallback(
|
||||
(v: number) => {
|
||||
const updatedValue = {
|
||||
...(field.value as DefaultCfgType),
|
||||
value: v,
|
||||
};
|
||||
field.onChange(updatedValue);
|
||||
},
|
||||
[field]
|
||||
);
|
||||
|
||||
const value = useMemo(() => {
|
||||
return (field.value as DefaultCfgType).value;
|
||||
}, [field.value]);
|
||||
|
||||
const isDisabled = useMemo(() => {
|
||||
return !(field.value as DefaultCfgType).isEnabled;
|
||||
}, [field.value]);
|
||||
|
||||
return (
|
||||
<FormControl flexDir="column" gap={1} alignItems="flex-start">
|
||||
<InformationalPopover feature="paramCFGScale">
|
||||
<FormLabel>{t('parameters.cfgScale')}</FormLabel>
|
||||
</InformationalPopover>
|
||||
<Flex w="full" gap={1}>
|
||||
<CompositeSlider
|
||||
value={value}
|
||||
min={sliderMin}
|
||||
max={sliderMax}
|
||||
step={coarseStep}
|
||||
fineStep={fineStep}
|
||||
onChange={onChange}
|
||||
marks={marks}
|
||||
isDisabled={isDisabled}
|
||||
/>
|
||||
<CompositeNumberInput
|
||||
value={value}
|
||||
min={numberInputMin}
|
||||
max={numberInputMax}
|
||||
step={coarseStep}
|
||||
fineStep={fineStep}
|
||||
onChange={onChange}
|
||||
isDisabled={isDisabled}
|
||||
/>
|
||||
</Flex>
|
||||
</FormControl>
|
||||
);
|
||||
}
|
@ -0,0 +1,50 @@
|
||||
import type { ComboboxOnChange } from '@invoke-ai/ui-library';
|
||||
import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import { SCHEDULER_OPTIONS } from 'features/parameters/types/constants';
|
||||
import { isParameterScheduler } from 'features/parameters/types/parameterSchemas';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { UseControllerProps } from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import type { DefaultSettingsFormData } from './DefaultSettingsForm';
|
||||
|
||||
type DefaultSchedulerType = DefaultSettingsFormData['scheduler'];
|
||||
|
||||
export function DefaultScheduler(props: UseControllerProps<DefaultSettingsFormData>) {
|
||||
const { t } = useTranslation();
|
||||
const { field } = useController(props);
|
||||
|
||||
const onChange = useCallback<ComboboxOnChange>(
|
||||
(v) => {
|
||||
if (!isParameterScheduler(v?.value)) {
|
||||
return;
|
||||
}
|
||||
const updatedValue = {
|
||||
...(field.value as DefaultSchedulerType),
|
||||
value: v.value,
|
||||
};
|
||||
field.onChange(updatedValue);
|
||||
},
|
||||
[field]
|
||||
);
|
||||
|
||||
const value = useMemo(
|
||||
() => SCHEDULER_OPTIONS.find((o) => o.value === (field.value as DefaultSchedulerType).value),
|
||||
[field]
|
||||
);
|
||||
|
||||
const isDisabled = useMemo(() => {
|
||||
return !(field.value as DefaultSchedulerType).isEnabled;
|
||||
}, [field.value]);
|
||||
|
||||
return (
|
||||
<FormControl flexDir="column" gap={1} alignItems="flex-start">
|
||||
<InformationalPopover feature="paramScheduler">
|
||||
<FormLabel>{t('parameters.scheduler')}</FormLabel>
|
||||
</InformationalPopover>
|
||||
<Combobox isDisabled={isDisabled} value={value} options={SCHEDULER_OPTIONS} onChange={onChange} />
|
||||
</FormControl>
|
||||
);
|
||||
}
|
@ -0,0 +1,147 @@
|
||||
import { Button, Flex, Heading } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import type { ParameterScheduler } from 'features/parameters/types/parameterSchemas';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { useCallback } from 'react';
|
||||
import type { SubmitHandler } from 'react-hook-form';
|
||||
import { useForm } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { IoPencil } from 'react-icons/io5';
|
||||
import { useUpdateModelMetadataMutation } from 'services/api/endpoints/models';
|
||||
|
||||
import { DefaultCfgRescaleMultiplier } from './DefaultCfgRescaleMultiplier';
|
||||
import { DefaultCfgScale } from './DefaultCfgScale';
|
||||
import { DefaultScheduler } from './DefaultScheduler';
|
||||
import { DefaultSteps } from './DefaultSteps';
|
||||
import { DefaultVae } from './DefaultVae';
|
||||
import { DefaultVaePrecision } from './DefaultVaePrecision';
|
||||
import { SettingToggle } from './SettingToggle';
|
||||
|
||||
export interface FormField<T> {
|
||||
value: T;
|
||||
isEnabled: boolean;
|
||||
}
|
||||
|
||||
export type DefaultSettingsFormData = {
|
||||
vae: FormField<string>;
|
||||
vaePrecision: FormField<string>;
|
||||
scheduler: FormField<ParameterScheduler>;
|
||||
steps: FormField<number>;
|
||||
cfgScale: FormField<number>;
|
||||
cfgRescaleMultiplier: FormField<number>;
|
||||
};
|
||||
|
||||
export const DefaultSettingsForm = ({
|
||||
defaultSettingsDefaults,
|
||||
}: {
|
||||
defaultSettingsDefaults: DefaultSettingsFormData;
|
||||
}) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||
|
||||
const [editModelMetadata, { isLoading }] = useUpdateModelMetadataMutation();
|
||||
|
||||
const { handleSubmit, control, formState } = useForm<DefaultSettingsFormData>({
|
||||
defaultValues: defaultSettingsDefaults,
|
||||
});
|
||||
|
||||
const onSubmit = useCallback<SubmitHandler<DefaultSettingsFormData>>(
|
||||
(data) => {
|
||||
if (!selectedModelKey) {
|
||||
return;
|
||||
}
|
||||
|
||||
const body = {
|
||||
vae: data.vae.isEnabled ? data.vae.value : null,
|
||||
vae_precision: data.vaePrecision.isEnabled ? data.vaePrecision.value : null,
|
||||
cfg_scale: data.cfgScale.isEnabled ? data.cfgScale.value : null,
|
||||
cfg_rescale_multiplier: data.cfgRescaleMultiplier.isEnabled ? data.cfgRescaleMultiplier.value : null,
|
||||
steps: data.steps.isEnabled ? data.steps.value : null,
|
||||
scheduler: data.scheduler.isEnabled ? data.scheduler.value : null,
|
||||
};
|
||||
|
||||
editModelMetadata({
|
||||
key: selectedModelKey,
|
||||
body: { default_settings: body },
|
||||
})
|
||||
.unwrap()
|
||||
.then((_) => {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('modelManager.defaultSettingsSaved'),
|
||||
status: 'success',
|
||||
})
|
||||
)
|
||||
);
|
||||
})
|
||||
.catch((error) => {
|
||||
if (error) {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: `${error.data.detail} `,
|
||||
status: 'error',
|
||||
})
|
||||
)
|
||||
);
|
||||
}
|
||||
});
|
||||
},
|
||||
[selectedModelKey, dispatch, editModelMetadata, t]
|
||||
);
|
||||
|
||||
return (
|
||||
<>
|
||||
<Flex gap="2" justifyContent="space-between" w="full" mb={5}>
|
||||
<Heading fontSize="md">{t('modelManager.defaultSettings')}</Heading>
|
||||
<Button
|
||||
size="sm"
|
||||
leftIcon={<IoPencil />}
|
||||
colorScheme="invokeYellow"
|
||||
isDisabled={!formState.isDirty}
|
||||
onClick={handleSubmit(onSubmit)}
|
||||
type="submit"
|
||||
isLoading={isLoading}
|
||||
>
|
||||
{t('common.save')}
|
||||
</Button>
|
||||
</Flex>
|
||||
|
||||
<Flex flexDir="column" gap={8}>
|
||||
<Flex gap={8}>
|
||||
<Flex gap={4} w="full">
|
||||
<SettingToggle control={control} name="vae" />
|
||||
<DefaultVae control={control} name="vae" />
|
||||
</Flex>
|
||||
<Flex gap={4} w="full">
|
||||
<SettingToggle control={control} name="vaePrecision" />
|
||||
<DefaultVaePrecision control={control} name="vaePrecision" />
|
||||
</Flex>
|
||||
</Flex>
|
||||
<Flex gap={8}>
|
||||
<Flex gap={4} w="full">
|
||||
<SettingToggle control={control} name="scheduler" />
|
||||
<DefaultScheduler control={control} name="scheduler" />
|
||||
</Flex>
|
||||
<Flex gap={4} w="full">
|
||||
<SettingToggle control={control} name="steps" />
|
||||
<DefaultSteps control={control} name="steps" />
|
||||
</Flex>
|
||||
</Flex>
|
||||
<Flex gap={8}>
|
||||
<Flex gap={4} w="full">
|
||||
<SettingToggle control={control} name="cfgScale" />
|
||||
<DefaultCfgScale control={control} name="cfgScale" />
|
||||
</Flex>
|
||||
<Flex gap={4} w="full">
|
||||
<SettingToggle control={control} name="cfgRescaleMultiplier" />
|
||||
<DefaultCfgRescaleMultiplier control={control} name="cfgRescaleMultiplier" />
|
||||
</Flex>
|
||||
</Flex>
|
||||
</Flex>
|
||||
</>
|
||||
);
|
||||
};
|
@ -0,0 +1,72 @@
|
||||
import { CompositeNumberInput, CompositeSlider, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { UseControllerProps } from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import type { DefaultSettingsFormData } from './DefaultSettingsForm';
|
||||
|
||||
type DefaultSteps = DefaultSettingsFormData['steps'];
|
||||
|
||||
export function DefaultSteps(props: UseControllerProps<DefaultSettingsFormData>) {
|
||||
const { field } = useController(props);
|
||||
|
||||
const sliderMin = useAppSelector((s) => s.config.sd.steps.sliderMin);
|
||||
const sliderMax = useAppSelector((s) => s.config.sd.steps.sliderMax);
|
||||
const numberInputMin = useAppSelector((s) => s.config.sd.steps.numberInputMin);
|
||||
const numberInputMax = useAppSelector((s) => s.config.sd.steps.numberInputMax);
|
||||
const coarseStep = useAppSelector((s) => s.config.sd.steps.coarseStep);
|
||||
const fineStep = useAppSelector((s) => s.config.sd.steps.fineStep);
|
||||
const { t } = useTranslation();
|
||||
const marks = useMemo(() => [sliderMin, Math.floor(sliderMax / 2), sliderMax], [sliderMax, sliderMin]);
|
||||
|
||||
const onChange = useCallback(
|
||||
(v: number) => {
|
||||
const updatedValue = {
|
||||
...(field.value as DefaultSteps),
|
||||
value: v,
|
||||
};
|
||||
field.onChange(updatedValue);
|
||||
},
|
||||
[field]
|
||||
);
|
||||
|
||||
const value = useMemo(() => {
|
||||
return (field.value as DefaultSteps).value;
|
||||
}, [field.value]);
|
||||
|
||||
const isDisabled = useMemo(() => {
|
||||
return !(field.value as DefaultSteps).isEnabled;
|
||||
}, [field.value]);
|
||||
|
||||
return (
|
||||
<FormControl flexDir="column" gap={1} alignItems="flex-start">
|
||||
<InformationalPopover feature="paramSteps">
|
||||
<FormLabel>{t('parameters.steps')}</FormLabel>
|
||||
</InformationalPopover>
|
||||
<Flex w="full" gap={1}>
|
||||
<CompositeSlider
|
||||
value={value}
|
||||
min={sliderMin}
|
||||
max={sliderMax}
|
||||
step={coarseStep}
|
||||
fineStep={fineStep}
|
||||
onChange={onChange}
|
||||
marks={marks}
|
||||
isDisabled={isDisabled}
|
||||
/>
|
||||
<CompositeNumberInput
|
||||
value={value}
|
||||
min={numberInputMin}
|
||||
max={numberInputMax}
|
||||
step={coarseStep}
|
||||
fineStep={fineStep}
|
||||
onChange={onChange}
|
||||
isDisabled={isDisabled}
|
||||
/>
|
||||
</Flex>
|
||||
</FormControl>
|
||||
);
|
||||
}
|
@ -0,0 +1,65 @@
|
||||
import type { ComboboxOnChange } from '@invoke-ai/ui-library';
|
||||
import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import { map } from 'lodash-es';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { UseControllerProps } from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetModelConfigQuery, useGetVaeModelsQuery } from 'services/api/endpoints/models';
|
||||
|
||||
import type { DefaultSettingsFormData } from './DefaultSettingsForm';
|
||||
|
||||
type DefaultVaeType = DefaultSettingsFormData['vae'];
|
||||
|
||||
export function DefaultVae(props: UseControllerProps<DefaultSettingsFormData>) {
|
||||
const { t } = useTranslation();
|
||||
const { field } = useController(props);
|
||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||
const { data: modelData } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
|
||||
|
||||
const { compatibleOptions } = useGetVaeModelsQuery(undefined, {
|
||||
selectFromResult: ({ data }) => {
|
||||
const modelArray = map(data?.entities);
|
||||
const compatibleOptions = modelArray
|
||||
.filter((vae) => vae.base === modelData?.base)
|
||||
.map((vae) => ({ label: vae.name, value: vae.key }));
|
||||
|
||||
const defaultOption = { label: 'Default VAE', value: 'default' };
|
||||
|
||||
return { compatibleOptions: [defaultOption, ...compatibleOptions] };
|
||||
},
|
||||
});
|
||||
|
||||
const onChange = useCallback<ComboboxOnChange>(
|
||||
(v) => {
|
||||
const newValue = !v?.value ? 'default' : v.value;
|
||||
|
||||
const updatedValue = {
|
||||
...(field.value as DefaultVaeType),
|
||||
value: newValue,
|
||||
};
|
||||
field.onChange(updatedValue);
|
||||
},
|
||||
[field]
|
||||
);
|
||||
|
||||
const value = useMemo(() => {
|
||||
return compatibleOptions.find((vae) => vae.value === (field.value as DefaultVaeType).value);
|
||||
}, [compatibleOptions, field.value]);
|
||||
|
||||
const isDisabled = useMemo(() => {
|
||||
return !(field.value as DefaultVaeType).isEnabled;
|
||||
}, [field.value]);
|
||||
|
||||
return (
|
||||
<FormControl flexDir="column" gap={1} alignItems="flex-start">
|
||||
<InformationalPopover feature="paramVAE">
|
||||
<FormLabel>{t('modelManager.vae')}</FormLabel>
|
||||
</InformationalPopover>
|
||||
<Combobox isDisabled={isDisabled} value={value} options={compatibleOptions} onChange={onChange} />
|
||||
</FormControl>
|
||||
);
|
||||
}
|
@ -0,0 +1,51 @@
|
||||
import type { ComboboxOnChange } from '@invoke-ai/ui-library';
|
||||
import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import { isParameterPrecision } from 'features/parameters/types/parameterSchemas';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { UseControllerProps } from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import type { DefaultSettingsFormData } from './DefaultSettingsForm';
|
||||
|
||||
const options = [
|
||||
{ label: 'FP16', value: 'fp16' },
|
||||
{ label: 'FP32', value: 'fp32' },
|
||||
];
|
||||
|
||||
type DefaultVaePrecisionType = DefaultSettingsFormData['vaePrecision'];
|
||||
|
||||
export function DefaultVaePrecision(props: UseControllerProps<DefaultSettingsFormData>) {
|
||||
const { t } = useTranslation();
|
||||
const { field } = useController(props);
|
||||
|
||||
const onChange = useCallback<ComboboxOnChange>(
|
||||
(v) => {
|
||||
if (!isParameterPrecision(v?.value)) {
|
||||
return;
|
||||
}
|
||||
const updatedValue = {
|
||||
...(field.value as DefaultVaePrecisionType),
|
||||
value: v.value,
|
||||
};
|
||||
field.onChange(updatedValue);
|
||||
},
|
||||
[field]
|
||||
);
|
||||
|
||||
const value = useMemo(() => options.find((o) => o.value === (field.value as DefaultVaePrecisionType).value), [field]);
|
||||
|
||||
const isDisabled = useMemo(() => {
|
||||
return !(field.value as DefaultVaePrecisionType).isEnabled;
|
||||
}, [field.value]);
|
||||
|
||||
return (
|
||||
<FormControl flexDir="column" gap={1} alignItems="flex-start">
|
||||
<InformationalPopover feature="paramVAEPrecision">
|
||||
<FormLabel>{t('modelManager.vaePrecision')}</FormLabel>
|
||||
</InformationalPopover>
|
||||
<Combobox isDisabled={isDisabled} value={value} options={options} onChange={onChange} />
|
||||
</FormControl>
|
||||
);
|
||||
}
|
@ -0,0 +1,28 @@
|
||||
import { Switch } from '@invoke-ai/ui-library';
|
||||
import type { ChangeEvent } from 'react';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { UseControllerProps } from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
|
||||
import type { DefaultSettingsFormData, FormField } from './DefaultSettingsForm';
|
||||
|
||||
export function SettingToggle<T>(props: UseControllerProps<DefaultSettingsFormData>) {
|
||||
const { field } = useController(props);
|
||||
|
||||
const value = useMemo(() => {
|
||||
return !!(field.value as FormField<T>).isEnabled;
|
||||
}, [field.value]);
|
||||
|
||||
const onChange = useCallback(
|
||||
(e: ChangeEvent<HTMLInputElement>) => {
|
||||
const updatedValue: FormField<T> = {
|
||||
...(field.value as FormField<T>),
|
||||
isEnabled: e.target.checked,
|
||||
};
|
||||
field.onChange(updatedValue);
|
||||
},
|
||||
[field]
|
||||
);
|
||||
|
||||
return <Switch isChecked={value} onChange={onChange} />;
|
||||
}
|
@ -17,6 +17,7 @@ import type {
|
||||
VAEModelConfig,
|
||||
} from 'services/api/types';
|
||||
|
||||
import { DefaultSettings } from './DefaultSettings';
|
||||
import { ModelAttrView } from './ModelAttrView';
|
||||
import { ModelConvert } from './ModelConvert';
|
||||
|
||||
@ -71,7 +72,7 @@ export const ModelView = () => {
|
||||
return <Text>{t('common.somethingWentWrong')}</Text>;
|
||||
}
|
||||
return (
|
||||
<Flex flexDir="column" h="full">
|
||||
<Flex flexDir="column" h="full" gap="2">
|
||||
<Box layerStyle="second" borderRadius="base" p={3}>
|
||||
<Flex gap="2" justifyContent="flex-end" w="full">
|
||||
<Button size="sm" leftIcon={<IoPencil />} colorScheme="invokeYellow" onClick={handleEditModel}>
|
||||
@ -118,6 +119,9 @@ export const ModelView = () => {
|
||||
)}
|
||||
</Flex>
|
||||
</Box>
|
||||
<Box layerStyle="second" borderRadius="base" p={3}>
|
||||
<DefaultSettings />
|
||||
</Box>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
@ -344,8 +344,8 @@ export const buildCanvasInpaintGraph = (
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: MASK_RESIZE_UP,
|
||||
field: 'image',
|
||||
node_id: INPAINT_CREATE_MASK,
|
||||
field: 'expanded_mask_area',
|
||||
},
|
||||
destination: {
|
||||
node_id: MASK_RESIZE_DOWN,
|
||||
|
@ -439,8 +439,8 @@ export const buildCanvasOutpaintGraph = (
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: MASK_RESIZE_UP,
|
||||
field: 'image',
|
||||
node_id: INPAINT_CREATE_MASK,
|
||||
field: 'expanded_mask_area',
|
||||
},
|
||||
destination: {
|
||||
node_id: MASK_RESIZE_DOWN,
|
||||
|
@ -355,8 +355,8 @@ export const buildCanvasSDXLInpaintGraph = (
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: MASK_RESIZE_UP,
|
||||
field: 'image',
|
||||
node_id: INPAINT_CREATE_MASK,
|
||||
field: 'expanded_mask_area',
|
||||
},
|
||||
destination: {
|
||||
node_id: MASK_RESIZE_DOWN,
|
||||
|
@ -448,8 +448,8 @@ export const buildCanvasSDXLOutpaintGraph = (
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: MASK_RESIZE_UP,
|
||||
field: 'image',
|
||||
node_id: INPAINT_CREATE_MASK,
|
||||
field: 'expanded_mask_area',
|
||||
},
|
||||
destination: {
|
||||
node_id: MASK_RESIZE_DOWN,
|
||||
|
@ -0,0 +1,36 @@
|
||||
import type { IconButtonProps } from '@invoke-ai/ui-library';
|
||||
import { IconButton } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { setActiveTab } from 'features/ui/store/uiSlice';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiGearSixBold } from 'react-icons/pi';
|
||||
|
||||
export const NavigateToModelManagerButton = memo((props: Omit<IconButtonProps, 'aria-label'>) => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const disabledTabs = useAppSelector((s) => s.config.disabledTabs);
|
||||
const shouldShowButton = useMemo(() => !disabledTabs.includes('modelManager'), [disabledTabs]);
|
||||
|
||||
const handleClick = useCallback(() => {
|
||||
dispatch(setActiveTab('modelManager'));
|
||||
}, [dispatch]);
|
||||
|
||||
if (!shouldShowButton) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<IconButton
|
||||
icon={<PiGearSixBold />}
|
||||
tooltip={t('modelManager.modelManager')}
|
||||
aria-label={t('modelManager.modelManager')}
|
||||
onClick={handleClick}
|
||||
size="sm"
|
||||
variant="ghost"
|
||||
{...props}
|
||||
/>
|
||||
);
|
||||
});
|
||||
|
||||
NavigateToModelManagerButton.displayName = 'NavigateToModelManagerButton';
|
@ -0,0 +1,28 @@
|
||||
import { IconButton } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { setDefaultSettings } from 'features/parameters/store/actions';
|
||||
import { useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { RiSparklingFill } from 'react-icons/ri';
|
||||
|
||||
export const UseDefaultSettingsButton = () => {
|
||||
const model = useAppSelector((s) => s.generation.model);
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const handleClickDefaultSettings = useCallback(() => {
|
||||
dispatch(setDefaultSettings());
|
||||
}, [dispatch]);
|
||||
|
||||
return (
|
||||
<IconButton
|
||||
icon={<RiSparklingFill />}
|
||||
tooltip={t('modelManager.useDefaultSettings')}
|
||||
aria-label={t('modelManager.useDefaultSettings')}
|
||||
isDisabled={!model}
|
||||
onClick={handleClickDefaultSettings}
|
||||
size="sm"
|
||||
variant="ghost"
|
||||
/>
|
||||
);
|
||||
};
|
@ -5,3 +5,5 @@ import type { ImageDTO } from 'services/api/types';
|
||||
export const initialImageSelected = createAction<ImageDTO | undefined>('generation/initialImageSelected');
|
||||
|
||||
export const modelSelected = createAction<ParameterModel>('generation/modelSelected');
|
||||
|
||||
export const setDefaultSettings = createAction('generation/setDefaultSettings');
|
||||
|
@ -230,6 +230,12 @@ export const generationSlice = createSlice({
|
||||
state.height = optimalDimension;
|
||||
}
|
||||
}
|
||||
if (action.payload.sd?.scheduler) {
|
||||
state.scheduler = action.payload.sd.scheduler;
|
||||
}
|
||||
if (action.payload.sd?.vaePrecision) {
|
||||
state.vaePrecision = action.payload.sd.vaePrecision;
|
||||
}
|
||||
});
|
||||
|
||||
// TODO: This is a temp fix to reduce issues with T2I adapter having a different downscaling
|
||||
|
@ -1,15 +1,5 @@
|
||||
import type { FormLabelProps } from '@invoke-ai/ui-library';
|
||||
import {
|
||||
Expander,
|
||||
Flex,
|
||||
FormControlGroup,
|
||||
StandaloneAccordion,
|
||||
Tab,
|
||||
TabList,
|
||||
TabPanel,
|
||||
TabPanels,
|
||||
Tabs,
|
||||
} from '@invoke-ai/ui-library';
|
||||
import { Box, Expander, Flex, FormControlGroup, StandaloneAccordion } from '@invoke-ai/ui-library';
|
||||
import { EMPTY_ARRAY } from 'app/store/constants';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
@ -20,7 +10,9 @@ import { SyncModelsIconButton } from 'features/modelManagerV2/components/SyncMod
|
||||
import ParamCFGScale from 'features/parameters/components/Core/ParamCFGScale';
|
||||
import ParamScheduler from 'features/parameters/components/Core/ParamScheduler';
|
||||
import ParamSteps from 'features/parameters/components/Core/ParamSteps';
|
||||
import { NavigateToModelManagerButton } from 'features/parameters/components/MainModel/NavigateToModelManagerButton';
|
||||
import ParamMainModelSelect from 'features/parameters/components/MainModel/ParamMainModelSelect';
|
||||
import { UseDefaultSettingsButton } from 'features/parameters/components/MainModel/UseDefaultSettingsButton';
|
||||
import { useExpanderToggle } from 'features/settingsAccordions/hooks/useExpanderToggle';
|
||||
import { useStandaloneAccordionToggle } from 'features/settingsAccordions/hooks/useStandaloneAccordionToggle';
|
||||
import { filter } from 'lodash-es';
|
||||
@ -39,11 +31,11 @@ export const GenerationSettingsAccordion = memo(() => {
|
||||
() =>
|
||||
createMemoizedSelector(selectLoraSlice, (lora) => {
|
||||
const enabledLoRAsCount = filter(lora.loras, (l) => !!l.isEnabled).length;
|
||||
const loraTabBadges = enabledLoRAsCount ? [enabledLoRAsCount] : EMPTY_ARRAY;
|
||||
const loraTabBadges = enabledLoRAsCount ? [`${enabledLoRAsCount} ${t('models.concepts')}`] : EMPTY_ARRAY;
|
||||
const accordionBadges = modelConfig ? [modelConfig.name, modelConfig.base] : EMPTY_ARRAY;
|
||||
return { loraTabBadges, accordionBadges };
|
||||
}),
|
||||
[modelConfig]
|
||||
[modelConfig, t]
|
||||
);
|
||||
const { loraTabBadges, accordionBadges } = useAppSelector(selectBadges);
|
||||
const { isOpen: isOpenExpander, onToggle: onToggleExpander } = useExpanderToggle({
|
||||
@ -58,39 +50,35 @@ export const GenerationSettingsAccordion = memo(() => {
|
||||
return (
|
||||
<StandaloneAccordion
|
||||
label={t('accordions.generation.title')}
|
||||
badges={accordionBadges}
|
||||
badges={[...accordionBadges, ...loraTabBadges]}
|
||||
isOpen={isOpenAccordion}
|
||||
onToggle={onToggleAccordion}
|
||||
>
|
||||
<Tabs variant="collapse">
|
||||
<TabList>
|
||||
<Tab>{t('accordions.generation.modelTab')}</Tab>
|
||||
<Tab badges={loraTabBadges}>{t('accordions.generation.conceptsTab')}</Tab>
|
||||
</TabList>
|
||||
<TabPanels>
|
||||
<TabPanel overflow="visible" px={4} pt={4}>
|
||||
<Flex gap={4} alignItems="center">
|
||||
<ParamMainModelSelect />
|
||||
<Box px={4} pt={4}>
|
||||
<Flex gap={4} flexDir="column">
|
||||
<Flex gap={4} alignItems="center">
|
||||
<ParamMainModelSelect />
|
||||
<Flex>
|
||||
<UseDefaultSettingsButton />
|
||||
<SyncModelsIconButton />
|
||||
<NavigateToModelManagerButton />
|
||||
</Flex>
|
||||
<Expander isOpen={isOpenExpander} onToggle={onToggleExpander}>
|
||||
<Flex gap={4} flexDir="column" pb={4}>
|
||||
<FormControlGroup formLabelProps={formLabelProps}>
|
||||
<ParamScheduler />
|
||||
<ParamSteps />
|
||||
<ParamCFGScale />
|
||||
</FormControlGroup>
|
||||
</Flex>
|
||||
</Expander>
|
||||
</TabPanel>
|
||||
<TabPanel>
|
||||
<Flex gap={4} p={4} flexDir="column">
|
||||
<LoRASelect />
|
||||
<LoRAList />
|
||||
</Flex>
|
||||
</TabPanel>
|
||||
</TabPanels>
|
||||
</Tabs>
|
||||
</Flex>
|
||||
<Flex gap={4} flexDir="column">
|
||||
<LoRASelect />
|
||||
<LoRAList />
|
||||
</Flex>
|
||||
</Flex>
|
||||
<Expander isOpen={isOpenExpander} onToggle={onToggleExpander}>
|
||||
<Flex gap={4} flexDir="column" pb={4}>
|
||||
<FormControlGroup formLabelProps={formLabelProps}>
|
||||
<ParamScheduler />
|
||||
<ParamSteps />
|
||||
<ParamCFGScale />
|
||||
</FormControlGroup>
|
||||
</Flex>
|
||||
</Expander>
|
||||
</Box>
|
||||
</StandaloneAccordion>
|
||||
);
|
||||
});
|
||||
|
@ -41,6 +41,8 @@ const initialConfigState: AppConfig = {
|
||||
boundingBoxHeight: { ...baseDimensionConfig },
|
||||
scaledBoundingBoxWidth: { ...baseDimensionConfig },
|
||||
scaledBoundingBoxHeight: { ...baseDimensionConfig },
|
||||
scheduler: 'euler',
|
||||
vaePrecision: 'fp32',
|
||||
steps: {
|
||||
initial: 30,
|
||||
sliderMin: 1,
|
||||
|
File diff suppressed because one or more lines are too long
@ -51,12 +51,12 @@ dependencies = [
|
||||
"torchmetrics==0.11.4",
|
||||
"torchsde==0.2.6",
|
||||
"torchvision==0.16.2",
|
||||
"transformers==4.37.2",
|
||||
"transformers==4.38.2",
|
||||
|
||||
# Core application dependencies, pinned for reproducible builds.
|
||||
"fastapi-events==0.10.1",
|
||||
"fastapi==0.109.2",
|
||||
"huggingface-hub==0.20.3",
|
||||
"huggingface-hub==0.21.3",
|
||||
"pydantic-settings==2.1.0",
|
||||
"pydantic==2.6.1",
|
||||
"python-socketio==5.11.1",
|
||||
@ -64,6 +64,7 @@ dependencies = [
|
||||
|
||||
# Auxiliary dependencies, pinned only if necessary.
|
||||
"albumentations",
|
||||
"blake3",
|
||||
"click",
|
||||
"datasets",
|
||||
"Deprecated",
|
||||
@ -72,7 +73,6 @@ dependencies = [
|
||||
"easing-functions",
|
||||
"einops",
|
||||
"facexlib",
|
||||
"imohash",
|
||||
"matplotlib", # needed for plotting of Penner easing functions
|
||||
"npyscreen",
|
||||
"omegaconf",
|
||||
|
@ -3,6 +3,7 @@ Test the model installer
|
||||
"""
|
||||
|
||||
import platform
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
@ -30,9 +31,8 @@ def test_registration(mm2_installer: ModelInstallServiceBase, embedding_file: Pa
|
||||
matches = store.search_by_attr(model_name="test_embedding")
|
||||
assert len(matches) == 0
|
||||
key = mm2_installer.register_path(embedding_file)
|
||||
assert key is not None
|
||||
assert key != "<NOKEY>"
|
||||
assert len(key) == 32
|
||||
# Not raising here is sufficient - key should be UUIDv4
|
||||
uuid.UUID(key, version=4)
|
||||
|
||||
|
||||
def test_registration_meta(mm2_installer: ModelInstallServiceBase, embedding_file: Path) -> None:
|
||||
|
96
tests/test_model_hash.py
Normal file
96
tests/test_model_hash.py
Normal file
@ -0,0 +1,96 @@
|
||||
# pyright:reportPrivateUsage=false
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Iterable
|
||||
|
||||
import pytest
|
||||
from blake3 import blake3
|
||||
|
||||
from invokeai.backend.model_manager.hash import ALGORITHM, MODEL_FILE_EXTENSIONS, ModelHash
|
||||
|
||||
test_cases: list[tuple[ALGORITHM, str]] = [
|
||||
("md5", "a0cd925fc063f98dbf029eee315060c3"),
|
||||
("sha1", "9e362940e5603fdc60566ea100a288ba2fe48b8c"),
|
||||
("sha256", "6dbdb6a147ad4d808455652bf5a10120161678395f6bfbd21eb6fe4e731aceeb"),
|
||||
(
|
||||
"sha512",
|
||||
"c4a10476b21e00042f638ad5755c561d91f2bb599d3504d25409495e1c7eda94543332a1a90fbb4efdaf9ee462c33e0336b5eae4acfb1fa0b186af452dd67dc6",
|
||||
),
|
||||
("blake3", "ce3f0c5f3c05d119f4a5dcaf209b50d3149046a0d3a9adee9fed4c83cad6b4d0"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("algorithm,expected_hash", test_cases)
|
||||
def test_model_hash_hashes_file(tmp_path: Path, algorithm: ALGORITHM, expected_hash: str):
|
||||
file = Path(tmp_path / "test")
|
||||
file.write_text("model data")
|
||||
md5 = ModelHash(algorithm).hash(file)
|
||||
assert md5 == expected_hash
|
||||
|
||||
|
||||
@pytest.mark.parametrize("algorithm", ["md5", "sha1", "sha256", "sha512", "blake3"])
|
||||
def test_model_hash_hashes_dir(tmp_path: Path, algorithm: ALGORITHM):
|
||||
model_hash = ModelHash(algorithm)
|
||||
files = [Path(tmp_path, f"{i}.bin") for i in range(5)]
|
||||
|
||||
for f in files:
|
||||
f.write_text("data")
|
||||
|
||||
md5 = model_hash.hash(tmp_path)
|
||||
|
||||
# Manual implementation of composite hash - always uses BLAKE3
|
||||
composite_hasher = blake3()
|
||||
for f in files:
|
||||
h = model_hash.hash(f)
|
||||
composite_hasher.update(h.encode("utf-8"))
|
||||
|
||||
assert md5 == composite_hasher.hexdigest()
|
||||
|
||||
|
||||
def test_model_hash_raises_error_on_invalid_algorithm():
|
||||
with pytest.raises(ValueError, match="Algorithm invalid_algorithm not available"):
|
||||
ModelHash("invalid_algorithm") # pyright: ignore [reportArgumentType]
|
||||
|
||||
|
||||
def paths_to_str_set(paths: Iterable[Path]) -> set[str]:
|
||||
return {str(p) for p in paths}
|
||||
|
||||
|
||||
def test_model_hash_filters_out_non_model_files(tmp_path: Path):
|
||||
model_files = {Path(tmp_path, f"{i}{ext}") for i, ext in enumerate(MODEL_FILE_EXTENSIONS)}
|
||||
|
||||
for i, f in enumerate(model_files):
|
||||
f.write_text(f"data{i}")
|
||||
|
||||
assert paths_to_str_set(ModelHash._get_file_paths(tmp_path, ModelHash._default_file_filter)) == paths_to_str_set(
|
||||
model_files
|
||||
)
|
||||
|
||||
# Add file that should be ignored - hash should not change
|
||||
file = tmp_path / "test.icecream"
|
||||
file.write_text("data")
|
||||
|
||||
assert paths_to_str_set(ModelHash._get_file_paths(tmp_path, ModelHash._default_file_filter)) == paths_to_str_set(
|
||||
model_files
|
||||
)
|
||||
|
||||
# Add file that should not be ignored - hash should change
|
||||
file = tmp_path / "test.bin"
|
||||
file.write_text("more data")
|
||||
model_files.add(file)
|
||||
|
||||
assert paths_to_str_set(ModelHash._get_file_paths(tmp_path, ModelHash._default_file_filter)) == paths_to_str_set(
|
||||
model_files
|
||||
)
|
||||
|
||||
|
||||
def test_model_hash_uses_custom_filter(tmp_path: Path):
|
||||
model_files = {Path(tmp_path, f"file{ext}") for ext in [".pickme", ".ignoreme"]}
|
||||
|
||||
for i, f in enumerate(model_files):
|
||||
f.write_text(f"data{i}")
|
||||
|
||||
def file_filter(file_path: str) -> bool:
|
||||
return file_path.endswith(".pickme")
|
||||
|
||||
assert {p.name for p in ModelHash._get_file_paths(tmp_path, file_filter)} == {"file.pickme"}
|
Loading…
Reference in New Issue
Block a user