mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge remote-tracking branch 'origin/main' into maryhipp/trigger-phrases-main
This commit is contained in:
@ -19,6 +19,8 @@ their descriptions.
|
|||||||
| Conditioning Primitive | A conditioning tensor primitive value |
|
| Conditioning Primitive | A conditioning tensor primitive value |
|
||||||
| Content Shuffle Processor | Applies content shuffle processing to image |
|
| Content Shuffle Processor | Applies content shuffle processing to image |
|
||||||
| ControlNet | Collects ControlNet info to pass to other nodes |
|
| ControlNet | Collects ControlNet info to pass to other nodes |
|
||||||
|
| Create Denoise Mask | Converts a greyscale or transparency image into a mask for denoising. |
|
||||||
|
| Create Gradient Mask | Creates a mask for Gradient ("soft", "differential") inpainting that gradually expands during denoising. Improves edge coherence. |
|
||||||
| Denoise Latents | Denoises noisy latents to decodable images |
|
| Denoise Latents | Denoises noisy latents to decodable images |
|
||||||
| Divide Integers | Divides two numbers |
|
| Divide Integers | Divides two numbers |
|
||||||
| Dynamic Prompt | Parses a prompt using adieyal/dynamicprompts' random or combinatorial generator |
|
| Dynamic Prompt | Parses a prompt using adieyal/dynamicprompts' random or combinatorial generator |
|
||||||
|
@ -283,6 +283,47 @@ async def update_model_metadata(
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@model_manager_router.patch(
|
||||||
|
"/i/{key}/metadata",
|
||||||
|
operation_id="update_model_metadata",
|
||||||
|
responses={
|
||||||
|
201: {
|
||||||
|
"description": "The model metadata was updated successfully",
|
||||||
|
"content": {"application/json": {"example": example_model_metadata}},
|
||||||
|
},
|
||||||
|
400: {"description": "Bad request"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
async def update_model_metadata(
|
||||||
|
key: str = Path(description="Key of the model repo metadata to fetch."),
|
||||||
|
changes: ModelMetadataChanges = Body(description="The changes"),
|
||||||
|
) -> Optional[AnyModelRepoMetadata]:
|
||||||
|
"""Updates or creates a model metadata object."""
|
||||||
|
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||||
|
metadata_store = ApiDependencies.invoker.services.model_manager.store.metadata_store
|
||||||
|
|
||||||
|
try:
|
||||||
|
original_metadata = record_store.get_metadata(key)
|
||||||
|
if original_metadata:
|
||||||
|
if changes.default_settings:
|
||||||
|
original_metadata.default_settings = changes.default_settings
|
||||||
|
|
||||||
|
metadata_store.update_metadata(key, original_metadata)
|
||||||
|
else:
|
||||||
|
metadata_store.add_metadata(
|
||||||
|
key, BaseMetadata(name="", author="", default_settings=changes.default_settings)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail=f"An error occurred while updating the model metadata: {e}",
|
||||||
|
)
|
||||||
|
|
||||||
|
result: Optional[AnyModelRepoMetadata] = record_store.get_metadata(key)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
@model_manager_router.get(
|
@model_manager_router.get(
|
||||||
"/tags",
|
"/tags",
|
||||||
operation_id="list_tags",
|
operation_id="list_tags",
|
||||||
@ -491,6 +532,7 @@ async def add_model_record(
|
|||||||
)
|
)
|
||||||
async def install_model(
|
async def install_model(
|
||||||
source: str = Query(description="Model source to install, can be a local path, repo_id, or remote URL"),
|
source: str = Query(description="Model source to install, can be a local path, repo_id, or remote URL"),
|
||||||
|
inplace: Optional[bool] = Query(description="Whether or not to install a local model in place", default=False),
|
||||||
# TODO(MM2): Can we type this?
|
# TODO(MM2): Can we type this?
|
||||||
config: Optional[Dict[str, Any]] = Body(
|
config: Optional[Dict[str, Any]] = Body(
|
||||||
description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
|
description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
|
||||||
@ -533,6 +575,7 @@ async def install_model(
|
|||||||
source=source,
|
source=source,
|
||||||
config=config,
|
config=config,
|
||||||
access_token=access_token,
|
access_token=access_token,
|
||||||
|
inplace=bool(inplace),
|
||||||
)
|
)
|
||||||
logger.info(f"Started installation of {source}")
|
logger.info(f"Started installation of {source}")
|
||||||
except UnknownModelException as e:
|
except UnknownModelException as e:
|
||||||
|
@ -173,6 +173,16 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@invocation_output("gradient_mask_output")
|
||||||
|
class GradientMaskOutput(BaseInvocationOutput):
|
||||||
|
"""Outputs a denoise mask and an image representing the total gradient of the mask."""
|
||||||
|
|
||||||
|
denoise_mask: DenoiseMaskField = OutputField(description="Mask for denoise model run")
|
||||||
|
expanded_mask_area: ImageField = OutputField(
|
||||||
|
description="Image representing the total gradient area of the mask. For paste-back purposes."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
"create_gradient_mask",
|
"create_gradient_mask",
|
||||||
title="Create Gradient Mask",
|
title="Create Gradient Mask",
|
||||||
@ -193,38 +203,42 @@ class CreateGradientMaskInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> DenoiseMaskOutput:
|
def invoke(self, context: InvocationContext) -> GradientMaskOutput:
|
||||||
mask_image = context.images.get_pil(self.mask.image_name, mode="L")
|
mask_image = context.images.get_pil(self.mask.image_name, mode="L")
|
||||||
if self.coherence_mode == "Box Blur":
|
if self.edge_radius > 0:
|
||||||
blur_mask = mask_image.filter(ImageFilter.BoxBlur(self.edge_radius))
|
if self.coherence_mode == "Box Blur":
|
||||||
else: # Gaussian Blur OR Staged
|
blur_mask = mask_image.filter(ImageFilter.BoxBlur(self.edge_radius))
|
||||||
# Gaussian Blur uses standard deviation. 1/2 radius is a good approximation
|
else: # Gaussian Blur OR Staged
|
||||||
blur_mask = mask_image.filter(ImageFilter.GaussianBlur(self.edge_radius / 2))
|
# Gaussian Blur uses standard deviation. 1/2 radius is a good approximation
|
||||||
|
blur_mask = mask_image.filter(ImageFilter.GaussianBlur(self.edge_radius / 2))
|
||||||
|
|
||||||
mask_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
|
blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(blur_mask, normalize=False)
|
||||||
blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(blur_mask, normalize=False)
|
|
||||||
|
|
||||||
# redistribute blur so that the edges are 0 and blur out to 1
|
# redistribute blur so that the original edges are 0 and blur outwards to 1
|
||||||
blur_tensor = (blur_tensor - 0.5) * 2
|
blur_tensor = (blur_tensor - 0.5) * 2
|
||||||
|
|
||||||
threshold = 1 - self.minimum_denoise
|
threshold = 1 - self.minimum_denoise
|
||||||
|
|
||||||
|
if self.coherence_mode == "Staged":
|
||||||
|
# wherever the blur_tensor is less than fully masked, convert it to threshold
|
||||||
|
blur_tensor = torch.where((blur_tensor < 1) & (blur_tensor > 0), threshold, blur_tensor)
|
||||||
|
else:
|
||||||
|
# wherever the blur_tensor is above threshold but less than 1, drop it to threshold
|
||||||
|
blur_tensor = torch.where((blur_tensor > threshold) & (blur_tensor < 1), threshold, blur_tensor)
|
||||||
|
|
||||||
if self.coherence_mode == "Staged":
|
|
||||||
# wherever the blur_tensor is masked to any degree, convert it to threshold
|
|
||||||
blur_tensor = torch.where((blur_tensor < 1), threshold, blur_tensor)
|
|
||||||
else:
|
else:
|
||||||
# wherever the blur_tensor is above threshold but less than 1, drop it to threshold
|
blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
|
||||||
blur_tensor = torch.where((blur_tensor > threshold) & (blur_tensor < 1), threshold, blur_tensor)
|
|
||||||
|
|
||||||
# multiply original mask to force actually masked regions to 0
|
|
||||||
blur_tensor = mask_tensor * blur_tensor
|
|
||||||
|
|
||||||
mask_name = context.tensors.save(tensor=blur_tensor.unsqueeze(1))
|
mask_name = context.tensors.save(tensor=blur_tensor.unsqueeze(1))
|
||||||
|
|
||||||
return DenoiseMaskOutput.build(
|
# compute a [0, 1] mask from the blur_tensor
|
||||||
mask_name=mask_name,
|
expanded_mask = torch.where((blur_tensor < 1), 0, 1)
|
||||||
masked_latents_name=None,
|
expanded_mask_image = Image.fromarray((expanded_mask.squeeze(0).numpy() * 255).astype(np.uint8), mode="L")
|
||||||
gradient=True,
|
expanded_image_dto = context.images.save(expanded_mask_image)
|
||||||
|
|
||||||
|
return GradientMaskOutput(
|
||||||
|
denoise_mask=DenoiseMaskField(mask_name=mask_name, masked_latents_name=None, gradient=True),
|
||||||
|
expanded_mask_area=ImageField(image_name=expanded_image_dto.image_name),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -775,10 +789,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
denoising_end=self.denoising_end,
|
denoising_end=self.denoising_end,
|
||||||
)
|
)
|
||||||
|
|
||||||
(
|
result_latents = pipeline.latents_from_embeddings(
|
||||||
result_latents,
|
|
||||||
result_attention_map_saver,
|
|
||||||
) = pipeline.latents_from_embeddings(
|
|
||||||
latents=latents,
|
latents=latents,
|
||||||
timesteps=timesteps,
|
timesteps=timesteps,
|
||||||
init_timestep=init_timestep,
|
init_timestep=init_timestep,
|
||||||
|
@ -7,7 +7,6 @@ import time
|
|||||||
from hashlib import sha256
|
from hashlib import sha256
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from queue import Empty, Queue
|
from queue import Empty, Queue
|
||||||
from random import randbytes
|
|
||||||
from shutil import copyfile, copytree, move, rmtree
|
from shutil import copyfile, copytree, move, rmtree
|
||||||
from tempfile import mkdtemp
|
from tempfile import mkdtemp
|
||||||
from typing import Any, Dict, List, Optional, Set, Union
|
from typing import Any, Dict, List, Optional, Set, Union
|
||||||
@ -21,6 +20,7 @@ from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase
|
|||||||
from invokeai.app.services.events.events_base import EventServiceBase
|
from invokeai.app.services.events.events_base import EventServiceBase
|
||||||
from invokeai.app.services.invoker import Invoker
|
from invokeai.app.services.invoker import Invoker
|
||||||
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase
|
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase
|
||||||
|
from invokeai.app.util.misc import uuid_string
|
||||||
from invokeai.backend.model_manager.config import (
|
from invokeai.backend.model_manager.config import (
|
||||||
AnyModelConfig,
|
AnyModelConfig,
|
||||||
BaseModelType,
|
BaseModelType,
|
||||||
@ -150,7 +150,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
config = config or {}
|
config = config or {}
|
||||||
if not config.get("source"):
|
if not config.get("source"):
|
||||||
config["source"] = model_path.resolve().as_posix()
|
config["source"] = model_path.resolve().as_posix()
|
||||||
config["key"] = config.get("key", self._create_key())
|
config["key"] = config.get("key", uuid_string())
|
||||||
|
|
||||||
info: AnyModelConfig = self._probe_model(Path(model_path), config)
|
info: AnyModelConfig = self._probe_model(Path(model_path), config)
|
||||||
|
|
||||||
@ -178,13 +178,14 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
source: str,
|
source: str,
|
||||||
config: Optional[Dict[str, Any]] = None,
|
config: Optional[Dict[str, Any]] = None,
|
||||||
access_token: Optional[str] = None,
|
access_token: Optional[str] = None,
|
||||||
|
inplace: bool = False,
|
||||||
) -> ModelInstallJob:
|
) -> ModelInstallJob:
|
||||||
variants = "|".join(ModelRepoVariant.__members__.values())
|
variants = "|".join(ModelRepoVariant.__members__.values())
|
||||||
hf_repoid_re = f"^([^/:]+/[^/:]+)(?::({variants})?(?::/?([^:]+))?)?$"
|
hf_repoid_re = f"^([^/:]+/[^/:]+)(?::({variants})?(?::/?([^:]+))?)?$"
|
||||||
source_obj: Optional[StringLikeSource] = None
|
source_obj: Optional[StringLikeSource] = None
|
||||||
|
|
||||||
if Path(source).exists(): # A local file or directory
|
if Path(source).exists(): # A local file or directory
|
||||||
source_obj = LocalModelSource(path=Path(source))
|
source_obj = LocalModelSource(path=Path(source), inplace=inplace)
|
||||||
elif match := re.match(hf_repoid_re, source):
|
elif match := re.match(hf_repoid_re, source):
|
||||||
source_obj = HFModelSource(
|
source_obj = HFModelSource(
|
||||||
repo_id=match.group(1),
|
repo_id=match.group(1),
|
||||||
@ -526,16 +527,17 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
setattr(info, key, value)
|
setattr(info, key, value)
|
||||||
return info
|
return info
|
||||||
|
|
||||||
def _create_key(self) -> str:
|
|
||||||
return sha256(randbytes(100)).hexdigest()[0:32]
|
|
||||||
|
|
||||||
def _register(
|
def _register(
|
||||||
self, model_path: Path, config: Optional[Dict[str, Any]] = None, info: Optional[AnyModelConfig] = None
|
self, model_path: Path, config: Optional[Dict[str, Any]] = None, info: Optional[AnyModelConfig] = None
|
||||||
) -> str:
|
) -> str:
|
||||||
# Note that we may be passed a pre-populated AnyModelConfig object,
|
# Note that we may be passed a pre-populated AnyModelConfig object,
|
||||||
# in which case the key field should have been populated by the caller (e.g. in `install_path`).
|
# in which case the key field should have been populated by the caller (e.g. in `install_path`).
|
||||||
config["key"] = config.get("key", self._create_key())
|
config["key"] = config.get("key", uuid_string())
|
||||||
info = info or ModelProbe.probe(model_path, config)
|
info = info or ModelProbe.probe(model_path, config)
|
||||||
|
override_key: Optional[str] = config.get("key") if config else None
|
||||||
|
|
||||||
|
assert info.original_hash # always assigned by probe()
|
||||||
|
info.key = override_key or info.original_hash
|
||||||
|
|
||||||
model_path = model_path.absolute()
|
model_path = model_path.absolute()
|
||||||
if model_path.is_relative_to(self.app_config.models_path):
|
if model_path.is_relative_to(self.app_config.models_path):
|
||||||
|
@ -7,17 +7,23 @@ from abc import ABC, abstractmethod
|
|||||||
from typing import List, Optional, Set, Tuple
|
from typing import List, Optional, Set, Tuple
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
|
||||||
|
|
||||||
|
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
||||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||||
|
from invokeai.backend.model_manager.metadata.metadata_base import ModelDefaultSettings
|
||||||
|
|
||||||
class ModelMetadataChanges(BaseModelExcludeNull, extra="allow"):
|
class ModelMetadataChanges(BaseModelExcludeNull, extra="allow"):
|
||||||
"""A set of changes to apply to model metadata.
|
"""A set of changes to apply to model metadata.
|
||||||
|
|
||||||
Only limited changes are valid:
|
Only limited changes are valid:
|
||||||
|
- `default_settings`: the user-configured default settings for this model
|
||||||
- `trigger_phrases`: the list of trigger phrases for this model
|
- `trigger_phrases`: the list of trigger phrases for this model
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
default_settings: Optional[ModelDefaultSettings] = Field(
|
||||||
|
default=None, description="The user-configured default settings for this model"
|
||||||
|
)
|
||||||
|
"""The user-configured default settings for this model"""
|
||||||
trigger_phrases: Optional[List[str]] = Field(default=None, description="The model's list of trigger phrases")
|
trigger_phrases: Optional[List[str]] = Field(default=None, description="The model's list of trigger phrases")
|
||||||
"""The model's list of trigger phrases"""
|
"""The model's list of trigger phrases"""
|
||||||
|
|
||||||
|
@ -184,7 +184,7 @@ class ModelMetadataStoreSQL(ModelMetadataStoreBase):
|
|||||||
def _update_tags(self, model_key: str, tags: Optional[Set[str]]) -> None:
|
def _update_tags(self, model_key: str, tags: Optional[Set[str]]) -> None:
|
||||||
"""Update tags for the model referenced by model_key."""
|
"""Update tags for the model referenced by model_key."""
|
||||||
if tags:
|
if tags:
|
||||||
# remove previous tags from this model
|
# remove previous tags from this model
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
"""--sql
|
"""--sql
|
||||||
DELETE FROM model_tags
|
DELETE FROM model_tags
|
||||||
|
@ -200,6 +200,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
|||||||
self._invoker.services.logger.error(
|
self._invoker.services.logger.error(
|
||||||
f"Error while invoking session {self._queue_item.session_id}, invocation {self._invocation.id} ({self._invocation.get_type()}):\n{e}"
|
f"Error while invoking session {self._queue_item.session_id}, invocation {self._invocation.id} ({self._invocation.get_type()}):\n{e}"
|
||||||
)
|
)
|
||||||
|
self._invoker.services.logger.error(error)
|
||||||
|
|
||||||
# Send error event
|
# Send error event
|
||||||
self._invoker.services.events.emit_invocation_error(
|
self._invoker.services.events.emit_invocation_error(
|
||||||
|
@ -3,7 +3,6 @@
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import sqlite3
|
import sqlite3
|
||||||
from hashlib import sha1
|
|
||||||
from logging import Logger
|
from logging import Logger
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
@ -22,7 +21,7 @@ from invokeai.backend.model_manager.config import (
|
|||||||
ModelConfigFactory,
|
ModelConfigFactory,
|
||||||
ModelType,
|
ModelType,
|
||||||
)
|
)
|
||||||
from invokeai.backend.model_manager.hash import FastModelHash
|
from invokeai.backend.model_manager.hash import ModelHash
|
||||||
|
|
||||||
ModelsValidator = TypeAdapter(AnyModelConfig)
|
ModelsValidator = TypeAdapter(AnyModelConfig)
|
||||||
|
|
||||||
@ -73,19 +72,27 @@ class MigrateModelYamlToDb1:
|
|||||||
|
|
||||||
base_type, model_type, model_name = str(model_key).split("/")
|
base_type, model_type, model_name = str(model_key).split("/")
|
||||||
try:
|
try:
|
||||||
hash = FastModelHash.hash(self.config.models_path / stanza.path)
|
hash = ModelHash().hash(self.config.models_path / stanza.path)
|
||||||
except OSError:
|
except OSError:
|
||||||
self.logger.warning(f"The model at {stanza.path} is not a valid file or directory. Skipping migration.")
|
self.logger.warning(f"The model at {stanza.path} is not a valid file or directory. Skipping migration.")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
assert isinstance(model_key, str)
|
|
||||||
new_key = sha1(model_key.encode("utf-8")).hexdigest()
|
|
||||||
|
|
||||||
stanza["base"] = BaseModelType(base_type)
|
stanza["base"] = BaseModelType(base_type)
|
||||||
stanza["type"] = ModelType(model_type)
|
stanza["type"] = ModelType(model_type)
|
||||||
stanza["name"] = model_name
|
stanza["name"] = model_name
|
||||||
stanza["original_hash"] = hash
|
stanza["original_hash"] = hash
|
||||||
stanza["current_hash"] = hash
|
stanza["current_hash"] = hash
|
||||||
|
new_key = hash # deterministic key assignment
|
||||||
|
|
||||||
|
# special case for ip adapters, which need the new `image_encoder_model_id` field
|
||||||
|
if stanza["type"] == ModelType.IPAdapter:
|
||||||
|
try:
|
||||||
|
stanza["image_encoder_model_id"] = self._get_image_encoder_model_id(
|
||||||
|
self.config.models_path / stanza.path
|
||||||
|
)
|
||||||
|
except OSError:
|
||||||
|
self.logger.warning(f"Could not determine image encoder for {stanza.path}. Skipping.")
|
||||||
|
continue
|
||||||
|
|
||||||
new_config: AnyModelConfig = ModelsValidator.validate_python(stanza) # type: ignore # see https://github.com/pydantic/pydantic/discussions/7094
|
new_config: AnyModelConfig = ModelsValidator.validate_python(stanza) # type: ignore # see https://github.com/pydantic/pydantic/discussions/7094
|
||||||
|
|
||||||
@ -95,7 +102,7 @@ class MigrateModelYamlToDb1:
|
|||||||
self.logger.info(f"Updating model {model_name} with information from models.yaml using key {key}")
|
self.logger.info(f"Updating model {model_name} with information from models.yaml using key {key}")
|
||||||
self._update_model(key, new_config)
|
self._update_model(key, new_config)
|
||||||
else:
|
else:
|
||||||
self.logger.info(f"Adding model {model_name} with key {model_key}")
|
self.logger.info(f"Adding model {model_name} with key {new_key}")
|
||||||
self._add_model(new_key, new_config)
|
self._add_model(new_key, new_config)
|
||||||
except DuplicateModelException:
|
except DuplicateModelException:
|
||||||
self.logger.warning(f"Model {model_name} is already in the database")
|
self.logger.warning(f"Model {model_name} is already in the database")
|
||||||
@ -149,3 +156,8 @@ class MigrateModelYamlToDb1:
|
|||||||
)
|
)
|
||||||
except sqlite3.IntegrityError as exc:
|
except sqlite3.IntegrityError as exc:
|
||||||
raise DuplicateModelException(f"{record.name}: model is already in database") from exc
|
raise DuplicateModelException(f"{record.name}: model is already in database") from exc
|
||||||
|
|
||||||
|
def _get_image_encoder_model_id(self, model_path: Path) -> str:
|
||||||
|
with open(model_path / "image_encoder.txt") as f:
|
||||||
|
encoder = f.read()
|
||||||
|
return encoder.strip()
|
||||||
|
@ -11,56 +11,175 @@ from invokeai.backend.model_managre.model_hash import FastModelHash
|
|||||||
import hashlib
|
import hashlib
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Union
|
from typing import Callable, Literal, Optional, Union
|
||||||
|
|
||||||
from imohash import hashfile
|
from blake3 import blake3
|
||||||
|
|
||||||
|
MODEL_FILE_EXTENSIONS = (".ckpt", ".safetensors", ".bin", ".pt", ".pth")
|
||||||
|
|
||||||
|
ALGORITHM = Literal[
|
||||||
|
"md5",
|
||||||
|
"sha1",
|
||||||
|
"sha224",
|
||||||
|
"sha256",
|
||||||
|
"sha384",
|
||||||
|
"sha512",
|
||||||
|
"blake2b",
|
||||||
|
"blake2s",
|
||||||
|
"sha3_224",
|
||||||
|
"sha3_256",
|
||||||
|
"sha3_384",
|
||||||
|
"sha3_512",
|
||||||
|
"shake_128",
|
||||||
|
"shake_256",
|
||||||
|
"blake3",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class FastModelHash(object):
|
class ModelHash:
|
||||||
"""FastModelHash obect provides one public class method, hash()."""
|
"""
|
||||||
|
Creates a hash of a model using a specified algorithm.
|
||||||
|
|
||||||
@classmethod
|
Args:
|
||||||
def hash(cls, model_location: Union[str, Path]) -> str:
|
algorithm: Hashing algorithm to use. Defaults to BLAKE3.
|
||||||
"""
|
file_filter: A function that takes a file name and returns True if the file should be included in the hash.
|
||||||
Return hexdigest string for model located at model_location.
|
|
||||||
|
|
||||||
:param model_location: Path to the model
|
If the model is a single file, it is hashed directly using the provided algorithm.
|
||||||
"""
|
|
||||||
model_location = Path(model_location)
|
If the model is a directory, each model weights file in the directory is hashed using the provided algorithm.
|
||||||
if model_location.is_file():
|
|
||||||
return cls._hash_file(model_location)
|
Only files with the following extensions are hashed: .ckpt, .safetensors, .bin, .pt, .pth
|
||||||
elif model_location.is_dir():
|
|
||||||
return cls._hash_dir(model_location)
|
The final hash is computed by hashing the hashes of all model files in the directory using BLAKE3, ensuring
|
||||||
|
that directory hashes are never weaker than the file hashes.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
```py
|
||||||
|
# BLAKE3 hash
|
||||||
|
ModelHash().hash("path/to/some/model.safetensors")
|
||||||
|
# MD5
|
||||||
|
ModelHash("md5").hash("path/to/model/dir/")
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, algorithm: ALGORITHM = "blake3", file_filter: Optional[Callable[[str], bool]] = None) -> None:
|
||||||
|
if algorithm == "blake3":
|
||||||
|
self._hash_file = self._blake3
|
||||||
|
elif algorithm in hashlib.algorithms_available:
|
||||||
|
self._hash_file = self._get_hashlib(algorithm)
|
||||||
else:
|
else:
|
||||||
raise OSError(f"Not a valid file or directory: {model_location}")
|
raise ValueError(f"Algorithm {algorithm} not available")
|
||||||
|
|
||||||
@classmethod
|
self._file_filter = file_filter or self._default_file_filter
|
||||||
def _hash_file(cls, model_location: Union[str, Path]) -> str:
|
|
||||||
|
def hash(self, model_path: Union[str, Path]) -> str:
|
||||||
"""
|
"""
|
||||||
Fasthash a single file and return its hexdigest.
|
Return hexdigest of hash of model located at model_path using the algorithm provided at class instantiation.
|
||||||
|
|
||||||
:param model_location: Path to the model file
|
If model_path is a directory, the hash is computed by hashing the hashes of all model files in the
|
||||||
|
directory. The final composite hash is always computed using BLAKE3.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path: Path to the model
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Hexdigest of the hash of the model
|
||||||
"""
|
"""
|
||||||
# we return md5 hash of the filehash to make it shorter
|
|
||||||
# cryptographic security not needed here
|
|
||||||
return hashlib.md5(hashfile(model_location)).hexdigest()
|
|
||||||
|
|
||||||
@classmethod
|
model_path = Path(model_path)
|
||||||
def _hash_dir(cls, model_location: Union[str, Path]) -> str:
|
if model_path.is_file():
|
||||||
components: Dict[str, str] = {}
|
return self._hash_file(model_path)
|
||||||
|
elif model_path.is_dir():
|
||||||
|
return self._hash_dir(model_path)
|
||||||
|
else:
|
||||||
|
raise OSError(f"Not a valid file or directory: {model_path}")
|
||||||
|
|
||||||
for root, _dirs, files in os.walk(model_location):
|
def _hash_dir(self, dir: Path) -> str:
|
||||||
for file in files:
|
"""Compute the hash for all files in a directory and return a hexdigest.
|
||||||
# only tally tensor files because diffusers config files change slightly
|
|
||||||
# depending on how the model was downloaded/converted.
|
|
||||||
if not file.endswith((".ckpt", ".safetensors", ".bin", ".pt", ".pth")):
|
|
||||||
continue
|
|
||||||
path = (Path(root) / file).as_posix()
|
|
||||||
fast_hash = cls._hash_file(path)
|
|
||||||
components.update({path: fast_hash})
|
|
||||||
|
|
||||||
# hash all the model hashes together, using alphabetic file order
|
Args:
|
||||||
md5 = hashlib.md5()
|
dir: Path to the directory
|
||||||
for _path, fast_hash in sorted(components.items()):
|
|
||||||
md5.update(fast_hash.encode("utf-8"))
|
Returns:
|
||||||
return md5.hexdigest()
|
str: Hexdigest of the hash of the directory
|
||||||
|
"""
|
||||||
|
model_component_paths = self._get_file_paths(dir, self._file_filter)
|
||||||
|
|
||||||
|
component_hashes: list[str] = []
|
||||||
|
for component in sorted(model_component_paths):
|
||||||
|
component_hashes.append(self._hash_file(component))
|
||||||
|
|
||||||
|
# BLAKE3 is cryptographically secure. We may as well fall back on a secure algorithm
|
||||||
|
# for the composite hash
|
||||||
|
composite_hasher = blake3()
|
||||||
|
for h in component_hashes:
|
||||||
|
composite_hasher.update(h.encode("utf-8"))
|
||||||
|
return composite_hasher.hexdigest()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_file_paths(model_path: Path, file_filter: Callable[[str], bool]) -> list[Path]:
|
||||||
|
"""Return a list of all model files in the directory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path: Path to the model
|
||||||
|
file_filter: Function that takes a file name and returns True if the file should be included in the list.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of all model files in the directory
|
||||||
|
"""
|
||||||
|
|
||||||
|
files: list[Path] = []
|
||||||
|
for root, _dirs, _files in os.walk(model_path):
|
||||||
|
for file in _files:
|
||||||
|
if file_filter(file):
|
||||||
|
files.append(Path(root, file))
|
||||||
|
return files
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _blake3(file_path: Path) -> str:
|
||||||
|
"""Hashes a file using BLAKE3
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to the file to hash
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Hexdigest of the hash of the file
|
||||||
|
"""
|
||||||
|
file_hasher = blake3(max_threads=blake3.AUTO)
|
||||||
|
file_hasher.update_mmap(file_path)
|
||||||
|
return file_hasher.hexdigest()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_hashlib(algorithm: ALGORITHM) -> Callable[[Path], str]:
|
||||||
|
"""Factory function that returns a function to hash a file with the given algorithm.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
algorithm: Hashing algorithm to use
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A function that hashes a file using the given algorithm
|
||||||
|
"""
|
||||||
|
|
||||||
|
def hashlib_hasher(file_path: Path) -> str:
|
||||||
|
"""Hashes a file using a hashlib algorithm. Uses `memoryview` to avoid reading the entire file into memory."""
|
||||||
|
hasher = hashlib.new(algorithm)
|
||||||
|
buffer = bytearray(128 * 1024)
|
||||||
|
mv = memoryview(buffer)
|
||||||
|
with open(file_path, "rb", buffering=0) as f:
|
||||||
|
while n := f.readinto(mv):
|
||||||
|
hasher.update(mv[:n])
|
||||||
|
return hasher.hexdigest()
|
||||||
|
|
||||||
|
return hashlib_hasher
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _default_file_filter(file_path: str) -> bool:
|
||||||
|
"""A default file filter that only includes files with the following extensions: .ckpt, .safetensors, .bin, .pt, .pth
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to the file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the file matches the given extensions, otherwise False
|
||||||
|
"""
|
||||||
|
return file_path.endswith(MODEL_FILE_EXTENSIONS)
|
||||||
|
@ -25,6 +25,7 @@ from pydantic.networks import AnyHttpUrl
|
|||||||
from requests.sessions import Session
|
from requests.sessions import Session
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
|
from invokeai.app.invocations.constants import SCHEDULER_NAME_VALUES
|
||||||
from invokeai.backend.model_manager import ModelRepoVariant
|
from invokeai.backend.model_manager import ModelRepoVariant
|
||||||
|
|
||||||
from ..util import select_hf_files
|
from ..util import select_hf_files
|
||||||
@ -68,6 +69,15 @@ class RemoteModelFile(BaseModel):
|
|||||||
sha256: Optional[str] = Field(description="SHA256 hash of this model (not always available)", default=None)
|
sha256: Optional[str] = Field(description="SHA256 hash of this model (not always available)", default=None)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelDefaultSettings(BaseModel):
|
||||||
|
vae: str | None
|
||||||
|
vae_precision: str | None
|
||||||
|
scheduler: SCHEDULER_NAME_VALUES | None
|
||||||
|
steps: int | None
|
||||||
|
cfg_scale: float | None
|
||||||
|
cfg_rescale_multiplier: float | None
|
||||||
|
|
||||||
|
|
||||||
class ModelMetadataBase(BaseModel):
|
class ModelMetadataBase(BaseModel):
|
||||||
"""Base class for model metadata information."""
|
"""Base class for model metadata information."""
|
||||||
|
|
||||||
@ -75,6 +85,9 @@ class ModelMetadataBase(BaseModel):
|
|||||||
author: str = Field(description="model's author")
|
author: str = Field(description="model's author")
|
||||||
tags: Optional[Set[str]] = Field(description="tags provided by model source", default=None)
|
tags: Optional[Set[str]] = Field(description="tags provided by model source", default=None)
|
||||||
trigger_phrases: Optional[List[str]] = Field(description="trigger phrases for this model", default=None)
|
trigger_phrases: Optional[List[str]] = Field(description="trigger phrases for this model", default=None)
|
||||||
|
default_settings: Optional[ModelDefaultSettings] = Field(
|
||||||
|
description="default settings for this model", default=None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class BaseMetadata(ModelMetadataBase):
|
class BaseMetadata(ModelMetadataBase):
|
||||||
|
@ -21,7 +21,7 @@ from .config import (
|
|||||||
ModelVariantType,
|
ModelVariantType,
|
||||||
SchedulerPredictionType,
|
SchedulerPredictionType,
|
||||||
)
|
)
|
||||||
from .hash import FastModelHash
|
from .hash import ModelHash
|
||||||
from .util.model_util import lora_token_vector_length, read_checkpoint_meta
|
from .util.model_util import lora_token_vector_length, read_checkpoint_meta
|
||||||
|
|
||||||
CkptType = Dict[str, Any]
|
CkptType = Dict[str, Any]
|
||||||
@ -147,7 +147,7 @@ class ModelProbe(object):
|
|||||||
if not probe_class:
|
if not probe_class:
|
||||||
raise InvalidModelConfigException(f"Unhandled combination of {format_type} and {model_type}")
|
raise InvalidModelConfigException(f"Unhandled combination of {format_type} and {model_type}")
|
||||||
|
|
||||||
hash = FastModelHash.hash(model_path)
|
hash = ModelHash().hash(model_path)
|
||||||
probe = probe_class(model_path)
|
probe = probe_class(model_path)
|
||||||
|
|
||||||
fields["path"] = model_path.as_posix()
|
fields["path"] = model_path.as_posix()
|
||||||
|
@ -4,13 +4,11 @@ Initialization file for the invokeai.backend.stable_diffusion package
|
|||||||
|
|
||||||
from .diffusers_pipeline import PipelineIntermediateState, StableDiffusionGeneratorPipeline # noqa: F401
|
from .diffusers_pipeline import PipelineIntermediateState, StableDiffusionGeneratorPipeline # noqa: F401
|
||||||
from .diffusion import InvokeAIDiffuserComponent # noqa: F401
|
from .diffusion import InvokeAIDiffuserComponent # noqa: F401
|
||||||
from .diffusion.cross_attention_map_saving import AttentionMapSaver # noqa: F401
|
|
||||||
from .seamless import set_seamless # noqa: F401
|
from .seamless import set_seamless # noqa: F401
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"PipelineIntermediateState",
|
"PipelineIntermediateState",
|
||||||
"StableDiffusionGeneratorPipeline",
|
"StableDiffusionGeneratorPipeline",
|
||||||
"InvokeAIDiffuserComponent",
|
"InvokeAIDiffuserComponent",
|
||||||
"AttentionMapSaver",
|
|
||||||
"set_seamless",
|
"set_seamless",
|
||||||
]
|
]
|
||||||
|
@ -12,7 +12,6 @@ import torch
|
|||||||
import torchvision.transforms as T
|
import torchvision.transforms as T
|
||||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||||
from diffusers.models.controlnet import ControlNetModel
|
from diffusers.models.controlnet import ControlNetModel
|
||||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
|
||||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline
|
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline
|
||||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||||
from diffusers.schedulers import KarrasDiffusionSchedulers
|
from diffusers.schedulers import KarrasDiffusionSchedulers
|
||||||
@ -26,9 +25,9 @@ from invokeai.app.services.config import InvokeAIAppConfig
|
|||||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||||
from invokeai.backend.ip_adapter.unet_patcher import UNetPatcher
|
from invokeai.backend.ip_adapter.unet_patcher import UNetPatcher
|
||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData
|
||||||
|
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||||
|
|
||||||
from ..util import auto_detect_slice_size, normalize_device
|
from ..util import auto_detect_slice_size, normalize_device
|
||||||
from .diffusion import AttentionMapSaver, InvokeAIDiffuserComponent
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -39,7 +38,6 @@ class PipelineIntermediateState:
|
|||||||
timestep: int
|
timestep: int
|
||||||
latents: torch.Tensor
|
latents: torch.Tensor
|
||||||
predicted_original: Optional[torch.Tensor] = None
|
predicted_original: Optional[torch.Tensor] = None
|
||||||
attention_map_saver: Optional[AttentionMapSaver] = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -190,19 +188,6 @@ class T2IAdapterData:
|
|||||||
end_step_percent: float = Field(default=1.0)
|
end_step_percent: float = Field(default=1.0)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class InvokeAIStableDiffusionPipelineOutput(StableDiffusionPipelineOutput):
|
|
||||||
r"""
|
|
||||||
Output class for InvokeAI's Stable Diffusion pipeline.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
attention_map_saver (`AttentionMapSaver`): Object containing attention maps that can be displayed to the user
|
|
||||||
after generation completes. Optional.
|
|
||||||
"""
|
|
||||||
|
|
||||||
attention_map_saver: Optional[AttentionMapSaver]
|
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||||
r"""
|
r"""
|
||||||
Pipeline for text-to-image generation using Stable Diffusion.
|
Pipeline for text-to-image generation using Stable Diffusion.
|
||||||
@ -343,9 +328,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
masked_latents: Optional[torch.Tensor] = None,
|
masked_latents: Optional[torch.Tensor] = None,
|
||||||
gradient_mask: Optional[bool] = False,
|
gradient_mask: Optional[bool] = False,
|
||||||
seed: Optional[int] = None,
|
seed: Optional[int] = None,
|
||||||
) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]:
|
) -> torch.Tensor:
|
||||||
if init_timestep.shape[0] == 0:
|
if init_timestep.shape[0] == 0:
|
||||||
return latents, None
|
return latents
|
||||||
|
|
||||||
if additional_guidance is None:
|
if additional_guidance is None:
|
||||||
additional_guidance = []
|
additional_guidance = []
|
||||||
@ -385,7 +370,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
additional_guidance.append(AddsMaskGuidance(mask, orig_latents, self.scheduler, noise, gradient_mask))
|
additional_guidance.append(AddsMaskGuidance(mask, orig_latents, self.scheduler, noise, gradient_mask))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
latents, attention_map_saver = self.generate_latents_from_embeddings(
|
latents = self.generate_latents_from_embeddings(
|
||||||
latents,
|
latents,
|
||||||
timesteps,
|
timesteps,
|
||||||
conditioning_data,
|
conditioning_data,
|
||||||
@ -402,7 +387,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
if mask is not None and not gradient_mask:
|
if mask is not None and not gradient_mask:
|
||||||
latents = torch.lerp(orig_latents, latents.to(dtype=orig_latents.dtype), mask.to(dtype=orig_latents.dtype))
|
latents = torch.lerp(orig_latents, latents.to(dtype=orig_latents.dtype), mask.to(dtype=orig_latents.dtype))
|
||||||
|
|
||||||
return latents, attention_map_saver
|
return latents
|
||||||
|
|
||||||
def generate_latents_from_embeddings(
|
def generate_latents_from_embeddings(
|
||||||
self,
|
self,
|
||||||
@ -415,16 +400,15 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
ip_adapter_data: Optional[list[IPAdapterData]] = None,
|
ip_adapter_data: Optional[list[IPAdapterData]] = None,
|
||||||
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
|
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
|
||||||
callback: Callable[[PipelineIntermediateState], None] = None,
|
callback: Callable[[PipelineIntermediateState], None] = None,
|
||||||
):
|
) -> torch.Tensor:
|
||||||
self._adjust_memory_efficient_attention(latents)
|
self._adjust_memory_efficient_attention(latents)
|
||||||
if additional_guidance is None:
|
if additional_guidance is None:
|
||||||
additional_guidance = []
|
additional_guidance = []
|
||||||
|
|
||||||
batch_size = latents.shape[0]
|
batch_size = latents.shape[0]
|
||||||
attention_map_saver: Optional[AttentionMapSaver] = None
|
|
||||||
|
|
||||||
if timesteps.shape[0] == 0:
|
if timesteps.shape[0] == 0:
|
||||||
return latents, attention_map_saver
|
return latents
|
||||||
|
|
||||||
ip_adapter_unet_patcher = None
|
ip_adapter_unet_patcher = None
|
||||||
extra_conditioning_info = conditioning_data.text_embeddings.extra_conditioning
|
extra_conditioning_info = conditioning_data.text_embeddings.extra_conditioning
|
||||||
@ -432,7 +416,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
attn_ctx = self.invokeai_diffuser.custom_attention_context(
|
attn_ctx = self.invokeai_diffuser.custom_attention_context(
|
||||||
self.invokeai_diffuser.model,
|
self.invokeai_diffuser.model,
|
||||||
extra_conditioning_info=extra_conditioning_info,
|
extra_conditioning_info=extra_conditioning_info,
|
||||||
step_count=len(self.scheduler.timesteps),
|
|
||||||
)
|
)
|
||||||
self.use_ip_adapter = False
|
self.use_ip_adapter = False
|
||||||
elif ip_adapter_data is not None:
|
elif ip_adapter_data is not None:
|
||||||
@ -483,13 +466,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
|
|
||||||
predicted_original = getattr(step_output, "pred_original_sample", None)
|
predicted_original = getattr(step_output, "pred_original_sample", None)
|
||||||
|
|
||||||
# TODO resuscitate attention map saving
|
|
||||||
# if i == len(timesteps)-1 and extra_conditioning_info is not None:
|
|
||||||
# eos_token_index = extra_conditioning_info.tokens_count_including_eos_bos - 1
|
|
||||||
# attention_map_token_ids = range(1, eos_token_index)
|
|
||||||
# attention_map_saver = AttentionMapSaver(token_ids=attention_map_token_ids, latents_shape=latents.shape[-2:])
|
|
||||||
# self.invokeai_diffuser.setup_attention_map_saving(attention_map_saver)
|
|
||||||
|
|
||||||
if callback is not None:
|
if callback is not None:
|
||||||
callback(
|
callback(
|
||||||
PipelineIntermediateState(
|
PipelineIntermediateState(
|
||||||
@ -499,11 +475,10 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
timestep=int(t),
|
timestep=int(t),
|
||||||
latents=latents,
|
latents=latents,
|
||||||
predicted_original=predicted_original,
|
predicted_original=predicted_original,
|
||||||
attention_map_saver=attention_map_saver,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return latents, attention_map_saver
|
return latents
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def step(
|
def step(
|
||||||
|
@ -2,6 +2,4 @@
|
|||||||
Initialization file for invokeai.models.diffusion
|
Initialization file for invokeai.models.diffusion
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .cross_attention_control import InvokeAICrossAttentionMixin # noqa: F401
|
|
||||||
from .cross_attention_map_saving import AttentionMapSaver # noqa: F401
|
|
||||||
from .shared_invokeai_diffusion import InvokeAIDiffuserComponent # noqa: F401
|
from .shared_invokeai_diffusion import InvokeAIDiffuserComponent # noqa: F401
|
||||||
|
@ -3,19 +3,13 @@
|
|||||||
|
|
||||||
|
|
||||||
import enum
|
import enum
|
||||||
import math
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Callable, Optional
|
from typing import Optional
|
||||||
|
|
||||||
import diffusers
|
|
||||||
import psutil
|
|
||||||
import torch
|
import torch
|
||||||
from compel.cross_attention_control import Arguments
|
from compel.cross_attention_control import Arguments
|
||||||
from diffusers.models.attention_processor import Attention, AttentionProcessor, AttnProcessor, SlicedAttnProcessor
|
from diffusers.models.attention_processor import Attention, SlicedAttnProcessor
|
||||||
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
||||||
from torch import nn
|
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
|
|
||||||
from ...util import torch_dtype
|
from ...util import torch_dtype
|
||||||
|
|
||||||
@ -25,72 +19,14 @@ class CrossAttentionType(enum.Enum):
|
|||||||
TOKENS = 2
|
TOKENS = 2
|
||||||
|
|
||||||
|
|
||||||
class Context:
|
class CrossAttnControlContext:
|
||||||
cross_attention_mask: Optional[torch.Tensor]
|
def __init__(self, arguments: Arguments):
|
||||||
cross_attention_index_map: Optional[torch.Tensor]
|
|
||||||
|
|
||||||
class Action(enum.Enum):
|
|
||||||
NONE = 0
|
|
||||||
SAVE = (1,)
|
|
||||||
APPLY = 2
|
|
||||||
|
|
||||||
def __init__(self, arguments: Arguments, step_count: int):
|
|
||||||
"""
|
"""
|
||||||
:param arguments: Arguments for the cross-attention control process
|
:param arguments: Arguments for the cross-attention control process
|
||||||
:param step_count: The absolute total number of steps of diffusion (for img2img this is likely larger than the number of steps that will actually run)
|
|
||||||
"""
|
"""
|
||||||
self.cross_attention_mask = None
|
self.cross_attention_mask: Optional[torch.Tensor] = None
|
||||||
self.cross_attention_index_map = None
|
self.cross_attention_index_map: Optional[torch.Tensor] = None
|
||||||
self.self_cross_attention_action = Context.Action.NONE
|
|
||||||
self.tokens_cross_attention_action = Context.Action.NONE
|
|
||||||
self.arguments = arguments
|
self.arguments = arguments
|
||||||
self.step_count = step_count
|
|
||||||
|
|
||||||
self.self_cross_attention_module_identifiers = []
|
|
||||||
self.tokens_cross_attention_module_identifiers = []
|
|
||||||
|
|
||||||
self.saved_cross_attention_maps = {}
|
|
||||||
|
|
||||||
self.clear_requests(cleanup=True)
|
|
||||||
|
|
||||||
def register_cross_attention_modules(self, model):
|
|
||||||
for name, _module in get_cross_attention_modules(model, CrossAttentionType.SELF):
|
|
||||||
if name in self.self_cross_attention_module_identifiers:
|
|
||||||
raise AssertionError(f"name {name} cannot appear more than once")
|
|
||||||
self.self_cross_attention_module_identifiers.append(name)
|
|
||||||
for name, _module in get_cross_attention_modules(model, CrossAttentionType.TOKENS):
|
|
||||||
if name in self.tokens_cross_attention_module_identifiers:
|
|
||||||
raise AssertionError(f"name {name} cannot appear more than once")
|
|
||||||
self.tokens_cross_attention_module_identifiers.append(name)
|
|
||||||
|
|
||||||
def request_save_attention_maps(self, cross_attention_type: CrossAttentionType):
|
|
||||||
if cross_attention_type == CrossAttentionType.SELF:
|
|
||||||
self.self_cross_attention_action = Context.Action.SAVE
|
|
||||||
else:
|
|
||||||
self.tokens_cross_attention_action = Context.Action.SAVE
|
|
||||||
|
|
||||||
def request_apply_saved_attention_maps(self, cross_attention_type: CrossAttentionType):
|
|
||||||
if cross_attention_type == CrossAttentionType.SELF:
|
|
||||||
self.self_cross_attention_action = Context.Action.APPLY
|
|
||||||
else:
|
|
||||||
self.tokens_cross_attention_action = Context.Action.APPLY
|
|
||||||
|
|
||||||
def is_tokens_cross_attention(self, module_identifier) -> bool:
|
|
||||||
return module_identifier in self.tokens_cross_attention_module_identifiers
|
|
||||||
|
|
||||||
def get_should_save_maps(self, module_identifier: str) -> bool:
|
|
||||||
if module_identifier in self.self_cross_attention_module_identifiers:
|
|
||||||
return self.self_cross_attention_action == Context.Action.SAVE
|
|
||||||
elif module_identifier in self.tokens_cross_attention_module_identifiers:
|
|
||||||
return self.tokens_cross_attention_action == Context.Action.SAVE
|
|
||||||
return False
|
|
||||||
|
|
||||||
def get_should_apply_saved_maps(self, module_identifier: str) -> bool:
|
|
||||||
if module_identifier in self.self_cross_attention_module_identifiers:
|
|
||||||
return self.self_cross_attention_action == Context.Action.APPLY
|
|
||||||
elif module_identifier in self.tokens_cross_attention_module_identifiers:
|
|
||||||
return self.tokens_cross_attention_action == Context.Action.APPLY
|
|
||||||
return False
|
|
||||||
|
|
||||||
def get_active_cross_attention_control_types_for_step(
|
def get_active_cross_attention_control_types_for_step(
|
||||||
self, percent_through: float = None
|
self, percent_through: float = None
|
||||||
@ -111,219 +47,8 @@ class Context:
|
|||||||
to_control.append(CrossAttentionType.TOKENS)
|
to_control.append(CrossAttentionType.TOKENS)
|
||||||
return to_control
|
return to_control
|
||||||
|
|
||||||
def save_slice(
|
|
||||||
self,
|
|
||||||
identifier: str,
|
|
||||||
slice: torch.Tensor,
|
|
||||||
dim: Optional[int],
|
|
||||||
offset: int,
|
|
||||||
slice_size: Optional[int],
|
|
||||||
):
|
|
||||||
if identifier not in self.saved_cross_attention_maps:
|
|
||||||
self.saved_cross_attention_maps[identifier] = {
|
|
||||||
"dim": dim,
|
|
||||||
"slice_size": slice_size,
|
|
||||||
"slices": {offset or 0: slice},
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
self.saved_cross_attention_maps[identifier]["slices"][offset or 0] = slice
|
|
||||||
|
|
||||||
def get_slice(
|
def setup_cross_attention_control_attention_processors(unet: UNet2DConditionModel, context: CrossAttnControlContext):
|
||||||
self,
|
|
||||||
identifier: str,
|
|
||||||
requested_dim: Optional[int],
|
|
||||||
requested_offset: int,
|
|
||||||
slice_size: int,
|
|
||||||
):
|
|
||||||
saved_attention_dict = self.saved_cross_attention_maps[identifier]
|
|
||||||
if requested_dim is None:
|
|
||||||
if saved_attention_dict["dim"] is not None:
|
|
||||||
raise RuntimeError(f"dim mismatch: expected dim=None, have {saved_attention_dict['dim']}")
|
|
||||||
return saved_attention_dict["slices"][0]
|
|
||||||
|
|
||||||
if saved_attention_dict["dim"] == requested_dim:
|
|
||||||
if slice_size != saved_attention_dict["slice_size"]:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"slice_size mismatch: expected slice_size={slice_size}, have {saved_attention_dict['slice_size']}"
|
|
||||||
)
|
|
||||||
return saved_attention_dict["slices"][requested_offset]
|
|
||||||
|
|
||||||
if saved_attention_dict["dim"] is None:
|
|
||||||
whole_saved_attention = saved_attention_dict["slices"][0]
|
|
||||||
if requested_dim == 0:
|
|
||||||
return whole_saved_attention[requested_offset : requested_offset + slice_size]
|
|
||||||
elif requested_dim == 1:
|
|
||||||
return whole_saved_attention[:, requested_offset : requested_offset + slice_size]
|
|
||||||
|
|
||||||
raise RuntimeError(f"Cannot convert dim {saved_attention_dict['dim']} to requested dim {requested_dim}")
|
|
||||||
|
|
||||||
def get_slicing_strategy(self, identifier: str) -> tuple[Optional[int], Optional[int]]:
|
|
||||||
saved_attention = self.saved_cross_attention_maps.get(identifier, None)
|
|
||||||
if saved_attention is None:
|
|
||||||
return None, None
|
|
||||||
return saved_attention["dim"], saved_attention["slice_size"]
|
|
||||||
|
|
||||||
def clear_requests(self, cleanup=True):
|
|
||||||
self.tokens_cross_attention_action = Context.Action.NONE
|
|
||||||
self.self_cross_attention_action = Context.Action.NONE
|
|
||||||
if cleanup:
|
|
||||||
self.saved_cross_attention_maps = {}
|
|
||||||
|
|
||||||
def offload_saved_attention_slices_to_cpu(self):
|
|
||||||
for _key, map_dict in self.saved_cross_attention_maps.items():
|
|
||||||
for offset, slice in map_dict["slices"].items():
|
|
||||||
map_dict[offset] = slice.to("cpu")
|
|
||||||
|
|
||||||
|
|
||||||
class InvokeAICrossAttentionMixin:
|
|
||||||
"""
|
|
||||||
Enable InvokeAI-flavoured Attention calculation, which does aggressive low-memory slicing and calls
|
|
||||||
through both to an attention_slice_wrangler and a slicing_strategy_getter for custom attention map wrangling
|
|
||||||
and dymamic slicing strategy selection.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.mem_total_gb = psutil.virtual_memory().total // (1 << 30)
|
|
||||||
self.attention_slice_wrangler = None
|
|
||||||
self.slicing_strategy_getter = None
|
|
||||||
self.attention_slice_calculated_callback = None
|
|
||||||
|
|
||||||
def set_attention_slice_wrangler(
|
|
||||||
self,
|
|
||||||
wrangler: Optional[Callable[[nn.Module, torch.Tensor, int, int, int], torch.Tensor]],
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Set custom attention calculator to be called when attention is calculated
|
|
||||||
:param wrangler: Callback, with args (module, suggested_attention_slice, dim, offset, slice_size),
|
|
||||||
which returns either the suggested_attention_slice or an adjusted equivalent.
|
|
||||||
`module` is the current Attention module for which the callback is being invoked.
|
|
||||||
`suggested_attention_slice` is the default-calculated attention slice
|
|
||||||
`dim` is -1 if the attenion map has not been sliced, or 0 or 1 for dimension-0 or dimension-1 slicing.
|
|
||||||
If `dim` is >= 0, `offset` and `slice_size` specify the slice start and length.
|
|
||||||
|
|
||||||
Pass None to use the default attention calculation.
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
self.attention_slice_wrangler = wrangler
|
|
||||||
|
|
||||||
def set_slicing_strategy_getter(self, getter: Optional[Callable[[nn.Module], tuple[int, int]]]):
|
|
||||||
self.slicing_strategy_getter = getter
|
|
||||||
|
|
||||||
def set_attention_slice_calculated_callback(self, callback: Optional[Callable[[torch.Tensor], None]]):
|
|
||||||
self.attention_slice_calculated_callback = callback
|
|
||||||
|
|
||||||
def einsum_lowest_level(self, query, key, value, dim, offset, slice_size):
|
|
||||||
# calculate attention scores
|
|
||||||
# attention_scores = torch.einsum('b i d, b j d -> b i j', q, k)
|
|
||||||
attention_scores = torch.baddbmm(
|
|
||||||
torch.empty(
|
|
||||||
query.shape[0],
|
|
||||||
query.shape[1],
|
|
||||||
key.shape[1],
|
|
||||||
dtype=query.dtype,
|
|
||||||
device=query.device,
|
|
||||||
),
|
|
||||||
query,
|
|
||||||
key.transpose(-1, -2),
|
|
||||||
beta=0,
|
|
||||||
alpha=self.scale,
|
|
||||||
)
|
|
||||||
|
|
||||||
# calculate attention slice by taking the best scores for each latent pixel
|
|
||||||
default_attention_slice = attention_scores.softmax(dim=-1, dtype=attention_scores.dtype)
|
|
||||||
attention_slice_wrangler = self.attention_slice_wrangler
|
|
||||||
if attention_slice_wrangler is not None:
|
|
||||||
attention_slice = attention_slice_wrangler(self, default_attention_slice, dim, offset, slice_size)
|
|
||||||
else:
|
|
||||||
attention_slice = default_attention_slice
|
|
||||||
|
|
||||||
if self.attention_slice_calculated_callback is not None:
|
|
||||||
self.attention_slice_calculated_callback(attention_slice, dim, offset, slice_size)
|
|
||||||
|
|
||||||
hidden_states = torch.bmm(attention_slice, value)
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
def einsum_op_slice_dim0(self, q, k, v, slice_size):
|
|
||||||
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
|
||||||
for i in range(0, q.shape[0], slice_size):
|
|
||||||
end = i + slice_size
|
|
||||||
r[i:end] = self.einsum_lowest_level(q[i:end], k[i:end], v[i:end], dim=0, offset=i, slice_size=slice_size)
|
|
||||||
return r
|
|
||||||
|
|
||||||
def einsum_op_slice_dim1(self, q, k, v, slice_size):
|
|
||||||
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
|
||||||
for i in range(0, q.shape[1], slice_size):
|
|
||||||
end = i + slice_size
|
|
||||||
r[:, i:end] = self.einsum_lowest_level(q[:, i:end], k, v, dim=1, offset=i, slice_size=slice_size)
|
|
||||||
return r
|
|
||||||
|
|
||||||
def einsum_op_mps_v1(self, q, k, v):
|
|
||||||
if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
|
|
||||||
return self.einsum_lowest_level(q, k, v, None, None, None)
|
|
||||||
else:
|
|
||||||
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
|
|
||||||
return self.einsum_op_slice_dim1(q, k, v, slice_size)
|
|
||||||
|
|
||||||
def einsum_op_mps_v2(self, q, k, v):
|
|
||||||
if self.mem_total_gb > 8 and q.shape[1] <= 4096:
|
|
||||||
return self.einsum_lowest_level(q, k, v, None, None, None)
|
|
||||||
else:
|
|
||||||
return self.einsum_op_slice_dim0(q, k, v, 1)
|
|
||||||
|
|
||||||
def einsum_op_tensor_mem(self, q, k, v, max_tensor_mb):
|
|
||||||
size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)
|
|
||||||
if size_mb <= max_tensor_mb:
|
|
||||||
return self.einsum_lowest_level(q, k, v, None, None, None)
|
|
||||||
div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length()
|
|
||||||
if div <= q.shape[0]:
|
|
||||||
return self.einsum_op_slice_dim0(q, k, v, q.shape[0] // div)
|
|
||||||
return self.einsum_op_slice_dim1(q, k, v, max(q.shape[1] // div, 1))
|
|
||||||
|
|
||||||
def einsum_op_cuda(self, q, k, v):
|
|
||||||
# check if we already have a slicing strategy (this should only happen during cross-attention controlled generation)
|
|
||||||
slicing_strategy_getter = self.slicing_strategy_getter
|
|
||||||
if slicing_strategy_getter is not None:
|
|
||||||
(dim, slice_size) = slicing_strategy_getter(self)
|
|
||||||
if dim is not None:
|
|
||||||
# print("using saved slicing strategy with dim", dim, "slice size", slice_size)
|
|
||||||
if dim == 0:
|
|
||||||
return self.einsum_op_slice_dim0(q, k, v, slice_size)
|
|
||||||
elif dim == 1:
|
|
||||||
return self.einsum_op_slice_dim1(q, k, v, slice_size)
|
|
||||||
|
|
||||||
# fallback for when there is no saved strategy, or saved strategy does not slice
|
|
||||||
mem_free_total = get_mem_free_total(q.device)
|
|
||||||
# Divide factor of safety as there's copying and fragmentation
|
|
||||||
return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
|
|
||||||
|
|
||||||
def get_invokeai_attention_mem_efficient(self, q, k, v):
|
|
||||||
if q.device.type == "cuda":
|
|
||||||
# print("in get_attention_mem_efficient with q shape", q.shape, ", k shape", k.shape, ", free memory is", get_mem_free_total(q.device))
|
|
||||||
return self.einsum_op_cuda(q, k, v)
|
|
||||||
|
|
||||||
if q.device.type == "mps" or q.device.type == "cpu":
|
|
||||||
if self.mem_total_gb >= 32:
|
|
||||||
return self.einsum_op_mps_v1(q, k, v)
|
|
||||||
return self.einsum_op_mps_v2(q, k, v)
|
|
||||||
|
|
||||||
# Smaller slices are faster due to L2/L3/SLC caches.
|
|
||||||
# Tested on i7 with 8MB L3 cache.
|
|
||||||
return self.einsum_op_tensor_mem(q, k, v, 32)
|
|
||||||
|
|
||||||
|
|
||||||
def restore_default_cross_attention(
|
|
||||||
model,
|
|
||||||
is_running_diffusers: bool,
|
|
||||||
restore_attention_processor: Optional[AttentionProcessor] = None,
|
|
||||||
):
|
|
||||||
if is_running_diffusers:
|
|
||||||
unet = model
|
|
||||||
unet.set_attn_processor(restore_attention_processor or AttnProcessor())
|
|
||||||
else:
|
|
||||||
remove_attention_function(model)
|
|
||||||
|
|
||||||
|
|
||||||
def setup_cross_attention_control_attention_processors(unet: UNet2DConditionModel, context: Context):
|
|
||||||
"""
|
"""
|
||||||
Inject attention parameters and functions into the passed in model to enable cross attention editing.
|
Inject attention parameters and functions into the passed in model to enable cross attention editing.
|
||||||
|
|
||||||
@ -362,170 +87,6 @@ def setup_cross_attention_control_attention_processors(unet: UNet2DConditionMode
|
|||||||
unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size))
|
unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size))
|
||||||
|
|
||||||
|
|
||||||
def get_cross_attention_modules(model, which: CrossAttentionType) -> list[tuple[str, InvokeAICrossAttentionMixin]]:
|
|
||||||
cross_attention_class: type = InvokeAIDiffusersCrossAttention
|
|
||||||
which_attn = "attn1" if which is CrossAttentionType.SELF else "attn2"
|
|
||||||
attention_module_tuples = [
|
|
||||||
(name, module)
|
|
||||||
for name, module in model.named_modules()
|
|
||||||
if isinstance(module, cross_attention_class) and which_attn in name
|
|
||||||
]
|
|
||||||
cross_attention_modules_in_model_count = len(attention_module_tuples)
|
|
||||||
expected_count = 16
|
|
||||||
if cross_attention_modules_in_model_count != expected_count:
|
|
||||||
# non-fatal error but .swap() won't work.
|
|
||||||
logger.error(
|
|
||||||
f"Error! CrossAttentionControl found an unexpected number of {cross_attention_class} modules in the model "
|
|
||||||
f"(expected {expected_count}, found {cross_attention_modules_in_model_count}). Either monkey-patching "
|
|
||||||
"failed or some assumption has changed about the structure of the model itself. Please fix the "
|
|
||||||
f"monkey-patching, and/or update the {expected_count} above to an appropriate number, and/or find and "
|
|
||||||
"inform someone who knows what it means. This error is non-fatal, but it is likely that .swap() and "
|
|
||||||
"attention map display will not work properly until it is fixed."
|
|
||||||
)
|
|
||||||
return attention_module_tuples
|
|
||||||
|
|
||||||
|
|
||||||
def inject_attention_function(unet, context: Context):
|
|
||||||
# ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276
|
|
||||||
|
|
||||||
def attention_slice_wrangler(module, suggested_attention_slice: torch.Tensor, dim, offset, slice_size):
|
|
||||||
# memory_usage = suggested_attention_slice.element_size() * suggested_attention_slice.nelement()
|
|
||||||
|
|
||||||
attention_slice = suggested_attention_slice
|
|
||||||
|
|
||||||
if context.get_should_save_maps(module.identifier):
|
|
||||||
# print(module.identifier, "saving suggested_attention_slice of shape",
|
|
||||||
# suggested_attention_slice.shape, "dim", dim, "offset", offset)
|
|
||||||
slice_to_save = attention_slice.to("cpu") if dim is not None else attention_slice
|
|
||||||
context.save_slice(
|
|
||||||
module.identifier,
|
|
||||||
slice_to_save,
|
|
||||||
dim=dim,
|
|
||||||
offset=offset,
|
|
||||||
slice_size=slice_size,
|
|
||||||
)
|
|
||||||
elif context.get_should_apply_saved_maps(module.identifier):
|
|
||||||
# print(module.identifier, "applying saved attention slice for dim", dim, "offset", offset)
|
|
||||||
saved_attention_slice = context.get_slice(module.identifier, dim, offset, slice_size)
|
|
||||||
|
|
||||||
# slice may have been offloaded to CPU
|
|
||||||
saved_attention_slice = saved_attention_slice.to(suggested_attention_slice.device)
|
|
||||||
|
|
||||||
if context.is_tokens_cross_attention(module.identifier):
|
|
||||||
index_map = context.cross_attention_index_map
|
|
||||||
remapped_saved_attention_slice = torch.index_select(saved_attention_slice, -1, index_map)
|
|
||||||
this_attention_slice = suggested_attention_slice
|
|
||||||
|
|
||||||
mask = context.cross_attention_mask.to(torch_dtype(suggested_attention_slice.device))
|
|
||||||
saved_mask = mask
|
|
||||||
this_mask = 1 - mask
|
|
||||||
attention_slice = remapped_saved_attention_slice * saved_mask + this_attention_slice * this_mask
|
|
||||||
else:
|
|
||||||
# just use everything
|
|
||||||
attention_slice = saved_attention_slice
|
|
||||||
|
|
||||||
return attention_slice
|
|
||||||
|
|
||||||
cross_attention_modules = get_cross_attention_modules(
|
|
||||||
unet, CrossAttentionType.TOKENS
|
|
||||||
) + get_cross_attention_modules(unet, CrossAttentionType.SELF)
|
|
||||||
for identifier, module in cross_attention_modules:
|
|
||||||
module.identifier = identifier
|
|
||||||
try:
|
|
||||||
module.set_attention_slice_wrangler(attention_slice_wrangler)
|
|
||||||
module.set_slicing_strategy_getter(lambda module: context.get_slicing_strategy(identifier)) # noqa: B023
|
|
||||||
except AttributeError as e:
|
|
||||||
if is_attribute_error_about(e, "set_attention_slice_wrangler"):
|
|
||||||
print(f"TODO: implement set_attention_slice_wrangler for {type(module)}") # TODO
|
|
||||||
else:
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
def remove_attention_function(unet):
|
|
||||||
cross_attention_modules = get_cross_attention_modules(
|
|
||||||
unet, CrossAttentionType.TOKENS
|
|
||||||
) + get_cross_attention_modules(unet, CrossAttentionType.SELF)
|
|
||||||
for _identifier, module in cross_attention_modules:
|
|
||||||
try:
|
|
||||||
# clear wrangler callback
|
|
||||||
module.set_attention_slice_wrangler(None)
|
|
||||||
module.set_slicing_strategy_getter(None)
|
|
||||||
except AttributeError as e:
|
|
||||||
if is_attribute_error_about(e, "set_attention_slice_wrangler"):
|
|
||||||
print(f"TODO: implement set_attention_slice_wrangler for {type(module)}")
|
|
||||||
else:
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
def is_attribute_error_about(error: AttributeError, attribute: str):
|
|
||||||
if hasattr(error, "name"): # Python 3.10
|
|
||||||
return error.name == attribute
|
|
||||||
else: # Python 3.9
|
|
||||||
return attribute in str(error)
|
|
||||||
|
|
||||||
|
|
||||||
def get_mem_free_total(device):
|
|
||||||
# only on cuda
|
|
||||||
if not torch.cuda.is_available():
|
|
||||||
return None
|
|
||||||
stats = torch.cuda.memory_stats(device)
|
|
||||||
mem_active = stats["active_bytes.all.current"]
|
|
||||||
mem_reserved = stats["reserved_bytes.all.current"]
|
|
||||||
mem_free_cuda, _ = torch.cuda.mem_get_info(device)
|
|
||||||
mem_free_torch = mem_reserved - mem_active
|
|
||||||
mem_free_total = mem_free_cuda + mem_free_torch
|
|
||||||
return mem_free_total
|
|
||||||
|
|
||||||
|
|
||||||
class InvokeAIDiffusersCrossAttention(diffusers.models.attention.Attention, InvokeAICrossAttentionMixin):
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
InvokeAICrossAttentionMixin.__init__(self)
|
|
||||||
|
|
||||||
def _attention(self, query, key, value, attention_mask=None):
|
|
||||||
# default_result = super()._attention(query, key, value)
|
|
||||||
if attention_mask is not None:
|
|
||||||
print(f"{type(self).__name__} ignoring passed-in attention_mask")
|
|
||||||
attention_result = self.get_invokeai_attention_mem_efficient(query, key, value)
|
|
||||||
|
|
||||||
hidden_states = self.reshape_batch_dim_to_heads(attention_result)
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
## 🧨diffusers implementation follows
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
# base implementation
|
|
||||||
|
|
||||||
class AttnProcessor:
|
|
||||||
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
|
||||||
batch_size, sequence_length, _ = hidden_states.shape
|
|
||||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
|
|
||||||
|
|
||||||
query = attn.to_q(hidden_states)
|
|
||||||
query = attn.head_to_batch_dim(query)
|
|
||||||
|
|
||||||
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
|
|
||||||
key = attn.to_k(encoder_hidden_states)
|
|
||||||
value = attn.to_v(encoder_hidden_states)
|
|
||||||
key = attn.head_to_batch_dim(key)
|
|
||||||
value = attn.head_to_batch_dim(value)
|
|
||||||
|
|
||||||
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
|
||||||
hidden_states = torch.bmm(attention_probs, value)
|
|
||||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
|
||||||
|
|
||||||
# linear proj
|
|
||||||
hidden_states = attn.to_out[0](hidden_states)
|
|
||||||
# dropout
|
|
||||||
hidden_states = attn.to_out[1](hidden_states)
|
|
||||||
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SwapCrossAttnContext:
|
class SwapCrossAttnContext:
|
||||||
modified_text_embeddings: torch.Tensor
|
modified_text_embeddings: torch.Tensor
|
||||||
|
@ -1,100 +0,0 @@
|
|||||||
import math
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from PIL import Image
|
|
||||||
from torchvision.transforms.functional import InterpolationMode
|
|
||||||
from torchvision.transforms.functional import resize as tv_resize
|
|
||||||
|
|
||||||
|
|
||||||
class AttentionMapSaver:
|
|
||||||
def __init__(self, token_ids: range, latents_shape: torch.Size):
|
|
||||||
self.token_ids = token_ids
|
|
||||||
self.latents_shape = latents_shape
|
|
||||||
# self.collated_maps = #torch.zeros([len(token_ids), latents_shape[0], latents_shape[1]])
|
|
||||||
self.collated_maps: dict[str, torch.Tensor] = {}
|
|
||||||
|
|
||||||
def clear_maps(self):
|
|
||||||
self.collated_maps = {}
|
|
||||||
|
|
||||||
def add_attention_maps(self, maps: torch.Tensor, key: str):
|
|
||||||
"""
|
|
||||||
Accumulate the given attention maps and store by summing with existing maps at the passed-in key (if any).
|
|
||||||
:param maps: Attention maps to store. Expected shape [A, (H*W), N] where A is attention heads count, H and W are the map size (fixed per-key) and N is the number of tokens (typically 77).
|
|
||||||
:param key: Storage key. If a map already exists for this key it will be summed with the incoming data. In this case the maps sizes (H and W) should match.
|
|
||||||
:return: None
|
|
||||||
"""
|
|
||||||
key_and_size = f"{key}_{maps.shape[1]}"
|
|
||||||
|
|
||||||
# extract desired tokens
|
|
||||||
maps = maps[:, :, self.token_ids]
|
|
||||||
|
|
||||||
# merge attention heads to a single map per token
|
|
||||||
maps = torch.sum(maps, 0)
|
|
||||||
|
|
||||||
# store
|
|
||||||
if key_and_size not in self.collated_maps:
|
|
||||||
self.collated_maps[key_and_size] = torch.zeros_like(maps, device="cpu")
|
|
||||||
self.collated_maps[key_and_size] += maps.cpu()
|
|
||||||
|
|
||||||
def write_maps_to_disk(self, path: str):
|
|
||||||
pil_image = self.get_stacked_maps_image()
|
|
||||||
if pil_image is not None:
|
|
||||||
pil_image.save(path, "PNG")
|
|
||||||
|
|
||||||
def get_stacked_maps_image(self) -> Optional[Image.Image]:
|
|
||||||
"""
|
|
||||||
Scale all collected attention maps to the same size, blend them together and return as an image.
|
|
||||||
:return: An image containing a vertical stack of blended attention maps, one for each requested token.
|
|
||||||
"""
|
|
||||||
num_tokens = len(self.token_ids)
|
|
||||||
if num_tokens == 0:
|
|
||||||
return None
|
|
||||||
|
|
||||||
latents_height = self.latents_shape[0]
|
|
||||||
latents_width = self.latents_shape[1]
|
|
||||||
|
|
||||||
merged = None
|
|
||||||
|
|
||||||
for _key, maps in self.collated_maps.items():
|
|
||||||
# maps has shape [(H*W), N] for N tokens
|
|
||||||
# but we want [N, H, W]
|
|
||||||
this_scale_factor = math.sqrt(maps.shape[0] / (latents_width * latents_height))
|
|
||||||
this_maps_height = int(float(latents_height) * this_scale_factor)
|
|
||||||
this_maps_width = int(float(latents_width) * this_scale_factor)
|
|
||||||
# and we need to do some dimension juggling
|
|
||||||
maps = torch.reshape(
|
|
||||||
torch.swapdims(maps, 0, 1),
|
|
||||||
[num_tokens, this_maps_height, this_maps_width],
|
|
||||||
)
|
|
||||||
|
|
||||||
# scale to output size if necessary
|
|
||||||
if this_scale_factor != 1:
|
|
||||||
maps = tv_resize(maps, [latents_height, latents_width], InterpolationMode.BICUBIC)
|
|
||||||
|
|
||||||
# normalize
|
|
||||||
maps_min = torch.min(maps)
|
|
||||||
maps_range = torch.max(maps) - maps_min
|
|
||||||
# print(f"map {key} size {[this_maps_width, this_maps_height]} range {[maps_min, maps_min + maps_range]}")
|
|
||||||
maps_normalized = (maps - maps_min) / maps_range
|
|
||||||
# expand to (-0.1, 1.1) and clamp
|
|
||||||
maps_normalized_expanded = maps_normalized * 1.1 - 0.05
|
|
||||||
maps_normalized_expanded_clamped = torch.clamp(maps_normalized_expanded, 0, 1)
|
|
||||||
|
|
||||||
# merge together, producing a vertical stack
|
|
||||||
maps_stacked = torch.reshape(
|
|
||||||
maps_normalized_expanded_clamped,
|
|
||||||
[num_tokens * latents_height, latents_width],
|
|
||||||
)
|
|
||||||
|
|
||||||
if merged is None:
|
|
||||||
merged = maps_stacked
|
|
||||||
else:
|
|
||||||
# screen blend
|
|
||||||
merged = 1 - (1 - maps_stacked) * (1 - merged)
|
|
||||||
|
|
||||||
if merged is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
merged_bytes = merged.mul(0xFF).byte()
|
|
||||||
return Image.fromarray(merged_bytes.numpy(), mode="L")
|
|
@ -17,13 +17,11 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from .cross_attention_control import (
|
from .cross_attention_control import (
|
||||||
Context,
|
|
||||||
CrossAttentionType,
|
CrossAttentionType,
|
||||||
|
CrossAttnControlContext,
|
||||||
SwapCrossAttnContext,
|
SwapCrossAttnContext,
|
||||||
get_cross_attention_modules,
|
|
||||||
setup_cross_attention_control_attention_processors,
|
setup_cross_attention_control_attention_processors,
|
||||||
)
|
)
|
||||||
from .cross_attention_map_saving import AttentionMapSaver
|
|
||||||
|
|
||||||
ModelForwardCallback: TypeAlias = Union[
|
ModelForwardCallback: TypeAlias = Union[
|
||||||
# x, t, conditioning, Optional[cross-attention kwargs]
|
# x, t, conditioning, Optional[cross-attention kwargs]
|
||||||
@ -69,14 +67,12 @@ class InvokeAIDiffuserComponent:
|
|||||||
self,
|
self,
|
||||||
unet: UNet2DConditionModel,
|
unet: UNet2DConditionModel,
|
||||||
extra_conditioning_info: Optional[ExtraConditioningInfo],
|
extra_conditioning_info: Optional[ExtraConditioningInfo],
|
||||||
step_count: int,
|
|
||||||
):
|
):
|
||||||
old_attn_processors = unet.attn_processors
|
old_attn_processors = unet.attn_processors
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.cross_attention_control_context = Context(
|
self.cross_attention_control_context = CrossAttnControlContext(
|
||||||
arguments=extra_conditioning_info.cross_attention_control_args,
|
arguments=extra_conditioning_info.cross_attention_control_args,
|
||||||
step_count=step_count,
|
|
||||||
)
|
)
|
||||||
setup_cross_attention_control_attention_processors(
|
setup_cross_attention_control_attention_processors(
|
||||||
unet,
|
unet,
|
||||||
@ -87,27 +83,6 @@ class InvokeAIDiffuserComponent:
|
|||||||
finally:
|
finally:
|
||||||
self.cross_attention_control_context = None
|
self.cross_attention_control_context = None
|
||||||
unet.set_attn_processor(old_attn_processors)
|
unet.set_attn_processor(old_attn_processors)
|
||||||
# TODO resuscitate attention map saving
|
|
||||||
# self.remove_attention_map_saving()
|
|
||||||
|
|
||||||
def setup_attention_map_saving(self, saver: AttentionMapSaver):
|
|
||||||
def callback(slice, dim, offset, slice_size, key):
|
|
||||||
if dim is not None:
|
|
||||||
# sliced tokens attention map saving is not implemented
|
|
||||||
return
|
|
||||||
saver.add_attention_maps(slice, key)
|
|
||||||
|
|
||||||
tokens_cross_attention_modules = get_cross_attention_modules(self.model, CrossAttentionType.TOKENS)
|
|
||||||
for identifier, module in tokens_cross_attention_modules:
|
|
||||||
key = "down" if identifier.startswith("down") else "up" if identifier.startswith("up") else "mid"
|
|
||||||
module.set_attention_slice_calculated_callback(
|
|
||||||
lambda slice, dim, offset, slice_size, key=key: callback(slice, dim, offset, slice_size, key)
|
|
||||||
)
|
|
||||||
|
|
||||||
def remove_attention_map_saving(self):
|
|
||||||
tokens_cross_attention_modules = get_cross_attention_modules(self.model, CrossAttentionType.TOKENS)
|
|
||||||
for _, module in tokens_cross_attention_modules:
|
|
||||||
module.set_attention_slice_calculated_callback(None)
|
|
||||||
|
|
||||||
def do_controlnet_step(
|
def do_controlnet_step(
|
||||||
self,
|
self,
|
||||||
@ -592,54 +567,3 @@ class InvokeAIDiffuserComponent:
|
|||||||
|
|
||||||
self.last_percent_through = percent_through
|
self.last_percent_through = percent_through
|
||||||
return latents.to(device=dev)
|
return latents.to(device=dev)
|
||||||
|
|
||||||
# todo: make this work
|
|
||||||
@classmethod
|
|
||||||
def apply_conjunction(cls, x, t, forward_func, uc, c_or_weighted_c_list, global_guidance_scale):
|
|
||||||
x_in = torch.cat([x] * 2)
|
|
||||||
t_in = torch.cat([t] * 2) # aka sigmas
|
|
||||||
|
|
||||||
deltas = None
|
|
||||||
uncond_latents = None
|
|
||||||
weighted_cond_list = (
|
|
||||||
c_or_weighted_c_list if isinstance(c_or_weighted_c_list, list) else [(c_or_weighted_c_list, 1)]
|
|
||||||
)
|
|
||||||
|
|
||||||
# below is fugly omg
|
|
||||||
conditionings = [uc] + [c for c, weight in weighted_cond_list]
|
|
||||||
weights = [1] + [weight for c, weight in weighted_cond_list]
|
|
||||||
chunk_count = math.ceil(len(conditionings) / 2)
|
|
||||||
deltas = None
|
|
||||||
for chunk_index in range(chunk_count):
|
|
||||||
offset = chunk_index * 2
|
|
||||||
chunk_size = min(2, len(conditionings) - offset)
|
|
||||||
|
|
||||||
if chunk_size == 1:
|
|
||||||
c_in = conditionings[offset]
|
|
||||||
latents_a = forward_func(x_in[:-1], t_in[:-1], c_in)
|
|
||||||
latents_b = None
|
|
||||||
else:
|
|
||||||
c_in = torch.cat(conditionings[offset : offset + 2])
|
|
||||||
latents_a, latents_b = forward_func(x_in, t_in, c_in).chunk(2)
|
|
||||||
|
|
||||||
# first chunk is guaranteed to be 2 entries: uncond_latents + first conditioining
|
|
||||||
if chunk_index == 0:
|
|
||||||
uncond_latents = latents_a
|
|
||||||
deltas = latents_b - uncond_latents
|
|
||||||
else:
|
|
||||||
deltas = torch.cat((deltas, latents_a - uncond_latents))
|
|
||||||
if latents_b is not None:
|
|
||||||
deltas = torch.cat((deltas, latents_b - uncond_latents))
|
|
||||||
|
|
||||||
# merge the weighted deltas together into a single merged delta
|
|
||||||
per_delta_weights = torch.tensor(weights[1:], dtype=deltas.dtype, device=deltas.device)
|
|
||||||
normalize = False
|
|
||||||
if normalize:
|
|
||||||
per_delta_weights /= torch.sum(per_delta_weights)
|
|
||||||
reshaped_weights = per_delta_weights.reshape(per_delta_weights.shape + (1, 1, 1))
|
|
||||||
deltas_merged = torch.sum(deltas * reshaped_weights, dim=0, keepdim=True)
|
|
||||||
|
|
||||||
# old_return_value = super().forward(x, sigma, uncond, cond, cond_scale)
|
|
||||||
# assert(0 == len(torch.nonzero(old_return_value - (uncond_latents + deltas_merged * cond_scale))))
|
|
||||||
|
|
||||||
return uncond_latents + deltas_merged * global_guidance_scale
|
|
||||||
|
@ -134,8 +134,6 @@
|
|||||||
"loadMore": "Mehr laden",
|
"loadMore": "Mehr laden",
|
||||||
"noImagesInGallery": "Keine Bilder in der Galerie",
|
"noImagesInGallery": "Keine Bilder in der Galerie",
|
||||||
"loading": "Lade",
|
"loading": "Lade",
|
||||||
"preparingDownload": "bereite Download vor",
|
|
||||||
"preparingDownloadFailed": "Problem beim Download vorbereiten",
|
|
||||||
"deleteImage": "Lösche Bild",
|
"deleteImage": "Lösche Bild",
|
||||||
"copy": "Kopieren",
|
"copy": "Kopieren",
|
||||||
"download": "Runterladen",
|
"download": "Runterladen",
|
||||||
@ -967,7 +965,7 @@
|
|||||||
"resumeFailed": "Problem beim Fortsetzen des Prozesses",
|
"resumeFailed": "Problem beim Fortsetzen des Prozesses",
|
||||||
"pruneFailed": "Problem beim leeren der Warteschlange",
|
"pruneFailed": "Problem beim leeren der Warteschlange",
|
||||||
"pauseTooltip": "Prozess anhalten",
|
"pauseTooltip": "Prozess anhalten",
|
||||||
"back": "Hinten",
|
"back": "Ende",
|
||||||
"resumeSucceeded": "Prozess wird fortgesetzt",
|
"resumeSucceeded": "Prozess wird fortgesetzt",
|
||||||
"resumeTooltip": "Prozess wieder aufnehmen",
|
"resumeTooltip": "Prozess wieder aufnehmen",
|
||||||
"time": "Zeit",
|
"time": "Zeit",
|
||||||
|
@ -741,6 +741,8 @@
|
|||||||
"customConfig": "Custom Config",
|
"customConfig": "Custom Config",
|
||||||
"customConfigFileLocation": "Custom Config File Location",
|
"customConfigFileLocation": "Custom Config File Location",
|
||||||
"customSaveLocation": "Custom Save Location",
|
"customSaveLocation": "Custom Save Location",
|
||||||
|
"defaultSettings": "Default Settings",
|
||||||
|
"defaultSettingsSaved": "Default Settings Saved",
|
||||||
"delete": "Delete",
|
"delete": "Delete",
|
||||||
"deleteConfig": "Delete Config",
|
"deleteConfig": "Delete Config",
|
||||||
"deleteModel": "Delete Model",
|
"deleteModel": "Delete Model",
|
||||||
@ -852,6 +854,7 @@
|
|||||||
"upcastAttention": "Upcast Attention",
|
"upcastAttention": "Upcast Attention",
|
||||||
"updateModel": "Update Model",
|
"updateModel": "Update Model",
|
||||||
"useCustomConfig": "Use Custom Config",
|
"useCustomConfig": "Use Custom Config",
|
||||||
|
"useDefaultSettings": "Use Default Settings",
|
||||||
"v1": "v1",
|
"v1": "v1",
|
||||||
"v2_768": "v2 (768px)",
|
"v2_768": "v2 (768px)",
|
||||||
"v2_base": "v2 (512px)",
|
"v2_base": "v2 (512px)",
|
||||||
@ -870,6 +873,7 @@
|
|||||||
"models": {
|
"models": {
|
||||||
"addLora": "Add LoRA",
|
"addLora": "Add LoRA",
|
||||||
"allLoRAsAdded": "All LoRAs added",
|
"allLoRAsAdded": "All LoRAs added",
|
||||||
|
"concepts": "Concepts",
|
||||||
"loraAlreadyAdded": "LoRA already added",
|
"loraAlreadyAdded": "LoRA already added",
|
||||||
"esrganModel": "ESRGAN Model",
|
"esrganModel": "ESRGAN Model",
|
||||||
"loading": "loading",
|
"loading": "loading",
|
||||||
|
@ -505,8 +505,6 @@
|
|||||||
"seamLowThreshold": "Bajo",
|
"seamLowThreshold": "Bajo",
|
||||||
"coherencePassHeader": "Parámetros de la coherencia",
|
"coherencePassHeader": "Parámetros de la coherencia",
|
||||||
"compositingSettingsHeader": "Ajustes de la composición",
|
"compositingSettingsHeader": "Ajustes de la composición",
|
||||||
"coherenceSteps": "Pasos",
|
|
||||||
"coherenceStrength": "Fuerza",
|
|
||||||
"patchmatchDownScaleSize": "Reducir a escala",
|
"patchmatchDownScaleSize": "Reducir a escala",
|
||||||
"coherenceMode": "Modo"
|
"coherenceMode": "Modo"
|
||||||
},
|
},
|
||||||
|
@ -114,7 +114,8 @@
|
|||||||
"checkpoint": "Checkpoint",
|
"checkpoint": "Checkpoint",
|
||||||
"safetensors": "Safetensors",
|
"safetensors": "Safetensors",
|
||||||
"ai": "ia",
|
"ai": "ia",
|
||||||
"file": "File"
|
"file": "File",
|
||||||
|
"toResolve": "Da risolvere"
|
||||||
},
|
},
|
||||||
"gallery": {
|
"gallery": {
|
||||||
"generations": "Generazioni",
|
"generations": "Generazioni",
|
||||||
@ -142,8 +143,6 @@
|
|||||||
"copy": "Copia",
|
"copy": "Copia",
|
||||||
"download": "Scarica",
|
"download": "Scarica",
|
||||||
"setCurrentImage": "Imposta come immagine corrente",
|
"setCurrentImage": "Imposta come immagine corrente",
|
||||||
"preparingDownload": "Preparazione del download",
|
|
||||||
"preparingDownloadFailed": "Problema durante la preparazione del download",
|
|
||||||
"downloadSelection": "Scarica gli elementi selezionati",
|
"downloadSelection": "Scarica gli elementi selezionati",
|
||||||
"noImageSelected": "Nessuna immagine selezionata",
|
"noImageSelected": "Nessuna immagine selezionata",
|
||||||
"deleteSelection": "Elimina la selezione",
|
"deleteSelection": "Elimina la selezione",
|
||||||
@ -609,8 +608,6 @@
|
|||||||
"seamLowThreshold": "Basso",
|
"seamLowThreshold": "Basso",
|
||||||
"seamHighThreshold": "Alto",
|
"seamHighThreshold": "Alto",
|
||||||
"coherencePassHeader": "Passaggio di coerenza",
|
"coherencePassHeader": "Passaggio di coerenza",
|
||||||
"coherenceSteps": "Passi",
|
|
||||||
"coherenceStrength": "Forza",
|
|
||||||
"compositingSettingsHeader": "Impostazioni di composizione",
|
"compositingSettingsHeader": "Impostazioni di composizione",
|
||||||
"patchmatchDownScaleSize": "Ridimensiona",
|
"patchmatchDownScaleSize": "Ridimensiona",
|
||||||
"coherenceMode": "Modalità",
|
"coherenceMode": "Modalità",
|
||||||
@ -1400,19 +1397,6 @@
|
|||||||
"Regola la maschera."
|
"Regola la maschera."
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"compositingCoherenceSteps": {
|
|
||||||
"heading": "Passi",
|
|
||||||
"paragraphs": [
|
|
||||||
"Numero di passi utilizzati nel Passaggio di Coerenza.",
|
|
||||||
"Simile ai passi di generazione."
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"compositingBlur": {
|
|
||||||
"heading": "Sfocatura",
|
|
||||||
"paragraphs": [
|
|
||||||
"Il raggio di sfocatura della maschera."
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"compositingCoherenceMode": {
|
"compositingCoherenceMode": {
|
||||||
"heading": "Modalità",
|
"heading": "Modalità",
|
||||||
"paragraphs": [
|
"paragraphs": [
|
||||||
@ -1431,13 +1415,6 @@
|
|||||||
"Un secondo ciclo di riduzione del rumore aiuta a comporre l'immagine Inpaint/Outpaint."
|
"Un secondo ciclo di riduzione del rumore aiuta a comporre l'immagine Inpaint/Outpaint."
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"compositingStrength": {
|
|
||||||
"heading": "Forza",
|
|
||||||
"paragraphs": [
|
|
||||||
"Quantità di rumore aggiunta per il Passaggio di Coerenza.",
|
|
||||||
"Simile alla forza di riduzione del rumore."
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"paramNegativeConditioning": {
|
"paramNegativeConditioning": {
|
||||||
"paragraphs": [
|
"paragraphs": [
|
||||||
"Il processo di generazione evita i concetti nel prompt negativo. Utilizzatelo per escludere qualità o oggetti dall'output.",
|
"Il processo di generazione evita i concetti nel prompt negativo. Utilizzatelo per escludere qualità o oggetti dall'output.",
|
||||||
|
@ -123,8 +123,6 @@
|
|||||||
"autoSwitchNewImages": "새로운 이미지로 자동 전환",
|
"autoSwitchNewImages": "새로운 이미지로 자동 전환",
|
||||||
"loading": "불러오는 중",
|
"loading": "불러오는 중",
|
||||||
"unableToLoad": "갤러리를 로드할 수 없음",
|
"unableToLoad": "갤러리를 로드할 수 없음",
|
||||||
"preparingDownload": "다운로드 준비",
|
|
||||||
"preparingDownloadFailed": "다운로드 준비 중 발생한 문제",
|
|
||||||
"singleColumnLayout": "단일 열 레이아웃",
|
"singleColumnLayout": "단일 열 레이아웃",
|
||||||
"image": "이미지",
|
"image": "이미지",
|
||||||
"loadMore": "더 불러오기",
|
"loadMore": "더 불러오기",
|
||||||
|
@ -97,8 +97,6 @@
|
|||||||
"featuresWillReset": "Als je deze afbeelding verwijdert, dan worden deze functies onmiddellijk teruggezet.",
|
"featuresWillReset": "Als je deze afbeelding verwijdert, dan worden deze functies onmiddellijk teruggezet.",
|
||||||
"loading": "Bezig met laden",
|
"loading": "Bezig met laden",
|
||||||
"unableToLoad": "Kan galerij niet laden",
|
"unableToLoad": "Kan galerij niet laden",
|
||||||
"preparingDownload": "Bezig met voorbereiden van download",
|
|
||||||
"preparingDownloadFailed": "Fout bij voorbereiden van download",
|
|
||||||
"downloadSelection": "Download selectie",
|
"downloadSelection": "Download selectie",
|
||||||
"currentlyInUse": "Deze afbeelding is momenteel in gebruik door de volgende functies:",
|
"currentlyInUse": "Deze afbeelding is momenteel in gebruik door de volgende functies:",
|
||||||
"copy": "Kopieer",
|
"copy": "Kopieer",
|
||||||
@ -535,8 +533,6 @@
|
|||||||
"coherencePassHeader": "Coherentiestap",
|
"coherencePassHeader": "Coherentiestap",
|
||||||
"maskBlur": "Vervaag",
|
"maskBlur": "Vervaag",
|
||||||
"maskBlurMethod": "Vervagingsmethode",
|
"maskBlurMethod": "Vervagingsmethode",
|
||||||
"coherenceSteps": "Stappen",
|
|
||||||
"coherenceStrength": "Sterkte",
|
|
||||||
"seamHighThreshold": "Hoog",
|
"seamHighThreshold": "Hoog",
|
||||||
"seamLowThreshold": "Laag",
|
"seamLowThreshold": "Laag",
|
||||||
"invoke": {
|
"invoke": {
|
||||||
@ -1139,13 +1135,6 @@
|
|||||||
"Een afbeeldingsgrootte (in aantal pixels) equivalent aan 512x512 wordt aanbevolen voor SD1.5-modellen. Een grootte-equivalent van 1024x1024 wordt aanbevolen voor SDXL-modellen."
|
"Een afbeeldingsgrootte (in aantal pixels) equivalent aan 512x512 wordt aanbevolen voor SD1.5-modellen. Een grootte-equivalent van 1024x1024 wordt aanbevolen voor SDXL-modellen."
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"compositingCoherenceSteps": {
|
|
||||||
"heading": "Stappen",
|
|
||||||
"paragraphs": [
|
|
||||||
"Het aantal te gebruiken ontruisingsstappen in de coherentiefase.",
|
|
||||||
"Gelijk aan de hoofdparameter Stappen."
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"dynamicPrompts": {
|
"dynamicPrompts": {
|
||||||
"paragraphs": [
|
"paragraphs": [
|
||||||
"Dynamische prompts vormt een enkele prompt om in vele.",
|
"Dynamische prompts vormt een enkele prompt om in vele.",
|
||||||
@ -1160,12 +1149,6 @@
|
|||||||
],
|
],
|
||||||
"heading": "VAE"
|
"heading": "VAE"
|
||||||
},
|
},
|
||||||
"compositingBlur": {
|
|
||||||
"heading": "Vervaging",
|
|
||||||
"paragraphs": [
|
|
||||||
"De vervagingsstraal van het masker."
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"paramIterations": {
|
"paramIterations": {
|
||||||
"paragraphs": [
|
"paragraphs": [
|
||||||
"Het aantal te genereren afbeeldingen.",
|
"Het aantal te genereren afbeeldingen.",
|
||||||
@ -1240,13 +1223,6 @@
|
|||||||
],
|
],
|
||||||
"heading": "Ontruisingssterkte"
|
"heading": "Ontruisingssterkte"
|
||||||
},
|
},
|
||||||
"compositingStrength": {
|
|
||||||
"heading": "Sterkte",
|
|
||||||
"paragraphs": [
|
|
||||||
"Ontruisingssterkte voor de coherentiefase.",
|
|
||||||
"Gelijk aan de parameter Ontruisingssterkte Afbeelding naar afbeelding."
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"paramNegativeConditioning": {
|
"paramNegativeConditioning": {
|
||||||
"paragraphs": [
|
"paragraphs": [
|
||||||
"Het genereerproces voorkomt de gegeven begrippen in de negatieve prompt. Gebruik dit om bepaalde zaken of voorwerpen uit te sluiten van de uitvoerafbeelding.",
|
"Het genereerproces voorkomt de gegeven begrippen in de negatieve prompt. Gebruik dit om bepaalde zaken of voorwerpen uit te sluiten van de uitvoerafbeelding.",
|
||||||
|
@ -143,8 +143,6 @@
|
|||||||
"problemDeletingImagesDesc": "Не удалось удалить одно или несколько изображений",
|
"problemDeletingImagesDesc": "Не удалось удалить одно или несколько изображений",
|
||||||
"loading": "Загрузка",
|
"loading": "Загрузка",
|
||||||
"unableToLoad": "Невозможно загрузить галерею",
|
"unableToLoad": "Невозможно загрузить галерею",
|
||||||
"preparingDownload": "Подготовка к скачиванию",
|
|
||||||
"preparingDownloadFailed": "Проблема с подготовкой к скачиванию",
|
|
||||||
"image": "изображение",
|
"image": "изображение",
|
||||||
"drop": "перебросить",
|
"drop": "перебросить",
|
||||||
"problemDeletingImages": "Проблема с удалением изображений",
|
"problemDeletingImages": "Проблема с удалением изображений",
|
||||||
@ -612,9 +610,7 @@
|
|||||||
"maskBlurMethod": "Метод размытия",
|
"maskBlurMethod": "Метод размытия",
|
||||||
"seamLowThreshold": "Низкий",
|
"seamLowThreshold": "Низкий",
|
||||||
"seamHighThreshold": "Высокий",
|
"seamHighThreshold": "Высокий",
|
||||||
"coherenceSteps": "Шагов",
|
|
||||||
"coherencePassHeader": "Порог Coherence",
|
"coherencePassHeader": "Порог Coherence",
|
||||||
"coherenceStrength": "Сила",
|
|
||||||
"compositingSettingsHeader": "Настройки компоновки",
|
"compositingSettingsHeader": "Настройки компоновки",
|
||||||
"invoke": {
|
"invoke": {
|
||||||
"noNodesInGraph": "Нет узлов в графе",
|
"noNodesInGraph": "Нет узлов в графе",
|
||||||
@ -1321,13 +1317,6 @@
|
|||||||
"Размер изображения (в пикселях), эквивалентный 512x512, рекомендуется для моделей SD1.5, а размер, эквивалентный 1024x1024, рекомендуется для моделей SDXL."
|
"Размер изображения (в пикселях), эквивалентный 512x512, рекомендуется для моделей SD1.5, а размер, эквивалентный 1024x1024, рекомендуется для моделей SDXL."
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"compositingCoherenceSteps": {
|
|
||||||
"heading": "Шаги",
|
|
||||||
"paragraphs": [
|
|
||||||
"Количество шагов снижения шума, используемых при прохождении когерентности.",
|
|
||||||
"То же, что и основной параметр «Шаги»."
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"dynamicPrompts": {
|
"dynamicPrompts": {
|
||||||
"paragraphs": [
|
"paragraphs": [
|
||||||
"Динамические запросы превращают одно приглашение на множество.",
|
"Динамические запросы превращают одно приглашение на множество.",
|
||||||
@ -1342,12 +1331,6 @@
|
|||||||
],
|
],
|
||||||
"heading": "VAE"
|
"heading": "VAE"
|
||||||
},
|
},
|
||||||
"compositingBlur": {
|
|
||||||
"heading": "Размытие",
|
|
||||||
"paragraphs": [
|
|
||||||
"Радиус размытия маски."
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"paramIterations": {
|
"paramIterations": {
|
||||||
"paragraphs": [
|
"paragraphs": [
|
||||||
"Количество изображений, которые нужно сгенерировать.",
|
"Количество изображений, которые нужно сгенерировать.",
|
||||||
@ -1422,13 +1405,6 @@
|
|||||||
],
|
],
|
||||||
"heading": "Шумоподавление"
|
"heading": "Шумоподавление"
|
||||||
},
|
},
|
||||||
"compositingStrength": {
|
|
||||||
"heading": "Сила",
|
|
||||||
"paragraphs": [
|
|
||||||
null,
|
|
||||||
"То же, что параметр «Сила шумоподавления img2img»."
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"paramNegativeConditioning": {
|
"paramNegativeConditioning": {
|
||||||
"paragraphs": [
|
"paragraphs": [
|
||||||
"Stable Diffusion пытается избежать указанных в отрицательном запросе концепций. Используйте это, чтобы исключить качества или объекты из вывода.",
|
"Stable Diffusion пытается избежать указанных в отрицательном запросе концепций. Используйте это, чтобы исключить качества или объекты из вывода.",
|
||||||
|
@ -355,7 +355,6 @@
|
|||||||
"starImage": "Yıldız Koy",
|
"starImage": "Yıldız Koy",
|
||||||
"download": "İndir",
|
"download": "İndir",
|
||||||
"deleteSelection": "Seçileni Sil",
|
"deleteSelection": "Seçileni Sil",
|
||||||
"preparingDownloadFailed": "İndirme Hazırlanırken Sorun",
|
|
||||||
"problemDeletingImages": "Görsel Silmede Sorun",
|
"problemDeletingImages": "Görsel Silmede Sorun",
|
||||||
"featuresWillReset": "Bu görseli silerseniz, o özellikler resetlenecektir.",
|
"featuresWillReset": "Bu görseli silerseniz, o özellikler resetlenecektir.",
|
||||||
"galleryImageResetSize": "Boyutu Resetle",
|
"galleryImageResetSize": "Boyutu Resetle",
|
||||||
@ -377,7 +376,6 @@
|
|||||||
"setCurrentImage": "Çalışma Görseli Yap",
|
"setCurrentImage": "Çalışma Görseli Yap",
|
||||||
"unableToLoad": "Galeri Yüklenemedi",
|
"unableToLoad": "Galeri Yüklenemedi",
|
||||||
"downloadSelection": "Seçileni İndir",
|
"downloadSelection": "Seçileni İndir",
|
||||||
"preparingDownload": "İndirmeye Hazırlanıyor",
|
|
||||||
"singleColumnLayout": "Tek Sütun Düzen",
|
"singleColumnLayout": "Tek Sütun Düzen",
|
||||||
"generations": "Çıktılar",
|
"generations": "Çıktılar",
|
||||||
"showUploads": "Yüklenenleri Göster",
|
"showUploads": "Yüklenenleri Göster",
|
||||||
@ -723,7 +721,6 @@
|
|||||||
"clipSkip": "CLIP Atlama",
|
"clipSkip": "CLIP Atlama",
|
||||||
"randomizeSeed": "Rastgele Tohum",
|
"randomizeSeed": "Rastgele Tohum",
|
||||||
"cfgScale": "CFG Ölçeği",
|
"cfgScale": "CFG Ölçeği",
|
||||||
"coherenceStrength": "Etki",
|
|
||||||
"controlNetControlMode": "Yönetim Kipi",
|
"controlNetControlMode": "Yönetim Kipi",
|
||||||
"general": "Genel",
|
"general": "Genel",
|
||||||
"img2imgStrength": "Görselden Görsel Ölçüsü",
|
"img2imgStrength": "Görselden Görsel Ölçüsü",
|
||||||
@ -793,7 +790,6 @@
|
|||||||
"cfgRescaleMultiplier": "CFG Rescale Çarpanı",
|
"cfgRescaleMultiplier": "CFG Rescale Çarpanı",
|
||||||
"cfgRescale": "CFG Rescale",
|
"cfgRescale": "CFG Rescale",
|
||||||
"coherencePassHeader": "Uyum Geçişi",
|
"coherencePassHeader": "Uyum Geçişi",
|
||||||
"coherenceSteps": "Adım",
|
|
||||||
"infillMethod": "Doldurma Yöntemi",
|
"infillMethod": "Doldurma Yöntemi",
|
||||||
"maskBlurMethod": "Bulandırma Yöntemi",
|
"maskBlurMethod": "Bulandırma Yöntemi",
|
||||||
"steps": "Adım",
|
"steps": "Adım",
|
||||||
|
@ -136,8 +136,6 @@
|
|||||||
"copy": "复制",
|
"copy": "复制",
|
||||||
"download": "下载",
|
"download": "下载",
|
||||||
"setCurrentImage": "设为当前图像",
|
"setCurrentImage": "设为当前图像",
|
||||||
"preparingDownload": "准备下载",
|
|
||||||
"preparingDownloadFailed": "准备下载时出现问题",
|
|
||||||
"downloadSelection": "下载所选内容",
|
"downloadSelection": "下载所选内容",
|
||||||
"noImageSelected": "无选中的图像",
|
"noImageSelected": "无选中的图像",
|
||||||
"deleteSelection": "删除所选内容",
|
"deleteSelection": "删除所选内容",
|
||||||
@ -616,11 +614,9 @@
|
|||||||
"incompatibleBaseModelForControlAdapter": "有 #{{number}} 个 Control Adapter 模型与主模型不兼容。"
|
"incompatibleBaseModelForControlAdapter": "有 #{{number}} 个 Control Adapter 模型与主模型不兼容。"
|
||||||
},
|
},
|
||||||
"patchmatchDownScaleSize": "缩小",
|
"patchmatchDownScaleSize": "缩小",
|
||||||
"coherenceSteps": "步数",
|
|
||||||
"clipSkip": "CLIP 跳过层",
|
"clipSkip": "CLIP 跳过层",
|
||||||
"compositingSettingsHeader": "合成设置",
|
"compositingSettingsHeader": "合成设置",
|
||||||
"useCpuNoise": "使用 CPU 噪声",
|
"useCpuNoise": "使用 CPU 噪声",
|
||||||
"coherenceStrength": "强度",
|
|
||||||
"enableNoiseSettings": "启用噪声设置",
|
"enableNoiseSettings": "启用噪声设置",
|
||||||
"coherenceMode": "模式",
|
"coherenceMode": "模式",
|
||||||
"cpuNoise": "CPU 噪声",
|
"cpuNoise": "CPU 噪声",
|
||||||
@ -1402,19 +1398,6 @@
|
|||||||
"图像尺寸(单位:像素)建议 SD 1.5 模型使用等效 512x512 的尺寸,SDXL 模型使用等效 1024x1024 的尺寸。"
|
"图像尺寸(单位:像素)建议 SD 1.5 模型使用等效 512x512 的尺寸,SDXL 模型使用等效 1024x1024 的尺寸。"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"compositingCoherenceSteps": {
|
|
||||||
"heading": "步数",
|
|
||||||
"paragraphs": [
|
|
||||||
"一致性层中使用的去噪步数。",
|
|
||||||
"与主参数中的步数相同。"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"compositingBlur": {
|
|
||||||
"heading": "模糊",
|
|
||||||
"paragraphs": [
|
|
||||||
"遮罩模糊半径。"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"noiseUseCPU": {
|
"noiseUseCPU": {
|
||||||
"heading": "使用 CPU 噪声",
|
"heading": "使用 CPU 噪声",
|
||||||
"paragraphs": [
|
"paragraphs": [
|
||||||
@ -1467,13 +1450,6 @@
|
|||||||
"第二轮去噪有助于合成内补/外扩图像。"
|
"第二轮去噪有助于合成内补/外扩图像。"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"compositingStrength": {
|
|
||||||
"heading": "强度",
|
|
||||||
"paragraphs": [
|
|
||||||
"一致性层使用的去噪强度。",
|
|
||||||
"去噪强度与图生图的参数相同。"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"paramNegativeConditioning": {
|
"paramNegativeConditioning": {
|
||||||
"paragraphs": [
|
"paragraphs": [
|
||||||
"生成过程会避免生成负向提示词中的概念。使用此选项来使输出排除部分质量或对象。",
|
"生成过程会避免生成负向提示词中的概念。使用此选项来使输出排除部分质量或对象。",
|
||||||
|
@ -55,6 +55,8 @@ import { addUpscaleRequestedListener } from 'app/store/middleware/listenerMiddle
|
|||||||
import { addWorkflowLoadRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested';
|
import { addWorkflowLoadRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested';
|
||||||
import type { AppDispatch, RootState } from 'app/store/store';
|
import type { AppDispatch, RootState } from 'app/store/store';
|
||||||
|
|
||||||
|
import { addSetDefaultSettingsListener } from './listeners/setDefaultSettings';
|
||||||
|
|
||||||
export const listenerMiddleware = createListenerMiddleware();
|
export const listenerMiddleware = createListenerMiddleware();
|
||||||
|
|
||||||
export type AppStartListening = TypedStartListening<RootState, AppDispatch>;
|
export type AppStartListening = TypedStartListening<RootState, AppDispatch>;
|
||||||
@ -153,3 +155,5 @@ addUpscaleRequestedListener(startAppListening);
|
|||||||
|
|
||||||
// Dynamic prompts
|
// Dynamic prompts
|
||||||
addDynamicPromptsListener(startAppListening);
|
addDynamicPromptsListener(startAppListening);
|
||||||
|
|
||||||
|
addSetDefaultSettingsListener(startAppListening);
|
||||||
|
@ -0,0 +1,96 @@
|
|||||||
|
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||||
|
import { setDefaultSettings } from 'features/parameters/store/actions';
|
||||||
|
import {
|
||||||
|
setCfgRescaleMultiplier,
|
||||||
|
setCfgScale,
|
||||||
|
setScheduler,
|
||||||
|
setSteps,
|
||||||
|
vaePrecisionChanged,
|
||||||
|
vaeSelected,
|
||||||
|
} from 'features/parameters/store/generationSlice';
|
||||||
|
import {
|
||||||
|
isParameterCFGRescaleMultiplier,
|
||||||
|
isParameterCFGScale,
|
||||||
|
isParameterPrecision,
|
||||||
|
isParameterScheduler,
|
||||||
|
isParameterSteps,
|
||||||
|
zParameterVAEModel,
|
||||||
|
} from 'features/parameters/types/parameterSchemas';
|
||||||
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
|
import { makeToast } from 'features/system/util/makeToast';
|
||||||
|
import { t } from 'i18next';
|
||||||
|
import { map } from 'lodash-es';
|
||||||
|
import { modelsApi } from 'services/api/endpoints/models';
|
||||||
|
|
||||||
|
export const addSetDefaultSettingsListener = (startAppListening: AppStartListening) => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: setDefaultSettings,
|
||||||
|
effect: async (action, { dispatch, getState }) => {
|
||||||
|
const state = getState();
|
||||||
|
|
||||||
|
const currentModel = state.generation.model;
|
||||||
|
|
||||||
|
if (!currentModel) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const metadata = await dispatch(modelsApi.endpoints.getModelMetadata.initiate(currentModel.key)).unwrap();
|
||||||
|
|
||||||
|
if (!metadata || !metadata.default_settings) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const { vae, vae_precision, cfg_scale, cfg_rescale_multiplier, steps, scheduler } = metadata.default_settings;
|
||||||
|
|
||||||
|
if (vae) {
|
||||||
|
// we store this as "default" within default settings
|
||||||
|
// to distinguish it from no default set
|
||||||
|
if (vae === 'default') {
|
||||||
|
dispatch(vaeSelected(null));
|
||||||
|
} else {
|
||||||
|
const { data } = modelsApi.endpoints.getVaeModels.select()(state);
|
||||||
|
const vaeArray = map(data?.entities);
|
||||||
|
const validVae = vaeArray.find((model) => model.key === vae);
|
||||||
|
|
||||||
|
const result = zParameterVAEModel.safeParse(validVae);
|
||||||
|
if (!result.success) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
dispatch(vaeSelected(result.data));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (vae_precision) {
|
||||||
|
if (isParameterPrecision(vae_precision)) {
|
||||||
|
dispatch(vaePrecisionChanged(vae_precision));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (cfg_scale) {
|
||||||
|
if (isParameterCFGScale(cfg_scale)) {
|
||||||
|
dispatch(setCfgScale(cfg_scale));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (cfg_rescale_multiplier) {
|
||||||
|
if (isParameterCFGRescaleMultiplier(cfg_rescale_multiplier)) {
|
||||||
|
dispatch(setCfgRescaleMultiplier(cfg_rescale_multiplier));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (steps) {
|
||||||
|
if (isParameterSteps(steps)) {
|
||||||
|
dispatch(setSteps(steps));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (scheduler) {
|
||||||
|
if (isParameterScheduler(scheduler)) {
|
||||||
|
dispatch(setScheduler(scheduler));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
dispatch(addToast(makeToast({ title: t('toast.parameterSet', { parameter: 'Default settings' }) })));
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
@ -1,4 +1,5 @@
|
|||||||
import type { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants';
|
import type { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants';
|
||||||
|
import type { ParameterPrecision, ParameterScheduler } from 'features/parameters/types/parameterSchemas';
|
||||||
import type { InvokeTabName } from 'features/ui/store/tabMap';
|
import type { InvokeTabName } from 'features/ui/store/tabMap';
|
||||||
import type { O } from 'ts-toolbelt';
|
import type { O } from 'ts-toolbelt';
|
||||||
|
|
||||||
@ -82,6 +83,8 @@ export type AppConfig = {
|
|||||||
guidance: NumericalParameterConfig;
|
guidance: NumericalParameterConfig;
|
||||||
cfgRescaleMultiplier: NumericalParameterConfig;
|
cfgRescaleMultiplier: NumericalParameterConfig;
|
||||||
img2imgStrength: NumericalParameterConfig;
|
img2imgStrength: NumericalParameterConfig;
|
||||||
|
scheduler?: ParameterScheduler;
|
||||||
|
vaePrecision?: ParameterPrecision;
|
||||||
// Canvas
|
// Canvas
|
||||||
boundingBoxHeight: NumericalParameterConfig; // initial value comes from model
|
boundingBoxHeight: NumericalParameterConfig; // initial value comes from model
|
||||||
boundingBoxWidth: NumericalParameterConfig; // initial value comes from model
|
boundingBoxWidth: NumericalParameterConfig; // initial value comes from model
|
||||||
|
@ -59,7 +59,7 @@ const LoRASelect = () => {
|
|||||||
return (
|
return (
|
||||||
<FormControl isDisabled={!options.length}>
|
<FormControl isDisabled={!options.length}>
|
||||||
<InformationalPopover feature="lora">
|
<InformationalPopover feature="lora">
|
||||||
<FormLabel>{t('models.lora')} </FormLabel>
|
<FormLabel>{t('models.concepts')} </FormLabel>
|
||||||
</InformationalPopover>
|
</InformationalPopover>
|
||||||
<Combobox
|
<Combobox
|
||||||
placeholder={placeholder}
|
placeholder={placeholder}
|
||||||
|
@ -15,7 +15,7 @@ const STATUSES = {
|
|||||||
const ImportQueueBadge = ({ status, errorReason }: { status?: ModelInstallStatus; errorReason?: string | null }) => {
|
const ImportQueueBadge = ({ status, errorReason }: { status?: ModelInstallStatus; errorReason?: string | null }) => {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
if (!status) {
|
if (!status || !Object.keys(STATUSES).includes(status)) {
|
||||||
return <></>;
|
return <></>;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -0,0 +1,66 @@
|
|||||||
|
import { skipToken } from '@reduxjs/toolkit/query';
|
||||||
|
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||||
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import Loading from 'common/components/Loading/Loading';
|
||||||
|
import { selectConfigSlice } from 'features/system/store/configSlice';
|
||||||
|
import { isNil } from 'lodash-es';
|
||||||
|
import { useMemo } from 'react';
|
||||||
|
import { useGetModelMetadataQuery } from 'services/api/endpoints/models';
|
||||||
|
|
||||||
|
import { DefaultSettingsForm } from './DefaultSettings/DefaultSettingsForm';
|
||||||
|
|
||||||
|
const initialStatesSelector = createMemoizedSelector(selectConfigSlice, (config) => {
|
||||||
|
const { steps, guidance, scheduler, cfgRescaleMultiplier, vaePrecision } = config.sd;
|
||||||
|
|
||||||
|
return {
|
||||||
|
initialSteps: steps.initial,
|
||||||
|
initialCfg: guidance.initial,
|
||||||
|
initialScheduler: scheduler,
|
||||||
|
initialCfgRescaleMultiplier: cfgRescaleMultiplier.initial,
|
||||||
|
initialVaePrecision: vaePrecision,
|
||||||
|
};
|
||||||
|
});
|
||||||
|
|
||||||
|
export const DefaultSettings = () => {
|
||||||
|
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||||
|
|
||||||
|
const { data, isLoading } = useGetModelMetadataQuery(selectedModelKey ?? skipToken);
|
||||||
|
const { initialSteps, initialCfg, initialScheduler, initialCfgRescaleMultiplier, initialVaePrecision } =
|
||||||
|
useAppSelector(initialStatesSelector);
|
||||||
|
|
||||||
|
const defaultSettingsDefaults = useMemo(() => {
|
||||||
|
return {
|
||||||
|
vae: { isEnabled: !isNil(data?.default_settings?.vae), value: data?.default_settings?.vae || 'default' },
|
||||||
|
vaePrecision: {
|
||||||
|
isEnabled: !isNil(data?.default_settings?.vae_precision),
|
||||||
|
value: data?.default_settings?.vae_precision || initialVaePrecision || 'fp32',
|
||||||
|
},
|
||||||
|
scheduler: {
|
||||||
|
isEnabled: !isNil(data?.default_settings?.scheduler),
|
||||||
|
value: data?.default_settings?.scheduler || initialScheduler || 'euler',
|
||||||
|
},
|
||||||
|
steps: { isEnabled: !isNil(data?.default_settings?.steps), value: data?.default_settings?.steps || initialSteps },
|
||||||
|
cfgScale: {
|
||||||
|
isEnabled: !isNil(data?.default_settings?.cfg_scale),
|
||||||
|
value: data?.default_settings?.cfg_scale || initialCfg,
|
||||||
|
},
|
||||||
|
cfgRescaleMultiplier: {
|
||||||
|
isEnabled: !isNil(data?.default_settings?.cfg_rescale_multiplier),
|
||||||
|
value: data?.default_settings?.cfg_rescale_multiplier || initialCfgRescaleMultiplier,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
}, [
|
||||||
|
data?.default_settings,
|
||||||
|
initialSteps,
|
||||||
|
initialCfg,
|
||||||
|
initialScheduler,
|
||||||
|
initialCfgRescaleMultiplier,
|
||||||
|
initialVaePrecision,
|
||||||
|
]);
|
||||||
|
|
||||||
|
if (isLoading) {
|
||||||
|
return <Loading />;
|
||||||
|
}
|
||||||
|
|
||||||
|
return <DefaultSettingsForm defaultSettingsDefaults={defaultSettingsDefaults} />;
|
||||||
|
};
|
@ -0,0 +1,72 @@
|
|||||||
|
import { CompositeNumberInput, CompositeSlider, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||||
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||||
|
import { useCallback, useMemo } from 'react';
|
||||||
|
import type { UseControllerProps } from 'react-hook-form';
|
||||||
|
import { useController } from 'react-hook-form';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
|
import type { DefaultSettingsFormData } from './DefaultSettingsForm';
|
||||||
|
|
||||||
|
type DefaultCfgRescaleMultiplierType = DefaultSettingsFormData['cfgRescaleMultiplier'];
|
||||||
|
|
||||||
|
export function DefaultCfgRescaleMultiplier(props: UseControllerProps<DefaultSettingsFormData>) {
|
||||||
|
const { field } = useController(props);
|
||||||
|
|
||||||
|
const sliderMin = useAppSelector((s) => s.config.sd.cfgRescaleMultiplier.sliderMin);
|
||||||
|
const sliderMax = useAppSelector((s) => s.config.sd.cfgRescaleMultiplier.sliderMax);
|
||||||
|
const numberInputMin = useAppSelector((s) => s.config.sd.cfgRescaleMultiplier.numberInputMin);
|
||||||
|
const numberInputMax = useAppSelector((s) => s.config.sd.cfgRescaleMultiplier.numberInputMax);
|
||||||
|
const coarseStep = useAppSelector((s) => s.config.sd.cfgRescaleMultiplier.coarseStep);
|
||||||
|
const fineStep = useAppSelector((s) => s.config.sd.cfgRescaleMultiplier.fineStep);
|
||||||
|
const { t } = useTranslation();
|
||||||
|
const marks = useMemo(() => [sliderMin, Math.floor(sliderMax / 2), sliderMax], [sliderMax, sliderMin]);
|
||||||
|
|
||||||
|
const onChange = useCallback(
|
||||||
|
(v: number) => {
|
||||||
|
const updatedValue = {
|
||||||
|
...(field.value as DefaultCfgRescaleMultiplierType),
|
||||||
|
value: v,
|
||||||
|
};
|
||||||
|
field.onChange(updatedValue);
|
||||||
|
},
|
||||||
|
[field]
|
||||||
|
);
|
||||||
|
|
||||||
|
const value = useMemo(() => {
|
||||||
|
return (field.value as DefaultCfgRescaleMultiplierType).value;
|
||||||
|
}, [field.value]);
|
||||||
|
|
||||||
|
const isDisabled = useMemo(() => {
|
||||||
|
return !(field.value as DefaultCfgRescaleMultiplierType).isEnabled;
|
||||||
|
}, [field.value]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<FormControl flexDir="column" gap={1} alignItems="flex-start">
|
||||||
|
<InformationalPopover feature="paramCFGRescaleMultiplier">
|
||||||
|
<FormLabel>{t('parameters.cfgRescaleMultiplier')}</FormLabel>
|
||||||
|
</InformationalPopover>
|
||||||
|
<Flex w="full" gap={1}>
|
||||||
|
<CompositeSlider
|
||||||
|
value={value}
|
||||||
|
min={sliderMin}
|
||||||
|
max={sliderMax}
|
||||||
|
step={coarseStep}
|
||||||
|
fineStep={fineStep}
|
||||||
|
onChange={onChange}
|
||||||
|
marks={marks}
|
||||||
|
isDisabled={isDisabled}
|
||||||
|
/>
|
||||||
|
<CompositeNumberInput
|
||||||
|
value={value}
|
||||||
|
min={numberInputMin}
|
||||||
|
max={numberInputMax}
|
||||||
|
step={coarseStep}
|
||||||
|
fineStep={fineStep}
|
||||||
|
onChange={onChange}
|
||||||
|
isDisabled={isDisabled}
|
||||||
|
/>
|
||||||
|
</Flex>
|
||||||
|
</FormControl>
|
||||||
|
);
|
||||||
|
}
|
@ -0,0 +1,72 @@
|
|||||||
|
import { CompositeNumberInput, CompositeSlider, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||||
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||||
|
import { useCallback, useMemo } from 'react';
|
||||||
|
import type { UseControllerProps } from 'react-hook-form';
|
||||||
|
import { useController } from 'react-hook-form';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
|
import type { DefaultSettingsFormData } from './DefaultSettingsForm';
|
||||||
|
|
||||||
|
type DefaultCfgType = DefaultSettingsFormData['cfgScale'];
|
||||||
|
|
||||||
|
export function DefaultCfgScale(props: UseControllerProps<DefaultSettingsFormData>) {
|
||||||
|
const { field } = useController(props);
|
||||||
|
|
||||||
|
const sliderMin = useAppSelector((s) => s.config.sd.guidance.sliderMin);
|
||||||
|
const sliderMax = useAppSelector((s) => s.config.sd.guidance.sliderMax);
|
||||||
|
const numberInputMin = useAppSelector((s) => s.config.sd.guidance.numberInputMin);
|
||||||
|
const numberInputMax = useAppSelector((s) => s.config.sd.guidance.numberInputMax);
|
||||||
|
const coarseStep = useAppSelector((s) => s.config.sd.guidance.coarseStep);
|
||||||
|
const fineStep = useAppSelector((s) => s.config.sd.guidance.fineStep);
|
||||||
|
const { t } = useTranslation();
|
||||||
|
const marks = useMemo(() => [sliderMin, Math.floor(sliderMax / 2), sliderMax], [sliderMax, sliderMin]);
|
||||||
|
|
||||||
|
const onChange = useCallback(
|
||||||
|
(v: number) => {
|
||||||
|
const updatedValue = {
|
||||||
|
...(field.value as DefaultCfgType),
|
||||||
|
value: v,
|
||||||
|
};
|
||||||
|
field.onChange(updatedValue);
|
||||||
|
},
|
||||||
|
[field]
|
||||||
|
);
|
||||||
|
|
||||||
|
const value = useMemo(() => {
|
||||||
|
return (field.value as DefaultCfgType).value;
|
||||||
|
}, [field.value]);
|
||||||
|
|
||||||
|
const isDisabled = useMemo(() => {
|
||||||
|
return !(field.value as DefaultCfgType).isEnabled;
|
||||||
|
}, [field.value]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<FormControl flexDir="column" gap={1} alignItems="flex-start">
|
||||||
|
<InformationalPopover feature="paramCFGScale">
|
||||||
|
<FormLabel>{t('parameters.cfgScale')}</FormLabel>
|
||||||
|
</InformationalPopover>
|
||||||
|
<Flex w="full" gap={1}>
|
||||||
|
<CompositeSlider
|
||||||
|
value={value}
|
||||||
|
min={sliderMin}
|
||||||
|
max={sliderMax}
|
||||||
|
step={coarseStep}
|
||||||
|
fineStep={fineStep}
|
||||||
|
onChange={onChange}
|
||||||
|
marks={marks}
|
||||||
|
isDisabled={isDisabled}
|
||||||
|
/>
|
||||||
|
<CompositeNumberInput
|
||||||
|
value={value}
|
||||||
|
min={numberInputMin}
|
||||||
|
max={numberInputMax}
|
||||||
|
step={coarseStep}
|
||||||
|
fineStep={fineStep}
|
||||||
|
onChange={onChange}
|
||||||
|
isDisabled={isDisabled}
|
||||||
|
/>
|
||||||
|
</Flex>
|
||||||
|
</FormControl>
|
||||||
|
);
|
||||||
|
}
|
@ -0,0 +1,50 @@
|
|||||||
|
import type { ComboboxOnChange } from '@invoke-ai/ui-library';
|
||||||
|
import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||||
|
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||||
|
import { SCHEDULER_OPTIONS } from 'features/parameters/types/constants';
|
||||||
|
import { isParameterScheduler } from 'features/parameters/types/parameterSchemas';
|
||||||
|
import { useCallback, useMemo } from 'react';
|
||||||
|
import type { UseControllerProps } from 'react-hook-form';
|
||||||
|
import { useController } from 'react-hook-form';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
|
import type { DefaultSettingsFormData } from './DefaultSettingsForm';
|
||||||
|
|
||||||
|
type DefaultSchedulerType = DefaultSettingsFormData['scheduler'];
|
||||||
|
|
||||||
|
export function DefaultScheduler(props: UseControllerProps<DefaultSettingsFormData>) {
|
||||||
|
const { t } = useTranslation();
|
||||||
|
const { field } = useController(props);
|
||||||
|
|
||||||
|
const onChange = useCallback<ComboboxOnChange>(
|
||||||
|
(v) => {
|
||||||
|
if (!isParameterScheduler(v?.value)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const updatedValue = {
|
||||||
|
...(field.value as DefaultSchedulerType),
|
||||||
|
value: v.value,
|
||||||
|
};
|
||||||
|
field.onChange(updatedValue);
|
||||||
|
},
|
||||||
|
[field]
|
||||||
|
);
|
||||||
|
|
||||||
|
const value = useMemo(
|
||||||
|
() => SCHEDULER_OPTIONS.find((o) => o.value === (field.value as DefaultSchedulerType).value),
|
||||||
|
[field]
|
||||||
|
);
|
||||||
|
|
||||||
|
const isDisabled = useMemo(() => {
|
||||||
|
return !(field.value as DefaultSchedulerType).isEnabled;
|
||||||
|
}, [field.value]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<FormControl flexDir="column" gap={1} alignItems="flex-start">
|
||||||
|
<InformationalPopover feature="paramScheduler">
|
||||||
|
<FormLabel>{t('parameters.scheduler')}</FormLabel>
|
||||||
|
</InformationalPopover>
|
||||||
|
<Combobox isDisabled={isDisabled} value={value} options={SCHEDULER_OPTIONS} onChange={onChange} />
|
||||||
|
</FormControl>
|
||||||
|
);
|
||||||
|
}
|
@ -0,0 +1,147 @@
|
|||||||
|
import { Button, Flex, Heading } from '@invoke-ai/ui-library';
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import type { ParameterScheduler } from 'features/parameters/types/parameterSchemas';
|
||||||
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
|
import { makeToast } from 'features/system/util/makeToast';
|
||||||
|
import { useCallback } from 'react';
|
||||||
|
import type { SubmitHandler } from 'react-hook-form';
|
||||||
|
import { useForm } from 'react-hook-form';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { IoPencil } from 'react-icons/io5';
|
||||||
|
import { useUpdateModelMetadataMutation } from 'services/api/endpoints/models';
|
||||||
|
|
||||||
|
import { DefaultCfgRescaleMultiplier } from './DefaultCfgRescaleMultiplier';
|
||||||
|
import { DefaultCfgScale } from './DefaultCfgScale';
|
||||||
|
import { DefaultScheduler } from './DefaultScheduler';
|
||||||
|
import { DefaultSteps } from './DefaultSteps';
|
||||||
|
import { DefaultVae } from './DefaultVae';
|
||||||
|
import { DefaultVaePrecision } from './DefaultVaePrecision';
|
||||||
|
import { SettingToggle } from './SettingToggle';
|
||||||
|
|
||||||
|
export interface FormField<T> {
|
||||||
|
value: T;
|
||||||
|
isEnabled: boolean;
|
||||||
|
}
|
||||||
|
|
||||||
|
export type DefaultSettingsFormData = {
|
||||||
|
vae: FormField<string>;
|
||||||
|
vaePrecision: FormField<string>;
|
||||||
|
scheduler: FormField<ParameterScheduler>;
|
||||||
|
steps: FormField<number>;
|
||||||
|
cfgScale: FormField<number>;
|
||||||
|
cfgRescaleMultiplier: FormField<number>;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const DefaultSettingsForm = ({
|
||||||
|
defaultSettingsDefaults,
|
||||||
|
}: {
|
||||||
|
defaultSettingsDefaults: DefaultSettingsFormData;
|
||||||
|
}) => {
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const { t } = useTranslation();
|
||||||
|
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||||
|
|
||||||
|
const [editModelMetadata, { isLoading }] = useUpdateModelMetadataMutation();
|
||||||
|
|
||||||
|
const { handleSubmit, control, formState } = useForm<DefaultSettingsFormData>({
|
||||||
|
defaultValues: defaultSettingsDefaults,
|
||||||
|
});
|
||||||
|
|
||||||
|
const onSubmit = useCallback<SubmitHandler<DefaultSettingsFormData>>(
|
||||||
|
(data) => {
|
||||||
|
if (!selectedModelKey) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const body = {
|
||||||
|
vae: data.vae.isEnabled ? data.vae.value : null,
|
||||||
|
vae_precision: data.vaePrecision.isEnabled ? data.vaePrecision.value : null,
|
||||||
|
cfg_scale: data.cfgScale.isEnabled ? data.cfgScale.value : null,
|
||||||
|
cfg_rescale_multiplier: data.cfgRescaleMultiplier.isEnabled ? data.cfgRescaleMultiplier.value : null,
|
||||||
|
steps: data.steps.isEnabled ? data.steps.value : null,
|
||||||
|
scheduler: data.scheduler.isEnabled ? data.scheduler.value : null,
|
||||||
|
};
|
||||||
|
|
||||||
|
editModelMetadata({
|
||||||
|
key: selectedModelKey,
|
||||||
|
body: { default_settings: body },
|
||||||
|
})
|
||||||
|
.unwrap()
|
||||||
|
.then((_) => {
|
||||||
|
dispatch(
|
||||||
|
addToast(
|
||||||
|
makeToast({
|
||||||
|
title: t('modelManager.defaultSettingsSaved'),
|
||||||
|
status: 'success',
|
||||||
|
})
|
||||||
|
)
|
||||||
|
);
|
||||||
|
})
|
||||||
|
.catch((error) => {
|
||||||
|
if (error) {
|
||||||
|
dispatch(
|
||||||
|
addToast(
|
||||||
|
makeToast({
|
||||||
|
title: `${error.data.detail} `,
|
||||||
|
status: 'error',
|
||||||
|
})
|
||||||
|
)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
},
|
||||||
|
[selectedModelKey, dispatch, editModelMetadata, t]
|
||||||
|
);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<>
|
||||||
|
<Flex gap="2" justifyContent="space-between" w="full" mb={5}>
|
||||||
|
<Heading fontSize="md">{t('modelManager.defaultSettings')}</Heading>
|
||||||
|
<Button
|
||||||
|
size="sm"
|
||||||
|
leftIcon={<IoPencil />}
|
||||||
|
colorScheme="invokeYellow"
|
||||||
|
isDisabled={!formState.isDirty}
|
||||||
|
onClick={handleSubmit(onSubmit)}
|
||||||
|
type="submit"
|
||||||
|
isLoading={isLoading}
|
||||||
|
>
|
||||||
|
{t('common.save')}
|
||||||
|
</Button>
|
||||||
|
</Flex>
|
||||||
|
|
||||||
|
<Flex flexDir="column" gap={8}>
|
||||||
|
<Flex gap={8}>
|
||||||
|
<Flex gap={4} w="full">
|
||||||
|
<SettingToggle control={control} name="vae" />
|
||||||
|
<DefaultVae control={control} name="vae" />
|
||||||
|
</Flex>
|
||||||
|
<Flex gap={4} w="full">
|
||||||
|
<SettingToggle control={control} name="vaePrecision" />
|
||||||
|
<DefaultVaePrecision control={control} name="vaePrecision" />
|
||||||
|
</Flex>
|
||||||
|
</Flex>
|
||||||
|
<Flex gap={8}>
|
||||||
|
<Flex gap={4} w="full">
|
||||||
|
<SettingToggle control={control} name="scheduler" />
|
||||||
|
<DefaultScheduler control={control} name="scheduler" />
|
||||||
|
</Flex>
|
||||||
|
<Flex gap={4} w="full">
|
||||||
|
<SettingToggle control={control} name="steps" />
|
||||||
|
<DefaultSteps control={control} name="steps" />
|
||||||
|
</Flex>
|
||||||
|
</Flex>
|
||||||
|
<Flex gap={8}>
|
||||||
|
<Flex gap={4} w="full">
|
||||||
|
<SettingToggle control={control} name="cfgScale" />
|
||||||
|
<DefaultCfgScale control={control} name="cfgScale" />
|
||||||
|
</Flex>
|
||||||
|
<Flex gap={4} w="full">
|
||||||
|
<SettingToggle control={control} name="cfgRescaleMultiplier" />
|
||||||
|
<DefaultCfgRescaleMultiplier control={control} name="cfgRescaleMultiplier" />
|
||||||
|
</Flex>
|
||||||
|
</Flex>
|
||||||
|
</Flex>
|
||||||
|
</>
|
||||||
|
);
|
||||||
|
};
|
@ -0,0 +1,72 @@
|
|||||||
|
import { CompositeNumberInput, CompositeSlider, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||||
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||||
|
import { useCallback, useMemo } from 'react';
|
||||||
|
import type { UseControllerProps } from 'react-hook-form';
|
||||||
|
import { useController } from 'react-hook-form';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
|
import type { DefaultSettingsFormData } from './DefaultSettingsForm';
|
||||||
|
|
||||||
|
type DefaultSteps = DefaultSettingsFormData['steps'];
|
||||||
|
|
||||||
|
export function DefaultSteps(props: UseControllerProps<DefaultSettingsFormData>) {
|
||||||
|
const { field } = useController(props);
|
||||||
|
|
||||||
|
const sliderMin = useAppSelector((s) => s.config.sd.steps.sliderMin);
|
||||||
|
const sliderMax = useAppSelector((s) => s.config.sd.steps.sliderMax);
|
||||||
|
const numberInputMin = useAppSelector((s) => s.config.sd.steps.numberInputMin);
|
||||||
|
const numberInputMax = useAppSelector((s) => s.config.sd.steps.numberInputMax);
|
||||||
|
const coarseStep = useAppSelector((s) => s.config.sd.steps.coarseStep);
|
||||||
|
const fineStep = useAppSelector((s) => s.config.sd.steps.fineStep);
|
||||||
|
const { t } = useTranslation();
|
||||||
|
const marks = useMemo(() => [sliderMin, Math.floor(sliderMax / 2), sliderMax], [sliderMax, sliderMin]);
|
||||||
|
|
||||||
|
const onChange = useCallback(
|
||||||
|
(v: number) => {
|
||||||
|
const updatedValue = {
|
||||||
|
...(field.value as DefaultSteps),
|
||||||
|
value: v,
|
||||||
|
};
|
||||||
|
field.onChange(updatedValue);
|
||||||
|
},
|
||||||
|
[field]
|
||||||
|
);
|
||||||
|
|
||||||
|
const value = useMemo(() => {
|
||||||
|
return (field.value as DefaultSteps).value;
|
||||||
|
}, [field.value]);
|
||||||
|
|
||||||
|
const isDisabled = useMemo(() => {
|
||||||
|
return !(field.value as DefaultSteps).isEnabled;
|
||||||
|
}, [field.value]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<FormControl flexDir="column" gap={1} alignItems="flex-start">
|
||||||
|
<InformationalPopover feature="paramSteps">
|
||||||
|
<FormLabel>{t('parameters.steps')}</FormLabel>
|
||||||
|
</InformationalPopover>
|
||||||
|
<Flex w="full" gap={1}>
|
||||||
|
<CompositeSlider
|
||||||
|
value={value}
|
||||||
|
min={sliderMin}
|
||||||
|
max={sliderMax}
|
||||||
|
step={coarseStep}
|
||||||
|
fineStep={fineStep}
|
||||||
|
onChange={onChange}
|
||||||
|
marks={marks}
|
||||||
|
isDisabled={isDisabled}
|
||||||
|
/>
|
||||||
|
<CompositeNumberInput
|
||||||
|
value={value}
|
||||||
|
min={numberInputMin}
|
||||||
|
max={numberInputMax}
|
||||||
|
step={coarseStep}
|
||||||
|
fineStep={fineStep}
|
||||||
|
onChange={onChange}
|
||||||
|
isDisabled={isDisabled}
|
||||||
|
/>
|
||||||
|
</Flex>
|
||||||
|
</FormControl>
|
||||||
|
);
|
||||||
|
}
|
@ -0,0 +1,65 @@
|
|||||||
|
import type { ComboboxOnChange } from '@invoke-ai/ui-library';
|
||||||
|
import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||||
|
import { skipToken } from '@reduxjs/toolkit/query';
|
||||||
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||||
|
import { map } from 'lodash-es';
|
||||||
|
import { useCallback, useMemo } from 'react';
|
||||||
|
import type { UseControllerProps } from 'react-hook-form';
|
||||||
|
import { useController } from 'react-hook-form';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { useGetModelConfigQuery, useGetVaeModelsQuery } from 'services/api/endpoints/models';
|
||||||
|
|
||||||
|
import type { DefaultSettingsFormData } from './DefaultSettingsForm';
|
||||||
|
|
||||||
|
type DefaultVaeType = DefaultSettingsFormData['vae'];
|
||||||
|
|
||||||
|
export function DefaultVae(props: UseControllerProps<DefaultSettingsFormData>) {
|
||||||
|
const { t } = useTranslation();
|
||||||
|
const { field } = useController(props);
|
||||||
|
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||||
|
const { data: modelData } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
|
||||||
|
|
||||||
|
const { compatibleOptions } = useGetVaeModelsQuery(undefined, {
|
||||||
|
selectFromResult: ({ data }) => {
|
||||||
|
const modelArray = map(data?.entities);
|
||||||
|
const compatibleOptions = modelArray
|
||||||
|
.filter((vae) => vae.base === modelData?.base)
|
||||||
|
.map((vae) => ({ label: vae.name, value: vae.key }));
|
||||||
|
|
||||||
|
const defaultOption = { label: 'Default VAE', value: 'default' };
|
||||||
|
|
||||||
|
return { compatibleOptions: [defaultOption, ...compatibleOptions] };
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
const onChange = useCallback<ComboboxOnChange>(
|
||||||
|
(v) => {
|
||||||
|
const newValue = !v?.value ? 'default' : v.value;
|
||||||
|
|
||||||
|
const updatedValue = {
|
||||||
|
...(field.value as DefaultVaeType),
|
||||||
|
value: newValue,
|
||||||
|
};
|
||||||
|
field.onChange(updatedValue);
|
||||||
|
},
|
||||||
|
[field]
|
||||||
|
);
|
||||||
|
|
||||||
|
const value = useMemo(() => {
|
||||||
|
return compatibleOptions.find((vae) => vae.value === (field.value as DefaultVaeType).value);
|
||||||
|
}, [compatibleOptions, field.value]);
|
||||||
|
|
||||||
|
const isDisabled = useMemo(() => {
|
||||||
|
return !(field.value as DefaultVaeType).isEnabled;
|
||||||
|
}, [field.value]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<FormControl flexDir="column" gap={1} alignItems="flex-start">
|
||||||
|
<InformationalPopover feature="paramVAE">
|
||||||
|
<FormLabel>{t('modelManager.vae')}</FormLabel>
|
||||||
|
</InformationalPopover>
|
||||||
|
<Combobox isDisabled={isDisabled} value={value} options={compatibleOptions} onChange={onChange} />
|
||||||
|
</FormControl>
|
||||||
|
);
|
||||||
|
}
|
@ -0,0 +1,51 @@
|
|||||||
|
import type { ComboboxOnChange } from '@invoke-ai/ui-library';
|
||||||
|
import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||||
|
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||||
|
import { isParameterPrecision } from 'features/parameters/types/parameterSchemas';
|
||||||
|
import { useCallback, useMemo } from 'react';
|
||||||
|
import type { UseControllerProps } from 'react-hook-form';
|
||||||
|
import { useController } from 'react-hook-form';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
|
import type { DefaultSettingsFormData } from './DefaultSettingsForm';
|
||||||
|
|
||||||
|
const options = [
|
||||||
|
{ label: 'FP16', value: 'fp16' },
|
||||||
|
{ label: 'FP32', value: 'fp32' },
|
||||||
|
];
|
||||||
|
|
||||||
|
type DefaultVaePrecisionType = DefaultSettingsFormData['vaePrecision'];
|
||||||
|
|
||||||
|
export function DefaultVaePrecision(props: UseControllerProps<DefaultSettingsFormData>) {
|
||||||
|
const { t } = useTranslation();
|
||||||
|
const { field } = useController(props);
|
||||||
|
|
||||||
|
const onChange = useCallback<ComboboxOnChange>(
|
||||||
|
(v) => {
|
||||||
|
if (!isParameterPrecision(v?.value)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const updatedValue = {
|
||||||
|
...(field.value as DefaultVaePrecisionType),
|
||||||
|
value: v.value,
|
||||||
|
};
|
||||||
|
field.onChange(updatedValue);
|
||||||
|
},
|
||||||
|
[field]
|
||||||
|
);
|
||||||
|
|
||||||
|
const value = useMemo(() => options.find((o) => o.value === (field.value as DefaultVaePrecisionType).value), [field]);
|
||||||
|
|
||||||
|
const isDisabled = useMemo(() => {
|
||||||
|
return !(field.value as DefaultVaePrecisionType).isEnabled;
|
||||||
|
}, [field.value]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<FormControl flexDir="column" gap={1} alignItems="flex-start">
|
||||||
|
<InformationalPopover feature="paramVAEPrecision">
|
||||||
|
<FormLabel>{t('modelManager.vaePrecision')}</FormLabel>
|
||||||
|
</InformationalPopover>
|
||||||
|
<Combobox isDisabled={isDisabled} value={value} options={options} onChange={onChange} />
|
||||||
|
</FormControl>
|
||||||
|
);
|
||||||
|
}
|
@ -0,0 +1,28 @@
|
|||||||
|
import { Switch } from '@invoke-ai/ui-library';
|
||||||
|
import type { ChangeEvent } from 'react';
|
||||||
|
import { useCallback, useMemo } from 'react';
|
||||||
|
import type { UseControllerProps } from 'react-hook-form';
|
||||||
|
import { useController } from 'react-hook-form';
|
||||||
|
|
||||||
|
import type { DefaultSettingsFormData, FormField } from './DefaultSettingsForm';
|
||||||
|
|
||||||
|
export function SettingToggle<T>(props: UseControllerProps<DefaultSettingsFormData>) {
|
||||||
|
const { field } = useController(props);
|
||||||
|
|
||||||
|
const value = useMemo(() => {
|
||||||
|
return !!(field.value as FormField<T>).isEnabled;
|
||||||
|
}, [field.value]);
|
||||||
|
|
||||||
|
const onChange = useCallback(
|
||||||
|
(e: ChangeEvent<HTMLInputElement>) => {
|
||||||
|
const updatedValue: FormField<T> = {
|
||||||
|
...(field.value as FormField<T>),
|
||||||
|
isEnabled: e.target.checked,
|
||||||
|
};
|
||||||
|
field.onChange(updatedValue);
|
||||||
|
},
|
||||||
|
[field]
|
||||||
|
);
|
||||||
|
|
||||||
|
return <Switch isChecked={value} onChange={onChange} />;
|
||||||
|
}
|
@ -17,6 +17,7 @@ import type {
|
|||||||
VAEModelConfig,
|
VAEModelConfig,
|
||||||
} from 'services/api/types';
|
} from 'services/api/types';
|
||||||
|
|
||||||
|
import { DefaultSettings } from './DefaultSettings';
|
||||||
import { ModelAttrView } from './ModelAttrView';
|
import { ModelAttrView } from './ModelAttrView';
|
||||||
import { ModelConvert } from './ModelConvert';
|
import { ModelConvert } from './ModelConvert';
|
||||||
|
|
||||||
@ -71,7 +72,7 @@ export const ModelView = () => {
|
|||||||
return <Text>{t('common.somethingWentWrong')}</Text>;
|
return <Text>{t('common.somethingWentWrong')}</Text>;
|
||||||
}
|
}
|
||||||
return (
|
return (
|
||||||
<Flex flexDir="column" h="full">
|
<Flex flexDir="column" h="full" gap="2">
|
||||||
<Box layerStyle="second" borderRadius="base" p={3}>
|
<Box layerStyle="second" borderRadius="base" p={3}>
|
||||||
<Flex gap="2" justifyContent="flex-end" w="full">
|
<Flex gap="2" justifyContent="flex-end" w="full">
|
||||||
<Button size="sm" leftIcon={<IoPencil />} colorScheme="invokeYellow" onClick={handleEditModel}>
|
<Button size="sm" leftIcon={<IoPencil />} colorScheme="invokeYellow" onClick={handleEditModel}>
|
||||||
@ -118,6 +119,9 @@ export const ModelView = () => {
|
|||||||
)}
|
)}
|
||||||
</Flex>
|
</Flex>
|
||||||
</Box>
|
</Box>
|
||||||
|
<Box layerStyle="second" borderRadius="base" p={3}>
|
||||||
|
<DefaultSettings />
|
||||||
|
</Box>
|
||||||
</Flex>
|
</Flex>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
@ -344,8 +344,8 @@ export const buildCanvasInpaintGraph = (
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: MASK_RESIZE_UP,
|
node_id: INPAINT_CREATE_MASK,
|
||||||
field: 'image',
|
field: 'expanded_mask_area',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
node_id: MASK_RESIZE_DOWN,
|
node_id: MASK_RESIZE_DOWN,
|
||||||
|
@ -439,8 +439,8 @@ export const buildCanvasOutpaintGraph = (
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: MASK_RESIZE_UP,
|
node_id: INPAINT_CREATE_MASK,
|
||||||
field: 'image',
|
field: 'expanded_mask_area',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
node_id: MASK_RESIZE_DOWN,
|
node_id: MASK_RESIZE_DOWN,
|
||||||
|
@ -355,8 +355,8 @@ export const buildCanvasSDXLInpaintGraph = (
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: MASK_RESIZE_UP,
|
node_id: INPAINT_CREATE_MASK,
|
||||||
field: 'image',
|
field: 'expanded_mask_area',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
node_id: MASK_RESIZE_DOWN,
|
node_id: MASK_RESIZE_DOWN,
|
||||||
|
@ -448,8 +448,8 @@ export const buildCanvasSDXLOutpaintGraph = (
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: MASK_RESIZE_UP,
|
node_id: INPAINT_CREATE_MASK,
|
||||||
field: 'image',
|
field: 'expanded_mask_area',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
node_id: MASK_RESIZE_DOWN,
|
node_id: MASK_RESIZE_DOWN,
|
||||||
|
@ -0,0 +1,36 @@
|
|||||||
|
import type { IconButtonProps } from '@invoke-ai/ui-library';
|
||||||
|
import { IconButton } from '@invoke-ai/ui-library';
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { setActiveTab } from 'features/ui/store/uiSlice';
|
||||||
|
import { memo, useCallback, useMemo } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { PiGearSixBold } from 'react-icons/pi';
|
||||||
|
|
||||||
|
export const NavigateToModelManagerButton = memo((props: Omit<IconButtonProps, 'aria-label'>) => {
|
||||||
|
const { t } = useTranslation();
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const disabledTabs = useAppSelector((s) => s.config.disabledTabs);
|
||||||
|
const shouldShowButton = useMemo(() => !disabledTabs.includes('modelManager'), [disabledTabs]);
|
||||||
|
|
||||||
|
const handleClick = useCallback(() => {
|
||||||
|
dispatch(setActiveTab('modelManager'));
|
||||||
|
}, [dispatch]);
|
||||||
|
|
||||||
|
if (!shouldShowButton) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<IconButton
|
||||||
|
icon={<PiGearSixBold />}
|
||||||
|
tooltip={t('modelManager.modelManager')}
|
||||||
|
aria-label={t('modelManager.modelManager')}
|
||||||
|
onClick={handleClick}
|
||||||
|
size="sm"
|
||||||
|
variant="ghost"
|
||||||
|
{...props}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
NavigateToModelManagerButton.displayName = 'NavigateToModelManagerButton';
|
@ -0,0 +1,28 @@
|
|||||||
|
import { IconButton } from '@invoke-ai/ui-library';
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { setDefaultSettings } from 'features/parameters/store/actions';
|
||||||
|
import { useCallback } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { RiSparklingFill } from 'react-icons/ri';
|
||||||
|
|
||||||
|
export const UseDefaultSettingsButton = () => {
|
||||||
|
const model = useAppSelector((s) => s.generation.model);
|
||||||
|
const { t } = useTranslation();
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
|
const handleClickDefaultSettings = useCallback(() => {
|
||||||
|
dispatch(setDefaultSettings());
|
||||||
|
}, [dispatch]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<IconButton
|
||||||
|
icon={<RiSparklingFill />}
|
||||||
|
tooltip={t('modelManager.useDefaultSettings')}
|
||||||
|
aria-label={t('modelManager.useDefaultSettings')}
|
||||||
|
isDisabled={!model}
|
||||||
|
onClick={handleClickDefaultSettings}
|
||||||
|
size="sm"
|
||||||
|
variant="ghost"
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
};
|
@ -5,3 +5,5 @@ import type { ImageDTO } from 'services/api/types';
|
|||||||
export const initialImageSelected = createAction<ImageDTO | undefined>('generation/initialImageSelected');
|
export const initialImageSelected = createAction<ImageDTO | undefined>('generation/initialImageSelected');
|
||||||
|
|
||||||
export const modelSelected = createAction<ParameterModel>('generation/modelSelected');
|
export const modelSelected = createAction<ParameterModel>('generation/modelSelected');
|
||||||
|
|
||||||
|
export const setDefaultSettings = createAction('generation/setDefaultSettings');
|
||||||
|
@ -230,6 +230,12 @@ export const generationSlice = createSlice({
|
|||||||
state.height = optimalDimension;
|
state.height = optimalDimension;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (action.payload.sd?.scheduler) {
|
||||||
|
state.scheduler = action.payload.sd.scheduler;
|
||||||
|
}
|
||||||
|
if (action.payload.sd?.vaePrecision) {
|
||||||
|
state.vaePrecision = action.payload.sd.vaePrecision;
|
||||||
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
// TODO: This is a temp fix to reduce issues with T2I adapter having a different downscaling
|
// TODO: This is a temp fix to reduce issues with T2I adapter having a different downscaling
|
||||||
|
@ -1,15 +1,5 @@
|
|||||||
import type { FormLabelProps } from '@invoke-ai/ui-library';
|
import type { FormLabelProps } from '@invoke-ai/ui-library';
|
||||||
import {
|
import { Box, Expander, Flex, FormControlGroup, StandaloneAccordion } from '@invoke-ai/ui-library';
|
||||||
Expander,
|
|
||||||
Flex,
|
|
||||||
FormControlGroup,
|
|
||||||
StandaloneAccordion,
|
|
||||||
Tab,
|
|
||||||
TabList,
|
|
||||||
TabPanel,
|
|
||||||
TabPanels,
|
|
||||||
Tabs,
|
|
||||||
} from '@invoke-ai/ui-library';
|
|
||||||
import { EMPTY_ARRAY } from 'app/store/constants';
|
import { EMPTY_ARRAY } from 'app/store/constants';
|
||||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
@ -20,7 +10,9 @@ import { SyncModelsIconButton } from 'features/modelManagerV2/components/SyncMod
|
|||||||
import ParamCFGScale from 'features/parameters/components/Core/ParamCFGScale';
|
import ParamCFGScale from 'features/parameters/components/Core/ParamCFGScale';
|
||||||
import ParamScheduler from 'features/parameters/components/Core/ParamScheduler';
|
import ParamScheduler from 'features/parameters/components/Core/ParamScheduler';
|
||||||
import ParamSteps from 'features/parameters/components/Core/ParamSteps';
|
import ParamSteps from 'features/parameters/components/Core/ParamSteps';
|
||||||
|
import { NavigateToModelManagerButton } from 'features/parameters/components/MainModel/NavigateToModelManagerButton';
|
||||||
import ParamMainModelSelect from 'features/parameters/components/MainModel/ParamMainModelSelect';
|
import ParamMainModelSelect from 'features/parameters/components/MainModel/ParamMainModelSelect';
|
||||||
|
import { UseDefaultSettingsButton } from 'features/parameters/components/MainModel/UseDefaultSettingsButton';
|
||||||
import { useExpanderToggle } from 'features/settingsAccordions/hooks/useExpanderToggle';
|
import { useExpanderToggle } from 'features/settingsAccordions/hooks/useExpanderToggle';
|
||||||
import { useStandaloneAccordionToggle } from 'features/settingsAccordions/hooks/useStandaloneAccordionToggle';
|
import { useStandaloneAccordionToggle } from 'features/settingsAccordions/hooks/useStandaloneAccordionToggle';
|
||||||
import { filter } from 'lodash-es';
|
import { filter } from 'lodash-es';
|
||||||
@ -39,11 +31,11 @@ export const GenerationSettingsAccordion = memo(() => {
|
|||||||
() =>
|
() =>
|
||||||
createMemoizedSelector(selectLoraSlice, (lora) => {
|
createMemoizedSelector(selectLoraSlice, (lora) => {
|
||||||
const enabledLoRAsCount = filter(lora.loras, (l) => !!l.isEnabled).length;
|
const enabledLoRAsCount = filter(lora.loras, (l) => !!l.isEnabled).length;
|
||||||
const loraTabBadges = enabledLoRAsCount ? [enabledLoRAsCount] : EMPTY_ARRAY;
|
const loraTabBadges = enabledLoRAsCount ? [`${enabledLoRAsCount} ${t('models.concepts')}`] : EMPTY_ARRAY;
|
||||||
const accordionBadges = modelConfig ? [modelConfig.name, modelConfig.base] : EMPTY_ARRAY;
|
const accordionBadges = modelConfig ? [modelConfig.name, modelConfig.base] : EMPTY_ARRAY;
|
||||||
return { loraTabBadges, accordionBadges };
|
return { loraTabBadges, accordionBadges };
|
||||||
}),
|
}),
|
||||||
[modelConfig]
|
[modelConfig, t]
|
||||||
);
|
);
|
||||||
const { loraTabBadges, accordionBadges } = useAppSelector(selectBadges);
|
const { loraTabBadges, accordionBadges } = useAppSelector(selectBadges);
|
||||||
const { isOpen: isOpenExpander, onToggle: onToggleExpander } = useExpanderToggle({
|
const { isOpen: isOpenExpander, onToggle: onToggleExpander } = useExpanderToggle({
|
||||||
@ -58,39 +50,35 @@ export const GenerationSettingsAccordion = memo(() => {
|
|||||||
return (
|
return (
|
||||||
<StandaloneAccordion
|
<StandaloneAccordion
|
||||||
label={t('accordions.generation.title')}
|
label={t('accordions.generation.title')}
|
||||||
badges={accordionBadges}
|
badges={[...accordionBadges, ...loraTabBadges]}
|
||||||
isOpen={isOpenAccordion}
|
isOpen={isOpenAccordion}
|
||||||
onToggle={onToggleAccordion}
|
onToggle={onToggleAccordion}
|
||||||
>
|
>
|
||||||
<Tabs variant="collapse">
|
<Box px={4} pt={4}>
|
||||||
<TabList>
|
<Flex gap={4} flexDir="column">
|
||||||
<Tab>{t('accordions.generation.modelTab')}</Tab>
|
<Flex gap={4} alignItems="center">
|
||||||
<Tab badges={loraTabBadges}>{t('accordions.generation.conceptsTab')}</Tab>
|
<ParamMainModelSelect />
|
||||||
</TabList>
|
<Flex>
|
||||||
<TabPanels>
|
<UseDefaultSettingsButton />
|
||||||
<TabPanel overflow="visible" px={4} pt={4}>
|
|
||||||
<Flex gap={4} alignItems="center">
|
|
||||||
<ParamMainModelSelect />
|
|
||||||
<SyncModelsIconButton />
|
<SyncModelsIconButton />
|
||||||
|
<NavigateToModelManagerButton />
|
||||||
</Flex>
|
</Flex>
|
||||||
<Expander isOpen={isOpenExpander} onToggle={onToggleExpander}>
|
</Flex>
|
||||||
<Flex gap={4} flexDir="column" pb={4}>
|
<Flex gap={4} flexDir="column">
|
||||||
<FormControlGroup formLabelProps={formLabelProps}>
|
<LoRASelect />
|
||||||
<ParamScheduler />
|
<LoRAList />
|
||||||
<ParamSteps />
|
</Flex>
|
||||||
<ParamCFGScale />
|
</Flex>
|
||||||
</FormControlGroup>
|
<Expander isOpen={isOpenExpander} onToggle={onToggleExpander}>
|
||||||
</Flex>
|
<Flex gap={4} flexDir="column" pb={4}>
|
||||||
</Expander>
|
<FormControlGroup formLabelProps={formLabelProps}>
|
||||||
</TabPanel>
|
<ParamScheduler />
|
||||||
<TabPanel>
|
<ParamSteps />
|
||||||
<Flex gap={4} p={4} flexDir="column">
|
<ParamCFGScale />
|
||||||
<LoRASelect />
|
</FormControlGroup>
|
||||||
<LoRAList />
|
</Flex>
|
||||||
</Flex>
|
</Expander>
|
||||||
</TabPanel>
|
</Box>
|
||||||
</TabPanels>
|
|
||||||
</Tabs>
|
|
||||||
</StandaloneAccordion>
|
</StandaloneAccordion>
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
@ -41,6 +41,8 @@ const initialConfigState: AppConfig = {
|
|||||||
boundingBoxHeight: { ...baseDimensionConfig },
|
boundingBoxHeight: { ...baseDimensionConfig },
|
||||||
scaledBoundingBoxWidth: { ...baseDimensionConfig },
|
scaledBoundingBoxWidth: { ...baseDimensionConfig },
|
||||||
scaledBoundingBoxHeight: { ...baseDimensionConfig },
|
scaledBoundingBoxHeight: { ...baseDimensionConfig },
|
||||||
|
scheduler: 'euler',
|
||||||
|
vaePrecision: 'fp32',
|
||||||
steps: {
|
steps: {
|
||||||
initial: 30,
|
initial: 30,
|
||||||
sliderMin: 1,
|
sliderMin: 1,
|
||||||
|
File diff suppressed because one or more lines are too long
@ -51,12 +51,12 @@ dependencies = [
|
|||||||
"torchmetrics==0.11.4",
|
"torchmetrics==0.11.4",
|
||||||
"torchsde==0.2.6",
|
"torchsde==0.2.6",
|
||||||
"torchvision==0.16.2",
|
"torchvision==0.16.2",
|
||||||
"transformers==4.37.2",
|
"transformers==4.38.2",
|
||||||
|
|
||||||
# Core application dependencies, pinned for reproducible builds.
|
# Core application dependencies, pinned for reproducible builds.
|
||||||
"fastapi-events==0.10.1",
|
"fastapi-events==0.10.1",
|
||||||
"fastapi==0.109.2",
|
"fastapi==0.109.2",
|
||||||
"huggingface-hub==0.20.3",
|
"huggingface-hub==0.21.3",
|
||||||
"pydantic-settings==2.1.0",
|
"pydantic-settings==2.1.0",
|
||||||
"pydantic==2.6.1",
|
"pydantic==2.6.1",
|
||||||
"python-socketio==5.11.1",
|
"python-socketio==5.11.1",
|
||||||
@ -64,6 +64,7 @@ dependencies = [
|
|||||||
|
|
||||||
# Auxiliary dependencies, pinned only if necessary.
|
# Auxiliary dependencies, pinned only if necessary.
|
||||||
"albumentations",
|
"albumentations",
|
||||||
|
"blake3",
|
||||||
"click",
|
"click",
|
||||||
"datasets",
|
"datasets",
|
||||||
"Deprecated",
|
"Deprecated",
|
||||||
@ -72,7 +73,6 @@ dependencies = [
|
|||||||
"easing-functions",
|
"easing-functions",
|
||||||
"einops",
|
"einops",
|
||||||
"facexlib",
|
"facexlib",
|
||||||
"imohash",
|
|
||||||
"matplotlib", # needed for plotting of Penner easing functions
|
"matplotlib", # needed for plotting of Penner easing functions
|
||||||
"npyscreen",
|
"npyscreen",
|
||||||
"omegaconf",
|
"omegaconf",
|
||||||
|
@ -3,6 +3,7 @@ Test the model installer
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import platform
|
import platform
|
||||||
|
import uuid
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -30,9 +31,8 @@ def test_registration(mm2_installer: ModelInstallServiceBase, embedding_file: Pa
|
|||||||
matches = store.search_by_attr(model_name="test_embedding")
|
matches = store.search_by_attr(model_name="test_embedding")
|
||||||
assert len(matches) == 0
|
assert len(matches) == 0
|
||||||
key = mm2_installer.register_path(embedding_file)
|
key = mm2_installer.register_path(embedding_file)
|
||||||
assert key is not None
|
# Not raising here is sufficient - key should be UUIDv4
|
||||||
assert key != "<NOKEY>"
|
uuid.UUID(key, version=4)
|
||||||
assert len(key) == 32
|
|
||||||
|
|
||||||
|
|
||||||
def test_registration_meta(mm2_installer: ModelInstallServiceBase, embedding_file: Path) -> None:
|
def test_registration_meta(mm2_installer: ModelInstallServiceBase, embedding_file: Path) -> None:
|
||||||
|
96
tests/test_model_hash.py
Normal file
96
tests/test_model_hash.py
Normal file
@ -0,0 +1,96 @@
|
|||||||
|
# pyright:reportPrivateUsage=false
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Iterable
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from blake3 import blake3
|
||||||
|
|
||||||
|
from invokeai.backend.model_manager.hash import ALGORITHM, MODEL_FILE_EXTENSIONS, ModelHash
|
||||||
|
|
||||||
|
test_cases: list[tuple[ALGORITHM, str]] = [
|
||||||
|
("md5", "a0cd925fc063f98dbf029eee315060c3"),
|
||||||
|
("sha1", "9e362940e5603fdc60566ea100a288ba2fe48b8c"),
|
||||||
|
("sha256", "6dbdb6a147ad4d808455652bf5a10120161678395f6bfbd21eb6fe4e731aceeb"),
|
||||||
|
(
|
||||||
|
"sha512",
|
||||||
|
"c4a10476b21e00042f638ad5755c561d91f2bb599d3504d25409495e1c7eda94543332a1a90fbb4efdaf9ee462c33e0336b5eae4acfb1fa0b186af452dd67dc6",
|
||||||
|
),
|
||||||
|
("blake3", "ce3f0c5f3c05d119f4a5dcaf209b50d3149046a0d3a9adee9fed4c83cad6b4d0"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("algorithm,expected_hash", test_cases)
|
||||||
|
def test_model_hash_hashes_file(tmp_path: Path, algorithm: ALGORITHM, expected_hash: str):
|
||||||
|
file = Path(tmp_path / "test")
|
||||||
|
file.write_text("model data")
|
||||||
|
md5 = ModelHash(algorithm).hash(file)
|
||||||
|
assert md5 == expected_hash
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("algorithm", ["md5", "sha1", "sha256", "sha512", "blake3"])
|
||||||
|
def test_model_hash_hashes_dir(tmp_path: Path, algorithm: ALGORITHM):
|
||||||
|
model_hash = ModelHash(algorithm)
|
||||||
|
files = [Path(tmp_path, f"{i}.bin") for i in range(5)]
|
||||||
|
|
||||||
|
for f in files:
|
||||||
|
f.write_text("data")
|
||||||
|
|
||||||
|
md5 = model_hash.hash(tmp_path)
|
||||||
|
|
||||||
|
# Manual implementation of composite hash - always uses BLAKE3
|
||||||
|
composite_hasher = blake3()
|
||||||
|
for f in files:
|
||||||
|
h = model_hash.hash(f)
|
||||||
|
composite_hasher.update(h.encode("utf-8"))
|
||||||
|
|
||||||
|
assert md5 == composite_hasher.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
def test_model_hash_raises_error_on_invalid_algorithm():
|
||||||
|
with pytest.raises(ValueError, match="Algorithm invalid_algorithm not available"):
|
||||||
|
ModelHash("invalid_algorithm") # pyright: ignore [reportArgumentType]
|
||||||
|
|
||||||
|
|
||||||
|
def paths_to_str_set(paths: Iterable[Path]) -> set[str]:
|
||||||
|
return {str(p) for p in paths}
|
||||||
|
|
||||||
|
|
||||||
|
def test_model_hash_filters_out_non_model_files(tmp_path: Path):
|
||||||
|
model_files = {Path(tmp_path, f"{i}{ext}") for i, ext in enumerate(MODEL_FILE_EXTENSIONS)}
|
||||||
|
|
||||||
|
for i, f in enumerate(model_files):
|
||||||
|
f.write_text(f"data{i}")
|
||||||
|
|
||||||
|
assert paths_to_str_set(ModelHash._get_file_paths(tmp_path, ModelHash._default_file_filter)) == paths_to_str_set(
|
||||||
|
model_files
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add file that should be ignored - hash should not change
|
||||||
|
file = tmp_path / "test.icecream"
|
||||||
|
file.write_text("data")
|
||||||
|
|
||||||
|
assert paths_to_str_set(ModelHash._get_file_paths(tmp_path, ModelHash._default_file_filter)) == paths_to_str_set(
|
||||||
|
model_files
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add file that should not be ignored - hash should change
|
||||||
|
file = tmp_path / "test.bin"
|
||||||
|
file.write_text("more data")
|
||||||
|
model_files.add(file)
|
||||||
|
|
||||||
|
assert paths_to_str_set(ModelHash._get_file_paths(tmp_path, ModelHash._default_file_filter)) == paths_to_str_set(
|
||||||
|
model_files
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_model_hash_uses_custom_filter(tmp_path: Path):
|
||||||
|
model_files = {Path(tmp_path, f"file{ext}") for ext in [".pickme", ".ignoreme"]}
|
||||||
|
|
||||||
|
for i, f in enumerate(model_files):
|
||||||
|
f.write_text(f"data{i}")
|
||||||
|
|
||||||
|
def file_filter(file_path: str) -> bool:
|
||||||
|
return file_path.endswith(".pickme")
|
||||||
|
|
||||||
|
assert {p.name for p in ModelHash._get_file_paths(tmp_path, file_filter)} == {"file.pickme"}
|
Reference in New Issue
Block a user