mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into ryan/regional-conditioning-tuning
This commit is contained in:
@ -14,6 +14,7 @@ from starlette.exceptions import HTTPException
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from invokeai.app.services.model_install import ModelInstallJob
|
||||
from invokeai.app.services.model_metadata.metadata_store_base import ModelMetadataChanges
|
||||
from invokeai.app.services.model_records import (
|
||||
DuplicateModelException,
|
||||
InvalidModelException,
|
||||
@ -32,6 +33,7 @@ from invokeai.backend.model_manager.config import (
|
||||
)
|
||||
from invokeai.backend.model_manager.merge import MergeInterpolationMethod, ModelMerger
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||
from invokeai.backend.model_manager.metadata.metadata_base import BaseMetadata
|
||||
from invokeai.backend.model_manager.search import ModelSearch
|
||||
|
||||
from ..dependencies import ApiDependencies
|
||||
@ -243,6 +245,47 @@ async def get_model_metadata(
|
||||
return result
|
||||
|
||||
|
||||
@model_manager_router.patch(
|
||||
"/i/{key}/metadata",
|
||||
operation_id="update_model_metadata",
|
||||
responses={
|
||||
201: {
|
||||
"description": "The model metadata was updated successfully",
|
||||
"content": {"application/json": {"example": example_model_metadata}},
|
||||
},
|
||||
400: {"description": "Bad request"},
|
||||
},
|
||||
)
|
||||
async def update_model_metadata(
|
||||
key: str = Path(description="Key of the model repo metadata to fetch."),
|
||||
changes: ModelMetadataChanges = Body(description="The changes"),
|
||||
) -> Optional[AnyModelRepoMetadata]:
|
||||
"""Updates or creates a model metadata object."""
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
metadata_store = ApiDependencies.invoker.services.model_manager.store.metadata_store
|
||||
|
||||
try:
|
||||
original_metadata = record_store.get_metadata(key)
|
||||
if original_metadata:
|
||||
if changes.default_settings:
|
||||
original_metadata.default_settings = changes.default_settings
|
||||
|
||||
metadata_store.update_metadata(key, original_metadata)
|
||||
else:
|
||||
metadata_store.add_metadata(
|
||||
key, BaseMetadata(name="", author="", default_settings=changes.default_settings)
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"An error occurred while updating the model metadata: {e}",
|
||||
)
|
||||
|
||||
result: Optional[AnyModelRepoMetadata] = record_store.get_metadata(key)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
"/tags",
|
||||
operation_id="list_tags",
|
||||
@ -451,6 +494,7 @@ async def add_model_record(
|
||||
)
|
||||
async def install_model(
|
||||
source: str = Query(description="Model source to install, can be a local path, repo_id, or remote URL"),
|
||||
inplace: Optional[bool] = Query(description="Whether or not to install a local model in place", default=False),
|
||||
# TODO(MM2): Can we type this?
|
||||
config: Optional[Dict[str, Any]] = Body(
|
||||
description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
|
||||
@ -493,6 +537,7 @@ async def install_model(
|
||||
source=source,
|
||||
config=config,
|
||||
access_token=access_token,
|
||||
inplace=bool(inplace),
|
||||
)
|
||||
logger.info(f"Started installation of {source}")
|
||||
except UnknownModelException as e:
|
||||
|
@ -181,6 +181,16 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation_output("gradient_mask_output")
|
||||
class GradientMaskOutput(BaseInvocationOutput):
|
||||
"""Outputs a denoise mask and an image representing the total gradient of the mask."""
|
||||
|
||||
denoise_mask: DenoiseMaskField = OutputField(description="Mask for denoise model run")
|
||||
expanded_mask_area: ImageField = OutputField(
|
||||
description="Image representing the total gradient area of the mask. For paste-back purposes."
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"create_gradient_mask",
|
||||
title="Create Gradient Mask",
|
||||
@ -201,38 +211,42 @@ class CreateGradientMaskInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> DenoiseMaskOutput:
|
||||
def invoke(self, context: InvocationContext) -> GradientMaskOutput:
|
||||
mask_image = context.images.get_pil(self.mask.image_name, mode="L")
|
||||
if self.coherence_mode == "Box Blur":
|
||||
blur_mask = mask_image.filter(ImageFilter.BoxBlur(self.edge_radius))
|
||||
else: # Gaussian Blur OR Staged
|
||||
# Gaussian Blur uses standard deviation. 1/2 radius is a good approximation
|
||||
blur_mask = mask_image.filter(ImageFilter.GaussianBlur(self.edge_radius / 2))
|
||||
if self.edge_radius > 0:
|
||||
if self.coherence_mode == "Box Blur":
|
||||
blur_mask = mask_image.filter(ImageFilter.BoxBlur(self.edge_radius))
|
||||
else: # Gaussian Blur OR Staged
|
||||
# Gaussian Blur uses standard deviation. 1/2 radius is a good approximation
|
||||
blur_mask = mask_image.filter(ImageFilter.GaussianBlur(self.edge_radius / 2))
|
||||
|
||||
mask_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
|
||||
blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(blur_mask, normalize=False)
|
||||
blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(blur_mask, normalize=False)
|
||||
|
||||
# redistribute blur so that the edges are 0 and blur out to 1
|
||||
blur_tensor = (blur_tensor - 0.5) * 2
|
||||
# redistribute blur so that the original edges are 0 and blur outwards to 1
|
||||
blur_tensor = (blur_tensor - 0.5) * 2
|
||||
|
||||
threshold = 1 - self.minimum_denoise
|
||||
threshold = 1 - self.minimum_denoise
|
||||
|
||||
if self.coherence_mode == "Staged":
|
||||
# wherever the blur_tensor is less than fully masked, convert it to threshold
|
||||
blur_tensor = torch.where((blur_tensor < 1) & (blur_tensor > 0), threshold, blur_tensor)
|
||||
else:
|
||||
# wherever the blur_tensor is above threshold but less than 1, drop it to threshold
|
||||
blur_tensor = torch.where((blur_tensor > threshold) & (blur_tensor < 1), threshold, blur_tensor)
|
||||
|
||||
if self.coherence_mode == "Staged":
|
||||
# wherever the blur_tensor is masked to any degree, convert it to threshold
|
||||
blur_tensor = torch.where((blur_tensor < 1), threshold, blur_tensor)
|
||||
else:
|
||||
# wherever the blur_tensor is above threshold but less than 1, drop it to threshold
|
||||
blur_tensor = torch.where((blur_tensor > threshold) & (blur_tensor < 1), threshold, blur_tensor)
|
||||
|
||||
# multiply original mask to force actually masked regions to 0
|
||||
blur_tensor = mask_tensor * blur_tensor
|
||||
blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
|
||||
|
||||
mask_name = context.tensors.save(tensor=blur_tensor.unsqueeze(1))
|
||||
|
||||
return DenoiseMaskOutput.build(
|
||||
mask_name=mask_name,
|
||||
masked_latents_name=None,
|
||||
gradient=True,
|
||||
# compute a [0, 1] mask from the blur_tensor
|
||||
expanded_mask = torch.where((blur_tensor < 1), 0, 1)
|
||||
expanded_mask_image = Image.fromarray((expanded_mask.squeeze(0).numpy() * 255).astype(np.uint8), mode="L")
|
||||
expanded_image_dto = context.images.save(expanded_mask_image)
|
||||
|
||||
return GradientMaskOutput(
|
||||
denoise_mask=DenoiseMaskField(mask_name=mask_name, masked_latents_name=None, gradient=True),
|
||||
expanded_mask_area=ImageField(image_name=expanded_image_dto.image_name),
|
||||
)
|
||||
|
||||
|
||||
@ -518,7 +532,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
def get_conditioning_data(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
unet,
|
||||
unet: UNet2DConditionModel,
|
||||
latent_height: int,
|
||||
latent_width: int,
|
||||
) -> TextConditioningData:
|
||||
|
@ -7,7 +7,6 @@ import time
|
||||
from hashlib import sha256
|
||||
from pathlib import Path
|
||||
from queue import Empty, Queue
|
||||
from random import randbytes
|
||||
from shutil import copyfile, copytree, move, rmtree
|
||||
from tempfile import mkdtemp
|
||||
from typing import Any, Dict, List, Optional, Set, Union
|
||||
@ -21,6 +20,7 @@ from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
@ -150,7 +150,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
config = config or {}
|
||||
if not config.get("source"):
|
||||
config["source"] = model_path.resolve().as_posix()
|
||||
config["key"] = config.get("key", self._create_key())
|
||||
config["key"] = config.get("key", uuid_string())
|
||||
|
||||
info: AnyModelConfig = self._probe_model(Path(model_path), config)
|
||||
|
||||
@ -178,13 +178,14 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
source: str,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
access_token: Optional[str] = None,
|
||||
inplace: bool = False,
|
||||
) -> ModelInstallJob:
|
||||
variants = "|".join(ModelRepoVariant.__members__.values())
|
||||
hf_repoid_re = f"^([^/:]+/[^/:]+)(?::({variants})?(?::/?([^:]+))?)?$"
|
||||
source_obj: Optional[StringLikeSource] = None
|
||||
|
||||
if Path(source).exists(): # A local file or directory
|
||||
source_obj = LocalModelSource(path=Path(source))
|
||||
source_obj = LocalModelSource(path=Path(source), inplace=inplace)
|
||||
elif match := re.match(hf_repoid_re, source):
|
||||
source_obj = HFModelSource(
|
||||
repo_id=match.group(1),
|
||||
@ -526,16 +527,17 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
setattr(info, key, value)
|
||||
return info
|
||||
|
||||
def _create_key(self) -> str:
|
||||
return sha256(randbytes(100)).hexdigest()[0:32]
|
||||
|
||||
def _register(
|
||||
self, model_path: Path, config: Optional[Dict[str, Any]] = None, info: Optional[AnyModelConfig] = None
|
||||
) -> str:
|
||||
# Note that we may be passed a pre-populated AnyModelConfig object,
|
||||
# in which case the key field should have been populated by the caller (e.g. in `install_path`).
|
||||
config["key"] = config.get("key", self._create_key())
|
||||
config["key"] = config.get("key", uuid_string())
|
||||
info = info or ModelProbe.probe(model_path, config)
|
||||
override_key: Optional[str] = config.get("key") if config else None
|
||||
|
||||
assert info.original_hash # always assigned by probe()
|
||||
info.key = override_key or info.original_hash
|
||||
|
||||
model_path = model_path.absolute()
|
||||
if model_path.is_relative_to(self.app_config.models_path):
|
||||
|
@ -4,9 +4,25 @@ Storage for Model Metadata
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Set, Tuple
|
||||
from typing import List, Optional, Set, Tuple
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||
from invokeai.backend.model_manager.metadata.metadata_base import ModelDefaultSettings
|
||||
|
||||
|
||||
class ModelMetadataChanges(BaseModelExcludeNull, extra="allow"):
|
||||
"""A set of changes to apply to model metadata.
|
||||
Only limited changes are valid:
|
||||
- `default_settings`: the user-configured default settings for this model
|
||||
"""
|
||||
|
||||
default_settings: Optional[ModelDefaultSettings] = Field(
|
||||
default=None, description="The user-configured default settings for this model"
|
||||
)
|
||||
"""The user-configured default settings for this model"""
|
||||
|
||||
|
||||
class ModelMetadataStoreBase(ABC):
|
||||
|
@ -179,44 +179,45 @@ class ModelMetadataStoreSQL(ModelMetadataStoreBase):
|
||||
)
|
||||
return {x[0] for x in self._cursor.fetchall()}
|
||||
|
||||
def _update_tags(self, model_key: str, tags: Set[str]) -> None:
|
||||
def _update_tags(self, model_key: str, tags: Optional[Set[str]]) -> None:
|
||||
"""Update tags for the model referenced by model_key."""
|
||||
# remove previous tags from this model
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM model_tags
|
||||
WHERE model_id=?;
|
||||
""",
|
||||
(model_key,),
|
||||
)
|
||||
if tags:
|
||||
# remove previous tags from this model
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM model_tags
|
||||
WHERE model_id=?;
|
||||
""",
|
||||
(model_key,),
|
||||
)
|
||||
|
||||
for tag in tags:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO tags (
|
||||
tag_text
|
||||
)
|
||||
VALUES (?);
|
||||
""",
|
||||
(tag,),
|
||||
)
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT tag_id
|
||||
FROM tags
|
||||
WHERE tag_text = ?
|
||||
LIMIT 1;
|
||||
""",
|
||||
(tag,),
|
||||
)
|
||||
tag_id = self._cursor.fetchone()[0]
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO model_tags (
|
||||
model_id,
|
||||
tag_id
|
||||
)
|
||||
VALUES (?,?);
|
||||
""",
|
||||
(model_key, tag_id),
|
||||
)
|
||||
for tag in tags:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO tags (
|
||||
tag_text
|
||||
)
|
||||
VALUES (?);
|
||||
""",
|
||||
(tag,),
|
||||
)
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT tag_id
|
||||
FROM tags
|
||||
WHERE tag_text = ?
|
||||
LIMIT 1;
|
||||
""",
|
||||
(tag,),
|
||||
)
|
||||
tag_id = self._cursor.fetchone()[0]
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO model_tags (
|
||||
model_id,
|
||||
tag_id
|
||||
)
|
||||
VALUES (?,?);
|
||||
""",
|
||||
(model_key, tag_id),
|
||||
)
|
||||
|
@ -200,6 +200,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
self._invoker.services.logger.error(
|
||||
f"Error while invoking session {self._queue_item.session_id}, invocation {self._invocation.id} ({self._invocation.get_type()}):\n{e}"
|
||||
)
|
||||
self._invoker.services.logger.error(error)
|
||||
|
||||
# Send error event
|
||||
self._invoker.services.events.emit_invocation_error(
|
||||
|
@ -3,7 +3,6 @@
|
||||
|
||||
import json
|
||||
import sqlite3
|
||||
from hashlib import sha1
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
@ -22,7 +21,7 @@ from invokeai.backend.model_manager.config import (
|
||||
ModelConfigFactory,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.hash import FastModelHash
|
||||
from invokeai.backend.model_manager.hash import ModelHash
|
||||
|
||||
ModelsValidator = TypeAdapter(AnyModelConfig)
|
||||
|
||||
@ -73,19 +72,27 @@ class MigrateModelYamlToDb1:
|
||||
|
||||
base_type, model_type, model_name = str(model_key).split("/")
|
||||
try:
|
||||
hash = FastModelHash.hash(self.config.models_path / stanza.path)
|
||||
hash = ModelHash().hash(self.config.models_path / stanza.path)
|
||||
except OSError:
|
||||
self.logger.warning(f"The model at {stanza.path} is not a valid file or directory. Skipping migration.")
|
||||
continue
|
||||
|
||||
assert isinstance(model_key, str)
|
||||
new_key = sha1(model_key.encode("utf-8")).hexdigest()
|
||||
|
||||
stanza["base"] = BaseModelType(base_type)
|
||||
stanza["type"] = ModelType(model_type)
|
||||
stanza["name"] = model_name
|
||||
stanza["original_hash"] = hash
|
||||
stanza["current_hash"] = hash
|
||||
new_key = hash # deterministic key assignment
|
||||
|
||||
# special case for ip adapters, which need the new `image_encoder_model_id` field
|
||||
if stanza["type"] == ModelType.IPAdapter:
|
||||
try:
|
||||
stanza["image_encoder_model_id"] = self._get_image_encoder_model_id(
|
||||
self.config.models_path / stanza.path
|
||||
)
|
||||
except OSError:
|
||||
self.logger.warning(f"Could not determine image encoder for {stanza.path}. Skipping.")
|
||||
continue
|
||||
|
||||
new_config: AnyModelConfig = ModelsValidator.validate_python(stanza) # type: ignore # see https://github.com/pydantic/pydantic/discussions/7094
|
||||
|
||||
@ -95,7 +102,7 @@ class MigrateModelYamlToDb1:
|
||||
self.logger.info(f"Updating model {model_name} with information from models.yaml using key {key}")
|
||||
self._update_model(key, new_config)
|
||||
else:
|
||||
self.logger.info(f"Adding model {model_name} with key {model_key}")
|
||||
self.logger.info(f"Adding model {model_name} with key {new_key}")
|
||||
self._add_model(new_key, new_config)
|
||||
except DuplicateModelException:
|
||||
self.logger.warning(f"Model {model_name} is already in the database")
|
||||
@ -149,3 +156,8 @@ class MigrateModelYamlToDb1:
|
||||
)
|
||||
except sqlite3.IntegrityError as exc:
|
||||
raise DuplicateModelException(f"{record.name}: model is already in database") from exc
|
||||
|
||||
def _get_image_encoder_model_id(self, model_path: Path) -> str:
|
||||
with open(model_path / "image_encoder.txt") as f:
|
||||
encoder = f.read()
|
||||
return encoder.strip()
|
||||
|
@ -11,56 +11,175 @@ from invokeai.backend.model_managre.model_hash import FastModelHash
|
||||
import hashlib
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict, Union
|
||||
from typing import Callable, Literal, Optional, Union
|
||||
|
||||
from imohash import hashfile
|
||||
from blake3 import blake3
|
||||
|
||||
MODEL_FILE_EXTENSIONS = (".ckpt", ".safetensors", ".bin", ".pt", ".pth")
|
||||
|
||||
ALGORITHM = Literal[
|
||||
"md5",
|
||||
"sha1",
|
||||
"sha224",
|
||||
"sha256",
|
||||
"sha384",
|
||||
"sha512",
|
||||
"blake2b",
|
||||
"blake2s",
|
||||
"sha3_224",
|
||||
"sha3_256",
|
||||
"sha3_384",
|
||||
"sha3_512",
|
||||
"shake_128",
|
||||
"shake_256",
|
||||
"blake3",
|
||||
]
|
||||
|
||||
|
||||
class FastModelHash(object):
|
||||
"""FastModelHash obect provides one public class method, hash()."""
|
||||
class ModelHash:
|
||||
"""
|
||||
Creates a hash of a model using a specified algorithm.
|
||||
|
||||
@classmethod
|
||||
def hash(cls, model_location: Union[str, Path]) -> str:
|
||||
"""
|
||||
Return hexdigest string for model located at model_location.
|
||||
Args:
|
||||
algorithm: Hashing algorithm to use. Defaults to BLAKE3.
|
||||
file_filter: A function that takes a file name and returns True if the file should be included in the hash.
|
||||
|
||||
:param model_location: Path to the model
|
||||
"""
|
||||
model_location = Path(model_location)
|
||||
if model_location.is_file():
|
||||
return cls._hash_file(model_location)
|
||||
elif model_location.is_dir():
|
||||
return cls._hash_dir(model_location)
|
||||
If the model is a single file, it is hashed directly using the provided algorithm.
|
||||
|
||||
If the model is a directory, each model weights file in the directory is hashed using the provided algorithm.
|
||||
|
||||
Only files with the following extensions are hashed: .ckpt, .safetensors, .bin, .pt, .pth
|
||||
|
||||
The final hash is computed by hashing the hashes of all model files in the directory using BLAKE3, ensuring
|
||||
that directory hashes are never weaker than the file hashes.
|
||||
|
||||
Usage:
|
||||
```py
|
||||
# BLAKE3 hash
|
||||
ModelHash().hash("path/to/some/model.safetensors")
|
||||
# MD5
|
||||
ModelHash("md5").hash("path/to/model/dir/")
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, algorithm: ALGORITHM = "blake3", file_filter: Optional[Callable[[str], bool]] = None) -> None:
|
||||
if algorithm == "blake3":
|
||||
self._hash_file = self._blake3
|
||||
elif algorithm in hashlib.algorithms_available:
|
||||
self._hash_file = self._get_hashlib(algorithm)
|
||||
else:
|
||||
raise OSError(f"Not a valid file or directory: {model_location}")
|
||||
raise ValueError(f"Algorithm {algorithm} not available")
|
||||
|
||||
@classmethod
|
||||
def _hash_file(cls, model_location: Union[str, Path]) -> str:
|
||||
self._file_filter = file_filter or self._default_file_filter
|
||||
|
||||
def hash(self, model_path: Union[str, Path]) -> str:
|
||||
"""
|
||||
Fasthash a single file and return its hexdigest.
|
||||
Return hexdigest of hash of model located at model_path using the algorithm provided at class instantiation.
|
||||
|
||||
:param model_location: Path to the model file
|
||||
If model_path is a directory, the hash is computed by hashing the hashes of all model files in the
|
||||
directory. The final composite hash is always computed using BLAKE3.
|
||||
|
||||
Args:
|
||||
model_path: Path to the model
|
||||
|
||||
Returns:
|
||||
str: Hexdigest of the hash of the model
|
||||
"""
|
||||
# we return md5 hash of the filehash to make it shorter
|
||||
# cryptographic security not needed here
|
||||
return hashlib.md5(hashfile(model_location)).hexdigest()
|
||||
|
||||
@classmethod
|
||||
def _hash_dir(cls, model_location: Union[str, Path]) -> str:
|
||||
components: Dict[str, str] = {}
|
||||
model_path = Path(model_path)
|
||||
if model_path.is_file():
|
||||
return self._hash_file(model_path)
|
||||
elif model_path.is_dir():
|
||||
return self._hash_dir(model_path)
|
||||
else:
|
||||
raise OSError(f"Not a valid file or directory: {model_path}")
|
||||
|
||||
for root, _dirs, files in os.walk(model_location):
|
||||
for file in files:
|
||||
# only tally tensor files because diffusers config files change slightly
|
||||
# depending on how the model was downloaded/converted.
|
||||
if not file.endswith((".ckpt", ".safetensors", ".bin", ".pt", ".pth")):
|
||||
continue
|
||||
path = (Path(root) / file).as_posix()
|
||||
fast_hash = cls._hash_file(path)
|
||||
components.update({path: fast_hash})
|
||||
def _hash_dir(self, dir: Path) -> str:
|
||||
"""Compute the hash for all files in a directory and return a hexdigest.
|
||||
|
||||
# hash all the model hashes together, using alphabetic file order
|
||||
md5 = hashlib.md5()
|
||||
for _path, fast_hash in sorted(components.items()):
|
||||
md5.update(fast_hash.encode("utf-8"))
|
||||
return md5.hexdigest()
|
||||
Args:
|
||||
dir: Path to the directory
|
||||
|
||||
Returns:
|
||||
str: Hexdigest of the hash of the directory
|
||||
"""
|
||||
model_component_paths = self._get_file_paths(dir, self._file_filter)
|
||||
|
||||
component_hashes: list[str] = []
|
||||
for component in sorted(model_component_paths):
|
||||
component_hashes.append(self._hash_file(component))
|
||||
|
||||
# BLAKE3 is cryptographically secure. We may as well fall back on a secure algorithm
|
||||
# for the composite hash
|
||||
composite_hasher = blake3()
|
||||
for h in component_hashes:
|
||||
composite_hasher.update(h.encode("utf-8"))
|
||||
return composite_hasher.hexdigest()
|
||||
|
||||
@staticmethod
|
||||
def _get_file_paths(model_path: Path, file_filter: Callable[[str], bool]) -> list[Path]:
|
||||
"""Return a list of all model files in the directory.
|
||||
|
||||
Args:
|
||||
model_path: Path to the model
|
||||
file_filter: Function that takes a file name and returns True if the file should be included in the list.
|
||||
|
||||
Returns:
|
||||
List of all model files in the directory
|
||||
"""
|
||||
|
||||
files: list[Path] = []
|
||||
for root, _dirs, _files in os.walk(model_path):
|
||||
for file in _files:
|
||||
if file_filter(file):
|
||||
files.append(Path(root, file))
|
||||
return files
|
||||
|
||||
@staticmethod
|
||||
def _blake3(file_path: Path) -> str:
|
||||
"""Hashes a file using BLAKE3
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to hash
|
||||
|
||||
Returns:
|
||||
Hexdigest of the hash of the file
|
||||
"""
|
||||
file_hasher = blake3(max_threads=blake3.AUTO)
|
||||
file_hasher.update_mmap(file_path)
|
||||
return file_hasher.hexdigest()
|
||||
|
||||
@staticmethod
|
||||
def _get_hashlib(algorithm: ALGORITHM) -> Callable[[Path], str]:
|
||||
"""Factory function that returns a function to hash a file with the given algorithm.
|
||||
|
||||
Args:
|
||||
algorithm: Hashing algorithm to use
|
||||
|
||||
Returns:
|
||||
A function that hashes a file using the given algorithm
|
||||
"""
|
||||
|
||||
def hashlib_hasher(file_path: Path) -> str:
|
||||
"""Hashes a file using a hashlib algorithm. Uses `memoryview` to avoid reading the entire file into memory."""
|
||||
hasher = hashlib.new(algorithm)
|
||||
buffer = bytearray(128 * 1024)
|
||||
mv = memoryview(buffer)
|
||||
with open(file_path, "rb", buffering=0) as f:
|
||||
while n := f.readinto(mv):
|
||||
hasher.update(mv[:n])
|
||||
return hasher.hexdigest()
|
||||
|
||||
return hashlib_hasher
|
||||
|
||||
@staticmethod
|
||||
def _default_file_filter(file_path: str) -> bool:
|
||||
"""A default file filter that only includes files with the following extensions: .ckpt, .safetensors, .bin, .pt, .pth
|
||||
|
||||
Args:
|
||||
file_path: Path to the file
|
||||
|
||||
Returns:
|
||||
True if the file matches the given extensions, otherwise False
|
||||
"""
|
||||
return file_path.endswith(MODEL_FILE_EXTENSIONS)
|
||||
|
@ -25,6 +25,7 @@ from pydantic.networks import AnyHttpUrl
|
||||
from requests.sessions import Session
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from invokeai.app.invocations.constants import SCHEDULER_NAME_VALUES
|
||||
from invokeai.backend.model_manager import ModelRepoVariant
|
||||
|
||||
from ..util import select_hf_files
|
||||
@ -68,12 +69,24 @@ class RemoteModelFile(BaseModel):
|
||||
sha256: Optional[str] = Field(description="SHA256 hash of this model (not always available)", default=None)
|
||||
|
||||
|
||||
class ModelDefaultSettings(BaseModel):
|
||||
vae: str | None
|
||||
vae_precision: str | None
|
||||
scheduler: SCHEDULER_NAME_VALUES | None
|
||||
steps: int | None
|
||||
cfg_scale: float | None
|
||||
cfg_rescale_multiplier: float | None
|
||||
|
||||
|
||||
class ModelMetadataBase(BaseModel):
|
||||
"""Base class for model metadata information."""
|
||||
|
||||
name: str = Field(description="model's name")
|
||||
author: str = Field(description="model's author")
|
||||
tags: Set[str] = Field(description="tags provided by model source")
|
||||
tags: Optional[Set[str]] = Field(description="tags provided by model source", default=None)
|
||||
default_settings: Optional[ModelDefaultSettings] = Field(
|
||||
description="default settings for this model", default=None
|
||||
)
|
||||
|
||||
|
||||
class BaseMetadata(ModelMetadataBase):
|
||||
|
@ -21,7 +21,7 @@ from .config import (
|
||||
ModelVariantType,
|
||||
SchedulerPredictionType,
|
||||
)
|
||||
from .hash import FastModelHash
|
||||
from .hash import ModelHash
|
||||
from .util.model_util import lora_token_vector_length, read_checkpoint_meta
|
||||
|
||||
CkptType = Dict[str, Any]
|
||||
@ -147,7 +147,7 @@ class ModelProbe(object):
|
||||
if not probe_class:
|
||||
raise InvalidModelConfigException(f"Unhandled combination of {format_type} and {model_type}")
|
||||
|
||||
hash = FastModelHash.hash(model_path)
|
||||
hash = ModelHash().hash(model_path)
|
||||
probe = probe_class(model_path)
|
||||
|
||||
fields["path"] = model_path.as_posix()
|
||||
|
@ -134,8 +134,6 @@
|
||||
"loadMore": "Mehr laden",
|
||||
"noImagesInGallery": "Keine Bilder in der Galerie",
|
||||
"loading": "Lade",
|
||||
"preparingDownload": "bereite Download vor",
|
||||
"preparingDownloadFailed": "Problem beim Download vorbereiten",
|
||||
"deleteImage": "Lösche Bild",
|
||||
"copy": "Kopieren",
|
||||
"download": "Runterladen",
|
||||
@ -967,7 +965,7 @@
|
||||
"resumeFailed": "Problem beim Fortsetzen des Prozesses",
|
||||
"pruneFailed": "Problem beim leeren der Warteschlange",
|
||||
"pauseTooltip": "Prozess anhalten",
|
||||
"back": "Hinten",
|
||||
"back": "Ende",
|
||||
"resumeSucceeded": "Prozess wird fortgesetzt",
|
||||
"resumeTooltip": "Prozess wieder aufnehmen",
|
||||
"time": "Zeit",
|
||||
|
@ -78,6 +78,7 @@
|
||||
"aboutDesc": "Using Invoke for work? Check out:",
|
||||
"aboutHeading": "Own Your Creative Power",
|
||||
"accept": "Accept",
|
||||
"add": "Add",
|
||||
"advanced": "Advanced",
|
||||
"advancedOptions": "Advanced Options",
|
||||
"ai": "ai",
|
||||
@ -734,6 +735,8 @@
|
||||
"customConfig": "Custom Config",
|
||||
"customConfigFileLocation": "Custom Config File Location",
|
||||
"customSaveLocation": "Custom Save Location",
|
||||
"defaultSettings": "Default Settings",
|
||||
"defaultSettingsSaved": "Default Settings Saved",
|
||||
"delete": "Delete",
|
||||
"deleteConfig": "Delete Config",
|
||||
"deleteModel": "Delete Model",
|
||||
@ -768,6 +771,7 @@
|
||||
"mergedModelName": "Merged Model Name",
|
||||
"mergedModelSaveLocation": "Save Location",
|
||||
"mergeModels": "Merge Models",
|
||||
"metadata": "Metadata",
|
||||
"model": "Model",
|
||||
"modelAdded": "Model Added",
|
||||
"modelConversionFailed": "Model Conversion Failed",
|
||||
@ -839,9 +843,12 @@
|
||||
"statusConverting": "Converting",
|
||||
"syncModels": "Sync Models",
|
||||
"syncModelsDesc": "If your models are out of sync with the backend, you can refresh them up using this option. This is generally handy in cases where you add models to the InvokeAI root folder or autoimport directory after the application has booted.",
|
||||
"triggerPhrases": "Trigger Phrases",
|
||||
"typePhraseHere": "Type phrase here",
|
||||
"upcastAttention": "Upcast Attention",
|
||||
"updateModel": "Update Model",
|
||||
"useCustomConfig": "Use Custom Config",
|
||||
"useDefaultSettings": "Use Default Settings",
|
||||
"v1": "v1",
|
||||
"v2_768": "v2 (768px)",
|
||||
"v2_base": "v2 (512px)",
|
||||
@ -860,6 +867,7 @@
|
||||
"models": {
|
||||
"addLora": "Add LoRA",
|
||||
"allLoRAsAdded": "All LoRAs added",
|
||||
"concepts": "Concepts",
|
||||
"loraAlreadyAdded": "LoRA already added",
|
||||
"esrganModel": "ESRGAN Model",
|
||||
"loading": "loading",
|
||||
|
@ -505,8 +505,6 @@
|
||||
"seamLowThreshold": "Bajo",
|
||||
"coherencePassHeader": "Parámetros de la coherencia",
|
||||
"compositingSettingsHeader": "Ajustes de la composición",
|
||||
"coherenceSteps": "Pasos",
|
||||
"coherenceStrength": "Fuerza",
|
||||
"patchmatchDownScaleSize": "Reducir a escala",
|
||||
"coherenceMode": "Modo"
|
||||
},
|
||||
|
@ -114,7 +114,8 @@
|
||||
"checkpoint": "Checkpoint",
|
||||
"safetensors": "Safetensors",
|
||||
"ai": "ia",
|
||||
"file": "File"
|
||||
"file": "File",
|
||||
"toResolve": "Da risolvere"
|
||||
},
|
||||
"gallery": {
|
||||
"generations": "Generazioni",
|
||||
@ -142,8 +143,6 @@
|
||||
"copy": "Copia",
|
||||
"download": "Scarica",
|
||||
"setCurrentImage": "Imposta come immagine corrente",
|
||||
"preparingDownload": "Preparazione del download",
|
||||
"preparingDownloadFailed": "Problema durante la preparazione del download",
|
||||
"downloadSelection": "Scarica gli elementi selezionati",
|
||||
"noImageSelected": "Nessuna immagine selezionata",
|
||||
"deleteSelection": "Elimina la selezione",
|
||||
@ -609,8 +608,6 @@
|
||||
"seamLowThreshold": "Basso",
|
||||
"seamHighThreshold": "Alto",
|
||||
"coherencePassHeader": "Passaggio di coerenza",
|
||||
"coherenceSteps": "Passi",
|
||||
"coherenceStrength": "Forza",
|
||||
"compositingSettingsHeader": "Impostazioni di composizione",
|
||||
"patchmatchDownScaleSize": "Ridimensiona",
|
||||
"coherenceMode": "Modalità",
|
||||
@ -1400,19 +1397,6 @@
|
||||
"Regola la maschera."
|
||||
]
|
||||
},
|
||||
"compositingCoherenceSteps": {
|
||||
"heading": "Passi",
|
||||
"paragraphs": [
|
||||
"Numero di passi utilizzati nel Passaggio di Coerenza.",
|
||||
"Simile ai passi di generazione."
|
||||
]
|
||||
},
|
||||
"compositingBlur": {
|
||||
"heading": "Sfocatura",
|
||||
"paragraphs": [
|
||||
"Il raggio di sfocatura della maschera."
|
||||
]
|
||||
},
|
||||
"compositingCoherenceMode": {
|
||||
"heading": "Modalità",
|
||||
"paragraphs": [
|
||||
@ -1431,13 +1415,6 @@
|
||||
"Un secondo ciclo di riduzione del rumore aiuta a comporre l'immagine Inpaint/Outpaint."
|
||||
]
|
||||
},
|
||||
"compositingStrength": {
|
||||
"heading": "Forza",
|
||||
"paragraphs": [
|
||||
"Quantità di rumore aggiunta per il Passaggio di Coerenza.",
|
||||
"Simile alla forza di riduzione del rumore."
|
||||
]
|
||||
},
|
||||
"paramNegativeConditioning": {
|
||||
"paragraphs": [
|
||||
"Il processo di generazione evita i concetti nel prompt negativo. Utilizzatelo per escludere qualità o oggetti dall'output.",
|
||||
|
@ -123,8 +123,6 @@
|
||||
"autoSwitchNewImages": "새로운 이미지로 자동 전환",
|
||||
"loading": "불러오는 중",
|
||||
"unableToLoad": "갤러리를 로드할 수 없음",
|
||||
"preparingDownload": "다운로드 준비",
|
||||
"preparingDownloadFailed": "다운로드 준비 중 발생한 문제",
|
||||
"singleColumnLayout": "단일 열 레이아웃",
|
||||
"image": "이미지",
|
||||
"loadMore": "더 불러오기",
|
||||
|
@ -97,8 +97,6 @@
|
||||
"featuresWillReset": "Als je deze afbeelding verwijdert, dan worden deze functies onmiddellijk teruggezet.",
|
||||
"loading": "Bezig met laden",
|
||||
"unableToLoad": "Kan galerij niet laden",
|
||||
"preparingDownload": "Bezig met voorbereiden van download",
|
||||
"preparingDownloadFailed": "Fout bij voorbereiden van download",
|
||||
"downloadSelection": "Download selectie",
|
||||
"currentlyInUse": "Deze afbeelding is momenteel in gebruik door de volgende functies:",
|
||||
"copy": "Kopieer",
|
||||
@ -535,8 +533,6 @@
|
||||
"coherencePassHeader": "Coherentiestap",
|
||||
"maskBlur": "Vervaag",
|
||||
"maskBlurMethod": "Vervagingsmethode",
|
||||
"coherenceSteps": "Stappen",
|
||||
"coherenceStrength": "Sterkte",
|
||||
"seamHighThreshold": "Hoog",
|
||||
"seamLowThreshold": "Laag",
|
||||
"invoke": {
|
||||
@ -1139,13 +1135,6 @@
|
||||
"Een afbeeldingsgrootte (in aantal pixels) equivalent aan 512x512 wordt aanbevolen voor SD1.5-modellen. Een grootte-equivalent van 1024x1024 wordt aanbevolen voor SDXL-modellen."
|
||||
]
|
||||
},
|
||||
"compositingCoherenceSteps": {
|
||||
"heading": "Stappen",
|
||||
"paragraphs": [
|
||||
"Het aantal te gebruiken ontruisingsstappen in de coherentiefase.",
|
||||
"Gelijk aan de hoofdparameter Stappen."
|
||||
]
|
||||
},
|
||||
"dynamicPrompts": {
|
||||
"paragraphs": [
|
||||
"Dynamische prompts vormt een enkele prompt om in vele.",
|
||||
@ -1160,12 +1149,6 @@
|
||||
],
|
||||
"heading": "VAE"
|
||||
},
|
||||
"compositingBlur": {
|
||||
"heading": "Vervaging",
|
||||
"paragraphs": [
|
||||
"De vervagingsstraal van het masker."
|
||||
]
|
||||
},
|
||||
"paramIterations": {
|
||||
"paragraphs": [
|
||||
"Het aantal te genereren afbeeldingen.",
|
||||
@ -1240,13 +1223,6 @@
|
||||
],
|
||||
"heading": "Ontruisingssterkte"
|
||||
},
|
||||
"compositingStrength": {
|
||||
"heading": "Sterkte",
|
||||
"paragraphs": [
|
||||
"Ontruisingssterkte voor de coherentiefase.",
|
||||
"Gelijk aan de parameter Ontruisingssterkte Afbeelding naar afbeelding."
|
||||
]
|
||||
},
|
||||
"paramNegativeConditioning": {
|
||||
"paragraphs": [
|
||||
"Het genereerproces voorkomt de gegeven begrippen in de negatieve prompt. Gebruik dit om bepaalde zaken of voorwerpen uit te sluiten van de uitvoerafbeelding.",
|
||||
|
@ -143,8 +143,6 @@
|
||||
"problemDeletingImagesDesc": "Не удалось удалить одно или несколько изображений",
|
||||
"loading": "Загрузка",
|
||||
"unableToLoad": "Невозможно загрузить галерею",
|
||||
"preparingDownload": "Подготовка к скачиванию",
|
||||
"preparingDownloadFailed": "Проблема с подготовкой к скачиванию",
|
||||
"image": "изображение",
|
||||
"drop": "перебросить",
|
||||
"problemDeletingImages": "Проблема с удалением изображений",
|
||||
@ -612,9 +610,7 @@
|
||||
"maskBlurMethod": "Метод размытия",
|
||||
"seamLowThreshold": "Низкий",
|
||||
"seamHighThreshold": "Высокий",
|
||||
"coherenceSteps": "Шагов",
|
||||
"coherencePassHeader": "Порог Coherence",
|
||||
"coherenceStrength": "Сила",
|
||||
"compositingSettingsHeader": "Настройки компоновки",
|
||||
"invoke": {
|
||||
"noNodesInGraph": "Нет узлов в графе",
|
||||
@ -1321,13 +1317,6 @@
|
||||
"Размер изображения (в пикселях), эквивалентный 512x512, рекомендуется для моделей SD1.5, а размер, эквивалентный 1024x1024, рекомендуется для моделей SDXL."
|
||||
]
|
||||
},
|
||||
"compositingCoherenceSteps": {
|
||||
"heading": "Шаги",
|
||||
"paragraphs": [
|
||||
"Количество шагов снижения шума, используемых при прохождении когерентности.",
|
||||
"То же, что и основной параметр «Шаги»."
|
||||
]
|
||||
},
|
||||
"dynamicPrompts": {
|
||||
"paragraphs": [
|
||||
"Динамические запросы превращают одно приглашение на множество.",
|
||||
@ -1342,12 +1331,6 @@
|
||||
],
|
||||
"heading": "VAE"
|
||||
},
|
||||
"compositingBlur": {
|
||||
"heading": "Размытие",
|
||||
"paragraphs": [
|
||||
"Радиус размытия маски."
|
||||
]
|
||||
},
|
||||
"paramIterations": {
|
||||
"paragraphs": [
|
||||
"Количество изображений, которые нужно сгенерировать.",
|
||||
@ -1422,13 +1405,6 @@
|
||||
],
|
||||
"heading": "Шумоподавление"
|
||||
},
|
||||
"compositingStrength": {
|
||||
"heading": "Сила",
|
||||
"paragraphs": [
|
||||
null,
|
||||
"То же, что параметр «Сила шумоподавления img2img»."
|
||||
]
|
||||
},
|
||||
"paramNegativeConditioning": {
|
||||
"paragraphs": [
|
||||
"Stable Diffusion пытается избежать указанных в отрицательном запросе концепций. Используйте это, чтобы исключить качества или объекты из вывода.",
|
||||
|
@ -355,7 +355,6 @@
|
||||
"starImage": "Yıldız Koy",
|
||||
"download": "İndir",
|
||||
"deleteSelection": "Seçileni Sil",
|
||||
"preparingDownloadFailed": "İndirme Hazırlanırken Sorun",
|
||||
"problemDeletingImages": "Görsel Silmede Sorun",
|
||||
"featuresWillReset": "Bu görseli silerseniz, o özellikler resetlenecektir.",
|
||||
"galleryImageResetSize": "Boyutu Resetle",
|
||||
@ -377,7 +376,6 @@
|
||||
"setCurrentImage": "Çalışma Görseli Yap",
|
||||
"unableToLoad": "Galeri Yüklenemedi",
|
||||
"downloadSelection": "Seçileni İndir",
|
||||
"preparingDownload": "İndirmeye Hazırlanıyor",
|
||||
"singleColumnLayout": "Tek Sütun Düzen",
|
||||
"generations": "Çıktılar",
|
||||
"showUploads": "Yüklenenleri Göster",
|
||||
@ -723,7 +721,6 @@
|
||||
"clipSkip": "CLIP Atlama",
|
||||
"randomizeSeed": "Rastgele Tohum",
|
||||
"cfgScale": "CFG Ölçeği",
|
||||
"coherenceStrength": "Etki",
|
||||
"controlNetControlMode": "Yönetim Kipi",
|
||||
"general": "Genel",
|
||||
"img2imgStrength": "Görselden Görsel Ölçüsü",
|
||||
@ -793,7 +790,6 @@
|
||||
"cfgRescaleMultiplier": "CFG Rescale Çarpanı",
|
||||
"cfgRescale": "CFG Rescale",
|
||||
"coherencePassHeader": "Uyum Geçişi",
|
||||
"coherenceSteps": "Adım",
|
||||
"infillMethod": "Doldurma Yöntemi",
|
||||
"maskBlurMethod": "Bulandırma Yöntemi",
|
||||
"steps": "Adım",
|
||||
|
@ -136,8 +136,6 @@
|
||||
"copy": "复制",
|
||||
"download": "下载",
|
||||
"setCurrentImage": "设为当前图像",
|
||||
"preparingDownload": "准备下载",
|
||||
"preparingDownloadFailed": "准备下载时出现问题",
|
||||
"downloadSelection": "下载所选内容",
|
||||
"noImageSelected": "无选中的图像",
|
||||
"deleteSelection": "删除所选内容",
|
||||
@ -616,11 +614,9 @@
|
||||
"incompatibleBaseModelForControlAdapter": "有 #{{number}} 个 Control Adapter 模型与主模型不兼容。"
|
||||
},
|
||||
"patchmatchDownScaleSize": "缩小",
|
||||
"coherenceSteps": "步数",
|
||||
"clipSkip": "CLIP 跳过层",
|
||||
"compositingSettingsHeader": "合成设置",
|
||||
"useCpuNoise": "使用 CPU 噪声",
|
||||
"coherenceStrength": "强度",
|
||||
"enableNoiseSettings": "启用噪声设置",
|
||||
"coherenceMode": "模式",
|
||||
"cpuNoise": "CPU 噪声",
|
||||
@ -1402,19 +1398,6 @@
|
||||
"图像尺寸(单位:像素)建议 SD 1.5 模型使用等效 512x512 的尺寸,SDXL 模型使用等效 1024x1024 的尺寸。"
|
||||
]
|
||||
},
|
||||
"compositingCoherenceSteps": {
|
||||
"heading": "步数",
|
||||
"paragraphs": [
|
||||
"一致性层中使用的去噪步数。",
|
||||
"与主参数中的步数相同。"
|
||||
]
|
||||
},
|
||||
"compositingBlur": {
|
||||
"heading": "模糊",
|
||||
"paragraphs": [
|
||||
"遮罩模糊半径。"
|
||||
]
|
||||
},
|
||||
"noiseUseCPU": {
|
||||
"heading": "使用 CPU 噪声",
|
||||
"paragraphs": [
|
||||
@ -1467,13 +1450,6 @@
|
||||
"第二轮去噪有助于合成内补/外扩图像。"
|
||||
]
|
||||
},
|
||||
"compositingStrength": {
|
||||
"heading": "强度",
|
||||
"paragraphs": [
|
||||
"一致性层使用的去噪强度。",
|
||||
"去噪强度与图生图的参数相同。"
|
||||
]
|
||||
},
|
||||
"paramNegativeConditioning": {
|
||||
"paragraphs": [
|
||||
"生成过程会避免生成负向提示词中的概念。使用此选项来使输出排除部分质量或对象。",
|
||||
|
@ -55,6 +55,8 @@ import { addUpscaleRequestedListener } from 'app/store/middleware/listenerMiddle
|
||||
import { addWorkflowLoadRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested';
|
||||
import type { AppDispatch, RootState } from 'app/store/store';
|
||||
|
||||
import { addSetDefaultSettingsListener } from './listeners/setDefaultSettings';
|
||||
|
||||
export const listenerMiddleware = createListenerMiddleware();
|
||||
|
||||
export type AppStartListening = TypedStartListening<RootState, AppDispatch>;
|
||||
@ -153,3 +155,5 @@ addUpscaleRequestedListener(startAppListening);
|
||||
|
||||
// Dynamic prompts
|
||||
addDynamicPromptsListener(startAppListening);
|
||||
|
||||
addSetDefaultSettingsListener(startAppListening);
|
||||
|
@ -0,0 +1,96 @@
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { setDefaultSettings } from 'features/parameters/store/actions';
|
||||
import {
|
||||
setCfgRescaleMultiplier,
|
||||
setCfgScale,
|
||||
setScheduler,
|
||||
setSteps,
|
||||
vaePrecisionChanged,
|
||||
vaeSelected,
|
||||
} from 'features/parameters/store/generationSlice';
|
||||
import {
|
||||
isParameterCFGRescaleMultiplier,
|
||||
isParameterCFGScale,
|
||||
isParameterPrecision,
|
||||
isParameterScheduler,
|
||||
isParameterSteps,
|
||||
zParameterVAEModel,
|
||||
} from 'features/parameters/types/parameterSchemas';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { t } from 'i18next';
|
||||
import { map } from 'lodash-es';
|
||||
import { modelsApi } from 'services/api/endpoints/models';
|
||||
|
||||
export const addSetDefaultSettingsListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
actionCreator: setDefaultSettings,
|
||||
effect: async (action, { dispatch, getState }) => {
|
||||
const state = getState();
|
||||
|
||||
const currentModel = state.generation.model;
|
||||
|
||||
if (!currentModel) {
|
||||
return;
|
||||
}
|
||||
|
||||
const metadata = await dispatch(modelsApi.endpoints.getModelMetadata.initiate(currentModel.key)).unwrap();
|
||||
|
||||
if (!metadata || !metadata.default_settings) {
|
||||
return;
|
||||
}
|
||||
|
||||
const { vae, vae_precision, cfg_scale, cfg_rescale_multiplier, steps, scheduler } = metadata.default_settings;
|
||||
|
||||
if (vae) {
|
||||
// we store this as "default" within default settings
|
||||
// to distinguish it from no default set
|
||||
if (vae === 'default') {
|
||||
dispatch(vaeSelected(null));
|
||||
} else {
|
||||
const { data } = modelsApi.endpoints.getVaeModels.select()(state);
|
||||
const vaeArray = map(data?.entities);
|
||||
const validVae = vaeArray.find((model) => model.key === vae);
|
||||
|
||||
const result = zParameterVAEModel.safeParse(validVae);
|
||||
if (!result.success) {
|
||||
return;
|
||||
}
|
||||
dispatch(vaeSelected(result.data));
|
||||
}
|
||||
}
|
||||
|
||||
if (vae_precision) {
|
||||
if (isParameterPrecision(vae_precision)) {
|
||||
dispatch(vaePrecisionChanged(vae_precision));
|
||||
}
|
||||
}
|
||||
|
||||
if (cfg_scale) {
|
||||
if (isParameterCFGScale(cfg_scale)) {
|
||||
dispatch(setCfgScale(cfg_scale));
|
||||
}
|
||||
}
|
||||
|
||||
if (cfg_rescale_multiplier) {
|
||||
if (isParameterCFGRescaleMultiplier(cfg_rescale_multiplier)) {
|
||||
dispatch(setCfgRescaleMultiplier(cfg_rescale_multiplier));
|
||||
}
|
||||
}
|
||||
|
||||
if (steps) {
|
||||
if (isParameterSteps(steps)) {
|
||||
dispatch(setSteps(steps));
|
||||
}
|
||||
}
|
||||
|
||||
if (scheduler) {
|
||||
if (isParameterScheduler(scheduler)) {
|
||||
dispatch(setScheduler(scheduler));
|
||||
}
|
||||
}
|
||||
|
||||
dispatch(addToast(makeToast({ title: t('toast.parameterSet', { parameter: 'Default settings' }) })));
|
||||
},
|
||||
});
|
||||
};
|
@ -1,4 +1,5 @@
|
||||
import type { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants';
|
||||
import type { ParameterPrecision, ParameterScheduler } from 'features/parameters/types/parameterSchemas';
|
||||
import type { InvokeTabName } from 'features/ui/store/tabMap';
|
||||
import type { O } from 'ts-toolbelt';
|
||||
|
||||
@ -82,6 +83,8 @@ export type AppConfig = {
|
||||
guidance: NumericalParameterConfig;
|
||||
cfgRescaleMultiplier: NumericalParameterConfig;
|
||||
img2imgStrength: NumericalParameterConfig;
|
||||
scheduler?: ParameterScheduler;
|
||||
vaePrecision?: ParameterPrecision;
|
||||
// Canvas
|
||||
boundingBoxHeight: NumericalParameterConfig; // initial value comes from model
|
||||
boundingBoxWidth: NumericalParameterConfig; // initial value comes from model
|
||||
|
@ -59,7 +59,7 @@ const LoRASelect = () => {
|
||||
return (
|
||||
<FormControl isDisabled={!options.length}>
|
||||
<InformationalPopover feature="lora">
|
||||
<FormLabel>{t('models.lora')} </FormLabel>
|
||||
<FormLabel>{t('models.concepts')} </FormLabel>
|
||||
</InformationalPopover>
|
||||
<Combobox
|
||||
placeholder={placeholder}
|
||||
|
@ -15,7 +15,7 @@ const STATUSES = {
|
||||
const ImportQueueBadge = ({ status, errorReason }: { status?: ModelInstallStatus; errorReason?: string | null }) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
if (!status) {
|
||||
if (!status || !Object.keys(STATUSES).includes(status)) {
|
||||
return <></>;
|
||||
}
|
||||
|
||||
|
@ -8,7 +8,7 @@ export const ModelPane = () => {
|
||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||
return (
|
||||
<Box layerStyle="first" p={2} borderRadius="base" w="50%" h="full">
|
||||
{selectedModelKey ? <Model /> : <ImportModels />}
|
||||
{selectedModelKey ? <Model key={selectedModelKey} /> : <ImportModels />}
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
|
@ -0,0 +1,66 @@
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import Loading from 'common/components/Loading/Loading';
|
||||
import { selectConfigSlice } from 'features/system/store/configSlice';
|
||||
import { isNil } from 'lodash-es';
|
||||
import { useMemo } from 'react';
|
||||
import { useGetModelMetadataQuery } from 'services/api/endpoints/models';
|
||||
|
||||
import { DefaultSettingsForm } from './DefaultSettings/DefaultSettingsForm';
|
||||
|
||||
const initialStatesSelector = createMemoizedSelector(selectConfigSlice, (config) => {
|
||||
const { steps, guidance, scheduler, cfgRescaleMultiplier, vaePrecision } = config.sd;
|
||||
|
||||
return {
|
||||
initialSteps: steps.initial,
|
||||
initialCfg: guidance.initial,
|
||||
initialScheduler: scheduler,
|
||||
initialCfgRescaleMultiplier: cfgRescaleMultiplier.initial,
|
||||
initialVaePrecision: vaePrecision,
|
||||
};
|
||||
});
|
||||
|
||||
export const DefaultSettings = () => {
|
||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||
|
||||
const { data, isLoading } = useGetModelMetadataQuery(selectedModelKey ?? skipToken);
|
||||
const { initialSteps, initialCfg, initialScheduler, initialCfgRescaleMultiplier, initialVaePrecision } =
|
||||
useAppSelector(initialStatesSelector);
|
||||
|
||||
const defaultSettingsDefaults = useMemo(() => {
|
||||
return {
|
||||
vae: { isEnabled: !isNil(data?.default_settings?.vae), value: data?.default_settings?.vae || 'default' },
|
||||
vaePrecision: {
|
||||
isEnabled: !isNil(data?.default_settings?.vae_precision),
|
||||
value: data?.default_settings?.vae_precision || initialVaePrecision || 'fp32',
|
||||
},
|
||||
scheduler: {
|
||||
isEnabled: !isNil(data?.default_settings?.scheduler),
|
||||
value: data?.default_settings?.scheduler || initialScheduler || 'euler',
|
||||
},
|
||||
steps: { isEnabled: !isNil(data?.default_settings?.steps), value: data?.default_settings?.steps || initialSteps },
|
||||
cfgScale: {
|
||||
isEnabled: !isNil(data?.default_settings?.cfg_scale),
|
||||
value: data?.default_settings?.cfg_scale || initialCfg,
|
||||
},
|
||||
cfgRescaleMultiplier: {
|
||||
isEnabled: !isNil(data?.default_settings?.cfg_rescale_multiplier),
|
||||
value: data?.default_settings?.cfg_rescale_multiplier || initialCfgRescaleMultiplier,
|
||||
},
|
||||
};
|
||||
}, [
|
||||
data?.default_settings,
|
||||
initialSteps,
|
||||
initialCfg,
|
||||
initialScheduler,
|
||||
initialCfgRescaleMultiplier,
|
||||
initialVaePrecision,
|
||||
]);
|
||||
|
||||
if (isLoading) {
|
||||
return <Loading />;
|
||||
}
|
||||
|
||||
return <DefaultSettingsForm defaultSettingsDefaults={defaultSettingsDefaults} />;
|
||||
};
|
@ -0,0 +1,72 @@
|
||||
import { CompositeNumberInput, CompositeSlider, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { UseControllerProps } from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import type { DefaultSettingsFormData } from './DefaultSettingsForm';
|
||||
|
||||
type DefaultCfgRescaleMultiplierType = DefaultSettingsFormData['cfgRescaleMultiplier'];
|
||||
|
||||
export function DefaultCfgRescaleMultiplier(props: UseControllerProps<DefaultSettingsFormData>) {
|
||||
const { field } = useController(props);
|
||||
|
||||
const sliderMin = useAppSelector((s) => s.config.sd.cfgRescaleMultiplier.sliderMin);
|
||||
const sliderMax = useAppSelector((s) => s.config.sd.cfgRescaleMultiplier.sliderMax);
|
||||
const numberInputMin = useAppSelector((s) => s.config.sd.cfgRescaleMultiplier.numberInputMin);
|
||||
const numberInputMax = useAppSelector((s) => s.config.sd.cfgRescaleMultiplier.numberInputMax);
|
||||
const coarseStep = useAppSelector((s) => s.config.sd.cfgRescaleMultiplier.coarseStep);
|
||||
const fineStep = useAppSelector((s) => s.config.sd.cfgRescaleMultiplier.fineStep);
|
||||
const { t } = useTranslation();
|
||||
const marks = useMemo(() => [sliderMin, Math.floor(sliderMax / 2), sliderMax], [sliderMax, sliderMin]);
|
||||
|
||||
const onChange = useCallback(
|
||||
(v: number) => {
|
||||
const updatedValue = {
|
||||
...(field.value as DefaultCfgRescaleMultiplierType),
|
||||
value: v,
|
||||
};
|
||||
field.onChange(updatedValue);
|
||||
},
|
||||
[field]
|
||||
);
|
||||
|
||||
const value = useMemo(() => {
|
||||
return (field.value as DefaultCfgRescaleMultiplierType).value;
|
||||
}, [field.value]);
|
||||
|
||||
const isDisabled = useMemo(() => {
|
||||
return !(field.value as DefaultCfgRescaleMultiplierType).isEnabled;
|
||||
}, [field.value]);
|
||||
|
||||
return (
|
||||
<FormControl flexDir="column" gap={1} alignItems="flex-start">
|
||||
<InformationalPopover feature="paramCFGRescaleMultiplier">
|
||||
<FormLabel>{t('parameters.cfgRescaleMultiplier')}</FormLabel>
|
||||
</InformationalPopover>
|
||||
<Flex w="full" gap={1}>
|
||||
<CompositeSlider
|
||||
value={value}
|
||||
min={sliderMin}
|
||||
max={sliderMax}
|
||||
step={coarseStep}
|
||||
fineStep={fineStep}
|
||||
onChange={onChange}
|
||||
marks={marks}
|
||||
isDisabled={isDisabled}
|
||||
/>
|
||||
<CompositeNumberInput
|
||||
value={value}
|
||||
min={numberInputMin}
|
||||
max={numberInputMax}
|
||||
step={coarseStep}
|
||||
fineStep={fineStep}
|
||||
onChange={onChange}
|
||||
isDisabled={isDisabled}
|
||||
/>
|
||||
</Flex>
|
||||
</FormControl>
|
||||
);
|
||||
}
|
@ -0,0 +1,72 @@
|
||||
import { CompositeNumberInput, CompositeSlider, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { UseControllerProps } from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import type { DefaultSettingsFormData } from './DefaultSettingsForm';
|
||||
|
||||
type DefaultCfgType = DefaultSettingsFormData['cfgScale'];
|
||||
|
||||
export function DefaultCfgScale(props: UseControllerProps<DefaultSettingsFormData>) {
|
||||
const { field } = useController(props);
|
||||
|
||||
const sliderMin = useAppSelector((s) => s.config.sd.guidance.sliderMin);
|
||||
const sliderMax = useAppSelector((s) => s.config.sd.guidance.sliderMax);
|
||||
const numberInputMin = useAppSelector((s) => s.config.sd.guidance.numberInputMin);
|
||||
const numberInputMax = useAppSelector((s) => s.config.sd.guidance.numberInputMax);
|
||||
const coarseStep = useAppSelector((s) => s.config.sd.guidance.coarseStep);
|
||||
const fineStep = useAppSelector((s) => s.config.sd.guidance.fineStep);
|
||||
const { t } = useTranslation();
|
||||
const marks = useMemo(() => [sliderMin, Math.floor(sliderMax / 2), sliderMax], [sliderMax, sliderMin]);
|
||||
|
||||
const onChange = useCallback(
|
||||
(v: number) => {
|
||||
const updatedValue = {
|
||||
...(field.value as DefaultCfgType),
|
||||
value: v,
|
||||
};
|
||||
field.onChange(updatedValue);
|
||||
},
|
||||
[field]
|
||||
);
|
||||
|
||||
const value = useMemo(() => {
|
||||
return (field.value as DefaultCfgType).value;
|
||||
}, [field.value]);
|
||||
|
||||
const isDisabled = useMemo(() => {
|
||||
return !(field.value as DefaultCfgType).isEnabled;
|
||||
}, [field.value]);
|
||||
|
||||
return (
|
||||
<FormControl flexDir="column" gap={1} alignItems="flex-start">
|
||||
<InformationalPopover feature="paramCFGScale">
|
||||
<FormLabel>{t('parameters.cfgScale')}</FormLabel>
|
||||
</InformationalPopover>
|
||||
<Flex w="full" gap={1}>
|
||||
<CompositeSlider
|
||||
value={value}
|
||||
min={sliderMin}
|
||||
max={sliderMax}
|
||||
step={coarseStep}
|
||||
fineStep={fineStep}
|
||||
onChange={onChange}
|
||||
marks={marks}
|
||||
isDisabled={isDisabled}
|
||||
/>
|
||||
<CompositeNumberInput
|
||||
value={value}
|
||||
min={numberInputMin}
|
||||
max={numberInputMax}
|
||||
step={coarseStep}
|
||||
fineStep={fineStep}
|
||||
onChange={onChange}
|
||||
isDisabled={isDisabled}
|
||||
/>
|
||||
</Flex>
|
||||
</FormControl>
|
||||
);
|
||||
}
|
@ -0,0 +1,50 @@
|
||||
import type { ComboboxOnChange } from '@invoke-ai/ui-library';
|
||||
import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import { SCHEDULER_OPTIONS } from 'features/parameters/types/constants';
|
||||
import { isParameterScheduler } from 'features/parameters/types/parameterSchemas';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { UseControllerProps } from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import type { DefaultSettingsFormData } from './DefaultSettingsForm';
|
||||
|
||||
type DefaultSchedulerType = DefaultSettingsFormData['scheduler'];
|
||||
|
||||
export function DefaultScheduler(props: UseControllerProps<DefaultSettingsFormData>) {
|
||||
const { t } = useTranslation();
|
||||
const { field } = useController(props);
|
||||
|
||||
const onChange = useCallback<ComboboxOnChange>(
|
||||
(v) => {
|
||||
if (!isParameterScheduler(v?.value)) {
|
||||
return;
|
||||
}
|
||||
const updatedValue = {
|
||||
...(field.value as DefaultSchedulerType),
|
||||
value: v.value,
|
||||
};
|
||||
field.onChange(updatedValue);
|
||||
},
|
||||
[field]
|
||||
);
|
||||
|
||||
const value = useMemo(
|
||||
() => SCHEDULER_OPTIONS.find((o) => o.value === (field.value as DefaultSchedulerType).value),
|
||||
[field]
|
||||
);
|
||||
|
||||
const isDisabled = useMemo(() => {
|
||||
return !(field.value as DefaultSchedulerType).isEnabled;
|
||||
}, [field.value]);
|
||||
|
||||
return (
|
||||
<FormControl flexDir="column" gap={1} alignItems="flex-start">
|
||||
<InformationalPopover feature="paramScheduler">
|
||||
<FormLabel>{t('parameters.scheduler')}</FormLabel>
|
||||
</InformationalPopover>
|
||||
<Combobox isDisabled={isDisabled} value={value} options={SCHEDULER_OPTIONS} onChange={onChange} />
|
||||
</FormControl>
|
||||
);
|
||||
}
|
@ -0,0 +1,147 @@
|
||||
import { Button, Flex, Heading } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import type { ParameterScheduler } from 'features/parameters/types/parameterSchemas';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { useCallback } from 'react';
|
||||
import type { SubmitHandler } from 'react-hook-form';
|
||||
import { useForm } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { IoPencil } from 'react-icons/io5';
|
||||
import { useUpdateModelMetadataMutation } from 'services/api/endpoints/models';
|
||||
|
||||
import { DefaultCfgRescaleMultiplier } from './DefaultCfgRescaleMultiplier';
|
||||
import { DefaultCfgScale } from './DefaultCfgScale';
|
||||
import { DefaultScheduler } from './DefaultScheduler';
|
||||
import { DefaultSteps } from './DefaultSteps';
|
||||
import { DefaultVae } from './DefaultVae';
|
||||
import { DefaultVaePrecision } from './DefaultVaePrecision';
|
||||
import { SettingToggle } from './SettingToggle';
|
||||
|
||||
export interface FormField<T> {
|
||||
value: T;
|
||||
isEnabled: boolean;
|
||||
}
|
||||
|
||||
export type DefaultSettingsFormData = {
|
||||
vae: FormField<string>;
|
||||
vaePrecision: FormField<string>;
|
||||
scheduler: FormField<ParameterScheduler>;
|
||||
steps: FormField<number>;
|
||||
cfgScale: FormField<number>;
|
||||
cfgRescaleMultiplier: FormField<number>;
|
||||
};
|
||||
|
||||
export const DefaultSettingsForm = ({
|
||||
defaultSettingsDefaults,
|
||||
}: {
|
||||
defaultSettingsDefaults: DefaultSettingsFormData;
|
||||
}) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||
|
||||
const [editModelMetadata, { isLoading }] = useUpdateModelMetadataMutation();
|
||||
|
||||
const { handleSubmit, control, formState } = useForm<DefaultSettingsFormData>({
|
||||
defaultValues: defaultSettingsDefaults,
|
||||
});
|
||||
|
||||
const onSubmit = useCallback<SubmitHandler<DefaultSettingsFormData>>(
|
||||
(data) => {
|
||||
if (!selectedModelKey) {
|
||||
return;
|
||||
}
|
||||
|
||||
const body = {
|
||||
vae: data.vae.isEnabled ? data.vae.value : null,
|
||||
vae_precision: data.vaePrecision.isEnabled ? data.vaePrecision.value : null,
|
||||
cfg_scale: data.cfgScale.isEnabled ? data.cfgScale.value : null,
|
||||
cfg_rescale_multiplier: data.cfgRescaleMultiplier.isEnabled ? data.cfgRescaleMultiplier.value : null,
|
||||
steps: data.steps.isEnabled ? data.steps.value : null,
|
||||
scheduler: data.scheduler.isEnabled ? data.scheduler.value : null,
|
||||
};
|
||||
|
||||
editModelMetadata({
|
||||
key: selectedModelKey,
|
||||
body: { default_settings: body },
|
||||
})
|
||||
.unwrap()
|
||||
.then((_) => {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('modelManager.defaultSettingsSaved'),
|
||||
status: 'success',
|
||||
})
|
||||
)
|
||||
);
|
||||
})
|
||||
.catch((error) => {
|
||||
if (error) {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: `${error.data.detail} `,
|
||||
status: 'error',
|
||||
})
|
||||
)
|
||||
);
|
||||
}
|
||||
});
|
||||
},
|
||||
[selectedModelKey, dispatch, editModelMetadata, t]
|
||||
);
|
||||
|
||||
return (
|
||||
<>
|
||||
<Flex gap="2" justifyContent="space-between" w="full" mb={5}>
|
||||
<Heading fontSize="md">{t('modelManager.defaultSettings')}</Heading>
|
||||
<Button
|
||||
size="sm"
|
||||
leftIcon={<IoPencil />}
|
||||
colorScheme="invokeYellow"
|
||||
isDisabled={!formState.isDirty}
|
||||
onClick={handleSubmit(onSubmit)}
|
||||
type="submit"
|
||||
isLoading={isLoading}
|
||||
>
|
||||
{t('common.save')}
|
||||
</Button>
|
||||
</Flex>
|
||||
|
||||
<Flex flexDir="column" gap={8}>
|
||||
<Flex gap={8}>
|
||||
<Flex gap={4} w="full">
|
||||
<SettingToggle control={control} name="vae" />
|
||||
<DefaultVae control={control} name="vae" />
|
||||
</Flex>
|
||||
<Flex gap={4} w="full">
|
||||
<SettingToggle control={control} name="vaePrecision" />
|
||||
<DefaultVaePrecision control={control} name="vaePrecision" />
|
||||
</Flex>
|
||||
</Flex>
|
||||
<Flex gap={8}>
|
||||
<Flex gap={4} w="full">
|
||||
<SettingToggle control={control} name="scheduler" />
|
||||
<DefaultScheduler control={control} name="scheduler" />
|
||||
</Flex>
|
||||
<Flex gap={4} w="full">
|
||||
<SettingToggle control={control} name="steps" />
|
||||
<DefaultSteps control={control} name="steps" />
|
||||
</Flex>
|
||||
</Flex>
|
||||
<Flex gap={8}>
|
||||
<Flex gap={4} w="full">
|
||||
<SettingToggle control={control} name="cfgScale" />
|
||||
<DefaultCfgScale control={control} name="cfgScale" />
|
||||
</Flex>
|
||||
<Flex gap={4} w="full">
|
||||
<SettingToggle control={control} name="cfgRescaleMultiplier" />
|
||||
<DefaultCfgRescaleMultiplier control={control} name="cfgRescaleMultiplier" />
|
||||
</Flex>
|
||||
</Flex>
|
||||
</Flex>
|
||||
</>
|
||||
);
|
||||
};
|
@ -0,0 +1,72 @@
|
||||
import { CompositeNumberInput, CompositeSlider, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { UseControllerProps } from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import type { DefaultSettingsFormData } from './DefaultSettingsForm';
|
||||
|
||||
type DefaultSteps = DefaultSettingsFormData['steps'];
|
||||
|
||||
export function DefaultSteps(props: UseControllerProps<DefaultSettingsFormData>) {
|
||||
const { field } = useController(props);
|
||||
|
||||
const sliderMin = useAppSelector((s) => s.config.sd.steps.sliderMin);
|
||||
const sliderMax = useAppSelector((s) => s.config.sd.steps.sliderMax);
|
||||
const numberInputMin = useAppSelector((s) => s.config.sd.steps.numberInputMin);
|
||||
const numberInputMax = useAppSelector((s) => s.config.sd.steps.numberInputMax);
|
||||
const coarseStep = useAppSelector((s) => s.config.sd.steps.coarseStep);
|
||||
const fineStep = useAppSelector((s) => s.config.sd.steps.fineStep);
|
||||
const { t } = useTranslation();
|
||||
const marks = useMemo(() => [sliderMin, Math.floor(sliderMax / 2), sliderMax], [sliderMax, sliderMin]);
|
||||
|
||||
const onChange = useCallback(
|
||||
(v: number) => {
|
||||
const updatedValue = {
|
||||
...(field.value as DefaultSteps),
|
||||
value: v,
|
||||
};
|
||||
field.onChange(updatedValue);
|
||||
},
|
||||
[field]
|
||||
);
|
||||
|
||||
const value = useMemo(() => {
|
||||
return (field.value as DefaultSteps).value;
|
||||
}, [field.value]);
|
||||
|
||||
const isDisabled = useMemo(() => {
|
||||
return !(field.value as DefaultSteps).isEnabled;
|
||||
}, [field.value]);
|
||||
|
||||
return (
|
||||
<FormControl flexDir="column" gap={1} alignItems="flex-start">
|
||||
<InformationalPopover feature="paramSteps">
|
||||
<FormLabel>{t('parameters.steps')}</FormLabel>
|
||||
</InformationalPopover>
|
||||
<Flex w="full" gap={1}>
|
||||
<CompositeSlider
|
||||
value={value}
|
||||
min={sliderMin}
|
||||
max={sliderMax}
|
||||
step={coarseStep}
|
||||
fineStep={fineStep}
|
||||
onChange={onChange}
|
||||
marks={marks}
|
||||
isDisabled={isDisabled}
|
||||
/>
|
||||
<CompositeNumberInput
|
||||
value={value}
|
||||
min={numberInputMin}
|
||||
max={numberInputMax}
|
||||
step={coarseStep}
|
||||
fineStep={fineStep}
|
||||
onChange={onChange}
|
||||
isDisabled={isDisabled}
|
||||
/>
|
||||
</Flex>
|
||||
</FormControl>
|
||||
);
|
||||
}
|
@ -0,0 +1,65 @@
|
||||
import type { ComboboxOnChange } from '@invoke-ai/ui-library';
|
||||
import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import { map } from 'lodash-es';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { UseControllerProps } from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetModelConfigQuery, useGetVaeModelsQuery } from 'services/api/endpoints/models';
|
||||
|
||||
import type { DefaultSettingsFormData } from './DefaultSettingsForm';
|
||||
|
||||
type DefaultVaeType = DefaultSettingsFormData['vae'];
|
||||
|
||||
export function DefaultVae(props: UseControllerProps<DefaultSettingsFormData>) {
|
||||
const { t } = useTranslation();
|
||||
const { field } = useController(props);
|
||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||
const { data: modelData } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
|
||||
|
||||
const { compatibleOptions } = useGetVaeModelsQuery(undefined, {
|
||||
selectFromResult: ({ data }) => {
|
||||
const modelArray = map(data?.entities);
|
||||
const compatibleOptions = modelArray
|
||||
.filter((vae) => vae.base === modelData?.base)
|
||||
.map((vae) => ({ label: vae.name, value: vae.key }));
|
||||
|
||||
const defaultOption = { label: 'Default VAE', value: 'default' };
|
||||
|
||||
return { compatibleOptions: [defaultOption, ...compatibleOptions] };
|
||||
},
|
||||
});
|
||||
|
||||
const onChange = useCallback<ComboboxOnChange>(
|
||||
(v) => {
|
||||
const newValue = !v?.value ? 'default' : v.value;
|
||||
|
||||
const updatedValue = {
|
||||
...(field.value as DefaultVaeType),
|
||||
value: newValue,
|
||||
};
|
||||
field.onChange(updatedValue);
|
||||
},
|
||||
[field]
|
||||
);
|
||||
|
||||
const value = useMemo(() => {
|
||||
return compatibleOptions.find((vae) => vae.value === (field.value as DefaultVaeType).value);
|
||||
}, [compatibleOptions, field.value]);
|
||||
|
||||
const isDisabled = useMemo(() => {
|
||||
return !(field.value as DefaultVaeType).isEnabled;
|
||||
}, [field.value]);
|
||||
|
||||
return (
|
||||
<FormControl flexDir="column" gap={1} alignItems="flex-start">
|
||||
<InformationalPopover feature="paramVAE">
|
||||
<FormLabel>{t('modelManager.vae')}</FormLabel>
|
||||
</InformationalPopover>
|
||||
<Combobox isDisabled={isDisabled} value={value} options={compatibleOptions} onChange={onChange} />
|
||||
</FormControl>
|
||||
);
|
||||
}
|
@ -0,0 +1,51 @@
|
||||
import type { ComboboxOnChange } from '@invoke-ai/ui-library';
|
||||
import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import { isParameterPrecision } from 'features/parameters/types/parameterSchemas';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { UseControllerProps } from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import type { DefaultSettingsFormData } from './DefaultSettingsForm';
|
||||
|
||||
const options = [
|
||||
{ label: 'FP16', value: 'fp16' },
|
||||
{ label: 'FP32', value: 'fp32' },
|
||||
];
|
||||
|
||||
type DefaultVaePrecisionType = DefaultSettingsFormData['vaePrecision'];
|
||||
|
||||
export function DefaultVaePrecision(props: UseControllerProps<DefaultSettingsFormData>) {
|
||||
const { t } = useTranslation();
|
||||
const { field } = useController(props);
|
||||
|
||||
const onChange = useCallback<ComboboxOnChange>(
|
||||
(v) => {
|
||||
if (!isParameterPrecision(v?.value)) {
|
||||
return;
|
||||
}
|
||||
const updatedValue = {
|
||||
...(field.value as DefaultVaePrecisionType),
|
||||
value: v.value,
|
||||
};
|
||||
field.onChange(updatedValue);
|
||||
},
|
||||
[field]
|
||||
);
|
||||
|
||||
const value = useMemo(() => options.find((o) => o.value === (field.value as DefaultVaePrecisionType).value), [field]);
|
||||
|
||||
const isDisabled = useMemo(() => {
|
||||
return !(field.value as DefaultVaePrecisionType).isEnabled;
|
||||
}, [field.value]);
|
||||
|
||||
return (
|
||||
<FormControl flexDir="column" gap={1} alignItems="flex-start">
|
||||
<InformationalPopover feature="paramVAEPrecision">
|
||||
<FormLabel>{t('modelManager.vaePrecision')}</FormLabel>
|
||||
</InformationalPopover>
|
||||
<Combobox isDisabled={isDisabled} value={value} options={options} onChange={onChange} />
|
||||
</FormControl>
|
||||
);
|
||||
}
|
@ -0,0 +1,28 @@
|
||||
import { Switch } from '@invoke-ai/ui-library';
|
||||
import type { ChangeEvent } from 'react';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { UseControllerProps } from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
|
||||
import type { DefaultSettingsFormData, FormField } from './DefaultSettingsForm';
|
||||
|
||||
export function SettingToggle<T>(props: UseControllerProps<DefaultSettingsFormData>) {
|
||||
const { field } = useController(props);
|
||||
|
||||
const value = useMemo(() => {
|
||||
return !!(field.value as FormField<T>).isEnabled;
|
||||
}, [field.value]);
|
||||
|
||||
const onChange = useCallback(
|
||||
(e: ChangeEvent<HTMLInputElement>) => {
|
||||
const updatedValue: FormField<T> = {
|
||||
...(field.value as FormField<T>),
|
||||
isEnabled: e.target.checked,
|
||||
};
|
||||
field.onChange(updatedValue);
|
||||
},
|
||||
[field]
|
||||
);
|
||||
|
||||
return <Switch isChecked={value} onChange={onChange} />;
|
||||
}
|
@ -0,0 +1,18 @@
|
||||
import { Flex } from '@invoke-ai/ui-library';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataViewer';
|
||||
import { useGetModelMetadataQuery } from 'services/api/endpoints/models';
|
||||
|
||||
export const ModelMetadata = () => {
|
||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||
const { data: metadata } = useGetModelMetadataQuery(selectedModelKey ?? skipToken);
|
||||
|
||||
return (
|
||||
<>
|
||||
<Flex flexDir="column" height="full" gap="3">
|
||||
<DataViewer label="metadata" data={metadata || {}} />
|
||||
</Flex>
|
||||
</>
|
||||
);
|
||||
};
|
@ -1,9 +1,58 @@
|
||||
import { Box, Flex, Heading, Tab, TabList, TabPanel, TabPanels, Tabs, Text } from '@invoke-ai/ui-library';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
|
||||
|
||||
import { ModelMetadata } from './Metadata/ModelMetadata';
|
||||
import { ModelAttrView } from './ModelAttrView';
|
||||
import { ModelEdit } from './ModelEdit';
|
||||
import { ModelView } from './ModelView';
|
||||
|
||||
export const Model = () => {
|
||||
const { t } = useTranslation();
|
||||
const selectedModelMode = useAppSelector((s) => s.modelmanagerV2.selectedModelMode);
|
||||
return selectedModelMode === 'view' ? <ModelView /> : <ModelEdit />;
|
||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||
const { data, isLoading } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
|
||||
|
||||
if (isLoading) {
|
||||
return <Text>{t('common.loading')}</Text>;
|
||||
}
|
||||
|
||||
if (!data) {
|
||||
return <Text>{t('common.somethingWentWrong')}</Text>;
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<Flex flexDir="column" gap={1} p={2}>
|
||||
<Heading as="h2" fontSize="lg">
|
||||
{data.name}
|
||||
</Heading>
|
||||
|
||||
{data.source && (
|
||||
<Text variant="subtext">
|
||||
{t('modelManager.source')}: {data?.source}
|
||||
</Text>
|
||||
)}
|
||||
<Box mt="4">
|
||||
<ModelAttrView label="Description" value={data.description} />
|
||||
</Box>
|
||||
</Flex>
|
||||
|
||||
<Tabs mt="4" h="100%">
|
||||
<TabList>
|
||||
<Tab>{t('modelManager.settings')}</Tab>
|
||||
<Tab>{t('modelManager.metadata')}</Tab>
|
||||
</TabList>
|
||||
|
||||
<TabPanels h="100%">
|
||||
<TabPanel>{selectedModelMode === 'view' ? <ModelView /> : <ModelEdit />}</TabPanel>
|
||||
<TabPanel h="full">
|
||||
<ModelMetadata />
|
||||
</TabPanel>
|
||||
</TabPanels>
|
||||
</Tabs>
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
@ -1,12 +1,11 @@
|
||||
import { Box, Button, Flex, Heading, Text } from '@invoke-ai/ui-library';
|
||||
import { Box, Button, Flex, Text } from '@invoke-ai/ui-library';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataViewer';
|
||||
import { setSelectedModelMode } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { IoPencil } from 'react-icons/io5';
|
||||
import { useGetModelConfigQuery, useGetModelMetadataQuery } from 'services/api/endpoints/models';
|
||||
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
|
||||
import type {
|
||||
CheckpointModelConfig,
|
||||
ControlNetModelConfig,
|
||||
@ -18,6 +17,7 @@ import type {
|
||||
VAEModelConfig,
|
||||
} from 'services/api/types';
|
||||
|
||||
import { DefaultSettings } from './DefaultSettings';
|
||||
import { ModelAttrView } from './ModelAttrView';
|
||||
import { ModelConvert } from './ModelConvert';
|
||||
|
||||
@ -26,7 +26,6 @@ export const ModelView = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||
const { data, isLoading } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
|
||||
const { data: metadata } = useGetModelMetadataQuery(selectedModelKey ?? skipToken);
|
||||
|
||||
const modelData = useMemo(() => {
|
||||
if (!data) {
|
||||
@ -73,85 +72,56 @@ export const ModelView = () => {
|
||||
return <Text>{t('common.somethingWentWrong')}</Text>;
|
||||
}
|
||||
return (
|
||||
<Flex flexDir="column" h="full">
|
||||
<Flex w="full" justifyContent="space-between">
|
||||
<Flex flexDir="column" gap={1} p={2}>
|
||||
<Heading as="h2" fontSize="lg">
|
||||
{modelData.name}
|
||||
</Heading>
|
||||
|
||||
{modelData.source && (
|
||||
<Text variant="subtext">
|
||||
{t('modelManager.source')}: {modelData.source}
|
||||
</Text>
|
||||
)}
|
||||
</Flex>
|
||||
<Flex gap={2}>
|
||||
<Flex flexDir="column" h="full" gap="2">
|
||||
<Box layerStyle="second" borderRadius="base" p={3}>
|
||||
<Flex gap="2" justifyContent="flex-end" w="full">
|
||||
<Button size="sm" leftIcon={<IoPencil />} colorScheme="invokeYellow" onClick={handleEditModel}>
|
||||
{t('modelManager.edit')}
|
||||
</Button>
|
||||
|
||||
{modelData.type === 'main' && modelData.format === 'checkpoint' && <ModelConvert model={modelData} />}
|
||||
</Flex>
|
||||
</Flex>
|
||||
|
||||
<Flex flexDir="column" p={2} gap={3}>
|
||||
<Flex>
|
||||
<ModelAttrView label="Description" value={modelData.description} />
|
||||
</Flex>
|
||||
<Heading as="h3" fontSize="md" mt="4">
|
||||
{t('modelManager.modelSettings')}
|
||||
</Heading>
|
||||
<Box layerStyle="second" borderRadius="base" p={3}>
|
||||
<Flex flexDir="column" gap={3}>
|
||||
<Flex gap={2}>
|
||||
<ModelAttrView label={t('modelManager.baseModel')} value={modelData.base} />
|
||||
<ModelAttrView label={t('modelManager.modelType')} value={modelData.type} />
|
||||
</Flex>
|
||||
<Flex gap={2}>
|
||||
<ModelAttrView label={t('common.format')} value={modelData.format} />
|
||||
<ModelAttrView label={t('modelManager.path')} value={modelData.path} />
|
||||
</Flex>
|
||||
{modelData.type === 'main' && (
|
||||
<>
|
||||
<Flex gap={2}>
|
||||
{modelData.format === 'diffusers' && (
|
||||
<ModelAttrView label={t('modelManager.repoVariant')} value={modelData.repo_variant} />
|
||||
)}
|
||||
{modelData.format === 'checkpoint' && (
|
||||
<ModelAttrView label={t('modelManager.pathToConfig')} value={modelData.config} />
|
||||
)}
|
||||
|
||||
<ModelAttrView label={t('modelManager.variant')} value={modelData.variant} />
|
||||
</Flex>
|
||||
<Flex gap={2}>
|
||||
<ModelAttrView label={t('modelManager.predictionType')} value={modelData.prediction_type} />
|
||||
<ModelAttrView label={t('modelManager.upcastAttention')} value={`${modelData.upcast_attention}`} />
|
||||
</Flex>
|
||||
<Flex gap={2}>
|
||||
<ModelAttrView label={t('modelManager.ztsnrTraining')} value={`${modelData.ztsnr_training}`} />
|
||||
<ModelAttrView label={t('modelManager.vae')} value={modelData.vae} />
|
||||
</Flex>
|
||||
</>
|
||||
)}
|
||||
{modelData.type === 'ip_adapter' && (
|
||||
<Flex flexDir="column" gap={3}>
|
||||
<Flex gap={2}>
|
||||
<ModelAttrView label={t('modelManager.baseModel')} value={modelData.base} />
|
||||
<ModelAttrView label={t('modelManager.modelType')} value={modelData.type} />
|
||||
</Flex>
|
||||
<Flex gap={2}>
|
||||
<ModelAttrView label={t('common.format')} value={modelData.format} />
|
||||
<ModelAttrView label={t('modelManager.path')} value={modelData.path} />
|
||||
</Flex>
|
||||
{modelData.type === 'main' && (
|
||||
<>
|
||||
<Flex gap={2}>
|
||||
<ModelAttrView label={t('modelManager.imageEncoderModelId')} value={modelData.image_encoder_model_id} />
|
||||
</Flex>
|
||||
)}
|
||||
</Flex>
|
||||
</Box>
|
||||
</Flex>
|
||||
{modelData.format === 'diffusers' && (
|
||||
<ModelAttrView label={t('modelManager.repoVariant')} value={modelData.repo_variant} />
|
||||
)}
|
||||
{modelData.format === 'checkpoint' && (
|
||||
<ModelAttrView label={t('modelManager.pathToConfig')} value={modelData.config} />
|
||||
)}
|
||||
|
||||
{metadata && (
|
||||
<>
|
||||
<Heading as="h3" fontSize="md" mt="4">
|
||||
{t('modelManager.modelMetadata')}
|
||||
</Heading>
|
||||
<Flex h="full" w="full" p={2}>
|
||||
<DataViewer label="metadata" data={metadata} />
|
||||
</Flex>
|
||||
</>
|
||||
)}
|
||||
<ModelAttrView label={t('modelManager.variant')} value={modelData.variant} />
|
||||
</Flex>
|
||||
<Flex gap={2}>
|
||||
<ModelAttrView label={t('modelManager.predictionType')} value={modelData.prediction_type} />
|
||||
<ModelAttrView label={t('modelManager.upcastAttention')} value={`${modelData.upcast_attention}`} />
|
||||
</Flex>
|
||||
<Flex gap={2}>
|
||||
<ModelAttrView label={t('modelManager.ztsnrTraining')} value={`${modelData.ztsnr_training}`} />
|
||||
<ModelAttrView label={t('modelManager.vae')} value={modelData.vae} />
|
||||
</Flex>
|
||||
</>
|
||||
)}
|
||||
{modelData.type === 'ip_adapter' && (
|
||||
<Flex gap={2}>
|
||||
<ModelAttrView label={t('modelManager.imageEncoderModelId')} value={modelData.image_encoder_model_id} />
|
||||
</Flex>
|
||||
)}
|
||||
</Flex>
|
||||
</Box>
|
||||
<Box layerStyle="second" borderRadius="base" p={3}>
|
||||
<DefaultSettings />
|
||||
</Box>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
@ -344,8 +344,8 @@ export const buildCanvasInpaintGraph = (
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: MASK_RESIZE_UP,
|
||||
field: 'image',
|
||||
node_id: INPAINT_CREATE_MASK,
|
||||
field: 'expanded_mask_area',
|
||||
},
|
||||
destination: {
|
||||
node_id: MASK_RESIZE_DOWN,
|
||||
|
@ -439,8 +439,8 @@ export const buildCanvasOutpaintGraph = (
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: MASK_RESIZE_UP,
|
||||
field: 'image',
|
||||
node_id: INPAINT_CREATE_MASK,
|
||||
field: 'expanded_mask_area',
|
||||
},
|
||||
destination: {
|
||||
node_id: MASK_RESIZE_DOWN,
|
||||
|
@ -355,8 +355,8 @@ export const buildCanvasSDXLInpaintGraph = (
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: MASK_RESIZE_UP,
|
||||
field: 'image',
|
||||
node_id: INPAINT_CREATE_MASK,
|
||||
field: 'expanded_mask_area',
|
||||
},
|
||||
destination: {
|
||||
node_id: MASK_RESIZE_DOWN,
|
||||
|
@ -448,8 +448,8 @@ export const buildCanvasSDXLOutpaintGraph = (
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: MASK_RESIZE_UP,
|
||||
field: 'image',
|
||||
node_id: INPAINT_CREATE_MASK,
|
||||
field: 'expanded_mask_area',
|
||||
},
|
||||
destination: {
|
||||
node_id: MASK_RESIZE_DOWN,
|
||||
|
@ -0,0 +1,36 @@
|
||||
import type { IconButtonProps } from '@invoke-ai/ui-library';
|
||||
import { IconButton } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { setActiveTab } from 'features/ui/store/uiSlice';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiGearSixBold } from 'react-icons/pi';
|
||||
|
||||
export const NavigateToModelManagerButton = memo((props: Omit<IconButtonProps, 'aria-label'>) => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const disabledTabs = useAppSelector((s) => s.config.disabledTabs);
|
||||
const shouldShowButton = useMemo(() => !disabledTabs.includes('modelManager'), [disabledTabs]);
|
||||
|
||||
const handleClick = useCallback(() => {
|
||||
dispatch(setActiveTab('modelManager'));
|
||||
}, [dispatch]);
|
||||
|
||||
if (!shouldShowButton) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<IconButton
|
||||
icon={<PiGearSixBold />}
|
||||
tooltip={t('modelManager.modelManager')}
|
||||
aria-label={t('modelManager.modelManager')}
|
||||
onClick={handleClick}
|
||||
size="sm"
|
||||
variant="ghost"
|
||||
{...props}
|
||||
/>
|
||||
);
|
||||
});
|
||||
|
||||
NavigateToModelManagerButton.displayName = 'NavigateToModelManagerButton';
|
@ -0,0 +1,28 @@
|
||||
import { IconButton } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { setDefaultSettings } from 'features/parameters/store/actions';
|
||||
import { useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { RiSparklingFill } from 'react-icons/ri';
|
||||
|
||||
export const UseDefaultSettingsButton = () => {
|
||||
const model = useAppSelector((s) => s.generation.model);
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const handleClickDefaultSettings = useCallback(() => {
|
||||
dispatch(setDefaultSettings());
|
||||
}, [dispatch]);
|
||||
|
||||
return (
|
||||
<IconButton
|
||||
icon={<RiSparklingFill />}
|
||||
tooltip={t('modelManager.useDefaultSettings')}
|
||||
aria-label={t('modelManager.useDefaultSettings')}
|
||||
isDisabled={!model}
|
||||
onClick={handleClickDefaultSettings}
|
||||
size="sm"
|
||||
variant="ghost"
|
||||
/>
|
||||
);
|
||||
};
|
@ -5,3 +5,5 @@ import type { ImageDTO } from 'services/api/types';
|
||||
export const initialImageSelected = createAction<ImageDTO | undefined>('generation/initialImageSelected');
|
||||
|
||||
export const modelSelected = createAction<ParameterModel>('generation/modelSelected');
|
||||
|
||||
export const setDefaultSettings = createAction('generation/setDefaultSettings');
|
||||
|
@ -230,6 +230,12 @@ export const generationSlice = createSlice({
|
||||
state.height = optimalDimension;
|
||||
}
|
||||
}
|
||||
if (action.payload.sd?.scheduler) {
|
||||
state.scheduler = action.payload.sd.scheduler;
|
||||
}
|
||||
if (action.payload.sd?.vaePrecision) {
|
||||
state.vaePrecision = action.payload.sd.vaePrecision;
|
||||
}
|
||||
});
|
||||
|
||||
// TODO: This is a temp fix to reduce issues with T2I adapter having a different downscaling
|
||||
|
@ -1,15 +1,5 @@
|
||||
import type { FormLabelProps } from '@invoke-ai/ui-library';
|
||||
import {
|
||||
Expander,
|
||||
Flex,
|
||||
FormControlGroup,
|
||||
StandaloneAccordion,
|
||||
Tab,
|
||||
TabList,
|
||||
TabPanel,
|
||||
TabPanels,
|
||||
Tabs,
|
||||
} from '@invoke-ai/ui-library';
|
||||
import { Box, Expander, Flex, FormControlGroup, StandaloneAccordion } from '@invoke-ai/ui-library';
|
||||
import { EMPTY_ARRAY } from 'app/store/constants';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
@ -20,7 +10,9 @@ import { SyncModelsIconButton } from 'features/modelManagerV2/components/SyncMod
|
||||
import ParamCFGScale from 'features/parameters/components/Core/ParamCFGScale';
|
||||
import ParamScheduler from 'features/parameters/components/Core/ParamScheduler';
|
||||
import ParamSteps from 'features/parameters/components/Core/ParamSteps';
|
||||
import { NavigateToModelManagerButton } from 'features/parameters/components/MainModel/NavigateToModelManagerButton';
|
||||
import ParamMainModelSelect from 'features/parameters/components/MainModel/ParamMainModelSelect';
|
||||
import { UseDefaultSettingsButton } from 'features/parameters/components/MainModel/UseDefaultSettingsButton';
|
||||
import { useExpanderToggle } from 'features/settingsAccordions/hooks/useExpanderToggle';
|
||||
import { useStandaloneAccordionToggle } from 'features/settingsAccordions/hooks/useStandaloneAccordionToggle';
|
||||
import { filter } from 'lodash-es';
|
||||
@ -39,11 +31,11 @@ export const GenerationSettingsAccordion = memo(() => {
|
||||
() =>
|
||||
createMemoizedSelector(selectLoraSlice, (lora) => {
|
||||
const enabledLoRAsCount = filter(lora.loras, (l) => !!l.isEnabled).length;
|
||||
const loraTabBadges = enabledLoRAsCount ? [enabledLoRAsCount] : EMPTY_ARRAY;
|
||||
const loraTabBadges = enabledLoRAsCount ? [`${enabledLoRAsCount} ${t('models.concepts')}`] : EMPTY_ARRAY;
|
||||
const accordionBadges = modelConfig ? [modelConfig.name, modelConfig.base] : EMPTY_ARRAY;
|
||||
return { loraTabBadges, accordionBadges };
|
||||
}),
|
||||
[modelConfig]
|
||||
[modelConfig, t]
|
||||
);
|
||||
const { loraTabBadges, accordionBadges } = useAppSelector(selectBadges);
|
||||
const { isOpen: isOpenExpander, onToggle: onToggleExpander } = useExpanderToggle({
|
||||
@ -58,39 +50,35 @@ export const GenerationSettingsAccordion = memo(() => {
|
||||
return (
|
||||
<StandaloneAccordion
|
||||
label={t('accordions.generation.title')}
|
||||
badges={accordionBadges}
|
||||
badges={[...accordionBadges, ...loraTabBadges]}
|
||||
isOpen={isOpenAccordion}
|
||||
onToggle={onToggleAccordion}
|
||||
>
|
||||
<Tabs variant="collapse">
|
||||
<TabList>
|
||||
<Tab>{t('accordions.generation.modelTab')}</Tab>
|
||||
<Tab badges={loraTabBadges}>{t('accordions.generation.conceptsTab')}</Tab>
|
||||
</TabList>
|
||||
<TabPanels>
|
||||
<TabPanel overflow="visible" px={4} pt={4}>
|
||||
<Flex gap={4} alignItems="center">
|
||||
<ParamMainModelSelect />
|
||||
<Box px={4} pt={4}>
|
||||
<Flex gap={4} flexDir="column">
|
||||
<Flex gap={4} alignItems="center">
|
||||
<ParamMainModelSelect />
|
||||
<Flex>
|
||||
<UseDefaultSettingsButton />
|
||||
<SyncModelsIconButton />
|
||||
<NavigateToModelManagerButton />
|
||||
</Flex>
|
||||
<Expander isOpen={isOpenExpander} onToggle={onToggleExpander}>
|
||||
<Flex gap={4} flexDir="column" pb={4}>
|
||||
<FormControlGroup formLabelProps={formLabelProps}>
|
||||
<ParamScheduler />
|
||||
<ParamSteps />
|
||||
<ParamCFGScale />
|
||||
</FormControlGroup>
|
||||
</Flex>
|
||||
</Expander>
|
||||
</TabPanel>
|
||||
<TabPanel>
|
||||
<Flex gap={4} p={4} flexDir="column">
|
||||
<LoRASelect />
|
||||
<LoRAList />
|
||||
</Flex>
|
||||
</TabPanel>
|
||||
</TabPanels>
|
||||
</Tabs>
|
||||
</Flex>
|
||||
<Flex gap={4} flexDir="column">
|
||||
<LoRASelect />
|
||||
<LoRAList />
|
||||
</Flex>
|
||||
</Flex>
|
||||
<Expander isOpen={isOpenExpander} onToggle={onToggleExpander}>
|
||||
<Flex gap={4} flexDir="column" pb={4}>
|
||||
<FormControlGroup formLabelProps={formLabelProps}>
|
||||
<ParamScheduler />
|
||||
<ParamSteps />
|
||||
<ParamCFGScale />
|
||||
</FormControlGroup>
|
||||
</Flex>
|
||||
</Expander>
|
||||
</Box>
|
||||
</StandaloneAccordion>
|
||||
);
|
||||
});
|
||||
|
@ -41,6 +41,8 @@ const initialConfigState: AppConfig = {
|
||||
boundingBoxHeight: { ...baseDimensionConfig },
|
||||
scaledBoundingBoxWidth: { ...baseDimensionConfig },
|
||||
scaledBoundingBoxHeight: { ...baseDimensionConfig },
|
||||
scheduler: 'euler',
|
||||
vaePrecision: 'fp32',
|
||||
steps: {
|
||||
initial: 30,
|
||||
sliderMin: 1,
|
||||
|
@ -24,7 +24,15 @@ export type UpdateModelArg = {
|
||||
body: paths['/api/v2/models/i/{key}']['patch']['requestBody']['content']['application/json'];
|
||||
};
|
||||
|
||||
type UpdateModelMetadataArg = {
|
||||
key: paths['/api/v2/models/i/{key}/metadata']['patch']['parameters']['path']['key'];
|
||||
body: paths['/api/v2/models/i/{key}/metadata']['patch']['requestBody']['content']['application/json'];
|
||||
};
|
||||
|
||||
type UpdateModelResponse = paths['/api/v2/models/i/{key}']['patch']['responses']['200']['content']['application/json'];
|
||||
type UpdateModelMetadataResponse =
|
||||
paths['/api/v2/models/i/{key}/metadata']['patch']['responses']['200']['content']['application/json'];
|
||||
|
||||
type GetModelConfigResponse = paths['/api/v2/models/i/{key}']['get']['responses']['200']['content']['application/json'];
|
||||
|
||||
type GetModelMetadataResponse =
|
||||
@ -172,6 +180,16 @@ export const modelsApi = api.injectEndpoints({
|
||||
},
|
||||
invalidatesTags: ['Model'],
|
||||
}),
|
||||
updateModelMetadata: build.mutation<UpdateModelMetadataResponse, UpdateModelMetadataArg>({
|
||||
query: ({ key, body }) => {
|
||||
return {
|
||||
url: buildModelsUrl(`i/${key}/metadata`),
|
||||
method: 'PATCH',
|
||||
body: body,
|
||||
};
|
||||
},
|
||||
invalidatesTags: ['Model'],
|
||||
}),
|
||||
installModel: build.mutation<InstallModelResponse, InstallModelArg>({
|
||||
query: ({ source, config, access_token }) => {
|
||||
return {
|
||||
@ -351,6 +369,7 @@ export const {
|
||||
useGetModelMetadataQuery,
|
||||
useDeleteModelImportMutation,
|
||||
usePruneModelImportsMutation,
|
||||
useUpdateModelMetadataMutation,
|
||||
} = modelsApi;
|
||||
|
||||
const upsertModelConfigs = (
|
||||
|
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user